Unverified Commit d6c49779 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: unify filesystem usage across all frameworks and workers (#6391)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 6dd3ce2e
......@@ -2,6 +2,8 @@
.vs/
.vscode/
.helix
issues/
plans/
*rendered.Dockerfile
[Bb]inlog/
[Bb][Uu][Ii][Ll][Dd]/
......
......@@ -31,6 +31,8 @@ class DynamoRuntimeConfig(ConfigBase):
dump_config_to: Optional[str] = None
multimodal_embedding_cache_capacity_gb: float
output_modalities: List[str]
media_output_fs_url: str = "file:///tmp/dynamo_media"
media_output_http_url: Optional[str] = None
def validate(self) -> None:
# TODO get a better way for spot fixes like this.
......@@ -176,3 +178,19 @@ class DynamoRuntimeArgGroup(ArgGroup):
help="Output modalities for omni/diffusion mode (e.g., --output-modalities text image audio video).",
nargs="*",
)
# Media storage (generated images and videos)
add_argument(
g,
flag_name="--media-output-fs-url",
env_var="DYN_MEDIA_OUTPUT_FS_URL",
default="file:///tmp/dynamo_media",
help="Filesystem URL for storing generated images and videos (e.g. file:///tmp/dynamo_media, s3://bucket/path).",
)
add_argument(
g,
flag_name="--media-output-http-url",
env_var="DYN_MEDIA_OUTPUT_HTTP_URL",
default=None,
help="Base URL for rewriting media file paths in responses (e.g. http://localhost:8000/media). If unset, returns raw filesystem paths.",
)
......@@ -32,6 +32,9 @@ S3:
"""
import asyncio
from typing import Optional
import fsspec
from fsspec.implementations.dirfs import DirFileSystem
......@@ -64,3 +67,51 @@ def get_fs(fs_url: str) -> DirFileSystem:
fs_opts = {"auto_mkdir": True}
return DirFileSystem(fs=fsspec.filesystem(protocol, **fs_opts), path=root_path)
def get_media_url(
fs: DirFileSystem, storage_path: str, base_url: Optional[str] = None
) -> str:
"""Build a public URL for a file stored in the media filesystem.
Args:
fs: The DirFileSystem returned by ``get_fs()``.
storage_path: Relative path within the filesystem (e.g. "videos/req-id.mp4").
base_url: Optional CDN / proxy base URL. When set, the returned URL is
``{base_url}/{storage_path}``. When *None*, the URL is constructed
from the filesystem's protocol and root path.
Returns:
Public URL string for the uploaded file.
"""
if base_url:
return f"{base_url.rstrip('/')}/{storage_path}"
protocol = fs.fs.protocol
if isinstance(protocol, (list, tuple)):
protocol = protocol[0]
return f"{protocol}://{fs.path}/{storage_path}"
async def upload_to_fs(
fs: DirFileSystem,
storage_path: str,
data: bytes,
base_url: Optional[str] = None,
) -> str:
"""Upload bytes to the media filesystem and return the public URL.
This is the canonical helper for all backends (vLLM, SGLang, TRT-LLM)
to store generated images/videos and produce a response URL.
Args:
fs: The DirFileSystem returned by ``get_fs()``.
storage_path: Relative path within the filesystem (e.g. "images/req-id/file.png").
data: Raw bytes to upload.
base_url: Optional CDN / proxy base URL for URL rewriting.
Returns:
Public URL string for the uploaded file.
"""
await asyncio.to_thread(fs.pipe, storage_path, data)
return get_media_url(fs, storage_path, base_url)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for dynamo.common.storage module."""
from unittest.mock import MagicMock, patch
import pytest
from dynamo.common.storage import get_fs, get_media_url, upload_to_fs
pytestmark = [
pytest.mark.unit,
pytest.mark.gpu_0,
pytest.mark.pre_merge,
]
class TestGetFs:
"""Tests for get_fs() filesystem initialization."""
def test_local_file_url(self, tmp_path):
"""Test file:// URL returns DirFileSystem with correct path."""
media_dir = tmp_path / "test_media"
fs = get_fs(f"file://{media_dir}")
assert fs.path == str(media_dir)
protocol = fs.fs.protocol
if isinstance(protocol, (list, tuple)):
protocol = protocol[0]
assert protocol == "file"
def test_local_file_url_auto_mkdir(self, tmp_path):
"""Test file:// URL enables auto_mkdir on underlying filesystem."""
media_dir = tmp_path / "test_media"
fs = get_fs(f"file://{media_dir}")
assert fs.fs.auto_mkdir is True
def test_no_protocol_defaults_to_file(self, tmp_path):
"""Test URL without protocol defaults to file."""
media_dir = tmp_path / "test_media"
fs = get_fs(str(media_dir))
protocol = fs.fs.protocol
if isinstance(protocol, (list, tuple)):
protocol = protocol[0]
assert protocol == "file"
def test_s3_url_protocol(self):
"""Test s3:// URL extracts correct protocol and bucket path."""
with patch("dynamo.common.storage.fsspec.filesystem") as mock_fsspec, patch(
"dynamo.common.storage.DirFileSystem"
) as mock_dirfs:
mock_inner_fs = MagicMock(protocol="s3")
mock_fsspec.return_value = mock_inner_fs
get_fs("s3://my-bucket/prefix")
mock_fsspec.assert_called_once_with("s3")
mock_dirfs.assert_called_once_with(
fs=mock_inner_fs, path="my-bucket/prefix"
)
def test_gs_url_protocol(self):
"""Test gs:// URL extracts correct protocol and path."""
with patch("dynamo.common.storage.fsspec.filesystem") as mock_fsspec, patch(
"dynamo.common.storage.DirFileSystem"
) as mock_dirfs:
mock_inner_fs = MagicMock(protocol="gs")
mock_fsspec.return_value = mock_inner_fs
get_fs("gs://my-gcs-bucket/data")
mock_fsspec.assert_called_once_with("gs")
mock_dirfs.assert_called_once_with(
fs=mock_inner_fs, path="my-gcs-bucket/data"
)
class TestGetMediaUrl:
"""Tests for get_media_url() URL construction."""
def _make_fs(self, protocol="file", path="/tmp/media"): # noqa: S108
"""Create a mock DirFileSystem."""
fs = MagicMock()
fs.fs.protocol = protocol
fs.path = path
return fs
def test_base_url_rewrite(self):
"""Test that base_url takes precedence over protocol fallback."""
fs = self._make_fs()
url = get_media_url(
fs, "videos/req-123.mp4", base_url="https://cdn.example.com/media"
)
assert url == "https://cdn.example.com/media/videos/req-123.mp4"
def test_base_url_trailing_slash_stripped(self):
"""Test that trailing slash on base_url is normalized."""
fs = self._make_fs()
url = get_media_url(fs, "images/test.png", base_url="https://cdn.example.com/")
assert url == "https://cdn.example.com/images/test.png"
def test_protocol_fallback_file(self):
"""Test URL construction from file:// protocol when no base_url."""
fs = self._make_fs(protocol="file", path="/tmp/dynamo_media") # noqa: S108
url = get_media_url(fs, "videos/req-123.mp4")
assert url == "file:///tmp/dynamo_media/videos/req-123.mp4"
def test_protocol_fallback_s3(self):
"""Test URL construction from s3:// protocol when no base_url."""
fs = self._make_fs(protocol="s3", path="my-bucket/prefix")
url = get_media_url(fs, "images/img.png")
assert url == "s3://my-bucket/prefix/images/img.png"
def test_tuple_protocol_uses_first(self):
"""Test that tuple protocol (e.g., ('s3', 's3a')) uses first element."""
fs = self._make_fs()
fs.fs.protocol = ("s3", "s3a")
fs.path = "my-bucket"
url = get_media_url(fs, "file.mp4")
assert url == "s3://my-bucket/file.mp4"
def test_list_protocol_uses_first(self):
"""Test that list protocol uses first element."""
fs = self._make_fs()
fs.fs.protocol = ["gs", "gcs"]
fs.path = "my-gcs-bucket"
url = get_media_url(fs, "file.mp4")
assert url == "gs://my-gcs-bucket/file.mp4"
def test_base_url_none_falls_back_to_protocol(self):
"""Test that None base_url triggers protocol fallback."""
fs = self._make_fs()
url = get_media_url(fs, "test.png", base_url=None)
assert url == "file:///tmp/media/test.png"
def test_base_url_empty_string_falls_back_to_protocol(self):
"""Test that empty string base_url triggers protocol fallback."""
fs = self._make_fs()
url = get_media_url(fs, "test.png", base_url="")
assert url == "file:///tmp/media/test.png"
class TestUploadToFs:
"""Tests for upload_to_fs() async upload + URL construction."""
def _make_fs(self, protocol="file", path="/tmp/media"): # noqa: S108
"""Create a mock DirFileSystem."""
fs = MagicMock()
fs.fs.protocol = protocol
fs.path = path
fs.pipe = MagicMock()
return fs
@pytest.mark.asyncio
async def test_calls_pipe_with_correct_args(self):
"""Test that fs.pipe is called with storage_path and data."""
fs = self._make_fs()
data = b"test image bytes"
await upload_to_fs(fs, "images/test.png", data)
fs.pipe.assert_called_once_with("images/test.png", data)
@pytest.mark.asyncio
async def test_returns_url_with_base_url(self):
"""Test that returned URL uses base_url when provided."""
fs = self._make_fs()
url = await upload_to_fs(
fs, "videos/req-123.mp4", b"video", base_url="https://cdn.example.com"
)
assert url == "https://cdn.example.com/videos/req-123.mp4"
@pytest.mark.asyncio
async def test_returns_url_with_protocol_fallback(self):
"""Test that returned URL uses protocol when no base_url."""
fs = self._make_fs(protocol="s3", path="my-bucket/prefix")
url = await upload_to_fs(fs, "images/img.png", b"image")
assert url == "s3://my-bucket/prefix/images/img.png"
@pytest.mark.asyncio
async def test_pipe_called_before_url_returned(self):
"""Test that pipe() is called (data uploaded) before URL is returned."""
fs = self._make_fs()
call_order = []
original_pipe = fs.pipe
def tracking_pipe(*args, **kwargs):
call_order.append("pipe")
return original_pipe(*args, **kwargs)
fs.pipe = tracking_pipe
url = await upload_to_fs(fs, "test.png", b"data")
call_order.append("url_returned")
assert call_order == ["pipe", "url_returned"]
assert url
......@@ -51,7 +51,7 @@ def compute_num_frames(
) -> int:
"""Compute the number of video frames.
Priority: num_frames > seconds × fps > default_num_frames.
Priority: num_frames > seconds x fps > default_num_frames.
"""
if num_frames is not None:
return num_frames
......
......@@ -73,20 +73,6 @@ class DynamoSGLangArgGroup(ArgGroup):
default=False,
help="Run as image diffusion worker for image generation.",
)
add_argument(
g,
flag_name="--image-diffusion-fs-url",
env_var="DYN_SGL_IMAGE_DIFFUSION_FS_URL",
default=None,
help="Filesystem URL for storing generated images using fsspec (e.g., s3://bucket/path, gs://bucket/path, file:///local/path). Supports any fsspec-compatible filesystem.",
)
add_argument(
g,
flag_name="--image-diffusion-base-url",
env_var="DYN_SGL_IMAGE_DIFFUSION_BASE_URL",
default="http://localhost:8008/",
help="Base URL for rewriting image URLs in responses (e.g., http://localhost:8008/). When set, generated image URLs will use this base instead of filesystem URLs.",
)
add_argument(
g,
flag_name="--disagg-config",
......@@ -108,13 +94,6 @@ class DynamoSGLangArgGroup(ArgGroup):
default=False,
help="Run as video generation worker for video generation (T2V/I2V).",
)
add_argument(
g,
flag_name="--video-generation-fs-url",
env_var="DYN_SGL_VIDEO_GENERATION_FS_URL",
default=None,
help="Filesystem URL for storing generated videos using fsspec (e.g., s3://bucket/path, gs://bucket/path, file:///local/path). Supports any fsspec-compatible filesystem.",
)
class DynamoSGLangConfig(ConfigBase):
......@@ -126,14 +105,11 @@ class DynamoSGLangConfig(ConfigBase):
multimodal_worker: bool
embedding_worker: bool
image_diffusion_worker: bool
image_diffusion_fs_url: Optional[str] = None
image_diffusion_base_url: Optional[str] = None
disagg_config: Optional[str] = None
disagg_config_key: Optional[str] = None
video_generation_worker: bool
video_generation_fs_url: Optional[str] = None
def validate(self) -> None:
if (self.disagg_config is not None) ^ (self.disagg_config_key is not None):
......
......@@ -583,12 +583,7 @@ async def init_image_diffusion(runtime: DistributedRuntime, config: Config):
dist_timeout=dist_timeout,
)
# Initialize fsspec filesystems for image storage
fs_url = dynamo_args.image_diffusion_fs_url
# Initialize primary filesystem
if not fs_url:
raise ValueError("--image-diffusion-fs-url is required for diffusion workers")
fs_url = dynamo_args.media_output_fs_url
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
......@@ -667,14 +662,7 @@ async def init_video_generation(runtime: DistributedRuntime, config: Config):
dist_timeout=dist_timeout,
)
# Initialize fsspec filesystems for video storage
fs_url = dynamo_args.video_generation_fs_url
# Initialize primary filesystem
if not fs_url:
raise ValueError(
"--video-generation-fs-url is required for video generation workers"
)
fs_url = dynamo_args.media_output_fs_url
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
......
......@@ -14,6 +14,7 @@ import torch
from PIL import Image
from dynamo._core import Component, Context
from dynamo.common.storage import upload_to_fs
from dynamo.sglang.args import Config
from dynamo.sglang.protocol import CreateImageRequest, ImageData, ImagesResponse, NvExt
from dynamo.sglang.publisher import DynamoSglangPublisher
......@@ -52,8 +53,8 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
self.generator = generator # DiffGenerator, not Engine
self.fs = fs
self.fs_url = config.dynamo_args.image_diffusion_fs_url
self.base_url = config.dynamo_args.image_diffusion_base_url
self.fs_url = config.dynamo_args.media_output_fs_url
self.base_url = config.dynamo_args.media_output_http_url
logger.info(
f"Image diffusion worker handler initialized with fs_url={self.fs_url}, url_base={self.base_url}"
......@@ -225,10 +226,7 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
# Per-user storage path
storage_path = f"users/{user_id}/generations/{request_id}/{image_filename}"
# send image to filesystem
await asyncio.to_thread(self.fs.pipe, storage_path, image_bytes)
return f"{self.base_url}/{storage_path}"
return await upload_to_fs(self.fs, storage_path, image_bytes, self.base_url)
def _encode_base64(self, image_bytes: bytes) -> str:
"""Encode image as base64 string"""
......
......@@ -12,6 +12,7 @@ from typing import Any, AsyncGenerator, Optional
import torch
from dynamo._core import Component, Context
from dynamo.common.storage import upload_to_fs
from dynamo.sglang.args import Config
from dynamo.sglang.protocol import (
CreateVideoRequest,
......@@ -56,7 +57,8 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler):
self.generator = generator # DiffGenerator, not Engine
self._generate_lock = asyncio.Lock() # Serialize generator access
self.fs = fs
self.fs_url = config.dynamo_args.video_generation_fs_url
self.fs_url = config.dynamo_args.media_output_fs_url
self.base_url = config.dynamo_args.media_output_http_url
logger.info(
f"Video generation worker handler initialized with fs_url={self.fs_url}"
......@@ -303,11 +305,7 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler):
URL for the uploaded video.
"""
storage_path = f"{request_id}.mp4"
# DirFileSystem handles root path and protocol internally
await asyncio.to_thread(self.fs.pipe, storage_path, video_bytes)
return f"{self.fs_url}/{storage_path}"
return await upload_to_fs(self.fs, storage_path, video_bytes, self.base_url)
def _encode_base64(self, video_bytes: bytes) -> str:
"""Encode video as base64 string"""
......
......@@ -42,8 +42,8 @@ def mock_config():
"""Mock Config object."""
config = MagicMock()
config.dynamo_args = MagicMock()
config.dynamo_args.image_diffusion_fs_url = "file:///tmp/images"
config.dynamo_args.image_diffusion_base_url = "file:///tmp/images"
config.dynamo_args.media_output_fs_url = "file:///tmp/images"
config.dynamo_args.media_output_http_url = "file:///tmp/images"
return config
......@@ -96,8 +96,8 @@ class TestImageDiffusionWorkerHandler:
"""Test handler initialization with URL base."""
config = MagicMock()
config.dynamo_args = MagicMock()
config.dynamo_args.image_diffusion_fs_url = "s3://my-bucket/images"
config.dynamo_args.image_diffusion_base_url = "http://localhost:8008/images"
config.dynamo_args.media_output_fs_url = "s3://my-bucket/images"
config.dynamo_args.media_output_http_url = "http://localhost:8008/images"
handler = ImageDiffusionWorkerHandler(
component=mock_component,
......
......@@ -195,13 +195,6 @@ class DynamoTrtllmArgGroup(ArgGroup):
"Diffusion Options [Experimental]",
"Options for video_diffusion modality",
)
add_argument(
diffusion_group,
flag_name="--output-dir",
env_var="DYN_TRTLLM_OUTPUT_DIR",
default="/tmp/dynamo_videos",
help="Directory to store generated videos/images.",
)
add_argument(
diffusion_group,
flag_name="--default-height",
......@@ -377,7 +370,6 @@ class DynamoTrtllmConfig(ConfigBase):
allowed_local_media_path: str
max_file_size_mb: int
output_dir: str
default_height: int
default_width: int
default_num_frames: int
......
......@@ -41,8 +41,9 @@ class DiffusionConfig:
# float16 can be used on older GPUs (V100, etc.)
torch_dtype: str = "bfloat16"
# Output config
output_dir: str = "/tmp/dynamo_videos"
# Media storage
media_output_fs_url: str = "file:///tmp/dynamo_media"
media_output_http_url: Optional[str] = None
# Default generation parameters
default_height: int = 480
......@@ -85,7 +86,7 @@ class DiffusionConfig:
f"endpoint={self.endpoint}, "
f"model_path={self.model_path}, "
f"served_model_name={self.served_model_name}, "
f"output_dir={self.output_dir}, "
f"media_output_fs_url={self.media_output_fs_url}, "
f"default_height={self.default_height}, "
f"default_width={self.default_width}, "
f"default_num_frames={self.default_num_frames}, "
......
......@@ -20,7 +20,8 @@ from dynamo.common.protocols.video_protocol import (
VideoData,
VideoNvExt,
)
from dynamo.common.utils.video_utils import encode_to_mp4, encode_to_mp4_bytes
from dynamo.common.storage import get_fs, upload_to_fs
from dynamo.common.utils.video_utils import encode_to_mp4_bytes
from dynamo.trtllm.configs.diffusion_config import DiffusionConfig
from dynamo.trtllm.engines.diffusion_engine import DiffusionEngine
from dynamo.trtllm.request_handlers.base_generative_handler import BaseGenerativeHandler
......@@ -55,6 +56,12 @@ class VideoGenerationHandler(BaseGenerativeHandler):
self.component = component
self.engine = engine
self.config = config
if not config.media_output_fs_url:
raise ValueError(
"media_output_fs_url must be set; use --media-output-fs-url or DYN_MEDIA_OUTPUT_FS_URL."
)
self.media_output_fs = get_fs(config.media_output_fs_url)
self.media_output_http_url = config.media_output_http_url
# Serialize pipeline access — visual_gen is not thread-safe (global
# singleton configs, mutable instance state, unprotected CUDA graph cache).
# asyncio.Lock suspends waiting coroutines cooperatively so the event
......@@ -218,21 +225,21 @@ class VideoGenerationHandler(BaseGenerativeHandler):
response_format = req.response_format or "url"
fps = nvext.fps or self.config.default_fps
# Encode frames to MP4 bytes in memory
video_bytes = await asyncio.to_thread(encode_to_mp4_bytes, frames, fps=fps)
if response_format == "url":
# Encode to MP4 and save to file
output_path = await asyncio.to_thread(
encode_to_mp4,
frames,
self.config.output_dir,
request_id,
fps=fps,
# Upload via filesystem
storage_path = f"videos/{request_id}.mp4"
video_url = await upload_to_fs(
self.media_output_fs,
storage_path,
video_bytes,
self.media_output_http_url,
)
video_data = VideoData(url=output_path)
video_data = VideoData(url=video_url)
else:
# Encode to base64
video_bytes = await asyncio.to_thread(
encode_to_mp4_bytes, frames, fps=fps
)
b64_video = base64.b64encode(video_bytes).decode("utf-8")
video_data = VideoData(b64_json=b64_video)
......
......@@ -98,8 +98,9 @@ class TestDiffusionConfig:
assert config.default_num_inference_steps == 50
assert config.default_guidance_scale == 5.0
# Model defaults
assert config.output_dir == "/tmp/dynamo_videos"
# Media storage defaults
assert config.media_output_fs_url == "file:///tmp/dynamo_media"
assert config.media_output_http_url is None
# Optimization defaults
assert config.enable_teacache is False
......@@ -126,6 +127,16 @@ class TestDiffusionConfig:
assert config.enable_teacache is True
assert config.dit_tp_size == 2
def test_custom_media_storage(self):
"""Test that media storage fields can be overridden."""
config = DiffusionConfig(
media_output_fs_url="s3://my-bucket/videos",
media_output_http_url="https://cdn.example.com/videos",
)
assert config.media_output_fs_url == "s3://my-bucket/videos"
assert config.media_output_http_url == "https://cdn.example.com/videos"
def test_str_representation(self):
"""Test that __str__ includes key fields."""
config = DiffusionConfig(
......@@ -576,16 +587,20 @@ class TestVideoHandlerConcurrency:
mock_engine.generate = tracker.generate
config = DiffusionConfig(
output_dir="/tmp/test_videos",
media_output_fs_url="file:///tmp/test_media",
default_fps=24,
default_seconds=4,
)
handler = VideoGenerationHandler(
component=MagicMock(),
engine=mock_engine,
config=config,
)
with patch(
"dynamo.trtllm.request_handlers.video_diffusion.video_handler.get_fs",
return_value=MagicMock(),
):
handler = VideoGenerationHandler(
component=MagicMock(),
engine=mock_engine,
config=config,
)
return handler, tracker
......@@ -616,8 +631,11 @@ class TestVideoHandlerConcurrency:
requests = [self._make_request() for _ in range(3)]
with patch(
"dynamo.trtllm.request_handlers.video_diffusion.video_handler.encode_to_mp4",
return_value="/tmp/test.mp4",
"dynamo.trtllm.request_handlers.video_diffusion.video_handler.encode_to_mp4_bytes",
return_value=b"fake_mp4_bytes",
), patch(
"dynamo.trtllm.request_handlers.video_diffusion.video_handler.upload_to_fs",
return_value="http://fake/video.mp4",
):
await asyncio.gather(
*(self._drain_generator(handler, req) for req in requests)
......@@ -631,3 +649,156 @@ class TestVideoHandlerConcurrency:
f"Expected max_concurrent=1 (serialized), got {tracker.max_concurrent}. "
"Pipeline was accessed concurrently — this would corrupt visual_gen state."
)
# =============================================================================
# Part 6: VideoGenerationHandler Response Format Tests
# =============================================================================
class TestVideoHandlerResponseFormats:
"""Tests for VideoGenerationHandler generate() response format branching."""
def _make_handler(self):
"""Create a handler with mocked engine and fs."""
import numpy as np
from dynamo.trtllm.request_handlers.video_diffusion.video_handler import (
VideoGenerationHandler,
)
mock_engine = MagicMock()
mock_engine.generate = MagicMock(
return_value=np.zeros((4, 64, 64, 3), dtype=np.uint8)
)
config = DiffusionConfig(
media_output_fs_url="file:///tmp/test_media",
media_output_http_url="https://cdn.example.com/media",
default_fps=24,
default_seconds=4,
)
with patch(
"dynamo.trtllm.request_handlers.video_diffusion.video_handler.get_fs",
return_value=MagicMock(),
):
handler = VideoGenerationHandler(
component=MagicMock(),
engine=mock_engine,
config=config,
)
return handler
@pytest.mark.asyncio
async def test_url_response_format(self):
"""Test generate() with url response format calls upload_to_fs."""
handler = self._make_handler()
request = {
"prompt": "a test video",
"model": "test-model",
"response_format": "url",
}
with patch(
"dynamo.trtllm.request_handlers.video_diffusion.video_handler.encode_to_mp4_bytes",
return_value=b"fake_mp4",
), patch(
"dynamo.trtllm.request_handlers.video_diffusion.video_handler.upload_to_fs",
return_value="https://cdn.example.com/media/videos/test.mp4",
) as mock_upload:
results = []
async for result in handler.generate(request, MagicMock()):
results.append(result)
assert len(results) == 1
response = results[0]
assert response["status"] == "completed"
assert len(response["data"]) == 1
assert (
response["data"][0]["url"]
== "https://cdn.example.com/media/videos/test.mp4"
)
mock_upload.assert_called_once()
@pytest.mark.asyncio
async def test_b64_response_format(self):
"""Test generate() with b64_json response format returns base64 encoded video."""
handler = self._make_handler()
request = {
"prompt": "a test video",
"model": "test-model",
"response_format": "b64_json",
}
with patch(
"dynamo.trtllm.request_handlers.video_diffusion.video_handler.encode_to_mp4_bytes",
return_value=b"fake_mp4_bytes",
):
results = []
async for result in handler.generate(request, MagicMock()):
results.append(result)
assert len(results) == 1
response = results[0]
assert response["status"] == "completed"
assert len(response["data"]) == 1
assert response["data"][0]["b64_json"] is not None
assert response["data"][0].get("url") is None
# Verify valid base64
import base64
decoded = base64.b64decode(response["data"][0]["b64_json"])
assert decoded == b"fake_mp4_bytes"
@pytest.mark.asyncio
async def test_default_response_format_is_url(self):
"""Test that generate() defaults to url response format."""
handler = self._make_handler()
request = {
"prompt": "a test video",
"model": "test-model",
# No response_format specified
}
with patch(
"dynamo.trtllm.request_handlers.video_diffusion.video_handler.encode_to_mp4_bytes",
return_value=b"fake_mp4",
), patch(
"dynamo.trtllm.request_handlers.video_diffusion.video_handler.upload_to_fs",
return_value="https://cdn.example.com/media/videos/test.mp4",
) as mock_upload:
results = []
async for result in handler.generate(request, MagicMock()):
results.append(result)
assert len(results) == 1
# Default should be "url" format, so upload_to_fs should be called
mock_upload.assert_called_once()
assert results[0]["data"][0]["url"] is not None
@pytest.mark.asyncio
async def test_error_response_on_failure(self):
"""Test that generate() returns error response on engine failure."""
handler = self._make_handler()
handler.engine.generate = MagicMock(side_effect=RuntimeError("GPU OOM"))
request = {
"prompt": "a test video",
"model": "test-model",
}
results = []
async for result in handler.generate(request, MagicMock()):
results.append(result)
assert len(results) == 1
response = results[0]
assert response["status"] == "failed"
assert response["error"] == "GPU OOM"
assert response["data"] == []
......@@ -60,7 +60,8 @@ async def init_video_diffusion_worker(
event_plane=config.event_plane,
model_path=config.model,
served_model_name=config.served_model_name,
output_dir=config.output_dir,
media_output_fs_url=config.media_output_fs_url,
media_output_http_url=config.media_output_http_url,
default_height=config.default_height,
default_width=config.default_width,
default_num_frames=config.default_num_frames,
......
......@@ -138,15 +138,7 @@ class DynamoVllmArgGroup(ArgGroup):
help="Path to vLLM-Omni stage configuration YAML file for --omni mode (optional).",
)
# Video diffusion output
# TODO: Propose an alternate design to switch to AsyncOmniEngine args while using vLLM-Omni
add_argument(
g,
flag_name="--video-output-dir",
env_var="DYN_VLLM_VIDEO_OUTPUT_DIR",
default="/tmp/dynamo_videos", # noqa: S108
help="Directory to save generated video MP4 files.",
)
# Video encoding
add_argument(
g,
flag_name="--default-video-fps",
......@@ -240,6 +232,13 @@ class DynamoVllmArgGroup(ArgGroup):
default=False,
help="Enable CPU offloading for diffusion models to reduce GPU memory usage.",
)
add_negatable_bool_argument(
g,
flag_name="--enforce-eager",
env_var="DYN_VLLM_ENFORCE_EAGER",
default=False,
help="Disable torch.compile and force eager execution for diffusion models.",
)
# Diffusion parallel configuration
add_argument(
g,
......@@ -300,8 +299,7 @@ class DynamoVllmConfig(ConfigBase):
omni: bool
stage_configs_path: Optional[str] = None
# Video diffusion output
video_output_dir: str = "/tmp/dynamo_videos" # noqa: S108
# Video encoding
default_video_fps: int = 16
# Diffusion engine-level parameters (passed to AsyncOmni constructor)
......
......@@ -17,6 +17,7 @@ from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
from dynamo import prometheus_names
from dynamo.common.config_dump import dump_config
from dynamo.common.storage import get_fs
from dynamo.common.utils.endpoint_types import parse_endpoint_types
from dynamo.common.utils.output_modalities import get_output_modalities
from dynamo.common.utils.prometheus import (
......@@ -907,7 +908,6 @@ async def init_omni(
Supports text-to-text, text-to-image, and text-to-video generation
through a single unified OmniHandler.
"""
# Lazy import to avoid loading vllm-omni unless explicitly needed
from dynamo.vllm.omni import OmniHandler
generate_endpoint = runtime.endpoint(
......@@ -915,6 +915,11 @@ async def init_omni(
)
component = generate_endpoint.component()
# Initialize media filesystem for storing generated images/videos
media_fs = (
get_fs(config.media_output_fs_url) if config.media_output_fs_url else None
)
# Initialize unified OmniHandler
handler = OmniHandler(
runtime=runtime,
......@@ -922,6 +927,8 @@ async def init_omni(
config=config,
default_sampling_params={},
shutdown_event=shutdown_event,
media_output_fs=media_fs,
media_output_http_url=config.media_output_http_url,
)
logger.info(f"Omni worker initialized for model: {config.model}")
......
......@@ -131,7 +131,6 @@ class BaseOmniHandler(BaseWorkerHandler):
) -> AsyncGenerator[Dict, None]:
"""Generate outputs using AsyncOmni orchestrator with OpenAI-compatible format.
Routes to OpenAI mode (detokenized text) or token mode based on config.
Subclasses should override ``_generate_openai_mode`` for custom output handling.
"""
request_id = context.id()
......
......@@ -3,14 +3,15 @@
import asyncio
import base64
import logging
import os
import tempfile
import time
import uuid
from dataclasses import dataclass
from io import BytesIO
from typing import Any, AsyncGenerator, Dict, Union
from typing import Any, AsyncGenerator, Dict, Optional, Union
from diffusers.utils import export_to_video
from fsspec.implementations.dirfs import DirFileSystem
from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt
from dynamo.common.protocols.image_protocol import (
......@@ -23,6 +24,7 @@ from dynamo.common.protocols.video_protocol import (
NvVideosResponse,
VideoData,
)
from dynamo.common.storage import upload_to_fs
from dynamo.common.utils.output_modalities import RequestType, parse_request_type
from dynamo.common.utils.video_utils import (
compute_num_frames,
......@@ -33,9 +35,7 @@ from dynamo.vllm.omni.base_handler import BaseOmniHandler
logger = logging.getLogger(__name__)
# TODO: Migrate to fs_url based approach in another PR
DEFAULT_VIDEO_FPS = 16
DEFAULT_VIDEO_OUTPUT_DIR = "/tmp/dynamo_videos" # noqa: S108
@dataclass
......@@ -59,40 +59,6 @@ class EngineInputs:
response_format: str | None = None
def prepare_image_output(images: list, response_format: str | None = None):
"""Prepare image output for response.
Args:
images: List of PIL Image objects.
response_format: Response format.
Returns:
List of image URLs or base64 strings.
"""
## This is a temporary function to prepare image output for response.
## Right now, there are different utilities across components that uploads image/video outputs to urls or b64_json.
## (ayushag) TODO: follow up, move all the utilities to common
outlist = []
for img in images:
if response_format == "url":
output_dir = "/tmp/dynamo_images" # noqa: S108
os.makedirs(output_dir, exist_ok=True)
img_path = os.path.join(output_dir, f"{uuid.uuid4()}.png")
img.save(img_path)
outlist.append(img_path)
elif response_format == "b64_json" or response_format is None:
# convert image to base64
buffer = BytesIO()
img.save(buffer, format="PNG")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
data_url = f"data:image/png;base64,{img_base64}"
outlist.append(data_url)
else:
raise ValueError(f"Invalid response format: {response_format}")
return outlist
class OmniHandler(BaseOmniHandler):
"""Unified handler for multi-stage pipelines using vLLM-Omni.
......@@ -106,6 +72,8 @@ class OmniHandler(BaseOmniHandler):
config,
default_sampling_params: Dict[str, Any],
shutdown_event: asyncio.Event | None = None,
media_output_fs: Optional[DirFileSystem] = None,
media_output_http_url: Optional[str] = None,
):
"""Initialize the unified Omni handler.
......@@ -115,6 +83,8 @@ class OmniHandler(BaseOmniHandler):
config: Parsed Config object from args.py.
default_sampling_params: Default sampling parameters dict.
shutdown_event: Optional asyncio event for graceful shutdown.
media_output_fs: Filesystem for storing generated images/videos.
media_output_http_url: Base URL for rewriting media paths in responses.
"""
super().__init__(
runtime=runtime,
......@@ -123,6 +93,8 @@ class OmniHandler(BaseOmniHandler):
default_sampling_params=default_sampling_params,
shutdown_event=shutdown_event,
)
self.media_output_fs = media_output_fs
self.media_output_http_url = media_output_http_url
async def generate(
self, request: Dict[str, Any], context
......@@ -194,7 +166,7 @@ class OmniHandler(BaseOmniHandler):
fps=inputs.fps,
)
else:
chunk = self._format_image_chunk(
chunk = await self._format_image_chunk(
stage_output.images,
request_id,
response_format=inputs.response_format,
......@@ -339,14 +311,51 @@ class OmniHandler(BaseOmniHandler):
fps=fps,
)
def _format_image_chunk(
async def _prepare_image_output(
self, images: list, request_id: str, response_format: str | None = None
) -> list:
"""Prepare image output for response.
Args:
images: List of PIL Image objects.
request_id: Unique request identifier.
response_format: Response format ("url" or "b64_json").
Returns:
List of image URLs or base64 data-URL strings.
"""
outlist = []
for img in images:
buffer = BytesIO()
img.save(buffer, format="PNG")
image_bytes = buffer.getvalue()
if response_format == "url":
storage_path = f"images/{request_id}/{uuid.uuid4()}.png"
url = await upload_to_fs(
self.media_output_fs,
storage_path,
image_bytes,
self.media_output_http_url,
)
outlist.append(url)
elif response_format == "b64_json" or response_format is None:
img_base64 = base64.b64encode(image_bytes).decode("utf-8")
data_url = f"data:image/png;base64,{img_base64}"
outlist.append(data_url)
else:
raise ValueError(f"Invalid response format: {response_format}")
return outlist
async def _format_image_chunk(
self,
images: list,
request_id: str,
response_format: str | None = None,
request_type: RequestType = RequestType.IMAGE_GENERATION,
) -> Dict[str, Any] | None:
"""Format image output as OpenAI chat completion chunk with base64 data URLs.
"""Format image output for the appropriate endpoint response.
Args:
images: List of PIL Image objects generated by AsyncOmni engine.
......@@ -355,17 +364,16 @@ class OmniHandler(BaseOmniHandler):
request_type: Request type (chat completion, image generation).
Returns:
Dict[str, Any] | None: Formatted chunk, or None if no images generated.
Formatted response dict, or None if no images generated.
"""
if not images:
return self._error_chunk(request_id, "No images generated")
data_urls = prepare_image_output(images, response_format)
data_urls = await self._prepare_image_output(
images, request_id, response_format
)
if request_type == RequestType.CHAT_COMPLETION:
# This branch is used when user send request via /v1/chat/completions endpoint.
# We need to return chat completion chunk with image_url content part.
chunk = {
"id": request_id,
"created": int(time.time()),
......@@ -387,14 +395,11 @@ class OmniHandler(BaseOmniHandler):
}
return chunk
elif request_type == RequestType.IMAGE_GENERATION:
# This branch is used when user send request via /v1/images/generations endpoint.
# This will return NvImagesResponse with list of ImageData objects.
image_data_list = []
for data_url in data_urls:
if response_format == "url":
image_data_list.append(ImageData(url=data_url))
elif response_format == "b64_json" or response_format is None:
# strip explicit prefix if present
if data_url.startswith("data:image"):
_, b64_part = data_url.split(",", 1)
image_data_list.append(ImageData(b64_json=b64_part))
......@@ -435,19 +440,21 @@ class OmniHandler(BaseOmniHandler):
f"(fps={fps})"
)
os.makedirs(DEFAULT_VIDEO_OUTPUT_DIR, exist_ok=True)
video_path = os.path.join(DEFAULT_VIDEO_OUTPUT_DIR, f"{request_id}.mp4")
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None,
export_to_video,
frame_list,
video_path,
fps,
# Encode frames to MP4 via temp file, then read bytes for upload
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=True) as tmp:
await asyncio.to_thread(export_to_video, frame_list, tmp.name, fps)
video_bytes = tmp.read()
# Upload via filesystem
storage_path = f"videos/{request_id}.mp4"
video_url = await upload_to_fs(
self.media_output_fs,
storage_path,
video_bytes,
self.media_output_http_url,
)
logger.info(f"Video saved to {video_path} for request {request_id}")
logger.info(f"Video uploaded to {video_url} for request {request_id}")
inference_time = time.time() - start_time
......@@ -458,7 +465,7 @@ class OmniHandler(BaseOmniHandler):
status="completed",
progress=100,
created=int(time.time()),
data=[VideoData(url=video_path)],
data=[VideoData(url=video_url)],
inference_time_s=inference_time,
)
return response.model_dump()
......
......@@ -9,7 +9,6 @@ from dynamo.common.protocols.image_protocol import NvCreateImageRequest
from dynamo.common.protocols.video_protocol import NvCreateVideoRequest
from dynamo.common.utils.output_modalities import RequestType
# TODO: Install vLLM omni dependencies in CI container so this skip is no longer needed.
try:
from dynamo.vllm.omni.omni_handler import (
EngineInputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment