Unverified Commit 6f9619a2 authored by Wenqi Glantz's avatar Wenqi Glantz Committed by GitHub
Browse files

feat(vllm): Add prompt embeds support for pre-computed inference inputs (#4739)


Signed-off-by: default avatarWenqi Glantz <wglantz@nvidia.com>
parent bfb95df7
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import base64
import binascii
import io
import logging import logging
import os import os
import tempfile import tempfile
...@@ -11,7 +14,8 @@ from abc import ABC, abstractmethod ...@@ -11,7 +14,8 @@ from abc import ABC, abstractmethod
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Final from typing import Any, AsyncGenerator, Dict, Final
from vllm.inputs import TextPrompt, TokensPrompt import torch
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.sampling_params import SamplingParams, StructuredOutputsParams
...@@ -674,6 +678,85 @@ class BaseWorkerHandler(ABC): ...@@ -674,6 +678,85 @@ class BaseWorkerHandler(ABC):
except Exception as e: except Exception as e:
logger.warning(f"Failed to clean up temp directory: {e}") logger.warning(f"Failed to clean up temp directory: {e}")
def _decode_prompt_embeds(self, prompt_embeds_base64: str):
"""
Decode base64-encoded prompt embeddings in PyTorch format.
Format: PyTorch tensor serialized with torch.save() and base64-encoded.
This matches NIM-LLM's implementation for compatibility.
Args:
prompt_embeds_base64: Base64-encoded PyTorch tensor
Returns:
torch.Tensor: Decoded prompt embeddings with preserved shape and dtype
Raises:
ValueError: If decoding fails or format is invalid
"""
try:
# Step 1: Decode base64 to bytes
embeds_bytes = base64.b64decode(prompt_embeds_base64)
# Step 2: Load PyTorch tensor from bytes
buffer = io.BytesIO(embeds_bytes)
embeddings_tensor = torch.load(buffer, weights_only=True)
# Step 3: Validate it's a tensor
if not isinstance(embeddings_tensor, torch.Tensor):
raise ValueError(
f"prompt_embeds must be a torch.Tensor, got {type(embeddings_tensor)}"
)
logger.debug(
f"Decoded PyTorch format embeddings: shape={embeddings_tensor.shape}, "
f"dtype={embeddings_tensor.dtype}, size={len(embeds_bytes)} bytes"
)
return embeddings_tensor
except binascii.Error as e:
logger.error(f"Invalid base64 encoding in prompt_embeds: {e}")
raise ValueError(f"Invalid base64 encoding in prompt_embeds: {e}")
except Exception as e:
logger.error(f"Failed to decode prompt_embeds: {e}")
raise ValueError(f"Failed to decode prompt_embeds as PyTorch tensor: {e}")
def _create_prompt_from_embeddings(
self, prompt_embeds_base64: str
) -> tuple[EmbedsPrompt, int, torch.Tensor]:
"""
Decode prompt embeddings and create EmbedsPrompt for vLLM.
Args:
prompt_embeds_base64: Base64-encoded PyTorch tensor
Returns:
Tuple of (EmbedsPrompt, sequence_length, tensor) where:
- EmbedsPrompt: The vLLM prompt input
- sequence_length: Extracted from tensor shape for usage statistics
- tensor: The decoded tensor (for logging shape/dtype)
Raises:
ValueError: If decoding fails or tensor is invalid
"""
embeddings_tensor = self._decode_prompt_embeds(prompt_embeds_base64)
# Extract sequence length from tensor shape for usage reporting
# Shape is typically (sequence_length, hidden_dim) or (batch, sequence_length, hidden_dim)
if embeddings_tensor.dim() == 2:
sequence_length = embeddings_tensor.shape[0]
elif embeddings_tensor.dim() == 3:
sequence_length = embeddings_tensor.shape[1]
else:
# Fallback for unexpected shapes
sequence_length = embeddings_tensor.shape[0]
# EmbedsInputs TypedDict has: {type: 'embeds', prompt_embeds: Tensor, cache_salt?: str}
prompt = EmbedsPrompt(prompt_embeds=embeddings_tensor)
return prompt, sequence_length, embeddings_tensor
async def _extract_multimodal_data( async def _extract_multimodal_data(
self, request: Dict[str, Any] self, request: Dict[str, Any]
) -> Dict[str, Any] | None: ) -> Dict[str, Any] | None:
...@@ -725,20 +808,95 @@ class BaseWorkerHandler(ABC): ...@@ -725,20 +808,95 @@ class BaseWorkerHandler(ABC):
return vllm_mm_data if vllm_mm_data else None return vllm_mm_data if vllm_mm_data else None
def _build_prompt_from_request(
self,
request: Dict[str, Any],
request_id: str,
multi_modal_data: Dict[str, Any] | None,
log_prefix: str = "",
) -> tuple[TokensPrompt | EmbedsPrompt | None, int | None, Dict[str, Any] | None]:
"""
Build a prompt from request, handling both prompt_embeds and token_ids.
Args:
request: The request dict containing either prompt_embeds or token_ids
request_id: Request ID for logging
multi_modal_data: Optional multimodal data to attach to TokensPrompt
log_prefix: Prefix for log messages (e.g., "Prefill " for prefill requests)
Returns:
Tuple of (prompt, embedding_sequence_length, error_dict) where:
- On success: (prompt, embedding_sequence_length or None, None)
- On failure: (None, None, error_dict to yield)
"""
embedding_sequence_length = None
if "prompt_embeds" in request and request["prompt_embeds"]:
try:
(
prompt,
embedding_sequence_length,
tensor,
) = self._create_prompt_from_embeddings(request["prompt_embeds"])
logger.info(
f"{log_prefix}Using prompt embeddings: shape={tensor.shape}, "
f"dtype={tensor.dtype}, sequence_length={embedding_sequence_length}, "
f"request_id={request_id}"
)
return prompt, embedding_sequence_length, None
except Exception as e:
logger.error(
f"Failed to process prompt_embeds for {log_prefix.lower().strip() or 'request'} "
f"{request_id}: {e}"
)
return (
None,
None,
{
"finish_reason": f"error: Invalid prompt_embeds: {e}",
"token_ids": [],
},
)
else:
# Normal path: use token IDs
prompt = TokensPrompt(
prompt_token_ids=request["token_ids"], multi_modal_data=multi_modal_data
)
return prompt, embedding_sequence_length, None
@staticmethod @staticmethod
def _build_completion_usage(request_output: RequestOutput) -> Dict[str, Any]: def _build_completion_usage(
request_output: RequestOutput,
embedding_sequence_length: int | None = None,
) -> Dict[str, Any]:
"""
Build completion usage statistics.
Args:
request_output: vLLM RequestOutput object
embedding_sequence_length: If using prompt embeddings, the sequence length
extracted from the embeddings tensor shape
Returns:
Dict with prompt_tokens, completion_tokens, total_tokens, prompt_tokens_details
"""
# Determine prompt token count:
# - For embeddings: use embedding_sequence_length from tensor shape
# - For normal text: use len(prompt_token_ids)
if embedding_sequence_length is not None:
prompt_tokens = embedding_sequence_length
elif request_output.prompt_token_ids:
prompt_tokens = len(request_output.prompt_token_ids)
else:
prompt_tokens = None
completion_tokens = len(request_output.outputs[0].token_ids)
return { return {
"prompt_tokens": ( "prompt_tokens": prompt_tokens,
len(request_output.prompt_token_ids) "completion_tokens": completion_tokens,
if request_output.prompt_token_ids
else None
),
"completion_tokens": len(request_output.outputs[0].token_ids),
"total_tokens": ( "total_tokens": (
len(request_output.prompt_token_ids) prompt_tokens + completion_tokens if prompt_tokens is not None else None
+ len(request_output.outputs[0].token_ids)
if request_output.prompt_token_ids
else None
), ),
"prompt_tokens_details": ( "prompt_tokens_details": (
{"cached_tokens": request_output.num_cached_tokens} {"cached_tokens": request_output.num_cached_tokens}
...@@ -821,6 +979,40 @@ class BaseWorkerHandler(ABC): ...@@ -821,6 +979,40 @@ class BaseWorkerHandler(ABC):
# TODO: properly propagate the trace-flags from current span. # TODO: properly propagate the trace-flags from current span.
return {"traceparent": f"00-{trace_id}-{span_id}-01"} return {"traceparent": f"00-{trace_id}-{span_id}-01"}
@staticmethod
def _log_with_lora_context(
message: str,
request_id: str,
lora_request=None,
level: str = "debug",
**kwargs,
) -> None:
"""
Log a message with optional LoRA context.
Args:
message: Base message to log (can include {lora_info} placeholder)
request_id: Request ID for correlation
lora_request: Optional LoRA request object
level: Log level ("debug" or "info")
**kwargs: Additional format arguments for the message
"""
if lora_request:
lora_info = f" with LoRA {lora_request.lora_name}"
else:
lora_info = ""
formatted_message = message.format(
request_id=request_id,
lora_info=lora_info,
**kwargs,
)
if level == "info":
logger.info(formatted_message)
else:
logger.debug(formatted_message)
async def generate_tokens( async def generate_tokens(
self, self,
prompt, prompt,
...@@ -828,19 +1020,16 @@ class BaseWorkerHandler(ABC): ...@@ -828,19 +1020,16 @@ class BaseWorkerHandler(ABC):
request_id, request_id,
data_parallel_rank=None, data_parallel_rank=None,
lora_request=None, lora_request=None,
embedding_sequence_length=None,
trace_headers=None, trace_headers=None,
): ):
try: try:
# Log LoRA usage for this generation (debug level to avoid log spam) # Log LoRA usage for this generation (debug level to avoid log spam)
if lora_request: self._log_with_lora_context(
logger.debug( "Starting token generation for request {request_id}{lora_info}",
f"Starting token generation for request {request_id} with LoRA: " request_id,
f"{lora_request.lora_name} (ID: {lora_request.lora_int_id})" lora_request,
) )
else:
logger.debug(
f"Starting token generation for request {request_id} (no LoRA)"
)
gen = self.engine_client.generate( gen = self.engine_client.generate(
prompt, prompt,
sampling_params, sampling_params,
...@@ -856,12 +1045,17 @@ class BaseWorkerHandler(ABC): ...@@ -856,12 +1045,17 @@ class BaseWorkerHandler(ABC):
# res is vllm's RequestOutput # res is vllm's RequestOutput
if not res.outputs: if not res.outputs:
if lora_request: self._log_with_lora_context(
logger.debug( "Request {request_id}{lora_info} returned no outputs",
f"Request {request_id} with LoRA {lora_request.lora_name} " request_id,
"returned no outputs" lora_request,
) )
yield {"finish_reason": "error", "token_ids": []} # Use string format "error: message" for consistency with vLLM's string-based finish_reason
# Rust will parse this into FinishReason::Error(message)
yield {
"finish_reason": "error: No outputs from vLLM engine",
"token_ids": [],
}
break break
output = res.outputs[0] output = res.outputs[0]
...@@ -882,20 +1076,18 @@ class BaseWorkerHandler(ABC): ...@@ -882,20 +1076,18 @@ class BaseWorkerHandler(ABC):
out[ out[
"completion_usage" "completion_usage"
] = BaseWorkerHandler._build_completion_usage( ] = BaseWorkerHandler._build_completion_usage(
request_output=res request_output=res,
embedding_sequence_length=embedding_sequence_length,
) )
# Log completion with LoRA info (debug level to avoid log spam) # Log completion with LoRA info (debug level to avoid log spam)
if lora_request: self._log_with_lora_context(
logger.debug( "Completed token generation for request {request_id}{lora_info}: "
f"Completed token generation for request {request_id} with LoRA " "{output_tokens} output tokens, finish_reason={finish_reason}",
f"{lora_request.lora_name}: {next_total_toks} output tokens, " request_id,
f"finish_reason={output.finish_reason}" lora_request,
) output_tokens=next_total_toks,
else: finish_reason=output.finish_reason,
logger.debug( )
f"Completed token generation for request {request_id}: "
f"{next_total_toks} output tokens, finish_reason={output.finish_reason}"
)
if output.stop_reason: if output.stop_reason:
out["stop_reason"] = output.stop_reason out["stop_reason"] = output.stop_reason
yield out yield out
...@@ -957,9 +1149,13 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -957,9 +1149,13 @@ class DecodeWorkerHandler(BaseWorkerHandler):
# Extract and decode multimodal data if present # Extract and decode multimodal data if present
multi_modal_data = await self._extract_multimodal_data(request) multi_modal_data = await self._extract_multimodal_data(request)
prompt = TokensPrompt( # Build prompt from request (handles both prompt_embeds and token_ids)
prompt_token_ids=request["token_ids"], multi_modal_data=multi_modal_data prompt, embedding_sequence_length, error = self._build_prompt_from_request(
request, request_id, multi_modal_data
) )
if error is not None:
yield error
return
# Build sampling params from request # Build sampling params from request
sampling_params = build_sampling_params( sampling_params = build_sampling_params(
...@@ -1017,6 +1213,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -1017,6 +1213,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
request_id, request_id,
data_parallel_rank=dp_rank, data_parallel_rank=dp_rank,
lora_request=lora_request, lora_request=lora_request,
embedding_sequence_length=embedding_sequence_length,
trace_headers=trace_headers, trace_headers=trace_headers,
): ):
if prefill_result is not None and "completion_usage" in tok: if prefill_result is not None and "completion_usage" in tok:
...@@ -1151,10 +1348,15 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -1151,10 +1348,15 @@ class PrefillWorkerHandler(BaseWorkerHandler):
# Extract and decode multimodal data if present # Extract and decode multimodal data if present
multi_modal_data = await self._extract_multimodal_data(request) multi_modal_data = await self._extract_multimodal_data(request)
token_ids = request["token_ids"] # Build prompt from request (handles both prompt_embeds and token_ids)
prompt = TokensPrompt( prompt, embedding_sequence_length, error = self._build_prompt_from_request(
prompt_token_ids=token_ids, multi_modal_data=multi_modal_data request, request_id, multi_modal_data, log_prefix="Prefill "
) )
if error is not None:
# Prefill errors need disaggregated_params field
error["disaggregated_params"] = None
yield error
return
# Build sampling params from request using shared utility # Build sampling params from request using shared utility
sampling_params = build_sampling_params( sampling_params = build_sampling_params(
...@@ -1236,17 +1438,21 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -1236,17 +1438,21 @@ class PrefillWorkerHandler(BaseWorkerHandler):
else None else None
), ),
"completion_usage": BaseWorkerHandler._build_completion_usage( "completion_usage": BaseWorkerHandler._build_completion_usage(
request_output=res request_output=res,
embedding_sequence_length=embedding_sequence_length,
), ),
} }
# Log prefill completion with LoRA info # Log prefill completion with LoRA info
if lora_request: self._log_with_lora_context(
logger.info( "Prefill completed for request {request_id}{lora_info}: "
f"Prefill completed for request {request_id} with LoRA {lora_request.lora_name}: " "generated {token_count} token(s), has_kv_params={has_kv_params}",
f"generated {len(token_ids)} token(s), " request_id,
f"has_kv_params={res.kv_transfer_params is not None}" lora_request,
) level="info" if lora_request else "debug",
token_count=len(token_ids),
has_kv_params=res.kv_transfer_params is not None,
)
yield output yield output
except asyncio.CancelledError: except asyncio.CancelledError:
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for prompt embeddings support in vLLM backend."""
import base64
import io
from unittest.mock import Mock
import numpy as np
import pytest
import torch
from dynamo.vllm.handlers import BaseWorkerHandler
pytestmark = [
pytest.mark.unit,
pytest.mark.vllm,
]
@pytest.fixture
def mock_handler():
"""Create a mock handler with _decode_prompt_embeds method."""
class MockHandler:
pass
handler = MockHandler()
handler._decode_prompt_embeds = BaseWorkerHandler._decode_prompt_embeds.__get__(
handler
)
return handler
def encode_tensor_to_base64(tensor: torch.Tensor) -> str:
"""Helper to encode a tensor to base64 using PyTorch format."""
buffer = io.BytesIO()
torch.save(tensor, buffer)
buffer.seek(0)
return base64.b64encode(buffer.getvalue()).decode("utf-8")
class TestPromptEmbedsDecode:
"""Tests for prompt embeddings decoding functionality."""
@pytest.mark.parametrize(
"shape,dtype",
[
((10, 4096), torch.float32), # 2D: sequence x hidden
((10, 768), torch.float32), # 2D: smaller hidden dim
((2, 10, 768), torch.float32), # 3D: batch x sequence x hidden
((5, 20, 1024), torch.float16), # 3D with float16
],
ids=["2d-4096", "2d-768", "3d-batch", "3d-float16"],
)
def test_decode_valid_embeddings_various_shapes(self, mock_handler, shape, dtype):
"""Test decoding embeddings with various shapes and dtypes."""
embeddings = torch.randn(*shape, dtype=dtype)
embeddings_base64 = encode_tensor_to_base64(embeddings)
result = mock_handler._decode_prompt_embeds(embeddings_base64)
assert isinstance(result, torch.Tensor)
assert result.shape == shape, f"Shape should be preserved: {shape}"
assert result.dtype == dtype, f"Dtype should be preserved: {dtype}"
torch.testing.assert_close(result, embeddings, rtol=1e-5, atol=1e-5)
@pytest.mark.parametrize(
"invalid_input,error_match,description",
[
# Invalid base64
("not-valid-base64!!!", r"(Invalid base64|Failed to decode)", "bad base64"),
# Empty string
("", r".", "empty string"),
# Raw bytes (not PyTorch format)
(
base64.b64encode(b"not a pytorch tensor").decode("utf-8"),
r"Failed to decode.*PyTorch",
"raw bytes",
),
# Corrupted PyTorch format
(
base64.b64encode(b"PK\x03\x04" + b"invalid_data" * 10).decode("utf-8"),
r"Failed to decode.*PyTorch",
"corrupted zip",
),
],
ids=["bad-base64", "empty", "raw-bytes", "corrupted-zip"],
)
def test_decode_invalid_inputs(
self, mock_handler, invalid_input, error_match, description
):
"""Test that invalid inputs raise ValueError."""
with pytest.raises(ValueError, match=error_match):
mock_handler._decode_prompt_embeds(invalid_input)
def test_decode_numpy_format_rejected(self, mock_handler):
"""Test that NumPy format is rejected (PyTorch format required)."""
embeddings = np.random.randn(10, 768).astype(np.float32)
buffer = io.BytesIO()
np.save(buffer, embeddings)
buffer.seek(0)
embeddings_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
with pytest.raises(ValueError, match="Failed to decode.*PyTorch"):
mock_handler._decode_prompt_embeds(embeddings_base64)
def test_decode_non_tensor_object_rejected(self, mock_handler):
"""Test that non-tensor PyTorch objects are rejected."""
non_tensor = {"key": "value"}
embeddings_base64 = encode_tensor_to_base64_obj(non_tensor)
with pytest.raises(ValueError, match="must be a torch.Tensor"):
mock_handler._decode_prompt_embeds(embeddings_base64)
def encode_tensor_to_base64_obj(obj) -> str:
"""Helper to encode any object to base64 using torch.save."""
buffer = io.BytesIO()
torch.save(obj, buffer)
buffer.seek(0)
return base64.b64encode(buffer.getvalue()).decode("utf-8")
class TestEmbeddingsDataFormats:
"""Tests for various embedding data formats and value ranges."""
@pytest.mark.parametrize("size", [128, 384, 768, 1024, 2048, 4096])
def test_various_embedding_sizes(self, mock_handler, size):
"""Test decoding embeddings of various sizes."""
embeddings = torch.randn(size, dtype=torch.float32)
embeddings_base64 = encode_tensor_to_base64(embeddings)
result = mock_handler._decode_prompt_embeds(embeddings_base64)
assert result.shape == (size,), f"Failed for size {size}"
torch.testing.assert_close(result, embeddings, rtol=1e-5, atol=1e-5)
@pytest.mark.parametrize(
"values",
[
[0.0, 0.0, 0.0], # Zeros
[1.0, 1.0, 1.0], # Ones
[-1.0, 0.0, 1.0], # Mixed
[1e-6, 1e-3, 1e3], # Various magnitudes
[3.14159265, 2.71828182, 1.41421356], # Precise values
],
ids=["zeros", "ones", "mixed", "magnitudes", "precise"],
)
def test_embedding_value_ranges_preserved(self, mock_handler, values):
"""Test that various value ranges are preserved with float32 precision."""
embeddings = torch.tensor(values, dtype=torch.float32)
embeddings_base64 = encode_tensor_to_base64(embeddings)
result = mock_handler._decode_prompt_embeds(embeddings_base64)
torch.testing.assert_close(result, embeddings, rtol=1e-6, atol=1e-6)
class TestUsageStatistics:
"""Tests for usage statistics calculation."""
@pytest.mark.parametrize(
"prompt_token_ids,embedding_seq_len,completion_tokens,expected_prompt,expected_total",
[
# Embeddings: use embedding_sequence_length
([], 10, 5, 10, 15),
# Text: use len(prompt_token_ids)
([1, 2, 3, 4, 5, 6, 7], None, 3, 7, 10),
# Embeddings override token_ids
([1, 2, 3], 20, 2, 20, 22),
# Zero sequence length edge case
([], 0, 2, 0, 2),
],
ids=["embeddings", "text", "embeddings-override", "zero-seq-len"],
)
def test_build_completion_usage(
self,
prompt_token_ids,
embedding_seq_len,
completion_tokens,
expected_prompt,
expected_total,
):
"""Test usage statistics calculation for various scenarios."""
mock_output = Mock()
mock_output.prompt_token_ids = prompt_token_ids
mock_output.outputs = [Mock(token_ids=list(range(completion_tokens)))]
mock_output.num_cached_tokens = 0
result = BaseWorkerHandler._build_completion_usage(
mock_output, embedding_sequence_length=embedding_seq_len
)
assert result["prompt_tokens"] == expected_prompt
assert result["completion_tokens"] == completion_tokens
assert result["total_tokens"] == expected_total
def test_build_completion_usage_no_prompt_info(self):
"""Test usage when no prompt token info available."""
mock_output = Mock()
mock_output.prompt_token_ids = None
mock_output.outputs = [Mock(token_ids=[1, 2, 3])]
mock_output.num_cached_tokens = 0
result = BaseWorkerHandler._build_completion_usage(
mock_output, embedding_sequence_length=None
)
assert result["prompt_tokens"] is None
assert result["completion_tokens"] == 3
assert result["total_tokens"] is None
def test_build_completion_usage_with_cached_tokens(self):
"""Test that cached tokens are reported in prompt_tokens_details."""
mock_output = Mock()
mock_output.prompt_token_ids = [1, 2, 3, 4, 5]
mock_output.outputs = [Mock(token_ids=[6, 7])]
mock_output.num_cached_tokens = 3
result = BaseWorkerHandler._build_completion_usage(
mock_output, embedding_sequence_length=None
)
assert result["prompt_tokens"] == 5
assert result["completion_tokens"] == 2
assert result["prompt_tokens_details"] == {"cached_tokens": 3}
...@@ -11,7 +11,9 @@ networks: ...@@ -11,7 +11,9 @@ networks:
services: services:
nats-server: nats-server:
image: nats:2.11.4 image: nats:2.11.4
command: [ "-js", "--trace", "-m", "8222" ] # max_payload set to 15MB to accommodate 10MB decoded embeddings
# Base64 encoding: 10MB → ~13.3MB, 15MB provides ~1.7MB buffer for metadata
command: [ "-js", "--trace", "-m", "8222", "--max_payload", "15728640" ]
ports: ports:
- 4222:4222 - 4222:4222
- 6222:6222 - 6222:6222
......
...@@ -417,8 +417,9 @@ nats: ...@@ -417,8 +417,9 @@ nats:
# jetstream: # jetstream:
# max_memory_store: << 1GB >> # max_memory_store: << 1GB >>
merge: merge:
# 10MB which allows for larger context size : The default NATS max payload size is 1MB, and 256K tokens (with tokens being int32 - 4 bytes each) tips over that 1MB max. # 15MB to accommodate prompt embeddings: 10MB decoded → ~13.3MB base64-encoded + metadata
max_payload: 10485760 # Also allows larger context: 256K tokens (int32 - 4 bytes each) = 1MB
max_payload: 15728640
patch: [] patch: []
############################################################ ############################################################
......
...@@ -42,6 +42,7 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1)) ...@@ -42,6 +42,7 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
| [**Load Based Planner**](../../../docs/planner/load_planner.md) | 🚧 | WIP | | [**Load Based Planner**](../../../docs/planner/load_planner.md) | 🚧 | WIP |
| [**KVBM**](../../../docs/kvbm/kvbm_architecture.md) | ✅ | | | [**KVBM**](../../../docs/kvbm/kvbm_architecture.md) | ✅ | |
| [**LMCache**](./LMCache_Integration.md) | ✅ | | | [**LMCache**](./LMCache_Integration.md) | ✅ | |
| [**Prompt Embeddings**](./prompt-embeddings.md) | ✅ | Requires `--enable-prompt-embeds` flag |
### Large Scale P/D and WideEP Features ### Large Scale P/D and WideEP Features
...@@ -152,6 +153,10 @@ vLLM workers are configured through command-line arguments. Key parameters inclu ...@@ -152,6 +153,10 @@ vLLM workers are configured through command-line arguments. Key parameters inclu
- `--is-prefill-worker`: Enable prefill-only mode for disaggregated serving - `--is-prefill-worker`: Enable prefill-only mode for disaggregated serving
- `--metrics-endpoint-port`: Port for publishing KV metrics to Dynamo - `--metrics-endpoint-port`: Port for publishing KV metrics to Dynamo
- `--connector`: Specify which kv_transfer_config you want vllm to use `[nixl, lmcache, kvbm, none]`. This is a helper flag which overwrites the engines KVTransferConfig. - `--connector`: Specify which kv_transfer_config you want vllm to use `[nixl, lmcache, kvbm, none]`. This is a helper flag which overwrites the engines KVTransferConfig.
- `--enable-prompt-embeds`: **Enable prompt embeddings feature** (opt-in, default: disabled)
- **Required for:** Accepting pre-computed prompt embeddings via API
- **Default behavior:** Prompt embeddings DISABLED - requests with `prompt_embeds` will fail
- **Error without flag:** `ValueError: You must set --enable-prompt-embeds to input prompt_embeds`
See `args.py` for the full list of configuration options and their defaults. See `args.py` for the full list of configuration options and their defaults.
......
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
-->
# Prompt Embeddings
Dynamo supports prompt embeddings (also known as prompt embeds) as a secure alternative input method to traditional text prompts. By allowing applications to use pre-computed embeddings for inference, this feature not only offers greater flexibility in prompt engineering but also significantly enhances privacy and data security. With prompt embeddings, sensitive user data can be transformed into embeddings before ever reaching the inference server, reducing the risk of exposing confidential information during the AI workflow.
## How It Works
| Path | What Happens |
|------|--------------|
| **Text prompt** | Tokenize → Embedding Layer → Transformer |
| **Prompt embeds** | Validate → Bypass Embedding → Transformer |
## Architecture
```mermaid
flowchart LR
subgraph FE["Frontend (Rust)"]
A[Request] --> B{prompt_embeds?}
B -->|No| C[🔴 Tokenize text]
B -->|Yes| D[🟢 Validate base64+size]
C --> E[token_ids, ISL=N]
D --> F[token_ids=empty, skip ISL]
end
subgraph RT["Router (NATS)"]
G[Route PreprocessedRequest]
end
subgraph WK["Worker (Python)"]
H[TokensPrompt#40;token_ids#41;]
I[Decode → EmbedsPrompt#40;tensor#41;]
end
subgraph VLLM["vLLM Engine"]
J[🔴 Embedding Layer]
K[🟢 Bypass Embedding]
L[Transformer Layers]
M[LM Head → Response]
end
E --> G
F --> G
G -->|Normal| H
G -->|Embeds| I
H --> J --> L
I --> K --> L
L --> M
```
| Layer | **Normal Flow** | **Prompt Embeds** |
|---|---|---|
| **Frontend (Rust)** | 🔴 Tokenize text → token_ids, compute ISL | 🟢 Validate base64+size, skip tokenization |
| **Router (NATS)** | Forward token_ids in PreprocessedRequest | Forward prompt_embeds string |
| **Worker (Python)** | `TokensPrompt(token_ids)` | Decode base64 → `EmbedsPrompt(tensor)` |
| **vLLM Engine** | 🔴 Embedding Layer → Transformer | 🟢 Bypass Embedding → Transformer |
## Quick Start
Send pre-computed prompt embeddings directly to vLLM, bypassing tokenization.
### 1. Enable Feature
```bash
python -m dynamo.vllm --model <model-name> --enable-prompt-embeds
```
> **Required:** The `--enable-prompt-embeds` flag must be set or requests will fail.
### 2. Send Request
```python
import torch
import base64
import io
from openai import OpenAI
# Prepare embeddings (sequence_length, hidden_dim)
embeddings = torch.randn(10, 4096, dtype=torch.float32)
# Encode
buffer = io.BytesIO()
torch.save(embeddings, buffer)
buffer.seek(0)
embeddings_base64 = base64.b64encode(buffer.read()).decode()
# Send
client = OpenAI(base_url="http://localhost:8000/v1", api_key="EMPTY")
response = client.completions.create(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
prompt="", # Can be empty or present; prompt_embeds takes precedence
max_tokens=100,
extra_body={"prompt_embeds": embeddings_base64}
)
```
## Configuration
### Docker Compose
```yaml
vllm-worker:
command:
- python
- -m
- dynamo.vllm
- --model
- meta-llama/Meta-Llama-3.1-8B-Instruct
- --enable-prompt-embeds # Add this
```
### Kubernetes
```yaml
extraPodSpec:
mainContainer:
args:
- "--model"
- "meta-llama/Meta-Llama-3.1-8B-Instruct"
- "--enable-prompt-embeds" # Add this
```
### NATS Configuration
NATS needs 15MB payload limit (already configured in default deployments):
```yaml
# Docker Compose - deploy/docker-compose.yml
nats-server:
command: ["-js", "--trace", "-m", "8222", "--max_payload", "15728640"]
# Kubernetes - deploy/cloud/helm/platform/values.yaml
nats:
config:
merge:
max_payload: 15728640
```
## API Reference
### Request
```json
{
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"prompt": "",
"prompt_embeds": "<base64-encoded-pytorch-tensor>",
"max_tokens": 100
}
```
**Requirements:**
- **Format:** PyTorch tensor serialized with `torch.save()` and base64-encoded
- **Size:** 100 bytes - 10MB (decoded)
- **Shape:** `(seq_len, hidden_dim)` or `(batch, seq_len, hidden_dim)`
- **Dtype:** `torch.float32` (recommended)
**Field Precedence:**
- Both `prompt` and `prompt_embeds` can be provided in the same request
- When both are present, **`prompt_embeds` takes precedence** and `prompt` is ignored
- The `prompt` field can be empty (`""`) when using `prompt_embeds`
### Response
Standard OpenAI format with accurate usage:
```json
{
"usage": {
"prompt_tokens": 10, // Extracted from embedding shape
"completion_tokens": 15,
"total_tokens": 25
}
}
```
## Errors
| Error | Fix |
|-------|-----|
| `ValueError: You must set --enable-prompt-embeds` | Add `--enable-prompt-embeds` to worker |
| `prompt_embeds must be valid base64` | Use `.decode('utf-8')` after `base64.b64encode()` |
| `decoded data must be at least 100 bytes` | Increase sequence length |
| `exceeds maximum size of 10MB` | Reduce sequence length |
| `must be a torch.Tensor` | Use `torch.save()` not NumPy |
| `size of tensor must match` | Use correct hidden dimension for model |
## Examples
### Streaming
```python
stream = client.completions.create(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
prompt="",
max_tokens=100,
stream=True,
extra_body={"prompt_embeds": embeddings_base64}
)
for chunk in stream:
if chunk.choices:
print(chunk.choices[0].text, end="", flush=True)
```
### Load from File
```python
embeddings = torch.load("embeddings.pt")
buffer = io.BytesIO()
torch.save(embeddings, buffer)
buffer.seek(0)
embeddings_base64 = base64.b64encode(buffer.read()).decode()
# Use in request...
```
## Limitations
- ❌ Requires `--enable-prompt-embeds` flag (disabled by default)
- ❌ PyTorch format only (NumPy not supported)
- ❌ 10MB decoded size limit
- ❌ Cannot mix with multimodal data (images/video)
## Testing
Comprehensive test coverage ensures reliability:
- **Unit Tests:** 31 tests (11 Rust + 20 Python)
- Validation, decoding, format handling, error cases, usage statistics
- **Integration Tests:** 21 end-to-end tests
- Core functionality, performance, formats, concurrency, usage statistics
Run integration tests:
```bash
# Start worker with flag
python -m dynamo.vllm --model Qwen/Qwen3-0.6B --enable-prompt-embeds
# Run tests
pytest tests/integration/test_prompt_embeds_integration.py -v
```
## See Also
- [vLLM Backend](README.md)
- [vLLM Configuration](README.md#configuration)
...@@ -70,6 +70,7 @@ ...@@ -70,6 +70,7 @@
backends/vllm/LMCache_Integration.md backends/vllm/LMCache_Integration.md
backends/vllm/multi-node.md backends/vllm/multi-node.md
backends/vllm/prometheus.md backends/vllm/prometheus.md
backends/vllm/prompt-embeddings.md
backends/vllm/speculative_decoding.md backends/vllm/speculative_decoding.md
benchmarks/kv-router-ab-testing.md benchmarks/kv-router-ab-testing.md
......
...@@ -75,9 +75,19 @@ extraPodSpec: ...@@ -75,9 +75,19 @@ extraPodSpec:
- "python3" - "python3"
- "-m" - "-m"
- "dynamo.vllm" - "dynamo.vllm"
# Model-specific arguments - "--model"
- "Qwen/Qwen3-0.6B"
# Optional: Enable prompt embeddings feature
# - "--enable-prompt-embeds"
# Other model-specific arguments
``` ```
**Common vLLM Flags:**
- `--enable-prompt-embeds`: Enable prompt embeddings feature
- `--enable-multimodal`: Enable multimodal (vision) support
- `--is-prefill-worker`: Prefill-only mode for disaggregated serving
- `--connector [nixl|lmcache|kvbm|none]`: KV transfer backend selection
## Prerequisites ## Prerequisites
Before using these templates, ensure you have: Before using these templates, ensure you have:
......
...@@ -105,6 +105,13 @@ pub struct CreateCompletionRequest { ...@@ -105,6 +105,13 @@ pub struct CreateCompletionRequest {
/// Note that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document. /// Note that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.
pub prompt: Prompt, pub prompt: Prompt,
/// Base64-encoded PyTorch tensor containing pre-computed embeddings.
/// At least one of prompt or prompt_embeds is required.
/// If both are provided, prompt_embeds takes precedence.
/// Maximum size: 10MB decoded.
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_embeds: Option<String>,
/// The suffix that comes after a completion of inserted text. /// The suffix that comes after a completion of inserted text.
/// ///
/// This parameter is only supported for `gpt-3.5-turbo-instruct`. /// This parameter is only supported for `gpt-3.5-turbo-instruct`.
......
...@@ -953,8 +953,11 @@ impl ...@@ -953,8 +953,11 @@ impl
let mut response_generator = Box::new(response_generator); let mut response_generator = Box::new(response_generator);
// update isl // Update ISL only for text prompts (embeddings get sequence length from tensor shape)
response_generator.update_isl(common_request.token_ids.len() as u32); if common_request.prompt_embeds.is_none() {
let isl = common_request.token_ids.len() as u32;
response_generator.update_isl(isl);
}
// repack the common completion request // repack the common completion request
let common_request = context.map(|_| common_request); let common_request = context.map(|_| common_request);
...@@ -1095,7 +1098,20 @@ impl ...@@ -1095,7 +1098,20 @@ impl
let mut response_generator = Box::new(response_generator); let mut response_generator = Box::new(response_generator);
// convert the chat completion request to a common completion request // convert the chat completion request to a common completion request
let mut builder = self.builder(&request)?; let mut builder = self.builder(&request)?;
let annotations = self.gather_tokens(&request, &mut builder, None)?;
// Check if embeddings are provided - skip tokenization path
let annotations = if let Some(ref prompt_embeds) = request.inner.prompt_embeds {
// Skip tokenization for embeddings
builder.token_ids(vec![]); // Empty token IDs
builder.prompt_embeds(Some(prompt_embeds.clone()));
// No token annotations
HashMap::new()
} else {
// Normal path: tokenize the prompt
self.gather_tokens(&request, &mut builder, None)?
};
// Gather multimodal data (works with both embeddings and text prompts)
self.gather_multi_modal_data(&request, &mut builder).await?; self.gather_multi_modal_data(&request, &mut builder).await?;
let mut common_request = builder.build()?; let mut common_request = builder.build()?;
...@@ -1103,8 +1119,11 @@ impl ...@@ -1103,8 +1119,11 @@ impl
// Attach the timing tracker to the request so downstream components can record metrics // Attach the timing tracker to the request so downstream components can record metrics
common_request.tracker = response_generator.tracker(); common_request.tracker = response_generator.tracker();
// update isl // Update ISL only for text prompts (embeddings get sequence length from tensor shape)
response_generator.update_isl(common_request.token_ids.len() as u32); if common_request.prompt_embeds.is_none() {
let isl = common_request.token_ids.len() as u32;
response_generator.update_isl(isl);
}
// repack the common completion request // repack the common completion request
let common_request = context.map(|_| common_request); let common_request = context.map(|_| common_request);
......
...@@ -90,6 +90,12 @@ pub struct PreprocessedRequest { ...@@ -90,6 +90,12 @@ pub struct PreprocessedRequest {
/// Type of prompt /// Type of prompt
pub token_ids: Vec<TokenIdType>, pub token_ids: Vec<TokenIdType>,
/// Base64-encoded PyTorch tensor containing pre-computed embeddings
/// If provided, this takes precedence over token_ids for inference
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt_embeds: Option<String>,
// Multimodal data // Multimodal data
#[builder(default)] #[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
......
...@@ -381,14 +381,17 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes ...@@ -381,14 +381,17 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
self.usage.completion_tokens += token_length; self.usage.completion_tokens += token_length;
// If backend provides completion_usage with prompt token details, // If backend provides completion_usage, use it to update usage stats
// propagate the entire details struct to usage tracking // This is critical for prompt embeddings where prompt_tokens comes from
if let Some(prompt_details) = delta // the embedding sequence length computed by the worker
.completion_usage if let Some(completion_usage) = delta.completion_usage.as_ref() {
.as_ref() // Update prompt_tokens from worker if provided (e.g., for embeddings)
.and_then(|usage| usage.prompt_tokens_details.as_ref()) self.usage.prompt_tokens = completion_usage.prompt_tokens;
{
self.usage.prompt_tokens_details = Some(prompt_details.clone()); // Propagate prompt token details if provided
if let Some(prompt_details) = completion_usage.prompt_tokens_details.as_ref() {
self.usage.prompt_tokens_details = Some(prompt_details.clone());
}
} }
let logprobs = self.create_logprobs( let logprobs = self.create_logprobs(
......
...@@ -421,7 +421,13 @@ impl ValidateRequest for NvCreateCompletionRequest { ...@@ -421,7 +421,13 @@ impl ValidateRequest for NvCreateCompletionRequest {
fn validate(&self) -> Result<(), anyhow::Error> { fn validate(&self) -> Result<(), anyhow::Error> {
validate::validate_no_unsupported_fields(&self.unsupported_fields)?; validate::validate_no_unsupported_fields(&self.unsupported_fields)?;
validate::validate_model(&self.inner.model)?; validate::validate_model(&self.inner.model)?;
validate::validate_prompt(&self.inner.prompt)?;
// Validate prompt and prompt_embeds together (checks presence, format, and content)
validate::validate_prompt_or_embeds(
Some(&self.inner.prompt),
self.inner.prompt_embeds.as_deref(),
)?;
validate::validate_suffix(self.inner.suffix.as_deref())?; validate::validate_suffix(self.inner.suffix.as_deref())?;
validate::validate_max_tokens(self.inner.max_tokens)?; validate::validate_max_tokens(self.inner.max_tokens)?;
validate::validate_temperature(self.inner.temperature)?; validate::validate_temperature(self.inner.temperature)?;
...@@ -458,7 +464,9 @@ impl ValidateRequest for NvCreateCompletionRequest { ...@@ -458,7 +464,9 @@ impl ValidateRequest for NvCreateCompletionRequest {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::engines::ValidateRequest;
use crate::protocols::common::OutputOptionsProvider; use crate::protocols::common::OutputOptionsProvider;
use base64::Engine;
use serde_json::json; use serde_json::json;
#[test] #[test]
...@@ -500,6 +508,146 @@ mod tests { ...@@ -500,6 +508,146 @@ mod tests {
} }
} }
#[test]
fn test_prompt_embeds_only() {
// Create valid embeddings: > 100 bytes (PyTorch format)
let valid_data = vec![0u8; 256];
let encoded = base64::engine::general_purpose::STANDARD.encode(&valid_data);
let json_str = json!({
"model": "test-model",
"prompt": "test",
"prompt_embeds": encoded
});
let request: NvCreateCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");
assert!(ValidateRequest::validate(&request).is_ok());
assert!(request.inner.prompt_embeds.is_some());
}
#[test]
fn test_both_prompt_and_embeds() {
// Both fields are allowed, prompt_embeds takes precedence at worker level
// Create valid embeddings: > 100 bytes (PyTorch format)
let valid_data = vec![0u8; 256];
let encoded = base64::engine::general_purpose::STANDARD.encode(&valid_data);
let json_str = json!({
"model": "test-model",
"prompt": "Hello",
"prompt_embeds": encoded
});
let request: NvCreateCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");
assert!(ValidateRequest::validate(&request).is_ok());
}
#[test]
fn test_invalid_base64() {
// Create invalid base64 that's long enough (>100 bytes) to pass size check
// Use characters that look like base64 but aren't valid
let invalid_base64 = "not-valid-base64!!!".repeat(10); // 190 bytes, looks like base64 but invalid
let json_str = json!({
"model": "test-model",
"prompt": "test",
"prompt_embeds": invalid_base64
});
let request: NvCreateCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");
let result = ValidateRequest::validate(&request);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("base64"));
}
#[test]
fn test_embeds_too_large() {
// Create embeddings with DECODED size larger than 10MB
// Base64 encoding adds ~33% overhead, so we need 11MB decoded = ~14.7MB encoded
let large_data = vec![0u8; 11 * 1024 * 1024]; // 11MB decoded
let large_embeds = base64::engine::general_purpose::STANDARD.encode(&large_data);
let json_str = json!({
"model": "test-model",
"prompt": "test",
"prompt_embeds": large_embeds
});
let request: NvCreateCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");
let result = ValidateRequest::validate(&request);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("10MB"));
}
#[test]
fn test_embeds_too_small() {
// Create embeddings with DECODED size smaller than 100 bytes
let small_data = vec![0u8; 20]; // Only 20 bytes when decoded
let encoded = base64::engine::general_purpose::STANDARD.encode(&small_data);
let json_str = json!({
"model": "test-model",
"prompt": "test",
"prompt_embeds": encoded
});
let request: NvCreateCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");
let result = ValidateRequest::validate(&request);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("100 bytes")
|| err_msg.contains("at least")
|| err_msg.contains("decoded")
);
}
#[test]
fn test_embeddings_with_empty_prompt() {
// Test that empty prompt is ALLOWED when embeddings provided
let valid_data = vec![0u8; 256]; // Valid size and aligned
let encoded = base64::engine::general_purpose::STANDARD.encode(&valid_data);
let json_str = json!({
"model": "test-model",
"prompt": "", // Empty prompt is OK with embeddings
"prompt_embeds": encoded
});
let request: NvCreateCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");
// Should succeed - embeddings take precedence, prompt can be empty
assert!(ValidateRequest::validate(&request).is_ok());
}
#[test]
fn test_empty_prompt_without_embeddings_fails() {
// Empty prompt WITHOUT embeddings should fail
let json_str = json!({
"model": "test-model",
"prompt": "", // Empty prompt
// No prompt_embeds
});
let request: NvCreateCompletionRequest =
serde_json::from_value(json_str).expect("Failed to deserialize request");
let result = ValidateRequest::validate(&request);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("cannot be empty"));
}
#[test] #[test]
fn test_stop() { fn test_stop() {
let null_stop = json!({ let null_stop = json!({
......
...@@ -300,14 +300,17 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for ...@@ -300,14 +300,17 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
self.usage.completion_tokens += token_length; self.usage.completion_tokens += token_length;
// If backend provides completion_usage with prompt token details, // If backend provides completion_usage, use it to update usage stats
// propagate the entire details struct to usage tracking // This is critical for prompt embeddings where prompt_tokens comes from
if let Some(prompt_details) = delta // the embedding sequence length computed by the worker
.completion_usage if let Some(completion_usage) = delta.completion_usage.as_ref() {
.as_ref() // Update prompt_tokens from worker if provided (e.g., for embeddings)
.and_then(|usage| usage.prompt_tokens_details.as_ref()) self.usage.prompt_tokens = completion_usage.prompt_tokens;
{
self.usage.prompt_tokens_details = Some(prompt_details.clone()); // Propagate prompt token details if provided
if let Some(prompt_details) = completion_usage.prompt_tokens_details.as_ref() {
self.usage.prompt_tokens_details = Some(prompt_details.clone());
}
} }
let logprobs = self.create_logprobs( let logprobs = self.create_logprobs(
......
...@@ -507,6 +507,75 @@ pub fn validate_prompt(prompt: &dynamo_async_openai::types::Prompt) -> Result<() ...@@ -507,6 +507,75 @@ pub fn validate_prompt(prompt: &dynamo_async_openai::types::Prompt) -> Result<()
Ok(()) Ok(())
} }
/// Validates prompt and prompt_embeds fields together.
///
/// This function consolidates all prompt-related validation:
/// - Ensures at least one of prompt or prompt_embeds is provided
/// - If prompt_embeds is provided, validates its format (base64, size limits)
/// - If prompt_embeds is NOT provided, validates that prompt is non-empty
///
/// Format for prompt_embeds: PyTorch tensor serialized with torch.save() and base64-encoded
pub fn validate_prompt_or_embeds(
prompt: Option<&dynamo_async_openai::types::Prompt>,
prompt_embeds: Option<&str>,
) -> Result<(), anyhow::Error> {
// Check that at least one is provided
if prompt.is_none() && prompt_embeds.is_none() {
anyhow::bail!("At least one of 'prompt' or 'prompt_embeds' must be provided");
}
// If prompt_embeds is provided, validate it
if let Some(embeds) = prompt_embeds {
validate_prompt_embeds_format(embeds)?;
} else if let Some(p) = prompt {
// Only validate prompt content if prompt_embeds is NOT provided
// When embeddings are present, prompt can be empty/placeholder
validate_prompt(p)?;
}
Ok(())
}
/// Validates prompt_embeds format (internal helper)
/// Format: PyTorch tensor serialized with torch.save() and base64-encoded
fn validate_prompt_embeds_format(embeds: &str) -> Result<(), anyhow::Error> {
use base64::{Engine as _, engine::general_purpose};
// Validate base64 encoding first
let decoded = general_purpose::STANDARD
.decode(embeds)
.map_err(|_| anyhow::anyhow!("prompt_embeds must be valid base64-encoded data"))?;
// Check minimum size on decoded bytes (100 bytes)
const MIN_SIZE: usize = 100;
if decoded.len() < MIN_SIZE {
anyhow::bail!(
"prompt_embeds decoded data must be at least {MIN_SIZE} bytes, got {} bytes",
decoded.len()
);
}
// Check maximum size on decoded bytes (10MB)
const MAX_SIZE: usize = 10 * 1024 * 1024;
if decoded.len() > MAX_SIZE {
anyhow::bail!(
"prompt_embeds decoded data exceeds maximum size of 10MB, got {} bytes",
decoded.len()
);
}
Ok(())
}
/// Validates prompt_embeds field (public wrapper for standalone validation)
/// Format: PyTorch tensor serialized with torch.save() and base64-encoded
pub fn validate_prompt_embeds(prompt_embeds: Option<&str>) -> Result<(), anyhow::Error> {
if let Some(embeds) = prompt_embeds {
validate_prompt_embeds_format(embeds)?;
}
Ok(())
}
/// Validates logprobs parameter (for completion requests) /// Validates logprobs parameter (for completion requests)
pub fn validate_logprobs(logprobs: Option<u8>) -> Result<(), anyhow::Error> { pub fn validate_logprobs(logprobs: Option<u8>) -> Result<(), anyhow::Error> {
if let Some(value) = logprobs if let Some(value) = logprobs
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Shared fixtures for frontend tests.""" """Shared fixtures for frontend tests.
Handles conditional test collection to prevent import errors when required
dependencies are not installed in the current environment.
"""
import importlib.util
import logging import logging
import os import os
import shutil import shutil
...@@ -19,6 +24,20 @@ from tests.utils.port_utils import allocate_port, deallocate_port ...@@ -19,6 +24,20 @@ from tests.utils.port_utils import allocate_port, deallocate_port
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def pytest_ignore_collect(collection_path, config):
"""Skip collecting test files if required dependencies aren't installed."""
filename = collection_path.name
# Skip prompt_embeds tests if openai or torch aren't available
if filename == "test_prompt_embeds.py":
if importlib.util.find_spec("openai") is None:
return True # openai not available, skip this file
if importlib.util.find_spec("torch") is None:
return True # torch not available, skip this file
return None
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def start_services_with_http( def start_services_with_http(
request, runtime_services_dynamic_ports, dynamo_dynamic_ports request, runtime_services_dynamic_ports, dynamo_dynamic_ports
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
End-to-end tests for prompt embeddings support in Dynamo.
These tests validate behavior that cannot be covered by Rust unit tests:
- Streaming responses with embeddings
- Python-side tensor decoding errors
- Usage statistics from worker (the v2.0.4 bug fix)
- Large payload handling through NATS
- Concurrent request handling
Validation tests (base64, size limits, empty prompt) are covered by Rust unit tests
in lib/llm/src/protocols/openai/completions.rs
"""
import base64
import concurrent.futures
import io
import logging
import pytest
import torch
from openai import OpenAI
logger = logging.getLogger(__name__)
# Test model - small and fast for CI
TEST_MODEL = "Qwen/Qwen3-0.6B"
@pytest.fixture
def dynamo_client():
"""Create OpenAI client pointing to Dynamo frontend."""
return OpenAI(
api_key="EMPTY",
base_url="http://localhost:8000/v1",
)
def create_embeddings_base64(shape: tuple[int, ...]) -> str:
"""Create random embeddings tensor and return as base64-encoded PyTorch format."""
embeddings = torch.randn(*shape, dtype=torch.float32)
buffer = io.BytesIO()
torch.save(embeddings, buffer)
buffer.seek(0)
return base64.b64encode(buffer.read()).decode("utf-8")
@pytest.mark.integration
@pytest.mark.vllm
@pytest.mark.nightly
@pytest.mark.gpu_1
@pytest.mark.model(TEST_MODEL)
class TestPromptEmbedsE2E:
"""
End-to-end tests for prompt embeddings.
These tests require a running Dynamo instance with vLLM backend.
They validate behavior that Rust unit tests cannot cover.
"""
def test_streaming_with_embeddings(self, dynamo_client):
"""
Test streaming responses work correctly with embeddings.
This is E2E only - Rust tests can't verify streaming behavior.
"""
embeddings_base64 = create_embeddings_base64((10, 1024))
stream = dynamo_client.completions.create(
model=TEST_MODEL,
prompt="",
max_tokens=10,
stream=True,
extra_body={"prompt_embeds": embeddings_base64},
)
chunks = list(stream)
assert len(chunks) > 0, "Should receive at least one chunk"
# Last chunk should have finish_reason
if chunks[-1].choices:
assert chunks[-1].choices[0].finish_reason is not None
def test_invalid_tensor_data_rejected(self, dynamo_client):
"""
Test that invalid tensor data is properly rejected by Python decoder.
This tests the Python-side torch.load() error handling, which
Rust validation cannot cover (Rust only checks base64 and size).
"""
# Create data that passes Rust validation (valid base64, >100 bytes)
# but fails Python torch.load()
invalid_data = b"this is not a valid pytorch tensor format!" * 10
invalid_base64 = base64.b64encode(invalid_data).decode("utf-8")
with pytest.raises(Exception) as exc_info:
dynamo_client.completions.create(
model=TEST_MODEL,
prompt="",
max_tokens=5,
extra_body={"prompt_embeds": invalid_base64},
)
error_msg = str(exc_info.value).lower()
assert any(
keyword in error_msg
for keyword in ["pytorch", "tensor", "invalid", "decode", "error"]
), f"Expected tensor decode error, got: {error_msg}"
def test_usage_prompt_tokens_not_zero(self, dynamo_client):
"""
CRITICAL REGRESSION TEST: Ensure prompt_tokens is correctly reported.
This validates the v2.0.4 fix where prompt_tokens was incorrectly
reported as 0 when using embeddings. The worker extracts sequence
length from tensor shape and includes it in completion_usage.
Rust tests cannot verify this - it requires E2E validation.
"""
sequence_length = 20
embeddings_base64 = create_embeddings_base64((sequence_length, 1024))
response = dynamo_client.completions.create(
model=TEST_MODEL,
prompt="",
max_tokens=3,
extra_body={"prompt_embeds": embeddings_base64},
)
assert response.usage is not None, "Should have usage statistics"
assert (
response.usage.prompt_tokens != 0
), "BUG REGRESSION: prompt_tokens is 0! This was the bug in v2.0.3."
assert (
response.usage.prompt_tokens == sequence_length
), f"Expected prompt_tokens={sequence_length}, got {response.usage.prompt_tokens}"
assert response.usage.total_tokens == (
response.usage.prompt_tokens + response.usage.completion_tokens
), "total_tokens should equal prompt_tokens + completion_tokens"
def test_large_embeddings_through_nats(self, dynamo_client):
"""
Test large embeddings are handled correctly through NATS.
This validates the NATS max_payload configuration (15MB) handles
large embedding payloads. Rust unit tests can't test this E2E path.
"""
# Create ~7MB embeddings (well under 10MB limit, but large enough to stress NATS)
large_shape = (1700, 1024) # ~6.6MB of float32 data
large_embeds = torch.randn(large_shape, dtype=torch.float32)
buffer = io.BytesIO()
torch.save(large_embeds, buffer)
buffer.seek(0)
large_bytes = buffer.read()
large_base64 = base64.b64encode(large_bytes).decode("utf-8")
logger.info(
f"Testing large embeddings: {len(large_bytes)/1024/1024:.2f}MB decoded"
)
response = dynamo_client.completions.create(
model=TEST_MODEL,
prompt="",
max_tokens=5,
extra_body={"prompt_embeds": large_base64},
)
assert response.choices, "Large embeddings should produce valid response"
assert len(large_bytes) < 10 * 1024 * 1024, "Test data should be under 10MB"
def test_concurrent_embeddings_requests(self, dynamo_client):
"""
Test concurrent requests with embeddings are handled correctly.
This validates the worker can handle multiple embedding requests
simultaneously without race conditions or resource conflicts.
"""
embeddings_base64 = create_embeddings_base64((10, 1024))
def send_request():
return dynamo_client.completions.create(
model=TEST_MODEL,
prompt="",
max_tokens=5,
extra_body={"prompt_embeds": embeddings_base64},
)
# Send 5 concurrent requests
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(send_request) for _ in range(5)]
results = [f.result() for f in concurrent.futures.as_completed(futures)]
assert len(results) == 5, "All concurrent requests should complete"
for response in results:
assert response.choices, "Each response should have choices"
assert len(response.choices[0].text) > 0, "Each response should have text"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment