Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
d644d88d
Unverified
Commit
d644d88d
authored
Apr 17, 2026
by
Graham King
Committed by
GitHub
Apr 17, 2026
Browse files
fix: Use vllm to load prompt embeds (#8228)
Signed-off-by:
Graham King
<
grahamk@nvidia.com
>
parent
61d4674c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
55 additions
and
53 deletions
+55
-53
components/src/dynamo/vllm/handlers.py
components/src/dynamo/vllm/handlers.py
+39
-48
components/src/dynamo/vllm/tests/test_vllm_prompt_embeds.py
components/src/dynamo/vllm/tests/test_vllm_prompt_embeds.py
+3
-4
components/src/dynamo/vllm/tests/test_vllm_worker_handler.py
components/src/dynamo/vllm/tests/test_vllm_worker_handler.py
+11
-1
components/src/dynamo/vllm/worker_factory.py
components/src/dynamo/vllm/worker_factory.py
+2
-0
No files found.
components/src/dynamo/vllm/handlers.py
View file @
d644d88d
...
...
@@ -2,9 +2,6 @@
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
base64
import
binascii
import
io
import
logging
import
os
import
tempfile
...
...
@@ -16,10 +13,11 @@ from dataclasses import dataclass
from
typing
import
Any
,
AsyncIterator
,
Dict
,
Final
,
Generic
,
Optional
,
TypeVar
import
torch
from
vllm.config
import
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.inputs
import
EmbedsPrompt
,
TextPrompt
,
TokensPrompt
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.renderers.embed_utils
import
safe_load_prompt_embeds
from
vllm.sampling_params
import
SamplingParams
,
StructuredOutputsParams
from
vllm.v1.engine.exceptions
import
EngineDeadError
...
...
@@ -371,6 +369,7 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
engine
,
default_sampling_params
,
model_max_len
:
int
|
None
=
None
,
model_config
:
ModelConfig
|
None
=
None
,
enable_multimodal
:
bool
=
False
,
generate_endpoint
=
None
,
use_vllm_tokenizer
:
bool
=
False
,
...
...
@@ -388,6 +387,7 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
self
.
engine_monitor
=
VllmEngineMonitor
(
runtime
,
engine
,
shutdown_event
)
self
.
temp_dirs
:
list
[
tempfile
.
TemporaryDirectory
]
=
[]
self
.
model_max_len
=
model_max_len
self
.
model_config
=
model_config
self
.
enable_multimodal
=
enable_multimodal
# LoRA tracking: name -> LoRAInfo(id, path)
self
.
loaded_loras
:
dict
[
str
,
LoRAInfo
]
=
{}
...
...
@@ -1105,41 +1105,31 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
"""
Decode base64-encoded prompt embeddings in PyTorch format.
Use vllm's safe loader to prevent out-of-bounds writes from maliciously crafted tensors.
Format: PyTorch tensor serialized with torch.save() and base64-encoded.
Args:
prompt_embeds_base64: Base64-encoded PyTorch tensor
Returns:
torch.Tensor: Decoded prompt embeddings with
preserved shape and dtype
torch.Tensor: Decoded prompt embeddings with
dim == 2
Raises:
ValueError: If decoding fails or format is invalid
"""
try
:
# Step 1: Decode base64 to bytes
embeds_bytes
=
base64
.
b64decode
(
prompt_embeds_base64
)
# Step 2: Load PyTorch tensor from bytes
buffer
=
io
.
BytesIO
(
embeds_bytes
)
embeddings_tensor
=
torch
.
load
(
buffer
,
weights_only
=
True
)
# Step 3: Validate it's a tensor
if
not
isinstance
(
embeddings_tensor
,
torch
.
Tensor
):
if
not
isinstance
(
prompt_embeds_base64
,
str
):
raise
ValueError
(
f
"prompt_embeds must be a torch.Tensor, got
{
type
(
embeddings_tensor
)
}
"
)
logger
.
debug
(
f
"Decoded PyTorch format embeddings: shape=
{
embeddings_tensor
.
shape
}
, "
f
"dtype=
{
embeddings_tensor
.
dtype
}
, size=
{
len
(
embeds_bytes
)
}
bytes"
f
"Prompt embeds must be base64 encoded string. Got
{
type
(
prompt_embeds_base64
)
}
."
)
return
embeddings_tensor
if
self
.
model_config
is
None
:
raise
ValueError
(
"ModelConfig is unavailable for prompt_embeds validation."
)
except
binascii
.
Error
as
e
:
logger
.
error
(
f
"Invalid base64 encoding in prompt_embeds:
{
e
}
"
)
raise
ValueError
(
f
"Invalid base64 encoding in prompt_embeds:
{
e
}
"
)
try
:
return
safe_load_prompt_embeds
(
self
.
model_config
,
prompt_embeds_base64
.
encode
()
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to decode prompt_embeds:
{
e
}
"
)
raise
ValueError
(
f
"Failed to decode prompt_embeds as PyTorch tensor:
{
e
}
"
)
...
...
@@ -1163,15 +1153,12 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
ValueError: If decoding fails or tensor is invalid
"""
embeddings_tensor
=
self
.
_decode_prompt_embeds
(
prompt_embeds_base64
)
if
embeddings_tensor
.
dim
()
!=
2
:
raise
ValueError
(
f
"prompt embeds should have dim 2 after vllm processing, but found dim
{
embeddings_tensor
.
dim
()
}
"
)
# Extract sequence length from tensor shape for usage reporting
# Shape is typically (sequence_length, hidden_dim) or (batch, sequence_length, hidden_dim)
if
embeddings_tensor
.
dim
()
==
2
:
sequence_length
=
embeddings_tensor
.
shape
[
0
]
elif
embeddings_tensor
.
dim
()
==
3
:
sequence_length
=
embeddings_tensor
.
shape
[
1
]
else
:
# Fallback for unexpected shapes
sequence_length
=
embeddings_tensor
.
shape
[
0
]
# EmbedsInputs TypedDict has: {type: 'embeds', prompt_embeds: Tensor, cache_salt?: str}
...
...
@@ -1627,6 +1614,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
engine
,
default_sampling_params
,
model_max_len
:
int
|
None
=
None
,
model_config
:
ModelConfig
|
None
=
None
,
enable_multimodal
:
bool
=
False
,
generate_endpoint
=
None
,
use_vllm_tokenizer
:
bool
=
False
,
...
...
@@ -1639,13 +1627,14 @@ class DecodeWorkerHandler(BaseWorkerHandler):
config
,
engine
,
default_sampling_params
,
model_max_len
,
enable_multimodal
,
generate_endpoint
,
use_vllm_tokenizer
,
shutdown_event
,
enable_frontend_decoding
,
encode_worker_client
,
model_max_len
=
model_max_len
,
model_config
=
model_config
,
enable_multimodal
=
enable_multimodal
,
generate_endpoint
=
generate_endpoint
,
use_vllm_tokenizer
=
use_vllm_tokenizer
,
shutdown_event
=
shutdown_event
,
enable_frontend_decoding
=
enable_frontend_decoding
,
encode_worker_client
=
encode_worker_client
,
)
async
def
generate
(
self
,
request
,
context
):
...
...
@@ -1904,6 +1893,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
engine
,
default_sampling_params
,
model_max_len
:
int
|
None
=
None
,
model_config
:
ModelConfig
|
None
=
None
,
enable_multimodal
:
bool
=
False
,
generate_endpoint
=
None
,
use_vllm_tokenizer
:
bool
=
False
,
...
...
@@ -1916,13 +1906,14 @@ class PrefillWorkerHandler(BaseWorkerHandler):
config
,
engine
,
default_sampling_params
,
model_max_len
,
enable_multimodal
,
generate_endpoint
,
use_vllm_tokenizer
,
shutdown_event
,
enable_frontend_decoding
,
encode_worker_client
,
model_max_len
=
model_max_len
,
model_config
=
model_config
,
enable_multimodal
=
enable_multimodal
,
generate_endpoint
=
generate_endpoint
,
use_vllm_tokenizer
=
use_vllm_tokenizer
,
shutdown_event
=
shutdown_event
,
enable_frontend_decoding
=
enable_frontend_decoding
,
encode_worker_client
=
encode_worker_client
,
)
# Cache Qwen VL grid parameters for computing image_grid_thw from
...
...
components/src/dynamo/vllm/tests/test_vllm_prompt_embeds.py
View file @
d644d88d
...
...
@@ -29,6 +29,7 @@ def mock_handler():
pass
handler
=
MockHandler
()
handler
.
model_config
=
Mock
(
enable_prompt_embeds
=
True
)
handler
.
_decode_prompt_embeds
=
BaseWorkerHandler
.
_decode_prompt_embeds
.
__get__
(
# type: ignore
handler
)
...
...
@@ -51,10 +52,8 @@ class TestPromptEmbedsDecode:
[
((
10
,
4096
),
torch
.
float32
),
# 2D: sequence x hidden
((
10
,
768
),
torch
.
float32
),
# 2D: smaller hidden dim
((
2
,
10
,
768
),
torch
.
float32
),
# 3D: batch x sequence x hidden
((
5
,
20
,
1024
),
torch
.
float16
),
# 3D with float16
],
ids
=
[
"2d-4096"
,
"2d-768"
,
"3d-batch"
,
"3d-float16"
],
ids
=
[
"2d-4096"
,
"2d-768"
],
)
def
test_decode_valid_embeddings_various_shapes
(
self
,
mock_handler
,
shape
,
dtype
):
"""Test decoding embeddings with various shapes and dtypes."""
...
...
@@ -113,7 +112,7 @@ class TestPromptEmbedsDecode:
non_tensor
=
{
"key"
:
"value"
}
embeddings_base64
=
encode_tensor_to_base64_obj
(
non_tensor
)
with
pytest
.
raises
(
ValueError
,
match
=
"
must be a torch.Tensor
"
):
with
pytest
.
raises
(
ValueError
,
match
=
"
Failed to decode
"
):
mock_handler
.
_decode_prompt_embeds
(
embeddings_base64
)
...
...
components/src/dynamo/vllm/tests/test_vllm_worker_handler.py
View file @
d644d88d
...
...
@@ -74,14 +74,18 @@ def _make_handler(
"""Construct a handler with BaseWorkerHandler.__init__ bypassed."""
if
config
is
None
:
config
=
_make_config
()
model_config
=
MagicMock
(
enable_prompt_embeds
=
True
)
with
patch
.
object
(
mod
.
BaseWorkerHandler
,
"__init__"
,
return_value
=
None
):
return
mod
.
DecodeWorkerHandler
(
handler
=
mod
.
DecodeWorkerHandler
(
runtime
=
MagicMock
(),
config
=
config
,
engine
=
MagicMock
(),
default_sampling_params
=
{},
model_config
=
model_config
,
encode_worker_client
=
encode_worker_client
,
)
handler
.
model_config
=
model_config
return
handler
def
_make_raw_frontend_request
(
image_urls
:
list
[
str
]
|
None
=
None
)
->
dict
:
...
...
@@ -317,14 +321,17 @@ def _make_decode_handler(
)
->
mod
.
DecodeWorkerHandler
:
"""Construct a DecodeWorkerHandler with mocked internals."""
config
=
_make_config
(
model
=
model
,
disaggregation_mode
=
disaggregation_mode
)
model_config
=
MagicMock
(
enable_prompt_embeds
=
True
)
with
patch
.
object
(
mod
.
BaseWorkerHandler
,
"__init__"
,
return_value
=
None
):
handler
=
mod
.
DecodeWorkerHandler
(
runtime
=
MagicMock
(),
config
=
config
,
engine
=
MagicMock
(),
default_sampling_params
=
{},
model_config
=
model_config
,
)
handler
.
config
=
config
handler
.
model_config
=
model_config
handler
.
enable_multimodal
=
True
handler
.
image_loader
=
MagicMock
()
handler
.
embedding_loader
=
None
...
...
@@ -462,14 +469,17 @@ def _make_prefill_handler(model: str = "test-model") -> mod.PrefillWorkerHandler
config
=
_make_config
(
model
=
model
,
is_prefill_worker
=
True
,
disaggregation_mode
=
"PREFILL"
)
model_config
=
MagicMock
(
enable_prompt_embeds
=
True
)
with
patch
.
object
(
mod
.
BaseWorkerHandler
,
"__init__"
,
return_value
=
None
):
handler
=
mod
.
PrefillWorkerHandler
(
runtime
=
MagicMock
(),
config
=
config
,
engine
=
MagicMock
(),
default_sampling_params
=
{},
model_config
=
model_config
,
)
handler
.
config
=
config
handler
.
model_config
=
model_config
return
handler
...
...
components/src/dynamo/vllm/worker_factory.py
View file @
d644d88d
...
...
@@ -279,6 +279,7 @@ class WorkerFactory:
engine_client
,
default_sampling_params
,
getattr
(
getattr
(
vllm_config
,
"model_config"
,
None
),
"max_model_len"
,
None
),
model_config
=
getattr
(
vllm_config
,
"model_config"
,
None
),
enable_multimodal
=
config
.
enable_multimodal
,
generate_endpoint
=
generate_endpoint
,
use_vllm_tokenizer
=
config
.
use_vllm_tokenizer
,
...
...
@@ -513,6 +514,7 @@ class WorkerFactory:
engine_client
,
default_sampling_params
,
getattr
(
getattr
(
vllm_config
,
"model_config"
,
None
),
"max_model_len"
,
None
),
model_config
=
getattr
(
vllm_config
,
"model_config"
,
None
),
enable_multimodal
=
config
.
enable_multimodal
,
generate_endpoint
=
generate_endpoint
,
use_vllm_tokenizer
=
config
.
use_vllm_tokenizer
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment