Unverified Commit 8fd1b847 authored by Neal Vaidya's avatar Neal Vaidya Committed by GitHub
Browse files

feat(vllm): add audio_url loading to multimodal handler (#7955)


Signed-off-by: default avatarNeal Vaidya <nealv@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent b579a772
...@@ -7,6 +7,7 @@ from collections.abc import Callable ...@@ -7,6 +7,7 @@ from collections.abc import Callable
from dynamo.common.constants import EmbeddingTransferMode from dynamo.common.constants import EmbeddingTransferMode
from dynamo.common.multimodal.async_encoder_cache import AsyncEncoderCache from dynamo.common.multimodal.async_encoder_cache import AsyncEncoderCache
from dynamo.common.multimodal.audio_loader import AudioLoader
from dynamo.common.multimodal.embedding_transfer import ( from dynamo.common.multimodal.embedding_transfer import (
AbstractEmbeddingReceiver, AbstractEmbeddingReceiver,
AbstractEmbeddingSender, AbstractEmbeddingSender,
...@@ -41,6 +42,7 @@ EMBEDDING_RECEIVER_FACTORIES: dict[ ...@@ -41,6 +42,7 @@ EMBEDDING_RECEIVER_FACTORIES: dict[
__all__ = [ __all__ = [
"AsyncEncoderCache", "AsyncEncoderCache",
"AudioLoader",
"EMBEDDING_RECEIVER_FACTORIES", "EMBEDDING_RECEIVER_FACTORIES",
"EMBEDDING_SENDER_FACTORIES", "EMBEDDING_SENDER_FACTORIES",
"ImageLoader", "ImageLoader",
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
from pathlib import Path
from typing import Any, Awaitable, Dict, Final, List
from urllib.parse import urlparse
import numpy as np
import dynamo.nixl_connect as nixl_connect
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.media_nixl import read_decoded_media_via_nixl
from dynamo.common.utils.runtime import run_async
logger = logging.getLogger(__name__)
# Constants for multimodal data variants
URL_VARIANT_KEY: Final = "Url"
DECODED_VARIANT_KEY: Final = "Decoded"
try:
from vllm.multimodal.media import MediaConnector
from vllm.multimodal.media.audio import AudioMediaIO
except ImportError:
MediaConnector = None # type: ignore[assignment]
AudioMediaIO = None # type: ignore[assignment]
def _require_vllm_audio_media() -> tuple[Any, Any]:
"""Return vLLM's audio media components, raising if not installed."""
if MediaConnector is None or AudioMediaIO is None:
raise RuntimeError(
"vLLM multimodal media components are required to decode `audio_url` "
"inputs in the vLLM backend."
)
return MediaConnector, AudioMediaIO
class AudioLoader:
"""Async audio loader for multimodal pipelines.
Delegates URL fetching and decoding to vLLM's ``MediaConnector`` +
``AudioMediaIO`` so that the exact same loading logic runs whether the
request arrives via ``vllm serve`` or through Dynamo. Returns
``(waveform, sample_rate)`` tuples at the native sample rate — vLLM's
model-specific ``MultiModalDataParser`` handles resampling and channel
normalization downstream.
Also supports the NIXL decoded variant for frontend-decoded audio
transferred via RDMA.
"""
def __init__(
self,
http_timeout: float = 30.0,
enable_frontend_decoding: bool = False,
) -> None:
if http_timeout <= 0:
raise ValueError(f"http_timeout must be positive, got {http_timeout}")
self._http_timeout = http_timeout
self._enable_frontend_decoding = enable_frontend_decoding
self._nixl_connector = None
self._vllm_media_connector = None
if self._enable_frontend_decoding:
self._nixl_connector = nixl_connect.Connector()
run_async(self._nixl_connector.initialize)
@staticmethod
def _normalize_audio_url(audio_url: str) -> str:
"""Convert bare filesystem paths to file:// URIs.
HTTP(S) and data: URLs are returned unchanged.
"""
parsed_url = urlparse(audio_url)
if parsed_url.scheme or not audio_url:
return audio_url
file_path = Path(audio_url).expanduser()
if not file_path.exists():
raise FileNotFoundError(f"Error reading file: {file_path}")
return file_path.resolve().as_uri()
def _get_vllm_media_connector(self) -> Any:
if self._vllm_media_connector is None:
MediaConnector, _ = _require_vllm_audio_media()
self._vllm_media_connector = MediaConnector(allowed_local_media_path="/")
return self._vllm_media_connector
def _create_vllm_audio_io(self) -> Any:
_, AudioMediaIO = _require_vllm_audio_media()
return AudioMediaIO()
@_nvtx.annotate("mm:audio:load_with_vllm", color="cyan")
async def _load_audio_with_vllm(self, audio_url: str) -> tuple[np.ndarray, float]:
connector = self._get_vllm_media_connector()
normalized_url = self._normalize_audio_url(audio_url)
# TODO: Add caching for repeated remote `audio_url` downloads to avoid
# refetching the same asset across requests.
return await connector.load_from_url_async(
normalized_url,
self._create_vllm_audio_io(),
fetch_timeout=self._http_timeout,
)
@_nvtx.annotate("mm:audio:load_audio", color="cyan")
async def load_audio(self, audio_url: str) -> tuple[np.ndarray, float]:
"""Load audio from a URL and return a (waveform, sample_rate) tuple.
Supports http(s), data: URIs, file:// paths, and bare filesystem paths.
Audio is loaded at the native sample rate — no resampling is performed.
"""
try:
waveform, sr = await self._load_audio_with_vllm(audio_url)
if waveform.size == 0:
raise ValueError(
f"Failed to decode audio from {audio_url}. Decoded waveform is empty."
)
return waveform, sr
except FileNotFoundError:
raise
except Exception as exc:
logger.error("Error loading audio from %s: %s", audio_url, exc)
raise ValueError(f"Failed to load audio from {audio_url}: {exc}") from exc
async def _load_decoded_audio(
self, decoded_metadata: Dict[str, Any]
) -> tuple[np.ndarray, float]:
"""Read pre-decoded audio via NIXL RDMA."""
if self._nixl_connector is None:
raise RuntimeError("NIXL connector is not initialized")
result = await read_decoded_media_via_nixl(
self._nixl_connector,
decoded_metadata,
return_metadata=True,
)
frames, metadata = result
if metadata is None:
metadata = {}
sr = metadata.get("sample_rate", 16000)
return frames, float(sr)
async def load_audio_batch(
self,
audio_mm_items: List[Dict[str, Any]],
) -> List[tuple[np.ndarray, float]]:
"""Load a batch of audio files from multimodal data items.
Supports two paths:
1. Url variant: Download and decode audio via vLLM's MediaConnector
2. Decoded variant: Read pre-decoded audio via NIXL RDMA
(requires enable_frontend_decoding=True)
Returns:
List of (waveform, sample_rate) tuples.
"""
audio_futures: List[Awaitable[tuple[np.ndarray, float]]] = []
for idx, item in enumerate(audio_mm_items):
if isinstance(item, dict) and URL_VARIANT_KEY in item:
url = item[URL_VARIANT_KEY]
audio_futures.append(self.load_audio(url))
logger.debug("Preparing to load audio from URL: %s...", url[:80])
elif isinstance(item, dict) and DECODED_VARIANT_KEY in item:
if self._enable_frontend_decoding:
metadata = item[DECODED_VARIANT_KEY]
audio_futures.append(self._load_decoded_audio(metadata))
else:
raise ValueError(
"Received decoded audio data but enable_frontend_decoding=False. "
"Enable frontend decoding to transfer decoded audio via NIXL."
)
else:
raise ValueError(
f"Invalid audio multimodal item at index {idx}. "
"Expected dict with 'Url' or 'Decoded' key."
)
results = await asyncio.gather(*audio_futures, return_exceptions=True)
loaded_audio: list[tuple[np.ndarray, float]] = []
collective_exceptions: list[str] = []
for media_item, result in zip(audio_mm_items, results, strict=True):
if isinstance(result, BaseException):
if isinstance(result, asyncio.CancelledError):
raise result
source = media_item.get(URL_VARIANT_KEY, "decoded")
logger.error("Failed to load audio from %s...: %s", source[:80], result)
collective_exceptions.append(
f"Failed to load audio from {source[:80]}...: {result}\n"
)
continue
loaded_audio.append(result)
if collective_exceptions:
raise Exception("".join(collective_exceptions))
return loaded_audio
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import AsyncMock
import numpy as np
import pytest
import dynamo.common.multimodal.audio_loader as audio_loader_module
from dynamo.common.multimodal.audio_loader import AudioLoader
pytestmark = [
pytest.mark.unit,
pytest.mark.gpu_0,
pytest.mark.pre_merge,
]
def test_normalize_audio_url_converts_local_paths(tmp_path):
audio_path = tmp_path / "sample.wav"
audio_path.write_bytes(b"RIFF")
assert (
AudioLoader._normalize_audio_url(str(audio_path))
== audio_path.resolve().as_uri()
)
def test_normalize_audio_url_preserves_data_urls():
data_url = "data:audio/wav;base64,UklGRg=="
assert AudioLoader._normalize_audio_url(data_url) == data_url
def test_normalize_audio_url_preserves_http_urls():
url = "https://example.com/audio.wav"
assert AudioLoader._normalize_audio_url(url) == url
def test_normalize_audio_url_raises_on_missing_file():
with pytest.raises(FileNotFoundError, match="Error reading file"):
AudioLoader._normalize_audio_url("/nonexistent/audio.wav")
@pytest.mark.asyncio
async def test_load_audio_uses_vllm_media_connector():
loader = AudioLoader()
waveform = np.random.randn(16000).astype(np.float32)
sr = 44100.0
loader._load_audio_with_vllm = AsyncMock( # type: ignore[method-assign]
return_value=(waveform, sr)
)
loaded_waveform, loaded_sr = await loader.load_audio(
"data:audio/wav;base64,UklGRg=="
)
np.testing.assert_array_equal(loaded_waveform, waveform)
assert loaded_sr == sr
@pytest.mark.asyncio
async def test_load_audio_rejects_empty_waveform():
loader = AudioLoader()
loader._load_audio_with_vllm = AsyncMock( # type: ignore[method-assign]
return_value=(np.array([], dtype=np.float32), 16000.0)
)
with pytest.raises(ValueError, match="empty"):
await loader.load_audio("https://example.com/empty.wav")
@pytest.mark.asyncio
async def test_load_audio_batch_uses_url_loader():
loader = AudioLoader()
first = (np.zeros(8000, dtype=np.float32), 16000.0)
second = (np.ones(8000, dtype=np.float32), 44100.0)
loader.load_audio = AsyncMock(side_effect=[first, second]) # type: ignore[method-assign]
audios = await loader.load_audio_batch(
[
{"Url": "https://example.com/one.wav"},
{"Url": "https://example.com/two.wav"},
]
)
assert len(audios) == 2
np.testing.assert_array_equal(audios[0][0], first[0])
assert audios[0][1] == first[1]
np.testing.assert_array_equal(audios[1][0], second[0])
assert audios[1][1] == second[1]
@pytest.mark.asyncio
async def test_load_audio_batch_rejects_malformed_items():
loader = AudioLoader(enable_frontend_decoding=False)
with pytest.raises(ValueError, match="Invalid audio multimodal item"):
await loader.load_audio_batch([{"bad_key": "value"}])
@pytest.mark.asyncio
async def test_load_audio_batch_rejects_decoded_variant_without_frontend_decoding():
loader = AudioLoader(enable_frontend_decoding=False)
with pytest.raises(ValueError, match="enable_frontend_decoding=False"):
await loader.load_audio_batch([{"Decoded": {"shape": [16000]}}])
@pytest.mark.asyncio
async def test_load_audio_batch_reads_decoded_variant(monkeypatch):
# Construct with enable_frontend_decoding=False to skip real NIXL init,
# then set the flags directly so the decoded path is exercised.
loader = AudioLoader(enable_frontend_decoding=False)
loader._enable_frontend_decoding = True
loader._nixl_connector = object()
decoded_item = {
"shape": [16000],
"metadata": {"sample_rate": 44100},
}
waveform = np.random.randn(16000).astype(np.float32)
read_decoded = AsyncMock(return_value=(waveform, decoded_item["metadata"]))
monkeypatch.setattr(
audio_loader_module, "read_decoded_media_via_nixl", read_decoded
)
audios = await loader.load_audio_batch([{"Decoded": decoded_item}])
assert len(audios) == 1
np.testing.assert_array_equal(audios[0][0], waveform)
assert audios[0][1] == 44100.0
read_decoded.assert_awaited_once_with(
loader._nixl_connector,
decoded_item,
return_metadata=True,
)
...@@ -27,6 +27,7 @@ from dynamo._core import Context ...@@ -27,6 +27,7 @@ from dynamo._core import Context
from dynamo.common.memory.multimodal_embedding_cache_manager import ( from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager, MultimodalEmbeddingCacheManager,
) )
from dynamo.common.multimodal.audio_loader import AudioLoader
from dynamo.common.multimodal.embedding_transfer import ( from dynamo.common.multimodal.embedding_transfer import (
LocalEmbeddingReceiver, LocalEmbeddingReceiver,
NixlReadEmbeddingReceiver, NixlReadEmbeddingReceiver,
...@@ -65,6 +66,7 @@ from .multimodal_utils.prefill_worker_utils import MultiModalEmbeddingLoader ...@@ -65,6 +66,7 @@ from .multimodal_utils.prefill_worker_utils import MultiModalEmbeddingLoader
# Multimodal data dictionary keys # Multimodal data dictionary keys
IMAGE_URL_KEY: Final = "image_url" IMAGE_URL_KEY: Final = "image_url"
VIDEO_URL_KEY: Final = "video_url" VIDEO_URL_KEY: Final = "video_url"
AUDIO_URL_KEY: Final = "audio_url"
URL_VARIANT_KEY: Final = "Url" URL_VARIANT_KEY: Final = "Url"
DECODED_VARIANT_KEY: Final = "Decoded" DECODED_VARIANT_KEY: Final = "Decoded"
...@@ -396,6 +398,9 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]): ...@@ -396,6 +398,9 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
self.image_loader = ImageLoader( self.image_loader = ImageLoader(
enable_frontend_decoding=enable_frontend_decoding enable_frontend_decoding=enable_frontend_decoding
) )
self.audio_loader = AudioLoader(
enable_frontend_decoding=enable_frontend_decoding
)
self.video_loader = VideoLoader( self.video_loader = VideoLoader(
enable_frontend_decoding=enable_frontend_decoding enable_frontend_decoding=enable_frontend_decoding
) )
...@@ -1238,6 +1243,16 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]): ...@@ -1238,6 +1243,16 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
f"Extracted {len(videos)} video(s) for multimodal processing" f"Extracted {len(videos)} video(s) for multimodal processing"
) )
# Handle audio_url entries
audio_mm_items = mm_map.get(AUDIO_URL_KEY, [])
if audio_mm_items:
audios = await self.audio_loader.load_audio_batch(audio_mm_items)
if audios:
vllm_mm_data["audio"] = audios[0] if len(audios) == 1 else audios
logger.debug(
f"Extracted {len(audios)} audio item(s) for multimodal processing"
)
return vllm_mm_data if vllm_mm_data else None return vllm_mm_data if vllm_mm_data else None
def _build_prompt_from_request( def _build_prompt_from_request(
...@@ -1632,7 +1647,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -1632,7 +1647,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
# supported — synthetic image data would be overwritten. # supported — synthetic image data would be overwritten.
if multi_modal_data is None and has_mm_data: if multi_modal_data is None and has_mm_data:
mm = request["multi_modal_data"] mm = request["multi_modal_data"]
if mm.get(VIDEO_URL_KEY) or mm.get("audio_url"): if mm.get(VIDEO_URL_KEY) or mm.get(AUDIO_URL_KEY):
multi_modal_data = await self._extract_multimodal_data( multi_modal_data = await self._extract_multimodal_data(
request, request_id, context request, request_id, context
) )
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from types import SimpleNamespace
from unittest.mock import AsyncMock
import numpy as np
import pytest
from PIL import Image
from dynamo.vllm.handlers import BaseWorkerHandler
pytestmark = [
pytest.mark.asyncio,
pytest.mark.unit,
pytest.mark.vllm,
pytest.mark.gpu_0,
pytest.mark.pre_merge,
]
class _TestWorkerHandler(BaseWorkerHandler):
async def generate(self, request, context):
yield {}
def _make_handler(enable_multimodal: bool = True) -> _TestWorkerHandler:
handler = _TestWorkerHandler.__new__(_TestWorkerHandler)
handler.enable_multimodal = enable_multimodal
handler.config = SimpleNamespace(model="Qwen/Qwen3-Omni-30B-A3B-Instruct")
handler.embedding_loader = None
handler.image_loader = SimpleNamespace(load_image_batch=AsyncMock(return_value=[]))
handler.audio_loader = SimpleNamespace(load_audio_batch=AsyncMock(return_value=[]))
return handler
async def test_extract_multimodal_data_loads_audio_url_items():
handler = _make_handler()
audio = (np.random.randn(16000).astype(np.float32), 16000.0)
handler.audio_loader.load_audio_batch = AsyncMock(return_value=[audio])
result = await handler._extract_multimodal_data(
{"multi_modal_data": {"audio_url": [{"Url": "https://example.com/audio.wav"}]}},
"req-1",
context=None,
)
assert result is not None
assert result["audio"] is audio
handler.audio_loader.load_audio_batch.assert_awaited_once()
async def test_extract_multimodal_data_loads_image_and_audio_together():
handler = _make_handler()
image = Image.new("RGB", (2, 2))
audio = (np.random.randn(16000).astype(np.float32), 16000.0)
handler.image_loader.load_image_batch = AsyncMock(return_value=[image])
handler.audio_loader.load_audio_batch = AsyncMock(return_value=[audio])
result = await handler._extract_multimodal_data(
{
"multi_modal_data": {
"image_url": [{"Url": "https://example.com/image.png"}],
"audio_url": [{"Url": "https://example.com/audio.wav"}],
}
},
"req-2",
context=None,
)
assert result is not None
assert result["image"] is image
assert result["audio"] is audio
handler.image_loader.load_image_batch.assert_awaited_once()
handler.audio_loader.load_audio_batch.assert_awaited_once()
async def test_extract_multimodal_data_multiple_audio_items():
handler = _make_handler()
audio1 = (np.zeros(8000, dtype=np.float32), 16000.0)
audio2 = (np.ones(8000, dtype=np.float32), 44100.0)
handler.audio_loader.load_audio_batch = AsyncMock(return_value=[audio1, audio2])
result = await handler._extract_multimodal_data(
{
"multi_modal_data": {
"audio_url": [
{"Url": "https://example.com/a.wav"},
{"Url": "https://example.com/b.wav"},
],
}
},
"req-3",
context=None,
)
assert result is not None
# Multiple items should be passed as a list, not unwrapped
assert isinstance(result["audio"], list)
assert len(result["audio"]) == 2
async def test_extract_multimodal_data_rejects_requests_when_disabled():
handler = _make_handler(enable_multimodal=False)
with pytest.raises(ValueError, match="multimodal processing is not enabled"):
await handler._extract_multimodal_data(
{
"multi_modal_data": {
"audio_url": [{"Url": "https://example.com/audio.wav"}]
}
},
"req-4",
context=None,
)
...@@ -16,7 +16,7 @@ This document provides a comprehensive guide for multimodal inference using vLLM ...@@ -16,7 +16,7 @@ This document provides a comprehensive guide for multimodal inference using vLLM
| ------------------------ | ---------- | ------------- | | ------------------------ | ---------- | ------------- |
| **Image** | Yes | Yes | | **Image** | Yes | Yes |
| **Video** | Yes | Yes | | **Video** | Yes | Yes |
| **Audio** (Experimental) | Yes | Yes | | **Audio** | Yes | No |
### Supported URL Formats ### Supported URL Formats
...@@ -144,50 +144,42 @@ bash launch/disagg_multimodal_epd.sh --model Qwen/Qwen3-VL-2B-Instruct ...@@ -144,50 +144,42 @@ bash launch/disagg_multimodal_epd.sh --model Qwen/Qwen3-VL-2B-Instruct
bash launch/disagg_multimodal_epd.sh --model Qwen/Qwen3-VL-2B-Instruct --single-gpu bash launch/disagg_multimodal_epd.sh --model Qwen/Qwen3-VL-2B-Instruct --single-gpu
``` ```
## Audio Serving (Experimental) ## Audio Serving
### Audio Aggregated Serving Dynamo supports `audio_url` requests for audio-capable models. Audio is loaded by the backend worker via vLLM's `AudioMediaIO` at native sample rate — vLLM's model-specific processor handles resampling and feature extraction internally. Omni models can handle `image_url`, `video_url`, and `audio_url` in the same request.
**Components:** ### Aggregated Serving
- workers: [AudioEncodeWorker](https://github.com/ai-dynamo/dynamo/tree/main/examples/multimodal/components/audio_encode_worker.py) for decoding audio into embeddings, and [VllmPDWorker](https://github.com/ai-dynamo/dynamo/tree/main/examples/multimodal/components/worker.py) for prefilling and decoding. Use the same aggregated multimodal launcher with an audio-capable model:
- processor: Tokenizes the prompt and passes it to the AudioEncodeWorker.
- frontend: HTTP endpoint to handle incoming requests.
**Workflow:** ```bash
pip install 'vllm[audio]' # installs librosa and other audio dependencies
cd $DYNAMO_HOME/examples/backends/vllm
bash launch/agg_multimodal.sh --model Qwen/Qwen3-Omni-30B-A3B-Instruct
```
```mermaid ```mermaid
flowchart LR flowchart LR
HTTP --> processor HTTP --> frontend
processor --> HTTP frontend --> HTTP
processor --audio_url--> audio_encode_worker frontend --audio_url--> vllm_worker
audio_encode_worker --> processor vllm_worker --> frontend
audio_encode_worker --embeddings--> pd_worker
pd_worker --> audio_encode_worker
``` ```
**Launch:** **Audio request:**
```bash
pip install 'vllm[audio]' accelerate # multimodal audio models dependency
cd $DYNAMO_HOME/examples/multimodal
bash launch/audio_agg.sh
```
**Client:**
```bash ```bash
curl http://localhost:8000/v1/chat/completions \ curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"model": "Qwen/Qwen2-Audio-7B-Instruct", "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
"messages": [ "messages": [
{ {
"role": "user", "role": "user",
"content": [ "content": [
{ {
"type": "text", "type": "text",
"text": "What is recited in the audio?" "text": "What sound is this?"
}, },
{ {
"type": "audio_url", "type": "audio_url",
...@@ -198,38 +190,11 @@ curl http://localhost:8000/v1/chat/completions \ ...@@ -198,38 +190,11 @@ curl http://localhost:8000/v1/chat/completions \
] ]
} }
], ],
"max_tokens": 6000, "max_tokens": 100,
"temperature": 0.8,
"stream": false "stream": false
}' | jq }' | jq
``` ```
### Audio Disaggregated Serving
**Workflow:**
For the Qwen2-Audio model, audio embeddings are only required during the prefill stage. The AudioEncodeWorker is connected directly to the prefill worker.
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --audio_url--> audio_encode_worker
audio_encode_worker --> processor
audio_encode_worker --embeddings--> prefill_worker
prefill_worker --> audio_encode_worker
prefill_worker --> decode_worker
decode_worker --> prefill_worker
```
**Launch:**
```bash
pip install 'vllm[audio]' accelerate # multimodal audio models dependency
cd $DYNAMO_HOME/examples/multimodal
bash launch/audio_disagg.sh
```
## Embedding Cache ## Embedding Cache
Dynamo supports embedding cache in both aggregated and disaggregated settings: Dynamo supports embedding cache in both aggregated and disaggregated settings:
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import logging
import os
import signal
import sys
from typing import AsyncIterator, Tuple
import torch
import uvloop
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser
import dynamo.nixl_connect as connect
from dynamo.runtime import Client, DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
from utils.args import Config, base_parse_args, parse_endpoint
from utils.audio_loader import AudioLoader
from utils.protocol import MyRequestOutput, vLLMMultimodalRequest
configure_dynamo_logging()
logger = logging.getLogger(__name__)
try:
import cupy as array_module
if not array_module.cuda.is_available():
raise ImportError("CUDA is not available.")
DEVICE = "cuda"
logger.info("Using cupy for array operations (GPU mode).")
except ImportError as e:
logger.warning(f"Failed to import cupy, falling back to numpy: {e}.")
import numpy as array_module
DEVICE = "cpu"
CACHE_SIZE_MAXIMUM = 8
class VllmEncodeWorker:
def __init__(
self,
args: argparse.Namespace,
engine_args: AsyncEngineArgs,
pd_worker_client: Client,
) -> None:
self.pd_worker_client = pd_worker_client
self.engine_args = engine_args
self.model = self.engine_args.model
self.audio_loader = AudioLoader(cache_size=CACHE_SIZE_MAXIMUM)
self.audio_processor = AutoProcessor.from_pretrained(
self.model, trust_remote_code=True
)
self.audio_model = Qwen2AudioForConditionalGeneration.from_pretrained(
self.model, device_map="auto", dtype=torch.float16
).eval()
def get_audio_embeddings(self, audio_features):
input_features, feature_attention_mask = (
audio_features.input_features,
audio_features.feature_attention_mask,
)
with torch.no_grad():
(
audio_feat_lengths,
audio_output_lengths,
) = self.audio_model.audio_tower._get_feat_extract_output_lengths(
feature_attention_mask.sum(-1)
)
batch_size, _, max_mel_seq_len = input_features.shape
max_seq_len = (max_mel_seq_len - 2) // 2 + 1
# Create a sequence tensor of shape (batch_size, max_seq_len)
seq_range = (
torch.arange(
0,
max_seq_len,
dtype=audio_feat_lengths.dtype,
device=audio_feat_lengths.device,
)
.unsqueeze(0)
.expand(batch_size, max_seq_len)
)
lengths_expand = audio_feat_lengths.unsqueeze(1).expand(
batch_size, max_seq_len
)
# Create mask
padding_mask = seq_range >= lengths_expand
audio_attention_mask_ = padding_mask.view(
batch_size, 1, 1, max_seq_len
).expand(batch_size, 1, max_seq_len, max_seq_len)
audio_attention_mask = audio_attention_mask_.to(
dtype=self.audio_model.audio_tower.conv1.weight.dtype,
device=self.audio_model.audio_tower.conv1.weight.device,
)
audio_attention_mask[audio_attention_mask_] = float("-inf")
audio_outputs = self.audio_model.audio_tower(
input_features, attention_mask=audio_attention_mask
)
selected_audio_feature = audio_outputs.last_hidden_state
audio_features = self.audio_model.multi_modal_projector(
selected_audio_feature
)
num_audios, max_audio_tokens, embed_dim = audio_features.shape
audio_features_mask = torch.arange(
max_audio_tokens, device=audio_output_lengths.device
)[None, :]
audio_features_mask = audio_features_mask < audio_output_lengths[:, None]
audio_features = audio_features[audio_features_mask]
return audio_features
def cleanup(self):
pass
async def generate(
self, request: vLLMMultimodalRequest
) -> AsyncIterator[MyRequestOutput]:
logger.debug(f"Got raw request: {request}")
if not isinstance(request, vLLMMultimodalRequest):
if isinstance(request, str):
request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received encode request: {{ id: {request.request_id} }}.")
request_id = request.request_id
# The following steps encode the requested audio and provided useful embeddings.
# 1. Open the audio from the provided URL.
# 2. Process the audio using the audio processor.
# 3. Run the audio through the audio model's audio tower.
# 4. Run the results of the audio tower through the multi-modal projector.
# 5. Create a descriptor for the embeddings.
# 6. Create a write operation using the serialized request and the descriptor.
# 7. Await for the write operation to complete.
# 8. Yield the encode response.
try:
audio, sr = await self.audio_loader.load_audio(
request.multimodal_input.audio_url
)
audio_features = self.audio_processor(
text="test<|AUDIO|>", audio=audio, return_tensors="pt", padding=False
)
with torch.no_grad():
audio_embeddings = self.get_audio_embeddings(audio_features)
descriptor = connect.Descriptor(audio_embeddings)
with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
# Clear the audio URL as hint that the audio is passed as embeddings.
request.multimodal_input.audio_url = None
request.embeddings_shape = tuple(audio_embeddings.shape)
logger.debug(f"Request: {request.model_dump_json()}")
response_generator = await self.pd_worker_client.round_robin(
request.model_dump_json()
)
await readable.wait_for_completion()
async for response in response_generator:
output = MyRequestOutput.model_validate_json(response.data())
yield MyRequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
).model_dump_json()
except Exception as e:
logger.error(f"Error processing request {request_id}: {e}")
raise
async def async_init(self, runtime: DistributedRuntime):
logger.info("Startup started.")
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector()
logger.info("Startup completed.")
@classmethod
def parse_args(cls) -> Tuple[argparse.Namespace, Config]:
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
DEFAULT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.encoder.generate"
DEFAULT_DOWNSTREAM_ENDPOINT = f"dyn://{DYN_NAMESPACE}.llm.generate"
parser = FlexibleArgumentParser(
description="vLLM based encoder for Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_ENDPOINT}'",
)
parser.add_argument(
"--downstream-endpoint",
type=str,
default=DEFAULT_DOWNSTREAM_ENDPOINT,
help=f"The endpoint string of the downstream LLM in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_DOWNSTREAM_ENDPOINT}'",
)
args, config = base_parse_args(parser)
return args, config
async def graceful_shutdown(runtime):
"""
By calling `runtime.shutdown()`, the endpoints will immediately be unavailable.
However, in-flight requests will still be processed until they are finished.
After all in-flight requests are finished, the `serve_endpoint` functions will return
and the engine will be shutdown by Python's garbage collector.
"""
logging.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
@dynamo_worker()
async def worker(runtime: DistributedRuntime):
# Runtime setup
# Set up signal handler for graceful shutdown
loop = asyncio.get_running_loop()
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logging.info("Signal handlers set up for graceful shutdown")
# worker setup
args, config = VllmEncodeWorker.parse_args()
await init(runtime, args, config)
async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Config):
"""
Instantiate and serve
"""
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
args.downstream_endpoint
)
pd_worker_client = await runtime.endpoint(
f"{parsed_namespace}.{parsed_component_name}.{parsed_endpoint_name}"
).client()
handler = VllmEncodeWorker(args, config.engine_args, pd_worker_client)
await handler.async_init(runtime)
logger.info("Waiting for PD Worker Instances ...")
await pd_worker_client.wait_for_instances()
logger.info(f"Starting to serve the {args.endpoint} endpoint...")
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate, metrics_labels=[("model", config.model)]
),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
source "$SCRIPT_DIR/../../common/launch_utils.sh"
source "$SCRIPT_DIR/../../common/gpu_utils.sh"
# Default values
MODEL_NAME="Qwen/Qwen2-Audio-7B-Instruct"
PROMPT_TEMPLATE=""
PROVIDED_PROMPT_TEMPLATE=""
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL_NAME=$2
shift 2
;;
--prompt-template)
PROVIDED_PROMPT_TEMPLATE=$2
shift 2
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --model <model_name> Specify the model to use (default: $MODEL_NAME)"
echo " --prompt-template <template> Specify the multi-modal prompt template to use. LLaVA 1.5 7B, Qwen2.5-VL, and Phi3V models have predefined templates."
echo " -h, --help Show this help message"
exit 0
;;
*)
echo "Unknown option: $1"
echo "Use --help for usage information"
exit 1
;;
esac
done
# Set PROMPT_TEMPLATE based on the MODEL_NAME
if [[ -n "$PROVIDED_PROMPT_TEMPLATE" ]]; then
PROMPT_TEMPLATE="$PROVIDED_PROMPT_TEMPLATE"
elif [[ "$MODEL_NAME" == "Qwen/Qwen2-Audio-7B-Instruct" ]]; then
PROMPT_TEMPLATE="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\n<prompt><|im_end|>\n<|im_start|>assistant\n"
else
echo "No multi-modal prompt template is defined for the model: $MODEL_NAME"
echo "Please provide a prompt template using --prompt-template option."
echo "Example: --prompt-template 'USER: <image>\n<prompt> ASSISTANT:'"
exit 1
fi
# Check and install required dependencies for audio multimodal models
echo "Checking audio multimodal dependencies..."
DEPS_MISSING=false
# Check for accelerate
if ! python -c "import accelerate" &> /dev/null; then
echo " accelerate not found"
DEPS_MISSING=true
else
echo " ✓ accelerate is installed"
fi
# Check for vllm with audio support
if ! python -c "import vllm" &> /dev/null; then
echo " vllm not found"
DEPS_MISSING=true
else
# Check if audio dependencies are available (librosa is a key audio dependency)
if ! python -c "import librosa" &> /dev/null; then
echo " vllm audio dependencies not found"
DEPS_MISSING=true
else
echo " ✓ vllm with audio support is installed"
fi
fi
# Install missing dependencies
if [ "$DEPS_MISSING" = true ]; then
echo "Installing missing dependencies..."
pip install 'vllm[audio]' accelerate
echo "Dependencies installed successfully"
else
echo "All required dependencies are already installed"
fi
# run ingress
# dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000)
python -m dynamo.frontend &
# run processor
python3 components/processor.py --model $MODEL_NAME --prompt-template "$PROMPT_TEMPLATE" &
# run E/P/D workers
GPU_MEM_ARGS=$(build_vllm_gpu_mem_args)
CUDA_VISIBLE_DEVICES=0 \
DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT2:-8082} \
python3 components/audio_encode_worker.py --model $MODEL_NAME &
CUDA_VISIBLE_DEVICES=1 \
DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT1:-8081} \
VLLM_NIXL_SIDE_CHANNEL_PORT=20097 \
python3 components/worker.py --model $MODEL_NAME --worker-type prefill $GPU_MEM_ARGS &
# Exit on first worker failure; kill 0 in the EXIT trap tears down the rest
wait_any_exit
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
source "$SCRIPT_DIR/../../common/launch_utils.sh"
source "$SCRIPT_DIR/../../common/gpu_utils.sh"
# Default values
MODEL_NAME="Qwen/Qwen2-Audio-7B-Instruct"
PROMPT_TEMPLATE=""
PROVIDED_PROMPT_TEMPLATE=""
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL_NAME=$2
shift 2
;;
--prompt-template)
PROVIDED_PROMPT_TEMPLATE=$2
shift 2
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --model <model_name> Specify the model to use (default: $MODEL_NAME)"
echo " --prompt-template <template> Specify the multi-modal prompt template to use. LLaVA 1.5 7B, Qwen2.5-VL, and Phi3V models have predefined templates."
echo " -h, --help Show this help message"
exit 0
;;
*)
echo "Unknown option: $1"
echo "Use --help for usage information"
exit 1
;;
esac
done
# Set PROMPT_TEMPLATE based on the MODEL_NAME
if [[ -n "$PROVIDED_PROMPT_TEMPLATE" ]]; then
PROMPT_TEMPLATE="$PROVIDED_PROMPT_TEMPLATE"
elif [[ "$MODEL_NAME" == "Qwen/Qwen2-Audio-7B-Instruct" ]]; then
PROMPT_TEMPLATE="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\n<prompt><|im_end|>\n<|im_start|>assistant\n"
else
echo "No multi-modal prompt template is defined for the model: $MODEL_NAME"
echo "Please provide a prompt template using --prompt-template option."
echo "Example: --prompt-template 'USER: <image>\n<prompt> ASSISTANT:'"
exit 1
fi
# Check and install required dependencies for audio multimodal models
echo "Checking audio multimodal dependencies..."
DEPS_MISSING=false
# Check for accelerate
if ! python -c "import accelerate" &> /dev/null; then
echo " accelerate not found"
DEPS_MISSING=true
else
echo " ✓ accelerate is installed"
fi
# Check for vllm with audio support
if ! python -c "import vllm" &> /dev/null; then
echo " vllm not found"
DEPS_MISSING=true
else
# Check if audio dependencies are available (librosa is a key audio dependency)
if ! python -c "import librosa" &> /dev/null; then
echo " vllm audio dependencies not found"
DEPS_MISSING=true
else
echo " ✓ vllm with audio support is installed"
fi
fi
# Install missing dependencies
if [ "$DEPS_MISSING" = true ]; then
echo "Installing missing dependencies..."
pip install 'vllm[audio]' accelerate
echo "Dependencies installed successfully"
else
echo "All required dependencies are already installed"
fi
# run ingress
# dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000)
python -m dynamo.frontend &
# run processor
DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT4:-8084} \
python3 components/processor.py --model $MODEL_NAME --prompt-template "$PROMPT_TEMPLATE" &
# run E/P/D workers
GPU_MEM_ARGS=$(build_vllm_gpu_mem_args)
CUDA_VISIBLE_DEVICES=0 \
DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT3:-8083} \
python3 components/audio_encode_worker.py --model $MODEL_NAME &
CUDA_VISIBLE_DEVICES=1 \
DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT1:-8081} \
DYN_VLLM_KV_EVENT_PORT=20081 \
VLLM_NIXL_SIDE_CHANNEL_PORT=20098 \
python3 components/worker.py --model $MODEL_NAME --worker-type prefill --enable-disagg $GPU_MEM_ARGS &
CUDA_VISIBLE_DEVICES=2 \
DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT2:-8082} \
DYN_VLLM_KV_EVENT_PORT=20082 \
VLLM_NIXL_SIDE_CHANNEL_PORT=20099 \
python3 components/worker.py --model $MODEL_NAME --worker-type decode --enable-disagg $GPU_MEM_ARGS &
# Exit on first worker failure; kill 0 in the EXIT trap tears down the rest
wait_any_exit
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import functools
import logging
from io import BytesIO
from typing import Tuple
from urllib.parse import urlparse
import httpx
import librosa
import numpy as np
logger = logging.getLogger(__name__)
class AudioLoader:
CACHE_SIZE_MAXIMUM = 8
def __init__(self, cache_size: int = CACHE_SIZE_MAXIMUM):
self._http_timeout = 30.0
# functools.lru_cache is not directly compatible with async methods.
# We create a synchronous method for fetching and processing audio,
# and then apply lru_cache to it. This cached synchronous method
# is then called from our async method using asyncio.to_thread.
self._load_and_process_audio_cached = functools.lru_cache(maxsize=cache_size)(
self._load_and_process_audio
)
def _load_and_process_audio(
self, audio_url: str, sampling_rate: int
) -> Tuple[np.ndarray, float]:
"""
Synchronously loads and processes audio from a URL.
This method is memoized using lru_cache.
"""
with httpx.Client(timeout=self._http_timeout) as client:
response = client.get(audio_url)
response.raise_for_status()
if not response.content:
raise ValueError("Empty response content from audio URL")
audio_data_stream = BytesIO(response.content)
audio_data, sr = librosa.load(audio_data_stream, sr=sampling_rate)
return audio_data, sr
async def load_audio(
self, audio_url: str, sampling_rate: int = 16000
) -> Tuple[np.ndarray, float]:
parsed_url = urlparse(audio_url)
if parsed_url.scheme not in ("http", "https"):
raise ValueError(f"Invalid audio source scheme: {parsed_url.scheme}")
try:
# Offload the synchronous, cached function to a separate thread
# to avoid blocking the asyncio event loop.
return await asyncio.to_thread(
self._load_and_process_audio_cached, audio_url, sampling_rate
)
except httpx.HTTPError as e:
logger.error(f"HTTP error loading audio: {e}")
raise
except Exception as e:
logger.error(f"Error loading audio: {e}")
raise ValueError(f"Failed to load audio: {e}")
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
import pytest import pytest
from dynamo.common.utils.paths import WORKSPACE_DIR
from tests.utils.multimodal import ( from tests.utils.multimodal import (
MultimodalModelProfile, MultimodalModelProfile,
TopologyConfig, TopologyConfig,
...@@ -19,12 +16,8 @@ VLLM_TOPOLOGY_SCRIPTS: dict[str, str] = { ...@@ -19,12 +16,8 @@ VLLM_TOPOLOGY_SCRIPTS: dict[str, str] = {
"e_pd": "disagg_multimodal_e_pd.sh", "e_pd": "disagg_multimodal_e_pd.sh",
"epd": "disagg_multimodal_epd.sh", "epd": "disagg_multimodal_epd.sh",
"p_d": "disagg_multimodal_p_d.sh", "p_d": "disagg_multimodal_p_d.sh",
"audio_agg": "audio_agg.sh",
"audio_disagg": "audio_disagg.sh",
} }
_AUDIO_DIR = os.path.join(WORKSPACE_DIR, "examples/multimodal")
VLLM_MULTIMODAL_PROFILES: list[MultimodalModelProfile] = [ VLLM_MULTIMODAL_PROFILES: list[MultimodalModelProfile] = [
MultimodalModelProfile( MultimodalModelProfile(
name="Qwen/Qwen3-VL-2B-Instruct", name="Qwen/Qwen3-VL-2B-Instruct",
...@@ -84,23 +77,18 @@ VLLM_MULTIMODAL_PROFILES: list[MultimodalModelProfile] = [ ...@@ -84,23 +77,18 @@ VLLM_MULTIMODAL_PROFILES: list[MultimodalModelProfile] = [
}, },
request_payloads=[make_image_payload(["purple"])], request_payloads=[make_image_payload(["purple"])],
), ),
# Audio: uses agg topology with DYN_CHAT_PROCESSOR=vllm because the Rust
# Jinja engine cannot render multimodal content arrays (audio_url).
MultimodalModelProfile( MultimodalModelProfile(
name="Qwen/Qwen2-Audio-7B-Instruct", name="Qwen/Qwen2-Audio-7B-Instruct",
short_name="qwen2-audio-7b", short_name="qwen2-audio-7b",
topologies={ topologies={
"audio_agg": TopologyConfig( "agg": TopologyConfig(
marks=[pytest.mark.nightly], marks=[pytest.mark.post_merge],
timeout_s=600,
directory=_AUDIO_DIR,
),
"audio_disagg": TopologyConfig(
marks=[pytest.mark.nightly],
timeout_s=600, timeout_s=600,
directory=_AUDIO_DIR, env={"DYN_CHAT_PROCESSOR": "vllm"},
gpu_marker="gpu_4",
), ),
}, },
gpu_marker="gpu_2",
request_payloads=[make_audio_payload(["Hester", "Pynne"])], request_payloads=[make_audio_payload(["Hester", "Pynne"])],
), ),
MultimodalModelProfile( MultimodalModelProfile(
......
...@@ -102,6 +102,7 @@ class TopologyConfig: ...@@ -102,6 +102,7 @@ class TopologyConfig:
directory: Optional[str] = None # override profile-level directory directory: Optional[str] = None # override profile-level directory
gpu_marker: Optional[str] = None # override profile-level gpu_marker gpu_marker: Optional[str] = None # override profile-level gpu_marker
single_gpu: bool = False # append --single-gpu to script_args single_gpu: bool = False # append --single-gpu to script_args
env: dict[str, str] = field(default_factory=dict) # extra env vars for subprocess
@dataclass @dataclass
...@@ -187,5 +188,6 @@ def make_multimodal_configs( ...@@ -187,5 +188,6 @@ def make_multimodal_configs(
marks=marks, marks=marks,
delayed_start=topo_cfg.delayed_start, delayed_start=topo_cfg.delayed_start,
request_payloads=profile.request_payloads, request_payloads=profile.request_payloads,
env=topo_cfg.env,
) )
return configs return configs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment