Unverified Commit 0b8bb86b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[1/N] Initial prototype for multi-modal processor (#10044)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent bb7991aa
import enum
from typing import TYPE_CHECKING, List, Optional, Union
from typing import List, Optional, Union
from vllm.inputs.data import DecoderOnlyInputs
from vllm.inputs import DecoderOnlyInputs, SingletonInputsAdapter, token_inputs
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams
......@@ -9,23 +9,20 @@ from vllm.sequence import RequestMetrics
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.inputs import DecoderOnlyInputs
class Request:
def __init__(
self,
request_id: str,
inputs: "DecoderOnlyInputs",
inputs: DecoderOnlyInputs,
sampling_params: SamplingParams,
eos_token_id: Optional[int],
arrival_time: float,
lora_request: Optional[LoRARequest] = None,
) -> None:
self.request_id = request_id
self.inputs = inputs
self.inputs = SingletonInputsAdapter(inputs)
self.sampling_params = sampling_params
# Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id
......@@ -41,17 +38,17 @@ class Request:
assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens
self.prompt = inputs.get("prompt")
self.prompt_token_ids = inputs["prompt_token_ids"]
self.prompt = self.inputs.prompt
self.prompt_token_ids = self.inputs.prompt_token_ids
self.num_prompt_tokens = len(self.prompt_token_ids)
self._output_token_ids: List[int] = []
self._all_token_ids: List[int] = self.prompt_token_ids.copy()
self.num_computed_tokens = 0
# Raw multimodal data before the mm input mapper (e.g., PIL images).
self.mm_data = inputs.get("multi_modal_data")
self.mm_processor_kwargs = inputs.get("mm_processor_kwargs")
mm_positions = inputs.get("multi_modal_placeholders")
self.mm_data = self.inputs.multi_modal_data
self.mm_processor_kwargs = self.inputs.mm_processor_kwargs
mm_positions = self.inputs.multi_modal_placeholders
if mm_positions:
# FIXME(woosuk): Support other modalities.
self.mm_positions = mm_positions.get("image", [])
......@@ -64,8 +61,7 @@ class Request:
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
return cls(
request_id=request.request_id,
inputs=DecoderOnlyInputs(
type="token",
inputs=token_inputs(
prompt_token_ids=request.prompt_token_ids,
prompt=request.prompt,
multi_modal_data=request.mm_data,
......@@ -114,7 +110,7 @@ class Request:
return RequestStatus.get_finished_reason(self.status)
def has_encoder_inputs(self) -> bool:
return self.mm_data is not None
return len(self.mm_data) > 0
@property
def num_encoder_inputs(self) -> int:
......
......@@ -28,7 +28,7 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
if TYPE_CHECKING:
from vllm.multimodal.base import PlaceholderRange
from vllm.multimodal.inputs import PlaceholderRange
from vllm.v1.core.scheduler import SchedulerOutput
logger = init_logger(__name__)
......
......@@ -148,19 +148,29 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
query_lens=seq_lens,
)
def _compute_multi_modal_input(self, seq_group: SequenceGroupMetadata,
seq_data: SequenceData, computed_len: int,
mm_processor_kwargs: Dict[str, Any]):
def _compute_multi_modal_input(
self,
seq_data: SequenceData,
computed_len: int,
seq_group_metadata: SequenceGroupMetadata,
):
# NOTE: mm_data only includes the subset of multi-modal items that
# intersect with the current prefill positions.
mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
seq_group, range(computed_len, len(seq_data.get_token_ids())))
seq_group_metadata,
range(computed_len, len(seq_data.get_token_ids())),
)
if not mm_data:
return
return None, None, None
mm_kwargs = self.multi_modal_input_mapper(mm_data, mm_processor_kwargs)
if self.runner.mm_registry.has_processor(self.runner.model_config):
mm_kwargs = mm_data
else:
mm_kwargs = self.multi_modal_input_mapper(
mm_data,
seq_group_metadata.mm_processor_kwargs,
)
# special processing for mrope position deltas.
mrope_positions = None
......@@ -202,7 +212,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
slot_mapping: List[int] = []
seq_lens: List[int] = []
multi_model_kwargs_list: List[MultiModalKwargs] = []
multi_modal_kwargs_list: List[MultiModalKwargs] = []
multi_modal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
......@@ -223,11 +233,14 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
mrope_positions = None
if seq_group_metadata.multi_modal_data:
mm_kwargs, placeholder_maps, mrope_positions = self \
._compute_multi_modal_input(
seq_group_metadata, seq_data, computed_len,
seq_group_metadata.mm_processor_kwargs)
multi_model_kwargs_list.append(mm_kwargs)
(
mm_kwargs,
placeholder_maps,
mrope_positions,
) = self._compute_multi_modal_input(seq_data, computed_len,
seq_group_metadata)
multi_modal_kwargs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend(
placeholder_map)
......@@ -302,7 +315,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
multi_modal_placeholder_index_maps=placeholder_index_maps,
)
multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs)
......
......@@ -716,7 +716,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
context_lens: List[int] = []
query_lens: List[int] = []
prefix_block_tables: List[List[int]] = []
multi_model_kwargs_list: List[MultiModalKwargs] = []
multi_modal_kwargs_list: List[MultiModalKwargs] = []
if len(seq_group_metadata_list) == 0:
return PreparePromptMetadata.empty()
......@@ -777,7 +777,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_model_kwargs_list.append(mm_kwargs)
multi_modal_kwargs_list.append(mm_kwargs)
if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
......@@ -876,7 +876,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
multi_modal_placeholder_index_maps=
None # FIXME(kzawora): mutli-modality will not work here
)
multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return PreparePromptMetadata(input_tokens=input_tokens,
input_positions=input_positions,
......
......@@ -252,7 +252,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
# Multi-modal inputs.
multi_model_kwargs: Optional[MultiModalKwargs] = None,
multi_modal_kwargs: Optional[MultiModalKwargs] = None,
multi_modal_placeholder_maps: Optional[Dict[
str, MultiModalPlaceholderMap]] = None,
......@@ -373,7 +373,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
prompt_adapter_prompt_mapping or [])
self.prompt_adapter_request = prompt_adapter_request
self.multi_model_kwargs = multi_model_kwargs
self.multi_modal_kwargs = multi_modal_kwargs
self.multi_modal_placeholder_maps = multi_modal_placeholder_maps
self.prefix_cache_hit = prefix_cache_hit
......@@ -661,10 +661,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
if not mm_data:
return
mm_kwargs = self.multi_modal_input_mapper(
mm_data,
mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs)
inter_data.multi_model_kwargs = mm_kwargs
if self.runner.mm_registry.has_processor(self.runner.model_config):
mm_kwargs = mm_data
else:
mm_kwargs = self.multi_modal_input_mapper(
mm_data,
seq_group_metadata.mm_processor_kwargs,
)
inter_data.multi_modal_kwargs = mm_kwargs
inter_data.multi_modal_placeholder_maps = placeholder_maps
# special processing for mrope position deltas.
......@@ -938,11 +943,11 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
)
# Multi-modal data.
multi_model_kwargs_list = [
data.multi_model_kwargs for data in self.inter_data_list
if data.multi_model_kwargs is not None
multi_modal_kwargs_list = [
data.multi_modal_kwargs for data in self.inter_data_list
if data.multi_modal_kwargs is not None
]
multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return self.model_input_cls(
input_tokens=input_tokens_tensor,
......
......@@ -67,7 +67,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
self.pin_memory = is_pin_memory_available()
# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
self.mm_registry = MULTIMODAL_REGISTRY
self.multi_modal_input_mapper = self.mm_registry \
.create_input_mapper(self.model_config)
# Lazy initialization.
......@@ -122,7 +123,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
input_block_ids: List[int] = []
seq_lens: List[int] = []
multi_model_kwargs_list: List[MultiModalKwargs] = []
multi_modal_kwargs_list: List[MultiModalKwargs] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
......@@ -144,12 +145,15 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
# Process multi-modal data
mm_kwargs = self.multi_modal_input_mapper(
mm_data,
mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs,
)
multi_model_kwargs_list.append(mm_kwargs)
if self.mm_registry.has_processor(self.model_config):
mm_kwargs = mm_data
else:
mm_kwargs = self.multi_modal_input_mapper(
mm_data,
seq_group_metadata.mm_processor_kwargs,
)
multi_modal_kwargs_list.append(mm_kwargs)
max_seq_len = max(seq_lens)
assert max_seq_len > 0
......@@ -167,7 +171,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
dtype=torch.long,
device=self.device)
multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return (input_tokens, input_positions, input_block_ids, seq_lens,
multi_modal_kwargs)
......
......@@ -70,7 +70,8 @@ class OpenVINOModelRunner(ModelRunnerBase):
)
# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
self.mm_registry = MULTIMODAL_REGISTRY
self.multi_modal_input_mapper = self.mm_registry \
.create_input_mapper(self.model_config)
# Lazy initialization.
......@@ -102,7 +103,7 @@ class OpenVINOModelRunner(ModelRunnerBase):
seq_lens: List[int] = []
past_lens: List[int] = []
query_lens: List[int] = []
multi_model_kwargs_list: List[MultiModalKwargs] = []
multi_modal_kwargs_list: List[MultiModalKwargs] = []
multi_modal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
......@@ -222,11 +223,15 @@ class OpenVINOModelRunner(ModelRunnerBase):
mm_data, placeholder_maps = MultiModalPlaceholderMap \
.from_seq_group(seq_group_metadata, positions_range)
mm_kwargs = self.multi_modal_input_mapper(
mm_data,
mm_processor_kwargs=seq_group_metadata.
mm_processor_kwargs)
multi_model_kwargs_list.append(mm_kwargs)
if self.mm_registry.has_processor(self.model_config):
mm_kwargs = mm_data
else:
mm_kwargs = self.multi_modal_input_mapper(
mm_data,
seq_group_metadata.mm_processor_kwargs,
)
multi_modal_kwargs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend(
......@@ -275,7 +280,7 @@ class OpenVINOModelRunner(ModelRunnerBase):
multi_modal_placeholder_index_maps=placeholder_index_maps,
)
multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return ModelInput(
input_tokens,
......
......@@ -160,7 +160,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
input_positions: List[int] = []
slot_mapping: List[int] = []
seq_lens: List[int] = []
multi_model_kwargs_list: List[MultiModalKwargs] = []
multi_modal_kwargs_list: List[MultiModalKwargs] = []
multi_modal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
......@@ -191,8 +191,16 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
mm_data, placeholder_maps = MultiModalPlaceholderMap \
.from_seq_group(seq_group_metadata, positions_range)
mm_kwargs = self.runner.multi_modal_input_mapper(mm_data)
multi_model_kwargs_list.append(mm_kwargs)
if self.runner.mm_registry.has_processor(
self.runner.model_config):
mm_kwargs = mm_data
else:
mm_kwargs = self.runner.multi_modal_input_mapper(
mm_data,
seq_group_metadata.mm_processor_kwargs,
)
multi_modal_kwargs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend(
......@@ -264,7 +272,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
block_tables=torch.tensor([], device=self.device, dtype=torch.int),
)
multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment