Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
cb88fdc7
"docs/vscode:/vscode.git/clone" did not exist on "8ad3b9a2984eaafe74cf1b7bf257eb828b32f016"
Unverified
Commit
cb88fdc7
authored
Feb 12, 2026
by
Qi Wang
Committed by
GitHub
Feb 13, 2026
Browse files
feat: add encode client and embedding cache to PD worker (#6029)
parent
166e1f4d
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
327 additions
and
232 deletions
+327
-232
components/src/dynamo/common/configuration/groups/runtime_args.py
...ts/src/dynamo/common/configuration/groups/runtime_args.py
+14
-1
components/src/dynamo/vllm/args.py
components/src/dynamo/vllm/args.py
+0
-1
components/src/dynamo/vllm/multimodal_handlers/__init__.py
components/src/dynamo/vllm/multimodal_handlers/__init__.py
+4
-4
components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py
.../dynamo/vllm/multimodal_handlers/encode_worker_handler.py
+0
-13
components/src/dynamo/vllm/multimodal_handlers/multimodal_pd_worker_handler.py
.../vllm/multimodal_handlers/multimodal_pd_worker_handler.py
+239
-0
components/src/dynamo/vllm/multimodal_handlers/worker_handler.py
...nts/src/dynamo/vllm/multimodal_handlers/worker_handler.py
+4
-213
components/src/dynamo/vllm/tests/multimodal_handlers/test_vllm_multimodal_pd_worker_handler.py
...imodal_handlers/test_vllm_multimodal_pd_worker_handler.py
+66
-0
No files found.
components/src/dynamo/common/configuration/groups/runtime_args.py
View file @
cb88fdc7
...
@@ -27,12 +27,17 @@ class DynamoRuntimeConfig(ConfigBase):
...
@@ -27,12 +27,17 @@ class DynamoRuntimeConfig(ConfigBase):
custom_jinja_template
:
Optional
[
str
]
=
None
custom_jinja_template
:
Optional
[
str
]
=
None
endpoint_types
:
str
endpoint_types
:
str
dump_config_to
:
Optional
[
str
]
=
None
dump_config_to
:
Optional
[
str
]
=
None
multimodal_embedding_cache_capacity_gb
:
float
def
validate
(
self
)
->
None
:
def
validate
(
self
)
->
None
:
# TODO get a better way for spot fixes like this.
# TODO get a better way for spot fixes like this.
self
.
enable_local_indexer
=
not
self
.
durable_kv_events
self
.
enable_local_indexer
=
not
self
.
durable_kv_events
# For simplicity, we do not prepend "dyn-" unless it's absolutely necessary. These are
# exemplary exceptions:
# - To avoid name conflicts with different backends, prefix "dyn-" for dynamo specific
# args.
class
DynamoRuntimeArgGroup
(
ArgGroup
):
class
DynamoRuntimeArgGroup
(
ArgGroup
):
"""Dynamo runtime configuration parameters (common to all backends)."""
"""Dynamo runtime configuration parameters (common to all backends)."""
...
@@ -89,7 +94,6 @@ class DynamoRuntimeArgGroup(ArgGroup):
...
@@ -89,7 +94,6 @@ class DynamoRuntimeArgGroup(ArgGroup):
)
)
# Optional: tool/reasoning parsers (choices from dynamo._core when available)
# Optional: tool/reasoning parsers (choices from dynamo._core when available)
# To avoid name conflicts with different backends, prefix "dyn-" for dynamo specific args
add_argument
(
add_argument
(
g
,
g
,
flag_name
=
"--dyn-tool-call-parser"
,
flag_name
=
"--dyn-tool-call-parser"
,
...
@@ -130,3 +134,12 @@ class DynamoRuntimeArgGroup(ArgGroup):
...
@@ -130,3 +134,12 @@ class DynamoRuntimeArgGroup(ArgGroup):
default
=
None
,
default
=
None
,
help
=
"Dump resolved configuration to the specified file path."
,
help
=
"Dump resolved configuration to the specified file path."
,
)
)
add_argument
(
g
,
flag_name
=
"--multimodal-embedding-cache-capacity-gb"
,
env_var
=
"DYN_MULTIMODAL_EMBEDDING_CACHE_CAPACITY_GB"
,
default
=
0
,
arg_type
=
float
,
help
=
"Capacity of the multimodal embedding cache in GB. 0 = disabled."
,
)
components/src/dynamo/vllm/args.py
View file @
cb88fdc7
...
@@ -79,7 +79,6 @@ def parse_args() -> Config:
...
@@ -79,7 +79,6 @@ def parse_args() -> Config:
Returns:
Returns:
Config: Parsed configuration object.
Config: Parsed configuration object.
"""
"""
dynamo_runtime_argspec
=
DynamoRuntimeArgGroup
()
dynamo_runtime_argspec
=
DynamoRuntimeArgGroup
()
dynamo_vllm_argspec
=
DynamoVllmArgGroup
()
dynamo_vllm_argspec
=
DynamoVllmArgGroup
()
...
...
components/src/dynamo/vllm/multimodal_handlers/__init__.py
View file @
cb88fdc7
...
@@ -5,14 +5,14 @@ from dynamo.vllm.multimodal_handlers.encode_worker_handler import (
...
@@ -5,14 +5,14 @@ from dynamo.vllm.multimodal_handlers.encode_worker_handler import (
EncodeWorkerHandler
,
EncodeWorkerHandler
,
VLLMEncodeWorkerHandler
,
VLLMEncodeWorkerHandler
,
)
)
from
dynamo.vllm.multimodal_handlers.multimodal_pd_worker_handler
import
(
MultimodalPDWorkerHandler
,
)
from
dynamo.vllm.multimodal_handlers.preprocessed_handler
import
(
from
dynamo.vllm.multimodal_handlers.preprocessed_handler
import
(
ECProcessorHandler
,
ECProcessorHandler
,
PreprocessedHandler
,
PreprocessedHandler
,
)
)
from
dynamo.vllm.multimodal_handlers.worker_handler
import
(
from
dynamo.vllm.multimodal_handlers.worker_handler
import
MultimodalDecodeWorkerHandler
MultimodalDecodeWorkerHandler
,
MultimodalPDWorkerHandler
,
)
__all__
=
[
__all__
=
[
"EncodeWorkerHandler"
,
"EncodeWorkerHandler"
,
...
...
components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py
View file @
cb88fdc7
...
@@ -35,19 +35,6 @@ from ..multimodal_utils.model import is_qwen_vl_model
...
@@ -35,19 +35,6 @@ from ..multimodal_utils.model import is_qwen_vl_model
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
try
:
import
cupy
as
array_module
if
not
array_module
.
cuda
.
is_available
():
raise
ImportError
(
"CUDA is not available."
)
DEVICE
=
"cuda"
logger
.
info
(
"Using cupy for array operations (GPU mode)."
)
except
ImportError
as
e
:
logger
.
warning
(
f
"Failed to import cupy, falling back to numpy:
{
e
}
."
)
import
numpy
as
array_module
DEVICE
=
"cpu"
CACHE_SIZE_MAXIMUM
=
8
CACHE_SIZE_MAXIMUM
=
8
TRANSFER_LOCAL
=
int
(
os
.
getenv
(
"TRANSFER_LOCAL"
,
1
))
TRANSFER_LOCAL
=
int
(
os
.
getenv
(
"TRANSFER_LOCAL"
,
1
))
...
...
components/src/dynamo/vllm/multimodal_handlers/multimodal_pd_worker_handler.py
0 → 100644
View file @
cb88fdc7
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
copy
import
logging
from
collections
import
defaultdict
from
typing
import
Any
import
torch
from
vllm.inputs.data
import
TokensPrompt
from
vllm.v1.engine.async_llm
import
AsyncLLM
import
dynamo.nixl_connect
as
connect
from
dynamo.common.memory.multimodal_embedding_cache_manager
import
(
MultimodalEmbeddingCacheManager
,
)
from
dynamo.runtime
import
Client
,
Component
,
DistributedRuntime
from
..args
import
Config
from
..handlers
import
BaseWorkerHandler
from
..multimodal_utils
import
ImageLoader
,
MyRequestOutput
,
vLLMMultimodalRequest
from
..multimodal_utils.model
import
is_qwen_vl_model
from
..multimodal_utils.prefill_worker_utils
import
(
accumulate_embeddings
,
load_embeddings
,
)
logger
=
logging
.
getLogger
(
__name__
)
class
MultimodalPDWorkerHandler
(
BaseWorkerHandler
):
"""Prefill/Decode or Prefill-only worker for multimodal serving"""
def
__init__
(
self
,
runtime
,
component
:
Component
,
engine_client
:
AsyncLLM
,
config
:
Config
,
encode_worker_client
:
Client
|
None
=
None
,
decode_worker_client
:
Client
|
None
=
None
,
shutdown_event
=
None
,
):
# Get default_sampling_params from config
default_sampling_params
=
(
config
.
engine_args
.
create_model_config
().
get_diff_sampling_param
()
)
# Call BaseWorkerHandler.__init__ with proper parameters
super
().
__init__
(
runtime
,
component
,
engine_client
,
default_sampling_params
,
enable_multimodal
=
config
.
enable_multimodal
,
shutdown_event
=
shutdown_event
,
)
self
.
config
=
config
self
.
encode_worker_client
=
encode_worker_client
self
.
decode_worker_client
=
decode_worker_client
self
.
enable_disagg
=
config
.
is_prefill_worker
self
.
embedding_cache_manager
:
MultimodalEmbeddingCacheManager
|
None
=
None
if
config
.
multimodal_embedding_cache_capacity_gb
>
0
:
capacity_bytes
=
int
(
config
.
multimodal_embedding_cache_capacity_gb
*
1024
**
3
)
self
.
embedding_cache_manager
=
MultimodalEmbeddingCacheManager
(
capacity_bytes
)
# Initialize multimodal-specific components
logger
.
info
(
"Multimodal PD Worker startup started."
)
if
"video"
in
self
.
config
.
model
.
lower
():
self
.
EMBEDDINGS_DTYPE
=
torch
.
uint8
else
:
self
.
EMBEDDINGS_DTYPE
=
torch
.
float16
self
.
EMBEDDINGS_DEVICE
=
"cpu"
# Create and initialize a dynamo connector for this worker.
# We'll need this to move data between this worker and remote workers efficiently.
# Note: This is synchronous initialization, async initialization happens in async_init
self
.
_connector
:
connect
.
Connector
|
None
=
(
None
# Will be initialized in async_init
)
self
.
image_loader
=
ImageLoader
()
logger
.
info
(
"Multimodal PD Worker has been initialized"
)
async
def
async_init
(
self
,
runtime
:
DistributedRuntime
):
"""Async initialization for connector that requires async setup"""
# Initialize the connector asynchronously
self
.
_connector
=
connect
.
Connector
()
logger
.
info
(
"Multimodal PD Worker async initialization completed."
)
async
def
generate
(
self
,
request
:
vLLMMultimodalRequest
,
context
):
logger
.
debug
(
f
"Got raw request:
{
request
}
"
)
if
type
(
request
)
is
not
vLLMMultimodalRequest
:
if
type
(
request
)
is
str
:
request
=
vLLMMultimodalRequest
.
model_validate_json
(
request
)
else
:
request
=
vLLMMultimodalRequest
.
model_validate
(
request
)
logger
.
debug
(
f
"Received PD request: {{ id:
{
request
.
request_id
}
}}."
)
multi_modal_data
:
dict
[
str
,
Any
]
=
defaultdict
(
list
)
for
mi
in
request
.
multimodal_inputs
:
if
mi
.
multimodal_input
.
image_url
:
# PIL image path — used by both EC consumer mode
# (vLLM looks up cached embeddings via mm_hash) and
# non-disaggregated mode (vLLM encodes inline).
multi_modal_data
[
"image"
].
append
(
await
self
.
image_loader
.
load_image
(
mi
.
multimodal_input
.
image_url
)
)
else
:
# Pre-computed embeddings via NIXL RDMA or local safetensors
embeddings
=
await
load_embeddings
(
mi
,
self
.
EMBEDDINGS_DTYPE
,
self
.
EMBEDDINGS_DEVICE
,
self
.
_connector
,
)
accumulate_embeddings
(
multi_modal_data
,
self
.
config
.
model
,
self
.
EMBEDDINGS_DTYPE
,
embeddings
,
mi
.
image_grid_thw
,
)
# For Qwen VL (mRoPE), capture the accumulated image grid + embedding shape
# from the constructed multimodal data so decode can reconstruct its
# multi_modal_data consistently for multiple images.
if
is_qwen_vl_model
(
self
.
config
.
model
)
and
isinstance
(
multi_modal_data
.
get
(
"image"
),
dict
):
image_data
=
multi_modal_data
[
"image"
]
image_grid_thw
=
image_data
.
get
(
"image_grid_thw"
)
image_embeds
=
image_data
.
get
(
"image_embeds"
)
if
image_grid_thw
is
not
None
:
request
.
image_grid_thw
=
(
image_grid_thw
.
tolist
()
if
isinstance
(
image_grid_thw
,
torch
.
Tensor
)
else
image_grid_thw
)
if
image_embeds
is
not
None
:
request
.
embeddings_shape
=
list
(
image_embeds
.
shape
)
# Remove the image features from the request as they are not required
# Use empty list instead of None to satisfy Pydantic validation on decode worker after vllm upgrade
request
.
multimodal_inputs
=
[]
logger
.
info
(
f
"Prepared multimodal data size:
{
len
(
multi_modal_data
[
'image'
])
}
"
)
logger
.
debug
(
"Multimodal data keys: %s"
,
list
(
multi_modal_data
.
keys
()))
# Deepcopy the request to avoid modifying the original
# when we adjust sampling params for prefill
pd_request
=
copy
.
deepcopy
(
request
)
# Do prefill and remote decode if enable_disagg is true
if
self
.
enable_disagg
and
self
.
decode_worker_client
:
extra_args
=
pd_request
.
sampling_params
.
extra_args
or
{}
extra_args
[
"kv_transfer_params"
]
=
{
"do_remote_decode"
:
True
,
}
pd_request
.
sampling_params
.
extra_args
=
extra_args
pd_request
.
sampling_params
.
max_tokens
=
1
pd_request
.
sampling_params
.
min_tokens
=
1
logger
.
debug
(
"Prefill request: %s"
,
pd_request
)
gen
=
self
.
engine_client
.
generate
(
prompt
=
TokensPrompt
(
prompt_token_ids
=
pd_request
.
engine_prompt
[
"prompt_token_ids"
],
multi_modal_data
=
multi_modal_data
,
),
sampling_params
=
pd_request
.
sampling_params
,
request_id
=
pd_request
.
request_id
,
)
if
self
.
enable_disagg
and
self
.
decode_worker_client
:
decode_request
=
copy
.
deepcopy
(
request
)
async
for
prefill_response
in
gen
:
# For Qwen VL models with mRoPE: Keep the ORIGINAL unexpanded prompt.
# The decode worker will pass multi_modal_data which causes vLLM to
# expand the prompt identically to prefill, ensuring block counts match.
#
# For other models: Use the expanded prompt from prefill response.
# These models don't pass multi_modal_data in decode, so they need
# the already-expanded prompt to match the KV cache layout.
if
not
is_qwen_vl_model
(
self
.
config
.
model
):
decode_request
.
engine_prompt
[
"prompt_token_ids"
]
=
prefill_response
.
prompt_token_ids
logger
.
debug
(
f
"Prefill response kv_transfer_params:
{
prefill_response
.
kv_transfer_params
}
"
)
extra_args
=
decode_request
.
sampling_params
.
extra_args
or
{}
extra_args
[
"kv_transfer_params"
]
=
prefill_response
.
kv_transfer_params
extra_args
.
pop
(
"serialized_request"
,
None
)
decode_request
.
sampling_params
.
extra_args
=
extra_args
logger
.
debug
(
"Decode request: %s"
,
decode_request
)
async
for
(
decode_response
)
in
await
self
.
decode_worker_client
.
round_robin
(
decode_request
.
model_dump_json
()
):
output
=
MyRequestOutput
.
model_validate_json
(
decode_response
.
data
())
# type: ignore[attr-defined]
yield
MyRequestOutput
(
request_id
=
output
.
request_id
,
prompt
=
output
.
prompt
,
prompt_token_ids
=
output
.
prompt_token_ids
,
prompt_logprobs
=
output
.
prompt_logprobs
,
outputs
=
output
.
outputs
,
finished
=
output
.
finished
,
metrics
=
output
.
metrics
,
kv_transfer_params
=
output
.
kv_transfer_params
,
).
model_dump_json
()
else
:
async
for
response
in
gen
:
logger
.
debug
(
f
"Response kv_transfer_params:
{
response
.
kv_transfer_params
}
"
)
logger
.
debug
(
f
"length of expanded prompt ids:
{
len
(
response
.
prompt_token_ids
)
}
"
)
# logger.info(f"Response outputs: {response.outputs}")
yield
MyRequestOutput
(
request_id
=
response
.
request_id
,
prompt
=
response
.
prompt
,
prompt_token_ids
=
response
.
prompt_token_ids
,
prompt_logprobs
=
response
.
prompt_logprobs
,
outputs
=
response
.
outputs
,
finished
=
response
.
finished
,
metrics
=
response
.
metrics
,
kv_transfer_params
=
response
.
kv_transfer_params
,
).
model_dump_json
()
components/src/dynamo/vllm/multimodal_handlers/worker_handler.py
View file @
cb88fdc7
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
copy
import
logging
import
logging
from
collections
import
defaultdict
from
typing
import
Any
import
torch
from
vllm.inputs.data
import
TokensPrompt
from
vllm.inputs.data
import
TokensPrompt
from
vllm.v1.engine.async_llm
import
AsyncLLM
import
dynamo.nixl_connect
as
connect
import
dynamo.nixl_connect
as
connect
from
dynamo.runtime
import
Client
,
Component
,
DistributedRuntime
from
dynamo.runtime
import
DistributedRuntime
from
..args
import
Config
from
..handlers
import
BaseWorkerHandler
from
..handlers
import
BaseWorkerHandler
from
..multimodal_utils
import
ImageLoader
,
MyRequestOutput
,
vLLMMultimodalRequest
from
..multimodal_utils
import
MyRequestOutput
,
vLLMMultimodalRequest
from
..multimodal_utils.model
import
construct_qwen_decode_mm_data
,
is_qwen_vl_model
from
..multimodal_utils.model
import
construct_qwen_decode_mm_data
,
is_qwen_vl_model
from
..multimodal_utils.prefill_worker_utils
import
(
accumulate_embeddings
,
load_embeddings
,
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -32,7 +24,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
...
@@ -32,7 +24,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
runtime
,
runtime
,
component
,
component
,
engine_client
,
engine_client
,
config
,
config
:
Config
,
shutdown_event
=
None
,
shutdown_event
=
None
,
):
):
# Get default_sampling_params from config
# Get default_sampling_params from config
...
@@ -111,204 +103,3 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
...
@@ -111,204 +103,3 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
metrics
=
response
.
metrics
,
metrics
=
response
.
metrics
,
kv_transfer_params
=
response
.
kv_transfer_params
,
kv_transfer_params
=
response
.
kv_transfer_params
,
).
model_dump_json
()
).
model_dump_json
()
class
MultimodalPDWorkerHandler
(
BaseWorkerHandler
):
"""Prefill/Decode or Prefill-only worker for multimodal serving"""
def
__init__
(
self
,
runtime
,
component
:
Component
,
engine_client
:
AsyncLLM
,
config
,
decode_worker_client
:
Client
|
None
=
None
,
shutdown_event
=
None
,
):
# Get default_sampling_params from config
default_sampling_params
=
(
config
.
engine_args
.
create_model_config
().
get_diff_sampling_param
()
)
# Call BaseWorkerHandler.__init__ with proper parameters
super
().
__init__
(
runtime
,
component
,
engine_client
,
default_sampling_params
,
enable_multimodal
=
config
.
enable_multimodal
,
shutdown_event
=
shutdown_event
,
)
self
.
config
=
config
self
.
decode_worker_client
=
decode_worker_client
self
.
enable_disagg
=
config
.
is_prefill_worker
# Initialize multimodal-specific components
logger
.
info
(
"Multimodal PD Worker startup started."
)
if
"video"
in
self
.
config
.
model
.
lower
():
self
.
EMBEDDINGS_DTYPE
=
torch
.
uint8
else
:
self
.
EMBEDDINGS_DTYPE
=
torch
.
float16
self
.
EMBEDDINGS_DEVICE
=
"cpu"
# Create and initialize a dynamo connector for this worker.
# We'll need this to move data between this worker and remote workers efficiently.
# Note: This is synchronous initialization, async initialization happens in async_init
self
.
_connector
:
connect
.
Connector
|
None
=
(
None
# Will be initialized in async_init
)
self
.
image_loader
=
ImageLoader
()
logger
.
info
(
"Multimodal PD Worker has been initialized"
)
async
def
async_init
(
self
,
runtime
:
DistributedRuntime
):
"""Async initialization for connector that requires async setup"""
# Initialize the connector asynchronously
self
.
_connector
=
connect
.
Connector
()
logger
.
info
(
"Multimodal PD Worker async initialization completed."
)
async
def
generate
(
self
,
request
:
vLLMMultimodalRequest
,
context
):
logger
.
debug
(
f
"Got raw request:
{
request
}
"
)
if
type
(
request
)
is
not
vLLMMultimodalRequest
:
if
type
(
request
)
is
str
:
request
=
vLLMMultimodalRequest
.
model_validate_json
(
request
)
else
:
request
=
vLLMMultimodalRequest
.
model_validate
(
request
)
logger
.
debug
(
f
"Received PD request: {{ id:
{
request
.
request_id
}
}}."
)
multi_modal_data
:
dict
[
str
,
Any
]
=
defaultdict
(
list
)
for
mi
in
request
.
multimodal_inputs
:
if
mi
.
multimodal_input
.
image_url
:
# PIL image path — used by both EC consumer mode
# (vLLM looks up cached embeddings via mm_hash) and
# non-disaggregated mode (vLLM encodes inline).
multi_modal_data
[
"image"
].
append
(
await
self
.
image_loader
.
load_image
(
mi
.
multimodal_input
.
image_url
)
)
else
:
# Pre-computed embeddings via NIXL RDMA or local safetensors
embeddings
=
await
load_embeddings
(
mi
,
self
.
EMBEDDINGS_DTYPE
,
self
.
EMBEDDINGS_DEVICE
,
self
.
_connector
,
)
accumulate_embeddings
(
multi_modal_data
,
self
.
config
.
model
,
self
.
EMBEDDINGS_DTYPE
,
embeddings
,
mi
.
image_grid_thw
,
)
# For Qwen VL (mRoPE), capture the accumulated image grid + embedding shape
# from the constructed multimodal data so decode can reconstruct its
# multi_modal_data consistently for multiple images.
if
is_qwen_vl_model
(
self
.
config
.
model
)
and
isinstance
(
multi_modal_data
.
get
(
"image"
),
dict
):
image_data
=
multi_modal_data
[
"image"
]
image_grid_thw
=
image_data
.
get
(
"image_grid_thw"
)
image_embeds
=
image_data
.
get
(
"image_embeds"
)
if
image_grid_thw
is
not
None
:
request
.
image_grid_thw
=
(
image_grid_thw
.
tolist
()
if
isinstance
(
image_grid_thw
,
torch
.
Tensor
)
else
image_grid_thw
)
if
image_embeds
is
not
None
:
request
.
embeddings_shape
=
list
(
image_embeds
.
shape
)
# Remove the image features from the request as they are not required
# Use empty list instead of None to satisfy Pydantic validation on decode worker after vllm upgrade
request
.
multimodal_inputs
=
[]
logger
.
info
(
f
"Prepared multimodal data size:
{
len
(
multi_modal_data
[
'image'
])
}
"
)
logger
.
info
(
f
"
{
multi_modal_data
}
"
)
# Deepcopy the request to avoid modifying the original
# when we adjust sampling params for prefill
pd_request
=
copy
.
deepcopy
(
request
)
# Do prefill and remote decode if enable_disagg is true
if
self
.
enable_disagg
and
self
.
decode_worker_client
:
extra_args
=
pd_request
.
sampling_params
.
extra_args
or
{}
extra_args
[
"kv_transfer_params"
]
=
{
"do_remote_decode"
:
True
,
}
pd_request
.
sampling_params
.
extra_args
=
extra_args
pd_request
.
sampling_params
.
max_tokens
=
1
pd_request
.
sampling_params
.
min_tokens
=
1
logger
.
debug
(
"Prefill request: %s"
,
pd_request
)
gen
=
self
.
engine_client
.
generate
(
prompt
=
TokensPrompt
(
prompt_token_ids
=
pd_request
.
engine_prompt
[
"prompt_token_ids"
],
multi_modal_data
=
multi_modal_data
,
),
sampling_params
=
pd_request
.
sampling_params
,
request_id
=
pd_request
.
request_id
,
)
if
self
.
enable_disagg
and
self
.
decode_worker_client
:
decode_request
=
copy
.
deepcopy
(
request
)
async
for
prefill_response
in
gen
:
# For Qwen VL models with mRoPE: Keep the ORIGINAL unexpanded prompt.
# The decode worker will pass multi_modal_data which causes vLLM to
# expand the prompt identically to prefill, ensuring block counts match.
#
# For other models: Use the expanded prompt from prefill response.
# These models don't pass multi_modal_data in decode, so they need
# the already-expanded prompt to match the KV cache layout.
if
not
is_qwen_vl_model
(
self
.
config
.
model
):
decode_request
.
engine_prompt
[
"prompt_token_ids"
]
=
prefill_response
.
prompt_token_ids
logger
.
debug
(
f
"Prefill response kv_transfer_params:
{
prefill_response
.
kv_transfer_params
}
"
)
extra_args
=
decode_request
.
sampling_params
.
extra_args
or
{}
extra_args
[
"kv_transfer_params"
]
=
prefill_response
.
kv_transfer_params
extra_args
.
pop
(
"serialized_request"
,
None
)
decode_request
.
sampling_params
.
extra_args
=
extra_args
logger
.
debug
(
"Decode request: %s"
,
decode_request
)
async
for
(
decode_response
)
in
await
self
.
decode_worker_client
.
round_robin
(
decode_request
.
model_dump_json
()
):
output
=
MyRequestOutput
.
model_validate_json
(
decode_response
.
data
())
# type: ignore[attr-defined]
yield
MyRequestOutput
(
request_id
=
output
.
request_id
,
prompt
=
output
.
prompt
,
prompt_token_ids
=
output
.
prompt_token_ids
,
prompt_logprobs
=
output
.
prompt_logprobs
,
outputs
=
output
.
outputs
,
finished
=
output
.
finished
,
metrics
=
output
.
metrics
,
kv_transfer_params
=
output
.
kv_transfer_params
,
).
model_dump_json
()
else
:
async
for
response
in
gen
:
logger
.
debug
(
f
"Response kv_transfer_params:
{
response
.
kv_transfer_params
}
"
)
logger
.
debug
(
f
"length of expanded prompt ids:
{
len
(
response
.
prompt_token_ids
)
}
"
)
# logger.info(f"Response outputs: {response.outputs}")
yield
MyRequestOutput
(
request_id
=
response
.
request_id
,
prompt
=
response
.
prompt
,
prompt_token_ids
=
response
.
prompt_token_ids
,
prompt_logprobs
=
response
.
prompt_logprobs
,
outputs
=
response
.
outputs
,
finished
=
response
.
finished
,
metrics
=
response
.
metrics
,
kv_transfer_params
=
response
.
kv_transfer_params
,
).
model_dump_json
()
components/src/dynamo/vllm/tests/multimodal_handlers/test_vllm_multimodal_pd_worker_handler.py
0 → 100644
View file @
cb88fdc7
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for MultimodalPDWorkerHandler.__init__."""
from
unittest.mock
import
MagicMock
,
patch
import
pytest
from
dynamo.common.memory.multimodal_embedding_cache_manager
import
(
MultimodalEmbeddingCacheManager
,
)
from
dynamo.vllm.multimodal_handlers
import
multimodal_pd_worker_handler
as
mod
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
vllm
,
pytest
.
mark
.
gpu_0
,
pytest
.
mark
.
multimodal
,
]
def
_make_config
(
model
:
str
=
"test-model"
,
is_prefill_worker
:
bool
=
False
,
enable_multimodal
:
bool
=
True
,
multimodal_embedding_cache_capacity_gb
:
float
=
0
,
)
->
MagicMock
:
"""Create a mock Config with the fields used by MultimodalPDWorkerHandler.__init__."""
config
=
MagicMock
()
config
.
model
=
model
config
.
is_prefill_worker
=
is_prefill_worker
config
.
enable_multimodal
=
enable_multimodal
config
.
multimodal_embedding_cache_capacity_gb
=
(
multimodal_embedding_cache_capacity_gb
)
config
.
engine_args
.
create_model_config
.
return_value
.
get_diff_sampling_param
.
return_value
=
(
{}
)
return
config
class
TestMultimodalPDWorkerHandlerInit
:
"""Tests for MultimodalPDWorkerHandler.__init__ focusing on embedding cache."""
def
test_init_with_embedding_cache
(
self
):
"""When capacity > 0, a MultimodalEmbeddingCacheManager is created with correct byte size."""
capacity_gb
=
0.1
config
=
_make_config
(
multimodal_embedding_cache_capacity_gb
=
capacity_gb
)
with
(
patch
.
object
(
mod
.
BaseWorkerHandler
,
"__init__"
,
return_value
=
None
),
patch
.
object
(
mod
,
"ImageLoader"
,
new_callable
=
MagicMock
),
):
handler
=
mod
.
MultimodalPDWorkerHandler
(
runtime
=
MagicMock
(),
component
=
MagicMock
(),
engine_client
=
MagicMock
(),
config
=
config
,
)
assert
isinstance
(
handler
.
embedding_cache_manager
,
MultimodalEmbeddingCacheManager
)
expected_bytes
=
int
(
capacity_gb
*
1024
**
3
)
assert
handler
.
embedding_cache_manager
.
_capacity_bytes
==
expected_bytes
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment