Unverified Commit 2d86b81d authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: flux benchmarking script + code clean (#8083)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 2e7a1e6c
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Generic aiperf benchmark for vLLM-Omni text-to-image generation.
# Assumes the server (Dynamo or native vllm-omni) is already running.
#
# Usage:
# bash aiperf_image_gen.sh [OPTIONS]
#
# Options:
# --model <model> Model to benchmark (default: black-forest-labs/FLUX.2-klein-4B)
# --url <url> Server URL (default: http://localhost:8000)
# --concurrency <n> Number of concurrent requests (default: 1)
# --request-count <n> Total requests to send (default: 10)
# --warmup-count <n> Warmup requests before measurement (default: 2)
# --image-size <WxH> Generated image size (default: 1024x1024)
# --response-format <fmt> Response format: url or b64_json (default: url)
# --prompt-tokens-mean <n> Mean synthetic prompt length in tokens (default: 50)
# --prompt-tokens-stddev <n> Stddev of synthetic prompt length (default: 10)
# -h, --help Show this help message
#
# Examples:
# bash aiperf_image_gen.sh
# bash aiperf_image_gen.sh --model zai-org/GLM-Image --concurrency 4
# bash aiperf_image_gen.sh --model Qwen/Qwen-Image --image-size 512x512 --request-count 20
MODEL="black-forest-labs/FLUX.2-klein-4B"
URL="http://localhost:8000"
CONCURRENCY=1
REQUEST_COUNT=10
WARMUP_COUNT=2
IMAGE_SIZE="1024x1024"
RESPONSE_FORMAT="url"
PROMPT_TOKENS_MEAN=50
PROMPT_TOKENS_STDDEV=10
ARTIFACT_DIR=""
while [[ $# -gt 0 ]]; do
case $1 in
--model) MODEL=$2; shift 2 ;;
--url) URL=$2; shift 2 ;;
--concurrency) CONCURRENCY=$2; shift 2 ;;
--request-count) REQUEST_COUNT=$2; shift 2 ;;
--warmup-count) WARMUP_COUNT=$2; shift 2 ;;
--image-size) IMAGE_SIZE=$2; shift 2 ;;
--response-format) RESPONSE_FORMAT=$2; shift 2 ;;
--prompt-tokens-mean) PROMPT_TOKENS_MEAN=$2; shift 2 ;;
--prompt-tokens-stddev) PROMPT_TOKENS_STDDEV=$2; shift 2 ;;
--artifact-dir) ARTIFACT_DIR=$2; shift 2 ;;
-h|--help)
sed -n '/^# Usage/,/^[^#]/p' "$0" | grep '^#' | sed 's/^# \?//'
exit 0 ;;
*) echo "Unknown option: $1"; exit 1 ;;
esac
done
AIPERF_ARGS=(
aiperf profile
--model "$MODEL"
--tokenizer gpt2
--url "$URL"
--endpoint-type image-generation
--synthetic-input-tokens-mean "$PROMPT_TOKENS_MEAN"
--synthetic-input-tokens-stddev "$PROMPT_TOKENS_STDDEV"
--extra-inputs "size:${IMAGE_SIZE}"
--extra-inputs "response_format:${RESPONSE_FORMAT}"
--concurrency "$CONCURRENCY"
--request-count "$REQUEST_COUNT"
--warmup-request-count "$WARMUP_COUNT"
--ui none
--no-server-metrics
)
if [[ -n "$ARTIFACT_DIR" ]]; then
AIPERF_ARGS+=(--artifact-dir "$ARTIFACT_DIR")
fi
"${AIPERF_ARGS[@]}"
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""Omni-specific argument parsing for python -m dynamo.vllm.omni.""" """Omni-specific argument parsing for python -m dynamo.vllm.omni."""
import argparse import argparse
import dataclasses
import logging import logging
from typing import Optional from typing import Optional
...@@ -24,15 +25,50 @@ from dynamo.common.configuration.utils import add_argument, add_negatable_bool_a ...@@ -24,15 +25,50 @@ from dynamo.common.configuration.utils import add_argument, add_negatable_bool_a
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class OmniArgGroup(ArgGroup): @dataclasses.dataclass
"""Diffusion pipeline kwargs passed through to AsyncOmni() constructor. class OmniDiffusionKwargs:
"""AsyncOmni constructor kwargs for diffusion engine configuration.
Every field here is passed directly to AsyncOmni(**kwargs) and consumed by
_create_default_diffusion_stage_cfg() in vllm-omni. Adding a new vllm-omni
diffusion flag only requires adding it here and to OmniArgGroup — the
passthrough in base_handler is automatic.
"""
enable_layerwise_offload: bool = False
layerwise_num_gpu_layers: int = 1
vae_use_slicing: bool = False
vae_use_tiling: bool = False
boundary_ratio: float = 0.875
flow_shift: Optional[float] = None
cache_backend: Optional[str] = None
cache_config: Optional[str] = None
enable_cache_dit_summary: bool = False
enable_cpu_offload: bool = False
enforce_eager: bool = False
These are NOT part of OmniEngineArgs (which handles vLLM engine-level @dataclasses.dataclass
args like model, tp, max_model_len). Instead they are direct constructor class OmniParallelKwargs:
kwargs for AsyncOmni and need Dynamo-side env-var (DYN_OMNI_*) support, """Diffusion parallelism configuration passed to DiffusionParallelConfig.
so we define them here rather than relying on the upstream arg parser.
Every field here maps 1:1 to a DiffusionParallelConfig field (excluding
tensor_parallel_size which comes from engine_args, and fixed/derived fields).
Adding a new parallelism field only requires adding it here and to OmniArgGroup.
""" """
ulysses_degree: int = 1
ring_degree: int = 1
cfg_parallel_size: int = 1
vae_patch_parallel_size: int = 1
use_hsdp: bool = False
hsdp_shard_size: int = -1
hsdp_replicate_size: int = 1
class OmniArgGroup(ArgGroup):
"""CLI argument definitions for Dynamo vLLM-Omni."""
name = "dynamo-omni" name = "dynamo-omni"
def add_arguments(self, parser) -> None: def add_arguments(self, parser) -> None:
...@@ -49,7 +85,6 @@ class OmniArgGroup(ArgGroup): ...@@ -49,7 +85,6 @@ class OmniArgGroup(ArgGroup):
help="Path to vLLM-Omni stage configuration YAML file (optional).", help="Path to vLLM-Omni stage configuration YAML file (optional).",
) )
# Video encoding
add_argument( add_argument(
g, g,
flag_name="--default-video-fps", flag_name="--default-video-fps",
...@@ -59,7 +94,7 @@ class OmniArgGroup(ArgGroup): ...@@ -59,7 +94,7 @@ class OmniArgGroup(ArgGroup):
help="Default frames per second for generated videos.", help="Default frames per second for generated videos.",
) )
# Layerwise offloading # OmniDiffusionKwargs fields
add_negatable_bool_argument( add_negatable_bool_argument(
g, g,
flag_name="--enable-layerwise-offload", flag_name="--enable-layerwise-offload",
...@@ -75,8 +110,6 @@ class OmniArgGroup(ArgGroup): ...@@ -75,8 +110,6 @@ class OmniArgGroup(ArgGroup):
arg_type=int, arg_type=int,
help="Number of ready layers (blocks) to keep on GPU during generation.", help="Number of ready layers (blocks) to keep on GPU during generation.",
) )
# VAE optimization
add_negatable_bool_argument( add_negatable_bool_argument(
g, g,
flag_name="--vae-use-slicing", flag_name="--vae-use-slicing",
...@@ -91,8 +124,6 @@ class OmniArgGroup(ArgGroup): ...@@ -91,8 +124,6 @@ class OmniArgGroup(ArgGroup):
default=False, default=False,
help="Enable VAE tiling for memory optimization in diffusion models.", help="Enable VAE tiling for memory optimization in diffusion models.",
) )
# Diffusion scheduling
add_argument( add_argument(
g, g,
flag_name="--boundary-ratio", flag_name="--boundary-ratio",
...@@ -113,8 +144,6 @@ class OmniArgGroup(ArgGroup): ...@@ -113,8 +144,6 @@ class OmniArgGroup(ArgGroup):
arg_type=float, arg_type=float,
help="Scheduler flow_shift parameter (5.0 for 720p, 12.0 for 480p).", help="Scheduler flow_shift parameter (5.0 for 720p, 12.0 for 480p).",
) )
# Cache acceleration
add_argument( add_argument(
g, g,
flag_name="--cache-backend", flag_name="--cache-backend",
...@@ -141,8 +170,6 @@ class OmniArgGroup(ArgGroup): ...@@ -141,8 +170,6 @@ class OmniArgGroup(ArgGroup):
default=False, default=False,
help="Enable cache-dit summary logging after diffusion forward passes.", help="Enable cache-dit summary logging after diffusion forward passes.",
) )
# Execution mode
add_negatable_bool_argument( add_negatable_bool_argument(
g, g,
flag_name="--enable-cpu-offload", flag_name="--enable-cpu-offload",
...@@ -204,7 +231,7 @@ class OmniArgGroup(ArgGroup): ...@@ -204,7 +231,7 @@ class OmniArgGroup(ArgGroup):
help="Maximum size in bytes for reference audio files (default: 50MB).", help="Maximum size in bytes for reference audio files (default: 50MB).",
) )
# Diffusion parallel configuration # OmniParallelKwargs fields
add_argument( add_argument(
g, g,
flag_name="--ulysses-degree", flag_name="--ulysses-degree",
...@@ -227,9 +254,43 @@ class OmniArgGroup(ArgGroup): ...@@ -227,9 +254,43 @@ class OmniArgGroup(ArgGroup):
env_var="DYN_OMNI_CFG_PARALLEL_SIZE", env_var="DYN_OMNI_CFG_PARALLEL_SIZE",
default=1, default=1,
arg_type=int, arg_type=int,
choices=[1, 2], choices=[1, 2, 3],
help="Number of GPUs used for classifier free guidance parallelism.", help="Number of GPUs used for classifier free guidance parallelism.",
) )
add_argument(
g,
flag_name="--vae-patch-parallel-size",
env_var="DYN_OMNI_VAE_PATCH_PARALLEL_SIZE",
default=1,
arg_type=int,
help="Number of ranks used for VAE patch/tile parallelism during decode/encode.",
)
add_negatable_bool_argument(
g,
flag_name="--use-hsdp",
env_var="DYN_OMNI_USE_HSDP",
default=False,
help=(
"Enable Hybrid Sharded Data Parallel (HSDP) for diffusion models. "
"Shards model weights across GPUs to reduce per-GPU memory usage."
),
)
add_argument(
g,
flag_name="--hsdp-shard-size",
env_var="DYN_OMNI_HSDP_SHARD_SIZE",
default=-1,
arg_type=int,
help="Number of GPUs to shard model weights across when using HSDP (-1 = auto).",
)
add_argument(
g,
flag_name="--hsdp-replicate-size",
env_var="DYN_OMNI_HSDP_REPLICATE_SIZE",
default=1,
arg_type=int,
help="Number of HSDP replica groups (default: 1).",
)
# Disaggregated stage worker flags # Disaggregated stage worker flags
add_argument( add_argument(
...@@ -244,7 +305,6 @@ class OmniArgGroup(ArgGroup): ...@@ -244,7 +305,6 @@ class OmniArgGroup(ArgGroup):
"Requires --stage-configs-path." "Requires --stage-configs-path."
), ),
) )
add_negatable_bool_argument( add_negatable_bool_argument(
g, g,
flag_name="--omni-router", flag_name="--omni-router",
...@@ -263,30 +323,18 @@ class OmniConfig(DynamoRuntimeConfig): ...@@ -263,30 +323,18 @@ class OmniConfig(DynamoRuntimeConfig):
component: str = "backend" component: str = "backend"
endpoint: Optional[str] = None endpoint: Optional[str] = None
# mirror vLLM
model: str model: str
served_model_name: Optional[str] = None served_model_name: Optional[str] = None
# vLLM-Omni engine args
engine_args: OmniEngineArgs engine_args: OmniEngineArgs
# OmniArgGroup fields (populated by from_cli_args)
stage_configs_path: Optional[str] = None stage_configs_path: Optional[str] = None
default_video_fps: int = 16 default_video_fps: int = 16
enable_layerwise_offload: bool = False
layerwise_num_gpu_layers: int = 1 # Nested structs — each group of fields has a clear destination
vae_use_slicing: bool = False diffusion: OmniDiffusionKwargs = dataclasses.field(
vae_use_tiling: bool = False default_factory=OmniDiffusionKwargs
boundary_ratio: float = 0.875 )
flow_shift: Optional[float] = None parallel: OmniParallelKwargs = dataclasses.field(default_factory=OmniParallelKwargs)
cache_backend: Optional[str] = None
cache_config: Optional[str] = None
enable_cache_dit_summary: bool = False
enable_cpu_offload: bool = False
enforce_eager: bool = False
ulysses_degree: int = 1
ring_degree: int = 1
cfg_parallel_size: int = 1
# TTS parameters # TTS parameters
tts_max_instructions_length: int = 500 tts_max_instructions_length: int = 500
...@@ -299,15 +347,36 @@ class OmniConfig(DynamoRuntimeConfig): ...@@ -299,15 +347,36 @@ class OmniConfig(DynamoRuntimeConfig):
stage_id: Optional[int] = None stage_id: Optional[int] = None
omni_router: bool = False omni_router: bool = False
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "OmniConfig":
config = super().from_cli_args(args)
config.diffusion = dataclasses.replace(
OmniDiffusionKwargs(),
**{
f.name: getattr(args, f.name)
for f in dataclasses.fields(OmniDiffusionKwargs)
if hasattr(args, f.name)
},
)
config.parallel = dataclasses.replace(
OmniParallelKwargs(),
**{
f.name: getattr(args, f.name)
for f in dataclasses.fields(OmniParallelKwargs)
if hasattr(args, f.name)
},
)
return config
def validate(self) -> None: def validate(self) -> None:
DynamoRuntimeConfig.validate(self) DynamoRuntimeConfig.validate(self)
if self.default_video_fps <= 0: if self.default_video_fps <= 0:
raise ValueError("--default-video-fps must be > 0") raise ValueError("--default-video-fps must be > 0")
if self.ulysses_degree <= 0: if self.parallel.ulysses_degree <= 0:
raise ValueError("--ulysses-degree must be > 0") raise ValueError("--ulysses-degree must be > 0")
if self.ring_degree <= 0: if self.parallel.ring_degree <= 0:
raise ValueError("--ring-degree must be > 0") raise ValueError("--ring-degree must be > 0")
if not (0 < self.boundary_ratio <= 1): if not (0 < self.diffusion.boundary_ratio <= 1):
raise ValueError("--boundary-ratio must be in (0, 1]") raise ValueError("--boundary-ratio must be in (0, 1]")
if self.stage_configs_path is None: if self.stage_configs_path is None:
if self.stage_id is not None: if self.stage_id is not None:
...@@ -334,7 +403,6 @@ def parse_omni_args() -> OmniConfig: ...@@ -334,7 +403,6 @@ def parse_omni_args() -> OmniConfig:
dynamo_runtime_argspec.add_arguments(parser) dynamo_runtime_argspec.add_arguments(parser)
omni_argspec.add_arguments(parser) omni_argspec.add_arguments(parser)
# Add vLLM-Omni engine args
vg = parser.add_argument_group( vg = parser.add_argument_group(
"vLLM-Omni Engine Options. Please refer to vLLM-Omni documentation for more details." "vLLM-Omni Engine Options. Please refer to vLLM-Omni documentation for more details."
) )
...@@ -349,7 +417,6 @@ def parse_omni_args() -> OmniConfig: ...@@ -349,7 +417,6 @@ def parse_omni_args() -> OmniConfig:
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
config = OmniConfig.from_cli_args(args) config = OmniConfig.from_cli_args(args)
# Default endpoint to "generate" if not explicitly provided by user
if config.endpoint is None: if config.endpoint is None:
config.endpoint = "generate" config.endpoint = "generate"
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""Base handler for vLLM-Omni multi-stage pipelines.""" """Base handler for vLLM-Omni multi-stage pipelines."""
import asyncio import asyncio
import dataclasses
import logging import logging
import time import time
from typing import Any, AsyncGenerator, Dict from typing import Any, AsyncGenerator, Dict
...@@ -74,31 +75,17 @@ class BaseOmniHandler(BaseWorkerHandler[Dict[str, Any], Dict[str, Any]]): ...@@ -74,31 +75,17 @@ class BaseOmniHandler(BaseWorkerHandler[Dict[str, Any], Dict[str, Any]]):
if config.stage_configs_path: if config.stage_configs_path:
omni_kwargs["stage_configs_path"] = config.stage_configs_path omni_kwargs["stage_configs_path"] = config.stage_configs_path
# Diffusion engine-level params — read directly from config namespace for field, value in dataclasses.asdict(config.diffusion).items():
diffusion_fields = [
"enable_layerwise_offload",
"layerwise_num_gpu_layers",
"vae_use_slicing",
"vae_use_tiling",
"boundary_ratio",
"flow_shift",
"cache_backend",
"cache_config",
"enable_cache_dit_summary",
"enable_cpu_offload",
"enforce_eager",
]
for field in diffusion_fields:
value = getattr(config, field, None)
if value is not None: if value is not None:
omni_kwargs[field] = value omni_kwargs[field] = value
# Build DiffusionParallelConfig if available # tensor_parallel_size comes from engine_args (vLLM's --tensor-parallel-size)
if DiffusionParallelConfig is not None: if DiffusionParallelConfig is not None:
parallel_config = DiffusionParallelConfig( parallel_config = DiffusionParallelConfig(
ulysses_degree=getattr(config, "ulysses_degree", 1), tensor_parallel_size=getattr(
ring_degree=getattr(config, "ring_degree", 1), config.engine_args, "tensor_parallel_size", 1
cfg_parallel_size=getattr(config, "cfg_parallel_size", 1), ),
**dataclasses.asdict(config.parallel),
) )
omni_kwargs["parallel_config"] = parallel_config omni_kwargs["parallel_config"] = parallel_config
else: else:
......
...@@ -2,18 +2,20 @@ ...@@ -2,18 +2,20 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import logging import logging
import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Optional, Union, cast from typing import Any, AsyncGenerator, Dict, Optional, Union, cast
import PIL.Image import PIL.Image
from fsspec.implementations.dirfs import DirFileSystem from fsspec.implementations.dirfs import DirFileSystem
from vllm.sampling_params import SamplingParams
from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt
from dynamo._core import Context from dynamo._core import Context
from dynamo.common.multimodal import ImageLoader from dynamo.common.multimodal import ImageLoader
from dynamo.common.protocols.audio_protocol import NvCreateAudioSpeechRequest from dynamo.common.protocols.audio_protocol import NvCreateAudioSpeechRequest
from dynamo.common.protocols.image_protocol import NvCreateImageRequest from dynamo.common.protocols.image_protocol import ImageNvExt, NvCreateImageRequest
from dynamo.common.protocols.video_protocol import NvCreateVideoRequest from dynamo.common.protocols.video_protocol import NvCreateVideoRequest, VideoNvExt
from dynamo.common.utils.output_modalities import RequestType, parse_request_type from dynamo.common.utils.output_modalities import RequestType, parse_request_type
from dynamo.common.utils.video_utils import compute_num_frames, parse_size from dynamo.common.utils.video_utils import compute_num_frames, parse_size
from dynamo.llm.exceptions import EngineShutdown from dynamo.llm.exceptions import EngineShutdown
...@@ -255,35 +257,60 @@ class OmniHandler(BaseOmniHandler): ...@@ -255,35 +257,60 @@ class OmniHandler(BaseOmniHandler):
fps=0, fps=0,
) )
@staticmethod
def _update_if_not_none(object: Any, key: str, val: Any) -> None:
if val is not None:
setattr(object, key, val)
def _build_sampling_params_list(
self, diffusion_sp: OmniDiffusionSamplingParams
) -> list:
# This is in sync with how vllm-omni builds sampling params currently.
defaults = list(self.engine_client.default_sampling_params_list or [])
result = []
for i, default in enumerate(defaults):
stage_type = self.engine_client.engine.get_stage_metadata(i).get(
"stage_type", "llm"
)
if stage_type == "diffusion":
result.append(diffusion_sp)
else:
result.append(
default.clone() if hasattr(default, "clone") else SamplingParams()
)
return result if result else [diffusion_sp]
def _engine_inputs_from_image(self, req: NvCreateImageRequest) -> EngineInputs: def _engine_inputs_from_image(self, req: NvCreateImageRequest) -> EngineInputs:
"""Build engine inputs from an NvCreateImageRequest.""" """Build engine inputs from an NvCreateImageRequest."""
width, height = parse_size(req.size, default_w=1024, default_h=1024) width, height = parse_size(req.size, default_w=1024, default_h=1024)
nvext = req.nvext nvext = req.nvext or ImageNvExt()
prompt = OmniTextPrompt( prompt = OmniTextPrompt(prompt=req.prompt)
prompt=req.prompt, if nvext and nvext.negative_prompt is not None:
negative_prompt=( prompt.negative_prompt = nvext.negative_prompt
nvext.negative_prompt if nvext and nvext.negative_prompt else None
),
)
sp = OmniDiffusionSamplingParams( sp = OmniDiffusionSamplingParams(
height=height, height=height,
width=width, width=width,
) )
if req.n is not None:
sp.num_outputs_per_prompt = req.n # TODO: Apply LoRA Request params here and move to shared utilities for disaggregated stages to use as well.
if nvext:
if nvext.num_inference_steps is not None: self._update_if_not_none(sp, "num_outputs_per_prompt", req.n)
sp.num_inference_steps = nvext.num_inference_steps
if nvext.guidance_scale is not None: self._update_if_not_none(sp, "num_inference_steps", nvext.num_inference_steps)
sp.guidance_scale = nvext.guidance_scale self._update_if_not_none(sp, "guidance_scale", nvext.guidance_scale)
if nvext.seed is not None: # If seed is not provided, generate a random one to ensure
sp.seed = nvext.seed # a proper generator is initialized in the backend.
# This fixes issues where using the default global generator
# might produce blurry images in some environments.
sp.seed = (
nvext.seed if nvext.seed is not None else random.randint(0, 2**32 - 1)
)
return EngineInputs( return EngineInputs(
prompt=prompt, prompt=prompt,
sampling_params_list=[sp], sampling_params_list=self._build_sampling_params_list(sp),
request_type=RequestType.IMAGE_GENERATION, request_type=RequestType.IMAGE_GENERATION,
response_format=req.response_format, response_format=req.response_format,
) )
...@@ -302,25 +329,19 @@ class OmniHandler(BaseOmniHandler): ...@@ -302,25 +329,19 @@ class OmniHandler(BaseOmniHandler):
I2V pipeline pre-process can use it. I2V pipeline pre-process can use it.
""" """
width, height = parse_size(req.size) width, height = parse_size(req.size)
nvext = req.nvext nvext = req.nvext or VideoNvExt()
nvext_fps = nvext.fps if nvext else None
nvext_num_frames = nvext.num_frames if nvext else None
num_frames = compute_num_frames( num_frames = compute_num_frames(
num_frames=nvext_num_frames, num_frames=nvext.num_frames,
seconds=req.seconds, seconds=req.seconds,
fps=nvext_fps, fps=nvext.fps,
default_fps=DEFAULT_VIDEO_FPS, default_fps=DEFAULT_VIDEO_FPS,
) )
fps = nvext_fps if nvext_fps is not None else DEFAULT_VIDEO_FPS fps = nvext.fps if nvext.fps is not None else DEFAULT_VIDEO_FPS
prompt = OmniTextPrompt( prompt = OmniTextPrompt(prompt=req.prompt)
prompt=req.prompt, if nvext.negative_prompt is not None:
negative_prompt=( prompt.negative_prompt = nvext.negative_prompt
nvext.negative_prompt if nvext and nvext.negative_prompt else None
),
)
if image is not None: if image is not None:
prompt["multi_modal_data"] = {"image": image} prompt["multi_modal_data"] = {"image": image}
...@@ -335,19 +356,14 @@ class OmniHandler(BaseOmniHandler): ...@@ -335,19 +356,14 @@ class OmniHandler(BaseOmniHandler):
width=width, width=width,
num_frames=num_frames, num_frames=num_frames,
) )
if nvext: self._update_if_not_none(sp, "num_inference_steps", nvext.num_inference_steps)
if nvext.num_inference_steps is not None: self._update_if_not_none(sp, "guidance_scale", nvext.guidance_scale)
sp.num_inference_steps = nvext.num_inference_steps sp.seed = (
if nvext.guidance_scale is not None: nvext.seed if nvext.seed is not None else random.randint(0, 2**32 - 1)
sp.guidance_scale = nvext.guidance_scale )
if nvext.seed is not None: self._update_if_not_none(sp, "boundary_ratio", nvext.boundary_ratio)
sp.seed = nvext.seed self._update_if_not_none(sp, "guidance_scale_2", nvext.guidance_scale_2)
if nvext.boundary_ratio is not None: self._update_if_not_none(sp, "fps", fps)
sp.boundary_ratio = nvext.boundary_ratio
if nvext.guidance_scale_2 is not None:
sp.guidance_scale_2 = nvext.guidance_scale_2
if fps is not None:
sp.fps = fps
logger.info( logger.info(
f"Video diffusion request: prompt='{req.prompt[:50]}...', " f"Video diffusion request: prompt='{req.prompt[:50]}...', "
...@@ -356,7 +372,7 @@ class OmniHandler(BaseOmniHandler): ...@@ -356,7 +372,7 @@ class OmniHandler(BaseOmniHandler):
return EngineInputs( return EngineInputs(
prompt=prompt, prompt=prompt,
sampling_params_list=[sp], sampling_params_list=self._build_sampling_params_list(sp),
request_type=RequestType.VIDEO_GENERATION, request_type=RequestType.VIDEO_GENERATION,
fps=fps, fps=fps,
) )
...@@ -3,12 +3,17 @@ ...@@ -3,12 +3,17 @@
"""Unit tests for OmniConfig validation.""" """Unit tests for OmniConfig validation."""
import dataclasses
from types import SimpleNamespace from types import SimpleNamespace
import pytest import pytest
try: try:
from dynamo.vllm.omni.args import OmniConfig from dynamo.vllm.omni.args import (
OmniConfig,
OmniDiffusionKwargs,
OmniParallelKwargs,
)
except ImportError: except ImportError:
pytest.skip("vLLM omni dependencies not available", allow_module_level=True) pytest.skip("vLLM omni dependencies not available", allow_module_level=True)
...@@ -19,11 +24,25 @@ pytestmark = [ ...@@ -19,11 +24,25 @@ pytestmark = [
pytest.mark.pre_merge, pytest.mark.pre_merge,
] ]
_DIFFUSION_FIELDS = {f.name for f in dataclasses.fields(OmniDiffusionKwargs)}
_PARALLEL_FIELDS = {f.name for f in dataclasses.fields(OmniParallelKwargs)}
def _make_omni_config(**overrides) -> OmniConfig: def _make_omni_config(**overrides) -> OmniConfig:
"""Build a minimal OmniConfig with valid defaults, applying overrides.""" """Build a minimal OmniConfig with valid defaults, applying overrides.
defaults: dict = {
# DynamoRuntimeConfig fields Overrides for diffusion fields (e.g. boundary_ratio) and parallel fields
(e.g. ulysses_degree) are automatically routed to the correct nested struct.
"""
diffusion_overrides = {k: v for k, v in overrides.items() if k in _DIFFUSION_FIELDS}
parallel_overrides = {k: v for k, v in overrides.items() if k in _PARALLEL_FIELDS}
flat_overrides = {
k: v
for k, v in overrides.items()
if k not in _DIFFUSION_FIELDS and k not in _PARALLEL_FIELDS
}
flat_defaults: dict = {
"namespace": "dynamo", "namespace": "dynamo",
"component": "backend", "component": "backend",
"endpoint": None, "endpoint": None,
...@@ -42,45 +61,36 @@ def _make_omni_config(**overrides) -> OmniConfig: ...@@ -42,45 +61,36 @@ def _make_omni_config(**overrides) -> OmniConfig:
"output_modalities": None, "output_modalities": None,
"media_output_fs_url": "file:///tmp/dynamo_media", "media_output_fs_url": "file:///tmp/dynamo_media",
"media_output_http_url": None, "media_output_http_url": None,
# OmniConfig fields
"model": "test-model", "model": "test-model",
"served_model_name": None, "served_model_name": None,
"engine_args": SimpleNamespace(), "engine_args": SimpleNamespace(),
"stage_configs_path": None, "stage_configs_path": None,
"default_video_fps": 16, "default_video_fps": 16,
"enable_layerwise_offload": False, "tts_max_instructions_length": 500,
"layerwise_num_gpu_layers": 1, "tts_max_new_tokens_min": 1,
"vae_use_slicing": False, "tts_max_new_tokens_max": 4096,
"vae_use_tiling": False, "tts_ref_audio_timeout": 15,
"boundary_ratio": 0.875, "tts_ref_audio_max_bytes": 50 * 1024 * 1024,
"flow_shift": None,
"cache_backend": None,
"cache_config": None,
"enable_cache_dit_summary": False,
"enable_cpu_offload": False,
"enforce_eager": False,
"ulysses_degree": 1,
"ring_degree": 1,
"cfg_parallel_size": 1,
"stage_id": None, "stage_id": None,
"omni_router": False, "omni_router": False,
} }
defaults.update(overrides) flat_defaults.update(flat_overrides)
obj = OmniConfig.__new__(OmniConfig) obj = OmniConfig.__new__(OmniConfig)
for k, v in defaults.items(): for k, v in flat_defaults.items():
setattr(obj, k, v) setattr(obj, k, v)
obj.diffusion = dataclasses.replace(OmniDiffusionKwargs(), **diffusion_overrides)
obj.parallel = dataclasses.replace(OmniParallelKwargs(), **parallel_overrides)
return obj return obj
def test_omni_config_valid_defaults(): def test_omni_config_valid_defaults():
"""Config with valid defaults passes validation."""
config = _make_omni_config() config = _make_omni_config()
config.validate() # should not raise config.validate()
@pytest.mark.parametrize("fps", [0, -1, -100]) @pytest.mark.parametrize("fps", [0, -1, -100])
def test_omni_config_invalid_video_fps(fps): def test_omni_config_invalid_video_fps(fps):
"""Non-positive FPS must be rejected."""
config = _make_omni_config(default_video_fps=fps) config = _make_omni_config(default_video_fps=fps)
with pytest.raises(ValueError, match="--default-video-fps must be > 0"): with pytest.raises(ValueError, match="--default-video-fps must be > 0"):
config.validate() config.validate()
...@@ -88,7 +98,6 @@ def test_omni_config_invalid_video_fps(fps): ...@@ -88,7 +98,6 @@ def test_omni_config_invalid_video_fps(fps):
@pytest.mark.parametrize("degree", [0, -1]) @pytest.mark.parametrize("degree", [0, -1])
def test_omni_config_invalid_ulysses_degree(degree): def test_omni_config_invalid_ulysses_degree(degree):
"""Non-positive ulysses_degree must be rejected."""
config = _make_omni_config(ulysses_degree=degree) config = _make_omni_config(ulysses_degree=degree)
with pytest.raises(ValueError, match="--ulysses-degree must be > 0"): with pytest.raises(ValueError, match="--ulysses-degree must be > 0"):
config.validate() config.validate()
...@@ -96,7 +105,6 @@ def test_omni_config_invalid_ulysses_degree(degree): ...@@ -96,7 +105,6 @@ def test_omni_config_invalid_ulysses_degree(degree):
@pytest.mark.parametrize("degree", [0, -1]) @pytest.mark.parametrize("degree", [0, -1])
def test_omni_config_invalid_ring_degree(degree): def test_omni_config_invalid_ring_degree(degree):
"""Non-positive ring_degree must be rejected."""
config = _make_omni_config(ring_degree=degree) config = _make_omni_config(ring_degree=degree)
with pytest.raises(ValueError, match="--ring-degree must be > 0"): with pytest.raises(ValueError, match="--ring-degree must be > 0"):
config.validate() config.validate()
...@@ -104,7 +112,6 @@ def test_omni_config_invalid_ring_degree(degree): ...@@ -104,7 +112,6 @@ def test_omni_config_invalid_ring_degree(degree):
@pytest.mark.parametrize("ratio", [0, -0.1, 1.01, 2.0]) @pytest.mark.parametrize("ratio", [0, -0.1, 1.01, 2.0])
def test_omni_config_invalid_boundary_ratio(ratio): def test_omni_config_invalid_boundary_ratio(ratio):
"""boundary_ratio outside (0, 1] must be rejected."""
config = _make_omni_config(boundary_ratio=ratio) config = _make_omni_config(boundary_ratio=ratio)
with pytest.raises(ValueError, match=r"--boundary-ratio must be in \(0, 1\]"): with pytest.raises(ValueError, match=r"--boundary-ratio must be in \(0, 1\]"):
config.validate() config.validate()
...@@ -112,12 +119,8 @@ def test_omni_config_invalid_boundary_ratio(ratio): ...@@ -112,12 +119,8 @@ def test_omni_config_invalid_boundary_ratio(ratio):
@pytest.mark.parametrize("ratio", [0.001, 0.5, 0.875, 1.0]) @pytest.mark.parametrize("ratio", [0.001, 0.5, 0.875, 1.0])
def test_omni_config_valid_boundary_ratio(ratio): def test_omni_config_valid_boundary_ratio(ratio):
"""boundary_ratio within (0, 1] should pass."""
config = _make_omni_config(boundary_ratio=ratio) config = _make_omni_config(boundary_ratio=ratio)
config.validate() # should not raise config.validate()
# --- disaggregated stage flag validation ---
def test_negative_stage_id_rejected(): def test_negative_stage_id_rejected():
...@@ -150,22 +153,20 @@ def test_stage_id_with_stage_configs_path_valid(tmp_path): ...@@ -150,22 +153,20 @@ def test_stage_id_with_stage_configs_path_valid(tmp_path):
config = _make_omni_config( config = _make_omni_config(
stage_id=0, stage_configs_path=str(tmp_path / "stages.yaml") stage_id=0, stage_configs_path=str(tmp_path / "stages.yaml")
) )
config.validate() # should not raise config.validate()
def test_omni_router_with_stage_configs_path_valid(tmp_path): def test_omni_router_with_stage_configs_path_valid(tmp_path):
config = _make_omni_config( config = _make_omni_config(
omni_router=True, stage_configs_path=str(tmp_path / "stages.yaml") omni_router=True, stage_configs_path=str(tmp_path / "stages.yaml")
) )
config.validate() # should not raise config.validate()
# --- vllm_omni API compatibility guards --- # --- vllm_omni API compatibility guards ---
# These tests catch regressions when vllm_omni is upgraded.
def test_omni_engine_args_importable(): def test_omni_engine_args_importable():
"""vllm_omni.engine.arg_utils must export a usable engine args class."""
from vllm_omni.engine.arg_utils import OmniEngineArgs from vllm_omni.engine.arg_utils import OmniEngineArgs
assert hasattr(OmniEngineArgs, "add_cli_args") assert hasattr(OmniEngineArgs, "add_cli_args")
...@@ -173,21 +174,17 @@ def test_omni_engine_args_importable(): ...@@ -173,21 +174,17 @@ def test_omni_engine_args_importable():
def test_omni_engine_args_add_cli_args_no_extra_params(): def test_omni_engine_args_add_cli_args_no_extra_params():
"""add_cli_args must accept a parser and no other required args."""
from vllm_omni.engine.arg_utils import OmniEngineArgs from vllm_omni.engine.arg_utils import OmniEngineArgs
try: try:
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
except ImportError: except ImportError:
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
parser = FlexibleArgumentParser(add_help=False) parser = FlexibleArgumentParser(add_help=False)
OmniEngineArgs.add_cli_args(parser) OmniEngineArgs.add_cli_args(parser)
def test_omni_config_imports_cleanly(): def test_omni_config_imports_cleanly():
"""OmniConfig and parse_omni_args must be importable without error."""
from dynamo.vllm.omni.args import OmniConfig, parse_omni_args from dynamo.vllm.omni.args import OmniConfig, parse_omni_args
assert OmniConfig is not None assert OmniConfig is not None
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Tests that every DiffusionParallelConfig field is either exposed in Dynamo or intentionally skipped."""
import dataclasses
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
try:
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.engine.arg_utils import OmniEngineArgs
from dynamo.vllm.omni.args import OmniDiffusionKwargs, OmniParallelKwargs
from dynamo.vllm.omni.base_handler import BaseOmniHandler
except ImportError:
pytest.skip("vLLM omni dependencies not available", allow_module_level=True)
pytestmark = [
pytest.mark.unit,
pytest.mark.vllm,
pytest.mark.gpu_0,
pytest.mark.pre_merge,
]
# These fields are not exposed in OmniParallelKwargs, because they are derived from other fields.
_SKIP_FIELDS = {
"sequence_parallel_size",
"enable_expert_parallel",
"ulysses_mode",
}
def _diffusion_parallel_fields() -> set:
return {f.name for f in dataclasses.fields(DiffusionParallelConfig)}
def _engine_args_fields() -> set:
fields: set = set()
for cls in OmniEngineArgs.__mro__:
fields |= set(getattr(cls, "__annotations__", {}).keys())
return fields
def _make_config(**parallel_overrides):
cfg = MagicMock()
cfg.model = "test-model"
cfg.stage_configs_path = None
cfg.engine_args.trust_remote_code = False
cfg.diffusion = OmniDiffusionKwargs()
cfg.parallel = dataclasses.replace(OmniParallelKwargs(), **parallel_overrides)
return cfg
def _build_kwargs(config):
handler = BaseOmniHandler.__new__(BaseOmniHandler)
return handler._build_omni_kwargs(config)
class TestDiffusionParallelConfigCoverage:
def test_all_diffusion_parallel_config_fields_covered(self):
"""Every DiffusionParallelConfig field must be in OmniParallelKwargs, engine_args, or _SKIP_FIELDS.
When vllm-omni adds a new parallelism field to DiffusionParallelConfig, this test fails.
Fix by adding it to OmniParallelKwargs and OmniArgGroup, or to _SKIP_FIELDS
"""
parallel_kwarg_fields = {f.name for f in dataclasses.fields(OmniParallelKwargs)}
engine_fields = _engine_args_fields()
uncovered = [
f
for f in _diffusion_parallel_fields()
if f not in _SKIP_FIELDS
and f not in parallel_kwarg_fields
and f not in engine_fields
]
assert not uncovered, (
f"DiffusionParallelConfig fields not covered: {uncovered}. "
f"Add to OmniParallelKwargs and OmniArgGroup, or add to _SKIP_FIELDS with a reason."
)
def test_tensor_parallel_size_read_from_engine_args(self):
"""tensor_parallel_size must come from engine_args (vLLM's --tensor-parallel-size),
not from OmniParallelKwargs, so it applies to both LLM encoder and diffusion transformer.
"""
config = _make_config()
config.engine_args.tensor_parallel_size = 4
with patch("dynamo.vllm.omni.base_handler.DiffusionParallelConfig") as MockCfg:
MockCfg.return_value = SimpleNamespace()
_build_kwargs(config)
_, kwargs = MockCfg.call_args
assert kwargs.get("tensor_parallel_size") == 4
...@@ -7,6 +7,8 @@ import pytest ...@@ -7,6 +7,8 @@ import pytest
try: try:
from PIL import Image from PIL import Image
from vllm.sampling_params import SamplingParams
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from dynamo.common.protocols.audio_protocol import NvCreateAudioSpeechRequest from dynamo.common.protocols.audio_protocol import NvCreateAudioSpeechRequest
from dynamo.common.protocols.image_protocol import NvCreateImageRequest from dynamo.common.protocols.image_protocol import NvCreateImageRequest
...@@ -25,7 +27,7 @@ pytestmark = [ ...@@ -25,7 +27,7 @@ pytestmark = [
] ]
def _make_handler(): def _make_handler(stage_types=("diffusion",)):
with patch( with patch(
"dynamo.vllm.omni.omni_handler.BaseOmniHandler.__init__", return_value=None "dynamo.vllm.omni.omni_handler.BaseOmniHandler.__init__", return_value=None
): ):
...@@ -36,6 +38,22 @@ def _make_handler(): ...@@ -36,6 +38,22 @@ def _make_handler():
config.served_model_name = None config.served_model_name = None
config.output_modalities = ["text"] config.output_modalities = ["text"]
handler.config = config handler.config = config
defaults = []
for st in stage_types:
if st == "diffusion":
defaults.append(OmniDiffusionSamplingParams())
else:
llm_default = MagicMock(spec=SamplingParams)
llm_default.clone.return_value = SamplingParams()
defaults.append(llm_default)
engine_client = MagicMock()
engine_client.default_sampling_params_list = defaults
engine_client.engine.get_stage_metadata.side_effect = lambda i: {
"stage_type": stage_types[i]
}
handler.engine_client = engine_client
return handler return handler
...@@ -167,6 +185,36 @@ class TestI2VEngineInputs: ...@@ -167,6 +185,36 @@ class TestI2VEngineInputs:
assert empty.guidance_scale_2 is None assert empty.guidance_scale_2 is None
class TestBuildSamplingParamsList:
def test_single_diffusion_stage(self):
handler = _make_handler(stage_types=("diffusion",))
sp = OmniDiffusionSamplingParams(height=512, width=512)
result = handler._build_sampling_params_list(sp)
assert len(result) == 1
assert result[0] is sp
def test_llm_then_diffusion(self):
handler = _make_handler(stage_types=("llm", "diffusion"))
sp = OmniDiffusionSamplingParams(height=512, width=512)
result = handler._build_sampling_params_list(sp)
assert len(result) == 2
assert isinstance(result[0], SamplingParams)
assert result[1] is sp
def test_fallback_when_defaults_empty(self):
handler = _make_handler()
handler.engine_client.default_sampling_params_list = []
sp = OmniDiffusionSamplingParams(height=512, width=512)
result = handler._build_sampling_params_list(sp)
assert result == [sp]
def test_llm_default_is_cloned(self):
handler = _make_handler(stage_types=("llm", "diffusion"))
sp = OmniDiffusionSamplingParams()
handler._build_sampling_params_list(sp)
handler.engine_client.default_sampling_params_list[0].clone.assert_called_once()
class TestBuildOriginalPrompt: class TestBuildOriginalPrompt:
"""build_original_prompt only carries prompt/negative_prompt/multi_modal_data. """build_original_prompt only carries prompt/negative_prompt/multi_modal_data.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment