"lib/bindings/python/vscode:/vscode.git/clone" did not exist on "0fc5273c1187a44712849c73be48fa3f1f084da7"
Unverified Commit 658b0d5a authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

refactor: move omni post-processing to dedicated formatters / processors (#7746)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 12b4dec5
......@@ -7,22 +7,13 @@ Extracted from omni_handler.py to keep modality-specific logic separate.
OmniHandler holds an instance as ``self.audio`` (composition).
"""
import asyncio
import base64
import logging
import time
import uuid
from io import BytesIO
from typing import Any, Dict
from vllm_omni.inputs.data import OmniTextPrompt
from dynamo.common.protocols.audio_protocol import (
AudioData,
NvAudioSpeechResponse,
NvCreateAudioSpeechRequest,
)
from dynamo.common.storage import upload_to_fs
from dynamo.common.protocols.audio_protocol import NvCreateAudioSpeechRequest
from dynamo.common.utils.output_modalities import RequestType
logger = logging.getLogger(__name__)
......@@ -426,157 +417,3 @@ class AudioGenerationHandler:
"Failed to estimate TTS prompt length, using fallback 2048: %s", e
)
return 2048
# -- Audio output formatting ----------------------------------------------
def _extract_audio_tensor(self, mm_output: Dict[str, Any]) -> tuple:
"""Extract audio tensor and sample rate from multimodal_output dict."""
import numpy as np
import torch
audio_key = "audio" if "audio" in mm_output else "model_outputs"
audio_val = mm_output.get(audio_key)
if audio_val is None:
raise ValueError(
f"No audio data in multimodal_output. Keys: {list(mm_output.keys())}"
)
if isinstance(audio_val, list):
audio_val = torch.cat(audio_val, dim=-1)
if hasattr(audio_val, "float"):
audio_np = audio_val.float().detach().cpu().numpy()
elif isinstance(audio_val, np.ndarray):
audio_np = audio_val.astype(np.float32)
else:
audio_np = np.array(audio_val, dtype=np.float32)
if audio_np.ndim > 1:
audio_np = audio_np.squeeze()
sr_raw = mm_output.get("sr", 24000)
if isinstance(sr_raw, list):
sr_raw = sr_raw[-1] if sr_raw else 24000
sample_rate = sr_raw.item() if hasattr(sr_raw, "item") else int(sr_raw)
return audio_np, sample_rate
def _encode_audio(
self, audio_np, sample_rate: int, fmt: str = "wav", speed: float = 1.0
) -> tuple:
"""Encode a numpy float32 waveform to audio bytes.
Uses soundfile for multi-format support.
Applies speed adjustment via librosa if speed != 1.0.
"""
import soundfile as sf
if speed != 1.0:
try:
import librosa
audio_np = librosa.effects.time_stretch(y=audio_np, rate=speed)
except ImportError:
logger.warning("librosa not installed, ignoring speed adjustment")
fmt = (fmt or "wav").lower()
format_map = {
"wav": ("WAV", "audio/wav", {}),
"pcm": ("RAW", "audio/pcm", {"subtype": "PCM_16"}),
"flac": ("FLAC", "audio/flac", {}),
"mp3": ("MP3", "audio/mpeg", {}),
"aac": ("AAC", "audio/aac", {}),
"opus": ("OGG", "audio/ogg", {"subtype": "OPUS"}),
}
if fmt not in format_map:
logger.warning(f"Unsupported format '{fmt}', defaulting to wav")
fmt = "wav"
sf_format, media_type, kwargs = format_map[fmt]
buf = BytesIO()
sf.write(buf, audio_np, sample_rate, format=sf_format, **kwargs)
return buf.getvalue(), media_type
async def format_output(
self,
mm_output: Dict[str, Any],
request_id: str,
response_format: str | None = None,
request_type: RequestType = RequestType.AUDIO_GENERATION,
speed: float = 1.0,
) -> Dict[str, Any] | None:
"""Format multimodal audio output for the response."""
if not mm_output:
return NvAudioSpeechResponse(
id=request_id,
model=self.config.served_model_name or self.config.model,
status="failed",
created=int(time.time()),
error="No audio generated",
).model_dump()
try:
start_time = time.time()
audio_np, sample_rate = self._extract_audio_tensor(mm_output)
encode_fmt = (
"wav"
if response_format in (None, "url", "b64_json")
else response_format
)
assert encode_fmt is not None
audio_bytes, media_type = await asyncio.to_thread(
self._encode_audio, audio_np, sample_rate, encode_fmt, speed
)
logger.info(
f"Audio encoded for request {request_id}: "
f"{len(audio_np)} samples, sr={sample_rate}, "
f"{len(audio_bytes)} bytes {encode_fmt}"
)
inference_time = time.time() - start_time
if response_format == "url":
ext = encode_fmt if encode_fmt != "opus" else "ogg"
storage_path = f"audios/{request_id}/{uuid.uuid4()}.{ext}"
url = await upload_to_fs(
self.media_output_fs,
storage_path,
audio_bytes,
self.media_output_http_url,
)
audio_data_obj = AudioData(url=url)
else:
b64 = base64.b64encode(audio_bytes).decode("utf-8")
audio_data_obj = AudioData(b64_json=b64)
response = NvAudioSpeechResponse(
id=request_id,
object="audio.speech",
model=self.config.served_model_name or self.config.model,
status="completed",
progress=100,
created=int(time.time()),
data=[audio_data_obj],
inference_time_s=inference_time,
)
return response.model_dump()
except Exception as e:
logger.error(f"Failed to process audio for request {request_id}: {e}")
error_response = NvAudioSpeechResponse(
id=request_id,
object="audio.speech",
model=self.config.served_model_name or self.config.model,
status="failed",
progress=0,
created=int(time.time()),
data=[],
error=str(e),
)
return error_response.model_dump()
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import base64
import logging
import tempfile
import time
import uuid
from dataclasses import dataclass
from io import BytesIO
from typing import Any, AsyncGenerator, Dict, Optional, Union, cast
import PIL.Image
from diffusers.utils import export_to_video
from fsspec.implementations.dirfs import DirFileSystem
from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt
from dynamo._core import Context
from dynamo.common.multimodal import ImageLoader
from dynamo.common.protocols.audio_protocol import NvCreateAudioSpeechRequest
from dynamo.common.protocols.image_protocol import (
ImageData,
NvCreateImageRequest,
NvImagesResponse,
)
from dynamo.common.protocols.video_protocol import (
NvCreateVideoRequest,
NvVideosResponse,
VideoData,
)
from dynamo.common.storage import upload_to_fs
from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.common.protocols.image_protocol import NvCreateImageRequest
from dynamo.common.protocols.video_protocol import NvCreateVideoRequest
from dynamo.common.utils.output_modalities import RequestType, parse_request_type
from dynamo.common.utils.video_utils import (
compute_num_frames,
normalize_video_frames,
parse_size,
)
from dynamo.common.utils.video_utils import compute_num_frames, parse_size
from dynamo.llm.exceptions import EngineShutdown
from dynamo.vllm.omni.audio_handler import AudioGenerationHandler
from dynamo.vllm.omni.base_handler import BaseOmniHandler
from dynamo.vllm.omni.output_formatter import OutputFormatter
logger = logging.getLogger(__name__)
......@@ -104,6 +85,13 @@ class OmniHandler(BaseOmniHandler):
self.media_output_http_url = media_output_http_url
self._image_loader = ImageLoader()
self.output_formatter = OutputFormatter(
model_name=config.served_model_name or config.model,
media_fs=media_output_fs,
media_http_url=media_output_http_url,
default_fps=getattr(config, "default_video_fps", 16),
)
# Audio/TTS handler — composition, not inheritance.
self.audio = AudioGenerationHandler(
config=config,
......@@ -189,54 +177,22 @@ class OmniHandler(BaseOmniHandler):
async for stage_output in self.engine_client.generate(
**generate_kwargs,
):
if (
stage_output.final_output_type == "text"
and stage_output.request_output
):
chunk = self._format_text_chunk(
stage_output.request_output,
request_id,
previous_text,
)
if chunk:
output = stage_output.request_output.outputs[0]
previous_text = output.text
yield chunk
elif (
stage_output.final_output_type == "image"
and stage_output.images
):
# vllm-omni uses final_output_type="image" for both
# image and video diffusion outputs. Use the parsed
# request type to route to the correct formatter.
if inputs.request_type == RequestType.VIDEO_GENERATION:
chunk = await self._format_video_chunk(
stage_output.images,
request_id,
fps=inputs.fps,
)
else:
chunk = await self._format_image_chunk(
stage_output.images,
chunk = await self.output_formatter.format(
stage_output,
request_id,
response_format=inputs.response_format,
request_type=inputs.request_type,
)
if chunk:
yield chunk
elif stage_output.final_output_type == "audio":
mm_output = stage_output.multimodal_output
if mm_output:
chunk = await self.audio.format_output(
mm_output,
request_id,
fps=inputs.fps,
response_format=inputs.response_format,
request_type=inputs.request_type,
previous_text=previous_text,
speed=inputs.speed,
)
if chunk:
# Track text state for streaming delta
if (
stage_output.final_output_type == "text"
and stage_output.request_output
):
previous_text = stage_output.request_output.outputs[0].text
yield chunk
except EngineShutdown:
......@@ -404,220 +360,3 @@ class OmniHandler(BaseOmniHandler):
request_type=RequestType.VIDEO_GENERATION,
fps=fps,
)
async def _prepare_image_output(
self, images: list, request_id: str, response_format: str | None = None
) -> list:
"""Prepare image output for response.
Args:
images: List of PIL Image objects.
request_id: Unique request identifier.
response_format: Response format ("url" or "b64_json").
Returns:
List of image URLs or base64 data-URL strings.
"""
outlist = []
for img in images:
buffer = BytesIO()
img.save(buffer, format="PNG")
image_bytes = buffer.getvalue()
if response_format == "url":
storage_path = f"images/{request_id}/{uuid.uuid4()}.png"
url = await upload_to_fs(
self.media_output_fs,
storage_path,
image_bytes,
self.media_output_http_url,
)
outlist.append(url)
elif response_format == "b64_json" or response_format is None:
img_base64 = base64.b64encode(image_bytes).decode("utf-8")
data_url = f"data:image/png;base64,{img_base64}"
outlist.append(data_url)
else:
raise ValueError(f"Invalid response format: {response_format}")
return outlist
async def _format_image_chunk(
self,
images: list,
request_id: str,
response_format: str | None = None,
request_type: RequestType = RequestType.IMAGE_GENERATION,
) -> Dict[str, Any] | None:
"""Format image output for the appropriate endpoint response.
Args:
images: List of PIL Image objects generated by AsyncOmni engine.
request_id: Unique request identifier.
response_format: Response format (url, b64_json, None).
request_type: Request type (chat completion, image generation).
Returns:
Formatted response dict, or None if no images generated.
"""
if not images:
return self._error_chunk(request_id, "No images generated")
data_urls = await self._prepare_image_output(
images, request_id, response_format
)
if request_type == RequestType.CHAT_COMPLETION:
chunk = {
"id": request_id,
"created": int(time.time()),
"object": "chat.completion.chunk",
"model": self.config.served_model_name or self.config.model,
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": [
{"type": "image_url", "image_url": {"url": data_url}}
for data_url in data_urls
],
},
"finish_reason": "stop",
}
],
}
return chunk
elif request_type == RequestType.IMAGE_GENERATION:
image_data_list = []
for data_url in data_urls:
if response_format == "url":
image_data_list.append(ImageData(url=data_url))
elif response_format == "b64_json" or response_format is None:
if data_url.startswith("data:image"):
_, b64_part = data_url.split(",", 1)
image_data_list.append(ImageData(b64_json=b64_part))
else:
image_data_list.append(ImageData(b64_json=data_url))
else:
raise ValueError(f"Invalid response format: {response_format}")
output = NvImagesResponse(created=int(time.time()), data=image_data_list)
return output.model_dump()
else:
return None
async def _format_video_chunk(
self,
images: list,
request_id: str,
fps: int,
) -> Dict[str, Any] | None:
"""Convert diffusion output frames to MP4 and return as NvVideosResponse.
Args:
images: List of PIL Image frames from the diffusion stage.
request_id: Unique request identifier.
fps: Frames per second for the output video.
Returns:
``NvVideosResponse.model_dump()`` dict, or ``None`` if no frames.
"""
if not images:
return None
try:
start_time = time.time()
frame_list = normalize_video_frames(images)
logger.info(
f"Encoding {len(frame_list)} frames to MP4 for request {request_id} "
f"(fps={fps})"
)
# Encode frames to MP4 via temp file, then read bytes for upload
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=True) as tmp:
await asyncio.to_thread(export_to_video, frame_list, tmp.name, fps)
video_bytes = tmp.read()
# Upload via filesystem
storage_path = f"videos/{request_id}.mp4"
video_url = await upload_to_fs(
self.media_output_fs,
storage_path,
video_bytes,
self.media_output_http_url,
)
logger.info(f"Video uploaded to {video_url} for request {request_id}")
inference_time = time.time() - start_time
response = NvVideosResponse(
id=request_id,
object="video",
model=self.config.served_model_name or self.config.model,
status="completed",
progress=100,
created=int(time.time()),
data=[VideoData(url=video_url)],
inference_time_s=inference_time,
)
return response.model_dump()
except Exception as e:
logger.error(f"Failed to encode video for request {request_id}: {e}")
error_response = NvVideosResponse(
id=request_id,
object="video",
model=self.config.served_model_name or self.config.model,
status="failed",
progress=0,
created=int(time.time()),
data=[],
error=str(e),
)
return error_response.model_dump()
def _format_text_chunk(
self,
request_output,
request_id: str,
previous_text: str,
) -> Dict[str, Any] | None:
"""Format text output as OpenAI chat completion chunk."""
if not request_output.outputs:
return self._error_chunk(request_id, "No outputs from engine")
output = request_output.outputs[0]
# Calculate delta text (new text since last chunk)
delta_text = output.text[len(previous_text) :]
chunk = {
"id": request_id,
"created": int(time.time()),
"object": "chat.completion.chunk",
"model": self.config.served_model_name or self.config.model,
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": delta_text,
},
"finish_reason": (
normalize_finish_reason(output.finish_reason)
if output.finish_reason
else None
),
}
],
}
# Add usage on final chunk
if output.finish_reason:
chunk["usage"] = self._build_completion_usage(request_output)
return chunk
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Modality-specific output formatters for vLLM-Omni.
Extracted from OmniHandler and AudioGenerationHandler so that any consumer
(aggregated handler, disaggregated router, test harness) can format engine
output without creating an engine or loading model weights.
"""
import asyncio
import base64
import logging
import tempfile
import time
import uuid
from io import BytesIO
from typing import Any, Dict, Optional
from dynamo.common.utils.engine_response import normalize_finish_reason
logger = logging.getLogger(__name__)
class TextFormatter:
"""Formats LLM text output as OpenAI chat completion chunks."""
def __init__(self, model_name: str) -> None:
self._model_name = model_name
def format(
self,
request_output: Any,
request_id: str,
*,
previous_text: str = "",
) -> Dict[str, Any] | None:
if not request_output.outputs:
return _error_chunk(request_id, self._model_name, "No outputs from engine")
output = request_output.outputs[0]
delta_text = output.text[len(previous_text) :]
chunk: Dict[str, Any] = {
"id": request_id,
"created": int(time.time()),
"object": "chat.completion.chunk",
"model": self._model_name,
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": delta_text},
"finish_reason": (
normalize_finish_reason(output.finish_reason)
if output.finish_reason
else None
),
}
],
}
if output.finish_reason:
chunk["usage"] = _build_completion_usage(request_output)
return chunk
class DiffusionFormatter:
"""Formats diffusion output (images/video frames) for the frontend.
Handles both image and video — routes by request_type since vllm-omni
reports final_output_type="image" for all diffusion outputs.
"""
def __init__(
self,
model_name: str,
media_fs: Any,
media_http_url: Optional[str],
default_fps: int = 16,
) -> None:
self._model_name = model_name
self._media_fs = media_fs
self._media_http_url = media_http_url
self._default_fps = default_fps
async def format(
self, stage_output: Any, request_id: str, *, request_type: Any, **ctx: Any
) -> Dict[str, Any] | None:
images = (
stage_output.images if hasattr(stage_output, "images") else stage_output
)
if not images:
return None
from dynamo.common.utils.output_modalities import RequestType
if request_type == RequestType.VIDEO_GENERATION:
return await self._encode_video(
images, request_id, fps=ctx.get("fps", self._default_fps)
)
return await self._encode_image(
images,
request_id,
request_type=request_type,
response_format=ctx.get("response_format"),
)
async def _encode_video(
self, images: list, request_id: str, fps: int
) -> Dict[str, Any] | None:
from diffusers.utils.export_utils import export_to_video
from dynamo.common.protocols.video_protocol import NvVideosResponse, VideoData
from dynamo.common.storage import upload_to_fs
from dynamo.common.utils.video_utils import normalize_video_frames
try:
start_time = time.time()
frame_list = normalize_video_frames(images)
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=True) as tmp:
await asyncio.to_thread(export_to_video, frame_list, tmp.name, fps)
video_bytes = tmp.read()
video_url = await upload_to_fs(
self._media_fs,
f"videos/{request_id}.mp4",
video_bytes,
self._media_http_url,
)
return NvVideosResponse(
id=request_id,
object="video",
model=self._model_name,
status="completed",
progress=100,
created=int(time.time()),
data=[VideoData(url=video_url)],
inference_time_s=time.time() - start_time,
).model_dump()
except Exception as e:
logger.error("Failed to encode video for request %s: %s", request_id, e)
return NvVideosResponse(
id=request_id,
object="video",
model=self._model_name,
status="failed",
progress=0,
created=int(time.time()),
data=[],
error=str(e),
).model_dump()
async def _encode_image(
self,
images: list,
request_id: str,
*,
request_type: Any,
response_format: Optional[str] = None,
) -> Dict[str, Any] | None:
from dynamo.common.protocols.image_protocol import ImageData, NvImagesResponse
from dynamo.common.utils.output_modalities import RequestType
if not images:
return _error_chunk(request_id, self._model_name, "No images generated")
data_urls = await self._prepare_images(images, request_id, response_format)
if request_type == RequestType.CHAT_COMPLETION:
return {
"id": request_id,
"created": int(time.time()),
"object": "chat.completion.chunk",
"model": self._model_name,
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": [
{"type": "image_url", "image_url": {"url": u}}
for u in data_urls
],
},
"finish_reason": "stop",
}
],
}
if request_type == RequestType.IMAGE_GENERATION:
image_data_list = []
for data_url in data_urls:
if response_format == "url":
image_data_list.append(ImageData(url=data_url))
elif response_format == "b64_json" or response_format is None:
b64 = (
data_url.split(",", 1)[1]
if data_url.startswith("data:")
else data_url
)
image_data_list.append(ImageData(b64_json=b64))
else:
raise ValueError(f"Invalid response format: {response_format}")
return NvImagesResponse(
created=int(time.time()), data=image_data_list
).model_dump()
return None
async def _prepare_images(
self, images: list, request_id: str, response_format: Optional[str] = None
) -> list:
from dynamo.common.storage import upload_to_fs
outlist = []
for img in images:
buf = BytesIO()
img.save(buf, format="PNG")
image_bytes = buf.getvalue()
if response_format == "url":
url = await upload_to_fs(
self._media_fs,
f"images/{request_id}/{uuid.uuid4()}.png",
image_bytes,
self._media_http_url,
)
outlist.append(url)
elif response_format == "b64_json" or response_format is None:
outlist.append(
f"data:image/png;base64,{base64.b64encode(image_bytes).decode()}"
)
else:
raise ValueError(f"Invalid response format: {response_format}")
return outlist
class AudioFormatter:
"""Formats audio multimodal_output → NvAudioSpeechResponse."""
def __init__(
self, model_name: str, media_fs: Any, media_http_url: Optional[str]
) -> None:
from dynamo.common.protocols.audio_protocol import AudioData
self._model_name = model_name
self._media_fs = media_fs
self._media_http_url = media_http_url
self._AudioData = AudioData
async def format(
self, stage_output: Any, request_id: str, **ctx: Any
) -> Dict[str, Any] | None:
mm_output = (
stage_output.multimodal_output
if hasattr(stage_output, "multimodal_output")
else stage_output
)
if not mm_output:
return self._error_response(request_id, "No audio generated")
response_format = ctx.get("response_format")
speed = ctx.get("speed", 1.0)
try:
start_time = time.time()
audio_np, sample_rate = self._extract_audio_tensor(mm_output)
encode_fmt = (
"wav"
if response_format in (None, "url", "b64_json")
else response_format
)
assert encode_fmt is not None
audio_bytes, media_type = await asyncio.to_thread(
self._encode_audio, audio_np, sample_rate, encode_fmt, speed
)
logger.info(
"Audio encoded for request %s: %d samples, sr=%d, %d bytes %s",
request_id,
len(audio_np),
sample_rate,
len(audio_bytes),
encode_fmt,
)
if response_format == "url":
from dynamo.common.storage import upload_to_fs
ext = encode_fmt if encode_fmt != "opus" else "ogg"
url = await upload_to_fs(
self._media_fs,
f"audios/{request_id}/{uuid.uuid4()}.{ext}",
audio_bytes,
self._media_http_url,
)
audio_data_obj = self._AudioData(url=url)
else:
audio_data_obj = self._AudioData(
b64_json=base64.b64encode(audio_bytes).decode()
)
from dynamo.common.protocols.audio_protocol import NvAudioSpeechResponse
return NvAudioSpeechResponse(
id=request_id,
object="audio.speech",
model=self._model_name,
status="completed",
progress=100,
created=int(time.time()),
data=[audio_data_obj],
inference_time_s=time.time() - start_time,
).model_dump()
except Exception as e:
logger.error("Failed to process audio for request %s: %s", request_id, e)
return self._error_response(request_id, str(e))
def _extract_audio_tensor(self, mm_output: Dict[str, Any]) -> tuple:
import numpy as np
import torch
audio_key = "audio" if "audio" in mm_output else "model_outputs"
audio_val = mm_output.get(audio_key)
if audio_val is None:
raise ValueError(
f"No audio data in multimodal_output. Keys: {list(mm_output.keys())}"
)
if isinstance(audio_val, list):
audio_val = torch.cat(audio_val, dim=-1)
if hasattr(audio_val, "float"):
audio_np = audio_val.float().detach().cpu().numpy()
elif isinstance(audio_val, np.ndarray):
audio_np = audio_val.astype(np.float32)
else:
audio_np = np.array(audio_val, dtype=np.float32)
if audio_np.ndim > 1:
audio_np = audio_np.squeeze()
sr_raw = mm_output.get("sr", 24000)
if isinstance(sr_raw, list):
sr_raw = sr_raw[-1] if sr_raw else 24000
sample_rate = sr_raw.item() if hasattr(sr_raw, "item") else int(sr_raw)
return audio_np, sample_rate
def _encode_audio(
self, audio_np: Any, sample_rate: int, fmt: str = "wav", speed: float = 1.0
) -> tuple:
import soundfile as sf
if speed != 1.0:
try:
import librosa
audio_np = librosa.effects.time_stretch(y=audio_np, rate=speed)
except ImportError:
logger.warning("librosa not installed, ignoring speed adjustment")
fmt = (fmt or "wav").lower()
format_map = {
"wav": ("WAV", "audio/wav", {}),
"pcm": ("RAW", "audio/pcm", {"subtype": "PCM_16"}),
"flac": ("FLAC", "audio/flac", {}),
"mp3": ("MP3", "audio/mpeg", {}),
"aac": ("AAC", "audio/aac", {}),
"opus": ("OGG", "audio/ogg", {"subtype": "OPUS"}),
}
if fmt not in format_map:
logger.warning("Unsupported format '%s', defaulting to wav", fmt)
fmt = "wav"
sf_format, media_type, kwargs = format_map[fmt]
buf = BytesIO()
sf.write(buf, audio_np, sample_rate, format=sf_format, **kwargs)
return buf.getvalue(), media_type
def _error_response(self, request_id: str, error: str) -> Dict[str, Any]:
from dynamo.common.protocols.audio_protocol import NvAudioSpeechResponse
return NvAudioSpeechResponse(
id=request_id,
model=self._model_name,
status="failed",
created=int(time.time()),
error=error,
).model_dump()
def _error_chunk(
request_id: str, model_name: str, error_message: str
) -> Dict[str, Any]:
"""Error response in OpenAI chat.completion.chunk format."""
return {
"id": request_id,
"created": int(time.time()),
"object": "chat.completion.chunk",
"model": model_name,
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": f"Error: {error_message}"},
"finish_reason": "error",
}
],
}
def _build_completion_usage(request_output: Any) -> Dict[str, Any]:
"""Build completion usage stats from a vLLM RequestOutput."""
prompt_tokens = (
len(request_output.prompt_token_ids)
if getattr(request_output, "prompt_token_ids", None)
else None
)
completion_tokens = len(request_output.outputs[0].token_ids)
return {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": (
prompt_tokens + completion_tokens if prompt_tokens is not None else None
),
"prompt_tokens_details": (
{"cached_tokens": num_cached}
if (num_cached := getattr(request_output, "num_cached_tokens", None))
else None
),
}
class OutputFormatter:
"""Dispatches raw engine output to modality-specific formatters.
Shared by OmniHandler (aggregated) and any future disaggregated router.
"""
def __init__(
self,
model_name: str,
media_fs: Any = None,
media_http_url: Optional[str] = None,
default_fps: int = 16,
) -> None:
self._formatters: Dict[str, Any] = {
"text": TextFormatter(model_name),
"image": DiffusionFormatter(
model_name, media_fs, media_http_url, default_fps
),
"audio": AudioFormatter(model_name, media_fs, media_http_url),
}
async def format(
self,
stage_output: Any,
request_id: str,
*,
request_type: Any = None,
**ctx: Any,
) -> Dict[str, Any] | None:
fmt_type = getattr(stage_output, "final_output_type", None)
formatter = self._formatters.get(fmt_type) if fmt_type else None
if formatter is None:
return None
# TextFormatter is sync and takes request_output, not stage_output.
if fmt_type == "text":
ro = getattr(stage_output, "request_output", None)
if not ro:
return None
return formatter.format(
ro, request_id, previous_text=ctx.get("previous_text", "")
)
return await formatter.format(
stage_output, request_id, request_type=request_type, **ctx
)
......@@ -190,94 +190,3 @@ class TestEngineInputsFromAudio:
req = NvCreateAudioSpeechRequest(input="hello", speed=2.0)
inputs = await handler.build_engine_inputs(req)
assert inputs.speed == 2.0
class TestExtractAudioTensor:
"""Tests for _extract_audio_tensor."""
def test_extracts_from_audio_key(self):
import numpy as np
handler = _make_audio_handler()
mm = {"audio": np.array([0.1, -0.2, 0.3], dtype=np.float32), "sr": 24000}
audio_np, sr = handler._extract_audio_tensor(mm)
assert sr == 24000
assert len(audio_np) == 3
def test_extracts_from_model_outputs_key(self):
import numpy as np
handler = _make_audio_handler()
mm = {
"model_outputs": np.array([0.5, -0.5], dtype=np.float32),
"sr": 16000,
}
audio_np, sr = handler._extract_audio_tensor(mm)
assert sr == 16000
assert len(audio_np) == 2
def test_missing_audio_raises(self):
handler = _make_audio_handler()
with pytest.raises(ValueError, match="No audio data"):
handler._extract_audio_tensor({"sr": 24000})
def test_squeezes_extra_dims(self):
import numpy as np
handler = _make_audio_handler()
mm = {"audio": np.array([[0.1, 0.2, 0.3]], dtype=np.float32), "sr": 24000}
audio_np, _ = handler._extract_audio_tensor(mm)
assert audio_np.ndim == 1
class TestEncodeAudio:
"""Tests for _encode_audio."""
def test_wav_encoding(self):
import numpy as np
handler = _make_audio_handler()
audio = np.zeros(2400, dtype=np.float32)
audio_bytes, media_type = handler._encode_audio(audio, 24000, "wav")
assert media_type == "audio/wav"
assert len(audio_bytes) > 0
assert audio_bytes[:4] == b"RIFF" # WAV header
def test_unsupported_format_falls_back_to_wav(self):
import numpy as np
handler = _make_audio_handler()
audio = np.zeros(100, dtype=np.float32)
_, media_type = handler._encode_audio(audio, 24000, "xyz")
assert media_type == "audio/wav"
def test_default_format_is_wav(self):
import numpy as np
handler = _make_audio_handler()
audio = np.zeros(100, dtype=np.float32)
_, media_type = handler._encode_audio(audio, 24000)
assert media_type == "audio/wav"
class TestFormatAudioChunk:
"""Tests for format_output."""
@pytest.mark.asyncio
async def test_empty_mm_output_returns_error(self):
handler = _make_audio_handler()
result = await handler.format_output({}, "req-1")
assert result["status"] == "failed"
assert "No audio generated" in result["error"]
@pytest.mark.asyncio
async def test_successful_generation(self):
import numpy as np
handler = _make_audio_handler()
mm = {"audio": np.random.randn(4800).astype(np.float32), "sr": 24000}
result = await handler.format_output(mm, "req-1")
assert result["status"] == "completed"
assert result["object"] == "audio.speech"
assert len(result["data"]) == 1
assert result["data"][0]["b64_json"] is not None
......@@ -48,44 +48,6 @@ class TestEngineInputs:
assert ei.response_format is None
class TestPrepareImageOutput:
@pytest.mark.asyncio
async def test_b64_json(self):
"""b64_json format returns data URI with base64 prefix."""
handler = _make_handler()
img = MagicMock()
img.save = lambda b, format: b.write(b"fake_png_data")
results = await handler._prepare_image_output([img], "req-1", "b64_json")
assert len(results) == 1
assert results[0].startswith("data:image/png;base64,")
@pytest.mark.asyncio
async def test_b64_default_when_none(self):
"""None response_format defaults to base64 encoding."""
handler = _make_handler()
img = MagicMock()
img.save = lambda b, format: b.write(b"data")
results = await handler._prepare_image_output([img], "req-1", None)
assert results[0].startswith("data:image/png;base64,")
@pytest.mark.asyncio
async def test_invalid_format(self):
"""Unsupported response_format raises ValueError."""
handler = _make_handler()
with pytest.raises(ValueError, match="Invalid response format"):
await handler._prepare_image_output([MagicMock()], "req-1", "invalid")
@pytest.mark.asyncio
async def test_multiple_images(self):
"""Multiple input images produce one output entry each."""
handler = _make_handler()
imgs = [MagicMock() for _ in range(3)]
for img in imgs:
img.save = lambda b, format: b.write(b"px")
results = await handler._prepare_image_output(imgs, "req-1", "b64_json")
assert len(results) == 3
class TestBuildEngineInputs:
@pytest.mark.asyncio
async def test_chat_completion(self):
......@@ -144,132 +106,6 @@ class TestBuildEngineInputs:
assert inputs.prompt["prompt"] == "Hello world"
class TestFormatTextChunk:
def _make_output(self, text="hello world", finish_reason=None):
output = MagicMock()
output.text = text
output.finish_reason = finish_reason
request_output = MagicMock()
request_output.outputs = [output]
request_output.prompt_token_ids = [1, 2, 3]
return request_output
def test_delta_text(self):
"""Delta content is the diff between current and previous text."""
handler = _make_handler()
ro = self._make_output("hello world")
chunk = handler._format_text_chunk(ro, "req-1", "hello ")
assert chunk["choices"][0]["delta"]["content"] == "world"
def test_no_outputs_returns_error(self):
"""Empty engine outputs produce an error chunk."""
handler = _make_handler()
ro = MagicMock()
ro.outputs = []
chunk = handler._format_text_chunk(ro, "req-1", "")
assert "Error" in chunk["choices"][0]["delta"]["content"]
def test_finish_reason_included(self):
"""Final chunk includes finish_reason and usage stats."""
handler = _make_handler()
handler._build_completion_usage = lambda ro: {
"prompt_tokens": 3,
"completion_tokens": 1,
}
ro = self._make_output("done", finish_reason="stop")
chunk = handler._format_text_chunk(ro, "req-1", "")
assert chunk["choices"][0]["finish_reason"] == "stop"
assert "usage" in chunk
def test_finish_reason_abort_normalized(self):
"""Abort finish reason is normalized to 'cancelled'."""
handler = _make_handler()
handler._build_completion_usage = lambda ro: {
"prompt_tokens": 3,
"completion_tokens": 1,
}
ro = self._make_output("done", finish_reason="abort")
chunk = handler._format_text_chunk(ro, "req-1", "")
assert chunk["choices"][0]["finish_reason"] == "cancelled"
def test_finish_reason_none_when_not_finished(self):
"""finish_reason is None when output has no finish_reason."""
handler = _make_handler()
ro = self._make_output("partial")
chunk = handler._format_text_chunk(ro, "req-1", "")
assert chunk["choices"][0]["finish_reason"] is None
class TestFormatImageChunk:
@pytest.mark.asyncio
async def test_chat_completion_format(self):
"""Chat completion route returns image_url content parts."""
handler = _make_handler()
img = MagicMock()
img.save = lambda b, format: b.write(b"px")
chunk = await handler._format_image_chunk(
[img], "req-1", request_type=RequestType.CHAT_COMPLETION
)
assert chunk["object"] == "chat.completion.chunk"
assert chunk["choices"][0]["delta"]["content"][0]["type"] == "image_url"
@pytest.mark.asyncio
async def test_image_generation_b64_format(self):
"""Image generation with b64_json format returns base64 data."""
handler = _make_handler()
img = MagicMock()
img.save = lambda b, format: b.write(b"px")
chunk = await handler._format_image_chunk(
[img],
"req-1",
response_format="b64_json",
request_type=RequestType.IMAGE_GENERATION,
)
assert chunk["data"][0]["b64_json"] is not None
@pytest.mark.asyncio
async def test_image_generation_default_format_returns_b64(self):
"""Image generation with response_format=None defaults to b64_json."""
handler = _make_handler()
img = MagicMock()
img.save = lambda b, format: b.write(b"px")
chunk = await handler._format_image_chunk(
[img],
"req-1",
response_format=None,
request_type=RequestType.IMAGE_GENERATION,
)
assert chunk["data"][0]["b64_json"] is not None
@pytest.mark.asyncio
async def test_empty_images_returns_error(self):
"""Empty image list produces an error chunk."""
handler = _make_handler()
chunk = await handler._format_image_chunk([], "req-1")
assert "Error" in chunk["choices"][0]["delta"]["content"]
class TestFormatVideoChunk:
@pytest.mark.asyncio
async def test_empty_frames_returns_none(self):
"""Empty frame list returns None."""
handler = _make_handler()
result = await handler._format_video_chunk([], "req-1", fps=16)
assert result is None
@pytest.mark.asyncio
async def test_error_returns_failed_status(self):
"""Encoding failure returns NvVideosResponse with failed status and error."""
handler = _make_handler()
with patch(
"dynamo.vllm.omni.omni_handler.normalize_video_frames",
side_effect=RuntimeError("boom"),
):
chunk = await handler._format_video_chunk([MagicMock()], "req-1", fps=16)
assert chunk["status"] == "failed"
assert "boom" in chunk["error"]
class TestI2VEngineInputs:
"""Tests for image-to-video: multi_modal_data attachment, I2V nvext params, and protocol fields."""
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Tests for output_formatter.py — modality-specific formatters."""
from unittest.mock import MagicMock
import pytest
try:
from dynamo.vllm.omni.output_formatter import (
DiffusionFormatter,
TextFormatter,
_build_completion_usage,
_error_chunk,
)
except ImportError:
pytest.skip("vLLM omni dependencies not available", allow_module_level=True)
pytestmark = [
pytest.mark.unit,
pytest.mark.vllm,
pytest.mark.gpu_1,
pytest.mark.pre_merge,
]
# ── TextFormatter ──────────────────────────────────────────
def _make_request_output(text="hello world", finish_reason=None):
output = MagicMock()
output.text = text
output.finish_reason = finish_reason
output.token_ids = [1, 2, 3] # 3 completion tokens
ro = MagicMock()
ro.outputs = [output]
ro.prompt_token_ids = [
10,
20,
30,
40,
50,
] # 5 prompt tokens (different from completion)
return ro
class TestTextFormatter:
def test_delta_text(self):
f = TextFormatter(model_name="test-model")
chunk = f.format(
_make_request_output("hello world"), "req-1", previous_text="hello "
)
assert chunk["choices"][0]["delta"]["content"] == "world"
def test_no_outputs_returns_error(self):
f = TextFormatter(model_name="test-model")
ro = MagicMock()
ro.outputs = []
chunk = f.format(ro, "req-1")
assert "Error" in chunk["choices"][0]["delta"]["content"]
def test_finish_reason_included(self):
f = TextFormatter(model_name="test-model")
ro = _make_request_output("done", finish_reason="stop")
chunk = f.format(ro, "req-1")
assert chunk["choices"][0]["finish_reason"] == "stop"
assert "usage" in chunk
def test_finish_reason_abort_normalized(self):
f = TextFormatter(model_name="test-model")
ro = _make_request_output("done", finish_reason="abort")
chunk = f.format(ro, "req-1")
assert chunk["choices"][0]["finish_reason"] == "cancelled"
def test_finish_reason_none_when_not_finished(self):
f = TextFormatter(model_name="test-model")
ro = _make_request_output("partial")
chunk = f.format(ro, "req-1")
assert chunk["choices"][0]["finish_reason"] is None
def test_model_name_in_response(self):
f = TextFormatter(model_name="my-model")
chunk = f.format(_make_request_output(), "req-1")
assert chunk["model"] == "my-model"
def test_usage_has_prompt_and_completion_tokens(self):
f = TextFormatter(model_name="test-model")
ro = _make_request_output("done", finish_reason="stop")
chunk = f.format(ro, "req-1")
assert chunk["usage"]["prompt_tokens"] == 5 # 5 prompt token IDs
assert chunk["usage"]["completion_tokens"] == 3 # 3 completion token IDs
assert chunk["usage"]["total_tokens"] == 8
# ── Helpers ────────────────────────────────────────────────
class TestErrorChunk:
def test_error_chunk_format(self):
chunk = _error_chunk("req-1", "my-model", "something broke")
assert chunk["choices"][0]["delta"]["content"] == "Error: something broke"
assert chunk["choices"][0]["finish_reason"] == "error"
assert chunk["model"] == "my-model"
# ── DiffusionFormatter ─────────────────────────────────────
def _make_diffusion_formatter():
return DiffusionFormatter(
model_name="test-model", media_fs=None, media_http_url=None
)
class TestDiffusionFormatterPrepareImages:
@pytest.mark.asyncio
async def test_b64_json(self):
f = _make_diffusion_formatter()
img = MagicMock()
img.save = lambda b, format: b.write(b"fake_png_data")
results = await f._prepare_images([img], "req-1", "b64_json")
assert len(results) == 1
assert results[0].startswith("data:image/png;base64,")
@pytest.mark.asyncio
async def test_b64_default_when_none(self):
f = _make_diffusion_formatter()
img = MagicMock()
img.save = lambda b, format: b.write(b"data")
results = await f._prepare_images([img], "req-1", None)
assert results[0].startswith("data:image/png;base64,")
@pytest.mark.asyncio
async def test_invalid_format(self):
f = _make_diffusion_formatter()
with pytest.raises(ValueError, match="Invalid response format"):
await f._prepare_images([MagicMock()], "req-1", "invalid")
@pytest.mark.asyncio
async def test_multiple_images(self):
f = _make_diffusion_formatter()
imgs = [MagicMock() for _ in range(3)]
for img in imgs:
img.save = lambda b, format: b.write(b"px")
results = await f._prepare_images(imgs, "req-1", "b64_json")
assert len(results) == 3
class TestDiffusionFormatterImage:
@pytest.mark.asyncio
async def test_chat_completion_format(self):
from dynamo.common.utils.output_modalities import RequestType
f = _make_diffusion_formatter()
img = MagicMock()
img.save = lambda b, format: b.write(b"px")
chunk = await f._encode_image(
[img], "req-1", request_type=RequestType.CHAT_COMPLETION
)
assert chunk["object"] == "chat.completion.chunk"
assert chunk["choices"][0]["delta"]["content"][0]["type"] == "image_url"
@pytest.mark.asyncio
async def test_image_generation_b64_format(self):
from dynamo.common.utils.output_modalities import RequestType
f = _make_diffusion_formatter()
img = MagicMock()
img.save = lambda b, format: b.write(b"px")
chunk = await f._encode_image(
[img],
"req-1",
response_format="b64_json",
request_type=RequestType.IMAGE_GENERATION,
)
assert chunk["data"][0]["b64_json"] is not None
@pytest.mark.asyncio
async def test_image_generation_default_format_returns_b64(self):
from dynamo.common.utils.output_modalities import RequestType
f = _make_diffusion_formatter()
img = MagicMock()
img.save = lambda b, format: b.write(b"px")
chunk = await f._encode_image(
[img],
"req-1",
response_format=None,
request_type=RequestType.IMAGE_GENERATION,
)
assert chunk["data"][0]["b64_json"] is not None
@pytest.mark.asyncio
async def test_empty_images_returns_error(self):
from dynamo.common.utils.output_modalities import RequestType
f = _make_diffusion_formatter()
chunk = await f._encode_image(
[], "req-1", request_type=RequestType.IMAGE_GENERATION
)
assert "Error" in chunk["choices"][0]["delta"]["content"]
class TestDiffusionFormatterVideo:
@pytest.mark.asyncio
async def test_empty_frames_returns_none(self):
from dynamo.common.utils.output_modalities import RequestType
f = _make_diffusion_formatter()
stage = MagicMock()
stage.images = []
result = await f.format(
stage, "req-1", request_type=RequestType.VIDEO_GENERATION
)
assert result is None
@pytest.mark.asyncio
async def test_error_returns_failed_status(self):
from unittest.mock import patch
f = _make_diffusion_formatter()
with patch(
"dynamo.common.utils.video_utils.normalize_video_frames",
side_effect=RuntimeError("boom"),
):
chunk = await f._encode_video([MagicMock()], "req-1", fps=16)
assert chunk["status"] == "failed"
assert "boom" in chunk["error"]
class TestBuildCompletionUsage:
def test_basic(self):
ro = _make_request_output("hello", finish_reason="stop")
usage = _build_completion_usage(ro)
assert usage["prompt_tokens"] == 5
assert usage["completion_tokens"] == 3
assert usage["total_tokens"] == 8
def test_no_prompt_tokens(self):
ro = _make_request_output()
ro.prompt_token_ids = None
usage = _build_completion_usage(ro)
assert usage["prompt_tokens"] is None
assert usage["total_tokens"] is None
# ── AudioFormatter ─────────────────────────────────────────
class TestAudioFormatterExtractTensor:
def test_extracts_from_audio_key(self):
import numpy as np
from dynamo.vllm.omni.output_formatter import AudioFormatter
f = AudioFormatter(model_name="test", media_fs=None, media_http_url=None)
mm = {"audio": np.array([0.1, -0.2, 0.3], dtype=np.float32), "sr": 24000}
audio_np, sr = f._extract_audio_tensor(mm)
assert sr == 24000
assert len(audio_np) == 3
def test_extracts_from_model_outputs_key(self):
import numpy as np
from dynamo.vllm.omni.output_formatter import AudioFormatter
f = AudioFormatter(model_name="test", media_fs=None, media_http_url=None)
mm = {"model_outputs": np.array([0.5, -0.5], dtype=np.float32), "sr": 16000}
audio_np, sr = f._extract_audio_tensor(mm)
assert sr == 16000
assert len(audio_np) == 2
def test_missing_audio_raises(self):
from dynamo.vllm.omni.output_formatter import AudioFormatter
f = AudioFormatter(model_name="test", media_fs=None, media_http_url=None)
with pytest.raises(ValueError, match="No audio data"):
f._extract_audio_tensor({"sr": 24000})
def test_squeezes_extra_dims(self):
import numpy as np
from dynamo.vllm.omni.output_formatter import AudioFormatter
f = AudioFormatter(model_name="test", media_fs=None, media_http_url=None)
mm = {"audio": np.array([[0.1, 0.2, 0.3]], dtype=np.float32), "sr": 24000}
audio_np, _ = f._extract_audio_tensor(mm)
assert audio_np.ndim == 1
class TestAudioFormatterEncode:
def test_wav_encoding(self):
import numpy as np
from dynamo.vllm.omni.output_formatter import AudioFormatter
f = AudioFormatter(model_name="test", media_fs=None, media_http_url=None)
audio_bytes, media_type = f._encode_audio(
np.zeros(2400, dtype=np.float32), 24000, "wav"
)
assert media_type == "audio/wav"
assert audio_bytes[:4] == b"RIFF"
def test_unsupported_format_falls_back_to_wav(self):
import numpy as np
from dynamo.vllm.omni.output_formatter import AudioFormatter
f = AudioFormatter(model_name="test", media_fs=None, media_http_url=None)
_, media_type = f._encode_audio(np.zeros(100, dtype=np.float32), 24000, "xyz")
assert media_type == "audio/wav"
def test_default_format_is_wav(self):
import numpy as np
from dynamo.vllm.omni.output_formatter import AudioFormatter
f = AudioFormatter(model_name="test", media_fs=None, media_http_url=None)
_, media_type = f._encode_audio(np.zeros(100, dtype=np.float32), 24000)
assert media_type == "audio/wav"
class TestAudioFormatterFormat:
@pytest.mark.asyncio
async def test_empty_returns_error(self):
from dynamo.vllm.omni.output_formatter import AudioFormatter
f = AudioFormatter(model_name="test", media_fs=None, media_http_url=None)
result = await f.format({}, "req-1")
assert result["status"] == "failed"
assert "No audio generated" in result["error"]
@pytest.mark.asyncio
async def test_successful_generation(self):
import numpy as np
from dynamo.vllm.omni.output_formatter import AudioFormatter
f = AudioFormatter(model_name="test", media_fs=None, media_http_url=None)
mm = {"audio": np.random.randn(4800).astype(np.float32), "sr": 24000}
result = await f.format(mm, "req-1")
assert result["status"] == "completed"
assert result["object"] == "audio.speech"
assert len(result["data"]) == 1
assert result["data"][0]["b64_json"] is not None
# ── OutputFormatter dispatcher ─────────────────────────────
class TestOutputFormatter:
"""Tests pass the full ctx that _generate_openai_mode actually sends
(request_type, fps, response_format, previous_text, speed) to catch
signature mismatches in individual formatters early."""
# Full ctx matching _generate_openai_mode's call signature
_FULL_CTX = dict(fps=16, response_format=None, previous_text="", speed=1.0)
@pytest.mark.asyncio
async def test_routes_text(self):
from dynamo.common.utils.output_modalities import RequestType
from dynamo.vllm.omni.output_formatter import OutputFormatter
f = OutputFormatter(model_name="test-model")
stage = MagicMock()
stage.final_output_type = "text"
stage.request_output = _make_request_output("hello world")
chunk = await f.format(
stage, "req-1", request_type=RequestType.CHAT_COMPLETION, **self._FULL_CTX
)
assert chunk["choices"][0]["delta"]["content"] == "hello world"
@pytest.mark.asyncio
async def test_routes_image(self):
from dynamo.common.utils.output_modalities import RequestType
from dynamo.vllm.omni.output_formatter import OutputFormatter
f = OutputFormatter(model_name="test-model")
stage = MagicMock()
stage.final_output_type = "image"
img = MagicMock()
img.save = lambda b, format: b.write(b"px")
stage.images = [img]
chunk = await f.format(
stage, "req-1", request_type=RequestType.CHAT_COMPLETION, **self._FULL_CTX
)
assert chunk["choices"][0]["delta"]["content"][0]["type"] == "image_url"
@pytest.mark.asyncio
async def test_routes_audio(self):
import numpy as np
from dynamo.common.utils.output_modalities import RequestType
from dynamo.vllm.omni.output_formatter import OutputFormatter
f = OutputFormatter(model_name="test-model")
stage = MagicMock()
stage.final_output_type = "audio"
stage.multimodal_output = {
"audio": np.random.randn(2400).astype(np.float32),
"sr": 24000,
}
chunk = await f.format(
stage, "req-1", request_type=RequestType.AUDIO_GENERATION, **self._FULL_CTX
)
assert chunk["status"] == "completed"
@pytest.mark.asyncio
async def test_unknown_type_returns_none(self):
from dynamo.vllm.omni.output_formatter import OutputFormatter
f = OutputFormatter(model_name="test-model")
stage = MagicMock()
stage.final_output_type = "unknown_modality"
result = await f.format(stage, "req-1")
assert result is None
@pytest.mark.asyncio
async def test_text_without_request_output_returns_none(self):
from dynamo.vllm.omni.output_formatter import OutputFormatter
f = OutputFormatter(model_name="test-model")
stage = MagicMock()
stage.final_output_type = "text"
stage.request_output = None
result = await f.format(stage, "req-1")
assert result is None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment