Unverified Commit 6c0b7f54 authored by Peter Salas's avatar Peter Salas Committed by GitHub
Browse files

[Core][VLM] Add precise multi-modal placeholder tracking (#8346)


Signed-off-by: default avatarPeter Salas <peter@fixie.ai>
parent d151fde8
import sys
from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from typing import (Any, Callable, Dict, List, Mapping, Optional, Tuple, Type,
TypedDict, TypeVar, Union, cast, final)
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping,
NamedTuple, Optional, Tuple, Type, TypedDict, TypeVar,
Union, cast, final)
import numpy as np
import torch
......@@ -11,12 +12,15 @@ from PIL import Image
from torch import nn
from typing_extensions import TypeAlias
from vllm.config import ModelConfig
from vllm.inputs import InputContext
from vllm.logger import init_logger
from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of,
json_map_leaves, resolve_mm_processor_kwargs)
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.sequence import SequenceGroupMetadata
logger = init_logger(__name__)
NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor]
......@@ -151,6 +155,30 @@ Note:
Read more on that :ref:`here <adding_multimodal_plugin>`.
"""
class PlaceholderRange(TypedDict):
"""
Placeholder location information for multi-modal data.
For example:
Prompt: AAAA BBBB What is in these images?
Images A and B will have:
A: { "offset": 0, "length": 4 }
B: { "offset": 5, "length": 4 }
"""
offset: int
"""The start index of the placeholder in the prompt."""
length: int
"""The length of the placeholder."""
MultiModalPlaceholderDict = Mapping[str, List[PlaceholderRange]]
"""
A dictionary containing placeholder ranges.
"""
MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]],
MultiModalInputs]
"""
......@@ -243,7 +271,7 @@ class MultiModalPlugin(ABC):
return wrapper
def map_input(self, model_config: ModelConfig,
def map_input(self, model_config: "ModelConfig",
data: MultiModalData[object],
mm_processor_kwargs: Dict[str, Any]) -> MultiModalInputs:
"""
......@@ -332,7 +360,7 @@ class MultiModalPlugin(ABC):
return wrapper
def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int:
def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
"""
Get the maximum number of multi-modal tokens
for profiling the memory usage of a model.
......@@ -366,3 +394,179 @@ class MultiModalPlugin(ABC):
self._validate_max_multimodal_tokens(max_mm_tokens)
return max_mm_tokens
class MultiModalPlaceholderMap:
"""
Relates multi-modal embeddings to their corresponding placeholders.
"""
class IndexMap(NamedTuple):
src: List[int]
dest: List[int]
src_ranges: List[range]
"""
The indices of the multi-modal embeddings that will replace the
corresponding placeholder embeddings pointed to by ``dest_ranges``.
"""
src_len: int
"""
The total number of flattened multi-modal embeddings.
"""
dest_ranges: List[range]
"""
The indices of the placeholder embeddings that will be replaced by the
multimodal embeddings.
"""
dest_len: int
"""
The total number of embeddings in the destination tensor.
"""
def __init__(self):
self.src_ranges = []
self.src_len = 0
self.dest_ranges = []
self.dest_len = 0
@classmethod
def from_seq_group(
cls, seq_group: "SequenceGroupMetadata", positions: range
) -> Tuple[Optional[MultiModalDataDict], Dict[str,
"MultiModalPlaceholderMap"]]:
"""
Returns the multi-modal items that intersect with the portion of a
prompt (``seq_group``) represented by ``positions``, as well as a
``MultiModalPlaceholderMap`` that relates the multi-modal embedding
vectors to their corresponding placeholders.
Consider the following scenarios:
Prompt: |AAAA BBBB What's in these images?|
Positions: |.................................|
images = [A, B]
src_ranges = [(0, 4), (4, 8)]
dest_ranges = [(0, 4), (5, 9)]
Prompt: |AAAA BBBB What's in these images?|
Positions: | ..... |
images = [A, B]
src_ranges = [(2, 4), (4, 6)]
dest_ranges = [(0, 2), (3, 5)]
Prompt: |AAAA BBBB What's in these images?|
Positions: | ......... |
images = [B]
src_ranges = [(0, 4)]
dest_ranges = [(0, 4)]
Prompt: |AAAA BBBB What's in these images?|
Positions: | .......................|
images = []
src_ranges = []
dest_ranges = []
"""
if (not seq_group.multi_modal_data
or not seq_group.multi_modal_placeholders):
return seq_group.multi_modal_data, {}
mm_data = {**seq_group.multi_modal_data}
placeholder_maps: Dict[str, MultiModalPlaceholderMap] = defaultdict(
MultiModalPlaceholderMap)
for modality, placeholders in seq_group.multi_modal_placeholders.items(
):
mm_items = mm_data.pop(modality)
if not isinstance(mm_items, list):
mm_items = [mm_items]
if positions:
intersecting_items = placeholder_maps[
modality].append_items_from_seq_group(
positions, mm_items, placeholders)
if intersecting_items:
mm_data[modality] = intersecting_items
return mm_data, placeholder_maps
def append_items_from_seq_group(
self, positions: range, multi_modal_items: List[_T],
multi_modal_placeholders: List[PlaceholderRange]) -> List[_T]:
"""
Adds the multi-modal items that intersect ```positions`` to this
placeholder map and returns the intersecting items.
"""
intersecting_items = []
if len(multi_modal_items) != len(multi_modal_placeholders):
raise ValueError(
"Multi-modal placeholders and items must have the same length."
)
for placeholder_dict, mm_item in zip(multi_modal_placeholders,
multi_modal_items):
placeholder = range(
placeholder_dict["offset"],
placeholder_dict["offset"] + placeholder_dict["length"])
intersection = range(max(positions.start, placeholder.start),
min(positions.stop, placeholder.stop))
if not intersection:
# Skip this multi-modal item.
continue
token_embedding_range = range(intersection.start - positions.start,
intersection.stop - positions.start)
multimodal_embedding_range = range(
intersection.start - placeholder.start + self.src_len,
intersection.stop - placeholder.start + self.src_len)
intersecting_items.append(mm_item)
self.dest_ranges.append(token_embedding_range)
self.src_ranges.append(multimodal_embedding_range)
self.src_len += len(placeholder)
self.dest_len += len(positions)
return intersecting_items
def extend(self, other: "MultiModalPlaceholderMap"):
"""
Adds the placeholders from another ``MultiModalPlaceholderMap`` to this
instance based on the source and destination tensors being
concatenated.
"""
self.src_ranges.extend(
range(self.src_len + r.start, self.src_len + r.stop)
for r in other.src_ranges)
self.src_len += other.src_len
self.dest_ranges.extend(
range(self.dest_len + r.start, self.dest_len + r.stop)
for r in other.dest_ranges)
self.dest_len += other.dest_len
def index_map(self) -> "IndexMap":
"""
Finalizes the placeholder map into lists of indices that can be used to
index the source and destination tensors.
"""
src_indices = [i for r in self.src_ranges for i in r]
dest_indices = [i for r in self.dest_ranges for i in r]
if len(src_indices) != len(dest_indices):
raise ValueError(
f"The number of source ({len(src_indices)}) and destination "
f"indices ({len(dest_indices)}) must be the same.")
return MultiModalPlaceholderMap.IndexMap(src=src_indices,
dest=dest_indices)
from functools import lru_cache
from typing import Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional
import torch
from PIL import Image
from transformers.image_processing_base import BatchFeature
from vllm.config import ModelConfig
from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.transformers_utils.processor import get_image_processor
......@@ -13,6 +12,9 @@ from vllm.utils import is_list_of
from .base import MultiModalData, MultiModalInputs, MultiModalPlugin
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__)
cached_get_image_processor = lru_cache(get_image_processor)
......@@ -26,7 +28,7 @@ class ImagePlugin(MultiModalPlugin):
def _get_hf_image_processor(
self,
model_config: ModelConfig,
model_config: "ModelConfig",
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
):
if mm_processor_kwargs is None:
......
import functools
from collections import UserDict
from typing import Any, Dict, Mapping, Optional, Sequence
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence
from vllm.config import ModelConfig
from vllm.logger import init_logger
from .audio import AudioPlugin
......@@ -11,6 +10,9 @@ from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
from .image import ImagePlugin
from .video import VideoPlugin
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__)
......@@ -20,7 +22,7 @@ class _MultiModalLimits(UserDict):
when attempting to access a model that does not exist.
"""
def __getitem__(self, key: ModelConfig) -> Dict[str, int]:
def __getitem__(self, key: "ModelConfig") -> Dict[str, int]:
try:
return super().__getitem__(key)
except KeyError as exc:
......@@ -98,7 +100,7 @@ class MultiModalRegistry:
def map_input(
self,
model_config: ModelConfig,
model_config: "ModelConfig",
data: MultiModalDataDict,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> MultiModalInputs:
......@@ -139,7 +141,7 @@ class MultiModalRegistry:
return MultiModalInputs(merged_dict)
def create_input_mapper(self, model_config: ModelConfig):
def create_input_mapper(self, model_config: "ModelConfig"):
"""
Create an input mapper (see :meth:`map_input`) for a specific model.
"""
......@@ -177,7 +179,7 @@ class MultiModalRegistry:
"""
return self.register_max_multimodal_tokens("image", max_mm_tokens)
def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int:
def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
"""
Get the maximum number of multi-modal tokens
for profiling the memory usage of a model.
......@@ -195,7 +197,7 @@ class MultiModalRegistry:
def init_mm_limits_per_prompt(
self,
model_config: ModelConfig,
model_config: "ModelConfig",
) -> None:
"""
Initialize the maximum number of multi-modal input instances for each
......@@ -231,7 +233,7 @@ class MultiModalRegistry:
def get_mm_limits_per_prompt(
self,
model_config: ModelConfig,
model_config: "ModelConfig",
) -> Mapping[str, int]:
"""
Get the maximum number of multi-modal input instances for each modality
......
......@@ -10,7 +10,7 @@ from PIL import Image
from vllm.connections import global_http_connection
from vllm.envs import VLLM_AUDIO_FETCH_TIMEOUT, VLLM_IMAGE_FETCH_TIMEOUT
from vllm.logger import init_logger
from vllm.multimodal.base import MultiModalDataDict
from vllm.multimodal.base import MultiModalDataDict, PlaceholderRange
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
logger = init_logger(__name__)
......@@ -258,7 +258,7 @@ def repeat_and_pad_placeholder_tokens(
repeat_count: Union[int, List[int]],
pad_token_left: Optional[int] = None,
pad_token_right: Optional[int] = None,
) -> Tuple[Optional[str], List[int]]:
) -> Tuple[Optional[str], List[int], List[PlaceholderRange]]:
if isinstance(repeat_count, int):
repeat_count = [repeat_count]
......@@ -301,6 +301,7 @@ def repeat_and_pad_placeholder_tokens(
new_prompt += prompt_parts[-1]
new_token_ids: List[int] = []
placeholder_ranges: List[PlaceholderRange] = []
placeholder_token_idx = 0
for i, token in enumerate(prompt_token_ids):
if token == placeholder_token_id:
......@@ -310,6 +311,10 @@ def repeat_and_pad_placeholder_tokens(
pad_token_left=pad_token_left,
pad_token_right=pad_token_right,
)
placeholder_ranges.append({
"offset": len(new_token_ids),
"length": len(replacement_ids)
})
new_token_ids.extend(replacement_ids)
placeholder_token_idx += 1
......@@ -320,4 +325,14 @@ def repeat_and_pad_placeholder_tokens(
else:
new_token_ids.append(token)
return new_prompt, new_token_ids
return new_prompt, new_token_ids, placeholder_ranges
def consecutive_placeholder_ranges(num_items: int,
item_size: int) -> List[PlaceholderRange]:
"""Returns a list of consecutive PlaceholderRanges of a fixed size"""
return [
PlaceholderRange(offset=i * item_size, length=item_size)
for i in range(num_items)
]
from functools import lru_cache
from typing import Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import numpy as np
from vllm.config import ModelConfig
from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.transformers_utils.processor import get_video_processor
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import is_list_of
from .base import MultiModalData, MultiModalInputs
from .image import ImagePlugin
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__)
cached_get_video_processor = lru_cache(get_video_processor)
......@@ -38,7 +39,7 @@ class VideoPlugin(ImagePlugin):
def _get_hf_video_processor(
self,
model_config: ModelConfig,
model_config: "ModelConfig",
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
):
if mm_processor_kwargs is None:
......@@ -56,7 +57,10 @@ class VideoPlugin(ImagePlugin):
) -> MultiModalInputs:
model_config = ctx.model_config
if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray):
if isinstance(data, list) and len(data) == 1:
data = data[0]
if isinstance(data, np.ndarray):
video_processor = self._get_hf_video_processor(
model_config,
mm_processor_kwargs,
......
......@@ -15,13 +15,13 @@ import torch
from vllm.inputs.parse import is_encoder_decoder_inputs
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
if TYPE_CHECKING:
from vllm.inputs import SingletonInputs
from vllm.multimodal.base import MultiModalDataDict
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
......@@ -485,7 +485,7 @@ class Sequence:
return cast(List[int], self.inputs.get(prompt_token_ids_key))
@property
def multi_modal_data(self) -> "MultiModalDataDict":
def multi_modal_data(self) -> MultiModalDataDict:
inputs = self.inputs
if (inputs.get("multi_modal_data")
......@@ -495,11 +495,15 @@ class Sequence:
)
return cast(
"MultiModalDataDict",
MultiModalDataDict,
(inputs.get("multi_modal_data")
or inputs.get("encoder_multi_modal_data") or {}),
)
@property
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
return self.inputs.get("multi_modal_placeholders") or {}
@property
def mm_processor_kwargs(self) -> Dict[str, Any]:
return self.inputs.get("mm_processor_kwargs") or {}
......@@ -728,9 +732,13 @@ class SequenceGroup:
if self.encoder_seq is not None else None)
@property
def multi_modal_data(self) -> "MultiModalDataDict":
def multi_modal_data(self) -> MultiModalDataDict:
return self.first_seq.multi_modal_data
@property
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
return self.first_seq.multi_modal_placeholders
@property
def mm_processor_kwargs(self) -> Dict[str, Any]:
return self.first_seq.mm_processor_kwargs
......@@ -946,6 +954,7 @@ class SequenceGroupMetadata(
# "MultiModalDataDict" types. We have to use Any due to msgspec
# doesn't allow to have union of 2 different dicts.
multi_modal_data: Optional[Any] = None
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
encoder_seq_data: Optional[SequenceData] = None
cross_block_table: Optional[List[int]] = None
......
import dataclasses
import weakref
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
......@@ -16,7 +17,7 @@ from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
MultiModalInputs, MultiModalPlaceholderMap)
from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata)
from vllm.transformers_utils.config import uses_mrope
......@@ -148,9 +149,18 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
query_lens=seq_lens,
)
def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data,
computed_len: int,
def _compute_multi_modal_input(self, seq_group: SequenceGroupMetadata,
seq_data: SequenceData, computed_len: int,
mm_processor_kwargs: Dict[str, Any]):
# NOTE: mm_data only includes the subset of multi-modal items that
# intersect with the current prefill positions.
mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
seq_group, range(computed_len, len(seq_data.get_token_ids())))
if not mm_data:
return
mm_kwargs = self.multi_modal_input_mapper(mm_data, mm_processor_kwargs)
# special processing for mrope position deltas.
......@@ -179,7 +189,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
context_len=computed_len,
)
seq_data.mrope_position_delta = mrope_position_delta
return mm_kwargs, mrope_positions
return mm_kwargs, placeholder_maps, mrope_positions
def _prepare_prompt(
self,
......@@ -194,6 +204,9 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
slot_mapping: List[int] = []
seq_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
multi_modal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
......@@ -210,11 +223,15 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
input_tokens.extend(prompt_tokens) # Token ids
mrope_positions = None
if (mm_data := seq_group_metadata.multi_modal_data):
mm_kwargs, mrope_positions = self._compute_multi_modal_input(
seq_data, mm_data, computed_len,
if seq_group_metadata.multi_modal_data:
mm_kwargs, placeholder_maps, mrope_positions = self \
._compute_multi_modal_input(
seq_group_metadata, seq_data, computed_len,
seq_group_metadata.mm_processor_kwargs)
multi_modal_inputs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend(
placeholder_map)
# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
......@@ -264,6 +281,11 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device) # type: ignore
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
multi_modal_placeholder_maps.items()
}
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
......@@ -275,6 +297,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
num_decode_tokens=0,
block_tables=torch.tensor([]),
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
......@@ -366,6 +389,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_decode_seq_len=max_decode_seq_len,
......
......@@ -306,13 +306,12 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
(group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
decoder_seq_data, decoder_dummy_multi_modal_data \
= self.input_registry.dummy_data_for_profiling(
self.model_config,
decoder_dummy_data = self.input_registry \
.dummy_data_for_profiling(self.model_config,
seq_len,
self.mm_registry,
is_encoder_data=False)
encoder_seq_data, encoder_dummy_multi_modal_data \
encoder_dummy_data \
= self.input_registry.dummy_data_for_profiling(
self.model_config,
seq_len,
......@@ -320,26 +319,31 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
is_encoder_data=True)
# Having more tokens is over-conservative but otherwise fine
assert len(decoder_seq_data.prompt_token_ids) >= seq_len, (
assert len(
decoder_dummy_data.seq_data.prompt_token_ids
) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but got: {len(decoder_seq_data.prompt_token_ids)}")
f"but got: {len(decoder_dummy_data.seq_data.prompt_token_ids)}"
)
assert decoder_dummy_multi_modal_data is None or \
encoder_dummy_multi_modal_data is None, (
assert decoder_dummy_data.multi_modal_data is None or \
encoder_dummy_data.multi_modal_data is None, (
"Multi-modal data can't be provided in both encoder and decoder"
)
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: decoder_seq_data},
seq_data={group_id: decoder_dummy_data.seq_data},
sampling_params=sampling_params,
block_tables=None,
encoder_seq_data=encoder_seq_data,
encoder_seq_data=encoder_dummy_data.seq_data,
cross_block_table=None,
multi_modal_data=decoder_dummy_multi_modal_data
or encoder_dummy_multi_modal_data,
)
multi_modal_data=decoder_dummy_data.multi_modal_data
or encoder_dummy_data.multi_modal_data,
multi_modal_placeholders=decoder_dummy_data.
multi_modal_placeholders
or encoder_dummy_data.multi_modal_placeholders)
seqs.append(seq)
# Run the model with the dummy inputs.
......
......@@ -40,7 +40,8 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs, MultiModalRegistry)
MultiModalInputs, MultiModalPlaceholderMap,
MultiModalRegistry)
from vllm.platforms import current_platform
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
......@@ -242,6 +243,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Multi-modal inputs.
multi_modal_inputs: Optional[MultiModalInputs] = None,
multi_modal_placeholder_maps: Optional[Dict[
str, MultiModalPlaceholderMap]] = None,
# Whether the prefix cache is hit (prefill only).
prefix_cache_hit: bool = False,
......@@ -361,6 +364,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.prompt_adapter_request = prompt_adapter_request
self.multi_modal_inputs = multi_modal_inputs
self.multi_modal_placeholder_maps = multi_modal_placeholder_maps
self.prefix_cache_hit = prefix_cache_hit
self.n_seqs = len(self.seq_ids)
......@@ -635,7 +639,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
seq_group_metadata: SequenceGroupMetadata):
"""If multi-modal data is given, add it to the input."""
mm_data = seq_group_metadata.multi_modal_data
# NOTE: mm_data only includes the subset of multi-modal items that
# intersect with the current prefill positions.
positions = inter_data.input_positions[0]
mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
seq_group_metadata,
range(positions[0], positions[0] + len(positions)))
if not mm_data:
return
......@@ -643,6 +652,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
mm_data,
mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs)
inter_data.multi_modal_inputs = mm_kwargs
inter_data.multi_modal_placeholder_maps = placeholder_maps
# special processing for mrope position deltas.
if self.runner.model_is_mrope:
......@@ -1255,7 +1265,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
(group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
seq_data, dummy_multi_modal_data = self.input_registry \
dummy_data = self.input_registry \
.dummy_data_for_profiling(self.model_config,
seq_len,
self.mm_registry)
......@@ -1263,12 +1273,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: seq_data},
seq_data={group_id: dummy_data.seq_data},
sampling_params=sampling_params,
block_tables=None,
lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None,
multi_modal_data=dummy_multi_modal_data,
multi_modal_data=dummy_data.multi_modal_data,
multi_modal_placeholders=dummy_data.multi_modal_placeholders,
)
seqs.append(seq)
......
......@@ -46,9 +46,8 @@ def _init_attn_metadata_from_tensor_dict(
# Extract the fields used to create AttentionMetadata.
valid_attn_kwargs = {}
for field in dataclasses.fields(attn_backend.get_metadata_cls()):
val = tensor_dict.pop(field.name, None)
if val is not None:
valid_attn_kwargs[field.name] = val
if field.name in tensor_dict:
valid_attn_kwargs[field.name] = tensor_dict.pop(field.name)
attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
tensor_dict["attn_metadata"] = attn_metadata
......
from typing import List, NamedTuple, Optional, Tuple
from collections import defaultdict
from typing import Dict, List, NamedTuple, Optional, Tuple
import openvino as ov
import torch
......@@ -14,7 +15,7 @@ from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.openvino import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
MultiModalInputs, MultiModalPlaceholderMap)
from vllm.sequence import SequenceGroupMetadata
logger = init_logger(__name__)
......@@ -115,6 +116,9 @@ class OpenVINOModelRunner:
past_lens: List[int] = []
query_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
multi_modal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
subsequence_begins: List[int] = []
block_indices: List[int] = []
......@@ -168,15 +172,6 @@ class OpenVINOModelRunner:
and self.sliding_window is None
and is_prompt)
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(
mm_data,
mm_processor_kwargs=seq_group_metadata.
mm_processor_kwargs,
)
multi_modal_inputs_list.append(mm_kwargs)
block_table = seq_group_metadata.block_tables[seq_id]
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
......@@ -220,7 +215,8 @@ class OpenVINOModelRunner:
query_lens.append(query_len)
input_tokens.extend(tokens)
input_positions.extend(list(range(computed_len, seq_len)))
positions_range = range(computed_len, seq_len)
input_positions.extend(list(positions_range))
past_lens.append(computed_len)
subsequence_begins.append(subsequence_begins[-1] + query_len)
......@@ -233,6 +229,22 @@ class OpenVINOModelRunner:
), "seq_len: {}, computed_len: {}, query_len: {}".format(
seq_len, computed_len, query_len)
if seq_group_metadata.multi_modal_data:
# NOTE: mm_data only includes the subset of multi-modal
# items that intersect with the current prefill positions.
mm_data, placeholder_maps = MultiModalPlaceholderMap \
.from_seq_group(seq_group_metadata, positions_range)
mm_kwargs = self.multi_modal_input_mapper(
mm_data,
mm_processor_kwargs=seq_group_metadata.
mm_processor_kwargs)
multi_modal_inputs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend(
placeholder_map, )
max_query_len = max(query_lens)
assert max_query_len > 0, "query_lens: {}".format(query_lens)
......@@ -261,12 +273,19 @@ class OpenVINOModelRunner:
max_context_len, dtype=torch.int32,
device=self.device) # type: ignore
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
multi_modal_placeholder_maps.items()
}
attn_metadata = self.attn_backend.make_openvino_metadata(
past_lens=past_lens_tensor,
subsequence_begins=subsequence_begins_tensor,
block_indices=block_indices_tensor,
block_indices_begins=block_indices_begins_tensor,
max_context_len=max_context_len_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
......
......@@ -184,6 +184,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_prefill_tokens=batch_size * seq_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
block_tables=None,
context_lens=None,
)
......@@ -216,6 +217,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_prefill_tokens=0,
num_decode_tokens=batch_size * seq_len,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
block_tables=block_tables,
context_lens=context_lens,
)
......@@ -360,6 +362,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_prefill_tokens=0, # NOTE: This is not used.
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
block_tables=None,
context_lens=None,
)
......@@ -429,6 +432,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
block_tables=block_tables,
context_lens=context_lens,
)
......
import dataclasses
import time
import weakref
from collections import defaultdict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Type, TypeVar)
......@@ -19,7 +20,8 @@ from vllm.model_executor import SamplingMetadataCache
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs, MultiModalRegistry)
MultiModalInputs, MultiModalPlaceholderMap,
MultiModalRegistry)
from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import DeviceMemoryProfiler, make_tensor_with_pad
......@@ -161,6 +163,9 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
slot_mapping: List[int] = []
seq_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
multi_modal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
......@@ -179,7 +184,21 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.extend(list(range(computed_len, seq_len)))
positions_range = range(computed_len, seq_len)
input_positions.extend(list(positions_range))
if seq_group_metadata.multi_modal_data:
# NOTE: mm_data only includes the subset of multi-modal items
# that intersect with the current prefill positions.
mm_data, placeholder_maps = MultiModalPlaceholderMap \
.from_seq_group(seq_group_metadata, positions_range)
mm_kwargs = self.runner.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend(
placeholder_map)
if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
......@@ -220,6 +239,11 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device) # type: ignore
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
multi_modal_placeholder_maps.items()
}
max_seqlen = max(seq_lens)
tmp = [0]
......@@ -230,6 +254,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
seq_lens=seq_lens,
seqlen_q=seqlen_q,
max_seqlen=max_seqlen,
......@@ -313,6 +338,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
seq_lens=seq_lens,
seqlen_q=torch.tensor([]),
max_seqlen=0,
......@@ -450,7 +476,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
(group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
seq_data, dummy_multi_modal_data = self.input_registry \
dummy_data = self.input_registry \
.dummy_data_for_profiling(self.model_config,
seq_len,
self.mm_registry)
......@@ -458,12 +484,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: seq_data},
seq_data={group_id: dummy_data.seq_data},
sampling_params=sampling_params,
block_tables=None,
lora_request=None,
multi_modal_data=dummy_multi_modal_data,
)
multi_modal_data=dummy_data.multi_modal_data,
multi_modal_placeholders=dummy_data.multi_modal_placeholders)
seqs.append(seq)
# Run the model with the dummy inputs.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment