Unverified Commit 2d86b81d authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: flux benchmarking script + code clean (#8083)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 2e7a1e6c
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Generic aiperf benchmark for vLLM-Omni text-to-image generation.
# Assumes the server (Dynamo or native vllm-omni) is already running.
#
# Usage:
# bash aiperf_image_gen.sh [OPTIONS]
#
# Options:
# --model <model> Model to benchmark (default: black-forest-labs/FLUX.2-klein-4B)
# --url <url> Server URL (default: http://localhost:8000)
# --concurrency <n> Number of concurrent requests (default: 1)
# --request-count <n> Total requests to send (default: 10)
# --warmup-count <n> Warmup requests before measurement (default: 2)
# --image-size <WxH> Generated image size (default: 1024x1024)
# --response-format <fmt> Response format: url or b64_json (default: url)
# --prompt-tokens-mean <n> Mean synthetic prompt length in tokens (default: 50)
# --prompt-tokens-stddev <n> Stddev of synthetic prompt length (default: 10)
# -h, --help Show this help message
#
# Examples:
# bash aiperf_image_gen.sh
# bash aiperf_image_gen.sh --model zai-org/GLM-Image --concurrency 4
# bash aiperf_image_gen.sh --model Qwen/Qwen-Image --image-size 512x512 --request-count 20
MODEL="black-forest-labs/FLUX.2-klein-4B"
URL="http://localhost:8000"
CONCURRENCY=1
REQUEST_COUNT=10
WARMUP_COUNT=2
IMAGE_SIZE="1024x1024"
RESPONSE_FORMAT="url"
PROMPT_TOKENS_MEAN=50
PROMPT_TOKENS_STDDEV=10
ARTIFACT_DIR=""
while [[ $# -gt 0 ]]; do
case $1 in
--model) MODEL=$2; shift 2 ;;
--url) URL=$2; shift 2 ;;
--concurrency) CONCURRENCY=$2; shift 2 ;;
--request-count) REQUEST_COUNT=$2; shift 2 ;;
--warmup-count) WARMUP_COUNT=$2; shift 2 ;;
--image-size) IMAGE_SIZE=$2; shift 2 ;;
--response-format) RESPONSE_FORMAT=$2; shift 2 ;;
--prompt-tokens-mean) PROMPT_TOKENS_MEAN=$2; shift 2 ;;
--prompt-tokens-stddev) PROMPT_TOKENS_STDDEV=$2; shift 2 ;;
--artifact-dir) ARTIFACT_DIR=$2; shift 2 ;;
-h|--help)
sed -n '/^# Usage/,/^[^#]/p' "$0" | grep '^#' | sed 's/^# \?//'
exit 0 ;;
*) echo "Unknown option: $1"; exit 1 ;;
esac
done
AIPERF_ARGS=(
aiperf profile
--model "$MODEL"
--tokenizer gpt2
--url "$URL"
--endpoint-type image-generation
--synthetic-input-tokens-mean "$PROMPT_TOKENS_MEAN"
--synthetic-input-tokens-stddev "$PROMPT_TOKENS_STDDEV"
--extra-inputs "size:${IMAGE_SIZE}"
--extra-inputs "response_format:${RESPONSE_FORMAT}"
--concurrency "$CONCURRENCY"
--request-count "$REQUEST_COUNT"
--warmup-request-count "$WARMUP_COUNT"
--ui none
--no-server-metrics
)
if [[ -n "$ARTIFACT_DIR" ]]; then
AIPERF_ARGS+=(--artifact-dir "$ARTIFACT_DIR")
fi
"${AIPERF_ARGS[@]}"
......@@ -4,6 +4,7 @@
"""Omni-specific argument parsing for python -m dynamo.vllm.omni."""
import argparse
import dataclasses
import logging
from typing import Optional
......@@ -24,15 +25,50 @@ from dynamo.common.configuration.utils import add_argument, add_negatable_bool_a
logger = logging.getLogger(__name__)
class OmniArgGroup(ArgGroup):
"""Diffusion pipeline kwargs passed through to AsyncOmni() constructor.
@dataclasses.dataclass
class OmniDiffusionKwargs:
"""AsyncOmni constructor kwargs for diffusion engine configuration.
Every field here is passed directly to AsyncOmni(**kwargs) and consumed by
_create_default_diffusion_stage_cfg() in vllm-omni. Adding a new vllm-omni
diffusion flag only requires adding it here and to OmniArgGroup — the
passthrough in base_handler is automatic.
"""
enable_layerwise_offload: bool = False
layerwise_num_gpu_layers: int = 1
vae_use_slicing: bool = False
vae_use_tiling: bool = False
boundary_ratio: float = 0.875
flow_shift: Optional[float] = None
cache_backend: Optional[str] = None
cache_config: Optional[str] = None
enable_cache_dit_summary: bool = False
enable_cpu_offload: bool = False
enforce_eager: bool = False
These are NOT part of OmniEngineArgs (which handles vLLM engine-level
args like model, tp, max_model_len). Instead they are direct constructor
kwargs for AsyncOmni and need Dynamo-side env-var (DYN_OMNI_*) support,
so we define them here rather than relying on the upstream arg parser.
@dataclasses.dataclass
class OmniParallelKwargs:
"""Diffusion parallelism configuration passed to DiffusionParallelConfig.
Every field here maps 1:1 to a DiffusionParallelConfig field (excluding
tensor_parallel_size which comes from engine_args, and fixed/derived fields).
Adding a new parallelism field only requires adding it here and to OmniArgGroup.
"""
ulysses_degree: int = 1
ring_degree: int = 1
cfg_parallel_size: int = 1
vae_patch_parallel_size: int = 1
use_hsdp: bool = False
hsdp_shard_size: int = -1
hsdp_replicate_size: int = 1
class OmniArgGroup(ArgGroup):
"""CLI argument definitions for Dynamo vLLM-Omni."""
name = "dynamo-omni"
def add_arguments(self, parser) -> None:
......@@ -49,7 +85,6 @@ class OmniArgGroup(ArgGroup):
help="Path to vLLM-Omni stage configuration YAML file (optional).",
)
# Video encoding
add_argument(
g,
flag_name="--default-video-fps",
......@@ -59,7 +94,7 @@ class OmniArgGroup(ArgGroup):
help="Default frames per second for generated videos.",
)
# Layerwise offloading
# OmniDiffusionKwargs fields
add_negatable_bool_argument(
g,
flag_name="--enable-layerwise-offload",
......@@ -75,8 +110,6 @@ class OmniArgGroup(ArgGroup):
arg_type=int,
help="Number of ready layers (blocks) to keep on GPU during generation.",
)
# VAE optimization
add_negatable_bool_argument(
g,
flag_name="--vae-use-slicing",
......@@ -91,8 +124,6 @@ class OmniArgGroup(ArgGroup):
default=False,
help="Enable VAE tiling for memory optimization in diffusion models.",
)
# Diffusion scheduling
add_argument(
g,
flag_name="--boundary-ratio",
......@@ -113,8 +144,6 @@ class OmniArgGroup(ArgGroup):
arg_type=float,
help="Scheduler flow_shift parameter (5.0 for 720p, 12.0 for 480p).",
)
# Cache acceleration
add_argument(
g,
flag_name="--cache-backend",
......@@ -141,8 +170,6 @@ class OmniArgGroup(ArgGroup):
default=False,
help="Enable cache-dit summary logging after diffusion forward passes.",
)
# Execution mode
add_negatable_bool_argument(
g,
flag_name="--enable-cpu-offload",
......@@ -204,7 +231,7 @@ class OmniArgGroup(ArgGroup):
help="Maximum size in bytes for reference audio files (default: 50MB).",
)
# Diffusion parallel configuration
# OmniParallelKwargs fields
add_argument(
g,
flag_name="--ulysses-degree",
......@@ -227,9 +254,43 @@ class OmniArgGroup(ArgGroup):
env_var="DYN_OMNI_CFG_PARALLEL_SIZE",
default=1,
arg_type=int,
choices=[1, 2],
choices=[1, 2, 3],
help="Number of GPUs used for classifier free guidance parallelism.",
)
add_argument(
g,
flag_name="--vae-patch-parallel-size",
env_var="DYN_OMNI_VAE_PATCH_PARALLEL_SIZE",
default=1,
arg_type=int,
help="Number of ranks used for VAE patch/tile parallelism during decode/encode.",
)
add_negatable_bool_argument(
g,
flag_name="--use-hsdp",
env_var="DYN_OMNI_USE_HSDP",
default=False,
help=(
"Enable Hybrid Sharded Data Parallel (HSDP) for diffusion models. "
"Shards model weights across GPUs to reduce per-GPU memory usage."
),
)
add_argument(
g,
flag_name="--hsdp-shard-size",
env_var="DYN_OMNI_HSDP_SHARD_SIZE",
default=-1,
arg_type=int,
help="Number of GPUs to shard model weights across when using HSDP (-1 = auto).",
)
add_argument(
g,
flag_name="--hsdp-replicate-size",
env_var="DYN_OMNI_HSDP_REPLICATE_SIZE",
default=1,
arg_type=int,
help="Number of HSDP replica groups (default: 1).",
)
# Disaggregated stage worker flags
add_argument(
......@@ -244,7 +305,6 @@ class OmniArgGroup(ArgGroup):
"Requires --stage-configs-path."
),
)
add_negatable_bool_argument(
g,
flag_name="--omni-router",
......@@ -263,30 +323,18 @@ class OmniConfig(DynamoRuntimeConfig):
component: str = "backend"
endpoint: Optional[str] = None
# mirror vLLM
model: str
served_model_name: Optional[str] = None
# vLLM-Omni engine args
engine_args: OmniEngineArgs
# OmniArgGroup fields (populated by from_cli_args)
stage_configs_path: Optional[str] = None
default_video_fps: int = 16
enable_layerwise_offload: bool = False
layerwise_num_gpu_layers: int = 1
vae_use_slicing: bool = False
vae_use_tiling: bool = False
boundary_ratio: float = 0.875
flow_shift: Optional[float] = None
cache_backend: Optional[str] = None
cache_config: Optional[str] = None
enable_cache_dit_summary: bool = False
enable_cpu_offload: bool = False
enforce_eager: bool = False
ulysses_degree: int = 1
ring_degree: int = 1
cfg_parallel_size: int = 1
# Nested structs — each group of fields has a clear destination
diffusion: OmniDiffusionKwargs = dataclasses.field(
default_factory=OmniDiffusionKwargs
)
parallel: OmniParallelKwargs = dataclasses.field(default_factory=OmniParallelKwargs)
# TTS parameters
tts_max_instructions_length: int = 500
......@@ -299,15 +347,36 @@ class OmniConfig(DynamoRuntimeConfig):
stage_id: Optional[int] = None
omni_router: bool = False
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "OmniConfig":
config = super().from_cli_args(args)
config.diffusion = dataclasses.replace(
OmniDiffusionKwargs(),
**{
f.name: getattr(args, f.name)
for f in dataclasses.fields(OmniDiffusionKwargs)
if hasattr(args, f.name)
},
)
config.parallel = dataclasses.replace(
OmniParallelKwargs(),
**{
f.name: getattr(args, f.name)
for f in dataclasses.fields(OmniParallelKwargs)
if hasattr(args, f.name)
},
)
return config
def validate(self) -> None:
DynamoRuntimeConfig.validate(self)
if self.default_video_fps <= 0:
raise ValueError("--default-video-fps must be > 0")
if self.ulysses_degree <= 0:
if self.parallel.ulysses_degree <= 0:
raise ValueError("--ulysses-degree must be > 0")
if self.ring_degree <= 0:
if self.parallel.ring_degree <= 0:
raise ValueError("--ring-degree must be > 0")
if not (0 < self.boundary_ratio <= 1):
if not (0 < self.diffusion.boundary_ratio <= 1):
raise ValueError("--boundary-ratio must be in (0, 1]")
if self.stage_configs_path is None:
if self.stage_id is not None:
......@@ -334,7 +403,6 @@ def parse_omni_args() -> OmniConfig:
dynamo_runtime_argspec.add_arguments(parser)
omni_argspec.add_arguments(parser)
# Add vLLM-Omni engine args
vg = parser.add_argument_group(
"vLLM-Omni Engine Options. Please refer to vLLM-Omni documentation for more details."
)
......@@ -349,7 +417,6 @@ def parse_omni_args() -> OmniConfig:
args, unknown = parser.parse_known_args()
config = OmniConfig.from_cli_args(args)
# Default endpoint to "generate" if not explicitly provided by user
if config.endpoint is None:
config.endpoint = "generate"
......
......@@ -4,6 +4,7 @@
"""Base handler for vLLM-Omni multi-stage pipelines."""
import asyncio
import dataclasses
import logging
import time
from typing import Any, AsyncGenerator, Dict
......@@ -74,31 +75,17 @@ class BaseOmniHandler(BaseWorkerHandler[Dict[str, Any], Dict[str, Any]]):
if config.stage_configs_path:
omni_kwargs["stage_configs_path"] = config.stage_configs_path
# Diffusion engine-level params — read directly from config namespace
diffusion_fields = [
"enable_layerwise_offload",
"layerwise_num_gpu_layers",
"vae_use_slicing",
"vae_use_tiling",
"boundary_ratio",
"flow_shift",
"cache_backend",
"cache_config",
"enable_cache_dit_summary",
"enable_cpu_offload",
"enforce_eager",
]
for field in diffusion_fields:
value = getattr(config, field, None)
for field, value in dataclasses.asdict(config.diffusion).items():
if value is not None:
omni_kwargs[field] = value
# Build DiffusionParallelConfig if available
# tensor_parallel_size comes from engine_args (vLLM's --tensor-parallel-size)
if DiffusionParallelConfig is not None:
parallel_config = DiffusionParallelConfig(
ulysses_degree=getattr(config, "ulysses_degree", 1),
ring_degree=getattr(config, "ring_degree", 1),
cfg_parallel_size=getattr(config, "cfg_parallel_size", 1),
tensor_parallel_size=getattr(
config.engine_args, "tensor_parallel_size", 1
),
**dataclasses.asdict(config.parallel),
)
omni_kwargs["parallel_config"] = parallel_config
else:
......
......@@ -2,18 +2,20 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
import random
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Optional, Union, cast
import PIL.Image
from fsspec.implementations.dirfs import DirFileSystem
from vllm.sampling_params import SamplingParams
from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt
from dynamo._core import Context
from dynamo.common.multimodal import ImageLoader
from dynamo.common.protocols.audio_protocol import NvCreateAudioSpeechRequest
from dynamo.common.protocols.image_protocol import NvCreateImageRequest
from dynamo.common.protocols.video_protocol import NvCreateVideoRequest
from dynamo.common.protocols.image_protocol import ImageNvExt, NvCreateImageRequest
from dynamo.common.protocols.video_protocol import NvCreateVideoRequest, VideoNvExt
from dynamo.common.utils.output_modalities import RequestType, parse_request_type
from dynamo.common.utils.video_utils import compute_num_frames, parse_size
from dynamo.llm.exceptions import EngineShutdown
......@@ -255,35 +257,60 @@ class OmniHandler(BaseOmniHandler):
fps=0,
)
@staticmethod
def _update_if_not_none(object: Any, key: str, val: Any) -> None:
if val is not None:
setattr(object, key, val)
def _build_sampling_params_list(
self, diffusion_sp: OmniDiffusionSamplingParams
) -> list:
# This is in sync with how vllm-omni builds sampling params currently.
defaults = list(self.engine_client.default_sampling_params_list or [])
result = []
for i, default in enumerate(defaults):
stage_type = self.engine_client.engine.get_stage_metadata(i).get(
"stage_type", "llm"
)
if stage_type == "diffusion":
result.append(diffusion_sp)
else:
result.append(
default.clone() if hasattr(default, "clone") else SamplingParams()
)
return result if result else [diffusion_sp]
def _engine_inputs_from_image(self, req: NvCreateImageRequest) -> EngineInputs:
"""Build engine inputs from an NvCreateImageRequest."""
width, height = parse_size(req.size, default_w=1024, default_h=1024)
nvext = req.nvext
nvext = req.nvext or ImageNvExt()
prompt = OmniTextPrompt(
prompt=req.prompt,
negative_prompt=(
nvext.negative_prompt if nvext and nvext.negative_prompt else None
),
)
prompt = OmniTextPrompt(prompt=req.prompt)
if nvext and nvext.negative_prompt is not None:
prompt.negative_prompt = nvext.negative_prompt
sp = OmniDiffusionSamplingParams(
height=height,
width=width,
)
if req.n is not None:
sp.num_outputs_per_prompt = req.n
if nvext:
if nvext.num_inference_steps is not None:
sp.num_inference_steps = nvext.num_inference_steps
if nvext.guidance_scale is not None:
sp.guidance_scale = nvext.guidance_scale
if nvext.seed is not None:
sp.seed = nvext.seed
# TODO: Apply LoRA Request params here and move to shared utilities for disaggregated stages to use as well.
self._update_if_not_none(sp, "num_outputs_per_prompt", req.n)
self._update_if_not_none(sp, "num_inference_steps", nvext.num_inference_steps)
self._update_if_not_none(sp, "guidance_scale", nvext.guidance_scale)
# If seed is not provided, generate a random one to ensure
# a proper generator is initialized in the backend.
# This fixes issues where using the default global generator
# might produce blurry images in some environments.
sp.seed = (
nvext.seed if nvext.seed is not None else random.randint(0, 2**32 - 1)
)
return EngineInputs(
prompt=prompt,
sampling_params_list=[sp],
sampling_params_list=self._build_sampling_params_list(sp),
request_type=RequestType.IMAGE_GENERATION,
response_format=req.response_format,
)
......@@ -302,25 +329,19 @@ class OmniHandler(BaseOmniHandler):
I2V pipeline pre-process can use it.
"""
width, height = parse_size(req.size)
nvext = req.nvext
nvext_fps = nvext.fps if nvext else None
nvext_num_frames = nvext.num_frames if nvext else None
nvext = req.nvext or VideoNvExt()
num_frames = compute_num_frames(
num_frames=nvext_num_frames,
num_frames=nvext.num_frames,
seconds=req.seconds,
fps=nvext_fps,
fps=nvext.fps,
default_fps=DEFAULT_VIDEO_FPS,
)
fps = nvext_fps if nvext_fps is not None else DEFAULT_VIDEO_FPS
fps = nvext.fps if nvext.fps is not None else DEFAULT_VIDEO_FPS
prompt = OmniTextPrompt(
prompt=req.prompt,
negative_prompt=(
nvext.negative_prompt if nvext and nvext.negative_prompt else None
),
)
prompt = OmniTextPrompt(prompt=req.prompt)
if nvext.negative_prompt is not None:
prompt.negative_prompt = nvext.negative_prompt
if image is not None:
prompt["multi_modal_data"] = {"image": image}
......@@ -335,19 +356,14 @@ class OmniHandler(BaseOmniHandler):
width=width,
num_frames=num_frames,
)
if nvext:
if nvext.num_inference_steps is not None:
sp.num_inference_steps = nvext.num_inference_steps
if nvext.guidance_scale is not None:
sp.guidance_scale = nvext.guidance_scale
if nvext.seed is not None:
sp.seed = nvext.seed
if nvext.boundary_ratio is not None:
sp.boundary_ratio = nvext.boundary_ratio
if nvext.guidance_scale_2 is not None:
sp.guidance_scale_2 = nvext.guidance_scale_2
if fps is not None:
sp.fps = fps
self._update_if_not_none(sp, "num_inference_steps", nvext.num_inference_steps)
self._update_if_not_none(sp, "guidance_scale", nvext.guidance_scale)
sp.seed = (
nvext.seed if nvext.seed is not None else random.randint(0, 2**32 - 1)
)
self._update_if_not_none(sp, "boundary_ratio", nvext.boundary_ratio)
self._update_if_not_none(sp, "guidance_scale_2", nvext.guidance_scale_2)
self._update_if_not_none(sp, "fps", fps)
logger.info(
f"Video diffusion request: prompt='{req.prompt[:50]}...', "
......@@ -356,7 +372,7 @@ class OmniHandler(BaseOmniHandler):
return EngineInputs(
prompt=prompt,
sampling_params_list=[sp],
sampling_params_list=self._build_sampling_params_list(sp),
request_type=RequestType.VIDEO_GENERATION,
fps=fps,
)
......@@ -3,12 +3,17 @@
"""Unit tests for OmniConfig validation."""
import dataclasses
from types import SimpleNamespace
import pytest
try:
from dynamo.vllm.omni.args import OmniConfig
from dynamo.vllm.omni.args import (
OmniConfig,
OmniDiffusionKwargs,
OmniParallelKwargs,
)
except ImportError:
pytest.skip("vLLM omni dependencies not available", allow_module_level=True)
......@@ -19,11 +24,25 @@ pytestmark = [
pytest.mark.pre_merge,
]
_DIFFUSION_FIELDS = {f.name for f in dataclasses.fields(OmniDiffusionKwargs)}
_PARALLEL_FIELDS = {f.name for f in dataclasses.fields(OmniParallelKwargs)}
def _make_omni_config(**overrides) -> OmniConfig:
"""Build a minimal OmniConfig with valid defaults, applying overrides."""
defaults: dict = {
# DynamoRuntimeConfig fields
"""Build a minimal OmniConfig with valid defaults, applying overrides.
Overrides for diffusion fields (e.g. boundary_ratio) and parallel fields
(e.g. ulysses_degree) are automatically routed to the correct nested struct.
"""
diffusion_overrides = {k: v for k, v in overrides.items() if k in _DIFFUSION_FIELDS}
parallel_overrides = {k: v for k, v in overrides.items() if k in _PARALLEL_FIELDS}
flat_overrides = {
k: v
for k, v in overrides.items()
if k not in _DIFFUSION_FIELDS and k not in _PARALLEL_FIELDS
}
flat_defaults: dict = {
"namespace": "dynamo",
"component": "backend",
"endpoint": None,
......@@ -42,45 +61,36 @@ def _make_omni_config(**overrides) -> OmniConfig:
"output_modalities": None,
"media_output_fs_url": "file:///tmp/dynamo_media",
"media_output_http_url": None,
# OmniConfig fields
"model": "test-model",
"served_model_name": None,
"engine_args": SimpleNamespace(),
"stage_configs_path": None,
"default_video_fps": 16,
"enable_layerwise_offload": False,
"layerwise_num_gpu_layers": 1,
"vae_use_slicing": False,
"vae_use_tiling": False,
"boundary_ratio": 0.875,
"flow_shift": None,
"cache_backend": None,
"cache_config": None,
"enable_cache_dit_summary": False,
"enable_cpu_offload": False,
"enforce_eager": False,
"ulysses_degree": 1,
"ring_degree": 1,
"cfg_parallel_size": 1,
"tts_max_instructions_length": 500,
"tts_max_new_tokens_min": 1,
"tts_max_new_tokens_max": 4096,
"tts_ref_audio_timeout": 15,
"tts_ref_audio_max_bytes": 50 * 1024 * 1024,
"stage_id": None,
"omni_router": False,
}
defaults.update(overrides)
flat_defaults.update(flat_overrides)
obj = OmniConfig.__new__(OmniConfig)
for k, v in defaults.items():
for k, v in flat_defaults.items():
setattr(obj, k, v)
obj.diffusion = dataclasses.replace(OmniDiffusionKwargs(), **diffusion_overrides)
obj.parallel = dataclasses.replace(OmniParallelKwargs(), **parallel_overrides)
return obj
def test_omni_config_valid_defaults():
"""Config with valid defaults passes validation."""
config = _make_omni_config()
config.validate() # should not raise
config.validate()
@pytest.mark.parametrize("fps", [0, -1, -100])
def test_omni_config_invalid_video_fps(fps):
"""Non-positive FPS must be rejected."""
config = _make_omni_config(default_video_fps=fps)
with pytest.raises(ValueError, match="--default-video-fps must be > 0"):
config.validate()
......@@ -88,7 +98,6 @@ def test_omni_config_invalid_video_fps(fps):
@pytest.mark.parametrize("degree", [0, -1])
def test_omni_config_invalid_ulysses_degree(degree):
"""Non-positive ulysses_degree must be rejected."""
config = _make_omni_config(ulysses_degree=degree)
with pytest.raises(ValueError, match="--ulysses-degree must be > 0"):
config.validate()
......@@ -96,7 +105,6 @@ def test_omni_config_invalid_ulysses_degree(degree):
@pytest.mark.parametrize("degree", [0, -1])
def test_omni_config_invalid_ring_degree(degree):
"""Non-positive ring_degree must be rejected."""
config = _make_omni_config(ring_degree=degree)
with pytest.raises(ValueError, match="--ring-degree must be > 0"):
config.validate()
......@@ -104,7 +112,6 @@ def test_omni_config_invalid_ring_degree(degree):
@pytest.mark.parametrize("ratio", [0, -0.1, 1.01, 2.0])
def test_omni_config_invalid_boundary_ratio(ratio):
"""boundary_ratio outside (0, 1] must be rejected."""
config = _make_omni_config(boundary_ratio=ratio)
with pytest.raises(ValueError, match=r"--boundary-ratio must be in \(0, 1\]"):
config.validate()
......@@ -112,12 +119,8 @@ def test_omni_config_invalid_boundary_ratio(ratio):
@pytest.mark.parametrize("ratio", [0.001, 0.5, 0.875, 1.0])
def test_omni_config_valid_boundary_ratio(ratio):
"""boundary_ratio within (0, 1] should pass."""
config = _make_omni_config(boundary_ratio=ratio)
config.validate() # should not raise
# --- disaggregated stage flag validation ---
config.validate()
def test_negative_stage_id_rejected():
......@@ -150,22 +153,20 @@ def test_stage_id_with_stage_configs_path_valid(tmp_path):
config = _make_omni_config(
stage_id=0, stage_configs_path=str(tmp_path / "stages.yaml")
)
config.validate() # should not raise
config.validate()
def test_omni_router_with_stage_configs_path_valid(tmp_path):
config = _make_omni_config(
omni_router=True, stage_configs_path=str(tmp_path / "stages.yaml")
)
config.validate() # should not raise
config.validate()
# --- vllm_omni API compatibility guards ---
# These tests catch regressions when vllm_omni is upgraded.
def test_omni_engine_args_importable():
"""vllm_omni.engine.arg_utils must export a usable engine args class."""
from vllm_omni.engine.arg_utils import OmniEngineArgs
assert hasattr(OmniEngineArgs, "add_cli_args")
......@@ -173,21 +174,17 @@ def test_omni_engine_args_importable():
def test_omni_engine_args_add_cli_args_no_extra_params():
"""add_cli_args must accept a parser and no other required args."""
from vllm_omni.engine.arg_utils import OmniEngineArgs
try:
from vllm.utils import FlexibleArgumentParser
except ImportError:
from vllm.utils.argparse_utils import FlexibleArgumentParser
parser = FlexibleArgumentParser(add_help=False)
OmniEngineArgs.add_cli_args(parser)
def test_omni_config_imports_cleanly():
"""OmniConfig and parse_omni_args must be importable without error."""
from dynamo.vllm.omni.args import OmniConfig, parse_omni_args
assert OmniConfig is not None
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Tests that every DiffusionParallelConfig field is either exposed in Dynamo or intentionally skipped."""
import dataclasses
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
try:
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.engine.arg_utils import OmniEngineArgs
from dynamo.vllm.omni.args import OmniDiffusionKwargs, OmniParallelKwargs
from dynamo.vllm.omni.base_handler import BaseOmniHandler
except ImportError:
pytest.skip("vLLM omni dependencies not available", allow_module_level=True)
pytestmark = [
pytest.mark.unit,
pytest.mark.vllm,
pytest.mark.gpu_0,
pytest.mark.pre_merge,
]
# These fields are not exposed in OmniParallelKwargs, because they are derived from other fields.
_SKIP_FIELDS = {
"sequence_parallel_size",
"enable_expert_parallel",
"ulysses_mode",
}
def _diffusion_parallel_fields() -> set:
return {f.name for f in dataclasses.fields(DiffusionParallelConfig)}
def _engine_args_fields() -> set:
fields: set = set()
for cls in OmniEngineArgs.__mro__:
fields |= set(getattr(cls, "__annotations__", {}).keys())
return fields
def _make_config(**parallel_overrides):
cfg = MagicMock()
cfg.model = "test-model"
cfg.stage_configs_path = None
cfg.engine_args.trust_remote_code = False
cfg.diffusion = OmniDiffusionKwargs()
cfg.parallel = dataclasses.replace(OmniParallelKwargs(), **parallel_overrides)
return cfg
def _build_kwargs(config):
handler = BaseOmniHandler.__new__(BaseOmniHandler)
return handler._build_omni_kwargs(config)
class TestDiffusionParallelConfigCoverage:
def test_all_diffusion_parallel_config_fields_covered(self):
"""Every DiffusionParallelConfig field must be in OmniParallelKwargs, engine_args, or _SKIP_FIELDS.
When vllm-omni adds a new parallelism field to DiffusionParallelConfig, this test fails.
Fix by adding it to OmniParallelKwargs and OmniArgGroup, or to _SKIP_FIELDS
"""
parallel_kwarg_fields = {f.name for f in dataclasses.fields(OmniParallelKwargs)}
engine_fields = _engine_args_fields()
uncovered = [
f
for f in _diffusion_parallel_fields()
if f not in _SKIP_FIELDS
and f not in parallel_kwarg_fields
and f not in engine_fields
]
assert not uncovered, (
f"DiffusionParallelConfig fields not covered: {uncovered}. "
f"Add to OmniParallelKwargs and OmniArgGroup, or add to _SKIP_FIELDS with a reason."
)
def test_tensor_parallel_size_read_from_engine_args(self):
"""tensor_parallel_size must come from engine_args (vLLM's --tensor-parallel-size),
not from OmniParallelKwargs, so it applies to both LLM encoder and diffusion transformer.
"""
config = _make_config()
config.engine_args.tensor_parallel_size = 4
with patch("dynamo.vllm.omni.base_handler.DiffusionParallelConfig") as MockCfg:
MockCfg.return_value = SimpleNamespace()
_build_kwargs(config)
_, kwargs = MockCfg.call_args
assert kwargs.get("tensor_parallel_size") == 4
......@@ -7,6 +7,8 @@ import pytest
try:
from PIL import Image
from vllm.sampling_params import SamplingParams
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from dynamo.common.protocols.audio_protocol import NvCreateAudioSpeechRequest
from dynamo.common.protocols.image_protocol import NvCreateImageRequest
......@@ -25,7 +27,7 @@ pytestmark = [
]
def _make_handler():
def _make_handler(stage_types=("diffusion",)):
with patch(
"dynamo.vllm.omni.omni_handler.BaseOmniHandler.__init__", return_value=None
):
......@@ -36,6 +38,22 @@ def _make_handler():
config.served_model_name = None
config.output_modalities = ["text"]
handler.config = config
defaults = []
for st in stage_types:
if st == "diffusion":
defaults.append(OmniDiffusionSamplingParams())
else:
llm_default = MagicMock(spec=SamplingParams)
llm_default.clone.return_value = SamplingParams()
defaults.append(llm_default)
engine_client = MagicMock()
engine_client.default_sampling_params_list = defaults
engine_client.engine.get_stage_metadata.side_effect = lambda i: {
"stage_type": stage_types[i]
}
handler.engine_client = engine_client
return handler
......@@ -167,6 +185,36 @@ class TestI2VEngineInputs:
assert empty.guidance_scale_2 is None
class TestBuildSamplingParamsList:
def test_single_diffusion_stage(self):
handler = _make_handler(stage_types=("diffusion",))
sp = OmniDiffusionSamplingParams(height=512, width=512)
result = handler._build_sampling_params_list(sp)
assert len(result) == 1
assert result[0] is sp
def test_llm_then_diffusion(self):
handler = _make_handler(stage_types=("llm", "diffusion"))
sp = OmniDiffusionSamplingParams(height=512, width=512)
result = handler._build_sampling_params_list(sp)
assert len(result) == 2
assert isinstance(result[0], SamplingParams)
assert result[1] is sp
def test_fallback_when_defaults_empty(self):
handler = _make_handler()
handler.engine_client.default_sampling_params_list = []
sp = OmniDiffusionSamplingParams(height=512, width=512)
result = handler._build_sampling_params_list(sp)
assert result == [sp]
def test_llm_default_is_cloned(self):
handler = _make_handler(stage_types=("llm", "diffusion"))
sp = OmniDiffusionSamplingParams()
handler._build_sampling_params_list(sp)
handler.engine_client.default_sampling_params_list[0].clone.assert_called_once()
class TestBuildOriginalPrompt:
"""build_original_prompt only carries prompt/negative_prompt/multi_modal_data.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment