Unverified Commit 0d800090 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix missing additional_stop_token_ids (#1769)

parent b7d05594
...@@ -9,5 +9,5 @@ ...@@ -9,5 +9,5 @@
/python/sglang/srt/models @merrymercy @Ying1123 @hnyls2002 @zhyncs @ispobock @ByronHsu /python/sglang/srt/models @merrymercy @Ying1123 @hnyls2002 @zhyncs @ispobock @ByronHsu
/python/sglang/srt/openai_api @merrymercy @Ying1123 @hnyls2002 @zhyncs @ispobock @ByronHsu /python/sglang/srt/openai_api @merrymercy @Ying1123 @hnyls2002 @zhyncs @ispobock @ByronHsu
/python/sglang/srt/sampling @merrymercy @hnyls2002 /python/sglang/srt/sampling @merrymercy @hnyls2002
/test/lang @merrymercy @Ying1123 @hnyls2002 @ByronHsu /test/lang @merrymercy @Ying1123 @ByronHsu
/test/srt @merrymercy @Ying1123 @hnyls2002 @zhyncs @ispobock @ByronHsu /test/srt @merrymercy @Ying1123 @zhyncs
...@@ -164,14 +164,7 @@ def get_tokenizer( ...@@ -164,14 +164,7 @@ def get_tokenizer(
"slowdown. Consider using a fast tokenizer instead." "slowdown. Consider using a fast tokenizer instead."
) )
# Special handling for stop token <|eom_id|> generated by llama 3 tool use. handle_additional_stop_token_ids(tokenizer)
if "<|eom_id|>" in tokenizer.get_added_vocab():
tokenizer.additional_stop_token_ids = set(
[tokenizer.get_added_vocab()["<|eom_id|>"]]
)
else:
tokenizer.additional_stop_token_ids = None
return tokenizer return tokenizer
...@@ -190,4 +183,16 @@ def get_processor( ...@@ -190,4 +183,16 @@ def get_processor(
tokenizer_revision=tokenizer_revision, tokenizer_revision=tokenizer_revision,
**kwargs, **kwargs,
) )
handle_additional_stop_token_ids(processor.tokenizer)
return processor return processor
def handle_additional_stop_token_ids(tokenizer):
# Special handling for stop token <|eom_id|> generated by llama 3 tool use.
if "<|eom_id|>" in tokenizer.get_added_vocab():
tokenizer.additional_stop_token_ids = set(
[tokenizer.get_added_vocab()["<|eom_id|>"]]
)
else:
tokenizer.additional_stop_token_ids = None
...@@ -135,7 +135,7 @@ class SamplingParams: ...@@ -135,7 +135,7 @@ class SamplingParams:
self.stop_str_max_len = stop_str_max_len self.stop_str_max_len = stop_str_max_len
# Process stop token ids # Process stop token ids
if tokenizer.additional_stop_token_ids: if tokenizer and tokenizer.additional_stop_token_ids:
self.stop_token_ids.update(tokenizer.additional_stop_token_ids) self.stop_token_ids.update(tokenizer.additional_stop_token_ids)
def to_srt_kwargs(self): def to_srt_kwargs(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment