Unverified Commit 2be83be2 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

feat: add video generation support (T2V) (#5793)

parent 14eceb43
...@@ -115,6 +115,9 @@ core ...@@ -115,6 +115,9 @@ core
/CLAUDE.md /CLAUDE.md
/CLAUDE.md.bak /CLAUDE.md.bak
# AI generated worklogs
worklogs/
# Benchmarks # Benchmarks
benchmarks/results benchmarks/results
profiling_results* profiling_results*
......
...@@ -137,13 +137,17 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = { ...@@ -137,13 +137,17 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = {
"default": None, "default": None,
"help": "Filesystem URL for storing generated images using fsspec (e.g., s3://bucket/path, gs://bucket/path, file:///local/path). Supports any fsspec-compatible filesystem.", "help": "Filesystem URL for storing generated images using fsspec (e.g., s3://bucket/path, gs://bucket/path, file:///local/path). Supports any fsspec-compatible filesystem.",
}, },
"image-diffusion-base-url": { "video-generation-worker": {
"flags": ["--image-diffusion-base-url"], "flags": ["--video-generation-worker"],
"action": "store_true",
"default": False,
"help": "Run as video generation worker for video generation (T2V/I2V)",
},
"video-generation-fs-url": {
"flags": ["--video-generation-fs-url"],
"type": str, "type": str,
"default": os.environ.get( "default": None,
"DYN_IMAGE_DIFFUSION_BASE_URL", "http://localhost:8008/" "help": "Filesystem URL for storing generated videos using fsspec (e.g., s3://bucket/path, gs://bucket/path, file:///local/path). Supports any fsspec-compatible filesystem.",
),
"help": "Base URL for rewriting image URLs in responses (e.g., http://localhost:8008/). When set, generated image URLs will use this base instead of filesystem URLs. Can be set via DYN_IMAGE_DIFFUSION_URL_BASE env var.",
}, },
} }
...@@ -189,7 +193,10 @@ class DynamoArgs: ...@@ -189,7 +193,10 @@ class DynamoArgs:
# image diffusion options # image diffusion options
image_diffusion_worker: bool = False image_diffusion_worker: bool = False
image_diffusion_fs_url: Optional[str] = None image_diffusion_fs_url: Optional[str] = None
image_diffusion_base_url: Optional[str] = None
# video generation options
video_generation_worker: bool = False
video_generation_fs_url: Optional[str] = None
class DisaggregationMode(Enum): class DisaggregationMode(Enum):
...@@ -434,6 +441,8 @@ async def parse_args(args: list[str]) -> Config: ...@@ -434,6 +441,8 @@ async def parse_args(args: list[str]) -> Config:
endpoint = f"dyn://{namespace}.backend.generate" endpoint = f"dyn://{namespace}.backend.generate"
elif getattr(parsed_args, "image_diffusion_worker", False): elif getattr(parsed_args, "image_diffusion_worker", False):
endpoint = f"dyn://{namespace}.backend.generate" endpoint = f"dyn://{namespace}.backend.generate"
elif getattr(parsed_args, "video_generation_worker", False):
endpoint = f"dyn://{namespace}.backend.generate"
elif ( elif (
hasattr(parsed_args, "disaggregation_mode") hasattr(parsed_args, "disaggregation_mode")
and parsed_args.disaggregation_mode == "prefill" and parsed_args.disaggregation_mode == "prefill"
...@@ -511,32 +520,38 @@ async def parse_args(args: list[str]) -> Config: ...@@ -511,32 +520,38 @@ async def parse_args(args: list[str]) -> Config:
# fetch_llm (download the model) here, in `parse_args`. `parse_args` should not # fetch_llm (download the model) here, in `parse_args`. `parse_args` should not
# contain code to download a model, it should only parse the args. # contain code to download a model, it should only parse the args.
# For diffusion workers, create a minimal dummy ServerArgs since diffusion # For diffusion/video workers, create a minimal dummy ServerArgs since diffusion
# doesn't use transformer models or sglang Engine - it uses DiffGenerator directly # doesn't use transformer models or sglang Engine - it uses DiffGenerator directly
image_diffusion_worker = getattr(parsed_args, "image_diffusion_worker", False) image_diffusion_worker = getattr(parsed_args, "image_diffusion_worker", False)
video_generation_worker = getattr(parsed_args, "video_generation_worker", False)
if image_diffusion_worker: if image_diffusion_worker or video_generation_worker:
logging.info(f"Image diffusion worker detected with model: {model_path}") worker_type = (
"image diffusion" if image_diffusion_worker else "video generation"
# Need to use ServerArgs not intended for sglang[diffusion], multimodal_gen has its own ServerArgs. )
server_args = ServerArgs("none") # HACK: Avoid triggering __post_init__ logging.info(
f"{worker_type.title()} worker detected with model: {model_path}, creating minimal ServerArgs stub"
)
# Create a minimal ServerArgs-like object that bypasses model config loading
# Diffusion/video workers don't actually use ServerArgs - they use DiffGenerator
import types
server_args = types.SimpleNamespace()
# Copy over any attrs that might be needed, but avoid triggering __post_init__
server_args.model_path = model_path server_args.model_path = model_path
server_args.served_model_name = parsed_args.served_model_name server_args.served_model_name = parsed_args.served_model_name
server_args.enable_metrics = getattr(parsed_args, "enable_metrics", False) server_args.enable_metrics = getattr(parsed_args, "enable_metrics", False)
server_args.log_level = getattr(parsed_args, "log_level", "info") server_args.log_level = getattr(parsed_args, "log_level", "info")
server_args.skip_tokenizer_init = True
server_args.kv_events_config = getattr(parsed_args, "kv_events_config", None) server_args.kv_events_config = getattr(parsed_args, "kv_events_config", None)
server_args.tp_size = getattr(parsed_args, "tp_size", 1)
server_args.dp_size = getattr(parsed_args, "dp_size", 1)
server_args.speculative_algorithm = None server_args.speculative_algorithm = None
server_args.disaggregation_mode = None server_args.disaggregation_mode = None
server_args.dllm_algorithm = False server_args.dllm_algorithm = False
server_args.tp_size = getattr(parsed_args, "tensor_parallel_size", 1) server_args.load_format = None
server_args.dp_size = getattr(parsed_args, "data_parallel_size", 1)
parsed_args.use_sglang_tokenizer = True
parsed_args.dyn_endpoint_types = "images"
logging.info( logging.info(
f"Created stub ServerArgs for diffusion: model_path={server_args.model_path}" f"Created stub ServerArgs for {worker_type}: model_path={server_args.model_path}"
) )
else: else:
server_args = ServerArgs.from_cli_args(parsed_args) server_args = ServerArgs.from_cli_args(parsed_args)
...@@ -595,7 +610,8 @@ async def parse_args(args: list[str]) -> Config: ...@@ -595,7 +610,8 @@ async def parse_args(args: list[str]) -> Config:
diffusion_worker=diffusion_worker, diffusion_worker=diffusion_worker,
image_diffusion_worker=getattr(parsed_args, "image_diffusion_worker", False), image_diffusion_worker=getattr(parsed_args, "image_diffusion_worker", False),
image_diffusion_fs_url=getattr(parsed_args, "image_diffusion_fs_url", None), image_diffusion_fs_url=getattr(parsed_args, "image_diffusion_fs_url", None),
image_diffusion_base_url=getattr(parsed_args, "image_diffusion_base_url", None), video_generation_worker=getattr(parsed_args, "video_generation_worker", False),
video_generation_fs_url=getattr(parsed_args, "video_generation_fs_url", None),
dump_config_to=parsed_args.dump_config_to, dump_config_to=parsed_args.dump_config_to,
enable_local_indexer=not parsed_args.durable_kv_events, enable_local_indexer=not parsed_args.durable_kv_events,
use_kv_events=use_kv_events, use_kv_events=use_kv_events,
......
...@@ -144,3 +144,33 @@ class ImageDiffusionHealthCheckPayload(HealthCheckPayload): ...@@ -144,3 +144,33 @@ class ImageDiffusionHealthCheckPayload(HealthCheckPayload):
} }
super().__init__() super().__init__()
class VideoGenerationHealthCheckPayload(HealthCheckPayload):
"""Video generation-specific health check payload for video generation workers.
Sends a minimal video generation request to verify the video worker
is responding and the model is loaded. Uses minimal resources for fast checks.
"""
def __init__(self, model_path: str):
"""Initialize video health check payload with minimal generation request.
Args:
model_path: The video generation model being served.
"""
self.default_payload = {
"prompt": "test", # Minimal prompt
"model": model_path,
"seconds": 1,
"size": "256x256", # Small size for fast health check
"response_format": "b64_json", # Don't require filesystem for health check
"nvext": {
"fps": 8,
"num_frames": 8, # Minimal frames for fast health check
"num_inference_steps": 1, # Just 1 step (fast but low quality)
"guidance_scale": 5.0, # Standard guidance scale for video
},
}
super().__init__()
...@@ -27,11 +27,13 @@ from dynamo.sglang.health_check import ( ...@@ -27,11 +27,13 @@ from dynamo.sglang.health_check import (
ImageDiffusionHealthCheckPayload, ImageDiffusionHealthCheckPayload,
SglangHealthCheckPayload, SglangHealthCheckPayload,
SglangPrefillHealthCheckPayload, SglangPrefillHealthCheckPayload,
VideoGenerationHealthCheckPayload,
) )
from dynamo.sglang.publisher import DynamoSglangPublisher, setup_sgl_metrics from dynamo.sglang.publisher import DynamoSglangPublisher, setup_sgl_metrics
from dynamo.sglang.register import ( from dynamo.sglang.register import (
register_image_diffusion_model, register_image_diffusion_model,
register_llm_with_readiness_gate, register_llm_with_readiness_gate,
register_video_generation_model,
) )
from dynamo.sglang.request_handlers import ( from dynamo.sglang.request_handlers import (
DecodeWorkerHandler, DecodeWorkerHandler,
...@@ -43,6 +45,7 @@ from dynamo.sglang.request_handlers import ( ...@@ -43,6 +45,7 @@ from dynamo.sglang.request_handlers import (
MultimodalProcessorHandler, MultimodalProcessorHandler,
MultimodalWorkerHandler, MultimodalWorkerHandler,
PrefillWorkerHandler, PrefillWorkerHandler,
VideoGenerationWorkerHandler,
) )
configure_dynamo_logging() configure_dynamo_logging()
...@@ -213,6 +216,8 @@ async def worker(): ...@@ -213,6 +216,8 @@ async def worker():
if config.dynamo_args.image_diffusion_worker: if config.dynamo_args.image_diffusion_worker:
await init_image_diffusion(runtime, config) await init_image_diffusion(runtime, config)
elif config.dynamo_args.video_generation_worker:
await init_video_generation(runtime, config)
elif config.dynamo_args.embedding_worker: elif config.dynamo_args.embedding_worker:
await init_embedding(runtime, config, shutdown_event) await init_embedding(runtime, config, shutdown_event)
elif config.dynamo_args.multimodal_processor: elif config.dynamo_args.multimodal_processor:
...@@ -643,6 +648,86 @@ async def init_image_diffusion(runtime: DistributedRuntime, config: Config): ...@@ -643,6 +648,86 @@ async def init_image_diffusion(runtime: DistributedRuntime, config: Config):
await RUN_DEFERRED_HANDLERS() await RUN_DEFERRED_HANDLERS()
async def init_video_generation(runtime: DistributedRuntime, config: Config):
"""Initialize video generation worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
# Initialize DiffGenerator (not sgl.Engine) - same as image diffusion
from sglang.multimodal_gen import DiffGenerator
if not server_args.model_path:
raise ValueError("--model is required for video generation workers")
# Parallelism configuration
tp_size = getattr(server_args, "tp_size", 1)
dp_size = getattr(server_args, "dp_size", 1)
num_gpus = tp_size * dp_size
# Distributed configuration
dist_timeout = getattr(server_args, "dist_timeout", None)
generator = DiffGenerator.from_pretrained(
model_path=server_args.model_path,
# Parallelism configuration
num_gpus=num_gpus,
tp_size=tp_size,
dp_size=dp_size,
# Distributed configuration
dist_timeout=dist_timeout,
)
# Initialize fsspec filesystems for video storage
fs_url = dynamo_args.video_generation_fs_url
# Initialize primary filesystem
if not fs_url:
raise ValueError(
"--video-generation-fs-url is required for video generation workers"
)
component = runtime.namespace(dynamo_args.namespace).component(
dynamo_args.component
)
generate_endpoint = component.endpoint(dynamo_args.endpoint)
handler = VideoGenerationWorkerHandler(
component,
generator,
config,
publisher=None,
fs=get_fs(fs_url),
)
# Create proper health check payload that sends a minimal video request
health_check_payload = VideoGenerationHealthCheckPayload(
model_path=server_args.model_path
).to_dict()
ready_event = asyncio.Event()
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=[], # No LLM metrics labels
health_check_payload=health_check_payload,
),
register_video_generation_model(
generator,
generate_endpoint,
server_args,
readiness_gate=ready_event,
),
)
except Exception as e:
logging.error(f"Failed to serve video generation endpoints: {e}")
raise
finally:
handler.cleanup()
async def init_multimodal_processor( async def init_multimodal_processor(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
): ):
......
...@@ -176,3 +176,52 @@ class ImagesResponse(BaseModel): ...@@ -176,3 +176,52 @@ class ImagesResponse(BaseModel):
created: int # Unix timestamp created: int # Unix timestamp
data: list[ImageData] data: list[ImageData]
# ============================================================================
# Video Generation Protocol Types
# ============================================================================
class VideoNvExt(BaseModel):
"""NVIDIA extensions for video generation requests."""
annotations: Optional[list[str]] = None
fps: Optional[int] = 24
num_frames: Optional[int] = None # Override: if set, ignores fps * seconds
negative_prompt: Optional[str] = None
num_inference_steps: Optional[int] = 50
guidance_scale: float = 5.0
seed: Optional[int] = None
class CreateVideoRequest(BaseModel):
"""Request for /v1/videos endpoint"""
prompt: str
model: str
input_reference: Optional[str] = None # For I2V (image-to-video) - image path/url
seconds: Optional[int] = 4
size: Optional[str] = "832x480" # WxH format (Wan default: 832x480)
user: Optional[str] = None
response_format: Optional[str] = "url" # url or b64_json
nvext: Optional[VideoNvExt] = None
class VideoData(BaseModel):
url: Optional[str] = None
b64_json: Optional[str] = None
class VideoGenerationResponse(BaseModel):
"""Response for video generation"""
id: str
object: str = "video"
model: str
status: str = "completed"
progress: int = 100
created: int
data: list[VideoData] = []
error: Optional[str] = None
inference_time_s: Optional[float] = None
...@@ -312,3 +312,43 @@ async def register_image_diffusion_model( ...@@ -312,3 +312,43 @@ async def register_image_diffusion_model(
readiness_gate.set() readiness_gate.set()
logging.info(f"Image diffusion model ready: {model_name}") logging.info(f"Image diffusion model ready: {model_name}")
async def register_video_generation_model(
generator: Any, # DiffGenerator
endpoint: Endpoint,
server_args: ServerArgs,
readiness_gate: Optional[asyncio.Event] = None,
) -> None:
"""Register video generation model with Dynamo runtime.
Args:
generator: The SGLang DiffGenerator instance (used for video generation).
endpoint: The Dynamo endpoint for generation requests.
server_args: SGLang server configuration.
readiness_gate: Optional event to signal when registration completes.
Note:
Video generation models use ModelInput.Text (text prompts) and ModelType.Videos.
"""
# Use model_path as the model name (video workers don't have served_model_name)
model_name = server_args.model_path
try:
await register_llm(
ModelInput.Text,
ModelType.Videos,
endpoint,
model_name,
model_name,
)
logging.info(f"Successfully registered video generation model: {model_name}")
except Exception as e:
logging.error(f"Failed to register video generation model: {e}")
raise RuntimeError("Video generation model registration failed")
# Signal readiness
if readiness_gate:
readiness_gate.set()
logging.info(f"Video generation model ready: {model_name}")
...@@ -21,6 +21,9 @@ from .multimodal import ( ...@@ -21,6 +21,9 @@ from .multimodal import (
MultimodalWorkerHandler, MultimodalWorkerHandler,
) )
# Video generation handlers
from .video_generation import VideoGenerationWorkerHandler
__all__ = [ __all__ = [
# Base handlers # Base handlers
"BaseGenerativeHandler", "BaseGenerativeHandler",
...@@ -33,6 +36,8 @@ __all__ = [ ...@@ -33,6 +36,8 @@ __all__ = [
"EmbeddingWorkerHandler", "EmbeddingWorkerHandler",
# Image diffusion handlers # Image diffusion handlers
"ImageDiffusionWorkerHandler", "ImageDiffusionWorkerHandler",
# Video generation handlers
"VideoGenerationWorkerHandler",
# Multimodal handlers # Multimodal handlers
"MultimodalEncodeWorkerHandler", "MultimodalEncodeWorkerHandler",
"MultimodalPrefillWorkerHandler", "MultimodalPrefillWorkerHandler",
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from .video_generation_handler import VideoGenerationWorkerHandler
__all__ = ["VideoGenerationWorkerHandler"]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import base64
import io
import logging
import random
import time
from typing import Any, AsyncGenerator, Optional
import torch
from dynamo._core import Component, Context
from dynamo.sglang.args import Config
from dynamo.sglang.protocol import (
CreateVideoRequest,
VideoData,
VideoGenerationResponse,
VideoNvExt,
)
from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseGenerativeHandler
logger = logging.getLogger(__name__)
class VideoGenerationWorkerHandler(BaseGenerativeHandler):
"""Handler for video generation (T2V/I2V).
Inherits from BaseGenerativeHandler for common infrastructure like
tracing, metrics publishing, and cancellation support.
"""
def __init__(
self,
component: Component,
generator: Any, # DiffGenerator, not sgl.Engine
config: Config,
publisher: Optional[DynamoSglangPublisher] = None,
fs: Any = None, # fsspec.AbstractFileSystem for primary storage
):
"""Initialize video generation worker handler.
Args:
component: The Dynamo runtime component.
generator: The SGLang DiffGenerator instance.
config: SGLang and Dynamo configuration.
publisher: Optional metrics publisher (not used for video currently).
fs: Optional fsspec filesystem for primary video storage.
"""
# Call parent constructor for common setup
super().__init__(component, config, publisher)
# Video generation-specific initialization
self.generator = generator # DiffGenerator, not Engine
self._generate_lock = asyncio.Lock() # Serialize generator access
self.fs = fs
self.fs_url = config.dynamo_args.video_generation_fs_url
logger.info(
f"Video generation worker handler initialized with fs_url={self.fs_url}"
)
def cleanup(self) -> None:
"""Cleanup generator resources"""
if self.generator is not None:
del self.generator
torch.cuda.empty_cache()
logger.info("Video generation generator cleanup complete")
# Call parent cleanup for any base class cleanup
super().cleanup()
async def generate(
self, request: dict[str, Any], context: Context
) -> AsyncGenerator[dict[str, Any], None]:
"""
Generate video from text/image prompt.
Unlike LLM streaming, video returns complete video at end.
Args:
request: Request dict with prompt and generation parameters.
context: Context object for cancellation handling.
Yields:
Response dict with generated video (OpenAI-compatible format).
"""
start_time = time.time()
# Get trace header for distributed tracing (for logging/observability)
trace_header = self._get_trace_header(context)
if trace_header:
logger.debug(f"Video generation request with trace: {trace_header}")
try:
req = CreateVideoRequest(**request)
nvext = req.nvext or VideoNvExt()
logger.info(
f"Video generation request: model={req.model}, "
f"size={req.size}, steps={nvext.num_inference_steps}"
)
# Parse size
width, height = self._parse_size(req.size)
# Calculate num_frames if not explicitly provided
num_frames = nvext.num_frames
if num_frames is None:
num_frames = nvext.fps * req.seconds
# Generate video
video_bytes = await self._generate_video(
prompt=req.prompt,
width=width,
height=height,
num_frames=num_frames,
fps=nvext.fps,
num_inference_steps=nvext.num_inference_steps,
guidance_scale=nvext.guidance_scale,
seed=nvext.seed,
request_id=context.id(),
negative_prompt=nvext.negative_prompt,
input_reference=req.input_reference,
)
video_data = []
if req.response_format == "url":
url = await self._upload_to_fs(video_bytes, context.id())
video_data.append(VideoData(url=url))
else: # b64_json
b64 = self._encode_base64(video_bytes)
video_data.append(VideoData(b64_json=b64))
inference_time = time.time() - start_time
response = VideoGenerationResponse(
id=f"video-{context.id()}",
model=req.model,
created=int(time.time()),
data=video_data,
inference_time_s=inference_time,
)
yield response.model_dump()
except Exception as e:
logger.error(f"Error in video generation: {e}", exc_info=True)
# Return error response
error_response = VideoGenerationResponse(
id=f"video-{context.id()}",
model=request.get("model", "unknown"),
created=int(time.time()),
status="failed",
progress=0,
data=[],
error=str(e),
)
yield error_response.model_dump()
async def _generate_video(
self,
prompt: str,
width: int,
height: int,
num_frames: int,
fps: int,
num_inference_steps: int,
guidance_scale: float,
seed: Optional[int],
request_id: str,
negative_prompt: Optional[str] = None,
input_reference: Optional[str] = None,
) -> bytes:
"""Generate video using SGLang DiffGenerator.
Args:
prompt: Text prompt for video generation.
width: Video width in pixels.
height: Video height in pixels.
num_frames: Number of frames to generate.
fps: Frames per second for output video.
num_inference_steps: Number of denoising steps.
guidance_scale: CFG scale for generation.
seed: Random seed for reproducibility.
request_id: Request ID for logging.
negative_prompt: Optional negative prompt.
input_reference: Optional image path for I2V.
Returns:
Video bytes (mp4 format).
"""
# Build args for DiffGenerator
args = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"height": height,
"width": width,
"num_frames": num_frames,
"fps": fps,
"num_inference_steps": num_inference_steps,
"save_output": False, # We handle saving ourselves
"guidance_scale": guidance_scale,
"seed": seed if seed is not None else random.randint(0, 1000000),
}
# Add image_path for I2V if provided
if input_reference:
args["image_path"] = input_reference
logger.info(
f"Generating video with {num_frames} frames at {width}x{height}, "
f"{num_inference_steps} steps, request_id={request_id}"
)
# Serialize access -- DiffGenerator has mutable state (CUDA graph
# caches, shared config objects) and is not thread-safe.
async with self._generate_lock:
# Run in thread pool to avoid blocking event loop
result = await asyncio.to_thread(
self.generator.generate,
sampling_params_kwargs=args,
)
# Result contains 'frames' with list of frames
frames = result.get("frames", [])
if not frames:
raise RuntimeError("DiffGenerator returned no frames")
# Convert frames to video bytes
video_bytes = await self._frames_to_video(frames, fps)
return video_bytes
async def _frames_to_video(
self, frames: list, fps: int, codec: str = "libx264"
) -> bytes:
"""Convert list of frames to video bytes.
Args:
frames: List of frames (PIL Images or numpy arrays).
fps: Frames per second.
codec: Video codec to use.
Returns:
Video bytes in mp4 format.
"""
try:
import numpy as np
from PIL import Image
# Convert frames to numpy arrays if needed
np_frames = []
for frame in frames:
if isinstance(frame, Image.Image):
np_frames.append(np.array(frame))
elif isinstance(frame, np.ndarray):
np_frames.append(frame)
else:
raise ValueError(f"Unsupported frame type: {type(frame)}")
# Use imageio to write video
import imageio
output_buffer = io.BytesIO()
with imageio.get_writer(
output_buffer,
format="mp4",
fps=fps,
codec=codec,
output_params=["-pix_fmt", "yuv420p"],
) as writer:
for frame in np_frames:
writer.append_data(frame)
output_buffer.seek(0)
return output_buffer.read()
except ImportError as e:
raise RuntimeError(
f"Missing dependency for video encoding: {e}. "
"Install with: pip install imageio imageio-ffmpeg"
)
def _parse_size(self, size_str: str) -> tuple[int, int]:
"""Parse 'WxH' -> (width, height)"""
try:
w, h = size_str.split("x")
return int(w), int(h)
except (ValueError, TypeError) as e:
raise ValueError(
f"Invalid size format '{size_str}', expected 'WxH' (e.g. '832x480')"
) from e
async def _upload_to_fs(self, video_bytes: bytes, request_id: str) -> str:
"""Upload video to filesystem and return URL.
Args:
video_bytes: Video data as bytes.
request_id: Request context ID.
Returns:
URL for the uploaded video.
"""
storage_path = f"{request_id}.mp4"
# DirFileSystem handles root path and protocol internally
await asyncio.to_thread(self.fs.pipe, storage_path, video_bytes)
return f"{self.fs_url}/{storage_path}"
def _encode_base64(self, video_bytes: bytes) -> str:
"""Encode video as base64 string"""
return base64.b64encode(video_bytes).decode("utf-8")
...@@ -12,25 +12,14 @@ from typing import Optional ...@@ -12,25 +12,14 @@ from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
class NvCreateVideoRequest(BaseModel): class VideoNvExt(BaseModel):
"""Request for video generation (/v1/videos/generations endpoint). """NVIDIA extensions for video generation requests.
Matches Rust NvCreateVideoRequest in lib/llm/src/protocols/openai/videos.rs. Matches Rust NvExt in lib/llm/src/protocols/openai/videos/nvext.rs.
""" """
# Required fields annotations: Optional[list[str]] = None
prompt: str """Annotations for SSE stream events."""
"""The text prompt for video generation."""
model: str
"""The model to use for video generation."""
# Optional fields
input_reference: Optional[str] = None
"""Optional input reference for I2V (image path/url)."""
seconds: Optional[int] = None
"""Duration in seconds (default: 4)."""
fps: Optional[int] = None fps: Optional[int] = None
"""Frames per second (default: 24).""" """Frames per second (default: 24)."""
...@@ -38,8 +27,8 @@ class NvCreateVideoRequest(BaseModel): ...@@ -38,8 +27,8 @@ class NvCreateVideoRequest(BaseModel):
num_frames: Optional[int] = None num_frames: Optional[int] = None
"""Number of frames to generate (overrides fps * seconds if set).""" """Number of frames to generate (overrides fps * seconds if set)."""
size: Optional[str] = None negative_prompt: Optional[str] = None
"""Video size in WxH format (default: '832x480').""" """Optional negative prompt."""
num_inference_steps: Optional[int] = None num_inference_steps: Optional[int] = None
"""Number of denoising steps (default: 50).""" """Number of denoising steps (default: 50)."""
...@@ -47,18 +36,42 @@ class NvCreateVideoRequest(BaseModel): ...@@ -47,18 +36,42 @@ class NvCreateVideoRequest(BaseModel):
guidance_scale: Optional[float] = None guidance_scale: Optional[float] = None
"""CFG guidance scale (default: 5.0).""" """CFG guidance scale (default: 5.0)."""
negative_prompt: Optional[str] = None
"""Optional negative prompt."""
seed: Optional[int] = None seed: Optional[int] = None
"""Random seed for reproducibility.""" """Random seed for reproducibility."""
class NvCreateVideoRequest(BaseModel):
"""Request for video generation (/v1/videos endpoint).
Matches Rust NvCreateVideoRequest in lib/llm/src/protocols/openai/videos.rs.
"""
# Required fields
prompt: str
"""The text prompt for video generation."""
model: str
"""The model to use for video generation."""
# Optional fields
input_reference: Optional[str] = None
"""Optional image reference that guides generation (for I2V)."""
seconds: Optional[int] = None
"""Clip duration in seconds."""
size: Optional[str] = None
"""Video size in WxH format (default: '832x480')."""
user: Optional[str] = None user: Optional[str] = None
"""Optional user identifier.""" """Optional user identifier."""
response_format: Optional[str] = None response_format: Optional[str] = None
"""Response format: 'url' or 'b64_json' (default: 'url').""" """Response format: 'url' or 'b64_json' (default: 'url')."""
nvext: Optional[VideoNvExt] = None
"""NVIDIA extensions."""
class VideoData(BaseModel): class VideoData(BaseModel):
"""Video data in response. """Video data in response.
......
...@@ -20,6 +20,7 @@ from dynamo.trtllm.protocols.video_protocol import ( ...@@ -20,6 +20,7 @@ from dynamo.trtllm.protocols.video_protocol import (
NvCreateVideoRequest, NvCreateVideoRequest,
NvVideosResponse, NvVideosResponse,
VideoData, VideoData,
VideoNvExt,
) )
from dynamo.trtllm.request_handlers.base_generative_handler import BaseGenerativeHandler from dynamo.trtllm.request_handlers.base_generative_handler import BaseGenerativeHandler
from dynamo.trtllm.request_handlers.video_diffusion.video_utils import ( from dynamo.trtllm.request_handlers.video_diffusion.video_utils import (
...@@ -118,36 +119,37 @@ class VideoGenerationHandler(BaseGenerativeHandler): ...@@ -118,36 +119,37 @@ class VideoGenerationHandler(BaseGenerativeHandler):
f"To allow larger sizes, increase --max-width and/or --max-height." f"To allow larger sizes, increase --max-width and/or --max-height."
) )
def _compute_num_frames(self, req: NvCreateVideoRequest) -> int: def _compute_num_frames(self, req: NvCreateVideoRequest, nvext: VideoNvExt) -> int:
"""Compute num_frames from request parameters. """Compute num_frames from request parameters.
Priority: Priority:
1. num_frames if explicitly set 1. nvext.num_frames if explicitly set
2. seconds * fps 2. req.seconds * nvext.fps
3. config defaults 3. config defaults
Args: Args:
req: The video generation request. req: The video generation request (contains seconds).
nvext: The NVIDIA extension parameters (contains fps, num_frames).
Returns: Returns:
Number of frames to generate. Number of frames to generate.
""" """
# Priority 1: Explicit num_frames takes precedence # Priority 1: Explicit num_frames takes precedence
if req.num_frames is not None: if nvext.num_frames is not None:
return req.num_frames return nvext.num_frames
# Priority 2: If user provided seconds and/or fps, calculate frame count # Priority 2: If user provided seconds and/or fps, calculate frame count
# Use config defaults for any unspecified value # Use config defaults for any unspecified value
seconds = ( seconds = (
req.seconds if req.seconds is not None else self.config.default_seconds req.seconds if req.seconds is not None else self.config.default_seconds
) )
fps = req.fps if req.fps is not None else self.config.default_fps fps = nvext.fps if nvext.fps is not None else self.config.default_fps
computed = seconds * fps computed = seconds * fps
# Priority 3: If user provided NEITHER seconds NOR fps, use config default # Priority 3: If user provided NEITHER seconds NOR fps, use config default
# This allows config.default_num_frames to take effect only when the user # This allows config.default_num_frames to take effect only when the user
# didn't specify any duration-related parameters # didn't specify any duration-related parameters
if req.seconds is None and req.fps is None: if req.seconds is None and nvext.fps is None:
return self.config.default_num_frames return self.config.default_num_frames
# User provided at least one of (seconds, fps), so use computed value # User provided at least one of (seconds, fps), so use computed value
...@@ -175,18 +177,19 @@ class VideoGenerationHandler(BaseGenerativeHandler): ...@@ -175,18 +177,19 @@ class VideoGenerationHandler(BaseGenerativeHandler):
try: try:
# Parse request # Parse request
req = NvCreateVideoRequest(**request) req = NvCreateVideoRequest(**request)
nvext = req.nvext or VideoNvExt()
# Parse parameters # Parse parameters
width, height = self._parse_size(req.size) width, height = self._parse_size(req.size)
num_frames = self._compute_num_frames(req) num_frames = self._compute_num_frames(req, nvext)
num_inference_steps = ( num_inference_steps = (
req.num_inference_steps nvext.num_inference_steps
if req.num_inference_steps is not None if nvext.num_inference_steps is not None
else self.config.default_num_inference_steps else self.config.default_num_inference_steps
) )
guidance_scale = ( guidance_scale = (
req.guidance_scale nvext.guidance_scale
if req.guidance_scale is not None if nvext.guidance_scale is not None
else self.config.default_guidance_scale else self.config.default_guidance_scale
) )
...@@ -205,18 +208,18 @@ class VideoGenerationHandler(BaseGenerativeHandler): ...@@ -205,18 +208,18 @@ class VideoGenerationHandler(BaseGenerativeHandler):
frames = await asyncio.to_thread( frames = await asyncio.to_thread(
self.engine.generate, self.engine.generate,
prompt=req.prompt, prompt=req.prompt,
negative_prompt=req.negative_prompt, negative_prompt=nvext.negative_prompt,
height=height, height=height,
width=width, width=width,
num_frames=num_frames, num_frames=num_frames,
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale, guidance_scale=guidance_scale,
seed=req.seed, seed=nvext.seed,
) )
# Determine output format # Determine output format
response_format = req.response_format or "url" response_format = req.response_format or "url"
fps = req.fps or self.config.default_fps fps = nvext.fps or self.config.default_fps
if response_format == "url": if response_format == "url":
# Encode to MP4 and save to file # Encode to MP4 and save to file
......
...@@ -24,6 +24,7 @@ from dynamo.trtllm.protocols.video_protocol import ( ...@@ -24,6 +24,7 @@ from dynamo.trtllm.protocols.video_protocol import (
NvCreateVideoRequest, NvCreateVideoRequest,
NvVideosResponse, NvVideosResponse,
VideoData, VideoData,
VideoNvExt,
) )
pytestmark = [ pytestmark = [
...@@ -266,53 +267,39 @@ class TestVideoHandlerComputeNumFrames: ...@@ -266,53 +267,39 @@ class TestVideoHandlerComputeNumFrames:
def test_compute_num_frames_explicit(self): def test_compute_num_frames_explicit(self):
"""Test that explicit num_frames takes priority.""" """Test that explicit num_frames takes priority."""
req = NvCreateVideoRequest( req = NvCreateVideoRequest(prompt="test", model="test-model", seconds=10)
prompt="test", nvext = VideoNvExt(
model="test-model",
num_frames=100, num_frames=100,
seconds=10, # Should be ignored
fps=30, # Should be ignored fps=30, # Should be ignored
) )
assert self.handler._compute_num_frames(req) == 100 assert self.handler._compute_num_frames(req, nvext) == 100
def test_compute_num_frames_from_seconds_fps(self): def test_compute_num_frames_from_seconds_fps(self):
"""Test computation from seconds * fps.""" """Test computation from seconds * fps."""
req = NvCreateVideoRequest( req = NvCreateVideoRequest(prompt="test", model="test-model", seconds=4)
prompt="test", nvext = VideoNvExt(fps=24)
model="test-model", assert self.handler._compute_num_frames(req, nvext) == 96 # 4 * 24
seconds=4,
fps=24,
)
assert self.handler._compute_num_frames(req) == 96 # 4 * 24
def test_compute_num_frames_only_seconds(self): def test_compute_num_frames_only_seconds(self):
"""Test seconds with default fps (24).""" """Test seconds with default fps (24)."""
req = NvCreateVideoRequest( req = NvCreateVideoRequest(prompt="test", model="test-model", seconds=5)
prompt="test", nvext = VideoNvExt()
model="test-model",
seconds=5,
)
# seconds=5, default fps=24 -> 5 * 24 = 120 # seconds=5, default fps=24 -> 5 * 24 = 120
assert self.handler._compute_num_frames(req) == 120 assert self.handler._compute_num_frames(req, nvext) == 120
def test_compute_num_frames_only_fps(self): def test_compute_num_frames_only_fps(self):
"""Test fps with default seconds (4).""" """Test fps with default seconds (4)."""
req = NvCreateVideoRequest( req = NvCreateVideoRequest(prompt="test", model="test-model")
prompt="test", nvext = VideoNvExt(fps=30)
model="test-model",
fps=30,
)
# default seconds=4, fps=30 -> 4 * 30 = 120 # default seconds=4, fps=30 -> 4 * 30 = 120
assert self.handler._compute_num_frames(req) == 120 assert self.handler._compute_num_frames(req, nvext) == 120
def test_compute_num_frames_defaults(self): def test_compute_num_frames_defaults(self):
"""Test all None uses config default.""" """Test all None uses config default."""
req = NvCreateVideoRequest( req = NvCreateVideoRequest(prompt="test", model="test-model")
prompt="test", nvext = VideoNvExt()
model="test-model",
)
assert ( assert (
self.handler._compute_num_frames(req) self.handler._compute_num_frames(req, nvext)
== MockDiffusionConfig.default_num_frames == MockDiffusionConfig.default_num_frames
) )
...@@ -345,43 +332,41 @@ class TestNvCreateVideoRequest: ...@@ -345,43 +332,41 @@ class TestNvCreateVideoRequest:
"""Test that optional fields default to None.""" """Test that optional fields default to None."""
req = NvCreateVideoRequest(prompt="A cat", model="wan_t2v") req = NvCreateVideoRequest(prompt="A cat", model="wan_t2v")
assert req.size is None assert req.input_reference is None
assert req.seconds is None assert req.seconds is None
assert req.fps is None assert req.size is None
assert req.num_frames is None
assert req.num_inference_steps is None
assert req.guidance_scale is None
assert req.negative_prompt is None
assert req.seed is None
assert req.response_format is None assert req.response_format is None
assert req.nvext is None
def test_full_request_valid(self): def test_full_request_valid(self):
"""Test a fully populated request.""" """Test a fully populated request with nvext."""
req = NvCreateVideoRequest( req = NvCreateVideoRequest(
prompt="A majestic lion", prompt="A majestic lion",
model="wan_t2v", model="wan_t2v",
size="1920x1080",
seconds=5, seconds=5,
size="1920x1080",
response_format="b64_json",
nvext=VideoNvExt(
fps=30, fps=30,
num_frames=150, num_frames=150,
num_inference_steps=30, num_inference_steps=30,
guidance_scale=7.5, guidance_scale=7.5,
negative_prompt="blurry, low quality", negative_prompt="blurry, low quality",
seed=42, seed=42,
response_format="b64_json", ),
) )
assert req.prompt == "A majestic lion" assert req.prompt == "A majestic lion"
assert req.model == "wan_t2v" assert req.model == "wan_t2v"
assert req.size == "1920x1080"
assert req.seconds == 5 assert req.seconds == 5
assert req.fps == 30 assert req.size == "1920x1080"
assert req.num_frames == 150
assert req.num_inference_steps == 30
assert req.guidance_scale == 7.5
assert req.negative_prompt == "blurry, low quality"
assert req.seed == 42
assert req.response_format == "b64_json" assert req.response_format == "b64_json"
assert req.nvext.fps == 30
assert req.nvext.num_frames == 150
assert req.nvext.num_inference_steps == 30
assert req.nvext.guidance_scale == 7.5
assert req.nvext.negative_prompt == "blurry, low quality"
assert req.nvext.seed == 42
class TestVideoData: class TestVideoData:
......
...@@ -243,17 +243,19 @@ python -m dynamo.trtllm \ ...@@ -243,17 +243,19 @@ python -m dynamo.trtllm \
### API Endpoint ### API Endpoint
Video generation uses the `/v1/videos/generations` endpoint: Video generation uses the `/v1/videos` endpoint:
```bash ```bash
curl -X POST http://localhost:8000/v1/videos/generations \ curl -X POST http://localhost:8000/v1/videos \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"prompt": "A cat playing piano", "prompt": "A cat playing piano",
"model": "wan_t2v", "model": "wan_t2v",
"size": "832x480",
"seconds": 4, "seconds": 4,
"size": "832x480",
"nvext": {
"fps": 24 "fps": 24
}
}' }'
``` ```
......
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
# Setup cleanup trap
cleanup() {
echo "Cleaning up background processes..."
kill $FRONTEND_PID 2>/dev/null || true
wait $FRONTEND_PID 2>/dev/null || true
echo "Cleanup complete."
}
trap cleanup EXIT INT TERM
# Defaults
WAN_SIZE="1b"
FS_URL="file:///tmp/dynamo_videos"
HTTP_PORT="${HTTP_PORT:-8000}"
NUM_FRAMES=17
HEIGHT=480
WIDTH=832
NUM_INFERENCE_STEPS=50
# Parse command line arguments
EXTRA_ARGS=()
while [[ $# -gt 0 ]]; do
case $1 in
--wan-size)
WAN_SIZE="$2"
shift 2
;;
--fs-url)
FS_URL="$2"
shift 2
;;
--http-port)
HTTP_PORT="$2"
shift 2
;;
--num-frames)
NUM_FRAMES="$2"
shift 2
;;
--height)
HEIGHT="$2"
shift 2
;;
--width)
WIDTH="$2"
shift 2
;;
--num-inference-steps)
NUM_INFERENCE_STEPS="$2"
shift 2
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo ""
echo "Launch a Dynamo T2V (text-to-video) worker with Wan models."
echo ""
echo "Options:"
echo " --wan-size <1b|14b> Model size (default: 1b)"
echo " --fs-url <url> Filesystem URL for video storage (default: file:///tmp/dynamo_videos)"
echo " --http-port <port> Frontend HTTP port (default: 8000)"
echo " --num-frames <n> Default frame count for health check (default: 17)"
echo " --height <n> Video height (default: 480)"
echo " --width <n> Video width (default: 832)"
echo " --num-inference-steps <n> Denoising steps (default: 50)"
echo " -h, --help Show this help message"
echo ""
echo "Additional flags are forwarded to dynamo.sglang."
exit 0
;;
*)
EXTRA_ARGS+=("$1")
shift
;;
esac
done
# Select model and TP based on size
case "$WAN_SIZE" in
1b|1B)
MODEL_PATH="Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
TP_SIZE=1
;;
14b|14B)
MODEL_PATH="Wan-AI/Wan2.1-T2V-14B-Diffusers"
TP_SIZE=2
;;
*)
echo "Error: --wan-size must be '1b' or '14b', got '$WAN_SIZE'"
exit 1
;;
esac
echo "=========================================="
echo "Launching T2V Video Generation Worker"
echo "=========================================="
echo "Model: $MODEL_PATH"
echo "TP Size: $TP_SIZE"
echo "Frontend: http://localhost:$HTTP_PORT"
echo "FS URL: $FS_URL"
echo "Resolution: ${WIDTH}x${HEIGHT}"
echo "=========================================="
echo ""
echo "Example test command:"
echo ""
echo " curl http://localhost:${HTTP_PORT}/v1/videos \\"
echo " -H 'Content-Type: application/json' \\"
echo " -d '{"
echo " \"prompt\": \"A curious raccoon exploring a garden\","
echo " \"model\": \"${MODEL_PATH}\","
echo " \"seconds\": 2,"
echo " \"size\": \"${WIDTH}x${HEIGHT}\","
echo " \"response_format\": \"url\","
echo " \"nvext\": {"
echo " \"fps\": 8,"
echo " \"num_frames\": ${NUM_FRAMES},"
echo " \"num_inference_steps\": ${NUM_INFERENCE_STEPS}"
echo " }"
echo " }'"
echo ""
echo "=========================================="
# Launch frontend
echo "Starting Dynamo Frontend on port $HTTP_PORT..."
python3 -m dynamo.frontend \
--http-port "$HTTP_PORT" &
FRONTEND_PID=$!
sleep 2
# Launch video generation worker
echo "Starting T2V Worker ($WAN_SIZE)..."
python3 -m dynamo.sglang \
--model-path "$MODEL_PATH" \
--served-model-name "$MODEL_PATH" \
--tp "$TP_SIZE" \
--video-generation-worker \
--video-generation-fs-url "$FS_URL" \
--trust-remote-code \
--skip-tokenizer-init \
--enable-metrics \
"${EXTRA_ARGS[@]}"
...@@ -263,6 +263,7 @@ fn register_llm<'p>( ...@@ -263,6 +263,7 @@ fn register_llm<'p>(
let is_tensor_based = model_type.inner.supports_tensor(); let is_tensor_based = model_type.inner.supports_tensor();
let is_images = model_type.inner.supports_images(); let is_images = model_type.inner.supports_images();
let is_videos = model_type.inner.supports_videos();
let model_type_obj = model_type.inner; let model_type_obj = model_type.inner;
...@@ -311,9 +312,9 @@ fn register_llm<'p>( ...@@ -311,9 +312,9 @@ fn register_llm<'p>(
.or_else(|| Some(source_path.clone())); .or_else(|| Some(source_path.clone()));
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
// For TensorBased and Images models, skip HuggingFace downloads and register directly // For TensorBased, Images, and Videos models, skip HuggingFace downloads and register directly
// Images models (vLLM-Omni) handle model loading internally, no tokenizer extraction needed // These model types handle model loading internally, no tokenizer extraction needed
if is_tensor_based || is_images { if is_tensor_based || is_images || is_videos {
let model_name = model_name.unwrap_or_else(|| source_path.clone()); let model_name = model_name.unwrap_or_else(|| source_path.clone());
let mut card = llm_rs::model_card::ModelDeploymentCard::with_name_only(&model_name); let mut card = llm_rs::model_card::ModelDeploymentCard::with_name_only(&model_name);
card.model_type = model_type_obj; card.model_type = model_type_obj;
...@@ -527,6 +528,10 @@ impl ModelType { ...@@ -527,6 +528,10 @@ impl ModelType {
const Images: Self = ModelType { const Images: Self = ModelType {
inner: llm_rs::model_type::ModelType::Images, inner: llm_rs::model_type::ModelType::Images,
}; };
#[classattr]
const Videos: Self = ModelType {
inner: llm_rs::model_type::ModelType::Videos,
};
fn supports_chat(&self) -> bool { fn supports_chat(&self) -> bool {
self.inner.supports_chat() self.inner.supports_chat()
......
...@@ -915,13 +915,14 @@ class ModelInput: ...@@ -915,13 +915,14 @@ class ModelInput:
... ...
class ModelType: class ModelType:
"""What type of request this model needs: Chat, Completions, Embedding, Tensor, Images or Prefill""" """What type of request this model needs: Chat, Completions, Embedding, Tensor, Images, Videos or Prefill"""
Chat: ModelType Chat: ModelType
Completions: ModelType Completions: ModelType
Embedding: ModelType Embedding: ModelType
TensorBased: ModelType TensorBased: ModelType
Prefill: ModelType Prefill: ModelType
Images: ModelType Images: ModelType
Videos: ModelType
... ...
class RouterMode: class RouterMode:
......
...@@ -34,6 +34,7 @@ use crate::{ ...@@ -34,6 +34,7 @@ use crate::{
chat_completions::OpenAIChatCompletionsStreamingEngine, chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine, completions::OpenAICompletionsStreamingEngine,
embeddings::OpenAIEmbeddingsStreamingEngine, images::OpenAIImagesStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine, images::OpenAIImagesStreamingEngine,
videos::OpenAIVideosStreamingEngine,
}, },
}, },
}; };
...@@ -67,6 +68,7 @@ pub struct ModelManager { ...@@ -67,6 +68,7 @@ pub struct ModelManager {
chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>, chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>,
embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>, embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
images_engines: RwLock<ModelEngines<OpenAIImagesStreamingEngine>>, images_engines: RwLock<ModelEngines<OpenAIImagesStreamingEngine>>,
videos_engines: RwLock<ModelEngines<OpenAIVideosStreamingEngine>>,
tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>, tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
// Prefill models don't have engines - they're only tracked for discovery/lifecycle // Prefill models don't have engines - they're only tracked for discovery/lifecycle
prefill_engines: RwLock<ModelEngines<()>>, prefill_engines: RwLock<ModelEngines<()>>,
...@@ -93,6 +95,7 @@ impl ModelManager { ...@@ -93,6 +95,7 @@ impl ModelManager {
chat_completion_engines: RwLock::new(ModelEngines::default()), chat_completion_engines: RwLock::new(ModelEngines::default()),
embeddings_engines: RwLock::new(ModelEngines::default()), embeddings_engines: RwLock::new(ModelEngines::default()),
images_engines: RwLock::new(ModelEngines::default()), images_engines: RwLock::new(ModelEngines::default()),
videos_engines: RwLock::new(ModelEngines::default()),
tensor_engines: RwLock::new(ModelEngines::default()), tensor_engines: RwLock::new(ModelEngines::default()),
prefill_engines: RwLock::new(ModelEngines::default()), prefill_engines: RwLock::new(ModelEngines::default()),
cards: DashMap::new(), cards: DashMap::new(),
...@@ -117,6 +120,7 @@ impl ModelManager { ...@@ -117,6 +120,7 @@ impl ModelManager {
ModelType::Embedding => self.embeddings_engines.read().checksum(model_name), ModelType::Embedding => self.embeddings_engines.read().checksum(model_name),
ModelType::TensorBased => self.tensor_engines.read().checksum(model_name), ModelType::TensorBased => self.tensor_engines.read().checksum(model_name),
ModelType::Images => self.images_engines.read().checksum(model_name), ModelType::Images => self.images_engines.read().checksum(model_name),
ModelType::Videos => self.videos_engines.read().checksum(model_name),
ModelType::Prefill => self.prefill_engines.read().checksum(model_name), ModelType::Prefill => self.prefill_engines.read().checksum(model_name),
_ => { _ => {
continue; continue;
...@@ -168,8 +172,9 @@ impl ModelManager { ...@@ -168,8 +172,9 @@ impl ModelManager {
.into_iter() .into_iter()
.chain(self.list_completions_models()) .chain(self.list_completions_models())
.chain(self.list_embeddings_models()) .chain(self.list_embeddings_models())
.chain(self.list_tensor_models())
.chain(self.list_images_models()) .chain(self.list_images_models())
.chain(self.list_videos_models())
.chain(self.list_tensor_models())
.chain(self.list_prefill_models()) .chain(self.list_prefill_models())
.collect() .collect()
} }
...@@ -198,6 +203,10 @@ impl ModelManager { ...@@ -198,6 +203,10 @@ impl ModelManager {
self.images_engines.read().list() self.images_engines.read().list()
} }
pub fn list_videos_models(&self) -> Vec<String> {
self.videos_engines.read().list()
}
pub fn add_completions_model( pub fn add_completions_model(
&self, &self,
model: &str, model: &str,
...@@ -248,6 +257,16 @@ impl ModelManager { ...@@ -248,6 +257,16 @@ impl ModelManager {
clients.add(model, card_checksum, engine) clients.add(model, card_checksum, engine)
} }
pub fn add_videos_model(
&self,
model: &str,
card_checksum: &str,
engine: OpenAIVideosStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.videos_engines.write();
clients.add(model, card_checksum, engine)
}
pub fn add_prefill_model( pub fn add_prefill_model(
&self, &self,
model: &str, model: &str,
...@@ -282,6 +301,11 @@ impl ModelManager { ...@@ -282,6 +301,11 @@ impl ModelManager {
clients.remove(model) clients.remove(model)
} }
pub fn remove_videos_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.videos_engines.write();
clients.remove(model)
}
pub fn remove_prefill_model(&self, model: &str) -> Result<(), ModelManagerError> { pub fn remove_prefill_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.prefill_engines.write(); let mut clients = self.prefill_engines.write();
clients.remove(model) clients.remove(model)
...@@ -342,6 +366,17 @@ impl ModelManager { ...@@ -342,6 +366,17 @@ impl ModelManager {
.ok_or(ModelManagerError::ModelNotFound(model.to_string())) .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
} }
pub fn get_videos_engine(
&self,
model: &str,
) -> Result<OpenAIVideosStreamingEngine, ModelManagerError> {
self.videos_engines
.read()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
}
/// Save a ModelDeploymentCard from an instance's key so we can fetch it later when the key is /// Save a ModelDeploymentCard from an instance's key so we can fetch it later when the key is
/// deleted. /// deleted.
pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> { pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
......
...@@ -40,6 +40,7 @@ use crate::{ ...@@ -40,6 +40,7 @@ use crate::{
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}, embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
images::{NvCreateImageRequest, NvImagesResponse}, images::{NvCreateImageRequest, NvImagesResponse},
videos::{NvCreateVideoRequest, NvVideosResponse},
}, },
tensor::{NvCreateTensorRequest, NvCreateTensorResponse}, tensor::{NvCreateTensorRequest, NvCreateTensorResponse},
}, },
...@@ -70,8 +71,9 @@ const ALL_MODEL_TYPES: &[ModelType] = &[ ...@@ -70,8 +71,9 @@ const ALL_MODEL_TYPES: &[ModelType] = &[
ModelType::Chat, ModelType::Chat,
ModelType::Completions, ModelType::Completions,
ModelType::Embedding, ModelType::Embedding,
ModelType::TensorBased,
ModelType::Images, ModelType::Images,
ModelType::Videos,
ModelType::TensorBased,
ModelType::Prefill, ModelType::Prefill,
]; ];
...@@ -284,15 +286,17 @@ impl ModelWatcher { ...@@ -284,15 +286,17 @@ impl ModelWatcher {
let chat_model_remove_err = self.manager.remove_chat_completions_model(&model_name); let chat_model_remove_err = self.manager.remove_chat_completions_model(&model_name);
let completions_model_remove_err = self.manager.remove_completions_model(&model_name); let completions_model_remove_err = self.manager.remove_completions_model(&model_name);
let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name); let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name);
let tensor_model_remove_err = self.manager.remove_tensor_model(&model_name);
let images_model_remove_err = self.manager.remove_images_model(&model_name); let images_model_remove_err = self.manager.remove_images_model(&model_name);
let videos_model_remove_err = self.manager.remove_videos_model(&model_name);
let tensor_model_remove_err = self.manager.remove_tensor_model(&model_name);
let prefill_model_remove_err = self.manager.remove_prefill_model(&model_name); let prefill_model_remove_err = self.manager.remove_prefill_model(&model_name);
let mut chat_model_removed = false; let mut chat_model_removed = false;
let mut completions_model_removed = false; let mut completions_model_removed = false;
let mut embeddings_model_removed = false; let mut embeddings_model_removed = false;
let mut tensor_model_removed = false;
let mut images_model_removed = false; let mut images_model_removed = false;
let mut videos_model_removed = false;
let mut tensor_model_removed = false;
let mut prefill_model_removed = false; let mut prefill_model_removed = false;
if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() { if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() {
...@@ -305,12 +309,15 @@ impl ModelWatcher { ...@@ -305,12 +309,15 @@ impl ModelWatcher {
if embeddings_model_remove_err.is_ok() && self.manager.list_embeddings_models().is_empty() { if embeddings_model_remove_err.is_ok() && self.manager.list_embeddings_models().is_empty() {
embeddings_model_removed = true; embeddings_model_removed = true;
} }
if tensor_model_remove_err.is_ok() && self.manager.list_tensor_models().is_empty() {
tensor_model_removed = true;
}
if images_model_remove_err.is_ok() && self.manager.list_images_models().is_empty() { if images_model_remove_err.is_ok() && self.manager.list_images_models().is_empty() {
images_model_removed = true; images_model_removed = true;
} }
if videos_model_remove_err.is_ok() && self.manager.list_videos_models().is_empty() {
videos_model_removed = true;
}
if tensor_model_remove_err.is_ok() && self.manager.list_tensor_models().is_empty() {
tensor_model_removed = true;
}
if prefill_model_remove_err.is_ok() && self.manager.list_prefill_models().is_empty() { if prefill_model_remove_err.is_ok() && self.manager.list_prefill_models().is_empty() {
prefill_model_removed = true; prefill_model_removed = true;
} }
...@@ -318,18 +325,20 @@ impl ModelWatcher { ...@@ -318,18 +325,20 @@ impl ModelWatcher {
if !chat_model_removed if !chat_model_removed
&& !completions_model_removed && !completions_model_removed
&& !embeddings_model_removed && !embeddings_model_removed
&& !tensor_model_removed
&& !images_model_removed && !images_model_removed
&& !videos_model_removed
&& !tensor_model_removed
&& !prefill_model_removed && !prefill_model_removed
{ {
tracing::debug!( tracing::debug!(
"No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, tensor_model_removed: {}, images_model_removed: {}, prefill_model_removed: {}", "No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, images_model_removed: {}, videos_model_removed: {}, tensor_model_removed: {}, prefill_model_removed: {}",
model_name, model_name,
chat_model_removed, chat_model_removed,
completions_model_removed, completions_model_removed,
embeddings_model_removed, embeddings_model_removed,
tensor_model_removed,
images_model_removed, images_model_removed,
videos_model_removed,
tensor_model_removed,
prefill_model_removed prefill_model_removed
); );
} else { } else {
...@@ -337,8 +346,9 @@ impl ModelWatcher { ...@@ -337,8 +346,9 @@ impl ModelWatcher {
if ((chat_model_removed && *model_type == ModelType::Chat) if ((chat_model_removed && *model_type == ModelType::Chat)
|| (completions_model_removed && *model_type == ModelType::Completions) || (completions_model_removed && *model_type == ModelType::Completions)
|| (embeddings_model_removed && *model_type == ModelType::Embedding) || (embeddings_model_removed && *model_type == ModelType::Embedding)
|| (tensor_model_removed && *model_type == ModelType::TensorBased)
|| (images_model_removed && *model_type == ModelType::Images) || (images_model_removed && *model_type == ModelType::Images)
|| (videos_model_removed && *model_type == ModelType::Videos)
|| (tensor_model_removed && *model_type == ModelType::TensorBased)
|| (prefill_model_removed && *model_type == ModelType::Prefill)) || (prefill_model_removed && *model_type == ModelType::Prefill))
&& let Some(tx) = &self.model_update_tx && let Some(tx) = &self.model_update_tx
{ {
...@@ -675,6 +685,19 @@ impl ModelWatcher { ...@@ -675,6 +685,19 @@ impl ModelWatcher {
checksum, checksum,
Arc::new(chat_router), Arc::new(chat_router),
)?; )?;
} else if card.model_input == ModelInput::Text && card.model_type.supports_videos() {
// Case: Text + Videos (video generation models)
// Takes text prompts as input, generates videos
let push_router = PushRouter::<
NvCreateVideoRequest,
Annotated<NvVideosResponse>,
>::from_client_with_threshold(
client, self.router_config.router_mode, None, None
)
.await?;
let engine = Arc::new(push_router);
self.manager
.add_videos_model(card.name(), checksum, engine)?;
} else if card.model_type.supports_prefill() { } else if card.model_type.supports_prefill() {
// Case 6: Prefill // Case 6: Prefill
// Guardrail: Verify model_input is Tokens // Guardrail: Verify model_input is Tokens
......
...@@ -14,6 +14,8 @@ pub enum EndpointType { ...@@ -14,6 +14,8 @@ pub enum EndpointType {
Embedding, Embedding,
/// Images API (Diffusion/DALL-E) /// Images API (Diffusion/DALL-E)
Images, Images,
/// Videos API (Video Generation)
Videos,
/// Responses API /// Responses API
Responses, Responses,
} }
...@@ -25,6 +27,7 @@ impl EndpointType { ...@@ -25,6 +27,7 @@ impl EndpointType {
Self::Completion => "completion", Self::Completion => "completion",
Self::Embedding => "embedding", Self::Embedding => "embedding",
Self::Images => "images", Self::Images => "images",
Self::Videos => "videos",
Self::Responses => "responses", Self::Responses => "responses",
} }
} }
...@@ -35,6 +38,7 @@ impl EndpointType { ...@@ -35,6 +38,7 @@ impl EndpointType {
Self::Completion, Self::Completion,
Self::Embedding, Self::Embedding,
Self::Images, Self::Images,
Self::Videos,
Self::Responses, Self::Responses,
] ]
} }
......
...@@ -291,6 +291,9 @@ pub enum Endpoint { ...@@ -291,6 +291,9 @@ pub enum Endpoint {
/// OAI Images /// OAI Images
Images, Images,
/// OAI Videos
Videos,
/// OAI Responses /// OAI Responses
Responses, Responses,
...@@ -943,6 +946,7 @@ impl std::fmt::Display for Endpoint { ...@@ -943,6 +946,7 @@ impl std::fmt::Display for Endpoint {
Endpoint::ChatCompletions => write!(f, "chat_completions"), Endpoint::ChatCompletions => write!(f, "chat_completions"),
Endpoint::Embeddings => write!(f, "embeddings"), Endpoint::Embeddings => write!(f, "embeddings"),
Endpoint::Images => write!(f, "images"), Endpoint::Images => write!(f, "images"),
Endpoint::Videos => write!(f, "videos"),
Endpoint::Responses => write!(f, "responses"), Endpoint::Responses => write!(f, "responses"),
Endpoint::Tensor => write!(f, "tensor"), Endpoint::Tensor => write!(f, "tensor"),
} }
...@@ -956,6 +960,7 @@ impl Endpoint { ...@@ -956,6 +960,7 @@ impl Endpoint {
Endpoint::ChatCompletions => "chat_completions", Endpoint::ChatCompletions => "chat_completions",
Endpoint::Embeddings => "embeddings", Endpoint::Embeddings => "embeddings",
Endpoint::Images => "images", Endpoint::Images => "images",
Endpoint::Videos => "videos",
Endpoint::Responses => "responses", Endpoint::Responses => "responses",
Endpoint::Tensor => "tensor", Endpoint::Tensor => "tensor",
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment