Unverified Commit a32cb49b authored by Mingliang Li's avatar Mingliang Li Committed by GitHub
Browse files

feat(frontend): early-fail tokenization guard for user requests (#31366)


Signed-off-by: default avatarlimingliang <limingliang@stepfun.com>
Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarlimingliang <limingliang@stepfun.com>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 20d7454c
This diff is collapsed.
......@@ -229,23 +229,53 @@ class TokenizeParams:
max_length = self.truncate_prompt_tokens
if max_length is not None and max_length < 0:
max_length = self.max_input_tokens
elif max_length is None and self.max_input_tokens is not None:
# This prevents tokenization from taking up more resources than necessary
# while still failing `self._token_len_check` as expected by users
max_length = self.max_input_tokens + 1
return dict(
truncation=self.truncate_prompt_tokens is not None,
truncation=max_length is not None,
max_length=max_length,
add_special_tokens=self.add_special_tokens,
)
def _apply_lowercase(self, tokenizer: TokenizerLike | None, text: str) -> str:
if self.do_lower_case:
text = text.lower()
def _text_len_check(self, tokenizer: TokenizerLike | None, text: str) -> str:
"""Apply length checks to prompt text if necessary."""
max_input_tokens = self.max_input_tokens
if max_input_tokens is None:
return text
if self.truncate_prompt_tokens is None and tokenizer is not None:
max_input_chars = max_input_tokens * tokenizer.max_chars_per_token
if len(text) > max_input_chars:
# To save resources, fail the request outright without even
# attempting tokenization
raise VLLMValidationError(
f"You passed {len(text)} input characters "
f"and requested {self.max_output_tokens} output tokens. "
f"However, the model's context length is only "
f"{self.max_total_tokens} tokens, resulting in a maximum "
f"input length of {max_input_tokens} tokens "
f"(at most {max_input_chars} characters). "
f"Please reduce the length of the input prompt.",
parameter="input_text",
value=len(text),
)
return text
def _text_lowercase(self, tokenizer: TokenizerLike | None, text: str) -> str:
"""Apply lowercase to prompt text if necessary."""
return text.lower() if self.do_lower_case else text
def _validate_text(self, tokenizer: TokenizerLike | None, text: str) -> str:
"""Apply all validators to prompt text."""
# TODO: Implement https://github.com/vllm-project/vllm/pull/31366
for validator in (self._apply_lowercase,):
for validator in (
self._text_len_check,
self._text_lowercase,
):
text = validator(tokenizer, text)
return text
......@@ -265,8 +295,8 @@ class TokenizeParams:
return prompt
def _apply_padding(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply padding to a token sequence."""
def _token_padding(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply padding to prompt tokens if necessary."""
pad_length = self.pad_prompt_tokens
if pad_length is not None and pad_length < 0:
pad_length = self.max_input_tokens
......@@ -281,8 +311,8 @@ class TokenizeParams:
return tokens + [tokenizer.pad_token_id] * (pad_length - len(tokens))
def _apply_truncation(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply truncation to a token sequence."""
def _token_truncation(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply truncation to prompt tokens if necessary."""
max_length = self.truncate_prompt_tokens
if max_length is not None and max_length < 0:
max_length = self.max_input_tokens
......@@ -297,18 +327,20 @@ class TokenizeParams:
return tokens[:max_length]
def _apply_length_check(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply length checks to a token sequence."""
def _token_len_check(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply length checks to prompt tokens if necessary."""
max_input_tokens = self.max_input_tokens
if max_input_tokens is None:
return tokens
if max_input_tokens is not None and len(tokens) > max_input_tokens:
if len(tokens) > max_input_tokens:
raise VLLMValidationError(
f"You passed {len(tokens)} input tokens and "
f"requested {self.max_output_tokens} output tokens. "
f"You passed {len(tokens)} input tokens "
f"and requested {self.max_output_tokens} output tokens. "
f"However, the model's context length is only "
f"{self.max_total_tokens}, resulting in a maximum "
f"input length of {max_input_tokens}. "
f"Please reduce the length of the input messages.",
f"{self.max_total_tokens} tokens, resulting in a maximum "
f"input length of {max_input_tokens} tokens. "
f"Please reduce the length of the input prompt.",
parameter="input_tokens",
value=len(tokens),
)
......@@ -318,9 +350,9 @@ class TokenizeParams:
def _validate_tokens(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply all validators to a token sequence."""
for validator in (
self._apply_padding,
self._apply_truncation,
self._apply_length_check,
self._token_padding,
self._token_truncation,
self._token_len_check,
):
tokens = validator(tokenizer, tokens)
......
......@@ -115,6 +115,10 @@ class DeepseekV32Tokenizer(CachedHfTokenizer):
def max_token_id(self) -> int:
return self.tokenizer.max_token_id
@property
def max_chars_per_token(self) -> int:
return self.tokenizer.max_chars_per_token
@property
def truncation_side(self) -> str:
return self.tokenizer.truncation_side
......
......@@ -277,6 +277,8 @@ class Grok2Tokenizer(TokenizerLike):
self._pad_token_id = self._special_tokens.get(PAD, self._eos_token_id)
self._unk_token_id = self._pad_token_id
self._max_chars_per_token = max(len(tok) for tok in self._token_to_id)
def num_special_tokens_to_add(self) -> int:
return 0
......@@ -312,6 +314,10 @@ class Grok2Tokenizer(TokenizerLike):
def max_token_id(self) -> int:
return self._tokenizer.n_vocab - 1
@property
def max_chars_per_token(self) -> int:
return self._max_chars_per_token
@property
def truncation_side(self) -> str:
return self._truncation_side
......
......@@ -28,6 +28,8 @@ def get_cached_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:
tokenizer_len = len(tokenizer)
max_token_id = max(tokenizer_vocab.values())
max_chars_per_token = max(len(tok) for tok in tokenizer_vocab)
# Some tokenizers (e.g., QwenTokenizer) have special tokens that
# are added and included in the implementation of the vocab_size
# property, but not in get_vocab(); if there is an implementation
......@@ -49,6 +51,10 @@ def get_cached_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:
def max_token_id(self) -> int:
return max_token_id
@property
def max_chars_per_token(self) -> int:
return max_chars_per_token
def get_vocab(self) -> dict[str, int]:
return tokenizer_vocab
......
......@@ -272,6 +272,7 @@ class MistralTokenizer(TokenizerLike):
# Vocab sorted by token id.
self._vocab = self.tokenizer.vocab()
self._max_token_id = self.vocab_size - 1
self._max_chars_per_token = max(len(tok) for tok in self._vocab)
# Cache special tokens for faster access.
self._special_token_ids = self._get_special_token_ids()
......@@ -325,6 +326,10 @@ class MistralTokenizer(TokenizerLike):
def max_token_id(self) -> int:
return self._max_token_id
@property
def max_chars_per_token(self) -> int:
return self._max_chars_per_token
@property
def truncation_side(self) -> str:
return self.transformers_tokenizer.truncation_side
......
......@@ -57,6 +57,10 @@ class TokenizerLike(Protocol):
def max_token_id(self) -> int:
raise NotImplementedError
@property
def max_chars_per_token(self) -> int:
raise NotImplementedError
@property
def truncation_side(self) -> str:
raise NotImplementedError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment