Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a837166e
Unverified
Commit
a837166e
authored
Jan 21, 2024
by
Lianmin Zheng
Committed by
GitHub
Jan 21, 2024
Browse files
Fix select and normalized logprobs (#67)
parent
11f3cca6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
52 additions
and
101 deletions
+52
-101
python/sglang/srt/layers/get_selected_logprob.py
python/sglang/srt/layers/get_selected_logprob.py
+0
-79
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+45
-15
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+7
-7
No files found.
python/sglang/srt/layers/get_selected_logprob.py
deleted
100644 → 0
View file @
11f3cca6
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.utils
import
wrap_kernel_launcher
@
triton
.
jit
def
_fwd_segmented_gather
(
all_logits
,
len_add_1
,
cum_len
,
input_ids
,
logprobs
,
max_seq_len
,
voc_size
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
cur_req
=
tl
.
program_id
(
0
)
cur_l
=
tl
.
load
(
len_add_1
+
cur_req
)
cum_l
=
tl
.
load
(
cum_len
+
cur_req
)
for
i
in
range
(
0
,
(
max_seq_len
+
BLOCK_SIZE
-
1
)
//
BLOCK_SIZE
):
off
=
i
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
off
<
cur_l
-
1
idx
=
tl
.
load
(
input_ids
+
cum_l
-
cur_l
+
off
+
1
,
mask
=
mask
)
data
=
tl
.
load
(
all_logits
+
(
cum_l
-
cur_l
+
off
)
*
voc_size
+
idx
,
mask
=
mask
)
tl
.
store
(
logprobs
+
cum_l
-
cur_l
-
cur_req
+
off
,
data
,
mask
=
mask
)
cached_kernel
=
None
def
get_selected_logprob
(
all_logits
,
len_add_1
,
input_ids
,
logprobs
):
cum_len
=
torch
.
cumsum
(
len_add_1
,
dtype
=
torch
.
int32
,
dim
=
0
)
voc_size
=
all_logits
.
shape
[
1
]
grid
=
(
len_add_1
.
shape
[
0
],
1
,
1
)
max_seq_len
=
len_add_1
.
max
().
item
()
global
cached_kernel
if
cached_kernel
:
cached_kernel
(
grid
,
4
,
all_logits
,
len_add_1
,
cum_len
,
input_ids
,
logprobs
,
max_seq_len
,
)
return
_fwd_segmented_gather
[
grid
](
all_logits
,
len_add_1
,
cum_len
,
input_ids
,
logprobs
,
max_seq_len
,
voc_size
,
BLOCK_SIZE
=
128
,
)
cached_kernel
=
wrap_kernel_launcher
(
_fwd_segmented_gather
)
if
__name__
==
"__main__"
:
all_logits
=
torch
.
tensor
(
# s s s
[[
0
,
1
,
2
,
3
],
[
1
,
2
,
3
,
4
],
[
2
,
3
,
4
,
5
],
[
3
,
4
,
5
,
6
],
[
4
,
5
,
6
,
7
]],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
len_add_1
=
torch
.
tensor
([
2
,
3
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
input_ids
=
torch
.
tensor
([
1
,
2
,
3
,
0
,
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
logprobs
=
torch
.
empty
((
3
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
get_selected_logprobs
(
all_logits
,
len_add_1
,
input_ids
,
logprobs
)
print
(
logprobs
)
# assert logprobs == [2, 2, 4]
python/sglang/srt/layers/logits_processor.py
View file @
a837166e
import
torch
import
torch
from
sglang.srt.layers.get_selected_logprob
import
get_selected_logprob
from
sglang.srt.managers.router.model_runner
import
ForwardMode
,
InputMetadata
from
sglang.srt.managers.router.model_runner
import
ForwardMode
,
InputMetadata
from
torch
import
nn
from
torch
import
nn
from
vllm.model_executor.parallel_utils.communication_op
import
(
from
vllm.model_executor.parallel_utils.communication_op
import
(
...
@@ -54,25 +53,56 @@ class LogitsProcessor(nn.Module):
...
@@ -54,25 +53,56 @@ class LogitsProcessor(nn.Module):
normalized_logprobs
=
compute_normalized_logprobs
(
normalized_logprobs
=
compute_normalized_logprobs
(
all_logprobs
,
all_logprobs
,
input_metadata
.
seq_lens
-
input_metadata
.
prefix_lens
,
input_ids
,
input_ids
,
input_metadata
.
extend_seq_lens
,
input_metadata
.
extend_start_loc
,
)
)
last_logits
=
logits
[
last_index
]
last_logits
=
logits
[
last_index
]
return
last_logits
,
normalized_logprobs
return
last_logits
,
normalized_logprobs
def
compute_normalized_logprobs
(
all_logprobs
,
len_add_1
,
input_ids
):
def
compute_normalized_logprobs
(
all_logprobs
,
input_ids
,
seq_lens
,
start_loc
):
# assert all_logprobs.shape[0] == input_ids.shape[0] == torch.sum(len_add_1)
logprobs
=
all_logprobs
[
logprobs
=
torch
.
zeros
(
torch
.
arange
(
all_logprobs
.
shape
[
0
],
device
=
"cuda"
),
(
all_logprobs
.
shape
[
0
]
-
len_add_1
.
shape
[
0
]),
dtype
=
torch
.
float32
,
device
=
"cuda"
torch
.
cat
([
input_ids
[
1
:],
torch
.
tensor
([
0
],
device
=
"cuda"
)]),
]
logprobs_cumsum
=
torch
.
cumsum
(
logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
start
=
start_loc
.
clone
()
end
=
start
+
seq_lens
-
2
start
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
end
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
sum_logp
=
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
logprobs
[
start
]
return
sum_logp
/
((
seq_lens
-
1
).
clamp
(
min
=
1
))
if
__name__
==
"__main__"
:
all_logprobs
=
torch
.
tensor
(
# s s s
[[
0
,
1
,
2
,
3
],
[
1
,
2
,
3
,
4
],
[
2
,
3
,
4
,
5
],
[
3
,
4
,
5
,
6
],
[
4
,
5
,
6
,
7
]],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
)
get_selected_logprob
(
all_logprobs
,
len_add_1
,
input_ids
,
logprobs
)
seq_lens
=
torch
.
tensor
([
2
,
0
,
3
,
0
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
cumsum
=
torch
.
cumsum
(
logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
input_ids
=
torch
.
tensor
([
1
,
2
,
3
,
0
,
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
end
=
torch
.
cumsum
(
len_add_1
.
sub_
(
1
),
dim
=
0
)
logprobs
=
torch
.
zeros
(
5
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
start
=
torch
.
cat
((
torch
.
tensor
([
0
],
device
=
"cuda"
),
end
[:
-
1
]),
0
)
end
.
sub_
(
1
)
logprobs
=
all_logprobs
[
torch
.
cuda
.
synchronize
()
torch
.
arange
(
all_logprobs
.
shape
[
0
],
device
=
"cuda"
),
sum_logp
=
cumsum
[
end
]
-
cumsum
[
start
]
+
logprobs
[
start
]
torch
.
cat
([
input_ids
[
1
:],
torch
.
tensor
([
0
],
device
=
"cuda"
)]),
res
=
sum_logp
/
len_add_1
]
return
res
logprobs_cumsum
=
torch
.
cumsum
(
logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
len_cumsum
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
start
=
torch
.
cat
((
torch
.
tensor
([
0
],
device
=
"cuda"
),
len_cumsum
[:
-
1
]),
0
)
end
=
start
+
seq_lens
-
2
start
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
end
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
sum_logp
=
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
logprobs
[
start
]
# assert logprobs == [2, _, 2, 4, _]
print
(
"logprobs"
,
logprobs
)
print
(
"start"
,
start
)
print
(
"end"
,
end
)
print
(
"sum_logp"
,
sum_logp
)
python/sglang/srt/managers/router/model_runner.py
View file @
a837166e
import
logging
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
typing
import
List
from
typing
import
List
import
logging
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -13,7 +13,6 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig
...
@@ -13,7 +13,6 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig
from
vllm.model_executor.model_loader
import
_set_default_torch_dtype
from
vllm.model_executor.model_loader
import
_set_default_torch_dtype
from
vllm.model_executor.parallel_utils.parallel_state
import
initialize_model_parallel
from
vllm.model_executor.parallel_utils.parallel_state
import
initialize_model_parallel
logger
=
logging
.
getLogger
(
"model_runner"
)
logger
=
logging
.
getLogger
(
"model_runner"
)
...
@@ -112,7 +111,7 @@ class InputMetadata:
...
@@ -112,7 +111,7 @@ class InputMetadata:
def
init_extend_args
(
self
):
def
init_extend_args
(
self
):
self
.
extend_seq_lens
=
self
.
seq_lens
-
self
.
prefix_lens
self
.
extend_seq_lens
=
self
.
seq_lens
-
self
.
prefix_lens
self
.
extend_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
)
self
.
extend_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
)
self
.
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
[:
-
1
],
0
)
self
.
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
[:
-
1
],
dim
=
0
)
self
.
max_extend_len
=
int
(
torch
.
max
(
self
.
extend_seq_lens
))
self
.
max_extend_len
=
int
(
torch
.
max
(
self
.
extend_seq_lens
))
@
classmethod
@
classmethod
...
@@ -262,7 +261,7 @@ class ModelRunner:
...
@@ -262,7 +261,7 @@ class ModelRunner:
if
model_class
is
None
:
if
model_class
is
None
:
raise
ValueError
(
f
"Unsupported architectures:
{
architectures
}
"
)
raise
ValueError
(
f
"Unsupported architectures:
{
architectures
}
"
)
logger
.
info
(
"
load weight begin."
)
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
:
load weight begin."
)
# Load weights
# Load weights
linear_method
=
None
linear_method
=
None
...
@@ -287,7 +286,7 @@ class ModelRunner:
...
@@ -287,7 +286,7 @@ class ModelRunner:
)
)
self
.
model
=
model
.
eval
()
self
.
model
=
model
.
eval
()
logger
.
info
(
"
load weight end."
)
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
:
load weight end."
)
def
profile_max_num_token
(
self
,
total_gpu_memory
):
def
profile_max_num_token
(
self
,
total_gpu_memory
):
available_gpu_memory
=
get_available_gpu_memory
(
available_gpu_memory
=
get_available_gpu_memory
(
...
@@ -308,8 +307,9 @@ class ModelRunner:
...
@@ -308,8 +307,9 @@ class ModelRunner:
self
.
max_total_num_token
=
self
.
profile_max_num_token
(
total_gpu_memory
)
self
.
max_total_num_token
=
self
.
profile_max_num_token
(
total_gpu_memory
)
if
self
.
max_total_num_token
<=
0
:
if
self
.
max_total_num_token
<=
0
:
raise
RuntimeError
(
"Not enought memory. "
raise
RuntimeError
(
"Please try to increase --mem-fraction-static."
)
"Not enought memory. "
"Please try to increase --mem-fraction-static."
)
self
.
req_to_token_pool
=
ReqToTokenPool
(
self
.
req_to_token_pool
=
ReqToTokenPool
(
int
(
self
.
max_total_num_token
/
self
.
model_config
.
context_len
*
256
),
int
(
self
.
max_total_num_token
/
self
.
model_config
.
context_len
*
256
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment