Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
20ccc9b2
Unverified
Commit
20ccc9b2
authored
Feb 10, 2026
by
Qi Wang
Committed by
GitHub
Feb 10, 2026
Browse files
refactor: add prefill_worker_utils in vLLM (#6017)
parent
1aab7f6b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
149 additions
and
109 deletions
+149
-109
components/src/dynamo/vllm/multimodal_handlers/worker_handler.py
...nts/src/dynamo/vllm/multimodal_handlers/worker_handler.py
+26
-109
components/src/dynamo/vllm/multimodal_utils/__init__.py
components/src/dynamo/vllm/multimodal_utils/__init__.py
+6
-0
components/src/dynamo/vllm/multimodal_utils/prefill_worker_utils.py
.../src/dynamo/vllm/multimodal_utils/prefill_worker_utils.py
+117
-0
No files found.
components/src/dynamo/vllm/multimodal_handlers/worker_handler.py
View file @
20ccc9b2
...
...
@@ -3,10 +3,9 @@
import
copy
import
logging
import
os
from
collections
import
defaultdict
from
typing
import
Any
import
safetensors
import
torch
from
vllm.inputs.data
import
TokensPrompt
from
vllm.v1.engine.async_llm
import
AsyncLLM
...
...
@@ -15,18 +14,15 @@ import dynamo.nixl_connect as connect
from
dynamo.runtime
import
Client
,
Component
,
DistributedRuntime
from
..handlers
import
BaseWorkerHandler
from
..multimodal_utils
import
(
ImageLoader
,
MyRequestOutput
,
construct_mm_data
,
vLLMMultimodalRequest
,
)
from
..multimodal_utils
import
ImageLoader
,
MyRequestOutput
,
vLLMMultimodalRequest
from
..multimodal_utils.model
import
construct_qwen_decode_mm_data
,
is_qwen_vl_model
from
..multimodal_utils.prefill_worker_utils
import
(
accumulate_embeddings
,
load_embeddings
,
)
logger
=
logging
.
getLogger
(
__name__
)
TRANSFER_LOCAL
=
int
(
os
.
getenv
(
"TRANSFER_LOCAL"
,
1
))
class
MultimodalDecodeWorkerHandler
(
BaseWorkerHandler
):
"""Decode worker for disaggregated multimodal serving"""
...
...
@@ -181,109 +177,30 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
request
=
vLLMMultimodalRequest
.
model_validate
(
request
)
logger
.
debug
(
f
"Received PD request: {{ id:
{
request
.
request_id
}
}}."
)
multi_modal_data
=
defaultdict
(
list
)
multi_modal_data
:
dict
[
str
,
Any
]
=
defaultdict
(
list
)
for
mi
in
request
.
multimodal_inputs
:
# ECConnector consumer mode: vLLM loads embeddings automatically from disk
# We need to pass multimodal_input so vLLM can generate mm_hash and look up cache
if
self
.
config
.
ec_consumer_mode
:
logger
.
debug
(
f
"[
{
request
.
request_id
}
] ECConnector consumer mode: "
f
"vLLM will load embeddings from cache using mm_hash"
)
# Use PIL image loading - vLLM will detect it's already in EC cache
# and load from disk instead of reprocessing
if
mi
.
multimodal_input
.
image_url
:
multi_modal_data
[
"image"
].
append
(
await
self
.
image_loader
.
load_image
(
mi
.
multimodal_input
.
image_url
)
)
elif
mi
.
multimodal_input
.
video_url
:
# For video, load as image placeholder (vLLM will use EC cache)
multi_modal_data
[
"image"
].
append
(
await
self
.
image_loader
.
load_image
(
request
.
multimodal_input
.
video_url
)
)
else
:
raise
ValueError
(
"ECConnector mode requires multimodal_input with image/video URL"
)
elif
(
mi
.
multimodal_input
.
image_url
is
None
and
mi
.
multimodal_input
.
video_url
is
None
):
# Process embeddings using the connector
# Create a descriptor based on the embedding shape.
if
TRANSFER_LOCAL
:
logger
.
info
(
"PD: Loading local safetensors file"
)
embeddings
=
safetensors
.
torch
.
load_file
(
mi
.
serialized_request
)[
"ec_cache"
]
else
:
embeddings
=
torch
.
empty
(
mi
.
embeddings_shape
,
dtype
=
self
.
EMBEDDINGS_DTYPE
,
device
=
self
.
EMBEDDINGS_DEVICE
,
)
descriptor
=
connect
.
Descriptor
(
embeddings
)
if
descriptor
is
None
:
raise
RuntimeError
(
"Descriptor is None in PD worker - cannot process embeddings"
)
read_op
=
await
self
.
_connector
.
begin_read
(
mi
.
serialized_request
,
descriptor
)
await
read_op
.
wait_for_completion
()
if
"video"
in
self
.
config
.
model
.
lower
():
video_numpy
=
embeddings
.
numpy
()
mm_data
=
construct_mm_data
(
self
.
config
.
model
,
self
.
EMBEDDINGS_DTYPE
,
video_numpy
=
video_numpy
,
)
multi_modal_data
[
"video"
].
append
(
mm_data
[
"video"
])
else
:
mm_data
=
construct_mm_data
(
self
.
config
.
model
,
self
.
EMBEDDINGS_DTYPE
,
image_embeds
=
embeddings
,
image_grid_thw
=
mi
.
image_grid_thw
,
)
if
isinstance
(
mm_data
[
"image"
],
dict
):
if
multi_modal_data
[
"image"
]
==
[]:
multi_modal_data
[
"image"
]
=
mm_data
[
"image"
]
else
:
# [gluo FIXME] need to understand how Qwen consumes multi-image embeddings
# Merging tensors
multi_modal_data
[
"image"
][
"image_embeds"
]
=
torch
.
cat
(
(
multi_modal_data
[
"image"
][
"image_embeds"
],
mm_data
[
"image"
][
"image_embeds"
],
)
)
multi_modal_data
[
"image"
][
"image_grid_thw"
]
=
torch
.
cat
(
(
multi_modal_data
[
"image"
][
"image_grid_thw"
],
mm_data
[
"image"
][
"image_grid_thw"
],
)
)
else
:
logger
.
info
(
f
"Get embedding of shape
{
mm_data
[
'image'
].
shape
}
"
)
# [gluo FIXME] embedding with multiple images?
if
multi_modal_data
[
"image"
]
==
[]:
multi_modal_data
[
"image"
]
=
mm_data
[
"image"
]
else
:
multi_modal_data
[
"image"
]
=
torch
.
cat
(
(
multi_modal_data
[
"image"
],
mm_data
[
"image"
])
)
else
:
# Use PIL image instead of image embeddings
if
mi
.
multimodal_input
.
image_url
:
# PIL image path — used by both EC consumer mode
# (vLLM looks up cached embeddings via mm_hash) and
# non-disaggregated mode (vLLM encodes inline).
multi_modal_data
[
"image"
].
append
(
await
self
.
image_loader
.
load_image
(
mi
.
multimodal_input
.
image_url
)
)
else
:
# Pre-computed embeddings via NIXL RDMA or local safetensors
embeddings
=
await
load_embeddings
(
mi
,
self
.
EMBEDDINGS_DTYPE
,
self
.
EMBEDDINGS_DEVICE
,
self
.
_connector
,
)
accumulate_embeddings
(
multi_modal_data
,
self
.
config
.
model
,
self
.
EMBEDDINGS_DTYPE
,
embeddings
,
mi
.
image_grid_thw
,
)
# For Qwen VL (mRoPE), capture the accumulated image grid + embedding shape
# from the constructed multimodal data so decode can reconstruct its
...
...
components/src/dynamo/vllm/multimodal_utils/__init__.py
View file @
20ccc9b2
...
...
@@ -19,6 +19,10 @@ from dynamo.vllm.multimodal_utils.model import (
construct_mm_data
,
load_vision_model
,
)
from
dynamo.vllm.multimodal_utils.prefill_worker_utils
import
(
accumulate_embeddings
,
load_embeddings
,
)
from
dynamo.vllm.multimodal_utils.protocol
import
(
MultiModalGroup
,
MultiModalInput
,
...
...
@@ -51,4 +55,6 @@ __all__ = [
"vLLMMultimodalRequest"
,
"VLLMNativeEncoderRequest"
,
"VLLMNativeEncoderResponse"
,
"accumulate_embeddings"
,
"load_embeddings"
,
]
components/src/dynamo/vllm/multimodal_utils/prefill_worker_utils.py
0 → 100644
View file @
20ccc9b2
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
logging
import
os
from
typing
import
Any
,
Dict
import
safetensors
import
torch
import
dynamo.nixl_connect
as
connect
from
.model
import
construct_mm_data
from
.protocol
import
MultiModalGroup
logger
=
logging
.
getLogger
(
__name__
)
TRANSFER_LOCAL
=
int
(
os
.
getenv
(
"TRANSFER_LOCAL"
,
1
))
async
def
load_embeddings
(
mi
:
MultiModalGroup
,
embeddings_dtype
:
torch
.
dtype
,
embeddings_device
:
str
,
connector
:
connect
.
Connector
|
None
,
)
->
torch
.
Tensor
:
"""Load pre-computed embedding tensor via local safetensors or NIXL RDMA.
Args:
mi: A single MultiModalGroup whose ``serialized_request`` field
contains either a local file path or NIXL RDMA metadata.
embeddings_dtype: Torch dtype for the tensor (used for RDMA path).
embeddings_device: Device string for the tensor (used for RDMA path).
connector: NIXL Connector for RDMA reads (required when TRANSFER_LOCAL=0).
Returns:
The embedding tensor loaded into CPU memory.
"""
if
TRANSFER_LOCAL
:
logger
.
info
(
"PD: Loading local safetensors file"
)
return
safetensors
.
torch
.
load_file
(
mi
.
serialized_request
)[
"ec_cache"
]
embeddings
=
torch
.
empty
(
mi
.
embeddings_shape
,
dtype
=
embeddings_dtype
,
device
=
embeddings_device
,
)
descriptor
=
connect
.
Descriptor
(
embeddings
)
if
descriptor
is
None
:
raise
RuntimeError
(
"Descriptor is None in PD worker - cannot process embeddings"
)
read_op
=
await
connector
.
begin_read
(
mi
.
serialized_request
,
descriptor
)
await
read_op
.
wait_for_completion
()
return
embeddings
def
accumulate_embeddings
(
multi_modal_data
:
Dict
[
str
,
Any
],
model
:
str
,
embeddings_dtype
:
torch
.
dtype
,
embeddings
:
torch
.
Tensor
,
image_grid_thw
,
)
->
None
:
"""Construct model-specific mm_data from embeddings and merge into the
accumulated ``multi_modal_data`` dict (mutated in-place).
Handles both video (numpy conversion) and image modalities, including
the Qwen-VL dict-style embeddings with ``image_embeds`` + ``image_grid_thw``.
"""
if
"video"
in
model
.
lower
():
video_numpy
=
embeddings
.
numpy
()
mm_data
=
construct_mm_data
(
model
,
embeddings_dtype
,
video_numpy
=
video_numpy
,
)
multi_modal_data
[
"video"
].
append
(
mm_data
[
"video"
])
return
mm_data
=
construct_mm_data
(
model
,
embeddings_dtype
,
image_embeds
=
embeddings
,
image_grid_thw
=
image_grid_thw
,
)
if
isinstance
(
mm_data
[
"image"
],
dict
):
# Qwen-VL style: dict with image_embeds + image_grid_thw tensors
if
multi_modal_data
[
"image"
]
==
[]:
multi_modal_data
[
"image"
]
=
mm_data
[
"image"
]
else
:
# [gluo FIXME] need to understand how Qwen consumes multi-image embeddings
multi_modal_data
[
"image"
][
"image_embeds"
]
=
torch
.
cat
(
(
multi_modal_data
[
"image"
][
"image_embeds"
],
mm_data
[
"image"
][
"image_embeds"
],
)
)
multi_modal_data
[
"image"
][
"image_grid_thw"
]
=
torch
.
cat
(
(
multi_modal_data
[
"image"
][
"image_grid_thw"
],
mm_data
[
"image"
][
"image_grid_thw"
],
)
)
else
:
# Plain tensor embeddings
logger
.
info
(
f
"Get embedding of shape
{
mm_data
[
'image'
].
shape
}
"
)
# [gluo FIXME] embedding with multiple images?
if
multi_modal_data
[
"image"
]
==
[]:
multi_modal_data
[
"image"
]
=
mm_data
[
"image"
]
else
:
multi_modal_data
[
"image"
]
=
torch
.
cat
(
(
multi_modal_data
[
"image"
],
mm_data
[
"image"
])
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment