Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7a1aecb9
Unverified
Commit
7a1aecb9
authored
Dec 16, 2024
by
Lianmin Zheng
Committed by
GitHub
Dec 16, 2024
Browse files
Simplify pytorch sampling kernel and logit processor (#2491)
parent
82699474
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
136 additions
and
107 deletions
+136
-107
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+98
-84
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+27
-5
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+3
-7
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+8
-6
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+0
-5
No files found.
python/sglang/srt/layers/logits_processor.py
View file @
7a1aecb9
...
@@ -100,81 +100,8 @@ class LogitsProcessor(nn.Module):
...
@@ -100,81 +100,8 @@ class LogitsProcessor(nn.Module):
self
.
do_tensor_parallel_all_gather
=
(
self
.
do_tensor_parallel_all_gather
=
(
not
skip_all_gather
and
get_tensor_model_parallel_world_size
()
>
1
not
skip_all_gather
and
get_tensor_model_parallel_world_size
()
>
1
)
)
self
.
final_logit_softcapping
=
getattr
(
def
_get_normalized_prompt_logprobs
(
self
.
config
,
"final_logit_softcapping"
,
None
self
,
input_token_logprobs
:
torch
.
Tensor
,
logits_metadata
:
LogitsMetadata
,
):
logprobs_cumsum
=
torch
.
cumsum
(
input_token_logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
pruned_lens
=
torch
.
tensor
(
logits_metadata
.
extend_logprob_pruned_lens_cpu
,
device
=
"cuda"
)
start
=
torch
.
zeros_like
(
pruned_lens
)
start
[
1
:]
=
torch
.
cumsum
(
pruned_lens
[:
-
1
],
dim
=
0
)
end
=
torch
.
clamp
(
start
+
pruned_lens
-
2
,
min
=
0
,
max
=
logprobs_cumsum
.
shape
[
0
]
-
1
)
sum_logp
=
(
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
input_token_logprobs
[
start
]
)
normalized_prompt_logprobs
=
sum_logp
/
(
pruned_lens
-
1
).
clamp
(
min
=
1
)
return
normalized_prompt_logprobs
@
staticmethod
def
get_top_logprobs
(
all_logprobs
:
torch
.
Tensor
,
logits_metadata
:
LogitsMetadata
):
max_k
=
max
(
logits_metadata
.
top_logprobs_nums
)
ret
=
all_logprobs
.
topk
(
max_k
,
dim
=
1
)
values
=
ret
.
values
.
tolist
()
indices
=
ret
.
indices
.
tolist
()
if
logits_metadata
.
forward_mode
.
is_decode
():
output_top_logprobs_val
=
[]
output_top_logprobs_idx
=
[]
for
i
,
k
in
enumerate
(
logits_metadata
.
top_logprobs_nums
):
output_top_logprobs_val
.
append
(
values
[
i
][:
k
])
output_top_logprobs_idx
.
append
(
indices
[
i
][:
k
])
return
None
,
None
,
output_top_logprobs_val
,
output_top_logprobs_idx
else
:
input_top_logprobs_val
,
input_top_logprobs_idx
=
[],
[]
output_top_logprobs_val
,
output_top_logprobs_idx
=
[],
[]
pt
=
0
for
k
,
pruned_len
in
zip
(
logits_metadata
.
top_logprobs_nums
,
logits_metadata
.
extend_logprob_pruned_lens_cpu
,
):
if
pruned_len
<=
0
:
input_top_logprobs_val
.
append
([])
input_top_logprobs_idx
.
append
([])
output_top_logprobs_val
.
append
([])
output_top_logprobs_idx
.
append
([])
continue
input_top_logprobs_val
.
append
(
[
values
[
pt
+
j
][:
k
]
for
j
in
range
(
pruned_len
-
1
)]
)
input_top_logprobs_idx
.
append
(
[
indices
[
pt
+
j
][:
k
]
for
j
in
range
(
pruned_len
-
1
)]
)
output_top_logprobs_val
.
append
(
list
(
values
[
pt
+
pruned_len
-
1
][:
k
],
)
)
output_top_logprobs_idx
.
append
(
list
(
indices
[
pt
+
pruned_len
-
1
][:
k
],
)
)
pt
+=
pruned_len
return
(
input_top_logprobs_val
,
input_top_logprobs_idx
,
output_top_logprobs_val
,
output_top_logprobs_idx
,
)
)
def
forward
(
def
forward
(
...
@@ -201,10 +128,10 @@ class LogitsProcessor(nn.Module):
...
@@ -201,10 +128,10 @@ class LogitsProcessor(nn.Module):
last_logits
=
tensor_model_parallel_all_gather
(
last_logits
)
last_logits
=
tensor_model_parallel_all_gather
(
last_logits
)
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
if
hasattr
(
self
.
config
,
"
final_logit_softcapping
"
)
:
if
self
.
final_logit_softcapping
:
last_logits
.
div_
(
self
.
config
.
final_logit_softcapping
)
last_logits
.
div_
(
self
.
final_logit_softcapping
)
torch
.
tanh
(
last_logits
,
out
=
last_logits
)
torch
.
tanh
(
last_logits
,
out
=
last_logits
)
last_logits
.
mul_
(
self
.
config
.
final_logit_softcapping
)
last_logits
.
mul_
(
self
.
final_logit_softcapping
)
# Return only last_logits if logprob is not requested
# Return only last_logits if logprob is not requested
if
not
logits_metadata
.
return_logprob
:
if
not
logits_metadata
.
return_logprob
:
...
@@ -212,7 +139,9 @@ class LogitsProcessor(nn.Module):
...
@@ -212,7 +139,9 @@ class LogitsProcessor(nn.Module):
next_token_logits
=
last_logits
,
next_token_logits
=
last_logits
,
)
)
else
:
else
:
last_logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
last_logits
,
dim
=-
1
)
last_logprobs
=
self
.
compute_temp_top_p_normalized_logprobs
(
last_logits
,
logits_metadata
)
if
logits_metadata
.
forward_mode
.
is_decode
():
if
logits_metadata
.
forward_mode
.
is_decode
():
if
logits_metadata
.
return_top_logprob
:
if
logits_metadata
.
return_top_logprob
:
...
@@ -248,14 +177,17 @@ class LogitsProcessor(nn.Module):
...
@@ -248,14 +177,17 @@ class LogitsProcessor(nn.Module):
# extra logits that this padding may have produced.
# extra logits that this padding may have produced.
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
if
hasattr
(
self
.
config
,
"
final_logit_softcapping
"
)
:
if
self
.
final_logit_softcapping
:
all_logits
.
div_
(
self
.
config
.
final_logit_softcapping
)
all_logits
.
div_
(
self
.
final_logit_softcapping
)
torch
.
tanh
(
all_logits
,
out
=
all_logits
)
torch
.
tanh
(
all_logits
,
out
=
all_logits
)
all_logits
.
mul_
(
self
.
config
.
final_logit_softcapping
)
all_logits
.
mul_
(
self
.
final_logit_softcapping
)
all_logprobs
=
all_logits
all_logprobs
=
all_logits
del
all_logits
,
hidden_states
del
all_logits
,
hidden_states
all_logprobs
[:]
=
torch
.
nn
.
functional
.
log_softmax
(
all_logprobs
,
dim
=-
1
)
all_logprobs
=
self
.
compute_temp_top_p_normalized_logprobs
(
all_logprobs
,
logits_metadata
)
# Get the logprob of top-k tokens
# Get the logprob of top-k tokens
if
logits_metadata
.
return_top_logprob
:
if
logits_metadata
.
return_top_logprob
:
...
@@ -309,11 +241,93 @@ class LogitsProcessor(nn.Module):
...
@@ -309,11 +241,93 @@ class LogitsProcessor(nn.Module):
# GGUF models
# GGUF models
logits
=
lm_head
.
linear_method
.
apply
(
lm_head
,
hidden_states
,
embedding_bias
)
logits
=
lm_head
.
linear_method
.
apply
(
lm_head
,
hidden_states
,
embedding_bias
)
# Optional scaling factor
, backported from vLLM 0.4
# Optional scaling factor
if
self
.
logit_scale
is
not
None
:
if
self
.
logit_scale
is
not
None
:
logits
.
mul_
(
self
.
logit_scale
)
# In-place multiply
logits
.
mul_
(
self
.
logit_scale
)
# In-place multiply
return
logits
return
logits
@
staticmethod
def
_get_normalized_prompt_logprobs
(
input_token_logprobs
:
torch
.
Tensor
,
logits_metadata
:
LogitsMetadata
,
):
logprobs_cumsum
=
torch
.
cumsum
(
input_token_logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
pruned_lens
=
torch
.
tensor
(
logits_metadata
.
extend_logprob_pruned_lens_cpu
,
device
=
"cuda"
)
start
=
torch
.
zeros_like
(
pruned_lens
)
start
[
1
:]
=
torch
.
cumsum
(
pruned_lens
[:
-
1
],
dim
=
0
)
end
=
torch
.
clamp
(
start
+
pruned_lens
-
2
,
min
=
0
,
max
=
logprobs_cumsum
.
shape
[
0
]
-
1
)
sum_logp
=
(
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
input_token_logprobs
[
start
]
)
normalized_prompt_logprobs
=
sum_logp
/
(
pruned_lens
-
1
).
clamp
(
min
=
1
)
return
normalized_prompt_logprobs
@
staticmethod
def
get_top_logprobs
(
all_logprobs
:
torch
.
Tensor
,
logits_metadata
:
LogitsMetadata
):
max_k
=
max
(
logits_metadata
.
top_logprobs_nums
)
ret
=
all_logprobs
.
topk
(
max_k
,
dim
=
1
)
values
=
ret
.
values
.
tolist
()
indices
=
ret
.
indices
.
tolist
()
if
logits_metadata
.
forward_mode
.
is_decode
():
output_top_logprobs_val
=
[]
output_top_logprobs_idx
=
[]
for
i
,
k
in
enumerate
(
logits_metadata
.
top_logprobs_nums
):
output_top_logprobs_val
.
append
(
values
[
i
][:
k
])
output_top_logprobs_idx
.
append
(
indices
[
i
][:
k
])
return
None
,
None
,
output_top_logprobs_val
,
output_top_logprobs_idx
else
:
input_top_logprobs_val
,
input_top_logprobs_idx
=
[],
[]
output_top_logprobs_val
,
output_top_logprobs_idx
=
[],
[]
pt
=
0
for
k
,
pruned_len
in
zip
(
logits_metadata
.
top_logprobs_nums
,
logits_metadata
.
extend_logprob_pruned_lens_cpu
,
):
if
pruned_len
<=
0
:
input_top_logprobs_val
.
append
([])
input_top_logprobs_idx
.
append
([])
output_top_logprobs_val
.
append
([])
output_top_logprobs_idx
.
append
([])
continue
input_top_logprobs_val
.
append
(
[
values
[
pt
+
j
][:
k
]
for
j
in
range
(
pruned_len
-
1
)]
)
input_top_logprobs_idx
.
append
(
[
indices
[
pt
+
j
][:
k
]
for
j
in
range
(
pruned_len
-
1
)]
)
output_top_logprobs_val
.
append
(
list
(
values
[
pt
+
pruned_len
-
1
][:
k
],
)
)
output_top_logprobs_idx
.
append
(
list
(
indices
[
pt
+
pruned_len
-
1
][:
k
],
)
)
pt
+=
pruned_len
return
(
input_top_logprobs_val
,
input_top_logprobs_idx
,
output_top_logprobs_val
,
output_top_logprobs_idx
,
)
@
staticmethod
def
compute_temp_top_p_normalized_logprobs
(
last_logits
:
torch
.
Tensor
,
logits_metadata
:
LogitsMetadata
)
->
torch
.
Tensor
:
return
torch
.
nn
.
functional
.
log_softmax
(
last_logits
,
dim
=-
1
)
def
test
():
def
test
():
all_logprobs
=
torch
.
tensor
(
all_logprobs
=
torch
.
tensor
(
...
...
python/sglang/srt/layers/sampler.py
View file @
7a1aecb9
...
@@ -51,7 +51,6 @@ class Sampler(nn.Module):
...
@@ -51,7 +51,6 @@ class Sampler(nn.Module):
# Post process logits
# Post process logits
logits
.
div_
(
sampling_info
.
temperatures
)
logits
.
div_
(
sampling_info
.
temperatures
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
logits
=
None
del
logits
del
logits
if
global_server_args_dict
[
"sampling_backend"
]
==
"flashinfer"
:
if
global_server_args_dict
[
"sampling_backend"
]
==
"flashinfer"
:
...
@@ -84,6 +83,7 @@ class Sampler(nn.Module):
...
@@ -84,6 +83,7 @@ class Sampler(nn.Module):
sampling_info
.
top_ks
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
,
sampling_info
.
top_ps
,
sampling_info
.
min_ps
,
sampling_info
.
min_ps
,
sampling_info
.
need_min_p_sampling
,
)
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
...
@@ -98,20 +98,42 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
...
@@ -98,20 +98,42 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
top_ks
:
torch
.
Tensor
,
top_ks
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
min_ps
:
torch
.
Tensor
,
min_ps
:
torch
.
Tensor
,
need_min_p_sampling
:
bool
,
):
):
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
min_p_thresholds
=
probs_sort
[:,
0
]
*
min_ps
probs_sort
[(
probs_sum
-
probs_sort
)
>
top_ps
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
[
probs_sort
[
torch
.
arange
(
0
,
probs
.
shape
[
-
1
],
device
=
probs
.
device
).
view
(
1
,
-
1
)
torch
.
arange
(
0
,
probs
.
shape
[
-
1
],
device
=
probs
.
device
).
view
(
1
,
-
1
)
>=
top_ks
.
view
(
-
1
,
1
)
>=
top_ks
.
view
(
-
1
,
1
)
]
=
0.0
]
=
0.0
probs_sort
[(
probs_sum
-
probs_sort
)
>
top_ps
.
view
(
-
1
,
1
)]
=
0.0
if
need_min_p_sampling
:
min_p_thresholds
=
probs_sort
[:,
0
]
*
min_ps
probs_sort
[
probs_sort
<
min_p_thresholds
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
[
probs_sort
<
min_p_thresholds
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
# int32 range is enough to represent the token ids
# int32 range is enough to represent the token ids
probs_idx
=
probs_idx
.
to
(
torch
.
int32
)
probs_idx
=
probs_idx
.
to
(
torch
.
int32
)
batch_next_token_ids
=
torch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
-
1
)
batch_next_token_ids
=
torch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
-
1
)
return
batch_next_token_ids
return
batch_next_token_ids
def
top_p_normalize_probs
(
probs
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
):
if
global_server_args_dict
[
"sampling_backend"
]
==
"flashinfer"
:
return
top_p_renorm_prob
(
probs
,
top_ps
)
elif
global_server_args_dict
[
"sampling_backend"
]
==
"pytorch"
:
# See also top_k_top_p_min_p_sampling_from_probs_torch
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
probs_sort
[(
probs_sum
-
probs_sort
)
>
top_ps
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
.
div_
(
probs_sort
.
sum
(
dim
=-
1
,
keepdim
=
True
))
return
torch
.
zeros_like
(
probs_sort
).
scatter_
(
-
1
,
probs_idx
,
probs_sort
)
else
:
raise
ValueError
(
f
"Invalid sampling backend:
{
global_server_args_dict
[
'sampling_backend'
]
}
"
)
python/sglang/srt/managers/schedule_batch.py
View file @
7a1aecb9
...
@@ -1086,9 +1086,9 @@ class ScheduleBatch:
...
@@ -1086,9 +1086,9 @@ class ScheduleBatch:
self
.
top_logprobs_nums
=
[
0
]
*
len
(
self
.
reqs
)
+
other
.
top_logprobs_nums
self
.
top_logprobs_nums
=
[
0
]
*
len
(
self
.
reqs
)
+
other
.
top_logprobs_nums
self
.
reqs
.
extend
(
other
.
reqs
)
self
.
reqs
.
extend
(
other
.
reqs
)
self
.
return_logprob
=
self
.
return_logprob
or
other
.
return_logprob
self
.
return_logprob
|
=
other
.
return_logprob
self
.
has_stream
=
self
.
has_stream
or
other
.
has_stream
self
.
has_stream
|
=
other
.
has_stream
self
.
has_grammar
=
self
.
has_grammar
or
other
.
has_grammar
self
.
has_grammar
|
=
other
.
has_grammar
def
get_model_worker_batch
(
self
):
def
get_model_worker_batch
(
self
):
if
self
.
forward_mode
.
is_decode
()
or
self
.
forward_mode
.
is_idle
():
if
self
.
forward_mode
.
is_decode
()
or
self
.
forward_mode
.
is_idle
():
...
@@ -1115,7 +1115,6 @@ class ScheduleBatch:
...
@@ -1115,7 +1115,6 @@ class ScheduleBatch:
seq_lens
=
self
.
seq_lens
,
seq_lens
=
self
.
seq_lens
,
out_cache_loc
=
self
.
out_cache_loc
,
out_cache_loc
=
self
.
out_cache_loc
,
seq_lens_sum
=
self
.
seq_lens_sum
,
seq_lens_sum
=
self
.
seq_lens_sum
,
req_to_token_pool_records
=
self
.
req_to_token_pool
.
get_write_records
(),
return_logprob
=
self
.
return_logprob
,
return_logprob
=
self
.
return_logprob
,
top_logprobs_nums
=
self
.
top_logprobs_nums
,
top_logprobs_nums
=
self
.
top_logprobs_nums
,
global_num_tokens
=
self
.
global_num_tokens
,
global_num_tokens
=
self
.
global_num_tokens
,
...
@@ -1170,9 +1169,6 @@ class ModelWorkerBatch:
...
@@ -1170,9 +1169,6 @@ class ModelWorkerBatch:
# The sum of all sequence lengths
# The sum of all sequence lengths
seq_lens_sum
:
int
seq_lens_sum
:
int
# The memory pool operation records
req_to_token_pool_records
:
Optional
[
List
[
Tuple
[
Tuple
,
torch
.
Tensor
]]]
# For logprob
# For logprob
return_logprob
:
bool
return_logprob
:
bool
top_logprobs_nums
:
Optional
[
List
[
int
]]
top_logprobs_nums
:
Optional
[
List
[
int
]]
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
7a1aecb9
...
@@ -387,8 +387,14 @@ class CudaGraphRunner:
...
@@ -387,8 +387,14 @@ class CudaGraphRunner:
# Extract logprobs
# Extract logprobs
if
forward_batch
.
return_logprob
:
if
forward_batch
.
return_logprob
:
next_token_logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
logits_metadata
=
LogitsMetadata
(
next_token_logits
,
dim
=-
1
forward_mode
=
ForwardMode
.
DECODE
,
top_logprobs_nums
=
forward_batch
.
top_logprobs_nums
,
)
next_token_logprobs
=
(
LogitsProcessor
.
compute_temp_top_p_normalized_logprobs
(
next_token_logits
,
logits_metadata
)
)
)
logits_output
=
LogitsProcessorOutput
(
logits_output
=
LogitsProcessorOutput
(
next_token_logits
=
next_token_logits
,
next_token_logits
=
next_token_logits
,
...
@@ -396,10 +402,6 @@ class CudaGraphRunner:
...
@@ -396,10 +402,6 @@ class CudaGraphRunner:
)
)
return_top_logprob
=
any
(
x
>
0
for
x
in
forward_batch
.
top_logprobs_nums
)
return_top_logprob
=
any
(
x
>
0
for
x
in
forward_batch
.
top_logprobs_nums
)
if
return_top_logprob
:
if
return_top_logprob
:
logits_metadata
=
LogitsMetadata
(
forward_mode
=
ForwardMode
.
DECODE
,
top_logprobs_nums
=
forward_batch
.
top_logprobs_nums
,
)
(
(
logits_output
.
output_top_logprobs_val
,
logits_output
.
output_top_logprobs_val
,
logits_output
.
output_top_logprobs_idx
,
logits_output
.
output_top_logprobs_idx
,
...
...
python/sglang/srt/server_args.py
View file @
7a1aecb9
...
@@ -698,11 +698,6 @@ class ServerArgs:
...
@@ -698,11 +698,6 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Disable Multi-head Latent Attention (MLA) for DeepSeek-V2."
,
help
=
"Disable Multi-head Latent Attention (MLA) for DeepSeek-V2."
,
)
)
parser
.
add_argument
(
"--disable-nan-detection"
,
action
=
"store_true"
,
help
=
"Disable the NaN detection for better performance."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--disable-overlap-schedule"
,
"--disable-overlap-schedule"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment