Unverified Commit 0b8bb86b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[1/N] Initial prototype for multi-modal processor (#10044)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent bb7991aa
...@@ -17,7 +17,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -17,7 +17,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import NestedTensors from vllm.multimodal.inputs import NestedTensors
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
......
...@@ -51,8 +51,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys ...@@ -51,8 +51,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.models.utils import LLMWrapper from vllm.model_executor.models.utils import LLMWrapper
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.base import MultiModalKwargs
from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
......
...@@ -39,7 +39,7 @@ from vllm.model_executor.models.llama import LlamaForCausalLM ...@@ -39,7 +39,7 @@ from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import NestedTensors, PlaceholderRange from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.utils import is_list_of from vllm.utils import is_list_of
......
...@@ -29,8 +29,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler ...@@ -29,8 +29,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.model_executor.models.utils import merge_multimodal_embeddings
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.base import MultiModalKwargs
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges) consecutive_placeholder_ranges)
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
......
...@@ -42,8 +42,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -42,8 +42,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.base import MultiModalKwargs
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from vllm.utils import is_list_of from vllm.utils import is_list_of
......
...@@ -60,10 +60,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead ...@@ -60,10 +60,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, from vllm.multimodal import MULTIMODAL_REGISTRY
MultiModalKwargs)
from vllm.multimodal.base import MultiModalData
from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict,
MultiModalKwargs)
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData
from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.config import uses_mrope
......
...@@ -15,7 +15,7 @@ from vllm.config import VllmConfig ...@@ -15,7 +15,7 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.multimodal.base import MultiModalPlaceholderMap, NestedTensors from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
......
from .base import (BatchedTensorInputs, MultiModalDataBuiltins, from .base import MultiModalPlaceholderMap, MultiModalPlugin
MultiModalDataDict, MultiModalKwargs, from .inputs import (BatchedTensorInputs, MultiModalData,
MultiModalPlaceholderDict, MultiModalPlaceholderMap, MultiModalDataBuiltins, MultiModalDataDict,
MultiModalPlugin, NestedTensors) MultiModalKwargs, MultiModalPlaceholderDict,
NestedTensors)
from .registry import MultiModalRegistry from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry() MULTIMODAL_REGISTRY = MultiModalRegistry()
...@@ -15,6 +16,7 @@ See also: ...@@ -15,6 +16,7 @@ See also:
__all__ = [ __all__ = [
"BatchedTensorInputs", "BatchedTensorInputs",
"MultiModalData",
"MultiModalDataBuiltins", "MultiModalDataBuiltins",
"MultiModalDataDict", "MultiModalDataDict",
"MultiModalKwargs", "MultiModalKwargs",
......
from vllm.inputs.registry import InputContext from vllm.inputs.registry import InputContext
from vllm.multimodal.base import MultiModalKwargs, MultiModalPlugin
from .base import MultiModalPlugin
from .inputs import AudioItem, MultiModalData, MultiModalKwargs
class AudioPlugin(MultiModalPlugin): class AudioPlugin(MultiModalPlugin):
...@@ -8,8 +10,12 @@ class AudioPlugin(MultiModalPlugin): ...@@ -8,8 +10,12 @@ class AudioPlugin(MultiModalPlugin):
def get_data_key(self) -> str: def get_data_key(self) -> str:
return "audio" return "audio"
def _default_input_mapper(self, ctx: InputContext, data: object, def _default_input_mapper(
**mm_processor_kwargs) -> MultiModalKwargs: self,
ctx: InputContext,
data: MultiModalData[AudioItem],
**mm_processor_kwargs,
) -> MultiModalKwargs:
raise NotImplementedError("There is no default audio input mapper") raise NotImplementedError("There is no default audio input mapper")
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import UserDict, defaultdict from collections import defaultdict
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping, from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
NamedTuple, Optional, Tuple, Type, TypedDict, TypeVar, Optional, Sequence, Tuple, Type, TypeVar, Union)
Union, cast, final)
import numpy as np
import torch
import torch.types
from PIL import Image
from torch import nn from torch import nn
from typing_extensions import TypeAlias
from vllm.inputs import InputContext from vllm.inputs import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of, from vllm.utils import (get_allowed_kwarg_only_overrides,
json_map_leaves, resolve_mm_processor_kwargs) resolve_mm_processor_kwargs)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.sequence import SequenceGroupMetadata from vllm.sequence import SequenceGroupMetadata
logger = init_logger(__name__) from .inputs import (MultiModalData, MultiModalDataDict, MultiModalKwargs,
PlaceholderRange)
NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor]
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""
BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors]
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalKwargs.batch`.
"""
class _MultiModalKwargsBase(UserDict[str, NestedTensors]):
pass
class MultiModalKwargs(_MultiModalKwargsBase):
"""
A dictionary that represents the keyword arguments to
:meth:`~torch.nn.Module.forward`.
"""
@staticmethod
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
"""
Recursively stacks lists of tensors when they all have the same shape.
"""
if isinstance(nested_tensors, torch.Tensor):
return nested_tensors
if isinstance(nested_tensors, np.ndarray):
return torch.from_numpy(nested_tensors)
if isinstance(nested_tensors, (int, float)):
return torch.tensor(nested_tensors)
stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors] logger = init_logger(__name__)
if not is_list_of(stacked, torch.Tensor, check="all"):
# Only tensors (not lists) can be stacked.
return stacked
tensors_ = cast(List[torch.Tensor], stacked)
if any(t.shape != tensors_[0].shape for t in tensors_):
# The tensors have incompatible shapes and can't be stacked.
return tensors_
return torch.stack(tensors_)
@staticmethod
def batch(inputs_list: List["MultiModalKwargs"]) -> BatchedTensorInputs:
"""
Batch multiple inputs together into a dictionary.
The resulting dictionary has the same keys as the inputs.
If the corresponding value from each input is a tensor and they all
share the same shape, the output value is a single batched tensor;
otherwise, the output value is a list containing the original value
from each input.
"""
if len(inputs_list) == 0:
return {}
item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
for inputs in inputs_list:
# For models that supports multiple modalities (e.g. Qwen2-VL),
# different modalities will return different data keys,
# so batch() should skip the same key check.
for k, v in inputs.items():
item_lists[k].append(v)
return {
k: MultiModalKwargs._try_stack(item_list)
for k, item_list in item_lists.items()
}
@staticmethod
def as_kwargs(
batched_inputs: BatchedTensorInputs,
*,
device: torch.types.Device,
) -> BatchedTensorInputs:
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
json_mapped = json_map_leaves(
lambda x: x.to(device, non_blocking=True),
json_inputs,
)
return cast(BatchedTensorInputs, json_mapped)
_T = TypeVar("_T")
MultiModalData: TypeAlias = Union[_T, List[_T]]
"""
Either a single data instance, or a list of data instances.
The number of data instances allowed per modality is restricted by
`--limit-mm-per-prompt`.
"""
@final
class MultiModalDataBuiltins(TypedDict, total=False):
"""Modality types that are predefined by vLLM."""
image: MultiModalData[Image.Image]
"""The input image(s)."""
audio: MultiModalData[Tuple[np.ndarray, Union[int, float]]]
"""The input audio item(s) and corresponding sampling rate(s)."""
video: MultiModalData[Tuple[np.ndarray]]
"""The input video(s)."""
MultiModalDataDict = Union[MultiModalDataBuiltins,
Mapping[str, MultiModalData[object]]]
"""
A dictionary containing an item for each modality type to input.
Note:
This dictionary also accepts modality keys defined outside
:class:`MultiModalDataBuiltins` as long as a customized plugin is registered
through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
Read more on that :ref:`here <adding_multimodal_plugin>`.
"""
class PlaceholderRange(TypedDict):
"""
Placeholder location information for multi-modal data.
For example:
Prompt: AAAA BBBB What is in these images?
Images A and B will have:
A: { "offset": 0, "length": 4 }
B: { "offset": 5, "length": 4 }
"""
offset: int
"""The start index of the placeholder in the prompt."""
length: int
"""The length of the placeholder."""
MultiModalPlaceholderDict = Mapping[str, List[PlaceholderRange]]
"""
A dictionary containing placeholder ranges.
"""
MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]], MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]],
MultiModalKwargs] MultiModalKwargs]
...@@ -192,6 +35,7 @@ Calculate the maximum number of multimodal tokens input to the language ...@@ -192,6 +35,7 @@ Calculate the maximum number of multimodal tokens input to the language
model. This does not include tokens that correspond to the input text. model. This does not include tokens that correspond to the input text.
""" """
_T = TypeVar("_T")
N = TypeVar("N", bound=Type[nn.Module]) N = TypeVar("N", bound=Type[nn.Module])
...@@ -224,7 +68,7 @@ class MultiModalPlugin(ABC): ...@@ -224,7 +68,7 @@ class MultiModalPlugin(ABC):
def _default_input_mapper( def _default_input_mapper(
self, self,
ctx: InputContext, ctx: InputContext,
data: MultiModalData[object], data: MultiModalData[Any],
**mm_processor_kwargs, **mm_processor_kwargs,
) -> MultiModalKwargs: ) -> MultiModalKwargs:
""" """
...@@ -273,8 +117,8 @@ class MultiModalPlugin(ABC): ...@@ -273,8 +117,8 @@ class MultiModalPlugin(ABC):
def map_input( def map_input(
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
data: MultiModalData[object], data: MultiModalData[Any],
mm_processor_kwargs: Dict[str, Any], mm_processor_kwargs: Optional[Dict[str, Any]],
) -> MultiModalKwargs: ) -> MultiModalKwargs:
""" """
Transform the data into a dictionary of model inputs using the Transform the data into a dictionary of model inputs using the
...@@ -289,6 +133,7 @@ class MultiModalPlugin(ABC): ...@@ -289,6 +133,7 @@ class MultiModalPlugin(ABC):
- :ref:`input_processing_pipeline` - :ref:`input_processing_pipeline`
- :ref:`enabling_multimodal_inputs` - :ref:`enabling_multimodal_inputs`
""" """
# Avoid circular import # Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture from vllm.model_executor.model_loader import get_model_architecture
...@@ -300,6 +145,9 @@ class MultiModalPlugin(ABC): ...@@ -300,6 +145,9 @@ class MultiModalPlugin(ABC):
raise KeyError(f"No input mapper in {self} is registered for " raise KeyError(f"No input mapper in {self} is registered for "
f"model class {model_cls.__name__}.") f"model class {model_cls.__name__}.")
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
# In the case of the default mapper, we have to get resource # In the case of the default mapper, we have to get resource
# processor through its HuggingFace autoclass; since this goes # processor through its HuggingFace autoclass; since this goes
# through **kwargs, we can't inspect it the same way, so we allow # through **kwargs, we can't inspect it the same way, so we allow
...@@ -508,7 +356,7 @@ class MultiModalPlaceholderMap: ...@@ -508,7 +356,7 @@ class MultiModalPlaceholderMap:
self, self,
positions: range, positions: range,
multi_modal_items: List[_T], multi_modal_items: List[_T],
multi_modal_placeholders: List[PlaceholderRange], multi_modal_placeholders: Sequence[PlaceholderRange],
) -> List[_T]: ) -> List[_T]:
""" """
Adds the multi-modal items that intersect ```positions`` to this Adds the multi-modal items that intersect ```positions`` to this
......
...@@ -3,14 +3,14 @@ from typing import TYPE_CHECKING, Any, Dict, Optional ...@@ -3,14 +3,14 @@ from typing import TYPE_CHECKING, Any, Dict, Optional
import torch import torch
from PIL import Image from PIL import Image
from transformers.image_processing_base import BatchFeature
from vllm.inputs.registry import InputContext from vllm.inputs.registry import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.processor import get_image_processor from vllm.transformers_utils.processor import get_image_processor
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .base import MultiModalData, MultiModalKwargs, MultiModalPlugin from .base import MultiModalPlugin
from .inputs import ImageItem, MultiModalData, MultiModalKwargs
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -41,15 +41,11 @@ class ImagePlugin(MultiModalPlugin): ...@@ -41,15 +41,11 @@ class ImagePlugin(MultiModalPlugin):
def _default_input_mapper( def _default_input_mapper(
self, self,
ctx: InputContext, ctx: InputContext,
data: MultiModalData[object], data: MultiModalData[ImageItem],
**mm_processor_kwargs, **mm_processor_kwargs,
) -> MultiModalKwargs: ) -> MultiModalKwargs:
model_config = ctx.model_config model_config = ctx.model_config
# Processed by input processor
if isinstance(data, BatchFeature):
return MultiModalKwargs(data.data)
# PIL image # PIL image
if isinstance(data, Image.Image) or is_list_of(data, Image.Image): if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
image_processor = self._get_hf_image_processor( image_processor = self._get_hf_image_processor(
......
from collections import UserDict, defaultdict
from typing import (Any, Dict, List, Literal, Mapping, Sequence, Tuple,
TypedDict, TypeVar, Union, cast, final)
import numpy as np
import torch
import torch.types
from PIL.Image import Image
from typing_extensions import TypeAlias
from vllm.utils import JSONTree, is_list_of, json_map_leaves
_T = TypeVar("_T")
# yapf: disable
ImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
"""
A :class:`transformers.image_utils.ImageInput` representing a single image,
which can be passed to a HuggingFace :code:`ImageProcessor`.
"""
VideoItem: TypeAlias = Union[
List[Image],
np.ndarray,
torch.Tensor,
List[np.ndarray],
List[torch.Tensor],
]
"""
A :class:`transformers.image_utils.VideoInput` representing a single video,
which can be passed to a HuggingFace :code:`VideoProcessor`.
"""
AudioItem: TypeAlias = Union[
np.ndarray,
List[float],
Tuple[np.ndarray, float], # DEPRECATED: Use mm_processor_kwargs instead
]
"""
Represents a single audio that can be inputted to a HuggingFace
:code:`AudioProcessor`.
"""
# yapf: enable
MultiModalData: TypeAlias = Union[_T, List[_T]]
"""
Either a single data item, or a list of data items.
The number of data items allowed per modality is restricted by
:code:`--limit-mm-per-prompt`.
"""
@final
class MultiModalDataBuiltins(TypedDict, total=False):
"""Type annotations for modality types predefined by vLLM."""
image: MultiModalData[ImageItem]
"""The input image(s)."""
video: MultiModalData[VideoItem]
"""The input video(s)."""
audio: MultiModalData[AudioItem]
"""The input audio(s)."""
MultiModalDataDict: TypeAlias = Mapping[str, MultiModalData[Any]]
"""
A dictionary containing an entry for each modality type to input.
Note:
This dictionary also accepts modality keys defined outside
:class:`MultiModalDataBuiltins` as long as a customized plugin
is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
Read more on that :ref:`here <adding_multimodal_plugin>`.
"""
class PlaceholderRange(TypedDict):
"""
Placeholder location information for multi-modal data.
For example:
Prompt: AAAA BBBB What is in these images?
Images A and B will have:
A: { "offset": 0, "length": 4 }
B: { "offset": 5, "length": 4 }
"""
offset: int
"""The start index of the placeholder in the prompt."""
length: int
"""The length of the placeholder."""
NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor]
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""
BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors]
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalKwargs.batch`.
"""
class MultiModalKwargs(UserDict[str, NestedTensors]):
"""
A dictionary that represents the keyword arguments to
:meth:`~torch.nn.Module.forward`.
"""
@staticmethod
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
"""
Stack the inner dimensions that have the same shape in
a nested list of tensors.
Thus, a dimension represented by a list means that the inner
dimensions are different for each element along that dimension.
"""
if isinstance(nested_tensors, torch.Tensor):
return nested_tensors
# TODO: Remove these once all models have been migrated
if isinstance(nested_tensors, np.ndarray):
return torch.from_numpy(nested_tensors)
if isinstance(nested_tensors, (int, float)):
return torch.tensor(nested_tensors)
stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors]
if not is_list_of(stacked, torch.Tensor, check="all"):
# Only tensors (not lists) can be stacked.
return stacked
tensors_ = cast(List[torch.Tensor], stacked)
if any(t.shape != tensors_[0].shape for t in tensors_):
# The tensors have incompatible shapes and can't be stacked.
return tensors_
return torch.stack(tensors_)
@staticmethod
def batch(inputs_list: List["MultiModalKwargs"]) -> BatchedTensorInputs:
"""
Batch multiple inputs together into a dictionary.
The resulting dictionary has the same keys as the inputs.
If the corresponding value from each input is a tensor and they all
share the same shape, the output value is a single batched tensor;
otherwise, the output value is a list containing the original value
from each input.
"""
if len(inputs_list) == 0:
return {}
# We need to consider the case where each item in the batch
# contains different modalities (i.e. different keys).
item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
for inputs in inputs_list:
for k, v in inputs.items():
item_lists[k].append(v)
return {
k: MultiModalKwargs._try_stack(item_list)
for k, item_list in item_lists.items()
}
@staticmethod
def as_kwargs(
batched_inputs: BatchedTensorInputs,
*,
device: torch.types.Device,
) -> BatchedTensorInputs:
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
json_mapped = json_map_leaves(
lambda x: x.to(device, non_blocking=True),
json_inputs,
)
return cast(BatchedTensorInputs, json_mapped)
MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
"""
A dictionary containing placeholder ranges.
"""
class MultiModalInputsV2(TypedDict):
"""
Represents the outputs of :class:`vllm.multimodal.MultiModalProcessor`,
ready to be passed to vLLM internals.
"""
type: Literal["multimodal"]
"""The type of inputs."""
prompt: str
"""
The original, unprocessed prompt text.
Note:
Since prompt text is not required by vLLM internals, we leave this
unprocessed to save CPU computation. You can still call
:code:`tokenizer.decode(prompt_token_ids)` to get the processed text.
"""
prompt_token_ids: List[int]
"""The processed token IDs which includes placeholder tokens."""
mm_kwargs: MultiModalKwargs
"""Keyword arguments to be directly passed to the model after batching."""
mm_placeholders: MultiModalPlaceholderDict
"""
For each modality, information about the placeholder tokens in
:code:`prompt_token_ids`.
"""
from dataclasses import dataclass
from functools import lru_cache, partial
from typing import (Any, Callable, Collection, Generic, List, Mapping,
Optional, TypedDict, TypeVar, final)
from transformers import BatchFeature
from typing_extensions import TypeAlias
from vllm.inputs import InputProcessingContext
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import is_list_of
from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
VideoItem)
_T = TypeVar("_T")
ReplacementFunc: TypeAlias = Callable[[_T, BatchFeature, int], List[int]]
"""
Given the original data item, HF-processed data, and index of the processed
item, output the replacement token IDs to be allocated in vLLM.
"""
@dataclass
class ModalityProcessingMetadata(Generic[_T]):
placeholder_replacements: Mapping[str, ReplacementFunc]
"""
A dictionary where each item represents the original placeholder in the
prompt text and the corresponding replacement.
"""
class MultiModalProcessingMetadataBuiltins(TypedDict, total=False):
"""Type annotations for modality types predefined by vLLM."""
image: ModalityProcessingMetadata[ImageItem]
video: ModalityProcessingMetadata[VideoItem]
audio: ModalityProcessingMetadata[AudioItem]
MultiModalProcessingMetadata: TypeAlias = \
Mapping[str, ModalityProcessingMetadata[Any]]
"""
A dictionary containing an entry for each modality type to process.
Note:
This dictionary also accepts modality keys defined outside
:class:`MultiModalProcessingMetadataBuiltins` as long as a customized plugin
is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
Read more on that :ref:`here <adding_multimodal_plugin>`.
"""
MultiModalMultiData: TypeAlias = List[_T]
"""
A list of data items, where the number of data items allowed
per modality is restricted by :code:`--limit-mm-per-prompt`.
"""
@final
class MultiModalMultiDataBuiltins(TypedDict, total=False):
"""Type annotations for modality types predefined by vLLM."""
image: MultiModalMultiData[ImageItem]
"""The input images."""
video: MultiModalMultiData[VideoItem]
"""The input videos."""
audio: MultiModalMultiData[AudioItem]
"""The input audios."""
MultiModalMultiDataDict: TypeAlias = Mapping[str, MultiModalMultiData[Any]]
"""
A dictionary containing an entry for each modality type to input.
Note:
This dictionary also accepts modality keys defined outside
:class:`MultiModalMultiDataBuiltins` as long as a customized plugin
is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
Read more on that :ref:`here <adding_multimodal_plugin>`.
"""
def to_multi_format(data: MultiModalDataDict) -> MultiModalMultiDataDict:
"""
Convert a :class:`MultiModalDataDict` containing single data items
to a :class:`MultiModalMultiDataDict` containing multiple data items
per entry.
"""
multi_data: Mapping[str, MultiModalMultiData[Any]] = {}
for k, v in data.items():
# yapf: disable
if k == "video":
# Special case since even a single item can be a list
multi_data[k] = v if is_list_of(v, list) else [v] # type: ignore[index]
elif k in ("image", "audio"):
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
else:
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
# yapf: enable
return multi_data
def encode_no_special_tokens(
tokenizer: AnyTokenizer,
text: str,
) -> List[int]:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.encode(text, add_special_tokens=False)`.
"""
if isinstance(tokenizer, MistralTokenizer):
return tokenizer.tokenizer.encode(text, bos=False, eos=False)
return tokenizer.encode(text, add_special_tokens=False)
@lru_cache
def candidate_placeholders(
tokenizer: AnyTokenizer,
placeholder_text: str,
) -> Collection[List[int]]:
"""Generate token ID sequences that may represent a placeholder text."""
# When the placeholder text is not mapped to a special token ID,
# it may be tokenized differently based on whether it is at the start/end
# of the string. So, we go through each combination of whether the text
# is at the start and end boundaries of the string
# Matches the placeholder when it is in the middle of the string
start_id, = encode_no_special_tokens(tokenizer, "a")
end_id, = encode_no_special_tokens(tokenizer, "b")
candidate_basic = encode_no_special_tokens(tokenizer, placeholder_text)
start_id_, *candidate_a = encode_no_special_tokens(
tokenizer,
f"a{placeholder_text}",
)
assert start_id == start_id_
start_id_, *candidate_ab, end_id_ = encode_no_special_tokens(
tokenizer,
f"a{placeholder_text}b",
)
assert start_id == start_id_ and end_id == end_id_
*candidate_b, end_id_ = encode_no_special_tokens(
tokenizer,
f"{placeholder_text}b",
)
assert end_id == end_id_
# Remove duplicates (need to convert to tuple to be hashable)
unique_candidates = {
tuple(c)
for c in [candidate_basic, candidate_a, candidate_ab, candidate_b]
}
# Convert back to list
return [list(c) for c in unique_candidates]
def apply_placeholders(
token_ids: List[int],
placeholder_ids: List[int],
get_replacement_ids: Callable[[], List[int]],
) -> Optional[PlaceholderRange]:
"""
Find the first occurrence of :code:`placeholder_ids`,
and replace it with the output of :code:`get_replacement_ids`.
This function updates :code:`token_ids` in place.
"""
placeholder_length = len(placeholder_ids)
for start_idx in range(len(token_ids) - placeholder_length + 1):
if token_ids[start_idx:placeholder_length] == placeholder_ids:
token_ids[start_idx:placeholder_length] = get_replacement_ids()
return PlaceholderRange(offset=start_idx,
length=placeholder_length)
return None
class MultiModalProcessor:
"""
Helper class to process multi-modal inputs to be used in vLLM.
"""
def __init__(
self,
ctx: InputProcessingContext,
metadata: MultiModalProcessingMetadata,
) -> None:
super().__init__()
self.ctx = ctx
self.metadata = metadata
def __call__(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
return self.apply(prompt, mm_data, mm_processor_kwargs)
def apply(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
tokenizer = self.ctx.tokenizer
hf_processor = self.ctx.get_hf_processor()
processed_inputs = hf_processor(
text=prompt, # type: ignore
**mm_data,
**mm_processor_kwargs,
)
new_token_ids, = processed_inputs.pop("input_ids").tolist()
mm_kwargs = MultiModalKwargs(processed_inputs)
mm_placeholders: Mapping[str, List[PlaceholderRange]] = {}
for modality, orig_inputs in to_multi_format(mm_data).items():
assert isinstance(orig_inputs, list)
metadata = self.metadata[modality]
placeholder_replacements = metadata.placeholder_replacements
modality_placeholders: List[PlaceholderRange] = []
for item_idx, orig_item in enumerate(orig_inputs):
for match_text, replace_fn in placeholder_replacements.items():
candidates = candidate_placeholders(tokenizer, match_text)
get_replacement_ids = partial(
replace_fn,
orig_item,
processed_inputs,
item_idx,
)
for match_ids in candidates:
# TODO(youkaichao): Don't update new_token_ids
placeholders = apply_placeholders(
new_token_ids,
match_ids,
get_replacement_ids,
)
if placeholders is not None:
modality_placeholders.append(placeholders)
# yapf: disable
mm_placeholders[modality] = modality_placeholders # type: ignore[index]
# yapf: enable
return MultiModalInputsV2(
type="multimodal",
prompt=prompt,
prompt_token_ids=new_token_ids,
mm_kwargs=mm_kwargs,
mm_placeholders=mm_placeholders,
)
import functools import functools
from collections import UserDict from collections import UserDict
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional,
Sequence, Type, TypeVar)
import torch.nn as nn
from typing_extensions import TypeAlias
from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .audio import AudioPlugin from .audio import AudioPlugin
from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalKwargs, from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
MultiModalPlugin, MultiModalTokensCalc, NestedTensors)
from .image import ImagePlugin from .image import ImagePlugin
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
from .processing import MultiModalProcessor
from .video import VideoPlugin from .video import VideoPlugin
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -15,8 +22,18 @@ if TYPE_CHECKING: ...@@ -15,8 +22,18 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
N = TypeVar("N", bound=Type[nn.Module])
MultiModalProcessorFactory: TypeAlias = Callable[[InputProcessingContext],
MultiModalProcessor]
"""
Constructs a :class:`MultiModalProcessor` instance from the context.
The processing metadata should be derived from the context.
"""
class _MultiModalLimits(UserDict): class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]):
""" """
Wraps `_limits_by_model` for a more informative error message Wraps `_limits_by_model` for a more informative error message
when attempting to access a model that does not exist. when attempting to access a model that does not exist.
...@@ -45,6 +62,9 @@ class MultiModalRegistry: ...@@ -45,6 +62,9 @@ class MultiModalRegistry:
plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None: plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
self._plugins = {p.get_data_key(): p for p in plugins} self._plugins = {p.get_data_key(): p for p in plugins}
self._processor_factories: Dict[Type[nn.Module],
MultiModalProcessorFactory] = {}
# This is used for non-multimodal models # This is used for non-multimodal models
self._disabled_limits_per_plugin = {k: 0 for k in self._plugins} self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}
...@@ -243,3 +263,59 @@ class MultiModalRegistry: ...@@ -243,3 +263,59 @@ class MultiModalRegistry:
This should be called after :meth:`init_mm_limits_per_prompt`. This should be called after :meth:`init_mm_limits_per_prompt`.
""" """
return self._limits_by_model[model_config] return self._limits_by_model[model_config]
def register_processor(
self,
factory: MultiModalProcessorFactory,
):
"""
Register a multi-modal processor to a model class.
When the model receives multi-modal data, the provided function is
invoked to transform the data into a dictionary of model inputs.
See also:
- :ref:`input_processing_pipeline`
- :ref:`enabling_multimodal_inputs`
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._processor_factories:
logger.warning(
"Model class %s already has an input mapper "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._processor_factories[model_cls] = factory
return model_cls
return wrapper
def has_processor(self, model_config: "ModelConfig") -> bool:
"""
Test whether a multi-modal processor is defined for a specific model.
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
return model_cls in self._processor_factories
def create_processor(
self,
model_config: "ModelConfig",
tokenizer: AnyTokenizer,
) -> MultiModalProcessor:
"""
Create a multi-modal processor for a specific model and tokenizer.
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
processor_factory = self._processor_factories[model_cls]
ctx = InputProcessingContext(model_config, tokenizer)
return processor_factory(ctx)
...@@ -11,9 +11,10 @@ from PIL import Image ...@@ -11,9 +11,10 @@ from PIL import Image
import vllm.envs as envs import vllm.envs as envs
from vllm.connections import global_http_connection from vllm.connections import global_http_connection
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal.base import MultiModalDataDict, PlaceholderRange
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from .inputs import MultiModalDataDict, PlaceholderRange
logger = init_logger(__name__) logger = init_logger(__name__)
cached_get_tokenizer = lru_cache(get_tokenizer) cached_get_tokenizer = lru_cache(get_tokenizer)
......
from functools import lru_cache from functools import lru_cache
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, Optional
import numpy as np import numpy as np
...@@ -9,8 +9,9 @@ from vllm.transformers_utils.processor import get_video_processor ...@@ -9,8 +9,9 @@ from vllm.transformers_utils.processor import get_video_processor
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .base import MultiModalData, MultiModalKwargs from .base import MultiModalData
from .image import ImagePlugin from .image import ImagePlugin
from .inputs import MultiModalKwargs, VideoItem
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -20,17 +21,6 @@ logger = init_logger(__name__) ...@@ -20,17 +21,6 @@ logger = init_logger(__name__)
cached_get_video_processor = lru_cache(get_video_processor) cached_get_video_processor = lru_cache(get_video_processor)
cached_get_tokenizer = lru_cache(get_tokenizer) cached_get_tokenizer = lru_cache(get_tokenizer)
VideoInput = Union[
"np.ndarray", # single video input
List["np.ndarray"],
# TODO: support more types
# List[Image.Image], List[List[Image.Image]],
# "torch.Tensor",
# List["torch.Tensor"],
# List[List["np.ndarrray"]],
# List[List["torch.Tensor"]],
]
class VideoPlugin(ImagePlugin): class VideoPlugin(ImagePlugin):
"""Plugin for video data.""" """Plugin for video data."""
...@@ -53,13 +43,13 @@ class VideoPlugin(ImagePlugin): ...@@ -53,13 +43,13 @@ class VideoPlugin(ImagePlugin):
def _default_input_mapper( def _default_input_mapper(
self, self,
ctx: InputContext, ctx: InputContext,
data: MultiModalData[object], data: MultiModalData[VideoItem],
**mm_processor_kwargs, **mm_processor_kwargs,
) -> MultiModalKwargs: ) -> MultiModalKwargs:
model_config = ctx.model_config model_config = ctx.model_config
if isinstance(data, list) and len(data) == 1: if isinstance(data, list) and len(data) == 1:
data = data[0] data = data[0] # type: ignore
if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray): if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray):
video_processor = self._get_hf_video_processor( video_processor = self._get_hf_video_processor(
......
...@@ -5,25 +5,21 @@ from abc import ABC, abstractmethod ...@@ -5,25 +5,21 @@ from abc import ABC, abstractmethod
from array import array from array import array
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import cached_property, reduce from functools import reduce
from typing import (TYPE_CHECKING, Any, Callable, DefaultDict, Dict, List, from typing import Any, Callable, DefaultDict, Dict, List, Mapping, Optional
Mapping, Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union from typing import Set, Tuple, Union
import msgspec import msgspec
import torch import torch
from typing_extensions import assert_never
from vllm.inputs import SingletonInputs, SingletonInputsAdapter
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
if TYPE_CHECKING:
from vllm.inputs import SingletonInputs
VLLM_TOKEN_ID_ARRAY_TYPE = "l" VLLM_TOKEN_ID_ARRAY_TYPE = "l"
VLLM_INVALID_TOKEN_ID = -1 VLLM_INVALID_TOKEN_ID = -1
...@@ -407,14 +403,14 @@ class Sequence: ...@@ -407,14 +403,14 @@ class Sequence:
def __init__( def __init__(
self, self,
seq_id: int, seq_id: int,
inputs: "SingletonInputs", inputs: SingletonInputs,
block_size: int, block_size: int,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None: ) -> None:
self.seq_id = seq_id self.seq_id = seq_id
self.inputs = inputs self.inputs = SingletonInputsAdapter(inputs)
self.block_size = block_size self.block_size = block_size
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.lora_request = lora_request self.lora_request = lora_request
...@@ -441,59 +437,29 @@ class Sequence: ...@@ -441,59 +437,29 @@ class Sequence:
def n_blocks(self) -> int: def n_blocks(self) -> int:
return (self.get_len() + self.block_size - 1) // self.block_size return (self.get_len() + self.block_size - 1) // self.block_size
@cached_property @property
def prompt(self) -> Optional[str]: def prompt(self) -> Optional[str]:
inputs = self.inputs return self.inputs.prompt
if inputs["type"] == "token":
return inputs.get("prompt")
assert_never(inputs) @property
@cached_property
def prompt_token_ids(self) -> List[int]: def prompt_token_ids(self) -> List[int]:
inputs = self.inputs return self.inputs.prompt_token_ids
if inputs["type"] == "token":
return inputs.get("prompt_token_ids", [])
assert_never(inputs) @property
@cached_property
def prompt_embeds(self) -> Optional[torch.Tensor]: def prompt_embeds(self) -> Optional[torch.Tensor]:
inputs = self.inputs return self.inputs.prompt_embeds
if inputs["type"] == "token":
return None
assert_never(inputs)
@cached_property @property
def multi_modal_data(self) -> "MultiModalDataDict": def multi_modal_data(self) -> "MultiModalDataDict":
inputs = self.inputs return self.inputs.multi_modal_data
if inputs["type"] == "token":
return inputs.get("multi_modal_data", {})
assert_never(inputs)
@cached_property
def mm_processor_kwargs(self) -> Dict[str, Any]:
inputs = self.inputs
if inputs["type"] == "token":
return inputs.get("mm_processor_kwargs", {})
assert_never(inputs)
@property @property
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
inputs = self.inputs return self.inputs.multi_modal_placeholders
if inputs["type"] == "token":
return inputs.get("multi_modal_placeholders", {})
assert_never(inputs) @property
def mm_processor_kwargs(self) -> Dict[str, Any]:
return self.inputs.mm_processor_kwargs
@property @property
def lora_int_id(self) -> int: def lora_int_id(self) -> int:
......
...@@ -6,6 +6,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs ...@@ -6,6 +6,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.metrics_types import StatLoggerBase from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
...@@ -321,6 +322,9 @@ class AsyncLLM(EngineClient): ...@@ -321,6 +322,9 @@ class AsyncLLM(EngineClient):
async def get_decoding_config(self): async def get_decoding_config(self):
raise ValueError("Not Supported on V1 yet.") raise ValueError("Not Supported on V1 yet.")
async def get_input_preprocessor(self) -> InputPreprocessor:
return self.processor.input_preprocessor
async def get_tokenizer( async def get_tokenizer(
self, self,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
......
...@@ -7,6 +7,7 @@ from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING ...@@ -7,6 +7,7 @@ from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
...@@ -32,6 +33,7 @@ class LLMEngine: ...@@ -32,6 +33,7 @@ class LLMEngine:
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY, input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False, use_cached_outputs: bool = False,
multiprocess_mode: bool = False, multiprocess_mode: bool = False,
) -> None: ) -> None:
...@@ -50,7 +52,7 @@ class LLMEngine: ...@@ -50,7 +52,7 @@ class LLMEngine:
# Processor (convert Inputs --> EngineCoreRequests) # Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(vllm_config.model_config, self.processor = Processor(vllm_config.model_config,
vllm_config.lora_config, self.tokenizer, vllm_config.lora_config, self.tokenizer,
input_registry) input_registry, mm_registry)
# Detokenizer (converts EngineCoreOutputs --> RequestOutput) # Detokenizer (converts EngineCoreOutputs --> RequestOutput)
self.detokenizer = Detokenizer( self.detokenizer = Detokenizer(
......
...@@ -2,15 +2,17 @@ import time ...@@ -2,15 +2,17 @@ import time
from typing import Any, Dict, Mapping, Optional, Tuple, Union from typing import Any, Dict, Mapping, Optional, Tuple, Union
from vllm.config import LoRAConfig, ModelConfig from vllm.config import LoRAConfig, ModelConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
EncoderDecoderLLMInputs, InputRegistry, PromptType) PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer_group import AnyTokenizer from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest
...@@ -20,8 +22,9 @@ class Processor: ...@@ -20,8 +22,9 @@ class Processor:
self, self,
model_config: ModelConfig, model_config: ModelConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
tokenizer: AnyTokenizer, tokenizer: BaseTokenizerGroup,
input_registry: InputRegistry = INPUT_REGISTRY, input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
): ):
self.model_config = model_config self.model_config = model_config
...@@ -31,7 +34,8 @@ class Processor: ...@@ -31,7 +34,8 @@ class Processor:
self.generation_config_fields = _load_generation_config_dict( self.generation_config_fields = _load_generation_config_dict(
model_config) model_config)
self.input_preprocessor = InputPreprocessor(model_config, self.input_preprocessor = InputPreprocessor(model_config,
self.tokenizer) self.tokenizer,
mm_registry)
self.input_processor = input_registry.create_input_processor( self.input_processor = input_registry.create_input_processor(
model_config) model_config)
...@@ -73,6 +77,19 @@ class Processor: ...@@ -73,6 +77,19 @@ class Processor:
self._validate_model_inputs(processed_inputs) self._validate_model_inputs(processed_inputs)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
if is_encoder_decoder_inputs(processed_inputs):
decoder_inputs = SingletonInputsAdapter(
processed_inputs["decoder"])
encoder_inputs = SingletonInputsAdapter(
processed_inputs["encoder"])
else:
decoder_inputs = SingletonInputsAdapter(processed_inputs)
encoder_inputs = None
# TODO: Impl encoder-decoder
if encoder_inputs is not None:
raise NotImplementedError
assert isinstance(params, SamplingParams) assert isinstance(params, SamplingParams)
# TODO: can we avoid cloning here in multiproc case # TODO: can we avoid cloning here in multiproc case
sampling_params = params.clone() sampling_params = params.clone()
...@@ -81,27 +98,43 @@ class Processor: ...@@ -81,27 +98,43 @@ class Processor:
# Make Request for Detokenizer. # Make Request for Detokenizer.
detokenizer_request = DetokenizerRequest( detokenizer_request = DetokenizerRequest(
request_id, processed_inputs.get("prompt"), request_id,
processed_inputs.get("prompt_token_ids"), decoder_inputs.prompt,
decoder_inputs.prompt_token_ids,
sampling_params.skip_special_tokens, sampling_params.skip_special_tokens,
sampling_params.spaces_between_special_tokens, sampling_params.spaces_between_special_tokens,
sampling_params.output_kind, sampling_params.stop, sampling_params.output_kind,
sampling_params.include_stop_str_in_output) sampling_params.stop,
sampling_params.include_stop_str_in_output,
)
# Make Request for EngineCore. # Make Request for EngineCore.
engine_core_request = EngineCoreRequest( engine_core_request = EngineCoreRequest(
request_id, processed_inputs.get("prompt"), request_id,
processed_inputs.get("prompt_token_ids"), decoder_inputs.prompt,
processed_inputs.get("multi_modal_data"), decoder_inputs.prompt_token_ids,
processed_inputs.get("multi_modal_placeholders"), decoder_inputs.multi_modal_data,
processed_inputs.get("mm_processor_kwargs"), sampling_params, decoder_inputs.multi_modal_placeholders,
eos_token_id, arrival_time, lora_request) decoder_inputs.mm_processor_kwargs,
sampling_params,
eos_token_id,
arrival_time,
lora_request,
)
return detokenizer_request, engine_core_request return detokenizer_request, engine_core_request
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, def _validate_model_inputs(self, inputs: ProcessorInputs):
EncoderDecoderLLMInputs]): if is_encoder_decoder_inputs(inputs):
prompt_ids = inputs.get("prompt_token_ids") # For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
prompt_inputs = inputs["decoder" if self.model_config.
is_multimodal_model else "encoder"]
else:
prompt_inputs = inputs
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
if prompt_ids is None or len(prompt_ids) == 0: if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty") raise ValueError("Prompt cannot be empty")
...@@ -117,6 +150,10 @@ class Processor: ...@@ -117,6 +150,10 @@ class Processor:
"inputs, the number of image tokens depends on the number " "inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well.") "of images, and possibly their aspect ratios as well.")
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
config = try_get_generation_config( config = try_get_generation_config(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment