"examples/vscode:/vscode.git/clone" did not exist on "93702e445622e9710b0cb154ca747f21bdfc52de"
Unverified Commit 9f76d060 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

feat: text to image vLLM Omni (#5912)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent d14d6ff4
......@@ -265,7 +265,7 @@ def parse_args() -> Config:
"--stage-configs-path",
type=str,
default=None,
help="Path to vLLM-Omni stage configuration YAML file. Required for --omni.",
help="Path to vLLM-Omni stage configuration YAML file for --omni mode (optional).",
)
parser.add_argument(
"--store-kv",
......@@ -388,9 +388,9 @@ def parse_args() -> Config:
)
# Validate omni worker requirements
if args.omni and not args.stage_configs_path:
if args.stage_configs_path and not args.omni:
raise ValueError(
"--stage-configs-path is required when using --omni. "
"--stage-configs-path is only allowed when using --omni. "
"Specify a YAML file containing stage configurations for the multi-stage pipeline."
)
......@@ -452,7 +452,8 @@ def parse_args() -> Config:
config.request_plane = args.request_plane
config.event_plane = args.event_plane
config.enable_local_indexer = not args.durable_kv_events
config.use_vllm_tokenizer = args.use_vllm_tokenizer
# For omni mode, use vLLM (AsyncOmni) tokenizer on backend
config.use_vllm_tokenizer = args.use_vllm_tokenizer or args.omni
config.sleep_mode_level = args.sleep_mode_level
# use_kv_events is set later in overwrite_args() based on kv_events_config
......
......@@ -1215,18 +1215,12 @@ async def init_omni(
component = runtime.namespace(config.namespace).component(config.component)
generate_endpoint = component.endpoint(config.endpoint)
# Load default sampling params from model config (same as other workers)
default_sampling_params = (
config.engine_args.create_model_config().get_diff_sampling_param()
)
logger.info(f"Loaded default sampling params: {default_sampling_params}")
# Initialize OmniHandler with Omni orchestrator
handler = OmniHandler(
runtime=runtime,
component=component,
config=config,
default_sampling_params=default_sampling_params,
default_sampling_params={},
shutdown_event=shutdown_event,
)
......@@ -1241,11 +1235,9 @@ async def init_omni(
return
# TODO: extend for multi-stage pipelines
# Register as Chat endpoint for text-to-text generation
# Use Tokens input since we're doing token-based processing
await register_llm(
ModelInput.Tokens,
ModelType.Chat,
ModelInput.Text,
ModelType.Images,
generate_endpoint,
config.model,
config.served_model_name,
......@@ -1254,7 +1246,6 @@ async def init_omni(
logger.info("Starting to serve Omni worker endpoint...")
# Create health check payload (extracts BOS token from AsyncOmni)
health_check_payload = (
await VllmOmniHealthCheckPayload.create(handler.engine_client)
).to_dict()
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Omni handler for text-to-text generation using vLLM-Omni orchestrator."""
import asyncio
import logging
import time
from typing import Any, AsyncGenerator, Dict
from vllm import SamplingParams
from vllm.inputs import TokensPrompt
from vllm_omni.entrypoints import AsyncOmni
from vllm_omni.inputs.data import OmniTextPrompt, OmniTokensPrompt
from dynamo.vllm.handlers import BaseWorkerHandler, build_sampling_params
......@@ -32,10 +30,6 @@ class OmniHandler(BaseWorkerHandler):
f"Initializing OmniHandler for multi-stage pipelines with model: {config.model}"
)
# Initialize AsyncOmni with stage configuration
# Note: stage_configs_path is validated as required in args.py
logger.info(f"Using stage config from: {config.stage_configs_path}")
omni_kwargs = {
"model": config.model,
"trust_remote_code": config.engine_args.trust_remote_code,
......@@ -54,95 +48,242 @@ class OmniHandler(BaseWorkerHandler):
self.config = config
self.model_max_len = config.engine_args.max_model_len
self.shutdown_event = shutdown_event
self.use_vllm_tokenizer = config.use_vllm_tokenizer
logger.info("OmniHandler initialized successfully for text-to-text generation")
async def generate(
self, request: Dict[str, Any], context
) -> AsyncGenerator[Dict, None]:
"""Generate text using AsyncOmni orchestrator. Currently supports text-to-text only."""
"""Generate outputs using AsyncOmni orchestrator with OpenAI-compatible format.
Supports text-to-text and text-to-image generation based on stage configuration.
Returns OpenAI-compatible streaming chunks with detokenized text.
"""
request_id = context.id()
logger.debug(f"Omni Request ID: {request_id}")
# Extract token_ids from internal protocol format
if self.use_vllm_tokenizer:
async for chunk in self._generate_openai_mode(request, context, request_id):
yield chunk
else:
async for chunk in self._generate_token_mode(request, context, request_id):
yield chunk
# Not used right now
async def _generate_token_mode(self, request, context, request_id):
"""
This mode returns token-ids as output
Text input -> Token-ids output
"""
token_ids = request.get("token_ids")
if not token_ids:
logger.error(f"Request {request_id}: No token_ids found in request")
prompt = OmniTokensPrompt(token_ids=token_ids)
num_output_tokens_so_far = 0
try:
async for stage_output in self.engine_client.generate(
prompt=prompt,
request_id=request_id,
):
vllm_output = stage_output.request_output
if not vllm_output.outputs:
logger.warning(f"Request {request_id} returned no outputs")
yield {
"finish_reason": "error: No outputs from vLLM engine",
"token_ids": [],
}
break
output = vllm_output.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = self._normalize_finish_reason(
output.finish_reason
)
out["completion_usage"] = self._build_completion_usage(vllm_output)
logger.debug(
f"Completed generation for request {request_id}: "
f"{next_total_toks} output tokens, finish_reason={output.finish_reason}"
)
if output.stop_reason:
out["stop_reason"] = output.stop_reason
yield out
num_output_tokens_so_far = next_total_toks
except GeneratorExit:
# Shutdown was triggered during generation
logger.info(f"Request {request_id} aborted due to shutdown")
raise
except Exception as e:
logger.error(f"Error during generation for request {request_id}: {e}")
yield {
"finish_reason": "error: No token_ids in request",
"finish_reason": f"error: {str(e)}",
"token_ids": [],
}
return
logger.info(
f"Request {request_id}: Generating text for {len(token_ids)} input tokens"
)
async def _generate_openai_mode(self, request, context, request_id):
"""
This mode returns OpenAI-compatible streaming chunks
Text input -> Text output / Image output
"""
# (ayushag) TODO: Support all type of OmniPrompt. Right now it works for only text prompts
# (ayushag) TODO: Document all I/O formats from vllm omni
# OmniText prompt support additional negative prompts as well. need to support that as well.
# Support multimodal content as well. That will involve applying tokenizer to the prompt and loading images. Follow general multimodal support pattern.
prompt = self._extract_text_prompt(request)
prompt = OmniTextPrompt(prompt=prompt)
# Build sampling parameters from request
sampling_params = self._build_sampling_params(request)
sampling_params_list = [sampling_params]
# (ayushag) TODO: Need to add proper multi-stage sampling param support
# sampling_params = self._build_sampling_params(request)
# sampling_params_list = [sampling_params]
tokens_prompt: TokensPrompt = {
"prompt_token_ids": token_ids,
}
previous_text = ""
async with self._abort_monitor(context, request_id):
try:
num_output_tokens_so_far = 0
async for stage_output in self.engine_client.generate(
prompt=tokens_prompt, # Pass TokensPrompt format
prompt=prompt,
request_id=request_id,
sampling_params_list=sampling_params_list,
# sampling_params_list=sampling_params_list,
):
# stage_output is OmniRequestOutput
# For text generation: stage_output.request_output is a single vLLM RequestOutput
if (
stage_output.final_output_type == "text"
and stage_output.request_output
):
vllm_output = stage_output.request_output
if not vllm_output.outputs:
logger.warning(f"Request {request_id} returned no outputs")
yield {
"finish_reason": "error: No outputs from vLLM engine",
"token_ids": [],
}
break
output = vllm_output.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = self._normalize_finish_reason(
output.finish_reason
)
out["completion_usage"] = self._build_completion_usage(
vllm_output
)
logger.debug(
f"Completed generation for request {request_id}: "
f"{next_total_toks} output tokens, finish_reason={output.finish_reason}"
)
if output.stop_reason:
out["stop_reason"] = output.stop_reason
yield out
num_output_tokens_so_far = next_total_toks
# Text generation (LLM stage)
chunk = self._format_text_chunk(
stage_output.request_output,
request_id,
previous_text,
)
if chunk:
# Update previous_text for delta calculation
output = stage_output.request_output.outputs[0]
previous_text = output.text
yield chunk
elif (
stage_output.final_output_type == "image"
and stage_output.images
):
# Image generation (diffusion stage)
chunk = self._format_image_chunk(
stage_output.images,
request_id,
)
if chunk:
yield chunk
except GeneratorExit:
# Shutdown was triggered during generation
logger.info(f"Request {request_id} aborted due to shutdown")
raise
except Exception as e:
logger.error(f"Error during generation for request {request_id}: {e}")
yield {
"finish_reason": f"error: {str(e)}",
"token_ids": [],
yield self._error_chunk(request_id, str(e))
def _format_text_chunk(
self,
request_output,
request_id: str,
previous_text: str,
) -> Dict[str, Any] | None:
"""Format text output as OpenAI chat completion chunk."""
if not request_output.outputs:
return self._error_chunk(request_id, "No outputs from engine")
output = request_output.outputs[0]
# Calculate delta text (new text since last chunk)
delta_text = output.text[len(previous_text) :]
chunk = {
"id": request_id,
"created": int(time.time()),
"object": "chat.completion.chunk",
"model": self.config.served_model_name or self.config.model,
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": delta_text,
},
"finish_reason": self._normalize_finish_reason(output.finish_reason)
if output.finish_reason
else None,
}
],
}
# Add usage on final chunk
if output.finish_reason:
chunk["usage"] = self._build_completion_usage(request_output)
return chunk
def _format_image_chunk(
self,
images: list,
request_id: str,
) -> Dict[str, Any] | None:
"""Format image output as OpenAI chat completion chunk with base64 data URLs."""
import base64
from io import BytesIO
if not images:
return self._error_chunk(request_id, "No images generated")
# Convert images to base64 data URLs
data_urls = []
for idx, img in enumerate(images):
# Convert PIL image to base64
buffer = BytesIO()
img.save(buffer, format="PNG")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
# Create data URL (can be opened directly in browser)
data_url = f"data:image/png;base64,{img_base64}"
data_urls.append(data_url)
logger.info(f"Generated image {idx} for request {request_id}")
chunk = {
"id": request_id,
"created": int(time.time()),
"object": "chat.completion.chunk",
"model": self.config.served_model_name or self.config.model,
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": [
{"type": "image_url", "image_url": {"url": data_url}}
for data_url in data_urls
],
},
"finish_reason": "stop",
}
],
}
return chunk
def _extract_text_prompt(self, request: Dict[str, Any]) -> str | None:
"""Extract text prompt from request."""
# OpenAI messages format - extract text content only
messages = request.get("messages", [])
# Assumes single user message
for message in messages:
if message.get("role") == "user":
return message.get("content")
return ""
def _build_sampling_params(self, request: Dict[str, Any]) -> SamplingParams:
"""Build sampling params using shared handler utility."""
......@@ -150,6 +291,25 @@ class OmniHandler(BaseWorkerHandler):
request, self.default_sampling_params, self.model_max_len
)
def _error_chunk(self, request_id: str, error_message: str) -> Dict[str, Any]:
"""Create an error chunk in OpenAI format."""
return {
"id": request_id,
"created": int(time.time()),
"object": "chat.completion.chunk",
"model": self.config.served_model_name or self.config.model,
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": f"Error: {error_message}",
},
"finish_reason": "error",
}
],
}
def cleanup(self):
"""Cleanup AsyncOmni orchestrator resources."""
try:
......
<!--
SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
-->
# [Experimental] Running Omni Models with vLLM
Dynamo supports omni (multimodal generation) models via the [vLLM-Omni](https://github.com/vllm-project/vllm-omni) backend. This enables multi-stage pipelines for tasks like text-to-text and text-to-image generation through an OpenAI-compatible API.
## Prerequisites
This guide assumes familiarity with deploying Dynamo with vLLM as described in [README.md](/docs/backends/vllm/README.md).
## Quick Start
### Text-to-Text
Launch an aggregated deployment (frontend + omni worker) using the provided script:
```bash
bash examples/backends/vllm/launch/agg_omni.sh
```
This starts `Qwen/Qwen2.5-Omni-7B` with a single-stage thinker config on one GPU. Override the model with:
```bash
bash examples/backends/vllm/launch/agg_omni.sh --model <your-model>
```
Test the deployment:
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen2.5-Omni-7B",
"messages": [{"role": "user", "content": "What is 2+2?"}],
"max_tokens": 50,
"stream": false
}'
```
### Text-to-Image
Text-to-image uses vLLM-Omni's built-in default stage configs (no custom YAML needed). Launch without a stage config path so vLLM-Omni loads the model's default multi-stage pipeline:
```bash
# Start frontend
python -m dynamo.frontend &
# Start omni worker (vLLM-Omni loads default stage configs for the model)
DYN_SYSTEM_PORT=8081 python -m dynamo.vllm \
--model <your-text-to-image-model> \
--omni \
--connector none
```
Images are returned as base64-encoded PNGs in the response:
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "<your-text-to-image-model>",
"messages": [{"role": "user", "content": "A cat sitting on a windowsill"}],
"stream": false
}'
```
The response contains image data URLs in the content field:
```json
{
"choices": [{
"delta": {
"content": [
{"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
]
}
}]
}
```
## Key Flags
| Flag | Description |
|------|-------------|
| `--omni` | Enable vLLM-Omni orchestrator (required) |
| `--stage-configs-path <path>` | Path to stage config YAML (optional; vLLM-Omni uses model defaults if omitted) |
| `--connector none` | Disable KV connector (recommended for omni) |
## Stage Configuration
Omni pipelines are configured via YAML stage configs. See [`examples/backends/vllm/launch/stage_configs/single_stage_llm.yaml`](/examples/backends/vllm/launch/stage_configs/single_stage_llm.yaml) for an example. Key fields:
- **`model_stage`**: Pipeline stage name (e.g., `thinker`, `talker`, `code2wav`)
- **`final_output_type`**: Output format — `text` or `image`
- **`is_comprehension`**: Whether this stage processes input text/multimodal content
For full documentation on stage config format, supported fields, and multi-stage pipeline examples, see the [vLLM-Omni Stage Configs documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/).
## Current Limitations
- Only text prompts are supported (no multimodal input yet)
- KV cache events are not published for omni workers
......@@ -42,6 +42,7 @@
backends/vllm/multi-node.md
backends/vllm/prometheus.md
backends/vllm/prompt-embeddings.md
backends/vllm/vllm-omni.md
backends/sglang/expert-distribution-eplb.md
backends/sglang/gpt-oss.md
......
......@@ -36,7 +36,7 @@ stage_args:
temperature: 0.0
top_p: 1.0
top_k: -1
max_tokens: 2048
max_tokens: 20
repetition_penalty: 1.1
seed: 42
detokenize: false # Token-based processing for Dynamo
detokenize: true # Token-based processing for Dynamo
......@@ -375,6 +375,77 @@ pub enum ChatCompletionRequestToolMessageContent {
Array(Vec<ChatCompletionRequestToolMessageContentPart>),
}
// Omni Specific Multimodal Content Types
// These types are used for assistant message responses that contain multimodal content
/// Response content part for text in assistant messages
#[derive(ToSchema, Clone, Serialize, Debug, Deserialize, PartialEq)]
pub struct ChatCompletionResponseContentPartText {
pub text: String,
}
/// Response content part for image URLs in assistant messages
#[derive(ToSchema, Clone, Serialize, Debug, Deserialize, PartialEq)]
pub struct ChatCompletionResponseContentPartImageUrl {
pub image_url: ImageUrlResponse,
}
/// Response content part for video URLs in assistant messages
#[derive(ToSchema, Clone, Serialize, Debug, Deserialize, PartialEq)]
pub struct ChatCompletionResponseContentPartVideoUrl {
pub video_url: VideoUrlResponse,
}
/// Response content part for audio URLs in assistant messages
#[derive(ToSchema, Clone, Serialize, Debug, Deserialize, PartialEq)]
pub struct ChatCompletionResponseContentPartAudioUrl {
pub audio_url: AudioUrlResponse,
}
/// Image URL in response messages (supports data URLs with base64)
#[derive(ToSchema, Clone, Serialize, Debug, Deserialize, PartialEq)]
pub struct ImageUrlResponse {
/// The URL of the image, either a URL or a data URL (data:image/png;base64,...)
pub url: String,
/// Optional detail level (for compatibility with OpenAI)
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
}
/// Video URL in response messages
#[derive(ToSchema, Clone, Serialize, Debug, Deserialize, PartialEq)]
pub struct VideoUrlResponse {
/// The URL of the video, either a URL or a data URL
pub url: String,
}
/// Audio URL in response messages
#[derive(ToSchema, Clone, Serialize, Debug, Deserialize, PartialEq)]
pub struct AudioUrlResponse {
/// The URL of the audio, either a URL or a data URL
pub url: String,
}
/// Content parts for assistant responses supporting multiple modalities (text, images, videos, audio)
#[derive(ToSchema, Clone, Serialize, Debug, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ChatCompletionResponseContentPart {
Text(ChatCompletionResponseContentPartText),
ImageUrl(ChatCompletionResponseContentPartImageUrl),
VideoUrl(ChatCompletionResponseContentPartVideoUrl),
AudioUrl(ChatCompletionResponseContentPartAudioUrl),
}
/// Assistant message content - can be a simple string or an array of content parts
#[derive(ToSchema, Clone, Serialize, Debug, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum ChatCompletionMessageContent {
/// Simple text content (backward compatible)
Text(String),
/// Array of content parts (for multimodal responses)
Parts(Vec<ChatCompletionResponseContentPart>),
}
#[derive(ToSchema, Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)]
#[builder(name = "ChatCompletionRequestUserMessageArgs")]
#[builder(pattern = "mutable")]
......@@ -486,9 +557,9 @@ pub struct ChatCompletionResponseMessageAudio {
/// A chat completion message generated by the model.
#[derive(ToSchema, Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct ChatCompletionResponseMessage {
/// The contents of the message.
/// The contents of the message - can be a string or array of content parts
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
pub content: Option<ChatCompletionMessageContent>,
/// The refusal message generated by the model.
#[serde(skip_serializing_if = "Option::is_none")]
pub refusal: Option<String>,
......@@ -1094,8 +1165,8 @@ pub struct ChatCompletionMessageToolCallChunk {
/// A chat completion delta generated by streamed model responses.
#[derive(ToSchema, Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct ChatCompletionStreamResponseDelta {
/// The contents of the chunk message.
pub content: Option<String>,
/// The contents of the chunk message - can be a string or array of content parts
pub content: Option<ChatCompletionMessageContent>,
/// Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model.
#[deprecated]
pub function_call: Option<FunctionCallStream>,
......
......@@ -311,7 +311,7 @@ fn register_llm<'p>(
pyo3_async_runtimes::tokio::future_into_py(py, async move {
// For TensorBased and Images models, skip HuggingFace downloads and register directly
// These model types don't require tokenizers
// Images models (vLLM-Omni) handle model loading internally, no tokenizer extraction needed
if is_tensor_based || is_images {
let model_name = model_name.unwrap_or_else(|| source_path.clone());
let mut card = llm_rs::model_card::ModelDeploymentCard::with_name_only(&model_name);
......
......@@ -405,7 +405,7 @@ impl
delta: dynamo_async_openai::types::ChatCompletionStreamResponseDelta{
//role: c.choices[0].delta.role,
role: Some(dynamo_async_openai::types::Role::Assistant),
content: Some(from_assistant),
content: Some(dynamo_async_openai::types::ChatCompletionMessageContent::Text(from_assistant)),
tool_calls: None,
refusal: None,
function_call: None,
......
......@@ -246,7 +246,8 @@ pub fn final_response_to_one_chunk_stream(
mod tests {
use super::*;
use dynamo_async_openai::types::{
ChatChoiceStream, ChatCompletionStreamResponseDelta, FinishReason, Role,
ChatChoiceStream, ChatCompletionMessageContent, ChatCompletionStreamResponseDelta,
FinishReason, Role,
};
use futures::StreamExt;
use futures::stream;
......@@ -261,7 +262,7 @@ mod tests {
index,
delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant),
content: Some(content),
content: Some(ChatCompletionMessageContent::Text(content)),
tool_calls: None,
function_call: None,
refusal: None,
......@@ -337,7 +338,10 @@ mod tests {
.as_ref()
.and_then(|d| d.choices.first())
.and_then(|c| c.delta.content.as_ref())
.cloned()
.and_then(|content| match content {
ChatCompletionMessageContent::Text(text) => Some(text.clone()),
ChatCompletionMessageContent::Parts(_) => None,
})
.unwrap_or_default()
}
......@@ -423,7 +427,9 @@ mod tests {
index: 0,
delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant),
content: Some("Content".to_string()),
content: Some(ChatCompletionMessageContent::Text(
"Content".to_string(),
)),
tool_calls: None,
function_call: None,
refusal: None,
......
......@@ -169,6 +169,7 @@ impl ModelManager {
.chain(self.list_completions_models())
.chain(self.list_embeddings_models())
.chain(self.list_tensor_models())
.chain(self.list_images_models())
.chain(self.list_prefill_models())
.collect()
}
......@@ -193,6 +194,10 @@ impl ModelManager {
self.prefill_engines.read().list()
}
pub fn list_images_models(&self) -> Vec<String> {
self.images_engines.read().list()
}
pub fn add_completions_model(
&self,
model: &str,
......
......@@ -71,6 +71,7 @@ const ALL_MODEL_TYPES: &[ModelType] = &[
ModelType::Completions,
ModelType::Embedding,
ModelType::TensorBased,
ModelType::Images,
ModelType::Prefill,
];
......@@ -284,12 +285,14 @@ impl ModelWatcher {
let completions_model_remove_err = self.manager.remove_completions_model(&model_name);
let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name);
let tensor_model_remove_err = self.manager.remove_tensor_model(&model_name);
let images_model_remove_err = self.manager.remove_images_model(&model_name);
let prefill_model_remove_err = self.manager.remove_prefill_model(&model_name);
let mut chat_model_removed = false;
let mut completions_model_removed = false;
let mut embeddings_model_removed = false;
let mut tensor_model_removed = false;
let mut images_model_removed = false;
let mut prefill_model_removed = false;
if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() {
......@@ -305,6 +308,9 @@ impl ModelWatcher {
if tensor_model_remove_err.is_ok() && self.manager.list_tensor_models().is_empty() {
tensor_model_removed = true;
}
if images_model_remove_err.is_ok() && self.manager.list_images_models().is_empty() {
images_model_removed = true;
}
if prefill_model_remove_err.is_ok() && self.manager.list_prefill_models().is_empty() {
prefill_model_removed = true;
}
......@@ -313,15 +319,17 @@ impl ModelWatcher {
&& !completions_model_removed
&& !embeddings_model_removed
&& !tensor_model_removed
&& !images_model_removed
&& !prefill_model_removed
{
tracing::debug!(
"No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, tensor_model_removed: {}, prefill_model_removed: {}",
"No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, tensor_model_removed: {}, images_model_removed: {}, prefill_model_removed: {}",
model_name,
chat_model_removed,
completions_model_removed,
embeddings_model_removed,
tensor_model_removed,
images_model_removed,
prefill_model_removed
);
} else {
......@@ -330,6 +338,7 @@ impl ModelWatcher {
|| (completions_model_removed && *model_type == ModelType::Completions)
|| (embeddings_model_removed && *model_type == ModelType::Embedding)
|| (tensor_model_removed && *model_type == ModelType::TensorBased)
|| (images_model_removed && *model_type == ModelType::Images)
|| (prefill_model_removed && *model_type == ModelType::Prefill))
&& let Some(tx) = &self.model_update_tx
{
......@@ -614,7 +623,7 @@ impl ModelWatcher {
self.manager
.add_embeddings_model(card.name(), checksum, embedding_engine)?;
} else if card.model_input == ModelInput::Tensor && card.model_type.supports_tensor() {
// Case 5: Tensor + Tensor (non-LLM)
// Case 6: Tensor + TensorBased (non-LLM)
// No KV cache concepts - not an LLM model
let push_router = PushRouter::<
NvCreateTensorRequest,
......@@ -627,18 +636,31 @@ impl ModelWatcher {
self.manager
.add_tensor_model(card.name(), checksum, engine)?;
} else if card.model_input == ModelInput::Text && card.model_type.supports_images() {
// Case: Text + Images (diffusion models)
// Takes text prompts as input, generates images
let push_router = PushRouter::<
// Case: Text + Images (e.g. vLLM-Omni, diffusion models)
// Takes text prompts as input, generates images. Images models also support
// chat completions (see model_type.rs as_endpoint_types).
let images_router = PushRouter::<
NvCreateImageRequest,
Annotated<NvImagesResponse>,
>::from_client_with_threshold(
client, self.router_config.router_mode, None, None
client.clone(), self.router_config.router_mode, None, None
)
.await?;
let engine = Arc::new(push_router);
self.manager
.add_images_model(card.name(), checksum, engine)?;
.add_images_model(card.name(), checksum, Arc::new(images_router))?;
let chat_router = PushRouter::<
NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client_with_threshold(
client, self.router_config.router_mode, None, None
)
.await?;
self.manager.add_chat_completions_model(
card.name(),
checksum,
Arc::new(chat_router),
)?;
} else if card.model_type.supports_prefill() {
// Case 6: Prefill
// Guardrail: Verify model_input is Tokens
......
......@@ -7,7 +7,7 @@ use crate::types::openai::chat_completions::{
NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine,
};
use anyhow::Context as _;
use dynamo_async_openai::types::FinishReason;
use dynamo_async_openai::types::{ChatCompletionMessageContent, FinishReason};
use dynamo_runtime::{DistributedRuntime, pipeline::Context, runtime::CancellationToken};
use futures::StreamExt;
use serde::{Deserialize, Serialize};
......@@ -241,7 +241,15 @@ async fn evaluate(
let choice = data.choices.first();
let chat_comp = choice.as_ref().unwrap();
if let Some(c) = &chat_comp.delta.content {
output += c;
match c {
ChatCompletionMessageContent::Text(text) => {
output += text;
}
ChatCompletionMessageContent::Parts(_) => {
// Multimodal content - skip for now in batch processing
// (ayushag) TODO: Handle multimodal content in batch mode
}
}
}
entry.finish_reason = chat_comp.finish_reason;
if chat_comp.finish_reason.is_some() {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::entrypoint::EngineConfig;
use crate::entrypoint::input::common;
use crate::request_template::RequestTemplate;
use crate::types::openai::chat_completions::{
NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine,
};
use dynamo_async_openai::types::ChatCompletionMessageContent;
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::pipeline::Context;
use futures::StreamExt;
use std::io::{ErrorKind, Write};
use crate::entrypoint::EngineConfig;
use crate::entrypoint::input::common;
/// Max response tokens for each single query. Must be less than model context size.
/// TODO: Cmd line flag to overwrite this
const MAX_TOKENS: u32 = 8192;
......@@ -140,9 +140,19 @@ async fn main_loop(
let entry = data.choices.first();
let chat_comp = entry.as_ref().unwrap();
if let Some(c) = &chat_comp.delta.content {
let _ = stdout.write(c.as_bytes());
let _ = stdout.flush();
assistant_message += c;
match c {
ChatCompletionMessageContent::Text(text) => {
let _ = stdout.write(text.as_bytes());
let _ = stdout.flush();
assistant_message += text;
}
ChatCompletionMessageContent::Parts(_) => {
// (ayushag) TODO: Handle multimodal content for multiturn conversations
// Multimodal content - for now just print a placeholder
let _ = stdout.write(b"[multimodal content]");
let _ = stdout.flush();
}
}
}
if let Some(reason) = chat_comp.finish_reason {
tracing::trace!("finish reason: {reason:?}");
......
......@@ -384,6 +384,8 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
token_ids: vec![token_id],
tokens: None, // Let backend handle detokenization
text: None,
output_type: Default::default(),
content_parts: None,
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
......
......@@ -126,8 +126,10 @@ impl ModelType {
if self.contains(Self::Embedding) {
endpoint_types.push(crate::endpoint_type::EndpointType::Embedding);
}
// Images models support both chat and completions endpoints
if self.contains(Self::Images) {
endpoint_types.push(crate::endpoint_type::EndpointType::Images);
endpoint_types.push(crate::endpoint_type::EndpointType::Chat);
}
// [gluo NOTE] ModelType::Tensor doesn't map to any endpoint type,
// current use of endpoint type is LLM specific and so does the HTTP
......
......@@ -953,7 +953,11 @@ mod tests {
choices: vec![ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
content: Some("test".to_string()),
content: Some(
dynamo_async_openai::types::ChatCompletionMessageContent::Text(
"test".to_string(),
),
),
function_call: None,
tool_calls: None,
role: Some(Role::Assistant),
......@@ -987,7 +991,11 @@ mod tests {
.map(|(i, token_logprobs)| ChatChoiceStream {
index: i as u32,
delta: ChatCompletionStreamResponseDelta {
content: Some("test".to_string()),
content: Some(
dynamo_async_openai::types::ChatCompletionMessageContent::Text(
"test".to_string(),
),
),
function_call: None,
tool_calls: None,
role: Some(Role::Assistant),
......@@ -1337,7 +1345,11 @@ mod tests {
choices: vec![ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
content: Some("test".to_string()),
content: Some(
dynamo_async_openai::types::ChatCompletionMessageContent::Text(
"test".to_string(),
),
),
function_call: None,
tool_calls: None,
role: Some(Role::Assistant),
......
......@@ -945,14 +945,23 @@ impl OpenAIPreprocessor {
response.map_data(|mut data| {
// Process all choices, not just the first one
for choice in data.choices.iter_mut() {
if let Some(text) = choice.delta.content.as_ref() {
// Reasoning parsing only applies to text content
if let Some(
dynamo_async_openai::types::ChatCompletionMessageContent::Text(
text,
),
) = choice.delta.content.as_ref()
{
let parser_result =
parser.parse_reasoning_streaming_incremental(text, &[]);
// Update this specific choice with parsed content
choice.delta.content = parser_result.get_some_normal_text();
choice.delta.content = parser_result.get_some_normal_text().map(
dynamo_async_openai::types::ChatCompletionMessageContent::Text,
);
choice.delta.reasoning_content = parser_result.get_some_reasoning();
}
// For multimodal content, pass through unchanged
}
Ok(data)
})
......
......@@ -13,6 +13,45 @@ use dynamo_runtime::protocols::maybe_error::MaybeError;
pub type TokenType = Option<String>;
pub type LogProbs = Vec<f64>;
/// Output type discriminator for different modalities
#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum OutputType {
#[default]
Text,
Image,
Video,
Audio,
}
/// Image URL data for responses
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ImageUrlData {
pub url: String,
}
/// Video URL data for responses
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct VideoUrlData {
pub url: String,
}
/// Audio URL data for responses
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct AudioUrlData {
pub url: String,
}
/// Content part for multimodal outputs (internal representation)
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPart {
Text { text: String },
ImageUrl { image_url: ImageUrlData },
VideoUrl { video_url: VideoUrlData },
AudioUrl { audio_url: AudioUrlData },
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct TopLogprob {
pub rank: u32,
......@@ -86,6 +125,14 @@ pub struct LLMEngineOutput {
// decoded text -
pub text: Option<String>,
/// Output type discriminator (text, image, video, audio)
#[serde(default)]
pub output_type: OutputType,
/// Multimodal content parts (for non-text outputs)
#[serde(default, skip_serializing_if = "Option::is_none")]
pub content_parts: Option<Vec<ContentPart>>,
/// cumulative log probabilities
pub cum_log_probs: Option<f64>,
......@@ -125,6 +172,8 @@ impl LLMEngineOutput {
token_ids: vec![],
tokens: None,
text: None,
output_type: OutputType::default(),
content_parts: None,
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
......@@ -142,6 +191,8 @@ impl LLMEngineOutput {
token_ids: vec![],
tokens: None,
text: None,
output_type: OutputType::default(),
content_parts: None,
cum_log_probs: None,
log_probs: None,
finish_reason: Some(FinishReason::Stop),
......@@ -159,6 +210,8 @@ impl LLMEngineOutput {
token_ids: vec![],
tokens: None,
text: None,
output_type: OutputType::default(),
content_parts: None,
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
......@@ -176,6 +229,8 @@ impl LLMEngineOutput {
token_ids: vec![],
tokens: None,
text: None,
output_type: OutputType::default(),
content_parts: None,
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
......
......@@ -12,7 +12,7 @@ use crate::protocols::{
openai::ParsingOptions,
};
use dynamo_async_openai::types::StopReason;
use dynamo_async_openai::types::{ChatCompletionMessageContent, StopReason};
use dynamo_runtime::engine::DataStream;
/// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single
......@@ -59,6 +59,9 @@ struct DeltaChoice {
/// Optional reasoning content for the chat choice.
reasoning_content: Option<String>,
/// Accumulated content parts for multimodal responses
content_parts: Vec<dynamo_async_openai::types::ChatCompletionResponseContentPart>,
}
impl Default for DeltaAggregator {
......@@ -166,10 +169,18 @@ impl DeltaAggregator {
logprobs: None,
tool_calls: None,
reasoning_content: None,
content_parts: Vec::new(),
});
// Append content if available.
// Handle content based on type
if let Some(content) = &choice.delta.content {
state_choice.text.push_str(content);
match content {
ChatCompletionMessageContent::Text(text) => {
state_choice.text.push_str(text);
}
ChatCompletionMessageContent::Parts(parts) => {
state_choice.content_parts.extend(parts.clone());
}
}
}
if let Some(reasoning_content) = &choice.delta.reasoning_content {
......@@ -289,14 +300,21 @@ impl From<DeltaChoice> for dynamo_async_openai::types::ChatChoice {
delta.finish_reason
};
// Determine content format based on what we accumulated
let content = if !delta.content_parts.is_empty() {
// Multimodal response with content parts
Some(ChatCompletionMessageContent::Parts(delta.content_parts))
} else if !delta.text.is_empty() {
// Text-only response (backward compatible)
Some(ChatCompletionMessageContent::Text(delta.text))
} else {
None
};
dynamo_async_openai::types::ChatChoice {
message: dynamo_async_openai::types::ChatCompletionResponseMessage {
role: delta.role.expect("delta should have a Role"),
content: if delta.text.is_empty() {
None
} else {
Some(delta.text)
},
content,
tool_calls: delta.tool_calls,
refusal: None,
function_call: None,
......@@ -396,7 +414,7 @@ mod tests {
};
let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
content: Some(text.to_string()),
content: Some(ChatCompletionMessageContent::Text(text.to_string())),
function_call: None,
tool_calls: tool_call_chunks,
role,
......@@ -496,7 +514,10 @@ mod tests {
assert_eq!(response.choices.len(), 1);
let choice = &response.choices[0];
assert_eq!(choice.index, 0);
assert_eq!(choice.message.content.as_ref().unwrap(), "Hello,");
assert_eq!(
choice.message.content.as_ref().unwrap(),
&ChatCompletionMessageContent::Text("Hello,".to_string())
);
assert!(choice.finish_reason.is_none());
assert_eq!(choice.message.role, dynamo_async_openai::types::Role::User);
assert!(response.service_tier.is_none());
......@@ -539,7 +560,10 @@ mod tests {
assert_eq!(response.choices.len(), 1);
let choice = &response.choices[0];
assert_eq!(choice.index, 0);
assert_eq!(choice.message.content.as_ref().unwrap(), "Hello, world!");
assert_eq!(
choice.message.content.as_ref().unwrap(),
&ChatCompletionMessageContent::Text("Hello, world!".to_string())
);
assert_eq!(
choice.finish_reason,
Some(dynamo_async_openai::types::FinishReason::Stop)
......@@ -605,7 +629,12 @@ mod tests {
let choice = &response.choices[0];
assert_eq!(choice.index, 0);
assert_eq!(choice.message.content.as_deref(), Some("Hello world"));
assert_eq!(
choice.message.content.as_ref(),
Some(&ChatCompletionMessageContent::Text(
"Hello world".to_string()
))
);
assert_eq!(
choice.finish_reason,
Some(dynamo_async_openai::types::FinishReason::Stop)
......@@ -630,7 +659,7 @@ mod tests {
index: 0,
delta: dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
role: Some(dynamo_async_openai::types::Role::Assistant),
content: Some("Choice 0".to_string()),
content: Some(ChatCompletionMessageContent::Text("Choice 0".to_string())),
function_call: None,
tool_calls: None,
refusal: None,
......@@ -644,7 +673,7 @@ mod tests {
index: 1,
delta: dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
role: Some(dynamo_async_openai::types::Role::Assistant),
content: Some("Choice 1".to_string()),
content: Some(ChatCompletionMessageContent::Text("Choice 1".to_string())),
function_call: None,
tool_calls: None,
refusal: None,
......@@ -680,7 +709,10 @@ mod tests {
response.choices.sort_by(|a, b| a.index.cmp(&b.index)); // Ensure the choices are ordered
let choice0 = &response.choices[0];
assert_eq!(choice0.index, 0);
assert_eq!(choice0.message.content.as_ref().unwrap(), "Choice 0");
assert_eq!(
choice0.message.content.as_ref().unwrap(),
&ChatCompletionMessageContent::Text("Choice 0".to_string())
);
assert_eq!(
choice0.finish_reason,
Some(dynamo_async_openai::types::FinishReason::Stop)
......@@ -692,7 +724,10 @@ mod tests {
let choice1 = &response.choices[1];
assert_eq!(choice1.index, 1);
assert_eq!(choice1.message.content.as_ref().unwrap(), "Choice 1");
assert_eq!(
choice1.message.content.as_ref().unwrap(),
&ChatCompletionMessageContent::Text("Choice 1".to_string())
);
assert_eq!(
choice1.finish_reason,
Some(dynamo_async_openai::types::FinishReason::Stop)
......@@ -962,7 +997,9 @@ mod tests {
assert_eq!(
choice.message.content.as_ref().unwrap(),
"Hey Dude ! What's the weather in San Francisco in Fahrenheit?"
&ChatCompletionMessageContent::Text(
"Hey Dude ! What's the weather in San Francisco in Fahrenheit?".to_string()
)
);
// The finish_reason should be ToolCalls
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment