Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
20ccc9b2
"lib/llm/src/entrypoint/input/common.rs" did not exist on "73fdfb8ab84c9f56982d7d6074ef4d2f2a214150"
Unverified
Commit
20ccc9b2
authored
Feb 10, 2026
by
Qi Wang
Committed by
GitHub
Feb 10, 2026
Browse files
refactor: add prefill_worker_utils in vLLM (#6017)
parent
1aab7f6b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
149 additions
and
109 deletions
+149
-109
components/src/dynamo/vllm/multimodal_handlers/worker_handler.py
...nts/src/dynamo/vllm/multimodal_handlers/worker_handler.py
+26
-109
components/src/dynamo/vllm/multimodal_utils/__init__.py
components/src/dynamo/vllm/multimodal_utils/__init__.py
+6
-0
components/src/dynamo/vllm/multimodal_utils/prefill_worker_utils.py
.../src/dynamo/vllm/multimodal_utils/prefill_worker_utils.py
+117
-0
No files found.
components/src/dynamo/vllm/multimodal_handlers/worker_handler.py
View file @
20ccc9b2
...
...
@@ -3,10 +3,9 @@
import
copy
import
logging
import
os
from
collections
import
defaultdict
from
typing
import
Any
import
safetensors
import
torch
from
vllm.inputs.data
import
TokensPrompt
from
vllm.v1.engine.async_llm
import
AsyncLLM
...
...
@@ -15,18 +14,15 @@ import dynamo.nixl_connect as connect
from
dynamo.runtime
import
Client
,
Component
,
DistributedRuntime
from
..handlers
import
BaseWorkerHandler
from
..multimodal_utils
import
(
ImageLoader
,
MyRequestOutput
,
construct_mm_data
,
vLLMMultimodalRequest
,
)
from
..multimodal_utils
import
ImageLoader
,
MyRequestOutput
,
vLLMMultimodalRequest
from
..multimodal_utils.model
import
construct_qwen_decode_mm_data
,
is_qwen_vl_model
from
..multimodal_utils.prefill_worker_utils
import
(
accumulate_embeddings
,
load_embeddings
,
)
logger
=
logging
.
getLogger
(
__name__
)
TRANSFER_LOCAL
=
int
(
os
.
getenv
(
"TRANSFER_LOCAL"
,
1
))
class
MultimodalDecodeWorkerHandler
(
BaseWorkerHandler
):
"""Decode worker for disaggregated multimodal serving"""
...
...
@@ -181,108 +177,29 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
request
=
vLLMMultimodalRequest
.
model_validate
(
request
)
logger
.
debug
(
f
"Received PD request: {{ id:
{
request
.
request_id
}
}}."
)
multi_modal_data
=
defaultdict
(
list
)
multi_modal_data
:
dict
[
str
,
Any
]
=
defaultdict
(
list
)
for
mi
in
request
.
multimodal_inputs
:
# ECConnector consumer mode: vLLM loads embeddings automatically from disk
# We need to pass multimodal_input so vLLM can generate mm_hash and look up cache
if
self
.
config
.
ec_consumer_mode
:
logger
.
debug
(
f
"[
{
request
.
request_id
}
] ECConnector consumer mode: "
f
"vLLM will load embeddings from cache using mm_hash"
)
# Use PIL image loading - vLLM will detect it's already in EC cache
# and load from disk instead of reprocessing
if
mi
.
multimodal_input
.
image_url
:
# PIL image path — used by both EC consumer mode
# (vLLM looks up cached embeddings via mm_hash) and
# non-disaggregated mode (vLLM encodes inline).
multi_modal_data
[
"image"
].
append
(
await
self
.
image_loader
.
load_image
(
mi
.
multimodal_input
.
image_url
)
)
elif
mi
.
multimodal_input
.
video_url
:
# For video, load as image placeholder (vLLM will use EC cache)
multi_modal_data
[
"image"
].
append
(
await
self
.
image_loader
.
load_image
(
request
.
multimodal_input
.
video_url
)
)
else
:
raise
ValueError
(
"ECConnector mode requires multimodal_input with image/video URL"
await
self
.
image_loader
.
load_image
(
mi
.
multimodal_input
.
image_url
)
)
elif
(
mi
.
multimodal_input
.
image_url
is
None
and
mi
.
multimodal_input
.
video_url
is
None
):
# Process embeddings using the connector
# Create a descriptor based on the embedding shape.
if
TRANSFER_LOCAL
:
logger
.
info
(
"PD: Loading local safetensors file"
)
embeddings
=
safetensors
.
torch
.
load_file
(
mi
.
serialized_request
)[
"ec_cache"
]
else
:
embeddings
=
torch
.
empty
(
mi
.
embeddings_shape
,
dtype
=
self
.
EMBEDDINGS_DTYPE
,
device
=
self
.
EMBEDDINGS_DEVICE
,
)
descriptor
=
connect
.
Descriptor
(
embeddings
)
if
descriptor
is
None
:
raise
RuntimeError
(
"Descriptor is None in PD worker - cannot process embeddings"
)
read_op
=
await
self
.
_connector
.
begin_read
(
mi
.
serialized_request
,
descriptor
)
await
read_op
.
wait_for_completion
()
if
"video"
in
self
.
config
.
model
.
lower
():
video_numpy
=
embeddings
.
numpy
()
mm_data
=
construct_mm_data
(
self
.
config
.
model
,
# Pre-computed embeddings via NIXL RDMA or local safetensors
embeddings
=
await
load_embeddings
(
mi
,
self
.
EMBEDDINGS_DTYPE
,
video_numpy
=
video_numpy
,
self
.
EMBEDDINGS_DEVICE
,
self
.
_connector
,
)
multi_modal_data
[
"video"
].
append
(
mm_data
[
"video"
])
else
:
mm_data
=
construct_mm_data
(
accumulate_embeddings
(
multi_modal_data
,
self
.
config
.
model
,
self
.
EMBEDDINGS_DTYPE
,
image_embeds
=
embeddings
,
image_grid_thw
=
mi
.
image_grid_thw
,
)
if
isinstance
(
mm_data
[
"image"
],
dict
):
if
multi_modal_data
[
"image"
]
==
[]:
multi_modal_data
[
"image"
]
=
mm_data
[
"image"
]
else
:
# [gluo FIXME] need to understand how Qwen consumes multi-image embeddings
# Merging tensors
multi_modal_data
[
"image"
][
"image_embeds"
]
=
torch
.
cat
(
(
multi_modal_data
[
"image"
][
"image_embeds"
],
mm_data
[
"image"
][
"image_embeds"
],
)
)
multi_modal_data
[
"image"
][
"image_grid_thw"
]
=
torch
.
cat
(
(
multi_modal_data
[
"image"
][
"image_grid_thw"
],
mm_data
[
"image"
][
"image_grid_thw"
],
)
)
else
:
logger
.
info
(
f
"Get embedding of shape
{
mm_data
[
'image'
].
shape
}
"
)
# [gluo FIXME] embedding with multiple images?
if
multi_modal_data
[
"image"
]
==
[]:
multi_modal_data
[
"image"
]
=
mm_data
[
"image"
]
else
:
multi_modal_data
[
"image"
]
=
torch
.
cat
(
(
multi_modal_data
[
"image"
],
mm_data
[
"image"
])
)
else
:
# Use PIL image instead of image embeddings
multi_modal_data
[
"image"
].
append
(
await
self
.
image_loader
.
load_image
(
mi
.
multimodal_input
.
image_url
)
embeddings
,
mi
.
image_grid_thw
,
)
# For Qwen VL (mRoPE), capture the accumulated image grid + embedding shape
...
...
components/src/dynamo/vllm/multimodal_utils/__init__.py
View file @
20ccc9b2
...
...
@@ -19,6 +19,10 @@ from dynamo.vllm.multimodal_utils.model import (
construct_mm_data
,
load_vision_model
,
)
from
dynamo.vllm.multimodal_utils.prefill_worker_utils
import
(
accumulate_embeddings
,
load_embeddings
,
)
from
dynamo.vllm.multimodal_utils.protocol
import
(
MultiModalGroup
,
MultiModalInput
,
...
...
@@ -51,4 +55,6 @@ __all__ = [
"vLLMMultimodalRequest"
,
"VLLMNativeEncoderRequest"
,
"VLLMNativeEncoderResponse"
,
"accumulate_embeddings"
,
"load_embeddings"
,
]
components/src/dynamo/vllm/multimodal_utils/prefill_worker_utils.py
0 → 100644
View file @
20ccc9b2
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
logging
import
os
from
typing
import
Any
,
Dict
import
safetensors
import
torch
import
dynamo.nixl_connect
as
connect
from
.model
import
construct_mm_data
from
.protocol
import
MultiModalGroup
logger
=
logging
.
getLogger
(
__name__
)
TRANSFER_LOCAL
=
int
(
os
.
getenv
(
"TRANSFER_LOCAL"
,
1
))
async
def
load_embeddings
(
mi
:
MultiModalGroup
,
embeddings_dtype
:
torch
.
dtype
,
embeddings_device
:
str
,
connector
:
connect
.
Connector
|
None
,
)
->
torch
.
Tensor
:
"""Load pre-computed embedding tensor via local safetensors or NIXL RDMA.
Args:
mi: A single MultiModalGroup whose ``serialized_request`` field
contains either a local file path or NIXL RDMA metadata.
embeddings_dtype: Torch dtype for the tensor (used for RDMA path).
embeddings_device: Device string for the tensor (used for RDMA path).
connector: NIXL Connector for RDMA reads (required when TRANSFER_LOCAL=0).
Returns:
The embedding tensor loaded into CPU memory.
"""
if
TRANSFER_LOCAL
:
logger
.
info
(
"PD: Loading local safetensors file"
)
return
safetensors
.
torch
.
load_file
(
mi
.
serialized_request
)[
"ec_cache"
]
embeddings
=
torch
.
empty
(
mi
.
embeddings_shape
,
dtype
=
embeddings_dtype
,
device
=
embeddings_device
,
)
descriptor
=
connect
.
Descriptor
(
embeddings
)
if
descriptor
is
None
:
raise
RuntimeError
(
"Descriptor is None in PD worker - cannot process embeddings"
)
read_op
=
await
connector
.
begin_read
(
mi
.
serialized_request
,
descriptor
)
await
read_op
.
wait_for_completion
()
return
embeddings
def
accumulate_embeddings
(
multi_modal_data
:
Dict
[
str
,
Any
],
model
:
str
,
embeddings_dtype
:
torch
.
dtype
,
embeddings
:
torch
.
Tensor
,
image_grid_thw
,
)
->
None
:
"""Construct model-specific mm_data from embeddings and merge into the
accumulated ``multi_modal_data`` dict (mutated in-place).
Handles both video (numpy conversion) and image modalities, including
the Qwen-VL dict-style embeddings with ``image_embeds`` + ``image_grid_thw``.
"""
if
"video"
in
model
.
lower
():
video_numpy
=
embeddings
.
numpy
()
mm_data
=
construct_mm_data
(
model
,
embeddings_dtype
,
video_numpy
=
video_numpy
,
)
multi_modal_data
[
"video"
].
append
(
mm_data
[
"video"
])
return
mm_data
=
construct_mm_data
(
model
,
embeddings_dtype
,
image_embeds
=
embeddings
,
image_grid_thw
=
image_grid_thw
,
)
if
isinstance
(
mm_data
[
"image"
],
dict
):
# Qwen-VL style: dict with image_embeds + image_grid_thw tensors
if
multi_modal_data
[
"image"
]
==
[]:
multi_modal_data
[
"image"
]
=
mm_data
[
"image"
]
else
:
# [gluo FIXME] need to understand how Qwen consumes multi-image embeddings
multi_modal_data
[
"image"
][
"image_embeds"
]
=
torch
.
cat
(
(
multi_modal_data
[
"image"
][
"image_embeds"
],
mm_data
[
"image"
][
"image_embeds"
],
)
)
multi_modal_data
[
"image"
][
"image_grid_thw"
]
=
torch
.
cat
(
(
multi_modal_data
[
"image"
][
"image_grid_thw"
],
mm_data
[
"image"
][
"image_grid_thw"
],
)
)
else
:
# Plain tensor embeddings
logger
.
info
(
f
"Get embedding of shape
{
mm_data
[
'image'
].
shape
}
"
)
# [gluo FIXME] embedding with multiple images?
if
multi_modal_data
[
"image"
]
==
[]:
multi_modal_data
[
"image"
]
=
mm_data
[
"image"
]
else
:
multi_modal_data
[
"image"
]
=
torch
.
cat
(
(
multi_modal_data
[
"image"
],
mm_data
[
"image"
])
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment