Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8f8f96a6
Unverified
Commit
8f8f96a6
authored
Oct 23, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 23, 2024
Browse files
Fix the perf regression due to additional_stop_token_ids (#1773)
parent
05b3bf5e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
20 additions
and
16 deletions
+20
-16
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+3
-3
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+1
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+7
-2
python/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py
...lang/srt/sampling/penaltylib/penalizers/min_new_tokens.py
+6
-3
python/sglang/srt/sampling/sampling_params.py
python/sglang/srt/sampling/sampling_params.py
+3
-7
No files found.
python/sglang/srt/hf_transformers_utils.py
View file @
8f8f96a6
...
@@ -164,7 +164,7 @@ def get_tokenizer(
...
@@ -164,7 +164,7 @@ def get_tokenizer(
"slowdown. Consider using a fast tokenizer instead."
"slowdown. Consider using a fast tokenizer instead."
)
)
handle
_additional_stop_token_ids
(
tokenizer
)
attach
_additional_stop_token_ids
(
tokenizer
)
return
tokenizer
return
tokenizer
...
@@ -184,11 +184,11 @@ def get_processor(
...
@@ -184,11 +184,11 @@ def get_processor(
**
kwargs
,
**
kwargs
,
)
)
handle
_additional_stop_token_ids
(
processor
.
tokenizer
)
attach
_additional_stop_token_ids
(
processor
.
tokenizer
)
return
processor
return
processor
def
handle
_additional_stop_token_ids
(
tokenizer
):
def
attach
_additional_stop_token_ids
(
tokenizer
):
# Special handling for stop token <|eom_id|> generated by llama 3 tool use.
# Special handling for stop token <|eom_id|> generated by llama 3 tool use.
if
"<|eom_id|>"
in
tokenizer
.
get_added_vocab
():
if
"<|eom_id|>"
in
tokenizer
.
get_added_vocab
():
tokenizer
.
additional_stop_token_ids
=
set
(
tokenizer
.
additional_stop_token_ids
=
set
(
...
...
python/sglang/srt/layers/sampler.py
View file @
8f8f96a6
...
@@ -42,11 +42,11 @@ class Sampler(nn.Module):
...
@@ -42,11 +42,11 @@ class Sampler(nn.Module):
logits
=
logits
.
contiguous
()
logits
=
logits
.
contiguous
()
if
self
.
use_nan_detectioin
and
torch
.
any
(
torch
.
isnan
(
logits
)):
if
self
.
use_nan_detectioin
and
torch
.
any
(
torch
.
isnan
(
logits
)):
exit
(
1
)
if
crash_on_warning
else
None
logger
.
warning
(
"Detected errors during sampling! NaN in the logits."
)
logger
.
warning
(
"Detected errors during sampling! NaN in the logits."
)
logits
=
torch
.
where
(
logits
=
torch
.
where
(
torch
.
isnan
(
logits
),
torch
.
full_like
(
logits
,
-
1e5
),
logits
torch
.
isnan
(
logits
),
torch
.
full_like
(
logits
,
-
1e5
),
logits
)
)
exit
(
1
)
if
crash_on_warning
else
None
if
sampling_info
.
is_all_greedy
:
if
sampling_info
.
is_all_greedy
:
# Use torch.argmax if all requests use greedy sampling
# Use torch.argmax if all requests use greedy sampling
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
8f8f96a6
...
@@ -334,15 +334,20 @@ class Req:
...
@@ -334,15 +334,20 @@ class Req:
last_token_id
=
self
.
output_ids
[
-
1
]
last_token_id
=
self
.
output_ids
[
-
1
]
matched_eos
=
last_token_id
in
self
.
sampling_params
.
stop_token_ids
matched_eos
=
False
# Check stop token ids
if
self
.
sampling_params
.
stop_token_ids
:
matched_eos
=
last_token_id
in
self
.
sampling_params
.
stop_token_ids
if
self
.
tokenizer
is
not
None
:
if
self
.
tokenizer
is
not
None
:
matched_eos
|=
last_token_id
==
self
.
tokenizer
.
eos_token_id
matched_eos
|=
last_token_id
==
self
.
tokenizer
.
eos_token_id
if
self
.
tokenizer
.
additional_stop_token_ids
:
matched_eos
|=
last_token_id
in
self
.
tokenizer
.
additional_stop_token_ids
if
matched_eos
and
not
self
.
sampling_params
.
ignore_eos
:
if
matched_eos
and
not
self
.
sampling_params
.
ignore_eos
:
self
.
finished_reason
=
FINISH_MATCHED_TOKEN
(
matched
=
last_token_id
)
self
.
finished_reason
=
FINISH_MATCHED_TOKEN
(
matched
=
last_token_id
)
return
return
# Check stop strings
if
len
(
self
.
sampling_params
.
stop_strs
)
>
0
:
if
len
(
self
.
sampling_params
.
stop_strs
)
>
0
:
tail_str
=
self
.
tokenizer
.
decode
(
tail_str
=
self
.
tokenizer
.
decode
(
self
.
output_ids
[
-
(
self
.
sampling_params
.
stop_str_max_len
+
1
)
:]
self
.
output_ids
[
-
(
self
.
sampling_params
.
stop_str_max_len
+
1
)
:]
...
...
python/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py
View file @
8f8f96a6
...
@@ -31,9 +31,12 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
...
@@ -31,9 +31,12 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
padded_stop_token_ids
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
padded_stop_token_ids
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
sequences
=
[
sequences
=
[
torch
.
tensor
(
torch
.
tensor
(
data
=
list
(
data
=
(
req
.
sampling_params
.
stop_token_ids
list
(
|
{
req
.
tokenizer
.
eos_token_id
}
(
req
.
sampling_params
.
stop_token_ids
or
set
())
|
(
req
.
tokenizer
.
additional_stop_token_ids
or
set
())
|
{
req
.
tokenizer
.
eos_token_id
}
)
),
),
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
self
.
orchestrator
.
device
,
device
=
self
.
orchestrator
.
device
,
...
...
python/sglang/srt/sampling/sampling_params.py
View file @
8f8f96a6
...
@@ -50,10 +50,10 @@ class SamplingParams:
...
@@ -50,10 +50,10 @@ class SamplingParams:
self
.
presence_penalty
=
presence_penalty
self
.
presence_penalty
=
presence_penalty
self
.
repetition_penalty
=
repetition_penalty
self
.
repetition_penalty
=
repetition_penalty
self
.
stop_strs
=
stop
self
.
stop_strs
=
stop
if
stop_token_ids
is
None
:
if
stop_token_ids
:
self
.
stop_token_ids
=
set
()
else
:
self
.
stop_token_ids
=
set
(
stop_token_ids
)
self
.
stop_token_ids
=
set
(
stop_token_ids
)
else
:
self
.
stop_token_ids
=
None
self
.
max_new_tokens
=
max_new_tokens
self
.
max_new_tokens
=
max_new_tokens
self
.
min_new_tokens
=
min_new_tokens
self
.
min_new_tokens
=
min_new_tokens
self
.
ignore_eos
=
ignore_eos
self
.
ignore_eos
=
ignore_eos
...
@@ -134,10 +134,6 @@ class SamplingParams:
...
@@ -134,10 +134,6 @@ class SamplingParams:
stop_str_max_len
=
max
(
stop_str_max_len
,
len
(
stop_str
))
stop_str_max_len
=
max
(
stop_str_max_len
,
len
(
stop_str
))
self
.
stop_str_max_len
=
stop_str_max_len
self
.
stop_str_max_len
=
stop_str_max_len
# Process stop token ids
if
tokenizer
and
tokenizer
.
additional_stop_token_ids
:
self
.
stop_token_ids
.
update
(
tokenizer
.
additional_stop_token_ids
)
def
to_srt_kwargs
(
self
):
def
to_srt_kwargs
(
self
):
return
{
return
{
"max_new_tokens"
:
self
.
max_new_tokens
,
"max_new_tokens"
:
self
.
max_new_tokens
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment