".github/vscode:/vscode.git/clone" did not exist on "a3d624a72668e007923e4941da4ae1ec193dbdc4"
Unverified Commit 1c199f88 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

feat: vllm omni disagg (#7409)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent ba34f131
......@@ -231,6 +231,31 @@ class OmniArgGroup(ArgGroup):
help="Number of GPUs used for classifier free guidance parallelism.",
)
# Disaggregated stage worker flags
add_argument(
g,
flag_name="--stage-id",
env_var="DYN_OMNI_STAGE_ID",
default=None,
arg_type=int,
help=(
"Stage ID for disaggregated omni mode. "
"Run a single stage as an independent Dynamo worker. "
"Requires --stage-configs-path."
),
)
add_negatable_bool_argument(
g,
flag_name="--omni-router",
env_var="DYN_OMNI_ROUTER",
default=False,
help=(
"Run as the stage router, orchestrating the multi-stage DAG. "
"Requires --stage-configs-path. Mutually exclusive with --stage-id."
),
)
class OmniConfig(DynamoRuntimeConfig):
"""Configuration for Dynamo vLLM-Omni worker."""
......@@ -270,6 +295,10 @@ class OmniConfig(DynamoRuntimeConfig):
tts_ref_audio_timeout: int = 15
tts_ref_audio_max_bytes: int = 50 * 1024 * 1024
# Disaggregated stage worker fields
stage_id: Optional[int] = None
omni_router: bool = False
def validate(self) -> None:
DynamoRuntimeConfig.validate(self)
if self.default_video_fps <= 0:
......@@ -280,6 +309,15 @@ class OmniConfig(DynamoRuntimeConfig):
raise ValueError("--ring-degree must be > 0")
if not (0 < self.boundary_ratio <= 1):
raise ValueError("--boundary-ratio must be in (0, 1]")
if self.stage_configs_path is None:
if self.stage_id is not None:
raise ValueError("--stage-id requires --stage-configs-path")
if self.omni_router:
raise ValueError("--omni-router requires --stage-configs-path")
if self.stage_id is not None and self.stage_id < 0:
raise ValueError("--stage-id must be >= 0")
if self.stage_id is not None and self.omni_router:
raise ValueError("--stage-id and --omni-router are mutually exclusive")
def parse_omni_args() -> OmniConfig:
......
......@@ -17,6 +17,8 @@ except ImportError:
DiffusionParallelConfig = None # type: ignore[assignment, misc]
from dynamo._core import Context
from dynamo.common.protocols.audio_protocol import NvAudioSpeechResponse
from dynamo.common.utils.output_modalities import RequestType
from dynamo.vllm.handlers import BaseWorkerHandler, build_sampling_params
logger = logging.getLogger(__name__)
......@@ -170,11 +172,7 @@ class BaseOmniHandler(BaseWorkerHandler[Dict[str, Any], Dict[str, Any]]):
For AUDIO_GENERATION returns NvAudioSpeechResponse format.
For all other types returns OpenAI chat.completion.chunk format.
"""
from dynamo.common.utils.output_modalities import RequestType
if request_type == RequestType.AUDIO_GENERATION:
from dynamo.common.protocols.audio_protocol import NvAudioSpeechResponse
return NvAudioSpeechResponse(
id=request_id,
model=self.config.served_model_name or self.config.model,
......
......@@ -20,7 +20,9 @@ from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.vllm.health_check import VllmOmniHealthCheckPayload
from dynamo.vllm.main import setup_metrics_collection
from dynamo.vllm.omni.tts_utils import (
from dynamo.vllm.omni.stage_router import init_omni_stage_router
from dynamo.vllm.omni.stage_worker import init_omni_stage
from dynamo.vllm.omni.utils import (
cleanup_dummy_tokenizer_for_tts,
ensure_dummy_tokenizer_for_tts,
)
......@@ -145,8 +147,15 @@ async def worker():
install_signal_handlers(loop, runtime, shutdown_endpoints, shutdown_event)
await init_omni(runtime, config, shutdown_event)
logger.debug("Omni worker completed, exiting...")
if config.stage_id is not None:
await init_omni_stage(runtime, config, shutdown_endpoints, shutdown_event)
logger.debug("init_omni_stage completed (stage %d)", config.stage_id)
elif config.omni_router:
await init_omni_stage_router(runtime, config, shutdown_endpoints)
logger.debug("init_omni_stage_router completed")
else:
await init_omni(runtime, config, shutdown_event)
logger.debug("Omni worker completed, exiting...")
def main():
......
......@@ -17,7 +17,18 @@ import uuid
from io import BytesIO
from typing import Any, Dict, Optional
import numpy as np
import soundfile as sf
import torch
from diffusers.utils.export_utils import export_to_video
from dynamo.common.protocols.audio_protocol import AudioData, NvAudioSpeechResponse
from dynamo.common.protocols.image_protocol import ImageData, NvImagesResponse
from dynamo.common.protocols.video_protocol import NvVideosResponse, VideoData
from dynamo.common.storage import upload_to_fs
from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.common.utils.output_modalities import RequestType
from dynamo.common.utils.video_utils import normalize_video_frames
logger = logging.getLogger(__name__)
......@@ -92,7 +103,6 @@ class DiffusionFormatter:
)
if not images:
return None
from dynamo.common.utils.output_modalities import RequestType
if request_type == RequestType.VIDEO_GENERATION:
return await self._encode_video(
......@@ -108,12 +118,6 @@ class DiffusionFormatter:
async def _encode_video(
self, images: list, request_id: str, fps: int
) -> Dict[str, Any] | None:
from diffusers.utils.export_utils import export_to_video
from dynamo.common.protocols.video_protocol import NvVideosResponse, VideoData
from dynamo.common.storage import upload_to_fs
from dynamo.common.utils.video_utils import normalize_video_frames
try:
start_time = time.time()
frame_list = normalize_video_frames(images)
......@@ -157,9 +161,6 @@ class DiffusionFormatter:
request_type: Any,
response_format: Optional[str] = None,
) -> Dict[str, Any] | None:
from dynamo.common.protocols.image_protocol import ImageData, NvImagesResponse
from dynamo.common.utils.output_modalities import RequestType
if not images:
return _error_chunk(request_id, self._model_name, "No images generated")
......@@ -209,8 +210,6 @@ class DiffusionFormatter:
async def _prepare_images(
self, images: list, request_id: str, response_format: Optional[str] = None
) -> list:
from dynamo.common.storage import upload_to_fs
outlist = []
for img in images:
buf = BytesIO()
......@@ -239,12 +238,10 @@ class AudioFormatter:
def __init__(
self, model_name: str, media_fs: Any, media_http_url: Optional[str]
) -> None:
from dynamo.common.protocols.audio_protocol import AudioData
self._model_name = model_name
self._media_fs = media_fs
self._media_http_url = media_http_url
self._AudioData = AudioData
self._AudioData = AudioData # stored for use in format()
async def format(
self, stage_output: Any, request_id: str, **ctx: Any
......@@ -284,8 +281,6 @@ class AudioFormatter:
)
if response_format == "url":
from dynamo.common.storage import upload_to_fs
ext = encode_fmt if encode_fmt != "opus" else "ogg"
url = await upload_to_fs(
self._media_fs,
......@@ -299,8 +294,6 @@ class AudioFormatter:
b64_json=base64.b64encode(audio_bytes).decode()
)
from dynamo.common.protocols.audio_protocol import NvAudioSpeechResponse
return NvAudioSpeechResponse(
id=request_id,
object="audio.speech",
......@@ -317,9 +310,6 @@ class AudioFormatter:
return self._error_response(request_id, str(e))
def _extract_audio_tensor(self, mm_output: Dict[str, Any]) -> tuple:
import numpy as np
import torch
audio_key = "audio" if "audio" in mm_output else "model_outputs"
audio_val = mm_output.get(audio_key)
if audio_val is None:
......@@ -350,8 +340,6 @@ class AudioFormatter:
def _encode_audio(
self, audio_np: Any, sample_rate: int, fmt: str = "wav", speed: float = 1.0
) -> tuple:
import soundfile as sf
if speed != 1.0:
try:
import librosa
......@@ -381,8 +369,6 @@ class AudioFormatter:
return buf.getvalue(), media_type
def _error_response(self, request_id: str, error: str) -> Dict[str, Any]:
from dynamo.common.protocols.audio_protocol import NvAudioSpeechResponse
return NvAudioSpeechResponse(
id=request_id,
model=self._model_name,
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Stage router for disaggregated omni pipelines."""
import json
import logging
import uuid
from typing import Any, AsyncGenerator, Dict, List
from vllm_omni.entrypoints.utils import load_stage_configs_from_yaml
from dynamo import prometheus_names
from dynamo.common.storage import get_fs
from dynamo.common.utils.output_modalities import (
RequestType,
get_output_modalities,
parse_request_type,
)
from dynamo.llm import ModelInput, register_model
from dynamo.runtime import DistributedRuntime
from dynamo.vllm.main import setup_metrics_collection
from dynamo.vllm.omni.args import OmniConfig
from dynamo.vllm.omni.output_formatter import OutputFormatter
from dynamo.vllm.omni.stage_worker import _resolve_model_type
from dynamo.vllm.omni.types import StageOutput
from dynamo.vllm.omni.utils import shm_deserialize
logger = logging.getLogger(__name__)
class OmniStageRouter:
"""Pure message broker for multi-stage omni pipelines."""
def __init__(
self,
config: OmniConfig,
stage_configs_path: str,
) -> None:
self.config = config
self.stage_configs = load_stage_configs_from_yaml(stage_configs_path)
self.stage_clients: Dict[str, Any] = {}
media_fs = (
get_fs(config.media_output_fs_url) if config.media_output_fs_url else None
)
self._formatter = OutputFormatter(
model_name=config.served_model_name or config.model,
media_fs=media_fs,
media_http_url=config.media_output_http_url,
default_fps=config.default_video_fps,
)
def set_stage_client(self, model_stage: str, client: Any) -> None:
self.stage_clients[model_stage] = client
logger.info("Registered stage client: %s", model_stage)
async def generate(
self,
request: dict,
context, # noqa: ARG002 — context unused; router generates its own request_id
) -> AsyncGenerator[dict, None]:
request_id = str(uuid.uuid4())
_, request_type = parse_request_type(request, self.config.output_modalities)
stage_outputs: List[StageOutput] = []
for stage_idx, stage_cfg in enumerate(self.stage_configs):
model_stage = getattr(
stage_cfg.engine_args, "model_stage", f"stage{stage_idx}"
)
client = self.stage_clients.get(model_stage)
if client is None:
yield {
"error": f"No client for stage '{model_stage}'",
"finished": True,
}
return
if stage_idx == 0:
# This is a workaround for now to pass in the raw request to stage 0. StageRequest validates it but ignores any unknown keys, so it gets passed through.
stage_request = {"request_id": request_id, **request}
else:
stage_request = stage_outputs[-1].to_next_stage_request(request_id)
raw_stage_output = {}
logger.info(
"Router: stage %d request keys=%s",
stage_idx,
list(stage_request.keys()),
)
# For now, it is just one chunk output from the stage. Keeping the loop style in mind if in future we decide to stream multiple chunks from the stage.
async for chunk in await client.round_robin(stage_request):
data = chunk.data()
if isinstance(data, (str, bytes)):
data = json.loads(data)
raw_stage_output.update(data)
stage_outputs.append(StageOutput.model_validate(raw_stage_output))
if stage_outputs[-1].error:
yield {"error": stage_outputs[-1].error, "finished": True}
return
final = stage_outputs[-1]
if not final.shm_meta:
yield {"error": "No SHM output from final stage", "finished": True}
return
# Build formatting context from the original request
nvext = request.get("nvext") or {}
fmt_ctx: Dict[str, Any] = {}
if nvext.get("fps") is not None:
fmt_ctx["fps"] = nvext["fps"]
if request.get("response_format") is not None:
fmt_ctx["response_format"] = request["response_format"]
if nvext.get("speed") is not None:
fmt_ctx["speed"] = nvext["speed"]
async for chunk in self._format_output(
final, request_id, request_type, fmt_ctx
):
yield chunk
async def _format_output(
self,
stage_output: StageOutput,
request_id: str,
request_type: RequestType,
ctx: dict,
) -> AsyncGenerator[dict, None]:
"""Read OmniRequestOutput from SHM and format via OutputFormatter."""
shm_meta = stage_output.shm_meta
if not shm_meta:
logger.warning("Router: no shm_meta in stage output")
return
result = shm_deserialize(shm_meta)
chunk = await self._formatter.format(
result, request_id, request_type=request_type, **ctx
)
if chunk:
yield chunk
else:
final_output_type = getattr(result, "final_output_type", "unknown")
logger.warning(
"Router: formatter returned None, final_output_type=%s",
final_output_type,
)
yield {
"error": f"Formatter returned no output for type '{final_output_type}'",
"finished": True,
}
async def init_omni_stage_router(
runtime: DistributedRuntime,
config: OmniConfig,
shutdown_endpoints: list,
) -> None:
"""Initialize OmniStageRouter as a Dynamo backend endpoint."""
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint or 'generate'}"
)
shutdown_endpoints[:] = [generate_endpoint]
router = OmniStageRouter(config, config.stage_configs_path) # type: ignore[arg-type]
setup_metrics_collection(config, generate_endpoint, logger)
# Discover stage endpoints
for stage_cfg in router.stage_configs:
model_stage = getattr(
stage_cfg.engine_args, "model_stage", f"stage{stage_cfg.stage_id}"
)
client = await runtime.endpoint(
f"{config.namespace}.{model_stage}.generate"
).client()
await client.wait_for_instances()
router.set_stage_client(model_stage, client)
final_cfg = router.stage_configs[-1]
final_output_type = getattr(final_cfg, "final_output_type", "image")
model_type = get_output_modalities(config.output_modalities, config.model)
if model_type is None:
model_type = _resolve_model_type(final_output_type)
await register_model(
ModelInput.Text,
model_type,
generate_endpoint,
config.model,
config.served_model_name,
)
logger.info("OmniStageRouter registered at '%s'", generate_endpoint)
try:
await generate_endpoint.serve_endpoint(
router.generate,
graceful_shutdown=True,
metrics_labels=[
(
prometheus_names.labels.MODEL,
config.served_model_name or config.model,
),
(
prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model,
),
],
)
except Exception as e:
logger.error("OmniStageRouter endpoint failed: %s", e)
raise
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Single-stage omni worker for disaggregated pipelines."""
import asyncio
import importlib
import logging
import os
import tempfile
from dataclasses import dataclass
from typing import Any, AsyncGenerator
import yaml
from vllm_omni.distributed.omni_connectors import initialize_orchestrator_connectors
from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.entrypoints.stage_utils import serialize_obj, shm_write_bytes
from vllm_omni.entrypoints.utils import load_stage_configs_from_yaml
from dynamo import prometheus_names
from dynamo.llm import ModelType
from dynamo.runtime import DistributedRuntime
from dynamo.vllm.health_check import VllmOmniHealthCheckPayload
from dynamo.vllm.main import setup_metrics_collection
from dynamo.vllm.omni.args import OmniConfig
from dynamo.vllm.omni.types import StageEngine, StageRequest, _int_keyed
from dynamo.vllm.omni.utils import _build_sampling_params, parse_omni_request
logger = logging.getLogger(__name__)
@dataclass
class _Proxy:
"""Satisfies stage_list[i].engine_outputs for processor functions.
Processor functions (e.g. ar2diffusion) access stage_list[i].engine_outputs
as a list of OmniRequestOutput objects.
"""
engine_outputs: Any = None
class OmniStageWorker:
"""Single-stage worker: fetches inputs → runs processor → runs engine → writes output.
For stage 0: gets engine_inputs directly from request.
For stage N > 0: fetches previous stage outputs from connectors via stage_connector_refs,
runs the pre-processor (e.g. thinker2talker) to produce this stage's engine inputs,
then runs the engine.
Non-final stages write output to a connector and yield stage_connector_refs for the router.
Final stages write to SHM and yield shm_meta for the router to format.
"""
def __init__(
self,
engine: StageEngine,
stage_config: Any,
connectors: dict,
stage_id: int,
output_modalities: list | None = None,
default_video_fps: int = 16,
) -> None:
self.engine = engine
self.stage_id = stage_id
self.connectors = connectors # {(from_stage, to_stage): vllm_omni connector}
self.final_output: bool = getattr(stage_config, "final_output", False)
self._output_modalities = output_modalities or []
self._default_video_fps = default_video_fps
self.stage_config = stage_config
func_path = getattr(stage_config, "custom_process_input_func", None)
self._processor = _load_processor(func_path)
self._engine_input_source: list[int] = getattr(
stage_config, "engine_input_source", []
)
self._requires_mm: bool = getattr(
stage_config, "requires_multimodal_data", False
)
async def generate(self, request: dict, context) -> AsyncGenerator[dict, None]:
req = StageRequest.model_validate(request)
request_id = req.request_id or context.id()
original_prompt = req.original_prompt
# JSON sends dict keys as strings; normalize to int for stage_connector_refs.
stage_connector_refs = _int_keyed(req.stage_connector_refs)
# --- Resolve engine inputs ---
sampling_params_list_override: dict | None = None
if stage_connector_refs:
# Stage N > 0: fetch previous stage outputs from connectors, run pre-processor.
sampling_params_list_override = req.sampling_params_list
try:
stage_list = self._fetch_stage_inputs(stage_connector_refs, request_id)
except RuntimeError as e:
yield {"error": str(e), "finished": True}
return
if len(stage_list) != len(
self._engine_input_source or stage_connector_refs
):
logger.warning(
"Stage %d: expected %d stage inputs, got %d",
self.stage_id,
len(self._engine_input_source or stage_connector_refs),
len(stage_list),
)
if self._processor is not None:
prompt = self._processor(
stage_list,
self._engine_input_source,
[original_prompt],
self._requires_mm,
)
if isinstance(prompt, list) and len(prompt) == 1:
prompt = prompt[0]
else:
# No processor: use the most recent fetched stage output directly.
prompt = stage_list[-1].engine_outputs[0]
elif req.request_id is not None:
# Stage 0 via router: raw request forwarded with request_id — parse it.
parsed = parse_omni_request(
request, self._output_modalities, self._default_video_fps
)
prompt = parsed["engine_inputs"]
original_prompt = parsed["original_prompt"]
sampling_params_list_override = parsed["sampling_params_list"]
else:
# Direct frontend → stage (single-stage, no router).
prompt = request
logger.debug(
"Stage %d: engine.generate for %s — prompt type=%s",
self.stage_id,
request_id,
type(prompt).__name__,
)
sp = _build_sampling_params(self.stage_config, sampling_params_list_override)
last_result = None
try:
async for chunk in self.engine.generate(
prompt, request_id, sampling_params_list=sp
):
last_result = chunk
except Exception as e:
logger.error(
"Stage %d engine error for %s: %s",
self.stage_id,
request_id,
e,
exc_info=True,
)
yield {"error": str(e), "finished": True}
return
# --- Write output ---
if not self.final_output:
from_s, to_s = _connector_key(self.stage_id, self.stage_id + 1)
connector = self.connectors.get((from_s, to_s))
if connector is not None:
try:
ok, _, metadata = connector.put( # type: ignore[arg-type]
from_s, to_s, request_id, last_result
)
except Exception as e:
logger.error(
"Stage %d: connector.put() raised %s: %s",
self.stage_id,
type(e).__name__,
e,
exc_info=True,
)
yield {"error": f"connector.put() raised: {e}", "finished": True}
return
if not ok:
yield {"error": "connector.put() failed", "finished": True}
return
out: dict = {
"original_prompt": original_prompt,
"stage_connector_refs": {
**{str(k): v for k, v in stage_connector_refs.items()},
str(self.stage_id): metadata,
},
"finished": True,
}
if sampling_params_list_override is not None:
out["sampling_params_list"] = sampling_params_list_override
yield out
return
logger.warning(
"Stage %d: no connector found for edge (%s→%s), falling through to SHM",
self.stage_id,
from_s,
to_s,
)
# Final stage → router: write output to shared memory and return the SHM handle.
# The router reads it back via shm_deserialize() to format the response.
#
# NOTE: This is a single-node-only workaround — SHM requires the final stage
# worker and the router to reside on the same machine. A proper multi-node
# solution would use a connector edge (like inter-stage connectors) instead.
# Tracked in TODO: shm_meta should be replaced by a YAML-configured connector edge.
shm_meta = shm_write_bytes(serialize_obj(last_result), name=request_id)
yield {"shm_meta": shm_meta, "finished": True}
def _fetch_stage_inputs(
self, stage_connector_refs: dict[int, Any], request_id: str
) -> list[_Proxy]:
"""Fetch previous stage outputs from connectors for the processor/engine.
Fetches only the stages listed in engine_input_source (or all refs if empty).
Returns _Proxy objects in engine_input_source order.
Raises RuntimeError on any failure so the caller can propagate it as an error chunk.
"""
sources = self._engine_input_source or sorted(stage_connector_refs.keys())
stage_list = []
for stage_k in sources:
if (meta_k := stage_connector_refs.get(stage_k)) is None:
raise RuntimeError(
f"Stage {self.stage_id}: no connector ref for source stage {stage_k}"
)
if (
connector := self.connectors.get(_connector_key(stage_k, self.stage_id))
) is None:
raise RuntimeError(
f"Stage {self.stage_id}: no connector for edge ({stage_k}{self.stage_id})"
)
try:
payload = connector.get(
str(stage_k), str(self.stage_id), request_id, metadata=meta_k
)
except Exception as e:
raise RuntimeError(
f"Stage {self.stage_id}: connector.get() failed: {e}"
) from e
payload_data = payload[0] if isinstance(payload, tuple) else payload
if not payload_data:
raise RuntimeError(
f"Stage {self.stage_id}: empty payload from connector ({stage_k}{self.stage_id})"
)
engine_inputs = (
payload_data.get("engine_inputs")
if isinstance(payload_data, dict)
else payload_data
)
stage_list.append(_Proxy(engine_outputs=[engine_inputs]))
return stage_list
async def init_omni_stage(
runtime: DistributedRuntime,
config: OmniConfig,
shutdown_endpoints: list,
shutdown_event: asyncio.Event | None = None,
) -> None:
"""Initialize a single omni stage worker.
Mirrors init_omni() setup pattern exactly to avoid routing/handler issues.
"""
if config.stage_id is None:
raise ValueError("--stage-id is required for stage worker initialization")
stage_id: int = config.stage_id
stage_configs = load_stage_configs_from_yaml(config.stage_configs_path) # type: ignore[arg-type]
if stage_id >= len(stage_configs):
raise ValueError(
f"--stage-id {stage_id} out of range (YAML has {len(stage_configs)} stages)"
)
my_config = stage_configs[stage_id]
stage_type: str = getattr(my_config, "stage_type", "llm")
# Stage worker registers at {ns}.{model_stage}.generate — NOT {ns}.backend.generate.
# Router registers at {ns}.backend.generate and discovers workers by model_stage.
model_stage = getattr(my_config.engine_args, "model_stage", f"stage{stage_id}")
generate_endpoint = runtime.endpoint(f"{config.namespace}.{model_stage}.generate")
shutdown_endpoints[:] = [generate_endpoint]
engine = _create_engine(config.model, my_config, stage_type)
logger.info("Stage %d: engine created (type=%s)", stage_id, stage_type)
# Connectors for inter-stage output transfer — type determined by YAML config
# (SharedMemoryConnector, MooncakeConnector, etc.)
_, connectors = initialize_orchestrator_connectors(config.stage_configs_path) # type: ignore[arg-type]
worker = OmniStageWorker(
engine=engine,
stage_config=my_config,
connectors=connectors,
output_modalities=config.output_modalities,
default_video_fps=config.default_video_fps,
stage_id=stage_id,
)
setup_metrics_collection(config, generate_endpoint, logger)
if config.engine_args.data_parallel_rank:
logger.info(
"Stage %d: non-leader DP rank %d; waiting for shutdown",
stage_id,
config.engine_args.data_parallel_rank,
)
if shutdown_event is not None:
await shutdown_event.wait()
return
logger.info(
"Stage %d: serving internal stage endpoint '%s' (not registering model)",
stage_id,
generate_endpoint,
)
health_check_payload = (
await VllmOmniHealthCheckPayload.create(engine) # type: ignore[arg-type]
).to_dict()
try:
await generate_endpoint.serve_endpoint(
worker.generate,
graceful_shutdown=True,
metrics_labels=[
(
prometheus_names.labels.MODEL,
config.served_model_name or config.model,
),
(
prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model,
),
],
health_check_payload=health_check_payload,
)
except Exception as e:
logger.error("Stage %d: endpoint failed: %s", stage_id, e)
raise
def _connector_key(from_stage: int, to_stage: int) -> tuple[str, str]:
"""Build the connector dict key used by initialize_orchestrator_connectors."""
return (str(from_stage), str(to_stage))
def _load_processor(func_path: str | None) -> Any:
"""Load a processor function from a dotted module path, or return None."""
if not func_path:
return None
module_path, func_name = func_path.rsplit(".", 1)
return getattr(importlib.import_module(module_path), func_name)
def _create_engine(model: str, stage_config: Any, stage_type: str) -> StageEngine:
"""Create AsyncOmni with a single-stage YAML."""
single_stage_config = {
"stage_args": [_stage_config_to_dict(stage_config, stage_type)],
"runtime": {"edges": []},
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as tmp:
yaml.dump(single_stage_config, tmp)
tmp_path = tmp.name
try:
return AsyncOmni(model=model, stage_configs_path=tmp_path)
finally:
os.unlink(tmp_path)
def _stage_config_to_dict(stage_config: Any, stage_type: str) -> dict:
"""Convert a parsed stage config to a single-stage YAML dict."""
from omegaconf import OmegaConf # type: ignore[import-not-found]
def _to_plain(obj: Any) -> Any:
if OmegaConf.is_config(obj):
return OmegaConf.to_container(obj, resolve=True)
if hasattr(obj, "__dict__"):
return dict(vars(obj))
return obj
result: dict = {
"stage_id": 0,
"stage_type": stage_type,
"engine_args": _to_plain(stage_config.engine_args),
"final_output": True,
"final_output_type": getattr(stage_config, "final_output_type", "text"),
}
for key in ("default_sampling_params", "is_comprehension"):
val = getattr(stage_config, key, None)
if val is not None:
result[key] = _to_plain(val)
runtime = getattr(stage_config, "runtime", None)
if runtime is not None:
rt = _to_plain(runtime)
rt["devices"] = "0"
result["runtime"] = rt
return result
def _resolve_model_type(final_output_type: str) -> ModelType:
return {
"image": ModelType.Images,
"video": ModelType.Videos,
}.get(final_output_type, ModelType.Chat)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Protocol types for disaggregated omni stage workers and connectors.
"""
import dataclasses
import logging
from typing import Any, AsyncGenerator, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, model_validator
@runtime_checkable
class StageEngine(Protocol):
"""Any engine that can generate outputs for a single pipeline stage.
Matches AsyncOmni.generate() signature — the only vllm_omni engine
with a consistent async generator interface for both LLM and diffusion.
"""
def generate(
self,
prompt: Any,
request_id: str = "",
*,
sampling_params_list: Any = None,
) -> AsyncGenerator[Any, None]:
...
class StageOutput(BaseModel):
"""Validated output dict from a stage worker.
Unknown keys are silently dropped (extra="ignore") to prevent arbitrary
stage output from accumulating across stages. Only protocol fields pass through.
finished/error are consumed by the router and not forwarded to subsequent stages.
"""
model_config = ConfigDict(extra="ignore")
@model_validator(mode="before")
@classmethod
def _warn_dropped_keys(cls, values: Any) -> Any:
if isinstance(values, dict):
known = {
"shm_meta",
"original_prompt",
"stage_connector_refs",
"sampling_params_list",
"finished",
"error",
}
dropped = set(values.keys()) - known
if dropped:
logging.warning(
"StageOutput: dropping unexpected keys from stage response: %s",
sorted(dropped),
)
return values
# TODO: shm_meta should be gone, its a WAR right now to send final output to the router via shm
shm_meta: dict | None = None
original_prompt: dict | None = None
# stage_connector_refs maps stage_id (str key from JSON) → opaque connector metadata
# returned by connector.put(). This metadata is an address ticket passed to
# connector.get(metadata=...) by the next stage to locate and fetch the data.
# The format is connector-specific and opaque to the router:
# SHM connector: {"shm": {"name": "<block_name>", "size": N}, "size": N}
# or {"inline_bytes": b"...", "size": N} (small payloads)
# Mooncake (RDMA): {"source_host": "...", "source_port": N, "data_size": N, ...}
# Keys arrive as strings from JSON; workers normalize them to int via _int_keyed().
stage_connector_refs: dict[str, Any] | None = None
sampling_params_list: dict | None = None
finished: bool | None = None
error: str | None = None
def to_next_stage_request(self, request_id: str) -> dict:
"""Build the request dict for the next stage: only inter-stage protocol fields.
shm_meta is intentionally excluded — it is final-stage → router only.
"""
fields = self.model_dump(
include={"original_prompt", "stage_connector_refs", "sampling_params_list"},
exclude_none=True,
)
fields["request_id"] = request_id
return fields
class StageRequest(BaseModel):
"""Validated request dict received by a stage worker from the router.
extra="ignore" handles all three request shapes:
Stage 0: {request_id, engine_inputs, original_prompt, stage_connector_refs: {}}
Stage N>0: {request_id, original_prompt, stage_connector_refs: {"0": ref0, ...}}
Direct: raw frontend request (no router, single-stage deployment)
"""
model_config = ConfigDict(extra="ignore")
request_id: str | None = None
engine_inputs: Any = None
original_prompt: dict | None = None
# stage_connector_refs: address tickets from previous stages (same format as
# StageOutput.stage_connector_refs). Callers normalize string keys to int via _int_keyed().
stage_connector_refs: dict[str, Any] | None = None
sampling_params_list: dict | None = None
def _int_keyed(d: dict | None) -> dict[int, Any]:
"""Normalize JSON-deserialized string keys back to int for stage_connector_refs."""
if not d:
return {}
return {int(k): v for k, v in d.items()}
@dataclasses.dataclass
class OmniInterStageRequest:
"""Protocol message passed between stage workers via the router.
The router passes this opaquely without inspecting stage_connector_refs.
Workers accumulate connector refs as the pipeline progresses, allowing
any stage to reconstruct stage_list for N-stage processor functions.
JSON-serializable: original_prompt is a TypedDict (dict subclass) with
no tensors. Tensors (token_ids, images) travel via the connector payload.
"""
request_id: str
# OmniPromptType | list | None — typed as Any to avoid importing vllm_omni at
# module level. Set once by the router at pipeline start, never modified by workers.
original_prompt: Any
# Grows as the pipeline progresses: {} → {0: ref0} → {0: ref0, 1: ref1} → ...
stage_connector_refs: dict[int, Any] = dataclasses.field(default_factory=dict)
def to_dict(self) -> dict:
return {
"request_id": self.request_id,
"original_prompt": self.original_prompt,
"stage_connector_refs": self.stage_connector_refs,
}
@classmethod
def from_dict(cls, d: dict) -> "OmniInterStageRequest":
return cls(
request_id=d["request_id"],
original_prompt=d["original_prompt"],
stage_connector_refs=_int_keyed(d.get("stage_connector_refs")),
)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""TTS/audio utility functions for the vLLM-Omni backend."""
"""Shared utilities for the vLLM-Omni backend."""
import json
import logging
from pathlib import Path
from typing import Any, cast
logger = logging.getLogger(__name__)
from huggingface_hub import scan_cache_dir
from vllm.sampling_params import SamplingParams
from vllm_omni.distributed.omni_connectors.utils.serialization import OmniSerializer
from vllm_omni.entrypoints.stage_utils import shm_read_bytes
from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt
from dynamo.common.utils.output_modalities import RequestType, parse_request_type
from dynamo.common.utils.video_utils import compute_num_frames, parse_size
DEFAULT_IMAGE_SIZE = "1024x1024"
DEFAULT_VIDEO_SIZE = "832x480"
def shm_deserialize(shm_meta: dict) -> Any:
"""Read and deserialize an OmniRequestOutput from shared memory."""
return OmniSerializer.deserialize(shm_read_bytes(shm_meta))
def build_original_prompt(request: dict, nvext: dict, height: int, width: int) -> Any:
"""Build the rich prompt dict that processor functions (ar2diffusion etc.) read."""
prompt = OmniTextPrompt(
prompt=request.get("prompt", ""),
negative_prompt=request.get("negative_prompt", None),
)
if request.get("multi_modal_data"):
prompt["multi_modal_data"] = request["multi_modal_data"]
return prompt
def parse_omni_request(
request: dict, output_modalities: list, default_video_fps: int = 16
) -> dict:
"""Parse a raw frontend request into engine_inputs, original_prompt, sampling_params_list.
Returns:
engine_inputs: text prompt (str or OmniTextPrompt) for the stage 0 engine
original_prompt: rich prompt dict with geometry/params for processor functions
sampling_params_list: raw user overrides dict (height/width/nvext) or None for chat
"""
_, request_type = parse_request_type(request, output_modalities)
if request_type in (RequestType.VIDEO_GENERATION, RequestType.IMAGE_GENERATION):
is_video = request_type == RequestType.VIDEO_GENERATION
nvext = request.get("nvext") or {}
default_size = DEFAULT_VIDEO_SIZE if is_video else DEFAULT_IMAGE_SIZE
size_kwargs = {} if is_video else {"default_w": 1024, "default_h": 1024}
width, height = parse_size(request.get("size", default_size), **size_kwargs)
sp: dict = {"height": height, "width": width, **nvext}
if is_video:
sp["num_frames"] = compute_num_frames(
num_frames=nvext.get("num_frames"),
fps=nvext.get("fps"),
default_fps=default_video_fps,
)
return {
"engine_inputs": OmniTextPrompt(prompt=request.get("prompt", "")),
"original_prompt": build_original_prompt(request, nvext, height, width),
"sampling_params_list": sp,
}
# Chat / text
messages = request.get("messages", [])
text = next(
(m.get("content", "") for m in reversed(messages) if m.get("role") == "user"),
request.get("prompt", ""),
)
return {
"engine_inputs": text,
"original_prompt": {"prompt": text},
"sampling_params_list": None,
}
def _build_sampling_params(stage_config: Any, overrides: dict | None) -> list | None:
"""Construct typed sampling params from YAML default_sampling_params."""
from omegaconf import OmegaConf # type: ignore[import-not-found]
defaults = getattr(stage_config, "default_sampling_params", None)
if not defaults:
return None
if OmegaConf.is_config(defaults):
params = OmegaConf.to_container(defaults, resolve=True)
else:
params = dict(defaults)
params_dict = cast(dict[str, Any], params)
stage_type = getattr(stage_config, "stage_type", "llm")
if stage_type == "diffusion":
diffusion_params = OmniDiffusionSamplingParams(**params_dict)
if overrides:
for arg, value in overrides.items():
if hasattr(diffusion_params, arg):
setattr(diffusion_params, arg, value)
return [diffusion_params]
llm_params = SamplingParams(**params_dict)
if overrides:
for arg, value in overrides.items():
if hasattr(llm_params, arg):
setattr(llm_params, arg, value)
return [llm_params]
def ensure_dummy_tokenizer_for_tts(model: str) -> list[Path]:
......@@ -24,8 +126,6 @@ def ensure_dummy_tokenizer_for_tts(model: str) -> list[Path]:
This is a short-term workaround. The long-term fix is making TokenizerKind
optional in ModelDeploymentCard::from_repo_checkout().
"""
from huggingface_hub import scan_cache_dir
created: list[Path] = []
cache_info = scan_cache_dir()
for repo in cache_info.repos:
......@@ -33,16 +133,12 @@ def ensure_dummy_tokenizer_for_tts(model: str) -> list[Path]:
for revision in repo.revisions:
tokenizer_path = Path(revision.snapshot_path) / "tokenizer.json"
if not tokenizer_path.exists():
logger.warning(
logging.warning(
"TTS model %s has no tokenizer.json; "
"creating a minimal placeholder at %s",
model,
tokenizer_path,
)
# Write a minimal but valid HF tokenizer JSON that
# tokenizers.TokenizerFast.from_file() can parse without
# crashing. The "model" key with type "BPE" is the
# minimum required structure.
minimal_tokenizer = {
"version": "1.0",
"model": {"type": "BPE", "vocab": {}, "merges": []},
......@@ -63,6 +159,6 @@ def cleanup_dummy_tokenizer_for_tts(paths: list[Path]):
for path in paths:
try:
path.unlink(missing_ok=True)
logger.info("Removed dummy tokenizer placeholder: %s", path)
logging.info("Removed dummy tokenizer placeholder: %s", path)
except OSError as e:
logger.warning("Failed to remove dummy tokenizer %s: %s", path, e)
logging.warning("Failed to remove dummy tokenizer %s: %s", path, e)
......@@ -62,6 +62,8 @@ def _make_omni_config(**overrides) -> OmniConfig:
"ulysses_degree": 1,
"ring_degree": 1,
"cfg_parallel_size": 1,
"stage_id": None,
"omni_router": False,
}
defaults.update(overrides)
obj = OmniConfig.__new__(OmniConfig)
......@@ -115,6 +117,49 @@ def test_omni_config_valid_boundary_ratio(ratio):
config.validate() # should not raise
# --- disaggregated stage flag validation ---
def test_negative_stage_id_rejected():
config = _make_omni_config(stage_id=-1, stage_configs_path="/fake/path.yaml")
with pytest.raises(ValueError, match="--stage-id must be >= 0"):
config.validate()
def test_stage_id_requires_stage_configs_path():
config = _make_omni_config(stage_id=0, stage_configs_path=None)
with pytest.raises(ValueError, match="--stage-id requires"):
config.validate()
def test_omni_router_requires_stage_configs_path():
config = _make_omni_config(omni_router=True, stage_configs_path=None)
with pytest.raises(ValueError, match="--omni-router requires"):
config.validate()
def test_stage_id_and_omni_router_mutually_exclusive(tmp_path):
config = _make_omni_config(
stage_id=0, omni_router=True, stage_configs_path=str(tmp_path / "stages.yaml")
)
with pytest.raises(ValueError, match="mutually exclusive"):
config.validate()
def test_stage_id_with_stage_configs_path_valid(tmp_path):
config = _make_omni_config(
stage_id=0, stage_configs_path=str(tmp_path / "stages.yaml")
)
config.validate() # should not raise
def test_omni_router_with_stage_configs_path_valid(tmp_path):
config = _make_omni_config(
omni_router=True, stage_configs_path=str(tmp_path / "stages.yaml")
)
config.validate() # should not raise
# --- vllm_omni API compatibility guards ---
# These tests catch regressions when vllm_omni is upgraded.
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for omni/types.py Protocol definitions.
No GPU, no vllm_omni — pure structural typing checks.
"""
import json
import pytest
try:
from dynamo.vllm.omni.types import OmniInterStageRequest, StageEngine, StageOutput
except ImportError:
pytest.skip("vLLM omni dependencies not available", allow_module_level=True)
pytestmark = [
pytest.mark.unit,
pytest.mark.vllm,
pytest.mark.gpu_1,
pytest.mark.pre_merge,
]
class _MockEngine:
def generate(self, prompt, request_id="", *, sampling_params_list=None):
async def _gen():
yield {}
return _gen()
def test_stage_engine_protocol_satisfied():
assert isinstance(_MockEngine(), StageEngine)
def test_missing_generate_not_stage_engine():
assert not isinstance(object(), StageEngine)
# ── StageOutput ───────────────────────────────────────────
class TestStageOutput:
def test_unknown_keys_are_dropped(self):
out = StageOutput.model_validate(
{"shm_meta": {"name": "x"}, "unknown_key": "noise"}
)
assert out.shm_meta == {"name": "x"}
assert not hasattr(out, "unknown_key")
def test_to_next_stage_request_excludes_finished_and_error(self):
out = StageOutput.model_validate(
{
"stage_connector_refs": {"0": {"name": "x"}},
"finished": True,
"error": None,
}
)
req = out.to_next_stage_request("req-1")
assert "finished" not in req
assert "error" not in req
assert req["request_id"] == "req-1"
def test_to_next_stage_request_excludes_shm_meta(self):
"""shm_meta is final-stage → router only; must not be forwarded to next stage."""
out = StageOutput.model_validate({"shm_meta": {"name": "x"}})
req = out.to_next_stage_request("req-2")
assert "shm_meta" not in req
def test_to_next_stage_request_passes_stage_connector_refs(self):
out = StageOutput.model_validate(
{
"original_prompt": {"prompt": "hi"},
"stage_connector_refs": {"0": {"ref": "abc"}},
}
)
req = out.to_next_stage_request("req-3")
assert req["original_prompt"] == {"prompt": "hi"}
assert req["stage_connector_refs"] == {"0": {"ref": "abc"}}
assert req["request_id"] == "req-3"
# ── OmniInterStageRequest ──────────────────────────────────
class TestOmniInterStageRequest:
def test_roundtrip_empty_refs(self):
req = OmniInterStageRequest(
request_id="req-1",
original_prompt={"prompt": "hello", "height": 1024, "width": 1024},
)
recovered = OmniInterStageRequest.from_dict(req.to_dict())
assert recovered.request_id == "req-1"
assert recovered.original_prompt == {
"prompt": "hello",
"height": 1024,
"width": 1024,
}
assert recovered.stage_connector_refs == {}
def test_roundtrip_with_refs(self):
req = OmniInterStageRequest(
request_id="req-2",
original_prompt={"prompt": "a cat"},
stage_connector_refs={0: {"name": "abc-shm", "size": 9000}},
)
recovered = OmniInterStageRequest.from_dict(req.to_dict())
assert recovered.stage_connector_refs[0] == {"name": "abc-shm", "size": 9000}
def test_int_keys_preserved_after_json_roundtrip(self):
"""JSON serializes dict keys as strings — from_dict must convert back to int."""
req = OmniInterStageRequest(
request_id="req-3",
original_prompt=None,
stage_connector_refs={0: "ref0", 1: "ref1"},
)
# Simulate JSON round-trip (Dynamo network boundary)
as_json = json.loads(json.dumps(req.to_dict()))
recovered = OmniInterStageRequest.from_dict(as_json)
assert 0 in recovered.stage_connector_refs
assert 1 in recovered.stage_connector_refs
assert isinstance(list(recovered.stage_connector_refs.keys())[0], int)
......@@ -13,6 +13,7 @@ try:
from dynamo.common.protocols.video_protocol import NvCreateVideoRequest, VideoNvExt
from dynamo.common.utils.output_modalities import RequestType
from dynamo.vllm.omni.omni_handler import EngineInputs, OmniHandler
from dynamo.vllm.omni.utils import build_original_prompt, parse_omni_request
except ImportError:
pytest.skip("vLLM omni dependencies not available", allow_module_level=True)
......@@ -164,3 +165,90 @@ class TestI2VEngineInputs:
empty = VideoNvExt()
assert empty.boundary_ratio is None
assert empty.guidance_scale_2 is None
class TestBuildOriginalPrompt:
"""build_original_prompt only carries prompt/negative_prompt/multi_modal_data.
height/width/num_inference_steps live in OmniDiffusionSamplingParams, not the prompt.
"""
def test_basic_fields(self):
result = build_original_prompt(
{"prompt": "a cat"}, nvext={}, height=512, width=512
)
assert result["prompt"] == "a cat"
assert result.get("negative_prompt") is None
assert "height" not in result
assert "width" not in result
def test_negative_prompt_from_request(self):
result = build_original_prompt(
{"prompt": "a cat", "negative_prompt": "blurry"},
nvext={"negative_prompt": "ignored"},
height=1024,
width=1024,
)
assert result["negative_prompt"] == "blurry"
def test_multi_modal_data_forwarded(self):
img = object()
result = build_original_prompt(
{"prompt": "x", "multi_modal_data": {"image": img}},
nvext={},
height=512,
width=512,
)
assert result["multi_modal_data"]["image"] is img
def test_no_inference_steps_or_guidance(self):
result = build_original_prompt(
{"prompt": "x"},
nvext={"num_inference_steps": 50, "guidance_scale": 7.5},
height=512,
width=512,
)
assert "num_inference_steps" not in result
assert "guidance_scale" not in result
class TestParseOmniRequest:
"""parse_omni_request: original_prompt only has prompt/negative_prompt,
geometry goes into sampling_params_list dict."""
def test_image_sampling_params_has_geometry(self):
request = {
"prompt": "a sunset",
"size": "512x512",
"output_modalities": ["image"],
}
result = parse_omni_request(request, ["image"])
sp = result["sampling_params_list"]
assert sp["height"] == 512
assert sp["width"] == 512
def test_image_original_prompt_no_geometry(self):
request = {
"prompt": "a sunset",
"size": "512x512",
"output_modalities": ["image"],
}
result = parse_omni_request(request, ["image"])
op = result["original_prompt"]
assert op["prompt"] == "a sunset"
assert "height" not in op
assert "width" not in op
def test_nvext_params_go_into_sampling_params_not_prompt(self):
request = {
"prompt": "x",
"size": "512x512",
"nvext": {"num_inference_steps": 30, "guidance_scale": 4.0},
}
result = parse_omni_request(request, ["image"])
sp = result["sampling_params_list"]
assert sp["num_inference_steps"] == 30
assert sp["guidance_scale"] == 4.0
op = result["original_prompt"]
assert "num_inference_steps" not in op
assert "guidance_scale" not in op
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for OmniStageRouter."""
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
try:
from dynamo.vllm.omni import stage_router
except ImportError:
pytest.skip("vLLM omni dependencies not available", allow_module_level=True)
pytestmark = [
pytest.mark.unit,
pytest.mark.vllm,
pytest.mark.gpu_1,
pytest.mark.pre_merge,
]
class _Chunk:
def __init__(self, payload):
self._payload = payload
def data(self):
return self._payload
class _StageClient:
def __init__(self, handler):
self._handler = handler
async def round_robin(self, request):
async def _gen():
payload = await self._handler(request)
yield _Chunk(payload)
return _gen()
def _make_stage_cfg(stage_id: int):
return SimpleNamespace(
stage_id=stage_id,
engine_args=SimpleNamespace(model_stage=f"stage{stage_id}"),
)
def _make_router(stage_configs, stage_clients, formatter=None):
router = stage_router.OmniStageRouter.__new__(stage_router.OmniStageRouter)
router.config = SimpleNamespace(output_modalities=None)
router.stage_configs = stage_configs
router.stage_clients = stage_clients
router._formatter = formatter or AsyncMock()
return router
def _patched_generate(router, request, request_id="req-1", request_type="chat"):
return (
patch(
"dynamo.vllm.omni.stage_router.parse_request_type",
return_value=(None, request_type),
),
patch("dynamo.vllm.omni.stage_router.uuid.uuid4", return_value=request_id),
)
# ── issue-004: opaque router ──────────────────────────────
@pytest.mark.asyncio
async def test_generate_passes_stage_connector_refs_opaquely():
"""Router must pass stage_connector_refs from stage output to next stage unchanged."""
stage1_received = {}
async def stage0_handler(request):
return {
"original_prompt": {"prompt": "hi"},
"stage_connector_refs": {"0": {"shm_name": "abc", "size": 42}},
"finished": True,
}
async def stage1_handler(request):
stage1_received.update(request)
return {"shm_meta": {"x": 1}, "finished": True}
mock_formatter = AsyncMock()
mock_formatter.format.return_value = {"finished": True}
router = _make_router(
stage_configs=[_make_stage_cfg(0), _make_stage_cfg(1)],
stage_clients={
"stage0": _StageClient(stage0_handler),
"stage1": _StageClient(stage1_handler),
},
formatter=mock_formatter,
)
p1, p2 = _patched_generate(router, {"prompt": "x"})
with p1, p2:
with patch.object(
stage_router, "shm_deserialize", return_value=SimpleNamespace()
):
[c async for c in router.generate({"prompt": "x"}, None)]
# Router must forward stage_connector_refs and original_prompt verbatim — never inspect them.
assert stage1_received["stage_connector_refs"] == {
"0": {"shm_name": "abc", "size": 42}
}
assert stage1_received["original_prompt"] == {"prompt": "hi"}
assert stage1_received["request_id"] == "req-1"
# 'finished' must be stripped — it is a router signal, not a stage protocol field.
assert "finished" not in stage1_received
@pytest.mark.asyncio
async def test_generate_concurrent_requests_have_independent_connector_refs():
"""Concurrent requests must carry independent stage_connector_refs (no cross-leakage)."""
stage1_refs_by_request: dict = {}
event = asyncio.Event()
async def stage0_handler(request):
rid = request["request_id"]
return {
"original_prompt": {"prompt": "x"},
"stage_connector_refs": {"0": f"ref-for-{rid}"},
"finished": True,
}
async def stage1_handler(request):
rid = request["request_id"]
if rid == "req-A":
await event.wait()
else:
event.set()
stage1_refs_by_request[rid] = request.get("stage_connector_refs")
return {"shm_meta": {"x": 1}, "finished": True}
mock_formatter = AsyncMock()
mock_formatter.format.return_value = {"finished": True}
router = _make_router(
stage_configs=[_make_stage_cfg(0), _make_stage_cfg(1)],
stage_clients={
"stage0": _StageClient(stage0_handler),
"stage1": _StageClient(stage1_handler),
},
formatter=mock_formatter,
)
async def run_one(request_id):
with patch(
"dynamo.vllm.omni.stage_router.parse_request_type",
return_value=(None, "chat"),
):
with patch(
"dynamo.vllm.omni.stage_router.uuid.uuid4", return_value=request_id
):
with patch.object(
stage_router, "shm_deserialize", return_value=SimpleNamespace()
):
return [c async for c in router.generate({"prompt": "x"}, None)]
await asyncio.gather(run_one("req-A"), run_one("req-B"))
assert stage1_refs_by_request["req-A"] == {"0": "ref-for-req-A"}
assert stage1_refs_by_request["req-B"] == {"0": "ref-for-req-B"}
@pytest.mark.asyncio
async def test_generate_stage_error_stops_pipeline():
"""Error from any stage must immediately stop the pipeline; later stages must not run."""
stage1_called = False
async def stage0_handler(request):
return {"error": "thinker exploded", "finished": True}
async def stage1_handler(request):
nonlocal stage1_called
stage1_called = True
return {"shm_meta": {"x": 1}, "finished": True}
router = _make_router(
stage_configs=[_make_stage_cfg(0), _make_stage_cfg(1)],
stage_clients={
"stage0": _StageClient(stage0_handler),
"stage1": _StageClient(stage1_handler),
},
)
p1, p2 = _patched_generate(router, {"prompt": "x"})
with p1, p2:
chunks = [c async for c in router.generate({"prompt": "x"}, None)]
assert chunks == [{"error": "thinker exploded", "finished": True}]
assert not stage1_called
# ── existing tests (formatting + error paths) ────────────
@pytest.mark.asyncio
async def test_generate_delegates_formatting_to_output_formatter():
"""Final stage output should be deserialized and passed to OutputFormatter."""
fake_result = SimpleNamespace(final_output_type="image")
mock_formatter = AsyncMock()
mock_formatter.format.return_value = {"data": [{"b64_json": "abc"}]}
async def stage0_handler(request):
return {"shm_meta": {"some": "meta"}, "finished": True}
router = _make_router(
stage_configs=[_make_stage_cfg(0)],
stage_clients={"stage0": _StageClient(stage0_handler)},
formatter=mock_formatter,
)
request = {"prompt": "x", "response_format": "b64_json"}
with patch.object(stage_router, "shm_deserialize", return_value=fake_result):
with patch(
"dynamo.vllm.omni.stage_router.parse_request_type",
return_value=(None, "image_generation"),
):
with patch(
"dynamo.vllm.omni.stage_router.uuid.uuid4", return_value="req-fmt"
):
chunks = [c async for c in router.generate(request, context=None)]
assert chunks == [{"data": [{"b64_json": "abc"}]}]
mock_formatter.format.assert_awaited_once_with(
fake_result,
"req-fmt",
request_type="image_generation",
response_format="b64_json",
)
@pytest.mark.asyncio
async def test_generate_yields_error_when_no_shm_meta():
"""When final stage returns no shm_meta, generate yields an error."""
async def stage0_handler(request):
return {"finished": True}
router = _make_router(
stage_configs=[_make_stage_cfg(0)],
stage_clients={"stage0": _StageClient(stage0_handler)},
)
with patch(
"dynamo.vllm.omni.stage_router.parse_request_type",
return_value=(None, "chat"),
):
with patch("dynamo.vllm.omni.stage_router.uuid.uuid4", return_value="r"):
chunks = [c async for c in router.generate({"prompt": "x"}, context=None)]
assert chunks == [{"error": "No SHM output from final stage", "finished": True}]
# ── issue-007: router forwards raw request to stage 0 ────────────
@pytest.mark.asyncio
async def test_generate_forwards_raw_request_to_stage0():
"""Stage 0 must receive the raw request fields + request_id (no router parsing)."""
stage0_received = {}
async def stage0_handler(request):
stage0_received.update(request)
return {"shm_meta": {"x": 1}, "finished": True}
mock_formatter = AsyncMock()
mock_formatter.format.return_value = {"finished": True}
router = _make_router(
stage_configs=[_make_stage_cfg(0)],
stage_clients={"stage0": _StageClient(stage0_handler)},
formatter=mock_formatter,
)
request = {
"prompt": "a dog",
"size": "832x480",
"nvext": {"num_inference_steps": 30},
}
with patch(
"dynamo.vllm.omni.stage_router.parse_request_type",
return_value=(None, "video_generation"),
):
with patch("dynamo.vllm.omni.stage_router.uuid.uuid4", return_value="req-raw"):
with patch.object(
stage_router, "shm_deserialize", return_value=SimpleNamespace()
):
[c async for c in router.generate(request, None)]
assert stage0_received["request_id"] == "req-raw"
assert stage0_received["prompt"] == "a dog"
assert stage0_received["size"] == "832x480"
assert stage0_received["nvext"] == {"num_inference_steps": 30}
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for OmniStageWorker.
No GPU, no vllm_omni — uses mock StageEngine matching AsyncOmni.generate() signature.
"""
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
try:
from dynamo.vllm.omni.stage_worker import OmniStageWorker, _Proxy
from dynamo.vllm.omni.utils import _build_sampling_params
except ImportError:
pytest.skip("vLLM omni dependencies not available", allow_module_level=True)
pytestmark = [
pytest.mark.unit,
pytest.mark.vllm,
pytest.mark.gpu_1,
pytest.mark.pre_merge,
]
class _MockEngine:
"""Satisfies StageEngine Protocol — matches AsyncOmni.generate() signature."""
def __init__(self, output=None):
self.received_prompt = None
self.received_request_id = None
self.received_sampling_params_list = None
self._output = output or {"output": "mock", "finished": True}
def generate(self, prompt, request_id="", *, sampling_params_list=None):
self.received_prompt = prompt
self.received_request_id = request_id
self.received_sampling_params_list = sampling_params_list
async def _gen():
yield self._output
return _gen()
class _ErrorEngine:
def generate(self, prompt, request_id="", *, sampling_params_list=None):
async def _gen():
raise RuntimeError("engine exploded")
yield # make it an async generator
return _gen()
class _MockContext:
def id(self):
return "test-req-id"
def _make_stage_config(**overrides):
defaults = dict(
stage_type="llm",
final_output=False,
final_output_type="text",
engine_input_source=[],
)
defaults.update(overrides)
return SimpleNamespace(**defaults)
def _make_worker(engine=None, stage_config=None, connectors=None, stage_id=0):
return OmniStageWorker(
engine=engine or _MockEngine(),
stage_config=stage_config or _make_stage_config(),
connectors=connectors or {},
stage_id=stage_id,
)
@pytest.mark.asyncio
async def test_direct_input_path():
"""Stage 0 direct path: engine receives the full request dict as prompt."""
engine = _MockEngine()
worker = _make_worker(engine=engine)
request = {"engine_inputs": {"prompt": "hello"}, "sampling_params_list": None}
chunks = [chunk async for chunk in worker.generate(request, _MockContext())]
# Direct path (no request_id, no stage_connector_refs) passes the whole request as prompt.
assert engine.received_prompt == request
assert any("shm_meta" in c for c in chunks)
@pytest.mark.asyncio
async def test_stage_connector_refs_input_path():
"""Stage N>0: engine receives output fetched from connector via stage_connector_refs."""
engine = _MockEngine()
fetched_prompt = {"prior_token_ids": [1, 2, 3]}
in_connector = MagicMock()
in_connector.get.return_value = {"engine_inputs": fetched_prompt}
out_connector = MagicMock()
out_connector.put.return_value = (True, 0, {"name": "ref1", "size": 10})
worker = _make_worker(
engine=engine,
connectors={("0", "1"): in_connector, ("1", "2"): out_connector},
stage_id=1,
)
request = {
"request_id": "req-1",
"original_prompt": {"prompt": "hello"},
"stage_connector_refs": {"0": {"name": "ref0", "size": 5}},
}
chunks = [chunk async for chunk in worker.generate(request, _MockContext())]
in_connector.get.assert_called_once_with(
"0", "1", "req-1", metadata={"name": "ref0", "size": 5}
)
assert engine.received_prompt == fetched_prompt
assert len(chunks) == 1
assert chunks[0]["stage_connector_refs"]["1"] == {"name": "ref1", "size": 10}
assert chunks[0]["stage_connector_refs"]["0"] == {"name": "ref0", "size": 5}
assert chunks[0]["original_prompt"] == {"prompt": "hello"}
@pytest.mark.asyncio
async def test_stage_connector_refs_with_processor():
"""Stage N>0 with processor: processor receives stage_list built from connector output."""
engine = _MockEngine()
fetched_output = {"latents": [0.1, 0.2]}
processed_prompt = {"diffusion_input": True}
in_connector = MagicMock()
in_connector.get.return_value = {"engine_inputs": fetched_output}
out_connector = MagicMock()
out_connector.put.return_value = (True, 0, {"name": "ref1"})
processor_calls = []
def mock_processor(stage_list, engine_input_source, original_prompts, requires_mm):
processor_calls.append(
{
"stage_list": stage_list,
"engine_input_source": engine_input_source,
"original_prompts": original_prompts,
}
)
return [processed_prompt]
cfg = _make_stage_config(
stage_type="llm",
final_output=False,
custom_process_input_func=None,
engine_input_source=[0],
requires_multimodal_data=False,
)
worker = OmniStageWorker(
engine=engine,
stage_config=cfg,
connectors={("0", "1"): in_connector, ("1", "2"): out_connector},
stage_id=1,
)
worker._processor = mock_processor
request = {
"request_id": "req-proc",
"original_prompt": {"prompt": "hi", "height": 480},
"stage_connector_refs": {"0": {"name": "ref0"}},
}
chunks = [chunk async for chunk in worker.generate(request, _MockContext())]
assert len(processor_calls) == 1
assert processor_calls[0]["stage_list"][0].engine_outputs == [fetched_output]
assert processor_calls[0]["original_prompts"] == [{"prompt": "hi", "height": 480}]
assert engine.received_prompt == processed_prompt
assert chunks[0]["stage_connector_refs"]["1"] == {"name": "ref1"}
@pytest.mark.asyncio
async def test_engine_error_yields_error_chunk():
"""Engine raises → yields {error: ..., finished: True}, no crash."""
worker = _make_worker(engine=_ErrorEngine())
request = {"engine_inputs": {"prompt": "hello"}}
chunks = [chunk async for chunk in worker.generate(request, _MockContext())]
assert any("error" in c for c in chunks)
assert any(c.get("finished") for c in chunks)
@pytest.mark.asyncio
async def test_connector_put_failure_yields_error():
"""connector.put() returning ok=False → yields error, stops."""
mock_connector = MagicMock()
mock_connector.get.return_value = {"engine_inputs": {"x": 1}}
mock_connector.put.return_value = (False, 0, {})
worker = _make_worker(
connectors={("1", "2"): mock_connector},
stage_id=1,
)
request = {
"request_id": "req-fail",
"stage_connector_refs": {"0": {"name": "ref0"}},
}
with patch.object(
worker, "_fetch_stage_inputs", return_value=[_Proxy(engine_outputs=[{"x": 1}])]
):
chunks = [chunk async for chunk in worker.generate(request, _MockContext())]
assert chunks == [{"error": "connector.put() failed", "finished": True}]
# ── _fetch_stage_inputs method unit tests ──────────────────
def _make_worker_at_stage(stage_id, connectors, engine_input_source=None):
cfg = _make_stage_config(engine_input_source=engine_input_source or [stage_id - 1])
return OmniStageWorker(
engine=_MockEngine(),
stage_config=cfg,
connectors=connectors,
stage_id=stage_id,
)
def test_fetch_stage_inputs_calls_correct_connector():
meta0 = {"name": "ref0"}
connector = MagicMock()
connector.get.return_value = {"engine_inputs": {"tok": [1, 2]}}
worker = _make_worker_at_stage(
1, connectors={("0", "1"): connector}, engine_input_source=[0]
)
result = worker._fetch_stage_inputs({0: meta0}, "r1")
connector.get.assert_called_once_with("0", "1", "r1", metadata=meta0)
assert result is not None
assert result[0].engine_outputs == [{"tok": [1, 2]}]
def test_fetch_stage_inputs_raises_on_missing_connector():
worker = _make_worker_at_stage(1, connectors={}, engine_input_source=[0])
with pytest.raises(RuntimeError, match="no connector for edge"):
worker._fetch_stage_inputs({0: {"name": "ref0"}}, "r1")
def test_fetch_stage_inputs_raises_on_missing_ref():
worker = _make_worker_at_stage(
1, connectors={("0", "1"): MagicMock()}, engine_input_source=[0]
)
with pytest.raises(RuntimeError, match="no connector ref"):
worker._fetch_stage_inputs({}, "r1") # ref for stage 0 missing
def test_build_sampling_params_user_overrides_yaml_defaults():
"""User overrides applied on top of YAML defaults via setattr; unspecified keys preserved."""
stage_config = SimpleNamespace(
stage_type="diffusion",
default_sampling_params={
"num_inference_steps": 20,
"guidance_scale": 5.0,
"height": 480,
"width": 832,
},
)
result = _build_sampling_params(
stage_config,
{"num_inference_steps": 50},
)
assert result is not None
sp = result[0]
assert sp.num_inference_steps == 50 # user override wins
assert sp.guidance_scale == 5.0 # YAML default preserved
def test_build_sampling_params_no_defaults_returns_none():
"""No default_sampling_params on stage_config -> returns None."""
stage_config = SimpleNamespace(stage_type="llm")
assert _build_sampling_params(stage_config, None) is None
assert _build_sampling_params(stage_config, {}) is None
@pytest.mark.asyncio
async def test_image_request_with_default_sampling_params():
"""Image stage with default_sampling_params builds typed params from YAML defaults + overrides."""
engine = _MockEngine()
worker = OmniStageWorker(
engine=engine,
stage_config=_make_stage_config(
stage_type="diffusion",
final_output=True,
default_sampling_params={
"num_inference_steps": 20,
"guidance_scale": 1.5,
"height": 1024,
"width": 1024,
},
),
connectors={},
stage_id=0,
output_modalities=["image"],
)
request = {
"request_id": "img-req-1",
"prompt": "a red apple",
"size": "1024x1024",
}
chunks = [chunk async for chunk in worker.generate(request, _MockContext())]
assert not any("error" in c for c in chunks)
assert engine.received_sampling_params_list is not None
@pytest.mark.asyncio
async def test_sampling_params_propagate_in_stage_output():
"""Non-final stage must include sampling_params_list in its output for downstream stages."""
engine = _MockEngine()
in_connector = MagicMock()
in_connector.get.return_value = {"engine_inputs": {"latents": [1, 2]}}
out_connector = MagicMock()
out_connector.put.return_value = (True, 0, {"name": "ref1"})
# Stage 1: non-final, receives stage_connector_refs from stage 0
worker = _make_worker(
engine=engine,
connectors={("0", "1"): in_connector, ("1", "2"): out_connector},
stage_id=1,
stage_config=_make_stage_config(final_output=False),
)
request = {
"request_id": "req-sp",
"original_prompt": {"prompt": "hi"},
"stage_connector_refs": {"0": {"name": "ref0"}},
"sampling_params_list": {
"num_inference_steps": 42,
"height": 480,
"width": 832,
},
}
with patch(
"dynamo.vllm.omni.stage_worker._build_sampling_params", return_value=None
):
chunks = [chunk async for chunk in worker.generate(request, _MockContext())]
assert len(chunks) == 1
assert chunks[0].get("sampling_params_list") == {
"num_inference_steps": 42,
"height": 480,
"width": 832,
}
......@@ -221,7 +221,7 @@ class TestDiffusionFormatterVideo:
f = _make_diffusion_formatter()
with patch(
"dynamo.common.utils.video_utils.normalize_video_frames",
"dynamo.vllm.omni.output_formatter.normalize_video_frames",
side_effect=RuntimeError("boom"),
):
chunk = await f._encode_video([MagicMock()], "req-1", fps=16)
......
......@@ -35,7 +35,7 @@ The `--output-modalities` flag determines which endpoint(s) the worker registers
| Modality | Models |
|---|---|
| Text-to-Text | `Qwen/Qwen2.5-Omni-7B` |
| Text-to-Image | `Qwen/Qwen-Image`, `AIDC-AI/Ovis-Image-7B` |
| Text-to-Image | `Qwen/Qwen-Image`, `AIDC-AI/Ovis-Image-7B`, `zai-org/GLM-Image` (disagg) |
| Text-to-Video | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers`, `Wan-AI/Wan2.2-T2V-A14B-Diffusers` |
| Image-to-Video | `Wan-AI/Wan2.2-TI2V-5B-Diffusers`, `Wan-AI/Wan2.2-I2V-A14B-Diffusers` |
| Text-to-Audio (TTS) | `Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice`, `Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign` |
......@@ -318,6 +318,95 @@ For S3 credential configuration, set the standard AWS environment variables (`AW
Omni pipelines are configured via YAML stage configs. See [`examples/backends/vllm/launch/stage_configs/single_stage_llm.yaml`](https://github.com/ai-dynamo/dynamo/blob/main/examples/backends/vllm/launch/stage_configs/single_stage_llm.yaml) for an example. For full documentation on stage config format and multi-stage pipelines, refer to the [vLLM-Omni Stage Configs documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/).
## Disaggregated Multi-Stage Serving
For models with multiple pipeline stages (e.g., AR + Diffusion), Dynamo supports disaggregated serving where each stage runs as an independent process on its own GPU. This enables independent scaling, GPU isolation, and multi-worker replicas per stage.
### Architecture
Each stage runs as an independent process on its own GPU. A lightweight router coordinates them, acting as a **pure message broker** — it never inspects or transforms inter-stage data.
```mermaid
flowchart LR
client(Client) --> frontend(Frontend)
frontend --> router(Router)
router -->|request| s0(Stage 0)
s0 -->|ref| router
router -->|ref| s1(Stage 1)
s1 -->|result| router
router --> frontend --> client
s0 <-->|bulk data| conn[(Connector)]
conn <--> s1
```
**How it works:**
- The router sends the initial request to Stage 0 and receives back a lightweight connector reference (pointer to the output in shared memory).
- The router forwards that reference — unchanged — to Stage 1. It never reads the bulk data.
- Each stage fetches its inputs from the connector, runs any model-specific processor (e.g., `ar2diffusion`, `thinker2talker`), then runs its engine.
- The final stage's result goes back to the router for formatting and response.
- Connector references accumulate as the pipeline progresses, so any stage can access outputs from all previous stages.
### Data Flow
```mermaid
sequenceDiagram
participant C as Client
participant R as Router
participant S0 as Stage 0 (AR)
participant SHM as Connector
participant S1 as Stage 1 (DiT)
C->>R: POST /v1/images/generations
R->>S0: request + prompt
S0->>SHM: store output
S0-->>R: connector ref
R->>S1: connector ref (opaque)
S1->>SHM: fetch output
S1->>S1: processor → engine
S1-->>R: result
R-->>C: {"data": [...]}
```
### Quick Start: GLM-Image (2-Stage, 2 GPUs)
GLM-Image is a 2-stage text-to-image model with an AR stage (generates prior token IDs) and a DiT stage (diffusion denoising + VAE decode). The built-in vLLM-Omni stage config already assigns each stage to a separate GPU.
```bash
bash examples/backends/vllm/launch/disagg_omni_glm_image.sh
```
Test:
```bash
curl -s http://localhost:8000/v1/images/generations \
-H "Content-Type: application/json" \
-d '{
"model": "zai-org/GLM-Image",
"prompt": "A red apple on a white table",
"size": "1024x1024",
"response_format": "url"
}' | jq
```
### Scaling Stage Replicas
Each stage registers independently with Dynamo's service discovery. To scale a bottleneck stage, launch additional workers with the same `--stage-id` on different GPUs — the router automatically load-balances across all replicas for that stage. Other stages are unaffected.
### Tested Models
| Model | Stages | Output | Stage Config |
|---|---|---|---|
| GLM-Image (`zai-org/GLM-Image`) | AR -> DiT | Image | `glm_image.yaml` (built-in) |
### CLI Flags (Disaggregated Mode)
| Flag | Description |
|---|---|
| `--stage-id <int>` | Run as a single-stage worker for the given stage ID. Requires `--stage-configs-path`. |
| `--omni-router` | Run as the stage router. Requires `--stage-configs-path`. Mutually exclusive with `--stage-id`. |
| `--stage-configs-path <path>` | Path to vLLM-Omni stage configuration YAML. |
## Current Limitations
- Image input is supported only for I2V via `input_reference` in `/v1/videos`. Other endpoints accept text prompts only.
......@@ -325,3 +414,4 @@ Omni pipelines are configured via YAML stage configs. See [`examples/backends/vl
- Each worker supports a single output modality at a time.
- Audio: streaming (`stream: true`) is not yet supported.
- Audio: Base task (voice cloning) is not yet supported.
- Disaggregated mode: `async_chunk=true` (streaming between stages) is not yet supported.
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# 2-stage disaggregated GLM-Image text-to-image generation.
# Stage 0: AR (GPU 0) — generates prior_token_ids
# Stage 1: DiT (GPU 1) — diffusion denoising + VAE decode → image
# Router: orchestrates the 2-stage pipeline, formats response
set -e
trap 'echo Cleaning up...; kill 0' EXIT
SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
source "$SCRIPT_DIR/../../../common/launch_utils.sh"
MODEL="${MODEL:-zai-org/GLM-Image}"
# Resolve vllm-omni's built-in GLM-Image stage config
if [ -z "$STAGE_CONFIG" ]; then
STAGE_CONFIG="$(python -c "import vllm_omni, os; print(os.path.join(os.path.dirname(vllm_omni.__file__), 'model_executor/stage_configs/glm_image.yaml'))" 2>/dev/null | tail -1)"
fi
EXTRA_ARGS=()
while [[ $# -gt 0 ]]; do
case $1 in
--model) MODEL="$2"; shift 2 ;;
--stage-configs-path) STAGE_CONFIG="$2"; shift 2 ;;
*) EXTRA_ARGS+=("$1"); shift ;;
esac
done
HTTP_PORT="${DYN_HTTP_PORT:-8000}"
# Use an isolated namespace by default to avoid stale discovery/model-card
# collisions from previous disaggregated runs (which can route directly to dit).
if [ -z "${DYN_NAMESPACE:-}" ]; then
export DYN_NAMESPACE="dynamo-omni-glm-$(date +%s)"
fi
echo "Namespace: ${DYN_NAMESPACE}"
print_launch_banner --no-curl "Disaggregated GLM-Image (2-stage, 2 GPUs)" "$MODEL" "$HTTP_PORT"
print_curl_footer <<CURL
curl -s http://localhost:${HTTP_PORT}/v1/images/generations \\
-H 'Content-Type: application/json' \\
-d '{
"model": "${MODEL}",
"prompt": "a red apple on a white table",
"size": "1024x1024"
}' | jq
CURL
export FLASHINFER_DISABLE_VERSION_CHECK=1
# Stage 0: AR worker (GPU 0) — generates prior_token_ids
echo "Starting Stage 0 (AR)..."
CUDA_VISIBLE_DEVICES=0 DYN_SYSTEM_PORT=8081 \
python -m dynamo.vllm.omni \
--model "$MODEL" \
--stage-id 0 \
--stage-configs-path "$STAGE_CONFIG" \
--output-modalities image \
--media-output-fs-url file:///tmp/dynamo_media \
"${EXTRA_ARGS[@]}" &
sleep 20
# Stage 1: DiT worker (GPU 1) — diffusion denoising + VAE decode
echo "Starting Stage 1 (DiT)..."
CUDA_VISIBLE_DEVICES=1 DYN_SYSTEM_PORT=8082 \
python -m dynamo.vllm.omni \
--model "$MODEL" \
--stage-id 1 \
--stage-configs-path "$STAGE_CONFIG" \
--output-modalities image \
--media-output-fs-url file:///tmp/dynamo_media \
"${EXTRA_ARGS[@]}" &
sleep 20
# Router — discovers stage workers, orchestrates pipeline, formats response
echo "Starting Router..."
DYN_SYSTEM_PORT=8083 \
python -m dynamo.vllm.omni \
--model "$MODEL" \
--omni-router \
--stage-configs-path "$STAGE_CONFIG" \
--output-modalities image \
--media-output-fs-url file:///tmp/dynamo_media \
"${EXTRA_ARGS[@]}" &
sleep 5
# Frontend
echo "Starting Frontend..."
python -m dynamo.frontend &
wait_any_exit
......@@ -144,6 +144,32 @@ class VLLMOmniConfig(EngineConfig):
vllm_omni_configs = {
"omni_disagg_t2i": VLLMOmniConfig(
name="omni_disagg_t2i",
directory=vllm_dir,
script_name="disagg_omni_glm_image.sh",
marks=[
pytest.mark.gpu_2,
pytest.mark.pre_merge,
pytest.mark.timeout(1200),
pytest.mark.skip(
reason="zai-org/GLM-Image requires ~23GB per GPU across 2 GPUs, exceeds CI capacity"
),
],
model="zai-org/GLM-Image",
request_payloads=[
ImageGenerationPayload(
body={
"prompt": "A red apple on a white table",
"size": "1024x1024",
"response_format": "url",
},
repeat_count=1,
expected_response=[],
expected_log=[],
),
],
),
"omni_text": VLLMOmniConfig(
name="omni_text",
directory=vllm_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment