Unverified Commit cee711fd authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Rename input data types (#8688)

parent 1de76a0e
......@@ -25,7 +25,7 @@ Module Contents
LLM Engine Inputs
-----------------
.. autoclass:: vllm.inputs.LLMInputs
.. autoclass:: vllm.inputs.DecoderOnlyInputs
:members:
:show-inheritance:
......
import os
import re
from typing import Callable, List, Optional, Tuple, Type
from typing import List, Optional, Tuple, Type
import pytest
import torch
from transformers import AutoImageProcessor, AutoTokenizer
from vllm.inputs import InputContext, LLMInputs
from vllm.inputs import InputContext, token_inputs
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.utils import rescale_image_size
......@@ -311,7 +311,7 @@ def test_input_mapper_override(model: str, image_assets: _ImageAssets,
(4, 781),
(16, 2653),
])
def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str,
def test_max_tokens_override(get_max_phi3v_image_tokens, model: str,
num_crops: int, expected_max_tokens: int):
"""Ensure get_max_phi3v_image_tokens handles num_crops properly."""
# NOTE: mm_processor_kwargs on the context in this test is unused, since
......@@ -343,8 +343,8 @@ def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str,
(16, 2653, 1),
(16, 2653, 2),
])
def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str,
num_crops: int, toks_per_img: int, num_imgs: int):
def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int,
toks_per_img: int, num_imgs: int):
"""Ensure dummy_data_for_phi3v handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
......@@ -374,7 +374,7 @@ def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str,
(16, 1921, 1),
(16, 1921, 2),
])
def test_input_processor_override(input_processor_for_phi3v: Callable,
def test_input_processor_override(input_processor_for_phi3v,
image_assets: _ImageAssets, model: str,
num_crops: int, expected_toks_per_img: int,
num_imgs: int):
......@@ -393,16 +393,14 @@ def test_input_processor_override(input_processor_for_phi3v: Callable,
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
images = [image_assets[0].pil_image] * num_imgs
llm_inputs = LLMInputs(prompt_token_ids=tokenizer.encode(prompt),
inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
prompt=prompt,
multi_modal_data={"image": images})
proc_llm_inputs = input_processor_for_phi3v(
ctx=ctx,
llm_inputs=llm_inputs,
num_crops=num_crops,
)
processed_inputs = input_processor_for_phi3v(ctx,
inputs,
num_crops=num_crops)
# Ensure we have the right number of placeholders per num_crops size
img_tok_count = proc_llm_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID)
img_tok_count = processed_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID)
assert img_tok_count == expected_toks_per_img * num_imgs
......@@ -5,7 +5,7 @@ import pytest
import torch
from PIL.Image import Image
from vllm.inputs import InputContext, LLMInputs
from vllm.inputs import InputContext, token_inputs
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size
......@@ -71,12 +71,12 @@ def test_input_processor_valid_mm_data(input_processor_for_qwen,
"""Happy cases for image inputs to Qwen's multimodal input processor."""
prompt = "".join(
[f"Picture {num}: <img></img>\n" for num in range(1, num_images + 1)])
inputs = LLMInputs(
inputs = token_inputs(
prompt=prompt,
# When processing multimodal data for a multimodal model, the qwen
# input processor will overwrite the provided prompt_token_ids with
# the image prompts
prompt_token_ids=None,
prompt_token_ids=[],
multi_modal_data={"image": torch.rand(num_images, TOKS_PER_IMG, 4096)},
)
proc_inputs = input_processor_for_qwen(qwen_vl_context, inputs)
......@@ -134,7 +134,7 @@ def test_input_processor_invalid_mm_data(input_processor_for_qwen,
trust_remote_code=True)
prompt = "Picture 1: <img></img>\n"
prompt_token_ids = tokenizer.encode(prompt)
inputs = LLMInputs(prompt=prompt,
inputs = token_inputs(prompt=prompt,
prompt_token_ids=prompt_token_ids,
multi_modal_data=mm_data)
# Should fail since we have too many or too few dimensions for embeddings
......
......@@ -5,7 +5,7 @@ from unittest.mock import patch
import pytest
import torch
from vllm.inputs import InputContext, LLMInputs
from vllm.inputs import DecoderOnlyInputs, InputContext, token_inputs
from vllm.inputs.registry import InputRegistry
from vllm.multimodal import MultiModalRegistry
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
......@@ -31,7 +31,7 @@ def use_processor_mock():
"""Patches the internal model input processor with an override callable."""
def custom_processor(ctx: InputContext,
llm_inputs: LLMInputs,
inputs: DecoderOnlyInputs,
*,
num_crops=DEFAULT_NUM_CROPS):
# For testing purposes, we don't worry about the llm inputs / return
......@@ -84,7 +84,7 @@ def test_default_processor_is_a_noop():
dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID)
processor = dummy_registry.create_input_processor(ctx.model_config)
proc_inputs = LLMInputs(prompt_token_ids=[], prompt="")
proc_inputs = token_inputs(prompt_token_ids=[], prompt="")
proc_outputs = processor(inputs=proc_inputs)
assert proc_inputs is proc_outputs
......@@ -125,7 +125,7 @@ def test_input_processor_kwargs(use_processor_mock, init_num_crops,
ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
processor = dummy_registry.create_input_processor(ctx.model_config)
num_crops_val = processor(
LLMInputs(prompt_token_ids=[],
token_inputs(prompt_token_ids=[],
prompt="",
mm_processor_kwargs=inference_kwargs))
assert num_crops_val == expected_seq_count
......@@ -154,7 +154,7 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock,
processor = dummy_registry.create_input_processor(ctx.model_config)
# Should filter out the inference time kwargs
num_crops_val = processor(
LLMInputs(prompt_token_ids=[],
token_inputs(prompt_token_ids=[],
prompt="",
mm_processor_kwargs=mm_processor_kwargs))
assert num_crops_val == DEFAULT_NUM_CROPS
......
......@@ -29,8 +29,8 @@ from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
InputRegistry, LLMInputs, PromptType)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
EncoderDecoderInputs, InputRegistry, PromptType)
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
......@@ -635,7 +635,7 @@ class LLMEngine:
def _add_processed_request(
self,
request_id: str,
processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
......@@ -1855,8 +1855,8 @@ class LLMEngine:
def is_embedding_model(self):
return self.model_config.is_embedding_model
def _validate_model_inputs(self, inputs: Union[LLMInputs,
EncoderDecoderLLMInputs]):
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
EncoderDecoderInputs]):
if self.model_config.is_multimodal_model:
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
......
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
LLMInputs, PromptType, SingletonPrompt, TextPrompt,
TokensPrompt, build_explicit_enc_dec_prompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
ExplicitEncoderDecoderPrompt, PromptType, SingletonInputs,
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
build_explicit_enc_dec_prompt, to_enc_dec_tuple_list,
token_inputs, zip_enc_dec_prompts)
from .registry import InputContext, InputRegistry
INPUT_REGISTRY = InputRegistry()
......@@ -19,8 +20,11 @@ __all__ = [
"PromptType",
"SingletonPrompt",
"ExplicitEncoderDecoderPrompt",
"LLMInputs",
"EncoderDecoderLLMInputs",
"TokenInputs",
"token_inputs",
"SingletonInputs",
"DecoderOnlyInputs",
"EncoderDecoderInputs",
"build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list",
"zip_enc_dec_prompts",
......@@ -31,9 +35,9 @@ __all__ = [
def __getattr__(name: str):
if name == "PromptInput":
import warnings
if name == "PromptInput":
msg = ("PromptInput has been renamed to PromptType. "
"The original name will be removed in an upcoming version.")
......@@ -41,4 +45,21 @@ def __getattr__(name: str):
return PromptType
if name == "LLMInputs":
msg = ("LLMInputs has been renamed to DecoderOnlyInputs. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return DecoderOnlyInputs
if name == "EncoderDecoderLLMInputs":
msg = (
"EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return EncoderDecoderInputs
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
Optional, Tuple, Union)
Optional, Tuple, Union, cast)
from typing_extensions import NotRequired, TypedDict, TypeVar
......@@ -51,7 +51,7 @@ class TokensPrompt(TypedDict):
SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
"""
Set of possible schemas for a single LLM input:
Set of possible schemas for a single prompt:
- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
......@@ -120,13 +120,8 @@ both decoder-only and encoder/decoder input types:
"""
class LLMInputs(TypedDict):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the data required for decoder-only models.
"""
class TokenInputs(TypedDict):
"""Represents token-based inputs."""
prompt_token_ids: List[int]
"""The token IDs of the prompt."""
......@@ -150,7 +145,40 @@ class LLMInputs(TypedDict):
"""
class EncoderDecoderLLMInputs(LLMInputs):
def token_inputs(
prompt_token_ids: List[int],
prompt: Optional[str] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> TokenInputs:
"""Construct :class:`TokenInputs` from optional values."""
inputs = TokenInputs(prompt_token_ids=prompt_token_ids)
if prompt is not None:
inputs["prompt"] = prompt
if multi_modal_data is not None:
inputs["multi_modal_data"] = multi_modal_data
if mm_processor_kwargs is not None:
inputs["mm_processor_kwargs"] = mm_processor_kwargs
return inputs
SingletonInputs = TokenInputs
"""
A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`.
"""
DecoderOnlyInputs = TokenInputs
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the data required for decoder-only models.
"""
class EncoderDecoderInputs(TokenInputs):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
......@@ -204,11 +232,12 @@ def zip_enc_dec_prompts(
be zipped with the encoder/decoder prompts.
"""
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
if isinstance(mm_processor_kwargs, Dict):
mm_processor_kwargs = cast(Dict[str, Any], {})
if isinstance(mm_processor_kwargs, dict):
return [
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt,
mm_processor_kwargs)
build_explicit_enc_dec_prompt(
encoder_prompt, decoder_prompt,
cast(Dict[str, Any], mm_processor_kwargs))
for (encoder_prompt,
decoder_prompt) in zip(enc_prompts, dec_prompts)
]
......@@ -229,9 +258,9 @@ def to_enc_dec_tuple_list(
def __getattr__(name: str):
if name == "PromptInput":
import warnings
if name == "PromptInput":
msg = ("PromptInput has been renamed to PromptType. "
"The original name will be removed in an upcoming version.")
......@@ -239,4 +268,21 @@ def __getattr__(name: str):
return PromptType
if name == "LLMInputs":
msg = ("LLMInputs has been renamed to DecoderOnlyInputs. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return DecoderOnlyInputs
if name == "EncoderDecoderLLMInputs":
msg = (
"EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return EncoderDecoderInputs
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
......@@ -4,9 +4,9 @@ from typing_extensions import TypeIs
from vllm.utils import is_list_of
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
LLMInputs, PromptType, SingletonPrompt, TextPrompt,
TokensPrompt)
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
ExplicitEncoderDecoderPrompt, PromptType, SingletonPrompt,
TextPrompt, TokensPrompt)
class ParsedText(TypedDict):
......@@ -100,7 +100,7 @@ def is_explicit_encoder_decoder_prompt(
return isinstance(prompt, dict) and "encoder_prompt" in prompt
def is_valid_encoder_decoder_llm_inputs(
inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
) -> TypeIs[EncoderDecoderLLMInputs]:
def is_encoder_decoder_inputs(
inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
) -> TypeIs[EncoderDecoderInputs]:
return "encoder_prompt_token_ids" in inputs
......@@ -10,7 +10,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.utils import print_warning_once
from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType,
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, PromptType,
SingletonPrompt)
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
......@@ -306,7 +306,7 @@ class InputPreprocessor:
encoder_comps: PromptComponents,
decoder_comps: DecoderPromptComponents,
mm_processor_kwargs: Dict[str, Any],
) -> EncoderDecoderLLMInputs:
) -> EncoderDecoderInputs:
encoder_prompt, encoder_prompt_ids, encoder_mm_data, _ = encoder_comps
decoder_prompt, decoder_prompt_ids, decoder_mm_data, _ = decoder_comps
......@@ -324,7 +324,7 @@ class InputPreprocessor:
decoder_prompt_ids,
force_bos=(encoder_mm_data is None and decoder_mm_data is None)))
return EncoderDecoderLLMInputs(
return EncoderDecoderInputs(
prompt_token_ids=decoder_prompt_ids,
prompt=decoder_prompt,
multi_modal_data=decoder_mm_data,
......@@ -338,11 +338,11 @@ class InputPreprocessor:
self,
prompt: PromptType,
request_id: str,
) -> EncoderDecoderLLMInputs:
) -> EncoderDecoderInputs:
'''
For encoder/decoder models only:
Process an input prompt into an
:class:`EncoderDecoderLLMInputs` instance.
:class:`EncoderDecoderInputs` instance.
There are two types of input prompts:
singleton prompts which carry only the
......@@ -369,7 +369,7 @@ class InputPreprocessor:
Returns:
* :class:`EncoderDecoderLLMInputs` instance
* :class:`EncoderDecoderInputs` instance
'''
encoder_comps: PromptComponents
......@@ -411,7 +411,7 @@ class InputPreprocessor:
self,
prompt: PromptType,
request_id: str,
) -> EncoderDecoderLLMInputs:
) -> EncoderDecoderInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
......@@ -455,14 +455,14 @@ class InputPreprocessor:
self,
prompt_comps: PromptComponents,
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> LLMInputs:
) -> DecoderOnlyInputs:
(prompt, prompt_token_ids, multi_modal_data,
mm_processor_kwargs) = prompt_comps
prompt_token_ids = self._apply_prompt_adapter(
prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
return LLMInputs(prompt_token_ids=prompt_token_ids,
return DecoderOnlyInputs(prompt_token_ids=prompt_token_ids,
prompt=prompt,
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs)
......@@ -473,10 +473,10 @@ class InputPreprocessor:
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
) -> DecoderOnlyInputs:
'''
For decoder-only models:
Process an input prompt into an :class:`LLMInputs` instance.
Process an input prompt into an :class:`DecoderOnlyInputs` instance.
Arguments:
......@@ -487,7 +487,7 @@ class InputPreprocessor:
Returns:
* :class:`LLMInputs` instance
* :class:`DecoderOnlyInputs` instance
'''
prompt_comps = self._extract_prompt_components(
......@@ -507,7 +507,7 @@ class InputPreprocessor:
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
) -> DecoderOnlyInputs:
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._extract_prompt_components_async(
prompt,
......@@ -526,7 +526,7 @@ class InputPreprocessor:
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]:
"""Preprocess the input prompt."""
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
......@@ -554,7 +554,7 @@ class InputPreprocessor:
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]:
"""Async version of :meth:`preprocess`."""
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
......
......@@ -12,7 +12,7 @@ from vllm.logger import init_logger
from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once,
resolve_mm_processor_kwargs)
from .data import LLMInputs
from .data import DecoderOnlyInputs
if TYPE_CHECKING:
from vllm.config import ModelConfig
......@@ -100,7 +100,7 @@ class _MultiModalCounts(UserDict):
raise KeyError(msg) from exc
InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs]
InputProcessor = Callable[[InputContext, DecoderOnlyInputs], DecoderOnlyInputs]
"""Preprocess the inputs to the model."""
......@@ -134,7 +134,7 @@ class InputRegistry:
# Avoid circular import
from vllm.sequence import SequenceData
dummy_seq_data = SequenceData.from_token_counts((0, seq_len))
dummy_seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
dummy_multi_modal_data = None
return dummy_seq_data, dummy_multi_modal_data
......@@ -245,8 +245,11 @@ class InputRegistry:
return seq_data, mm_data
def _default_input_processor(self, ctx: InputContext,
inputs: LLMInputs) -> LLMInputs:
def _default_input_processor(
self,
ctx: InputContext,
inputs: DecoderOnlyInputs,
) -> DecoderOnlyInputs:
"""The default input processor is a no-op."""
return inputs
......@@ -279,7 +282,7 @@ class InputRegistry:
.get(model_cls, self._default_input_processor)
def process_input(self, model_config: "ModelConfig",
inputs: LLMInputs) -> LLMInputs:
inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
"""
Apply an input processor to an instance of model inputs.
......
......@@ -10,7 +10,7 @@ from transformers.models.blip.modeling_blip import BlipAttention
from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import LLMInputs
from vllm.inputs import DecoderOnlyInputs, token_inputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
......@@ -63,7 +63,7 @@ def dummy_seq_data_for_blip(
else:
image_feature_size = image_feature_size_override
return SequenceData.from_token_counts(
return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)
......@@ -89,14 +89,14 @@ def dummy_image_for_blip(
def input_processor_for_blip(
model_config: ModelConfig,
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
llm_inputs: LLMInputs,
inputs: DecoderOnlyInputs,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
multi_modal_data = llm_inputs.get("multi_modal_data")
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
return inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer)
......@@ -107,14 +107,14 @@ def input_processor_for_blip(
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=image_token_id,
repeat_count=image_feature_size,
)
# NOTE: Create a defensive copy of the original inputs
return LLMInputs(prompt_token_ids=new_token_ids,
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
......
......@@ -9,7 +9,8 @@ from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
......@@ -421,7 +422,7 @@ def dummy_seq_data_for_blip2(
else:
image_feature_size = image_feature_size_override
return SequenceData.from_token_counts(
return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)
......@@ -449,10 +450,10 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
raise NotImplementedError(msg)
def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
return inputs
hf_config = ctx.get_hf_config(Blip2Config)
image_feature_size = get_blip2_image_feature_size(hf_config)
......@@ -460,13 +461,13 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
# The original model places image tokens at the front
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
new_token_ids += llm_inputs["prompt_token_ids"]
new_token_ids += inputs["prompt_token_ids"]
new_prompt = llm_inputs.get("prompt")
new_prompt = inputs.get("prompt")
if new_prompt is not None:
new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt
return LLMInputs(prompt_token_ids=new_token_ids,
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
......
......@@ -11,7 +11,8 @@ from transformers import ChameleonConfig, ChameleonVQVAEConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
......@@ -69,7 +70,7 @@ def dummy_seq_data_for_chameleon(
else:
image_feature_size = image_feature_size_override
return SequenceData.from_token_counts(
return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)
......@@ -106,7 +107,8 @@ def dummy_data_for_chameleon(ctx: InputContext, seq_len: int,
return seq_data, mm_data
def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
def input_processor_for_chameleon(ctx: InputContext,
inputs: DecoderOnlyInputs):
"""
Processing input prompt to insert required tokens for image placeholder.
......@@ -114,16 +116,16 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58
""" # noqa
multi_modal_data = llm_inputs.get("multi_modal_data")
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
return inputs
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID,
repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH,
pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID,
......@@ -137,7 +139,7 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
new_token_ids += [CHAMELEON_SEP_TOKEN_ID]
# NOTE: Create a defensive copy of the original inputs
return LLMInputs(prompt_token_ids=new_token_ids,
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
......
......@@ -14,7 +14,7 @@ from torch.nn import LayerNorm
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -149,20 +149,20 @@ def find_all_positions(input_ids: List[int], target: int) -> List[int]:
return [index for index, value in enumerate(input_ids) if value == target]
def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs):
def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
hf_config = ctx.get_hf_config(ChatGLMConfig)
vision_config = getattr(hf_config, 'vision_config', None)
if vision_config is None:
return llm_inputs
return inputs
elif isinstance(vision_config, dict):
image_placeholder_length = calculate_image_placeholder(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
input_ids = llm_inputs.get("prompt_token_ids")
position_ids = llm_inputs.get("position_ids")
input_ids = inputs.get("prompt_token_ids")
position_ids = inputs.get("position_ids")
tokenizer = cached_get_tokenizer(
ctx.model_config.model,
trust_remote_code=ctx.model_config.trust_remote_code)
......@@ -171,15 +171,15 @@ def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs):
raw_batch_data = tokenizer.apply_chat_template(
conversation=[{
"role": "user",
"image": llm_inputs['multi_modal_data']["image"],
"content": llm_inputs['prompt']
"image": inputs['multi_modal_data']["image"],
"content": inputs['prompt']
}],
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True).data
except Exception:
logger.error("Failed to process content (%s)", llm_inputs['prompt'])
logger.error("Failed to process content (%s)", inputs['prompt'])
raise
input_ids = raw_batch_data['input_ids'][0].tolist()
......@@ -214,9 +214,9 @@ def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs):
assert len(new_input_ids) == len(new_position_ids)
llm_inputs["prompt_token_ids"] = new_input_ids
llm_inputs["position_ids"] = new_position_ids
return llm_inputs
inputs["prompt_token_ids"] = new_input_ids
inputs["position_ids"] = new_position_ids
return inputs
class GLMAttention(nn.Module):
......
......@@ -11,7 +11,7 @@ from transformers.models.clip.modeling_clip import CLIPSdpaAttention
from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import LLMInputs
from vllm.inputs import DecoderOnlyInputs, token_inputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
......@@ -62,7 +62,7 @@ def dummy_seq_data_for_clip(
else:
image_feature_size = image_feature_size_override
return SequenceData.from_token_counts(
return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
)
......@@ -106,14 +106,14 @@ def dummy_video_for_clip(
def input_processor_for_clip(
model_config: ModelConfig,
hf_config: CLIPVisionConfig,
llm_inputs: LLMInputs,
inputs: DecoderOnlyInputs,
*,
image_token_id: int,
image_feature_size_override: Optional[Union[int, List[int]]] = None,
):
multi_modal_data = llm_inputs.get("multi_modal_data")
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
return inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer)
......@@ -130,14 +130,14 @@ def input_processor_for_clip(
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=image_token_id,
repeat_count=image_feature_size,
)
# NOTE: Create a defensive copy of the original inputs
return LLMInputs(prompt_token_ids=new_token_ids,
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
......
......@@ -27,7 +27,8 @@ from transformers import FuyuConfig, FuyuImageProcessor
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput
......@@ -149,10 +150,10 @@ def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
return model_image_input
def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
return inputs
model_config = ctx.model_config
image_data = multi_modal_data["image"]
......@@ -176,8 +177,8 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
raise TypeError(f"Invalid image type: {type(image_data)}")
# process prompts
prompt = llm_inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"]
prompt = inputs.get("prompt")
prompt_token_ids = inputs["prompt_token_ids"]
tokenizer = cached_get_tokenizer(model_config.model)
# dim0 is batch_size, dim1 is subseq_size which will always be 1
image_input_ids: List[List[
......@@ -190,7 +191,7 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[
1:] + boa_token
return LLMInputs(prompt=new_prompt,
return token_inputs(prompt=new_prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=new_multi_modal_data)
......
......@@ -17,7 +17,8 @@ from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.models.intern_vit import InternVisionModel
......@@ -276,13 +277,13 @@ class InternVLInputPipeline:
def input_processor(
self,
ctx: InputContext,
llm_inputs: LLMInputs,
inputs: DecoderOnlyInputs,
*,
max_dynamic_patch: Optional[int] = None,
) -> LLMInputs:
multi_modal_data = llm_inputs.get("multi_modal_data")
) -> DecoderOnlyInputs:
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
return inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config()
......@@ -311,8 +312,8 @@ class InternVLInputPipeline:
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
prompt = llm_inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"]
prompt = inputs.get("prompt")
prompt_token_ids = inputs["prompt_token_ids"]
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
......@@ -320,7 +321,7 @@ class InternVLInputPipeline:
num_patches)
new_prompt_token_ids = tokenizer.encode(new_prompt)
return LLMInputs(prompt=prompt,
return token_inputs(prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data)
......
......@@ -9,7 +9,7 @@ from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
......@@ -125,10 +125,10 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
raise NotImplementedError(msg)
def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
return inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaConfig)
......@@ -151,7 +151,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
return input_processor_for_clip(
model_config,
vision_config,
llm_inputs,
inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
......@@ -159,7 +159,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
return input_processor_for_siglip(
model_config,
vision_config,
llm_inputs,
inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
......
......@@ -12,7 +12,7 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -201,10 +201,11 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
raise NotImplementedError(msg)
def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
def input_processor_for_llava_next(ctx: InputContext,
inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
return inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaNextConfig)
......@@ -239,7 +240,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
return input_processor_for_clip(
model_config,
vision_config,
llm_inputs,
inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
......@@ -247,7 +248,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
return input_processor_for_siglip(
model_config,
vision_config,
llm_inputs,
inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
......
......@@ -11,7 +11,8 @@ from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
......@@ -139,10 +140,10 @@ def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int,
def input_processor_for_llava_next_video(ctx: InputContext,
llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "video" not in multi_modal_data:
return llm_inputs
return inputs
video_data = multi_modal_data["video"]
model_config = ctx.model_config
......@@ -160,13 +161,13 @@ def input_processor_for_llava_next_video(ctx: InputContext,
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=hf_config.video_token_index,
repeat_count=video_feature_size,
)
return LLMInputs(prompt_token_ids=new_token_ids,
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment