Unverified Commit 6dd3ce2e authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

feat: vllm omni text to video generation pipeline (#6104)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 23de4e86
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Protocol types for TensorRT-LLM backend. """Shared protocol types used across multiple Dynamo backends.
This module provides protocol types for various modalities: This module provides protocol types for various modalities:
- video_protocol: NvCreateVideoRequest, NvVideosResponse for video generation - video_protocol: NvCreateVideoRequest, NvVideosResponse for video generation
- image_protocol: (future) Protocol types for image generation
""" """
from dynamo.trtllm.protocols.video_protocol import ( from dynamo.common.protocols.video_protocol import (
NvCreateVideoRequest, NvCreateVideoRequest,
NvVideosResponse, NvVideosResponse,
VideoData, VideoData,
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
from pydantic import BaseModel
# For omni models, we need to support raw_request parsing and json output format. We need to have these protocols defined here for serialization and deserialization.
# TODO: Replace these Pydantic models with Python bindings to the Rust protocol types once PyO3 bindings are available.
class ImageNvExt(BaseModel):
"""NVIDIA extensions for image generation requests.
Matches Rust NvExt in lib/llm/src/protocols/openai/images/nvext.rs.
"""
annotations: Optional[list[str]] = None
"""Annotations for SSE stream events."""
negative_prompt: Optional[str] = None
"""Optional negative prompt."""
num_inference_steps: Optional[int] = None
"""Number of denoising steps."""
guidance_scale: Optional[float] = None
"""CFG guidance scale."""
seed: Optional[int] = None
"""Random seed for reproducibility."""
class NvCreateImageRequest(BaseModel):
"""Request for image generation (/v1/images/generations endpoint).
Matches the flattened Rust NvCreateImageRequest in lib/llm/src/protocols/openai/images.rs
"""
prompt: str
"""The text prompt for image generation."""
model: Optional[str] = None
"""The model to use for image generation."""
n: Optional[int] = None
"""Number of images to generate (1-10)."""
quality: Optional[str] = None
"""Image quality: standard, hd, high, medium, low, auto."""
response_format: Optional[str] = None
"""Response format: url or b64_json."""
size: Optional[str] = None
"""Image size in WxH format (e.g. 1024x1024)."""
style: Optional[str] = None
"""Image style: vivid or natural."""
user: Optional[str] = None
"""Optional user identifier."""
moderation: Optional[str] = None
"""Content moderation level: auto or low."""
nvext: Optional[ImageNvExt] = None
"""NVIDIA extensions."""
class ImageData(BaseModel):
"""Individual image data in a response.
Matches the flattened Rust Image enum in lib/async-openai/src/types/image.rs.
"""
url: Optional[str] = None
"""URL of the generated image (if response_format is url)."""
b64_json: Optional[str] = None
"""Base64-encoded image (if response_format is b64_json)."""
revised_prompt: Optional[str] = None
"""Revised prompt, when the model rewrites the original prompt."""
class NvImagesResponse(BaseModel):
"""Response structure for image generation.
Matches the flattened Rust NvImagesResponse in lib/llm/src/protocols/openai/images.rs
"""
created: int
"""Unix timestamp of creation."""
data: list[ImageData] = []
"""List of generated images."""
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
These types match the Rust protocol types in lib/llm/src/protocols/openai/videos.rs These types match the Rust protocol types in lib/llm/src/protocols/openai/videos.rs
to ensure compatibility with the Dynamo HTTP frontend. to ensure compatibility with the Dynamo HTTP frontend.
""" """
# TODO: Replace these Pydantic models with Python bindings to the Rust protocol types once PyO3 bindings are available.
from typing import Optional from typing import Optional
......
...@@ -2,8 +2,12 @@ ...@@ -2,8 +2,12 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from enum import Enum from enum import Enum
from typing import List, Optional from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel
from dynamo.common.protocols.image_protocol import NvCreateImageRequest
from dynamo.common.protocols.video_protocol import NvCreateVideoRequest
from dynamo.llm import ModelType from dynamo.llm import ModelType
...@@ -32,6 +36,15 @@ class OutputModality(Enum): ...@@ -32,6 +36,15 @@ class OutputModality(Enum):
return {m.name.lower() for m in cls} return {m.name.lower() for m in cls}
class RequestType(Enum):
"""Identifies the parsed request type returned by parse_request_type."""
CHAT_COMPLETION = "chat_completion"
IMAGE_GENERATION = "image_generation"
VIDEO_GENERATION = "video_generation"
AUDIO_GENERATION = "audio_generation"
def get_output_modalities(cli_input: List[str], model_repo: str) -> Optional[ModelType]: def get_output_modalities(cli_input: List[str], model_repo: str) -> Optional[ModelType]:
""" """
Get the combined ModelType flags for omni models based on CLI input. Get the combined ModelType flags for omni models based on CLI input.
...@@ -52,3 +65,34 @@ def get_output_modalities(cli_input: List[str], model_repo: str) -> Optional[Mod ...@@ -52,3 +65,34 @@ def get_output_modalities(cli_input: List[str], model_repo: str) -> Optional[Mod
flag if output_modalities is None else output_modalities | flag flag if output_modalities is None else output_modalities | flag
) )
return output_modalities return output_modalities
def parse_request_type(
raw_request: Dict[str, Any],
output_modalities: List[str],
) -> Tuple[Union[BaseModel, Dict[str, Any]], RequestType]:
"""
Classify the endpoint based on the output modality and serialize the request if necessary.
Assumption: Right now we only consider user passes only one modality at a time.
"""
# Fetch the first output modality from the list.
if not output_modalities:
raise ValueError("output_modalities must not be empty")
output_modality = output_modalities[0]
modality = OutputModality.from_name(output_modality)
if modality is OutputModality.IMAGE:
if "messages" in raw_request:
return raw_request, RequestType.CHAT_COMPLETION
return NvCreateImageRequest(**raw_request), RequestType.IMAGE_GENERATION
if modality is OutputModality.VIDEO:
return NvCreateVideoRequest(**raw_request), RequestType.VIDEO_GENERATION
if modality is OutputModality.AUDIO:
# Audio protocol types are not yet defined; pass through the raw dict.
return raw_request, RequestType.AUDIO_GENERATION
# Text Modality
return raw_request, RequestType.CHAT_COMPLETION
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Video encoding utilities for TensorRT-LLM video diffusion. """Video utilities for video diffusion.
This module provides utilities for encoding numpy video frames to MP4 format. Provides helpers for parsing video request parameters and encoding numpy
video frames to MP4 format.
""" """
import io import io
import logging import logging
import os import os
from typing import Tuple
import numpy as np import numpy as np
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_VIDEO_WIDTH = 832
DEFAULT_VIDEO_HEIGHT = 480
DEFAULT_VIDEO_FPS = 16
DEFAULT_VIDEO_NUM_FRAMES = 97
def parse_size(
size: str | None,
default_w: int = DEFAULT_VIDEO_WIDTH,
default_h: int = DEFAULT_VIDEO_HEIGHT,
) -> Tuple[int, int]:
"""Parse a 'WxH' string into (width, height).
Falls back to default_w x default_h when size is None or malformed.
"""
if not size:
return default_w, default_h
try:
w, h = size.split("x")
return int(w), int(h)
except (ValueError, AttributeError):
logger.warning("Invalid size format: %s, using defaults", size)
return default_w, default_h
def compute_num_frames(
num_frames: int | None = None,
seconds: int | None = None,
fps: int | None = None,
default_fps: int = DEFAULT_VIDEO_FPS,
default_num_frames: int = DEFAULT_VIDEO_NUM_FRAMES,
) -> int:
"""Compute the number of video frames.
Priority: num_frames > seconds × fps > default_num_frames.
"""
if num_frames is not None:
return num_frames
if seconds is not None or fps is not None:
_seconds = seconds if seconds is not None else 4
_fps = fps if fps is not None else default_fps
return _seconds * _fps
return default_num_frames
def normalize_video_frames(images) -> list:
"""Normalize stage_output.images into a frame list for export_to_video.
Args:
images: stage_output.images -- a list that may contain a single
torch.Tensor or np.ndarray representing the full video.
Returns:
List of frames suitable for diffusers export_to_video.
"""
frames = images[0] if len(images) == 1 else images
if isinstance(frames, np.ndarray):
if frames.ndim == 5:
frames = frames[0]
return list(frames)
return list(frames)
def frames_to_numpy(images: list) -> np.ndarray:
"""Convert a list of PIL Images to a numpy array suitable for video encoding.
Args:
images: List of PIL Image objects (video frames).
Returns:
Numpy array of shape ``(num_frames, height, width, 3)`` with dtype
``uint8`` and values in ``[0, 255]``.
Raises:
ValueError: If no images are provided or images have inconsistent sizes.
"""
if not images:
raise ValueError("No images provided for video encoding")
frames = []
for img in images:
arr = np.array(img.convert("RGB"))
frames.append(arr)
# Validate consistent sizes
shapes = {f.shape for f in frames}
if len(shapes) > 1:
raise ValueError(
f"Inconsistent frame sizes detected: {shapes}. "
"All frames must have the same dimensions."
)
return np.stack(frames, axis=0)
def encode_to_mp4( def encode_to_mp4(
frames: np.ndarray, frames: np.ndarray,
output_dir: str, output_dir: str,
......
...@@ -14,19 +14,16 @@ import uuid ...@@ -14,19 +14,16 @@ import uuid
from typing import Any, AsyncGenerator, Optional from typing import Any, AsyncGenerator, Optional
from dynamo._core import Component, Context from dynamo._core import Component, Context
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig from dynamo.common.protocols.video_protocol import (
from dynamo.trtllm.engines.diffusion_engine import DiffusionEngine
from dynamo.trtllm.protocols.video_protocol import (
NvCreateVideoRequest, NvCreateVideoRequest,
NvVideosResponse, NvVideosResponse,
VideoData, VideoData,
VideoNvExt, VideoNvExt,
) )
from dynamo.common.utils.video_utils import encode_to_mp4, encode_to_mp4_bytes
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig
from dynamo.trtllm.engines.diffusion_engine import DiffusionEngine
from dynamo.trtllm.request_handlers.base_generative_handler import BaseGenerativeHandler from dynamo.trtllm.request_handlers.base_generative_handler import BaseGenerativeHandler
from dynamo.trtllm.request_handlers.video_diffusion.video_utils import (
encode_to_mp4,
encode_to_mp4_bytes,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -18,14 +18,14 @@ from unittest.mock import MagicMock, patch ...@@ -18,14 +18,14 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig from dynamo.common.protocols.video_protocol import (
from dynamo.trtllm.constants import Modality
from dynamo.trtllm.protocols.video_protocol import (
NvCreateVideoRequest, NvCreateVideoRequest,
NvVideosResponse, NvVideosResponse,
VideoData, VideoData,
VideoNvExt, VideoNvExt,
) )
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig
from dynamo.trtllm.constants import Modality
pytestmark = [ pytestmark = [
pytest.mark.unit, pytest.mark.unit,
......
...@@ -138,6 +138,135 @@ class DynamoVllmArgGroup(ArgGroup): ...@@ -138,6 +138,135 @@ class DynamoVllmArgGroup(ArgGroup):
help="Path to vLLM-Omni stage configuration YAML file for --omni mode (optional).", help="Path to vLLM-Omni stage configuration YAML file for --omni mode (optional).",
) )
# Video diffusion output
# TODO: Propose an alternate design to switch to AsyncOmniEngine args while using vLLM-Omni
add_argument(
g,
flag_name="--video-output-dir",
env_var="DYN_VLLM_VIDEO_OUTPUT_DIR",
default="/tmp/dynamo_videos", # noqa: S108
help="Directory to save generated video MP4 files.",
)
add_argument(
g,
flag_name="--default-video-fps",
env_var="DYN_VLLM_DEFAULT_VIDEO_FPS",
default=16,
arg_type=int,
help="Default frames per second for generated videos.",
)
# Diffusion engine-level args (passed to AsyncOmni constructor)
add_negatable_bool_argument(
g,
flag_name="--enable-layerwise-offload",
env_var="DYN_VLLM_ENABLE_LAYERWISE_OFFLOAD",
default=False,
help="Enable layerwise (blockwise) offloading on DiT modules to reduce GPU memory.",
)
add_argument(
g,
flag_name="--layerwise-num-gpu-layers",
env_var="DYN_VLLM_LAYERWISE_NUM_GPU_LAYERS",
default=1,
arg_type=int,
help="Number of ready layers (blocks) to keep on GPU during generation.",
)
add_negatable_bool_argument(
g,
flag_name="--vae-use-slicing",
env_var="DYN_VLLM_VAE_USE_SLICING",
default=False,
help="Enable VAE slicing for memory optimization in diffusion models.",
)
add_negatable_bool_argument(
g,
flag_name="--vae-use-tiling",
env_var="DYN_VLLM_VAE_USE_TILING",
default=False,
help="Enable VAE tiling for memory optimization in diffusion models.",
)
add_argument(
g,
flag_name="--boundary-ratio",
env_var="DYN_VLLM_BOUNDARY_RATIO",
default=0.875,
arg_type=float,
help=(
"Boundary split ratio for low/high DiT transformers. "
"Default 0.875 uses both transformers for best quality. "
"Set to 1.0 to load only the low-noise transformer (saves memory). "
"Only used with --omni."
),
)
add_argument(
g,
flag_name="--flow-shift",
env_var="DYN_VLLM_FLOW_SHIFT",
default=None,
arg_type=float,
help="Scheduler flow_shift parameter (5.0 for 720p, 12.0 for 480p). Only used with --omni.",
)
add_argument(
g,
flag_name="--diffusion-cache-backend",
env_var="DYN_VLLM_DIFFUSION_CACHE_BACKEND",
default=None,
choices=["cache_dit", "tea_cache"],
help=(
"Cache backend for diffusion acceleration. "
"'cache_dit' enables DBCache + SCM + TaylorSeer. "
"'tea_cache' enables TeaCache. Only used with --omni."
),
)
add_argument(
g,
flag_name="--diffusion-cache-config",
env_var="DYN_VLLM_DIFFUSION_CACHE_CONFIG",
default=None,
help="Cache configuration as JSON string (overrides defaults). Only used with --omni.",
)
add_negatable_bool_argument(
g,
flag_name="--enable-cache-dit-summary",
env_var="DYN_VLLM_ENABLE_CACHE_DIT_SUMMARY",
default=False,
help="Enable cache-dit summary logging after diffusion forward passes.",
)
add_negatable_bool_argument(
g,
flag_name="--enable-cpu-offload",
env_var="DYN_VLLM_ENABLE_CPU_OFFLOAD",
default=False,
help="Enable CPU offloading for diffusion models to reduce GPU memory usage.",
)
# Diffusion parallel configuration
add_argument(
g,
flag_name="--ulysses-degree",
env_var="DYN_VLLM_ULYSSES_DEGREE",
default=1,
arg_type=int,
help="Number of GPUs used for Ulysses sequence parallelism in diffusion.",
)
add_argument(
g,
flag_name="--ring-degree",
env_var="DYN_VLLM_RING_DEGREE",
default=1,
arg_type=int,
help="Number of GPUs used for ring sequence parallelism in diffusion.",
)
add_argument(
g,
flag_name="--cfg-parallel-size",
env_var="DYN_VLLM_CFG_PARALLEL_SIZE",
default=1,
arg_type=int,
choices=[1, 2],
help="Number of GPUs used for classifier free guidance parallelism.",
)
# ModelExpress P2P # ModelExpress P2P
add_argument( add_argument(
g, g,
...@@ -171,6 +300,27 @@ class DynamoVllmConfig(ConfigBase): ...@@ -171,6 +300,27 @@ class DynamoVllmConfig(ConfigBase):
omni: bool omni: bool
stage_configs_path: Optional[str] = None stage_configs_path: Optional[str] = None
# Video diffusion output
video_output_dir: str = "/tmp/dynamo_videos" # noqa: S108
default_video_fps: int = 16
# Diffusion engine-level parameters (passed to AsyncOmni constructor)
enable_layerwise_offload: bool = False
layerwise_num_gpu_layers: int = 1
vae_use_slicing: bool = False
vae_use_tiling: bool = False
boundary_ratio: float = 0.875
flow_shift: Optional[float] = None
diffusion_cache_backend: Optional[str] = None
diffusion_cache_config: Optional[str] = None
enable_cache_dit_summary: bool = False
enable_cpu_offload: bool = False
# Diffusion parallel configuration
ulysses_degree: int = 1
ring_degree: int = 1
cfg_parallel_size: int = 1
# ModelExpress P2P # ModelExpress P2P
model_express_url: Optional[str] = None model_express_url: Optional[str] = None
......
...@@ -902,11 +902,10 @@ def get_engine_cache_info(engine: AsyncLLM): ...@@ -902,11 +902,10 @@ def get_engine_cache_info(engine: AsyncLLM):
async def init_omni( async def init_omni(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
): ):
""" """Initialize Omni worker for multi-stage pipeline generation using vLLM-Omni.
Initialize Omni worker for text-to-text generation using vLLM-Omni orchestrator.
Uses vLLM-Omni's Omni class for single-stage text generation pipeline. Supports text-to-text, text-to-image, and text-to-video generation
For now, supports text-to-text only (no multimodal). through a single unified OmniHandler.
""" """
# Lazy import to avoid loading vllm-omni unless explicitly needed # Lazy import to avoid loading vllm-omni unless explicitly needed
from dynamo.vllm.omni import OmniHandler from dynamo.vllm.omni import OmniHandler
...@@ -916,7 +915,7 @@ async def init_omni( ...@@ -916,7 +915,7 @@ async def init_omni(
) )
component = generate_endpoint.component() component = generate_endpoint.component()
# Initialize OmniHandler with Omni orchestrator # Initialize unified OmniHandler
handler = OmniHandler( handler = OmniHandler(
runtime=runtime, runtime=runtime,
component=component, component=component,
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
"""vLLM-Omni integration for Dynamo.""" """vLLM-Omni integration for Dynamo."""
from .base_handler import BaseOmniHandler
from .omni_handler import OmniHandler from .omni_handler import OmniHandler
__all__ = ["OmniHandler"] __all__ = ["BaseOmniHandler", "OmniHandler"]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Base handler for vLLM-Omni multi-stage pipelines."""
import asyncio
import logging
import time
from typing import Any, AsyncGenerator, Dict
from vllm import SamplingParams
from vllm_omni.entrypoints import AsyncOmni
try:
from vllm_omni.diffusion.data import DiffusionParallelConfig
except ImportError:
DiffusionParallelConfig = None # type: ignore[assignment, misc]
from dynamo.vllm.handlers import BaseWorkerHandler, build_sampling_params
logger = logging.getLogger(__name__)
class BaseOmniHandler(BaseWorkerHandler):
"""Base handler for multi-stage pipelines using vLLM-Omni's AsyncOmni orchestrator."""
def __init__(
self,
runtime,
component,
config,
default_sampling_params: Dict[str, Any],
shutdown_event: asyncio.Event | None = None,
):
"""Initialize handler with AsyncOmni orchestrator.
Args:
runtime: Dynamo distributed runtime.
component: Dynamo component handle.
config: Parsed Config object from args.py.
default_sampling_params: Default sampling parameters dict.
shutdown_event: Optional asyncio event for graceful shutdown.
"""
logger.info(
f"Initializing {self.__class__.__name__} for multi-stage pipelines "
f"with model: {config.model}"
)
omni_kwargs = self._build_omni_kwargs(config)
self.engine_client = AsyncOmni(**omni_kwargs)
# Initialize attributes needed from BaseWorkerHandler
# We don't call super().__init__() because VllmEngineMonitor expects AsyncLLM,
# but AsyncOmni manages its own engines internally
# TODO: Kv publishers not supported yet
# TODO: Adopt to baseworker initialization pattern
self.runtime = runtime
self.component = component
self.default_sampling_params = default_sampling_params
self.config = config
self.model_max_len = config.engine_args.max_model_len
self.shutdown_event = shutdown_event
self.use_vllm_tokenizer = config.use_vllm_tokenizer
logger.info(f"{self.__class__.__name__} initialized successfully")
def _build_omni_kwargs(self, config) -> Dict[str, Any]:
"""Build keyword arguments for AsyncOmni constructor.
Constructs the full kwargs dict including engine-level diffusion
parameters and parallel configuration when available.
Args:
config: Parsed Config object.
Returns:
Dictionary of keyword arguments for AsyncOmni.
"""
omni_kwargs: Dict[str, Any] = {
"model": config.model,
"trust_remote_code": config.engine_args.trust_remote_code,
}
if config.stage_configs_path:
omni_kwargs["stage_configs_path"] = config.stage_configs_path
# Add diffusion engine-level params if present on config
diffusion_params = [
"enable_layerwise_offload",
"layerwise_num_gpu_layers",
"vae_use_slicing",
"vae_use_tiling",
"boundary_ratio",
"flow_shift",
"diffusion_cache_backend",
"diffusion_cache_config",
"enable_cache_dit_summary",
"enable_cpu_offload",
]
for param in diffusion_params:
if hasattr(config, param):
value = getattr(config, param)
if value is not None:
# Map config attribute names to AsyncOmni kwarg names
kwarg_name = param
if param == "diffusion_cache_backend":
kwarg_name = "cache_backend"
elif param == "diffusion_cache_config":
kwarg_name = "cache_config"
omni_kwargs[kwarg_name] = value
# Build DiffusionParallelConfig if parallel params are present
if DiffusionParallelConfig is not None and hasattr(config, "ulysses_degree"):
parallel_config = DiffusionParallelConfig(
ulysses_degree=getattr(config, "ulysses_degree", 1),
ring_degree=getattr(config, "ring_degree", 1),
cfg_parallel_size=getattr(config, "cfg_parallel_size", 1),
)
omni_kwargs["parallel_config"] = parallel_config
elif DiffusionParallelConfig is None:
logger.warning(
"DiffusionParallelConfig not available; "
"skipping parallel config for AsyncOmni"
)
return omni_kwargs
async def generate(
self, request: Dict[str, Any], context
) -> AsyncGenerator[Dict, None]:
"""Generate outputs using AsyncOmni orchestrator with OpenAI-compatible format.
Routes to OpenAI mode (detokenized text) or token mode based on config.
Subclasses should override ``_generate_openai_mode`` for custom output handling.
"""
request_id = context.id()
logger.debug(f"Omni Request ID: {request_id}")
async for chunk in self._generate_openai_mode(request, context, request_id):
yield chunk
async def _generate_openai_mode(
self, request, context, request_id
) -> AsyncGenerator[Dict, None]:
"""Generate OpenAI-compatible streaming chunks.
Subclasses should override this to handle their specific output types.
The base implementation raises NotImplementedError.
"""
raise NotImplementedError(
f"{self.__class__.__name__} must implement _generate_openai_mode"
)
def _extract_text_prompt(self, request: Dict[str, Any]) -> str | None:
"""Extract text prompt from OpenAI messages format.
Looks for the last user message and returns its text content.
"""
messages = request.get("messages", [])
for message in reversed(messages):
if message.get("role") == "user":
return message.get("content")
return None
def _extract_extra_body(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""Extract extra_body parameters from the request.
The extra_body is passed through by the OpenAI client and contains
model-specific parameters (e.g. diffusion sampling params).
"""
return request.get("extra_body", {}) or {}
def _build_sampling_params(self, request: Dict[str, Any]) -> SamplingParams:
"""Build sampling params using shared handler utility."""
return build_sampling_params(
request, self.default_sampling_params, self.model_max_len
)
def _error_chunk(self, request_id: str, error_message: str) -> Dict[str, Any]:
"""Create an error chunk in OpenAI format."""
return {
"id": request_id,
"created": int(time.time()),
"object": "chat.completion.chunk",
"model": self.config.served_model_name or self.config.model,
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": f"Error: {error_message}",
},
"finish_reason": "error",
}
],
}
def cleanup(self):
"""Cleanup AsyncOmni orchestrator resources."""
try:
if hasattr(self, "engine_client"):
self.engine_client.close()
logger.info("AsyncOmni orchestrator closed")
except Exception as e:
logger.error(f"Error closing AsyncOmni orchestrator: {e}")
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import base64
import logging import logging
import os
import time import time
from typing import Any, AsyncGenerator, Dict import uuid
from dataclasses import dataclass
from io import BytesIO
from typing import Any, AsyncGenerator, Dict, Union
from diffusers.utils import export_to_video
from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt
from dynamo.common.protocols.image_protocol import (
ImageData,
NvCreateImageRequest,
NvImagesResponse,
)
from dynamo.common.protocols.video_protocol import (
NvCreateVideoRequest,
NvVideosResponse,
VideoData,
)
from dynamo.common.utils.output_modalities import RequestType, parse_request_type
from dynamo.common.utils.video_utils import (
compute_num_frames,
normalize_video_frames,
parse_size,
)
from dynamo.vllm.omni.base_handler import BaseOmniHandler
from vllm import SamplingParams logger = logging.getLogger(__name__)
from vllm_omni.entrypoints import AsyncOmni
from vllm_omni.inputs.data import OmniTextPrompt, OmniTokensPrompt
from dynamo.vllm.handlers import BaseWorkerHandler, build_sampling_params # TODO: Migrate to fs_url based approach in another PR
DEFAULT_VIDEO_FPS = 16
DEFAULT_VIDEO_OUTPUT_DIR = "/tmp/dynamo_videos" # noqa: S108
logger = logging.getLogger(__name__)
@dataclass
class EngineInputs:
"""Parsed engine inputs ready for AsyncOmni.generate().
Attributes:
prompt: OmniTextPrompt dict for the engine.
sampling_params_list: Per-stage sampling parameters, or None for defaults.
request_type: The resolved request type (may differ from the initial parse
when a chat completion request carries video params).
fps: Frames per second, only meaningful for video requests.
response_format: Desired response format (e.g. "url" or "b64_json" for
image requests). None means use the default for the request type.
"""
prompt: OmniTextPrompt
sampling_params_list: list | None = None
request_type: RequestType = RequestType.CHAT_COMPLETION
fps: int = 0
response_format: str | None = None
class OmniHandler(BaseWorkerHandler): def prepare_image_output(images: list, response_format: str | None = None):
"""Handler for multi-stage pipelines using vLLM-Omni's AsyncOmni orchestrator.""" """Prepare image output for response.
Args:
images: List of PIL Image objects.
response_format: Response format.
Returns:
List of image URLs or base64 strings.
"""
## This is a temporary function to prepare image output for response.
## Right now, there are different utilities across components that uploads image/video outputs to urls or b64_json.
## (ayushag) TODO: follow up, move all the utilities to common
outlist = []
for img in images:
if response_format == "url":
output_dir = "/tmp/dynamo_images" # noqa: S108
os.makedirs(output_dir, exist_ok=True)
img_path = os.path.join(output_dir, f"{uuid.uuid4()}.png")
img.save(img_path)
outlist.append(img_path)
elif response_format == "b64_json" or response_format is None:
# convert image to base64
buffer = BytesIO()
img.save(buffer, format="PNG")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
data_url = f"data:image/png;base64,{img_base64}"
outlist.append(data_url)
else:
raise ValueError(f"Invalid response format: {response_format}")
return outlist
class OmniHandler(BaseOmniHandler):
"""Unified handler for multi-stage pipelines using vLLM-Omni.
Handles text-to-text, text-to-image, and text-to-video generation.
"""
def __init__( def __init__(
self, self,
...@@ -25,145 +107,75 @@ class OmniHandler(BaseWorkerHandler): ...@@ -25,145 +107,75 @@ class OmniHandler(BaseWorkerHandler):
default_sampling_params: Dict[str, Any], default_sampling_params: Dict[str, Any],
shutdown_event: asyncio.Event | None = None, shutdown_event: asyncio.Event | None = None,
): ):
"""Initialize handler with AsyncOmni orchestrator.""" """Initialize the unified Omni handler.
logger.info(
f"Initializing OmniHandler for multi-stage pipelines with model: {config.model}" Args:
runtime: Dynamo distributed runtime.
component: Dynamo component handle.
config: Parsed Config object from args.py.
default_sampling_params: Default sampling parameters dict.
shutdown_event: Optional asyncio event for graceful shutdown.
"""
super().__init__(
runtime=runtime,
component=component,
config=config,
default_sampling_params=default_sampling_params,
shutdown_event=shutdown_event,
) )
omni_kwargs = {
"model": config.model,
"trust_remote_code": config.engine_args.trust_remote_code,
"stage_configs_path": config.stage_configs_path,
}
self.engine_client = AsyncOmni(**omni_kwargs)
# Initialize attributes needed from BaseWorkerHandler
# We don't call super().__init__() because VllmEngineMonitor expects AsyncLLM,
# but AsyncOmni manages its own engines internally
# TODO: Kv publishers not supported yet
# TODO: Adopt to baseworker initialization pattern
self.default_sampling_params = default_sampling_params
self.config = config
self.model_max_len = config.engine_args.max_model_len
self.shutdown_event = shutdown_event
self.use_vllm_tokenizer = config.use_vllm_tokenizer
logger.info("OmniHandler initialized successfully for text-to-text generation")
async def generate( async def generate(
self, request: Dict[str, Any], context self, request: Dict[str, Any], context
) -> AsyncGenerator[Dict, None]: ) -> AsyncGenerator[Dict, None]:
"""Generate outputs using AsyncOmni orchestrator with OpenAI-compatible format. """Generate outputs via the unified OpenAI mode.
Args:
request: Raw request dictionary from the Rust frontend.
context: Dynamo context for request tracking.
Supports text-to-text and text-to-image generation based on stage configuration. Yields:
Returns OpenAI-compatible streaming chunks with detokenized text. Response dictionaries.
""" """
request_id = context.id() request_id = context.id()
logger.debug(f"Omni Request ID: {request_id}") logger.debug(f"Omni Request ID: {request_id}")
if self.use_vllm_tokenizer:
async for chunk in self._generate_openai_mode(request, context, request_id): async for chunk in self._generate_openai_mode(request, context, request_id):
yield chunk yield chunk
else:
async for chunk in self._generate_token_mode(request, context, request_id):
yield chunk
# Not used right now async def _generate_openai_mode(
async def _generate_token_mode(self, request, context, request_id): self, request: Dict[str, Any], context, request_id: str
""" ) -> AsyncGenerator[Dict[str, Any], None]:
This mode returns token-ids as output """Single generation path for all request protocols and output modalities."""
Text input -> Token-ids output
"""
token_ids = request.get("token_ids")
prompt = OmniTokensPrompt(token_ids=token_ids)
num_output_tokens_so_far = 0
try:
async for stage_output in self.engine_client.generate(
prompt=prompt,
request_id=request_id,
):
vllm_output = stage_output.request_output
if not vllm_output.outputs: parsed_request, request_type = parse_request_type(
logger.warning(f"Request {request_id} returned no outputs") request, self.config.output_modalities
yield {
"finish_reason": "error: No outputs from vLLM engine",
"token_ids": [],
}
break
output = vllm_output.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = self._normalize_finish_reason(
output.finish_reason
)
out["completion_usage"] = self._build_completion_usage(vllm_output)
logger.debug(
f"Completed generation for request {request_id}: "
f"{next_total_toks} output tokens, finish_reason={output.finish_reason}"
) )
inputs = self.build_engine_inputs(parsed_request, request_type)
if output.stop_reason: generate_kwargs: Dict[str, Any] = {
out["stop_reason"] = output.stop_reason "prompt": inputs.prompt,
"request_id": request_id,
yield out
num_output_tokens_so_far = next_total_toks
except GeneratorExit:
# Shutdown was triggered during generation
logger.info(f"Request {request_id} aborted due to shutdown")
raise
except Exception as e:
logger.error(f"Error during generation for request {request_id}: {e}")
yield {
"finish_reason": f"error: {str(e)}",
"token_ids": [],
} }
if inputs.sampling_params_list is not None:
async def _generate_openai_mode(self, request, context, request_id): generate_kwargs["sampling_params_list"] = inputs.sampling_params_list
"""
This mode returns OpenAI-compatible streaming chunks
Text input -> Text output / Image output
"""
# (ayushag) TODO: Support all type of OmniPrompt. Right now it works for only text prompts
# (ayushag) TODO: Document all I/O formats from vllm omni
# OmniText prompt support additional negative prompts as well. need to support that as well.
# Support multimodal content as well. That will involve applying tokenizer to the prompt and loading images. Follow general multimodal support pattern.
prompt = self._extract_text_prompt(request)
prompt = OmniTextPrompt(prompt=prompt)
# Build sampling parameters from request
# (ayushag) TODO: Need to add proper multi-stage sampling param support
# sampling_params = self._build_sampling_params(request)
# sampling_params_list = [sampling_params]
previous_text = "" previous_text = ""
async with self._abort_monitor(context, request_id): async with self._abort_monitor(context, request_id):
try: try:
async for stage_output in self.engine_client.generate( async for stage_output in self.engine_client.generate(
prompt=prompt, **generate_kwargs,
request_id=request_id,
# sampling_params_list=sampling_params_list,
): ):
if ( if (
stage_output.final_output_type == "text" stage_output.final_output_type == "text"
and stage_output.request_output and stage_output.request_output
): ):
# Text generation (LLM stage)
chunk = self._format_text_chunk( chunk = self._format_text_chunk(
stage_output.request_output, stage_output.request_output,
request_id, request_id,
previous_text, previous_text,
) )
if chunk: if chunk:
# Update previous_text for delta calculation
output = stage_output.request_output.outputs[0] output = stage_output.request_output.outputs[0]
previous_text = output.text previous_text = output.text
yield chunk yield chunk
...@@ -172,10 +184,21 @@ class OmniHandler(BaseWorkerHandler): ...@@ -172,10 +184,21 @@ class OmniHandler(BaseWorkerHandler):
stage_output.final_output_type == "image" stage_output.final_output_type == "image"
and stage_output.images and stage_output.images
): ):
# Image generation (diffusion stage) # vllm-omni uses final_output_type="image" for both
# image and video diffusion outputs. Use the parsed
# request type to route to the correct formatter.
if inputs.request_type == RequestType.VIDEO_GENERATION:
chunk = await self._format_video_chunk(
stage_output.images,
request_id,
fps=inputs.fps,
)
else:
chunk = self._format_image_chunk( chunk = self._format_image_chunk(
stage_output.images, stage_output.images,
request_id, request_id,
response_format=inputs.response_format,
request_type=inputs.request_type,
) )
if chunk: if chunk:
yield chunk yield chunk
...@@ -187,71 +210,162 @@ class OmniHandler(BaseWorkerHandler): ...@@ -187,71 +210,162 @@ class OmniHandler(BaseWorkerHandler):
logger.error(f"Error during generation for request {request_id}: {e}") logger.error(f"Error during generation for request {request_id}: {e}")
yield self._error_chunk(request_id, str(e)) yield self._error_chunk(request_id, str(e))
def _format_text_chunk( def build_engine_inputs(
self, self,
request_output, parsed_request: Union[
request_id: str, NvCreateImageRequest, NvCreateVideoRequest, Dict[str, Any]
previous_text: str, ],
) -> Dict[str, Any] | None: request_type: RequestType,
"""Format text output as OpenAI chat completion chunk.""" ) -> EngineInputs:
if not request_output.outputs: """Convert a parsed request into AsyncOmni engine inputs.
return self._error_chunk(request_id, "No outputs from engine")
output = request_output.outputs[0] Args:
parsed_request: Output from parse_request_type -- a Pydantic model
for image/video requests, or a raw dict for chat completions.
request_type: The RequestType determined by parse_request_type.
# Calculate delta text (new text since last chunk) Returns:
delta_text = output.text[len(previous_text) :] EngineInputs ready for engine_client.generate().
"""
if request_type == RequestType.CHAT_COMPLETION:
return self._engine_inputs_from_chat(parsed_request)
elif request_type == RequestType.IMAGE_GENERATION:
return self._engine_inputs_from_image(parsed_request)
elif request_type == RequestType.VIDEO_GENERATION:
return self._engine_inputs_from_video(parsed_request)
chunk = { elif request_type == RequestType.AUDIO_GENERATION:
"id": request_id, raise NotImplementedError("Audio generation is not yet supported")
"created": int(time.time()),
"object": "chat.completion.chunk",
"model": self.config.served_model_name or self.config.model,
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": delta_text,
},
"finish_reason": self._normalize_finish_reason(output.finish_reason)
if output.finish_reason
else None,
}
],
}
# Add usage on final chunk raise ValueError(f"Unknown request type: {request_type}")
if output.finish_reason:
chunk["usage"] = self._build_completion_usage(request_output)
return chunk def _engine_inputs_from_chat(self, request: Dict[str, Any]) -> EngineInputs:
"""Build engine inputs from a chat completions request dict."""
# Chat completions request does not support extra_body passthrough
# So, we can't extract any diffusion related params from the raw_request
# It falls back to default sampling params
text_prompt = self._extract_text_prompt(request)
if text_prompt is None:
raise ValueError("No user message found in chat completion request")
prompt = OmniTextPrompt(prompt=text_prompt)
sampling_params_list = None
return EngineInputs(
prompt=prompt,
sampling_params_list=sampling_params_list,
request_type=RequestType.CHAT_COMPLETION,
fps=0,
)
def _engine_inputs_from_image(self, req: NvCreateImageRequest) -> EngineInputs:
"""Build engine inputs from an NvCreateImageRequest."""
width, height = parse_size(req.size, default_w=1024, default_h=1024)
nvext = req.nvext
prompt = OmniTextPrompt(
prompt=req.prompt,
negative_prompt=nvext.negative_prompt or "" if nvext else "",
)
sp = OmniDiffusionSamplingParams(
height=height,
width=width,
)
if req.n is not None:
sp.num_outputs_per_prompt = req.n
if nvext:
if nvext.num_inference_steps is not None:
sp.num_inference_steps = nvext.num_inference_steps
if nvext.guidance_scale is not None:
sp.guidance_scale = nvext.guidance_scale
if nvext.seed is not None:
sp.seed = nvext.seed
return EngineInputs(
prompt=prompt,
sampling_params_list=[sp],
request_type=RequestType.IMAGE_GENERATION,
response_format=req.response_format,
)
def _engine_inputs_from_video(self, req: NvCreateVideoRequest) -> EngineInputs:
"""Build engine inputs from an NvCreateVideoRequest."""
width, height = parse_size(req.size)
nvext = req.nvext
nvext_fps = nvext.fps if nvext else None
nvext_num_frames = nvext.num_frames if nvext else None
num_frames = compute_num_frames(
num_frames=nvext_num_frames,
seconds=req.seconds,
fps=nvext_fps,
default_fps=DEFAULT_VIDEO_FPS,
)
fps = nvext_fps if nvext_fps is not None else DEFAULT_VIDEO_FPS
prompt = OmniTextPrompt(
prompt=req.prompt,
negative_prompt=nvext.negative_prompt or "" if nvext else "",
)
sp = OmniDiffusionSamplingParams(
height=height,
width=width,
num_frames=num_frames,
)
if nvext:
if nvext.num_inference_steps is not None:
sp.num_inference_steps = nvext.num_inference_steps
if nvext.guidance_scale is not None:
sp.guidance_scale = nvext.guidance_scale
if nvext.seed is not None:
sp.seed = nvext.seed
if fps is not None:
sp.fps = fps
logger.info(
f"Video diffusion request: prompt='{req.prompt[:50]}...', "
f"size={width}x{height}, frames={num_frames}, fps={fps}"
)
return EngineInputs(
prompt=prompt,
sampling_params_list=[sp],
request_type=RequestType.VIDEO_GENERATION,
fps=fps,
)
def _format_image_chunk( def _format_image_chunk(
self, self,
images: list, images: list,
request_id: str, request_id: str,
response_format: str | None = None,
request_type: RequestType = RequestType.IMAGE_GENERATION,
) -> Dict[str, Any] | None: ) -> Dict[str, Any] | None:
"""Format image output as OpenAI chat completion chunk with base64 data URLs.""" """Format image output as OpenAI chat completion chunk with base64 data URLs.
import base64
from io import BytesIO Args:
images: List of PIL Image objects generated by AsyncOmni engine.
request_id: Unique request identifier.
response_format: Response format (url, b64_json, None).
request_type: Request type (chat completion, image generation).
Returns:
Dict[str, Any] | None: Formatted chunk, or None if no images generated.
"""
if not images: if not images:
return self._error_chunk(request_id, "No images generated") return self._error_chunk(request_id, "No images generated")
# Convert images to base64 data URLs data_urls = prepare_image_output(images, response_format)
data_urls = []
for idx, img in enumerate(images):
# Convert PIL image to base64
buffer = BytesIO()
img.save(buffer, format="PNG")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
# Create data URL (can be opened directly in browser)
data_url = f"data:image/png;base64,{img_base64}"
data_urls.append(data_url)
logger.info(f"Generated image {idx} for request {request_id}")
if request_type == RequestType.CHAT_COMPLETION:
# This branch is used when user send request via /v1/chat/completions endpoint.
# We need to return chat completion chunk with image_url content part.
chunk = { chunk = {
"id": request_id, "id": request_id,
"created": int(time.time()), "created": int(time.time()),
...@@ -271,29 +385,114 @@ class OmniHandler(BaseWorkerHandler): ...@@ -271,29 +385,114 @@ class OmniHandler(BaseWorkerHandler):
} }
], ],
} }
return chunk return chunk
elif request_type == RequestType.IMAGE_GENERATION:
# This branch is used when user send request via /v1/images/generations endpoint.
# This will return NvImagesResponse with list of ImageData objects.
image_data_list = []
for data_url in data_urls:
if response_format == "url":
image_data_list.append(ImageData(url=data_url))
elif response_format == "b64_json" or response_format is None:
# strip explicit prefix if present
if data_url.startswith("data:image"):
_, b64_part = data_url.split(",", 1)
image_data_list.append(ImageData(b64_json=b64_part))
else:
image_data_list.append(ImageData(b64_json=data_url))
else:
raise ValueError(f"Invalid response format: {response_format}")
output = NvImagesResponse(created=int(time.time()), data=image_data_list)
return output.model_dump()
async def _format_video_chunk(
self,
images: list,
request_id: str,
fps: int,
) -> Dict[str, Any] | None:
"""Convert diffusion output frames to MP4 and return as NvVideosResponse.
def _extract_text_prompt(self, request: Dict[str, Any]) -> str | None: Args:
"""Extract text prompt from request.""" images: List of PIL Image frames from the diffusion stage.
request_id: Unique request identifier.
# OpenAI messages format - extract text content only fps: Frames per second for the output video.
messages = request.get("messages", [])
# Assumes single user message Returns:
for message in messages: ``NvVideosResponse.model_dump()`` dict, or ``None`` if no frames.
if message.get("role") == "user": """
return message.get("content") if not images:
return "" return None
def _build_sampling_params(self, request: Dict[str, Any]) -> SamplingParams: try:
"""Build sampling params using shared handler utility.""" start_time = time.time()
return build_sampling_params(
request, self.default_sampling_params, self.model_max_len frame_list = normalize_video_frames(images)
logger.info(
f"Encoding {len(frame_list)} frames to MP4 for request {request_id} "
f"(fps={fps})"
)
os.makedirs(DEFAULT_VIDEO_OUTPUT_DIR, exist_ok=True)
video_path = os.path.join(DEFAULT_VIDEO_OUTPUT_DIR, f"{request_id}.mp4")
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None,
export_to_video,
frame_list,
video_path,
fps,
) )
def _error_chunk(self, request_id: str, error_message: str) -> Dict[str, Any]: logger.info(f"Video saved to {video_path} for request {request_id}")
"""Create an error chunk in OpenAI format."""
return { inference_time = time.time() - start_time
response = NvVideosResponse(
id=request_id,
object="video",
model=self.config.served_model_name or self.config.model,
status="completed",
progress=100,
created=int(time.time()),
data=[VideoData(url=video_path)],
inference_time_s=inference_time,
)
return response.model_dump()
except Exception as e:
logger.error(f"Failed to encode video for request {request_id}: {e}")
error_response = NvVideosResponse(
id=request_id,
object="video",
model=self.config.served_model_name or self.config.model,
status="failed",
progress=0,
created=int(time.time()),
data=[],
error=str(e),
)
return error_response.model_dump()
def _format_text_chunk(
self,
request_output,
request_id: str,
previous_text: str,
) -> Dict[str, Any] | None:
"""Format text output as OpenAI chat completion chunk."""
if not request_output.outputs:
return self._error_chunk(request_id, "No outputs from engine")
output = request_output.outputs[0]
# Calculate delta text (new text since last chunk)
delta_text = output.text[len(previous_text) :]
chunk = {
"id": request_id, "id": request_id,
"created": int(time.time()), "created": int(time.time()),
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
...@@ -303,18 +502,17 @@ class OmniHandler(BaseWorkerHandler): ...@@ -303,18 +502,17 @@ class OmniHandler(BaseWorkerHandler):
"index": 0, "index": 0,
"delta": { "delta": {
"role": "assistant", "role": "assistant",
"content": f"Error: {error_message}", "content": delta_text,
}, },
"finish_reason": "error", "finish_reason": self._normalize_finish_reason(output.finish_reason)
if output.finish_reason
else None,
} }
], ],
} }
def cleanup(self): # Add usage on final chunk
"""Cleanup AsyncOmni orchestrator resources.""" if output.finish_reason:
try: chunk["usage"] = self._build_completion_usage(request_output)
if hasattr(self, "engine_client"):
self.engine_client.close() return chunk
logger.info("AsyncOmni orchestrator closed")
except Exception as e:
logger.error(f"Error closing AsyncOmni orchestrator: {e}")
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import MagicMock, patch
import pytest
from dynamo.common.protocols.image_protocol import NvCreateImageRequest
from dynamo.common.protocols.video_protocol import NvCreateVideoRequest
from dynamo.common.utils.output_modalities import RequestType
# TODO: Install vLLM omni dependencies in CI container so this skip is no longer needed.
try:
from dynamo.vllm.omni.omni_handler import (
EngineInputs,
OmniHandler,
prepare_image_output,
)
except ImportError:
pytest.skip("vLLM omni dependencies not available", allow_module_level=True)
pytestmark = [
pytest.mark.unit,
pytest.mark.vllm,
pytest.mark.gpu_0,
pytest.mark.pre_merge,
]
def _make_handler():
with patch(
"dynamo.vllm.omni.omni_handler.BaseOmniHandler.__init__", return_value=None
):
handler = OmniHandler.__new__(OmniHandler)
config = MagicMock()
config.model = "test-model"
config.served_model_name = None
config.output_modalities = ["text"]
handler.config = config
return handler
class TestEngineInputs:
def test_defaults(self):
"""EngineInputs uses CHAT_COMPLETION, fps=0, and None optionals by default."""
ei = EngineInputs(prompt={"prompt": "hello"})
assert ei.request_type == RequestType.CHAT_COMPLETION
assert ei.fps == 0
assert ei.sampling_params_list is None
assert ei.response_format is None
class TestPrepareImageOutput:
def test_b64_json(self):
"""b64_json format returns data URI with base64 prefix."""
img = MagicMock()
img.save = lambda b, format: b.write(b"fake_png_data")
results = prepare_image_output([img], "b64_json")
assert len(results) == 1
assert results[0].startswith("data:image/png;base64,")
def test_b64_default_when_none(self):
"""None response_format defaults to base64 encoding."""
img = MagicMock()
img.save = lambda b, format: b.write(b"data")
results = prepare_image_output([img], None)
assert results[0].startswith("data:image/png;base64,")
def test_invalid_format(self):
"""Unsupported response_format raises ValueError."""
with pytest.raises(ValueError, match="Invalid response format"):
prepare_image_output([MagicMock()], "invalid")
def test_multiple_images(self):
"""Multiple input images produce one output entry each."""
imgs = [MagicMock() for _ in range(3)]
for img in imgs:
img.save = lambda b, format: b.write(b"px")
results = prepare_image_output(imgs, "b64_json")
assert len(results) == 3
class TestBuildEngineInputs:
def test_chat_completion(self):
"""Chat request extracts text prompt with no sampling params."""
handler = _make_handler()
raw = {"messages": [{"role": "user", "content": "hello"}]}
inputs = handler.build_engine_inputs(raw, RequestType.CHAT_COMPLETION)
assert inputs.request_type == RequestType.CHAT_COMPLETION
assert inputs.prompt["prompt"] == "hello"
assert inputs.sampling_params_list is None
def test_image_generation(self):
"""Image request parses prompt, size, and creates diffusion sampling params."""
handler = _make_handler()
req = NvCreateImageRequest(prompt="a cat", size="512x512")
inputs = handler.build_engine_inputs(req, RequestType.IMAGE_GENERATION)
assert inputs.request_type == RequestType.IMAGE_GENERATION
assert inputs.prompt["prompt"] == "a cat"
assert len(inputs.sampling_params_list) == 1
sp = inputs.sampling_params_list[0]
assert sp.height == 512
assert sp.width == 512
def test_video_generation(self):
"""Video request parses prompt, size, seconds, and sets fps."""
handler = _make_handler()
req = NvCreateVideoRequest(
prompt="a drone", model="test", size="832x480", seconds=2
)
inputs = handler.build_engine_inputs(req, RequestType.VIDEO_GENERATION)
assert inputs.request_type == RequestType.VIDEO_GENERATION
assert inputs.prompt["prompt"] == "a drone"
assert inputs.fps > 0
def test_audio_not_implemented(self):
"""Audio generation raises NotImplementedError."""
handler = _make_handler()
with pytest.raises(NotImplementedError):
handler.build_engine_inputs({}, RequestType.AUDIO_GENERATION)
class TestFormatTextChunk:
def _make_output(self, text="hello world", finish_reason=None):
output = MagicMock()
output.text = text
output.finish_reason = finish_reason
request_output = MagicMock()
request_output.outputs = [output]
request_output.prompt_token_ids = [1, 2, 3]
return request_output
def test_delta_text(self):
"""Delta content is the diff between current and previous text."""
handler = _make_handler()
ro = self._make_output("hello world")
chunk = handler._format_text_chunk(ro, "req-1", "hello ")
assert chunk["choices"][0]["delta"]["content"] == "world"
def test_no_outputs_returns_error(self):
"""Empty engine outputs produce an error chunk."""
handler = _make_handler()
ro = MagicMock()
ro.outputs = []
chunk = handler._format_text_chunk(ro, "req-1", "")
assert "Error" in chunk["choices"][0]["delta"]["content"]
def test_finish_reason_included(self):
"""Final chunk includes finish_reason and usage stats."""
handler = _make_handler()
handler._normalize_finish_reason = lambda r: r
handler._build_completion_usage = lambda ro: {
"prompt_tokens": 3,
"completion_tokens": 1,
}
ro = self._make_output("done", finish_reason="stop")
chunk = handler._format_text_chunk(ro, "req-1", "")
assert chunk["choices"][0]["finish_reason"] == "stop"
assert "usage" in chunk
class TestFormatImageChunk:
def test_chat_completion_format(self):
"""Chat completion route returns image_url content parts."""
handler = _make_handler()
img = MagicMock()
img.save = lambda b, format: b.write(b"px")
chunk = handler._format_image_chunk(
[img], "req-1", request_type=RequestType.CHAT_COMPLETION
)
assert chunk["object"] == "chat.completion.chunk"
assert chunk["choices"][0]["delta"]["content"][0]["type"] == "image_url"
def test_image_generation_b64_format(self):
"""Image generation with b64_json format returns base64 data."""
handler = _make_handler()
img = MagicMock()
img.save = lambda b, format: b.write(b"px")
chunk = handler._format_image_chunk(
[img],
"req-1",
response_format="b64_json",
request_type=RequestType.IMAGE_GENERATION,
)
assert chunk["data"][0]["b64_json"] is not None
def test_image_generation_default_format_returns_b64(self):
"""Image generation with response_format=None defaults to b64_json."""
handler = _make_handler()
img = MagicMock()
img.save = lambda b, format: b.write(b"px")
chunk = handler._format_image_chunk(
[img],
"req-1",
response_format=None,
request_type=RequestType.IMAGE_GENERATION,
)
assert chunk["data"][0]["b64_json"] is not None
def test_empty_images_returns_error(self):
"""Empty image list produces an error chunk."""
handler = _make_handler()
chunk = handler._format_image_chunk([], "req-1")
assert "Error" in chunk["choices"][0]["delta"]["content"]
class TestFormatVideoChunk:
@pytest.mark.asyncio
async def test_empty_frames_returns_none(self):
"""Empty frame list returns None."""
handler = _make_handler()
result = await handler._format_video_chunk([], "req-1", fps=16)
assert result is None
@pytest.mark.asyncio
async def test_error_returns_failed_status(self):
"""Encoding failure returns NvVideosResponse with failed status and error."""
handler = _make_handler()
with patch(
"dynamo.vllm.omni.omni_handler.normalize_video_frames",
side_effect=RuntimeError("boom"),
):
chunk = await handler._format_video_chunk([MagicMock()], "req-1", fps=16)
assert chunk["status"] == "failed"
assert "boom" in chunk["error"]
...@@ -4,34 +4,53 @@ ...@@ -4,34 +4,53 @@
title: vLLM-Omni title: vLLM-Omni
--- ---
# [Experimental] Running Omni Models with vLLM # [Experimental] Omni Models with vLLM
Dynamo supports omni (multimodal generation) models via the [vLLM-Omni](https://github.com/vllm-project/vllm-omni) backend. This enables multi-stage pipelines for tasks like text-to-text and text-to-image generation through an OpenAI-compatible API. Dynamo supports multimodal generation through the [vLLM-Omni](https://github.com/vllm-project/vllm-omni) backend. This integration exposes text-to-text, text-to-image, and text-to-video capabilities via OpenAI-compatible API endpoints.
## Prerequisites ## Prerequisites
This guide assumes familiarity with deploying Dynamo with vLLM as described in [README.md](/docs/pages/backends/vllm/README.md). This guide assumes familiarity with deploying Dynamo with vLLM as described in the [vLLM backend guide](/docs/pages/backends/vllm/README.md).
## Quick Start ## Supported Modalities
### Text-to-Text | Modality | Endpoint(s) | `--output-modalities` |
|---|---|---|
| Text-to-Text | `/v1/chat/completions` | `text` (default) |
| Text-to-Image | `/v1/chat/completions`, `/v1/images/generations` | `image` |
| Text-to-Video | `/v1/videos` | `video` |
Launch an aggregated deployment (frontend + omni worker) using the provided script: The `--output-modalities` flag determines which endpoint(s) the worker registers. When set to `image`, both `/v1/chat/completions` (returns inline base64 images) and `/v1/images/generations` are available. When set to `video`, the worker serves `/v1/videos`.
## Tested Models
| Modality | Models |
|---|---|
| Text-to-Text | `Qwen/Qwen2.5-Omni-7B` |
| Text-to-Image | `Qwen/Qwen-Image`, `AIDC-AI/Ovis-Image-7B` |
| Text-to-Video | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers`, `Wan-AI/Wan2.2-T2V-A14B-Diffusers` |
To run a non-default model, pass `--model` to any launch script:
```bash ```bash
bash examples/backends/vllm/launch/agg_omni.sh bash examples/backends/vllm/launch/agg_omni_image.sh --model AIDC-AI/Ovis-Image-7B
bash examples/backends/vllm/launch/agg_omni_video.sh --model Wan-AI/Wan2.2-T2V-A14B-Diffusers
``` ```
This starts `Qwen/Qwen2.5-Omni-7B` with a single-stage thinker config on one GPU. Override the model with: ## Text-to-Text
Launch an aggregated deployment (frontend + omni worker):
```bash ```bash
bash examples/backends/vllm/launch/agg_omni.sh --model <your-model> bash examples/backends/vllm/launch/agg_omni.sh
``` ```
Test the deployment: This starts `Qwen/Qwen2.5-Omni-7B` with a single-stage thinker config on one GPU.
Verify the deployment:
```bash ```bash
curl -X POST http://localhost:8000/v1/chat/completions \ curl -s http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"model": "Qwen/Qwen2.5-Omni-7B", "model": "Qwen/Qwen2.5-Omni-7B",
...@@ -41,34 +60,29 @@ curl -X POST http://localhost:8000/v1/chat/completions \ ...@@ -41,34 +60,29 @@ curl -X POST http://localhost:8000/v1/chat/completions \
}' }'
``` ```
### Text-to-Image This script uses a custom stage config (`stage_configs/single_stage_llm.yaml`) that configures the thinker stage for text generation. See [Stage Configuration](#stage-configuration) for details.
## Text-to-Image
Text-to-image uses vLLM-Omni's built-in default stage configs (no custom YAML needed). Launch without a stage config path so vLLM-Omni loads the model's default multi-stage pipeline: Launch using the provided script with `Qwen/Qwen-Image`:
```bash ```bash
# Start frontend bash examples/backends/vllm/launch/agg_omni_image.sh
python -m dynamo.frontend &
# Start omni worker (vLLM-Omni loads default stage configs for the model)
DYN_SYSTEM_PORT=8081 python -m dynamo.vllm \
--model <your-text-to-image-model> \
--omni \
--connector none
``` ```
Images are returned as base64-encoded PNGs in the response: ### Via `/v1/chat/completions`
```bash ```bash
curl -X POST http://localhost:8000/v1/chat/completions \ curl -s http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"model": "<your-text-to-image-model>", "model": "Qwen/Qwen-Image",
"messages": [{"role": "user", "content": "A cat sitting on a windowsill"}], "messages": [{"role": "user", "content": "A cat sitting on a windowsill"}],
"stream": false "stream": false
}' }'
``` ```
The response contains image data URLs in the content field: The response includes base64-encoded images inline:
```json ```json
{ {
...@@ -82,25 +96,79 @@ The response contains image data URLs in the content field: ...@@ -82,25 +96,79 @@ The response contains image data URLs in the content field:
} }
``` ```
## Key Flags ### Via `/v1/images/generations`
```bash
curl -s http://localhost:8000/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen-Image",
"prompt": "A cat sitting on a windowsill",
"size": "1024x1024",
"response_format": "url"
}'
```
## Text-to-Video
Launch using the provided script with `Wan-AI/Wan2.1-T2V-1.3B-Diffusers`:
```bash
bash examples/backends/vllm/launch/agg_omni_video.sh
```
Generate a video via `/v1/videos`:
```bash
curl -s http://localhost:8000/v1/videos \
-H "Content-Type: application/json" \
-d '{
"model": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
"prompt": "A drone flyover of a mountain landscape",
"seconds": 2,
"size": "832x480",
"response_format": "url"
}'
```
The response returns a video URL or base64 data depending on `response_format`:
```json
{
"id": "...",
"object": "video",
"model": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
"status": "completed",
"data": [{"url": "/tmp/generated_video.mp4"}]
}
```
The `/v1/videos` endpoint also accepts NVIDIA extensions via the `nvext` field for fine-grained control:
| Field | Description | Default |
|---|---|---|
| `nvext.fps` | Frames per second | 24 |
| `nvext.num_frames` | Number of frames (overrides `fps * seconds`) | -- |
| `nvext.negative_prompt` | Negative prompt for guidance | -- |
| `nvext.num_inference_steps` | Number of denoising steps | 50 |
| `nvext.guidance_scale` | CFG guidance scale | 5.0 |
| `nvext.seed` | Random seed for reproducibility | -- |
## CLI Reference
| Flag | Description | | Flag | Description |
|------|-------------| |---|---|
| `--omni` | Enable vLLM-Omni orchestrator (required) | | `--omni` | Enable the vLLM-Omni orchestrator (required for all omni workloads) |
| `--output-modalities <modality>` | Output modality: `text`, `image`, or `video` |
| `--stage-configs-path <path>` | Path to stage config YAML (optional; vLLM-Omni uses model defaults if omitted) | | `--stage-configs-path <path>` | Path to stage config YAML (optional; vLLM-Omni uses model defaults if omitted) |
| `--connector none` | Disable KV connector (recommended for omni) | | `--connector none` | Disable KV connector (recommended for omni workers) |
## Stage Configuration ## Stage Configuration
Omni pipelines are configured via YAML stage configs. See [`examples/backends/vllm/launch/stage_configs/single_stage_llm.yaml`](/examples/backends/vllm/launch/stage_configs/single_stage_llm.yaml) for an example. Key fields: Omni pipelines are configured via YAML stage configs. See [`examples/backends/vllm/launch/stage_configs/single_stage_llm.yaml`](/examples/backends/vllm/launch/stage_configs/single_stage_llm.yaml) for an example. For full documentation on stage config format and multi-stage pipelines, refer to the [vLLM-Omni Stage Configs documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/).
- **`model_stage`**: Pipeline stage name (e.g., `thinker`, `talker`, `code2wav`)
- **`final_output_type`**: Output format — `text` or `image`
- **`is_comprehension`**: Whether this stage processes input text/multimodal content
For full documentation on stage config format, supported fields, and multi-stage pipeline examples, see the [vLLM-Omni Stage Configs documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/).
## Current Limitations ## Current Limitations
- Only text prompts are supported (no multimodal input yet) - Only text prompts are supported as input (no multimodal input yet).
- KV cache events are not published for omni workers - KV cache events are not published for omni workers.
- Each worker supports a single output modality at a time.
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
MODEL="Qwen/Qwen-Image"
# Parse command line arguments
EXTRA_ARGS=()
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL="$2"
shift 2
;;
*)
EXTRA_ARGS+=("$1")
shift
;;
esac
done
echo "=========================================="
echo "Starting vLLM-Omni Worker"
echo "Model: $MODEL"
echo "=========================================="
echo "Starting frontend on port ${DYN_HTTP_PORT:-8000}..."
python -m dynamo.frontend &
FRONTEND_PID=$!
sleep 2
echo "Starting Omni worker..."
DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT:-8081} \
python -m dynamo.vllm \
--model "$MODEL" \
--omni \
--connector none \
--output-modalities image \
"${EXTRA_ARGS[@]}"
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
MODEL="Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
# Parse command line arguments
EXTRA_ARGS=()
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL="$2"
shift 2
;;
*)
EXTRA_ARGS+=("$1")
shift
;;
esac
done
echo "=========================================="
echo "Starting vLLM-Omni Worker"
echo "Model: $MODEL"
echo "=========================================="
echo "Starting frontend on port ${DYN_HTTP_PORT:-8000}..."
python -m dynamo.frontend &
FRONTEND_PID=$!
sleep 2
echo "Starting Omni worker..."
DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT:-8081} \
python -m dynamo.vllm \
--model "$MODEL" \
--omni \
--connector none \
--output-modalities video \
"${EXTRA_ARGS[@]}"
\ No newline at end of file
...@@ -589,8 +589,65 @@ impl ModelWatcher { ...@@ -589,8 +589,65 @@ impl ModelWatcher {
let engine = Arc::new(push_router); let engine = Arc::new(push_router);
self.manager self.manager
.add_embeddings_model(card.name(), checksum, engine)?; .add_embeddings_model(card.name(), checksum, engine)?;
}
// Case: Text + (Images, Audio, Videos)
// Must come before the plain Text+Chat / Text+Completions branches because
// diffusion models often set both Images and Chat flags. The branch below
// handles the chat registration internally when supports_chat() is true.
else if card.model_input == ModelInput::Text
&& (card.model_type.supports_images()
|| card.model_type.supports_audios()
|| card.model_type.supports_videos())
{
// Image Models can support chat completions (vllm omni way)
// So register chat_completions model as well
if card.model_type.supports_chat() {
let chat_router = PushRouter::<
NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client_with_threshold(
client.clone(),
self.router_config.router_mode,
None,
None,
)
.await?;
self.manager.add_chat_completions_model(
card.name(),
checksum,
Arc::new(chat_router),
)?;
}
// This is ModelType::Images : registers /v1/images/* endpoints
if card.model_type.supports_images() {
let images_router = PushRouter::<
NvCreateImageRequest,
Annotated<NvImagesResponse>,
>::from_client_with_threshold(
client.clone(), self.router_config.router_mode, None, None
)
.await?;
self.manager
.add_images_model(card.name(), checksum, Arc::new(images_router))?;
}
// This is ModelType::Videos : registers /v1/videos/* endpoints
if card.model_type.supports_videos() {
let videos_router = PushRouter::<
NvCreateVideoRequest,
Annotated<NvVideosResponse>,
>::from_client_with_threshold(
client.clone(), self.router_config.router_mode, None, None
)
.await?;
self.manager
.add_videos_model(card.name(), checksum, Arc::new(videos_router))?;
}
// TODO: add audio models support
} else if card.model_input == ModelInput::Text && card.model_type.supports_chat() { } else if card.model_input == ModelInput::Text && card.model_type.supports_chat() {
// Case 3: Text + Chat // Case: Text + Chat (pure text-to-text, no diffusion)
let push_router = PushRouter::< let push_router = PushRouter::<
NvCreateChatCompletionRequest, NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>, Annotated<NvCreateChatCompletionStreamResponse>,
...@@ -602,7 +659,7 @@ impl ModelWatcher { ...@@ -602,7 +659,7 @@ impl ModelWatcher {
self.manager self.manager
.add_chat_completions_model(card.name(), checksum, engine)?; .add_chat_completions_model(card.name(), checksum, engine)?;
} else if card.model_input == ModelInput::Text && card.model_type.supports_completions() { } else if card.model_input == ModelInput::Text && card.model_type.supports_completions() {
// Case 2: Text + Completions // Case: Text + Completions
let push_router = PushRouter::< let push_router = PushRouter::<
NvCreateCompletionRequest, NvCreateCompletionRequest,
Annotated<NvCreateCompletionResponse>, Annotated<NvCreateCompletionResponse>,
...@@ -660,60 +717,6 @@ impl ModelWatcher { ...@@ -660,60 +717,6 @@ impl ModelWatcher {
let engine = Arc::new(push_router); let engine = Arc::new(push_router);
self.manager self.manager
.add_tensor_model(card.name(), checksum, engine)?; .add_tensor_model(card.name(), checksum, engine)?;
}
// Case: Text + (Images, Audio, Videos)
else if card.model_input == ModelInput::Text
&& (card.model_type.supports_images()
|| card.model_type.supports_audios()
|| card.model_type.supports_videos())
{
// Image Models can support chat completions (vllm omni way)
// So register chat_completions model as well
if card.model_type.supports_chat() {
let chat_router = PushRouter::<
NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client_with_threshold(
client.clone(),
self.router_config.router_mode,
None,
None,
)
.await?;
self.manager.add_chat_completions_model(
card.name(),
checksum,
Arc::new(chat_router),
)?;
}
// This is ModelType::Images : registers /v1/images/* endpoints
if card.model_type.supports_images() {
let images_router = PushRouter::<
NvCreateImageRequest,
Annotated<NvImagesResponse>,
>::from_client_with_threshold(
client.clone(), self.router_config.router_mode, None, None
)
.await?;
self.manager
.add_images_model(card.name(), checksum, Arc::new(images_router))?;
}
// This is ModelType::Videos : registers /v1/videos/* endpoints
if card.model_type.supports_videos() {
let videos_router = PushRouter::<
NvCreateVideoRequest,
Annotated<NvVideosResponse>,
>::from_client_with_threshold(
client.clone(), self.router_config.router_mode, None, None
)
.await?;
self.manager
.add_videos_model(card.name(), checksum, Arc::new(videos_router))?;
}
// TODO: add audio models support
} else if card.model_type.supports_prefill() { } else if card.model_type.supports_prefill() {
// Case 6: Prefill // Case 6: Prefill
// Guardrail: Verify model_input is Tokens // Guardrail: Verify model_input is Tokens
......
...@@ -154,7 +154,6 @@ impl ModelType { ...@@ -154,7 +154,6 @@ impl ModelType {
if self.contains(Self::Embedding) { if self.contains(Self::Embedding) {
endpoint_types.push(crate::endpoint_type::EndpointType::Embedding); endpoint_types.push(crate::endpoint_type::EndpointType::Embedding);
} }
// Images models support both chat and completions endpoints
if self.contains(Self::Images) { if self.contains(Self::Images) {
endpoint_types.push(crate::endpoint_type::EndpointType::Images); endpoint_types.push(crate::endpoint_type::EndpointType::Images);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment