Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
cb88fdc7
Unverified
Commit
cb88fdc7
authored
Feb 12, 2026
by
Qi Wang
Committed by
GitHub
Feb 13, 2026
Browse files
feat: add encode client and embedding cache to PD worker (#6029)
parent
166e1f4d
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
327 additions
and
232 deletions
+327
-232
components/src/dynamo/common/configuration/groups/runtime_args.py
...ts/src/dynamo/common/configuration/groups/runtime_args.py
+14
-1
components/src/dynamo/vllm/args.py
components/src/dynamo/vllm/args.py
+0
-1
components/src/dynamo/vllm/multimodal_handlers/__init__.py
components/src/dynamo/vllm/multimodal_handlers/__init__.py
+4
-4
components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py
.../dynamo/vllm/multimodal_handlers/encode_worker_handler.py
+0
-13
components/src/dynamo/vllm/multimodal_handlers/multimodal_pd_worker_handler.py
.../vllm/multimodal_handlers/multimodal_pd_worker_handler.py
+239
-0
components/src/dynamo/vllm/multimodal_handlers/worker_handler.py
...nts/src/dynamo/vllm/multimodal_handlers/worker_handler.py
+4
-213
components/src/dynamo/vllm/tests/multimodal_handlers/test_vllm_multimodal_pd_worker_handler.py
...imodal_handlers/test_vllm_multimodal_pd_worker_handler.py
+66
-0
No files found.
components/src/dynamo/common/configuration/groups/runtime_args.py
View file @
cb88fdc7
...
@@ -27,12 +27,17 @@ class DynamoRuntimeConfig(ConfigBase):
...
@@ -27,12 +27,17 @@ class DynamoRuntimeConfig(ConfigBase):
custom_jinja_template
:
Optional
[
str
]
=
None
custom_jinja_template
:
Optional
[
str
]
=
None
endpoint_types
:
str
endpoint_types
:
str
dump_config_to
:
Optional
[
str
]
=
None
dump_config_to
:
Optional
[
str
]
=
None
multimodal_embedding_cache_capacity_gb
:
float
def
validate
(
self
)
->
None
:
def
validate
(
self
)
->
None
:
# TODO get a better way for spot fixes like this.
# TODO get a better way for spot fixes like this.
self
.
enable_local_indexer
=
not
self
.
durable_kv_events
self
.
enable_local_indexer
=
not
self
.
durable_kv_events
# For simplicity, we do not prepend "dyn-" unless it's absolutely necessary. These are
# exemplary exceptions:
# - To avoid name conflicts with different backends, prefix "dyn-" for dynamo specific
# args.
class
DynamoRuntimeArgGroup
(
ArgGroup
):
class
DynamoRuntimeArgGroup
(
ArgGroup
):
"""Dynamo runtime configuration parameters (common to all backends)."""
"""Dynamo runtime configuration parameters (common to all backends)."""
...
@@ -89,7 +94,6 @@ class DynamoRuntimeArgGroup(ArgGroup):
...
@@ -89,7 +94,6 @@ class DynamoRuntimeArgGroup(ArgGroup):
)
)
# Optional: tool/reasoning parsers (choices from dynamo._core when available)
# Optional: tool/reasoning parsers (choices from dynamo._core when available)
# To avoid name conflicts with different backends, prefix "dyn-" for dynamo specific args
add_argument
(
add_argument
(
g
,
g
,
flag_name
=
"--dyn-tool-call-parser"
,
flag_name
=
"--dyn-tool-call-parser"
,
...
@@ -130,3 +134,12 @@ class DynamoRuntimeArgGroup(ArgGroup):
...
@@ -130,3 +134,12 @@ class DynamoRuntimeArgGroup(ArgGroup):
default
=
None
,
default
=
None
,
help
=
"Dump resolved configuration to the specified file path."
,
help
=
"Dump resolved configuration to the specified file path."
,
)
)
add_argument
(
g
,
flag_name
=
"--multimodal-embedding-cache-capacity-gb"
,
env_var
=
"DYN_MULTIMODAL_EMBEDDING_CACHE_CAPACITY_GB"
,
default
=
0
,
arg_type
=
float
,
help
=
"Capacity of the multimodal embedding cache in GB. 0 = disabled."
,
)
components/src/dynamo/vllm/args.py
View file @
cb88fdc7
...
@@ -79,7 +79,6 @@ def parse_args() -> Config:
...
@@ -79,7 +79,6 @@ def parse_args() -> Config:
Returns:
Returns:
Config: Parsed configuration object.
Config: Parsed configuration object.
"""
"""
dynamo_runtime_argspec
=
DynamoRuntimeArgGroup
()
dynamo_runtime_argspec
=
DynamoRuntimeArgGroup
()
dynamo_vllm_argspec
=
DynamoVllmArgGroup
()
dynamo_vllm_argspec
=
DynamoVllmArgGroup
()
...
...
components/src/dynamo/vllm/multimodal_handlers/__init__.py
View file @
cb88fdc7
...
@@ -5,14 +5,14 @@ from dynamo.vllm.multimodal_handlers.encode_worker_handler import (
...
@@ -5,14 +5,14 @@ from dynamo.vllm.multimodal_handlers.encode_worker_handler import (
EncodeWorkerHandler
,
EncodeWorkerHandler
,
VLLMEncodeWorkerHandler
,
VLLMEncodeWorkerHandler
,
)
)
from
dynamo.vllm.multimodal_handlers.multimodal_pd_worker_handler
import
(
MultimodalPDWorkerHandler
,
)
from
dynamo.vllm.multimodal_handlers.preprocessed_handler
import
(
from
dynamo.vllm.multimodal_handlers.preprocessed_handler
import
(
ECProcessorHandler
,
ECProcessorHandler
,
PreprocessedHandler
,
PreprocessedHandler
,
)
)
from
dynamo.vllm.multimodal_handlers.worker_handler
import
(
from
dynamo.vllm.multimodal_handlers.worker_handler
import
MultimodalDecodeWorkerHandler
MultimodalDecodeWorkerHandler
,
MultimodalPDWorkerHandler
,
)
__all__
=
[
__all__
=
[
"EncodeWorkerHandler"
,
"EncodeWorkerHandler"
,
...
...
components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py
View file @
cb88fdc7
...
@@ -35,19 +35,6 @@ from ..multimodal_utils.model import is_qwen_vl_model
...
@@ -35,19 +35,6 @@ from ..multimodal_utils.model import is_qwen_vl_model
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
try
:
import
cupy
as
array_module
if
not
array_module
.
cuda
.
is_available
():
raise
ImportError
(
"CUDA is not available."
)
DEVICE
=
"cuda"
logger
.
info
(
"Using cupy for array operations (GPU mode)."
)
except
ImportError
as
e
:
logger
.
warning
(
f
"Failed to import cupy, falling back to numpy:
{
e
}
."
)
import
numpy
as
array_module
DEVICE
=
"cpu"
CACHE_SIZE_MAXIMUM
=
8
CACHE_SIZE_MAXIMUM
=
8
TRANSFER_LOCAL
=
int
(
os
.
getenv
(
"TRANSFER_LOCAL"
,
1
))
TRANSFER_LOCAL
=
int
(
os
.
getenv
(
"TRANSFER_LOCAL"
,
1
))
...
...
components/src/dynamo/vllm/multimodal_handlers/multimodal_pd_worker_handler.py
0 → 100644
View file @
cb88fdc7
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
copy
import
logging
from
collections
import
defaultdict
from
typing
import
Any
import
torch
from
vllm.inputs.data
import
TokensPrompt
from
vllm.v1.engine.async_llm
import
AsyncLLM
import
dynamo.nixl_connect
as
connect
from
dynamo.common.memory.multimodal_embedding_cache_manager
import
(
MultimodalEmbeddingCacheManager
,
)
from
dynamo.runtime
import
Client
,
Component
,
DistributedRuntime
from
..args
import
Config
from
..handlers
import
BaseWorkerHandler
from
..multimodal_utils
import
ImageLoader
,
MyRequestOutput
,
vLLMMultimodalRequest
from
..multimodal_utils.model
import
is_qwen_vl_model
from
..multimodal_utils.prefill_worker_utils
import
(
accumulate_embeddings
,
load_embeddings
,
)
logger
=
logging
.
getLogger
(
__name__
)
class
MultimodalPDWorkerHandler
(
BaseWorkerHandler
):
"""Prefill/Decode or Prefill-only worker for multimodal serving"""
def
__init__
(
self
,
runtime
,
component
:
Component
,
engine_client
:
AsyncLLM
,
config
:
Config
,
encode_worker_client
:
Client
|
None
=
None
,
decode_worker_client
:
Client
|
None
=
None
,
shutdown_event
=
None
,
):
# Get default_sampling_params from config
default_sampling_params
=
(
config
.
engine_args
.
create_model_config
().
get_diff_sampling_param
()
)
# Call BaseWorkerHandler.__init__ with proper parameters
super
().
__init__
(
runtime
,
component
,
engine_client
,
default_sampling_params
,
enable_multimodal
=
config
.
enable_multimodal
,
shutdown_event
=
shutdown_event
,
)
self
.
config
=
config
self
.
encode_worker_client
=
encode_worker_client
self
.
decode_worker_client
=
decode_worker_client
self
.
enable_disagg
=
config
.
is_prefill_worker
self
.
embedding_cache_manager
:
MultimodalEmbeddingCacheManager
|
None
=
None
if
config
.
multimodal_embedding_cache_capacity_gb
>
0
:
capacity_bytes
=
int
(
config
.
multimodal_embedding_cache_capacity_gb
*
1024
**
3
)
self
.
embedding_cache_manager
=
MultimodalEmbeddingCacheManager
(
capacity_bytes
)
# Initialize multimodal-specific components
logger
.
info
(
"Multimodal PD Worker startup started."
)
if
"video"
in
self
.
config
.
model
.
lower
():
self
.
EMBEDDINGS_DTYPE
=
torch
.
uint8
else
:
self
.
EMBEDDINGS_DTYPE
=
torch
.
float16
self
.
EMBEDDINGS_DEVICE
=
"cpu"
# Create and initialize a dynamo connector for this worker.
# We'll need this to move data between this worker and remote workers efficiently.
# Note: This is synchronous initialization, async initialization happens in async_init
self
.
_connector
:
connect
.
Connector
|
None
=
(
None
# Will be initialized in async_init
)
self
.
image_loader
=
ImageLoader
()
logger
.
info
(
"Multimodal PD Worker has been initialized"
)
async
def
async_init
(
self
,
runtime
:
DistributedRuntime
):
"""Async initialization for connector that requires async setup"""
# Initialize the connector asynchronously
self
.
_connector
=
connect
.
Connector
()
logger
.
info
(
"Multimodal PD Worker async initialization completed."
)
async
def
generate
(
self
,
request
:
vLLMMultimodalRequest
,
context
):
logger
.
debug
(
f
"Got raw request:
{
request
}
"
)
if
type
(
request
)
is
not
vLLMMultimodalRequest
:
if
type
(
request
)
is
str
:
request
=
vLLMMultimodalRequest
.
model_validate_json
(
request
)
else
:
request
=
vLLMMultimodalRequest
.
model_validate
(
request
)
logger
.
debug
(
f
"Received PD request: {{ id:
{
request
.
request_id
}
}}."
)
multi_modal_data
:
dict
[
str
,
Any
]
=
defaultdict
(
list
)
for
mi
in
request
.
multimodal_inputs
:
if
mi
.
multimodal_input
.
image_url
:
# PIL image path — used by both EC consumer mode
# (vLLM looks up cached embeddings via mm_hash) and
# non-disaggregated mode (vLLM encodes inline).
multi_modal_data
[
"image"
].
append
(
await
self
.
image_loader
.
load_image
(
mi
.
multimodal_input
.
image_url
)
)
else
:
# Pre-computed embeddings via NIXL RDMA or local safetensors
embeddings
=
await
load_embeddings
(
mi
,
self
.
EMBEDDINGS_DTYPE
,
self
.
EMBEDDINGS_DEVICE
,
self
.
_connector
,
)
accumulate_embeddings
(
multi_modal_data
,
self
.
config
.
model
,
self
.
EMBEDDINGS_DTYPE
,
embeddings
,
mi
.
image_grid_thw
,
)
# For Qwen VL (mRoPE), capture the accumulated image grid + embedding shape
# from the constructed multimodal data so decode can reconstruct its
# multi_modal_data consistently for multiple images.
if
is_qwen_vl_model
(
self
.
config
.
model
)
and
isinstance
(
multi_modal_data
.
get
(
"image"
),
dict
):
image_data
=
multi_modal_data
[
"image"
]
image_grid_thw
=
image_data
.
get
(
"image_grid_thw"
)
image_embeds
=
image_data
.
get
(
"image_embeds"
)
if
image_grid_thw
is
not
None
:
request
.
image_grid_thw
=
(
image_grid_thw
.
tolist
()
if
isinstance
(
image_grid_thw
,
torch
.
Tensor
)
else
image_grid_thw
)
if
image_embeds
is
not
None
:
request
.
embeddings_shape
=
list
(
image_embeds
.
shape
)
# Remove the image features from the request as they are not required
# Use empty list instead of None to satisfy Pydantic validation on decode worker after vllm upgrade
request
.
multimodal_inputs
=
[]
logger
.
info
(
f
"Prepared multimodal data size:
{
len
(
multi_modal_data
[
'image'
])
}
"
)
logger
.
debug
(
"Multimodal data keys: %s"
,
list
(
multi_modal_data
.
keys
()))
# Deepcopy the request to avoid modifying the original
# when we adjust sampling params for prefill
pd_request
=
copy
.
deepcopy
(
request
)
# Do prefill and remote decode if enable_disagg is true
if
self
.
enable_disagg
and
self
.
decode_worker_client
:
extra_args
=
pd_request
.
sampling_params
.
extra_args
or
{}
extra_args
[
"kv_transfer_params"
]
=
{
"do_remote_decode"
:
True
,
}
pd_request
.
sampling_params
.
extra_args
=
extra_args
pd_request
.
sampling_params
.
max_tokens
=
1
pd_request
.
sampling_params
.
min_tokens
=
1
logger
.
debug
(
"Prefill request: %s"
,
pd_request
)
gen
=
self
.
engine_client
.
generate
(
prompt
=
TokensPrompt
(
prompt_token_ids
=
pd_request
.
engine_prompt
[
"prompt_token_ids"
],
multi_modal_data
=
multi_modal_data
,
),
sampling_params
=
pd_request
.
sampling_params
,
request_id
=
pd_request
.
request_id
,
)
if
self
.
enable_disagg
and
self
.
decode_worker_client
:
decode_request
=
copy
.
deepcopy
(
request
)
async
for
prefill_response
in
gen
:
# For Qwen VL models with mRoPE: Keep the ORIGINAL unexpanded prompt.
# The decode worker will pass multi_modal_data which causes vLLM to
# expand the prompt identically to prefill, ensuring block counts match.
#
# For other models: Use the expanded prompt from prefill response.
# These models don't pass multi_modal_data in decode, so they need
# the already-expanded prompt to match the KV cache layout.
if
not
is_qwen_vl_model
(
self
.
config
.
model
):
decode_request
.
engine_prompt
[
"prompt_token_ids"
]
=
prefill_response
.
prompt_token_ids
logger
.
debug
(
f
"Prefill response kv_transfer_params:
{
prefill_response
.
kv_transfer_params
}
"
)
extra_args
=
decode_request
.
sampling_params
.
extra_args
or
{}
extra_args
[
"kv_transfer_params"
]
=
prefill_response
.
kv_transfer_params
extra_args
.
pop
(
"serialized_request"
,
None
)
decode_request
.
sampling_params
.
extra_args
=
extra_args
logger
.
debug
(
"Decode request: %s"
,
decode_request
)
async
for
(
decode_response
)
in
await
self
.
decode_worker_client
.
round_robin
(
decode_request
.
model_dump_json
()
):
output
=
MyRequestOutput
.
model_validate_json
(
decode_response
.
data
())
# type: ignore[attr-defined]
yield
MyRequestOutput
(
request_id
=
output
.
request_id
,
prompt
=
output
.
prompt
,
prompt_token_ids
=
output
.
prompt_token_ids
,
prompt_logprobs
=
output
.
prompt_logprobs
,
outputs
=
output
.
outputs
,
finished
=
output
.
finished
,
metrics
=
output
.
metrics
,
kv_transfer_params
=
output
.
kv_transfer_params
,
).
model_dump_json
()
else
:
async
for
response
in
gen
:
logger
.
debug
(
f
"Response kv_transfer_params:
{
response
.
kv_transfer_params
}
"
)
logger
.
debug
(
f
"length of expanded prompt ids:
{
len
(
response
.
prompt_token_ids
)
}
"
)
# logger.info(f"Response outputs: {response.outputs}")
yield
MyRequestOutput
(
request_id
=
response
.
request_id
,
prompt
=
response
.
prompt
,
prompt_token_ids
=
response
.
prompt_token_ids
,
prompt_logprobs
=
response
.
prompt_logprobs
,
outputs
=
response
.
outputs
,
finished
=
response
.
finished
,
metrics
=
response
.
metrics
,
kv_transfer_params
=
response
.
kv_transfer_params
,
).
model_dump_json
()
components/src/dynamo/vllm/multimodal_handlers/worker_handler.py
View file @
cb88fdc7
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
copy
import
logging
import
logging
from
collections
import
defaultdict
from
typing
import
Any
import
torch
from
vllm.inputs.data
import
TokensPrompt
from
vllm.inputs.data
import
TokensPrompt
from
vllm.v1.engine.async_llm
import
AsyncLLM
import
dynamo.nixl_connect
as
connect
import
dynamo.nixl_connect
as
connect
from
dynamo.runtime
import
Client
,
Component
,
DistributedRuntime
from
dynamo.runtime
import
DistributedRuntime
from
..args
import
Config
from
..handlers
import
BaseWorkerHandler
from
..handlers
import
BaseWorkerHandler
from
..multimodal_utils
import
ImageLoader
,
MyRequestOutput
,
vLLMMultimodalRequest
from
..multimodal_utils
import
MyRequestOutput
,
vLLMMultimodalRequest
from
..multimodal_utils.model
import
construct_qwen_decode_mm_data
,
is_qwen_vl_model
from
..multimodal_utils.model
import
construct_qwen_decode_mm_data
,
is_qwen_vl_model
from
..multimodal_utils.prefill_worker_utils
import
(
accumulate_embeddings
,
load_embeddings
,
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -32,7 +24,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
...
@@ -32,7 +24,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
runtime
,
runtime
,
component
,
component
,
engine_client
,
engine_client
,
config
,
config
:
Config
,
shutdown_event
=
None
,
shutdown_event
=
None
,
):
):
# Get default_sampling_params from config
# Get default_sampling_params from config
...
@@ -111,204 +103,3 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
...
@@ -111,204 +103,3 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
metrics
=
response
.
metrics
,
metrics
=
response
.
metrics
,
kv_transfer_params
=
response
.
kv_transfer_params
,
kv_transfer_params
=
response
.
kv_transfer_params
,
).
model_dump_json
()
).
model_dump_json
()
class
MultimodalPDWorkerHandler
(
BaseWorkerHandler
):
"""Prefill/Decode or Prefill-only worker for multimodal serving"""
def
__init__
(
self
,
runtime
,
component
:
Component
,
engine_client
:
AsyncLLM
,
config
,
decode_worker_client
:
Client
|
None
=
None
,
shutdown_event
=
None
,
):
# Get default_sampling_params from config
default_sampling_params
=
(
config
.
engine_args
.
create_model_config
().
get_diff_sampling_param
()
)
# Call BaseWorkerHandler.__init__ with proper parameters
super
().
__init__
(
runtime
,
component
,
engine_client
,
default_sampling_params
,
enable_multimodal
=
config
.
enable_multimodal
,
shutdown_event
=
shutdown_event
,
)
self
.
config
=
config
self
.
decode_worker_client
=
decode_worker_client
self
.
enable_disagg
=
config
.
is_prefill_worker
# Initialize multimodal-specific components
logger
.
info
(
"Multimodal PD Worker startup started."
)
if
"video"
in
self
.
config
.
model
.
lower
():
self
.
EMBEDDINGS_DTYPE
=
torch
.
uint8
else
:
self
.
EMBEDDINGS_DTYPE
=
torch
.
float16
self
.
EMBEDDINGS_DEVICE
=
"cpu"
# Create and initialize a dynamo connector for this worker.
# We'll need this to move data between this worker and remote workers efficiently.
# Note: This is synchronous initialization, async initialization happens in async_init
self
.
_connector
:
connect
.
Connector
|
None
=
(
None
# Will be initialized in async_init
)
self
.
image_loader
=
ImageLoader
()
logger
.
info
(
"Multimodal PD Worker has been initialized"
)
async
def
async_init
(
self
,
runtime
:
DistributedRuntime
):
"""Async initialization for connector that requires async setup"""
# Initialize the connector asynchronously
self
.
_connector
=
connect
.
Connector
()
logger
.
info
(
"Multimodal PD Worker async initialization completed."
)
async
def
generate
(
self
,
request
:
vLLMMultimodalRequest
,
context
):
logger
.
debug
(
f
"Got raw request:
{
request
}
"
)
if
type
(
request
)
is
not
vLLMMultimodalRequest
:
if
type
(
request
)
is
str
:
request
=
vLLMMultimodalRequest
.
model_validate_json
(
request
)
else
:
request
=
vLLMMultimodalRequest
.
model_validate
(
request
)
logger
.
debug
(
f
"Received PD request: {{ id:
{
request
.
request_id
}
}}."
)
multi_modal_data
:
dict
[
str
,
Any
]
=
defaultdict
(
list
)
for
mi
in
request
.
multimodal_inputs
:
if
mi
.
multimodal_input
.
image_url
:
# PIL image path — used by both EC consumer mode
# (vLLM looks up cached embeddings via mm_hash) and
# non-disaggregated mode (vLLM encodes inline).
multi_modal_data
[
"image"
].
append
(
await
self
.
image_loader
.
load_image
(
mi
.
multimodal_input
.
image_url
)
)
else
:
# Pre-computed embeddings via NIXL RDMA or local safetensors
embeddings
=
await
load_embeddings
(
mi
,
self
.
EMBEDDINGS_DTYPE
,
self
.
EMBEDDINGS_DEVICE
,
self
.
_connector
,
)
accumulate_embeddings
(
multi_modal_data
,
self
.
config
.
model
,
self
.
EMBEDDINGS_DTYPE
,
embeddings
,
mi
.
image_grid_thw
,
)
# For Qwen VL (mRoPE), capture the accumulated image grid + embedding shape
# from the constructed multimodal data so decode can reconstruct its
# multi_modal_data consistently for multiple images.
if
is_qwen_vl_model
(
self
.
config
.
model
)
and
isinstance
(
multi_modal_data
.
get
(
"image"
),
dict
):
image_data
=
multi_modal_data
[
"image"
]
image_grid_thw
=
image_data
.
get
(
"image_grid_thw"
)
image_embeds
=
image_data
.
get
(
"image_embeds"
)
if
image_grid_thw
is
not
None
:
request
.
image_grid_thw
=
(
image_grid_thw
.
tolist
()
if
isinstance
(
image_grid_thw
,
torch
.
Tensor
)
else
image_grid_thw
)
if
image_embeds
is
not
None
:
request
.
embeddings_shape
=
list
(
image_embeds
.
shape
)
# Remove the image features from the request as they are not required
# Use empty list instead of None to satisfy Pydantic validation on decode worker after vllm upgrade
request
.
multimodal_inputs
=
[]
logger
.
info
(
f
"Prepared multimodal data size:
{
len
(
multi_modal_data
[
'image'
])
}
"
)
logger
.
info
(
f
"
{
multi_modal_data
}
"
)
# Deepcopy the request to avoid modifying the original
# when we adjust sampling params for prefill
pd_request
=
copy
.
deepcopy
(
request
)
# Do prefill and remote decode if enable_disagg is true
if
self
.
enable_disagg
and
self
.
decode_worker_client
:
extra_args
=
pd_request
.
sampling_params
.
extra_args
or
{}
extra_args
[
"kv_transfer_params"
]
=
{
"do_remote_decode"
:
True
,
}
pd_request
.
sampling_params
.
extra_args
=
extra_args
pd_request
.
sampling_params
.
max_tokens
=
1
pd_request
.
sampling_params
.
min_tokens
=
1
logger
.
debug
(
"Prefill request: %s"
,
pd_request
)
gen
=
self
.
engine_client
.
generate
(
prompt
=
TokensPrompt
(
prompt_token_ids
=
pd_request
.
engine_prompt
[
"prompt_token_ids"
],
multi_modal_data
=
multi_modal_data
,
),
sampling_params
=
pd_request
.
sampling_params
,
request_id
=
pd_request
.
request_id
,
)
if
self
.
enable_disagg
and
self
.
decode_worker_client
:
decode_request
=
copy
.
deepcopy
(
request
)
async
for
prefill_response
in
gen
:
# For Qwen VL models with mRoPE: Keep the ORIGINAL unexpanded prompt.
# The decode worker will pass multi_modal_data which causes vLLM to
# expand the prompt identically to prefill, ensuring block counts match.
#
# For other models: Use the expanded prompt from prefill response.
# These models don't pass multi_modal_data in decode, so they need
# the already-expanded prompt to match the KV cache layout.
if
not
is_qwen_vl_model
(
self
.
config
.
model
):
decode_request
.
engine_prompt
[
"prompt_token_ids"
]
=
prefill_response
.
prompt_token_ids
logger
.
debug
(
f
"Prefill response kv_transfer_params:
{
prefill_response
.
kv_transfer_params
}
"
)
extra_args
=
decode_request
.
sampling_params
.
extra_args
or
{}
extra_args
[
"kv_transfer_params"
]
=
prefill_response
.
kv_transfer_params
extra_args
.
pop
(
"serialized_request"
,
None
)
decode_request
.
sampling_params
.
extra_args
=
extra_args
logger
.
debug
(
"Decode request: %s"
,
decode_request
)
async
for
(
decode_response
)
in
await
self
.
decode_worker_client
.
round_robin
(
decode_request
.
model_dump_json
()
):
output
=
MyRequestOutput
.
model_validate_json
(
decode_response
.
data
())
# type: ignore[attr-defined]
yield
MyRequestOutput
(
request_id
=
output
.
request_id
,
prompt
=
output
.
prompt
,
prompt_token_ids
=
output
.
prompt_token_ids
,
prompt_logprobs
=
output
.
prompt_logprobs
,
outputs
=
output
.
outputs
,
finished
=
output
.
finished
,
metrics
=
output
.
metrics
,
kv_transfer_params
=
output
.
kv_transfer_params
,
).
model_dump_json
()
else
:
async
for
response
in
gen
:
logger
.
debug
(
f
"Response kv_transfer_params:
{
response
.
kv_transfer_params
}
"
)
logger
.
debug
(
f
"length of expanded prompt ids:
{
len
(
response
.
prompt_token_ids
)
}
"
)
# logger.info(f"Response outputs: {response.outputs}")
yield
MyRequestOutput
(
request_id
=
response
.
request_id
,
prompt
=
response
.
prompt
,
prompt_token_ids
=
response
.
prompt_token_ids
,
prompt_logprobs
=
response
.
prompt_logprobs
,
outputs
=
response
.
outputs
,
finished
=
response
.
finished
,
metrics
=
response
.
metrics
,
kv_transfer_params
=
response
.
kv_transfer_params
,
).
model_dump_json
()
components/src/dynamo/vllm/tests/multimodal_handlers/test_vllm_multimodal_pd_worker_handler.py
0 → 100644
View file @
cb88fdc7
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for MultimodalPDWorkerHandler.__init__."""
from
unittest.mock
import
MagicMock
,
patch
import
pytest
from
dynamo.common.memory.multimodal_embedding_cache_manager
import
(
MultimodalEmbeddingCacheManager
,
)
from
dynamo.vllm.multimodal_handlers
import
multimodal_pd_worker_handler
as
mod
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
vllm
,
pytest
.
mark
.
gpu_0
,
pytest
.
mark
.
multimodal
,
]
def
_make_config
(
model
:
str
=
"test-model"
,
is_prefill_worker
:
bool
=
False
,
enable_multimodal
:
bool
=
True
,
multimodal_embedding_cache_capacity_gb
:
float
=
0
,
)
->
MagicMock
:
"""Create a mock Config with the fields used by MultimodalPDWorkerHandler.__init__."""
config
=
MagicMock
()
config
.
model
=
model
config
.
is_prefill_worker
=
is_prefill_worker
config
.
enable_multimodal
=
enable_multimodal
config
.
multimodal_embedding_cache_capacity_gb
=
(
multimodal_embedding_cache_capacity_gb
)
config
.
engine_args
.
create_model_config
.
return_value
.
get_diff_sampling_param
.
return_value
=
(
{}
)
return
config
class
TestMultimodalPDWorkerHandlerInit
:
"""Tests for MultimodalPDWorkerHandler.__init__ focusing on embedding cache."""
def
test_init_with_embedding_cache
(
self
):
"""When capacity > 0, a MultimodalEmbeddingCacheManager is created with correct byte size."""
capacity_gb
=
0.1
config
=
_make_config
(
multimodal_embedding_cache_capacity_gb
=
capacity_gb
)
with
(
patch
.
object
(
mod
.
BaseWorkerHandler
,
"__init__"
,
return_value
=
None
),
patch
.
object
(
mod
,
"ImageLoader"
,
new_callable
=
MagicMock
),
):
handler
=
mod
.
MultimodalPDWorkerHandler
(
runtime
=
MagicMock
(),
component
=
MagicMock
(),
engine_client
=
MagicMock
(),
config
=
config
,
)
assert
isinstance
(
handler
.
embedding_cache_manager
,
MultimodalEmbeddingCacheManager
)
expected_bytes
=
int
(
capacity_gb
*
1024
**
3
)
assert
handler
.
embedding_cache_manager
.
_capacity_bytes
==
expected_bytes
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment