Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
69f46359
Unverified
Commit
69f46359
authored
Aug 29, 2025
by
Flora Feng
Committed by
GitHub
Aug 29, 2025
Browse files
[Multimodal] Consolidate mm inputs into MultiModalFeatureSpec (#23779)
Signed-off-by:
sfeng33
<
4florafeng@gmail.com
>
parent
d9e00dbd
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
143 additions
and
146 deletions
+143
-146
tests/tokenization/test_detokenize.py
tests/tokenization/test_detokenize.py
+0
-2
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+13
-9
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+13
-9
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+14
-12
tests/v1/core/utils.py
tests/v1/core/utils.py
+15
-15
tests/v1/engine/test_engine_core.py
tests/v1/engine/test_engine_core.py
+1
-3
tests/v1/engine/test_engine_core_client.py
tests/v1/engine/test_engine_core_client.py
+1
-3
tests/v1/engine/test_fast_incdec_prefix_err.py
tests/v1/engine/test_fast_incdec_prefix_err.py
+8
-10
tests/v1/engine/test_output_processor.py
tests/v1/engine/test_output_processor.py
+10
-20
tests/v1/kv_connector/unit/utils.py
tests/v1/kv_connector/unit/utils.py
+1
-3
vllm/multimodal/cache.py
vllm/multimodal/cache.py
+13
-3
vllm/multimodal/inputs.py
vllm/multimodal/inputs.py
+23
-0
vllm/v1/engine/__init__.py
vllm/v1/engine/__init__.py
+2
-5
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+7
-9
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+12
-19
vllm/v1/request.py
vllm/v1/request.py
+10
-24
No files found.
tests/tokenization/test_detokenize.py
View file @
69f46359
...
@@ -64,8 +64,6 @@ def _run_incremental_decode(tokenizer,
...
@@ -64,8 +64,6 @@ def _run_incremental_decode(tokenizer,
request
=
EngineCoreRequest
(
""
,
request
=
EngineCoreRequest
(
""
,
prompt_token_ids
,
prompt_token_ids
,
None
,
None
,
None
,
None
,
params
,
params
,
None
,
None
,
None
,
None
,
...
...
tests/v1/core/test_kv_cache_utils.py
View file @
69f46359
...
@@ -7,7 +7,8 @@ import pytest
...
@@ -7,7 +7,8 @@ import pytest
import
torch
import
torch
from
vllm.config
import
ModelConfig
,
SchedulerConfig
,
VllmConfig
from
vllm.config
import
ModelConfig
,
SchedulerConfig
,
VllmConfig
from
vllm.multimodal.inputs
import
MultiModalKwargsItem
,
PlaceholderRange
from
vllm.multimodal.inputs
import
(
MultiModalFeatureSpec
,
MultiModalKwargsItem
,
PlaceholderRange
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
GiB_bytes
,
sha256
,
sha256_cbor_64bit
from
vllm.utils
import
GiB_bytes
,
sha256
,
sha256_cbor_64bit
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
...
@@ -37,17 +38,20 @@ def make_request(
...
@@ -37,17 +38,20 @@ def make_request(
mm_hashes
:
Optional
[
list
[
str
]]
=
None
,
mm_hashes
:
Optional
[
list
[
str
]]
=
None
,
cache_salt
:
Optional
[
str
]
=
None
,
cache_salt
:
Optional
[
str
]
=
None
,
):
):
if
mm_positions
is
None
:
mm_features
=
[]
mm_kwargs
=
None
if
mm_positions
is
not
None
:
else
:
for
j
,
position
in
enumerate
(
mm_positions
):
mm_item
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
)
identifier
=
mm_hashes
[
j
]
if
mm_hashes
else
f
"hash_
{
j
}
"
mm_kwargs
=
[
mm_item
]
*
len
(
mm_positions
)
mm_feature
=
MultiModalFeatureSpec
(
data
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
),
mm_position
=
position
,
identifier
=
identifier
,
modality
=
"image"
)
mm_features
.
append
(
mm_feature
)
return
Request
(
request_id
=
request_id
,
return
Request
(
request_id
=
request_id
,
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
multi_modal_kwargs
=
mm_kwargs
,
mm_features
=
mm_features
if
mm_features
else
None
,
multi_modal_hashes
=
mm_hashes
,
multi_modal_placeholders
=
mm_positions
,
sampling_params
=
SamplingParams
(
max_tokens
=
17
),
sampling_params
=
SamplingParams
(
max_tokens
=
17
),
pooling_params
=
None
,
pooling_params
=
None
,
eos_token_id
=
100
,
eos_token_id
=
100
,
...
...
tests/v1/core/test_prefix_caching.py
View file @
69f46359
...
@@ -9,7 +9,8 @@ import pytest
...
@@ -9,7 +9,8 @@ import pytest
import
torch
import
torch
from
vllm.distributed.kv_events
import
AllBlocksCleared
,
BlockRemoved
from
vllm.distributed.kv_events
import
AllBlocksCleared
,
BlockRemoved
from
vllm.multimodal.inputs
import
MultiModalKwargsItem
,
PlaceholderRange
from
vllm.multimodal.inputs
import
(
MultiModalFeatureSpec
,
MultiModalKwargsItem
,
PlaceholderRange
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
sha256
,
sha256_cbor_64bit
from
vllm.utils
import
sha256
,
sha256_cbor_64bit
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.block_pool
import
BlockPool
...
@@ -32,17 +33,20 @@ def make_request(
...
@@ -32,17 +33,20 @@ def make_request(
prompt_logprobs
:
Optional
[
int
]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
,
cache_salt
:
Optional
[
str
]
=
None
,
cache_salt
:
Optional
[
str
]
=
None
,
):
):
if
mm_positions
is
None
:
mm_features
=
[]
mm_kwargs
=
None
if
mm_positions
is
not
None
:
else
:
for
j
,
position
in
enumerate
(
mm_positions
):
mm_item
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
)
identifier
=
mm_hashes
[
j
]
if
mm_hashes
else
f
"hash_
{
j
}
"
mm_kwargs
=
[
mm_item
]
*
len
(
mm_positions
)
mm_feature
=
MultiModalFeatureSpec
(
data
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
),
mm_position
=
position
,
identifier
=
identifier
,
modality
=
"image"
)
mm_features
.
append
(
mm_feature
)
return
Request
(
request_id
=
request_id
,
return
Request
(
request_id
=
request_id
,
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
multi_modal_kwargs
=
mm_kwargs
,
mm_features
=
mm_features
if
mm_features
else
None
,
multi_modal_hashes
=
mm_hashes
,
multi_modal_placeholders
=
mm_positions
,
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
max_tokens
=
17
,
prompt_logprobs
=
prompt_logprobs
),
max_tokens
=
17
,
prompt_logprobs
=
prompt_logprobs
),
pooling_params
=
None
,
pooling_params
=
None
,
...
...
tests/v1/core/test_scheduler.py
View file @
69f46359
...
@@ -8,7 +8,8 @@ import torch
...
@@ -8,7 +8,8 @@ import torch
from
vllm.config
import
(
CacheConfig
,
KVTransferConfig
,
ModelConfig
,
from
vllm.config
import
(
CacheConfig
,
KVTransferConfig
,
ModelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
VllmConfig
)
SchedulerConfig
,
SpeculativeConfig
,
VllmConfig
)
from
vllm.multimodal.inputs
import
MultiModalKwargsItem
,
PlaceholderRange
from
vllm.multimodal.inputs
import
(
MultiModalFeatureSpec
,
MultiModalKwargsItem
,
PlaceholderRange
)
from
vllm.sampling_params
import
GuidedDecodingParams
,
SamplingParams
from
vllm.sampling_params
import
GuidedDecodingParams
,
SamplingParams
from
vllm.v1.core.sched.output
import
CachedRequestData
,
SchedulerOutput
from
vllm.v1.core.sched.output
import
CachedRequestData
,
SchedulerOutput
from
vllm.v1.core.sched.scheduler
import
Scheduler
from
vllm.v1.core.sched.scheduler
import
Scheduler
...
@@ -1308,21 +1309,24 @@ def create_requests_with_priority(
...
@@ -1308,21 +1309,24 @@ def create_requests_with_priority(
prompt_logprobs
=
prompt_logprobs
)
prompt_logprobs
=
prompt_logprobs
)
requests
=
[]
requests
=
[]
for
i
in
range
(
num_requests
):
for
i
in
range
(
num_requests
):
mm_features
=
[]
if
mm_positions
is
not
None
:
if
mm_positions
is
not
None
:
mm_position
=
mm_positions
[
i
]
mm_position
=
mm_positions
[
i
]
mm_item
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
)
for
j
,
position
in
enumerate
(
mm_position
):
mm_kwargs
=
[
mm_item
]
*
len
(
mm_position
)
identifier
=
f
"hash
{
i
}
_
{
j
}
"
else
:
mm_feature
=
MultiModalFeatureSpec
(
mm_position
=
None
data
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
),
mm_kwargs
=
None
mm_position
=
position
,
identifier
=
identifier
,
modality
=
"image"
)
mm_features
.
append
(
mm_feature
)
request
=
Request
(
request
=
Request
(
request_id
=
f
"
{
i
+
starting_idx
}
"
,
request_id
=
f
"
{
i
+
starting_idx
}
"
,
prompt_token_ids
=
[
i
+
starting_idx
]
*
num_tokens
,
prompt_token_ids
=
[
i
+
starting_idx
]
*
num_tokens
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
pooling_params
=
None
,
pooling_params
=
None
,
multi_modal_kwargs
=
mm_kwargs
,
mm_features
=
mm_features
if
mm_features
else
None
,
multi_modal_placeholders
=
mm_position
,
multi_modal_hashes
=
None
,
eos_token_id
=
EOS_TOKEN_ID
,
eos_token_id
=
EOS_TOKEN_ID
,
arrival_time
=
arrival_times
[
i
],
arrival_time
=
arrival_times
[
i
],
priority
=
priorities
[
i
],
priority
=
priorities
[
i
],
...
@@ -1801,9 +1805,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
...
@@ -1801,9 +1805,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
request
=
Request
(
request
=
Request
(
request_id
=
"0"
,
request_id
=
"0"
,
prompt_token_ids
=
[
0
,
1
],
prompt_token_ids
=
[
0
,
1
],
multi_modal_kwargs
=
None
,
mm_features
=
None
,
multi_modal_hashes
=
None
,
multi_modal_placeholders
=
None
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
pooling_params
=
None
,
pooling_params
=
None
,
eos_token_id
=
EOS_TOKEN_ID
,
eos_token_id
=
EOS_TOKEN_ID
,
...
...
tests/v1/core/utils.py
View file @
69f46359
...
@@ -6,7 +6,8 @@ import torch
...
@@ -6,7 +6,8 @@ import torch
from
vllm.config
import
(
CacheConfig
,
KVTransferConfig
,
ModelConfig
,
from
vllm.config
import
(
CacheConfig
,
KVTransferConfig
,
ModelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
VllmConfig
)
SchedulerConfig
,
SpeculativeConfig
,
VllmConfig
)
from
vllm.multimodal.inputs
import
MultiModalKwargsItem
,
PlaceholderRange
from
vllm.multimodal.inputs
import
(
MultiModalFeatureSpec
,
MultiModalKwargsItem
,
PlaceholderRange
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.core.kv_cache_utils
import
(
get_request_block_hasher
,
from
vllm.v1.core.kv_cache_utils
import
(
get_request_block_hasher
,
init_none_hash
)
init_none_hash
)
...
@@ -139,19 +140,20 @@ def create_requests(
...
@@ -139,19 +140,20 @@ def create_requests(
prompt_logprobs
=
prompt_logprobs
)
prompt_logprobs
=
prompt_logprobs
)
requests
=
[]
requests
=
[]
for
i
in
range
(
num_requests
):
for
i
in
range
(
num_requests
):
mm_features
=
[]
if
mm_positions
is
not
None
:
if
mm_positions
is
not
None
:
mm_position
=
mm_positions
[
i
]
mm_position
=
mm_positions
[
i
]
mm_item
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
)
for
j
,
position
in
enumerate
(
mm_position
):
mm_kwargs
=
[
mm_item
]
*
len
(
mm_position
)
# Dummy hash for each mm item should be unique
# Dummy hash for each mm item should be unique
# since encoder cache tracks entries by hash
# since encoder cache tracks entries by hash
mm_hashes
=
[
identifier
=
f
"hash
{
i
}
_
{
j
}
"
"hash"
+
str
(
i
)
+
"_"
+
str
(
j
)
for
j
in
range
(
len
(
mm_position
))
mm_feature
=
MultiModalFeatureSpec
(
]
data
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
),
else
:
mm_position
=
position
,
mm_position
=
None
identifier
=
identifier
,
mm_kwargs
=
None
modality
=
"image"
)
mm_hashes
=
None
mm_features
.
append
(
mm_feature
)
prompt_token_ids
=
([
0
]
*
num_tokens
if
same_prompt
else
[
i
]
*
prompt_token_ids
=
([
0
]
*
num_tokens
if
same_prompt
else
[
i
]
*
num_tokens
)
num_tokens
)
request
=
Request
(
request
=
Request
(
...
@@ -159,9 +161,7 @@ def create_requests(
...
@@ -159,9 +161,7 @@ def create_requests(
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
pooling_params
=
None
,
pooling_params
=
None
,
multi_modal_kwargs
=
mm_kwargs
,
mm_features
=
mm_features
if
mm_features
else
None
,
multi_modal_placeholders
=
mm_position
,
multi_modal_hashes
=
mm_hashes
,
eos_token_id
=
EOS_TOKEN_ID
,
eos_token_id
=
EOS_TOKEN_ID
,
block_hasher
=
block_hasher
,
block_hasher
=
block_hasher
,
)
)
...
...
tests/v1/engine/test_engine_core.py
View file @
69f46359
...
@@ -35,9 +35,7 @@ def make_request() -> EngineCoreRequest:
...
@@ -35,9 +35,7 @@ def make_request() -> EngineCoreRequest:
return
EngineCoreRequest
(
return
EngineCoreRequest
(
request_id
=
str
(
uuid
.
uuid4
()),
request_id
=
str
(
uuid
.
uuid4
()),
prompt_token_ids
=
PROMPT_TOKENS
,
prompt_token_ids
=
PROMPT_TOKENS
,
mm_kwargs
=
None
,
mm_features
=
None
,
mm_hashes
=
None
,
mm_placeholders
=
None
,
sampling_params
=
SamplingParams
(),
sampling_params
=
SamplingParams
(),
pooling_params
=
None
,
pooling_params
=
None
,
eos_token_id
=
None
,
eos_token_id
=
None
,
...
...
tests/v1/engine/test_engine_core_client.py
View file @
69f46359
...
@@ -52,9 +52,7 @@ def make_request(
...
@@ -52,9 +52,7 @@ def make_request(
return
EngineCoreRequest
(
return
EngineCoreRequest
(
request_id
=
str
(
uuid
.
uuid4
()),
request_id
=
str
(
uuid
.
uuid4
()),
prompt_token_ids
=
prompt_tokens_ids
,
prompt_token_ids
=
prompt_tokens_ids
,
mm_kwargs
=
None
,
mm_features
=
None
,
mm_hashes
=
None
,
mm_placeholders
=
None
,
sampling_params
=
params
,
sampling_params
=
params
,
pooling_params
=
None
,
pooling_params
=
None
,
eos_token_id
=
None
,
eos_token_id
=
None
,
...
...
tests/v1/engine/test_fast_incdec_prefix_err.py
View file @
69f46359
...
@@ -26,16 +26,14 @@ def test_fast_inc_detok_invalid_utf8_err_case():
...
@@ -26,16 +26,14 @@ def test_fast_inc_detok_invalid_utf8_err_case():
prompt_token_ids
=
[
107
,
4606
,
236787
,
107
]
prompt_token_ids
=
[
107
,
4606
,
236787
,
107
]
params
=
SamplingParams
(
skip_special_tokens
=
True
)
params
=
SamplingParams
(
skip_special_tokens
=
True
)
request
=
EngineCoreRequest
(
request
=
EngineCoreRequest
(
"test"
,
request_id
=
"test"
,
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
None
,
mm_features
=
None
,
None
,
sampling_params
=
params
,
None
,
pooling_params
=
None
,
params
,
eos_token_id
=
None
,
None
,
arrival_time
=
0.0
,
None
,
lora_request
=
None
,
0.0
,
None
,
cache_salt
=
None
,
cache_salt
=
None
,
data_parallel_rank
=
None
,
data_parallel_rank
=
None
,
)
)
...
...
tests/v1/engine/test_output_processor.py
View file @
69f46359
...
@@ -52,11 +52,9 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
...
@@ -52,11 +52,9 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
requests
=
[
requests
=
[
EngineCoreRequest
(
request_id
=
f
"request-
{
idx
}
"
,
EngineCoreRequest
(
request_id
=
f
"request-
{
idx
}
"
,
prompt_token_ids
=
prompt_tokens
,
prompt_token_ids
=
prompt_tokens
,
arrival_time
=
0
,
mm_features
=
None
,
mm_kwargs
=
None
,
mm_hashes
=
None
,
mm_placeholders
=
None
,
eos_token_id
=
None
,
eos_token_id
=
None
,
arrival_time
=
0
,
lora_request
=
None
,
lora_request
=
None
,
cache_salt
=
None
,
cache_salt
=
None
,
data_parallel_rank
=
None
,
data_parallel_rank
=
None
,
...
@@ -401,11 +399,9 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
...
@@ -401,11 +399,9 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
requests
=
[
requests
=
[
EngineCoreRequest
(
request_id
=
request_id_list
[
idx
],
EngineCoreRequest
(
request_id
=
request_id_list
[
idx
],
prompt_token_ids
=
prompt_tokens
,
prompt_token_ids
=
prompt_tokens
,
arrival_time
=
0
,
mm_features
=
None
,
mm_kwargs
=
None
,
mm_hashes
=
None
,
mm_placeholders
=
None
,
eos_token_id
=
None
,
eos_token_id
=
None
,
arrival_time
=
0
,
lora_request
=
None
,
lora_request
=
None
,
cache_salt
=
None
,
cache_salt
=
None
,
data_parallel_rank
=
None
,
data_parallel_rank
=
None
,
...
@@ -566,11 +562,9 @@ def test_stop_token(include_stop_str_in_output: bool,
...
@@ -566,11 +562,9 @@ def test_stop_token(include_stop_str_in_output: bool,
request
=
EngineCoreRequest
(
request
=
EngineCoreRequest
(
request_id
=
request_id
,
request_id
=
request_id
,
prompt_token_ids
=
prompt_tokens
,
prompt_token_ids
=
prompt_tokens
,
arrival_time
=
0
,
mm_features
=
None
,
mm_kwargs
=
None
,
mm_hashes
=
None
,
mm_placeholders
=
None
,
eos_token_id
=
eos_token_id
,
eos_token_id
=
eos_token_id
,
arrival_time
=
0
,
lora_request
=
None
,
lora_request
=
None
,
cache_salt
=
None
,
cache_salt
=
None
,
data_parallel_rank
=
None
,
data_parallel_rank
=
None
,
...
@@ -665,11 +659,9 @@ def test_stop_string(include_stop_str_in_output: bool,
...
@@ -665,11 +659,9 @@ def test_stop_string(include_stop_str_in_output: bool,
EngineCoreRequest
(
EngineCoreRequest
(
request_id
=
request_id_list
[
idx
],
request_id
=
request_id_list
[
idx
],
prompt_token_ids
=
prompt_tokens
,
prompt_token_ids
=
prompt_tokens
,
arrival_time
=
0
,
mm_features
=
None
,
mm_kwargs
=
None
,
mm_hashes
=
None
,
mm_placeholders
=
None
,
eos_token_id
=
None
,
eos_token_id
=
None
,
arrival_time
=
0
,
lora_request
=
None
,
lora_request
=
None
,
cache_salt
=
None
,
cache_salt
=
None
,
data_parallel_rank
=
None
,
data_parallel_rank
=
None
,
...
@@ -781,11 +773,9 @@ def test_iteration_stats(dummy_test_vectors):
...
@@ -781,11 +773,9 @@ def test_iteration_stats(dummy_test_vectors):
EngineCoreRequest
(
EngineCoreRequest
(
request_id
=
f
"request-
{
idx
}
"
,
request_id
=
f
"request-
{
idx
}
"
,
prompt_token_ids
=
prompt_tokens
,
prompt_token_ids
=
prompt_tokens
,
arrival_time
=
0
,
mm_features
=
None
,
mm_kwargs
=
None
,
mm_hashes
=
None
,
mm_placeholders
=
None
,
eos_token_id
=
None
,
eos_token_id
=
None
,
arrival_time
=
0
,
lora_request
=
None
,
lora_request
=
None
,
cache_salt
=
None
,
cache_salt
=
None
,
data_parallel_rank
=
None
,
data_parallel_rank
=
None
,
...
...
tests/v1/kv_connector/unit/utils.py
View file @
69f46359
...
@@ -162,9 +162,7 @@ def create_request(request_id: int,
...
@@ -162,9 +162,7 @@ def create_request(request_id: int,
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
pooling_params
=
None
,
pooling_params
=
None
,
multi_modal_kwargs
=
None
,
mm_features
=
None
,
multi_modal_placeholders
=
None
,
multi_modal_hashes
=
None
,
eos_token_id
=
EOS_TOKEN_ID
,
eos_token_id
=
EOS_TOKEN_ID
,
block_hasher
=
get_request_block_hasher
(
block_size
,
hash_fn
),
block_hasher
=
get_request_block_hasher
(
block_size
,
hash_fn
),
)
)
...
...
vllm/multimodal/cache.py
View file @
69f46359
...
@@ -12,9 +12,9 @@ from vllm.logger import init_logger
...
@@ -12,9 +12,9 @@ from vllm.logger import init_logger
from
vllm.utils
import
GiB_bytes
,
LRUCache
from
vllm.utils
import
GiB_bytes
,
LRUCache
from
vllm.utils.jsontree
import
json_map_leaves
,
json_reduce_leaves
from
vllm.utils.jsontree
import
json_map_leaves
,
json_reduce_leaves
from
.inputs
import
(
MultiModalF
ieldElem
,
MultiModal
Kwargs
,
from
.inputs
import
(
MultiModalF
eatureSpec
,
MultiModal
FieldElem
,
MultiModalKwargs
Item
,
MultiModalKwargsItem
s
,
MultiModalKwargs
,
MultiModalKwargsItem
,
NestedTensors
)
MultiModalKwargsItems
,
NestedTensors
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
...
@@ -418,6 +418,16 @@ class BaseMultiModalReceiverCache(
...
@@ -418,6 +418,16 @@ class BaseMultiModalReceiverCache(
MultiModalKwargsItem
]):
MultiModalKwargsItem
]):
"""The required interface for caches on P1."""
"""The required interface for caches on P1."""
def
get_and_update_features
(
self
,
mm_features
:
list
[
"MultiModalFeatureSpec"
],
)
->
list
[
"MultiModalFeatureSpec"
]:
"""Update multimodal features with cached encoder outputs."""
for
feature
in
mm_features
:
feature
.
data
=
self
.
get_and_update_item
(
feature
.
data
,
feature
.
identifier
)
return
mm_features
class
MultiModalReceiverCache
(
BaseMultiModalReceiverCache
):
class
MultiModalReceiverCache
(
BaseMultiModalReceiverCache
):
"""
"""
...
...
vllm/multimodal/inputs.py
View file @
69f46359
...
@@ -198,6 +198,29 @@ A dictionary containing nested tensors which have been batched via
...
@@ -198,6 +198,29 @@ A dictionary containing nested tensors which have been batched via
"""
"""
@
dataclass
class
MultiModalFeatureSpec
:
"""
Represents a single multimodal input with its processed data and metadata.
Used by the V1 engine to track multimodal data through processing and
caching. A request containing multiple multimodal items will have one
MultiModalFeatureSpec per item.
"""
data
:
Optional
[
"MultiModalKwargsItem"
]
"""Multimodal data for this feature"""
modality
:
str
"""Based on the input, e.g., "image", "audio", "video"."""
identifier
:
str
"""mm_hash or uuid for caching encoder outputs."""
mm_position
:
PlaceholderRange
"""e.g., PlaceholderRange(offset=2, length=336)"""
@
dataclass
@
dataclass
class
MultiModalFieldElem
:
class
MultiModalFieldElem
:
"""
"""
...
...
vllm/v1/engine/__init__.py
View file @
69f46359
...
@@ -3,14 +3,13 @@
...
@@ -3,14 +3,13 @@
import
enum
import
enum
import
time
import
time
from
collections.abc
import
Sequence
from
typing
import
Any
,
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
import
msgspec
import
msgspec
import
torch
import
torch
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal.inputs
import
MultiModal
KwargsItem
,
PlaceholderRange
from
vllm.multimodal.inputs
import
MultiModal
FeatureSpec
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.metrics.stats
import
SchedulerStats
from
vllm.v1.metrics.stats
import
SchedulerStats
...
@@ -48,9 +47,7 @@ class EngineCoreRequest(
...
@@ -48,9 +47,7 @@ class EngineCoreRequest(
request_id
:
str
request_id
:
str
prompt_token_ids
:
list
[
int
]
prompt_token_ids
:
list
[
int
]
mm_kwargs
:
Optional
[
Sequence
[
Optional
[
MultiModalKwargsItem
]]]
mm_features
:
Optional
[
list
[
MultiModalFeatureSpec
]]
mm_hashes
:
Optional
[
list
[
str
]]
mm_placeholders
:
Optional
[
list
[
PlaceholderRange
]]
sampling_params
:
Optional
[
SamplingParams
]
sampling_params
:
Optional
[
SamplingParams
]
pooling_params
:
Optional
[
PoolingParams
]
pooling_params
:
Optional
[
PoolingParams
]
eos_token_id
:
Optional
[
int
]
eos_token_id
:
Optional
[
int
]
...
...
vllm/v1/engine/core.py
View file @
69f46359
...
@@ -434,15 +434,13 @@ class EngineCore:
...
@@ -434,15 +434,13 @@ class EngineCore:
This function could be directly used in input processing thread to allow
This function could be directly used in input processing thread to allow
request initialization running in parallel with Model forward
request initialization running in parallel with Model forward
"""
"""
if
request
.
mm_hashes
is
not
None
:
assert
request
.
mm_kwargs
is
not
None
# Note on thread safety: no race condition.
# Note on thread safety: no race condition.
# `mm_receiver_cache` is reset at the end of LLMEngine init,
# `mm_receiver_cache` is reset at the end of LLMEngine init,
# and will only accessed in the input processing thread afterwards.
# and will only accessed in the input processing thread afterwards.
if
self
.
mm_receiver_cache
is
not
None
:
if
self
.
mm_receiver_cache
is
not
None
and
request
.
mm_features
:
request
.
mm_kwargs
=
self
.
mm_receiver_cache
.
get_and_update
(
request
.
mm_features
=
(
request
.
mm_kwargs
,
request
.
mm_hashes
)
self
.
mm_receiver_cache
.
get_and_update_features
(
request
.
mm_features
))
req
=
Request
.
from_engine_core_request
(
request
,
req
=
Request
.
from_engine_core_request
(
request
,
self
.
request_block_hasher
)
self
.
request_block_hasher
)
...
...
vllm/v1/engine/processor.py
View file @
69f46359
...
@@ -12,7 +12,7 @@ from vllm.inputs.preprocess import InputPreprocessor
...
@@ -12,7 +12,7 @@ from vllm.inputs.preprocess import InputPreprocessor
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal.cache
import
processor_cache_from_config
from
vllm.multimodal.cache
import
processor_cache_from_config
from
vllm.multimodal.inputs
import
MultiModal
KwargsItem
,
PlaceholderRange
from
vllm.multimodal.inputs
import
MultiModal
FeatureSpec
from
vllm.multimodal.processing
import
EncDecMultiModalProcessor
from
vllm.multimodal.processing
import
EncDecMultiModalProcessor
from
vllm.multimodal.utils
import
argsort_mm_positions
from
vllm.multimodal.utils
import
argsort_mm_positions
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
...
@@ -346,9 +346,8 @@ class Processor:
...
@@ -346,9 +346,8 @@ class Processor:
pooling_params
=
params
.
clone
()
pooling_params
=
params
.
clone
()
# Multimodal related.
# Multimodal related.
sorted_mm_inputs
:
Optional
[
list
[
Optional
[
MultiModalKwargsItem
]]]
=
None
mm_features
:
Optional
[
list
[
MultiModalFeatureSpec
]]
=
None
sorted_mm_positions
:
Optional
[
list
[
PlaceholderRange
]]
=
None
sorted_mm_hashes
:
Optional
[
list
[
str
]]
=
None
if
decoder_inputs
[
"type"
]
==
"multimodal"
:
if
decoder_inputs
[
"type"
]
==
"multimodal"
:
decoder_mm_inputs
=
decoder_inputs
[
"mm_kwargs"
]
decoder_mm_inputs
=
decoder_inputs
[
"mm_kwargs"
]
decoder_mm_positions
=
decoder_inputs
[
"mm_placeholders"
]
decoder_mm_positions
=
decoder_inputs
[
"mm_placeholders"
]
...
@@ -359,25 +358,19 @@ class Processor:
...
@@ -359,25 +358,19 @@ class Processor:
# in the input sequence.
# in the input sequence.
sorted_mm_idxs
=
argsort_mm_positions
(
decoder_mm_positions
)
sorted_mm_idxs
=
argsort_mm_positions
(
decoder_mm_positions
)
sorted_mm_inputs
=
[
mm_features
=
[]
decoder_mm_inputs
[
modality
][
idx
]
for
modality
,
idx
in
sorted_mm_idxs
:
for
modality
,
idx
in
sorted_mm_idxs
mm_features
.
append
(
]
MultiModalFeatureSpec
(
sorted_mm_positions
=
[
data
=
decoder_mm_inputs
[
modality
][
idx
],
decoder_mm_positions
[
modality
][
idx
]
modality
=
modality
,
for
modality
,
idx
in
sorted_mm_idxs
identifier
=
decoder_mm_hashes
[
modality
][
idx
],
]
mm_position
=
decoder_mm_positions
[
modality
][
idx
]))
sorted_mm_hashes
=
[
decoder_mm_hashes
[
modality
][
idx
]
for
modality
,
idx
in
sorted_mm_idxs
]
return
decoder_inputs
.
get
(
"prompt"
),
EngineCoreRequest
(
return
decoder_inputs
.
get
(
"prompt"
),
EngineCoreRequest
(
request_id
=
request_id
,
request_id
=
request_id
,
prompt_token_ids
=
decoder_inputs
[
"prompt_token_ids"
],
prompt_token_ids
=
decoder_inputs
[
"prompt_token_ids"
],
mm_kwargs
=
sorted_mm_inputs
,
mm_features
=
mm_features
,
mm_hashes
=
sorted_mm_hashes
,
mm_placeholders
=
sorted_mm_positions
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
pooling_params
=
pooling_params
,
pooling_params
=
pooling_params
,
eos_token_id
=
eos_token_id
,
eos_token_id
=
eos_token_id
,
...
...
vllm/v1/request.py
View file @
69f46359
...
@@ -6,10 +6,9 @@ import time
...
@@ -6,10 +6,9 @@ import time
from
functools
import
partial
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Optional
,
Union
from
vllm.multimodal.inputs
import
MultiModal
KwargsItem
,
PlaceholderRange
from
vllm.multimodal.inputs
import
MultiModal
FeatureSpec
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
is_list_of
from
vllm.v1.engine
import
(
EngineCoreEvent
,
EngineCoreEventType
,
from
vllm.v1.engine
import
(
EngineCoreEvent
,
EngineCoreEventType
,
EngineCoreRequest
,
FinishReason
)
EngineCoreRequest
,
FinishReason
)
from
vllm.v1.structured_output.request
import
StructuredOutputRequest
from
vllm.v1.structured_output.request
import
StructuredOutputRequest
...
@@ -26,14 +25,12 @@ class Request:
...
@@ -26,14 +25,12 @@ class Request:
self
,
self
,
request_id
:
str
,
request_id
:
str
,
prompt_token_ids
:
list
[
int
],
prompt_token_ids
:
list
[
int
],
multi_modal_kwargs
:
Optional
[
list
[
MultiModalKwargsItem
]],
multi_modal_hashes
:
Optional
[
list
[
str
]],
multi_modal_placeholders
:
Optional
[
list
[
PlaceholderRange
]],
sampling_params
:
Optional
[
SamplingParams
],
sampling_params
:
Optional
[
SamplingParams
],
pooling_params
:
Optional
[
PoolingParams
],
pooling_params
:
Optional
[
PoolingParams
],
eos_token_id
:
Optional
[
int
],
eos_token_id
:
Optional
[
int
],
client_index
:
int
=
0
,
client_index
:
int
=
0
,
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
mm_features
:
Optional
[
list
[
MultiModalFeatureSpec
]]
=
None
,
lora_request
:
Optional
[
"LoRARequest"
]
=
None
,
lora_request
:
Optional
[
"LoRARequest"
]
=
None
,
structured_output_request
:
Optional
[
"StructuredOutputRequest"
]
=
None
,
structured_output_request
:
Optional
[
"StructuredOutputRequest"
]
=
None
,
cache_salt
:
Optional
[
str
]
=
None
,
cache_salt
:
Optional
[
str
]
=
None
,
...
@@ -89,16 +86,14 @@ class Request:
...
@@ -89,16 +86,14 @@ class Request:
self
.
cache_salt
:
Optional
[
str
]
=
cache_salt
self
.
cache_salt
:
Optional
[
str
]
=
cache_salt
# Multi-modal related
# Multi-modal related
self
.
mm_positions
=
multi_modal_placeholders
or
[]
self
.
mm_features
=
mm_features
or
[]
self
.
mm_kwargs
=
multi_modal_kwargs
or
[]
self
.
num_encoder_inputs
=
len
(
self
.
mm_features
)
self
.
mm_hashes
:
list
[
str
]
=
multi_modal_hashes
or
[]
self
.
num_encoder_inputs
=
len
(
self
.
mm_kwargs
)
self
.
has_encoder_inputs
=
self
.
num_encoder_inputs
>
0
self
.
has_encoder_inputs
=
self
.
num_encoder_inputs
>
0
# TODO(sfeng33): Remove these legacy fields after clearing out all
#
Sanity check
#
references in scheduler and model runner
assert
len
(
self
.
mm_kwargs
)
==
len
(
self
.
mm_positions
)
self
.
mm_positions
=
[
f
.
mm_position
for
f
in
self
.
mm_features
]
if
self
.
mm_
hash
es
:
self
.
mm_kwargs
=
[
f
.
data
for
f
in
self
.
mm_
featur
es
]
assert
len
(
self
.
mm_kwargs
)
==
len
(
self
.
mm_
hash
es
)
self
.
mm_hashes
=
[
f
.
identifier
for
f
in
self
.
mm_
featur
es
]
# Read-only views
# Read-only views
# Prevent directly appending to these lists since
# Prevent directly appending to these lists since
...
@@ -126,20 +121,11 @@ class Request:
...
@@ -126,20 +121,11 @@ class Request:
cls
,
request
:
EngineCoreRequest
,
cls
,
request
:
EngineCoreRequest
,
block_hasher
:
Optional
[
Callable
[[
"Request"
],
list
[
"BlockHash"
]]]
block_hasher
:
Optional
[
Callable
[[
"Request"
],
list
[
"BlockHash"
]]]
)
->
"Request"
:
)
->
"Request"
:
if
request
.
mm_kwargs
is
not
None
:
mm_kwargs_lst
=
list
(
request
.
mm_kwargs
)
assert
is_list_of
(
mm_kwargs_lst
,
MultiModalKwargsItem
),
(
"mm_kwargs was not updated in EngineCore.add_request"
)
else
:
mm_kwargs_lst
=
None
return
cls
(
return
cls
(
request_id
=
request
.
request_id
,
request_id
=
request
.
request_id
,
client_index
=
request
.
client_index
,
client_index
=
request
.
client_index
,
prompt_token_ids
=
request
.
prompt_token_ids
,
prompt_token_ids
=
request
.
prompt_token_ids
,
multi_modal_kwargs
=
mm_kwargs_lst
,
mm_features
=
request
.
mm_features
,
multi_modal_hashes
=
request
.
mm_hashes
,
multi_modal_placeholders
=
request
.
mm_placeholders
,
sampling_params
=
request
.
sampling_params
,
sampling_params
=
request
.
sampling_params
,
pooling_params
=
request
.
pooling_params
,
pooling_params
=
request
.
pooling_params
,
eos_token_id
=
request
.
eos_token_id
,
eos_token_id
=
request
.
eos_token_id
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment