"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "57cdb9a1230a86a308a03fcf8ba2a5ce92403894"
Unverified Commit 6dd3ce2e authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

feat: vllm omni text to video generation pipeline (#6104)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 23de4e86
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Protocol types for TensorRT-LLM backend. """Shared protocol types used across multiple Dynamo backends.
This module provides protocol types for various modalities: This module provides protocol types for various modalities:
- video_protocol: NvCreateVideoRequest, NvVideosResponse for video generation - video_protocol: NvCreateVideoRequest, NvVideosResponse for video generation
- image_protocol: (future) Protocol types for image generation
""" """
from dynamo.trtllm.protocols.video_protocol import ( from dynamo.common.protocols.video_protocol import (
NvCreateVideoRequest, NvCreateVideoRequest,
NvVideosResponse, NvVideosResponse,
VideoData, VideoData,
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
from pydantic import BaseModel
# For omni models, we need to support raw_request parsing and json output format. We need to have these protocols defined here for serialization and deserialization.
# TODO: Replace these Pydantic models with Python bindings to the Rust protocol types once PyO3 bindings are available.
class ImageNvExt(BaseModel):
"""NVIDIA extensions for image generation requests.
Matches Rust NvExt in lib/llm/src/protocols/openai/images/nvext.rs.
"""
annotations: Optional[list[str]] = None
"""Annotations for SSE stream events."""
negative_prompt: Optional[str] = None
"""Optional negative prompt."""
num_inference_steps: Optional[int] = None
"""Number of denoising steps."""
guidance_scale: Optional[float] = None
"""CFG guidance scale."""
seed: Optional[int] = None
"""Random seed for reproducibility."""
class NvCreateImageRequest(BaseModel):
"""Request for image generation (/v1/images/generations endpoint).
Matches the flattened Rust NvCreateImageRequest in lib/llm/src/protocols/openai/images.rs
"""
prompt: str
"""The text prompt for image generation."""
model: Optional[str] = None
"""The model to use for image generation."""
n: Optional[int] = None
"""Number of images to generate (1-10)."""
quality: Optional[str] = None
"""Image quality: standard, hd, high, medium, low, auto."""
response_format: Optional[str] = None
"""Response format: url or b64_json."""
size: Optional[str] = None
"""Image size in WxH format (e.g. 1024x1024)."""
style: Optional[str] = None
"""Image style: vivid or natural."""
user: Optional[str] = None
"""Optional user identifier."""
moderation: Optional[str] = None
"""Content moderation level: auto or low."""
nvext: Optional[ImageNvExt] = None
"""NVIDIA extensions."""
class ImageData(BaseModel):
"""Individual image data in a response.
Matches the flattened Rust Image enum in lib/async-openai/src/types/image.rs.
"""
url: Optional[str] = None
"""URL of the generated image (if response_format is url)."""
b64_json: Optional[str] = None
"""Base64-encoded image (if response_format is b64_json)."""
revised_prompt: Optional[str] = None
"""Revised prompt, when the model rewrites the original prompt."""
class NvImagesResponse(BaseModel):
"""Response structure for image generation.
Matches the flattened Rust NvImagesResponse in lib/llm/src/protocols/openai/images.rs
"""
created: int
"""Unix timestamp of creation."""
data: list[ImageData] = []
"""List of generated images."""
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
These types match the Rust protocol types in lib/llm/src/protocols/openai/videos.rs These types match the Rust protocol types in lib/llm/src/protocols/openai/videos.rs
to ensure compatibility with the Dynamo HTTP frontend. to ensure compatibility with the Dynamo HTTP frontend.
""" """
# TODO: Replace these Pydantic models with Python bindings to the Rust protocol types once PyO3 bindings are available.
from typing import Optional from typing import Optional
......
...@@ -2,8 +2,12 @@ ...@@ -2,8 +2,12 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from enum import Enum from enum import Enum
from typing import List, Optional from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel
from dynamo.common.protocols.image_protocol import NvCreateImageRequest
from dynamo.common.protocols.video_protocol import NvCreateVideoRequest
from dynamo.llm import ModelType from dynamo.llm import ModelType
...@@ -32,6 +36,15 @@ class OutputModality(Enum): ...@@ -32,6 +36,15 @@ class OutputModality(Enum):
return {m.name.lower() for m in cls} return {m.name.lower() for m in cls}
class RequestType(Enum):
"""Identifies the parsed request type returned by parse_request_type."""
CHAT_COMPLETION = "chat_completion"
IMAGE_GENERATION = "image_generation"
VIDEO_GENERATION = "video_generation"
AUDIO_GENERATION = "audio_generation"
def get_output_modalities(cli_input: List[str], model_repo: str) -> Optional[ModelType]: def get_output_modalities(cli_input: List[str], model_repo: str) -> Optional[ModelType]:
""" """
Get the combined ModelType flags for omni models based on CLI input. Get the combined ModelType flags for omni models based on CLI input.
...@@ -52,3 +65,34 @@ def get_output_modalities(cli_input: List[str], model_repo: str) -> Optional[Mod ...@@ -52,3 +65,34 @@ def get_output_modalities(cli_input: List[str], model_repo: str) -> Optional[Mod
flag if output_modalities is None else output_modalities | flag flag if output_modalities is None else output_modalities | flag
) )
return output_modalities return output_modalities
def parse_request_type(
raw_request: Dict[str, Any],
output_modalities: List[str],
) -> Tuple[Union[BaseModel, Dict[str, Any]], RequestType]:
"""
Classify the endpoint based on the output modality and serialize the request if necessary.
Assumption: Right now we only consider user passes only one modality at a time.
"""
# Fetch the first output modality from the list.
if not output_modalities:
raise ValueError("output_modalities must not be empty")
output_modality = output_modalities[0]
modality = OutputModality.from_name(output_modality)
if modality is OutputModality.IMAGE:
if "messages" in raw_request:
return raw_request, RequestType.CHAT_COMPLETION
return NvCreateImageRequest(**raw_request), RequestType.IMAGE_GENERATION
if modality is OutputModality.VIDEO:
return NvCreateVideoRequest(**raw_request), RequestType.VIDEO_GENERATION
if modality is OutputModality.AUDIO:
# Audio protocol types are not yet defined; pass through the raw dict.
return raw_request, RequestType.AUDIO_GENERATION
# Text Modality
return raw_request, RequestType.CHAT_COMPLETION
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Video encoding utilities for TensorRT-LLM video diffusion. """Video utilities for video diffusion.
This module provides utilities for encoding numpy video frames to MP4 format. Provides helpers for parsing video request parameters and encoding numpy
video frames to MP4 format.
""" """
import io import io
import logging import logging
import os import os
from typing import Tuple
import numpy as np import numpy as np
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_VIDEO_WIDTH = 832
DEFAULT_VIDEO_HEIGHT = 480
DEFAULT_VIDEO_FPS = 16
DEFAULT_VIDEO_NUM_FRAMES = 97
def parse_size(
size: str | None,
default_w: int = DEFAULT_VIDEO_WIDTH,
default_h: int = DEFAULT_VIDEO_HEIGHT,
) -> Tuple[int, int]:
"""Parse a 'WxH' string into (width, height).
Falls back to default_w x default_h when size is None or malformed.
"""
if not size:
return default_w, default_h
try:
w, h = size.split("x")
return int(w), int(h)
except (ValueError, AttributeError):
logger.warning("Invalid size format: %s, using defaults", size)
return default_w, default_h
def compute_num_frames(
num_frames: int | None = None,
seconds: int | None = None,
fps: int | None = None,
default_fps: int = DEFAULT_VIDEO_FPS,
default_num_frames: int = DEFAULT_VIDEO_NUM_FRAMES,
) -> int:
"""Compute the number of video frames.
Priority: num_frames > seconds × fps > default_num_frames.
"""
if num_frames is not None:
return num_frames
if seconds is not None or fps is not None:
_seconds = seconds if seconds is not None else 4
_fps = fps if fps is not None else default_fps
return _seconds * _fps
return default_num_frames
def normalize_video_frames(images) -> list:
"""Normalize stage_output.images into a frame list for export_to_video.
Args:
images: stage_output.images -- a list that may contain a single
torch.Tensor or np.ndarray representing the full video.
Returns:
List of frames suitable for diffusers export_to_video.
"""
frames = images[0] if len(images) == 1 else images
if isinstance(frames, np.ndarray):
if frames.ndim == 5:
frames = frames[0]
return list(frames)
return list(frames)
def frames_to_numpy(images: list) -> np.ndarray:
"""Convert a list of PIL Images to a numpy array suitable for video encoding.
Args:
images: List of PIL Image objects (video frames).
Returns:
Numpy array of shape ``(num_frames, height, width, 3)`` with dtype
``uint8`` and values in ``[0, 255]``.
Raises:
ValueError: If no images are provided or images have inconsistent sizes.
"""
if not images:
raise ValueError("No images provided for video encoding")
frames = []
for img in images:
arr = np.array(img.convert("RGB"))
frames.append(arr)
# Validate consistent sizes
shapes = {f.shape for f in frames}
if len(shapes) > 1:
raise ValueError(
f"Inconsistent frame sizes detected: {shapes}. "
"All frames must have the same dimensions."
)
return np.stack(frames, axis=0)
def encode_to_mp4( def encode_to_mp4(
frames: np.ndarray, frames: np.ndarray,
output_dir: str, output_dir: str,
......
...@@ -14,19 +14,16 @@ import uuid ...@@ -14,19 +14,16 @@ import uuid
from typing import Any, AsyncGenerator, Optional from typing import Any, AsyncGenerator, Optional
from dynamo._core import Component, Context from dynamo._core import Component, Context
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig from dynamo.common.protocols.video_protocol import (
from dynamo.trtllm.engines.diffusion_engine import DiffusionEngine
from dynamo.trtllm.protocols.video_protocol import (
NvCreateVideoRequest, NvCreateVideoRequest,
NvVideosResponse, NvVideosResponse,
VideoData, VideoData,
VideoNvExt, VideoNvExt,
) )
from dynamo.common.utils.video_utils import encode_to_mp4, encode_to_mp4_bytes
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig
from dynamo.trtllm.engines.diffusion_engine import DiffusionEngine
from dynamo.trtllm.request_handlers.base_generative_handler import BaseGenerativeHandler from dynamo.trtllm.request_handlers.base_generative_handler import BaseGenerativeHandler
from dynamo.trtllm.request_handlers.video_diffusion.video_utils import (
encode_to_mp4,
encode_to_mp4_bytes,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -18,14 +18,14 @@ from unittest.mock import MagicMock, patch ...@@ -18,14 +18,14 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig from dynamo.common.protocols.video_protocol import (
from dynamo.trtllm.constants import Modality
from dynamo.trtllm.protocols.video_protocol import (
NvCreateVideoRequest, NvCreateVideoRequest,
NvVideosResponse, NvVideosResponse,
VideoData, VideoData,
VideoNvExt, VideoNvExt,
) )
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig
from dynamo.trtllm.constants import Modality
pytestmark = [ pytestmark = [
pytest.mark.unit, pytest.mark.unit,
......
...@@ -138,6 +138,135 @@ class DynamoVllmArgGroup(ArgGroup): ...@@ -138,6 +138,135 @@ class DynamoVllmArgGroup(ArgGroup):
help="Path to vLLM-Omni stage configuration YAML file for --omni mode (optional).", help="Path to vLLM-Omni stage configuration YAML file for --omni mode (optional).",
) )
# Video diffusion output
# TODO: Propose an alternate design to switch to AsyncOmniEngine args while using vLLM-Omni
add_argument(
g,
flag_name="--video-output-dir",
env_var="DYN_VLLM_VIDEO_OUTPUT_DIR",
default="/tmp/dynamo_videos", # noqa: S108
help="Directory to save generated video MP4 files.",
)
add_argument(
g,
flag_name="--default-video-fps",
env_var="DYN_VLLM_DEFAULT_VIDEO_FPS",
default=16,
arg_type=int,
help="Default frames per second for generated videos.",
)
# Diffusion engine-level args (passed to AsyncOmni constructor)
add_negatable_bool_argument(
g,
flag_name="--enable-layerwise-offload",
env_var="DYN_VLLM_ENABLE_LAYERWISE_OFFLOAD",
default=False,
help="Enable layerwise (blockwise) offloading on DiT modules to reduce GPU memory.",
)
add_argument(
g,
flag_name="--layerwise-num-gpu-layers",
env_var="DYN_VLLM_LAYERWISE_NUM_GPU_LAYERS",
default=1,
arg_type=int,
help="Number of ready layers (blocks) to keep on GPU during generation.",
)
add_negatable_bool_argument(
g,
flag_name="--vae-use-slicing",
env_var="DYN_VLLM_VAE_USE_SLICING",
default=False,
help="Enable VAE slicing for memory optimization in diffusion models.",
)
add_negatable_bool_argument(
g,
flag_name="--vae-use-tiling",
env_var="DYN_VLLM_VAE_USE_TILING",
default=False,
help="Enable VAE tiling for memory optimization in diffusion models.",
)
add_argument(
g,
flag_name="--boundary-ratio",
env_var="DYN_VLLM_BOUNDARY_RATIO",
default=0.875,
arg_type=float,
help=(
"Boundary split ratio for low/high DiT transformers. "
"Default 0.875 uses both transformers for best quality. "
"Set to 1.0 to load only the low-noise transformer (saves memory). "
"Only used with --omni."
),
)
add_argument(
g,
flag_name="--flow-shift",
env_var="DYN_VLLM_FLOW_SHIFT",
default=None,
arg_type=float,
help="Scheduler flow_shift parameter (5.0 for 720p, 12.0 for 480p). Only used with --omni.",
)
add_argument(
g,
flag_name="--diffusion-cache-backend",
env_var="DYN_VLLM_DIFFUSION_CACHE_BACKEND",
default=None,
choices=["cache_dit", "tea_cache"],
help=(
"Cache backend for diffusion acceleration. "
"'cache_dit' enables DBCache + SCM + TaylorSeer. "
"'tea_cache' enables TeaCache. Only used with --omni."
),
)
add_argument(
g,
flag_name="--diffusion-cache-config",
env_var="DYN_VLLM_DIFFUSION_CACHE_CONFIG",
default=None,
help="Cache configuration as JSON string (overrides defaults). Only used with --omni.",
)
add_negatable_bool_argument(
g,
flag_name="--enable-cache-dit-summary",
env_var="DYN_VLLM_ENABLE_CACHE_DIT_SUMMARY",
default=False,
help="Enable cache-dit summary logging after diffusion forward passes.",
)
add_negatable_bool_argument(
g,
flag_name="--enable-cpu-offload",
env_var="DYN_VLLM_ENABLE_CPU_OFFLOAD",
default=False,
help="Enable CPU offloading for diffusion models to reduce GPU memory usage.",
)
# Diffusion parallel configuration
add_argument(
g,
flag_name="--ulysses-degree",
env_var="DYN_VLLM_ULYSSES_DEGREE",
default=1,
arg_type=int,
help="Number of GPUs used for Ulysses sequence parallelism in diffusion.",
)
add_argument(
g,
flag_name="--ring-degree",
env_var="DYN_VLLM_RING_DEGREE",
default=1,
arg_type=int,
help="Number of GPUs used for ring sequence parallelism in diffusion.",
)
add_argument(
g,
flag_name="--cfg-parallel-size",
env_var="DYN_VLLM_CFG_PARALLEL_SIZE",
default=1,
arg_type=int,
choices=[1, 2],
help="Number of GPUs used for classifier free guidance parallelism.",
)
# ModelExpress P2P # ModelExpress P2P
add_argument( add_argument(
g, g,
...@@ -171,6 +300,27 @@ class DynamoVllmConfig(ConfigBase): ...@@ -171,6 +300,27 @@ class DynamoVllmConfig(ConfigBase):
omni: bool omni: bool
stage_configs_path: Optional[str] = None stage_configs_path: Optional[str] = None
# Video diffusion output
video_output_dir: str = "/tmp/dynamo_videos" # noqa: S108
default_video_fps: int = 16
# Diffusion engine-level parameters (passed to AsyncOmni constructor)
enable_layerwise_offload: bool = False
layerwise_num_gpu_layers: int = 1
vae_use_slicing: bool = False
vae_use_tiling: bool = False
boundary_ratio: float = 0.875
flow_shift: Optional[float] = None
diffusion_cache_backend: Optional[str] = None
diffusion_cache_config: Optional[str] = None
enable_cache_dit_summary: bool = False
enable_cpu_offload: bool = False
# Diffusion parallel configuration
ulysses_degree: int = 1
ring_degree: int = 1
cfg_parallel_size: int = 1
# ModelExpress P2P # ModelExpress P2P
model_express_url: Optional[str] = None model_express_url: Optional[str] = None
......
...@@ -902,11 +902,10 @@ def get_engine_cache_info(engine: AsyncLLM): ...@@ -902,11 +902,10 @@ def get_engine_cache_info(engine: AsyncLLM):
async def init_omni( async def init_omni(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
): ):
""" """Initialize Omni worker for multi-stage pipeline generation using vLLM-Omni.
Initialize Omni worker for text-to-text generation using vLLM-Omni orchestrator.
Uses vLLM-Omni's Omni class for single-stage text generation pipeline. Supports text-to-text, text-to-image, and text-to-video generation
For now, supports text-to-text only (no multimodal). through a single unified OmniHandler.
""" """
# Lazy import to avoid loading vllm-omni unless explicitly needed # Lazy import to avoid loading vllm-omni unless explicitly needed
from dynamo.vllm.omni import OmniHandler from dynamo.vllm.omni import OmniHandler
...@@ -916,7 +915,7 @@ async def init_omni( ...@@ -916,7 +915,7 @@ async def init_omni(
) )
component = generate_endpoint.component() component = generate_endpoint.component()
# Initialize OmniHandler with Omni orchestrator # Initialize unified OmniHandler
handler = OmniHandler( handler = OmniHandler(
runtime=runtime, runtime=runtime,
component=component, component=component,
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
"""vLLM-Omni integration for Dynamo.""" """vLLM-Omni integration for Dynamo."""
from .base_handler import BaseOmniHandler
from .omni_handler import OmniHandler from .omni_handler import OmniHandler
__all__ = ["OmniHandler"] __all__ = ["BaseOmniHandler", "OmniHandler"]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Base handler for vLLM-Omni multi-stage pipelines."""
import asyncio
import logging
import time
from typing import Any, AsyncGenerator, Dict
from vllm import SamplingParams
from vllm_omni.entrypoints import AsyncOmni
try:
from vllm_omni.diffusion.data import DiffusionParallelConfig
except ImportError:
DiffusionParallelConfig = None # type: ignore[assignment, misc]
from dynamo.vllm.handlers import BaseWorkerHandler, build_sampling_params
logger = logging.getLogger(__name__)
class BaseOmniHandler(BaseWorkerHandler):
"""Base handler for multi-stage pipelines using vLLM-Omni's AsyncOmni orchestrator."""
def __init__(
self,
runtime,
component,
config,
default_sampling_params: Dict[str, Any],
shutdown_event: asyncio.Event | None = None,
):
"""Initialize handler with AsyncOmni orchestrator.
Args:
runtime: Dynamo distributed runtime.
component: Dynamo component handle.
config: Parsed Config object from args.py.
default_sampling_params: Default sampling parameters dict.
shutdown_event: Optional asyncio event for graceful shutdown.
"""
logger.info(
f"Initializing {self.__class__.__name__} for multi-stage pipelines "
f"with model: {config.model}"
)
omni_kwargs = self._build_omni_kwargs(config)
self.engine_client = AsyncOmni(**omni_kwargs)
# Initialize attributes needed from BaseWorkerHandler
# We don't call super().__init__() because VllmEngineMonitor expects AsyncLLM,
# but AsyncOmni manages its own engines internally
# TODO: Kv publishers not supported yet
# TODO: Adopt to baseworker initialization pattern
self.runtime = runtime
self.component = component
self.default_sampling_params = default_sampling_params
self.config = config
self.model_max_len = config.engine_args.max_model_len
self.shutdown_event = shutdown_event
self.use_vllm_tokenizer = config.use_vllm_tokenizer
logger.info(f"{self.__class__.__name__} initialized successfully")
def _build_omni_kwargs(self, config) -> Dict[str, Any]:
"""Build keyword arguments for AsyncOmni constructor.
Constructs the full kwargs dict including engine-level diffusion
parameters and parallel configuration when available.
Args:
config: Parsed Config object.
Returns:
Dictionary of keyword arguments for AsyncOmni.
"""
omni_kwargs: Dict[str, Any] = {
"model": config.model,
"trust_remote_code": config.engine_args.trust_remote_code,
}
if config.stage_configs_path:
omni_kwargs["stage_configs_path"] = config.stage_configs_path
# Add diffusion engine-level params if present on config
diffusion_params = [
"enable_layerwise_offload",
"layerwise_num_gpu_layers",
"vae_use_slicing",
"vae_use_tiling",
"boundary_ratio",
"flow_shift",
"diffusion_cache_backend",
"diffusion_cache_config",
"enable_cache_dit_summary",
"enable_cpu_offload",
]
for param in diffusion_params:
if hasattr(config, param):
value = getattr(config, param)
if value is not None:
# Map config attribute names to AsyncOmni kwarg names
kwarg_name = param
if param == "diffusion_cache_backend":
kwarg_name = "cache_backend"
elif param == "diffusion_cache_config":
kwarg_name = "cache_config"
omni_kwargs[kwarg_name] = value
# Build DiffusionParallelConfig if parallel params are present
if DiffusionParallelConfig is not None and hasattr(config, "ulysses_degree"):
parallel_config = DiffusionParallelConfig(
ulysses_degree=getattr(config, "ulysses_degree", 1),
ring_degree=getattr(config, "ring_degree", 1),
cfg_parallel_size=getattr(config, "cfg_parallel_size", 1),
)
omni_kwargs["parallel_config"] = parallel_config
elif DiffusionParallelConfig is None:
logger.warning(
"DiffusionParallelConfig not available; "
"skipping parallel config for AsyncOmni"
)
return omni_kwargs
async def generate(
self, request: Dict[str, Any], context
) -> AsyncGenerator[Dict, None]:
"""Generate outputs using AsyncOmni orchestrator with OpenAI-compatible format.
Routes to OpenAI mode (detokenized text) or token mode based on config.
Subclasses should override ``_generate_openai_mode`` for custom output handling.
"""
request_id = context.id()
logger.debug(f"Omni Request ID: {request_id}")
async for chunk in self._generate_openai_mode(request, context, request_id):
yield chunk
async def _generate_openai_mode(
self, request, context, request_id
) -> AsyncGenerator[Dict, None]:
"""Generate OpenAI-compatible streaming chunks.
Subclasses should override this to handle their specific output types.
The base implementation raises NotImplementedError.
"""
raise NotImplementedError(
f"{self.__class__.__name__} must implement _generate_openai_mode"
)
def _extract_text_prompt(self, request: Dict[str, Any]) -> str | None:
"""Extract text prompt from OpenAI messages format.
Looks for the last user message and returns its text content.
"""
messages = request.get("messages", [])
for message in reversed(messages):
if message.get("role") == "user":
return message.get("content")
return None
def _extract_extra_body(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""Extract extra_body parameters from the request.
The extra_body is passed through by the OpenAI client and contains
model-specific parameters (e.g. diffusion sampling params).
"""
return request.get("extra_body", {}) or {}
def _build_sampling_params(self, request: Dict[str, Any]) -> SamplingParams:
"""Build sampling params using shared handler utility."""
return build_sampling_params(
request, self.default_sampling_params, self.model_max_len
)
def _error_chunk(self, request_id: str, error_message: str) -> Dict[str, Any]:
"""Create an error chunk in OpenAI format."""
return {
"id": request_id,
"created": int(time.time()),
"object": "chat.completion.chunk",
"model": self.config.served_model_name or self.config.model,
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": f"Error: {error_message}",
},
"finish_reason": "error",
}
],
}
def cleanup(self):
"""Cleanup AsyncOmni orchestrator resources."""
try:
if hasattr(self, "engine_client"):
self.engine_client.close()
logger.info("AsyncOmni orchestrator closed")
except Exception as e:
logger.error(f"Error closing AsyncOmni orchestrator: {e}")
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import MagicMock, patch
import pytest
from dynamo.common.protocols.image_protocol import NvCreateImageRequest
from dynamo.common.protocols.video_protocol import NvCreateVideoRequest
from dynamo.common.utils.output_modalities import RequestType
# TODO: Install vLLM omni dependencies in CI container so this skip is no longer needed.
try:
from dynamo.vllm.omni.omni_handler import (
EngineInputs,
OmniHandler,
prepare_image_output,
)
except ImportError:
pytest.skip("vLLM omni dependencies not available", allow_module_level=True)
pytestmark = [
pytest.mark.unit,
pytest.mark.vllm,
pytest.mark.gpu_0,
pytest.mark.pre_merge,
]
def _make_handler():
with patch(
"dynamo.vllm.omni.omni_handler.BaseOmniHandler.__init__", return_value=None
):
handler = OmniHandler.__new__(OmniHandler)
config = MagicMock()
config.model = "test-model"
config.served_model_name = None
config.output_modalities = ["text"]
handler.config = config
return handler
class TestEngineInputs:
def test_defaults(self):
"""EngineInputs uses CHAT_COMPLETION, fps=0, and None optionals by default."""
ei = EngineInputs(prompt={"prompt": "hello"})
assert ei.request_type == RequestType.CHAT_COMPLETION
assert ei.fps == 0
assert ei.sampling_params_list is None
assert ei.response_format is None
class TestPrepareImageOutput:
def test_b64_json(self):
"""b64_json format returns data URI with base64 prefix."""
img = MagicMock()
img.save = lambda b, format: b.write(b"fake_png_data")
results = prepare_image_output([img], "b64_json")
assert len(results) == 1
assert results[0].startswith("data:image/png;base64,")
def test_b64_default_when_none(self):
"""None response_format defaults to base64 encoding."""
img = MagicMock()
img.save = lambda b, format: b.write(b"data")
results = prepare_image_output([img], None)
assert results[0].startswith("data:image/png;base64,")
def test_invalid_format(self):
"""Unsupported response_format raises ValueError."""
with pytest.raises(ValueError, match="Invalid response format"):
prepare_image_output([MagicMock()], "invalid")
def test_multiple_images(self):
"""Multiple input images produce one output entry each."""
imgs = [MagicMock() for _ in range(3)]
for img in imgs:
img.save = lambda b, format: b.write(b"px")
results = prepare_image_output(imgs, "b64_json")
assert len(results) == 3
class TestBuildEngineInputs:
def test_chat_completion(self):
"""Chat request extracts text prompt with no sampling params."""
handler = _make_handler()
raw = {"messages": [{"role": "user", "content": "hello"}]}
inputs = handler.build_engine_inputs(raw, RequestType.CHAT_COMPLETION)
assert inputs.request_type == RequestType.CHAT_COMPLETION
assert inputs.prompt["prompt"] == "hello"
assert inputs.sampling_params_list is None
def test_image_generation(self):
"""Image request parses prompt, size, and creates diffusion sampling params."""
handler = _make_handler()
req = NvCreateImageRequest(prompt="a cat", size="512x512")
inputs = handler.build_engine_inputs(req, RequestType.IMAGE_GENERATION)
assert inputs.request_type == RequestType.IMAGE_GENERATION
assert inputs.prompt["prompt"] == "a cat"
assert len(inputs.sampling_params_list) == 1
sp = inputs.sampling_params_list[0]
assert sp.height == 512
assert sp.width == 512
def test_video_generation(self):
"""Video request parses prompt, size, seconds, and sets fps."""
handler = _make_handler()
req = NvCreateVideoRequest(
prompt="a drone", model="test", size="832x480", seconds=2
)
inputs = handler.build_engine_inputs(req, RequestType.VIDEO_GENERATION)
assert inputs.request_type == RequestType.VIDEO_GENERATION
assert inputs.prompt["prompt"] == "a drone"
assert inputs.fps > 0
def test_audio_not_implemented(self):
"""Audio generation raises NotImplementedError."""
handler = _make_handler()
with pytest.raises(NotImplementedError):
handler.build_engine_inputs({}, RequestType.AUDIO_GENERATION)
class TestFormatTextChunk:
def _make_output(self, text="hello world", finish_reason=None):
output = MagicMock()
output.text = text
output.finish_reason = finish_reason
request_output = MagicMock()
request_output.outputs = [output]
request_output.prompt_token_ids = [1, 2, 3]
return request_output
def test_delta_text(self):
"""Delta content is the diff between current and previous text."""
handler = _make_handler()
ro = self._make_output("hello world")
chunk = handler._format_text_chunk(ro, "req-1", "hello ")
assert chunk["choices"][0]["delta"]["content"] == "world"
def test_no_outputs_returns_error(self):
"""Empty engine outputs produce an error chunk."""
handler = _make_handler()
ro = MagicMock()
ro.outputs = []
chunk = handler._format_text_chunk(ro, "req-1", "")
assert "Error" in chunk["choices"][0]["delta"]["content"]
def test_finish_reason_included(self):
"""Final chunk includes finish_reason and usage stats."""
handler = _make_handler()
handler._normalize_finish_reason = lambda r: r
handler._build_completion_usage = lambda ro: {
"prompt_tokens": 3,
"completion_tokens": 1,
}
ro = self._make_output("done", finish_reason="stop")
chunk = handler._format_text_chunk(ro, "req-1", "")
assert chunk["choices"][0]["finish_reason"] == "stop"
assert "usage" in chunk
class TestFormatImageChunk:
def test_chat_completion_format(self):
"""Chat completion route returns image_url content parts."""
handler = _make_handler()
img = MagicMock()
img.save = lambda b, format: b.write(b"px")
chunk = handler._format_image_chunk(
[img], "req-1", request_type=RequestType.CHAT_COMPLETION
)
assert chunk["object"] == "chat.completion.chunk"
assert chunk["choices"][0]["delta"]["content"][0]["type"] == "image_url"
def test_image_generation_b64_format(self):
"""Image generation with b64_json format returns base64 data."""
handler = _make_handler()
img = MagicMock()
img.save = lambda b, format: b.write(b"px")
chunk = handler._format_image_chunk(
[img],
"req-1",
response_format="b64_json",
request_type=RequestType.IMAGE_GENERATION,
)
assert chunk["data"][0]["b64_json"] is not None
def test_image_generation_default_format_returns_b64(self):
"""Image generation with response_format=None defaults to b64_json."""
handler = _make_handler()
img = MagicMock()
img.save = lambda b, format: b.write(b"px")
chunk = handler._format_image_chunk(
[img],
"req-1",
response_format=None,
request_type=RequestType.IMAGE_GENERATION,
)
assert chunk["data"][0]["b64_json"] is not None
def test_empty_images_returns_error(self):
"""Empty image list produces an error chunk."""
handler = _make_handler()
chunk = handler._format_image_chunk([], "req-1")
assert "Error" in chunk["choices"][0]["delta"]["content"]
class TestFormatVideoChunk:
@pytest.mark.asyncio
async def test_empty_frames_returns_none(self):
"""Empty frame list returns None."""
handler = _make_handler()
result = await handler._format_video_chunk([], "req-1", fps=16)
assert result is None
@pytest.mark.asyncio
async def test_error_returns_failed_status(self):
"""Encoding failure returns NvVideosResponse with failed status and error."""
handler = _make_handler()
with patch(
"dynamo.vllm.omni.omni_handler.normalize_video_frames",
side_effect=RuntimeError("boom"),
):
chunk = await handler._format_video_chunk([MagicMock()], "req-1", fps=16)
assert chunk["status"] == "failed"
assert "boom" in chunk["error"]
...@@ -4,34 +4,53 @@ ...@@ -4,34 +4,53 @@
title: vLLM-Omni title: vLLM-Omni
--- ---
# [Experimental] Running Omni Models with vLLM # [Experimental] Omni Models with vLLM
Dynamo supports omni (multimodal generation) models via the [vLLM-Omni](https://github.com/vllm-project/vllm-omni) backend. This enables multi-stage pipelines for tasks like text-to-text and text-to-image generation through an OpenAI-compatible API. Dynamo supports multimodal generation through the [vLLM-Omni](https://github.com/vllm-project/vllm-omni) backend. This integration exposes text-to-text, text-to-image, and text-to-video capabilities via OpenAI-compatible API endpoints.
## Prerequisites ## Prerequisites
This guide assumes familiarity with deploying Dynamo with vLLM as described in [README.md](/docs/pages/backends/vllm/README.md). This guide assumes familiarity with deploying Dynamo with vLLM as described in the [vLLM backend guide](/docs/pages/backends/vllm/README.md).
## Quick Start ## Supported Modalities
### Text-to-Text | Modality | Endpoint(s) | `--output-modalities` |
|---|---|---|
| Text-to-Text | `/v1/chat/completions` | `text` (default) |
| Text-to-Image | `/v1/chat/completions`, `/v1/images/generations` | `image` |
| Text-to-Video | `/v1/videos` | `video` |
Launch an aggregated deployment (frontend + omni worker) using the provided script: The `--output-modalities` flag determines which endpoint(s) the worker registers. When set to `image`, both `/v1/chat/completions` (returns inline base64 images) and `/v1/images/generations` are available. When set to `video`, the worker serves `/v1/videos`.
## Tested Models
| Modality | Models |
|---|---|
| Text-to-Text | `Qwen/Qwen2.5-Omni-7B` |
| Text-to-Image | `Qwen/Qwen-Image`, `AIDC-AI/Ovis-Image-7B` |
| Text-to-Video | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers`, `Wan-AI/Wan2.2-T2V-A14B-Diffusers` |
To run a non-default model, pass `--model` to any launch script:
```bash ```bash
bash examples/backends/vllm/launch/agg_omni.sh bash examples/backends/vllm/launch/agg_omni_image.sh --model AIDC-AI/Ovis-Image-7B
bash examples/backends/vllm/launch/agg_omni_video.sh --model Wan-AI/Wan2.2-T2V-A14B-Diffusers
``` ```
This starts `Qwen/Qwen2.5-Omni-7B` with a single-stage thinker config on one GPU. Override the model with: ## Text-to-Text
Launch an aggregated deployment (frontend + omni worker):
```bash ```bash
bash examples/backends/vllm/launch/agg_omni.sh --model <your-model> bash examples/backends/vllm/launch/agg_omni.sh
``` ```
Test the deployment: This starts `Qwen/Qwen2.5-Omni-7B` with a single-stage thinker config on one GPU.
Verify the deployment:
```bash ```bash
curl -X POST http://localhost:8000/v1/chat/completions \ curl -s http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"model": "Qwen/Qwen2.5-Omni-7B", "model": "Qwen/Qwen2.5-Omni-7B",
...@@ -41,34 +60,29 @@ curl -X POST http://localhost:8000/v1/chat/completions \ ...@@ -41,34 +60,29 @@ curl -X POST http://localhost:8000/v1/chat/completions \
}' }'
``` ```
### Text-to-Image This script uses a custom stage config (`stage_configs/single_stage_llm.yaml`) that configures the thinker stage for text generation. See [Stage Configuration](#stage-configuration) for details.
## Text-to-Image
Text-to-image uses vLLM-Omni's built-in default stage configs (no custom YAML needed). Launch without a stage config path so vLLM-Omni loads the model's default multi-stage pipeline: Launch using the provided script with `Qwen/Qwen-Image`:
```bash ```bash
# Start frontend bash examples/backends/vllm/launch/agg_omni_image.sh
python -m dynamo.frontend &
# Start omni worker (vLLM-Omni loads default stage configs for the model)
DYN_SYSTEM_PORT=8081 python -m dynamo.vllm \
--model <your-text-to-image-model> \
--omni \
--connector none
``` ```
Images are returned as base64-encoded PNGs in the response: ### Via `/v1/chat/completions`
```bash ```bash
curl -X POST http://localhost:8000/v1/chat/completions \ curl -s http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"model": "<your-text-to-image-model>", "model": "Qwen/Qwen-Image",
"messages": [{"role": "user", "content": "A cat sitting on a windowsill"}], "messages": [{"role": "user", "content": "A cat sitting on a windowsill"}],
"stream": false "stream": false
}' }'
``` ```
The response contains image data URLs in the content field: The response includes base64-encoded images inline:
```json ```json
{ {
...@@ -82,25 +96,79 @@ The response contains image data URLs in the content field: ...@@ -82,25 +96,79 @@ The response contains image data URLs in the content field:
} }
``` ```
## Key Flags ### Via `/v1/images/generations`
```bash
curl -s http://localhost:8000/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen-Image",
"prompt": "A cat sitting on a windowsill",
"size": "1024x1024",
"response_format": "url"
}'
```
## Text-to-Video
Launch using the provided script with `Wan-AI/Wan2.1-T2V-1.3B-Diffusers`:
```bash
bash examples/backends/vllm/launch/agg_omni_video.sh
```
Generate a video via `/v1/videos`:
```bash
curl -s http://localhost:8000/v1/videos \
-H "Content-Type: application/json" \
-d '{
"model": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
"prompt": "A drone flyover of a mountain landscape",
"seconds": 2,
"size": "832x480",
"response_format": "url"
}'
```
The response returns a video URL or base64 data depending on `response_format`:
```json
{
"id": "...",
"object": "video",
"model": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
"status": "completed",
"data": [{"url": "/tmp/generated_video.mp4"}]
}
```
The `/v1/videos` endpoint also accepts NVIDIA extensions via the `nvext` field for fine-grained control:
| Field | Description | Default |
|---|---|---|
| `nvext.fps` | Frames per second | 24 |
| `nvext.num_frames` | Number of frames (overrides `fps * seconds`) | -- |
| `nvext.negative_prompt` | Negative prompt for guidance | -- |
| `nvext.num_inference_steps` | Number of denoising steps | 50 |
| `nvext.guidance_scale` | CFG guidance scale | 5.0 |
| `nvext.seed` | Random seed for reproducibility | -- |
## CLI Reference
| Flag | Description | | Flag | Description |
|------|-------------| |---|---|
| `--omni` | Enable vLLM-Omni orchestrator (required) | | `--omni` | Enable the vLLM-Omni orchestrator (required for all omni workloads) |
| `--output-modalities <modality>` | Output modality: `text`, `image`, or `video` |
| `--stage-configs-path <path>` | Path to stage config YAML (optional; vLLM-Omni uses model defaults if omitted) | | `--stage-configs-path <path>` | Path to stage config YAML (optional; vLLM-Omni uses model defaults if omitted) |
| `--connector none` | Disable KV connector (recommended for omni) | | `--connector none` | Disable KV connector (recommended for omni workers) |
## Stage Configuration ## Stage Configuration
Omni pipelines are configured via YAML stage configs. See [`examples/backends/vllm/launch/stage_configs/single_stage_llm.yaml`](/examples/backends/vllm/launch/stage_configs/single_stage_llm.yaml) for an example. Key fields: Omni pipelines are configured via YAML stage configs. See [`examples/backends/vllm/launch/stage_configs/single_stage_llm.yaml`](/examples/backends/vllm/launch/stage_configs/single_stage_llm.yaml) for an example. For full documentation on stage config format and multi-stage pipelines, refer to the [vLLM-Omni Stage Configs documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/).
- **`model_stage`**: Pipeline stage name (e.g., `thinker`, `talker`, `code2wav`)
- **`final_output_type`**: Output format — `text` or `image`
- **`is_comprehension`**: Whether this stage processes input text/multimodal content
For full documentation on stage config format, supported fields, and multi-stage pipeline examples, see the [vLLM-Omni Stage Configs documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/).
## Current Limitations ## Current Limitations
- Only text prompts are supported (no multimodal input yet) - Only text prompts are supported as input (no multimodal input yet).
- KV cache events are not published for omni workers - KV cache events are not published for omni workers.
- Each worker supports a single output modality at a time.
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
MODEL="Qwen/Qwen-Image"
# Parse command line arguments
EXTRA_ARGS=()
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL="$2"
shift 2
;;
*)
EXTRA_ARGS+=("$1")
shift
;;
esac
done
echo "=========================================="
echo "Starting vLLM-Omni Worker"
echo "Model: $MODEL"
echo "=========================================="
echo "Starting frontend on port ${DYN_HTTP_PORT:-8000}..."
python -m dynamo.frontend &
FRONTEND_PID=$!
sleep 2
echo "Starting Omni worker..."
DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT:-8081} \
python -m dynamo.vllm \
--model "$MODEL" \
--omni \
--connector none \
--output-modalities image \
"${EXTRA_ARGS[@]}"
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
MODEL="Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
# Parse command line arguments
EXTRA_ARGS=()
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL="$2"
shift 2
;;
*)
EXTRA_ARGS+=("$1")
shift
;;
esac
done
echo "=========================================="
echo "Starting vLLM-Omni Worker"
echo "Model: $MODEL"
echo "=========================================="
echo "Starting frontend on port ${DYN_HTTP_PORT:-8000}..."
python -m dynamo.frontend &
FRONTEND_PID=$!
sleep 2
echo "Starting Omni worker..."
DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT:-8081} \
python -m dynamo.vllm \
--model "$MODEL" \
--omni \
--connector none \
--output-modalities video \
"${EXTRA_ARGS[@]}"
\ No newline at end of file
...@@ -589,8 +589,65 @@ impl ModelWatcher { ...@@ -589,8 +589,65 @@ impl ModelWatcher {
let engine = Arc::new(push_router); let engine = Arc::new(push_router);
self.manager self.manager
.add_embeddings_model(card.name(), checksum, engine)?; .add_embeddings_model(card.name(), checksum, engine)?;
}
// Case: Text + (Images, Audio, Videos)
// Must come before the plain Text+Chat / Text+Completions branches because
// diffusion models often set both Images and Chat flags. The branch below
// handles the chat registration internally when supports_chat() is true.
else if card.model_input == ModelInput::Text
&& (card.model_type.supports_images()
|| card.model_type.supports_audios()
|| card.model_type.supports_videos())
{
// Image Models can support chat completions (vllm omni way)
// So register chat_completions model as well
if card.model_type.supports_chat() {
let chat_router = PushRouter::<
NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client_with_threshold(
client.clone(),
self.router_config.router_mode,
None,
None,
)
.await?;
self.manager.add_chat_completions_model(
card.name(),
checksum,
Arc::new(chat_router),
)?;
}
// This is ModelType::Images : registers /v1/images/* endpoints
if card.model_type.supports_images() {
let images_router = PushRouter::<
NvCreateImageRequest,
Annotated<NvImagesResponse>,
>::from_client_with_threshold(
client.clone(), self.router_config.router_mode, None, None
)
.await?;
self.manager
.add_images_model(card.name(), checksum, Arc::new(images_router))?;
}
// This is ModelType::Videos : registers /v1/videos/* endpoints
if card.model_type.supports_videos() {
let videos_router = PushRouter::<
NvCreateVideoRequest,
Annotated<NvVideosResponse>,
>::from_client_with_threshold(
client.clone(), self.router_config.router_mode, None, None
)
.await?;
self.manager
.add_videos_model(card.name(), checksum, Arc::new(videos_router))?;
}
// TODO: add audio models support
} else if card.model_input == ModelInput::Text && card.model_type.supports_chat() { } else if card.model_input == ModelInput::Text && card.model_type.supports_chat() {
// Case 3: Text + Chat // Case: Text + Chat (pure text-to-text, no diffusion)
let push_router = PushRouter::< let push_router = PushRouter::<
NvCreateChatCompletionRequest, NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>, Annotated<NvCreateChatCompletionStreamResponse>,
...@@ -602,7 +659,7 @@ impl ModelWatcher { ...@@ -602,7 +659,7 @@ impl ModelWatcher {
self.manager self.manager
.add_chat_completions_model(card.name(), checksum, engine)?; .add_chat_completions_model(card.name(), checksum, engine)?;
} else if card.model_input == ModelInput::Text && card.model_type.supports_completions() { } else if card.model_input == ModelInput::Text && card.model_type.supports_completions() {
// Case 2: Text + Completions // Case: Text + Completions
let push_router = PushRouter::< let push_router = PushRouter::<
NvCreateCompletionRequest, NvCreateCompletionRequest,
Annotated<NvCreateCompletionResponse>, Annotated<NvCreateCompletionResponse>,
...@@ -660,60 +717,6 @@ impl ModelWatcher { ...@@ -660,60 +717,6 @@ impl ModelWatcher {
let engine = Arc::new(push_router); let engine = Arc::new(push_router);
self.manager self.manager
.add_tensor_model(card.name(), checksum, engine)?; .add_tensor_model(card.name(), checksum, engine)?;
}
// Case: Text + (Images, Audio, Videos)
else if card.model_input == ModelInput::Text
&& (card.model_type.supports_images()
|| card.model_type.supports_audios()
|| card.model_type.supports_videos())
{
// Image Models can support chat completions (vllm omni way)
// So register chat_completions model as well
if card.model_type.supports_chat() {
let chat_router = PushRouter::<
NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client_with_threshold(
client.clone(),
self.router_config.router_mode,
None,
None,
)
.await?;
self.manager.add_chat_completions_model(
card.name(),
checksum,
Arc::new(chat_router),
)?;
}
// This is ModelType::Images : registers /v1/images/* endpoints
if card.model_type.supports_images() {
let images_router = PushRouter::<
NvCreateImageRequest,
Annotated<NvImagesResponse>,
>::from_client_with_threshold(
client.clone(), self.router_config.router_mode, None, None
)
.await?;
self.manager
.add_images_model(card.name(), checksum, Arc::new(images_router))?;
}
// This is ModelType::Videos : registers /v1/videos/* endpoints
if card.model_type.supports_videos() {
let videos_router = PushRouter::<
NvCreateVideoRequest,
Annotated<NvVideosResponse>,
>::from_client_with_threshold(
client.clone(), self.router_config.router_mode, None, None
)
.await?;
self.manager
.add_videos_model(card.name(), checksum, Arc::new(videos_router))?;
}
// TODO: add audio models support
} else if card.model_type.supports_prefill() { } else if card.model_type.supports_prefill() {
// Case 6: Prefill // Case 6: Prefill
// Guardrail: Verify model_input is Tokens // Guardrail: Verify model_input is Tokens
......
...@@ -154,7 +154,6 @@ impl ModelType { ...@@ -154,7 +154,6 @@ impl ModelType {
if self.contains(Self::Embedding) { if self.contains(Self::Embedding) {
endpoint_types.push(crate::endpoint_type::EndpointType::Embedding); endpoint_types.push(crate::endpoint_type::EndpointType::Embedding);
} }
// Images models support both chat and completions endpoints
if self.contains(Self::Images) { if self.contains(Self::Images) {
endpoint_types.push(crate::endpoint_type::EndpointType::Images); endpoint_types.push(crate::endpoint_type::EndpointType::Images);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment