Unverified Commit a32cb49b authored by Mingliang Li's avatar Mingliang Li Committed by GitHub
Browse files

feat(frontend): early-fail tokenization guard for user requests (#31366)


Signed-off-by: default avatarlimingliang <limingliang@stepfun.com>
Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarlimingliang <limingliang@stepfun.com>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 20d7454c
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import io import io
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
from unittest.mock import AsyncMock
import pybase64 import pybase64
import pytest import pytest
...@@ -28,7 +27,6 @@ class MockModelConfig: ...@@ -28,7 +27,6 @@ class MockModelConfig:
model: str = MODEL_NAME model: str = MODEL_NAME
tokenizer: str = MODEL_NAME tokenizer: str = MODEL_NAME
trust_remote_code: bool = False trust_remote_code: bool = False
max_model_len: int = 100
tokenizer_revision = None tokenizer_revision = None
tokenizer_mode = "auto" tokenizer_mode = "auto"
hf_config = MockHFConfig() hf_config = MockHFConfig()
...@@ -37,25 +35,50 @@ class MockModelConfig: ...@@ -37,25 +35,50 @@ class MockModelConfig:
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
@pytest.fixture @dataclass
def mock_model_config(): class DummyTokenizer:
return MockModelConfig() truncation_side: str = "left"
max_chars_per_token: int = 1
def __post_init__(self) -> None:
self._captured_encode_kwargs: dict = {}
def decode(self, tokens: list[int]):
return str(tokens)
def encode(self, text: str, **kwargs):
self._captured_encode_kwargs = kwargs
in_length = len(text)
truncation = kwargs.get("truncation")
max_length = kwargs.get("max_length")
if truncation and max_length is not None:
return list(range(min(in_length, max_length)))
@pytest.fixture return list(range(in_length))
def mock_async_tokenizer():
return AsyncMock()
@pytest.fixture def _build_renderer(
def renderer(mock_model_config): model_config: MockModelConfig,
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(mock_model_config) *,
truncation_side: str = "left",
max_chars_per_token: int = 1,
):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer( renderer = HfRenderer(
mock_model_config, model_config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name}, tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
) )
if not model_config.skip_tokenizer_init:
renderer._tokenizer = DummyTokenizer(
truncation_side=truncation_side,
max_chars_per_token=max_chars_per_token,
)
return renderer
class TestValidatePrompt: class TestValidatePrompt:
STRING_INPUTS = [ STRING_INPUTS = [
...@@ -81,39 +104,50 @@ class TestValidatePrompt: ...@@ -81,39 +104,50 @@ class TestValidatePrompt:
] ]
# Test that a nested mixed-type list of lists raises a TypeError. # Test that a nested mixed-type list of lists raises a TypeError.
def test_empty_input(self, renderer): def test_empty_input(self):
renderer = _build_renderer(MockModelConfig())
with pytest.raises(ValueError, match="at least one prompt"): with pytest.raises(ValueError, match="at least one prompt"):
renderer.render_completions([]) renderer.render_completions([])
def test_invalid_type(self, renderer): def test_invalid_type(self):
renderer = _build_renderer(MockModelConfig())
with pytest.raises(TypeError, match="string or an array of tokens"): with pytest.raises(TypeError, match="string or an array of tokens"):
renderer.render_completions([[1, 2], ["foo", "bar"]]) renderer.render_completions([[1, 2], ["foo", "bar"]])
@pytest.mark.parametrize("string_input", STRING_INPUTS) @pytest.mark.parametrize("string_input", STRING_INPUTS)
def test_string_consistent(self, renderer, string_input: str): def test_string_consistent(self, string_input: str):
renderer = _build_renderer(MockModelConfig())
assert renderer.render_completions(string_input) == renderer.render_completions( assert renderer.render_completions(string_input) == renderer.render_completions(
[string_input] [string_input]
) )
@pytest.mark.parametrize("token_input", TOKEN_INPUTS) @pytest.mark.parametrize("token_input", TOKEN_INPUTS)
def test_token_consistent(self, renderer, token_input: list[int]): def test_token_consistent(self, token_input: list[int]):
renderer = _build_renderer(MockModelConfig())
assert renderer.render_completions(token_input) == renderer.render_completions( assert renderer.render_completions(token_input) == renderer.render_completions(
[token_input] [token_input]
) )
@pytest.mark.parametrize("inputs_slice", INPUTS_SLICES) @pytest.mark.parametrize("inputs_slice", INPUTS_SLICES)
def test_string_slice(self, renderer, inputs_slice: slice): def test_string_slice(self, inputs_slice: slice):
renderer = _build_renderer(MockModelConfig())
assert renderer.render_completions(self.STRING_INPUTS)[ assert renderer.render_completions(self.STRING_INPUTS)[
inputs_slice inputs_slice
] == renderer.render_completions(self.STRING_INPUTS[inputs_slice]) ] == renderer.render_completions(self.STRING_INPUTS[inputs_slice])
class TestRenderPrompt: class TestRenderPrompt:
@pytest.mark.asyncio def test_token_input(self):
async def test_token_input(self, renderer): renderer = _build_renderer(MockModelConfig())
tokens = [101, 7592, 2088] tokens = [101, 7592, 2088]
prompts = await renderer.render_completions_async(tokens) prompts = renderer.render_completions(tokens)
results = await renderer.tokenize_prompts_async( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100), TokenizeParams(max_total_tokens=100),
) )
...@@ -121,11 +155,12 @@ class TestRenderPrompt: ...@@ -121,11 +155,12 @@ class TestRenderPrompt:
assert len(results) == 1 assert len(results) == 1
assert results[0]["prompt_token_ids"] == tokens assert results[0]["prompt_token_ids"] == tokens
@pytest.mark.asyncio def test_token_list_input(self):
async def test_token_list_input(self, renderer): renderer = _build_renderer(MockModelConfig())
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]] token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
prompts = await renderer.render_completions_async(token_lists) prompts = renderer.render_completions(token_lists)
results = await renderer.tokenize_prompts_async( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100), TokenizeParams(max_total_tokens=100),
) )
...@@ -135,167 +170,178 @@ class TestRenderPrompt: ...@@ -135,167 +170,178 @@ class TestRenderPrompt:
assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012] assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012]
assert results[2]["prompt_token_ids"] == [103, 4567] assert results[2]["prompt_token_ids"] == [103, 4567]
@pytest.mark.asyncio def test_text_input(self):
async def test_text_input(self, renderer, mock_async_tokenizer): renderer = _build_renderer(MockModelConfig())
mock_async_tokenizer.encode.return_value = [101, 7592, 2088]
renderer._async_tokenizer = mock_async_tokenizer
prompts = await renderer.render_completions_async("Hello world") text_input = "x" * 10
results = await renderer.tokenize_prompts_async( prompts = renderer.render_completions(text_input)
results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100), TokenizeParams(max_total_tokens=100),
) )
assert len(results) == 1 assert len(results) == 1
assert results[0]["prompt_token_ids"] == [101, 7592, 2088] assert len(results[0]["prompt_token_ids"]) == 10
mock_async_tokenizer.encode.assert_called_once()
@pytest.mark.asyncio def test_text_list_input(self):
async def test_text_list_input(self, renderer, mock_async_tokenizer): renderer = _build_renderer(MockModelConfig())
mock_async_tokenizer.encode.return_value = [101, 7592, 2088]
renderer._async_tokenizer = mock_async_tokenizer
text_list_input = ["Hello world", "How are you?", "Good morning"] text_list_input = ["x" * 10, "x" * 12, "x" * 14]
prompts = await renderer.render_completions_async(text_list_input) prompts = renderer.render_completions(text_list_input)
results = await renderer.tokenize_prompts_async( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100), TokenizeParams(max_total_tokens=100),
) )
assert len(results) == 3 assert len(results) == 3
for result in results: for text_input, result in zip(text_list_input, results):
assert result["prompt_token_ids"] == [101, 7592, 2088] assert len(result["prompt_token_ids"]) == len(text_input)
assert mock_async_tokenizer.encode.call_count == 3
@pytest.mark.asyncio def test_zero_truncation(self):
async def test_no_truncation(self, renderer, mock_async_tokenizer): renderer = _build_renderer(MockModelConfig())
mock_async_tokenizer.encode.return_value = [101, 7592, 2088]
renderer._async_tokenizer = mock_async_tokenizer
prompts = await renderer.render_completions_async("Hello world") prompts = renderer.render_completions("x" * 200)
results = await renderer.tokenize_prompts_async( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100), TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=0),
) )
assert len(results) == 1 assert len(results) == 1
call_args = mock_async_tokenizer.encode.call_args assert len(results[0]["prompt_token_ids"]) == 0
assert (
"truncation" not in call_args.kwargs
or call_args.kwargs["truncation"] is False
)
@pytest.mark.asyncio def test_pos_truncation(self):
async def test_truncation_positive(self, renderer, mock_async_tokenizer): renderer = _build_renderer(MockModelConfig())
mock_async_tokenizer.encode.return_value = [101, 7592, 2088] # Truncated
renderer._async_tokenizer = mock_async_tokenizer
prompts = await renderer.render_completions_async("Hello world") prompts = renderer.render_completions("x" * 200)
results = await renderer.tokenize_prompts_async( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams( TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=50),
max_total_tokens=200,
truncate_prompt_tokens=50,
),
) )
assert len(results) == 1 assert len(results) == 1
call_args = mock_async_tokenizer.encode.call_args assert len(results[0]["prompt_token_ids"]) == 50
assert call_args.kwargs["truncation"] is True
assert call_args.kwargs["max_length"] == 50 def test_neg_truncation(self):
renderer = _build_renderer(MockModelConfig())
@pytest.mark.asyncio
async def test_truncation_negative(self, renderer, mock_async_tokenizer): prompts = renderer.render_completions("x" * 200)
# Test that negative truncation uses model's max_model_len results = renderer.tokenize_prompts(
mock_async_tokenizer.encode.return_value = [
101,
7592,
2088,
] # Truncated to max_model_len
renderer._async_tokenizer = mock_async_tokenizer
prompts = await renderer.render_completions_async("Hello world")
results = await renderer.tokenize_prompts_async(
prompts, prompts,
TokenizeParams( TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=-1),
max_total_tokens=200,
truncate_prompt_tokens=-1,
),
) )
assert len(results) == 1 assert len(results) == 1
call_args = mock_async_tokenizer.encode.call_args assert len(results[0]["prompt_token_ids"]) == 100 # max_total_tokens
assert call_args.kwargs["truncation"] is True
assert call_args.kwargs["max_length"] == 200 def test_truncation_left(self):
renderer = _build_renderer(MockModelConfig(), truncation_side="left")
@pytest.mark.asyncio
async def test_token_truncation_last_elements(self, renderer):
# Test that token truncation keeps the last N elements
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
prompts = await renderer.render_completions_async(long_tokens) prompts = renderer.render_completions(long_tokens)
results = await renderer.tokenize_prompts_async( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams( TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
max_total_tokens=100,
truncate_prompt_tokens=5,
),
) )
assert len(results) == 1 assert len(results) == 1
# Should keep the last 5 tokens: [105, 106, 107, 108, 109] # Should keep the last 5 tokens: [105, 106, 107, 108, 109]
assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109] assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109]
@pytest.mark.asyncio def test_truncation_right(self):
async def test_max_length_exceeded(self, renderer): renderer = _build_renderer(MockModelConfig(), truncation_side="right")
long_tokens = list(range(150)) # Exceeds max_model_len=100
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
prompts = renderer.render_completions(long_tokens)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
)
prompts = await renderer.render_completions_async(long_tokens) assert len(results) == 1
# Should keep the first 5 tokens: [100, 101, 102, 103, 104]
assert results[0]["prompt_token_ids"] == [100, 101, 102, 103, 104]
with pytest.raises(ValueError, match="context length is only"): def test_text_max_length_exceeded_obvious(self):
await renderer.tokenize_prompts_async( renderer = _build_renderer(MockModelConfig(), max_chars_per_token=1)
# Exceeds max_total_tokens and max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
long_tokens = "x" * 150
prompts = renderer.render_completions(long_tokens)
with pytest.raises(
ValueError,
match="input characters and requested .* context length is only",
):
renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100), TokenizeParams(max_total_tokens=100),
) )
@pytest.mark.asyncio # Should not even attempt tokenization
async def test_no_tokenizer_for_text(self, renderer): assert renderer._tokenizer._captured_encode_kwargs == {}
renderer_no_tokenizer = HfRenderer.from_config(
MockModelConfig(skip_tokenizer_init=True), def test_text_max_length_exceeded_nonobvious(self):
tokenizer_kwargs={}, renderer = _build_renderer(MockModelConfig(), max_chars_per_token=2)
)
# Exceeds max_total_tokens but not max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
long_tokens = "x" * 150
prompts = renderer.render_completions(long_tokens)
with pytest.raises(
ValueError,
match="input tokens and requested .* context length is only",
):
renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
)
# Should only tokenize the first max_total_tokens + 1 tokens
assert renderer._tokenizer._captured_encode_kwargs["truncation"] is True
assert renderer._tokenizer._captured_encode_kwargs["max_length"] == 101
def test_token_max_length_exceeded(self):
renderer = _build_renderer(MockModelConfig())
long_tokens = list(range(150)) # Exceeds max_total_tokens=100
prompts = renderer.render_completions(long_tokens)
with pytest.raises(
ValueError,
match="input tokens and requested .* context length is only",
):
renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=None),
)
def test_no_tokenizer_for_text(self):
renderer = _build_renderer(MockModelConfig(skip_tokenizer_init=True))
prompts = await renderer_no_tokenizer.render_completions_async("Hello world") prompts = renderer.render_completions("Hello world")
with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"): with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
await renderer_no_tokenizer.tokenize_prompts_async( renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100), TokenizeParams(max_total_tokens=100),
) )
@pytest.mark.asyncio def test_token_input_with_needs_detokenization(self):
async def test_token_input_with_needs_detokenization( renderer = _build_renderer(MockModelConfig())
self, renderer, mock_async_tokenizer
):
# When needs_detokenization=True for token inputs, renderer should
# use the async tokenizer to decode and include the original text
# in the returned prompt object.
mock_async_tokenizer.decode = AsyncMock(return_value="decoded text")
renderer._async_tokenizer = mock_async_tokenizer
tokens = [1, 2, 3, 4] tokens = [1, 2, 3, 4]
prompts = await renderer.render_completions_async(tokens) prompts = renderer.render_completions(tokens)
results = await renderer.tokenize_prompts_async( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams( TokenizeParams(
max_total_tokens=renderer.config.max_model_len, max_total_tokens=100,
needs_detokenization=True, needs_detokenization=True,
), ),
) )
assert len(results) == 1 assert len(results) == 1
assert results[0]["prompt_token_ids"] == tokens assert results[0]["prompt_token_ids"] == tokens
assert results[0]["prompt"] == "decoded text" assert results[0]["prompt"] == "[1, 2, 3, 4]"
mock_async_tokenizer.decode.assert_awaited_once()
class TestRenderEmbedPrompt: class TestRenderEmbedPrompt:
...@@ -306,118 +352,121 @@ class TestRenderEmbedPrompt: ...@@ -306,118 +352,121 @@ class TestRenderEmbedPrompt:
buffer.seek(0) buffer.seek(0)
return pybase64.b64encode(buffer.read()) return pybase64.b64encode(buffer.read())
@pytest.mark.asyncio def test_single_prompt_embed(self):
async def test_single_prompt_embed(self, renderer): renderer = _build_renderer(MockModelConfig())
# Create a test tensor # Create a test tensor
test_tensor = torch.randn(10, 768, dtype=torch.float32) tensor_input = torch.randn(10, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor) embed_bytes = self._create_test_embed_bytes(tensor_input)
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes) prompts = renderer.render_completions(prompt_embeds=embed_bytes)
results = await renderer.tokenize_prompts_async( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len), TokenizeParams(max_total_tokens=100),
) )
assert len(results) == 1 assert len(results) == 1
assert torch.allclose(results[0]["prompt_embeds"], test_tensor) assert torch.equal(results[0]["prompt_embeds"], tensor_input)
def test_multiple_prompt_embeds(self):
renderer = _build_renderer(MockModelConfig())
@pytest.mark.asyncio
async def test_multiple_prompt_embeds(self, renderer):
# Create multiple test tensors # Create multiple test tensors
test_tensors = [ tensor_inputs = [
torch.randn(8, 512, dtype=torch.float32), torch.randn(8, 512, dtype=torch.float32),
torch.randn(12, 512, dtype=torch.float32), torch.randn(12, 512, dtype=torch.float32),
] ]
embed_bytes_list = [self._create_test_embed_bytes(t) for t in test_tensors]
prompts = await renderer.render_completions_async( prompts = renderer.render_completions(
prompt_embeds=embed_bytes_list prompt_embeds=[self._create_test_embed_bytes(t) for t in tensor_inputs],
) )
results = await renderer.tokenize_prompts_async( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len), TokenizeParams(max_total_tokens=100),
) )
assert len(results) == 2 assert len(results) == 2
for i, result in enumerate(results): for i, result in enumerate(results):
assert torch.allclose(result["prompt_embeds"], test_tensors[i]) assert torch.allclose(result["prompt_embeds"], tensor_inputs[i])
def test_prompt_embed_truncation(self):
renderer = _build_renderer(MockModelConfig())
@pytest.mark.asyncio
async def test_prompt_embed_truncation(self, renderer):
# Create tensor with more tokens than truncation limit # Create tensor with more tokens than truncation limit
test_tensor = torch.randn(20, 768, dtype=torch.float32) tensor_input = torch.randn(20, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes) prompts = renderer.render_completions(
results = await renderer.tokenize_prompts_async( prompt_embeds=self._create_test_embed_bytes(tensor_input),
)
results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams( TokenizeParams(
max_total_tokens=renderer.config.max_model_len, max_total_tokens=100,
truncate_prompt_tokens=10, truncate_prompt_tokens=10,
), ),
) )
assert len(results) == 1 assert len(results) == 1
# Should keep last 10 tokens # Should keep last 10 tokens
expected = test_tensor[-10:] expected = tensor_input[-10:]
assert torch.allclose(results[0]["prompt_embeds"], expected) assert torch.equal(results[0]["prompt_embeds"], expected)
def test_prompt_embed_different_dtypes(self):
renderer = _build_renderer(MockModelConfig())
@pytest.mark.asyncio
async def test_prompt_embed_different_dtypes(self, renderer):
# Test different supported dtypes # Test different supported dtypes
dtypes = [torch.float32, torch.float16, torch.bfloat16] dtypes = [torch.float32, torch.float16, torch.bfloat16]
for dtype in dtypes: for dtype in dtypes:
test_tensor = torch.randn(5, 256, dtype=dtype) tensor_input = torch.randn(5, 256, dtype=dtype)
embed_bytes = self._create_test_embed_bytes(test_tensor)
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes) prompts = renderer.render_completions(
results = await renderer.tokenize_prompts_async( prompt_embeds=self._create_test_embed_bytes(tensor_input),
)
results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len), TokenizeParams(max_total_tokens=100),
) )
assert len(results) == 1 assert len(results) == 1
assert results[0]["prompt_embeds"].dtype == dtype assert results[0]["prompt_embeds"].dtype == dtype
@pytest.mark.asyncio def test_prompt_embed_squeeze_batch_dim(self):
async def test_prompt_embed_squeeze_batch_dim(self, renderer): renderer = _build_renderer(MockModelConfig())
# Test tensor with batch dimension gets squeezed # Test tensor with batch dimension gets squeezed
test_tensor = torch.randn(1, 10, 768, dtype=torch.float32) tensor_input = torch.randn(1, 10, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes) prompts = renderer.render_completions(
results = await renderer.tokenize_prompts_async( prompt_embeds=self._create_test_embed_bytes(tensor_input),
)
results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len), TokenizeParams(max_total_tokens=100),
) )
assert len(results) == 1 assert len(results) == 1
# Should be squeezed to 2D # Should be squeezed to 2D
assert results[0]["prompt_embeds"].shape == (10, 768) assert results[0]["prompt_embeds"].shape == (10, 768)
@pytest.mark.asyncio def test_both_prompts_and_embeds(self):
async def test_both_prompts_and_embeds(self, renderer, mock_async_tokenizer): renderer = _build_renderer(MockModelConfig())
# Set up text tokenization
mock_async_tokenizer.encode.return_value = [101, 102, 103]
renderer._async_tokenizer = mock_async_tokenizer
# Create embed text_input = "Hello world"
test_tensor = torch.randn(5, 256, dtype=torch.float32) tensor_input = torch.randn(5, 256, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
prompts = await renderer.render_completions_async( prompts = renderer.render_completions(
"Hello world", text_input,
prompt_embeds=embed_bytes, prompt_embeds=self._create_test_embed_bytes(tensor_input),
) )
results = await renderer.tokenize_prompts_async( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len), TokenizeParams(max_total_tokens=100),
) )
assert len(results) == 2 assert len(results) == 2
# First should be embed prompt # First should be embed prompt
assert torch.allclose(results[0]["prompt_embeds"], test_tensor) assert torch.equal(results[0]["prompt_embeds"], tensor_input)
# Second should be tokens prompt # Second should be tokens prompt
assert "prompt_token_ids" in results[1] assert "prompt_token_ids" in results[1]
assert results[1]["prompt_token_ids"] == [101, 102, 103] assert len(results[1]["prompt_token_ids"]) == len(text_input)
...@@ -229,23 +229,53 @@ class TokenizeParams: ...@@ -229,23 +229,53 @@ class TokenizeParams:
max_length = self.truncate_prompt_tokens max_length = self.truncate_prompt_tokens
if max_length is not None and max_length < 0: if max_length is not None and max_length < 0:
max_length = self.max_input_tokens max_length = self.max_input_tokens
elif max_length is None and self.max_input_tokens is not None:
# This prevents tokenization from taking up more resources than necessary
# while still failing `self._token_len_check` as expected by users
max_length = self.max_input_tokens + 1
return dict( return dict(
truncation=self.truncate_prompt_tokens is not None, truncation=max_length is not None,
max_length=max_length, max_length=max_length,
add_special_tokens=self.add_special_tokens, add_special_tokens=self.add_special_tokens,
) )
def _apply_lowercase(self, tokenizer: TokenizerLike | None, text: str) -> str: def _text_len_check(self, tokenizer: TokenizerLike | None, text: str) -> str:
if self.do_lower_case: """Apply length checks to prompt text if necessary."""
text = text.lower() max_input_tokens = self.max_input_tokens
if max_input_tokens is None:
return text
if self.truncate_prompt_tokens is None and tokenizer is not None:
max_input_chars = max_input_tokens * tokenizer.max_chars_per_token
if len(text) > max_input_chars:
# To save resources, fail the request outright without even
# attempting tokenization
raise VLLMValidationError(
f"You passed {len(text)} input characters "
f"and requested {self.max_output_tokens} output tokens. "
f"However, the model's context length is only "
f"{self.max_total_tokens} tokens, resulting in a maximum "
f"input length of {max_input_tokens} tokens "
f"(at most {max_input_chars} characters). "
f"Please reduce the length of the input prompt.",
parameter="input_text",
value=len(text),
)
return text return text
def _text_lowercase(self, tokenizer: TokenizerLike | None, text: str) -> str:
"""Apply lowercase to prompt text if necessary."""
return text.lower() if self.do_lower_case else text
def _validate_text(self, tokenizer: TokenizerLike | None, text: str) -> str: def _validate_text(self, tokenizer: TokenizerLike | None, text: str) -> str:
"""Apply all validators to prompt text.""" """Apply all validators to prompt text."""
# TODO: Implement https://github.com/vllm-project/vllm/pull/31366 for validator in (
for validator in (self._apply_lowercase,): self._text_len_check,
self._text_lowercase,
):
text = validator(tokenizer, text) text = validator(tokenizer, text)
return text return text
...@@ -265,8 +295,8 @@ class TokenizeParams: ...@@ -265,8 +295,8 @@ class TokenizeParams:
return prompt return prompt
def _apply_padding(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S: def _token_padding(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply padding to a token sequence.""" """Apply padding to prompt tokens if necessary."""
pad_length = self.pad_prompt_tokens pad_length = self.pad_prompt_tokens
if pad_length is not None and pad_length < 0: if pad_length is not None and pad_length < 0:
pad_length = self.max_input_tokens pad_length = self.max_input_tokens
...@@ -281,8 +311,8 @@ class TokenizeParams: ...@@ -281,8 +311,8 @@ class TokenizeParams:
return tokens + [tokenizer.pad_token_id] * (pad_length - len(tokens)) return tokens + [tokenizer.pad_token_id] * (pad_length - len(tokens))
def _apply_truncation(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S: def _token_truncation(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply truncation to a token sequence.""" """Apply truncation to prompt tokens if necessary."""
max_length = self.truncate_prompt_tokens max_length = self.truncate_prompt_tokens
if max_length is not None and max_length < 0: if max_length is not None and max_length < 0:
max_length = self.max_input_tokens max_length = self.max_input_tokens
...@@ -297,18 +327,20 @@ class TokenizeParams: ...@@ -297,18 +327,20 @@ class TokenizeParams:
return tokens[:max_length] return tokens[:max_length]
def _apply_length_check(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S: def _token_len_check(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply length checks to a token sequence.""" """Apply length checks to prompt tokens if necessary."""
max_input_tokens = self.max_input_tokens max_input_tokens = self.max_input_tokens
if max_input_tokens is None:
return tokens
if max_input_tokens is not None and len(tokens) > max_input_tokens: if len(tokens) > max_input_tokens:
raise VLLMValidationError( raise VLLMValidationError(
f"You passed {len(tokens)} input tokens and " f"You passed {len(tokens)} input tokens "
f"requested {self.max_output_tokens} output tokens. " f"and requested {self.max_output_tokens} output tokens. "
f"However, the model's context length is only " f"However, the model's context length is only "
f"{self.max_total_tokens}, resulting in a maximum " f"{self.max_total_tokens} tokens, resulting in a maximum "
f"input length of {max_input_tokens}. " f"input length of {max_input_tokens} tokens. "
f"Please reduce the length of the input messages.", f"Please reduce the length of the input prompt.",
parameter="input_tokens", parameter="input_tokens",
value=len(tokens), value=len(tokens),
) )
...@@ -318,9 +350,9 @@ class TokenizeParams: ...@@ -318,9 +350,9 @@ class TokenizeParams:
def _validate_tokens(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S: def _validate_tokens(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply all validators to a token sequence.""" """Apply all validators to a token sequence."""
for validator in ( for validator in (
self._apply_padding, self._token_padding,
self._apply_truncation, self._token_truncation,
self._apply_length_check, self._token_len_check,
): ):
tokens = validator(tokenizer, tokens) tokens = validator(tokenizer, tokens)
......
...@@ -115,6 +115,10 @@ class DeepseekV32Tokenizer(CachedHfTokenizer): ...@@ -115,6 +115,10 @@ class DeepseekV32Tokenizer(CachedHfTokenizer):
def max_token_id(self) -> int: def max_token_id(self) -> int:
return self.tokenizer.max_token_id return self.tokenizer.max_token_id
@property
def max_chars_per_token(self) -> int:
return self.tokenizer.max_chars_per_token
@property @property
def truncation_side(self) -> str: def truncation_side(self) -> str:
return self.tokenizer.truncation_side return self.tokenizer.truncation_side
......
...@@ -277,6 +277,8 @@ class Grok2Tokenizer(TokenizerLike): ...@@ -277,6 +277,8 @@ class Grok2Tokenizer(TokenizerLike):
self._pad_token_id = self._special_tokens.get(PAD, self._eos_token_id) self._pad_token_id = self._special_tokens.get(PAD, self._eos_token_id)
self._unk_token_id = self._pad_token_id self._unk_token_id = self._pad_token_id
self._max_chars_per_token = max(len(tok) for tok in self._token_to_id)
def num_special_tokens_to_add(self) -> int: def num_special_tokens_to_add(self) -> int:
return 0 return 0
...@@ -312,6 +314,10 @@ class Grok2Tokenizer(TokenizerLike): ...@@ -312,6 +314,10 @@ class Grok2Tokenizer(TokenizerLike):
def max_token_id(self) -> int: def max_token_id(self) -> int:
return self._tokenizer.n_vocab - 1 return self._tokenizer.n_vocab - 1
@property
def max_chars_per_token(self) -> int:
return self._max_chars_per_token
@property @property
def truncation_side(self) -> str: def truncation_side(self) -> str:
return self._truncation_side return self._truncation_side
......
...@@ -28,6 +28,8 @@ def get_cached_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer: ...@@ -28,6 +28,8 @@ def get_cached_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:
tokenizer_len = len(tokenizer) tokenizer_len = len(tokenizer)
max_token_id = max(tokenizer_vocab.values()) max_token_id = max(tokenizer_vocab.values())
max_chars_per_token = max(len(tok) for tok in tokenizer_vocab)
# Some tokenizers (e.g., QwenTokenizer) have special tokens that # Some tokenizers (e.g., QwenTokenizer) have special tokens that
# are added and included in the implementation of the vocab_size # are added and included in the implementation of the vocab_size
# property, but not in get_vocab(); if there is an implementation # property, but not in get_vocab(); if there is an implementation
...@@ -49,6 +51,10 @@ def get_cached_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer: ...@@ -49,6 +51,10 @@ def get_cached_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:
def max_token_id(self) -> int: def max_token_id(self) -> int:
return max_token_id return max_token_id
@property
def max_chars_per_token(self) -> int:
return max_chars_per_token
def get_vocab(self) -> dict[str, int]: def get_vocab(self) -> dict[str, int]:
return tokenizer_vocab return tokenizer_vocab
......
...@@ -272,6 +272,7 @@ class MistralTokenizer(TokenizerLike): ...@@ -272,6 +272,7 @@ class MistralTokenizer(TokenizerLike):
# Vocab sorted by token id. # Vocab sorted by token id.
self._vocab = self.tokenizer.vocab() self._vocab = self.tokenizer.vocab()
self._max_token_id = self.vocab_size - 1 self._max_token_id = self.vocab_size - 1
self._max_chars_per_token = max(len(tok) for tok in self._vocab)
# Cache special tokens for faster access. # Cache special tokens for faster access.
self._special_token_ids = self._get_special_token_ids() self._special_token_ids = self._get_special_token_ids()
...@@ -325,6 +326,10 @@ class MistralTokenizer(TokenizerLike): ...@@ -325,6 +326,10 @@ class MistralTokenizer(TokenizerLike):
def max_token_id(self) -> int: def max_token_id(self) -> int:
return self._max_token_id return self._max_token_id
@property
def max_chars_per_token(self) -> int:
return self._max_chars_per_token
@property @property
def truncation_side(self) -> str: def truncation_side(self) -> str:
return self.transformers_tokenizer.truncation_side return self.transformers_tokenizer.truncation_side
......
...@@ -57,6 +57,10 @@ class TokenizerLike(Protocol): ...@@ -57,6 +57,10 @@ class TokenizerLike(Protocol):
def max_token_id(self) -> int: def max_token_id(self) -> int:
raise NotImplementedError raise NotImplementedError
@property
def max_chars_per_token(self) -> int:
raise NotImplementedError
@property @property
def truncation_side(self) -> str: def truncation_side(self) -> str:
raise NotImplementedError raise NotImplementedError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment