Unverified Commit 80a90547 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix stop condition for <|eom_id|> (#1766)

parent 9af7b88e
...@@ -163,6 +163,15 @@ def get_tokenizer( ...@@ -163,6 +163,15 @@ def get_tokenizer(
"Using a slow tokenizer. This might cause a significant " "Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead." "slowdown. Consider using a fast tokenizer instead."
) )
# Special handling for stop token <|eom_id|> generated by llama 3 tool use.
if "<|eom_id|>" in tokenizer.get_added_vocab():
tokenizer.additional_stop_token_ids = set(
[tokenizer.get_added_vocab()["<|eom_id|>"]]
)
else:
tokenizer.additional_stop_token_ids = None
return tokenizer return tokenizer
......
...@@ -51,8 +51,9 @@ class SamplingParams: ...@@ -51,8 +51,9 @@ class SamplingParams:
self.repetition_penalty = repetition_penalty self.repetition_penalty = repetition_penalty
self.stop_strs = stop self.stop_strs = stop
if stop_token_ids is None: if stop_token_ids is None:
stop_token_ids = [] self.stop_token_ids = set()
self.stop_token_ids = {*stop_token_ids} else:
self.stop_token_ids = set(stop_token_ids)
self.max_new_tokens = max_new_tokens self.max_new_tokens = max_new_tokens
self.min_new_tokens = min_new_tokens self.min_new_tokens = min_new_tokens
self.ignore_eos = ignore_eos self.ignore_eos = ignore_eos
...@@ -119,10 +120,7 @@ class SamplingParams: ...@@ -119,10 +120,7 @@ class SamplingParams:
# Process stop strings # Process stop strings
if self.stop_strs is None: if self.stop_strs is None:
self.stop_strs = [] self.stop_strs = []
if self.stop_token_ids is None:
self.stop_str_max_len = 0 self.stop_str_max_len = 0
else:
self.stop_str_max_len = 1
else: else:
if isinstance(self.stop_strs, str): if isinstance(self.stop_strs, str):
self.stop_strs = [self.stop_strs] self.stop_strs = [self.stop_strs]
...@@ -136,6 +134,10 @@ class SamplingParams: ...@@ -136,6 +134,10 @@ class SamplingParams:
stop_str_max_len = max(stop_str_max_len, len(stop_str)) stop_str_max_len = max(stop_str_max_len, len(stop_str))
self.stop_str_max_len = stop_str_max_len self.stop_str_max_len = stop_str_max_len
# Process stop token ids
if tokenizer.additional_stop_token_ids:
self.stop_token_ids.update(tokenizer.additional_stop_token_ids)
def to_srt_kwargs(self): def to_srt_kwargs(self):
return { return {
"max_new_tokens": self.max_new_tokens, "max_new_tokens": self.max_new_tokens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment