Unverified Commit a32cb49b authored by Mingliang Li's avatar Mingliang Li Committed by GitHub
Browse files

feat(frontend): early-fail tokenization guard for user requests (#31366)


Signed-off-by: default avatarlimingliang <limingliang@stepfun.com>
Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarlimingliang <limingliang@stepfun.com>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 20d7454c
This diff is collapsed.
...@@ -229,23 +229,53 @@ class TokenizeParams: ...@@ -229,23 +229,53 @@ class TokenizeParams:
max_length = self.truncate_prompt_tokens max_length = self.truncate_prompt_tokens
if max_length is not None and max_length < 0: if max_length is not None and max_length < 0:
max_length = self.max_input_tokens max_length = self.max_input_tokens
elif max_length is None and self.max_input_tokens is not None:
# This prevents tokenization from taking up more resources than necessary
# while still failing `self._token_len_check` as expected by users
max_length = self.max_input_tokens + 1
return dict( return dict(
truncation=self.truncate_prompt_tokens is not None, truncation=max_length is not None,
max_length=max_length, max_length=max_length,
add_special_tokens=self.add_special_tokens, add_special_tokens=self.add_special_tokens,
) )
def _apply_lowercase(self, tokenizer: TokenizerLike | None, text: str) -> str: def _text_len_check(self, tokenizer: TokenizerLike | None, text: str) -> str:
if self.do_lower_case: """Apply length checks to prompt text if necessary."""
text = text.lower() max_input_tokens = self.max_input_tokens
if max_input_tokens is None:
return text
if self.truncate_prompt_tokens is None and tokenizer is not None:
max_input_chars = max_input_tokens * tokenizer.max_chars_per_token
if len(text) > max_input_chars:
# To save resources, fail the request outright without even
# attempting tokenization
raise VLLMValidationError(
f"You passed {len(text)} input characters "
f"and requested {self.max_output_tokens} output tokens. "
f"However, the model's context length is only "
f"{self.max_total_tokens} tokens, resulting in a maximum "
f"input length of {max_input_tokens} tokens "
f"(at most {max_input_chars} characters). "
f"Please reduce the length of the input prompt.",
parameter="input_text",
value=len(text),
)
return text return text
def _text_lowercase(self, tokenizer: TokenizerLike | None, text: str) -> str:
"""Apply lowercase to prompt text if necessary."""
return text.lower() if self.do_lower_case else text
def _validate_text(self, tokenizer: TokenizerLike | None, text: str) -> str: def _validate_text(self, tokenizer: TokenizerLike | None, text: str) -> str:
"""Apply all validators to prompt text.""" """Apply all validators to prompt text."""
# TODO: Implement https://github.com/vllm-project/vllm/pull/31366 for validator in (
for validator in (self._apply_lowercase,): self._text_len_check,
self._text_lowercase,
):
text = validator(tokenizer, text) text = validator(tokenizer, text)
return text return text
...@@ -265,8 +295,8 @@ class TokenizeParams: ...@@ -265,8 +295,8 @@ class TokenizeParams:
return prompt return prompt
def _apply_padding(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S: def _token_padding(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply padding to a token sequence.""" """Apply padding to prompt tokens if necessary."""
pad_length = self.pad_prompt_tokens pad_length = self.pad_prompt_tokens
if pad_length is not None and pad_length < 0: if pad_length is not None and pad_length < 0:
pad_length = self.max_input_tokens pad_length = self.max_input_tokens
...@@ -281,8 +311,8 @@ class TokenizeParams: ...@@ -281,8 +311,8 @@ class TokenizeParams:
return tokens + [tokenizer.pad_token_id] * (pad_length - len(tokens)) return tokens + [tokenizer.pad_token_id] * (pad_length - len(tokens))
def _apply_truncation(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S: def _token_truncation(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply truncation to a token sequence.""" """Apply truncation to prompt tokens if necessary."""
max_length = self.truncate_prompt_tokens max_length = self.truncate_prompt_tokens
if max_length is not None and max_length < 0: if max_length is not None and max_length < 0:
max_length = self.max_input_tokens max_length = self.max_input_tokens
...@@ -297,18 +327,20 @@ class TokenizeParams: ...@@ -297,18 +327,20 @@ class TokenizeParams:
return tokens[:max_length] return tokens[:max_length]
def _apply_length_check(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S: def _token_len_check(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply length checks to a token sequence.""" """Apply length checks to prompt tokens if necessary."""
max_input_tokens = self.max_input_tokens max_input_tokens = self.max_input_tokens
if max_input_tokens is None:
return tokens
if max_input_tokens is not None and len(tokens) > max_input_tokens: if len(tokens) > max_input_tokens:
raise VLLMValidationError( raise VLLMValidationError(
f"You passed {len(tokens)} input tokens and " f"You passed {len(tokens)} input tokens "
f"requested {self.max_output_tokens} output tokens. " f"and requested {self.max_output_tokens} output tokens. "
f"However, the model's context length is only " f"However, the model's context length is only "
f"{self.max_total_tokens}, resulting in a maximum " f"{self.max_total_tokens} tokens, resulting in a maximum "
f"input length of {max_input_tokens}. " f"input length of {max_input_tokens} tokens. "
f"Please reduce the length of the input messages.", f"Please reduce the length of the input prompt.",
parameter="input_tokens", parameter="input_tokens",
value=len(tokens), value=len(tokens),
) )
...@@ -318,9 +350,9 @@ class TokenizeParams: ...@@ -318,9 +350,9 @@ class TokenizeParams:
def _validate_tokens(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S: def _validate_tokens(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply all validators to a token sequence.""" """Apply all validators to a token sequence."""
for validator in ( for validator in (
self._apply_padding, self._token_padding,
self._apply_truncation, self._token_truncation,
self._apply_length_check, self._token_len_check,
): ):
tokens = validator(tokenizer, tokens) tokens = validator(tokenizer, tokens)
......
...@@ -115,6 +115,10 @@ class DeepseekV32Tokenizer(CachedHfTokenizer): ...@@ -115,6 +115,10 @@ class DeepseekV32Tokenizer(CachedHfTokenizer):
def max_token_id(self) -> int: def max_token_id(self) -> int:
return self.tokenizer.max_token_id return self.tokenizer.max_token_id
@property
def max_chars_per_token(self) -> int:
return self.tokenizer.max_chars_per_token
@property @property
def truncation_side(self) -> str: def truncation_side(self) -> str:
return self.tokenizer.truncation_side return self.tokenizer.truncation_side
......
...@@ -277,6 +277,8 @@ class Grok2Tokenizer(TokenizerLike): ...@@ -277,6 +277,8 @@ class Grok2Tokenizer(TokenizerLike):
self._pad_token_id = self._special_tokens.get(PAD, self._eos_token_id) self._pad_token_id = self._special_tokens.get(PAD, self._eos_token_id)
self._unk_token_id = self._pad_token_id self._unk_token_id = self._pad_token_id
self._max_chars_per_token = max(len(tok) for tok in self._token_to_id)
def num_special_tokens_to_add(self) -> int: def num_special_tokens_to_add(self) -> int:
return 0 return 0
...@@ -312,6 +314,10 @@ class Grok2Tokenizer(TokenizerLike): ...@@ -312,6 +314,10 @@ class Grok2Tokenizer(TokenizerLike):
def max_token_id(self) -> int: def max_token_id(self) -> int:
return self._tokenizer.n_vocab - 1 return self._tokenizer.n_vocab - 1
@property
def max_chars_per_token(self) -> int:
return self._max_chars_per_token
@property @property
def truncation_side(self) -> str: def truncation_side(self) -> str:
return self._truncation_side return self._truncation_side
......
...@@ -28,6 +28,8 @@ def get_cached_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer: ...@@ -28,6 +28,8 @@ def get_cached_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:
tokenizer_len = len(tokenizer) tokenizer_len = len(tokenizer)
max_token_id = max(tokenizer_vocab.values()) max_token_id = max(tokenizer_vocab.values())
max_chars_per_token = max(len(tok) for tok in tokenizer_vocab)
# Some tokenizers (e.g., QwenTokenizer) have special tokens that # Some tokenizers (e.g., QwenTokenizer) have special tokens that
# are added and included in the implementation of the vocab_size # are added and included in the implementation of the vocab_size
# property, but not in get_vocab(); if there is an implementation # property, but not in get_vocab(); if there is an implementation
...@@ -49,6 +51,10 @@ def get_cached_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer: ...@@ -49,6 +51,10 @@ def get_cached_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:
def max_token_id(self) -> int: def max_token_id(self) -> int:
return max_token_id return max_token_id
@property
def max_chars_per_token(self) -> int:
return max_chars_per_token
def get_vocab(self) -> dict[str, int]: def get_vocab(self) -> dict[str, int]:
return tokenizer_vocab return tokenizer_vocab
......
...@@ -272,6 +272,7 @@ class MistralTokenizer(TokenizerLike): ...@@ -272,6 +272,7 @@ class MistralTokenizer(TokenizerLike):
# Vocab sorted by token id. # Vocab sorted by token id.
self._vocab = self.tokenizer.vocab() self._vocab = self.tokenizer.vocab()
self._max_token_id = self.vocab_size - 1 self._max_token_id = self.vocab_size - 1
self._max_chars_per_token = max(len(tok) for tok in self._vocab)
# Cache special tokens for faster access. # Cache special tokens for faster access.
self._special_token_ids = self._get_special_token_ids() self._special_token_ids = self._get_special_token_ids()
...@@ -325,6 +326,10 @@ class MistralTokenizer(TokenizerLike): ...@@ -325,6 +326,10 @@ class MistralTokenizer(TokenizerLike):
def max_token_id(self) -> int: def max_token_id(self) -> int:
return self._max_token_id return self._max_token_id
@property
def max_chars_per_token(self) -> int:
return self._max_chars_per_token
@property @property
def truncation_side(self) -> str: def truncation_side(self) -> str:
return self.transformers_tokenizer.truncation_side return self.transformers_tokenizer.truncation_side
......
...@@ -57,6 +57,10 @@ class TokenizerLike(Protocol): ...@@ -57,6 +57,10 @@ class TokenizerLike(Protocol):
def max_token_id(self) -> int: def max_token_id(self) -> int:
raise NotImplementedError raise NotImplementedError
@property
def max_chars_per_token(self) -> int:
raise NotImplementedError
@property @property
def truncation_side(self) -> str: def truncation_side(self) -> str:
raise NotImplementedError raise NotImplementedError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment