Unverified Commit 19b927e5 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Use individual MM items in P0/P1 cache and model runner (#22570)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 20d65aa7
...@@ -5,7 +5,7 @@ import base64 ...@@ -5,7 +5,7 @@ import base64
import mimetypes import mimetypes
import os import os
from tempfile import NamedTemporaryFile, TemporaryDirectory from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import TYPE_CHECKING, NamedTuple, Optional from typing import TYPE_CHECKING, NamedTuple
import numpy as np import numpy as np
import pytest import pytest
...@@ -19,14 +19,12 @@ from vllm.distributed.parallel_state import (init_distributed_environment, ...@@ -19,14 +19,12 @@ from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel) initialize_model_parallel)
from vllm.multimodal.image import convert_image_mode from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import (MediaConnector, from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions,
merge_and_sort_multimodal_metadata,
run_dp_sharded_vision_model) run_dp_sharded_vision_model)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import get_open_port, update_environment_variables from vllm.utils import get_open_port, update_environment_variables
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.multimodal.hasher import MultiModalHashDict
from vllm.multimodal.inputs import MultiModalPlaceholderDict from vllm.multimodal.inputs import MultiModalPlaceholderDict
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
...@@ -178,19 +176,17 @@ async def test_fetch_video_http(video_url: str, num_frames: int): ...@@ -178,19 +176,17 @@ async def test_fetch_video_http(video_url: str, num_frames: int):
assert metadata_sync == metadata_async assert metadata_sync == metadata_async
# Used for the next two tests related to `merge_and_sort_multimodal_metadata`. # Used for `test_argsort_mm_positions`.
class TestCase(NamedTuple): class TestCase(NamedTuple):
mm_positions: "MultiModalPlaceholderDict" mm_positions: "MultiModalPlaceholderDict"
mm_hashes: Optional["MultiModalHashDict"] expected_modality_idxs: list[tuple[str, int]]
expected_modalities: list[str]
expected_ranges: list[PlaceholderRange]
expected_hashes: Optional[list[str]]
def test_merge_and_sort_multimodal_metadata(): def test_argsort_mm_positions():
test_cases = [ test_cases = [
# Single modality should return result as is but flattened # Single modality
## Internally sorted
TestCase( TestCase(
mm_positions={ mm_positions={
"image": [ "image": [
...@@ -198,34 +194,27 @@ def test_merge_and_sort_multimodal_metadata(): ...@@ -198,34 +194,27 @@ def test_merge_and_sort_multimodal_metadata():
PlaceholderRange(offset=3, length=2), PlaceholderRange(offset=3, length=2),
] ]
}, },
mm_hashes={"image": ["hash1", "hash2"]}, expected_modality_idxs=[
expected_modalities=["image", "image"], ("image", 0),
expected_ranges=[ ("image", 1),
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=3, length=2),
], ],
expected_hashes=["hash1", "hash2"],
), ),
## Internally unsorted
# Single modality without hashes return None for mm hash.
TestCase( TestCase(
mm_positions={ mm_positions={
"image": [ "image": [
PlaceholderRange(offset=3, length=2),
PlaceholderRange(offset=0, length=2), PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=2),
] ]
}, },
mm_hashes=None, expected_modality_idxs=[
expected_modalities=["image", "image"], ("image", 1),
expected_ranges=[ ("image", 0),
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=2),
], ],
expected_hashes=None,
), ),
# Multiple modalities with hashes should return sorted modalities # Two modalities
# and flattened ranges and hashes. ## Internally sorted
TestCase( TestCase(
mm_positions={ mm_positions={
"image": [ "image": [
...@@ -237,47 +226,54 @@ def test_merge_and_sort_multimodal_metadata(): ...@@ -237,47 +226,54 @@ def test_merge_and_sort_multimodal_metadata():
PlaceholderRange(offset=2, length=3), PlaceholderRange(offset=2, length=3),
] ]
}, },
mm_hashes={ expected_modality_idxs=[
"image": ["image_hash1", "image_hash2"], ("audio", 0),
"audio": ["audio_hash1", "audio_hash2"], ("audio", 1),
}, ("image", 0),
expected_modalities=["audio", "audio", "image", "image"], ("image", 1),
expected_ranges=[ ],
PlaceholderRange(offset=0, length=2), ),
PlaceholderRange(offset=2, length=3), ## Interleaved, internally sorted
PlaceholderRange(offset=7, length=4), TestCase(
PlaceholderRange(offset=11, length=5), mm_positions={
"image": [
PlaceholderRange(offset=0, length=4),
PlaceholderRange(offset=8, length=2),
], ],
expected_hashes=[ "audio": [
"audio_hash1", "audio_hash2", "image_hash1", "image_hash2" PlaceholderRange(offset=5, length=2),
PlaceholderRange(offset=11, length=4),
]
},
expected_modality_idxs=[
("image", 0),
("audio", 0),
("image", 1),
("audio", 1),
], ],
), ),
## Interleaved, internally unsorted
# Multiple modalities without hashes should return sorted modalities
# and flattened ranges and None.
TestCase( TestCase(
mm_positions={ mm_positions={
"image": [ "image": [
PlaceholderRange(offset=7, length=4), PlaceholderRange(offset=8, length=2),
PlaceholderRange(offset=11, length=5), PlaceholderRange(offset=0, length=4),
], ],
"audio": [ "audio": [
PlaceholderRange(offset=0, length=2), PlaceholderRange(offset=11, length=4),
PlaceholderRange(offset=2, length=3), PlaceholderRange(offset=5, length=2),
] ]
}, },
mm_hashes=None, expected_modality_idxs=[
expected_modalities=["audio", "audio", "image", "image"], ("image", 1),
expected_ranges=[ ("audio", 1),
PlaceholderRange(offset=0, length=2), ("image", 0),
PlaceholderRange(offset=2, length=3), ("audio", 0),
PlaceholderRange(offset=7, length=4),
PlaceholderRange(offset=11, length=5),
], ],
expected_hashes=None,
), ),
# Three modalities # Three modalities
## Internally sorted
TestCase( TestCase(
mm_positions={ mm_positions={
"image": [ "image": [
...@@ -293,72 +289,16 @@ def test_merge_and_sort_multimodal_metadata(): ...@@ -293,72 +289,16 @@ def test_merge_and_sort_multimodal_metadata():
PlaceholderRange(offset=12, length=6), PlaceholderRange(offset=12, length=6),
] ]
}, },
mm_hashes={ expected_modality_idxs=[
"image": ["image_hash1", "image_hash2"], ("audio", 0),
"audio": ["audio_hash1"], ("video", 0),
"video": ["video_hash1", "video_hash2", "video_hash3"] ("video", 1),
}, ("video", 2),
expected_modalities=[ ("image", 0),
"audio", "video", "video", "video", "image", "image" ("image", 1),
],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=3, length=4),
PlaceholderRange(offset=7, length=5),
PlaceholderRange(offset=12, length=6),
PlaceholderRange(offset=15, length=7),
PlaceholderRange(offset=22, length=8),
],
expected_hashes=[
"audio_hash1", "video_hash1", "video_hash2", "video_hash3",
"image_hash1", "image_hash2"
],
),
]
for (mm_positions, mm_hashes, expected_modalities, expected_ranges,
expected_hashes) in test_cases:
modalities, ranges, hashes = merge_and_sort_multimodal_metadata(
mm_positions, mm_hashes)
assert modalities == expected_modalities
assert ranges == expected_ranges
assert hashes == expected_hashes
def test_merge_and_sort_multimodal_metadata_with_interleaving():
test_cases = [
# <image> <audio> <image> <audio>
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=0, length=4),
PlaceholderRange(offset=8, length=2),
],
"audio": [
PlaceholderRange(offset=5, length=2),
PlaceholderRange(offset=11, length=4),
]
},
mm_hashes={
"image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1", "audio_hash2"],
},
expected_modalities=["image", "audio", "image", "audio"],
expected_ranges=[
PlaceholderRange(offset=0, length=4),
PlaceholderRange(offset=5, length=2),
PlaceholderRange(offset=8, length=2),
PlaceholderRange(offset=11, length=4),
],
expected_hashes=[
"image_hash1", "audio_hash1", "image_hash2", "audio_hash2"
], ],
), ),
## Interleaved, internally sorted
# <image> <image> <audio> <video> <image>
TestCase( TestCase(
mm_positions={ mm_positions={
"image": [ "image": [
...@@ -373,58 +313,43 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving(): ...@@ -373,58 +313,43 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving():
PlaceholderRange(offset=8, length=5), PlaceholderRange(offset=8, length=5),
] ]
}, },
mm_hashes=None, expected_modality_idxs=[
expected_modalities=["image", "image", "audio", "video", "image"], ("image", 0),
expected_ranges=[ ("image", 1),
PlaceholderRange(offset=0, length=2), ("audio", 0),
PlaceholderRange(offset=2, length=3), ("video", 0),
PlaceholderRange(offset=5, length=2), ("image", 2),
PlaceholderRange(offset=8, length=5),
PlaceholderRange(offset=20, length=4),
], ],
expected_hashes=None,
), ),
## Interleaved, internally sunorted
# <image> <audio> <video> <image> with hashes
TestCase( TestCase(
mm_positions={ mm_positions={
"image": [ "image": [
PlaceholderRange(offset=0, length=2), PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=18, length=4), PlaceholderRange(offset=20, length=4),
PlaceholderRange(offset=2, length=3),
], ],
"audio": [ "audio": [
PlaceholderRange(offset=6, length=2), PlaceholderRange(offset=5, length=2),
], ],
"video": [ "video": [
PlaceholderRange(offset=10, length=5), PlaceholderRange(offset=8, length=5),
] ]
}, },
mm_hashes={ expected_modality_idxs=[
"image": ["image_hash1", "image_hash2"], ("image", 0),
"audio": ["audio_hash1"], ("image", 2),
"video": ["video_hash1"], ("audio", 0),
}, ("video", 0),
expected_modalities=["image", "audio", "video", "image"], ("image", 1),
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=6, length=2),
PlaceholderRange(offset=10, length=5),
PlaceholderRange(offset=18, length=4),
],
expected_hashes=[
"image_hash1", "audio_hash1", "video_hash1", "image_hash2"
], ],
), ),
] ]
for (mm_positions, mm_hashes, expected_modalities, expected_ranges, for mm_positions, expected_modality_idxs in test_cases:
expected_hashes) in test_cases: modality_idxs = argsort_mm_positions(mm_positions)
modalities, ranges, hashes = merge_and_sort_multimodal_metadata(
mm_positions, mm_hashes)
assert modalities == expected_modalities assert modality_idxs == expected_modality_idxs
assert ranges == expected_ranges
assert hashes == expected_hashes
class SimpleLinearModel(torch.nn.Module): class SimpleLinearModel(torch.nn.Module):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib import importlib
from typing import Optional
import pytest import pytest
import torch import torch
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargsItem,
PlaceholderRange)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheManager
...@@ -27,20 +30,29 @@ from vllm.v1.request import Request ...@@ -27,20 +30,29 @@ from vllm.v1.request import Request
# yapf: enable # yapf: enable
def make_request(request_id, def make_request(
prompt_token_ids, request_id: str,
mm_positions=None, prompt_token_ids: list[int],
mm_hashes=None, mm_positions: Optional[list[PlaceholderRange]] = None,
cache_salt=None): mm_hashes: Optional[list[str]] = None,
cache_salt: Optional[str] = None,
):
if mm_positions is None: if mm_positions is None:
multi_modal_inputs = None mm_kwargs = None
else: else:
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions) mm_elem = MultiModalFieldElem(
modality="dummy_m",
key="dummy_k",
data=None,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_kwargs = [mm_item] * len(mm_positions)
return Request( return Request(
request_id=request_id, request_id=request_id,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
multi_modal_inputs=multi_modal_inputs, multi_modal_kwargs=mm_kwargs,
multi_modal_hashes=mm_hashes, multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions, multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17), sampling_params=SamplingParams(max_tokens=17),
...@@ -316,7 +328,7 @@ def test_free_kv_cache_block_queue_get_all_free_blocks(): ...@@ -316,7 +328,7 @@ def test_free_kv_cache_block_queue_get_all_free_blocks():
def test_generate_block_hash_extra_keys(): def test_generate_block_hash_extra_keys():
request = make_request( request = make_request(
request_id=0, request_id="0",
prompt_token_ids=[_ for _ in range(20)], prompt_token_ids=[_ for _ in range(20)],
mm_positions=[ mm_positions=[
PlaceholderRange(offset=0, length=5), PlaceholderRange(offset=0, length=5),
...@@ -348,7 +360,7 @@ def test_generate_block_hash_extra_keys(): ...@@ -348,7 +360,7 @@ def test_generate_block_hash_extra_keys():
def test_generate_block_hash_extra_keys_no_mm_inputs(): def test_generate_block_hash_extra_keys_no_mm_inputs():
request = make_request( request = make_request(
request_id=0, request_id="0",
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
mm_positions=None, mm_positions=None,
mm_hashes=None, mm_hashes=None,
...@@ -361,7 +373,7 @@ def test_generate_block_hash_extra_keys_no_mm_inputs(): ...@@ -361,7 +373,7 @@ def test_generate_block_hash_extra_keys_no_mm_inputs():
def test_generate_block_hash_extra_keys_cache_salt(): def test_generate_block_hash_extra_keys_cache_salt():
request = make_request( request = make_request(
request_id=0, request_id="0",
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
mm_positions=None, mm_positions=None,
mm_hashes=None, mm_hashes=None,
...@@ -382,7 +394,7 @@ def test_generate_block_hash_extra_keys_cache_salt(): ...@@ -382,7 +394,7 @@ def test_generate_block_hash_extra_keys_cache_salt():
# works together with other extra keys # works together with other extra keys
request_mm = make_request( request_mm = make_request(
request_id=0, request_id="0",
prompt_token_ids=[_ for _ in range(20)], prompt_token_ids=[_ for _ in range(20)],
mm_positions=[ mm_positions=[
PlaceholderRange(offset=0, length=5), PlaceholderRange(offset=0, length=5),
...@@ -420,7 +432,7 @@ def test_hash_request_tokens(hash_fn): ...@@ -420,7 +432,7 @@ def test_hash_request_tokens(hash_fn):
import vllm.v1.core.kv_cache_utils import vllm.v1.core.kv_cache_utils
init_none_hash(hash_fn) init_none_hash(hash_fn)
request = make_request( request = make_request(
request_id=0, request_id="0",
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
mm_positions=[ mm_positions=[
PlaceholderRange(offset=0, length=3), PlaceholderRange(offset=0, length=3),
...@@ -450,7 +462,7 @@ def test_hash_tokens_different_mm_input(hash_fn): ...@@ -450,7 +462,7 @@ def test_hash_tokens_different_mm_input(hash_fn):
init_none_hash(hash_fn) init_none_hash(hash_fn)
request1 = make_request( request1 = make_request(
request_id=0, request_id="0",
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
mm_positions=[ mm_positions=[
PlaceholderRange(offset=0, length=3), PlaceholderRange(offset=0, length=3),
...@@ -459,7 +471,7 @@ def test_hash_tokens_different_mm_input(hash_fn): ...@@ -459,7 +471,7 @@ def test_hash_tokens_different_mm_input(hash_fn):
mm_hashes=["hash1", "hash2"], mm_hashes=["hash1", "hash2"],
) )
request2 = make_request( request2 = make_request(
request_id=1, request_id="1",
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
mm_positions=[ mm_positions=[
PlaceholderRange(offset=0, length=3), PlaceholderRange(offset=0, length=3),
...@@ -479,7 +491,7 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn): ...@@ -479,7 +491,7 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
init_none_hash(hash_fn) init_none_hash(hash_fn)
request = make_request( request = make_request(
request_id=0, request_id="0",
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
mm_positions=None, mm_positions=None,
mm_hashes=None, mm_hashes=None,
...@@ -844,7 +856,7 @@ def test_allocate_with_lookahead(): ...@@ -844,7 +856,7 @@ def test_allocate_with_lookahead():
) )
request = make_request( request = make_request(
request_id=0, request_id="0",
prompt_token_ids=[], prompt_token_ids=[],
mm_positions=None, mm_positions=None,
mm_hashes=None, mm_hashes=None,
......
...@@ -9,7 +9,9 @@ import pytest ...@@ -9,7 +9,9 @@ import pytest
import torch import torch
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargsItem,
PlaceholderRange)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import sha256, sha256_cbor_64bit from vllm.utils import sha256, sha256_cbor_64bit
from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.block_pool import BlockPool
...@@ -21,21 +23,30 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, ...@@ -21,21 +23,30 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, SlidingWindowSpec) KVCacheGroupSpec, SlidingWindowSpec)
def make_request(request_id, def make_request(
prompt_token_ids, request_id: str,
mm_positions=None, prompt_token_ids: list[int],
mm_hashes=None, mm_positions: Optional[list[PlaceholderRange]] = None,
mm_hashes: Optional[list[str]] = None,
prompt_logprobs: Optional[int] = None, prompt_logprobs: Optional[int] = None,
cache_salt: Optional[str] = None): cache_salt: Optional[str] = None,
):
if mm_positions is None: if mm_positions is None:
multi_modal_inputs = None mm_kwargs = None
else: else:
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions) mm_elem = MultiModalFieldElem(
modality="dummy_m",
key="dummy_k",
data=None,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_kwargs = [mm_item] * len(mm_positions)
return Request( return Request(
request_id=request_id, request_id=request_id,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
multi_modal_inputs=multi_modal_inputs, multi_modal_kwargs=mm_kwargs,
multi_modal_hashes=mm_hashes, multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions, multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17, sampling_params=SamplingParams(max_tokens=17,
......
...@@ -8,7 +8,9 @@ import torch ...@@ -8,7 +8,9 @@ import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig) SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargsItem,
PlaceholderRange)
from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
...@@ -1304,7 +1306,7 @@ def create_requests_with_priority( ...@@ -1304,7 +1306,7 @@ def create_requests_with_priority(
priorities: list[int], priorities: list[int],
arrival_times: Optional[list[float]] = None, arrival_times: Optional[list[float]] = None,
num_tokens: int = 10, num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None, mm_positions: Optional[list[list[PlaceholderRange]]] = None,
max_tokens: int = 16, max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None, stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None): prompt_logprobs: Optional[int] = None):
...@@ -1323,16 +1325,23 @@ def create_requests_with_priority( ...@@ -1323,16 +1325,23 @@ def create_requests_with_priority(
for i in range(num_requests): for i in range(num_requests):
if mm_positions is not None: if mm_positions is not None:
mm_position = mm_positions[i] mm_position = mm_positions[i]
mm_inputs = [MultiModalKwargs({})] * len(mm_position) mm_elem = MultiModalFieldElem(
modality="dummy_m",
key="dummy_k",
data=None,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_kwargs = [mm_item] * len(mm_position)
else: else:
mm_position = None mm_position = None
mm_inputs = None mm_kwargs = None
request = Request( request = Request(
request_id=f"{i}", request_id=f"{i}",
prompt_token_ids=[i] * num_tokens, prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
multi_modal_inputs=mm_inputs, multi_modal_kwargs=mm_kwargs,
multi_modal_placeholders=mm_position, multi_modal_placeholders=mm_position,
multi_modal_hashes=None, multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID,
...@@ -1816,7 +1825,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request(): ...@@ -1816,7 +1825,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
request = Request( request = Request(
request_id="0", request_id="0",
prompt_token_ids=[0, 1], prompt_token_ids=[0, 1],
multi_modal_inputs=None, multi_modal_kwargs=None,
multi_modal_hashes=None, multi_modal_hashes=None,
multi_modal_placeholders=None, multi_modal_placeholders=None,
sampling_params=sampling_params, sampling_params=sampling_params,
......
...@@ -6,7 +6,9 @@ import torch ...@@ -6,7 +6,9 @@ import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig) SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargsItem,
PlaceholderRange)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.async_scheduler import AsyncScheduler from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
...@@ -115,7 +117,7 @@ def create_scheduler( ...@@ -115,7 +117,7 @@ def create_scheduler(
def create_requests( def create_requests(
num_requests: int, num_requests: int,
num_tokens: int = 10, num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None, mm_positions: Optional[list[list[PlaceholderRange]]] = None,
max_tokens: int = 16, max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None, stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None, prompt_logprobs: Optional[int] = None,
...@@ -129,10 +131,17 @@ def create_requests( ...@@ -129,10 +131,17 @@ def create_requests(
for i in range(num_requests): for i in range(num_requests):
if mm_positions is not None: if mm_positions is not None:
mm_position = mm_positions[i] mm_position = mm_positions[i]
mm_inputs = [MultiModalKwargs({})] * len(mm_position) mm_elem = MultiModalFieldElem(
modality="dummy_m",
key="dummy_k",
data=None,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_kwargs = [mm_item] * len(mm_position)
else: else:
mm_position = None mm_position = None
mm_inputs = None mm_kwargs = None
prompt_token_ids = ([0] * num_tokens if same_prompt else [i] * prompt_token_ids = ([0] * num_tokens if same_prompt else [i] *
num_tokens) num_tokens)
request = Request( request = Request(
...@@ -140,7 +149,7 @@ def create_requests( ...@@ -140,7 +149,7 @@ def create_requests(
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
multi_modal_inputs=mm_inputs, multi_modal_kwargs=mm_kwargs,
multi_modal_placeholders=mm_position, multi_modal_placeholders=mm_position,
multi_modal_hashes=None, multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID,
......
...@@ -35,7 +35,7 @@ def make_request() -> EngineCoreRequest: ...@@ -35,7 +35,7 @@ def make_request() -> EngineCoreRequest:
return EngineCoreRequest( return EngineCoreRequest(
request_id=str(uuid.uuid4()), request_id=str(uuid.uuid4()),
prompt_token_ids=PROMPT_TOKENS, prompt_token_ids=PROMPT_TOKENS,
mm_inputs=None, mm_kwargs=None,
mm_hashes=None, mm_hashes=None,
mm_placeholders=None, mm_placeholders=None,
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
......
...@@ -52,7 +52,7 @@ def make_request( ...@@ -52,7 +52,7 @@ def make_request(
return EngineCoreRequest( return EngineCoreRequest(
request_id=str(uuid.uuid4()), request_id=str(uuid.uuid4()),
prompt_token_ids=prompt_tokens_ids, prompt_token_ids=prompt_tokens_ids,
mm_inputs=None, mm_kwargs=None,
mm_hashes=None, mm_hashes=None,
mm_placeholders=None, mm_placeholders=None,
sampling_params=params, sampling_params=params,
......
...@@ -53,7 +53,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, ...@@ -53,7 +53,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
EngineCoreRequest(request_id=f"request-{idx}", EngineCoreRequest(request_id=f"request-{idx}",
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, arrival_time=0,
mm_inputs=None, mm_kwargs=None,
mm_hashes=None, mm_hashes=None,
mm_placeholders=None, mm_placeholders=None,
eos_token_id=None, eos_token_id=None,
...@@ -402,7 +402,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, ...@@ -402,7 +402,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
EngineCoreRequest(request_id=request_id_list[idx], EngineCoreRequest(request_id=request_id_list[idx],
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, arrival_time=0,
mm_inputs=None, mm_kwargs=None,
mm_hashes=None, mm_hashes=None,
mm_placeholders=None, mm_placeholders=None,
eos_token_id=None, eos_token_id=None,
...@@ -567,7 +567,7 @@ def test_stop_token(include_stop_str_in_output: bool, ...@@ -567,7 +567,7 @@ def test_stop_token(include_stop_str_in_output: bool,
request_id=request_id, request_id=request_id,
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, arrival_time=0,
mm_inputs=None, mm_kwargs=None,
mm_hashes=None, mm_hashes=None,
mm_placeholders=None, mm_placeholders=None,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
...@@ -666,7 +666,7 @@ def test_stop_string(include_stop_str_in_output: bool, ...@@ -666,7 +666,7 @@ def test_stop_string(include_stop_str_in_output: bool,
request_id=request_id_list[idx], request_id=request_id_list[idx],
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, arrival_time=0,
mm_inputs=None, mm_kwargs=None,
mm_hashes=None, mm_hashes=None,
mm_placeholders=None, mm_placeholders=None,
eos_token_id=None, eos_token_id=None,
...@@ -782,7 +782,7 @@ def test_iteration_stats(dummy_test_vectors): ...@@ -782,7 +782,7 @@ def test_iteration_stats(dummy_test_vectors):
request_id=f"request-{idx}", request_id=f"request-{idx}",
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0, arrival_time=0,
mm_inputs=None, mm_kwargs=None,
mm_hashes=None, mm_hashes=None,
mm_placeholders=None, mm_placeholders=None,
eos_token_id=None, eos_token_id=None,
......
...@@ -154,7 +154,7 @@ def create_request( ...@@ -154,7 +154,7 @@ def create_request(
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
multi_modal_inputs=None, multi_modal_kwargs=None,
multi_modal_placeholders=None, multi_modal_placeholders=None,
multi_modal_hashes=None, multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID,
......
...@@ -64,7 +64,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: ...@@ -64,7 +64,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData( NewRequestData(
req_id=req_id, req_id=req_id,
prompt_token_ids=[1, 2, 3], prompt_token_ids=[1, 2, 3],
mm_inputs=[], mm_kwargs=[],
mm_hashes=[], mm_hashes=[],
mm_positions=[], mm_positions=[],
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
......
...@@ -203,7 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int): ...@@ -203,7 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int):
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
sampling_params=_create_sampling_params(), sampling_params=_create_sampling_params(),
pooling_params=None, pooling_params=None,
mm_inputs=[], mm_kwargs=[],
mm_positions=[], mm_positions=[],
block_ids=([], ), block_ids=([], ),
generator=None, generator=None,
......
...@@ -120,7 +120,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: ...@@ -120,7 +120,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData( NewRequestData(
req_id=req_id, req_id=req_id,
prompt_token_ids=[1, 2, 3], prompt_token_ids=[1, 2, 3],
mm_inputs=[], mm_kwargs=[],
mm_hashes=[], mm_hashes=[],
mm_positions=[], mm_positions=[],
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import UserDict, defaultdict from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass, replace
from functools import partial from functools import partial
from itertools import accumulate from itertools import accumulate
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar, from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
...@@ -198,7 +198,7 @@ A dictionary containing nested tensors which have been batched via ...@@ -198,7 +198,7 @@ A dictionary containing nested tensors which have been batched via
""" """
@dataclass(frozen=True) @dataclass
class MultiModalFieldElem: class MultiModalFieldElem:
""" """
Represents a keyword argument corresponding to a multi-modal item Represents a keyword argument corresponding to a multi-modal item
...@@ -218,11 +218,14 @@ class MultiModalFieldElem: ...@@ -218,11 +218,14 @@ class MultiModalFieldElem:
i.e. the name of the keyword argument to be passed to the model. i.e. the name of the keyword argument to be passed to the model.
""" """
data: NestedTensors data: Optional[NestedTensors]
""" """
The tensor data of this field in The tensor data of this field in
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs], [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
i.e. the value of the keyword argument to be passed to the model. i.e. the value of the keyword argument to be passed to the model.
It may be set to `None` if it is determined that the item is cached
in `EngineCore`.
""" """
field: "BaseMultiModalField" field: "BaseMultiModalField"
...@@ -235,8 +238,15 @@ class MultiModalFieldElem: ...@@ -235,8 +238,15 @@ class MultiModalFieldElem:
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return False return False
if self.data is None:
data_equal = other.data is None
elif other.data is None:
data_equal = self.data is None
else:
data_equal = nested_tensors_equal(self.data, other.data)
return ((self.modality, self.key) == (other.modality, other.key) return ((self.modality, self.key) == (other.modality, other.key)
and nested_tensors_equal(self.data, other.data) and data_equal
and type(self.field) == type(other.field)) # noqa: E721 and type(self.field) == type(other.field)) # noqa: E721
...@@ -280,10 +290,20 @@ class BaseMultiModalField(ABC): ...@@ -280,10 +290,20 @@ class BaseMultiModalField(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: def _reduce_data(
self,
batch: list[NestedTensors],
*,
pin_memory: bool,
) -> NestedTensors:
raise NotImplementedError raise NotImplementedError
def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors: def reduce_data(
self,
elems: list[MultiModalFieldElem],
*,
pin_memory: bool = False,
) -> NestedTensors:
""" """
Merge the data from multiple instances of Merge the data from multiple instances of
[`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]. [`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
...@@ -295,7 +315,13 @@ class BaseMultiModalField(ABC): ...@@ -295,7 +315,13 @@ class BaseMultiModalField(ABC):
if len(set(field_types)) > 1: if len(set(field_types)) > 1:
raise ValueError(f"Cannot merge different {field_types=}") raise ValueError(f"Cannot merge different {field_types=}")
return self._reduce_data([item.data for item in elems]) validated_data = list[NestedTensors]()
for i, elem in enumerate(elems):
assert elem.data is not None, (
f"Cannot merge with empty `elems[{i}]`")
validated_data.append(elem.data)
return self._reduce_data(validated_data, pin_memory=pin_memory)
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -314,7 +340,12 @@ class MultiModalBatchedField(BaseMultiModalField): ...@@ -314,7 +340,12 @@ class MultiModalBatchedField(BaseMultiModalField):
field_factory = self._field_factory(modality=modality, key=key) field_factory = self._field_factory(modality=modality, key=key)
return [field_factory(item) for item in data] return [field_factory(item) for item in data]
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: def _reduce_data(
self,
batch: list[NestedTensors],
*,
pin_memory: bool,
) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
if len(batch) == 1: if len(batch) == 1:
# An optimization when `batch` contains only one tensor: # An optimization when `batch` contains only one tensor:
...@@ -323,7 +354,11 @@ class MultiModalBatchedField(BaseMultiModalField): ...@@ -323,7 +354,11 @@ class MultiModalBatchedField(BaseMultiModalField):
return batch[0].unsqueeze(0).contiguous() return batch[0].unsqueeze(0).contiguous()
first_shape = batch[0].shape first_shape = batch[0].shape
if all(elem.shape == first_shape for elem in batch): if all(elem.shape == first_shape for elem in batch):
return torch.stack(batch) out = torch.empty((len(batch), *batch[0].shape),
dtype=batch[0].dtype,
device=batch[0].device,
pin_memory=pin_memory)
return torch.stack(batch, out=out)
return batch return batch
...@@ -350,7 +385,12 @@ class MultiModalFlatField(BaseMultiModalField): ...@@ -350,7 +385,12 @@ class MultiModalFlatField(BaseMultiModalField):
"torch.Tensor is required for multiple slices" "torch.Tensor is required for multiple slices"
return [field_factory(data[cast(slice, s)]) for s in self.slices] return [field_factory(data[cast(slice, s)]) for s in self.slices]
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: def _reduce_data(
self,
batch: list[NestedTensors],
*,
pin_memory: bool,
) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
if len(batch) == 1: if len(batch) == 1:
# An optimization when `batch` contains only one tensor: # An optimization when `batch` contains only one tensor:
...@@ -358,13 +398,21 @@ class MultiModalFlatField(BaseMultiModalField): ...@@ -358,13 +398,21 @@ class MultiModalFlatField(BaseMultiModalField):
# - will achieve zero-copy if the tensor is contiguous # - will achieve zero-copy if the tensor is contiguous
return batch[0].contiguous() return batch[0].contiguous()
def _expect_same_shape(tensor: torch.Tensor): dim = self.dim + (self.dim < 0) * len(batch[0].shape)
return tensor.shape[:self.dim] + tensor.shape[self.dim + 1:]
first_shape = _expect_same_shape(batch[0]) def _shape_before_after(tensor: torch.Tensor):
return tensor.shape[:dim], tensor.shape[dim + 1:]
if all(_expect_same_shape(elem) == first_shape for elem in batch): first_shape = _shape_before_after(batch[0])
return torch.concat(batch, dim=self.dim)
if all(_shape_before_after(elem) == first_shape for elem in batch):
shape_before, shape_after = first_shape
shape_concat = sum(item.shape[dim] for item in batch)
out = torch.empty((*shape_before, shape_concat, *shape_after),
dtype=batch[0].dtype,
device=batch[0].device,
pin_memory=pin_memory)
return torch.concat(batch, dim=self.dim, out=out)
assert self.dim == 0, "dim == 0 is required for nested list" assert self.dim == 0, "dim == 0 is required for nested list"
return [e for elem in batch for e in elem] return [e for elem in batch for e in elem]
...@@ -387,7 +435,12 @@ class MultiModalSharedField(BaseMultiModalField): ...@@ -387,7 +435,12 @@ class MultiModalSharedField(BaseMultiModalField):
field_factory = self._field_factory(modality=modality, key=key) field_factory = self._field_factory(modality=modality, key=key)
return [field_factory(data)] * self.batch_size return [field_factory(data)] * self.batch_size
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: def _reduce_data(
self,
batch: list[NestedTensors],
*,
pin_memory: bool,
) -> NestedTensors:
return batch[0] return batch[0]
...@@ -594,11 +647,53 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): ...@@ -594,11 +647,53 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
def from_elems(elems: Sequence[MultiModalFieldElem]): def from_elems(elems: Sequence[MultiModalFieldElem]):
return MultiModalKwargsItem({elem.key: elem for elem in elems}) return MultiModalKwargsItem({elem.key: elem for elem in elems})
@property def __init__(self, data: Mapping[str, MultiModalFieldElem]) -> None:
def modality(self) -> str: super().__init__(data)
modalities = {elem.modality for elem in self.data.values()} modalities = {elem.modality for elem in self.data.values()}
assert len(modalities) == 1, f"Found different modalities={modalities}" assert len(modalities) == 1, f"Found different modalities={modalities}"
return next(iter(modalities)) self._modality = next(iter(modalities))
self._is_empty = any(elem.data is None for elem in self.values())
@property
def modality(self) -> str:
return self._modality
@property
def is_empty(self) -> bool:
return self._is_empty
def get_data(self) -> Optional[Mapping[str, NestedTensors]]:
if self._is_empty:
return None
out_data = dict[str, NestedTensors]()
for key, elem in self.items():
assert elem.data is not None, (
f"Cannot get data of empty `elem[{key!r}]`")
out_data[key] = elem.data
return out_data
def require_data(self) -> Mapping[str, NestedTensors]:
if (data := self.get_data()) is None:
raise RuntimeError("Cannot get data of empty item")
return data
# These methods create a new item to avoid mutating cached items in place
def with_data(self, data: Mapping[str, NestedTensors]):
return MultiModalKwargsItem({
key: replace(elem, data=data[key])
for key, elem in self.items()
})
def without_data(self):
return MultiModalKwargsItem({
key: replace(elem, data=None)
for key, elem in self.items()
})
# NOTE: UserDict is for V0 compatibility. # NOTE: UserDict is for V0 compatibility.
...@@ -650,7 +745,11 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): ...@@ -650,7 +745,11 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
return MultiModalKwargs.from_items(items) return MultiModalKwargs.from_items(items)
@staticmethod @staticmethod
def from_items(items: Sequence[MultiModalKwargsItem]): def from_items(
items: Sequence[MultiModalKwargsItem],
*,
pin_memory: bool = False,
):
"""Construct a new """Construct a new
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs]
from multiple items.""" from multiple items."""
...@@ -660,7 +759,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): ...@@ -660,7 +759,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
elems_by_key[key].append(elem) elems_by_key[key].append(elem)
data = { data = {
key: elems[0].field.reduce_data(elems) key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
for key, elems in elems_by_key.items() if len(elems) > 0 for key, elems in elems_by_key.items() if len(elems) > 0
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import asyncio import asyncio
import atexit import atexit
from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from itertools import groupby from itertools import groupby
from pathlib import Path from pathlib import Path
...@@ -13,6 +14,7 @@ import numpy as np ...@@ -13,6 +14,7 @@ import numpy as np
import numpy.typing as npt import numpy.typing as npt
import torch import torch
from PIL import Image, UnidentifiedImageError from PIL import Image, UnidentifiedImageError
from typing_extensions import deprecated
import vllm.envs as envs import vllm.envs as envs
from vllm.connections import HTTPConnection, global_http_connection from vllm.connections import HTTPConnection, global_http_connection
...@@ -23,17 +25,17 @@ from vllm.distributed import (get_tensor_model_parallel_rank, ...@@ -23,17 +25,17 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
from .audio import AudioMediaIO from .audio import AudioMediaIO
from .base import MediaIO from .base import MediaIO
from .image import ImageEmbeddingMediaIO, ImageMediaIO from .image import ImageEmbeddingMediaIO, ImageMediaIO
from .inputs import PlaceholderRange
from .video import VideoMediaIO from .video import VideoMediaIO
_M = TypeVar("_M") _M = TypeVar("_M")
if TYPE_CHECKING: if TYPE_CHECKING:
from .hasher import MultiModalHashDict from .inputs import (BatchedTensorInputs, MultiModalKwargs,
from .inputs import MultiModalKwargs, MultiModalPlaceholderDict MultiModalKwargsItem, MultiModalPlaceholderDict)
else: else:
MultiModalHashDict = Any BatchedTensorInputs = Any
MultiModalKwargs = Any MultiModalKwargs = Any
MultiModalKwargsItem = Any
MultiModalPlaceholderDict = Any MultiModalPlaceholderDict = Any
global_thread_pool = ThreadPoolExecutor( global_thread_pool = ThreadPoolExecutor(
...@@ -331,79 +333,32 @@ def encode_video_base64(frames: npt.NDArray) -> str: ...@@ -331,79 +333,32 @@ def encode_video_base64(frames: npt.NDArray) -> str:
return video_io.encode_base64(frames) return video_io.encode_base64(frames)
def merge_and_sort_multimodal_metadata( def argsort_mm_positions(
mm_positions: MultiModalPlaceholderDict, mm_positions: MultiModalPlaceholderDict) -> list[tuple[str, int]]:
mm_hashes: Optional[MultiModalHashDict], """
) -> tuple[list[str], list[PlaceholderRange], Optional[list[str]]]: Given a `MultiModalPlaceholderDict`, output a sequence of keys to
"""Given a MultiModalPlaceholderDict, merge all PlaceholderRange sort the dictionary by `offset` (starting index in the input sequence)
objects from all available modalities into a single list of in ascending order.
PlaceholderRange, sorted by their offset (starting index in the input
sequence) in the ascending order.
Optionally if a `MultiModalHashDict` is given, same operation will be
applied to the object and the sorted list of hashes will be returned.
Returns: Returns:
list[str]: List of item modalities in order of their positions in the A list of `(modality, idx)`, which can be used to access an item
input sequence. by `mm_positions[modality][idx]`.
list[PlaceholderRange]: Sorted list of all PlaceholderRanges from
mm_positions.
Optional[list[str]]: Sorted list of all hashes from mm_hashes if given,
None otherwise.
""" """
flat_items = ((modality, idx, item)
for modality, items in mm_positions.items()
for idx, item in enumerate(items))
modalities = list(mm_positions.keys()) sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
assert len(modalities) > 0, "No modalities found in the mm_positions."
# For single modality, placeholder ranges and hashes are already sorted
# so we can return the list directly.
if len(modalities) == 1:
modality = modalities[0]
placeholder_list = list(mm_positions[modality])
return [modality] * len(
placeholder_list
), placeholder_list, None if not mm_hashes else mm_hashes[modality]
# Create a list of (modality, placeholder, hash) tuples for all placeholders
all_items = []
for modality in modalities:
placeholder_list = list(mm_positions[modality])
hash_list: list[Optional[str]] = list(
mm_hashes[modality]) if mm_hashes and modality in mm_hashes else [
None
] * len(placeholder_list)
for placeholder, hash_value in zip(placeholder_list, hash_list): return [(modality, idx) for modality, idx, _ in sorted_flat_items]
all_items.append((modality, placeholder, hash_value))
# Sort all items by offset
all_items.sort(key=lambda x: x[1].offset)
# Split into separate lists
sorted_modalities = [item[0] for item in all_items]
merged_placeholders = [item[1] for item in all_items]
merged_hashes = [str(item[2])
for item in all_items] if mm_hashes is not None else None
return sorted_modalities, merged_placeholders, merged_hashes
# Temporary back-compatibility for plugins that define model runner
@deprecated("`group_mm_inputs_by_modality` is superseded by "
"`group_mm_kwargs_by_modality` and will be removed in v0.13. "
"Please use `group_mm_kwargs_by_modality` instead.")
def group_mm_inputs_by_modality( def group_mm_inputs_by_modality(
mm_inputs: list[MultiModalKwargs]) -> list[list[MultiModalKwargs]]: mm_inputs: list[MultiModalKwargs]) -> list[list[MultiModalKwargs]]:
"""Group consecutive MultiModalKwargs from mm_inputs with the same modality
together into the same list for batching purpose. For MultiModalKwargs with
multiple modalities, put them into their own list.
Args:
mm_inputs: List of MultiModalKwargs.
Returns:
list[list[vllm.multimodal.MultiModalKwargs]]: List of list of
`MultiModalKwargs`, each inner list contains consecutive
`MultiModalKwargs` with same modality.
"""
if not mm_inputs: if not mm_inputs:
return [] return []
...@@ -426,6 +381,48 @@ def group_mm_inputs_by_modality( ...@@ -426,6 +381,48 @@ def group_mm_inputs_by_modality(
] ]
def group_mm_kwargs_by_modality(
mm_kwargs: list[MultiModalKwargsItem],
*,
device: torch.types.Device = None,
pin_memory: bool = False,
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
modality together into the same `MultiModalKwargs` instance.
Args:
mm_inputs: List of `MultiModalKwargsItem`.
Yields:
A tuple `(modality, num_items, grouped_kwargs)`.
"""
from vllm.multimodal.inputs import MultiModalKwargs
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
items_lst = list(items)
# mm_kwargs_group = MultiModalKwargs.from_items(items_lst,
# pin_memory=pin_memory)
# if device is not None:
# mm_kwargs_group = json_map_leaves(lambda x: x.to(device=device),
# mm_kwargs_group.data)
# TODO: Once V0 is removed, we can use the merging logic above
# to avoid creating an extra batch dimension (except for fields
# that are meant to be stacked anyway).
# We will also need to update each model to remove `flatten_bn`.
mm_kwargs_group = MultiModalKwargs.as_kwargs(
MultiModalKwargs.batch(
[MultiModalKwargs.from_items([item]) for item in items_lst],
pin_memory=pin_memory,
),
device=device,
)
yield modality, len(items_lst), mm_kwargs_group
def run_dp_sharded_vision_model(image_input: torch.Tensor, def run_dp_sharded_vision_model(image_input: torch.Tensor,
vision_model: torch.nn.Module) -> torch.Tensor: vision_model: torch.nn.Module) -> torch.Tensor:
"""Run a vision model with data parallelism (DP) sharding. The function """Run a vision model with data parallelism (DP) sharding. The function
......
...@@ -13,7 +13,7 @@ if TYPE_CHECKING: ...@@ -13,7 +13,7 @@ if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata) KVConnectorMetadata)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request from vllm.v1.request import Request
...@@ -24,7 +24,7 @@ class NewRequestData: ...@@ -24,7 +24,7 @@ class NewRequestData:
req_id: str req_id: str
prompt_token_ids: list[int] prompt_token_ids: list[int]
mm_inputs: list[MultiModalKwargs] mm_kwargs: list[MultiModalKwargsItem]
mm_hashes: list[str] mm_hashes: list[str]
mm_positions: list[PlaceholderRange] mm_positions: list[PlaceholderRange]
sampling_params: Optional[SamplingParams] sampling_params: Optional[SamplingParams]
...@@ -42,7 +42,7 @@ class NewRequestData: ...@@ -42,7 +42,7 @@ class NewRequestData:
return cls( return cls(
req_id=request.request_id, req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids, prompt_token_ids=request.prompt_token_ids,
mm_inputs=request.mm_inputs, mm_kwargs=request.mm_kwargs,
mm_hashes=request.mm_hashes, mm_hashes=request.mm_hashes,
mm_positions=request.mm_positions, mm_positions=request.mm_positions,
sampling_params=request.sampling_params, sampling_params=request.sampling_params,
...@@ -56,7 +56,7 @@ class NewRequestData: ...@@ -56,7 +56,7 @@ class NewRequestData:
return (f"NewRequestData(" return (f"NewRequestData("
f"req_id={self.req_id}," f"req_id={self.req_id},"
f"prompt_token_ids={self.prompt_token_ids}," f"prompt_token_ids={self.prompt_token_ids},"
f"mm_inputs={self.mm_inputs}," f"mm_kwargs={self.mm_kwargs},"
f"mm_hashes={self.mm_hashes}," f"mm_hashes={self.mm_hashes},"
f"mm_positions={self.mm_positions}," f"mm_positions={self.mm_positions},"
f"sampling_params={self.sampling_params}," f"sampling_params={self.sampling_params},"
...@@ -70,7 +70,7 @@ class NewRequestData: ...@@ -70,7 +70,7 @@ class NewRequestData:
return (f"NewRequestData(" return (f"NewRequestData("
f"req_id={self.req_id}," f"req_id={self.req_id},"
f"prompt_token_ids_len={len(self.prompt_token_ids)}," f"prompt_token_ids_len={len(self.prompt_token_ids)},"
f"mm_inputs={self.mm_inputs}," f"mm_kwargs={self.mm_kwargs},"
f"mm_hashes={self.mm_hashes}," f"mm_hashes={self.mm_hashes},"
f"mm_positions={self.mm_positions}," f"mm_positions={self.mm_positions},"
f"sampling_params={self.sampling_params}," f"sampling_params={self.sampling_params},"
......
...@@ -3,15 +3,13 @@ ...@@ -3,15 +3,13 @@
import enum import enum
import time import time
from collections.abc import Sequence
from typing import Any, Optional, Union from typing import Any, Optional, Union
import msgspec import msgspec
import torch import torch
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.inputs import PlaceholderRange
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
...@@ -49,7 +47,7 @@ class EngineCoreRequest( ...@@ -49,7 +47,7 @@ class EngineCoreRequest(
request_id: str request_id: str
prompt_token_ids: list[int] prompt_token_ids: list[int]
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] mm_kwargs: Optional[list[MultiModalKwargsItem]]
mm_hashes: Optional[list[str]] mm_hashes: Optional[list[str]]
mm_placeholders: Optional[list[PlaceholderRange]] mm_placeholders: Optional[list[PlaceholderRange]]
sampling_params: Optional[SamplingParams] sampling_params: Optional[SamplingParams]
......
...@@ -409,12 +409,13 @@ class EngineCore: ...@@ -409,12 +409,13 @@ class EngineCore:
request initialization running in parallel with Model forward request initialization running in parallel with Model forward
""" """
if request.mm_hashes is not None: if request.mm_hashes is not None:
assert request.mm_inputs is not None assert request.mm_kwargs is not None
# Note on thread safety: no race condition. # Note on thread safety: no race condition.
# `mm_input_cache_server` is reset at the end of LLMEngine init, # `mm_input_cache_server` is reset at the end of LLMEngine init,
# and will only accessed in the input processing thread afterwards. # and will only accessed in the input processing thread afterwards.
request.mm_inputs = self.mm_input_cache_server.get_and_update( request.mm_kwargs = self.mm_input_cache_server.get_and_update(
request.mm_inputs, request.mm_hashes) request.mm_kwargs, request.mm_hashes)
req = Request.from_engine_core_request(request) req = Request.from_engine_core_request(request)
if req.use_structured_output: if req.use_structured_output:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence from collections.abc import Mapping
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
from vllm.multimodal import MultiModalKwargs, MultiModalRegistry from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
from vllm.utils import is_list_of from vllm.multimodal.inputs import MultiModalKwargsItem, NestedTensors
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -17,23 +17,23 @@ if TYPE_CHECKING: ...@@ -17,23 +17,23 @@ if TYPE_CHECKING:
# -- P0: # -- P0:
# - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of # - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of
# each input multi-modal item (e.g. image), # each input multi-modal item (e.g. image),
# - BaseMultiModalProcessor processes the input items into `mm_inputs`, # - BaseMultiModalProcessor processes the input items into `mm_kwargs`,
# which are MultiModalKwargsItem instances that each correspond to an # which are MultiModalKwargsItem instances that each correspond to an
# input multi-modal item. # input multi-modal item.
# - MultiModalInputCacheClient accepts the `mm_inputs` and corresponding # - MultiModalInputCacheClient accepts the `mm_kwargs` and corresponding
# `mm_hash` for each item. It stores the `mm_hash` as keys and the size # `mm_hash` for each item. It stores the `mm_hash` as keys and the size
# of `mm_inputs`, but not the `mm_inputs` themselves, to avoid taking # of `mm_kwargs`, but not the `mm_kwargs` themselves, to avoid taking
# up additional memory in P0. # up additional memory in P0.
# - The `mm_hash` is always sent to P1. # - The `mm_hash` is always sent to P1.
# - The corresponding `mm_inputs` are only sent to P1 if they are not cached # - The corresponding `mm_kwargs` are only sent to P1 if they are not cached
# in MultiModalInputCacheServer. # in MultiModalInputCacheServer.
# #
# -- P1: # -- P1:
# - If the `mm_hash` is cached (i.e. `mm_inputs` are not sent from P0), # - If the `mm_hash` is cached (i.e. `mm_kwargs` are not sent from P0),
# MultiModalInputCacheServer retrieves the corresponding `mm_inputs`. # MultiModalInputCacheServer retrieves the corresponding `mm_kwargs`.
# - If the `mm_hash` is not cached (i.e. `mm_inputs` are sent from P0), # - If the `mm_hash` is not cached (i.e. `mm_kwargs` are sent from P0),
# MultiModalInputCacheServer stores `mm_inputs` under the key `mm_hash`. # MultiModalInputCacheServer stores `mm_kwargs` under the key `mm_hash`.
# - Either way, the `mm_hash` and corresponding `mm_inputs` are sent to # - Either way, the `mm_hash` and corresponding `mm_kwargs` are sent to
# the engine for model execution. # the engine for model execution.
# #
# Both Client and Server must perform cache update and eviction based on the # Both Client and Server must perform cache update and eviction based on the
...@@ -58,26 +58,24 @@ class MultiModalInputCacheClient: ...@@ -58,26 +58,24 @@ class MultiModalInputCacheClient:
def get_and_update( def get_and_update(
self, self,
mm_inputs: Sequence[MultiModalKwargs], mm_kwargs: list[MultiModalKwargsItem],
mm_hashes: list[str], mm_hashes: list[str],
) -> Sequence[Optional[MultiModalKwargs]]: ) -> list[MultiModalKwargsItem]:
assert len(mm_inputs) == len(mm_hashes)
if not self.enabled: if not self.enabled:
assert is_list_of(mm_inputs, MultiModalKwargs) return mm_kwargs
return mm_inputs
assert len(mm_kwargs) == len(mm_hashes)
full_mm_inputs = list[Optional[MultiModalKwargs]]() out_mm_items = list[MultiModalKwargsItem]()
for mm_input, mm_hash in zip(mm_inputs, mm_hashes): for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
if self.mm_cache.get(mm_hash) is not None: if self.mm_cache.get(mm_hash) is not None:
mm_input = None out_mm_items.append(mm_item.without_data())
else: else:
self.mm_cache[mm_hash] = \ self.mm_cache[mm_hash] = \
MultiModalCacheItemMetadata.wraps(mm_input) MultiModalCacheItemMetadata.wraps(mm_item.require_data())
out_mm_items.append(mm_item)
full_mm_inputs.append(mm_input)
return full_mm_inputs return out_mm_items
def reset(self) -> None: def reset(self) -> None:
self.mm_cache.clear() self.mm_cache.clear()
...@@ -93,30 +91,28 @@ class MultiModalInputCacheServer: ...@@ -93,30 +91,28 @@ class MultiModalInputCacheServer:
self.enabled = mm_registry.enable_mm_input_cache(model_config) self.enabled = mm_registry.enable_mm_input_cache(model_config)
self.mm_cache = MultiModalCache.get_lru_cache( self.mm_cache = MultiModalCache.get_lru_cache(
model_config.get_mm_input_cache_gb(), model_config.get_mm_input_cache_gb(),
MultiModalKwargs, Mapping[str, NestedTensors],
) )
def get_and_update( def get_and_update(
self, self,
mm_inputs: Sequence[Optional[MultiModalKwargs]], mm_kwargs: list[MultiModalKwargsItem],
mm_hashes: list[str], mm_hashes: list[str],
) -> Sequence[MultiModalKwargs]: ) -> list[MultiModalKwargsItem]:
assert len(mm_inputs) == len(mm_hashes)
if not self.enabled: if not self.enabled:
assert is_list_of(mm_inputs, MultiModalKwargs) return mm_kwargs
return mm_inputs
full_mm_inputs = list[MultiModalKwargs]() assert len(mm_kwargs) == len(mm_hashes)
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
if mm_input is None:
mm_input = self.mm_cache[mm_hash]
else:
self.mm_cache[mm_hash] = mm_input
full_mm_inputs.append(mm_input) out_mm_items = list[MultiModalKwargsItem]()
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
if (mm_data := mm_item.get_data()) is None:
out_mm_items.append(mm_item.with_data(self.mm_cache[mm_hash]))
else:
self.mm_cache[mm_hash] = mm_data
out_mm_items.append(mm_item)
return full_mm_inputs return out_mm_items
def reset(self) -> None: def reset(self) -> None:
self.mm_cache.clear() self.mm_cache.clear()
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time import time
from collections.abc import Mapping, Sequence from collections.abc import Mapping
from typing import Any, Literal, Optional, Union from typing import Any, Literal, Optional, Union
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -10,11 +10,10 @@ from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs ...@@ -10,11 +10,10 @@ from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
MultiModalRegistry) from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
...@@ -296,57 +295,42 @@ class Processor: ...@@ -296,57 +295,42 @@ class Processor:
pooling_params = params.clone() pooling_params = params.clone()
# Multimodal related. # Multimodal related.
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None sorted_mm_inputs: Optional[list[MultiModalKwargsItem]] = None
sorted_mm_positions: Optional[list[PlaceholderRange]] = None sorted_mm_positions: Optional[list[PlaceholderRange]] = None
sorted_mm_hashes: Optional[list[str]] = None sorted_mm_hashes: Optional[list[str]] = None
if decoder_inputs["type"] == "multimodal": if decoder_inputs["type"] == "multimodal":
decoder_mm_inputs = decoder_inputs["mm_kwargs"] decoder_mm_inputs = decoder_inputs["mm_kwargs"]
decoder_mm_positions = decoder_inputs["mm_placeholders"]
decoder_mm_hashes = decoder_inputs.get("mm_hashes")
# Merge and flatten multimodal placeholders, hashes and inputs # Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position # from dictionaries to lists, and sort them by each item's position
# in the input sequence. # in the input sequence.
( sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
sorted_item_modalities,
sorted_mm_positions,
sorted_mm_hashes,
) = merge_and_sort_multimodal_metadata(
decoder_inputs["mm_placeholders"],
decoder_inputs["mm_hashes"] if return_mm_hashes else None,
)
# The output of merged multi-modal processor (`decoder_mm_inputs`) sorted_mm_inputs = [
# is a single MultiModalKwargs for all items from all modalities. decoder_mm_inputs.get_item(modality, idx)
# This code flattens kwargs for individual items in a list and for modality, idx in sorted_mm_idxs
# sorts them by each item's position in the input sequence if there ]
# are multiple modalities. sorted_mm_positions = [
unique_modalities = set(sorted_item_modalities) decoder_mm_positions[modality][idx]
if len(unique_modalities) > 1: for modality, idx in sorted_mm_idxs
orig_sorted_mm_inputs = [] ]
used_indices = {modality: 0 for modality in unique_modalities} sorted_mm_hashes = None if decoder_mm_hashes is None else [
decoder_mm_hashes[modality][idx]
for modality in sorted_item_modalities: for modality, idx in sorted_mm_idxs
items = decoder_mm_inputs.get_items(modality)
item = items[used_indices[modality]]
orig_sorted_mm_inputs.append(
MultiModalKwargs.from_items([item]))
used_indices[modality] += 1
else:
orig_sorted_mm_inputs = [
MultiModalKwargs.from_items([item]) for item in
decoder_mm_inputs.get_items(sorted_item_modalities[0])
] ]
if sorted_mm_hashes is not None: if sorted_mm_hashes is not None:
sorted_mm_inputs = self.mm_input_cache_client.get_and_update( sorted_mm_inputs = self.mm_input_cache_client.get_and_update(
orig_sorted_mm_inputs, sorted_mm_hashes) sorted_mm_inputs,
else: sorted_mm_hashes,
sorted_mm_inputs = orig_sorted_mm_inputs )
return decoder_inputs.get("prompt"), EngineCoreRequest( return decoder_inputs.get("prompt"), EngineCoreRequest(
request_id=request_id, request_id=request_id,
prompt_token_ids=decoder_inputs["prompt_token_ids"], prompt_token_ids=decoder_inputs["prompt_token_ids"],
mm_inputs=sorted_mm_inputs, mm_kwargs=sorted_mm_inputs,
mm_hashes=sorted_mm_hashes, mm_hashes=sorted_mm_hashes,
mm_placeholders=sorted_mm_positions, mm_placeholders=sorted_mm_positions,
sampling_params=sampling_params, sampling_params=sampling_params,
......
...@@ -5,7 +5,7 @@ import enum ...@@ -5,7 +5,7 @@ import enum
import time import time
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import is_list_of from vllm.utils import is_list_of
...@@ -24,7 +24,7 @@ class Request: ...@@ -24,7 +24,7 @@ class Request:
self, self,
request_id: str, request_id: str,
prompt_token_ids: list[int], prompt_token_ids: list[int],
multi_modal_inputs: Optional[list[MultiModalKwargs]], multi_modal_kwargs: Optional[list[MultiModalKwargsItem]],
multi_modal_hashes: Optional[list[str]], multi_modal_hashes: Optional[list[str]],
multi_modal_placeholders: Optional[list[PlaceholderRange]], multi_modal_placeholders: Optional[list[PlaceholderRange]],
sampling_params: Optional[SamplingParams], sampling_params: Optional[SamplingParams],
...@@ -84,15 +84,15 @@ class Request: ...@@ -84,15 +84,15 @@ class Request:
# Multi-modal related # Multi-modal related
self.mm_positions = multi_modal_placeholders or [] self.mm_positions = multi_modal_placeholders or []
self.mm_inputs = multi_modal_inputs or [] self.mm_kwargs = multi_modal_kwargs or []
self.mm_hashes: list[str] = multi_modal_hashes or [] self.mm_hashes: list[str] = multi_modal_hashes or []
self.num_encoder_inputs = len(self.mm_inputs) self.num_encoder_inputs = len(self.mm_kwargs)
self.has_encoder_inputs = self.num_encoder_inputs > 0 self.has_encoder_inputs = self.num_encoder_inputs > 0
# Sanity check # Sanity check
assert len(self.mm_inputs) == len(self.mm_positions) assert len(self.mm_kwargs) == len(self.mm_positions)
if self.mm_hashes: if self.mm_hashes:
assert len(self.mm_inputs) == len(self.mm_hashes) assert len(self.mm_kwargs) == len(self.mm_hashes)
# Read-only views # Read-only views
# Prevent directly appending to these lists since # Prevent directly appending to these lists since
...@@ -110,16 +110,15 @@ class Request: ...@@ -110,16 +110,15 @@ class Request:
@classmethod @classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
if request.mm_inputs is not None: if request.mm_kwargs is not None:
assert isinstance(request.mm_inputs, list) assert is_list_of(request.mm_kwargs, MultiModalKwargsItem), (
assert is_list_of(request.mm_inputs, MultiModalKwargs), ( "mm_kwargs was not updated in EngineCore.add_request")
"mm_inputs was not updated in EngineCore.add_request")
return cls( return cls(
request_id=request.request_id, request_id=request.request_id,
client_index=request.client_index, client_index=request.client_index,
prompt_token_ids=request.prompt_token_ids, prompt_token_ids=request.prompt_token_ids,
multi_modal_inputs=request.mm_inputs, multi_modal_kwargs=request.mm_kwargs,
multi_modal_hashes=request.mm_hashes, multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders, multi_modal_placeholders=request.mm_placeholders,
sampling_params=request.sampling_params, sampling_params=request.sampling_params,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment