Unverified Commit cee711fd authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Rename input data types (#8688)

parent 1de76a0e
...@@ -25,7 +25,7 @@ Module Contents ...@@ -25,7 +25,7 @@ Module Contents
LLM Engine Inputs LLM Engine Inputs
----------------- -----------------
.. autoclass:: vllm.inputs.LLMInputs .. autoclass:: vllm.inputs.DecoderOnlyInputs
:members: :members:
:show-inheritance: :show-inheritance:
......
import os import os
import re import re
from typing import Callable, List, Optional, Tuple, Type from typing import List, Optional, Tuple, Type
import pytest import pytest
import torch import torch
from transformers import AutoImageProcessor, AutoTokenizer from transformers import AutoImageProcessor, AutoTokenizer
from vllm.inputs import InputContext, LLMInputs from vllm.inputs import InputContext, token_inputs
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
from vllm.multimodal import MultiModalRegistry from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.utils import rescale_image_size from vllm.multimodal.utils import rescale_image_size
...@@ -311,7 +311,7 @@ def test_input_mapper_override(model: str, image_assets: _ImageAssets, ...@@ -311,7 +311,7 @@ def test_input_mapper_override(model: str, image_assets: _ImageAssets,
(4, 781), (4, 781),
(16, 2653), (16, 2653),
]) ])
def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str, def test_max_tokens_override(get_max_phi3v_image_tokens, model: str,
num_crops: int, expected_max_tokens: int): num_crops: int, expected_max_tokens: int):
"""Ensure get_max_phi3v_image_tokens handles num_crops properly.""" """Ensure get_max_phi3v_image_tokens handles num_crops properly."""
# NOTE: mm_processor_kwargs on the context in this test is unused, since # NOTE: mm_processor_kwargs on the context in this test is unused, since
...@@ -343,8 +343,8 @@ def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str, ...@@ -343,8 +343,8 @@ def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str,
(16, 2653, 1), (16, 2653, 1),
(16, 2653, 2), (16, 2653, 2),
]) ])
def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str, def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int,
num_crops: int, toks_per_img: int, num_imgs: int): toks_per_img: int, num_imgs: int):
"""Ensure dummy_data_for_phi3v handles num_crops properly.""" """Ensure dummy_data_for_phi3v handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs # Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by # in this test and assume that the kwargs will be correctly expanded by
...@@ -374,7 +374,7 @@ def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str, ...@@ -374,7 +374,7 @@ def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str,
(16, 1921, 1), (16, 1921, 1),
(16, 1921, 2), (16, 1921, 2),
]) ])
def test_input_processor_override(input_processor_for_phi3v: Callable, def test_input_processor_override(input_processor_for_phi3v,
image_assets: _ImageAssets, model: str, image_assets: _ImageAssets, model: str,
num_crops: int, expected_toks_per_img: int, num_crops: int, expected_toks_per_img: int,
num_imgs: int): num_imgs: int):
...@@ -393,16 +393,14 @@ def test_input_processor_override(input_processor_for_phi3v: Callable, ...@@ -393,16 +393,14 @@ def test_input_processor_override(input_processor_for_phi3v: Callable,
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
images = [image_assets[0].pil_image] * num_imgs images = [image_assets[0].pil_image] * num_imgs
llm_inputs = LLMInputs(prompt_token_ids=tokenizer.encode(prompt), inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
prompt=prompt, prompt=prompt,
multi_modal_data={"image": images}) multi_modal_data={"image": images})
proc_llm_inputs = input_processor_for_phi3v( processed_inputs = input_processor_for_phi3v(ctx,
ctx=ctx, inputs,
llm_inputs=llm_inputs, num_crops=num_crops)
num_crops=num_crops,
)
# Ensure we have the right number of placeholders per num_crops size # Ensure we have the right number of placeholders per num_crops size
img_tok_count = proc_llm_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID) img_tok_count = processed_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID)
assert img_tok_count == expected_toks_per_img * num_imgs assert img_tok_count == expected_toks_per_img * num_imgs
...@@ -5,7 +5,7 @@ import pytest ...@@ -5,7 +5,7 @@ import pytest
import torch import torch
from PIL.Image import Image from PIL.Image import Image
from vllm.inputs import InputContext, LLMInputs from vllm.inputs import InputContext, token_inputs
from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size
...@@ -71,12 +71,12 @@ def test_input_processor_valid_mm_data(input_processor_for_qwen, ...@@ -71,12 +71,12 @@ def test_input_processor_valid_mm_data(input_processor_for_qwen,
"""Happy cases for image inputs to Qwen's multimodal input processor.""" """Happy cases for image inputs to Qwen's multimodal input processor."""
prompt = "".join( prompt = "".join(
[f"Picture {num}: <img></img>\n" for num in range(1, num_images + 1)]) [f"Picture {num}: <img></img>\n" for num in range(1, num_images + 1)])
inputs = LLMInputs( inputs = token_inputs(
prompt=prompt, prompt=prompt,
# When processing multimodal data for a multimodal model, the qwen # When processing multimodal data for a multimodal model, the qwen
# input processor will overwrite the provided prompt_token_ids with # input processor will overwrite the provided prompt_token_ids with
# the image prompts # the image prompts
prompt_token_ids=None, prompt_token_ids=[],
multi_modal_data={"image": torch.rand(num_images, TOKS_PER_IMG, 4096)}, multi_modal_data={"image": torch.rand(num_images, TOKS_PER_IMG, 4096)},
) )
proc_inputs = input_processor_for_qwen(qwen_vl_context, inputs) proc_inputs = input_processor_for_qwen(qwen_vl_context, inputs)
...@@ -134,9 +134,9 @@ def test_input_processor_invalid_mm_data(input_processor_for_qwen, ...@@ -134,9 +134,9 @@ def test_input_processor_invalid_mm_data(input_processor_for_qwen,
trust_remote_code=True) trust_remote_code=True)
prompt = "Picture 1: <img></img>\n" prompt = "Picture 1: <img></img>\n"
prompt_token_ids = tokenizer.encode(prompt) prompt_token_ids = tokenizer.encode(prompt)
inputs = LLMInputs(prompt=prompt, inputs = token_inputs(prompt=prompt,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
multi_modal_data=mm_data) multi_modal_data=mm_data)
# Should fail since we have too many or too few dimensions for embeddings # Should fail since we have too many or too few dimensions for embeddings
with pytest.raises(ValueError): with pytest.raises(ValueError):
input_processor_for_qwen(qwen_vl_context, inputs) input_processor_for_qwen(qwen_vl_context, inputs)
......
...@@ -5,7 +5,7 @@ from unittest.mock import patch ...@@ -5,7 +5,7 @@ from unittest.mock import patch
import pytest import pytest
import torch import torch
from vllm.inputs import InputContext, LLMInputs from vllm.inputs import DecoderOnlyInputs, InputContext, token_inputs
from vllm.inputs.registry import InputRegistry from vllm.inputs.registry import InputRegistry
from vllm.multimodal import MultiModalRegistry from vllm.multimodal import MultiModalRegistry
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
...@@ -31,7 +31,7 @@ def use_processor_mock(): ...@@ -31,7 +31,7 @@ def use_processor_mock():
"""Patches the internal model input processor with an override callable.""" """Patches the internal model input processor with an override callable."""
def custom_processor(ctx: InputContext, def custom_processor(ctx: InputContext,
llm_inputs: LLMInputs, inputs: DecoderOnlyInputs,
*, *,
num_crops=DEFAULT_NUM_CROPS): num_crops=DEFAULT_NUM_CROPS):
# For testing purposes, we don't worry about the llm inputs / return # For testing purposes, we don't worry about the llm inputs / return
...@@ -84,7 +84,7 @@ def test_default_processor_is_a_noop(): ...@@ -84,7 +84,7 @@ def test_default_processor_is_a_noop():
dummy_registry = InputRegistry() dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID) ctx = build_model_context(DUMMY_MODEL_ID)
processor = dummy_registry.create_input_processor(ctx.model_config) processor = dummy_registry.create_input_processor(ctx.model_config)
proc_inputs = LLMInputs(prompt_token_ids=[], prompt="") proc_inputs = token_inputs(prompt_token_ids=[], prompt="")
proc_outputs = processor(inputs=proc_inputs) proc_outputs = processor(inputs=proc_inputs)
assert proc_inputs is proc_outputs assert proc_inputs is proc_outputs
...@@ -125,9 +125,9 @@ def test_input_processor_kwargs(use_processor_mock, init_num_crops, ...@@ -125,9 +125,9 @@ def test_input_processor_kwargs(use_processor_mock, init_num_crops,
ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs) ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
processor = dummy_registry.create_input_processor(ctx.model_config) processor = dummy_registry.create_input_processor(ctx.model_config)
num_crops_val = processor( num_crops_val = processor(
LLMInputs(prompt_token_ids=[], token_inputs(prompt_token_ids=[],
prompt="", prompt="",
mm_processor_kwargs=inference_kwargs)) mm_processor_kwargs=inference_kwargs))
assert num_crops_val == expected_seq_count assert num_crops_val == expected_seq_count
...@@ -154,9 +154,9 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock, ...@@ -154,9 +154,9 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock,
processor = dummy_registry.create_input_processor(ctx.model_config) processor = dummy_registry.create_input_processor(ctx.model_config)
# Should filter out the inference time kwargs # Should filter out the inference time kwargs
num_crops_val = processor( num_crops_val = processor(
LLMInputs(prompt_token_ids=[], token_inputs(prompt_token_ids=[],
prompt="", prompt="",
mm_processor_kwargs=mm_processor_kwargs)) mm_processor_kwargs=mm_processor_kwargs))
assert num_crops_val == DEFAULT_NUM_CROPS assert num_crops_val == DEFAULT_NUM_CROPS
......
...@@ -29,8 +29,8 @@ from vllm.entrypoints.openai.logits_processors import get_logits_processors ...@@ -29,8 +29,8 @@ from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
InputRegistry, LLMInputs, PromptType) EncoderDecoderInputs, InputRegistry, PromptType)
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -635,7 +635,7 @@ class LLMEngine: ...@@ -635,7 +635,7 @@ class LLMEngine:
def _add_processed_request( def _add_processed_request(
self, self,
request_id: str, request_id: str,
processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs], processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
...@@ -1855,8 +1855,8 @@ class LLMEngine: ...@@ -1855,8 +1855,8 @@ class LLMEngine:
def is_embedding_model(self): def is_embedding_model(self):
return self.model_config.is_embedding_model return self.model_config.is_embedding_model
def _validate_model_inputs(self, inputs: Union[LLMInputs, def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
EncoderDecoderLLMInputs]): EncoderDecoderInputs]):
if self.model_config.is_multimodal_model: if self.model_config.is_multimodal_model:
# For encoder-decoder multimodal models, the max_prompt_len # For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length # restricts the decoder prompt length
......
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
LLMInputs, PromptType, SingletonPrompt, TextPrompt, ExplicitEncoderDecoderPrompt, PromptType, SingletonInputs,
TokensPrompt, build_explicit_enc_dec_prompt, SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts) build_explicit_enc_dec_prompt, to_enc_dec_tuple_list,
token_inputs, zip_enc_dec_prompts)
from .registry import InputContext, InputRegistry from .registry import InputContext, InputRegistry
INPUT_REGISTRY = InputRegistry() INPUT_REGISTRY = InputRegistry()
...@@ -19,8 +20,11 @@ __all__ = [ ...@@ -19,8 +20,11 @@ __all__ = [
"PromptType", "PromptType",
"SingletonPrompt", "SingletonPrompt",
"ExplicitEncoderDecoderPrompt", "ExplicitEncoderDecoderPrompt",
"LLMInputs", "TokenInputs",
"EncoderDecoderLLMInputs", "token_inputs",
"SingletonInputs",
"DecoderOnlyInputs",
"EncoderDecoderInputs",
"build_explicit_enc_dec_prompt", "build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list", "to_enc_dec_tuple_list",
"zip_enc_dec_prompts", "zip_enc_dec_prompts",
...@@ -31,9 +35,9 @@ __all__ = [ ...@@ -31,9 +35,9 @@ __all__ = [
def __getattr__(name: str): def __getattr__(name: str):
if name == "PromptInput": import warnings
import warnings
if name == "PromptInput":
msg = ("PromptInput has been renamed to PromptType. " msg = ("PromptInput has been renamed to PromptType. "
"The original name will be removed in an upcoming version.") "The original name will be removed in an upcoming version.")
...@@ -41,4 +45,21 @@ def __getattr__(name: str): ...@@ -41,4 +45,21 @@ def __getattr__(name: str):
return PromptType return PromptType
if name == "LLMInputs":
msg = ("LLMInputs has been renamed to DecoderOnlyInputs. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return DecoderOnlyInputs
if name == "EncoderDecoderLLMInputs":
msg = (
"EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return EncoderDecoderInputs
raise AttributeError(f"module {__name__!r} has no attribute {name!r}") raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
Optional, Tuple, Union) Optional, Tuple, Union, cast)
from typing_extensions import NotRequired, TypedDict, TypeVar from typing_extensions import NotRequired, TypedDict, TypeVar
...@@ -51,7 +51,7 @@ class TokensPrompt(TypedDict): ...@@ -51,7 +51,7 @@ class TokensPrompt(TypedDict):
SingletonPrompt = Union[str, TextPrompt, TokensPrompt] SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
""" """
Set of possible schemas for a single LLM input: Set of possible schemas for a single prompt:
- A text prompt (:class:`str` or :class:`TextPrompt`) - A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`) - A tokenized prompt (:class:`TokensPrompt`)
...@@ -120,13 +120,8 @@ both decoder-only and encoder/decoder input types: ...@@ -120,13 +120,8 @@ both decoder-only and encoder/decoder input types:
""" """
class LLMInputs(TypedDict): class TokenInputs(TypedDict):
""" """Represents token-based inputs."""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the data required for decoder-only models.
"""
prompt_token_ids: List[int] prompt_token_ids: List[int]
"""The token IDs of the prompt.""" """The token IDs of the prompt."""
...@@ -150,7 +145,40 @@ class LLMInputs(TypedDict): ...@@ -150,7 +145,40 @@ class LLMInputs(TypedDict):
""" """
class EncoderDecoderLLMInputs(LLMInputs): def token_inputs(
prompt_token_ids: List[int],
prompt: Optional[str] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> TokenInputs:
"""Construct :class:`TokenInputs` from optional values."""
inputs = TokenInputs(prompt_token_ids=prompt_token_ids)
if prompt is not None:
inputs["prompt"] = prompt
if multi_modal_data is not None:
inputs["multi_modal_data"] = multi_modal_data
if mm_processor_kwargs is not None:
inputs["mm_processor_kwargs"] = mm_processor_kwargs
return inputs
SingletonInputs = TokenInputs
"""
A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`.
"""
DecoderOnlyInputs = TokenInputs
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the data required for decoder-only models.
"""
class EncoderDecoderInputs(TokenInputs):
""" """
The inputs in :class:`~vllm.LLMEngine` before they are The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor. passed to the model executor.
...@@ -204,11 +232,12 @@ def zip_enc_dec_prompts( ...@@ -204,11 +232,12 @@ def zip_enc_dec_prompts(
be zipped with the encoder/decoder prompts. be zipped with the encoder/decoder prompts.
""" """
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
mm_processor_kwargs = {} mm_processor_kwargs = cast(Dict[str, Any], {})
if isinstance(mm_processor_kwargs, Dict): if isinstance(mm_processor_kwargs, dict):
return [ return [
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, build_explicit_enc_dec_prompt(
mm_processor_kwargs) encoder_prompt, decoder_prompt,
cast(Dict[str, Any], mm_processor_kwargs))
for (encoder_prompt, for (encoder_prompt,
decoder_prompt) in zip(enc_prompts, dec_prompts) decoder_prompt) in zip(enc_prompts, dec_prompts)
] ]
...@@ -229,9 +258,9 @@ def to_enc_dec_tuple_list( ...@@ -229,9 +258,9 @@ def to_enc_dec_tuple_list(
def __getattr__(name: str): def __getattr__(name: str):
if name == "PromptInput": import warnings
import warnings
if name == "PromptInput":
msg = ("PromptInput has been renamed to PromptType. " msg = ("PromptInput has been renamed to PromptType. "
"The original name will be removed in an upcoming version.") "The original name will be removed in an upcoming version.")
...@@ -239,4 +268,21 @@ def __getattr__(name: str): ...@@ -239,4 +268,21 @@ def __getattr__(name: str):
return PromptType return PromptType
if name == "LLMInputs":
msg = ("LLMInputs has been renamed to DecoderOnlyInputs. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return DecoderOnlyInputs
if name == "EncoderDecoderLLMInputs":
msg = (
"EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return EncoderDecoderInputs
raise AttributeError(f"module {__name__!r} has no attribute {name!r}") raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
...@@ -4,9 +4,9 @@ from typing_extensions import TypeIs ...@@ -4,9 +4,9 @@ from typing_extensions import TypeIs
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
LLMInputs, PromptType, SingletonPrompt, TextPrompt, ExplicitEncoderDecoderPrompt, PromptType, SingletonPrompt,
TokensPrompt) TextPrompt, TokensPrompt)
class ParsedText(TypedDict): class ParsedText(TypedDict):
...@@ -100,7 +100,7 @@ def is_explicit_encoder_decoder_prompt( ...@@ -100,7 +100,7 @@ def is_explicit_encoder_decoder_prompt(
return isinstance(prompt, dict) and "encoder_prompt" in prompt return isinstance(prompt, dict) and "encoder_prompt" in prompt
def is_valid_encoder_decoder_llm_inputs( def is_encoder_decoder_inputs(
inputs: Union[LLMInputs, EncoderDecoderLLMInputs], inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
) -> TypeIs[EncoderDecoderLLMInputs]: ) -> TypeIs[EncoderDecoderInputs]:
return "encoder_prompt_token_ids" in inputs return "encoder_prompt_token_ids" in inputs
...@@ -10,7 +10,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest ...@@ -10,7 +10,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType, from .data import (DecoderOnlyInputs, EncoderDecoderInputs, PromptType,
SingletonPrompt) SingletonPrompt)
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
...@@ -306,7 +306,7 @@ class InputPreprocessor: ...@@ -306,7 +306,7 @@ class InputPreprocessor:
encoder_comps: PromptComponents, encoder_comps: PromptComponents,
decoder_comps: DecoderPromptComponents, decoder_comps: DecoderPromptComponents,
mm_processor_kwargs: Dict[str, Any], mm_processor_kwargs: Dict[str, Any],
) -> EncoderDecoderLLMInputs: ) -> EncoderDecoderInputs:
encoder_prompt, encoder_prompt_ids, encoder_mm_data, _ = encoder_comps encoder_prompt, encoder_prompt_ids, encoder_mm_data, _ = encoder_comps
decoder_prompt, decoder_prompt_ids, decoder_mm_data, _ = decoder_comps decoder_prompt, decoder_prompt_ids, decoder_mm_data, _ = decoder_comps
...@@ -324,7 +324,7 @@ class InputPreprocessor: ...@@ -324,7 +324,7 @@ class InputPreprocessor:
decoder_prompt_ids, decoder_prompt_ids,
force_bos=(encoder_mm_data is None and decoder_mm_data is None))) force_bos=(encoder_mm_data is None and decoder_mm_data is None)))
return EncoderDecoderLLMInputs( return EncoderDecoderInputs(
prompt_token_ids=decoder_prompt_ids, prompt_token_ids=decoder_prompt_ids,
prompt=decoder_prompt, prompt=decoder_prompt,
multi_modal_data=decoder_mm_data, multi_modal_data=decoder_mm_data,
...@@ -338,11 +338,11 @@ class InputPreprocessor: ...@@ -338,11 +338,11 @@ class InputPreprocessor:
self, self,
prompt: PromptType, prompt: PromptType,
request_id: str, request_id: str,
) -> EncoderDecoderLLMInputs: ) -> EncoderDecoderInputs:
''' '''
For encoder/decoder models only: For encoder/decoder models only:
Process an input prompt into an Process an input prompt into an
:class:`EncoderDecoderLLMInputs` instance. :class:`EncoderDecoderInputs` instance.
There are two types of input prompts: There are two types of input prompts:
singleton prompts which carry only the singleton prompts which carry only the
...@@ -369,7 +369,7 @@ class InputPreprocessor: ...@@ -369,7 +369,7 @@ class InputPreprocessor:
Returns: Returns:
* :class:`EncoderDecoderLLMInputs` instance * :class:`EncoderDecoderInputs` instance
''' '''
encoder_comps: PromptComponents encoder_comps: PromptComponents
...@@ -411,7 +411,7 @@ class InputPreprocessor: ...@@ -411,7 +411,7 @@ class InputPreprocessor:
self, self,
prompt: PromptType, prompt: PromptType,
request_id: str, request_id: str,
) -> EncoderDecoderLLMInputs: ) -> EncoderDecoderInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`.""" """Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps: PromptComponents encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents decoder_comps: DecoderPromptComponents
...@@ -455,17 +455,17 @@ class InputPreprocessor: ...@@ -455,17 +455,17 @@ class InputPreprocessor:
self, self,
prompt_comps: PromptComponents, prompt_comps: PromptComponents,
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
) -> LLMInputs: ) -> DecoderOnlyInputs:
(prompt, prompt_token_ids, multi_modal_data, (prompt, prompt_token_ids, multi_modal_data,
mm_processor_kwargs) = prompt_comps mm_processor_kwargs) = prompt_comps
prompt_token_ids = self._apply_prompt_adapter( prompt_token_ids = self._apply_prompt_adapter(
prompt_token_ids, prompt_adapter_request=prompt_adapter_request) prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
return LLMInputs(prompt_token_ids=prompt_token_ids, return DecoderOnlyInputs(prompt_token_ids=prompt_token_ids,
prompt=prompt, prompt=prompt,
multi_modal_data=multi_modal_data, multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs) mm_processor_kwargs=mm_processor_kwargs)
def _process_decoder_only_prompt( def _process_decoder_only_prompt(
self, self,
...@@ -473,10 +473,10 @@ class InputPreprocessor: ...@@ -473,10 +473,10 @@ class InputPreprocessor:
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs: ) -> DecoderOnlyInputs:
''' '''
For decoder-only models: For decoder-only models:
Process an input prompt into an :class:`LLMInputs` instance. Process an input prompt into an :class:`DecoderOnlyInputs` instance.
Arguments: Arguments:
...@@ -487,7 +487,7 @@ class InputPreprocessor: ...@@ -487,7 +487,7 @@ class InputPreprocessor:
Returns: Returns:
* :class:`LLMInputs` instance * :class:`DecoderOnlyInputs` instance
''' '''
prompt_comps = self._extract_prompt_components( prompt_comps = self._extract_prompt_components(
...@@ -507,7 +507,7 @@ class InputPreprocessor: ...@@ -507,7 +507,7 @@ class InputPreprocessor:
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs: ) -> DecoderOnlyInputs:
"""Async version of :meth:`_process_decoder_only_prompt`.""" """Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._extract_prompt_components_async( prompt_comps = await self._extract_prompt_components_async(
prompt, prompt,
...@@ -526,7 +526,7 @@ class InputPreprocessor: ...@@ -526,7 +526,7 @@ class InputPreprocessor:
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]: ) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]:
"""Preprocess the input prompt.""" """Preprocess the input prompt."""
if self.is_encoder_decoder_model(): if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
...@@ -554,7 +554,7 @@ class InputPreprocessor: ...@@ -554,7 +554,7 @@ class InputPreprocessor:
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]: ) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]:
"""Async version of :meth:`preprocess`.""" """Async version of :meth:`preprocess`."""
if self.is_encoder_decoder_model(): if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
......
...@@ -12,7 +12,7 @@ from vllm.logger import init_logger ...@@ -12,7 +12,7 @@ from vllm.logger import init_logger
from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once, from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once,
resolve_mm_processor_kwargs) resolve_mm_processor_kwargs)
from .data import LLMInputs from .data import DecoderOnlyInputs
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -100,7 +100,7 @@ class _MultiModalCounts(UserDict): ...@@ -100,7 +100,7 @@ class _MultiModalCounts(UserDict):
raise KeyError(msg) from exc raise KeyError(msg) from exc
InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs] InputProcessor = Callable[[InputContext, DecoderOnlyInputs], DecoderOnlyInputs]
"""Preprocess the inputs to the model.""" """Preprocess the inputs to the model."""
...@@ -134,7 +134,7 @@ class InputRegistry: ...@@ -134,7 +134,7 @@ class InputRegistry:
# Avoid circular import # Avoid circular import
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
dummy_seq_data = SequenceData.from_token_counts((0, seq_len)) dummy_seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
dummy_multi_modal_data = None dummy_multi_modal_data = None
return dummy_seq_data, dummy_multi_modal_data return dummy_seq_data, dummy_multi_modal_data
...@@ -245,8 +245,11 @@ class InputRegistry: ...@@ -245,8 +245,11 @@ class InputRegistry:
return seq_data, mm_data return seq_data, mm_data
def _default_input_processor(self, ctx: InputContext, def _default_input_processor(
inputs: LLMInputs) -> LLMInputs: self,
ctx: InputContext,
inputs: DecoderOnlyInputs,
) -> DecoderOnlyInputs:
"""The default input processor is a no-op.""" """The default input processor is a no-op."""
return inputs return inputs
...@@ -279,7 +282,7 @@ class InputRegistry: ...@@ -279,7 +282,7 @@ class InputRegistry:
.get(model_cls, self._default_input_processor) .get(model_cls, self._default_input_processor)
def process_input(self, model_config: "ModelConfig", def process_input(self, model_config: "ModelConfig",
inputs: LLMInputs) -> LLMInputs: inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
""" """
Apply an input processor to an instance of model inputs. Apply an input processor to an instance of model inputs.
......
...@@ -10,7 +10,7 @@ from transformers.models.blip.modeling_blip import BlipAttention ...@@ -10,7 +10,7 @@ from transformers.models.blip.modeling_blip import BlipAttention
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import LLMInputs from vllm.inputs import DecoderOnlyInputs, token_inputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -63,7 +63,7 @@ def dummy_seq_data_for_blip( ...@@ -63,7 +63,7 @@ def dummy_seq_data_for_blip(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
return SequenceData.from_token_counts( return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images), (image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images), (0, seq_len - image_feature_size * num_images),
) )
...@@ -89,14 +89,14 @@ def dummy_image_for_blip( ...@@ -89,14 +89,14 @@ def dummy_image_for_blip(
def input_processor_for_blip( def input_processor_for_blip(
model_config: ModelConfig, model_config: ModelConfig,
hf_config: Union[BlipVisionConfig, Blip2VisionConfig], hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
llm_inputs: LLMInputs, inputs: DecoderOnlyInputs,
*, *,
image_token_id: int, image_token_id: int,
image_feature_size_override: Optional[int] = None, image_feature_size_override: Optional[int] = None,
): ):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs return inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
...@@ -107,16 +107,16 @@ def input_processor_for_blip( ...@@ -107,16 +107,16 @@ def input_processor_for_blip(
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
llm_inputs.get("prompt"), inputs.get("prompt"),
llm_inputs["prompt_token_ids"], inputs["prompt_token_ids"],
placeholder_token_id=image_token_id, placeholder_token_id=image_token_id,
repeat_count=image_feature_size, repeat_count=image_feature_size,
) )
# NOTE: Create a defensive copy of the original inputs # NOTE: Create a defensive copy of the original inputs
return LLMInputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data)
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
......
...@@ -9,7 +9,8 @@ from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig, ...@@ -9,7 +9,8 @@ from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
...@@ -421,7 +422,7 @@ def dummy_seq_data_for_blip2( ...@@ -421,7 +422,7 @@ def dummy_seq_data_for_blip2(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
return SequenceData.from_token_counts( return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images), (image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images), (0, seq_len - image_feature_size * num_images),
) )
...@@ -449,10 +450,10 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int, ...@@ -449,10 +450,10 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
raise NotImplementedError(msg) raise NotImplementedError(msg)
def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs): def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs return inputs
hf_config = ctx.get_hf_config(Blip2Config) hf_config = ctx.get_hf_config(Blip2Config)
image_feature_size = get_blip2_image_feature_size(hf_config) image_feature_size = get_blip2_image_feature_size(hf_config)
...@@ -460,15 +461,15 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -460,15 +461,15 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
# The original model places image tokens at the front # The original model places image tokens at the front
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514 # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
new_token_ids += llm_inputs["prompt_token_ids"] new_token_ids += inputs["prompt_token_ids"]
new_prompt = llm_inputs.get("prompt") new_prompt = inputs.get("prompt")
if new_prompt is not None: if new_prompt is not None:
new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt
return LLMInputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data)
@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_image_input_mapper()
......
...@@ -11,7 +11,8 @@ from transformers import ChameleonConfig, ChameleonVQVAEConfig ...@@ -11,7 +11,8 @@ from transformers import ChameleonConfig, ChameleonVQVAEConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...@@ -69,7 +70,7 @@ def dummy_seq_data_for_chameleon( ...@@ -69,7 +70,7 @@ def dummy_seq_data_for_chameleon(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
return SequenceData.from_token_counts( return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images), (image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images), (0, seq_len - image_feature_size * num_images),
) )
...@@ -106,7 +107,8 @@ def dummy_data_for_chameleon(ctx: InputContext, seq_len: int, ...@@ -106,7 +107,8 @@ def dummy_data_for_chameleon(ctx: InputContext, seq_len: int,
return seq_data, mm_data return seq_data, mm_data
def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs): def input_processor_for_chameleon(ctx: InputContext,
inputs: DecoderOnlyInputs):
""" """
Processing input prompt to insert required tokens for image placeholder. Processing input prompt to insert required tokens for image placeholder.
...@@ -114,16 +116,16 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -114,16 +116,16 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58 See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58
""" # noqa """ # noqa
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs return inputs
model_config = ctx.model_config model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
llm_inputs.get("prompt"), inputs.get("prompt"),
llm_inputs["prompt_token_ids"], inputs["prompt_token_ids"],
placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID, placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID,
repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH, repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH,
pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID, pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID,
...@@ -137,9 +139,9 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -137,9 +139,9 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
new_token_ids += [CHAMELEON_SEP_TOKEN_ID] new_token_ids += [CHAMELEON_SEP_TOKEN_ID]
# NOTE: Create a defensive copy of the original inputs # NOTE: Create a defensive copy of the original inputs
return LLMInputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data)
class ChameleonLayerNorm(nn.LayerNorm): class ChameleonLayerNorm(nn.LayerNorm):
......
...@@ -14,7 +14,7 @@ from torch.nn import LayerNorm ...@@ -14,7 +14,7 @@ from torch.nn import LayerNorm
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -149,20 +149,20 @@ def find_all_positions(input_ids: List[int], target: int) -> List[int]: ...@@ -149,20 +149,20 @@ def find_all_positions(input_ids: List[int], target: int) -> List[int]:
return [index for index, value in enumerate(input_ids) if value == target] return [index for index, value in enumerate(input_ids) if value == target]
def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs): def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
hf_config = ctx.get_hf_config(ChatGLMConfig) hf_config = ctx.get_hf_config(ChatGLMConfig)
vision_config = getattr(hf_config, 'vision_config', None) vision_config = getattr(hf_config, 'vision_config', None)
if vision_config is None: if vision_config is None:
return llm_inputs return inputs
elif isinstance(vision_config, dict): elif isinstance(vision_config, dict):
image_placeholder_length = calculate_image_placeholder(vision_config) image_placeholder_length = calculate_image_placeholder(vision_config)
else: else:
msg = f"Unsupported vision config: {type(vision_config)}" msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg) raise NotImplementedError(msg)
input_ids = llm_inputs.get("prompt_token_ids") input_ids = inputs.get("prompt_token_ids")
position_ids = llm_inputs.get("position_ids") position_ids = inputs.get("position_ids")
tokenizer = cached_get_tokenizer( tokenizer = cached_get_tokenizer(
ctx.model_config.model, ctx.model_config.model,
trust_remote_code=ctx.model_config.trust_remote_code) trust_remote_code=ctx.model_config.trust_remote_code)
...@@ -171,15 +171,15 @@ def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -171,15 +171,15 @@ def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs):
raw_batch_data = tokenizer.apply_chat_template( raw_batch_data = tokenizer.apply_chat_template(
conversation=[{ conversation=[{
"role": "user", "role": "user",
"image": llm_inputs['multi_modal_data']["image"], "image": inputs['multi_modal_data']["image"],
"content": llm_inputs['prompt'] "content": inputs['prompt']
}], }],
add_generation_prompt=True, add_generation_prompt=True,
tokenize=True, tokenize=True,
return_tensors="pt", return_tensors="pt",
return_dict=True).data return_dict=True).data
except Exception: except Exception:
logger.error("Failed to process content (%s)", llm_inputs['prompt']) logger.error("Failed to process content (%s)", inputs['prompt'])
raise raise
input_ids = raw_batch_data['input_ids'][0].tolist() input_ids = raw_batch_data['input_ids'][0].tolist()
...@@ -214,9 +214,9 @@ def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -214,9 +214,9 @@ def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs):
assert len(new_input_ids) == len(new_position_ids) assert len(new_input_ids) == len(new_position_ids)
llm_inputs["prompt_token_ids"] = new_input_ids inputs["prompt_token_ids"] = new_input_ids
llm_inputs["position_ids"] = new_position_ids inputs["position_ids"] = new_position_ids
return llm_inputs return inputs
class GLMAttention(nn.Module): class GLMAttention(nn.Module):
......
...@@ -11,7 +11,7 @@ from transformers.models.clip.modeling_clip import CLIPSdpaAttention ...@@ -11,7 +11,7 @@ from transformers.models.clip.modeling_clip import CLIPSdpaAttention
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import LLMInputs from vllm.inputs import DecoderOnlyInputs, token_inputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -62,7 +62,7 @@ def dummy_seq_data_for_clip( ...@@ -62,7 +62,7 @@ def dummy_seq_data_for_clip(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
return SequenceData.from_token_counts( return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images), (image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images), (0, seq_len - image_feature_size * num_images),
) )
...@@ -106,14 +106,14 @@ def dummy_video_for_clip( ...@@ -106,14 +106,14 @@ def dummy_video_for_clip(
def input_processor_for_clip( def input_processor_for_clip(
model_config: ModelConfig, model_config: ModelConfig,
hf_config: CLIPVisionConfig, hf_config: CLIPVisionConfig,
llm_inputs: LLMInputs, inputs: DecoderOnlyInputs,
*, *,
image_token_id: int, image_token_id: int,
image_feature_size_override: Optional[Union[int, List[int]]] = None, image_feature_size_override: Optional[Union[int, List[int]]] = None,
): ):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs return inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
...@@ -130,16 +130,16 @@ def input_processor_for_clip( ...@@ -130,16 +130,16 @@ def input_processor_for_clip(
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
llm_inputs.get("prompt"), inputs.get("prompt"),
llm_inputs["prompt_token_ids"], inputs["prompt_token_ids"],
placeholder_token_id=image_token_id, placeholder_token_id=image_token_id,
repeat_count=image_feature_size, repeat_count=image_feature_size,
) )
# NOTE: Create a defensive copy of the original inputs # NOTE: Create a defensive copy of the original inputs
return LLMInputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data)
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
......
...@@ -27,7 +27,8 @@ from transformers import FuyuConfig, FuyuImageProcessor ...@@ -27,7 +27,8 @@ from transformers import FuyuConfig, FuyuImageProcessor
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
...@@ -149,10 +150,10 @@ def _fuyu_image_preprocess(image_processor: FuyuImageProcessor, ...@@ -149,10 +150,10 @@ def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
return model_image_input return model_image_input
def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs return inputs
model_config = ctx.model_config model_config = ctx.model_config
image_data = multi_modal_data["image"] image_data = multi_modal_data["image"]
...@@ -176,8 +177,8 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -176,8 +177,8 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")
# process prompts # process prompts
prompt = llm_inputs.get("prompt") prompt = inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"] prompt_token_ids = inputs["prompt_token_ids"]
tokenizer = cached_get_tokenizer(model_config.model) tokenizer = cached_get_tokenizer(model_config.model)
# dim0 is batch_size, dim1 is subseq_size which will always be 1 # dim0 is batch_size, dim1 is subseq_size which will always be 1
image_input_ids: List[List[ image_input_ids: List[List[
...@@ -190,9 +191,9 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -190,9 +191,9 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[ new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[
1:] + boa_token 1:] + boa_token
return LLMInputs(prompt=new_prompt, return token_inputs(prompt=new_prompt,
prompt_token_ids=new_prompt_token_ids, prompt_token_ids=new_prompt_token_ids,
multi_modal_data=new_multi_modal_data) multi_modal_data=new_multi_modal_data)
def input_mapper_for_fuyu(ctx: InputContext, data: object): def input_mapper_for_fuyu(ctx: InputContext, data: object):
......
...@@ -17,7 +17,8 @@ from transformers import PretrainedConfig ...@@ -17,7 +17,8 @@ from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.models.intern_vit import InternVisionModel from vllm.model_executor.models.intern_vit import InternVisionModel
...@@ -276,13 +277,13 @@ class InternVLInputPipeline: ...@@ -276,13 +277,13 @@ class InternVLInputPipeline:
def input_processor( def input_processor(
self, self,
ctx: InputContext, ctx: InputContext,
llm_inputs: LLMInputs, inputs: DecoderOnlyInputs,
*, *,
max_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None,
) -> LLMInputs: ) -> DecoderOnlyInputs:
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs return inputs
model_config = ctx.model_config model_config = ctx.model_config
hf_config = ctx.get_hf_config() hf_config = ctx.get_hf_config()
...@@ -311,8 +312,8 @@ class InternVLInputPipeline: ...@@ -311,8 +312,8 @@ class InternVLInputPipeline:
model_config.tokenizer, model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code) trust_remote_code=model_config.trust_remote_code)
prompt = llm_inputs.get("prompt") prompt = inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"] prompt_token_ids = inputs["prompt_token_ids"]
if prompt is None: if prompt is None:
prompt = tokenizer.decode(prompt_token_ids) prompt = tokenizer.decode(prompt_token_ids)
...@@ -320,9 +321,9 @@ class InternVLInputPipeline: ...@@ -320,9 +321,9 @@ class InternVLInputPipeline:
num_patches) num_patches)
new_prompt_token_ids = tokenizer.encode(new_prompt) new_prompt_token_ids = tokenizer.encode(new_prompt)
return LLMInputs(prompt=prompt, return token_inputs(prompt=prompt,
prompt_token_ids=new_prompt_token_ids, prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data)
def input_mapper( def input_mapper(
self, self,
......
...@@ -9,7 +9,7 @@ from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig ...@@ -9,7 +9,7 @@ from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
...@@ -125,10 +125,10 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, ...@@ -125,10 +125,10 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
raise NotImplementedError(msg) raise NotImplementedError(msg)
def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs return inputs
model_config = ctx.model_config model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaConfig) hf_config = ctx.get_hf_config(LlavaConfig)
...@@ -151,7 +151,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -151,7 +151,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
return input_processor_for_clip( return input_processor_for_clip(
model_config, model_config,
vision_config, vision_config,
llm_inputs, inputs,
image_token_id=hf_config.image_token_index, image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size, image_feature_size_override=image_feature_size,
) )
...@@ -159,7 +159,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -159,7 +159,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
return input_processor_for_siglip( return input_processor_for_siglip(
model_config, model_config,
vision_config, vision_config,
llm_inputs, inputs,
image_token_id=hf_config.image_token_index, image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size, image_feature_size_override=image_feature_size,
) )
......
...@@ -12,7 +12,7 @@ from typing_extensions import NotRequired ...@@ -12,7 +12,7 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -201,10 +201,11 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, ...@@ -201,10 +201,11 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
raise NotImplementedError(msg) raise NotImplementedError(msg)
def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): def input_processor_for_llava_next(ctx: InputContext,
multi_modal_data = llm_inputs.get("multi_modal_data") inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs return inputs
model_config = ctx.model_config model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaNextConfig) hf_config = ctx.get_hf_config(LlavaNextConfig)
...@@ -239,7 +240,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -239,7 +240,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
return input_processor_for_clip( return input_processor_for_clip(
model_config, model_config,
vision_config, vision_config,
llm_inputs, inputs,
image_token_id=hf_config.image_token_index, image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size, image_feature_size_override=image_feature_size,
) )
...@@ -247,7 +248,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -247,7 +248,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
return input_processor_for_siglip( return input_processor_for_siglip(
model_config, model_config,
vision_config, vision_config,
llm_inputs, inputs,
image_token_id=hf_config.image_token_index, image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size, image_feature_size_override=image_feature_size,
) )
......
...@@ -11,7 +11,8 @@ from transformers import (CLIPVisionConfig, LlavaNextVideoConfig, ...@@ -11,7 +11,8 @@ from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
...@@ -139,10 +140,10 @@ def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int, ...@@ -139,10 +140,10 @@ def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int,
def input_processor_for_llava_next_video(ctx: InputContext, def input_processor_for_llava_next_video(ctx: InputContext,
llm_inputs: LLMInputs): inputs: DecoderOnlyInputs):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "video" not in multi_modal_data: if multi_modal_data is None or "video" not in multi_modal_data:
return llm_inputs return inputs
video_data = multi_modal_data["video"] video_data = multi_modal_data["video"]
model_config = ctx.model_config model_config = ctx.model_config
...@@ -160,15 +161,15 @@ def input_processor_for_llava_next_video(ctx: InputContext, ...@@ -160,15 +161,15 @@ def input_processor_for_llava_next_video(ctx: InputContext,
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
llm_inputs.get("prompt"), inputs.get("prompt"),
llm_inputs["prompt_token_ids"], inputs["prompt_token_ids"],
placeholder_token_id=hf_config.video_token_index, placeholder_token_id=hf_config.video_token_index,
repeat_count=video_feature_size, repeat_count=video_feature_size,
) )
return LLMInputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data)
elif is_list_of(video_data, np.ndarray): elif is_list_of(video_data, np.ndarray):
raise NotImplementedError( raise NotImplementedError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment