Unverified Commit cd8b405b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Consolidate sequence normalization and enc-dec parsing (#33928)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 4707f7eb
......@@ -54,6 +54,7 @@ class MockModelConfig:
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
is_encoder_decoder: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
......
......@@ -53,6 +53,7 @@ class MockModelConfig:
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
is_encoder_decoder: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
......
......@@ -52,6 +52,7 @@ class MockModelConfig:
encoder_config = None
generation_config: str = "auto"
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
......
......@@ -529,6 +529,7 @@ class MockModelConfig:
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.renderers.inputs.preprocess import prompt_to_seq
def test_empty_input():
assert prompt_to_seq([]) == []
assert prompt_to_seq([[]]) == [[]]
assert prompt_to_seq([[], []]) == [[], []]
def test_text_input():
assert prompt_to_seq("foo") == ["foo"]
assert prompt_to_seq(["foo"]) == ["foo"]
assert prompt_to_seq(["foo", "bar"]) == ["foo", "bar"]
def test_token_input():
assert prompt_to_seq([1, 2]) == [[1, 2]]
assert prompt_to_seq([[1, 2]]) == [[1, 2]]
assert prompt_to_seq([[1, 2], [3, 4]]) == [[1, 2], [3, 4]]
def test_text_token_input():
assert prompt_to_seq([[1, 2], "foo"]) == [[1, 2], "foo"]
assert prompt_to_seq(["foo", [1, 2]]) == ["foo", [1, 2]]
def test_bytes_input():
assert prompt_to_seq(b"foo") == [b"foo"]
assert prompt_to_seq([b"foo"]) == [b"foo"]
assert prompt_to_seq([b"foo", b"bar"]) == [b"foo", b"bar"]
def test_dict_input():
assert prompt_to_seq({"prompt": "foo"}) == [{"prompt": "foo"}]
assert prompt_to_seq([{"prompt": "foo"}]) == [{"prompt": "foo"}]
assert prompt_to_seq([{"prompt": "foo"}, {"prompt_token_ids": [1, 2]}]) == [
{"prompt": "foo"},
{"prompt_token_ids": [1, 2]},
]
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import io
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
......@@ -9,8 +10,11 @@ import pybase64
import pytest
import torch
from vllm.config import ModelConfig
from vllm.inputs import SingletonPrompt
from vllm.renderers import TokenizeParams
from vllm.renderers.hf import HfRenderer
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
from vllm.tokenizers.registry import tokenizer_args_from_config
MODEL_NAME = "openai-community/gpt2"
......@@ -33,6 +37,7 @@ class MockModelConfig:
encoder_config: dict[str, Any] | None = None
enable_prompt_embeds: bool = True
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
@dataclass
......@@ -80,65 +85,34 @@ def _build_renderer(
return renderer
class TestValidatePrompt:
STRING_INPUTS = [
"",
"foo",
"foo bar",
"foo baz bar",
"foo bar qux baz",
]
TOKEN_INPUTS = [
[-1],
[1],
[1, 2],
[1, 3, 4],
[1, 2, 4, 3],
def _preprocess_prompt(
mdoel_config: ModelConfig,
prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes],
):
return [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(mdoel_config, prompt)
)
for prompt in prompt_to_seq(prompt_or_prompts)
]
INPUTS_SLICES = [
slice(None, None, -1),
slice(None, None, 2),
slice(None, None, -2),
]
# Test that a nested mixed-type list of lists raises a TypeError.
class TestValidatePrompt:
def test_empty_input(self):
renderer = _build_renderer(MockModelConfig())
with pytest.raises(ValueError, match="at least one prompt"):
renderer.render_completions([])
renderer.render_prompts(_preprocess_prompt(renderer.config, []))
def test_invalid_type(self):
renderer = _build_renderer(MockModelConfig())
with pytest.raises(TypeError, match="string or an array of tokens"):
renderer.render_completions([[1, 2], ["foo", "bar"]])
@pytest.mark.parametrize("string_input", STRING_INPUTS)
def test_string_consistent(self, string_input: str):
renderer = _build_renderer(MockModelConfig())
assert renderer.render_completions(string_input) == renderer.render_completions(
[string_input]
)
@pytest.mark.parametrize("token_input", TOKEN_INPUTS)
def test_token_consistent(self, token_input: list[int]):
renderer = _build_renderer(MockModelConfig())
assert renderer.render_completions(token_input) == renderer.render_completions(
[token_input]
)
@pytest.mark.parametrize("inputs_slice", INPUTS_SLICES)
def test_string_slice(self, inputs_slice: slice):
renderer = _build_renderer(MockModelConfig())
assert renderer.render_completions(self.STRING_INPUTS)[
inputs_slice
] == renderer.render_completions(self.STRING_INPUTS[inputs_slice])
with pytest.raises(TypeError, match="should be a list of integers"):
renderer.render_prompts(
_preprocess_prompt(renderer.config, [[1, 2], ["foo", "bar"]]) # type: ignore[arg-type]
)
class TestRenderPrompt:
......@@ -146,7 +120,7 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig())
tokens = [101, 7592, 2088]
prompts = renderer.render_completions(tokens)
prompts = renderer.render_prompts(_preprocess_prompt(renderer.config, tokens))
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
......@@ -159,7 +133,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig())
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
prompts = renderer.render_completions(token_lists)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, token_lists)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
......@@ -174,7 +150,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig())
text_input = "x" * 10
prompts = renderer.render_completions(text_input)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, text_input)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
......@@ -187,7 +165,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig())
text_list_input = ["x" * 10, "x" * 12, "x" * 14]
prompts = renderer.render_completions(text_list_input)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, text_list_input)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
......@@ -200,7 +180,9 @@ class TestRenderPrompt:
def test_zero_truncation(self):
renderer = _build_renderer(MockModelConfig())
prompts = renderer.render_completions("x" * 200)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, "x" * 200)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=0),
......@@ -212,7 +194,9 @@ class TestRenderPrompt:
def test_pos_truncation(self):
renderer = _build_renderer(MockModelConfig())
prompts = renderer.render_completions("x" * 200)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, "x" * 200)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=50),
......@@ -224,7 +208,9 @@ class TestRenderPrompt:
def test_neg_truncation(self):
renderer = _build_renderer(MockModelConfig())
prompts = renderer.render_completions("x" * 200)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, "x" * 200)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=-1),
......@@ -237,7 +223,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig(), truncation_side="left")
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
prompts = renderer.render_completions(long_tokens)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
......@@ -251,7 +239,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig(), truncation_side="right")
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
prompts = renderer.render_completions(long_tokens)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
......@@ -266,7 +256,9 @@ class TestRenderPrompt:
# Exceeds max_total_tokens and max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
long_tokens = "x" * 150
prompts = renderer.render_completions(long_tokens)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
with pytest.raises(
ValueError,
......@@ -285,7 +277,9 @@ class TestRenderPrompt:
# Exceeds max_total_tokens but not max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
long_tokens = "x" * 150
prompts = renderer.render_completions(long_tokens)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
with pytest.raises(
ValueError,
......@@ -304,7 +298,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig())
long_tokens = list(range(150)) # Exceeds max_total_tokens=100
prompts = renderer.render_completions(long_tokens)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
with pytest.raises(
ValueError,
......@@ -318,7 +314,9 @@ class TestRenderPrompt:
def test_no_tokenizer_for_text(self):
renderer = _build_renderer(MockModelConfig(skip_tokenizer_init=True))
prompts = renderer.render_completions("Hello world")
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, "Hello world")
)
with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
renderer.tokenize_prompts(
......@@ -330,7 +328,7 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig())
tokens = [1, 2, 3, 4]
prompts = renderer.render_completions(tokens)
prompts = renderer.render_prompts(_preprocess_prompt(renderer.config, tokens))
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(
......@@ -359,7 +357,9 @@ class TestRenderEmbedPrompt:
tensor_input = torch.randn(10, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(tensor_input)
prompts = renderer.render_completions(prompt_embeds=embed_bytes)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, embed_bytes)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
......@@ -377,8 +377,11 @@ class TestRenderEmbedPrompt:
torch.randn(12, 512, dtype=torch.float32),
]
prompts = renderer.render_completions(
prompt_embeds=[self._create_test_embed_bytes(t) for t in tensor_inputs],
prompts = renderer.render_prompts(
_preprocess_prompt(
renderer.config,
[self._create_test_embed_bytes(t) for t in tensor_inputs],
)
)
results = renderer.tokenize_prompts(
prompts,
......@@ -395,8 +398,10 @@ class TestRenderEmbedPrompt:
# Create tensor with more tokens than truncation limit
tensor_input = torch.randn(20, 768, dtype=torch.float32)
prompts = renderer.render_completions(
prompt_embeds=self._create_test_embed_bytes(tensor_input),
prompts = renderer.render_prompts(
_preprocess_prompt(
renderer.config, self._create_test_embed_bytes(tensor_input)
)
)
results = renderer.tokenize_prompts(
prompts,
......@@ -420,8 +425,10 @@ class TestRenderEmbedPrompt:
for dtype in dtypes:
tensor_input = torch.randn(5, 256, dtype=dtype)
prompts = renderer.render_completions(
prompt_embeds=self._create_test_embed_bytes(tensor_input),
prompts = renderer.render_prompts(
_preprocess_prompt(
renderer.config, self._create_test_embed_bytes(tensor_input)
)
)
results = renderer.tokenize_prompts(
prompts,
......@@ -437,8 +444,10 @@ class TestRenderEmbedPrompt:
# Test tensor with batch dimension gets squeezed
tensor_input = torch.randn(1, 10, 768, dtype=torch.float32)
prompts = renderer.render_completions(
prompt_embeds=self._create_test_embed_bytes(tensor_input),
prompts = renderer.render_prompts(
_preprocess_prompt(
renderer.config, self._create_test_embed_bytes(tensor_input)
)
)
results = renderer.tokenize_prompts(
prompts,
......@@ -455,9 +464,11 @@ class TestRenderEmbedPrompt:
text_input = "Hello world"
tensor_input = torch.randn(5, 256, dtype=torch.float32)
prompts = renderer.render_completions(
text_input,
prompt_embeds=self._create_test_embed_bytes(tensor_input),
prompts = renderer.render_prompts(
_preprocess_prompt(
renderer.config,
[text_input, self._create_test_embed_bytes(tensor_input)],
)
)
results = renderer.tokenize_prompts(
prompts,
......@@ -465,8 +476,8 @@ class TestRenderEmbedPrompt:
)
assert len(results) == 2
# First should be embed prompt
assert torch.equal(results[0]["prompt_embeds"], tensor_input)
# Second should be tokens prompt
assert "prompt_token_ids" in results[1]
assert len(results[1]["prompt_token_ids"]) == len(text_input)
# First should be tokens prompt
assert "prompt_token_ids" in results[0]
assert len(results[0]["prompt_token_ids"]) == len(text_input)
# Second should be embed prompt
assert torch.equal(results[1]["prompt_embeds"], tensor_input)
......@@ -3,16 +3,40 @@
import asyncio
import time
from dataclasses import dataclass
from typing import Any
from unittest.mock import Mock
import pytest
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from vllm.config import ModelConfig
from vllm.renderers import ChatParams
from vllm.renderers.mistral import MistralRenderer, safe_apply_chat_template
from vllm.tokenizers.mistral import MistralTokenizer
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
@dataclass
class MockHFConfig:
model_type: str = "any"
@dataclass
class MockModelConfig:
runner_type = "generate"
model: str = MODEL_NAME
tokenizer: str = MODEL_NAME
trust_remote_code: bool = False
max_model_len: int = 100
tokenizer_revision = None
tokenizer_mode = "mistral"
hf_config = MockHFConfig()
encoder_config: dict[str, Any] | None = None
enable_prompt_embeds: bool = True
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
@pytest.mark.asyncio
async def test_async_mistral_tokenizer_does_not_block_event_loop():
......@@ -23,9 +47,10 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop():
time.sleep(2)
return expected_tokens
mock_model_config = MockModelConfig(skip_tokenizer_init=True)
mock_tokenizer = Mock(spec=MistralTokenizer)
mock_tokenizer.apply_chat_template = mocked_apply_chat_template
mock_renderer = MistralRenderer(Mock(spec=ModelConfig), tokenizer_kwargs={})
mock_renderer = MistralRenderer(mock_model_config, tokenizer_kwargs={})
mock_renderer._tokenizer = mock_tokenizer
task = mock_renderer.render_messages_async([], ChatParams())
......
......@@ -4,52 +4,13 @@
import pytest
from vllm.config import ModelConfig
from vllm.inputs import zip_enc_dec_prompts
from vllm.inputs.preprocess import InputPreprocessor
pytestmark = pytest.mark.cpu_test
@pytest.mark.parametrize(
"mm_processor_kwargs,expected_mm_kwargs",
[
(None, [{}, {}]),
({}, [{}, {}]),
({"foo": 100}, [{"foo": 100}, {"foo": 100}]),
([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]),
],
)
def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
"""Test mm_processor_kwargs init for zipping enc/dec prompts."""
encoder_prompts = ["An encoder prompt", "Another encoder prompt"]
decoder_prompts = ["A decoder prompt", "Another decoder prompt"]
zipped_prompts = zip_enc_dec_prompts(
encoder_prompts, decoder_prompts, mm_processor_kwargs
)
assert len(zipped_prompts) == len(encoder_prompts) == len(decoder_prompts)
for enc, dec, exp_kwargs, zipped in zip(
encoder_prompts, decoder_prompts, expected_mm_kwargs, zipped_prompts
):
assert isinstance(zipped, dict)
assert len(zipped.keys()) == 3
assert zipped["encoder_prompt"] == enc
assert zipped["decoder_prompt"] == dec
assert zipped["mm_processor_kwargs"] == exp_kwargs
@pytest.mark.parametrize(
"model_id",
[
"facebook/chameleon-7b",
],
)
@pytest.mark.parametrize(
"prompt",
[
"",
{"prompt_token_ids": []},
],
)
@pytest.mark.parametrize("model_id", ["facebook/chameleon-7b"])
@pytest.mark.parametrize("prompt", ["", {"prompt_token_ids": []}])
@pytest.mark.skip(
reason=(
"Applying huggingface processor on text inputs results in "
......
......@@ -16,6 +16,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor
from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.v1.engine import EngineCoreRequest
......@@ -53,7 +54,11 @@ class EngineClient(ABC):
@abstractmethod
def generate(
self,
prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
prompt: EngineCoreRequest
| PromptType
| DictPrompt
| TokPrompt
| AsyncGenerator[StreamingInput, None],
sampling_params: SamplingParams,
request_id: str,
*,
......@@ -70,7 +75,7 @@ class EngineClient(ABC):
@abstractmethod
def encode(
self,
prompt: PromptType,
prompt: PromptType | DictPrompt | TokPrompt,
pooling_params: PoolingParams,
request_id: str,
lora_request: LoRARequest | None = None,
......
This diff is collapsed.
......@@ -67,12 +67,13 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
)
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.parser import ParserManager
from vllm.reasoning import ReasoningParser
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import (
......@@ -218,10 +219,7 @@ class OpenAIServingChat(OpenAIServing):
async def render_chat_request(
self,
request: ChatCompletionRequest,
) -> (
tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]
| ErrorResponse
):
) -> tuple[list[ConversationMessage], list[TokPrompt]] | ErrorResponse:
"""
render chat request by validating and preprocessing inputs.
......@@ -380,7 +378,7 @@ class OpenAIServingChat(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text = engine_prompt.get("prompt")
prompt_text = self._extract_prompt_text(engine_prompt)
# If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids.
......@@ -389,10 +387,10 @@ class OpenAIServingChat(OpenAIServing):
)
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
prompt=engine_prompt,
default_sampling_params=self.default_sampling_params,
self.max_model_len,
request,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params,
)
sampling_params: SamplingParams | BeamSearchParams
......
......@@ -34,10 +34,10 @@ from vllm.entrypoints.openai.engine.serving import (
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import merge_async_iterators
......@@ -78,7 +78,7 @@ class OpenAIServingCompletion(OpenAIServing):
async def render_completion_request(
self,
request: CompletionRequest,
) -> list[TokensPrompt | EmbedsPrompt] | ErrorResponse:
) -> list[TokPrompt] | ErrorResponse:
"""
render completion request by validating and preprocessing inputs.
......@@ -160,13 +160,13 @@ class OpenAIServingCompletion(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text = engine_prompt.get("prompt")
prompt_text = self._extract_prompt_text(engine_prompt)
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
prompt=engine_prompt,
default_sampling_params=self.default_sampling_params,
self.max_model_len,
request,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params,
)
sampling_params: SamplingParams | BeamSearchParams
......@@ -277,7 +277,7 @@ class OpenAIServingCompletion(OpenAIServing):
# with the inputs token IDs
if final_res.prompt is None:
engine_prompt = engine_prompts[i]
final_res.prompt = engine_prompt.get("prompt")
final_res.prompt = self._extract_prompt_text(engine_prompt)
final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
......@@ -313,7 +313,7 @@ class OpenAIServingCompletion(OpenAIServing):
async def completion_stream_generator(
self,
request: CompletionRequest,
engine_prompts: list[TokensPrompt | EmbedsPrompt],
engine_prompts: list[TokPrompt],
result_generator: AsyncIterator[tuple[int, RequestOutput]],
request_id: str,
created_time: int,
......@@ -347,7 +347,7 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_text = res.prompt
if prompt_text is None:
engine_prompt = engine_prompts[prompt_idx]
prompt_text = engine_prompt.get("prompt")
prompt_text = self._extract_prompt_text(engine_prompt)
# Prompt details are excluded from later streamed outputs
if prompt_token_ids is not None:
......
......@@ -96,11 +96,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
)
from vllm.entrypoints.utils import get_max_tokens, sanitize_message
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, PromptType, TokensPrompt
from vllm.inputs.parse import (
get_prompt_components,
is_explicit_encoder_decoder_prompt,
)
from vllm.inputs.data import PromptType, SingletonPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest
......@@ -108,6 +104,14 @@ from vllm.multimodal import MultiModalDataDict
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.renderers.inputs import TokPrompt
from vllm.renderers.inputs.preprocess import (
SingletonDictPrompt,
extract_prompt_components,
extract_prompt_len,
parse_model_prompt,
prompt_to_seq,
)
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser
......@@ -203,7 +207,7 @@ class ServeContext(Generic[RequestT]):
request_id: str
created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None
engine_prompts: list[TokensPrompt | EmbedsPrompt] | None = None
engine_prompts: list[TokPrompt] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None
......@@ -247,7 +251,7 @@ class OpenAIServing:
async def beam_search(
self,
prompt: PromptType,
prompt: TokPrompt,
request_id: str,
params: BeamSearchParams,
lora_request: LoRARequest | None = None,
......@@ -271,20 +275,12 @@ class OpenAIServing:
eos_token_id: int = tokenizer.eos_token_id # type: ignore
if is_explicit_encoder_decoder_prompt(prompt):
raise NotImplementedError
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
raise NotImplementedError("Encoder-decoder prompt not supported")
prompt_text: str | None
prompt_token_ids: list[int]
multi_modal_data: MultiModalDataDict | None
if isinstance(prompt, str):
prompt_text = prompt
prompt_token_ids = []
multi_modal_data = None
else:
prompt_text = prompt.get("prompt") # type: ignore
prompt_token_ids = prompt.get("prompt_token_ids", []) # type: ignore
multi_modal_data = prompt.get("multi_modal_data") # type: ignore
prompt_text: str | None = prompt.get("prompt") # type: ignore
prompt_token_ids: list[int] = prompt.get("prompt_token_ids", []) # type: ignore
multi_modal_data: MultiModalDataDict | None = prompt.get("multi_modal_data") # type: ignore
mm_processor_kwargs: dict[str, Any] | None = None
......@@ -963,22 +959,40 @@ class OpenAIServing:
request: RendererRequest,
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None,
) -> list[TokensPrompt | EmbedsPrompt]:
) -> list[TokPrompt]:
renderer = self.renderer
tok_params = request.build_tok_params(self.model_config)
model_config = self.model_config
in_prompts = await renderer.render_completions_async(
prompt_input, prompt_embeds
)
engine_prompts = await renderer.tokenize_prompts_async(in_prompts, tok_params)
tok_params = request.build_tok_params(model_config)
prompts = list[SingletonPrompt | bytes]()
if prompt_embeds is not None: # embeds take higher priority
prompts.extend(prompt_to_seq(prompt_embeds))
if prompt_input is not None:
prompts.extend(prompt_to_seq(prompt_input))
parsed_prompts = [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(model_config, prompt)
)
for prompt in prompts
]
in_prompts = await renderer.render_prompts_async(parsed_prompts)
extra_items = {
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
}
for prompt in engine_prompts:
prompt.update(extra_items) # type: ignore
for in_prompt in in_prompts:
target_prompt: SingletonDictPrompt = in_prompt.get( # type: ignore
"encoder_prompt", in_prompt
)
target_prompt.update(extra_items) # type: ignore
engine_prompts = await renderer.tokenize_prompts_async(in_prompts, tok_params)
return engine_prompts
......@@ -991,7 +1005,7 @@ class OpenAIServing:
default_template_kwargs: dict[str, Any] | None,
tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
) -> tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]:
) -> tuple[list[ConversationMessage], list[TokPrompt]]:
from vllm.tokenizers.mistral import MistralTokenizer
renderer = self.renderer
......@@ -1009,17 +1023,21 @@ class OpenAIServing:
default_template, default_template_content_format
).with_defaults(default_template_kwargs)
conversation, prompt = await renderer.render_messages_async(
conversation, in_prompt = await renderer.render_messages_async(
messages, chat_params
)
engine_prompt = await renderer.tokenize_prompt_async(prompt, tok_params)
target_prompt: SingletonDictPrompt = in_prompt.get( # type: ignore
"encoder_prompt", in_prompt
)
extra_items = {
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
}
engine_prompt.update(extra_items) # type: ignore
target_prompt.update(extra_items) # type: ignore
engine_prompt = await renderer.tokenize_prompt_async(target_prompt, tok_params)
# tool parsing is done only if a tool_parser has been set and if
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser
......@@ -1040,6 +1058,15 @@ class OpenAIServing:
return conversation, [engine_prompt]
def _extract_prompt_components(self, prompt: object):
return extract_prompt_components(self.model_config, prompt)
def _extract_prompt_text(self, prompt: object):
return self._extract_prompt_components(prompt).text
def _extract_prompt_len(self, prompt: object):
return extract_prompt_len(self.model_config, prompt)
async def _render_next_turn(
self,
request: ResponsesRequest,
......@@ -1067,7 +1094,7 @@ class OpenAIServing:
async def _generate_with_builtin_tools(
self,
request_id: str,
engine_prompt: TokensPrompt | EmbedsPrompt,
engine_prompt: TokPrompt,
sampling_params: SamplingParams,
tok_params: TokenizeParams,
context: ConversationContext,
......@@ -1075,7 +1102,7 @@ class OpenAIServing:
priority: int = 0,
trace_headers: Mapping[str, str] | None = None,
):
prompt_text = engine_prompt.get("prompt")
prompt_text = self._extract_prompt_text(engine_prompt)
orig_priority = priority
sub_request = 0
......@@ -1145,12 +1172,12 @@ class OpenAIServing:
context.chat_template_content_format,
)
engine_prompt = engine_prompts[0]
prompt_text = engine_prompt.get("prompt")
prompt_text = self._extract_prompt_text(engine_prompt)
sampling_params.max_tokens = get_max_tokens(
self.max_model_len,
context.request,
engine_prompt,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params, # type: ignore
)
......@@ -1161,20 +1188,20 @@ class OpenAIServing:
def _log_inputs(
self,
request_id: str,
inputs: PromptType,
inputs: PromptType | TokPrompt,
params: SamplingParams | PoolingParams | BeamSearchParams | None,
lora_request: LoRARequest | None,
) -> None:
if self.request_logger is None:
return
prompt, prompt_token_ids, prompt_embeds = get_prompt_components(inputs)
components = self._extract_prompt_components(inputs)
self.request_logger.log_inputs(
request_id,
prompt,
prompt_token_ids,
prompt_embeds,
components.text,
components.token_ids,
components.embeds,
params=params,
lora_request=lora_request,
)
......
......@@ -116,13 +116,13 @@ from vllm.entrypoints.openai.responses.utils import (
)
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_len
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs
from vllm.outputs import CompletionOutput
from vllm.parser import ParserManager
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid
......@@ -292,10 +292,10 @@ class OpenAIServingResponses(OpenAIServing):
def _validate_generator_input(
self,
engine_prompt: TokensPrompt | EmbedsPrompt,
engine_prompt: TokPrompt,
) -> ErrorResponse | None:
"""Add validations to the input to the generator here."""
prompt_len = get_prompt_len(engine_prompt)
prompt_len = self._extract_prompt_len(engine_prompt)
if self.max_model_len <= prompt_len:
error_message = (
f"The engine prompt length {prompt_len} "
......@@ -442,7 +442,7 @@ class OpenAIServingResponses(OpenAIServing):
default_max_tokens = get_max_tokens(
self.max_model_len,
request,
engine_prompt,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params,
)
......
......@@ -7,7 +7,7 @@ import time
import zlib
from collections.abc import AsyncGenerator, Callable
from functools import cached_property
from typing import Literal, TypeAlias, TypeVar, cast
from typing import Final, Literal, TypeAlias, TypeVar, cast
import numpy as np
from fastapi import Request
......@@ -37,12 +37,13 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
TranslationStreamResponse,
)
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import ExplicitEncoderDecoderPrompt, PromptType
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.logprobs import FlatLogprobs, Logprob
from vllm.model_executor.models import SupportsTranscription, supports_transcription
from vllm.outputs import RequestOutput
from vllm.renderers.inputs import EncoderDecoderDictPrompt
from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt
from vllm.tokenizers import get_tokenizer
from vllm.utils.import_utils import PlaceholderModule
......@@ -94,7 +95,7 @@ class OpenAISpeechToText(OpenAIServing):
)
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.task_type = task_type
self.task_type: Final = task_type
self.asr_config = self.model_cls.get_speech_to_text_config(
self.model_config, task_type
......@@ -298,35 +299,26 @@ class OpenAISpeechToText(OpenAIServing):
to_language=to_language,
)
if request.response_format == "verbose_json":
if not is_explicit_encoder_decoder_prompt(prompt):
raise VLLMValidationError(
"Expected prompt to be an encoder-decoder prompt",
parameter="prompt",
value=type(prompt).__name__,
)
prompt = self._preprocess_verbose_prompt(prompt)
prompt = self._preprocess_verbose_prompt(parse_enc_dec_prompt(prompt))
prompts.append(prompt)
return prompts, duration
def _repl_verbose_text(self, text: str):
return text.replace("<|notimestamps|>", "<|0.00|>")
return prompts, duration
def _preprocess_verbose_prompt(self, prompt: ExplicitEncoderDecoderPrompt):
def _preprocess_verbose_prompt(self, prompt: EncoderDecoderDictPrompt):
dec_prompt = prompt["decoder_prompt"]
if isinstance(dec_prompt, str):
prompt["decoder_prompt"] = self._repl_verbose_text(dec_prompt)
elif isinstance(dec_prompt, dict) and "prompt" in dec_prompt:
dec_prompt["prompt"] = self._repl_verbose_text(dec_prompt["prompt"])
else:
if not (isinstance(dec_prompt, dict) and "prompt" in dec_prompt):
raise VLLMValidationError(
"Expected decoder_prompt to contain text",
parameter="decoder_prompt",
value=type(dec_prompt).__name__,
)
dec_prompt["prompt"] = dec_prompt["prompt"].replace(
"<|notimestamps|>", "<|0.00|>"
)
return prompt
def _get_verbose_segments(
......
......@@ -28,10 +28,11 @@ from vllm.entrypoints.pooling.utils import (
encode_pooling_output_base64,
encode_pooling_output_float,
)
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.pooling_params import PoolingParams
from vllm.renderers.inputs import TokPrompt
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import chunk_list
from vllm.utils.serial_utils import EmbedDType, Endianness
......@@ -369,7 +370,7 @@ class OpenAIServingEmbedding(OpenAIServing):
async def _create_single_prompt_generator(
self,
ctx: EmbeddingServeContext,
engine_prompt: TokensPrompt | EmbedsPrompt,
engine_prompt: TokPrompt,
pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None,
prompt_index: int,
......
......@@ -33,8 +33,11 @@ from vllm.entrypoints.pooling.utils import (
encode_pooling_output_base64,
encode_pooling_output_float,
)
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.renderers.inputs import TokPrompt
from vllm.renderers.inputs.preprocess import prompt_to_seq
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
......@@ -91,6 +94,7 @@ class OpenAIServingPooling(OpenAIServing):
"dimensions is currently not supported"
)
engine_prompts: Sequence[PromptType | TokPrompt]
if is_io_processor_request:
if self.io_processor is None:
raise ValueError(
......@@ -102,14 +106,10 @@ class OpenAIServingPooling(OpenAIServing):
validated_prompt = self.io_processor.parse_request(request)
engine_prompts = await self.io_processor.pre_process_async(
raw_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id
)
if not isinstance(engine_prompts, Sequence) or isinstance(
engine_prompts, (str, bytes, bytearray)
):
engine_prompts = [engine_prompts]
engine_prompts = prompt_to_seq(raw_prompts)
elif isinstance(request, PoolingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=request.chat_template,
......
......@@ -17,8 +17,6 @@ from starlette.background import BackgroundTask, BackgroundTasks
from vllm import envs
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import EmbedsPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_len
from vllm.logger import current_formatter_type, init_logger
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser
......@@ -189,7 +187,7 @@ def cli_env_setup():
def get_max_tokens(
max_model_len: int,
request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest",
prompt: TokensPrompt | EmbedsPrompt,
input_length: int,
default_sampling_params: dict,
) -> int:
# NOTE: Avoid isinstance() for better efficiency
......@@ -204,7 +202,6 @@ def get_max_tokens(
# CompletionRequest (also a fallback for ChatCompletionRequest)
max_tokens = getattr(request, "max_tokens", None)
input_length = get_prompt_len(prompt)
default_max_tokens = max_model_len - input_length
max_output_tokens = current_platform.get_max_output_tokens(input_length)
......
......@@ -16,11 +16,8 @@ from .data import (
TextPrompt,
TokenInputs,
TokensPrompt,
build_explicit_enc_dec_prompt,
embeds_inputs,
to_enc_dec_tuple_list,
token_inputs,
zip_enc_dec_prompts,
)
__all__ = [
......@@ -39,8 +36,5 @@ __all__ = [
"EncoderDecoderInputs",
"ProcessorInputs",
"SingletonInputs",
"build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list",
"zip_enc_dec_prompts",
"StreamingInput",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment