Unverified Commit 69f46359 authored by Flora Feng's avatar Flora Feng Committed by GitHub
Browse files

[Multimodal] Consolidate mm inputs into MultiModalFeatureSpec (#23779)


Signed-off-by: default avatarsfeng33 <4florafeng@gmail.com>
parent d9e00dbd
...@@ -64,8 +64,6 @@ def _run_incremental_decode(tokenizer, ...@@ -64,8 +64,6 @@ def _run_incremental_decode(tokenizer,
request = EngineCoreRequest("", request = EngineCoreRequest("",
prompt_token_ids, prompt_token_ids,
None, None,
None,
None,
params, params,
None, None,
None, None,
......
...@@ -7,7 +7,8 @@ import pytest ...@@ -7,7 +7,8 @@ import pytest
import torch import torch
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheManager
...@@ -37,17 +38,20 @@ def make_request( ...@@ -37,17 +38,20 @@ def make_request(
mm_hashes: Optional[list[str]] = None, mm_hashes: Optional[list[str]] = None,
cache_salt: Optional[str] = None, cache_salt: Optional[str] = None,
): ):
if mm_positions is None: mm_features = []
mm_kwargs = None if mm_positions is not None:
else: for j, position in enumerate(mm_positions):
mm_item = MultiModalKwargsItem.dummy("dummy_m") identifier = mm_hashes[j] if mm_hashes else f"hash_{j}"
mm_kwargs = [mm_item] * len(mm_positions) mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image")
mm_features.append(mm_feature)
return Request(request_id=request_id, return Request(request_id=request_id,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
multi_modal_kwargs=mm_kwargs, mm_features=mm_features if mm_features else None,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17), sampling_params=SamplingParams(max_tokens=17),
pooling_params=None, pooling_params=None,
eos_token_id=100, eos_token_id=100,
......
...@@ -9,7 +9,8 @@ import pytest ...@@ -9,7 +9,8 @@ import pytest
import torch import torch
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import sha256, sha256_cbor_64bit from vllm.utils import sha256, sha256_cbor_64bit
from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.block_pool import BlockPool
...@@ -32,17 +33,20 @@ def make_request( ...@@ -32,17 +33,20 @@ def make_request(
prompt_logprobs: Optional[int] = None, prompt_logprobs: Optional[int] = None,
cache_salt: Optional[str] = None, cache_salt: Optional[str] = None,
): ):
if mm_positions is None: mm_features = []
mm_kwargs = None if mm_positions is not None:
else: for j, position in enumerate(mm_positions):
mm_item = MultiModalKwargsItem.dummy("dummy_m") identifier = mm_hashes[j] if mm_hashes else f"hash_{j}"
mm_kwargs = [mm_item] * len(mm_positions) mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image")
mm_features.append(mm_feature)
return Request(request_id=request_id, return Request(request_id=request_id,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
multi_modal_kwargs=mm_kwargs, mm_features=mm_features if mm_features else None,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams( sampling_params=SamplingParams(
max_tokens=17, prompt_logprobs=prompt_logprobs), max_tokens=17, prompt_logprobs=prompt_logprobs),
pooling_params=None, pooling_params=None,
......
...@@ -8,7 +8,8 @@ import torch ...@@ -8,7 +8,8 @@ import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig) SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
...@@ -1308,21 +1309,24 @@ def create_requests_with_priority( ...@@ -1308,21 +1309,24 @@ def create_requests_with_priority(
prompt_logprobs=prompt_logprobs) prompt_logprobs=prompt_logprobs)
requests = [] requests = []
for i in range(num_requests): for i in range(num_requests):
mm_features = []
if mm_positions is not None: if mm_positions is not None:
mm_position = mm_positions[i] mm_position = mm_positions[i]
mm_item = MultiModalKwargsItem.dummy("dummy_m") for j, position in enumerate(mm_position):
mm_kwargs = [mm_item] * len(mm_position) identifier = f"hash{i}_{j}"
else: mm_feature = MultiModalFeatureSpec(
mm_position = None data=MultiModalKwargsItem.dummy("dummy_m"),
mm_kwargs = None mm_position=position,
identifier=identifier,
modality="image")
mm_features.append(mm_feature)
request = Request( request = Request(
request_id=f"{i + starting_idx}", request_id=f"{i + starting_idx}",
prompt_token_ids=[i + starting_idx] * num_tokens, prompt_token_ids=[i + starting_idx] * num_tokens,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
multi_modal_kwargs=mm_kwargs, mm_features=mm_features if mm_features else None,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID,
arrival_time=arrival_times[i], arrival_time=arrival_times[i],
priority=priorities[i], priority=priorities[i],
...@@ -1801,9 +1805,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request(): ...@@ -1801,9 +1805,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
request = Request( request = Request(
request_id="0", request_id="0",
prompt_token_ids=[0, 1], prompt_token_ids=[0, 1],
multi_modal_kwargs=None, mm_features=None,
multi_modal_hashes=None,
multi_modal_placeholders=None,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
eos_token_id=EOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID,
......
...@@ -6,7 +6,8 @@ import torch ...@@ -6,7 +6,8 @@ import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig) SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash) init_none_hash)
...@@ -139,19 +140,20 @@ def create_requests( ...@@ -139,19 +140,20 @@ def create_requests(
prompt_logprobs=prompt_logprobs) prompt_logprobs=prompt_logprobs)
requests = [] requests = []
for i in range(num_requests): for i in range(num_requests):
mm_features = []
if mm_positions is not None: if mm_positions is not None:
mm_position = mm_positions[i] mm_position = mm_positions[i]
mm_item = MultiModalKwargsItem.dummy("dummy_m") for j, position in enumerate(mm_position):
mm_kwargs = [mm_item] * len(mm_position)
# Dummy hash for each mm item should be unique # Dummy hash for each mm item should be unique
# since encoder cache tracks entries by hash # since encoder cache tracks entries by hash
mm_hashes = [ identifier = f"hash{i}_{j}"
"hash" + str(i) + "_" + str(j) for j in range(len(mm_position)) mm_feature = MultiModalFeatureSpec(
] data=MultiModalKwargsItem.dummy("dummy_m"),
else: mm_position=position,
mm_position = None identifier=identifier,
mm_kwargs = None modality="image")
mm_hashes = None mm_features.append(mm_feature)
prompt_token_ids = ([0] * num_tokens if same_prompt else [i] * prompt_token_ids = ([0] * num_tokens if same_prompt else [i] *
num_tokens) num_tokens)
request = Request( request = Request(
...@@ -159,9 +161,7 @@ def create_requests( ...@@ -159,9 +161,7 @@ def create_requests(
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
multi_modal_kwargs=mm_kwargs, mm_features=mm_features if mm_features else None,
multi_modal_placeholders=mm_position,
multi_modal_hashes=mm_hashes,
eos_token_id=EOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID,
block_hasher=block_hasher, block_hasher=block_hasher,
) )
......
...@@ -35,9 +35,7 @@ def make_request() -> EngineCoreRequest: ...@@ -35,9 +35,7 @@ def make_request() -> EngineCoreRequest:
return EngineCoreRequest( return EngineCoreRequest(
request_id=str(uuid.uuid4()), request_id=str(uuid.uuid4()),
prompt_token_ids=PROMPT_TOKENS, prompt_token_ids=PROMPT_TOKENS,
mm_kwargs=None, mm_features=None,
mm_hashes=None,
mm_placeholders=None,
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
pooling_params=None, pooling_params=None,
eos_token_id=None, eos_token_id=None,
......
...@@ -52,9 +52,7 @@ def make_request( ...@@ -52,9 +52,7 @@ def make_request(
return EngineCoreRequest( return EngineCoreRequest(
request_id=str(uuid.uuid4()), request_id=str(uuid.uuid4()),
prompt_token_ids=prompt_tokens_ids, prompt_token_ids=prompt_tokens_ids,
mm_kwargs=None, mm_features=None,
mm_hashes=None,
mm_placeholders=None,
sampling_params=params, sampling_params=params,
pooling_params=None, pooling_params=None,
eos_token_id=None, eos_token_id=None,
......
...@@ -26,16 +26,14 @@ def test_fast_inc_detok_invalid_utf8_err_case(): ...@@ -26,16 +26,14 @@ def test_fast_inc_detok_invalid_utf8_err_case():
prompt_token_ids = [107, 4606, 236787, 107] prompt_token_ids = [107, 4606, 236787, 107]
params = SamplingParams(skip_special_tokens=True) params = SamplingParams(skip_special_tokens=True)
request = EngineCoreRequest( request = EngineCoreRequest(
"test", request_id="test",
prompt_token_ids, prompt_token_ids=prompt_token_ids,
None, mm_features=None,
None, sampling_params=params,
None, pooling_params=None,
params, eos_token_id=None,
None, arrival_time=0.0,
None, lora_request=None,
0.0,
None,
cache_salt=None, cache_salt=None,
data_parallel_rank=None, data_parallel_rank=None,
) )
......
...@@ -52,11 +52,9 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, ...@@ -52,11 +52,9 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
requests = [ requests = [
EngineCoreRequest(request_id=f"request-{idx}", EngineCoreRequest(request_id=f"request-{idx}",
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, mm_features=None,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
eos_token_id=None, eos_token_id=None,
arrival_time=0,
lora_request=None, lora_request=None,
cache_salt=None, cache_salt=None,
data_parallel_rank=None, data_parallel_rank=None,
...@@ -401,11 +399,9 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, ...@@ -401,11 +399,9 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
requests = [ requests = [
EngineCoreRequest(request_id=request_id_list[idx], EngineCoreRequest(request_id=request_id_list[idx],
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, mm_features=None,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
eos_token_id=None, eos_token_id=None,
arrival_time=0,
lora_request=None, lora_request=None,
cache_salt=None, cache_salt=None,
data_parallel_rank=None, data_parallel_rank=None,
...@@ -566,11 +562,9 @@ def test_stop_token(include_stop_str_in_output: bool, ...@@ -566,11 +562,9 @@ def test_stop_token(include_stop_str_in_output: bool,
request = EngineCoreRequest( request = EngineCoreRequest(
request_id=request_id, request_id=request_id,
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, mm_features=None,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
arrival_time=0,
lora_request=None, lora_request=None,
cache_salt=None, cache_salt=None,
data_parallel_rank=None, data_parallel_rank=None,
...@@ -665,11 +659,9 @@ def test_stop_string(include_stop_str_in_output: bool, ...@@ -665,11 +659,9 @@ def test_stop_string(include_stop_str_in_output: bool,
EngineCoreRequest( EngineCoreRequest(
request_id=request_id_list[idx], request_id=request_id_list[idx],
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, mm_features=None,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
eos_token_id=None, eos_token_id=None,
arrival_time=0,
lora_request=None, lora_request=None,
cache_salt=None, cache_salt=None,
data_parallel_rank=None, data_parallel_rank=None,
...@@ -781,11 +773,9 @@ def test_iteration_stats(dummy_test_vectors): ...@@ -781,11 +773,9 @@ def test_iteration_stats(dummy_test_vectors):
EngineCoreRequest( EngineCoreRequest(
request_id=f"request-{idx}", request_id=f"request-{idx}",
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, mm_features=None,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
eos_token_id=None, eos_token_id=None,
arrival_time=0,
lora_request=None, lora_request=None,
cache_salt=None, cache_salt=None,
data_parallel_rank=None, data_parallel_rank=None,
......
...@@ -162,9 +162,7 @@ def create_request(request_id: int, ...@@ -162,9 +162,7 @@ def create_request(request_id: int,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
multi_modal_kwargs=None, mm_features=None,
multi_modal_placeholders=None,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID,
block_hasher=get_request_block_hasher(block_size, hash_fn), block_hasher=get_request_block_hasher(block_size, hash_fn),
) )
......
...@@ -12,9 +12,9 @@ from vllm.logger import init_logger ...@@ -12,9 +12,9 @@ from vllm.logger import init_logger
from vllm.utils import GiB_bytes, LRUCache from vllm.utils import GiB_bytes, LRUCache
from vllm.utils.jsontree import json_map_leaves, json_reduce_leaves from vllm.utils.jsontree import json_map_leaves, json_reduce_leaves
from .inputs import (MultiModalFieldElem, MultiModalKwargs, from .inputs import (MultiModalFeatureSpec, MultiModalFieldElem,
MultiModalKwargsItem, MultiModalKwargsItems, MultiModalKwargs, MultiModalKwargsItem,
NestedTensors) MultiModalKwargsItems, NestedTensors)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
...@@ -418,6 +418,16 @@ class BaseMultiModalReceiverCache( ...@@ -418,6 +418,16 @@ class BaseMultiModalReceiverCache(
MultiModalKwargsItem]): MultiModalKwargsItem]):
"""The required interface for caches on P1.""" """The required interface for caches on P1."""
def get_and_update_features(
self,
mm_features: list["MultiModalFeatureSpec"],
) -> list["MultiModalFeatureSpec"]:
"""Update multimodal features with cached encoder outputs."""
for feature in mm_features:
feature.data = self.get_and_update_item(feature.data,
feature.identifier)
return mm_features
class MultiModalReceiverCache(BaseMultiModalReceiverCache): class MultiModalReceiverCache(BaseMultiModalReceiverCache):
""" """
......
...@@ -198,6 +198,29 @@ A dictionary containing nested tensors which have been batched via ...@@ -198,6 +198,29 @@ A dictionary containing nested tensors which have been batched via
""" """
@dataclass
class MultiModalFeatureSpec:
"""
Represents a single multimodal input with its processed data and metadata.
Used by the V1 engine to track multimodal data through processing and
caching. A request containing multiple multimodal items will have one
MultiModalFeatureSpec per item.
"""
data: Optional["MultiModalKwargsItem"]
"""Multimodal data for this feature"""
modality: str
"""Based on the input, e.g., "image", "audio", "video"."""
identifier: str
"""mm_hash or uuid for caching encoder outputs."""
mm_position: PlaceholderRange
"""e.g., PlaceholderRange(offset=2, length=336)"""
@dataclass @dataclass
class MultiModalFieldElem: class MultiModalFieldElem:
""" """
......
...@@ -3,14 +3,13 @@ ...@@ -3,14 +3,13 @@
import enum import enum
import time import time
from collections.abc import Sequence
from typing import Any, Optional, Union from typing import Any, Optional, Union
import msgspec import msgspec
import torch import torch
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
...@@ -48,9 +47,7 @@ class EngineCoreRequest( ...@@ -48,9 +47,7 @@ class EngineCoreRequest(
request_id: str request_id: str
prompt_token_ids: list[int] prompt_token_ids: list[int]
mm_kwargs: Optional[Sequence[Optional[MultiModalKwargsItem]]] mm_features: Optional[list[MultiModalFeatureSpec]]
mm_hashes: Optional[list[str]]
mm_placeholders: Optional[list[PlaceholderRange]]
sampling_params: Optional[SamplingParams] sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams] pooling_params: Optional[PoolingParams]
eos_token_id: Optional[int] eos_token_id: Optional[int]
......
...@@ -434,15 +434,13 @@ class EngineCore: ...@@ -434,15 +434,13 @@ class EngineCore:
This function could be directly used in input processing thread to allow This function could be directly used in input processing thread to allow
request initialization running in parallel with Model forward request initialization running in parallel with Model forward
""" """
if request.mm_hashes is not None:
assert request.mm_kwargs is not None
# Note on thread safety: no race condition. # Note on thread safety: no race condition.
# `mm_receiver_cache` is reset at the end of LLMEngine init, # `mm_receiver_cache` is reset at the end of LLMEngine init,
# and will only accessed in the input processing thread afterwards. # and will only accessed in the input processing thread afterwards.
if self.mm_receiver_cache is not None: if self.mm_receiver_cache is not None and request.mm_features:
request.mm_kwargs = self.mm_receiver_cache.get_and_update( request.mm_features = (
request.mm_kwargs, request.mm_hashes) self.mm_receiver_cache.get_and_update_features(
request.mm_features))
req = Request.from_engine_core_request(request, req = Request.from_engine_core_request(request,
self.request_block_hasher) self.request_block_hasher)
......
...@@ -12,7 +12,7 @@ from vllm.inputs.preprocess import InputPreprocessor ...@@ -12,7 +12,7 @@ from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config from vllm.multimodal.cache import processor_cache_from_config
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import argsort_mm_positions from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -346,9 +346,8 @@ class Processor: ...@@ -346,9 +346,8 @@ class Processor:
pooling_params = params.clone() pooling_params = params.clone()
# Multimodal related. # Multimodal related.
sorted_mm_inputs: Optional[list[Optional[MultiModalKwargsItem]]] = None mm_features: Optional[list[MultiModalFeatureSpec]] = None
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
sorted_mm_hashes: Optional[list[str]] = None
if decoder_inputs["type"] == "multimodal": if decoder_inputs["type"] == "multimodal":
decoder_mm_inputs = decoder_inputs["mm_kwargs"] decoder_mm_inputs = decoder_inputs["mm_kwargs"]
decoder_mm_positions = decoder_inputs["mm_placeholders"] decoder_mm_positions = decoder_inputs["mm_placeholders"]
...@@ -359,25 +358,19 @@ class Processor: ...@@ -359,25 +358,19 @@ class Processor:
# in the input sequence. # in the input sequence.
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions) sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
sorted_mm_inputs = [ mm_features = []
decoder_mm_inputs[modality][idx] for modality, idx in sorted_mm_idxs:
for modality, idx in sorted_mm_idxs mm_features.append(
] MultiModalFeatureSpec(
sorted_mm_positions = [ data=decoder_mm_inputs[modality][idx],
decoder_mm_positions[modality][idx] modality=modality,
for modality, idx in sorted_mm_idxs identifier=decoder_mm_hashes[modality][idx],
] mm_position=decoder_mm_positions[modality][idx]))
sorted_mm_hashes = [
decoder_mm_hashes[modality][idx]
for modality, idx in sorted_mm_idxs
]
return decoder_inputs.get("prompt"), EngineCoreRequest( return decoder_inputs.get("prompt"), EngineCoreRequest(
request_id=request_id, request_id=request_id,
prompt_token_ids=decoder_inputs["prompt_token_ids"], prompt_token_ids=decoder_inputs["prompt_token_ids"],
mm_kwargs=sorted_mm_inputs, mm_features=mm_features,
mm_hashes=sorted_mm_hashes,
mm_placeholders=sorted_mm_positions,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=pooling_params, pooling_params=pooling_params,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
......
...@@ -6,10 +6,9 @@ import time ...@@ -6,10 +6,9 @@ import time
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional, Union from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import is_list_of
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
EngineCoreRequest, FinishReason) EngineCoreRequest, FinishReason)
from vllm.v1.structured_output.request import StructuredOutputRequest from vllm.v1.structured_output.request import StructuredOutputRequest
...@@ -26,14 +25,12 @@ class Request: ...@@ -26,14 +25,12 @@ class Request:
self, self,
request_id: str, request_id: str,
prompt_token_ids: list[int], prompt_token_ids: list[int],
multi_modal_kwargs: Optional[list[MultiModalKwargsItem]],
multi_modal_hashes: Optional[list[str]],
multi_modal_placeholders: Optional[list[PlaceholderRange]],
sampling_params: Optional[SamplingParams], sampling_params: Optional[SamplingParams],
pooling_params: Optional[PoolingParams], pooling_params: Optional[PoolingParams],
eos_token_id: Optional[int], eos_token_id: Optional[int],
client_index: int = 0, client_index: int = 0,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
mm_features: Optional[list[MultiModalFeatureSpec]] = None,
lora_request: Optional["LoRARequest"] = None, lora_request: Optional["LoRARequest"] = None,
structured_output_request: Optional["StructuredOutputRequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None,
cache_salt: Optional[str] = None, cache_salt: Optional[str] = None,
...@@ -89,16 +86,14 @@ class Request: ...@@ -89,16 +86,14 @@ class Request:
self.cache_salt: Optional[str] = cache_salt self.cache_salt: Optional[str] = cache_salt
# Multi-modal related # Multi-modal related
self.mm_positions = multi_modal_placeholders or [] self.mm_features = mm_features or []
self.mm_kwargs = multi_modal_kwargs or [] self.num_encoder_inputs = len(self.mm_features)
self.mm_hashes: list[str] = multi_modal_hashes or []
self.num_encoder_inputs = len(self.mm_kwargs)
self.has_encoder_inputs = self.num_encoder_inputs > 0 self.has_encoder_inputs = self.num_encoder_inputs > 0
# TODO(sfeng33): Remove these legacy fields after clearing out all
# Sanity check # references in scheduler and model runner
assert len(self.mm_kwargs) == len(self.mm_positions) self.mm_positions = [f.mm_position for f in self.mm_features]
if self.mm_hashes: self.mm_kwargs = [f.data for f in self.mm_features]
assert len(self.mm_kwargs) == len(self.mm_hashes) self.mm_hashes = [f.identifier for f in self.mm_features]
# Read-only views # Read-only views
# Prevent directly appending to these lists since # Prevent directly appending to these lists since
...@@ -126,20 +121,11 @@ class Request: ...@@ -126,20 +121,11 @@ class Request:
cls, request: EngineCoreRequest, cls, request: EngineCoreRequest,
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]] block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
) -> "Request": ) -> "Request":
if request.mm_kwargs is not None:
mm_kwargs_lst = list(request.mm_kwargs)
assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem), (
"mm_kwargs was not updated in EngineCore.add_request")
else:
mm_kwargs_lst = None
return cls( return cls(
request_id=request.request_id, request_id=request.request_id,
client_index=request.client_index, client_index=request.client_index,
prompt_token_ids=request.prompt_token_ids, prompt_token_ids=request.prompt_token_ids,
multi_modal_kwargs=mm_kwargs_lst, mm_features=request.mm_features,
multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders,
sampling_params=request.sampling_params, sampling_params=request.sampling_params,
pooling_params=request.pooling_params, pooling_params=request.pooling_params,
eos_token_id=request.eos_token_id, eos_token_id=request.eos_token_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment