Unverified Commit d644d88d authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix: Use vllm to load prompt embeds (#8228)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 61d4674c
......@@ -2,9 +2,6 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import base64
import binascii
import io
import logging
import os
import tempfile
......@@ -16,10 +13,11 @@ from dataclasses import dataclass
from typing import Any, AsyncIterator, Dict, Final, Generic, Optional, TypeVar
import torch
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.renderers.embed_utils import safe_load_prompt_embeds
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.engine.exceptions import EngineDeadError
......@@ -371,6 +369,7 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
engine,
default_sampling_params,
model_max_len: int | None = None,
model_config: ModelConfig | None = None,
enable_multimodal: bool = False,
generate_endpoint=None,
use_vllm_tokenizer: bool = False,
......@@ -388,6 +387,7 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
self.engine_monitor = VllmEngineMonitor(runtime, engine, shutdown_event)
self.temp_dirs: list[tempfile.TemporaryDirectory] = []
self.model_max_len = model_max_len
self.model_config = model_config
self.enable_multimodal = enable_multimodal
# LoRA tracking: name -> LoRAInfo(id, path)
self.loaded_loras: dict[str, LoRAInfo] = {}
......@@ -1105,41 +1105,31 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
"""
Decode base64-encoded prompt embeddings in PyTorch format.
Use vllm's safe loader to prevent out-of-bounds writes from maliciously crafted tensors.
Format: PyTorch tensor serialized with torch.save() and base64-encoded.
Args:
prompt_embeds_base64: Base64-encoded PyTorch tensor
Returns:
torch.Tensor: Decoded prompt embeddings with preserved shape and dtype
torch.Tensor: Decoded prompt embeddings with dim == 2
Raises:
ValueError: If decoding fails or format is invalid
"""
try:
# Step 1: Decode base64 to bytes
embeds_bytes = base64.b64decode(prompt_embeds_base64)
# Step 2: Load PyTorch tensor from bytes
buffer = io.BytesIO(embeds_bytes)
embeddings_tensor = torch.load(buffer, weights_only=True)
# Step 3: Validate it's a tensor
if not isinstance(embeddings_tensor, torch.Tensor):
if not isinstance(prompt_embeds_base64, str):
raise ValueError(
f"prompt_embeds must be a torch.Tensor, got {type(embeddings_tensor)}"
)
logger.debug(
f"Decoded PyTorch format embeddings: shape={embeddings_tensor.shape}, "
f"dtype={embeddings_tensor.dtype}, size={len(embeds_bytes)} bytes"
f"Prompt embeds must be base64 encoded string. Got {type(prompt_embeds_base64)}."
)
return embeddings_tensor
if self.model_config is None:
raise ValueError("ModelConfig is unavailable for prompt_embeds validation.")
except binascii.Error as e:
logger.error(f"Invalid base64 encoding in prompt_embeds: {e}")
raise ValueError(f"Invalid base64 encoding in prompt_embeds: {e}")
try:
return safe_load_prompt_embeds(
self.model_config, prompt_embeds_base64.encode()
)
except Exception as e:
logger.error(f"Failed to decode prompt_embeds: {e}")
raise ValueError(f"Failed to decode prompt_embeds as PyTorch tensor: {e}")
......@@ -1163,15 +1153,12 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
ValueError: If decoding fails or tensor is invalid
"""
embeddings_tensor = self._decode_prompt_embeds(prompt_embeds_base64)
if embeddings_tensor.dim() != 2:
raise ValueError(
f"prompt embeds should have dim 2 after vllm processing, but found dim {embeddings_tensor.dim()}"
)
# Extract sequence length from tensor shape for usage reporting
# Shape is typically (sequence_length, hidden_dim) or (batch, sequence_length, hidden_dim)
if embeddings_tensor.dim() == 2:
sequence_length = embeddings_tensor.shape[0]
elif embeddings_tensor.dim() == 3:
sequence_length = embeddings_tensor.shape[1]
else:
# Fallback for unexpected shapes
sequence_length = embeddings_tensor.shape[0]
# EmbedsInputs TypedDict has: {type: 'embeds', prompt_embeds: Tensor, cache_salt?: str}
......@@ -1627,6 +1614,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
engine,
default_sampling_params,
model_max_len: int | None = None,
model_config: ModelConfig | None = None,
enable_multimodal: bool = False,
generate_endpoint=None,
use_vllm_tokenizer: bool = False,
......@@ -1639,13 +1627,14 @@ class DecodeWorkerHandler(BaseWorkerHandler):
config,
engine,
default_sampling_params,
model_max_len,
enable_multimodal,
generate_endpoint,
use_vllm_tokenizer,
shutdown_event,
enable_frontend_decoding,
encode_worker_client,
model_max_len=model_max_len,
model_config=model_config,
enable_multimodal=enable_multimodal,
generate_endpoint=generate_endpoint,
use_vllm_tokenizer=use_vllm_tokenizer,
shutdown_event=shutdown_event,
enable_frontend_decoding=enable_frontend_decoding,
encode_worker_client=encode_worker_client,
)
async def generate(self, request, context):
......@@ -1904,6 +1893,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
engine,
default_sampling_params,
model_max_len: int | None = None,
model_config: ModelConfig | None = None,
enable_multimodal: bool = False,
generate_endpoint=None,
use_vllm_tokenizer: bool = False,
......@@ -1916,13 +1906,14 @@ class PrefillWorkerHandler(BaseWorkerHandler):
config,
engine,
default_sampling_params,
model_max_len,
enable_multimodal,
generate_endpoint,
use_vllm_tokenizer,
shutdown_event,
enable_frontend_decoding,
encode_worker_client,
model_max_len=model_max_len,
model_config=model_config,
enable_multimodal=enable_multimodal,
generate_endpoint=generate_endpoint,
use_vllm_tokenizer=use_vllm_tokenizer,
shutdown_event=shutdown_event,
enable_frontend_decoding=enable_frontend_decoding,
encode_worker_client=encode_worker_client,
)
# Cache Qwen VL grid parameters for computing image_grid_thw from
......
......@@ -29,6 +29,7 @@ def mock_handler():
pass
handler = MockHandler()
handler.model_config = Mock(enable_prompt_embeds=True)
handler._decode_prompt_embeds = BaseWorkerHandler._decode_prompt_embeds.__get__( # type: ignore
handler
)
......@@ -51,10 +52,8 @@ class TestPromptEmbedsDecode:
[
((10, 4096), torch.float32), # 2D: sequence x hidden
((10, 768), torch.float32), # 2D: smaller hidden dim
((2, 10, 768), torch.float32), # 3D: batch x sequence x hidden
((5, 20, 1024), torch.float16), # 3D with float16
],
ids=["2d-4096", "2d-768", "3d-batch", "3d-float16"],
ids=["2d-4096", "2d-768"],
)
def test_decode_valid_embeddings_various_shapes(self, mock_handler, shape, dtype):
"""Test decoding embeddings with various shapes and dtypes."""
......@@ -113,7 +112,7 @@ class TestPromptEmbedsDecode:
non_tensor = {"key": "value"}
embeddings_base64 = encode_tensor_to_base64_obj(non_tensor)
with pytest.raises(ValueError, match="must be a torch.Tensor"):
with pytest.raises(ValueError, match="Failed to decode"):
mock_handler._decode_prompt_embeds(embeddings_base64)
......
......@@ -74,14 +74,18 @@ def _make_handler(
"""Construct a handler with BaseWorkerHandler.__init__ bypassed."""
if config is None:
config = _make_config()
model_config = MagicMock(enable_prompt_embeds=True)
with patch.object(mod.BaseWorkerHandler, "__init__", return_value=None):
return mod.DecodeWorkerHandler(
handler = mod.DecodeWorkerHandler(
runtime=MagicMock(),
config=config,
engine=MagicMock(),
default_sampling_params={},
model_config=model_config,
encode_worker_client=encode_worker_client,
)
handler.model_config = model_config
return handler
def _make_raw_frontend_request(image_urls: list[str] | None = None) -> dict:
......@@ -317,14 +321,17 @@ def _make_decode_handler(
) -> mod.DecodeWorkerHandler:
"""Construct a DecodeWorkerHandler with mocked internals."""
config = _make_config(model=model, disaggregation_mode=disaggregation_mode)
model_config = MagicMock(enable_prompt_embeds=True)
with patch.object(mod.BaseWorkerHandler, "__init__", return_value=None):
handler = mod.DecodeWorkerHandler(
runtime=MagicMock(),
config=config,
engine=MagicMock(),
default_sampling_params={},
model_config=model_config,
)
handler.config = config
handler.model_config = model_config
handler.enable_multimodal = True
handler.image_loader = MagicMock()
handler.embedding_loader = None
......@@ -462,14 +469,17 @@ def _make_prefill_handler(model: str = "test-model") -> mod.PrefillWorkerHandler
config = _make_config(
model=model, is_prefill_worker=True, disaggregation_mode="PREFILL"
)
model_config = MagicMock(enable_prompt_embeds=True)
with patch.object(mod.BaseWorkerHandler, "__init__", return_value=None):
handler = mod.PrefillWorkerHandler(
runtime=MagicMock(),
config=config,
engine=MagicMock(),
default_sampling_params={},
model_config=model_config,
)
handler.config = config
handler.model_config = model_config
return handler
......
......@@ -279,6 +279,7 @@ class WorkerFactory:
engine_client,
default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
model_config=getattr(vllm_config, "model_config", None),
enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint,
use_vllm_tokenizer=config.use_vllm_tokenizer,
......@@ -513,6 +514,7 @@ class WorkerFactory:
engine_client,
default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
model_config=getattr(vllm_config, "model_config", None),
enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint,
use_vllm_tokenizer=config.use_vllm_tokenizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment