Unverified Commit 9f76d060 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

feat: text to image vLLM Omni (#5912)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent d14d6ff4
...@@ -265,7 +265,7 @@ def parse_args() -> Config: ...@@ -265,7 +265,7 @@ def parse_args() -> Config:
"--stage-configs-path", "--stage-configs-path",
type=str, type=str,
default=None, default=None,
help="Path to vLLM-Omni stage configuration YAML file. Required for --omni.", help="Path to vLLM-Omni stage configuration YAML file for --omni mode (optional).",
) )
parser.add_argument( parser.add_argument(
"--store-kv", "--store-kv",
...@@ -388,9 +388,9 @@ def parse_args() -> Config: ...@@ -388,9 +388,9 @@ def parse_args() -> Config:
) )
# Validate omni worker requirements # Validate omni worker requirements
if args.omni and not args.stage_configs_path: if args.stage_configs_path and not args.omni:
raise ValueError( raise ValueError(
"--stage-configs-path is required when using --omni. " "--stage-configs-path is only allowed when using --omni. "
"Specify a YAML file containing stage configurations for the multi-stage pipeline." "Specify a YAML file containing stage configurations for the multi-stage pipeline."
) )
...@@ -452,7 +452,8 @@ def parse_args() -> Config: ...@@ -452,7 +452,8 @@ def parse_args() -> Config:
config.request_plane = args.request_plane config.request_plane = args.request_plane
config.event_plane = args.event_plane config.event_plane = args.event_plane
config.enable_local_indexer = not args.durable_kv_events config.enable_local_indexer = not args.durable_kv_events
config.use_vllm_tokenizer = args.use_vllm_tokenizer # For omni mode, use vLLM (AsyncOmni) tokenizer on backend
config.use_vllm_tokenizer = args.use_vllm_tokenizer or args.omni
config.sleep_mode_level = args.sleep_mode_level config.sleep_mode_level = args.sleep_mode_level
# use_kv_events is set later in overwrite_args() based on kv_events_config # use_kv_events is set later in overwrite_args() based on kv_events_config
......
...@@ -1215,18 +1215,12 @@ async def init_omni( ...@@ -1215,18 +1215,12 @@ async def init_omni(
component = runtime.namespace(config.namespace).component(config.component) component = runtime.namespace(config.namespace).component(config.component)
generate_endpoint = component.endpoint(config.endpoint) generate_endpoint = component.endpoint(config.endpoint)
# Load default sampling params from model config (same as other workers)
default_sampling_params = (
config.engine_args.create_model_config().get_diff_sampling_param()
)
logger.info(f"Loaded default sampling params: {default_sampling_params}")
# Initialize OmniHandler with Omni orchestrator # Initialize OmniHandler with Omni orchestrator
handler = OmniHandler( handler = OmniHandler(
runtime=runtime, runtime=runtime,
component=component, component=component,
config=config, config=config,
default_sampling_params=default_sampling_params, default_sampling_params={},
shutdown_event=shutdown_event, shutdown_event=shutdown_event,
) )
...@@ -1241,11 +1235,9 @@ async def init_omni( ...@@ -1241,11 +1235,9 @@ async def init_omni(
return return
# TODO: extend for multi-stage pipelines # TODO: extend for multi-stage pipelines
# Register as Chat endpoint for text-to-text generation
# Use Tokens input since we're doing token-based processing
await register_llm( await register_llm(
ModelInput.Tokens, ModelInput.Text,
ModelType.Chat, ModelType.Images,
generate_endpoint, generate_endpoint,
config.model, config.model,
config.served_model_name, config.served_model_name,
...@@ -1254,7 +1246,6 @@ async def init_omni( ...@@ -1254,7 +1246,6 @@ async def init_omni(
logger.info("Starting to serve Omni worker endpoint...") logger.info("Starting to serve Omni worker endpoint...")
# Create health check payload (extracts BOS token from AsyncOmni)
health_check_payload = ( health_check_payload = (
await VllmOmniHealthCheckPayload.create(handler.engine_client) await VllmOmniHealthCheckPayload.create(handler.engine_client)
).to_dict() ).to_dict()
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Omni handler for text-to-text generation using vLLM-Omni orchestrator."""
import asyncio import asyncio
import logging import logging
import time
from typing import Any, AsyncGenerator, Dict from typing import Any, AsyncGenerator, Dict
from vllm import SamplingParams from vllm import SamplingParams
from vllm.inputs import TokensPrompt
from vllm_omni.entrypoints import AsyncOmni from vllm_omni.entrypoints import AsyncOmni
from vllm_omni.inputs.data import OmniTextPrompt, OmniTokensPrompt
from dynamo.vllm.handlers import BaseWorkerHandler, build_sampling_params from dynamo.vllm.handlers import BaseWorkerHandler, build_sampling_params
...@@ -32,10 +30,6 @@ class OmniHandler(BaseWorkerHandler): ...@@ -32,10 +30,6 @@ class OmniHandler(BaseWorkerHandler):
f"Initializing OmniHandler for multi-stage pipelines with model: {config.model}" f"Initializing OmniHandler for multi-stage pipelines with model: {config.model}"
) )
# Initialize AsyncOmni with stage configuration
# Note: stage_configs_path is validated as required in args.py
logger.info(f"Using stage config from: {config.stage_configs_path}")
omni_kwargs = { omni_kwargs = {
"model": config.model, "model": config.model,
"trust_remote_code": config.engine_args.trust_remote_code, "trust_remote_code": config.engine_args.trust_remote_code,
...@@ -54,95 +48,242 @@ class OmniHandler(BaseWorkerHandler): ...@@ -54,95 +48,242 @@ class OmniHandler(BaseWorkerHandler):
self.config = config self.config = config
self.model_max_len = config.engine_args.max_model_len self.model_max_len = config.engine_args.max_model_len
self.shutdown_event = shutdown_event self.shutdown_event = shutdown_event
self.use_vllm_tokenizer = config.use_vllm_tokenizer
logger.info("OmniHandler initialized successfully for text-to-text generation") logger.info("OmniHandler initialized successfully for text-to-text generation")
async def generate( async def generate(
self, request: Dict[str, Any], context self, request: Dict[str, Any], context
) -> AsyncGenerator[Dict, None]: ) -> AsyncGenerator[Dict, None]:
"""Generate text using AsyncOmni orchestrator. Currently supports text-to-text only.""" """Generate outputs using AsyncOmni orchestrator with OpenAI-compatible format.
Supports text-to-text and text-to-image generation based on stage configuration.
Returns OpenAI-compatible streaming chunks with detokenized text.
"""
request_id = context.id() request_id = context.id()
logger.debug(f"Omni Request ID: {request_id}") logger.debug(f"Omni Request ID: {request_id}")
# Extract token_ids from internal protocol format if self.use_vllm_tokenizer:
async for chunk in self._generate_openai_mode(request, context, request_id):
yield chunk
else:
async for chunk in self._generate_token_mode(request, context, request_id):
yield chunk
# Not used right now
async def _generate_token_mode(self, request, context, request_id):
"""
This mode returns token-ids as output
Text input -> Token-ids output
"""
token_ids = request.get("token_ids") token_ids = request.get("token_ids")
if not token_ids: prompt = OmniTokensPrompt(token_ids=token_ids)
logger.error(f"Request {request_id}: No token_ids found in request") num_output_tokens_so_far = 0
try:
async for stage_output in self.engine_client.generate(
prompt=prompt,
request_id=request_id,
):
vllm_output = stage_output.request_output
if not vllm_output.outputs:
logger.warning(f"Request {request_id} returned no outputs")
yield {
"finish_reason": "error: No outputs from vLLM engine",
"token_ids": [],
}
break
output = vllm_output.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = self._normalize_finish_reason(
output.finish_reason
)
out["completion_usage"] = self._build_completion_usage(vllm_output)
logger.debug(
f"Completed generation for request {request_id}: "
f"{next_total_toks} output tokens, finish_reason={output.finish_reason}"
)
if output.stop_reason:
out["stop_reason"] = output.stop_reason
yield out
num_output_tokens_so_far = next_total_toks
except GeneratorExit:
# Shutdown was triggered during generation
logger.info(f"Request {request_id} aborted due to shutdown")
raise
except Exception as e:
logger.error(f"Error during generation for request {request_id}: {e}")
yield { yield {
"finish_reason": "error: No token_ids in request", "finish_reason": f"error: {str(e)}",
"token_ids": [], "token_ids": [],
} }
return
logger.info( async def _generate_openai_mode(self, request, context, request_id):
f"Request {request_id}: Generating text for {len(token_ids)} input tokens" """
) This mode returns OpenAI-compatible streaming chunks
Text input -> Text output / Image output
"""
# (ayushag) TODO: Support all type of OmniPrompt. Right now it works for only text prompts
# (ayushag) TODO: Document all I/O formats from vllm omni
# OmniText prompt support additional negative prompts as well. need to support that as well.
# Support multimodal content as well. That will involve applying tokenizer to the prompt and loading images. Follow general multimodal support pattern.
prompt = self._extract_text_prompt(request)
prompt = OmniTextPrompt(prompt=prompt)
# Build sampling parameters from request # Build sampling parameters from request
sampling_params = self._build_sampling_params(request) # (ayushag) TODO: Need to add proper multi-stage sampling param support
sampling_params_list = [sampling_params] # sampling_params = self._build_sampling_params(request)
# sampling_params_list = [sampling_params]
tokens_prompt: TokensPrompt = { previous_text = ""
"prompt_token_ids": token_ids,
}
async with self._abort_monitor(context, request_id): async with self._abort_monitor(context, request_id):
try: try:
num_output_tokens_so_far = 0
async for stage_output in self.engine_client.generate( async for stage_output in self.engine_client.generate(
prompt=tokens_prompt, # Pass TokensPrompt format prompt=prompt,
request_id=request_id, request_id=request_id,
sampling_params_list=sampling_params_list, # sampling_params_list=sampling_params_list,
): ):
# stage_output is OmniRequestOutput
# For text generation: stage_output.request_output is a single vLLM RequestOutput
if ( if (
stage_output.final_output_type == "text" stage_output.final_output_type == "text"
and stage_output.request_output and stage_output.request_output
): ):
vllm_output = stage_output.request_output # Text generation (LLM stage)
chunk = self._format_text_chunk(
if not vllm_output.outputs: stage_output.request_output,
logger.warning(f"Request {request_id} returned no outputs") request_id,
yield { previous_text,
"finish_reason": "error: No outputs from vLLM engine", )
"token_ids": [], if chunk:
} # Update previous_text for delta calculation
break output = stage_output.request_output.outputs[0]
previous_text = output.text
output = vllm_output.outputs[0] yield chunk
next_total_toks = len(output.token_ids)
elif (
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]} stage_output.final_output_type == "image"
and stage_output.images
if output.finish_reason: ):
out["finish_reason"] = self._normalize_finish_reason( # Image generation (diffusion stage)
output.finish_reason chunk = self._format_image_chunk(
) stage_output.images,
out["completion_usage"] = self._build_completion_usage( request_id,
vllm_output )
) if chunk:
logger.debug( yield chunk
f"Completed generation for request {request_id}: "
f"{next_total_toks} output tokens, finish_reason={output.finish_reason}"
)
if output.stop_reason:
out["stop_reason"] = output.stop_reason
yield out
num_output_tokens_so_far = next_total_toks
except GeneratorExit: except GeneratorExit:
# Shutdown was triggered during generation
logger.info(f"Request {request_id} aborted due to shutdown") logger.info(f"Request {request_id} aborted due to shutdown")
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error during generation for request {request_id}: {e}") logger.error(f"Error during generation for request {request_id}: {e}")
yield { yield self._error_chunk(request_id, str(e))
"finish_reason": f"error: {str(e)}",
"token_ids": [], def _format_text_chunk(
self,
request_output,
request_id: str,
previous_text: str,
) -> Dict[str, Any] | None:
"""Format text output as OpenAI chat completion chunk."""
if not request_output.outputs:
return self._error_chunk(request_id, "No outputs from engine")
output = request_output.outputs[0]
# Calculate delta text (new text since last chunk)
delta_text = output.text[len(previous_text) :]
chunk = {
"id": request_id,
"created": int(time.time()),
"object": "chat.completion.chunk",
"model": self.config.served_model_name or self.config.model,
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": delta_text,
},
"finish_reason": self._normalize_finish_reason(output.finish_reason)
if output.finish_reason
else None,
} }
],
}
# Add usage on final chunk
if output.finish_reason:
chunk["usage"] = self._build_completion_usage(request_output)
return chunk
def _format_image_chunk(
self,
images: list,
request_id: str,
) -> Dict[str, Any] | None:
"""Format image output as OpenAI chat completion chunk with base64 data URLs."""
import base64
from io import BytesIO
if not images:
return self._error_chunk(request_id, "No images generated")
# Convert images to base64 data URLs
data_urls = []
for idx, img in enumerate(images):
# Convert PIL image to base64
buffer = BytesIO()
img.save(buffer, format="PNG")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
# Create data URL (can be opened directly in browser)
data_url = f"data:image/png;base64,{img_base64}"
data_urls.append(data_url)
logger.info(f"Generated image {idx} for request {request_id}")
chunk = {
"id": request_id,
"created": int(time.time()),
"object": "chat.completion.chunk",
"model": self.config.served_model_name or self.config.model,
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": [
{"type": "image_url", "image_url": {"url": data_url}}
for data_url in data_urls
],
},
"finish_reason": "stop",
}
],
}
return chunk
def _extract_text_prompt(self, request: Dict[str, Any]) -> str | None:
"""Extract text prompt from request."""
# OpenAI messages format - extract text content only
messages = request.get("messages", [])
# Assumes single user message
for message in messages:
if message.get("role") == "user":
return message.get("content")
return ""
def _build_sampling_params(self, request: Dict[str, Any]) -> SamplingParams: def _build_sampling_params(self, request: Dict[str, Any]) -> SamplingParams:
"""Build sampling params using shared handler utility.""" """Build sampling params using shared handler utility."""
...@@ -150,6 +291,25 @@ class OmniHandler(BaseWorkerHandler): ...@@ -150,6 +291,25 @@ class OmniHandler(BaseWorkerHandler):
request, self.default_sampling_params, self.model_max_len request, self.default_sampling_params, self.model_max_len
) )
def _error_chunk(self, request_id: str, error_message: str) -> Dict[str, Any]:
"""Create an error chunk in OpenAI format."""
return {
"id": request_id,
"created": int(time.time()),
"object": "chat.completion.chunk",
"model": self.config.served_model_name or self.config.model,
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": f"Error: {error_message}",
},
"finish_reason": "error",
}
],
}
def cleanup(self): def cleanup(self):
"""Cleanup AsyncOmni orchestrator resources.""" """Cleanup AsyncOmni orchestrator resources."""
try: try:
......
<!--
SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
-->
# [Experimental] Running Omni Models with vLLM
Dynamo supports omni (multimodal generation) models via the [vLLM-Omni](https://github.com/vllm-project/vllm-omni) backend. This enables multi-stage pipelines for tasks like text-to-text and text-to-image generation through an OpenAI-compatible API.
## Prerequisites
This guide assumes familiarity with deploying Dynamo with vLLM as described in [README.md](/docs/backends/vllm/README.md).
## Quick Start
### Text-to-Text
Launch an aggregated deployment (frontend + omni worker) using the provided script:
```bash
bash examples/backends/vllm/launch/agg_omni.sh
```
This starts `Qwen/Qwen2.5-Omni-7B` with a single-stage thinker config on one GPU. Override the model with:
```bash
bash examples/backends/vllm/launch/agg_omni.sh --model <your-model>
```
Test the deployment:
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen2.5-Omni-7B",
"messages": [{"role": "user", "content": "What is 2+2?"}],
"max_tokens": 50,
"stream": false
}'
```
### Text-to-Image
Text-to-image uses vLLM-Omni's built-in default stage configs (no custom YAML needed). Launch without a stage config path so vLLM-Omni loads the model's default multi-stage pipeline:
```bash
# Start frontend
python -m dynamo.frontend &
# Start omni worker (vLLM-Omni loads default stage configs for the model)
DYN_SYSTEM_PORT=8081 python -m dynamo.vllm \
--model <your-text-to-image-model> \
--omni \
--connector none
```
Images are returned as base64-encoded PNGs in the response:
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "<your-text-to-image-model>",
"messages": [{"role": "user", "content": "A cat sitting on a windowsill"}],
"stream": false
}'
```
The response contains image data URLs in the content field:
```json
{
"choices": [{
"delta": {
"content": [
{"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
]
}
}]
}
```
## Key Flags
| Flag | Description |
|------|-------------|
| `--omni` | Enable vLLM-Omni orchestrator (required) |
| `--stage-configs-path <path>` | Path to stage config YAML (optional; vLLM-Omni uses model defaults if omitted) |
| `--connector none` | Disable KV connector (recommended for omni) |
## Stage Configuration
Omni pipelines are configured via YAML stage configs. See [`examples/backends/vllm/launch/stage_configs/single_stage_llm.yaml`](/examples/backends/vllm/launch/stage_configs/single_stage_llm.yaml) for an example. Key fields:
- **`model_stage`**: Pipeline stage name (e.g., `thinker`, `talker`, `code2wav`)
- **`final_output_type`**: Output format — `text` or `image`
- **`is_comprehension`**: Whether this stage processes input text/multimodal content
For full documentation on stage config format, supported fields, and multi-stage pipeline examples, see the [vLLM-Omni Stage Configs documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/).
## Current Limitations
- Only text prompts are supported (no multimodal input yet)
- KV cache events are not published for omni workers
...@@ -42,6 +42,7 @@ ...@@ -42,6 +42,7 @@
backends/vllm/multi-node.md backends/vllm/multi-node.md
backends/vllm/prometheus.md backends/vllm/prometheus.md
backends/vllm/prompt-embeddings.md backends/vllm/prompt-embeddings.md
backends/vllm/vllm-omni.md
backends/sglang/expert-distribution-eplb.md backends/sglang/expert-distribution-eplb.md
backends/sglang/gpt-oss.md backends/sglang/gpt-oss.md
......
...@@ -36,7 +36,7 @@ stage_args: ...@@ -36,7 +36,7 @@ stage_args:
temperature: 0.0 temperature: 0.0
top_p: 1.0 top_p: 1.0
top_k: -1 top_k: -1
max_tokens: 2048 max_tokens: 20
repetition_penalty: 1.1 repetition_penalty: 1.1
seed: 42 seed: 42
detokenize: false # Token-based processing for Dynamo detokenize: true # Token-based processing for Dynamo
...@@ -375,6 +375,77 @@ pub enum ChatCompletionRequestToolMessageContent { ...@@ -375,6 +375,77 @@ pub enum ChatCompletionRequestToolMessageContent {
Array(Vec<ChatCompletionRequestToolMessageContentPart>), Array(Vec<ChatCompletionRequestToolMessageContentPart>),
} }
// Omni Specific Multimodal Content Types
// These types are used for assistant message responses that contain multimodal content
/// Response content part for text in assistant messages
#[derive(ToSchema, Clone, Serialize, Debug, Deserialize, PartialEq)]
pub struct ChatCompletionResponseContentPartText {
pub text: String,
}
/// Response content part for image URLs in assistant messages
#[derive(ToSchema, Clone, Serialize, Debug, Deserialize, PartialEq)]
pub struct ChatCompletionResponseContentPartImageUrl {
pub image_url: ImageUrlResponse,
}
/// Response content part for video URLs in assistant messages
#[derive(ToSchema, Clone, Serialize, Debug, Deserialize, PartialEq)]
pub struct ChatCompletionResponseContentPartVideoUrl {
pub video_url: VideoUrlResponse,
}
/// Response content part for audio URLs in assistant messages
#[derive(ToSchema, Clone, Serialize, Debug, Deserialize, PartialEq)]
pub struct ChatCompletionResponseContentPartAudioUrl {
pub audio_url: AudioUrlResponse,
}
/// Image URL in response messages (supports data URLs with base64)
#[derive(ToSchema, Clone, Serialize, Debug, Deserialize, PartialEq)]
pub struct ImageUrlResponse {
/// The URL of the image, either a URL or a data URL (data:image/png;base64,...)
pub url: String,
/// Optional detail level (for compatibility with OpenAI)
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
}
/// Video URL in response messages
#[derive(ToSchema, Clone, Serialize, Debug, Deserialize, PartialEq)]
pub struct VideoUrlResponse {
/// The URL of the video, either a URL or a data URL
pub url: String,
}
/// Audio URL in response messages
#[derive(ToSchema, Clone, Serialize, Debug, Deserialize, PartialEq)]
pub struct AudioUrlResponse {
/// The URL of the audio, either a URL or a data URL
pub url: String,
}
/// Content parts for assistant responses supporting multiple modalities (text, images, videos, audio)
#[derive(ToSchema, Clone, Serialize, Debug, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ChatCompletionResponseContentPart {
Text(ChatCompletionResponseContentPartText),
ImageUrl(ChatCompletionResponseContentPartImageUrl),
VideoUrl(ChatCompletionResponseContentPartVideoUrl),
AudioUrl(ChatCompletionResponseContentPartAudioUrl),
}
/// Assistant message content - can be a simple string or an array of content parts
#[derive(ToSchema, Clone, Serialize, Debug, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum ChatCompletionMessageContent {
/// Simple text content (backward compatible)
Text(String),
/// Array of content parts (for multimodal responses)
Parts(Vec<ChatCompletionResponseContentPart>),
}
#[derive(ToSchema, Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)] #[derive(ToSchema, Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)]
#[builder(name = "ChatCompletionRequestUserMessageArgs")] #[builder(name = "ChatCompletionRequestUserMessageArgs")]
#[builder(pattern = "mutable")] #[builder(pattern = "mutable")]
...@@ -486,9 +557,9 @@ pub struct ChatCompletionResponseMessageAudio { ...@@ -486,9 +557,9 @@ pub struct ChatCompletionResponseMessageAudio {
/// A chat completion message generated by the model. /// A chat completion message generated by the model.
#[derive(ToSchema, Debug, Deserialize, Serialize, Clone, PartialEq)] #[derive(ToSchema, Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct ChatCompletionResponseMessage { pub struct ChatCompletionResponseMessage {
/// The contents of the message. /// The contents of the message - can be a string or array of content parts
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>, pub content: Option<ChatCompletionMessageContent>,
/// The refusal message generated by the model. /// The refusal message generated by the model.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub refusal: Option<String>, pub refusal: Option<String>,
...@@ -1094,8 +1165,8 @@ pub struct ChatCompletionMessageToolCallChunk { ...@@ -1094,8 +1165,8 @@ pub struct ChatCompletionMessageToolCallChunk {
/// A chat completion delta generated by streamed model responses. /// A chat completion delta generated by streamed model responses.
#[derive(ToSchema, Debug, Deserialize, Serialize, Clone, PartialEq)] #[derive(ToSchema, Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct ChatCompletionStreamResponseDelta { pub struct ChatCompletionStreamResponseDelta {
/// The contents of the chunk message. /// The contents of the chunk message - can be a string or array of content parts
pub content: Option<String>, pub content: Option<ChatCompletionMessageContent>,
/// Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model. /// Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model.
#[deprecated] #[deprecated]
pub function_call: Option<FunctionCallStream>, pub function_call: Option<FunctionCallStream>,
......
...@@ -311,7 +311,7 @@ fn register_llm<'p>( ...@@ -311,7 +311,7 @@ fn register_llm<'p>(
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
// For TensorBased and Images models, skip HuggingFace downloads and register directly // For TensorBased and Images models, skip HuggingFace downloads and register directly
// These model types don't require tokenizers // Images models (vLLM-Omni) handle model loading internally, no tokenizer extraction needed
if is_tensor_based || is_images { if is_tensor_based || is_images {
let model_name = model_name.unwrap_or_else(|| source_path.clone()); let model_name = model_name.unwrap_or_else(|| source_path.clone());
let mut card = llm_rs::model_card::ModelDeploymentCard::with_name_only(&model_name); let mut card = llm_rs::model_card::ModelDeploymentCard::with_name_only(&model_name);
......
...@@ -405,7 +405,7 @@ impl ...@@ -405,7 +405,7 @@ impl
delta: dynamo_async_openai::types::ChatCompletionStreamResponseDelta{ delta: dynamo_async_openai::types::ChatCompletionStreamResponseDelta{
//role: c.choices[0].delta.role, //role: c.choices[0].delta.role,
role: Some(dynamo_async_openai::types::Role::Assistant), role: Some(dynamo_async_openai::types::Role::Assistant),
content: Some(from_assistant), content: Some(dynamo_async_openai::types::ChatCompletionMessageContent::Text(from_assistant)),
tool_calls: None, tool_calls: None,
refusal: None, refusal: None,
function_call: None, function_call: None,
......
...@@ -246,7 +246,8 @@ pub fn final_response_to_one_chunk_stream( ...@@ -246,7 +246,8 @@ pub fn final_response_to_one_chunk_stream(
mod tests { mod tests {
use super::*; use super::*;
use dynamo_async_openai::types::{ use dynamo_async_openai::types::{
ChatChoiceStream, ChatCompletionStreamResponseDelta, FinishReason, Role, ChatChoiceStream, ChatCompletionMessageContent, ChatCompletionStreamResponseDelta,
FinishReason, Role,
}; };
use futures::StreamExt; use futures::StreamExt;
use futures::stream; use futures::stream;
...@@ -261,7 +262,7 @@ mod tests { ...@@ -261,7 +262,7 @@ mod tests {
index, index,
delta: ChatCompletionStreamResponseDelta { delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant), role: Some(Role::Assistant),
content: Some(content), content: Some(ChatCompletionMessageContent::Text(content)),
tool_calls: None, tool_calls: None,
function_call: None, function_call: None,
refusal: None, refusal: None,
...@@ -337,7 +338,10 @@ mod tests { ...@@ -337,7 +338,10 @@ mod tests {
.as_ref() .as_ref()
.and_then(|d| d.choices.first()) .and_then(|d| d.choices.first())
.and_then(|c| c.delta.content.as_ref()) .and_then(|c| c.delta.content.as_ref())
.cloned() .and_then(|content| match content {
ChatCompletionMessageContent::Text(text) => Some(text.clone()),
ChatCompletionMessageContent::Parts(_) => None,
})
.unwrap_or_default() .unwrap_or_default()
} }
...@@ -423,7 +427,9 @@ mod tests { ...@@ -423,7 +427,9 @@ mod tests {
index: 0, index: 0,
delta: ChatCompletionStreamResponseDelta { delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant), role: Some(Role::Assistant),
content: Some("Content".to_string()), content: Some(ChatCompletionMessageContent::Text(
"Content".to_string(),
)),
tool_calls: None, tool_calls: None,
function_call: None, function_call: None,
refusal: None, refusal: None,
......
...@@ -169,6 +169,7 @@ impl ModelManager { ...@@ -169,6 +169,7 @@ impl ModelManager {
.chain(self.list_completions_models()) .chain(self.list_completions_models())
.chain(self.list_embeddings_models()) .chain(self.list_embeddings_models())
.chain(self.list_tensor_models()) .chain(self.list_tensor_models())
.chain(self.list_images_models())
.chain(self.list_prefill_models()) .chain(self.list_prefill_models())
.collect() .collect()
} }
...@@ -193,6 +194,10 @@ impl ModelManager { ...@@ -193,6 +194,10 @@ impl ModelManager {
self.prefill_engines.read().list() self.prefill_engines.read().list()
} }
pub fn list_images_models(&self) -> Vec<String> {
self.images_engines.read().list()
}
pub fn add_completions_model( pub fn add_completions_model(
&self, &self,
model: &str, model: &str,
......
...@@ -71,6 +71,7 @@ const ALL_MODEL_TYPES: &[ModelType] = &[ ...@@ -71,6 +71,7 @@ const ALL_MODEL_TYPES: &[ModelType] = &[
ModelType::Completions, ModelType::Completions,
ModelType::Embedding, ModelType::Embedding,
ModelType::TensorBased, ModelType::TensorBased,
ModelType::Images,
ModelType::Prefill, ModelType::Prefill,
]; ];
...@@ -284,12 +285,14 @@ impl ModelWatcher { ...@@ -284,12 +285,14 @@ impl ModelWatcher {
let completions_model_remove_err = self.manager.remove_completions_model(&model_name); let completions_model_remove_err = self.manager.remove_completions_model(&model_name);
let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name); let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name);
let tensor_model_remove_err = self.manager.remove_tensor_model(&model_name); let tensor_model_remove_err = self.manager.remove_tensor_model(&model_name);
let images_model_remove_err = self.manager.remove_images_model(&model_name);
let prefill_model_remove_err = self.manager.remove_prefill_model(&model_name); let prefill_model_remove_err = self.manager.remove_prefill_model(&model_name);
let mut chat_model_removed = false; let mut chat_model_removed = false;
let mut completions_model_removed = false; let mut completions_model_removed = false;
let mut embeddings_model_removed = false; let mut embeddings_model_removed = false;
let mut tensor_model_removed = false; let mut tensor_model_removed = false;
let mut images_model_removed = false;
let mut prefill_model_removed = false; let mut prefill_model_removed = false;
if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() { if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() {
...@@ -305,6 +308,9 @@ impl ModelWatcher { ...@@ -305,6 +308,9 @@ impl ModelWatcher {
if tensor_model_remove_err.is_ok() && self.manager.list_tensor_models().is_empty() { if tensor_model_remove_err.is_ok() && self.manager.list_tensor_models().is_empty() {
tensor_model_removed = true; tensor_model_removed = true;
} }
if images_model_remove_err.is_ok() && self.manager.list_images_models().is_empty() {
images_model_removed = true;
}
if prefill_model_remove_err.is_ok() && self.manager.list_prefill_models().is_empty() { if prefill_model_remove_err.is_ok() && self.manager.list_prefill_models().is_empty() {
prefill_model_removed = true; prefill_model_removed = true;
} }
...@@ -313,15 +319,17 @@ impl ModelWatcher { ...@@ -313,15 +319,17 @@ impl ModelWatcher {
&& !completions_model_removed && !completions_model_removed
&& !embeddings_model_removed && !embeddings_model_removed
&& !tensor_model_removed && !tensor_model_removed
&& !images_model_removed
&& !prefill_model_removed && !prefill_model_removed
{ {
tracing::debug!( tracing::debug!(
"No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, tensor_model_removed: {}, prefill_model_removed: {}", "No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, tensor_model_removed: {}, images_model_removed: {}, prefill_model_removed: {}",
model_name, model_name,
chat_model_removed, chat_model_removed,
completions_model_removed, completions_model_removed,
embeddings_model_removed, embeddings_model_removed,
tensor_model_removed, tensor_model_removed,
images_model_removed,
prefill_model_removed prefill_model_removed
); );
} else { } else {
...@@ -330,6 +338,7 @@ impl ModelWatcher { ...@@ -330,6 +338,7 @@ impl ModelWatcher {
|| (completions_model_removed && *model_type == ModelType::Completions) || (completions_model_removed && *model_type == ModelType::Completions)
|| (embeddings_model_removed && *model_type == ModelType::Embedding) || (embeddings_model_removed && *model_type == ModelType::Embedding)
|| (tensor_model_removed && *model_type == ModelType::TensorBased) || (tensor_model_removed && *model_type == ModelType::TensorBased)
|| (images_model_removed && *model_type == ModelType::Images)
|| (prefill_model_removed && *model_type == ModelType::Prefill)) || (prefill_model_removed && *model_type == ModelType::Prefill))
&& let Some(tx) = &self.model_update_tx && let Some(tx) = &self.model_update_tx
{ {
...@@ -614,7 +623,7 @@ impl ModelWatcher { ...@@ -614,7 +623,7 @@ impl ModelWatcher {
self.manager self.manager
.add_embeddings_model(card.name(), checksum, embedding_engine)?; .add_embeddings_model(card.name(), checksum, embedding_engine)?;
} else if card.model_input == ModelInput::Tensor && card.model_type.supports_tensor() { } else if card.model_input == ModelInput::Tensor && card.model_type.supports_tensor() {
// Case 5: Tensor + Tensor (non-LLM) // Case 6: Tensor + TensorBased (non-LLM)
// No KV cache concepts - not an LLM model // No KV cache concepts - not an LLM model
let push_router = PushRouter::< let push_router = PushRouter::<
NvCreateTensorRequest, NvCreateTensorRequest,
...@@ -627,18 +636,31 @@ impl ModelWatcher { ...@@ -627,18 +636,31 @@ impl ModelWatcher {
self.manager self.manager
.add_tensor_model(card.name(), checksum, engine)?; .add_tensor_model(card.name(), checksum, engine)?;
} else if card.model_input == ModelInput::Text && card.model_type.supports_images() { } else if card.model_input == ModelInput::Text && card.model_type.supports_images() {
// Case: Text + Images (diffusion models) // Case: Text + Images (e.g. vLLM-Omni, diffusion models)
// Takes text prompts as input, generates images // Takes text prompts as input, generates images. Images models also support
let push_router = PushRouter::< // chat completions (see model_type.rs as_endpoint_types).
let images_router = PushRouter::<
NvCreateImageRequest, NvCreateImageRequest,
Annotated<NvImagesResponse>, Annotated<NvImagesResponse>,
>::from_client_with_threshold( >::from_client_with_threshold(
client, self.router_config.router_mode, None, None client.clone(), self.router_config.router_mode, None, None
) )
.await?; .await?;
let engine = Arc::new(push_router);
self.manager self.manager
.add_images_model(card.name(), checksum, engine)?; .add_images_model(card.name(), checksum, Arc::new(images_router))?;
let chat_router = PushRouter::<
NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client_with_threshold(
client, self.router_config.router_mode, None, None
)
.await?;
self.manager.add_chat_completions_model(
card.name(),
checksum,
Arc::new(chat_router),
)?;
} else if card.model_type.supports_prefill() { } else if card.model_type.supports_prefill() {
// Case 6: Prefill // Case 6: Prefill
// Guardrail: Verify model_input is Tokens // Guardrail: Verify model_input is Tokens
......
...@@ -7,7 +7,7 @@ use crate::types::openai::chat_completions::{ ...@@ -7,7 +7,7 @@ use crate::types::openai::chat_completions::{
NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine, NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine,
}; };
use anyhow::Context as _; use anyhow::Context as _;
use dynamo_async_openai::types::FinishReason; use dynamo_async_openai::types::{ChatCompletionMessageContent, FinishReason};
use dynamo_runtime::{DistributedRuntime, pipeline::Context, runtime::CancellationToken}; use dynamo_runtime::{DistributedRuntime, pipeline::Context, runtime::CancellationToken};
use futures::StreamExt; use futures::StreamExt;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
...@@ -241,7 +241,15 @@ async fn evaluate( ...@@ -241,7 +241,15 @@ async fn evaluate(
let choice = data.choices.first(); let choice = data.choices.first();
let chat_comp = choice.as_ref().unwrap(); let chat_comp = choice.as_ref().unwrap();
if let Some(c) = &chat_comp.delta.content { if let Some(c) = &chat_comp.delta.content {
output += c; match c {
ChatCompletionMessageContent::Text(text) => {
output += text;
}
ChatCompletionMessageContent::Parts(_) => {
// Multimodal content - skip for now in batch processing
// (ayushag) TODO: Handle multimodal content in batch mode
}
}
} }
entry.finish_reason = chat_comp.finish_reason; entry.finish_reason = chat_comp.finish_reason;
if chat_comp.finish_reason.is_some() { if chat_comp.finish_reason.is_some() {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use crate::entrypoint::EngineConfig;
use crate::entrypoint::input::common;
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
use crate::types::openai::chat_completions::{ use crate::types::openai::chat_completions::{
NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine, NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine,
}; };
use dynamo_async_openai::types::ChatCompletionMessageContent;
use dynamo_runtime::DistributedRuntime; use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::pipeline::Context; use dynamo_runtime::pipeline::Context;
use futures::StreamExt; use futures::StreamExt;
use std::io::{ErrorKind, Write}; use std::io::{ErrorKind, Write};
use crate::entrypoint::EngineConfig;
use crate::entrypoint::input::common;
/// Max response tokens for each single query. Must be less than model context size. /// Max response tokens for each single query. Must be less than model context size.
/// TODO: Cmd line flag to overwrite this /// TODO: Cmd line flag to overwrite this
const MAX_TOKENS: u32 = 8192; const MAX_TOKENS: u32 = 8192;
...@@ -140,9 +140,19 @@ async fn main_loop( ...@@ -140,9 +140,19 @@ async fn main_loop(
let entry = data.choices.first(); let entry = data.choices.first();
let chat_comp = entry.as_ref().unwrap(); let chat_comp = entry.as_ref().unwrap();
if let Some(c) = &chat_comp.delta.content { if let Some(c) = &chat_comp.delta.content {
let _ = stdout.write(c.as_bytes()); match c {
let _ = stdout.flush(); ChatCompletionMessageContent::Text(text) => {
assistant_message += c; let _ = stdout.write(text.as_bytes());
let _ = stdout.flush();
assistant_message += text;
}
ChatCompletionMessageContent::Parts(_) => {
// (ayushag) TODO: Handle multimodal content for multiturn conversations
// Multimodal content - for now just print a placeholder
let _ = stdout.write(b"[multimodal content]");
let _ = stdout.flush();
}
}
} }
if let Some(reason) = chat_comp.finish_reason { if let Some(reason) = chat_comp.finish_reason {
tracing::trace!("finish reason: {reason:?}"); tracing::trace!("finish reason: {reason:?}");
......
...@@ -384,6 +384,8 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -384,6 +384,8 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
token_ids: vec![token_id], token_ids: vec![token_id],
tokens: None, // Let backend handle detokenization tokens: None, // Let backend handle detokenization
text: None, text: None,
output_type: Default::default(),
content_parts: None,
cum_log_probs: None, cum_log_probs: None,
log_probs: None, log_probs: None,
top_logprobs: None, top_logprobs: None,
......
...@@ -126,8 +126,10 @@ impl ModelType { ...@@ -126,8 +126,10 @@ impl ModelType {
if self.contains(Self::Embedding) { if self.contains(Self::Embedding) {
endpoint_types.push(crate::endpoint_type::EndpointType::Embedding); endpoint_types.push(crate::endpoint_type::EndpointType::Embedding);
} }
// Images models support both chat and completions endpoints
if self.contains(Self::Images) { if self.contains(Self::Images) {
endpoint_types.push(crate::endpoint_type::EndpointType::Images); endpoint_types.push(crate::endpoint_type::EndpointType::Images);
endpoint_types.push(crate::endpoint_type::EndpointType::Chat);
} }
// [gluo NOTE] ModelType::Tensor doesn't map to any endpoint type, // [gluo NOTE] ModelType::Tensor doesn't map to any endpoint type,
// current use of endpoint type is LLM specific and so does the HTTP // current use of endpoint type is LLM specific and so does the HTTP
......
...@@ -953,7 +953,11 @@ mod tests { ...@@ -953,7 +953,11 @@ mod tests {
choices: vec![ChatChoiceStream { choices: vec![ChatChoiceStream {
index: 0, index: 0,
delta: ChatCompletionStreamResponseDelta { delta: ChatCompletionStreamResponseDelta {
content: Some("test".to_string()), content: Some(
dynamo_async_openai::types::ChatCompletionMessageContent::Text(
"test".to_string(),
),
),
function_call: None, function_call: None,
tool_calls: None, tool_calls: None,
role: Some(Role::Assistant), role: Some(Role::Assistant),
...@@ -987,7 +991,11 @@ mod tests { ...@@ -987,7 +991,11 @@ mod tests {
.map(|(i, token_logprobs)| ChatChoiceStream { .map(|(i, token_logprobs)| ChatChoiceStream {
index: i as u32, index: i as u32,
delta: ChatCompletionStreamResponseDelta { delta: ChatCompletionStreamResponseDelta {
content: Some("test".to_string()), content: Some(
dynamo_async_openai::types::ChatCompletionMessageContent::Text(
"test".to_string(),
),
),
function_call: None, function_call: None,
tool_calls: None, tool_calls: None,
role: Some(Role::Assistant), role: Some(Role::Assistant),
...@@ -1337,7 +1345,11 @@ mod tests { ...@@ -1337,7 +1345,11 @@ mod tests {
choices: vec![ChatChoiceStream { choices: vec![ChatChoiceStream {
index: 0, index: 0,
delta: ChatCompletionStreamResponseDelta { delta: ChatCompletionStreamResponseDelta {
content: Some("test".to_string()), content: Some(
dynamo_async_openai::types::ChatCompletionMessageContent::Text(
"test".to_string(),
),
),
function_call: None, function_call: None,
tool_calls: None, tool_calls: None,
role: Some(Role::Assistant), role: Some(Role::Assistant),
......
...@@ -945,14 +945,23 @@ impl OpenAIPreprocessor { ...@@ -945,14 +945,23 @@ impl OpenAIPreprocessor {
response.map_data(|mut data| { response.map_data(|mut data| {
// Process all choices, not just the first one // Process all choices, not just the first one
for choice in data.choices.iter_mut() { for choice in data.choices.iter_mut() {
if let Some(text) = choice.delta.content.as_ref() { // Reasoning parsing only applies to text content
if let Some(
dynamo_async_openai::types::ChatCompletionMessageContent::Text(
text,
),
) = choice.delta.content.as_ref()
{
let parser_result = let parser_result =
parser.parse_reasoning_streaming_incremental(text, &[]); parser.parse_reasoning_streaming_incremental(text, &[]);
// Update this specific choice with parsed content // Update this specific choice with parsed content
choice.delta.content = parser_result.get_some_normal_text(); choice.delta.content = parser_result.get_some_normal_text().map(
dynamo_async_openai::types::ChatCompletionMessageContent::Text,
);
choice.delta.reasoning_content = parser_result.get_some_reasoning(); choice.delta.reasoning_content = parser_result.get_some_reasoning();
} }
// For multimodal content, pass through unchanged
} }
Ok(data) Ok(data)
}) })
......
...@@ -13,6 +13,45 @@ use dynamo_runtime::protocols::maybe_error::MaybeError; ...@@ -13,6 +13,45 @@ use dynamo_runtime::protocols::maybe_error::MaybeError;
pub type TokenType = Option<String>; pub type TokenType = Option<String>;
pub type LogProbs = Vec<f64>; pub type LogProbs = Vec<f64>;
/// Output type discriminator for different modalities
#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum OutputType {
#[default]
Text,
Image,
Video,
Audio,
}
/// Image URL data for responses
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ImageUrlData {
pub url: String,
}
/// Video URL data for responses
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct VideoUrlData {
pub url: String,
}
/// Audio URL data for responses
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct AudioUrlData {
pub url: String,
}
/// Content part for multimodal outputs (internal representation)
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPart {
Text { text: String },
ImageUrl { image_url: ImageUrlData },
VideoUrl { video_url: VideoUrlData },
AudioUrl { audio_url: AudioUrlData },
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct TopLogprob { pub struct TopLogprob {
pub rank: u32, pub rank: u32,
...@@ -86,6 +125,14 @@ pub struct LLMEngineOutput { ...@@ -86,6 +125,14 @@ pub struct LLMEngineOutput {
// decoded text - // decoded text -
pub text: Option<String>, pub text: Option<String>,
/// Output type discriminator (text, image, video, audio)
#[serde(default)]
pub output_type: OutputType,
/// Multimodal content parts (for non-text outputs)
#[serde(default, skip_serializing_if = "Option::is_none")]
pub content_parts: Option<Vec<ContentPart>>,
/// cumulative log probabilities /// cumulative log probabilities
pub cum_log_probs: Option<f64>, pub cum_log_probs: Option<f64>,
...@@ -125,6 +172,8 @@ impl LLMEngineOutput { ...@@ -125,6 +172,8 @@ impl LLMEngineOutput {
token_ids: vec![], token_ids: vec![],
tokens: None, tokens: None,
text: None, text: None,
output_type: OutputType::default(),
content_parts: None,
cum_log_probs: None, cum_log_probs: None,
log_probs: None, log_probs: None,
top_logprobs: None, top_logprobs: None,
...@@ -142,6 +191,8 @@ impl LLMEngineOutput { ...@@ -142,6 +191,8 @@ impl LLMEngineOutput {
token_ids: vec![], token_ids: vec![],
tokens: None, tokens: None,
text: None, text: None,
output_type: OutputType::default(),
content_parts: None,
cum_log_probs: None, cum_log_probs: None,
log_probs: None, log_probs: None,
finish_reason: Some(FinishReason::Stop), finish_reason: Some(FinishReason::Stop),
...@@ -159,6 +210,8 @@ impl LLMEngineOutput { ...@@ -159,6 +210,8 @@ impl LLMEngineOutput {
token_ids: vec![], token_ids: vec![],
tokens: None, tokens: None,
text: None, text: None,
output_type: OutputType::default(),
content_parts: None,
cum_log_probs: None, cum_log_probs: None,
log_probs: None, log_probs: None,
top_logprobs: None, top_logprobs: None,
...@@ -176,6 +229,8 @@ impl LLMEngineOutput { ...@@ -176,6 +229,8 @@ impl LLMEngineOutput {
token_ids: vec![], token_ids: vec![],
tokens: None, tokens: None,
text: None, text: None,
output_type: OutputType::default(),
content_parts: None,
cum_log_probs: None, cum_log_probs: None,
log_probs: None, log_probs: None,
top_logprobs: None, top_logprobs: None,
......
...@@ -12,7 +12,7 @@ use crate::protocols::{ ...@@ -12,7 +12,7 @@ use crate::protocols::{
openai::ParsingOptions, openai::ParsingOptions,
}; };
use dynamo_async_openai::types::StopReason; use dynamo_async_openai::types::{ChatCompletionMessageContent, StopReason};
use dynamo_runtime::engine::DataStream; use dynamo_runtime::engine::DataStream;
/// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single /// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single
...@@ -59,6 +59,9 @@ struct DeltaChoice { ...@@ -59,6 +59,9 @@ struct DeltaChoice {
/// Optional reasoning content for the chat choice. /// Optional reasoning content for the chat choice.
reasoning_content: Option<String>, reasoning_content: Option<String>,
/// Accumulated content parts for multimodal responses
content_parts: Vec<dynamo_async_openai::types::ChatCompletionResponseContentPart>,
} }
impl Default for DeltaAggregator { impl Default for DeltaAggregator {
...@@ -166,10 +169,18 @@ impl DeltaAggregator { ...@@ -166,10 +169,18 @@ impl DeltaAggregator {
logprobs: None, logprobs: None,
tool_calls: None, tool_calls: None,
reasoning_content: None, reasoning_content: None,
content_parts: Vec::new(),
}); });
// Append content if available. // Handle content based on type
if let Some(content) = &choice.delta.content { if let Some(content) = &choice.delta.content {
state_choice.text.push_str(content); match content {
ChatCompletionMessageContent::Text(text) => {
state_choice.text.push_str(text);
}
ChatCompletionMessageContent::Parts(parts) => {
state_choice.content_parts.extend(parts.clone());
}
}
} }
if let Some(reasoning_content) = &choice.delta.reasoning_content { if let Some(reasoning_content) = &choice.delta.reasoning_content {
...@@ -289,14 +300,21 @@ impl From<DeltaChoice> for dynamo_async_openai::types::ChatChoice { ...@@ -289,14 +300,21 @@ impl From<DeltaChoice> for dynamo_async_openai::types::ChatChoice {
delta.finish_reason delta.finish_reason
}; };
// Determine content format based on what we accumulated
let content = if !delta.content_parts.is_empty() {
// Multimodal response with content parts
Some(ChatCompletionMessageContent::Parts(delta.content_parts))
} else if !delta.text.is_empty() {
// Text-only response (backward compatible)
Some(ChatCompletionMessageContent::Text(delta.text))
} else {
None
};
dynamo_async_openai::types::ChatChoice { dynamo_async_openai::types::ChatChoice {
message: dynamo_async_openai::types::ChatCompletionResponseMessage { message: dynamo_async_openai::types::ChatCompletionResponseMessage {
role: delta.role.expect("delta should have a Role"), role: delta.role.expect("delta should have a Role"),
content: if delta.text.is_empty() { content,
None
} else {
Some(delta.text)
},
tool_calls: delta.tool_calls, tool_calls: delta.tool_calls,
refusal: None, refusal: None,
function_call: None, function_call: None,
...@@ -396,7 +414,7 @@ mod tests { ...@@ -396,7 +414,7 @@ mod tests {
}; };
let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta { let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
content: Some(text.to_string()), content: Some(ChatCompletionMessageContent::Text(text.to_string())),
function_call: None, function_call: None,
tool_calls: tool_call_chunks, tool_calls: tool_call_chunks,
role, role,
...@@ -496,7 +514,10 @@ mod tests { ...@@ -496,7 +514,10 @@ mod tests {
assert_eq!(response.choices.len(), 1); assert_eq!(response.choices.len(), 1);
let choice = &response.choices[0]; let choice = &response.choices[0];
assert_eq!(choice.index, 0); assert_eq!(choice.index, 0);
assert_eq!(choice.message.content.as_ref().unwrap(), "Hello,"); assert_eq!(
choice.message.content.as_ref().unwrap(),
&ChatCompletionMessageContent::Text("Hello,".to_string())
);
assert!(choice.finish_reason.is_none()); assert!(choice.finish_reason.is_none());
assert_eq!(choice.message.role, dynamo_async_openai::types::Role::User); assert_eq!(choice.message.role, dynamo_async_openai::types::Role::User);
assert!(response.service_tier.is_none()); assert!(response.service_tier.is_none());
...@@ -539,7 +560,10 @@ mod tests { ...@@ -539,7 +560,10 @@ mod tests {
assert_eq!(response.choices.len(), 1); assert_eq!(response.choices.len(), 1);
let choice = &response.choices[0]; let choice = &response.choices[0];
assert_eq!(choice.index, 0); assert_eq!(choice.index, 0);
assert_eq!(choice.message.content.as_ref().unwrap(), "Hello, world!"); assert_eq!(
choice.message.content.as_ref().unwrap(),
&ChatCompletionMessageContent::Text("Hello, world!".to_string())
);
assert_eq!( assert_eq!(
choice.finish_reason, choice.finish_reason,
Some(dynamo_async_openai::types::FinishReason::Stop) Some(dynamo_async_openai::types::FinishReason::Stop)
...@@ -605,7 +629,12 @@ mod tests { ...@@ -605,7 +629,12 @@ mod tests {
let choice = &response.choices[0]; let choice = &response.choices[0];
assert_eq!(choice.index, 0); assert_eq!(choice.index, 0);
assert_eq!(choice.message.content.as_deref(), Some("Hello world")); assert_eq!(
choice.message.content.as_ref(),
Some(&ChatCompletionMessageContent::Text(
"Hello world".to_string()
))
);
assert_eq!( assert_eq!(
choice.finish_reason, choice.finish_reason,
Some(dynamo_async_openai::types::FinishReason::Stop) Some(dynamo_async_openai::types::FinishReason::Stop)
...@@ -630,7 +659,7 @@ mod tests { ...@@ -630,7 +659,7 @@ mod tests {
index: 0, index: 0,
delta: dynamo_async_openai::types::ChatCompletionStreamResponseDelta { delta: dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
role: Some(dynamo_async_openai::types::Role::Assistant), role: Some(dynamo_async_openai::types::Role::Assistant),
content: Some("Choice 0".to_string()), content: Some(ChatCompletionMessageContent::Text("Choice 0".to_string())),
function_call: None, function_call: None,
tool_calls: None, tool_calls: None,
refusal: None, refusal: None,
...@@ -644,7 +673,7 @@ mod tests { ...@@ -644,7 +673,7 @@ mod tests {
index: 1, index: 1,
delta: dynamo_async_openai::types::ChatCompletionStreamResponseDelta { delta: dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
role: Some(dynamo_async_openai::types::Role::Assistant), role: Some(dynamo_async_openai::types::Role::Assistant),
content: Some("Choice 1".to_string()), content: Some(ChatCompletionMessageContent::Text("Choice 1".to_string())),
function_call: None, function_call: None,
tool_calls: None, tool_calls: None,
refusal: None, refusal: None,
...@@ -680,7 +709,10 @@ mod tests { ...@@ -680,7 +709,10 @@ mod tests {
response.choices.sort_by(|a, b| a.index.cmp(&b.index)); // Ensure the choices are ordered response.choices.sort_by(|a, b| a.index.cmp(&b.index)); // Ensure the choices are ordered
let choice0 = &response.choices[0]; let choice0 = &response.choices[0];
assert_eq!(choice0.index, 0); assert_eq!(choice0.index, 0);
assert_eq!(choice0.message.content.as_ref().unwrap(), "Choice 0"); assert_eq!(
choice0.message.content.as_ref().unwrap(),
&ChatCompletionMessageContent::Text("Choice 0".to_string())
);
assert_eq!( assert_eq!(
choice0.finish_reason, choice0.finish_reason,
Some(dynamo_async_openai::types::FinishReason::Stop) Some(dynamo_async_openai::types::FinishReason::Stop)
...@@ -692,7 +724,10 @@ mod tests { ...@@ -692,7 +724,10 @@ mod tests {
let choice1 = &response.choices[1]; let choice1 = &response.choices[1];
assert_eq!(choice1.index, 1); assert_eq!(choice1.index, 1);
assert_eq!(choice1.message.content.as_ref().unwrap(), "Choice 1"); assert_eq!(
choice1.message.content.as_ref().unwrap(),
&ChatCompletionMessageContent::Text("Choice 1".to_string())
);
assert_eq!( assert_eq!(
choice1.finish_reason, choice1.finish_reason,
Some(dynamo_async_openai::types::FinishReason::Stop) Some(dynamo_async_openai::types::FinishReason::Stop)
...@@ -962,7 +997,9 @@ mod tests { ...@@ -962,7 +997,9 @@ mod tests {
assert_eq!( assert_eq!(
choice.message.content.as_ref().unwrap(), choice.message.content.as_ref().unwrap(),
"Hey Dude ! What's the weather in San Francisco in Fahrenheit?" &ChatCompletionMessageContent::Text(
"Hey Dude ! What's the weather in San Francisco in Fahrenheit?".to_string()
)
); );
// The finish_reason should be ToolCalls // The finish_reason should be ToolCalls
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment