Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7a1aecb9
Unverified
Commit
7a1aecb9
authored
Dec 16, 2024
by
Lianmin Zheng
Committed by
GitHub
Dec 16, 2024
Browse files
Simplify pytorch sampling kernel and logit processor (#2491)
parent
82699474
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
136 additions
and
107 deletions
+136
-107
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+98
-84
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+27
-5
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+3
-7
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+8
-6
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+0
-5
No files found.
python/sglang/srt/layers/logits_processor.py
View file @
7a1aecb9
...
...
@@ -100,82 +100,9 @@ class LogitsProcessor(nn.Module):
self
.
do_tensor_parallel_all_gather
=
(
not
skip_all_gather
and
get_tensor_model_parallel_world_size
()
>
1
)
def
_get_normalized_prompt_logprobs
(
self
,
input_token_logprobs
:
torch
.
Tensor
,
logits_metadata
:
LogitsMetadata
,
):
logprobs_cumsum
=
torch
.
cumsum
(
input_token_logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
pruned_lens
=
torch
.
tensor
(
logits_metadata
.
extend_logprob_pruned_lens_cpu
,
device
=
"cuda"
)
start
=
torch
.
zeros_like
(
pruned_lens
)
start
[
1
:]
=
torch
.
cumsum
(
pruned_lens
[:
-
1
],
dim
=
0
)
end
=
torch
.
clamp
(
start
+
pruned_lens
-
2
,
min
=
0
,
max
=
logprobs_cumsum
.
shape
[
0
]
-
1
self
.
final_logit_softcapping
=
getattr
(
self
.
config
,
"final_logit_softcapping"
,
None
)
sum_logp
=
(
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
input_token_logprobs
[
start
]
)
normalized_prompt_logprobs
=
sum_logp
/
(
pruned_lens
-
1
).
clamp
(
min
=
1
)
return
normalized_prompt_logprobs
@
staticmethod
def
get_top_logprobs
(
all_logprobs
:
torch
.
Tensor
,
logits_metadata
:
LogitsMetadata
):
max_k
=
max
(
logits_metadata
.
top_logprobs_nums
)
ret
=
all_logprobs
.
topk
(
max_k
,
dim
=
1
)
values
=
ret
.
values
.
tolist
()
indices
=
ret
.
indices
.
tolist
()
if
logits_metadata
.
forward_mode
.
is_decode
():
output_top_logprobs_val
=
[]
output_top_logprobs_idx
=
[]
for
i
,
k
in
enumerate
(
logits_metadata
.
top_logprobs_nums
):
output_top_logprobs_val
.
append
(
values
[
i
][:
k
])
output_top_logprobs_idx
.
append
(
indices
[
i
][:
k
])
return
None
,
None
,
output_top_logprobs_val
,
output_top_logprobs_idx
else
:
input_top_logprobs_val
,
input_top_logprobs_idx
=
[],
[]
output_top_logprobs_val
,
output_top_logprobs_idx
=
[],
[]
pt
=
0
for
k
,
pruned_len
in
zip
(
logits_metadata
.
top_logprobs_nums
,
logits_metadata
.
extend_logprob_pruned_lens_cpu
,
):
if
pruned_len
<=
0
:
input_top_logprobs_val
.
append
([])
input_top_logprobs_idx
.
append
([])
output_top_logprobs_val
.
append
([])
output_top_logprobs_idx
.
append
([])
continue
input_top_logprobs_val
.
append
(
[
values
[
pt
+
j
][:
k
]
for
j
in
range
(
pruned_len
-
1
)]
)
input_top_logprobs_idx
.
append
(
[
indices
[
pt
+
j
][:
k
]
for
j
in
range
(
pruned_len
-
1
)]
)
output_top_logprobs_val
.
append
(
list
(
values
[
pt
+
pruned_len
-
1
][:
k
],
)
)
output_top_logprobs_idx
.
append
(
list
(
indices
[
pt
+
pruned_len
-
1
][:
k
],
)
)
pt
+=
pruned_len
return
(
input_top_logprobs_val
,
input_top_logprobs_idx
,
output_top_logprobs_val
,
output_top_logprobs_idx
,
)
def
forward
(
self
,
...
...
@@ -201,10 +128,10 @@ class LogitsProcessor(nn.Module):
last_logits
=
tensor_model_parallel_all_gather
(
last_logits
)
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
if
hasattr
(
self
.
config
,
"
final_logit_softcapping
"
)
:
last_logits
.
div_
(
self
.
config
.
final_logit_softcapping
)
if
self
.
final_logit_softcapping
:
last_logits
.
div_
(
self
.
final_logit_softcapping
)
torch
.
tanh
(
last_logits
,
out
=
last_logits
)
last_logits
.
mul_
(
self
.
config
.
final_logit_softcapping
)
last_logits
.
mul_
(
self
.
final_logit_softcapping
)
# Return only last_logits if logprob is not requested
if
not
logits_metadata
.
return_logprob
:
...
...
@@ -212,7 +139,9 @@ class LogitsProcessor(nn.Module):
next_token_logits
=
last_logits
,
)
else
:
last_logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
last_logits
,
dim
=-
1
)
last_logprobs
=
self
.
compute_temp_top_p_normalized_logprobs
(
last_logits
,
logits_metadata
)
if
logits_metadata
.
forward_mode
.
is_decode
():
if
logits_metadata
.
return_top_logprob
:
...
...
@@ -248,14 +177,17 @@ class LogitsProcessor(nn.Module):
# extra logits that this padding may have produced.
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
if
hasattr
(
self
.
config
,
"
final_logit_softcapping
"
)
:
all_logits
.
div_
(
self
.
config
.
final_logit_softcapping
)
if
self
.
final_logit_softcapping
:
all_logits
.
div_
(
self
.
final_logit_softcapping
)
torch
.
tanh
(
all_logits
,
out
=
all_logits
)
all_logits
.
mul_
(
self
.
config
.
final_logit_softcapping
)
all_logits
.
mul_
(
self
.
final_logit_softcapping
)
all_logprobs
=
all_logits
del
all_logits
,
hidden_states
all_logprobs
[:]
=
torch
.
nn
.
functional
.
log_softmax
(
all_logprobs
,
dim
=-
1
)
all_logprobs
=
self
.
compute_temp_top_p_normalized_logprobs
(
all_logprobs
,
logits_metadata
)
# Get the logprob of top-k tokens
if
logits_metadata
.
return_top_logprob
:
...
...
@@ -309,11 +241,93 @@ class LogitsProcessor(nn.Module):
# GGUF models
logits
=
lm_head
.
linear_method
.
apply
(
lm_head
,
hidden_states
,
embedding_bias
)
# Optional scaling factor
, backported from vLLM 0.4
# Optional scaling factor
if
self
.
logit_scale
is
not
None
:
logits
.
mul_
(
self
.
logit_scale
)
# In-place multiply
return
logits
@
staticmethod
def
_get_normalized_prompt_logprobs
(
input_token_logprobs
:
torch
.
Tensor
,
logits_metadata
:
LogitsMetadata
,
):
logprobs_cumsum
=
torch
.
cumsum
(
input_token_logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
pruned_lens
=
torch
.
tensor
(
logits_metadata
.
extend_logprob_pruned_lens_cpu
,
device
=
"cuda"
)
start
=
torch
.
zeros_like
(
pruned_lens
)
start
[
1
:]
=
torch
.
cumsum
(
pruned_lens
[:
-
1
],
dim
=
0
)
end
=
torch
.
clamp
(
start
+
pruned_lens
-
2
,
min
=
0
,
max
=
logprobs_cumsum
.
shape
[
0
]
-
1
)
sum_logp
=
(
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
input_token_logprobs
[
start
]
)
normalized_prompt_logprobs
=
sum_logp
/
(
pruned_lens
-
1
).
clamp
(
min
=
1
)
return
normalized_prompt_logprobs
@
staticmethod
def
get_top_logprobs
(
all_logprobs
:
torch
.
Tensor
,
logits_metadata
:
LogitsMetadata
):
max_k
=
max
(
logits_metadata
.
top_logprobs_nums
)
ret
=
all_logprobs
.
topk
(
max_k
,
dim
=
1
)
values
=
ret
.
values
.
tolist
()
indices
=
ret
.
indices
.
tolist
()
if
logits_metadata
.
forward_mode
.
is_decode
():
output_top_logprobs_val
=
[]
output_top_logprobs_idx
=
[]
for
i
,
k
in
enumerate
(
logits_metadata
.
top_logprobs_nums
):
output_top_logprobs_val
.
append
(
values
[
i
][:
k
])
output_top_logprobs_idx
.
append
(
indices
[
i
][:
k
])
return
None
,
None
,
output_top_logprobs_val
,
output_top_logprobs_idx
else
:
input_top_logprobs_val
,
input_top_logprobs_idx
=
[],
[]
output_top_logprobs_val
,
output_top_logprobs_idx
=
[],
[]
pt
=
0
for
k
,
pruned_len
in
zip
(
logits_metadata
.
top_logprobs_nums
,
logits_metadata
.
extend_logprob_pruned_lens_cpu
,
):
if
pruned_len
<=
0
:
input_top_logprobs_val
.
append
([])
input_top_logprobs_idx
.
append
([])
output_top_logprobs_val
.
append
([])
output_top_logprobs_idx
.
append
([])
continue
input_top_logprobs_val
.
append
(
[
values
[
pt
+
j
][:
k
]
for
j
in
range
(
pruned_len
-
1
)]
)
input_top_logprobs_idx
.
append
(
[
indices
[
pt
+
j
][:
k
]
for
j
in
range
(
pruned_len
-
1
)]
)
output_top_logprobs_val
.
append
(
list
(
values
[
pt
+
pruned_len
-
1
][:
k
],
)
)
output_top_logprobs_idx
.
append
(
list
(
indices
[
pt
+
pruned_len
-
1
][:
k
],
)
)
pt
+=
pruned_len
return
(
input_top_logprobs_val
,
input_top_logprobs_idx
,
output_top_logprobs_val
,
output_top_logprobs_idx
,
)
@
staticmethod
def
compute_temp_top_p_normalized_logprobs
(
last_logits
:
torch
.
Tensor
,
logits_metadata
:
LogitsMetadata
)
->
torch
.
Tensor
:
return
torch
.
nn
.
functional
.
log_softmax
(
last_logits
,
dim
=-
1
)
def
test
():
all_logprobs
=
torch
.
tensor
(
...
...
python/sglang/srt/layers/sampler.py
View file @
7a1aecb9
...
...
@@ -51,7 +51,6 @@ class Sampler(nn.Module):
# Post process logits
logits
.
div_
(
sampling_info
.
temperatures
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
logits
=
None
del
logits
if
global_server_args_dict
[
"sampling_backend"
]
==
"flashinfer"
:
...
...
@@ -84,6 +83,7 @@ class Sampler(nn.Module):
sampling_info
.
top_ks
,
sampling_info
.
top_ps
,
sampling_info
.
min_ps
,
sampling_info
.
need_min_p_sampling
,
)
else
:
raise
ValueError
(
...
...
@@ -98,20 +98,42 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
top_ks
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
min_ps
:
torch
.
Tensor
,
need_min_p_sampling
:
bool
,
):
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
min_p_thresholds
=
probs_sort
[:,
0
]
*
min_ps
probs_sort
[(
probs_sum
-
probs_sort
)
>
top_ps
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
[
torch
.
arange
(
0
,
probs
.
shape
[
-
1
],
device
=
probs
.
device
).
view
(
1
,
-
1
)
>=
top_ks
.
view
(
-
1
,
1
)
]
=
0.0
probs_sort
[
probs_sort
<
min_p_thresholds
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
probs_sort
[(
probs_sum
-
probs_sort
)
>
top_ps
.
view
(
-
1
,
1
)]
=
0.0
if
need_min_p_sampling
:
min_p_thresholds
=
probs_sort
[:,
0
]
*
min_ps
probs_sort
[
probs_sort
<
min_p_thresholds
.
view
(
-
1
,
1
)]
=
0.0
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
# int32 range is enough to represent the token ids
probs_idx
=
probs_idx
.
to
(
torch
.
int32
)
batch_next_token_ids
=
torch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
-
1
)
return
batch_next_token_ids
def
top_p_normalize_probs
(
probs
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
):
if
global_server_args_dict
[
"sampling_backend"
]
==
"flashinfer"
:
return
top_p_renorm_prob
(
probs
,
top_ps
)
elif
global_server_args_dict
[
"sampling_backend"
]
==
"pytorch"
:
# See also top_k_top_p_min_p_sampling_from_probs_torch
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
probs_sort
[(
probs_sum
-
probs_sort
)
>
top_ps
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
.
div_
(
probs_sort
.
sum
(
dim
=-
1
,
keepdim
=
True
))
return
torch
.
zeros_like
(
probs_sort
).
scatter_
(
-
1
,
probs_idx
,
probs_sort
)
else
:
raise
ValueError
(
f
"Invalid sampling backend:
{
global_server_args_dict
[
'sampling_backend'
]
}
"
)
python/sglang/srt/managers/schedule_batch.py
View file @
7a1aecb9
...
...
@@ -1086,9 +1086,9 @@ class ScheduleBatch:
self
.
top_logprobs_nums
=
[
0
]
*
len
(
self
.
reqs
)
+
other
.
top_logprobs_nums
self
.
reqs
.
extend
(
other
.
reqs
)
self
.
return_logprob
=
self
.
return_logprob
or
other
.
return_logprob
self
.
has_stream
=
self
.
has_stream
or
other
.
has_stream
self
.
has_grammar
=
self
.
has_grammar
or
other
.
has_grammar
self
.
return_logprob
|
=
other
.
return_logprob
self
.
has_stream
|
=
other
.
has_stream
self
.
has_grammar
|
=
other
.
has_grammar
def
get_model_worker_batch
(
self
):
if
self
.
forward_mode
.
is_decode
()
or
self
.
forward_mode
.
is_idle
():
...
...
@@ -1115,7 +1115,6 @@ class ScheduleBatch:
seq_lens
=
self
.
seq_lens
,
out_cache_loc
=
self
.
out_cache_loc
,
seq_lens_sum
=
self
.
seq_lens_sum
,
req_to_token_pool_records
=
self
.
req_to_token_pool
.
get_write_records
(),
return_logprob
=
self
.
return_logprob
,
top_logprobs_nums
=
self
.
top_logprobs_nums
,
global_num_tokens
=
self
.
global_num_tokens
,
...
...
@@ -1170,9 +1169,6 @@ class ModelWorkerBatch:
# The sum of all sequence lengths
seq_lens_sum
:
int
# The memory pool operation records
req_to_token_pool_records
:
Optional
[
List
[
Tuple
[
Tuple
,
torch
.
Tensor
]]]
# For logprob
return_logprob
:
bool
top_logprobs_nums
:
Optional
[
List
[
int
]]
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
7a1aecb9
...
...
@@ -387,8 +387,14 @@ class CudaGraphRunner:
# Extract logprobs
if
forward_batch
.
return_logprob
:
next_token_logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
next_token_logits
,
dim
=-
1
logits_metadata
=
LogitsMetadata
(
forward_mode
=
ForwardMode
.
DECODE
,
top_logprobs_nums
=
forward_batch
.
top_logprobs_nums
,
)
next_token_logprobs
=
(
LogitsProcessor
.
compute_temp_top_p_normalized_logprobs
(
next_token_logits
,
logits_metadata
)
)
logits_output
=
LogitsProcessorOutput
(
next_token_logits
=
next_token_logits
,
...
...
@@ -396,10 +402,6 @@ class CudaGraphRunner:
)
return_top_logprob
=
any
(
x
>
0
for
x
in
forward_batch
.
top_logprobs_nums
)
if
return_top_logprob
:
logits_metadata
=
LogitsMetadata
(
forward_mode
=
ForwardMode
.
DECODE
,
top_logprobs_nums
=
forward_batch
.
top_logprobs_nums
,
)
(
logits_output
.
output_top_logprobs_val
,
logits_output
.
output_top_logprobs_idx
,
...
...
python/sglang/srt/server_args.py
View file @
7a1aecb9
...
...
@@ -698,11 +698,6 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"Disable Multi-head Latent Attention (MLA) for DeepSeek-V2."
,
)
parser
.
add_argument
(
"--disable-nan-detection"
,
action
=
"store_true"
,
help
=
"Disable the NaN detection for better performance."
,
)
parser
.
add_argument
(
"--disable-overlap-schedule"
,
action
=
"store_true"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment