Unverified Commit 08ebdf79 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix the `--allow-auto-truncate` argument in tokenizer manager. (#9391)


Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 42c87045
......@@ -565,14 +565,24 @@ class TokenizerManager:
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
) -> None:
"""Validates that the input token count and the requested token count doesn't exceed the model's context length."""
# FIXME: unify the length validation logic with the one in the scheduler.
_max_req_len = self.context_len - 1
input_token_num = len(input_ids) if input_ids is not None else 0
# Check if input alone exceeds context length
if input_token_num >= self.context_len:
raise ValueError(
f"The input ({input_token_num} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)."
)
if self.server_args.allow_auto_truncate:
logger.warning(
f"The input ({input_token_num} tokens) is longer than the "
f"model's context length ({self.context_len} tokens). "
"Truncating the input."
)
input_ids = input_ids[:_max_req_len]
input_token_num = len(input_ids)
else:
raise ValueError(
f"The input ({input_token_num} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)."
)
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError(
......@@ -584,17 +594,27 @@ class TokenizerManager:
max_new_tokens = obj.sampling_params.get("max_new_tokens")
if (
max_new_tokens is not None
and (max_new_tokens + input_token_num) >= self.context_len
and (max_new_tokens + input_token_num) >= _max_req_len
):
total_tokens = max_new_tokens + input_token_num
error_msg = (
f"Requested token count exceeds the model's maximum context length "
f"of {self.context_len} tokens. You requested a total of {total_tokens} "
f"tokens: {input_token_num} tokens from the input messages and "
f"{max_new_tokens} tokens for the completion. Please reduce the number "
f"of tokens in the input messages or the completion to fit within the limit."
)
raise ValueError(error_msg)
if self.server_args.allow_auto_truncate:
logger.warning(
f"Requested token count ({input_token_num} input + {max_new_tokens} new) "
f"exceeds the model's context length ({self.context_len} tokens). "
"Truncating max_new_tokens."
)
obj.sampling_params["max_new_tokens"] = max(
0, _max_req_len - input_token_num
)
else:
total_tokens = max_new_tokens + input_token_num
error_msg = (
f"Requested token count exceeds the model's maximum context length "
f"of {self.context_len} tokens. You requested a total of {total_tokens} "
f"tokens: {input_token_num} tokens from the input messages and "
f"{max_new_tokens} tokens for the completion. Please reduce the number "
f"of tokens in the input messages or the completion to fit within the limit."
)
raise ValueError(error_msg)
if isinstance(obj, GenerateReqInput):
if (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment