Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7f24ea95
Unverified
Commit
7f24ea95
authored
Sep 18, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 18, 2024
Browse files
Fuse top_k and top_k in the sampler (#1457)
parent
1acccb36
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
4 deletions
+12
-4
docs/en/sampling_params.md
docs/en/sampling_params.md
+1
-0
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+9
-2
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+2
-2
No files found.
docs/en/sampling_params.md
View file @
7f24ea95
...
...
@@ -23,6 +23,7 @@ class GenerateReqInput:
# Whether to return logprobs.
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
None
# The start location of the prompt for return_logprob.
# By default, this value is "-1", which means it will only return logprobs for output tokens.
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
# The number of top logprobs to return.
top_logprobs_num
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
...
...
python/sglang/srt/layers/sampler.py
View file @
7f24ea95
...
...
@@ -31,8 +31,11 @@ class Sampler(nn.Module):
logits
=
logits
.
next_token_logits
# Post process logits
logits
=
logits
.
contiguous
()
logits
.
div_
(
sampling_info
.
temperatures
)
probs
=
logits
[:]
=
torch
.
softmax
(
logits
,
dim
=-
1
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
logits
=
None
del
logits
if
torch
.
any
(
torch
.
isnan
(
probs
)):
logger
.
warning
(
"Detected errors during sampling! NaN in the probability."
)
...
...
@@ -53,7 +56,11 @@ class Sampler(nn.Module):
)
else
:
batch_next_token_ids
,
success
=
top_k_top_p_sampling_from_probs
(
probs
,
uniform_samples
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
probs
,
uniform_samples
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
,
filter_apply_order
=
"joint"
,
)
if
not
torch
.
all
(
success
):
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
7f24ea95
...
...
@@ -400,8 +400,8 @@ class ModelRunner:
)
self
.
req_to_token_pool
=
ReqToTokenPool
(
max_num_reqs
,
self
.
model_config
.
context_len
+
8
,
max_num_reqs
+
1
,
self
.
model_config
.
context_len
+
4
,
)
if
(
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment