"tests/vscode:/vscode.git/clone" did not exist on "ae002924e96bd17cfc690c266623c340ff28a70f"
Unverified Commit cd8b405b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Consolidate sequence normalization and enc-dec parsing (#33928)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 4707f7eb
......@@ -54,6 +54,7 @@ class MockModelConfig:
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
is_encoder_decoder: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
......
......@@ -53,6 +53,7 @@ class MockModelConfig:
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
is_encoder_decoder: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
......
......@@ -52,6 +52,7 @@ class MockModelConfig:
encoder_config = None
generation_config: str = "auto"
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
......
......@@ -529,6 +529,7 @@ class MockModelConfig:
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.renderers.inputs.preprocess import prompt_to_seq
def test_empty_input():
assert prompt_to_seq([]) == []
assert prompt_to_seq([[]]) == [[]]
assert prompt_to_seq([[], []]) == [[], []]
def test_text_input():
assert prompt_to_seq("foo") == ["foo"]
assert prompt_to_seq(["foo"]) == ["foo"]
assert prompt_to_seq(["foo", "bar"]) == ["foo", "bar"]
def test_token_input():
assert prompt_to_seq([1, 2]) == [[1, 2]]
assert prompt_to_seq([[1, 2]]) == [[1, 2]]
assert prompt_to_seq([[1, 2], [3, 4]]) == [[1, 2], [3, 4]]
def test_text_token_input():
assert prompt_to_seq([[1, 2], "foo"]) == [[1, 2], "foo"]
assert prompt_to_seq(["foo", [1, 2]]) == ["foo", [1, 2]]
def test_bytes_input():
assert prompt_to_seq(b"foo") == [b"foo"]
assert prompt_to_seq([b"foo"]) == [b"foo"]
assert prompt_to_seq([b"foo", b"bar"]) == [b"foo", b"bar"]
def test_dict_input():
assert prompt_to_seq({"prompt": "foo"}) == [{"prompt": "foo"}]
assert prompt_to_seq([{"prompt": "foo"}]) == [{"prompt": "foo"}]
assert prompt_to_seq([{"prompt": "foo"}, {"prompt_token_ids": [1, 2]}]) == [
{"prompt": "foo"},
{"prompt_token_ids": [1, 2]},
]
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import io
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
......@@ -9,8 +10,11 @@ import pybase64
import pytest
import torch
from vllm.config import ModelConfig
from vllm.inputs import SingletonPrompt
from vllm.renderers import TokenizeParams
from vllm.renderers.hf import HfRenderer
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
from vllm.tokenizers.registry import tokenizer_args_from_config
MODEL_NAME = "openai-community/gpt2"
......@@ -33,6 +37,7 @@ class MockModelConfig:
encoder_config: dict[str, Any] | None = None
enable_prompt_embeds: bool = True
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
@dataclass
......@@ -80,65 +85,34 @@ def _build_renderer(
return renderer
class TestValidatePrompt:
STRING_INPUTS = [
"",
"foo",
"foo bar",
"foo baz bar",
"foo bar qux baz",
]
TOKEN_INPUTS = [
[-1],
[1],
[1, 2],
[1, 3, 4],
[1, 2, 4, 3],
def _preprocess_prompt(
mdoel_config: ModelConfig,
prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes],
):
return [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(mdoel_config, prompt)
)
for prompt in prompt_to_seq(prompt_or_prompts)
]
INPUTS_SLICES = [
slice(None, None, -1),
slice(None, None, 2),
slice(None, None, -2),
]
# Test that a nested mixed-type list of lists raises a TypeError.
class TestValidatePrompt:
def test_empty_input(self):
renderer = _build_renderer(MockModelConfig())
with pytest.raises(ValueError, match="at least one prompt"):
renderer.render_completions([])
renderer.render_prompts(_preprocess_prompt(renderer.config, []))
def test_invalid_type(self):
renderer = _build_renderer(MockModelConfig())
with pytest.raises(TypeError, match="string or an array of tokens"):
renderer.render_completions([[1, 2], ["foo", "bar"]])
@pytest.mark.parametrize("string_input", STRING_INPUTS)
def test_string_consistent(self, string_input: str):
renderer = _build_renderer(MockModelConfig())
assert renderer.render_completions(string_input) == renderer.render_completions(
[string_input]
)
@pytest.mark.parametrize("token_input", TOKEN_INPUTS)
def test_token_consistent(self, token_input: list[int]):
renderer = _build_renderer(MockModelConfig())
assert renderer.render_completions(token_input) == renderer.render_completions(
[token_input]
)
@pytest.mark.parametrize("inputs_slice", INPUTS_SLICES)
def test_string_slice(self, inputs_slice: slice):
renderer = _build_renderer(MockModelConfig())
assert renderer.render_completions(self.STRING_INPUTS)[
inputs_slice
] == renderer.render_completions(self.STRING_INPUTS[inputs_slice])
with pytest.raises(TypeError, match="should be a list of integers"):
renderer.render_prompts(
_preprocess_prompt(renderer.config, [[1, 2], ["foo", "bar"]]) # type: ignore[arg-type]
)
class TestRenderPrompt:
......@@ -146,7 +120,7 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig())
tokens = [101, 7592, 2088]
prompts = renderer.render_completions(tokens)
prompts = renderer.render_prompts(_preprocess_prompt(renderer.config, tokens))
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
......@@ -159,7 +133,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig())
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
prompts = renderer.render_completions(token_lists)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, token_lists)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
......@@ -174,7 +150,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig())
text_input = "x" * 10
prompts = renderer.render_completions(text_input)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, text_input)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
......@@ -187,7 +165,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig())
text_list_input = ["x" * 10, "x" * 12, "x" * 14]
prompts = renderer.render_completions(text_list_input)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, text_list_input)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
......@@ -200,7 +180,9 @@ class TestRenderPrompt:
def test_zero_truncation(self):
renderer = _build_renderer(MockModelConfig())
prompts = renderer.render_completions("x" * 200)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, "x" * 200)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=0),
......@@ -212,7 +194,9 @@ class TestRenderPrompt:
def test_pos_truncation(self):
renderer = _build_renderer(MockModelConfig())
prompts = renderer.render_completions("x" * 200)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, "x" * 200)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=50),
......@@ -224,7 +208,9 @@ class TestRenderPrompt:
def test_neg_truncation(self):
renderer = _build_renderer(MockModelConfig())
prompts = renderer.render_completions("x" * 200)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, "x" * 200)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=-1),
......@@ -237,7 +223,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig(), truncation_side="left")
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
prompts = renderer.render_completions(long_tokens)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
......@@ -251,7 +239,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig(), truncation_side="right")
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
prompts = renderer.render_completions(long_tokens)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
......@@ -266,7 +256,9 @@ class TestRenderPrompt:
# Exceeds max_total_tokens and max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
long_tokens = "x" * 150
prompts = renderer.render_completions(long_tokens)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
with pytest.raises(
ValueError,
......@@ -285,7 +277,9 @@ class TestRenderPrompt:
# Exceeds max_total_tokens but not max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
long_tokens = "x" * 150
prompts = renderer.render_completions(long_tokens)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
with pytest.raises(
ValueError,
......@@ -304,7 +298,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig())
long_tokens = list(range(150)) # Exceeds max_total_tokens=100
prompts = renderer.render_completions(long_tokens)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
with pytest.raises(
ValueError,
......@@ -318,7 +314,9 @@ class TestRenderPrompt:
def test_no_tokenizer_for_text(self):
renderer = _build_renderer(MockModelConfig(skip_tokenizer_init=True))
prompts = renderer.render_completions("Hello world")
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, "Hello world")
)
with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
renderer.tokenize_prompts(
......@@ -330,7 +328,7 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig())
tokens = [1, 2, 3, 4]
prompts = renderer.render_completions(tokens)
prompts = renderer.render_prompts(_preprocess_prompt(renderer.config, tokens))
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(
......@@ -359,7 +357,9 @@ class TestRenderEmbedPrompt:
tensor_input = torch.randn(10, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(tensor_input)
prompts = renderer.render_completions(prompt_embeds=embed_bytes)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, embed_bytes)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
......@@ -377,8 +377,11 @@ class TestRenderEmbedPrompt:
torch.randn(12, 512, dtype=torch.float32),
]
prompts = renderer.render_completions(
prompt_embeds=[self._create_test_embed_bytes(t) for t in tensor_inputs],
prompts = renderer.render_prompts(
_preprocess_prompt(
renderer.config,
[self._create_test_embed_bytes(t) for t in tensor_inputs],
)
)
results = renderer.tokenize_prompts(
prompts,
......@@ -395,8 +398,10 @@ class TestRenderEmbedPrompt:
# Create tensor with more tokens than truncation limit
tensor_input = torch.randn(20, 768, dtype=torch.float32)
prompts = renderer.render_completions(
prompt_embeds=self._create_test_embed_bytes(tensor_input),
prompts = renderer.render_prompts(
_preprocess_prompt(
renderer.config, self._create_test_embed_bytes(tensor_input)
)
)
results = renderer.tokenize_prompts(
prompts,
......@@ -420,8 +425,10 @@ class TestRenderEmbedPrompt:
for dtype in dtypes:
tensor_input = torch.randn(5, 256, dtype=dtype)
prompts = renderer.render_completions(
prompt_embeds=self._create_test_embed_bytes(tensor_input),
prompts = renderer.render_prompts(
_preprocess_prompt(
renderer.config, self._create_test_embed_bytes(tensor_input)
)
)
results = renderer.tokenize_prompts(
prompts,
......@@ -437,8 +444,10 @@ class TestRenderEmbedPrompt:
# Test tensor with batch dimension gets squeezed
tensor_input = torch.randn(1, 10, 768, dtype=torch.float32)
prompts = renderer.render_completions(
prompt_embeds=self._create_test_embed_bytes(tensor_input),
prompts = renderer.render_prompts(
_preprocess_prompt(
renderer.config, self._create_test_embed_bytes(tensor_input)
)
)
results = renderer.tokenize_prompts(
prompts,
......@@ -455,9 +464,11 @@ class TestRenderEmbedPrompt:
text_input = "Hello world"
tensor_input = torch.randn(5, 256, dtype=torch.float32)
prompts = renderer.render_completions(
text_input,
prompt_embeds=self._create_test_embed_bytes(tensor_input),
prompts = renderer.render_prompts(
_preprocess_prompt(
renderer.config,
[text_input, self._create_test_embed_bytes(tensor_input)],
)
)
results = renderer.tokenize_prompts(
prompts,
......@@ -465,8 +476,8 @@ class TestRenderEmbedPrompt:
)
assert len(results) == 2
# First should be embed prompt
assert torch.equal(results[0]["prompt_embeds"], tensor_input)
# Second should be tokens prompt
assert "prompt_token_ids" in results[1]
assert len(results[1]["prompt_token_ids"]) == len(text_input)
# First should be tokens prompt
assert "prompt_token_ids" in results[0]
assert len(results[0]["prompt_token_ids"]) == len(text_input)
# Second should be embed prompt
assert torch.equal(results[1]["prompt_embeds"], tensor_input)
......@@ -3,16 +3,40 @@
import asyncio
import time
from dataclasses import dataclass
from typing import Any
from unittest.mock import Mock
import pytest
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from vllm.config import ModelConfig
from vllm.renderers import ChatParams
from vllm.renderers.mistral import MistralRenderer, safe_apply_chat_template
from vllm.tokenizers.mistral import MistralTokenizer
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
@dataclass
class MockHFConfig:
model_type: str = "any"
@dataclass
class MockModelConfig:
runner_type = "generate"
model: str = MODEL_NAME
tokenizer: str = MODEL_NAME
trust_remote_code: bool = False
max_model_len: int = 100
tokenizer_revision = None
tokenizer_mode = "mistral"
hf_config = MockHFConfig()
encoder_config: dict[str, Any] | None = None
enable_prompt_embeds: bool = True
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
@pytest.mark.asyncio
async def test_async_mistral_tokenizer_does_not_block_event_loop():
......@@ -23,9 +47,10 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop():
time.sleep(2)
return expected_tokens
mock_model_config = MockModelConfig(skip_tokenizer_init=True)
mock_tokenizer = Mock(spec=MistralTokenizer)
mock_tokenizer.apply_chat_template = mocked_apply_chat_template
mock_renderer = MistralRenderer(Mock(spec=ModelConfig), tokenizer_kwargs={})
mock_renderer = MistralRenderer(mock_model_config, tokenizer_kwargs={})
mock_renderer._tokenizer = mock_tokenizer
task = mock_renderer.render_messages_async([], ChatParams())
......
......@@ -4,52 +4,13 @@
import pytest
from vllm.config import ModelConfig
from vllm.inputs import zip_enc_dec_prompts
from vllm.inputs.preprocess import InputPreprocessor
pytestmark = pytest.mark.cpu_test
@pytest.mark.parametrize(
"mm_processor_kwargs,expected_mm_kwargs",
[
(None, [{}, {}]),
({}, [{}, {}]),
({"foo": 100}, [{"foo": 100}, {"foo": 100}]),
([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]),
],
)
def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
"""Test mm_processor_kwargs init for zipping enc/dec prompts."""
encoder_prompts = ["An encoder prompt", "Another encoder prompt"]
decoder_prompts = ["A decoder prompt", "Another decoder prompt"]
zipped_prompts = zip_enc_dec_prompts(
encoder_prompts, decoder_prompts, mm_processor_kwargs
)
assert len(zipped_prompts) == len(encoder_prompts) == len(decoder_prompts)
for enc, dec, exp_kwargs, zipped in zip(
encoder_prompts, decoder_prompts, expected_mm_kwargs, zipped_prompts
):
assert isinstance(zipped, dict)
assert len(zipped.keys()) == 3
assert zipped["encoder_prompt"] == enc
assert zipped["decoder_prompt"] == dec
assert zipped["mm_processor_kwargs"] == exp_kwargs
@pytest.mark.parametrize(
"model_id",
[
"facebook/chameleon-7b",
],
)
@pytest.mark.parametrize(
"prompt",
[
"",
{"prompt_token_ids": []},
],
)
@pytest.mark.parametrize("model_id", ["facebook/chameleon-7b"])
@pytest.mark.parametrize("prompt", ["", {"prompt_token_ids": []}])
@pytest.mark.skip(
reason=(
"Applying huggingface processor on text inputs results in "
......
......@@ -16,6 +16,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor
from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.v1.engine import EngineCoreRequest
......@@ -53,7 +54,11 @@ class EngineClient(ABC):
@abstractmethod
def generate(
self,
prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
prompt: EngineCoreRequest
| PromptType
| DictPrompt
| TokPrompt
| AsyncGenerator[StreamingInput, None],
sampling_params: SamplingParams,
request_id: str,
*,
......@@ -70,7 +75,7 @@ class EngineClient(ABC):
@abstractmethod
def encode(
self,
prompt: PromptType,
prompt: PromptType | DictPrompt | TokPrompt,
pooling_params: PoolingParams,
request_id: str,
lora_request: LoRARequest | None = None,
......
......@@ -4,7 +4,7 @@
import itertools
import warnings
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, TypeAlias, cast
from typing import TYPE_CHECKING, Any, cast
import cloudpickle
import torch.nn as nn
......@@ -53,16 +53,13 @@ from vllm.entrypoints.pooling.score.utils import (
validate_score_input,
)
from vllm.entrypoints.utils import log_non_default_args
from vllm.inputs import (
from vllm.inputs.data import (
DataPrompt,
EmbedsPrompt,
ExplicitEncoderDecoderPrompt,
PromptType,
SingletonPrompt,
TextPrompt,
TokensPrompt,
)
from vllm.inputs.parse import get_prompt_components, is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.quantization import QuantizationMethods
......@@ -76,6 +73,13 @@ from vllm.outputs import (
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.renderers.inputs import DictPrompt, SingletonDictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import (
conversation_to_seq,
extract_prompt_components,
parse_model_prompt,
prompt_to_seq,
)
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
from vllm.tasks import PoolingTask
from vllm.tokenizers import TokenizerLike
......@@ -93,9 +97,6 @@ logger = init_logger(__name__)
_R = TypeVar("_R", default=Any)
EnginePrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt
EngineEncDecPrompt: TypeAlias = ExplicitEncoderDecoderPrompt[EnginePrompt, EnginePrompt]
class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.
......@@ -445,21 +446,20 @@ class LLM:
if sampling_params is None:
sampling_params = self.get_default_sampling_params()
self._validate_and_add_requests(
outputs = self._run_completion(
prompts=prompts,
params=sampling_params,
use_tqdm=use_tqdm,
lora_request=self._get_modality_specific_lora_reqs(prompts, lora_request),
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, RequestOutput)
def _get_modality_specific_lora_reqs(
self,
prompts: PromptType | Sequence[PromptType],
prompts: Sequence[DictPrompt | TokPrompt],
lora_request: list[LoRARequest] | LoRARequest | None,
):
# Grab the lora config off the vllm config on the engine,
......@@ -475,9 +475,6 @@ class LLM:
):
return lora_request
if not isinstance(prompts, Sequence) or isinstance(prompts, str):
prompts = [prompts]
optional_loras = (
[lora_request] * len(prompts)
if not isinstance(lora_request, Sequence)
......@@ -495,14 +492,12 @@ class LLM:
def _resolve_single_prompt_mm_lora(
self,
prompt: PromptType,
prompt: DictPrompt | TokPrompt,
lora_request: LoRARequest | None,
default_mm_loras: dict[str, str] | None,
):
if (
not default_mm_loras
or not isinstance(prompt, dict)
or not (mm_data := prompt.get("multi_modal_data") or {})
if not default_mm_loras or not (
mm_data := prompt.get("multi_modal_data") or {}
):
return lora_request
......@@ -806,61 +801,11 @@ class LLM:
add_special_tokens=not model_config.is_encoder_decoder,
).with_kwargs(tokenization_kwargs)
def _normalize_prompts(
self,
prompts: PromptType | Sequence[PromptType],
) -> list[EnginePrompt | EngineEncDecPrompt]:
if isinstance(prompts, str):
prompts = TextPrompt(prompt=prompts)
return prompts if isinstance(prompts, Sequence) else [prompts] # type: ignore[return-value]
def _preprocess_cmpl_singleton(
self,
prompt: SingletonPrompt,
tok_params: TokenizeParams,
*,
tokenize: bool,
) -> EnginePrompt:
renderer = self.llm_engine.renderer
if not isinstance(prompt, dict):
prompt = renderer.render_completion(prompt)
return renderer.tokenize_prompt(prompt, tok_params) if tokenize else prompt
def _preprocess_cmpl_enc_dec(
self,
prompt: ExplicitEncoderDecoderPrompt,
tok_params: TokenizeParams,
) -> EngineEncDecPrompt:
enc_prompt = prompt["encoder_prompt"]
dec_prompt = prompt["decoder_prompt"]
return EngineEncDecPrompt(
encoder_prompt=self._preprocess_cmpl_singleton(
enc_prompt,
tok_params,
# TODO: Move multi-modal processor into tokenization
tokenize=not self.model_config.is_multimodal_model,
),
decoder_prompt=(
None
if dec_prompt is None
else self._preprocess_cmpl_singleton(
dec_prompt,
tok_params,
# TODO: Move multi-modal processor into tokenization
tokenize=not self.model_config.is_multimodal_model,
)
),
)
def _preprocess_completion(
self,
prompts: PromptType | Sequence[PromptType],
prompts: Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[EnginePrompt | EngineEncDecPrompt]:
) -> list[DictPrompt | TokPrompt]:
"""
Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
a format that can be passed to `_add_request`.
......@@ -871,32 +816,26 @@ class LLM:
A list of `TokensPrompts` objects containing the tokenized prompt
after chat template interpolation, and the raw multi-modal inputs.
"""
renderer = self.llm_engine.renderer
model_config = self.model_config
tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
engine_prompts = list[EnginePrompt | EngineEncDecPrompt]()
for prompt in self._normalize_prompts(prompts):
if is_explicit_encoder_decoder_prompt(prompt):
engine_prompts.append(self._preprocess_cmpl_enc_dec(prompt, tok_params))
else:
# Some MM models have non-default `add_special_tokens`
# TODO: Move multi-modal processor into tokenization
engine_prompts.append(
self._preprocess_cmpl_singleton(
prompt,
tok_params,
tokenize=not self.model_config.is_multimodal_model,
)
)
engine_prompts = list[DictPrompt | TokPrompt]()
for prompt in prompts:
parsed_prompt = parse_model_prompt(model_config, prompt)
in_prompt = renderer.render_prompt(parsed_prompt)
# Some MM models have non-default `add_special_tokens`
# TODO: Move multi-modal processor into tokenization
engine_prompts.append(
in_prompt
if model_config.is_multimodal_model
else renderer.tokenize_prompt(in_prompt, tok_params)
)
return engine_prompts
def _normalize_conversations(
self,
conversations: list[ChatCompletionMessageParam]
| list[list[ChatCompletionMessageParam]],
) -> list[list[ChatCompletionMessageParam]]:
return conversations if is_list_of(conversations, list) else [conversations] # type: ignore[list-item,return-value]
def _get_chat_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
model_config = self.model_config
encoder_config = model_config.encoder_config or {}
......@@ -909,8 +848,7 @@ class LLM:
def _preprocess_chat(
self,
conversations: list[ChatCompletionMessageParam]
| list[list[ChatCompletionMessageParam]],
conversations: Sequence[list[ChatCompletionMessageParam]],
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
chat_template_kwargs: dict[str, Any] | None = None,
......@@ -919,7 +857,7 @@ class LLM:
tools: list[dict[str, Any]] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> list[EnginePrompt]:
) -> list[DictPrompt | TokPrompt]:
"""
Convert a list of conversations into prompts so that they can then
be used as input for other LLM APIs.
......@@ -947,11 +885,14 @@ class LLM:
)
tok_params = self._get_chat_tok_params(tokenization_kwargs)
engine_prompts = list[EnginePrompt]()
for conversation in self._normalize_conversations(conversations):
engine_prompts = list[DictPrompt | TokPrompt]()
for conversation in conversations:
_, in_prompt = renderer.render_messages(conversation, chat_params)
if mm_processor_kwargs is not None:
in_prompt["mm_processor_kwargs"] = mm_processor_kwargs
target_prompt: SingletonDictPrompt = in_prompt.get( # type: ignore
"encoder_prompt", in_prompt
)
target_prompt["mm_processor_kwargs"] = mm_processor_kwargs # type: ignore
engine_prompts.append(renderer.tokenize_prompt(in_prompt, tok_params))
......@@ -960,8 +901,8 @@ class LLM:
def chat(
self,
messages: list[ChatCompletionMessageParam]
| list[list[ChatCompletionMessageParam]],
sampling_params: SamplingParams | list[SamplingParams] | None = None,
| Sequence[list[ChatCompletionMessageParam]],
sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: LoRARequest | None = None,
chat_template: str | None = None,
......@@ -984,7 +925,7 @@ class LLM:
to the OpenAI API.
Args:
messages: A list of conversations or a single conversation.
messages: A sequence of conversations or a single conversation.
- Each conversation is represented as a list of messages.
- Each message is a dictionary with 'role' and 'content' keys.
......@@ -1023,8 +964,23 @@ class LLM:
A list of `RequestOutput` objects containing the generated
responses in the same order as the input messages.
"""
prompts = self._preprocess_chat(
messages,
model_config = self.model_config
runner_type = model_config.runner_type
if runner_type != "generate":
raise ValueError(
"LLM.chat() is only supported for generative models. "
"Try passing `--runner generate` to use the model as a "
"generative model."
)
if sampling_params is None:
sampling_params = self.get_default_sampling_params()
outputs = self._run_chat(
messages=messages,
params=sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
chat_template_kwargs=chat_template_kwargs,
......@@ -1035,13 +991,7 @@ class LLM:
mm_processor_kwargs=mm_processor_kwargs,
)
return self.generate(
prompts,
sampling_params=sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
return self.engine_class.validate_outputs(outputs, RequestOutput)
def encode(
self,
......@@ -1163,7 +1113,7 @@ class LLM:
msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
raise ValueError(msg)
self._validate_and_add_requests(
outputs = self._run_completion(
prompts=prompts,
params=pooling_params,
use_tqdm=use_tqdm,
......@@ -1171,8 +1121,6 @@ class LLM:
tokenization_kwargs=tokenization_kwargs,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
model_outputs = self.engine_class.validate_outputs(
outputs, PoolingRequestOutput
)
......@@ -1523,14 +1471,13 @@ class LLM:
prompts.append(engine_prompt)
self._validate_and_add_requests(
outputs = self._run_completion(
prompts=prompts,
params=pooling_params_list,
use_tqdm=use_tqdm,
lora_request=lora_request,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items]
......@@ -1727,33 +1674,29 @@ class LLM:
"""
return self.llm_engine.get_metrics()
def _validate_and_add_requests(
def _params_to_seq(
self,
prompts: PromptType | Sequence[PromptType],
params: SamplingParams
| Sequence[SamplingParams]
| PoolingParams
| Sequence[PoolingParams],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
tokenization_kwargs: dict[str, Any] | None = None,
priority: list[int] | None = None,
) -> None:
in_prompts = self._normalize_prompts(prompts)
num_requests = len(in_prompts)
| Sequence[SamplingParams | PoolingParams],
num_requests: int,
) -> Sequence[SamplingParams | PoolingParams]:
if isinstance(params, Sequence):
if len(params) != num_requests:
raise ValueError(
f"The lengths of prompts ({params}) "
f"and lora_request ({len(params)}) must be the same."
f"and params ({len(params)}) must be the same."
)
engine_params = params
else:
engine_params = [params] * num_requests
return params
return [params] * num_requests
def _lora_request_to_seq(
self,
lora_request: LoRARequest | None | Sequence[LoRARequest | None],
num_requests: int,
) -> Sequence[LoRARequest | None]:
if isinstance(lora_request, Sequence):
if len(lora_request) != num_requests:
raise ValueError(
......@@ -1761,28 +1704,50 @@ class LLM:
f"and lora_request ({len(lora_request)}) must be the same."
)
engine_lora_requests: Sequence[LoRARequest | None] = lora_request
else:
engine_lora_requests = [lora_request] * num_requests
return lora_request
return [lora_request] * num_requests
def _priority_to_seq(
self,
priority: list[int] | None,
num_requests: int,
) -> Sequence[int]:
if priority is not None:
if len(priority) != num_requests:
raise ValueError(
f"The lengths of prompts ({num_requests}) "
f"and priority ({len(priority)}) must be the same."
)
else:
priority = [0] * num_requests
if any(param.truncate_prompt_tokens is not None for param in engine_params):
return priority
return [0] * num_requests
def _run_completion(
self,
prompts: PromptType | Sequence[PromptType],
params: SamplingParams
| PoolingParams
| Sequence[SamplingParams | PoolingParams],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: list[LoRARequest] | LoRARequest | None = None,
priority: list[int] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
):
seq_prompts = prompt_to_seq(prompts)
seq_params = self._params_to_seq(params, len(seq_prompts))
if any(param.truncate_prompt_tokens is not None for param in seq_params):
# TODO: Remove this after deprecating `param.truncate_prompt_tokens`
# Then, move the code from the `else` block to the top and let
# `self._preprocess_completion` handle prompt normalization
engine_prompts = [
engine_prompt
for in_prompt, param in zip(in_prompts, engine_params)
for prompt, param in zip(seq_prompts, seq_params)
for engine_prompt in self._preprocess_completion(
[in_prompt],
[prompt],
tokenization_kwargs=merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
......@@ -1791,17 +1756,90 @@ class LLM:
]
else:
engine_prompts = self._preprocess_completion(
in_prompts,
seq_prompts,
tokenization_kwargs=tokenization_kwargs,
)
for sp in engine_params:
self._validate_and_add_requests(
prompts=engine_prompts,
params=seq_params,
use_tqdm=use_tqdm,
lora_request=self._get_modality_specific_lora_reqs(
engine_prompts, lora_request
),
tokenization_kwargs=tokenization_kwargs,
priority=priority,
)
return self._run_engine(use_tqdm=use_tqdm)
def _run_chat(
self,
messages: list[ChatCompletionMessageParam]
| Sequence[list[ChatCompletionMessageParam]],
params: SamplingParams
| PoolingParams
| Sequence[SamplingParams | PoolingParams],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: LoRARequest | None = None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tools: list[dict[str, Any]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None,
):
engine_prompts = self._preprocess_chat(
conversation_to_seq(messages),
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
chat_template_kwargs=chat_template_kwargs,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
tokenization_kwargs=tokenization_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
)
self._validate_and_add_requests(
prompts=engine_prompts,
params=params,
use_tqdm=use_tqdm,
lora_request=self._get_modality_specific_lora_reqs(
engine_prompts, lora_request
),
tokenization_kwargs=tokenization_kwargs,
)
return self._run_engine(use_tqdm=use_tqdm)
def _validate_and_add_requests(
self,
prompts: Sequence[DictPrompt | TokPrompt],
params: SamplingParams
| PoolingParams
| Sequence[SamplingParams | PoolingParams],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
tokenization_kwargs: dict[str, Any] | None = None,
priority: list[int] | None = None,
) -> None:
num_requests = len(prompts)
seq_params = self._params_to_seq(params, num_requests)
seq_lora_requests = self._lora_request_to_seq(lora_request, num_requests)
seq_priority = self._priority_to_seq(priority, num_requests)
for sp in seq_params:
if isinstance(sp, SamplingParams):
# We only care about the final output
sp.output_kind = RequestOutputKind.FINAL_ONLY
# Add requests to the engine.
it = engine_prompts
it = prompts
if use_tqdm:
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
it = tqdm_func(it, desc="Adding requests")
......@@ -1812,10 +1850,10 @@ class LLM:
for i, prompt in enumerate(it):
request_id = self._add_request(
prompt,
engine_params[i],
lora_request=engine_lora_requests[i],
seq_params[i],
lora_request=seq_lora_requests[i],
tokenization_kwargs=tokenization_kwargs,
priority=priority[i],
priority=seq_priority[i],
)
added_request_ids.append(request_id)
except Exception as e:
......@@ -1825,13 +1863,13 @@ class LLM:
def _add_request(
self,
prompt: PromptType,
prompt: PromptType | DictPrompt | TokPrompt,
params: SamplingParams | PoolingParams,
lora_request: LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
priority: int = 0,
) -> str:
prompt_text, _, _ = get_prompt_components(prompt)
prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
request_id = str(next(self.request_counter))
if params.truncate_prompt_tokens is not None:
......
......@@ -67,12 +67,13 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
)
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.parser import ParserManager
from vllm.reasoning import ReasoningParser
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import (
......@@ -218,10 +219,7 @@ class OpenAIServingChat(OpenAIServing):
async def render_chat_request(
self,
request: ChatCompletionRequest,
) -> (
tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]
| ErrorResponse
):
) -> tuple[list[ConversationMessage], list[TokPrompt]] | ErrorResponse:
"""
render chat request by validating and preprocessing inputs.
......@@ -380,7 +378,7 @@ class OpenAIServingChat(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text = engine_prompt.get("prompt")
prompt_text = self._extract_prompt_text(engine_prompt)
# If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids.
......@@ -389,10 +387,10 @@ class OpenAIServingChat(OpenAIServing):
)
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
prompt=engine_prompt,
default_sampling_params=self.default_sampling_params,
self.max_model_len,
request,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params,
)
sampling_params: SamplingParams | BeamSearchParams
......
......@@ -34,10 +34,10 @@ from vllm.entrypoints.openai.engine.serving import (
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import merge_async_iterators
......@@ -78,7 +78,7 @@ class OpenAIServingCompletion(OpenAIServing):
async def render_completion_request(
self,
request: CompletionRequest,
) -> list[TokensPrompt | EmbedsPrompt] | ErrorResponse:
) -> list[TokPrompt] | ErrorResponse:
"""
render completion request by validating and preprocessing inputs.
......@@ -160,13 +160,13 @@ class OpenAIServingCompletion(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text = engine_prompt.get("prompt")
prompt_text = self._extract_prompt_text(engine_prompt)
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
prompt=engine_prompt,
default_sampling_params=self.default_sampling_params,
self.max_model_len,
request,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params,
)
sampling_params: SamplingParams | BeamSearchParams
......@@ -277,7 +277,7 @@ class OpenAIServingCompletion(OpenAIServing):
# with the inputs token IDs
if final_res.prompt is None:
engine_prompt = engine_prompts[i]
final_res.prompt = engine_prompt.get("prompt")
final_res.prompt = self._extract_prompt_text(engine_prompt)
final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
......@@ -313,7 +313,7 @@ class OpenAIServingCompletion(OpenAIServing):
async def completion_stream_generator(
self,
request: CompletionRequest,
engine_prompts: list[TokensPrompt | EmbedsPrompt],
engine_prompts: list[TokPrompt],
result_generator: AsyncIterator[tuple[int, RequestOutput]],
request_id: str,
created_time: int,
......@@ -347,7 +347,7 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_text = res.prompt
if prompt_text is None:
engine_prompt = engine_prompts[prompt_idx]
prompt_text = engine_prompt.get("prompt")
prompt_text = self._extract_prompt_text(engine_prompt)
# Prompt details are excluded from later streamed outputs
if prompt_token_ids is not None:
......
......@@ -96,11 +96,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
)
from vllm.entrypoints.utils import get_max_tokens, sanitize_message
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, PromptType, TokensPrompt
from vllm.inputs.parse import (
get_prompt_components,
is_explicit_encoder_decoder_prompt,
)
from vllm.inputs.data import PromptType, SingletonPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest
......@@ -108,6 +104,14 @@ from vllm.multimodal import MultiModalDataDict
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.renderers.inputs import TokPrompt
from vllm.renderers.inputs.preprocess import (
SingletonDictPrompt,
extract_prompt_components,
extract_prompt_len,
parse_model_prompt,
prompt_to_seq,
)
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser
......@@ -203,7 +207,7 @@ class ServeContext(Generic[RequestT]):
request_id: str
created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None
engine_prompts: list[TokensPrompt | EmbedsPrompt] | None = None
engine_prompts: list[TokPrompt] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None
......@@ -247,7 +251,7 @@ class OpenAIServing:
async def beam_search(
self,
prompt: PromptType,
prompt: TokPrompt,
request_id: str,
params: BeamSearchParams,
lora_request: LoRARequest | None = None,
......@@ -271,20 +275,12 @@ class OpenAIServing:
eos_token_id: int = tokenizer.eos_token_id # type: ignore
if is_explicit_encoder_decoder_prompt(prompt):
raise NotImplementedError
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
raise NotImplementedError("Encoder-decoder prompt not supported")
prompt_text: str | None
prompt_token_ids: list[int]
multi_modal_data: MultiModalDataDict | None
if isinstance(prompt, str):
prompt_text = prompt
prompt_token_ids = []
multi_modal_data = None
else:
prompt_text = prompt.get("prompt") # type: ignore
prompt_token_ids = prompt.get("prompt_token_ids", []) # type: ignore
multi_modal_data = prompt.get("multi_modal_data") # type: ignore
prompt_text: str | None = prompt.get("prompt") # type: ignore
prompt_token_ids: list[int] = prompt.get("prompt_token_ids", []) # type: ignore
multi_modal_data: MultiModalDataDict | None = prompt.get("multi_modal_data") # type: ignore
mm_processor_kwargs: dict[str, Any] | None = None
......@@ -963,22 +959,40 @@ class OpenAIServing:
request: RendererRequest,
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None,
) -> list[TokensPrompt | EmbedsPrompt]:
) -> list[TokPrompt]:
renderer = self.renderer
tok_params = request.build_tok_params(self.model_config)
model_config = self.model_config
in_prompts = await renderer.render_completions_async(
prompt_input, prompt_embeds
)
engine_prompts = await renderer.tokenize_prompts_async(in_prompts, tok_params)
tok_params = request.build_tok_params(model_config)
prompts = list[SingletonPrompt | bytes]()
if prompt_embeds is not None: # embeds take higher priority
prompts.extend(prompt_to_seq(prompt_embeds))
if prompt_input is not None:
prompts.extend(prompt_to_seq(prompt_input))
parsed_prompts = [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(model_config, prompt)
)
for prompt in prompts
]
in_prompts = await renderer.render_prompts_async(parsed_prompts)
extra_items = {
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
}
for prompt in engine_prompts:
prompt.update(extra_items) # type: ignore
for in_prompt in in_prompts:
target_prompt: SingletonDictPrompt = in_prompt.get( # type: ignore
"encoder_prompt", in_prompt
)
target_prompt.update(extra_items) # type: ignore
engine_prompts = await renderer.tokenize_prompts_async(in_prompts, tok_params)
return engine_prompts
......@@ -991,7 +1005,7 @@ class OpenAIServing:
default_template_kwargs: dict[str, Any] | None,
tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
) -> tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]:
) -> tuple[list[ConversationMessage], list[TokPrompt]]:
from vllm.tokenizers.mistral import MistralTokenizer
renderer = self.renderer
......@@ -1009,17 +1023,21 @@ class OpenAIServing:
default_template, default_template_content_format
).with_defaults(default_template_kwargs)
conversation, prompt = await renderer.render_messages_async(
conversation, in_prompt = await renderer.render_messages_async(
messages, chat_params
)
engine_prompt = await renderer.tokenize_prompt_async(prompt, tok_params)
target_prompt: SingletonDictPrompt = in_prompt.get( # type: ignore
"encoder_prompt", in_prompt
)
extra_items = {
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
}
engine_prompt.update(extra_items) # type: ignore
target_prompt.update(extra_items) # type: ignore
engine_prompt = await renderer.tokenize_prompt_async(target_prompt, tok_params)
# tool parsing is done only if a tool_parser has been set and if
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser
......@@ -1040,6 +1058,15 @@ class OpenAIServing:
return conversation, [engine_prompt]
def _extract_prompt_components(self, prompt: object):
return extract_prompt_components(self.model_config, prompt)
def _extract_prompt_text(self, prompt: object):
return self._extract_prompt_components(prompt).text
def _extract_prompt_len(self, prompt: object):
return extract_prompt_len(self.model_config, prompt)
async def _render_next_turn(
self,
request: ResponsesRequest,
......@@ -1067,7 +1094,7 @@ class OpenAIServing:
async def _generate_with_builtin_tools(
self,
request_id: str,
engine_prompt: TokensPrompt | EmbedsPrompt,
engine_prompt: TokPrompt,
sampling_params: SamplingParams,
tok_params: TokenizeParams,
context: ConversationContext,
......@@ -1075,7 +1102,7 @@ class OpenAIServing:
priority: int = 0,
trace_headers: Mapping[str, str] | None = None,
):
prompt_text = engine_prompt.get("prompt")
prompt_text = self._extract_prompt_text(engine_prompt)
orig_priority = priority
sub_request = 0
......@@ -1145,12 +1172,12 @@ class OpenAIServing:
context.chat_template_content_format,
)
engine_prompt = engine_prompts[0]
prompt_text = engine_prompt.get("prompt")
prompt_text = self._extract_prompt_text(engine_prompt)
sampling_params.max_tokens = get_max_tokens(
self.max_model_len,
context.request,
engine_prompt,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params, # type: ignore
)
......@@ -1161,20 +1188,20 @@ class OpenAIServing:
def _log_inputs(
self,
request_id: str,
inputs: PromptType,
inputs: PromptType | TokPrompt,
params: SamplingParams | PoolingParams | BeamSearchParams | None,
lora_request: LoRARequest | None,
) -> None:
if self.request_logger is None:
return
prompt, prompt_token_ids, prompt_embeds = get_prompt_components(inputs)
components = self._extract_prompt_components(inputs)
self.request_logger.log_inputs(
request_id,
prompt,
prompt_token_ids,
prompt_embeds,
components.text,
components.token_ids,
components.embeds,
params=params,
lora_request=lora_request,
)
......
......@@ -116,13 +116,13 @@ from vllm.entrypoints.openai.responses.utils import (
)
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_len
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs
from vllm.outputs import CompletionOutput
from vllm.parser import ParserManager
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid
......@@ -292,10 +292,10 @@ class OpenAIServingResponses(OpenAIServing):
def _validate_generator_input(
self,
engine_prompt: TokensPrompt | EmbedsPrompt,
engine_prompt: TokPrompt,
) -> ErrorResponse | None:
"""Add validations to the input to the generator here."""
prompt_len = get_prompt_len(engine_prompt)
prompt_len = self._extract_prompt_len(engine_prompt)
if self.max_model_len <= prompt_len:
error_message = (
f"The engine prompt length {prompt_len} "
......@@ -442,7 +442,7 @@ class OpenAIServingResponses(OpenAIServing):
default_max_tokens = get_max_tokens(
self.max_model_len,
request,
engine_prompt,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params,
)
......
......@@ -7,7 +7,7 @@ import time
import zlib
from collections.abc import AsyncGenerator, Callable
from functools import cached_property
from typing import Literal, TypeAlias, TypeVar, cast
from typing import Final, Literal, TypeAlias, TypeVar, cast
import numpy as np
from fastapi import Request
......@@ -37,12 +37,13 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
TranslationStreamResponse,
)
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import ExplicitEncoderDecoderPrompt, PromptType
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.logprobs import FlatLogprobs, Logprob
from vllm.model_executor.models import SupportsTranscription, supports_transcription
from vllm.outputs import RequestOutput
from vllm.renderers.inputs import EncoderDecoderDictPrompt
from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt
from vllm.tokenizers import get_tokenizer
from vllm.utils.import_utils import PlaceholderModule
......@@ -94,7 +95,7 @@ class OpenAISpeechToText(OpenAIServing):
)
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.task_type = task_type
self.task_type: Final = task_type
self.asr_config = self.model_cls.get_speech_to_text_config(
self.model_config, task_type
......@@ -298,35 +299,26 @@ class OpenAISpeechToText(OpenAIServing):
to_language=to_language,
)
if request.response_format == "verbose_json":
if not is_explicit_encoder_decoder_prompt(prompt):
raise VLLMValidationError(
"Expected prompt to be an encoder-decoder prompt",
parameter="prompt",
value=type(prompt).__name__,
)
prompt = self._preprocess_verbose_prompt(prompt)
prompt = self._preprocess_verbose_prompt(parse_enc_dec_prompt(prompt))
prompts.append(prompt)
return prompts, duration
def _repl_verbose_text(self, text: str):
return text.replace("<|notimestamps|>", "<|0.00|>")
return prompts, duration
def _preprocess_verbose_prompt(self, prompt: ExplicitEncoderDecoderPrompt):
def _preprocess_verbose_prompt(self, prompt: EncoderDecoderDictPrompt):
dec_prompt = prompt["decoder_prompt"]
if isinstance(dec_prompt, str):
prompt["decoder_prompt"] = self._repl_verbose_text(dec_prompt)
elif isinstance(dec_prompt, dict) and "prompt" in dec_prompt:
dec_prompt["prompt"] = self._repl_verbose_text(dec_prompt["prompt"])
else:
if not (isinstance(dec_prompt, dict) and "prompt" in dec_prompt):
raise VLLMValidationError(
"Expected decoder_prompt to contain text",
parameter="decoder_prompt",
value=type(dec_prompt).__name__,
)
dec_prompt["prompt"] = dec_prompt["prompt"].replace(
"<|notimestamps|>", "<|0.00|>"
)
return prompt
def _get_verbose_segments(
......
......@@ -28,10 +28,11 @@ from vllm.entrypoints.pooling.utils import (
encode_pooling_output_base64,
encode_pooling_output_float,
)
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.pooling_params import PoolingParams
from vllm.renderers.inputs import TokPrompt
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import chunk_list
from vllm.utils.serial_utils import EmbedDType, Endianness
......@@ -369,7 +370,7 @@ class OpenAIServingEmbedding(OpenAIServing):
async def _create_single_prompt_generator(
self,
ctx: EmbeddingServeContext,
engine_prompt: TokensPrompt | EmbedsPrompt,
engine_prompt: TokPrompt,
pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None,
prompt_index: int,
......
......@@ -33,8 +33,11 @@ from vllm.entrypoints.pooling.utils import (
encode_pooling_output_base64,
encode_pooling_output_float,
)
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.renderers.inputs import TokPrompt
from vllm.renderers.inputs.preprocess import prompt_to_seq
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
......@@ -91,6 +94,7 @@ class OpenAIServingPooling(OpenAIServing):
"dimensions is currently not supported"
)
engine_prompts: Sequence[PromptType | TokPrompt]
if is_io_processor_request:
if self.io_processor is None:
raise ValueError(
......@@ -102,14 +106,10 @@ class OpenAIServingPooling(OpenAIServing):
validated_prompt = self.io_processor.parse_request(request)
engine_prompts = await self.io_processor.pre_process_async(
raw_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id
)
if not isinstance(engine_prompts, Sequence) or isinstance(
engine_prompts, (str, bytes, bytearray)
):
engine_prompts = [engine_prompts]
engine_prompts = prompt_to_seq(raw_prompts)
elif isinstance(request, PoolingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=request.chat_template,
......
......@@ -17,8 +17,6 @@ from starlette.background import BackgroundTask, BackgroundTasks
from vllm import envs
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import EmbedsPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_len
from vllm.logger import current_formatter_type, init_logger
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser
......@@ -189,7 +187,7 @@ def cli_env_setup():
def get_max_tokens(
max_model_len: int,
request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest",
prompt: TokensPrompt | EmbedsPrompt,
input_length: int,
default_sampling_params: dict,
) -> int:
# NOTE: Avoid isinstance() for better efficiency
......@@ -204,7 +202,6 @@ def get_max_tokens(
# CompletionRequest (also a fallback for ChatCompletionRequest)
max_tokens = getattr(request, "max_tokens", None)
input_length = get_prompt_len(prompt)
default_max_tokens = max_model_len - input_length
max_output_tokens = current_platform.get_max_output_tokens(input_length)
......
......@@ -16,11 +16,8 @@ from .data import (
TextPrompt,
TokenInputs,
TokensPrompt,
build_explicit_enc_dec_prompt,
embeds_inputs,
to_enc_dec_tuple_list,
token_inputs,
zip_enc_dec_prompts,
)
__all__ = [
......@@ -39,8 +36,5 @@ __all__ = [
"EncoderDecoderInputs",
"ProcessorInputs",
"SingletonInputs",
"build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list",
"zip_enc_dec_prompts",
"StreamingInput",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment