Unverified Commit cd8b405b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Consolidate sequence normalization and enc-dec parsing (#33928)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 4707f7eb
...@@ -54,6 +54,7 @@ class MockModelConfig: ...@@ -54,6 +54,7 @@ class MockModelConfig:
generation_config: str = "auto" generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False skip_tokenizer_init = False
is_encoder_decoder: bool = False
def get_diff_sampling_param(self): def get_diff_sampling_param(self):
return self.diff_sampling_param or {} return self.diff_sampling_param or {}
......
...@@ -53,6 +53,7 @@ class MockModelConfig: ...@@ -53,6 +53,7 @@ class MockModelConfig:
generation_config: str = "auto" generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False skip_tokenizer_init = False
is_encoder_decoder: bool = False
def get_diff_sampling_param(self): def get_diff_sampling_param(self):
return self.diff_sampling_param or {} return self.diff_sampling_param or {}
......
...@@ -52,6 +52,7 @@ class MockModelConfig: ...@@ -52,6 +52,7 @@ class MockModelConfig:
encoder_config = None encoder_config = None
generation_config: str = "auto" generation_config: str = "auto"
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
def get_diff_sampling_param(self): def get_diff_sampling_param(self):
return self.diff_sampling_param or {} return self.diff_sampling_param or {}
......
...@@ -529,6 +529,7 @@ class MockModelConfig: ...@@ -529,6 +529,7 @@ class MockModelConfig:
generation_config: str = "auto" generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
def get_diff_sampling_param(self): def get_diff_sampling_param(self):
return self.diff_sampling_param or {} return self.diff_sampling_param or {}
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.renderers.inputs.preprocess import prompt_to_seq
def test_empty_input():
assert prompt_to_seq([]) == []
assert prompt_to_seq([[]]) == [[]]
assert prompt_to_seq([[], []]) == [[], []]
def test_text_input():
assert prompt_to_seq("foo") == ["foo"]
assert prompt_to_seq(["foo"]) == ["foo"]
assert prompt_to_seq(["foo", "bar"]) == ["foo", "bar"]
def test_token_input():
assert prompt_to_seq([1, 2]) == [[1, 2]]
assert prompt_to_seq([[1, 2]]) == [[1, 2]]
assert prompt_to_seq([[1, 2], [3, 4]]) == [[1, 2], [3, 4]]
def test_text_token_input():
assert prompt_to_seq([[1, 2], "foo"]) == [[1, 2], "foo"]
assert prompt_to_seq(["foo", [1, 2]]) == ["foo", [1, 2]]
def test_bytes_input():
assert prompt_to_seq(b"foo") == [b"foo"]
assert prompt_to_seq([b"foo"]) == [b"foo"]
assert prompt_to_seq([b"foo", b"bar"]) == [b"foo", b"bar"]
def test_dict_input():
assert prompt_to_seq({"prompt": "foo"}) == [{"prompt": "foo"}]
assert prompt_to_seq([{"prompt": "foo"}]) == [{"prompt": "foo"}]
assert prompt_to_seq([{"prompt": "foo"}, {"prompt_token_ids": [1, 2]}]) == [
{"prompt": "foo"},
{"prompt_token_ids": [1, 2]},
]
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import io import io
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
...@@ -9,8 +10,11 @@ import pybase64 ...@@ -9,8 +10,11 @@ import pybase64
import pytest import pytest
import torch import torch
from vllm.config import ModelConfig
from vllm.inputs import SingletonPrompt
from vllm.renderers import TokenizeParams from vllm.renderers import TokenizeParams
from vllm.renderers.hf import HfRenderer from vllm.renderers.hf import HfRenderer
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
from vllm.tokenizers.registry import tokenizer_args_from_config from vllm.tokenizers.registry import tokenizer_args_from_config
MODEL_NAME = "openai-community/gpt2" MODEL_NAME = "openai-community/gpt2"
...@@ -33,6 +37,7 @@ class MockModelConfig: ...@@ -33,6 +37,7 @@ class MockModelConfig:
encoder_config: dict[str, Any] | None = None encoder_config: dict[str, Any] | None = None
enable_prompt_embeds: bool = True enable_prompt_embeds: bool = True
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
@dataclass @dataclass
...@@ -80,73 +85,42 @@ def _build_renderer( ...@@ -80,73 +85,42 @@ def _build_renderer(
return renderer return renderer
class TestValidatePrompt: def _preprocess_prompt(
STRING_INPUTS = [ mdoel_config: ModelConfig,
"", prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes],
"foo", ):
"foo bar", return [
"foo baz bar", (
"foo bar qux baz", prompt
if isinstance(prompt, bytes)
else parse_model_prompt(mdoel_config, prompt)
)
for prompt in prompt_to_seq(prompt_or_prompts)
] ]
TOKEN_INPUTS = [
[-1],
[1],
[1, 2],
[1, 3, 4],
[1, 2, 4, 3],
]
INPUTS_SLICES = [ class TestValidatePrompt:
slice(None, None, -1),
slice(None, None, 2),
slice(None, None, -2),
]
# Test that a nested mixed-type list of lists raises a TypeError.
def test_empty_input(self): def test_empty_input(self):
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
with pytest.raises(ValueError, match="at least one prompt"): with pytest.raises(ValueError, match="at least one prompt"):
renderer.render_completions([]) renderer.render_prompts(_preprocess_prompt(renderer.config, []))
def test_invalid_type(self): def test_invalid_type(self):
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
with pytest.raises(TypeError, match="string or an array of tokens"): with pytest.raises(TypeError, match="should be a list of integers"):
renderer.render_completions([[1, 2], ["foo", "bar"]]) renderer.render_prompts(
_preprocess_prompt(renderer.config, [[1, 2], ["foo", "bar"]]) # type: ignore[arg-type]
@pytest.mark.parametrize("string_input", STRING_INPUTS)
def test_string_consistent(self, string_input: str):
renderer = _build_renderer(MockModelConfig())
assert renderer.render_completions(string_input) == renderer.render_completions(
[string_input]
) )
@pytest.mark.parametrize("token_input", TOKEN_INPUTS)
def test_token_consistent(self, token_input: list[int]):
renderer = _build_renderer(MockModelConfig())
assert renderer.render_completions(token_input) == renderer.render_completions(
[token_input]
)
@pytest.mark.parametrize("inputs_slice", INPUTS_SLICES)
def test_string_slice(self, inputs_slice: slice):
renderer = _build_renderer(MockModelConfig())
assert renderer.render_completions(self.STRING_INPUTS)[
inputs_slice
] == renderer.render_completions(self.STRING_INPUTS[inputs_slice])
class TestRenderPrompt: class TestRenderPrompt:
def test_token_input(self): def test_token_input(self):
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
tokens = [101, 7592, 2088] tokens = [101, 7592, 2088]
prompts = renderer.render_completions(tokens) prompts = renderer.render_prompts(_preprocess_prompt(renderer.config, tokens))
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100), TokenizeParams(max_total_tokens=100),
...@@ -159,7 +133,9 @@ class TestRenderPrompt: ...@@ -159,7 +133,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]] token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
prompts = renderer.render_completions(token_lists) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, token_lists)
)
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100), TokenizeParams(max_total_tokens=100),
...@@ -174,7 +150,9 @@ class TestRenderPrompt: ...@@ -174,7 +150,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
text_input = "x" * 10 text_input = "x" * 10
prompts = renderer.render_completions(text_input) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, text_input)
)
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100), TokenizeParams(max_total_tokens=100),
...@@ -187,7 +165,9 @@ class TestRenderPrompt: ...@@ -187,7 +165,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
text_list_input = ["x" * 10, "x" * 12, "x" * 14] text_list_input = ["x" * 10, "x" * 12, "x" * 14]
prompts = renderer.render_completions(text_list_input) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, text_list_input)
)
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100), TokenizeParams(max_total_tokens=100),
...@@ -200,7 +180,9 @@ class TestRenderPrompt: ...@@ -200,7 +180,9 @@ class TestRenderPrompt:
def test_zero_truncation(self): def test_zero_truncation(self):
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
prompts = renderer.render_completions("x" * 200) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, "x" * 200)
)
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=0), TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=0),
...@@ -212,7 +194,9 @@ class TestRenderPrompt: ...@@ -212,7 +194,9 @@ class TestRenderPrompt:
def test_pos_truncation(self): def test_pos_truncation(self):
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
prompts = renderer.render_completions("x" * 200) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, "x" * 200)
)
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=50), TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=50),
...@@ -224,7 +208,9 @@ class TestRenderPrompt: ...@@ -224,7 +208,9 @@ class TestRenderPrompt:
def test_neg_truncation(self): def test_neg_truncation(self):
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
prompts = renderer.render_completions("x" * 200) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, "x" * 200)
)
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=-1), TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=-1),
...@@ -237,7 +223,9 @@ class TestRenderPrompt: ...@@ -237,7 +223,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig(), truncation_side="left") renderer = _build_renderer(MockModelConfig(), truncation_side="left")
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
prompts = renderer.render_completions(long_tokens) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5), TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
...@@ -251,7 +239,9 @@ class TestRenderPrompt: ...@@ -251,7 +239,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig(), truncation_side="right") renderer = _build_renderer(MockModelConfig(), truncation_side="right")
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
prompts = renderer.render_completions(long_tokens) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5), TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
...@@ -266,7 +256,9 @@ class TestRenderPrompt: ...@@ -266,7 +256,9 @@ class TestRenderPrompt:
# Exceeds max_total_tokens and max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN # Exceeds max_total_tokens and max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
long_tokens = "x" * 150 long_tokens = "x" * 150
prompts = renderer.render_completions(long_tokens) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
with pytest.raises( with pytest.raises(
ValueError, ValueError,
...@@ -285,7 +277,9 @@ class TestRenderPrompt: ...@@ -285,7 +277,9 @@ class TestRenderPrompt:
# Exceeds max_total_tokens but not max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN # Exceeds max_total_tokens but not max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
long_tokens = "x" * 150 long_tokens = "x" * 150
prompts = renderer.render_completions(long_tokens) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
with pytest.raises( with pytest.raises(
ValueError, ValueError,
...@@ -304,7 +298,9 @@ class TestRenderPrompt: ...@@ -304,7 +298,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
long_tokens = list(range(150)) # Exceeds max_total_tokens=100 long_tokens = list(range(150)) # Exceeds max_total_tokens=100
prompts = renderer.render_completions(long_tokens) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
with pytest.raises( with pytest.raises(
ValueError, ValueError,
...@@ -318,7 +314,9 @@ class TestRenderPrompt: ...@@ -318,7 +314,9 @@ class TestRenderPrompt:
def test_no_tokenizer_for_text(self): def test_no_tokenizer_for_text(self):
renderer = _build_renderer(MockModelConfig(skip_tokenizer_init=True)) renderer = _build_renderer(MockModelConfig(skip_tokenizer_init=True))
prompts = renderer.render_completions("Hello world") prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, "Hello world")
)
with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"): with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
renderer.tokenize_prompts( renderer.tokenize_prompts(
...@@ -330,7 +328,7 @@ class TestRenderPrompt: ...@@ -330,7 +328,7 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
tokens = [1, 2, 3, 4] tokens = [1, 2, 3, 4]
prompts = renderer.render_completions(tokens) prompts = renderer.render_prompts(_preprocess_prompt(renderer.config, tokens))
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams( TokenizeParams(
...@@ -359,7 +357,9 @@ class TestRenderEmbedPrompt: ...@@ -359,7 +357,9 @@ class TestRenderEmbedPrompt:
tensor_input = torch.randn(10, 768, dtype=torch.float32) tensor_input = torch.randn(10, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(tensor_input) embed_bytes = self._create_test_embed_bytes(tensor_input)
prompts = renderer.render_completions(prompt_embeds=embed_bytes) prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, embed_bytes)
)
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
TokenizeParams(max_total_tokens=100), TokenizeParams(max_total_tokens=100),
...@@ -377,8 +377,11 @@ class TestRenderEmbedPrompt: ...@@ -377,8 +377,11 @@ class TestRenderEmbedPrompt:
torch.randn(12, 512, dtype=torch.float32), torch.randn(12, 512, dtype=torch.float32),
] ]
prompts = renderer.render_completions( prompts = renderer.render_prompts(
prompt_embeds=[self._create_test_embed_bytes(t) for t in tensor_inputs], _preprocess_prompt(
renderer.config,
[self._create_test_embed_bytes(t) for t in tensor_inputs],
)
) )
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
...@@ -395,8 +398,10 @@ class TestRenderEmbedPrompt: ...@@ -395,8 +398,10 @@ class TestRenderEmbedPrompt:
# Create tensor with more tokens than truncation limit # Create tensor with more tokens than truncation limit
tensor_input = torch.randn(20, 768, dtype=torch.float32) tensor_input = torch.randn(20, 768, dtype=torch.float32)
prompts = renderer.render_completions( prompts = renderer.render_prompts(
prompt_embeds=self._create_test_embed_bytes(tensor_input), _preprocess_prompt(
renderer.config, self._create_test_embed_bytes(tensor_input)
)
) )
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
...@@ -420,8 +425,10 @@ class TestRenderEmbedPrompt: ...@@ -420,8 +425,10 @@ class TestRenderEmbedPrompt:
for dtype in dtypes: for dtype in dtypes:
tensor_input = torch.randn(5, 256, dtype=dtype) tensor_input = torch.randn(5, 256, dtype=dtype)
prompts = renderer.render_completions( prompts = renderer.render_prompts(
prompt_embeds=self._create_test_embed_bytes(tensor_input), _preprocess_prompt(
renderer.config, self._create_test_embed_bytes(tensor_input)
)
) )
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
...@@ -437,8 +444,10 @@ class TestRenderEmbedPrompt: ...@@ -437,8 +444,10 @@ class TestRenderEmbedPrompt:
# Test tensor with batch dimension gets squeezed # Test tensor with batch dimension gets squeezed
tensor_input = torch.randn(1, 10, 768, dtype=torch.float32) tensor_input = torch.randn(1, 10, 768, dtype=torch.float32)
prompts = renderer.render_completions( prompts = renderer.render_prompts(
prompt_embeds=self._create_test_embed_bytes(tensor_input), _preprocess_prompt(
renderer.config, self._create_test_embed_bytes(tensor_input)
)
) )
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
...@@ -455,9 +464,11 @@ class TestRenderEmbedPrompt: ...@@ -455,9 +464,11 @@ class TestRenderEmbedPrompt:
text_input = "Hello world" text_input = "Hello world"
tensor_input = torch.randn(5, 256, dtype=torch.float32) tensor_input = torch.randn(5, 256, dtype=torch.float32)
prompts = renderer.render_completions( prompts = renderer.render_prompts(
text_input, _preprocess_prompt(
prompt_embeds=self._create_test_embed_bytes(tensor_input), renderer.config,
[text_input, self._create_test_embed_bytes(tensor_input)],
)
) )
results = renderer.tokenize_prompts( results = renderer.tokenize_prompts(
prompts, prompts,
...@@ -465,8 +476,8 @@ class TestRenderEmbedPrompt: ...@@ -465,8 +476,8 @@ class TestRenderEmbedPrompt:
) )
assert len(results) == 2 assert len(results) == 2
# First should be embed prompt # First should be tokens prompt
assert torch.equal(results[0]["prompt_embeds"], tensor_input) assert "prompt_token_ids" in results[0]
# Second should be tokens prompt assert len(results[0]["prompt_token_ids"]) == len(text_input)
assert "prompt_token_ids" in results[1] # Second should be embed prompt
assert len(results[1]["prompt_token_ids"]) == len(text_input) assert torch.equal(results[1]["prompt_embeds"], tensor_input)
...@@ -3,16 +3,40 @@ ...@@ -3,16 +3,40 @@
import asyncio import asyncio
import time import time
from dataclasses import dataclass
from typing import Any
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from vllm.config import ModelConfig
from vllm.renderers import ChatParams from vllm.renderers import ChatParams
from vllm.renderers.mistral import MistralRenderer, safe_apply_chat_template from vllm.renderers.mistral import MistralRenderer, safe_apply_chat_template
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
@dataclass
class MockHFConfig:
model_type: str = "any"
@dataclass
class MockModelConfig:
runner_type = "generate"
model: str = MODEL_NAME
tokenizer: str = MODEL_NAME
trust_remote_code: bool = False
max_model_len: int = 100
tokenizer_revision = None
tokenizer_mode = "mistral"
hf_config = MockHFConfig()
encoder_config: dict[str, Any] | None = None
enable_prompt_embeds: bool = True
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_mistral_tokenizer_does_not_block_event_loop(): async def test_async_mistral_tokenizer_does_not_block_event_loop():
...@@ -23,9 +47,10 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop(): ...@@ -23,9 +47,10 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop():
time.sleep(2) time.sleep(2)
return expected_tokens return expected_tokens
mock_model_config = MockModelConfig(skip_tokenizer_init=True)
mock_tokenizer = Mock(spec=MistralTokenizer) mock_tokenizer = Mock(spec=MistralTokenizer)
mock_tokenizer.apply_chat_template = mocked_apply_chat_template mock_tokenizer.apply_chat_template = mocked_apply_chat_template
mock_renderer = MistralRenderer(Mock(spec=ModelConfig), tokenizer_kwargs={}) mock_renderer = MistralRenderer(mock_model_config, tokenizer_kwargs={})
mock_renderer._tokenizer = mock_tokenizer mock_renderer._tokenizer = mock_tokenizer
task = mock_renderer.render_messages_async([], ChatParams()) task = mock_renderer.render_messages_async([], ChatParams())
......
...@@ -4,52 +4,13 @@ ...@@ -4,52 +4,13 @@
import pytest import pytest
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.inputs import zip_enc_dec_prompts
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
pytestmark = pytest.mark.cpu_test pytestmark = pytest.mark.cpu_test
@pytest.mark.parametrize( @pytest.mark.parametrize("model_id", ["facebook/chameleon-7b"])
"mm_processor_kwargs,expected_mm_kwargs", @pytest.mark.parametrize("prompt", ["", {"prompt_token_ids": []}])
[
(None, [{}, {}]),
({}, [{}, {}]),
({"foo": 100}, [{"foo": 100}, {"foo": 100}]),
([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]),
],
)
def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
"""Test mm_processor_kwargs init for zipping enc/dec prompts."""
encoder_prompts = ["An encoder prompt", "Another encoder prompt"]
decoder_prompts = ["A decoder prompt", "Another decoder prompt"]
zipped_prompts = zip_enc_dec_prompts(
encoder_prompts, decoder_prompts, mm_processor_kwargs
)
assert len(zipped_prompts) == len(encoder_prompts) == len(decoder_prompts)
for enc, dec, exp_kwargs, zipped in zip(
encoder_prompts, decoder_prompts, expected_mm_kwargs, zipped_prompts
):
assert isinstance(zipped, dict)
assert len(zipped.keys()) == 3
assert zipped["encoder_prompt"] == enc
assert zipped["decoder_prompt"] == dec
assert zipped["mm_processor_kwargs"] == exp_kwargs
@pytest.mark.parametrize(
"model_id",
[
"facebook/chameleon-7b",
],
)
@pytest.mark.parametrize(
"prompt",
[
"",
{"prompt_token_ids": []},
],
)
@pytest.mark.skip( @pytest.mark.skip(
reason=( reason=(
"Applying huggingface processor on text inputs results in " "Applying huggingface processor on text inputs results in "
......
...@@ -16,6 +16,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput ...@@ -16,6 +16,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor from vllm.plugins.io_processors import IOProcessor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer from vllm.renderers import BaseRenderer
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
...@@ -53,7 +54,11 @@ class EngineClient(ABC): ...@@ -53,7 +54,11 @@ class EngineClient(ABC):
@abstractmethod @abstractmethod
def generate( def generate(
self, self,
prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None], prompt: EngineCoreRequest
| PromptType
| DictPrompt
| TokPrompt
| AsyncGenerator[StreamingInput, None],
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
*, *,
...@@ -70,7 +75,7 @@ class EngineClient(ABC): ...@@ -70,7 +75,7 @@ class EngineClient(ABC):
@abstractmethod @abstractmethod
def encode( def encode(
self, self,
prompt: PromptType, prompt: PromptType | DictPrompt | TokPrompt,
pooling_params: PoolingParams, pooling_params: PoolingParams,
request_id: str, request_id: str,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import itertools import itertools
import warnings import warnings
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, TypeAlias, cast from typing import TYPE_CHECKING, Any, cast
import cloudpickle import cloudpickle
import torch.nn as nn import torch.nn as nn
...@@ -53,16 +53,13 @@ from vllm.entrypoints.pooling.score.utils import ( ...@@ -53,16 +53,13 @@ from vllm.entrypoints.pooling.score.utils import (
validate_score_input, validate_score_input,
) )
from vllm.entrypoints.utils import log_non_default_args from vllm.entrypoints.utils import log_non_default_args
from vllm.inputs import ( from vllm.inputs.data import (
DataPrompt, DataPrompt,
EmbedsPrompt,
ExplicitEncoderDecoderPrompt,
PromptType, PromptType,
SingletonPrompt, SingletonPrompt,
TextPrompt, TextPrompt,
TokensPrompt, TokensPrompt,
) )
from vllm.inputs.parse import get_prompt_components, is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
...@@ -76,6 +73,13 @@ from vllm.outputs import ( ...@@ -76,6 +73,13 @@ from vllm.outputs import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.renderers.inputs import DictPrompt, SingletonDictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import (
conversation_to_seq,
extract_prompt_components,
parse_model_prompt,
prompt_to_seq,
)
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
...@@ -93,9 +97,6 @@ logger = init_logger(__name__) ...@@ -93,9 +97,6 @@ logger = init_logger(__name__)
_R = TypeVar("_R", default=Any) _R = TypeVar("_R", default=Any)
EnginePrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt
EngineEncDecPrompt: TypeAlias = ExplicitEncoderDecoderPrompt[EnginePrompt, EnginePrompt]
class LLM: class LLM:
"""An LLM for generating texts from given prompts and sampling parameters. """An LLM for generating texts from given prompts and sampling parameters.
...@@ -445,21 +446,20 @@ class LLM: ...@@ -445,21 +446,20 @@ class LLM:
if sampling_params is None: if sampling_params is None:
sampling_params = self.get_default_sampling_params() sampling_params = self.get_default_sampling_params()
self._validate_and_add_requests( outputs = self._run_completion(
prompts=prompts, prompts=prompts,
params=sampling_params, params=sampling_params,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=self._get_modality_specific_lora_reqs(prompts, lora_request), lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
priority=priority, priority=priority,
) )
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, RequestOutput) return self.engine_class.validate_outputs(outputs, RequestOutput)
def _get_modality_specific_lora_reqs( def _get_modality_specific_lora_reqs(
self, self,
prompts: PromptType | Sequence[PromptType], prompts: Sequence[DictPrompt | TokPrompt],
lora_request: list[LoRARequest] | LoRARequest | None, lora_request: list[LoRARequest] | LoRARequest | None,
): ):
# Grab the lora config off the vllm config on the engine, # Grab the lora config off the vllm config on the engine,
...@@ -475,9 +475,6 @@ class LLM: ...@@ -475,9 +475,6 @@ class LLM:
): ):
return lora_request return lora_request
if not isinstance(prompts, Sequence) or isinstance(prompts, str):
prompts = [prompts]
optional_loras = ( optional_loras = (
[lora_request] * len(prompts) [lora_request] * len(prompts)
if not isinstance(lora_request, Sequence) if not isinstance(lora_request, Sequence)
...@@ -495,14 +492,12 @@ class LLM: ...@@ -495,14 +492,12 @@ class LLM:
def _resolve_single_prompt_mm_lora( def _resolve_single_prompt_mm_lora(
self, self,
prompt: PromptType, prompt: DictPrompt | TokPrompt,
lora_request: LoRARequest | None, lora_request: LoRARequest | None,
default_mm_loras: dict[str, str] | None, default_mm_loras: dict[str, str] | None,
): ):
if ( if not default_mm_loras or not (
not default_mm_loras mm_data := prompt.get("multi_modal_data") or {}
or not isinstance(prompt, dict)
or not (mm_data := prompt.get("multi_modal_data") or {})
): ):
return lora_request return lora_request
...@@ -806,61 +801,11 @@ class LLM: ...@@ -806,61 +801,11 @@ class LLM:
add_special_tokens=not model_config.is_encoder_decoder, add_special_tokens=not model_config.is_encoder_decoder,
).with_kwargs(tokenization_kwargs) ).with_kwargs(tokenization_kwargs)
def _normalize_prompts(
self,
prompts: PromptType | Sequence[PromptType],
) -> list[EnginePrompt | EngineEncDecPrompt]:
if isinstance(prompts, str):
prompts = TextPrompt(prompt=prompts)
return prompts if isinstance(prompts, Sequence) else [prompts] # type: ignore[return-value]
def _preprocess_cmpl_singleton(
self,
prompt: SingletonPrompt,
tok_params: TokenizeParams,
*,
tokenize: bool,
) -> EnginePrompt:
renderer = self.llm_engine.renderer
if not isinstance(prompt, dict):
prompt = renderer.render_completion(prompt)
return renderer.tokenize_prompt(prompt, tok_params) if tokenize else prompt
def _preprocess_cmpl_enc_dec(
self,
prompt: ExplicitEncoderDecoderPrompt,
tok_params: TokenizeParams,
) -> EngineEncDecPrompt:
enc_prompt = prompt["encoder_prompt"]
dec_prompt = prompt["decoder_prompt"]
return EngineEncDecPrompt(
encoder_prompt=self._preprocess_cmpl_singleton(
enc_prompt,
tok_params,
# TODO: Move multi-modal processor into tokenization
tokenize=not self.model_config.is_multimodal_model,
),
decoder_prompt=(
None
if dec_prompt is None
else self._preprocess_cmpl_singleton(
dec_prompt,
tok_params,
# TODO: Move multi-modal processor into tokenization
tokenize=not self.model_config.is_multimodal_model,
)
),
)
def _preprocess_completion( def _preprocess_completion(
self, self,
prompts: PromptType | Sequence[PromptType], prompts: Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
) -> list[EnginePrompt | EngineEncDecPrompt]: ) -> list[DictPrompt | TokPrompt]:
""" """
Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
a format that can be passed to `_add_request`. a format that can be passed to `_add_request`.
...@@ -871,32 +816,26 @@ class LLM: ...@@ -871,32 +816,26 @@ class LLM:
A list of `TokensPrompts` objects containing the tokenized prompt A list of `TokensPrompts` objects containing the tokenized prompt
after chat template interpolation, and the raw multi-modal inputs. after chat template interpolation, and the raw multi-modal inputs.
""" """
renderer = self.llm_engine.renderer
model_config = self.model_config
tok_params = self._get_cmpl_tok_params(tokenization_kwargs) tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
engine_prompts = list[EnginePrompt | EngineEncDecPrompt]() engine_prompts = list[DictPrompt | TokPrompt]()
for prompt in self._normalize_prompts(prompts): for prompt in prompts:
if is_explicit_encoder_decoder_prompt(prompt): parsed_prompt = parse_model_prompt(model_config, prompt)
engine_prompts.append(self._preprocess_cmpl_enc_dec(prompt, tok_params)) in_prompt = renderer.render_prompt(parsed_prompt)
else:
# Some MM models have non-default `add_special_tokens` # Some MM models have non-default `add_special_tokens`
# TODO: Move multi-modal processor into tokenization # TODO: Move multi-modal processor into tokenization
engine_prompts.append( engine_prompts.append(
self._preprocess_cmpl_singleton( in_prompt
prompt, if model_config.is_multimodal_model
tok_params, else renderer.tokenize_prompt(in_prompt, tok_params)
tokenize=not self.model_config.is_multimodal_model,
)
) )
return engine_prompts return engine_prompts
def _normalize_conversations(
self,
conversations: list[ChatCompletionMessageParam]
| list[list[ChatCompletionMessageParam]],
) -> list[list[ChatCompletionMessageParam]]:
return conversations if is_list_of(conversations, list) else [conversations] # type: ignore[list-item,return-value]
def _get_chat_tok_params(self, tokenization_kwargs: dict[str, Any] | None): def _get_chat_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
model_config = self.model_config model_config = self.model_config
encoder_config = model_config.encoder_config or {} encoder_config = model_config.encoder_config or {}
...@@ -909,8 +848,7 @@ class LLM: ...@@ -909,8 +848,7 @@ class LLM:
def _preprocess_chat( def _preprocess_chat(
self, self,
conversations: list[ChatCompletionMessageParam] conversations: Sequence[list[ChatCompletionMessageParam]],
| list[list[ChatCompletionMessageParam]],
chat_template: str | None = None, chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto", chat_template_content_format: ChatTemplateContentFormatOption = "auto",
chat_template_kwargs: dict[str, Any] | None = None, chat_template_kwargs: dict[str, Any] | None = None,
...@@ -919,7 +857,7 @@ class LLM: ...@@ -919,7 +857,7 @@ class LLM:
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None, mm_processor_kwargs: dict[str, Any] | None = None,
) -> list[EnginePrompt]: ) -> list[DictPrompt | TokPrompt]:
""" """
Convert a list of conversations into prompts so that they can then Convert a list of conversations into prompts so that they can then
be used as input for other LLM APIs. be used as input for other LLM APIs.
...@@ -947,11 +885,14 @@ class LLM: ...@@ -947,11 +885,14 @@ class LLM:
) )
tok_params = self._get_chat_tok_params(tokenization_kwargs) tok_params = self._get_chat_tok_params(tokenization_kwargs)
engine_prompts = list[EnginePrompt]() engine_prompts = list[DictPrompt | TokPrompt]()
for conversation in self._normalize_conversations(conversations): for conversation in conversations:
_, in_prompt = renderer.render_messages(conversation, chat_params) _, in_prompt = renderer.render_messages(conversation, chat_params)
if mm_processor_kwargs is not None: if mm_processor_kwargs is not None:
in_prompt["mm_processor_kwargs"] = mm_processor_kwargs target_prompt: SingletonDictPrompt = in_prompt.get( # type: ignore
"encoder_prompt", in_prompt
)
target_prompt["mm_processor_kwargs"] = mm_processor_kwargs # type: ignore
engine_prompts.append(renderer.tokenize_prompt(in_prompt, tok_params)) engine_prompts.append(renderer.tokenize_prompt(in_prompt, tok_params))
...@@ -960,8 +901,8 @@ class LLM: ...@@ -960,8 +901,8 @@ class LLM:
def chat( def chat(
self, self,
messages: list[ChatCompletionMessageParam] messages: list[ChatCompletionMessageParam]
| list[list[ChatCompletionMessageParam]], | Sequence[list[ChatCompletionMessageParam]],
sampling_params: SamplingParams | list[SamplingParams] | None = None, sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
use_tqdm: bool | Callable[..., tqdm] = True, use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
chat_template: str | None = None, chat_template: str | None = None,
...@@ -984,7 +925,7 @@ class LLM: ...@@ -984,7 +925,7 @@ class LLM:
to the OpenAI API. to the OpenAI API.
Args: Args:
messages: A list of conversations or a single conversation. messages: A sequence of conversations or a single conversation.
- Each conversation is represented as a list of messages. - Each conversation is represented as a list of messages.
- Each message is a dictionary with 'role' and 'content' keys. - Each message is a dictionary with 'role' and 'content' keys.
...@@ -1023,8 +964,23 @@ class LLM: ...@@ -1023,8 +964,23 @@ class LLM:
A list of `RequestOutput` objects containing the generated A list of `RequestOutput` objects containing the generated
responses in the same order as the input messages. responses in the same order as the input messages.
""" """
prompts = self._preprocess_chat( model_config = self.model_config
messages, runner_type = model_config.runner_type
if runner_type != "generate":
raise ValueError(
"LLM.chat() is only supported for generative models. "
"Try passing `--runner generate` to use the model as a "
"generative model."
)
if sampling_params is None:
sampling_params = self.get_default_sampling_params()
outputs = self._run_chat(
messages=messages,
params=sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
chat_template=chat_template, chat_template=chat_template,
chat_template_content_format=chat_template_content_format, chat_template_content_format=chat_template_content_format,
chat_template_kwargs=chat_template_kwargs, chat_template_kwargs=chat_template_kwargs,
...@@ -1035,13 +991,7 @@ class LLM: ...@@ -1035,13 +991,7 @@ class LLM:
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
) )
return self.generate( return self.engine_class.validate_outputs(outputs, RequestOutput)
prompts,
sampling_params=sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
def encode( def encode(
self, self,
...@@ -1163,7 +1113,7 @@ class LLM: ...@@ -1163,7 +1113,7 @@ class LLM:
msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!" msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
raise ValueError(msg) raise ValueError(msg)
self._validate_and_add_requests( outputs = self._run_completion(
prompts=prompts, prompts=prompts,
params=pooling_params, params=pooling_params,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
...@@ -1171,8 +1121,6 @@ class LLM: ...@@ -1171,8 +1121,6 @@ class LLM:
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
outputs = self._run_engine(use_tqdm=use_tqdm)
model_outputs = self.engine_class.validate_outputs( model_outputs = self.engine_class.validate_outputs(
outputs, PoolingRequestOutput outputs, PoolingRequestOutput
) )
...@@ -1523,14 +1471,13 @@ class LLM: ...@@ -1523,14 +1471,13 @@ class LLM:
prompts.append(engine_prompt) prompts.append(engine_prompt)
self._validate_and_add_requests( outputs = self._run_completion(
prompts=prompts, prompts=prompts,
params=pooling_params_list, params=pooling_params_list,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
) )
outputs = self._run_engine(use_tqdm=use_tqdm)
items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput) items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items] return [ScoringRequestOutput.from_base(item) for item in items]
...@@ -1727,33 +1674,29 @@ class LLM: ...@@ -1727,33 +1674,29 @@ class LLM:
""" """
return self.llm_engine.get_metrics() return self.llm_engine.get_metrics()
def _validate_and_add_requests( def _params_to_seq(
self, self,
prompts: PromptType | Sequence[PromptType],
params: SamplingParams params: SamplingParams
| Sequence[SamplingParams]
| PoolingParams | PoolingParams
| Sequence[PoolingParams], | Sequence[SamplingParams | PoolingParams],
*, num_requests: int,
use_tqdm: bool | Callable[..., tqdm] = True, ) -> Sequence[SamplingParams | PoolingParams]:
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
tokenization_kwargs: dict[str, Any] | None = None,
priority: list[int] | None = None,
) -> None:
in_prompts = self._normalize_prompts(prompts)
num_requests = len(in_prompts)
if isinstance(params, Sequence): if isinstance(params, Sequence):
if len(params) != num_requests: if len(params) != num_requests:
raise ValueError( raise ValueError(
f"The lengths of prompts ({params}) " f"The lengths of prompts ({params}) "
f"and lora_request ({len(params)}) must be the same." f"and params ({len(params)}) must be the same."
) )
engine_params = params return params
else:
engine_params = [params] * num_requests return [params] * num_requests
def _lora_request_to_seq(
self,
lora_request: LoRARequest | None | Sequence[LoRARequest | None],
num_requests: int,
) -> Sequence[LoRARequest | None]:
if isinstance(lora_request, Sequence): if isinstance(lora_request, Sequence):
if len(lora_request) != num_requests: if len(lora_request) != num_requests:
raise ValueError( raise ValueError(
...@@ -1761,28 +1704,50 @@ class LLM: ...@@ -1761,28 +1704,50 @@ class LLM:
f"and lora_request ({len(lora_request)}) must be the same." f"and lora_request ({len(lora_request)}) must be the same."
) )
engine_lora_requests: Sequence[LoRARequest | None] = lora_request return lora_request
else:
engine_lora_requests = [lora_request] * num_requests
return [lora_request] * num_requests
def _priority_to_seq(
self,
priority: list[int] | None,
num_requests: int,
) -> Sequence[int]:
if priority is not None: if priority is not None:
if len(priority) != num_requests: if len(priority) != num_requests:
raise ValueError( raise ValueError(
f"The lengths of prompts ({num_requests}) " f"The lengths of prompts ({num_requests}) "
f"and priority ({len(priority)}) must be the same." f"and priority ({len(priority)}) must be the same."
) )
else:
priority = [0] * num_requests
if any(param.truncate_prompt_tokens is not None for param in engine_params): return priority
return [0] * num_requests
def _run_completion(
self,
prompts: PromptType | Sequence[PromptType],
params: SamplingParams
| PoolingParams
| Sequence[SamplingParams | PoolingParams],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: list[LoRARequest] | LoRARequest | None = None,
priority: list[int] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
):
seq_prompts = prompt_to_seq(prompts)
seq_params = self._params_to_seq(params, len(seq_prompts))
if any(param.truncate_prompt_tokens is not None for param in seq_params):
# TODO: Remove this after deprecating `param.truncate_prompt_tokens` # TODO: Remove this after deprecating `param.truncate_prompt_tokens`
# Then, move the code from the `else` block to the top and let # Then, move the code from the `else` block to the top and let
# `self._preprocess_completion` handle prompt normalization # `self._preprocess_completion` handle prompt normalization
engine_prompts = [ engine_prompts = [
engine_prompt engine_prompt
for in_prompt, param in zip(in_prompts, engine_params) for prompt, param in zip(seq_prompts, seq_params)
for engine_prompt in self._preprocess_completion( for engine_prompt in self._preprocess_completion(
[in_prompt], [prompt],
tokenization_kwargs=merge_kwargs( tokenization_kwargs=merge_kwargs(
tokenization_kwargs, tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens), dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
...@@ -1791,17 +1756,90 @@ class LLM: ...@@ -1791,17 +1756,90 @@ class LLM:
] ]
else: else:
engine_prompts = self._preprocess_completion( engine_prompts = self._preprocess_completion(
in_prompts, seq_prompts,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
for sp in engine_params: self._validate_and_add_requests(
prompts=engine_prompts,
params=seq_params,
use_tqdm=use_tqdm,
lora_request=self._get_modality_specific_lora_reqs(
engine_prompts, lora_request
),
tokenization_kwargs=tokenization_kwargs,
priority=priority,
)
return self._run_engine(use_tqdm=use_tqdm)
def _run_chat(
self,
messages: list[ChatCompletionMessageParam]
| Sequence[list[ChatCompletionMessageParam]],
params: SamplingParams
| PoolingParams
| Sequence[SamplingParams | PoolingParams],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: LoRARequest | None = None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tools: list[dict[str, Any]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None,
):
engine_prompts = self._preprocess_chat(
conversation_to_seq(messages),
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
chat_template_kwargs=chat_template_kwargs,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
tokenization_kwargs=tokenization_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
)
self._validate_and_add_requests(
prompts=engine_prompts,
params=params,
use_tqdm=use_tqdm,
lora_request=self._get_modality_specific_lora_reqs(
engine_prompts, lora_request
),
tokenization_kwargs=tokenization_kwargs,
)
return self._run_engine(use_tqdm=use_tqdm)
def _validate_and_add_requests(
self,
prompts: Sequence[DictPrompt | TokPrompt],
params: SamplingParams
| PoolingParams
| Sequence[SamplingParams | PoolingParams],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
tokenization_kwargs: dict[str, Any] | None = None,
priority: list[int] | None = None,
) -> None:
num_requests = len(prompts)
seq_params = self._params_to_seq(params, num_requests)
seq_lora_requests = self._lora_request_to_seq(lora_request, num_requests)
seq_priority = self._priority_to_seq(priority, num_requests)
for sp in seq_params:
if isinstance(sp, SamplingParams): if isinstance(sp, SamplingParams):
# We only care about the final output # We only care about the final output
sp.output_kind = RequestOutputKind.FINAL_ONLY sp.output_kind = RequestOutputKind.FINAL_ONLY
# Add requests to the engine. # Add requests to the engine.
it = engine_prompts it = prompts
if use_tqdm: if use_tqdm:
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
it = tqdm_func(it, desc="Adding requests") it = tqdm_func(it, desc="Adding requests")
...@@ -1812,10 +1850,10 @@ class LLM: ...@@ -1812,10 +1850,10 @@ class LLM:
for i, prompt in enumerate(it): for i, prompt in enumerate(it):
request_id = self._add_request( request_id = self._add_request(
prompt, prompt,
engine_params[i], seq_params[i],
lora_request=engine_lora_requests[i], lora_request=seq_lora_requests[i],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
priority=priority[i], priority=seq_priority[i],
) )
added_request_ids.append(request_id) added_request_ids.append(request_id)
except Exception as e: except Exception as e:
...@@ -1825,13 +1863,13 @@ class LLM: ...@@ -1825,13 +1863,13 @@ class LLM:
def _add_request( def _add_request(
self, self,
prompt: PromptType, prompt: PromptType | DictPrompt | TokPrompt,
params: SamplingParams | PoolingParams, params: SamplingParams | PoolingParams,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
priority: int = 0, priority: int = 0,
) -> str: ) -> str:
prompt_text, _, _ = get_prompt_components(prompt) prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
if params.truncate_prompt_tokens is not None: if params.truncate_prompt_tokens is not None:
......
...@@ -67,12 +67,13 @@ from vllm.entrypoints.openai.parser.harmony_utils import ( ...@@ -67,12 +67,13 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
) )
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import EmbedsPrompt, TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.parser import ParserManager from vllm.parser import ParserManager
from vllm.reasoning import ReasoningParser from vllm.reasoning import ReasoningParser
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import ( from vllm.tokenizers.mistral import (
...@@ -218,10 +219,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -218,10 +219,7 @@ class OpenAIServingChat(OpenAIServing):
async def render_chat_request( async def render_chat_request(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> ( ) -> tuple[list[ConversationMessage], list[TokPrompt]] | ErrorResponse:
tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]
| ErrorResponse
):
""" """
render chat request by validating and preprocessing inputs. render chat request by validating and preprocessing inputs.
...@@ -380,7 +378,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -380,7 +378,7 @@ class OpenAIServingChat(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = [] generators: list[AsyncGenerator[RequestOutput, None]] = []
try: try:
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
prompt_text = engine_prompt.get("prompt") prompt_text = self._extract_prompt_text(engine_prompt)
# If we are creating sub requests for multiple prompts, ensure that they # If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids. # have unique request ids.
...@@ -389,10 +387,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -389,10 +387,10 @@ class OpenAIServingChat(OpenAIServing):
) )
max_tokens = get_max_tokens( max_tokens = get_max_tokens(
max_model_len=self.max_model_len, self.max_model_len,
request=request, request,
prompt=engine_prompt, self._extract_prompt_len(engine_prompt),
default_sampling_params=self.default_sampling_params, self.default_sampling_params,
) )
sampling_params: SamplingParams | BeamSearchParams sampling_params: SamplingParams | BeamSearchParams
......
...@@ -34,10 +34,10 @@ from vllm.entrypoints.openai.engine.serving import ( ...@@ -34,10 +34,10 @@ from vllm.entrypoints.openai.engine.serving import (
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
...@@ -78,7 +78,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -78,7 +78,7 @@ class OpenAIServingCompletion(OpenAIServing):
async def render_completion_request( async def render_completion_request(
self, self,
request: CompletionRequest, request: CompletionRequest,
) -> list[TokensPrompt | EmbedsPrompt] | ErrorResponse: ) -> list[TokPrompt] | ErrorResponse:
""" """
render completion request by validating and preprocessing inputs. render completion request by validating and preprocessing inputs.
...@@ -160,13 +160,13 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -160,13 +160,13 @@ class OpenAIServingCompletion(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = [] generators: list[AsyncGenerator[RequestOutput, None]] = []
try: try:
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
prompt_text = engine_prompt.get("prompt") prompt_text = self._extract_prompt_text(engine_prompt)
max_tokens = get_max_tokens( max_tokens = get_max_tokens(
max_model_len=self.max_model_len, self.max_model_len,
request=request, request,
prompt=engine_prompt, self._extract_prompt_len(engine_prompt),
default_sampling_params=self.default_sampling_params, self.default_sampling_params,
) )
sampling_params: SamplingParams | BeamSearchParams sampling_params: SamplingParams | BeamSearchParams
...@@ -277,7 +277,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -277,7 +277,7 @@ class OpenAIServingCompletion(OpenAIServing):
# with the inputs token IDs # with the inputs token IDs
if final_res.prompt is None: if final_res.prompt is None:
engine_prompt = engine_prompts[i] engine_prompt = engine_prompts[i]
final_res.prompt = engine_prompt.get("prompt") final_res.prompt = self._extract_prompt_text(engine_prompt)
final_res_batch_checked = cast(list[RequestOutput], final_res_batch) final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
...@@ -313,7 +313,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -313,7 +313,7 @@ class OpenAIServingCompletion(OpenAIServing):
async def completion_stream_generator( async def completion_stream_generator(
self, self,
request: CompletionRequest, request: CompletionRequest,
engine_prompts: list[TokensPrompt | EmbedsPrompt], engine_prompts: list[TokPrompt],
result_generator: AsyncIterator[tuple[int, RequestOutput]], result_generator: AsyncIterator[tuple[int, RequestOutput]],
request_id: str, request_id: str,
created_time: int, created_time: int,
...@@ -347,7 +347,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -347,7 +347,7 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_text = res.prompt prompt_text = res.prompt
if prompt_text is None: if prompt_text is None:
engine_prompt = engine_prompts[prompt_idx] engine_prompt = engine_prompts[prompt_idx]
prompt_text = engine_prompt.get("prompt") prompt_text = self._extract_prompt_text(engine_prompt)
# Prompt details are excluded from later streamed outputs # Prompt details are excluded from later streamed outputs
if prompt_token_ids is not None: if prompt_token_ids is not None:
......
...@@ -96,11 +96,7 @@ from vllm.entrypoints.serve.tokenize.protocol import ( ...@@ -96,11 +96,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
) )
from vllm.entrypoints.utils import get_max_tokens, sanitize_message from vllm.entrypoints.utils import get_max_tokens, sanitize_message
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, PromptType, TokensPrompt from vllm.inputs.data import PromptType, SingletonPrompt, TokensPrompt
from vllm.inputs.parse import (
get_prompt_components,
is_explicit_encoder_decoder_prompt,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -108,6 +104,14 @@ from vllm.multimodal import MultiModalDataDict ...@@ -108,6 +104,14 @@ from vllm.multimodal import MultiModalDataDict
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.renderers.inputs import TokPrompt
from vllm.renderers.inputs.preprocess import (
SingletonDictPrompt,
extract_prompt_components,
extract_prompt_len,
parse_model_prompt,
prompt_to_seq,
)
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser from vllm.tool_parsers import ToolParser
...@@ -203,7 +207,7 @@ class ServeContext(Generic[RequestT]): ...@@ -203,7 +207,7 @@ class ServeContext(Generic[RequestT]):
request_id: str request_id: str
created_time: int = field(default_factory=lambda: int(time.time())) created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None lora_request: LoRARequest | None = None
engine_prompts: list[TokensPrompt | EmbedsPrompt] | None = None engine_prompts: list[TokPrompt] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = ( result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None None
...@@ -247,7 +251,7 @@ class OpenAIServing: ...@@ -247,7 +251,7 @@ class OpenAIServing:
async def beam_search( async def beam_search(
self, self,
prompt: PromptType, prompt: TokPrompt,
request_id: str, request_id: str,
params: BeamSearchParams, params: BeamSearchParams,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
...@@ -271,20 +275,12 @@ class OpenAIServing: ...@@ -271,20 +275,12 @@ class OpenAIServing:
eos_token_id: int = tokenizer.eos_token_id # type: ignore eos_token_id: int = tokenizer.eos_token_id # type: ignore
if is_explicit_encoder_decoder_prompt(prompt): if isinstance(prompt, dict) and "encoder_prompt" in prompt:
raise NotImplementedError raise NotImplementedError("Encoder-decoder prompt not supported")
prompt_text: str | None prompt_text: str | None = prompt.get("prompt") # type: ignore
prompt_token_ids: list[int] prompt_token_ids: list[int] = prompt.get("prompt_token_ids", []) # type: ignore
multi_modal_data: MultiModalDataDict | None multi_modal_data: MultiModalDataDict | None = prompt.get("multi_modal_data") # type: ignore
if isinstance(prompt, str):
prompt_text = prompt
prompt_token_ids = []
multi_modal_data = None
else:
prompt_text = prompt.get("prompt") # type: ignore
prompt_token_ids = prompt.get("prompt_token_ids", []) # type: ignore
multi_modal_data = prompt.get("multi_modal_data") # type: ignore
mm_processor_kwargs: dict[str, Any] | None = None mm_processor_kwargs: dict[str, Any] | None = None
...@@ -963,22 +959,40 @@ class OpenAIServing: ...@@ -963,22 +959,40 @@ class OpenAIServing:
request: RendererRequest, request: RendererRequest,
prompt_input: str | list[str] | list[int] | list[list[int]] | None, prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None, prompt_embeds: bytes | list[bytes] | None,
) -> list[TokensPrompt | EmbedsPrompt]: ) -> list[TokPrompt]:
renderer = self.renderer renderer = self.renderer
tok_params = request.build_tok_params(self.model_config) model_config = self.model_config
in_prompts = await renderer.render_completions_async( tok_params = request.build_tok_params(model_config)
prompt_input, prompt_embeds
prompts = list[SingletonPrompt | bytes]()
if prompt_embeds is not None: # embeds take higher priority
prompts.extend(prompt_to_seq(prompt_embeds))
if prompt_input is not None:
prompts.extend(prompt_to_seq(prompt_input))
parsed_prompts = [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(model_config, prompt)
) )
engine_prompts = await renderer.tokenize_prompts_async(in_prompts, tok_params) for prompt in prompts
]
in_prompts = await renderer.render_prompts_async(parsed_prompts)
extra_items = { extra_items = {
k: v k: v
for k in ("mm_processor_kwargs", "cache_salt") for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None if (v := getattr(request, k, None)) is not None
} }
for prompt in engine_prompts: for in_prompt in in_prompts:
prompt.update(extra_items) # type: ignore target_prompt: SingletonDictPrompt = in_prompt.get( # type: ignore
"encoder_prompt", in_prompt
)
target_prompt.update(extra_items) # type: ignore
engine_prompts = await renderer.tokenize_prompts_async(in_prompts, tok_params)
return engine_prompts return engine_prompts
...@@ -991,7 +1005,7 @@ class OpenAIServing: ...@@ -991,7 +1005,7 @@ class OpenAIServing:
default_template_kwargs: dict[str, Any] | None, default_template_kwargs: dict[str, Any] | None,
tool_dicts: list[dict[str, Any]] | None = None, tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
) -> tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]: ) -> tuple[list[ConversationMessage], list[TokPrompt]]:
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
renderer = self.renderer renderer = self.renderer
...@@ -1009,17 +1023,21 @@ class OpenAIServing: ...@@ -1009,17 +1023,21 @@ class OpenAIServing:
default_template, default_template_content_format default_template, default_template_content_format
).with_defaults(default_template_kwargs) ).with_defaults(default_template_kwargs)
conversation, prompt = await renderer.render_messages_async( conversation, in_prompt = await renderer.render_messages_async(
messages, chat_params messages, chat_params
) )
engine_prompt = await renderer.tokenize_prompt_async(prompt, tok_params) target_prompt: SingletonDictPrompt = in_prompt.get( # type: ignore
"encoder_prompt", in_prompt
)
extra_items = { extra_items = {
k: v k: v
for k in ("mm_processor_kwargs", "cache_salt") for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None if (v := getattr(request, k, None)) is not None
} }
engine_prompt.update(extra_items) # type: ignore target_prompt.update(extra_items) # type: ignore
engine_prompt = await renderer.tokenize_prompt_async(target_prompt, tok_params)
# tool parsing is done only if a tool_parser has been set and if # tool parsing is done only if a tool_parser has been set and if
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
...@@ -1040,6 +1058,15 @@ class OpenAIServing: ...@@ -1040,6 +1058,15 @@ class OpenAIServing:
return conversation, [engine_prompt] return conversation, [engine_prompt]
def _extract_prompt_components(self, prompt: object):
return extract_prompt_components(self.model_config, prompt)
def _extract_prompt_text(self, prompt: object):
return self._extract_prompt_components(prompt).text
def _extract_prompt_len(self, prompt: object):
return extract_prompt_len(self.model_config, prompt)
async def _render_next_turn( async def _render_next_turn(
self, self,
request: ResponsesRequest, request: ResponsesRequest,
...@@ -1067,7 +1094,7 @@ class OpenAIServing: ...@@ -1067,7 +1094,7 @@ class OpenAIServing:
async def _generate_with_builtin_tools( async def _generate_with_builtin_tools(
self, self,
request_id: str, request_id: str,
engine_prompt: TokensPrompt | EmbedsPrompt, engine_prompt: TokPrompt,
sampling_params: SamplingParams, sampling_params: SamplingParams,
tok_params: TokenizeParams, tok_params: TokenizeParams,
context: ConversationContext, context: ConversationContext,
...@@ -1075,7 +1102,7 @@ class OpenAIServing: ...@@ -1075,7 +1102,7 @@ class OpenAIServing:
priority: int = 0, priority: int = 0,
trace_headers: Mapping[str, str] | None = None, trace_headers: Mapping[str, str] | None = None,
): ):
prompt_text = engine_prompt.get("prompt") prompt_text = self._extract_prompt_text(engine_prompt)
orig_priority = priority orig_priority = priority
sub_request = 0 sub_request = 0
...@@ -1145,12 +1172,12 @@ class OpenAIServing: ...@@ -1145,12 +1172,12 @@ class OpenAIServing:
context.chat_template_content_format, context.chat_template_content_format,
) )
engine_prompt = engine_prompts[0] engine_prompt = engine_prompts[0]
prompt_text = engine_prompt.get("prompt") prompt_text = self._extract_prompt_text(engine_prompt)
sampling_params.max_tokens = get_max_tokens( sampling_params.max_tokens = get_max_tokens(
self.max_model_len, self.max_model_len,
context.request, context.request,
engine_prompt, self._extract_prompt_len(engine_prompt),
self.default_sampling_params, # type: ignore self.default_sampling_params, # type: ignore
) )
...@@ -1161,20 +1188,20 @@ class OpenAIServing: ...@@ -1161,20 +1188,20 @@ class OpenAIServing:
def _log_inputs( def _log_inputs(
self, self,
request_id: str, request_id: str,
inputs: PromptType, inputs: PromptType | TokPrompt,
params: SamplingParams | PoolingParams | BeamSearchParams | None, params: SamplingParams | PoolingParams | BeamSearchParams | None,
lora_request: LoRARequest | None, lora_request: LoRARequest | None,
) -> None: ) -> None:
if self.request_logger is None: if self.request_logger is None:
return return
prompt, prompt_token_ids, prompt_embeds = get_prompt_components(inputs) components = self._extract_prompt_components(inputs)
self.request_logger.log_inputs( self.request_logger.log_inputs(
request_id, request_id,
prompt, components.text,
prompt_token_ids, components.token_ids,
prompt_embeds, components.embeds,
params=params, params=params,
lora_request=lora_request, lora_request=lora_request,
) )
......
...@@ -116,13 +116,13 @@ from vllm.entrypoints.openai.responses.utils import ( ...@@ -116,13 +116,13 @@ from vllm.entrypoints.openai.responses.utils import (
) )
from vllm.entrypoints.utils import get_max_tokens from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.inputs.parse import get_prompt_len
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs from vllm.logprobs import SampleLogprobs
from vllm.outputs import CompletionOutput from vllm.outputs import CompletionOutput
from vllm.parser import ParserManager from vllm.parser import ParserManager
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid from vllm.utils import random_uuid
...@@ -292,10 +292,10 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -292,10 +292,10 @@ class OpenAIServingResponses(OpenAIServing):
def _validate_generator_input( def _validate_generator_input(
self, self,
engine_prompt: TokensPrompt | EmbedsPrompt, engine_prompt: TokPrompt,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
"""Add validations to the input to the generator here.""" """Add validations to the input to the generator here."""
prompt_len = get_prompt_len(engine_prompt) prompt_len = self._extract_prompt_len(engine_prompt)
if self.max_model_len <= prompt_len: if self.max_model_len <= prompt_len:
error_message = ( error_message = (
f"The engine prompt length {prompt_len} " f"The engine prompt length {prompt_len} "
...@@ -442,7 +442,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -442,7 +442,7 @@ class OpenAIServingResponses(OpenAIServing):
default_max_tokens = get_max_tokens( default_max_tokens = get_max_tokens(
self.max_model_len, self.max_model_len,
request, request,
engine_prompt, self._extract_prompt_len(engine_prompt),
self.default_sampling_params, self.default_sampling_params,
) )
......
...@@ -7,7 +7,7 @@ import time ...@@ -7,7 +7,7 @@ import time
import zlib import zlib
from collections.abc import AsyncGenerator, Callable from collections.abc import AsyncGenerator, Callable
from functools import cached_property from functools import cached_property
from typing import Literal, TypeAlias, TypeVar, cast from typing import Final, Literal, TypeAlias, TypeVar, cast
import numpy as np import numpy as np
from fastapi import Request from fastapi import Request
...@@ -37,12 +37,13 @@ from vllm.entrypoints.openai.speech_to_text.protocol import ( ...@@ -37,12 +37,13 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
TranslationStreamResponse, TranslationStreamResponse,
) )
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import ExplicitEncoderDecoderPrompt, PromptType from vllm.inputs.data import PromptType
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import FlatLogprobs, Logprob from vllm.logprobs import FlatLogprobs, Logprob
from vllm.model_executor.models import SupportsTranscription, supports_transcription from vllm.model_executor.models import SupportsTranscription, supports_transcription
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.renderers.inputs import EncoderDecoderDictPrompt
from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt
from vllm.tokenizers import get_tokenizer from vllm.tokenizers import get_tokenizer
from vllm.utils.import_utils import PlaceholderModule from vllm.utils.import_utils import PlaceholderModule
...@@ -94,7 +95,7 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -94,7 +95,7 @@ class OpenAISpeechToText(OpenAIServing):
) )
self.default_sampling_params = self.model_config.get_diff_sampling_param() self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.task_type = task_type self.task_type: Final = task_type
self.asr_config = self.model_cls.get_speech_to_text_config( self.asr_config = self.model_cls.get_speech_to_text_config(
self.model_config, task_type self.model_config, task_type
...@@ -298,35 +299,26 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -298,35 +299,26 @@ class OpenAISpeechToText(OpenAIServing):
to_language=to_language, to_language=to_language,
) )
if request.response_format == "verbose_json": if request.response_format == "verbose_json":
if not is_explicit_encoder_decoder_prompt(prompt): prompt = self._preprocess_verbose_prompt(parse_enc_dec_prompt(prompt))
raise VLLMValidationError(
"Expected prompt to be an encoder-decoder prompt",
parameter="prompt",
value=type(prompt).__name__,
)
prompt = self._preprocess_verbose_prompt(prompt)
prompts.append(prompt) prompts.append(prompt)
return prompts, duration
def _repl_verbose_text(self, text: str): return prompts, duration
return text.replace("<|notimestamps|>", "<|0.00|>")
def _preprocess_verbose_prompt(self, prompt: ExplicitEncoderDecoderPrompt): def _preprocess_verbose_prompt(self, prompt: EncoderDecoderDictPrompt):
dec_prompt = prompt["decoder_prompt"] dec_prompt = prompt["decoder_prompt"]
if isinstance(dec_prompt, str): if not (isinstance(dec_prompt, dict) and "prompt" in dec_prompt):
prompt["decoder_prompt"] = self._repl_verbose_text(dec_prompt)
elif isinstance(dec_prompt, dict) and "prompt" in dec_prompt:
dec_prompt["prompt"] = self._repl_verbose_text(dec_prompt["prompt"])
else:
raise VLLMValidationError( raise VLLMValidationError(
"Expected decoder_prompt to contain text", "Expected decoder_prompt to contain text",
parameter="decoder_prompt", parameter="decoder_prompt",
value=type(dec_prompt).__name__, value=type(dec_prompt).__name__,
) )
dec_prompt["prompt"] = dec_prompt["prompt"].replace(
"<|notimestamps|>", "<|0.00|>"
)
return prompt return prompt
def _get_verbose_segments( def _get_verbose_segments(
......
...@@ -28,10 +28,11 @@ from vllm.entrypoints.pooling.utils import ( ...@@ -28,10 +28,11 @@ from vllm.entrypoints.pooling.utils import (
encode_pooling_output_base64, encode_pooling_output_base64,
encode_pooling_output_float, encode_pooling_output_float,
) )
from vllm.inputs.data import EmbedsPrompt, TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers.inputs import TokPrompt
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import chunk_list from vllm.utils.collection_utils import chunk_list
from vllm.utils.serial_utils import EmbedDType, Endianness from vllm.utils.serial_utils import EmbedDType, Endianness
...@@ -369,7 +370,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -369,7 +370,7 @@ class OpenAIServingEmbedding(OpenAIServing):
async def _create_single_prompt_generator( async def _create_single_prompt_generator(
self, self,
ctx: EmbeddingServeContext, ctx: EmbeddingServeContext,
engine_prompt: TokensPrompt | EmbedsPrompt, engine_prompt: TokPrompt,
pooling_params: PoolingParams, pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None, trace_headers: Mapping[str, str] | None,
prompt_index: int, prompt_index: int,
......
...@@ -33,8 +33,11 @@ from vllm.entrypoints.pooling.utils import ( ...@@ -33,8 +33,11 @@ from vllm.entrypoints.pooling.utils import (
encode_pooling_output_base64, encode_pooling_output_base64,
encode_pooling_output_float, encode_pooling_output_float,
) )
from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.renderers.inputs import TokPrompt
from vllm.renderers.inputs.preprocess import prompt_to_seq
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
...@@ -91,6 +94,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -91,6 +94,7 @@ class OpenAIServingPooling(OpenAIServing):
"dimensions is currently not supported" "dimensions is currently not supported"
) )
engine_prompts: Sequence[PromptType | TokPrompt]
if is_io_processor_request: if is_io_processor_request:
if self.io_processor is None: if self.io_processor is None:
raise ValueError( raise ValueError(
...@@ -102,14 +106,10 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -102,14 +106,10 @@ class OpenAIServingPooling(OpenAIServing):
validated_prompt = self.io_processor.parse_request(request) validated_prompt = self.io_processor.parse_request(request)
engine_prompts = await self.io_processor.pre_process_async( raw_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id prompt=validated_prompt, request_id=request_id
) )
if not isinstance(engine_prompts, Sequence) or isinstance( engine_prompts = prompt_to_seq(raw_prompts)
engine_prompts, (str, bytes, bytearray)
):
engine_prompts = [engine_prompts]
elif isinstance(request, PoolingChatRequest): elif isinstance(request, PoolingChatRequest):
error_check_ret = self._validate_chat_template( error_check_ret = self._validate_chat_template(
request_chat_template=request.chat_template, request_chat_template=request.chat_template,
......
...@@ -17,8 +17,6 @@ from starlette.background import BackgroundTask, BackgroundTasks ...@@ -17,8 +17,6 @@ from starlette.background import BackgroundTask, BackgroundTasks
from vllm import envs from vllm import envs
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import EmbedsPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_len
from vllm.logger import current_formatter_type, init_logger from vllm.logger import current_formatter_type, init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
...@@ -189,7 +187,7 @@ def cli_env_setup(): ...@@ -189,7 +187,7 @@ def cli_env_setup():
def get_max_tokens( def get_max_tokens(
max_model_len: int, max_model_len: int,
request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest", request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest",
prompt: TokensPrompt | EmbedsPrompt, input_length: int,
default_sampling_params: dict, default_sampling_params: dict,
) -> int: ) -> int:
# NOTE: Avoid isinstance() for better efficiency # NOTE: Avoid isinstance() for better efficiency
...@@ -204,7 +202,6 @@ def get_max_tokens( ...@@ -204,7 +202,6 @@ def get_max_tokens(
# CompletionRequest (also a fallback for ChatCompletionRequest) # CompletionRequest (also a fallback for ChatCompletionRequest)
max_tokens = getattr(request, "max_tokens", None) max_tokens = getattr(request, "max_tokens", None)
input_length = get_prompt_len(prompt)
default_max_tokens = max_model_len - input_length default_max_tokens = max_model_len - input_length
max_output_tokens = current_platform.get_max_output_tokens(input_length) max_output_tokens = current_platform.get_max_output_tokens(input_length)
......
...@@ -16,11 +16,8 @@ from .data import ( ...@@ -16,11 +16,8 @@ from .data import (
TextPrompt, TextPrompt,
TokenInputs, TokenInputs,
TokensPrompt, TokensPrompt,
build_explicit_enc_dec_prompt,
embeds_inputs, embeds_inputs,
to_enc_dec_tuple_list,
token_inputs, token_inputs,
zip_enc_dec_prompts,
) )
__all__ = [ __all__ = [
...@@ -39,8 +36,5 @@ __all__ = [ ...@@ -39,8 +36,5 @@ __all__ = [
"EncoderDecoderInputs", "EncoderDecoderInputs",
"ProcessorInputs", "ProcessorInputs",
"SingletonInputs", "SingletonInputs",
"build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list",
"zip_enc_dec_prompts",
"StreamingInput", "StreamingInput",
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment