Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8f8f96a6
"examples/community/pipeline_stable_diffusion_boxdiff.py" did not exist on "7caa3682e440ce506dc4674373052739f9d80303"
Unverified
Commit
8f8f96a6
authored
Oct 23, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 23, 2024
Browse files
Fix the perf regression due to additional_stop_token_ids (#1773)
parent
05b3bf5e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
20 additions
and
16 deletions
+20
-16
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+3
-3
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+1
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+7
-2
python/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py
...lang/srt/sampling/penaltylib/penalizers/min_new_tokens.py
+6
-3
python/sglang/srt/sampling/sampling_params.py
python/sglang/srt/sampling/sampling_params.py
+3
-7
No files found.
python/sglang/srt/hf_transformers_utils.py
View file @
8f8f96a6
...
...
@@ -164,7 +164,7 @@ def get_tokenizer(
"slowdown. Consider using a fast tokenizer instead."
)
handle
_additional_stop_token_ids
(
tokenizer
)
attach
_additional_stop_token_ids
(
tokenizer
)
return
tokenizer
...
...
@@ -184,11 +184,11 @@ def get_processor(
**
kwargs
,
)
handle
_additional_stop_token_ids
(
processor
.
tokenizer
)
attach
_additional_stop_token_ids
(
processor
.
tokenizer
)
return
processor
def
handle
_additional_stop_token_ids
(
tokenizer
):
def
attach
_additional_stop_token_ids
(
tokenizer
):
# Special handling for stop token <|eom_id|> generated by llama 3 tool use.
if
"<|eom_id|>"
in
tokenizer
.
get_added_vocab
():
tokenizer
.
additional_stop_token_ids
=
set
(
...
...
python/sglang/srt/layers/sampler.py
View file @
8f8f96a6
...
...
@@ -42,11 +42,11 @@ class Sampler(nn.Module):
logits
=
logits
.
contiguous
()
if
self
.
use_nan_detectioin
and
torch
.
any
(
torch
.
isnan
(
logits
)):
exit
(
1
)
if
crash_on_warning
else
None
logger
.
warning
(
"Detected errors during sampling! NaN in the logits."
)
logits
=
torch
.
where
(
torch
.
isnan
(
logits
),
torch
.
full_like
(
logits
,
-
1e5
),
logits
)
exit
(
1
)
if
crash_on_warning
else
None
if
sampling_info
.
is_all_greedy
:
# Use torch.argmax if all requests use greedy sampling
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
8f8f96a6
...
...
@@ -334,15 +334,20 @@ class Req:
last_token_id
=
self
.
output_ids
[
-
1
]
matched_eos
=
last_token_id
in
self
.
sampling_params
.
stop_token_ids
matched_eos
=
False
# Check stop token ids
if
self
.
sampling_params
.
stop_token_ids
:
matched_eos
=
last_token_id
in
self
.
sampling_params
.
stop_token_ids
if
self
.
tokenizer
is
not
None
:
matched_eos
|=
last_token_id
==
self
.
tokenizer
.
eos_token_id
if
self
.
tokenizer
.
additional_stop_token_ids
:
matched_eos
|=
last_token_id
in
self
.
tokenizer
.
additional_stop_token_ids
if
matched_eos
and
not
self
.
sampling_params
.
ignore_eos
:
self
.
finished_reason
=
FINISH_MATCHED_TOKEN
(
matched
=
last_token_id
)
return
# Check stop strings
if
len
(
self
.
sampling_params
.
stop_strs
)
>
0
:
tail_str
=
self
.
tokenizer
.
decode
(
self
.
output_ids
[
-
(
self
.
sampling_params
.
stop_str_max_len
+
1
)
:]
...
...
python/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py
View file @
8f8f96a6
...
...
@@ -31,9 +31,12 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
padded_stop_token_ids
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
sequences
=
[
torch
.
tensor
(
data
=
list
(
req
.
sampling_params
.
stop_token_ids
|
{
req
.
tokenizer
.
eos_token_id
}
data
=
(
list
(
(
req
.
sampling_params
.
stop_token_ids
or
set
())
|
(
req
.
tokenizer
.
additional_stop_token_ids
or
set
())
|
{
req
.
tokenizer
.
eos_token_id
}
)
),
dtype
=
torch
.
int64
,
device
=
self
.
orchestrator
.
device
,
...
...
python/sglang/srt/sampling/sampling_params.py
View file @
8f8f96a6
...
...
@@ -50,10 +50,10 @@ class SamplingParams:
self
.
presence_penalty
=
presence_penalty
self
.
repetition_penalty
=
repetition_penalty
self
.
stop_strs
=
stop
if
stop_token_ids
is
None
:
self
.
stop_token_ids
=
set
()
else
:
if
stop_token_ids
:
self
.
stop_token_ids
=
set
(
stop_token_ids
)
else
:
self
.
stop_token_ids
=
None
self
.
max_new_tokens
=
max_new_tokens
self
.
min_new_tokens
=
min_new_tokens
self
.
ignore_eos
=
ignore_eos
...
...
@@ -134,10 +134,6 @@ class SamplingParams:
stop_str_max_len
=
max
(
stop_str_max_len
,
len
(
stop_str
))
self
.
stop_str_max_len
=
stop_str_max_len
# Process stop token ids
if
tokenizer
and
tokenizer
.
additional_stop_token_ids
:
self
.
stop_token_ids
.
update
(
tokenizer
.
additional_stop_token_ids
)
def
to_srt_kwargs
(
self
):
return
{
"max_new_tokens"
:
self
.
max_new_tokens
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment