Unverified Commit 08ebdf79 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix the `--allow-auto-truncate` argument in tokenizer manager. (#9391)


Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 42c87045
...@@ -565,10 +565,20 @@ class TokenizerManager: ...@@ -565,10 +565,20 @@ class TokenizerManager:
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int] self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
) -> None: ) -> None:
"""Validates that the input token count and the requested token count doesn't exceed the model's context length.""" """Validates that the input token count and the requested token count doesn't exceed the model's context length."""
# FIXME: unify the length validation logic with the one in the scheduler.
_max_req_len = self.context_len - 1
input_token_num = len(input_ids) if input_ids is not None else 0 input_token_num = len(input_ids) if input_ids is not None else 0
# Check if input alone exceeds context length
if input_token_num >= self.context_len: if input_token_num >= self.context_len:
if self.server_args.allow_auto_truncate:
logger.warning(
f"The input ({input_token_num} tokens) is longer than the "
f"model's context length ({self.context_len} tokens). "
"Truncating the input."
)
input_ids = input_ids[:_max_req_len]
input_token_num = len(input_ids)
else:
raise ValueError( raise ValueError(
f"The input ({input_token_num} tokens) is longer than the " f"The input ({input_token_num} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)." f"model's context length ({self.context_len} tokens)."
...@@ -584,8 +594,18 @@ class TokenizerManager: ...@@ -584,8 +594,18 @@ class TokenizerManager:
max_new_tokens = obj.sampling_params.get("max_new_tokens") max_new_tokens = obj.sampling_params.get("max_new_tokens")
if ( if (
max_new_tokens is not None max_new_tokens is not None
and (max_new_tokens + input_token_num) >= self.context_len and (max_new_tokens + input_token_num) >= _max_req_len
): ):
if self.server_args.allow_auto_truncate:
logger.warning(
f"Requested token count ({input_token_num} input + {max_new_tokens} new) "
f"exceeds the model's context length ({self.context_len} tokens). "
"Truncating max_new_tokens."
)
obj.sampling_params["max_new_tokens"] = max(
0, _max_req_len - input_token_num
)
else:
total_tokens = max_new_tokens + input_token_num total_tokens = max_new_tokens + input_token_num
error_msg = ( error_msg = (
f"Requested token count exceeds the model's maximum context length " f"Requested token count exceeds the model's maximum context length "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment