Unverified Commit ba2f0acc authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Reorganize inputs (#35182)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 678b3c99
...@@ -27,11 +27,9 @@ LLM Class. ...@@ -27,11 +27,9 @@ LLM Class.
- [vllm.LLM][] - [vllm.LLM][]
LLM Inputs. Prompt schema for LLM APIs.
- [vllm.inputs.PromptType][] - [vllm.inputs.llm][]
- [vllm.inputs.TextPrompt][]
- [vllm.inputs.TokensPrompt][]
## vLLM Engines ## vLLM Engines
...@@ -58,13 +56,7 @@ Looking to add your own multi-modal model? Please follow the instructions listed ...@@ -58,13 +56,7 @@ Looking to add your own multi-modal model? Please follow the instructions listed
- [vllm.multimodal.MULTIMODAL_REGISTRY][] - [vllm.multimodal.MULTIMODAL_REGISTRY][]
### Inputs ### Internal data structures
User-facing inputs.
- [vllm.multimodal.inputs.MultiModalDataDict][]
Internal data structures.
- [vllm.multimodal.inputs.PlaceholderRange][] - [vllm.multimodal.inputs.PlaceholderRange][]
- [vllm.multimodal.inputs.NestedTensors][] - [vllm.multimodal.inputs.NestedTensors][]
...@@ -72,7 +64,6 @@ Internal data structures. ...@@ -72,7 +64,6 @@ Internal data structures.
- [vllm.multimodal.inputs.MultiModalFieldConfig][] - [vllm.multimodal.inputs.MultiModalFieldConfig][]
- [vllm.multimodal.inputs.MultiModalKwargsItem][] - [vllm.multimodal.inputs.MultiModalKwargsItem][]
- [vllm.multimodal.inputs.MultiModalKwargsItems][] - [vllm.multimodal.inputs.MultiModalKwargsItems][]
- [vllm.multimodal.inputs.MultiModalInputs][]
### Data Parsing ### Data Parsing
......
...@@ -23,7 +23,7 @@ Declare supported languages and capabilities: ...@@ -23,7 +23,7 @@ Declare supported languages and capabilities:
from torch import nn from torch import nn
from vllm.config import ModelConfig, SpeechToTextConfig from vllm.config import ModelConfig, SpeechToTextConfig
from vllm.inputs.data import PromptType from vllm.inputs import PromptType
from vllm.model_executor.models.interfaces import SupportsTranscription from vllm.model_executor.models.interfaces import SupportsTranscription
class YourASRModel(nn.Module, SupportsTranscription): class YourASRModel(nn.Module, SupportsTranscription):
...@@ -66,7 +66,7 @@ This is for controlling general behavior of the API when serving your model: ...@@ -66,7 +66,7 @@ This is for controlling general behavior of the API when serving your model:
See [Audio preprocessing and chunking](#audio-preprocessing-and-chunking) for what each field controls. See [Audio preprocessing and chunking](#audio-preprocessing-and-chunking) for what each field controls.
Implement the prompt construction via [get_generation_prompt][vllm.model_executor.models.interfaces.SupportsTranscription.get_generation_prompt]. The server passes you the resampled waveform and task parameters; you return a valid [PromptType][vllm.inputs.data.PromptType]. There are two common patterns: Implement the prompt construction via [get_generation_prompt][vllm.model_executor.models.interfaces.SupportsTranscription.get_generation_prompt]. The server passes you the resampled waveform and task parameters; you return a valid [PromptType][vllm.inputs.llm.PromptType]. There are two common patterns:
#### Multimodal LLM with audio embeddings (e.g., Voxtral, Gemma3n) #### Multimodal LLM with audio embeddings (e.g., Voxtral, Gemma3n)
......
...@@ -18,7 +18,7 @@ This page teaches you how to pass multi-modal inputs to [multi-modal models](../ ...@@ -18,7 +18,7 @@ This page teaches you how to pass multi-modal inputs to [multi-modal models](../
To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]: To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]:
- `prompt`: The prompt should follow the format that is documented on HuggingFace. - `prompt`: The prompt should follow the format that is documented on HuggingFace.
- `multi_modal_data`: This is a dictionary that follows the schema defined in [vllm.multimodal.inputs.MultiModalDataDict][]. - `multi_modal_data`: This is a dictionary that follows the schema defined in [vllm.inputs.MultiModalDataDict][].
### Image Inputs ### Image Inputs
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import torch import torch
from vllm import LLM from vllm import LLM
from vllm.inputs.data import TextPrompt from vllm.inputs import TextPrompt
from vllm.multimodal.utils import fetch_image from vllm.multimodal.utils import fetch_image
# Initialize model # Initialize model
......
...@@ -105,7 +105,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: ...@@ -105,7 +105,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
) )
async def _fake_preprocess_chat(*args, **kwargs): async def _fake_preprocess_chat(*args, **kwargs):
# return conversation, engine_prompts # return conversation, engine_inputs
return ( return (
[{"role": "user", "content": "Test"}], [{"role": "user", "content": "Test"}],
[{"prompt_token_ids": [1, 2, 3]}], [{"prompt_token_ids": [1, 2, 3]}],
......
...@@ -958,14 +958,14 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type): ...@@ -958,14 +958,14 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
serving_chat = _build_serving_chat(mock_engine) serving_chat = _build_serving_chat(mock_engine)
orig_render_chat_request = serving_chat.render_chat_request orig_render_chat_request = serving_chat.render_chat_request
captured_prompts = [] captured_inputs = []
async def render_chat_request(request): async def render_chat_request(request):
result = await orig_render_chat_request(request) result = await orig_render_chat_request(request)
assert isinstance(result, tuple) assert isinstance(result, tuple)
conversation, engine_prompts = result conversation, engine_inputs = result
captured_prompts.extend(engine_prompts) captured_inputs.extend(engine_inputs)
return result return result
...@@ -981,18 +981,18 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type): ...@@ -981,18 +981,18 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
with suppress(Exception): with suppress(Exception):
await serving_chat.create_chat_completion(req) await serving_chat.create_chat_completion(req)
assert len(captured_prompts) == 1 assert len(captured_inputs) == 1
assert "cache_salt" not in captured_prompts[0] assert "cache_salt" not in captured_inputs[0]
captured_prompts.clear() captured_inputs.clear()
# Test with certain cache_salt # Test with certain cache_salt
req.cache_salt = "test_salt" req.cache_salt = "test_salt"
with suppress(Exception): with suppress(Exception):
await serving_chat.create_chat_completion(req) await serving_chat.create_chat_completion(req)
assert len(captured_prompts) == 1 assert len(captured_inputs) == 1
assert captured_prompts[0]["cache_salt"] == "test_salt" assert captured_inputs[0]["cache_salt"] == "test_salt"
@pytest.mark.asyncio @pytest.mark.asyncio
......
...@@ -37,7 +37,7 @@ from vllm.entrypoints.openai.responses.serving import ( ...@@ -37,7 +37,7 @@ from vllm.entrypoints.openai.responses.serving import (
from vllm.entrypoints.openai.responses.streaming_events import ( from vllm.entrypoints.openai.responses.streaming_events import (
StreamingState, StreamingState,
) )
from vllm.inputs.data import TokensPrompt from vllm.inputs import tokens_input
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -258,20 +258,20 @@ class TestValidateGeneratorInput: ...@@ -258,20 +258,20 @@ class TestValidateGeneratorInput:
"""Test _validate_generator_input with valid prompt length""" """Test _validate_generator_input with valid prompt length"""
# Create an engine prompt with valid length (less than max_model_len) # Create an engine prompt with valid length (less than max_model_len)
valid_prompt_token_ids = list(range(5)) # 5 tokens < 100 max_model_len valid_prompt_token_ids = list(range(5)) # 5 tokens < 100 max_model_len
engine_prompt = TokensPrompt(prompt_token_ids=valid_prompt_token_ids) engine_input = tokens_input(valid_prompt_token_ids)
# Call the method # Call the method
result = serving_responses_instance._validate_generator_input(engine_prompt) result = serving_responses_instance._validate_generator_input(engine_input)
# Should return None for valid input # Should return None for valid input
assert result is None assert result is None
# create an invalid engine prompt # create an invalid engine prompt
invalid_prompt_token_ids = list(range(200)) # 100 tokens >= 100 max_model_len invalid_prompt_token_ids = list(range(200)) # 100 tokens >= 100 max_model_len
engine_prompt = TokensPrompt(prompt_token_ids=invalid_prompt_token_ids) engine_input = tokens_input(invalid_prompt_token_ids)
# Call the method # Call the method
result = serving_responses_instance._validate_generator_input(engine_prompt) result = serving_responses_instance._validate_generator_input(engine_input)
# Should return an ErrorResponse # Should return an ErrorResponse
assert result is not None assert result is not None
......
...@@ -73,20 +73,6 @@ async def test_chat_render_multi_turn(client): ...@@ -73,20 +73,6 @@ async def test_chat_render_multi_turn(client):
assert len(data["token_ids"]) > 0 assert len(data["token_ids"]) > 0
@pytest.mark.asyncio
async def test_chat_render_invalid_model(client):
response = await client.post(
"/v1/chat/completions/render",
json={
"model": "nonexistent-model",
"messages": [{"role": "user", "content": "Hello"}],
},
)
assert response.status_code == 404
assert "error" in response.json()
# -- Completion Render -- # -- Completion Render --
......
...@@ -16,7 +16,7 @@ from vllm.entrypoints.chat_utils import ( ...@@ -16,7 +16,7 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages, parse_chat_messages,
parse_chat_messages_async, parse_chat_messages_async,
) )
from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict from vllm.inputs import MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.utils import ( from vllm.multimodal.utils import (
encode_audio_url, encode_audio_url,
encode_image_url, encode_image_url,
......
...@@ -13,8 +13,8 @@ from mistral_common.tokens.tokenizers.multimodal import image_from_chunk ...@@ -13,8 +13,8 @@ from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
from transformers import AutoProcessor from transformers import AutoProcessor
from vllm import SamplingParams, TextPrompt, TokensPrompt from vllm import SamplingParams, TextPrompt, TokensPrompt
from vllm.inputs import MultiModalDataBuiltins
from vllm.logprobs import Logprob, SampleLogprobs from vllm.logprobs import Logprob, SampleLogprobs
from vllm.multimodal import MultiModalDataBuiltins
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ....utils import VLLM_PATH, large_gpu_test from ....utils import VLLM_PATH, large_gpu_test
......
...@@ -15,13 +15,11 @@ from vllm.config.multimodal import ( ...@@ -15,13 +15,11 @@ from vllm.config.multimodal import (
ImageDummyOptions, ImageDummyOptions,
VideoDummyOptions, VideoDummyOptions,
) )
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict from vllm.inputs import MultiModalDataDict, MultiModalInput
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import MultiModalProcessorOnlyCache from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import MultiModalInputs, batched_tensors_equal from vllm.multimodal.inputs import batched_tensors_equal
from vllm.multimodal.processing import ( from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
BaseMultiModalProcessor,
InputProcessingContext,
)
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from vllm.utils.mistral import is_mistral_tokenizer from vllm.utils.mistral import is_mistral_tokenizer
...@@ -420,8 +418,8 @@ def test_processing_correctness( ...@@ -420,8 +418,8 @@ def test_processing_correctness(
def _assert_inputs_equal( def _assert_inputs_equal(
a: MultiModalInputs, a: MultiModalInput,
b: MultiModalInputs, b: MultiModalInput,
*, *,
ignore_mm_keys: set[str] | None = None, ignore_mm_keys: set[str] | None = None,
msg: str = "", msg: str = "",
......
...@@ -6,11 +6,9 @@ from collections.abc import Sequence ...@@ -6,11 +6,9 @@ from collections.abc import Sequence
from vllm.config import ModelConfig, PoolerConfig, VllmConfig from vllm.config import ModelConfig, PoolerConfig, VllmConfig
from vllm.entrypoints.openai.engine.protocol import UsageInfo from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.protocol import EmbedRequestMixin from vllm.entrypoints.pooling.base.protocol import EmbedRequestMixin
from vllm.inputs.data import PromptType from vllm.inputs import PromptType
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.plugins.io_processors.interface import ( from vllm.plugins.io_processors.interface import IOProcessor
IOProcessor,
)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer from vllm.renderers import BaseRenderer
from vllm.tokenizers.detokenizer_utils import convert_ids_list_to_tokens from vllm.tokenizers.detokenizer_utils import convert_ids_list_to_tokens
......
...@@ -18,7 +18,7 @@ from einops import rearrange ...@@ -18,7 +18,7 @@ from einops import rearrange
from terratorch.datamodules import Sen1Floods11NonGeoDataModule from terratorch.datamodules import Sen1Floods11NonGeoDataModule
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs.data import PromptType from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.plugins.io_processors.interface import IOProcessor from vllm.plugins.io_processors.interface import IOProcessor
......
...@@ -6,7 +6,7 @@ from unittest.mock import MagicMock, patch ...@@ -6,7 +6,7 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs.data import PromptType from vllm.inputs import PromptType
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.plugins.io_processors import get_io_processor from vllm.plugins.io_processors import get_io_processor
from vllm.plugins.io_processors.interface import IOProcessor from vllm.plugins.io_processors.interface import IOProcessor
......
...@@ -15,13 +15,13 @@ def test_text_input(): ...@@ -15,13 +15,13 @@ def test_text_input():
assert prompt_to_seq(["foo", "bar"]) == ["foo", "bar"] assert prompt_to_seq(["foo", "bar"]) == ["foo", "bar"]
def test_token_input(): def test_tokens_input():
assert prompt_to_seq([1, 2]) == [[1, 2]] assert prompt_to_seq([1, 2]) == [[1, 2]]
assert prompt_to_seq([[1, 2]]) == [[1, 2]] assert prompt_to_seq([[1, 2]]) == [[1, 2]]
assert prompt_to_seq([[1, 2], [3, 4]]) == [[1, 2], [3, 4]] assert prompt_to_seq([[1, 2], [3, 4]]) == [[1, 2], [3, 4]]
def test_text_token_input(): def test_text_tokens_input():
assert prompt_to_seq([[1, 2], "foo"]) == [[1, 2], "foo"] assert prompt_to_seq([[1, 2], "foo"]) == [[1, 2], "foo"]
assert prompt_to_seq(["foo", [1, 2]]) == ["foo", [1, 2]] assert prompt_to_seq(["foo", [1, 2]]) == ["foo", [1, 2]]
......
...@@ -129,7 +129,7 @@ class TestValidatePrompt: ...@@ -129,7 +129,7 @@ class TestValidatePrompt:
class TestRenderPrompt: class TestRenderPrompt:
def test_token_input(self): def test_tokens_input(self):
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
tokens = [101, 7592, 2088] tokens = [101, 7592, 2088]
...@@ -339,7 +339,7 @@ class TestRenderPrompt: ...@@ -339,7 +339,7 @@ class TestRenderPrompt:
TokenizeParams(max_total_tokens=100), TokenizeParams(max_total_tokens=100),
) )
def test_token_input_with_needs_detokenization(self): def test_tokens_input_with_needs_detokenization(self):
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
tokens = [1, 2, 3, 4] tokens = [1, 2, 3, 4]
......
...@@ -9,7 +9,7 @@ import pytest ...@@ -9,7 +9,7 @@ import pytest
from tests.v1.shutdown.utils import SHUTDOWN_TEST_TIMEOUT_SEC from tests.v1.shutdown.utils import SHUTDOWN_TEST_TIMEOUT_SEC
from vllm import SamplingParams from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs.data import TokensPrompt from vllm.inputs import TokensPrompt
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.exceptions import EngineGenerateError from vllm.v1.engine.exceptions import EngineGenerateError
......
...@@ -3,11 +3,16 @@ ...@@ -3,11 +3,16 @@
from dataclasses import dataclass from dataclasses import dataclass
from vllm.inputs import EncoderDecoderInputs, TokenInputs, token_inputs from vllm.inputs import (
from vllm.inputs.data import DecoderInputs DecoderOnlyEngineInput,
EncoderDecoderInput,
MultiModalInput,
TokensInput,
mm_input,
tokens_input,
)
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalInputs, mm_inputs
@dataclass @dataclass
...@@ -18,7 +23,7 @@ class BeamSearchSequence: ...@@ -18,7 +23,7 @@ class BeamSearchSequence:
about to be returned to the user. about to be returned to the user.
""" """
orig_prompt: TokenInputs | MultiModalInputs | EncoderDecoderInputs orig_prompt: TokensInput | MultiModalInput | EncoderDecoderInput
# NOTE: Tokens represents decoder tokens in the encoder / decoder case # NOTE: Tokens represents decoder tokens in the encoder / decoder case
tokens: list[int] tokens: list[int]
...@@ -40,13 +45,13 @@ class BeamSearchSequence: ...@@ -40,13 +45,13 @@ class BeamSearchSequence:
cache_salt = prompt.get("cache_salt") cache_salt = prompt.get("cache_salt")
if prompt["type"] == "token": if prompt["type"] == "token":
return token_inputs( return tokens_input(
self.tokens, self.tokens,
prompt=prompt_text, prompt=prompt_text,
cache_salt=cache_salt, cache_salt=cache_salt,
) )
return mm_inputs( return mm_input(
prompt_token_ids=self.tokens, prompt_token_ids=self.tokens,
mm_kwargs=prompt["mm_kwargs"], mm_kwargs=prompt["mm_kwargs"],
mm_hashes=prompt["mm_hashes"], mm_hashes=prompt["mm_hashes"],
...@@ -56,8 +61,8 @@ class BeamSearchSequence: ...@@ -56,8 +61,8 @@ class BeamSearchSequence:
) )
def _build_encoder_decoder_inputs( def _build_encoder_decoder_inputs(
self, prompt: EncoderDecoderInputs self, prompt: EncoderDecoderInput
) -> EncoderDecoderInputs: ) -> EncoderDecoderInput:
"""Rebuild the encoder-decoder inputs with the current beam search """Rebuild the encoder-decoder inputs with the current beam search
sequence's tokens. sequence's tokens.
...@@ -70,9 +75,9 @@ class BeamSearchSequence: ...@@ -70,9 +75,9 @@ class BeamSearchSequence:
# Rebuild decoder prompt with updated tokens, # Rebuild decoder prompt with updated tokens,
# but keep everything else the same. # but keep everything else the same.
new_dec_prompt: DecoderInputs new_dec_prompt: DecoderOnlyEngineInput
if dec_prompt["type"] == "multimodal": if dec_prompt["type"] == "multimodal":
new_dec_prompt = mm_inputs( new_dec_prompt = mm_input(
self.tokens, self.tokens,
mm_kwargs=dec_prompt["mm_kwargs"], mm_kwargs=dec_prompt["mm_kwargs"],
mm_hashes=dec_prompt["mm_hashes"], mm_hashes=dec_prompt["mm_hashes"],
...@@ -81,13 +86,13 @@ class BeamSearchSequence: ...@@ -81,13 +86,13 @@ class BeamSearchSequence:
cache_salt=dec_prompt.get("cache_salt"), cache_salt=dec_prompt.get("cache_salt"),
) )
else: else:
new_dec_prompt = token_inputs( new_dec_prompt = tokens_input(
self.tokens, self.tokens,
prompt=dec_prompt.get("prompt"), prompt=dec_prompt.get("prompt"),
cache_salt=dec_prompt.get("cache_salt"), cache_salt=dec_prompt.get("cache_salt"),
) )
return EncoderDecoderInputs( return EncoderDecoderInput(
type="enc_dec", type="enc_dec",
encoder_prompt=prompt["encoder_prompt"], encoder_prompt=prompt["encoder_prompt"],
decoder_prompt=new_dec_prompt, decoder_prompt=new_dec_prompt,
...@@ -107,7 +112,7 @@ class BeamSearchOutput: ...@@ -107,7 +112,7 @@ class BeamSearchOutput:
class BeamSearchInstance: class BeamSearchInstance:
def __init__( def __init__(
self, self,
prompt: TokenInputs | MultiModalInputs | EncoderDecoderInputs, prompt: TokensInput | MultiModalInput | EncoderDecoderInput,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
logprobs: list[dict[int, Logprob]] | None = None, logprobs: list[dict[int, Logprob]] | None = None,
**kwargs, **kwargs,
......
...@@ -35,9 +35,9 @@ from huggingface_hub import snapshot_download ...@@ -35,9 +35,9 @@ from huggingface_hub import snapshot_download
from PIL import Image from PIL import Image
from typing_extensions import deprecated from typing_extensions import deprecated
from vllm.inputs import MultiModalDataDict
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.audio import get_audio_duration from vllm.multimodal.audio import get_audio_duration
from vllm.multimodal.image import convert_image_mode from vllm.multimodal.image import convert_image_mode
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
......
...@@ -11,7 +11,7 @@ from vllm.distributed.weight_transfer.base import ( ...@@ -11,7 +11,7 @@ from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest, WeightTransferInitRequest,
WeightTransferUpdateRequest, WeightTransferUpdateRequest,
) )
from vllm.inputs.data import ProcessorInputs, PromptType from vllm.inputs import EngineInput, PromptType
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor from vllm.plugins.io_processors import IOProcessor
...@@ -34,7 +34,7 @@ class StreamingInput: ...@@ -34,7 +34,7 @@ class StreamingInput:
where inputs are provided via an async generator. where inputs are provided via an async generator.
""" """
prompt: ProcessorInputs prompt: EngineInput
sampling_params: SamplingParams | None = None sampling_params: SamplingParams | None = None
...@@ -68,7 +68,7 @@ class EngineClient(ABC): ...@@ -68,7 +68,7 @@ class EngineClient(ABC):
self, self,
prompt: EngineCoreRequest prompt: EngineCoreRequest
| PromptType | PromptType
| ProcessorInputs | EngineInput
| AsyncGenerator[StreamingInput, None], | AsyncGenerator[StreamingInput, None],
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
...@@ -87,7 +87,7 @@ class EngineClient(ABC): ...@@ -87,7 +87,7 @@ class EngineClient(ABC):
@abstractmethod @abstractmethod
def encode( def encode(
self, self,
prompt: PromptType | ProcessorInputs, prompt: PromptType | EngineInput,
pooling_params: PoolingParams, pooling_params: PoolingParams,
request_id: str, request_id: str,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment