Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
80955ef4
Unverified
Commit
80955ef4
authored
Feb 25, 2026
by
Kris Hung
Committed by
GitHub
Feb 25, 2026
Browse files
perf: Keep embeddings on GPU Embedding Sender in EPD pipeline + minor fixes (#6535)
parent
5734f5c4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
26 additions
and
18 deletions
+26
-18
components/src/dynamo/common/multimodal/embedding_transfer.py
...onents/src/dynamo/common/multimodal/embedding_transfer.py
+4
-6
components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py
.../dynamo/vllm/multimodal_handlers/encode_worker_handler.py
+11
-11
components/src/dynamo/vllm/multimodal_handlers/multimodal_pd_worker_handler.py
.../vllm/multimodal_handlers/multimodal_pd_worker_handler.py
+2
-0
components/src/dynamo/vllm/multimodal_utils/model.py
components/src/dynamo/vllm/multimodal_utils/model.py
+9
-1
No files found.
components/src/dynamo/common/multimodal/embedding_transfer.py
View file @
80955ef4
...
@@ -361,13 +361,11 @@ class NixlPersistentEmbeddingSender(AbstractEmbeddingSender):
...
@@ -361,13 +361,11 @@ class NixlPersistentEmbeddingSender(AbstractEmbeddingSender):
Returns:
Returns:
A tuple containing the TransferRequest object and a future that can be awaited to indicate the send is completed.
A tuple containing the TransferRequest object and a future that can be awaited to indicate the send is completed.
"""
"""
# If not staging embedding and embedding is on CPU, we explicitly copy
if
stage_embeddings
:
# the tensor as torch.Tensor.cpu() will return original tensor if it's already on CPU
transfer_buf
=
embeddings
if
not
stage_embeddings
and
not
embeddings
.
is_cuda
:
embeddings_cpu
=
embeddings
.
clone
().
detach
()
else
:
else
:
embeddings_cpu
=
embeddings
.
c
pu
()
transfer_buf
=
embeddings
.
c
lone
().
detach
()
descriptor
=
nixl_connect
.
Descriptor
(
embeddings_cpu
)
descriptor
=
nixl_connect
.
Descriptor
(
transfer_buf
)
readable_op
=
await
self
.
connector
.
create_readable
(
descriptor
)
readable_op
=
await
self
.
connector
.
create_readable
(
descriptor
)
request
=
TransferRequest
(
request
=
TransferRequest
(
...
...
components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py
View file @
80955ef4
...
@@ -42,7 +42,7 @@ ENABLE_ENCODER_CACHE = int(os.getenv("ENABLE_ENCODER_CACHE", 1))
...
@@ -42,7 +42,7 @@ ENABLE_ENCODER_CACHE = int(os.getenv("ENABLE_ENCODER_CACHE", 1))
class
EmbeddingItem
:
class
EmbeddingItem
:
key
:
str
key
:
str
image_grid_thw
:
list
image_grid_thw
:
list
embeddings
_cpu
:
torch
.
Tensor
embeddings
:
torch
.
Tensor
class
EncodeWorkerHandler
:
class
EncodeWorkerHandler
:
...
@@ -140,11 +140,11 @@ class EncodeWorkerHandler:
...
@@ -140,11 +140,11 @@ class EncodeWorkerHandler:
if
self
.
embedding_cache
is
not
None
and
self
.
embedding_cache
.
has_key
(
if
self
.
embedding_cache
is
not
None
and
self
.
embedding_cache
.
has_key
(
embedding_key
embedding_key
):
):
(
image_grid_thw
,
embeddings
_cpu
)
=
self
.
embedding_cache
.
get
(
(
image_grid_thw
,
embeddings
)
=
self
.
embedding_cache
.
get
(
embedding_key
embedding_key
)
)
embedding_lists
[
idx
]
=
EmbeddingItem
(
embedding_lists
[
idx
]
=
EmbeddingItem
(
embedding_key
,
image_grid_thw
,
embeddings
_cpu
embedding_key
,
image_grid_thw
,
embeddings
)
)
# compute
# compute
else
:
else
:
...
@@ -200,7 +200,7 @@ class EncodeWorkerHandler:
...
@@ -200,7 +200,7 @@ class EncodeWorkerHandler:
//
merge_size
//
merge_size
//
merge_size
//
merge_size
).
tolist
()
).
tolist
()
splitted_embeddings
=
embeddings
.
cpu
().
squeeze
(
0
).
split
(
sizes
)
splitted_embeddings
=
embeddings
.
squeeze
(
0
).
split
(
sizes
)
logger
.
debug
(
logger
.
debug
(
f
"Splitted embeddings lengths:
{
[
e
.
shape
for
e
in
splitted_embeddings
]
}
"
f
"Splitted embeddings lengths:
{
[
e
.
shape
for
e
in
splitted_embeddings
]
}
"
)
)
...
@@ -209,7 +209,7 @@ class EncodeWorkerHandler:
...
@@ -209,7 +209,7 @@ class EncodeWorkerHandler:
# embeddings already has batch dimension for images, so we can directly
# embeddings already has batch dimension for images, so we can directly
# split by batch dimension
# split by batch dimension
logger
.
debug
(
f
"image embedding shape:
{
embeddings
.
shape
}
"
)
logger
.
debug
(
f
"image embedding shape:
{
embeddings
.
shape
}
"
)
splitted_embeddings
=
embeddings
.
cpu
()
splitted_embeddings
=
embeddings
image_grid_thw
=
(
image_grid_thw
=
(
image_embeds
[
"image_grid_thw"
].
tolist
()
image_embeds
[
"image_grid_thw"
].
tolist
()
...
@@ -230,7 +230,7 @@ class EncodeWorkerHandler:
...
@@ -230,7 +230,7 @@ class EncodeWorkerHandler:
embedding_lists
[
list_idx
].
key
,
embedding_lists
[
list_idx
].
key
,
(
(
embedding_lists
[
list_idx
].
image_grid_thw
,
embedding_lists
[
list_idx
].
image_grid_thw
,
embedding_lists
[
list_idx
].
embeddings
_cpu
,
embedding_lists
[
list_idx
].
embeddings
,
),
),
)
)
...
@@ -240,7 +240,7 @@ class EncodeWorkerHandler:
...
@@ -240,7 +240,7 @@ class EncodeWorkerHandler:
send_tasks
=
[
send_tasks
=
[
asyncio
.
create_task
(
asyncio
.
create_task
(
self
.
embedding_sender
.
send_embeddings
(
self
.
embedding_sender
.
send_embeddings
(
embedding_item
.
embeddings
_cpu
,
stage_embeddings
=
True
embedding_item
.
embeddings
,
stage_embeddings
=
True
)
)
)
)
for
embedding_item
in
embedding_lists
for
embedding_item
in
embedding_lists
...
@@ -252,7 +252,7 @@ class EncodeWorkerHandler:
...
@@ -252,7 +252,7 @@ class EncodeWorkerHandler:
for
idx
,
item
in
enumerate
(
zip
(
embedding_lists
,
transfer_requests
)):
for
idx
,
item
in
enumerate
(
zip
(
embedding_lists
,
transfer_requests
)):
embedding_item
,
transfer_request
=
item
embedding_item
,
transfer_request
=
item
logger
.
debug
(
logger
.
debug
(
f
"
{
embedding_item
.
embeddings
_cpu
.
shape
}
prepared for transfer."
f
"
{
embedding_item
.
embeddings
.
shape
}
prepared for transfer."
)
)
# Update request for transfer metadata
# Update request for transfer metadata
request
.
multimodal_inputs
[
idx
].
multimodal_input
.
image_url
=
None
request
.
multimodal_inputs
[
idx
].
multimodal_input
.
image_url
=
None
...
@@ -260,13 +260,13 @@ class EncodeWorkerHandler:
...
@@ -260,13 +260,13 @@ class EncodeWorkerHandler:
idx
idx
].
image_grid_thw
=
embedding_item
.
image_grid_thw
].
image_grid_thw
=
embedding_item
.
image_grid_thw
request
.
multimodal_inputs
[
idx
].
embeddings_shape
=
tuple
(
request
.
multimodal_inputs
[
idx
].
embeddings_shape
=
tuple
(
embedding_item
.
embeddings
_cpu
.
shape
embedding_item
.
embeddings
.
shape
)
)
request
.
multimodal_inputs
[
idx
].
serialized_request
=
transfer_request
[
0
]
request
.
multimodal_inputs
[
idx
].
serialized_request
=
transfer_request
[
0
]
# Keep a reference of the embedding
_cpu
and only drop reference when the transfer is done
# Keep a reference of the embedding and only drop reference when the transfer is done
self
.
send_complete_queue
.
put_nowait
(
self
.
send_complete_queue
.
put_nowait
(
(
transfer_request
[
1
],
embedding_item
.
embeddings
_cpu
)
(
transfer_request
[
1
],
embedding_item
.
embeddings
)
)
)
logger
.
debug
(
f
"Request:
{
request
.
model_dump_json
()
}
"
)
logger
.
debug
(
f
"Request:
{
request
.
model_dump_json
()
}
"
)
...
...
components/src/dynamo/vllm/multimodal_handlers/multimodal_pd_worker_handler.py
View file @
80955ef4
...
@@ -129,6 +129,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
...
@@ -129,6 +129,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
for
item
in
mm_data
.
get
(
IMAGE_URL_KEY
,
[]):
for
item
in
mm_data
.
get
(
IMAGE_URL_KEY
,
[]):
if
isinstance
(
item
,
dict
)
and
"Url"
in
item
:
if
isinstance
(
item
,
dict
)
and
"Url"
in
item
:
image_urls
.
append
(
item
[
"Url"
])
image_urls
.
append
(
item
[
"Url"
])
elif
isinstance
(
item
,
dict
)
and
"Decoded"
in
item
:
image_urls
.
append
(
item
[
"Decoded"
])
sampling_params
=
build_sampling_params
(
sampling_params
=
build_sampling_params
(
raw_request
,
self
.
default_sampling_params
raw_request
,
self
.
default_sampling_params
...
...
components/src/dynamo/vllm/multimodal_utils/model.py
View file @
80955ef4
...
@@ -33,6 +33,10 @@ VLLM_ENCODER = int(os.getenv("VLLM_ENCODER", 1))
...
@@ -33,6 +33,10 @@ VLLM_ENCODER = int(os.getenv("VLLM_ENCODER", 1))
class
SupportedModels
:
class
SupportedModels
:
"""Supported multimodal model identifiers"""
"""Supported multimodal model identifiers"""
# TODO: Replace this explicit model list with dynamic detection using
# HF config `architectures` field or vLLM's model registry, so any
# vLLM-supported VLM works without maintaining entries here.
LLAVA_1_5_7B
=
"llava-hf/llava-1.5-7b-hf"
LLAVA_1_5_7B
=
"llava-hf/llava-1.5-7b-hf"
QWEN_2_VL_2B
=
"Qwen/Qwen2-VL-2B-Instruct"
QWEN_2_VL_2B
=
"Qwen/Qwen2-VL-2B-Instruct"
QWEN_2_5_VL_3B
=
"Qwen/Qwen2.5-VL-3B-Instruct"
QWEN_2_5_VL_3B
=
"Qwen/Qwen2.5-VL-3B-Instruct"
...
@@ -42,6 +46,8 @@ class SupportedModels:
...
@@ -42,6 +46,8 @@ class SupportedModels:
QWEN_3_VL_30B_A3B
=
"Qwen/Qwen3-VL-30B-A3B-Instruct"
QWEN_3_VL_30B_A3B
=
"Qwen/Qwen3-VL-30B-A3B-Instruct"
QWEN_3_VL_30B_A3B_FP8
=
"Qwen/Qwen3-VL-30B-A3B-Instruct-FP8"
QWEN_3_VL_30B_A3B_FP8
=
"Qwen/Qwen3-VL-30B-A3B-Instruct-FP8"
QWEN_3_VL_8B_FP8
=
"Qwen/Qwen3-VL-8B-Instruct-FP8"
QWEN_3_VL_8B_FP8
=
"Qwen/Qwen3-VL-8B-Instruct-FP8"
QWEN_3_VL_4B
=
"Qwen/Qwen3-VL-4B-Instruct"
QWEN_3_VL_4B_FP8
=
"Qwen/Qwen3-VL-4B-Instruct-FP8"
LLAVA_NEXT_VIDEO_7B
=
"llava-hf/LLaVA-NeXT-Video-7B-hf"
LLAVA_NEXT_VIDEO_7B
=
"llava-hf/LLaVA-NeXT-Video-7B-hf"
...
@@ -124,6 +130,8 @@ QWEN_VL_MODELS = [
...
@@ -124,6 +130,8 @@ QWEN_VL_MODELS = [
SupportedModels
.
QWEN_3_VL_30B_A3B
,
SupportedModels
.
QWEN_3_VL_30B_A3B
,
SupportedModels
.
QWEN_3_VL_30B_A3B_FP8
,
SupportedModels
.
QWEN_3_VL_30B_A3B_FP8
,
SupportedModels
.
QWEN_3_VL_8B_FP8
,
SupportedModels
.
QWEN_3_VL_8B_FP8
,
SupportedModels
.
QWEN_3_VL_4B
,
SupportedModels
.
QWEN_3_VL_4B_FP8
,
]
]
...
@@ -159,7 +167,7 @@ def load_vision_model(model_id: str) -> torch.nn.Module:
...
@@ -159,7 +167,7 @@ def load_vision_model(model_id: str) -> torch.nn.Module:
# Load only the vision model via vLLM
# Load only the vision model via vLLM
vllm_model
=
LLM
(
vllm_model
=
LLM
(
model
=
model_id
,
model
=
model_id
,
enforce_eager
=
Tru
e
,
enforce_eager
=
Fals
e
,
kv_cache_memory_bytes
=
1024
kv_cache_memory_bytes
=
1024
*
1024
*
1024
*
8
,
# 8MB KV cache for vLLM to complete the init lifecycle, encoder-only doesn't require KV cache.
*
8
,
# 8MB KV cache for vLLM to complete the init lifecycle, encoder-only doesn't require KV cache.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment