Unverified Commit cd8b405b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Consolidate sequence normalization and enc-dec parsing (#33928)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 4707f7eb
...@@ -54,6 +54,7 @@ class MockModelConfig: ...@@ -54,6 +54,7 @@ class MockModelConfig:
generation_config: str = "auto" generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False skip_tokenizer_init = False
is_encoder_decoder: bool = False
def get_diff_sampling_param(self): def get_diff_sampling_param(self):
return self.diff_sampling_param or {} return self.diff_sampling_param or {}
......
...@@ -53,6 +53,7 @@ class MockModelConfig: ...@@ -53,6 +53,7 @@ class MockModelConfig:
generation_config: str = "auto" generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False skip_tokenizer_init = False
is_encoder_decoder: bool = False
def get_diff_sampling_param(self): def get_diff_sampling_param(self):
return self.diff_sampling_param or {} return self.diff_sampling_param or {}
......
...@@ -52,6 +52,7 @@ class MockModelConfig: ...@@ -52,6 +52,7 @@ class MockModelConfig:
encoder_config = None encoder_config = None
generation_config: str = "auto" generation_config: str = "auto"
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
def get_diff_sampling_param(self): def get_diff_sampling_param(self):
return self.diff_sampling_param or {} return self.diff_sampling_param or {}
......
...@@ -529,6 +529,7 @@ class MockModelConfig: ...@@ -529,6 +529,7 @@ class MockModelConfig:
generation_config: str = "auto" generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
def get_diff_sampling_param(self): def get_diff_sampling_param(self):
return self.diff_sampling_param or {} return self.diff_sampling_param or {}
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.renderers.inputs.preprocess import prompt_to_seq
def test_empty_input():
assert prompt_to_seq([]) == []
assert prompt_to_seq([[]]) == [[]]
assert prompt_to_seq([[], []]) == [[], []]
def test_text_input():
assert prompt_to_seq("foo") == ["foo"]
assert prompt_to_seq(["foo"]) == ["foo"]
assert prompt_to_seq(["foo", "bar"]) == ["foo", "bar"]
def test_token_input():
assert prompt_to_seq([1, 2]) == [[1, 2]]
assert prompt_to_seq([[1, 2]]) == [[1, 2]]
assert prompt_to_seq([[1, 2], [3, 4]]) == [[1, 2], [3, 4]]
def test_text_token_input():
assert prompt_to_seq([[1, 2], "foo"]) == [[1, 2], "foo"]
assert prompt_to_seq(["foo", [1, 2]]) == ["foo", [1, 2]]
def test_bytes_input():
assert prompt_to_seq(b"foo") == [b"foo"]
assert prompt_to_seq([b"foo"]) == [b"foo"]
assert prompt_to_seq([b"foo", b"bar"]) == [b"foo", b"bar"]
def test_dict_input():
assert prompt_to_seq({"prompt": "foo"}) == [{"prompt": "foo"}]
assert prompt_to_seq([{"prompt": "foo"}]) == [{"prompt": "foo"}]
assert prompt_to_seq([{"prompt": "foo"}, {"prompt_token_ids": [1, 2]}]) == [
{"prompt": "foo"},
{"prompt_token_ids": [1, 2]},
]
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import io import io
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
...@@ -9,8 +10,11 @@ import pybase64 ...@@ -9,8 +10,11 @@ import pybase64
import pytest import pytest
import torch import torch
from vllm.config import ModelConfig
from vllm.inputs import SingletonPrompt
from vllm.renderers import TokenizeParams from vllm.renderers import TokenizeParams
from vllm.renderers.hf import HfRenderer from vllm.renderers.hf import HfRenderer
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
from vllm.tokenizers.registry import tokenizer_args_from_config from vllm.tokenizers.registry import tokenizer_args_from_config
MODEL_NAME = "openai-community/gpt2" MODEL_NAME = "openai-community/gpt2"
...@@ -33,6 +37,7 @@ class MockModelConfig: ...@@ -33,6 +37,7 @@ class MockModelConfig:
encoder_config: dict[str, Any] | None = None encoder_config: dict[str, Any] | None = None
enable_prompt_embeds: bool = True enable_prompt_embeds: bool = True
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
@dataclass @dataclass
...@@ -80,73 +85,42 @@ def _build_renderer( ...@@ -80,73 +85,42 @@ def _build_renderer(
return renderer return renderer
class TestValidatePrompt: def _preprocess_prompt(
STRING_INPUTS = [ mdoel_config: ModelConfig,
"", prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes],
"foo", ):
"foo bar", return [
"foo baz bar", (
"foo bar qux baz", prompt
if isinstance(prompt, bytes)
else parse_model_prompt(mdoel_config, prompt)
)
for prompt in prompt_to_seq(prompt_or_prompts)
] ]
TOKEN_INPUTS = [
[-1],
[1],
[1, 2],
[1, 3, 4],
[1, 2, 4, 3],
]
INPUTS_SLICES = [ class TestValidatePrompt:
slice(None, None, -1),
slice(None, None, 2),
slice(None, None, -2),
]
# Test that a nested mixed-type list of lists raises a TypeError.
def test_empty_input(self): def test_empty_input(self):
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
with pytest.raises(ValueError, match="at least one prompt"): with pytest.raises(ValueError, match="at least one prompt"):
renderer.render_completions([]) renderer.render_prompts(_preprocess_prompt(renderer.config, []))
def test_invalid_type(self): def test_invalid_type(self):
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
with pytest.raises(TypeError, match="string or an array of tokens"): with pytest.raises(TypeError, match="should be a list of integers"):
renderer.render_completions([[1, 2], ["foo", "bar"]]) renderer.render_prompts(
_preprocess_prompt(renderer.config, [[1, 2], ["foo", "bar"]]) # type: ignore[arg-type]
@pytest.mark.parametrize("string_input", STRING_INPUTS)
def test_string_consistent(self, string_input: str):
renderer = _build_renderer(MockModelConfig())
assert renderer.render_completions(string_input) == renderer.render_completions(
[string_input]
) )
@pytest.mark.parametrize("token_input", TOKEN_INPUTS)
def test_token_consistent(self, token_input: list[int]):
renderer = _build_renderer(MockModelConfig())
assert renderer.render_completions(token_input) == renderer.render_completions(
[token_input]
)
@pytest.mark.parametrize("inputs_slice", INPUTS_SLICES)
def test_string_slice(self, inputs_slice: slice):
renderer = _build_renderer(MockModelConfig())
assert renderer.render_completions(self.STRING_INPUTS)[
inputs_slice
] == renderer.render_completions(self.STRING_INPUTS[inputs_slice])
class TestRenderPrompt: class TestRenderPrompt:
def test_token_input(self): def test_token_input(self):
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
tokens = [101, 7592, 2088] tokens = [101, 7592, 2088]
prompts = renderer.render_completions(tokens) prompts = renderer.render_prompts(_preprocess_prompt(renderer.config, tokens))
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100), TokenizeParams(max_total_tokens=100),
...@@ -159,7 +133,9 @@ class TestRenderPrompt: ...@@ -159,7 +133,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]] token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
prompts = renderer.render_completions(token_lists) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, token_lists)
)
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100), TokenizeParams(max_total_tokens=100),
...@@ -174,7 +150,9 @@ class TestRenderPrompt: ...@@ -174,7 +150,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
text_input = "x" * 10 text_input = "x" * 10
prompts = renderer.render_completions(text_input) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, text_input)
)
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100), TokenizeParams(max_total_tokens=100),
...@@ -187,7 +165,9 @@ class TestRenderPrompt: ...@@ -187,7 +165,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
text_list_input = ["x" * 10, "x" * 12, "x" * 14] text_list_input = ["x" * 10, "x" * 12, "x" * 14]
prompts = renderer.render_completions(text_list_input) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, text_list_input)
)
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100), TokenizeParams(max_total_tokens=100),
...@@ -200,7 +180,9 @@ class TestRenderPrompt: ...@@ -200,7 +180,9 @@ class TestRenderPrompt:
def test_zero_truncation(self): def test_zero_truncation(self):
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
prompts = renderer.render_completions("x" * 200) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, "x" * 200)
)
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=0), TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=0),
...@@ -212,7 +194,9 @@ class TestRenderPrompt: ...@@ -212,7 +194,9 @@ class TestRenderPrompt:
def test_pos_truncation(self): def test_pos_truncation(self):
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
prompts = renderer.render_completions("x" * 200) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, "x" * 200)
)
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=50), TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=50),
...@@ -224,7 +208,9 @@ class TestRenderPrompt: ...@@ -224,7 +208,9 @@ class TestRenderPrompt:
def test_neg_truncation(self): def test_neg_truncation(self):
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
prompts = renderer.render_completions("x" * 200) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, "x" * 200)
)
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=-1), TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=-1),
...@@ -237,7 +223,9 @@ class TestRenderPrompt: ...@@ -237,7 +223,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig(), truncation_side="left") renderer = _build_renderer(MockModelConfig(), truncation_side="left")
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
prompts = renderer.render_completions(long_tokens) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5), TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
...@@ -251,7 +239,9 @@ class TestRenderPrompt: ...@@ -251,7 +239,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig(), truncation_side="right") renderer = _build_renderer(MockModelConfig(), truncation_side="right")
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
prompts = renderer.render_completions(long_tokens) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5), TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
...@@ -266,7 +256,9 @@ class TestRenderPrompt: ...@@ -266,7 +256,9 @@ class TestRenderPrompt:
# Exceeds max_total_tokens and max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN # Exceeds max_total_tokens and max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
long_tokens = "x" * 150 long_tokens = "x" * 150
prompts = renderer.render_completions(long_tokens) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
with pytest.raises( with pytest.raises(
ValueError, ValueError,
...@@ -285,7 +277,9 @@ class TestRenderPrompt: ...@@ -285,7 +277,9 @@ class TestRenderPrompt:
# Exceeds max_total_tokens but not max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN # Exceeds max_total_tokens but not max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
long_tokens = "x" * 150 long_tokens = "x" * 150
prompts = renderer.render_completions(long_tokens) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
with pytest.raises( with pytest.raises(
ValueError, ValueError,
...@@ -304,7 +298,9 @@ class TestRenderPrompt: ...@@ -304,7 +298,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
long_tokens = list(range(150)) # Exceeds max_total_tokens=100 long_tokens = list(range(150)) # Exceeds max_total_tokens=100
prompts = renderer.render_completions(long_tokens) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
with pytest.raises( with pytest.raises(
ValueError, ValueError,
...@@ -318,7 +314,9 @@ class TestRenderPrompt: ...@@ -318,7 +314,9 @@ class TestRenderPrompt:
def test_no_tokenizer_for_text(self): def test_no_tokenizer_for_text(self):
renderer = _build_renderer(MockModelConfig(skip_tokenizer_init=True)) renderer = _build_renderer(MockModelConfig(skip_tokenizer_init=True))
prompts = renderer.render_completions("Hello world") prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, "Hello world")
)
with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"): with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
renderer.tokenize_prompts( renderer.tokenize_prompts(
...@@ -330,7 +328,7 @@ class TestRenderPrompt: ...@@ -330,7 +328,7 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
tokens = [1, 2, 3, 4] tokens = [1, 2, 3, 4]
prompts = renderer.render_completions(tokens) prompts = renderer.render_prompts(_preprocess_prompt(renderer.config, tokens))
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams( TokenizeParams(
...@@ -359,7 +357,9 @@ class TestRenderEmbedPrompt: ...@@ -359,7 +357,9 @@ class TestRenderEmbedPrompt:
tensor_input = torch.randn(10, 768, dtype=torch.float32) tensor_input = torch.randn(10, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(tensor_input) embed_bytes = self._create_test_embed_bytes(tensor_input)
prompts = renderer.render_completions(prompt_embeds=embed_bytes) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, embed_bytes)
)
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100), TokenizeParams(max_total_tokens=100),
...@@ -377,8 +377,11 @@ class TestRenderEmbedPrompt: ...@@ -377,8 +377,11 @@ class TestRenderEmbedPrompt:
torch.randn(12, 512, dtype=torch.float32), torch.randn(12, 512, dtype=torch.float32),
] ]
prompts = renderer.render_completions( prompts = renderer.render_prompts(
prompt_embeds=[self._create_test_embed_bytes(t) for t in tensor_inputs], _preprocess_prompt(
renderer.config,
[self._create_test_embed_bytes(t) for t in tensor_inputs],
)
) )
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
...@@ -395,8 +398,10 @@ class TestRenderEmbedPrompt: ...@@ -395,8 +398,10 @@ class TestRenderEmbedPrompt:
# Create tensor with more tokens than truncation limit # Create tensor with more tokens than truncation limit
tensor_input = torch.randn(20, 768, dtype=torch.float32) tensor_input = torch.randn(20, 768, dtype=torch.float32)
prompts = renderer.render_completions( prompts = renderer.render_prompts(
prompt_embeds=self._create_test_embed_bytes(tensor_input), _preprocess_prompt(
renderer.config, self._create_test_embed_bytes(tensor_input)
)
) )
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
...@@ -420,8 +425,10 @@ class TestRenderEmbedPrompt: ...@@ -420,8 +425,10 @@ class TestRenderEmbedPrompt:
for dtype in dtypes: for dtype in dtypes:
tensor_input = torch.randn(5, 256, dtype=dtype) tensor_input = torch.randn(5, 256, dtype=dtype)
prompts = renderer.render_completions( prompts = renderer.render_prompts(
prompt_embeds=self._create_test_embed_bytes(tensor_input), _preprocess_prompt(
renderer.config, self._create_test_embed_bytes(tensor_input)
)
) )
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
...@@ -437,8 +444,10 @@ class TestRenderEmbedPrompt: ...@@ -437,8 +444,10 @@ class TestRenderEmbedPrompt:
# Test tensor with batch dimension gets squeezed # Test tensor with batch dimension gets squeezed
tensor_input = torch.randn(1, 10, 768, dtype=torch.float32) tensor_input = torch.randn(1, 10, 768, dtype=torch.float32)
prompts = renderer.render_completions( prompts = renderer.render_prompts(
prompt_embeds=self._create_test_embed_bytes(tensor_input), _preprocess_prompt(
renderer.config, self._create_test_embed_bytes(tensor_input)
)
) )
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
...@@ -455,9 +464,11 @@ class TestRenderEmbedPrompt: ...@@ -455,9 +464,11 @@ class TestRenderEmbedPrompt:
text_input = "Hello world" text_input = "Hello world"
tensor_input = torch.randn(5, 256, dtype=torch.float32) tensor_input = torch.randn(5, 256, dtype=torch.float32)
prompts = renderer.render_completions( prompts = renderer.render_prompts(
text_input, _preprocess_prompt(
prompt_embeds=self._create_test_embed_bytes(tensor_input), renderer.config,
[text_input, self._create_test_embed_bytes(tensor_input)],
)
) )
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
...@@ -465,8 +476,8 @@ class TestRenderEmbedPrompt: ...@@ -465,8 +476,8 @@ class TestRenderEmbedPrompt:
) )
assert len(results) == 2 assert len(results) == 2
# First should be embed prompt # First should be tokens prompt
assert torch.equal(results[0]["prompt_embeds"], tensor_input) assert "prompt_token_ids" in results[0]
# Second should be tokens prompt assert len(results[0]["prompt_token_ids"]) == len(text_input)
assert "prompt_token_ids" in results[1] # Second should be embed prompt
assert len(results[1]["prompt_token_ids"]) == len(text_input) assert torch.equal(results[1]["prompt_embeds"], tensor_input)
...@@ -3,16 +3,40 @@ ...@@ -3,16 +3,40 @@
import asyncio import asyncio
import time import time
from dataclasses import dataclass
from typing import Any
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from vllm.config import ModelConfig
from vllm.renderers import ChatParams from vllm.renderers import ChatParams
from vllm.renderers.mistral import MistralRenderer, safe_apply_chat_template from vllm.renderers.mistral import MistralRenderer, safe_apply_chat_template
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
@dataclass
class MockHFConfig:
model_type: str = "any"
@dataclass
class MockModelConfig:
runner_type = "generate"
model: str = MODEL_NAME
tokenizer: str = MODEL_NAME
trust_remote_code: bool = False
max_model_len: int = 100
tokenizer_revision = None
tokenizer_mode = "mistral"
hf_config = MockHFConfig()
encoder_config: dict[str, Any] | None = None
enable_prompt_embeds: bool = True
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_mistral_tokenizer_does_not_block_event_loop(): async def test_async_mistral_tokenizer_does_not_block_event_loop():
...@@ -23,9 +47,10 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop(): ...@@ -23,9 +47,10 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop():
time.sleep(2) time.sleep(2)
return expected_tokens return expected_tokens
mock_model_config = MockModelConfig(skip_tokenizer_init=True)
mock_tokenizer = Mock(spec=MistralTokenizer) mock_tokenizer = Mock(spec=MistralTokenizer)
mock_tokenizer.apply_chat_template = mocked_apply_chat_template mock_tokenizer.apply_chat_template = mocked_apply_chat_template
mock_renderer = MistralRenderer(Mock(spec=ModelConfig), tokenizer_kwargs={}) mock_renderer = MistralRenderer(mock_model_config, tokenizer_kwargs={})
mock_renderer._tokenizer = mock_tokenizer mock_renderer._tokenizer = mock_tokenizer
task = mock_renderer.render_messages_async([], ChatParams()) task = mock_renderer.render_messages_async([], ChatParams())
......
...@@ -4,52 +4,13 @@ ...@@ -4,52 +4,13 @@
import pytest import pytest
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.inputs import zip_enc_dec_prompts
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
pytestmark = pytest.mark.cpu_test pytestmark = pytest.mark.cpu_test
@pytest.mark.parametrize( @pytest.mark.parametrize("model_id", ["facebook/chameleon-7b"])
"mm_processor_kwargs,expected_mm_kwargs", @pytest.mark.parametrize("prompt", ["", {"prompt_token_ids": []}])
[
(None, [{}, {}]),
({}, [{}, {}]),
({"foo": 100}, [{"foo": 100}, {"foo": 100}]),
([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]),
],
)
def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
"""Test mm_processor_kwargs init for zipping enc/dec prompts."""
encoder_prompts = ["An encoder prompt", "Another encoder prompt"]
decoder_prompts = ["A decoder prompt", "Another decoder prompt"]
zipped_prompts = zip_enc_dec_prompts(
encoder_prompts, decoder_prompts, mm_processor_kwargs
)
assert len(zipped_prompts) == len(encoder_prompts) == len(decoder_prompts)
for enc, dec, exp_kwargs, zipped in zip(
encoder_prompts, decoder_prompts, expected_mm_kwargs, zipped_prompts
):
assert isinstance(zipped, dict)
assert len(zipped.keys()) == 3
assert zipped["encoder_prompt"] == enc
assert zipped["decoder_prompt"] == dec
assert zipped["mm_processor_kwargs"] == exp_kwargs
@pytest.mark.parametrize(
"model_id",
[
"facebook/chameleon-7b",
],
)
@pytest.mark.parametrize(
"prompt",
[
"",
{"prompt_token_ids": []},
],
)
@pytest.mark.skip( @pytest.mark.skip(
reason=( reason=(
"Applying huggingface processor on text inputs results in " "Applying huggingface processor on text inputs results in "
......
...@@ -16,6 +16,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput ...@@ -16,6 +16,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor from vllm.plugins.io_processors import IOProcessor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer from vllm.renderers import BaseRenderer
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
...@@ -53,7 +54,11 @@ class EngineClient(ABC): ...@@ -53,7 +54,11 @@ class EngineClient(ABC):
@abstractmethod @abstractmethod
def generate( def generate(
self, self,
prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None], prompt: EngineCoreRequest
| PromptType
| DictPrompt
| TokPrompt
| AsyncGenerator[StreamingInput, None],
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
*, *,
...@@ -70,7 +75,7 @@ class EngineClient(ABC): ...@@ -70,7 +75,7 @@ class EngineClient(ABC):
@abstractmethod @abstractmethod
def encode( def encode(
self, self,
prompt: PromptType, prompt: PromptType | DictPrompt | TokPrompt,
pooling_params: PoolingParams, pooling_params: PoolingParams,
request_id: str, request_id: str,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
......
This diff is collapsed.
...@@ -67,12 +67,13 @@ from vllm.entrypoints.openai.parser.harmony_utils import ( ...@@ -67,12 +67,13 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
) )
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import EmbedsPrompt, TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.parser import ParserManager from vllm.parser import ParserManager
from vllm.reasoning import ReasoningParser from vllm.reasoning import ReasoningParser
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import ( from vllm.tokenizers.mistral import (
...@@ -218,10 +219,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -218,10 +219,7 @@ class OpenAIServingChat(OpenAIServing):
async def render_chat_request( async def render_chat_request(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> ( ) -> tuple[list[ConversationMessage], list[TokPrompt]] | ErrorResponse:
tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]
| ErrorResponse
):
""" """
render chat request by validating and preprocessing inputs. render chat request by validating and preprocessing inputs.
...@@ -380,7 +378,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -380,7 +378,7 @@ class OpenAIServingChat(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = [] generators: list[AsyncGenerator[RequestOutput, None]] = []
try: try:
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
prompt_text = engine_prompt.get("prompt") prompt_text = self._extract_prompt_text(engine_prompt)
# If we are creating sub requests for multiple prompts, ensure that they # If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids. # have unique request ids.
...@@ -389,10 +387,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -389,10 +387,10 @@ class OpenAIServingChat(OpenAIServing):
) )
max_tokens = get_max_tokens( max_tokens = get_max_tokens(
max_model_len=self.max_model_len, self.max_model_len,
request=request, request,
prompt=engine_prompt, self._extract_prompt_len(engine_prompt),
default_sampling_params=self.default_sampling_params, self.default_sampling_params,
) )
sampling_params: SamplingParams | BeamSearchParams sampling_params: SamplingParams | BeamSearchParams
......
...@@ -34,10 +34,10 @@ from vllm.entrypoints.openai.engine.serving import ( ...@@ -34,10 +34,10 @@ from vllm.entrypoints.openai.engine.serving import (
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
...@@ -78,7 +78,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -78,7 +78,7 @@ class OpenAIServingCompletion(OpenAIServing):
async def render_completion_request( async def render_completion_request(
self, self,
request: CompletionRequest, request: CompletionRequest,
) -> list[TokensPrompt | EmbedsPrompt] | ErrorResponse: ) -> list[TokPrompt] | ErrorResponse:
""" """
render completion request by validating and preprocessing inputs. render completion request by validating and preprocessing inputs.
...@@ -160,13 +160,13 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -160,13 +160,13 @@ class OpenAIServingCompletion(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = [] generators: list[AsyncGenerator[RequestOutput, None]] = []
try: try:
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
prompt_text = engine_prompt.get("prompt") prompt_text = self._extract_prompt_text(engine_prompt)
max_tokens = get_max_tokens( max_tokens = get_max_tokens(
max_model_len=self.max_model_len, self.max_model_len,
request=request, request,
prompt=engine_prompt, self._extract_prompt_len(engine_prompt),
default_sampling_params=self.default_sampling_params, self.default_sampling_params,
) )
sampling_params: SamplingParams | BeamSearchParams sampling_params: SamplingParams | BeamSearchParams
...@@ -277,7 +277,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -277,7 +277,7 @@ class OpenAIServingCompletion(OpenAIServing):
# with the inputs token IDs # with the inputs token IDs
if final_res.prompt is None: if final_res.prompt is None:
engine_prompt = engine_prompts[i] engine_prompt = engine_prompts[i]
final_res.prompt = engine_prompt.get("prompt") final_res.prompt = self._extract_prompt_text(engine_prompt)
final_res_batch_checked = cast(list[RequestOutput], final_res_batch) final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
...@@ -313,7 +313,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -313,7 +313,7 @@ class OpenAIServingCompletion(OpenAIServing):
async def completion_stream_generator( async def completion_stream_generator(
self, self,
request: CompletionRequest, request: CompletionRequest,
engine_prompts: list[TokensPrompt | EmbedsPrompt], engine_prompts: list[TokPrompt],
result_generator: AsyncIterator[tuple[int, RequestOutput]], result_generator: AsyncIterator[tuple[int, RequestOutput]],
request_id: str, request_id: str,
created_time: int, created_time: int,
...@@ -347,7 +347,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -347,7 +347,7 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_text = res.prompt prompt_text = res.prompt
if prompt_text is None: if prompt_text is None:
engine_prompt = engine_prompts[prompt_idx] engine_prompt = engine_prompts[prompt_idx]
prompt_text = engine_prompt.get("prompt") prompt_text = self._extract_prompt_text(engine_prompt)
# Prompt details are excluded from later streamed outputs # Prompt details are excluded from later streamed outputs
if prompt_token_ids is not None: if prompt_token_ids is not None:
......
...@@ -96,11 +96,7 @@ from vllm.entrypoints.serve.tokenize.protocol import ( ...@@ -96,11 +96,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
) )
from vllm.entrypoints.utils import get_max_tokens, sanitize_message from vllm.entrypoints.utils import get_max_tokens, sanitize_message
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, PromptType, TokensPrompt from vllm.inputs.data import PromptType, SingletonPrompt, TokensPrompt
from vllm.inputs.parse import (
get_prompt_components,
is_explicit_encoder_decoder_prompt,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -108,6 +104,14 @@ from vllm.multimodal import MultiModalDataDict ...@@ -108,6 +104,14 @@ from vllm.multimodal import MultiModalDataDict
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.renderers.inputs import TokPrompt
from vllm.renderers.inputs.preprocess import (
SingletonDictPrompt,
extract_prompt_components,
extract_prompt_len,
parse_model_prompt,
prompt_to_seq,
)
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser from vllm.tool_parsers import ToolParser
...@@ -203,7 +207,7 @@ class ServeContext(Generic[RequestT]): ...@@ -203,7 +207,7 @@ class ServeContext(Generic[RequestT]):
request_id: str request_id: str
created_time: int = field(default_factory=lambda: int(time.time())) created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None lora_request: LoRARequest | None = None
engine_prompts: list[TokensPrompt | EmbedsPrompt] | None = None engine_prompts: list[TokPrompt] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = ( result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None None
...@@ -247,7 +251,7 @@ class OpenAIServing: ...@@ -247,7 +251,7 @@ class OpenAIServing:
async def beam_search( async def beam_search(
self, self,
prompt: PromptType, prompt: TokPrompt,
request_id: str, request_id: str,
params: BeamSearchParams, params: BeamSearchParams,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
...@@ -271,20 +275,12 @@ class OpenAIServing: ...@@ -271,20 +275,12 @@ class OpenAIServing:
eos_token_id: int = tokenizer.eos_token_id # type: ignore eos_token_id: int = tokenizer.eos_token_id # type: ignore
if is_explicit_encoder_decoder_prompt(prompt): if isinstance(prompt, dict) and "encoder_prompt" in prompt:
raise NotImplementedError raise NotImplementedError("Encoder-decoder prompt not supported")
prompt_text: str | None prompt_text: str | None = prompt.get("prompt") # type: ignore
prompt_token_ids: list[int] prompt_token_ids: list[int] = prompt.get("prompt_token_ids", []) # type: ignore
multi_modal_data: MultiModalDataDict | None multi_modal_data: MultiModalDataDict | None = prompt.get("multi_modal_data") # type: ignore
if isinstance(prompt, str):
prompt_text = prompt
prompt_token_ids = []
multi_modal_data = None
else:
prompt_text = prompt.get("prompt") # type: ignore
prompt_token_ids = prompt.get("prompt_token_ids", []) # type: ignore
multi_modal_data = prompt.get("multi_modal_data") # type: ignore
mm_processor_kwargs: dict[str, Any] | None = None mm_processor_kwargs: dict[str, Any] | None = None
...@@ -963,22 +959,40 @@ class OpenAIServing: ...@@ -963,22 +959,40 @@ class OpenAIServing:
request: RendererRequest, request: RendererRequest,
prompt_input: str | list[str] | list[int] | list[list[int]] | None, prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None, prompt_embeds: bytes | list[bytes] | None,
) -> list[TokensPrompt | EmbedsPrompt]: ) -> list[TokPrompt]:
renderer = self.renderer renderer = self.renderer
tok_params = request.build_tok_params(self.model_config) model_config = self.model_config
in_prompts = await renderer.render_completions_async( tok_params = request.build_tok_params(model_config)
prompt_input, prompt_embeds
prompts = list[SingletonPrompt | bytes]()
if prompt_embeds is not None: # embeds take higher priority
prompts.extend(prompt_to_seq(prompt_embeds))
if prompt_input is not None:
prompts.extend(prompt_to_seq(prompt_input))
parsed_prompts = [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(model_config, prompt)
) )
engine_prompts = await renderer.tokenize_prompts_async(in_prompts, tok_params) for prompt in prompts
]
in_prompts = await renderer.render_prompts_async(parsed_prompts)
extra_items = { extra_items = {
k: v k: v
for k in ("mm_processor_kwargs", "cache_salt") for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None if (v := getattr(request, k, None)) is not None
} }
for prompt in engine_prompts: for in_prompt in in_prompts:
prompt.update(extra_items) # type: ignore target_prompt: SingletonDictPrompt = in_prompt.get( # type: ignore
"encoder_prompt", in_prompt
)
target_prompt.update(extra_items) # type: ignore
engine_prompts = await renderer.tokenize_prompts_async(in_prompts, tok_params)
return engine_prompts return engine_prompts
...@@ -991,7 +1005,7 @@ class OpenAIServing: ...@@ -991,7 +1005,7 @@ class OpenAIServing:
default_template_kwargs: dict[str, Any] | None, default_template_kwargs: dict[str, Any] | None,
tool_dicts: list[dict[str, Any]] | None = None, tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
) -> tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]: ) -> tuple[list[ConversationMessage], list[TokPrompt]]:
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
renderer = self.renderer renderer = self.renderer
...@@ -1009,17 +1023,21 @@ class OpenAIServing: ...@@ -1009,17 +1023,21 @@ class OpenAIServing:
default_template, default_template_content_format default_template, default_template_content_format
).with_defaults(default_template_kwargs) ).with_defaults(default_template_kwargs)
conversation, prompt = await renderer.render_messages_async( conversation, in_prompt = await renderer.render_messages_async(
messages, chat_params messages, chat_params
) )
engine_prompt = await renderer.tokenize_prompt_async(prompt, tok_params) target_prompt: SingletonDictPrompt = in_prompt.get( # type: ignore
"encoder_prompt", in_prompt
)
extra_items = { extra_items = {
k: v k: v
for k in ("mm_processor_kwargs", "cache_salt") for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None if (v := getattr(request, k, None)) is not None
} }
engine_prompt.update(extra_items) # type: ignore target_prompt.update(extra_items) # type: ignore
engine_prompt = await renderer.tokenize_prompt_async(target_prompt, tok_params)
# tool parsing is done only if a tool_parser has been set and if # tool parsing is done only if a tool_parser has been set and if
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
...@@ -1040,6 +1058,15 @@ class OpenAIServing: ...@@ -1040,6 +1058,15 @@ class OpenAIServing:
return conversation, [engine_prompt] return conversation, [engine_prompt]
def _extract_prompt_components(self, prompt: object):
return extract_prompt_components(self.model_config, prompt)
def _extract_prompt_text(self, prompt: object):
return self._extract_prompt_components(prompt).text
def _extract_prompt_len(self, prompt: object):
return extract_prompt_len(self.model_config, prompt)
async def _render_next_turn( async def _render_next_turn(
self, self,
request: ResponsesRequest, request: ResponsesRequest,
...@@ -1067,7 +1094,7 @@ class OpenAIServing: ...@@ -1067,7 +1094,7 @@ class OpenAIServing:
async def _generate_with_builtin_tools( async def _generate_with_builtin_tools(
self, self,
request_id: str, request_id: str,
engine_prompt: TokensPrompt | EmbedsPrompt, engine_prompt: TokPrompt,
sampling_params: SamplingParams, sampling_params: SamplingParams,
tok_params: TokenizeParams, tok_params: TokenizeParams,
context: ConversationContext, context: ConversationContext,
...@@ -1075,7 +1102,7 @@ class OpenAIServing: ...@@ -1075,7 +1102,7 @@ class OpenAIServing:
priority: int = 0, priority: int = 0,
trace_headers: Mapping[str, str] | None = None, trace_headers: Mapping[str, str] | None = None,
): ):
prompt_text = engine_prompt.get("prompt") prompt_text = self._extract_prompt_text(engine_prompt)
orig_priority = priority orig_priority = priority
sub_request = 0 sub_request = 0
...@@ -1145,12 +1172,12 @@ class OpenAIServing: ...@@ -1145,12 +1172,12 @@ class OpenAIServing:
context.chat_template_content_format, context.chat_template_content_format,
) )
engine_prompt = engine_prompts[0] engine_prompt = engine_prompts[0]
prompt_text = engine_prompt.get("prompt") prompt_text = self._extract_prompt_text(engine_prompt)
sampling_params.max_tokens = get_max_tokens( sampling_params.max_tokens = get_max_tokens(
self.max_model_len, self.max_model_len,
context.request, context.request,
engine_prompt, self._extract_prompt_len(engine_prompt),
self.default_sampling_params, # type: ignore self.default_sampling_params, # type: ignore
) )
...@@ -1161,20 +1188,20 @@ class OpenAIServing: ...@@ -1161,20 +1188,20 @@ class OpenAIServing:
def _log_inputs( def _log_inputs(
self, self,
request_id: str, request_id: str,
inputs: PromptType, inputs: PromptType | TokPrompt,
params: SamplingParams | PoolingParams | BeamSearchParams | None, params: SamplingParams | PoolingParams | BeamSearchParams | None,
lora_request: LoRARequest | None, lora_request: LoRARequest | None,
) -> None: ) -> None:
if self.request_logger is None: if self.request_logger is None:
return return
prompt, prompt_token_ids, prompt_embeds = get_prompt_components(inputs) components = self._extract_prompt_components(inputs)
self.request_logger.log_inputs( self.request_logger.log_inputs(
request_id, request_id,
prompt, components.text,
prompt_token_ids, components.token_ids,
prompt_embeds, components.embeds,
params=params, params=params,
lora_request=lora_request, lora_request=lora_request,
) )
......
...@@ -116,13 +116,13 @@ from vllm.entrypoints.openai.responses.utils import ( ...@@ -116,13 +116,13 @@ from vllm.entrypoints.openai.responses.utils import (
) )
from vllm.entrypoints.utils import get_max_tokens from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.inputs.parse import get_prompt_len
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs from vllm.logprobs import SampleLogprobs
from vllm.outputs import CompletionOutput from vllm.outputs import CompletionOutput
from vllm.parser import ParserManager from vllm.parser import ParserManager
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid from vllm.utils import random_uuid
...@@ -292,10 +292,10 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -292,10 +292,10 @@ class OpenAIServingResponses(OpenAIServing):
def _validate_generator_input( def _validate_generator_input(
self, self,
engine_prompt: TokensPrompt | EmbedsPrompt, engine_prompt: TokPrompt,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
"""Add validations to the input to the generator here.""" """Add validations to the input to the generator here."""
prompt_len = get_prompt_len(engine_prompt) prompt_len = self._extract_prompt_len(engine_prompt)
if self.max_model_len <= prompt_len: if self.max_model_len <= prompt_len:
error_message = ( error_message = (
f"The engine prompt length {prompt_len} " f"The engine prompt length {prompt_len} "
...@@ -442,7 +442,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -442,7 +442,7 @@ class OpenAIServingResponses(OpenAIServing):
default_max_tokens = get_max_tokens( default_max_tokens = get_max_tokens(
self.max_model_len, self.max_model_len,
request, request,
engine_prompt, self._extract_prompt_len(engine_prompt),
self.default_sampling_params, self.default_sampling_params,
) )
......
...@@ -7,7 +7,7 @@ import time ...@@ -7,7 +7,7 @@ import time
import zlib import zlib
from collections.abc import AsyncGenerator, Callable from collections.abc import AsyncGenerator, Callable
from functools import cached_property from functools import cached_property
from typing import Literal, TypeAlias, TypeVar, cast from typing import Final, Literal, TypeAlias, TypeVar, cast
import numpy as np import numpy as np
from fastapi import Request from fastapi import Request
...@@ -37,12 +37,13 @@ from vllm.entrypoints.openai.speech_to_text.protocol import ( ...@@ -37,12 +37,13 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
TranslationStreamResponse, TranslationStreamResponse,
) )
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import ExplicitEncoderDecoderPrompt, PromptType from vllm.inputs.data import PromptType
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import FlatLogprobs, Logprob from vllm.logprobs import FlatLogprobs, Logprob
from vllm.model_executor.models import SupportsTranscription, supports_transcription from vllm.model_executor.models import SupportsTranscription, supports_transcription
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.renderers.inputs import EncoderDecoderDictPrompt
from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt
from vllm.tokenizers import get_tokenizer from vllm.tokenizers import get_tokenizer
from vllm.utils.import_utils import PlaceholderModule from vllm.utils.import_utils import PlaceholderModule
...@@ -94,7 +95,7 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -94,7 +95,7 @@ class OpenAISpeechToText(OpenAIServing):
) )
self.default_sampling_params = self.model_config.get_diff_sampling_param() self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.task_type = task_type self.task_type: Final = task_type
self.asr_config = self.model_cls.get_speech_to_text_config( self.asr_config = self.model_cls.get_speech_to_text_config(
self.model_config, task_type self.model_config, task_type
...@@ -298,35 +299,26 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -298,35 +299,26 @@ class OpenAISpeechToText(OpenAIServing):
to_language=to_language, to_language=to_language,
) )
if request.response_format == "verbose_json": if request.response_format == "verbose_json":
if not is_explicit_encoder_decoder_prompt(prompt): prompt = self._preprocess_verbose_prompt(parse_enc_dec_prompt(prompt))
raise VLLMValidationError(
"Expected prompt to be an encoder-decoder prompt",
parameter="prompt",
value=type(prompt).__name__,
)
prompt = self._preprocess_verbose_prompt(prompt)
prompts.append(prompt) prompts.append(prompt)
return prompts, duration
def _repl_verbose_text(self, text: str): return prompts, duration
return text.replace("<|notimestamps|>", "<|0.00|>")
def _preprocess_verbose_prompt(self, prompt: ExplicitEncoderDecoderPrompt): def _preprocess_verbose_prompt(self, prompt: EncoderDecoderDictPrompt):
dec_prompt = prompt["decoder_prompt"] dec_prompt = prompt["decoder_prompt"]
if isinstance(dec_prompt, str): if not (isinstance(dec_prompt, dict) and "prompt" in dec_prompt):
prompt["decoder_prompt"] = self._repl_verbose_text(dec_prompt)
elif isinstance(dec_prompt, dict) and "prompt" in dec_prompt:
dec_prompt["prompt"] = self._repl_verbose_text(dec_prompt["prompt"])
else:
raise VLLMValidationError( raise VLLMValidationError(
"Expected decoder_prompt to contain text", "Expected decoder_prompt to contain text",
parameter="decoder_prompt", parameter="decoder_prompt",
value=type(dec_prompt).__name__, value=type(dec_prompt).__name__,
) )
dec_prompt["prompt"] = dec_prompt["prompt"].replace(
"<|notimestamps|>", "<|0.00|>"
)
return prompt return prompt
def _get_verbose_segments( def _get_verbose_segments(
......
...@@ -28,10 +28,11 @@ from vllm.entrypoints.pooling.utils import ( ...@@ -28,10 +28,11 @@ from vllm.entrypoints.pooling.utils import (
encode_pooling_output_base64, encode_pooling_output_base64,
encode_pooling_output_float, encode_pooling_output_float,
) )
from vllm.inputs.data import EmbedsPrompt, TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers.inputs import TokPrompt
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import chunk_list from vllm.utils.collection_utils import chunk_list
from vllm.utils.serial_utils import EmbedDType, Endianness from vllm.utils.serial_utils import EmbedDType, Endianness
...@@ -369,7 +370,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -369,7 +370,7 @@ class OpenAIServingEmbedding(OpenAIServing):
async def _create_single_prompt_generator( async def _create_single_prompt_generator(
self, self,
ctx: EmbeddingServeContext, ctx: EmbeddingServeContext,
engine_prompt: TokensPrompt | EmbedsPrompt, engine_prompt: TokPrompt,
pooling_params: PoolingParams, pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None, trace_headers: Mapping[str, str] | None,
prompt_index: int, prompt_index: int,
......
...@@ -33,8 +33,11 @@ from vllm.entrypoints.pooling.utils import ( ...@@ -33,8 +33,11 @@ from vllm.entrypoints.pooling.utils import (
encode_pooling_output_base64, encode_pooling_output_base64,
encode_pooling_output_float, encode_pooling_output_float,
) )
from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.renderers.inputs import TokPrompt
from vllm.renderers.inputs.preprocess import prompt_to_seq
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
...@@ -91,6 +94,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -91,6 +94,7 @@ class OpenAIServingPooling(OpenAIServing):
"dimensions is currently not supported" "dimensions is currently not supported"
) )
engine_prompts: Sequence[PromptType | TokPrompt]
if is_io_processor_request: if is_io_processor_request:
if self.io_processor is None: if self.io_processor is None:
raise ValueError( raise ValueError(
...@@ -102,14 +106,10 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -102,14 +106,10 @@ class OpenAIServingPooling(OpenAIServing):
validated_prompt = self.io_processor.parse_request(request) validated_prompt = self.io_processor.parse_request(request)
engine_prompts = await self.io_processor.pre_process_async( raw_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id prompt=validated_prompt, request_id=request_id
) )
if not isinstance(engine_prompts, Sequence) or isinstance( engine_prompts = prompt_to_seq(raw_prompts)
engine_prompts, (str, bytes, bytearray)
):
engine_prompts = [engine_prompts]
elif isinstance(request, PoolingChatRequest): elif isinstance(request, PoolingChatRequest):
error_check_ret = self._validate_chat_template( error_check_ret = self._validate_chat_template(
request_chat_template=request.chat_template, request_chat_template=request.chat_template,
......
...@@ -17,8 +17,6 @@ from starlette.background import BackgroundTask, BackgroundTasks ...@@ -17,8 +17,6 @@ from starlette.background import BackgroundTask, BackgroundTasks
from vllm import envs from vllm import envs
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import EmbedsPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_len
from vllm.logger import current_formatter_type, init_logger from vllm.logger import current_formatter_type, init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
...@@ -189,7 +187,7 @@ def cli_env_setup(): ...@@ -189,7 +187,7 @@ def cli_env_setup():
def get_max_tokens( def get_max_tokens(
max_model_len: int, max_model_len: int,
request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest", request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest",
prompt: TokensPrompt | EmbedsPrompt, input_length: int,
default_sampling_params: dict, default_sampling_params: dict,
) -> int: ) -> int:
# NOTE: Avoid isinstance() for better efficiency # NOTE: Avoid isinstance() for better efficiency
...@@ -204,7 +202,6 @@ def get_max_tokens( ...@@ -204,7 +202,6 @@ def get_max_tokens(
# CompletionRequest (also a fallback for ChatCompletionRequest) # CompletionRequest (also a fallback for ChatCompletionRequest)
max_tokens = getattr(request, "max_tokens", None) max_tokens = getattr(request, "max_tokens", None)
input_length = get_prompt_len(prompt)
default_max_tokens = max_model_len - input_length default_max_tokens = max_model_len - input_length
max_output_tokens = current_platform.get_max_output_tokens(input_length) max_output_tokens = current_platform.get_max_output_tokens(input_length)
......
...@@ -16,11 +16,8 @@ from .data import ( ...@@ -16,11 +16,8 @@ from .data import (
TextPrompt, TextPrompt,
TokenInputs, TokenInputs,
TokensPrompt, TokensPrompt,
build_explicit_enc_dec_prompt,
embeds_inputs, embeds_inputs,
to_enc_dec_tuple_list,
token_inputs, token_inputs,
zip_enc_dec_prompts,
) )
__all__ = [ __all__ = [
...@@ -39,8 +36,5 @@ __all__ = [ ...@@ -39,8 +36,5 @@ __all__ = [
"EncoderDecoderInputs", "EncoderDecoderInputs",
"ProcessorInputs", "ProcessorInputs",
"SingletonInputs", "SingletonInputs",
"build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list",
"zip_enc_dec_prompts",
"StreamingInput", "StreamingInput",
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment