Unverified Commit d644d88d authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix: Use vllm to load prompt embeds (#8228)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 61d4674c
...@@ -2,9 +2,6 @@ ...@@ -2,9 +2,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import base64
import binascii
import io
import logging import logging
import os import os
import tempfile import tempfile
...@@ -16,10 +13,11 @@ from dataclasses import dataclass ...@@ -16,10 +13,11 @@ from dataclasses import dataclass
from typing import Any, AsyncIterator, Dict, Final, Generic, Optional, TypeVar from typing import Any, AsyncIterator, Dict, Final, Generic, Optional, TypeVar
import torch import torch
from vllm.config import VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.renderers.embed_utils import safe_load_prompt_embeds
from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.engine.exceptions import EngineDeadError
...@@ -371,6 +369,7 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]): ...@@ -371,6 +369,7 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
engine, engine,
default_sampling_params, default_sampling_params,
model_max_len: int | None = None, model_max_len: int | None = None,
model_config: ModelConfig | None = None,
enable_multimodal: bool = False, enable_multimodal: bool = False,
generate_endpoint=None, generate_endpoint=None,
use_vllm_tokenizer: bool = False, use_vllm_tokenizer: bool = False,
...@@ -388,6 +387,7 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]): ...@@ -388,6 +387,7 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
self.engine_monitor = VllmEngineMonitor(runtime, engine, shutdown_event) self.engine_monitor = VllmEngineMonitor(runtime, engine, shutdown_event)
self.temp_dirs: list[tempfile.TemporaryDirectory] = [] self.temp_dirs: list[tempfile.TemporaryDirectory] = []
self.model_max_len = model_max_len self.model_max_len = model_max_len
self.model_config = model_config
self.enable_multimodal = enable_multimodal self.enable_multimodal = enable_multimodal
# LoRA tracking: name -> LoRAInfo(id, path) # LoRA tracking: name -> LoRAInfo(id, path)
self.loaded_loras: dict[str, LoRAInfo] = {} self.loaded_loras: dict[str, LoRAInfo] = {}
...@@ -1105,41 +1105,31 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]): ...@@ -1105,41 +1105,31 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
""" """
Decode base64-encoded prompt embeddings in PyTorch format. Decode base64-encoded prompt embeddings in PyTorch format.
Use vllm's safe loader to prevent out-of-bounds writes from maliciously crafted tensors.
Format: PyTorch tensor serialized with torch.save() and base64-encoded. Format: PyTorch tensor serialized with torch.save() and base64-encoded.
Args: Args:
prompt_embeds_base64: Base64-encoded PyTorch tensor prompt_embeds_base64: Base64-encoded PyTorch tensor
Returns: Returns:
torch.Tensor: Decoded prompt embeddings with preserved shape and dtype torch.Tensor: Decoded prompt embeddings with dim == 2
Raises: Raises:
ValueError: If decoding fails or format is invalid ValueError: If decoding fails or format is invalid
""" """
try: if not isinstance(prompt_embeds_base64, str):
# Step 1: Decode base64 to bytes raise ValueError(
embeds_bytes = base64.b64decode(prompt_embeds_base64) f"Prompt embeds must be base64 encoded string. Got {type(prompt_embeds_base64)}."
# Step 2: Load PyTorch tensor from bytes
buffer = io.BytesIO(embeds_bytes)
embeddings_tensor = torch.load(buffer, weights_only=True)
# Step 3: Validate it's a tensor
if not isinstance(embeddings_tensor, torch.Tensor):
raise ValueError(
f"prompt_embeds must be a torch.Tensor, got {type(embeddings_tensor)}"
)
logger.debug(
f"Decoded PyTorch format embeddings: shape={embeddings_tensor.shape}, "
f"dtype={embeddings_tensor.dtype}, size={len(embeds_bytes)} bytes"
) )
return embeddings_tensor if self.model_config is None:
raise ValueError("ModelConfig is unavailable for prompt_embeds validation.")
except binascii.Error as e: try:
logger.error(f"Invalid base64 encoding in prompt_embeds: {e}") return safe_load_prompt_embeds(
raise ValueError(f"Invalid base64 encoding in prompt_embeds: {e}") self.model_config, prompt_embeds_base64.encode()
)
except Exception as e: except Exception as e:
logger.error(f"Failed to decode prompt_embeds: {e}") logger.error(f"Failed to decode prompt_embeds: {e}")
raise ValueError(f"Failed to decode prompt_embeds as PyTorch tensor: {e}") raise ValueError(f"Failed to decode prompt_embeds as PyTorch tensor: {e}")
...@@ -1163,16 +1153,13 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]): ...@@ -1163,16 +1153,13 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
ValueError: If decoding fails or tensor is invalid ValueError: If decoding fails or tensor is invalid
""" """
embeddings_tensor = self._decode_prompt_embeds(prompt_embeds_base64) embeddings_tensor = self._decode_prompt_embeds(prompt_embeds_base64)
if embeddings_tensor.dim() != 2:
raise ValueError(
f"prompt embeds should have dim 2 after vllm processing, but found dim {embeddings_tensor.dim()}"
)
# Extract sequence length from tensor shape for usage reporting # Extract sequence length from tensor shape for usage reporting
# Shape is typically (sequence_length, hidden_dim) or (batch, sequence_length, hidden_dim) sequence_length = embeddings_tensor.shape[0]
if embeddings_tensor.dim() == 2:
sequence_length = embeddings_tensor.shape[0]
elif embeddings_tensor.dim() == 3:
sequence_length = embeddings_tensor.shape[1]
else:
# Fallback for unexpected shapes
sequence_length = embeddings_tensor.shape[0]
# EmbedsInputs TypedDict has: {type: 'embeds', prompt_embeds: Tensor, cache_salt?: str} # EmbedsInputs TypedDict has: {type: 'embeds', prompt_embeds: Tensor, cache_salt?: str}
prompt = EmbedsPrompt(prompt_embeds=embeddings_tensor) prompt = EmbedsPrompt(prompt_embeds=embeddings_tensor)
...@@ -1627,6 +1614,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -1627,6 +1614,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
engine, engine,
default_sampling_params, default_sampling_params,
model_max_len: int | None = None, model_max_len: int | None = None,
model_config: ModelConfig | None = None,
enable_multimodal: bool = False, enable_multimodal: bool = False,
generate_endpoint=None, generate_endpoint=None,
use_vllm_tokenizer: bool = False, use_vllm_tokenizer: bool = False,
...@@ -1639,13 +1627,14 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -1639,13 +1627,14 @@ class DecodeWorkerHandler(BaseWorkerHandler):
config, config,
engine, engine,
default_sampling_params, default_sampling_params,
model_max_len, model_max_len=model_max_len,
enable_multimodal, model_config=model_config,
generate_endpoint, enable_multimodal=enable_multimodal,
use_vllm_tokenizer, generate_endpoint=generate_endpoint,
shutdown_event, use_vllm_tokenizer=use_vllm_tokenizer,
enable_frontend_decoding, shutdown_event=shutdown_event,
encode_worker_client, enable_frontend_decoding=enable_frontend_decoding,
encode_worker_client=encode_worker_client,
) )
async def generate(self, request, context): async def generate(self, request, context):
...@@ -1904,6 +1893,7 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -1904,6 +1893,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
engine, engine,
default_sampling_params, default_sampling_params,
model_max_len: int | None = None, model_max_len: int | None = None,
model_config: ModelConfig | None = None,
enable_multimodal: bool = False, enable_multimodal: bool = False,
generate_endpoint=None, generate_endpoint=None,
use_vllm_tokenizer: bool = False, use_vllm_tokenizer: bool = False,
...@@ -1916,13 +1906,14 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -1916,13 +1906,14 @@ class PrefillWorkerHandler(BaseWorkerHandler):
config, config,
engine, engine,
default_sampling_params, default_sampling_params,
model_max_len, model_max_len=model_max_len,
enable_multimodal, model_config=model_config,
generate_endpoint, enable_multimodal=enable_multimodal,
use_vllm_tokenizer, generate_endpoint=generate_endpoint,
shutdown_event, use_vllm_tokenizer=use_vllm_tokenizer,
enable_frontend_decoding, shutdown_event=shutdown_event,
encode_worker_client, enable_frontend_decoding=enable_frontend_decoding,
encode_worker_client=encode_worker_client,
) )
# Cache Qwen VL grid parameters for computing image_grid_thw from # Cache Qwen VL grid parameters for computing image_grid_thw from
......
...@@ -29,6 +29,7 @@ def mock_handler(): ...@@ -29,6 +29,7 @@ def mock_handler():
pass pass
handler = MockHandler() handler = MockHandler()
handler.model_config = Mock(enable_prompt_embeds=True)
handler._decode_prompt_embeds = BaseWorkerHandler._decode_prompt_embeds.__get__( # type: ignore handler._decode_prompt_embeds = BaseWorkerHandler._decode_prompt_embeds.__get__( # type: ignore
handler handler
) )
...@@ -51,10 +52,8 @@ class TestPromptEmbedsDecode: ...@@ -51,10 +52,8 @@ class TestPromptEmbedsDecode:
[ [
((10, 4096), torch.float32), # 2D: sequence x hidden ((10, 4096), torch.float32), # 2D: sequence x hidden
((10, 768), torch.float32), # 2D: smaller hidden dim ((10, 768), torch.float32), # 2D: smaller hidden dim
((2, 10, 768), torch.float32), # 3D: batch x sequence x hidden
((5, 20, 1024), torch.float16), # 3D with float16
], ],
ids=["2d-4096", "2d-768", "3d-batch", "3d-float16"], ids=["2d-4096", "2d-768"],
) )
def test_decode_valid_embeddings_various_shapes(self, mock_handler, shape, dtype): def test_decode_valid_embeddings_various_shapes(self, mock_handler, shape, dtype):
"""Test decoding embeddings with various shapes and dtypes.""" """Test decoding embeddings with various shapes and dtypes."""
...@@ -113,7 +112,7 @@ class TestPromptEmbedsDecode: ...@@ -113,7 +112,7 @@ class TestPromptEmbedsDecode:
non_tensor = {"key": "value"} non_tensor = {"key": "value"}
embeddings_base64 = encode_tensor_to_base64_obj(non_tensor) embeddings_base64 = encode_tensor_to_base64_obj(non_tensor)
with pytest.raises(ValueError, match="must be a torch.Tensor"): with pytest.raises(ValueError, match="Failed to decode"):
mock_handler._decode_prompt_embeds(embeddings_base64) mock_handler._decode_prompt_embeds(embeddings_base64)
......
...@@ -74,14 +74,18 @@ def _make_handler( ...@@ -74,14 +74,18 @@ def _make_handler(
"""Construct a handler with BaseWorkerHandler.__init__ bypassed.""" """Construct a handler with BaseWorkerHandler.__init__ bypassed."""
if config is None: if config is None:
config = _make_config() config = _make_config()
model_config = MagicMock(enable_prompt_embeds=True)
with patch.object(mod.BaseWorkerHandler, "__init__", return_value=None): with patch.object(mod.BaseWorkerHandler, "__init__", return_value=None):
return mod.DecodeWorkerHandler( handler = mod.DecodeWorkerHandler(
runtime=MagicMock(), runtime=MagicMock(),
config=config, config=config,
engine=MagicMock(), engine=MagicMock(),
default_sampling_params={}, default_sampling_params={},
model_config=model_config,
encode_worker_client=encode_worker_client, encode_worker_client=encode_worker_client,
) )
handler.model_config = model_config
return handler
def _make_raw_frontend_request(image_urls: list[str] | None = None) -> dict: def _make_raw_frontend_request(image_urls: list[str] | None = None) -> dict:
...@@ -317,14 +321,17 @@ def _make_decode_handler( ...@@ -317,14 +321,17 @@ def _make_decode_handler(
) -> mod.DecodeWorkerHandler: ) -> mod.DecodeWorkerHandler:
"""Construct a DecodeWorkerHandler with mocked internals.""" """Construct a DecodeWorkerHandler with mocked internals."""
config = _make_config(model=model, disaggregation_mode=disaggregation_mode) config = _make_config(model=model, disaggregation_mode=disaggregation_mode)
model_config = MagicMock(enable_prompt_embeds=True)
with patch.object(mod.BaseWorkerHandler, "__init__", return_value=None): with patch.object(mod.BaseWorkerHandler, "__init__", return_value=None):
handler = mod.DecodeWorkerHandler( handler = mod.DecodeWorkerHandler(
runtime=MagicMock(), runtime=MagicMock(),
config=config, config=config,
engine=MagicMock(), engine=MagicMock(),
default_sampling_params={}, default_sampling_params={},
model_config=model_config,
) )
handler.config = config handler.config = config
handler.model_config = model_config
handler.enable_multimodal = True handler.enable_multimodal = True
handler.image_loader = MagicMock() handler.image_loader = MagicMock()
handler.embedding_loader = None handler.embedding_loader = None
...@@ -462,14 +469,17 @@ def _make_prefill_handler(model: str = "test-model") -> mod.PrefillWorkerHandler ...@@ -462,14 +469,17 @@ def _make_prefill_handler(model: str = "test-model") -> mod.PrefillWorkerHandler
config = _make_config( config = _make_config(
model=model, is_prefill_worker=True, disaggregation_mode="PREFILL" model=model, is_prefill_worker=True, disaggregation_mode="PREFILL"
) )
model_config = MagicMock(enable_prompt_embeds=True)
with patch.object(mod.BaseWorkerHandler, "__init__", return_value=None): with patch.object(mod.BaseWorkerHandler, "__init__", return_value=None):
handler = mod.PrefillWorkerHandler( handler = mod.PrefillWorkerHandler(
runtime=MagicMock(), runtime=MagicMock(),
config=config, config=config,
engine=MagicMock(), engine=MagicMock(),
default_sampling_params={}, default_sampling_params={},
model_config=model_config,
) )
handler.config = config handler.config = config
handler.model_config = model_config
return handler return handler
......
...@@ -279,6 +279,7 @@ class WorkerFactory: ...@@ -279,6 +279,7 @@ class WorkerFactory:
engine_client, engine_client,
default_sampling_params, default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None), getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
model_config=getattr(vllm_config, "model_config", None),
enable_multimodal=config.enable_multimodal, enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint, generate_endpoint=generate_endpoint,
use_vllm_tokenizer=config.use_vllm_tokenizer, use_vllm_tokenizer=config.use_vllm_tokenizer,
...@@ -513,6 +514,7 @@ class WorkerFactory: ...@@ -513,6 +514,7 @@ class WorkerFactory:
engine_client, engine_client,
default_sampling_params, default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None), getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
model_config=getattr(vllm_config, "model_config", None),
enable_multimodal=config.enable_multimodal, enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint, generate_endpoint=generate_endpoint,
use_vllm_tokenizer=config.use_vllm_tokenizer, use_vllm_tokenizer=config.use_vllm_tokenizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment