"examples/git@developer.sourcefind.cn:OpenDAS/fairseq.git" did not exist on "fbd4cef9a575b5f77ca05d4b7c3ad3adb11141ac"
Unverified Commit 08ebdf79 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix the `--allow-auto-truncate` argument in tokenizer manager. (#9391)


Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 42c87045
...@@ -565,14 +565,24 @@ class TokenizerManager: ...@@ -565,14 +565,24 @@ class TokenizerManager:
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int] self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
) -> None: ) -> None:
"""Validates that the input token count and the requested token count doesn't exceed the model's context length.""" """Validates that the input token count and the requested token count doesn't exceed the model's context length."""
# FIXME: unify the length validation logic with the one in the scheduler.
_max_req_len = self.context_len - 1
input_token_num = len(input_ids) if input_ids is not None else 0 input_token_num = len(input_ids) if input_ids is not None else 0
# Check if input alone exceeds context length
if input_token_num >= self.context_len: if input_token_num >= self.context_len:
raise ValueError( if self.server_args.allow_auto_truncate:
f"The input ({input_token_num} tokens) is longer than the " logger.warning(
f"model's context length ({self.context_len} tokens)." f"The input ({input_token_num} tokens) is longer than the "
) f"model's context length ({self.context_len} tokens). "
"Truncating the input."
)
input_ids = input_ids[:_max_req_len]
input_token_num = len(input_ids)
else:
raise ValueError(
f"The input ({input_token_num} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)."
)
if isinstance(obj, EmbeddingReqInput) and self.is_generation: if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError( raise ValueError(
...@@ -584,17 +594,27 @@ class TokenizerManager: ...@@ -584,17 +594,27 @@ class TokenizerManager:
max_new_tokens = obj.sampling_params.get("max_new_tokens") max_new_tokens = obj.sampling_params.get("max_new_tokens")
if ( if (
max_new_tokens is not None max_new_tokens is not None
and (max_new_tokens + input_token_num) >= self.context_len and (max_new_tokens + input_token_num) >= _max_req_len
): ):
total_tokens = max_new_tokens + input_token_num if self.server_args.allow_auto_truncate:
error_msg = ( logger.warning(
f"Requested token count exceeds the model's maximum context length " f"Requested token count ({input_token_num} input + {max_new_tokens} new) "
f"of {self.context_len} tokens. You requested a total of {total_tokens} " f"exceeds the model's context length ({self.context_len} tokens). "
f"tokens: {input_token_num} tokens from the input messages and " "Truncating max_new_tokens."
f"{max_new_tokens} tokens for the completion. Please reduce the number " )
f"of tokens in the input messages or the completion to fit within the limit." obj.sampling_params["max_new_tokens"] = max(
) 0, _max_req_len - input_token_num
raise ValueError(error_msg) )
else:
total_tokens = max_new_tokens + input_token_num
error_msg = (
f"Requested token count exceeds the model's maximum context length "
f"of {self.context_len} tokens. You requested a total of {total_tokens} "
f"tokens: {input_token_num} tokens from the input messages and "
f"{max_new_tokens} tokens for the completion. Please reduce the number "
f"of tokens in the input messages or the completion to fit within the limit."
)
raise ValueError(error_msg)
if isinstance(obj, GenerateReqInput): if isinstance(obj, GenerateReqInput):
if ( if (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment