Unverified Commit e0191a95 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[0/N] Rename `MultiModalInputs` to `MultiModalKwargs` (#10040)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent d7edca1d
...@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence ...@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence
from vllm.logger import init_logger from vllm.logger import init_logger
from .audio import AudioPlugin from .audio import AudioPlugin
from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs, from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalKwargs,
MultiModalPlugin, MultiModalTokensCalc, NestedTensors) MultiModalPlugin, MultiModalTokensCalc, NestedTensors)
from .image import ImagePlugin from .image import ImagePlugin
from .video import VideoPlugin from .video import VideoPlugin
...@@ -103,7 +103,7 @@ class MultiModalRegistry: ...@@ -103,7 +103,7 @@ class MultiModalRegistry:
model_config: "ModelConfig", model_config: "ModelConfig",
data: MultiModalDataDict, data: MultiModalDataDict,
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> MultiModalInputs: ) -> MultiModalKwargs:
""" """
Apply an input mapper to the data passed to the model. Apply an input mapper to the data passed to the model.
...@@ -139,7 +139,7 @@ class MultiModalRegistry: ...@@ -139,7 +139,7 @@ class MultiModalRegistry:
merged_dict[input_key] = input_tensor merged_dict[input_key] = input_tensor
return MultiModalInputs(merged_dict) return MultiModalKwargs(merged_dict)
def create_input_mapper(self, model_config: "ModelConfig"): def create_input_mapper(self, model_config: "ModelConfig"):
""" """
......
...@@ -9,7 +9,7 @@ from vllm.transformers_utils.processor import get_video_processor ...@@ -9,7 +9,7 @@ from vllm.transformers_utils.processor import get_video_processor
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .base import MultiModalData, MultiModalInputs from .base import MultiModalData, MultiModalKwargs
from .image import ImagePlugin from .image import ImagePlugin
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -55,7 +55,7 @@ class VideoPlugin(ImagePlugin): ...@@ -55,7 +55,7 @@ class VideoPlugin(ImagePlugin):
ctx: InputContext, ctx: InputContext,
data: MultiModalData[object], data: MultiModalData[object],
**mm_processor_kwargs, **mm_processor_kwargs,
) -> MultiModalInputs: ) -> MultiModalKwargs:
model_config = ctx.model_config model_config = ctx.model_config
if isinstance(data, list) and len(data) == 1: if isinstance(data, list) and len(data) == 1:
...@@ -79,7 +79,7 @@ class VideoPlugin(ImagePlugin): ...@@ -79,7 +79,7 @@ class VideoPlugin(ImagePlugin):
logger.error("Failed to process video (%s)", data) logger.error("Failed to process video (%s)", data)
raise raise
return MultiModalInputs(batch_data) return MultiModalKwargs(batch_data)
raise TypeError(f"Invalid video type: {type(data)}") raise TypeError(f"Invalid video type: {type(data)}")
......
...@@ -18,7 +18,7 @@ except (ModuleNotFoundError, ImportError) as err: ...@@ -18,7 +18,7 @@ except (ModuleNotFoundError, ImportError) as err:
"CUDA and ROCm flash attention backend.") from err "CUDA and ROCm flash attention backend.") from err
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalInputs from vllm.multimodal import MultiModalKwargs
from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner) ModelRunner)
...@@ -280,7 +280,7 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -280,7 +280,7 @@ class TP1DraftModelRunner(ModelRunner):
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata, attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalInputs.as_kwargs(multi_modal_kwargs, **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device), device=self.device),
**kwargs, **kwargs,
) )
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalInputs from vllm.multimodal import MultiModalKwargs
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad from vllm.utils import make_tensor_with_pad
from vllm.worker.cpu_model_runner import (CPUModelRunner, from vllm.worker.cpu_model_runner import (CPUModelRunner,
...@@ -287,7 +287,7 @@ class CPUEncoderDecoderModelRunner(CPUModelRunner): ...@@ -287,7 +287,7 @@ class CPUEncoderDecoderModelRunner(CPUModelRunner):
kv_caches, kv_caches,
"attn_metadata": "attn_metadata":
model_input.attn_metadata, model_input.attn_metadata,
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device), device=self.device),
"intermediate_tensors": "intermediate_tensors":
intermediate_tensors, intermediate_tensors,
......
...@@ -15,7 +15,7 @@ from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding ...@@ -15,7 +15,7 @@ from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs, MultiModalPlaceholderMap) MultiModalKwargs, MultiModalPlaceholderMap)
from vllm.sequence import (IntermediateTensors, SequenceData, from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.utils import make_tensor_with_pad from vllm.utils import make_tensor_with_pad
...@@ -200,7 +200,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -200,7 +200,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
slot_mapping: List[int] = [] slot_mapping: List[int] = []
seq_lens: List[int] = [] seq_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = [] multi_model_kwargs_list: List[MultiModalKwargs] = []
multi_modal_placeholder_maps: Dict[ multi_modal_placeholder_maps: Dict[
str, str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
...@@ -225,7 +225,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -225,7 +225,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
._compute_multi_modal_input( ._compute_multi_modal_input(
seq_group_metadata, seq_data, computed_len, seq_group_metadata, seq_data, computed_len,
seq_group_metadata.mm_processor_kwargs) seq_group_metadata.mm_processor_kwargs)
multi_modal_inputs_list.append(mm_kwargs) multi_model_kwargs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items(): for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend( multi_modal_placeholder_maps[modality].extend(
placeholder_map) placeholder_map)
...@@ -297,7 +297,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): ...@@ -297,7 +297,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
multi_modal_placeholder_index_maps=placeholder_index_maps, multi_modal_placeholder_index_maps=placeholder_index_maps,
) )
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list)
return (input_tokens, input_positions, attn_metadata, seq_lens, return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs) multi_modal_kwargs)
...@@ -520,7 +520,7 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]): ...@@ -520,7 +520,7 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
kv_caches, kv_caches,
"attn_metadata": "attn_metadata":
model_input.attn_metadata, model_input.attn_metadata,
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device), device=self.device),
"intermediate_tensors": "intermediate_tensors":
intermediate_tensors, intermediate_tensors,
......
...@@ -8,7 +8,7 @@ from vllm.distributed import get_pp_group ...@@ -8,7 +8,7 @@ from vllm.distributed import get_pp_group
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalInputs from vllm.multimodal import MultiModalKwargs
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
SequenceGroupMetadata) SequenceGroupMetadata)
...@@ -104,7 +104,7 @@ class EmbeddingModelRunner( ...@@ -104,7 +104,7 @@ class EmbeddingModelRunner(
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata, attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalInputs.as_kwargs(multi_modal_kwargs, **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device)) device=self.device))
if (self.observability_config is not None if (self.observability_config is not None
......
...@@ -18,7 +18,7 @@ from vllm.logger import init_logger ...@@ -18,7 +18,7 @@ from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.utils import get_architecture_class_name from vllm.model_executor.model_loader.utils import get_architecture_class_name
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry) MultiModalRegistry)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, from vllm.sequence import (IntermediateTensors, PoolerOutput,
...@@ -206,7 +206,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -206,7 +206,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata, attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalInputs.as_kwargs(multi_modal_kwargs, **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device), device=self.device),
**seqlen_agnostic_kwargs) **seqlen_agnostic_kwargs)
......
...@@ -36,7 +36,7 @@ from vllm.model_executor import SamplingMetadata ...@@ -36,7 +36,7 @@ from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs) MultiModalKwargs)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SequenceData, from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata) SequenceGroupMetadata)
...@@ -716,7 +716,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -716,7 +716,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
context_lens: List[int] = [] context_lens: List[int] = []
query_lens: List[int] = [] query_lens: List[int] = []
prefix_block_tables: List[List[int]] = [] prefix_block_tables: List[List[int]] = []
multi_modal_inputs_list: List[MultiModalInputs] = [] multi_model_kwargs_list: List[MultiModalKwargs] = []
if len(seq_group_metadata_list) == 0: if len(seq_group_metadata_list) == 0:
return PreparePromptMetadata.empty() return PreparePromptMetadata.empty()
...@@ -777,7 +777,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -777,7 +777,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
mm_data = seq_group_metadata.multi_modal_data mm_data = seq_group_metadata.multi_modal_data
if mm_data: if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data) mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs) multi_model_kwargs_list.append(mm_kwargs)
if seq_group_metadata.block_tables is None: if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized # During memory profiling, the block tables are not initialized
...@@ -876,7 +876,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -876,7 +876,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
multi_modal_placeholder_index_maps= multi_modal_placeholder_index_maps=
None # FIXME(kzawora): mutli-modality will not work here None # FIXME(kzawora): mutli-modality will not work here
) )
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list)
return PreparePromptMetadata(input_tokens=input_tokens, return PreparePromptMetadata(input_tokens=input_tokens,
input_positions=input_positions, input_positions=input_positions,
......
...@@ -38,7 +38,7 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig ...@@ -38,7 +38,7 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs, MultiModalPlaceholderMap, MultiModalKwargs, MultiModalPlaceholderMap,
MultiModalRegistry) MultiModalRegistry)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.layers import PromptAdapterMapping
...@@ -252,7 +252,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -252,7 +252,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
# Multi-modal inputs. # Multi-modal inputs.
multi_modal_inputs: Optional[MultiModalInputs] = None, multi_model_kwargs: Optional[MultiModalKwargs] = None,
multi_modal_placeholder_maps: Optional[Dict[ multi_modal_placeholder_maps: Optional[Dict[
str, MultiModalPlaceholderMap]] = None, str, MultiModalPlaceholderMap]] = None,
...@@ -373,7 +373,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -373,7 +373,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
prompt_adapter_prompt_mapping or []) prompt_adapter_prompt_mapping or [])
self.prompt_adapter_request = prompt_adapter_request self.prompt_adapter_request = prompt_adapter_request
self.multi_modal_inputs = multi_modal_inputs self.multi_model_kwargs = multi_model_kwargs
self.multi_modal_placeholder_maps = multi_modal_placeholder_maps self.multi_modal_placeholder_maps = multi_modal_placeholder_maps
self.prefix_cache_hit = prefix_cache_hit self.prefix_cache_hit = prefix_cache_hit
...@@ -661,7 +661,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -661,7 +661,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
mm_kwargs = self.multi_modal_input_mapper( mm_kwargs = self.multi_modal_input_mapper(
mm_data, mm_data,
mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs) mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs)
inter_data.multi_modal_inputs = mm_kwargs inter_data.multi_model_kwargs = mm_kwargs
inter_data.multi_modal_placeholder_maps = placeholder_maps inter_data.multi_modal_placeholder_maps = placeholder_maps
# special processing for mrope position deltas. # special processing for mrope position deltas.
...@@ -935,11 +935,11 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -935,11 +935,11 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
) )
# Multi-modal data. # Multi-modal data.
multi_modal_inputs_list = [ multi_model_kwargs_list = [
data.multi_modal_inputs for data in self.inter_data_list data.multi_model_kwargs for data in self.inter_data_list
if data.multi_modal_inputs is not None if data.multi_model_kwargs is not None
] ]
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list)
return self.model_input_cls( return self.model_input_cls(
input_tokens=input_tokens_tensor, input_tokens=input_tokens_tensor,
...@@ -1649,7 +1649,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1649,7 +1649,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata, attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalInputs.as_kwargs(multi_modal_kwargs, **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device), device=self.device),
**seqlen_agnostic_kwargs) **seqlen_agnostic_kwargs)
......
...@@ -13,7 +13,7 @@ from vllm.model_executor import SamplingMetadata ...@@ -13,7 +13,7 @@ from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs) MultiModalKwargs)
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
...@@ -122,7 +122,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -122,7 +122,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
input_block_ids: List[int] = [] input_block_ids: List[int] = []
seq_lens: List[int] = [] seq_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = [] multi_model_kwargs_list: List[MultiModalKwargs] = []
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = list(seq_group_metadata.seq_data.keys())
...@@ -149,7 +149,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -149,7 +149,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
mm_data, mm_data,
mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs, mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs,
) )
multi_modal_inputs_list.append(mm_kwargs) multi_model_kwargs_list.append(mm_kwargs)
max_seq_len = max(seq_lens) max_seq_len = max(seq_lens)
assert max_seq_len > 0 assert max_seq_len > 0
...@@ -167,7 +167,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -167,7 +167,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list)
return (input_tokens, input_positions, input_block_ids, seq_lens, return (input_tokens, input_positions, input_block_ids, seq_lens,
multi_modal_kwargs) multi_modal_kwargs)
...@@ -314,7 +314,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -314,7 +314,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids, input_block_ids=model_input.input_block_ids,
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device), device=self.device),
) )
......
...@@ -13,7 +13,7 @@ from vllm.model_executor import SamplingMetadata ...@@ -13,7 +13,7 @@ from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.openvino import get_model from vllm.model_executor.model_loader.openvino import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs, MultiModalPlaceholderMap) MultiModalKwargs, MultiModalPlaceholderMap)
from vllm.sequence import SequenceGroupMetadata from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner_base import ModelRunnerBase from vllm.worker.model_runner_base import ModelRunnerBase
...@@ -102,7 +102,7 @@ class OpenVINOModelRunner(ModelRunnerBase): ...@@ -102,7 +102,7 @@ class OpenVINOModelRunner(ModelRunnerBase):
seq_lens: List[int] = [] seq_lens: List[int] = []
past_lens: List[int] = [] past_lens: List[int] = []
query_lens: List[int] = [] query_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = [] multi_model_kwargs_list: List[MultiModalKwargs] = []
multi_modal_placeholder_maps: Dict[ multi_modal_placeholder_maps: Dict[
str, str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
...@@ -226,7 +226,7 @@ class OpenVINOModelRunner(ModelRunnerBase): ...@@ -226,7 +226,7 @@ class OpenVINOModelRunner(ModelRunnerBase):
mm_data, mm_data,
mm_processor_kwargs=seq_group_metadata. mm_processor_kwargs=seq_group_metadata.
mm_processor_kwargs) mm_processor_kwargs)
multi_modal_inputs_list.append(mm_kwargs) multi_model_kwargs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items(): for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend( multi_modal_placeholder_maps[modality].extend(
...@@ -275,7 +275,7 @@ class OpenVINOModelRunner(ModelRunnerBase): ...@@ -275,7 +275,7 @@ class OpenVINOModelRunner(ModelRunnerBase):
multi_modal_placeholder_index_maps=placeholder_index_maps, multi_modal_placeholder_index_maps=placeholder_index_maps,
) )
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list)
return ModelInput( return ModelInput(
input_tokens, input_tokens,
...@@ -341,7 +341,7 @@ class OpenVINOModelRunner(ModelRunnerBase): ...@@ -341,7 +341,7 @@ class OpenVINOModelRunner(ModelRunnerBase):
kv_caches, kv_caches,
"attn_metadata": "attn_metadata":
attn_metadata, attn_metadata,
**MultiModalInputs.as_kwargs(multi_modal_kwargs or {}, **MultiModalKwargs.as_kwargs(multi_modal_kwargs or {},
device=self.device), device=self.device),
} }
......
...@@ -18,7 +18,7 @@ from vllm.model_executor import SamplingMetadataCache ...@@ -18,7 +18,7 @@ from vllm.model_executor import SamplingMetadataCache
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs, MultiModalPlaceholderMap, MultiModalKwargs, MultiModalPlaceholderMap,
MultiModalRegistry) MultiModalRegistry)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
...@@ -160,7 +160,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): ...@@ -160,7 +160,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
input_positions: List[int] = [] input_positions: List[int] = []
slot_mapping: List[int] = [] slot_mapping: List[int] = []
seq_lens: List[int] = [] seq_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = [] multi_model_kwargs_list: List[MultiModalKwargs] = []
multi_modal_placeholder_maps: Dict[ multi_modal_placeholder_maps: Dict[
str, str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
...@@ -192,7 +192,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): ...@@ -192,7 +192,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
.from_seq_group(seq_group_metadata, positions_range) .from_seq_group(seq_group_metadata, positions_range)
mm_kwargs = self.runner.multi_modal_input_mapper(mm_data) mm_kwargs = self.runner.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs) multi_model_kwargs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items(): for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend( multi_modal_placeholder_maps[modality].extend(
...@@ -264,7 +264,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): ...@@ -264,7 +264,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
block_tables=torch.tensor([], device=self.device, dtype=torch.int), block_tables=torch.tensor([], device=self.device, dtype=torch.int),
) )
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list)
return (input_tokens, input_positions, attn_metadata, seq_lens, return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs) multi_modal_kwargs)
...@@ -565,7 +565,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): ...@@ -565,7 +565,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata, attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device)) device=self.device))
# Compute the logits in the last pipeline stage. # Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment