Unverified Commit 6c0b7f54 authored by Peter Salas's avatar Peter Salas Committed by GitHub
Browse files

[Core][VLM] Add precise multi-modal placeholder tracking (#8346)


Signed-off-by: default avatarPeter Salas <peter@fixie.ai>
parent d151fde8
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import UserDict, defaultdict from collections import UserDict, defaultdict
from typing import (Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping,
TypedDict, TypeVar, Union, cast, final) NamedTuple, Optional, Tuple, Type, TypedDict, TypeVar,
Union, cast, final)
import numpy as np import numpy as np
import torch import torch
...@@ -11,12 +12,15 @@ from PIL import Image ...@@ -11,12 +12,15 @@ from PIL import Image
from torch import nn from torch import nn
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
from vllm.config import ModelConfig
from vllm.inputs import InputContext from vllm.inputs import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of, from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of,
json_map_leaves, resolve_mm_processor_kwargs) json_map_leaves, resolve_mm_processor_kwargs)
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.sequence import SequenceGroupMetadata
logger = init_logger(__name__) logger = init_logger(__name__)
NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor] NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor]
...@@ -151,6 +155,30 @@ Note: ...@@ -151,6 +155,30 @@ Note:
Read more on that :ref:`here <adding_multimodal_plugin>`. Read more on that :ref:`here <adding_multimodal_plugin>`.
""" """
class PlaceholderRange(TypedDict):
"""
Placeholder location information for multi-modal data.
For example:
Prompt: AAAA BBBB What is in these images?
Images A and B will have:
A: { "offset": 0, "length": 4 }
B: { "offset": 5, "length": 4 }
"""
offset: int
"""The start index of the placeholder in the prompt."""
length: int
"""The length of the placeholder."""
MultiModalPlaceholderDict = Mapping[str, List[PlaceholderRange]]
"""
A dictionary containing placeholder ranges.
"""
MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]], MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]],
MultiModalInputs] MultiModalInputs]
""" """
...@@ -243,7 +271,7 @@ class MultiModalPlugin(ABC): ...@@ -243,7 +271,7 @@ class MultiModalPlugin(ABC):
return wrapper return wrapper
def map_input(self, model_config: ModelConfig, def map_input(self, model_config: "ModelConfig",
data: MultiModalData[object], data: MultiModalData[object],
mm_processor_kwargs: Dict[str, Any]) -> MultiModalInputs: mm_processor_kwargs: Dict[str, Any]) -> MultiModalInputs:
""" """
...@@ -332,7 +360,7 @@ class MultiModalPlugin(ABC): ...@@ -332,7 +360,7 @@ class MultiModalPlugin(ABC):
return wrapper return wrapper
def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
""" """
Get the maximum number of multi-modal tokens Get the maximum number of multi-modal tokens
for profiling the memory usage of a model. for profiling the memory usage of a model.
...@@ -366,3 +394,179 @@ class MultiModalPlugin(ABC): ...@@ -366,3 +394,179 @@ class MultiModalPlugin(ABC):
self._validate_max_multimodal_tokens(max_mm_tokens) self._validate_max_multimodal_tokens(max_mm_tokens)
return max_mm_tokens return max_mm_tokens
class MultiModalPlaceholderMap:
"""
Relates multi-modal embeddings to their corresponding placeholders.
"""
class IndexMap(NamedTuple):
src: List[int]
dest: List[int]
src_ranges: List[range]
"""
The indices of the multi-modal embeddings that will replace the
corresponding placeholder embeddings pointed to by ``dest_ranges``.
"""
src_len: int
"""
The total number of flattened multi-modal embeddings.
"""
dest_ranges: List[range]
"""
The indices of the placeholder embeddings that will be replaced by the
multimodal embeddings.
"""
dest_len: int
"""
The total number of embeddings in the destination tensor.
"""
def __init__(self):
self.src_ranges = []
self.src_len = 0
self.dest_ranges = []
self.dest_len = 0
@classmethod
def from_seq_group(
cls, seq_group: "SequenceGroupMetadata", positions: range
) -> Tuple[Optional[MultiModalDataDict], Dict[str,
"MultiModalPlaceholderMap"]]:
"""
Returns the multi-modal items that intersect with the portion of a
prompt (``seq_group``) represented by ``positions``, as well as a
``MultiModalPlaceholderMap`` that relates the multi-modal embedding
vectors to their corresponding placeholders.
Consider the following scenarios:
Prompt: |AAAA BBBB What's in these images?|
Positions: |.................................|
images = [A, B]
src_ranges = [(0, 4), (4, 8)]
dest_ranges = [(0, 4), (5, 9)]
Prompt: |AAAA BBBB What's in these images?|
Positions: | ..... |
images = [A, B]
src_ranges = [(2, 4), (4, 6)]
dest_ranges = [(0, 2), (3, 5)]
Prompt: |AAAA BBBB What's in these images?|
Positions: | ......... |
images = [B]
src_ranges = [(0, 4)]
dest_ranges = [(0, 4)]
Prompt: |AAAA BBBB What's in these images?|
Positions: | .......................|
images = []
src_ranges = []
dest_ranges = []
"""
if (not seq_group.multi_modal_data
or not seq_group.multi_modal_placeholders):
return seq_group.multi_modal_data, {}
mm_data = {**seq_group.multi_modal_data}
placeholder_maps: Dict[str, MultiModalPlaceholderMap] = defaultdict(
MultiModalPlaceholderMap)
for modality, placeholders in seq_group.multi_modal_placeholders.items(
):
mm_items = mm_data.pop(modality)
if not isinstance(mm_items, list):
mm_items = [mm_items]
if positions:
intersecting_items = placeholder_maps[
modality].append_items_from_seq_group(
positions, mm_items, placeholders)
if intersecting_items:
mm_data[modality] = intersecting_items
return mm_data, placeholder_maps
def append_items_from_seq_group(
self, positions: range, multi_modal_items: List[_T],
multi_modal_placeholders: List[PlaceholderRange]) -> List[_T]:
"""
Adds the multi-modal items that intersect ```positions`` to this
placeholder map and returns the intersecting items.
"""
intersecting_items = []
if len(multi_modal_items) != len(multi_modal_placeholders):
raise ValueError(
"Multi-modal placeholders and items must have the same length."
)
for placeholder_dict, mm_item in zip(multi_modal_placeholders,
multi_modal_items):
placeholder = range(
placeholder_dict["offset"],
placeholder_dict["offset"] + placeholder_dict["length"])
intersection = range(max(positions.start, placeholder.start),
min(positions.stop, placeholder.stop))
if not intersection:
# Skip this multi-modal item.
continue
token_embedding_range = range(intersection.start - positions.start,
intersection.stop - positions.start)
multimodal_embedding_range = range(
intersection.start - placeholder.start + self.src_len,
intersection.stop - placeholder.start + self.src_len)
intersecting_items.append(mm_item)
self.dest_ranges.append(token_embedding_range)
self.src_ranges.append(multimodal_embedding_range)
self.src_len += len(placeholder)
self.dest_len += len(positions)
return intersecting_items
def extend(self, other: "MultiModalPlaceholderMap"):
"""
Adds the placeholders from another ``MultiModalPlaceholderMap`` to this
instance based on the source and destination tensors being
concatenated.
"""
self.src_ranges.extend(
range(self.src_len + r.start, self.src_len + r.stop)
for r in other.src_ranges)
self.src_len += other.src_len
self.dest_ranges.extend(
range(self.dest_len + r.start, self.dest_len + r.stop)
for r in other.dest_ranges)
self.dest_len += other.dest_len
def index_map(self) -> "IndexMap":
"""
Finalizes the placeholder map into lists of indices that can be used to
index the source and destination tensors.
"""
src_indices = [i for r in self.src_ranges for i in r]
dest_indices = [i for r in self.dest_ranges for i in r]
if len(src_indices) != len(dest_indices):
raise ValueError(
f"The number of source ({len(src_indices)}) and destination "
f"indices ({len(dest_indices)}) must be the same.")
return MultiModalPlaceholderMap.IndexMap(src=src_indices,
dest=dest_indices)
from functools import lru_cache from functools import lru_cache
from typing import Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional
import torch import torch
from PIL import Image from PIL import Image
from transformers.image_processing_base import BatchFeature from transformers.image_processing_base import BatchFeature
from vllm.config import ModelConfig
from vllm.inputs.registry import InputContext from vllm.inputs.registry import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.processor import get_image_processor from vllm.transformers_utils.processor import get_image_processor
...@@ -13,6 +12,9 @@ from vllm.utils import is_list_of ...@@ -13,6 +12,9 @@ from vllm.utils import is_list_of
from .base import MultiModalData, MultiModalInputs, MultiModalPlugin from .base import MultiModalData, MultiModalInputs, MultiModalPlugin
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__) logger = init_logger(__name__)
cached_get_image_processor = lru_cache(get_image_processor) cached_get_image_processor = lru_cache(get_image_processor)
...@@ -26,7 +28,7 @@ class ImagePlugin(MultiModalPlugin): ...@@ -26,7 +28,7 @@ class ImagePlugin(MultiModalPlugin):
def _get_hf_image_processor( def _get_hf_image_processor(
self, self,
model_config: ModelConfig, model_config: "ModelConfig",
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
): ):
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
......
import functools import functools
from collections import UserDict from collections import UserDict
from typing import Any, Dict, Mapping, Optional, Sequence from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence
from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from .audio import AudioPlugin from .audio import AudioPlugin
...@@ -11,6 +10,9 @@ from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs, ...@@ -11,6 +10,9 @@ from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
from .image import ImagePlugin from .image import ImagePlugin
from .video import VideoPlugin from .video import VideoPlugin
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -20,7 +22,7 @@ class _MultiModalLimits(UserDict): ...@@ -20,7 +22,7 @@ class _MultiModalLimits(UserDict):
when attempting to access a model that does not exist. when attempting to access a model that does not exist.
""" """
def __getitem__(self, key: ModelConfig) -> Dict[str, int]: def __getitem__(self, key: "ModelConfig") -> Dict[str, int]:
try: try:
return super().__getitem__(key) return super().__getitem__(key)
except KeyError as exc: except KeyError as exc:
...@@ -98,7 +100,7 @@ class MultiModalRegistry: ...@@ -98,7 +100,7 @@ class MultiModalRegistry:
def map_input( def map_input(
self, self,
model_config: ModelConfig, model_config: "ModelConfig",
data: MultiModalDataDict, data: MultiModalDataDict,
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
...@@ -139,7 +141,7 @@ class MultiModalRegistry: ...@@ -139,7 +141,7 @@ class MultiModalRegistry:
return MultiModalInputs(merged_dict) return MultiModalInputs(merged_dict)
def create_input_mapper(self, model_config: ModelConfig): def create_input_mapper(self, model_config: "ModelConfig"):
""" """
Create an input mapper (see :meth:`map_input`) for a specific model. Create an input mapper (see :meth:`map_input`) for a specific model.
""" """
...@@ -177,7 +179,7 @@ class MultiModalRegistry: ...@@ -177,7 +179,7 @@ class MultiModalRegistry:
""" """
return self.register_max_multimodal_tokens("image", max_mm_tokens) return self.register_max_multimodal_tokens("image", max_mm_tokens)
def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
""" """
Get the maximum number of multi-modal tokens Get the maximum number of multi-modal tokens
for profiling the memory usage of a model. for profiling the memory usage of a model.
...@@ -195,7 +197,7 @@ class MultiModalRegistry: ...@@ -195,7 +197,7 @@ class MultiModalRegistry:
def init_mm_limits_per_prompt( def init_mm_limits_per_prompt(
self, self,
model_config: ModelConfig, model_config: "ModelConfig",
) -> None: ) -> None:
""" """
Initialize the maximum number of multi-modal input instances for each Initialize the maximum number of multi-modal input instances for each
...@@ -231,7 +233,7 @@ class MultiModalRegistry: ...@@ -231,7 +233,7 @@ class MultiModalRegistry:
def get_mm_limits_per_prompt( def get_mm_limits_per_prompt(
self, self,
model_config: ModelConfig, model_config: "ModelConfig",
) -> Mapping[str, int]: ) -> Mapping[str, int]:
""" """
Get the maximum number of multi-modal input instances for each modality Get the maximum number of multi-modal input instances for each modality
......
...@@ -10,7 +10,7 @@ from PIL import Image ...@@ -10,7 +10,7 @@ from PIL import Image
from vllm.connections import global_http_connection from vllm.connections import global_http_connection
from vllm.envs import VLLM_AUDIO_FETCH_TIMEOUT, VLLM_IMAGE_FETCH_TIMEOUT from vllm.envs import VLLM_AUDIO_FETCH_TIMEOUT, VLLM_IMAGE_FETCH_TIMEOUT
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal.base import MultiModalDataDict from vllm.multimodal.base import MultiModalDataDict, PlaceholderRange
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -258,7 +258,7 @@ def repeat_and_pad_placeholder_tokens( ...@@ -258,7 +258,7 @@ def repeat_and_pad_placeholder_tokens(
repeat_count: Union[int, List[int]], repeat_count: Union[int, List[int]],
pad_token_left: Optional[int] = None, pad_token_left: Optional[int] = None,
pad_token_right: Optional[int] = None, pad_token_right: Optional[int] = None,
) -> Tuple[Optional[str], List[int]]: ) -> Tuple[Optional[str], List[int], List[PlaceholderRange]]:
if isinstance(repeat_count, int): if isinstance(repeat_count, int):
repeat_count = [repeat_count] repeat_count = [repeat_count]
...@@ -301,6 +301,7 @@ def repeat_and_pad_placeholder_tokens( ...@@ -301,6 +301,7 @@ def repeat_and_pad_placeholder_tokens(
new_prompt += prompt_parts[-1] new_prompt += prompt_parts[-1]
new_token_ids: List[int] = [] new_token_ids: List[int] = []
placeholder_ranges: List[PlaceholderRange] = []
placeholder_token_idx = 0 placeholder_token_idx = 0
for i, token in enumerate(prompt_token_ids): for i, token in enumerate(prompt_token_ids):
if token == placeholder_token_id: if token == placeholder_token_id:
...@@ -310,6 +311,10 @@ def repeat_and_pad_placeholder_tokens( ...@@ -310,6 +311,10 @@ def repeat_and_pad_placeholder_tokens(
pad_token_left=pad_token_left, pad_token_left=pad_token_left,
pad_token_right=pad_token_right, pad_token_right=pad_token_right,
) )
placeholder_ranges.append({
"offset": len(new_token_ids),
"length": len(replacement_ids)
})
new_token_ids.extend(replacement_ids) new_token_ids.extend(replacement_ids)
placeholder_token_idx += 1 placeholder_token_idx += 1
...@@ -320,4 +325,14 @@ def repeat_and_pad_placeholder_tokens( ...@@ -320,4 +325,14 @@ def repeat_and_pad_placeholder_tokens(
else: else:
new_token_ids.append(token) new_token_ids.append(token)
return new_prompt, new_token_ids return new_prompt, new_token_ids, placeholder_ranges
def consecutive_placeholder_ranges(num_items: int,
item_size: int) -> List[PlaceholderRange]:
"""Returns a list of consecutive PlaceholderRanges of a fixed size"""
return [
PlaceholderRange(offset=i * item_size, length=item_size)
for i in range(num_items)
]
from functools import lru_cache from functools import lru_cache
from typing import Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import numpy as np import numpy as np
from vllm.config import ModelConfig
from vllm.inputs.registry import InputContext from vllm.inputs.registry import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.processor import get_video_processor from vllm.transformers_utils.processor import get_video_processor
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import is_list_of
from .base import MultiModalData, MultiModalInputs from .base import MultiModalData, MultiModalInputs
from .image import ImagePlugin from .image import ImagePlugin
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__) logger = init_logger(__name__)
cached_get_video_processor = lru_cache(get_video_processor) cached_get_video_processor = lru_cache(get_video_processor)
...@@ -38,7 +39,7 @@ class VideoPlugin(ImagePlugin): ...@@ -38,7 +39,7 @@ class VideoPlugin(ImagePlugin):
def _get_hf_video_processor( def _get_hf_video_processor(
self, self,
model_config: ModelConfig, model_config: "ModelConfig",
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
): ):
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
...@@ -56,7 +57,10 @@ class VideoPlugin(ImagePlugin): ...@@ -56,7 +57,10 @@ class VideoPlugin(ImagePlugin):
) -> MultiModalInputs: ) -> MultiModalInputs:
model_config = ctx.model_config model_config = ctx.model_config
if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray): if isinstance(data, list) and len(data) == 1:
data = data[0]
if isinstance(data, np.ndarray):
video_processor = self._get_hf_video_processor( video_processor = self._get_hf_video_processor(
model_config, model_config,
mm_processor_kwargs, mm_processor_kwargs,
......
...@@ -15,13 +15,13 @@ import torch ...@@ -15,13 +15,13 @@ import torch
from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.inputs.parse import is_encoder_decoder_inputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.inputs import SingletonInputs from vllm.inputs import SingletonInputs
from vllm.multimodal.base import MultiModalDataDict
VLLM_TOKEN_ID_ARRAY_TYPE = "l" VLLM_TOKEN_ID_ARRAY_TYPE = "l"
...@@ -485,7 +485,7 @@ class Sequence: ...@@ -485,7 +485,7 @@ class Sequence:
return cast(List[int], self.inputs.get(prompt_token_ids_key)) return cast(List[int], self.inputs.get(prompt_token_ids_key))
@property @property
def multi_modal_data(self) -> "MultiModalDataDict": def multi_modal_data(self) -> MultiModalDataDict:
inputs = self.inputs inputs = self.inputs
if (inputs.get("multi_modal_data") if (inputs.get("multi_modal_data")
...@@ -495,11 +495,15 @@ class Sequence: ...@@ -495,11 +495,15 @@ class Sequence:
) )
return cast( return cast(
"MultiModalDataDict", MultiModalDataDict,
(inputs.get("multi_modal_data") (inputs.get("multi_modal_data")
or inputs.get("encoder_multi_modal_data") or {}), or inputs.get("encoder_multi_modal_data") or {}),
) )
@property
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
return self.inputs.get("multi_modal_placeholders") or {}
@property @property
def mm_processor_kwargs(self) -> Dict[str, Any]: def mm_processor_kwargs(self) -> Dict[str, Any]:
return self.inputs.get("mm_processor_kwargs") or {} return self.inputs.get("mm_processor_kwargs") or {}
...@@ -728,9 +732,13 @@ class SequenceGroup: ...@@ -728,9 +732,13 @@ class SequenceGroup:
if self.encoder_seq is not None else None) if self.encoder_seq is not None else None)
@property @property
def multi_modal_data(self) -> "MultiModalDataDict": def multi_modal_data(self) -> MultiModalDataDict:
return self.first_seq.multi_modal_data return self.first_seq.multi_modal_data
@property
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
return self.first_seq.multi_modal_placeholders
@property @property
def mm_processor_kwargs(self) -> Dict[str, Any]: def mm_processor_kwargs(self) -> Dict[str, Any]:
return self.first_seq.mm_processor_kwargs return self.first_seq.mm_processor_kwargs
...@@ -946,6 +954,7 @@ class SequenceGroupMetadata( ...@@ -946,6 +954,7 @@ class SequenceGroupMetadata(
# "MultiModalDataDict" types. We have to use Any due to msgspec # "MultiModalDataDict" types. We have to use Any due to msgspec
# doesn't allow to have union of 2 different dicts. # doesn't allow to have union of 2 different dicts.
multi_modal_data: Optional[Any] = None multi_modal_data: Optional[Any] = None
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None
encoder_seq_data: Optional[SequenceData] = None encoder_seq_data: Optional[SequenceData] = None
cross_block_table: Optional[List[int]] = None cross_block_table: Optional[List[int]] = None
......
import dataclasses import dataclasses
import weakref import weakref
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
...@@ -16,7 +17,7 @@ from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding ...@@ -16,7 +17,7 @@ from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs) MultiModalInputs, MultiModalPlaceholderMap)
from vllm.sequence import (IntermediateTensors, SequenceData, from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.config import uses_mrope
...@@ -148,9 +149,18 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -148,9 +149,18 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
query_lens=seq_lens, query_lens=seq_lens,
) )
def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data, def _compute_multi_modal_input(self, seq_group: SequenceGroupMetadata,
computed_len: int, seq_data: SequenceData, computed_len: int,
mm_processor_kwargs: Dict[str, Any]): mm_processor_kwargs: Dict[str, Any]):
# NOTE: mm_data only includes the subset of multi-modal items that
# intersect with the current prefill positions.
mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
seq_group, range(computed_len, len(seq_data.get_token_ids())))
if not mm_data:
return
mm_kwargs = self.multi_modal_input_mapper(mm_data, mm_processor_kwargs) mm_kwargs = self.multi_modal_input_mapper(mm_data, mm_processor_kwargs)
# special processing for mrope position deltas. # special processing for mrope position deltas.
...@@ -179,7 +189,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -179,7 +189,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
context_len=computed_len, context_len=computed_len,
) )
seq_data.mrope_position_delta = mrope_position_delta seq_data.mrope_position_delta = mrope_position_delta
return mm_kwargs, mrope_positions return mm_kwargs, placeholder_maps, mrope_positions
def _prepare_prompt( def _prepare_prompt(
self, self,
...@@ -194,6 +204,9 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -194,6 +204,9 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
slot_mapping: List[int] = [] slot_mapping: List[int] = []
seq_lens: List[int] = [] seq_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = [] multi_modal_inputs_list: List[MultiModalInputs] = []
multi_modal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt assert seq_group_metadata.is_prompt
...@@ -210,11 +223,15 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -210,11 +223,15 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
input_tokens.extend(prompt_tokens) # Token ids input_tokens.extend(prompt_tokens) # Token ids
mrope_positions = None mrope_positions = None
if (mm_data := seq_group_metadata.multi_modal_data): if seq_group_metadata.multi_modal_data:
mm_kwargs, mrope_positions = self._compute_multi_modal_input( mm_kwargs, placeholder_maps, mrope_positions = self \
seq_data, mm_data, computed_len, ._compute_multi_modal_input(
seq_group_metadata, seq_data, computed_len,
seq_group_metadata.mm_processor_kwargs) seq_group_metadata.mm_processor_kwargs)
multi_modal_inputs_list.append(mm_kwargs) multi_modal_inputs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend(
placeholder_map)
# Token position ids # Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt # NOTE(woosuk): Here we assume that the first token in the prompt
...@@ -264,6 +281,11 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -264,6 +281,11 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
slot_mapping = torch.tensor(slot_mapping, slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long, dtype=torch.long,
device=self.device) # type: ignore device=self.device) # type: ignore
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
multi_modal_placeholder_maps.items()
}
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=True, is_prompt=True,
...@@ -275,6 +297,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -275,6 +297,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
num_decode_tokens=0, num_decode_tokens=0,
block_tables=torch.tensor([]), block_tables=torch.tensor([]),
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
) )
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
...@@ -366,6 +389,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -366,6 +389,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=False, is_prompt=False,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
max_decode_seq_len=max_decode_seq_len, max_decode_seq_len=max_decode_seq_len,
......
...@@ -306,13 +306,12 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -306,13 +306,12 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
(group_id < max_num_batched_tokens % max_num_seqs)) (group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len batch_size += seq_len
decoder_seq_data, decoder_dummy_multi_modal_data \ decoder_dummy_data = self.input_registry \
= self.input_registry.dummy_data_for_profiling( .dummy_data_for_profiling(self.model_config,
self.model_config,
seq_len, seq_len,
self.mm_registry, self.mm_registry,
is_encoder_data=False) is_encoder_data=False)
encoder_seq_data, encoder_dummy_multi_modal_data \ encoder_dummy_data \
= self.input_registry.dummy_data_for_profiling( = self.input_registry.dummy_data_for_profiling(
self.model_config, self.model_config,
seq_len, seq_len,
...@@ -320,26 +319,31 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -320,26 +319,31 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
is_encoder_data=True) is_encoder_data=True)
# Having more tokens is over-conservative but otherwise fine # Having more tokens is over-conservative but otherwise fine
assert len(decoder_seq_data.prompt_token_ids) >= seq_len, ( assert len(
decoder_dummy_data.seq_data.prompt_token_ids
) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, " f"Expected at least {seq_len} dummy tokens for profiling, "
f"but got: {len(decoder_seq_data.prompt_token_ids)}") f"but got: {len(decoder_dummy_data.seq_data.prompt_token_ids)}"
)
assert decoder_dummy_multi_modal_data is None or \ assert decoder_dummy_data.multi_modal_data is None or \
encoder_dummy_multi_modal_data is None, ( encoder_dummy_data.multi_modal_data is None, (
"Multi-modal data can't be provided in both encoder and decoder" "Multi-modal data can't be provided in both encoder and decoder"
) )
seq = SequenceGroupMetadata( seq = SequenceGroupMetadata(
request_id=str(group_id), request_id=str(group_id),
is_prompt=True, is_prompt=True,
seq_data={group_id: decoder_seq_data}, seq_data={group_id: decoder_dummy_data.seq_data},
sampling_params=sampling_params, sampling_params=sampling_params,
block_tables=None, block_tables=None,
encoder_seq_data=encoder_seq_data, encoder_seq_data=encoder_dummy_data.seq_data,
cross_block_table=None, cross_block_table=None,
multi_modal_data=decoder_dummy_multi_modal_data multi_modal_data=decoder_dummy_data.multi_modal_data
or encoder_dummy_multi_modal_data, or encoder_dummy_data.multi_modal_data,
) multi_modal_placeholders=decoder_dummy_data.
multi_modal_placeholders
or encoder_dummy_data.multi_modal_placeholders)
seqs.append(seq) seqs.append(seq)
# Run the model with the dummy inputs. # Run the model with the dummy inputs.
......
...@@ -40,7 +40,8 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig ...@@ -40,7 +40,8 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs, MultiModalRegistry) MultiModalInputs, MultiModalPlaceholderMap,
MultiModalRegistry)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
...@@ -242,6 +243,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -242,6 +243,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Multi-modal inputs. # Multi-modal inputs.
multi_modal_inputs: Optional[MultiModalInputs] = None, multi_modal_inputs: Optional[MultiModalInputs] = None,
multi_modal_placeholder_maps: Optional[Dict[
str, MultiModalPlaceholderMap]] = None,
# Whether the prefix cache is hit (prefill only). # Whether the prefix cache is hit (prefill only).
prefix_cache_hit: bool = False, prefix_cache_hit: bool = False,
...@@ -361,6 +364,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -361,6 +364,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.prompt_adapter_request = prompt_adapter_request self.prompt_adapter_request = prompt_adapter_request
self.multi_modal_inputs = multi_modal_inputs self.multi_modal_inputs = multi_modal_inputs
self.multi_modal_placeholder_maps = multi_modal_placeholder_maps
self.prefix_cache_hit = prefix_cache_hit self.prefix_cache_hit = prefix_cache_hit
self.n_seqs = len(self.seq_ids) self.n_seqs = len(self.seq_ids)
...@@ -635,7 +639,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -635,7 +639,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
seq_group_metadata: SequenceGroupMetadata): seq_group_metadata: SequenceGroupMetadata):
"""If multi-modal data is given, add it to the input.""" """If multi-modal data is given, add it to the input."""
mm_data = seq_group_metadata.multi_modal_data # NOTE: mm_data only includes the subset of multi-modal items that
# intersect with the current prefill positions.
positions = inter_data.input_positions[0]
mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
seq_group_metadata,
range(positions[0], positions[0] + len(positions)))
if not mm_data: if not mm_data:
return return
...@@ -643,6 +652,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -643,6 +652,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
mm_data, mm_data,
mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs) mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs)
inter_data.multi_modal_inputs = mm_kwargs inter_data.multi_modal_inputs = mm_kwargs
inter_data.multi_modal_placeholder_maps = placeholder_maps
# special processing for mrope position deltas. # special processing for mrope position deltas.
if self.runner.model_is_mrope: if self.runner.model_is_mrope:
...@@ -1255,7 +1265,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1255,7 +1265,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
(group_id < max_num_batched_tokens % max_num_seqs)) (group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len batch_size += seq_len
seq_data, dummy_multi_modal_data = self.input_registry \ dummy_data = self.input_registry \
.dummy_data_for_profiling(self.model_config, .dummy_data_for_profiling(self.model_config,
seq_len, seq_len,
self.mm_registry) self.mm_registry)
...@@ -1263,12 +1273,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1263,12 +1273,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
seq = SequenceGroupMetadata( seq = SequenceGroupMetadata(
request_id=str(group_id), request_id=str(group_id),
is_prompt=True, is_prompt=True,
seq_data={group_id: seq_data}, seq_data={group_id: dummy_data.seq_data},
sampling_params=sampling_params, sampling_params=sampling_params,
block_tables=None, block_tables=None,
lora_request=dummy_lora_requests_per_seq[group_id] lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None, if dummy_lora_requests_per_seq else None,
multi_modal_data=dummy_multi_modal_data, multi_modal_data=dummy_data.multi_modal_data,
multi_modal_placeholders=dummy_data.multi_modal_placeholders,
) )
seqs.append(seq) seqs.append(seq)
......
...@@ -46,9 +46,8 @@ def _init_attn_metadata_from_tensor_dict( ...@@ -46,9 +46,8 @@ def _init_attn_metadata_from_tensor_dict(
# Extract the fields used to create AttentionMetadata. # Extract the fields used to create AttentionMetadata.
valid_attn_kwargs = {} valid_attn_kwargs = {}
for field in dataclasses.fields(attn_backend.get_metadata_cls()): for field in dataclasses.fields(attn_backend.get_metadata_cls()):
val = tensor_dict.pop(field.name, None) if field.name in tensor_dict:
if val is not None: valid_attn_kwargs[field.name] = tensor_dict.pop(field.name)
valid_attn_kwargs[field.name] = val
attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
tensor_dict["attn_metadata"] = attn_metadata tensor_dict["attn_metadata"] = attn_metadata
......
from typing import List, NamedTuple, Optional, Tuple from collections import defaultdict
from typing import Dict, List, NamedTuple, Optional, Tuple
import openvino as ov import openvino as ov
import torch import torch
...@@ -14,7 +15,7 @@ from vllm.model_executor import SamplingMetadata ...@@ -14,7 +15,7 @@ from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.openvino import get_model from vllm.model_executor.model_loader.openvino import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs) MultiModalInputs, MultiModalPlaceholderMap)
from vllm.sequence import SequenceGroupMetadata from vllm.sequence import SequenceGroupMetadata
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -115,6 +116,9 @@ class OpenVINOModelRunner: ...@@ -115,6 +116,9 @@ class OpenVINOModelRunner:
past_lens: List[int] = [] past_lens: List[int] = []
query_lens: List[int] = [] query_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = [] multi_modal_inputs_list: List[MultiModalInputs] = []
multi_modal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
subsequence_begins: List[int] = [] subsequence_begins: List[int] = []
block_indices: List[int] = [] block_indices: List[int] = []
...@@ -168,15 +172,6 @@ class OpenVINOModelRunner: ...@@ -168,15 +172,6 @@ class OpenVINOModelRunner:
and self.sliding_window is None and self.sliding_window is None
and is_prompt) and is_prompt)
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(
mm_data,
mm_processor_kwargs=seq_group_metadata.
mm_processor_kwargs,
)
multi_modal_inputs_list.append(mm_kwargs)
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
# TODO(sang): Combine chunked prefill and prefix caching by # TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size. # only allowing multiple of block_size chunk size.
...@@ -220,7 +215,8 @@ class OpenVINOModelRunner: ...@@ -220,7 +215,8 @@ class OpenVINOModelRunner:
query_lens.append(query_len) query_lens.append(query_len)
input_tokens.extend(tokens) input_tokens.extend(tokens)
input_positions.extend(list(range(computed_len, seq_len))) positions_range = range(computed_len, seq_len)
input_positions.extend(list(positions_range))
past_lens.append(computed_len) past_lens.append(computed_len)
subsequence_begins.append(subsequence_begins[-1] + query_len) subsequence_begins.append(subsequence_begins[-1] + query_len)
...@@ -233,6 +229,22 @@ class OpenVINOModelRunner: ...@@ -233,6 +229,22 @@ class OpenVINOModelRunner:
), "seq_len: {}, computed_len: {}, query_len: {}".format( ), "seq_len: {}, computed_len: {}, query_len: {}".format(
seq_len, computed_len, query_len) seq_len, computed_len, query_len)
if seq_group_metadata.multi_modal_data:
# NOTE: mm_data only includes the subset of multi-modal
# items that intersect with the current prefill positions.
mm_data, placeholder_maps = MultiModalPlaceholderMap \
.from_seq_group(seq_group_metadata, positions_range)
mm_kwargs = self.multi_modal_input_mapper(
mm_data,
mm_processor_kwargs=seq_group_metadata.
mm_processor_kwargs)
multi_modal_inputs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend(
placeholder_map, )
max_query_len = max(query_lens) max_query_len = max(query_lens)
assert max_query_len > 0, "query_lens: {}".format(query_lens) assert max_query_len > 0, "query_lens: {}".format(query_lens)
...@@ -261,12 +273,19 @@ class OpenVINOModelRunner: ...@@ -261,12 +273,19 @@ class OpenVINOModelRunner:
max_context_len, dtype=torch.int32, max_context_len, dtype=torch.int32,
device=self.device) # type: ignore device=self.device) # type: ignore
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
multi_modal_placeholder_maps.items()
}
attn_metadata = self.attn_backend.make_openvino_metadata( attn_metadata = self.attn_backend.make_openvino_metadata(
past_lens=past_lens_tensor, past_lens=past_lens_tensor,
subsequence_begins=subsequence_begins_tensor, subsequence_begins=subsequence_begins_tensor,
block_indices=block_indices_tensor, block_indices=block_indices_tensor,
block_indices_begins=block_indices_begins_tensor, block_indices_begins=block_indices_begins_tensor,
max_context_len=max_context_len_tensor, max_context_len=max_context_len_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
) )
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
......
...@@ -184,6 +184,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -184,6 +184,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_prefill_tokens=batch_size * seq_len, num_prefill_tokens=batch_size * seq_len,
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
block_tables=None, block_tables=None,
context_lens=None, context_lens=None,
) )
...@@ -216,6 +217,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -216,6 +217,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=batch_size * seq_len, num_decode_tokens=batch_size * seq_len,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
block_tables=block_tables, block_tables=block_tables,
context_lens=context_lens, context_lens=context_lens,
) )
...@@ -360,6 +362,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -360,6 +362,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_prefill_tokens=0, # NOTE: This is not used. num_prefill_tokens=0, # NOTE: This is not used.
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
block_tables=None, block_tables=None,
context_lens=None, context_lens=None,
) )
...@@ -429,6 +432,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -429,6 +432,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=batch_size, num_decode_tokens=batch_size,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
block_tables=block_tables, block_tables=block_tables,
context_lens=context_lens, context_lens=context_lens,
) )
......
import dataclasses import dataclasses
import time import time
import weakref import weakref
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Type, TypeVar) Type, TypeVar)
...@@ -19,7 +20,8 @@ from vllm.model_executor import SamplingMetadataCache ...@@ -19,7 +20,8 @@ from vllm.model_executor import SamplingMetadataCache
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs, MultiModalRegistry) MultiModalInputs, MultiModalPlaceholderMap,
MultiModalRegistry)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import DeviceMemoryProfiler, make_tensor_with_pad from vllm.utils import DeviceMemoryProfiler, make_tensor_with_pad
...@@ -161,6 +163,9 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): ...@@ -161,6 +163,9 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
slot_mapping: List[int] = [] slot_mapping: List[int] = []
seq_lens: List[int] = [] seq_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = [] multi_modal_inputs_list: List[MultiModalInputs] = []
multi_modal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt assert seq_group_metadata.is_prompt
...@@ -179,7 +184,21 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): ...@@ -179,7 +184,21 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
# Token position ids # Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt # NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence. # is always the first token in the sequence.
input_positions.extend(list(range(computed_len, seq_len))) positions_range = range(computed_len, seq_len)
input_positions.extend(list(positions_range))
if seq_group_metadata.multi_modal_data:
# NOTE: mm_data only includes the subset of multi-modal items
# that intersect with the current prefill positions.
mm_data, placeholder_maps = MultiModalPlaceholderMap \
.from_seq_group(seq_group_metadata, positions_range)
mm_kwargs = self.runner.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend(
placeholder_map)
if seq_group_metadata.block_tables is None: if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized # During memory profiling, the block tables are not initialized
...@@ -220,6 +239,11 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): ...@@ -220,6 +239,11 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
slot_mapping = torch.tensor(slot_mapping, slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long, dtype=torch.long,
device=self.device) # type: ignore device=self.device) # type: ignore
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
multi_modal_placeholder_maps.items()
}
max_seqlen = max(seq_lens) max_seqlen = max(seq_lens)
tmp = [0] tmp = [0]
...@@ -230,6 +254,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): ...@@ -230,6 +254,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=True, is_prompt=True,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
seq_lens=seq_lens, seq_lens=seq_lens,
seqlen_q=seqlen_q, seqlen_q=seqlen_q,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
...@@ -313,6 +338,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): ...@@ -313,6 +338,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=False, is_prompt=False,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
seq_lens=seq_lens, seq_lens=seq_lens,
seqlen_q=torch.tensor([]), seqlen_q=torch.tensor([]),
max_seqlen=0, max_seqlen=0,
...@@ -450,7 +476,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): ...@@ -450,7 +476,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
(group_id < max_num_batched_tokens % max_num_seqs)) (group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len batch_size += seq_len
seq_data, dummy_multi_modal_data = self.input_registry \ dummy_data = self.input_registry \
.dummy_data_for_profiling(self.model_config, .dummy_data_for_profiling(self.model_config,
seq_len, seq_len,
self.mm_registry) self.mm_registry)
...@@ -458,12 +484,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): ...@@ -458,12 +484,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
seq = SequenceGroupMetadata( seq = SequenceGroupMetadata(
request_id=str(group_id), request_id=str(group_id),
is_prompt=True, is_prompt=True,
seq_data={group_id: seq_data}, seq_data={group_id: dummy_data.seq_data},
sampling_params=sampling_params, sampling_params=sampling_params,
block_tables=None, block_tables=None,
lora_request=None, lora_request=None,
multi_modal_data=dummy_multi_modal_data, multi_modal_data=dummy_data.multi_modal_data,
) multi_modal_placeholders=dummy_data.multi_modal_placeholders)
seqs.append(seq) seqs.append(seq)
# Run the model with the dummy inputs. # Run the model with the dummy inputs.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment