Unverified Commit ba2f0acc authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Reorganize inputs (#35182)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 678b3c99
......@@ -27,11 +27,9 @@ LLM Class.
- [vllm.LLM][]
LLM Inputs.
Prompt schema for LLM APIs.
- [vllm.inputs.PromptType][]
- [vllm.inputs.TextPrompt][]
- [vllm.inputs.TokensPrompt][]
- [vllm.inputs.llm][]
## vLLM Engines
......@@ -58,13 +56,7 @@ Looking to add your own multi-modal model? Please follow the instructions listed
- [vllm.multimodal.MULTIMODAL_REGISTRY][]
### Inputs
User-facing inputs.
- [vllm.multimodal.inputs.MultiModalDataDict][]
Internal data structures.
### Internal data structures
- [vllm.multimodal.inputs.PlaceholderRange][]
- [vllm.multimodal.inputs.NestedTensors][]
......@@ -72,7 +64,6 @@ Internal data structures.
- [vllm.multimodal.inputs.MultiModalFieldConfig][]
- [vllm.multimodal.inputs.MultiModalKwargsItem][]
- [vllm.multimodal.inputs.MultiModalKwargsItems][]
- [vllm.multimodal.inputs.MultiModalInputs][]
### Data Parsing
......
......@@ -23,7 +23,7 @@ Declare supported languages and capabilities:
from torch import nn
from vllm.config import ModelConfig, SpeechToTextConfig
from vllm.inputs.data import PromptType
from vllm.inputs import PromptType
from vllm.model_executor.models.interfaces import SupportsTranscription
class YourASRModel(nn.Module, SupportsTranscription):
......@@ -66,7 +66,7 @@ This is for controlling general behavior of the API when serving your model:
See [Audio preprocessing and chunking](#audio-preprocessing-and-chunking) for what each field controls.
Implement the prompt construction via [get_generation_prompt][vllm.model_executor.models.interfaces.SupportsTranscription.get_generation_prompt]. The server passes you the resampled waveform and task parameters; you return a valid [PromptType][vllm.inputs.data.PromptType]. There are two common patterns:
Implement the prompt construction via [get_generation_prompt][vllm.model_executor.models.interfaces.SupportsTranscription.get_generation_prompt]. The server passes you the resampled waveform and task parameters; you return a valid [PromptType][vllm.inputs.llm.PromptType]. There are two common patterns:
#### Multimodal LLM with audio embeddings (e.g., Voxtral, Gemma3n)
......
......@@ -18,7 +18,7 @@ This page teaches you how to pass multi-modal inputs to [multi-modal models](../
To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]:
- `prompt`: The prompt should follow the format that is documented on HuggingFace.
- `multi_modal_data`: This is a dictionary that follows the schema defined in [vllm.multimodal.inputs.MultiModalDataDict][].
- `multi_modal_data`: This is a dictionary that follows the schema defined in [vllm.inputs.MultiModalDataDict][].
### Image Inputs
......
......@@ -4,7 +4,7 @@
import torch
from vllm import LLM
from vllm.inputs.data import TextPrompt
from vllm.inputs import TextPrompt
from vllm.multimodal.utils import fetch_image
# Initialize model
......
......@@ -105,7 +105,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
)
async def _fake_preprocess_chat(*args, **kwargs):
# return conversation, engine_prompts
# return conversation, engine_inputs
return (
[{"role": "user", "content": "Test"}],
[{"prompt_token_ids": [1, 2, 3]}],
......
......@@ -958,14 +958,14 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
serving_chat = _build_serving_chat(mock_engine)
orig_render_chat_request = serving_chat.render_chat_request
captured_prompts = []
captured_inputs = []
async def render_chat_request(request):
result = await orig_render_chat_request(request)
assert isinstance(result, tuple)
conversation, engine_prompts = result
captured_prompts.extend(engine_prompts)
conversation, engine_inputs = result
captured_inputs.extend(engine_inputs)
return result
......@@ -981,18 +981,18 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
with suppress(Exception):
await serving_chat.create_chat_completion(req)
assert len(captured_prompts) == 1
assert "cache_salt" not in captured_prompts[0]
assert len(captured_inputs) == 1
assert "cache_salt" not in captured_inputs[0]
captured_prompts.clear()
captured_inputs.clear()
# Test with certain cache_salt
req.cache_salt = "test_salt"
with suppress(Exception):
await serving_chat.create_chat_completion(req)
assert len(captured_prompts) == 1
assert captured_prompts[0]["cache_salt"] == "test_salt"
assert len(captured_inputs) == 1
assert captured_inputs[0]["cache_salt"] == "test_salt"
@pytest.mark.asyncio
......
......@@ -37,7 +37,7 @@ from vllm.entrypoints.openai.responses.serving import (
from vllm.entrypoints.openai.responses.streaming_events import (
StreamingState,
)
from vllm.inputs.data import TokensPrompt
from vllm.inputs import tokens_input
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams
......@@ -258,20 +258,20 @@ class TestValidateGeneratorInput:
"""Test _validate_generator_input with valid prompt length"""
# Create an engine prompt with valid length (less than max_model_len)
valid_prompt_token_ids = list(range(5)) # 5 tokens < 100 max_model_len
engine_prompt = TokensPrompt(prompt_token_ids=valid_prompt_token_ids)
engine_input = tokens_input(valid_prompt_token_ids)
# Call the method
result = serving_responses_instance._validate_generator_input(engine_prompt)
result = serving_responses_instance._validate_generator_input(engine_input)
# Should return None for valid input
assert result is None
# create an invalid engine prompt
invalid_prompt_token_ids = list(range(200)) # 100 tokens >= 100 max_model_len
engine_prompt = TokensPrompt(prompt_token_ids=invalid_prompt_token_ids)
engine_input = tokens_input(invalid_prompt_token_ids)
# Call the method
result = serving_responses_instance._validate_generator_input(engine_prompt)
result = serving_responses_instance._validate_generator_input(engine_input)
# Should return an ErrorResponse
assert result is not None
......
......@@ -73,20 +73,6 @@ async def test_chat_render_multi_turn(client):
assert len(data["token_ids"]) > 0
@pytest.mark.asyncio
async def test_chat_render_invalid_model(client):
response = await client.post(
"/v1/chat/completions/render",
json={
"model": "nonexistent-model",
"messages": [{"role": "user", "content": "Hello"}],
},
)
assert response.status_code == 404
assert "error" in response.json()
# -- Completion Render --
......
......@@ -16,7 +16,7 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict
from vllm.inputs import MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.utils import (
encode_audio_url,
encode_image_url,
......
......@@ -13,8 +13,8 @@ from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
from transformers import AutoProcessor
from vllm import SamplingParams, TextPrompt, TokensPrompt
from vllm.inputs import MultiModalDataBuiltins
from vllm.logprobs import Logprob, SampleLogprobs
from vllm.multimodal import MultiModalDataBuiltins
from vllm.platforms import current_platform
from ....utils import VLLM_PATH, large_gpu_test
......
......@@ -15,13 +15,11 @@ from vllm.config.multimodal import (
ImageDummyOptions,
VideoDummyOptions,
)
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.inputs import MultiModalDataDict, MultiModalInput
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import MultiModalInputs, batched_tensors_equal
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
InputProcessingContext,
)
from vllm.multimodal.inputs import batched_tensors_equal
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from vllm.utils.mistral import is_mistral_tokenizer
......@@ -420,8 +418,8 @@ def test_processing_correctness(
def _assert_inputs_equal(
a: MultiModalInputs,
b: MultiModalInputs,
a: MultiModalInput,
b: MultiModalInput,
*,
ignore_mm_keys: set[str] | None = None,
msg: str = "",
......
......@@ -6,11 +6,9 @@ from collections.abc import Sequence
from vllm.config import ModelConfig, PoolerConfig, VllmConfig
from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.protocol import EmbedRequestMixin
from vllm.inputs.data import PromptType
from vllm.inputs import PromptType
from vllm.outputs import PoolingRequestOutput
from vllm.plugins.io_processors.interface import (
IOProcessor,
)
from vllm.plugins.io_processors.interface import IOProcessor
from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer
from vllm.tokenizers.detokenizer_utils import convert_ids_list_to_tokens
......
......@@ -18,7 +18,7 @@ from einops import rearrange
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
from vllm.config import VllmConfig
from vllm.inputs.data import PromptType
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.plugins.io_processors.interface import IOProcessor
......
......@@ -6,7 +6,7 @@ from unittest.mock import MagicMock, patch
import pytest
from vllm.config import VllmConfig
from vllm.inputs.data import PromptType
from vllm.inputs import PromptType
from vllm.outputs import PoolingRequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.plugins.io_processors.interface import IOProcessor
......
......@@ -15,13 +15,13 @@ def test_text_input():
assert prompt_to_seq(["foo", "bar"]) == ["foo", "bar"]
def test_token_input():
def test_tokens_input():
assert prompt_to_seq([1, 2]) == [[1, 2]]
assert prompt_to_seq([[1, 2]]) == [[1, 2]]
assert prompt_to_seq([[1, 2], [3, 4]]) == [[1, 2], [3, 4]]
def test_text_token_input():
def test_text_tokens_input():
assert prompt_to_seq([[1, 2], "foo"]) == [[1, 2], "foo"]
assert prompt_to_seq(["foo", [1, 2]]) == ["foo", [1, 2]]
......
......@@ -129,7 +129,7 @@ class TestValidatePrompt:
class TestRenderPrompt:
def test_token_input(self):
def test_tokens_input(self):
renderer = _build_renderer(MockModelConfig())
tokens = [101, 7592, 2088]
......@@ -339,7 +339,7 @@ class TestRenderPrompt:
TokenizeParams(max_total_tokens=100),
)
def test_token_input_with_needs_detokenization(self):
def test_tokens_input_with_needs_detokenization(self):
renderer = _build_renderer(MockModelConfig())
tokens = [1, 2, 3, 4]
......
......@@ -9,7 +9,7 @@ import pytest
from tests.v1.shutdown.utils import SHUTDOWN_TEST_TIMEOUT_SEC
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs.data import TokensPrompt
from vllm.inputs import TokensPrompt
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.exceptions import EngineGenerateError
......
......@@ -3,11 +3,16 @@
from dataclasses import dataclass
from vllm.inputs import EncoderDecoderInputs, TokenInputs, token_inputs
from vllm.inputs.data import DecoderInputs
from vllm.inputs import (
DecoderOnlyEngineInput,
EncoderDecoderInput,
MultiModalInput,
TokensInput,
mm_input,
tokens_input,
)
from vllm.logprobs import Logprob
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalInputs, mm_inputs
@dataclass
......@@ -18,7 +23,7 @@ class BeamSearchSequence:
about to be returned to the user.
"""
orig_prompt: TokenInputs | MultiModalInputs | EncoderDecoderInputs
orig_prompt: TokensInput | MultiModalInput | EncoderDecoderInput
# NOTE: Tokens represents decoder tokens in the encoder / decoder case
tokens: list[int]
......@@ -40,13 +45,13 @@ class BeamSearchSequence:
cache_salt = prompt.get("cache_salt")
if prompt["type"] == "token":
return token_inputs(
return tokens_input(
self.tokens,
prompt=prompt_text,
cache_salt=cache_salt,
)
return mm_inputs(
return mm_input(
prompt_token_ids=self.tokens,
mm_kwargs=prompt["mm_kwargs"],
mm_hashes=prompt["mm_hashes"],
......@@ -56,8 +61,8 @@ class BeamSearchSequence:
)
def _build_encoder_decoder_inputs(
self, prompt: EncoderDecoderInputs
) -> EncoderDecoderInputs:
self, prompt: EncoderDecoderInput
) -> EncoderDecoderInput:
"""Rebuild the encoder-decoder inputs with the current beam search
sequence's tokens.
......@@ -70,9 +75,9 @@ class BeamSearchSequence:
# Rebuild decoder prompt with updated tokens,
# but keep everything else the same.
new_dec_prompt: DecoderInputs
new_dec_prompt: DecoderOnlyEngineInput
if dec_prompt["type"] == "multimodal":
new_dec_prompt = mm_inputs(
new_dec_prompt = mm_input(
self.tokens,
mm_kwargs=dec_prompt["mm_kwargs"],
mm_hashes=dec_prompt["mm_hashes"],
......@@ -81,13 +86,13 @@ class BeamSearchSequence:
cache_salt=dec_prompt.get("cache_salt"),
)
else:
new_dec_prompt = token_inputs(
new_dec_prompt = tokens_input(
self.tokens,
prompt=dec_prompt.get("prompt"),
cache_salt=dec_prompt.get("cache_salt"),
)
return EncoderDecoderInputs(
return EncoderDecoderInput(
type="enc_dec",
encoder_prompt=prompt["encoder_prompt"],
decoder_prompt=new_dec_prompt,
......@@ -107,7 +112,7 @@ class BeamSearchOutput:
class BeamSearchInstance:
def __init__(
self,
prompt: TokenInputs | MultiModalInputs | EncoderDecoderInputs,
prompt: TokensInput | MultiModalInput | EncoderDecoderInput,
lora_request: LoRARequest | None = None,
logprobs: list[dict[int, Logprob]] | None = None,
**kwargs,
......
......@@ -35,9 +35,9 @@ from huggingface_hub import snapshot_download
from PIL import Image
from typing_extensions import deprecated
from vllm.inputs import MultiModalDataDict
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.audio import get_audio_duration
from vllm.multimodal.image import convert_image_mode
from vllm.tokenizers import TokenizerLike
......
......@@ -11,7 +11,7 @@ from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest,
WeightTransferUpdateRequest,
)
from vllm.inputs.data import ProcessorInputs, PromptType
from vllm.inputs import EngineInput, PromptType
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor
......@@ -34,7 +34,7 @@ class StreamingInput:
where inputs are provided via an async generator.
"""
prompt: ProcessorInputs
prompt: EngineInput
sampling_params: SamplingParams | None = None
......@@ -68,7 +68,7 @@ class EngineClient(ABC):
self,
prompt: EngineCoreRequest
| PromptType
| ProcessorInputs
| EngineInput
| AsyncGenerator[StreamingInput, None],
sampling_params: SamplingParams,
request_id: str,
......@@ -87,7 +87,7 @@ class EngineClient(ABC):
@abstractmethod
def encode(
self,
prompt: PromptType | ProcessorInputs,
prompt: PromptType | EngineInput,
pooling_params: PoolingParams,
request_id: str,
lora_request: LoRARequest | None = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment