Unverified Commit 4791aaaa authored by Ryan McCormick's avatar Ryan McCormick Committed by GitHub
Browse files

refactor(vLLM): Move video support from example to backend (#7663)

parent 234a89c0
......@@ -19,6 +19,7 @@ from dynamo.common.multimodal.embedding_transfer import (
TransferRequest,
)
from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.common.multimodal.video_loader import VideoLoader
EMBEDDING_SENDER_FACTORIES: dict[
EmbeddingTransferMode, Callable[[], AbstractEmbeddingSender]
......@@ -43,6 +44,7 @@ __all__ = [
"EMBEDDING_RECEIVER_FACTORIES",
"EMBEDDING_SENDER_FACTORIES",
"ImageLoader",
"VideoLoader",
"NixlReadEmbeddingReceiver",
"NixlReadEmbeddingSender",
"NixlWriteEmbeddingSender",
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
from pathlib import Path
from typing import Any, Awaitable, Dict, Final, List
from urllib.parse import urlparse
import numpy as np
import dynamo.nixl_connect as nixl_connect
from dynamo.common.utils.media_nixl import read_decoded_media_via_nixl
from dynamo.common.utils.runtime import run_async
logger = logging.getLogger(__name__)
URL_VARIANT_KEY: Final = "Url"
DECODED_VARIANT_KEY: Final = "Decoded"
def _require_vllm_video_media() -> tuple[Any, Any, Any]:
try:
from vllm.multimodal.media import MediaConnector, VideoMediaIO
from vllm.multimodal.media.image import ImageMediaIO
except ImportError as exc:
raise RuntimeError(
"vLLM multimodal media components are required to decode `video_url` "
"inputs in the vLLM backend."
) from exc
return MediaConnector, VideoMediaIO, ImageMediaIO
class VideoLoader:
NUM_FRAMES_DEFAULT = int(os.environ.get("DYN_MM_VIDEO_NUM_FRAMES", "32"))
def __init__(
self,
http_timeout: float = 60.0,
num_frames: int = NUM_FRAMES_DEFAULT,
enable_frontend_decoding: bool = False,
) -> None:
self._http_timeout = int(http_timeout)
self._num_frames = num_frames
self._enable_frontend_decoding = enable_frontend_decoding
self._nixl_connector = None
self._vllm_media_connector = None
if self._enable_frontend_decoding:
self._nixl_connector = nixl_connect.Connector()
run_async(self._nixl_connector.initialize)
@staticmethod
def _normalize_video_url(video_url: str) -> str:
parsed_url = urlparse(video_url)
if parsed_url.scheme or not video_url:
return video_url
file_path = Path(video_url).expanduser()
if not file_path.exists():
raise FileNotFoundError(f"Error reading file: {file_path}")
return file_path.resolve().as_uri()
def _get_vllm_media_connector(self) -> Any:
if self._vllm_media_connector is None:
MediaConnector, _, _ = _require_vllm_video_media()
# Match the previous backend behavior and allow direct local file paths.
self._vllm_media_connector = MediaConnector(allowed_local_media_path="/")
return self._vllm_media_connector
def _create_vllm_video_io(self) -> Any:
_, VideoMediaIO, ImageMediaIO = _require_vllm_video_media()
return VideoMediaIO(
ImageMediaIO(image_mode="RGB"),
num_frames=self._num_frames,
)
async def _load_video_with_vllm(
self, video_url: str
) -> tuple[np.ndarray, Dict[str, Any]]:
connector = self._get_vllm_media_connector()
normalized_url = self._normalize_video_url(video_url)
# TODO: Add caching for repeated remote `video_url` downloads to avoid
# refetching the same asset across requests.
return await connector.load_from_url_async(
normalized_url,
self._create_vllm_video_io(),
fetch_timeout=self._http_timeout,
)
async def load_video(self, video_url: str) -> tuple[np.ndarray, Dict[str, Any]]:
try:
frames, metadata = await self._load_video_with_vllm(video_url)
if frames.size == 0:
raise ValueError(
f"Failed to extract video frames from {video_url}. Decoded clip is empty."
)
return np.ascontiguousarray(frames), metadata
except FileNotFoundError:
raise
except Exception as exc:
logger.error("Error loading video from %s: %s", video_url, exc)
raise ValueError(f"Failed to load video from {video_url}: {exc}") from exc
async def _load_decoded_video(
self, decoded_metadata: Dict[str, Any]
) -> tuple[np.ndarray, Dict[str, Any]]:
if self._nixl_connector is None:
raise RuntimeError("NIXL connector is not initialized")
frames, metadata = await read_decoded_media_via_nixl(
self._nixl_connector,
decoded_metadata,
return_metadata=True,
)
if metadata is None:
raise ValueError("Decoded video metadata is required")
return np.ascontiguousarray(frames), metadata
async def load_video_batch(
self,
video_mm_items: List[Dict[str, Any]],
) -> List[tuple[np.ndarray, Dict[str, Any]]]:
video_futures: List[Awaitable[tuple[np.ndarray, Dict[str, Any]]]] = []
for item in video_mm_items:
if isinstance(item, dict) and URL_VARIANT_KEY in item:
url = item[URL_VARIANT_KEY]
video_futures.append(self.load_video(url))
logger.debug("Preparing to load video from URL: %s...", url[:80])
elif isinstance(item, dict) and DECODED_VARIANT_KEY in item:
if self._enable_frontend_decoding:
metadata = item[DECODED_VARIANT_KEY]
video_futures.append(self._load_decoded_video(metadata))
else:
raise ValueError(
"Received decoded video data but enable_frontend_decoding=False. "
"Enable frontend decoding to transfer decoded video frames via NIXL."
)
results = await asyncio.gather(*video_futures, return_exceptions=True)
loaded_videos: list[tuple[np.ndarray, Dict[str, Any]]] = []
collective_exceptions: list[str] = []
for media_item, result in zip(video_mm_items, results):
if isinstance(result, BaseException):
if isinstance(result, asyncio.CancelledError):
raise result
source = media_item.get(URL_VARIANT_KEY, "decoded")
logger.error("Failed to load video from %s...: %s", source[:80], result)
collective_exceptions.append(
f"Failed to load video from {source[:80]}...: {result}\n"
)
continue
frames, metadata = result
loaded_videos.append((np.ascontiguousarray(frames), metadata))
if collective_exceptions:
raise Exception("".join(collective_exceptions))
return loaded_videos
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import AsyncMock
import numpy as np
import pytest
import dynamo.common.multimodal.video_loader as video_loader_module
from dynamo.common.multimodal.video_loader import VideoLoader
pytestmark = [
pytest.mark.unit,
pytest.mark.pre_merge,
pytest.mark.gpu_0,
]
def test_normalize_video_url_converts_local_paths(tmp_path):
video_path = tmp_path / "sample.webm"
video_path.write_bytes(b"video")
assert (
VideoLoader._normalize_video_url(str(video_path))
== video_path.resolve().as_uri()
)
def test_normalize_video_url_preserves_data_urls():
data_url = "data:video/webm;base64,Zm9v"
assert VideoLoader._normalize_video_url(data_url) == data_url
@pytest.mark.asyncio
async def test_load_video_uses_vllm_media_connector():
loader = VideoLoader()
frames = np.arange(24, dtype=np.uint8).reshape(1, 2, 4, 3)[:, :, ::-1, :]
metadata = {"fps": 4.0, "frames_indices": [0], "total_num_frames": 1}
loader._load_video_with_vllm = AsyncMock( # type: ignore[method-assign]
return_value=(frames, metadata)
)
loaded_frames, loaded_metadata = await loader.load_video(
"data:video/webm;base64,Zm9v"
)
assert loaded_frames.flags["C_CONTIGUOUS"]
np.testing.assert_array_equal(loaded_frames, np.ascontiguousarray(frames))
assert loaded_metadata == metadata
@pytest.mark.asyncio
async def test_load_video_batch_uses_url_loader():
loader = VideoLoader()
first = (
np.zeros((1, 2, 2, 3), dtype=np.uint8),
{"fps": 2.0, "frames_indices": [0], "total_num_frames": 1},
)
second = (
np.ones((1, 2, 2, 3), dtype=np.uint8),
{"fps": 2.0, "frames_indices": [0], "total_num_frames": 1},
)
loader.load_video = AsyncMock(side_effect=[first, second]) # type: ignore[method-assign]
videos = await loader.load_video_batch(
[
{"Url": "https://example.com/one.mp4"},
{"Url": "https://example.com/two.mp4"},
]
)
np.testing.assert_array_equal(videos[0][0], first[0])
np.testing.assert_array_equal(videos[1][0], second[0])
assert videos[0][1] == first[1]
assert videos[1][1] == second[1]
@pytest.mark.asyncio
async def test_load_video_batch_rejects_decoded_variant_without_frontend_decoding():
loader = VideoLoader(enable_frontend_decoding=False)
with pytest.raises(ValueError, match="enable_frontend_decoding=False"):
await loader.load_video_batch([{"Decoded": {"shape": [1, 2, 2, 3]}}])
@pytest.mark.asyncio
async def test_load_video_batch_reads_decoded_variant_with_metadata(monkeypatch):
loader = VideoLoader(enable_frontend_decoding=False)
loader._enable_frontend_decoding = True
loader._nixl_connector = object()
decoded_item = {
"shape": [1, 2, 2, 3],
"metadata": {"fps": 3.0, "frames_indices": [0], "total_num_frames": 1},
}
frames = np.arange(12, dtype=np.uint8).reshape(1, 2, 2, 3)
read_decoded = AsyncMock(return_value=(frames, decoded_item["metadata"]))
monkeypatch.setattr(
video_loader_module, "read_decoded_media_via_nixl", read_decoded
)
videos = await loader.load_video_batch([{"Decoded": decoded_item}])
np.testing.assert_array_equal(videos[0][0], np.ascontiguousarray(frames))
assert videos[0][1] == decoded_item["metadata"]
read_decoded.assert_awaited_once_with(
loader._nixl_connector,
decoded_item,
return_metadata=True,
)
......@@ -4,7 +4,7 @@
import logging
import time
import uuid
from typing import Any, Dict, Tuple
from typing import Any, Dict, Literal, Tuple, overload
import numpy as np
import torch
......@@ -15,6 +15,24 @@ from dynamo.nixl_connect import OperationKind, RdmaMetadata, SerializedDescripto
logger = logging.getLogger(__name__)
@overload
async def read_decoded_media_via_nixl(
connector: nixl_connect.Connector,
decoded_meta: Dict[str, Any],
return_metadata: Literal[False] = False,
) -> np.ndarray:
...
@overload
async def read_decoded_media_via_nixl(
connector: nixl_connect.Connector,
decoded_meta: Dict[str, Any],
return_metadata: Literal[True],
) -> Tuple[np.ndarray, Dict[str, Any] | None]:
...
async def read_decoded_media_via_nixl(
connector: nixl_connect.Connector,
decoded_meta: Dict[str, Any],
......
......@@ -33,6 +33,7 @@ from dynamo.common.multimodal.embedding_transfer import (
NixlWriteEmbeddingReceiver,
)
from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.common.multimodal.video_loader import VideoLoader
from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.common.utils.input_params import InputParamManager
from dynamo.common.utils.otel_tracing import build_trace_headers
......@@ -391,6 +392,9 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
self.image_loader = ImageLoader(
enable_frontend_decoding=enable_frontend_decoding
)
self.video_loader = VideoLoader(
enable_frontend_decoding=enable_frontend_decoding
)
self.embedding_loader = self.init_embedding_loader(config, encode_worker_client)
self.use_vllm_tokenizer = use_vllm_tokenizer
......@@ -1178,8 +1182,10 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
mm_map = request["multi_modal_data"]
# [gluo NOTE] If embedding loader is configured, currently we unconditionally
# fetch from the embedding loader.
vllm_mm_data = {}
# [gluo NOTE] If embedding loader is configured, fetch image embeddings first.
# Still continue below so mixed image+video requests can attach `video`.
if self.embedding_loader is not None:
# [gluo FIXME] couldn't simply pass 'mm_map.get(IMAGE_URL_KEY, [])' like below
# as currently the encode worker is using 'ImageLoader.load_image()' which doesn't
......@@ -1198,23 +1204,30 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
logger.debug(
f"Fetched multimodal embeddings for {len(vllm_mm_data)} items"
)
return vllm_mm_data if vllm_mm_data else None
# Fallback that the vLLM engine will perform encoding internally.
vllm_mm_data = {}
# Process image_url entries
image_mm_items = mm_map.get(IMAGE_URL_KEY, [])
if "image" not in vllm_mm_data and image_mm_items:
images = await self.image_loader.load_image_batch(
mm_map.get(IMAGE_URL_KEY, []),
image_mm_items,
)
if images:
# vLLM expects single image or list
vllm_mm_data["image"] = images[0] if len(images) == 1 else images
logger.debug(f"Extracted {len(images)} image(s) for multimodal processing")
logger.debug(
f"Extracted {len(images)} image(s) for multimodal processing"
)
video_mm_items = mm_map.get(VIDEO_URL_KEY, [])
if video_mm_items:
videos = await self.video_loader.load_video_batch(video_mm_items)
# Handle video_url entries (future expansion)
if VIDEO_URL_KEY in mm_map:
logger.warning("Video multimodal data not yet supported in standard worker")
if videos:
# vLLM expects single video or list
vllm_mm_data["video"] = videos[0] if len(videos) == 1 else videos
logger.debug(
f"Extracted {len(videos)} video(s) for multimodal processing"
)
return vllm_mm_data if vllm_mm_data else None
......
......@@ -51,7 +51,6 @@ class SupportedModels:
QWEN_3_VL_4B_FP8 = "Qwen/Qwen3-VL-4B-Instruct-FP8"
QWEN_3_VL_32B = "Qwen/Qwen3-VL-32B-Instruct"
QWEN_3_VL_32B_FP8 = "Qwen/Qwen3-VL-32B-Instruct-FP8"
LLAVA_NEXT_VIDEO_7B = "llava-hf/LLaVA-NeXT-Video-7B-hf"
def normalize_model_name(model_name: str) -> str:
......@@ -198,10 +197,7 @@ def construct_mm_data(
) -> Dict[str, Any]:
"""Construct multimodal data for a vLLM request for models that require additional parameters alongside the embeddings"""
# Handle video models
if is_model_supported(model, SupportedModels.LLAVA_NEXT_VIDEO_7B):
if video_numpy is None:
raise ValueError("No video frames provided.")
if video_numpy is not None:
return {"video": video_numpy}
# Handle image models - validate image embeddings first
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from types import SimpleNamespace
from unittest.mock import AsyncMock
import numpy as np
import pytest
from dynamo.vllm.handlers import BaseWorkerHandler
pytestmark = [
pytest.mark.unit,
pytest.mark.vllm,
pytest.mark.gpu_0,
pytest.mark.pre_merge,
]
class _TestWorkerHandler(BaseWorkerHandler):
async def generate(self, request, context):
yield {}
def _make_handler(enable_multimodal: bool = True) -> _TestWorkerHandler:
handler = _TestWorkerHandler.__new__(_TestWorkerHandler)
handler.enable_multimodal = enable_multimodal
handler.config = SimpleNamespace(model="Qwen/Qwen3-VL-2B-Instruct")
handler.embedding_loader = None
handler.image_loader = SimpleNamespace(load_image_batch=AsyncMock(return_value=[]))
handler.video_loader = SimpleNamespace(load_video_batch=AsyncMock(return_value=[]))
return handler
@pytest.mark.asyncio
async def test_extract_multimodal_data_loads_video_url_items():
handler = _make_handler()
video = (
np.zeros((2, 4, 4, 3), dtype=np.uint8),
{"fps": 2.0, "frames_indices": [0, 1], "total_num_frames": 2},
)
handler.video_loader.load_video_batch = AsyncMock(return_value=[video])
result = await handler._extract_multimodal_data(
{"multi_modal_data": {"video_url": [{"Url": "https://example.com/video.mp4"}]}},
"req-1",
context=None,
)
assert result is not None
assert result["video"] is video
handler.image_loader.load_image_batch.assert_not_awaited()
@pytest.mark.asyncio
async def test_extract_multimodal_data_merges_image_embeddings_with_video():
handler = _make_handler()
image_mm_data = {"image": {"image_embeds": object()}}
video = (
np.ones((3, 4, 4, 3), dtype=np.uint8),
{"fps": 2.0, "frames_indices": [0, 1, 2], "total_num_frames": 3},
)
handler.embedding_loader = SimpleNamespace(
load_multimodal_embeddings=AsyncMock(return_value=image_mm_data)
)
handler.video_loader.load_video_batch = AsyncMock(return_value=[video])
result = await handler._extract_multimodal_data(
{
"multi_modal_data": {
"image_url": [{"Url": "https://example.com/image.png"}],
"video_url": [{"Url": "https://example.com/video.mp4"}],
}
},
"req-2",
context=None,
)
assert result is not None
assert result["image"] is image_mm_data["image"]
assert result["video"] is video
handler.image_loader.load_image_batch.assert_not_awaited()
@pytest.mark.asyncio
async def test_extract_multimodal_data_falls_back_to_image_loader_for_decoded_images():
handler = _make_handler()
image = object()
video = (
np.full((1, 2, 2, 3), 7, dtype=np.uint8),
{"fps": 2.0, "frames_indices": [0], "total_num_frames": 1},
)
handler.embedding_loader = SimpleNamespace(
load_multimodal_embeddings=AsyncMock(return_value={"image": "unused"})
)
handler.image_loader.load_image_batch = AsyncMock(return_value=[image])
handler.video_loader.load_video_batch = AsyncMock(return_value=[video])
result = await handler._extract_multimodal_data(
{
"multi_modal_data": {
"image_url": [{"Decoded": {"shape": [1, 1, 3]}}],
"video_url": [{"Url": "https://example.com/video.mp4"}],
}
},
"req-3",
context=None,
)
assert result is not None
assert result["image"] is image
assert result["video"] is video
handler.embedding_loader.load_multimodal_embeddings.assert_not_awaited()
handler.image_loader.load_image_batch.assert_awaited_once()
@pytest.mark.asyncio
async def test_extract_multimodal_data_rejects_requests_when_disabled():
handler = _make_handler(enable_multimodal=False)
with pytest.raises(ValueError, match="multimodal processing is not enabled"):
await handler._extract_multimodal_data(
{
"multi_modal_data": {
"video_url": [{"Url": "https://example.com/video.mp4"}]
}
},
"req-4",
context=None,
)
......@@ -52,10 +52,10 @@ Dynamo provides support for improving latency and throughput for vision-and-lang
Reference implementations for deploying multimodal models:
- [vLLM multimodal examples](https://github.com/ai-dynamo/dynamo/tree/main/examples/backends/vllm/launch)
- [vLLM multimodal examples](https://github.com/ai-dynamo/dynamo/tree/main/examples/backends/vllm/launch) (image, video)
- [TRT-LLM multimodal examples](https://github.com/ai-dynamo/dynamo/tree/main/examples/backends/trtllm/launch)
- [SGLang multimodal examples](https://github.com/ai-dynamo/dynamo/tree/main/examples/backends/sglang/launch)
- [Experimental multimodal examples](https://github.com/ai-dynamo/dynamo/tree/main/examples/multimodal/launch) (video, audio)
- [Experimental multimodal examples](https://github.com/ai-dynamo/dynamo/tree/main/examples/multimodal/launch) (audio)
## Backend Documentation
......
......@@ -247,28 +247,24 @@ bash launch/disagg_multimodal_llama.sh
**Components:**
- workers: [VideoEncodeWorker](https://github.com/ai-dynamo/dynamo/tree/main/examples/multimodal/components/video_encode_worker.py) for decoding video into frames, and [VllmPDWorker](https://github.com/ai-dynamo/dynamo/tree/main/examples/multimodal/components/worker.py) for prefilling and decoding.
- processor: Tokenizes the prompt and passes it to the VideoEncodeWorker.
- frontend: HTTP endpoint to handle incoming requests.
- worker: Standard `python -m dynamo.vllm --enable-multimodal` backend.
- frontend: Standard `python -m dynamo.frontend` OpenAI-compatible endpoint.
**Workflow:**
The VideoEncodeWorker decodes the video into frames. Unlike the image pipeline which generates embeddings, this pipeline passes raw frames directly to the VllmPDWorker via NATS and RDMA.
The Rust preprocessor tokenizes the request and forwards `multi_modal_data` with `video_url` entries. The vLLM backend decodes video URLs into sampled RGB frames and attaches them to `TokensPrompt(multi_modal_data=...)` for standard multimodal processing.
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --video_url--> video_encode_worker
video_encode_worker --> processor
video_encode_worker --frames--> pd_worker
pd_worker --> video_encode_worker
HTTP --> frontend
frontend --> vllm_worker
vllm_worker --> frontend
```
**Launch:**
```bash
cd $DYNAMO_HOME/examples/multimodal
cd $DYNAMO_HOME/examples/backends/vllm
bash launch/video_agg.sh
```
......@@ -278,7 +274,7 @@ bash launch/video_agg.sh
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llava-hf/LLaVA-NeXT-Video-7B-hf",
"model": "Qwen/Qwen3-VL-2B-Instruct",
"messages": [
{
"role": "user",
......@@ -305,24 +301,20 @@ curl http://localhost:8000/v1/chat/completions \
**Workflow:**
For the LLaVA-NeXT-Video-7B model, frames are only required during the prefill stage. The VideoEncodeWorker is connected directly to the prefill worker, decoding the video into frames and passing them via RDMA.
The Rust preprocessor tokenizes the request and forwards `multi_modal_data` with `video_url` entries. The prefill worker decodes the video into sampled RGB frames locally, runs the multimodal prefill, and forwards KV state to the decode worker through the normal disaggregated vLLM path.
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --video_url--> video_encode_worker
video_encode_worker --> processor
video_encode_worker --frames--> prefill_worker
prefill_worker --> video_encode_worker
HTTP --> frontend
frontend --> prefill_worker
prefill_worker --> decode_worker
decode_worker --> prefill_worker
decode_worker --> frontend
```
**Launch:**
```bash
cd $DYNAMO_HOME/examples/multimodal
cd $DYNAMO_HOME/examples/backends/vllm
bash launch/video_disagg.sh
```
......@@ -655,7 +647,6 @@ The following models have been tested with Dynamo's vLLM multimodal backend:
- **Qwen3-VL** - `Qwen/Qwen3-VL-30B-A3B-Instruct-FP8`
- **LLaVA 1.5** - `llava-hf/llava-1.5-7b-hf`
- **Llama 4 Maverick** - `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`
- **LLaVA Next Video** - `llava-hf/LLaVA-NeXT-Video-7B-hf`
- **Qwen2-Audio** - `Qwen/Qwen2-Audio-7B-Instruct`
For a complete list of multimodal models supported by vLLM, see [vLLM Supported Multimodal Models](https://docs.vllm.ai/en/latest/models/supported_models/#list-of-multimodal-language-models). Models listed there should work with Simple Aggregated Mode but may not be explicitly tested.
......
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Aggregated video serving with standard Dynamo preprocessing and vLLM backend.
set -euo pipefail
cleanup() {
echo "Cleaning up..."
local pids
pids="$(jobs -pr)"
if [[ -n "$pids" ]]; then
kill $pids 2>/dev/null || true
fi
}
trap cleanup EXIT
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/../../../.." && pwd)"
source "$SCRIPT_DIR/../../../common/gpu_utils.sh"
source "$SCRIPT_DIR/../../../common/launch_utils.sh"
export PYTHONPATH="${REPO_ROOT}/components/src:${REPO_ROOT}/lib/bindings/python/src${PYTHONPATH:+:${PYTHONPATH}}"
MODEL_NAME="${DYN_MODEL_NAME:-Qwen/Qwen3-VL-2B-Instruct}"
HTTP_PORT="${DYN_HTTP_PORT:-8000}"
GPU_DEVICE="${CUDA_VISIBLE_DEVICES:-0}"
MAX_MODEL_LEN="${MAX_MODEL_LEN:-8192}"
MAX_NUM_SEQS="${MAX_NUM_SEQS:-2}"
EXTRA_ARGS=()
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL_NAME=$2
shift 2
;;
-h|--help)
cat <<USAGE
Usage: $0 [OPTIONS] [-- EXTRA_VLLM_ARGS]
Options:
--model <model_name> Video-capable VLM to serve (default: $MODEL_NAME)
-h, --help Show this help message
Any arguments after '--' are passed through to the vLLM worker.
USAGE
exit 0
;;
--)
shift
EXTRA_ARGS+=("$@")
break
;;
*)
EXTRA_ARGS+=("$1")
shift
;;
esac
done
export DYN_REQUEST_PLANE=tcp
GPU_MEM_ARGS=$(build_gpu_mem_args vllm)
print_launch_banner --no-curl "Launching Aggregated Video Serving" "$MODEL_NAME" "$HTTP_PORT" \
"Backend: dynamo.vllm --enable-multimodal" \
"Video path: Standard TokensPrompt multi_modal_data flow"
print_curl_footer <<CURL
curl http://localhost:${HTTP_PORT}/v1/chat/completions \\
-H 'Content-Type: application/json' \\
-d '{
"model": "${MODEL_NAME}",
"messages": [{"role": "user", "content": [
{"type": "text", "text": "Describe the video in detail"},
{"type": "video_url", "video_url": {"url": "https://storage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4"}}
]}],
"max_tokens": 128
}'
CURL
python -m dynamo.frontend &
CUDA_VISIBLE_DEVICES="$GPU_DEVICE" \
python -m dynamo.vllm \
--enable-multimodal \
--model "$MODEL_NAME" \
--max-model-len "$MAX_MODEL_LEN" \
--max-num-seqs "$MAX_NUM_SEQS" \
$GPU_MEM_ARGS \
"${EXTRA_ARGS[@]}" &
wait_any_exit
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Disaggregated video serving with standard Dynamo preprocessing and vLLM backend.
set -euo pipefail
cleanup() {
echo "Cleaning up..."
local pids
pids="$(jobs -pr)"
if [[ -n "$pids" ]]; then
kill $pids 2>/dev/null || true
fi
}
trap cleanup EXIT
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/../../../.." && pwd)"
source "$SCRIPT_DIR/../../../common/gpu_utils.sh"
source "$SCRIPT_DIR/../../../common/launch_utils.sh"
export PYTHONPATH="${REPO_ROOT}/components/src:${REPO_ROOT}/lib/bindings/python/src${PYTHONPATH:+:${PYTHONPATH}}"
MODEL_NAME="${DYN_MODEL_NAME:-Qwen/Qwen3-VL-2B-Instruct}"
HTTP_PORT="${DYN_HTTP_PORT:-8000}"
SINGLE_GPU=false
EXTRA_ARGS=()
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL_NAME=$2
shift 2
;;
--single-gpu)
SINGLE_GPU=true
shift
;;
-h|--help)
cat <<USAGE
Usage: $0 [OPTIONS] [-- EXTRA_VLLM_ARGS]
Options:
--model <model_name> Video-capable VLM to serve (default: $MODEL_NAME)
--single-gpu Run prefill and decode on one GPU for functional testing
-h, --help Show this help message
Any arguments after '--' are passed through to both vLLM workers.
USAGE
exit 0
;;
--)
shift
EXTRA_ARGS+=("$@")
break
;;
*)
EXTRA_ARGS+=("$1")
shift
;;
esac
done
export DYN_REQUEST_PLANE=tcp
if [[ "$SINGLE_GPU" == "true" ]]; then
GPU_LABEL="1 GPU"
PREFILL_GPU="${DYN_PREFILL_WORKER_GPU:-${CUDA_VISIBLE_DEVICES:-0}}"
DECODE_GPU="${DYN_DECODE_WORKER_GPU:-${CUDA_VISIBLE_DEVICES:-0}}"
MAX_MODEL_LEN="${MAX_MODEL_LEN:-4096}"
PD_KV_CACHE_BYTES=$((512 * 1024 * 1024))
SHARED_GPU_FRACTION=$(build_gpu_mem_args vllm --workers-per-gpu 2)
PREFILL_GPU_MEM="${DYN_PREFILL_GPU_MEM:-${SHARED_GPU_FRACTION:-0.45}}"
DECODE_GPU_MEM="${DYN_DECODE_GPU_MEM:-${SHARED_GPU_FRACTION:-0.45}}"
SHARED_ARGS=(
--enforce-eager
--max-model-len "$MAX_MODEL_LEN"
--kv-cache-memory-bytes "$PD_KV_CACHE_BYTES"
--limit-mm-per-prompt '{"image":1,"video":1,"audio":0}'
)
else
GPU_LABEL="2 GPUs"
PREFILL_GPU="${DYN_PREFILL_WORKER_GPU:-0}"
DECODE_GPU="${DYN_DECODE_WORKER_GPU:-1}"
MAX_MODEL_LEN="${MAX_MODEL_LEN:-8192}"
GPU_MEM_ARGS=$(build_gpu_mem_args vllm)
PREFILL_GPU_MEM="${DYN_PREFILL_GPU_MEM:-${GPU_MEM_ARGS:-0.9}}"
DECODE_GPU_MEM="${DYN_DECODE_GPU_MEM:-${GPU_MEM_ARGS:-0.9}}"
SHARED_ARGS=(--max-model-len "$MAX_MODEL_LEN")
fi
print_launch_banner --no-curl "Launching Disaggregated Video Serving ($GPU_LABEL)" "$MODEL_NAME" "$HTTP_PORT" \
"Backend: Prefill + decode workers via dynamo.vllm" \
"Video path: Standard TokensPrompt multi_modal_data flow"
print_curl_footer <<CURL
curl http://localhost:${HTTP_PORT}/v1/chat/completions \\
-H 'Content-Type: application/json' \\
-d '{
"model": "${MODEL_NAME}",
"messages": [{"role": "user", "content": [
{"type": "text", "text": "Describe the video in detail"},
{"type": "video_url", "video_url": {"url": "https://storage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4"}}
]}],
"max_tokens": 128
}'
CURL
python -m dynamo.frontend &
VLLM_NIXL_SIDE_CHANNEL_PORT=20098 \
CUDA_VISIBLE_DEVICES="$PREFILL_GPU" \
python -m dynamo.vllm \
--disaggregation-mode prefill \
--enable-multimodal \
--model "$MODEL_NAME" \
--gpu-memory-utilization "$PREFILL_GPU_MEM" \
"${SHARED_ARGS[@]}" \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' \
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20081"}' \
"${EXTRA_ARGS[@]}" &
VLLM_NIXL_SIDE_CHANNEL_PORT=20099 \
CUDA_VISIBLE_DEVICES="$DECODE_GPU" \
python -m dynamo.vllm \
--disaggregation-mode decode \
--enable-multimodal \
--model "$MODEL_NAME" \
--gpu-memory-utilization "$DECODE_GPU_MEM" \
"${SHARED_ARGS[@]}" \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' \
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20082"}' \
"${EXTRA_ARGS[@]}" &
wait_any_exit
......@@ -232,11 +232,17 @@ class Processor(ProcessMixIn):
for message in raw_request.messages:
for item in message.content:
if item.type == "image_url":
multimodal_input.image_url = item.image_url.url
raise ValueError(
"Image requests should use the standard `python -m dynamo.frontend` "
"+ `python -m dynamo.vllm --enable-multimodal` flow instead of the "
"legacy multimodal example processor."
)
elif item.type == "video_url":
if multimodal_input.image_url is not None:
raise ValueError("Cannot provide both image and video URLs")
multimodal_input.video_url = item.video_url.url
raise ValueError(
"Video requests should use the standard `python -m dynamo.frontend` "
"+ `python -m dynamo.vllm --enable-multimodal` flow instead of the "
"legacy multimodal example processor."
)
elif item.type == "audio_url":
if (
multimodal_input.image_url is not None
......@@ -250,7 +256,10 @@ class Processor(ProcessMixIn):
and multimodal_input.video_url is None
and multimodal_input.audio_url is None
):
raise ValueError("Either image URL or video URL or audio URL is required")
raise ValueError(
"Audio requests are the only multimodal mode supported by the "
"legacy example processor."
)
async for response in self._generate(
chat_request, multimodal_input, RequestType.CHAT
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
import asyncio
import logging
import os
import signal
import sys
from io import BytesIO
from queue import Queue
from typing import AsyncIterator, Optional, Tuple
import av
import numpy as np
import torch
import uvloop
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser
import dynamo.nixl_connect as connect
from dynamo.runtime import Client, DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
from utils.args import Config, base_parse_args, parse_endpoint
from utils.protocol import MyRequestOutput, vLLMMultimodalRequest
from utils.video_utils import (
calculate_frame_sampling_indices,
get_video_metadata,
load_video_content,
open_video_container,
prepare_tensor_for_rdma,
read_video_pyav,
resize_video_frames,
)
configure_dynamo_logging()
logger = logging.getLogger(__name__)
try:
import cupy as array_module
if not array_module.cuda.is_available():
raise ImportError("CUDA is not available.")
DEVICE = "cuda"
logger.info("Using cupy for array operations (GPU mode).")
except ImportError as e:
logger.warning(f"Failed to import cupy, falling back to numpy: {e}.")
import numpy as array_module
DEVICE = "cpu"
CACHE_SIZE_MAXIMUM = 8
class VllmEncodeWorker:
def __init__(
self,
args: argparse.Namespace,
engine_args: AsyncEngineArgs,
pd_worker_client: Client,
) -> None:
self.pd_worker_client = pd_worker_client
self.engine_args = engine_args
self.model = self.engine_args.model
self.min_workers = 1
# Video processing parameters
self.num_frames_to_sample = args.num_frames_to_sample
self.frame_height = 336
self.frame_width = 336
self.frame_channels = 3
self._video_content_cache: dict[str, BytesIO] = {}
self._cache_queue: Queue[str] = Queue(maxsize=CACHE_SIZE_MAXIMUM)
self._http_timeout = 60.0
def cleanup(self):
pass
async def generate(
self, request: vLLMMultimodalRequest
) -> AsyncIterator[MyRequestOutput]:
logger.debug(f"Got raw request: {request}")
if not isinstance(request, vLLMMultimodalRequest):
if isinstance(request, str):
request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received encode request: {{ id: {request.request_id} }}.")
request_id = request.request_id
video_url = request.multimodal_input.video_url
if video_url is None:
raise ValueError("Video URL is required.")
container: Optional[av.container.InputContainer] = None
try:
video_content_stream = await load_video_content(
video_url,
self._video_content_cache,
self._cache_queue,
self._http_timeout,
)
# Open video container using utility function
container = await open_video_container(video_content_stream, video_url)
if not container or not container.streams.video:
logger.error(f"No video stream found in {video_url}.")
raise ValueError(f"No video stream in {video_url}.")
# Get video metadata using utility function
total_frames, duration_sec = get_video_metadata(container)
# Calculate frame sampling indices using utility function
indices = calculate_frame_sampling_indices(
total_frames, self.num_frames_to_sample, duration_sec, video_url
)
if not container:
raise ValueError(f"Container is None for {video_url}")
# Decode video frames
clip_np: np.ndarray = await read_video_pyav(container, indices)
if clip_np.size == 0:
raise ValueError(
f"Failed to extract any video frames from {video_url} for indices {indices.tolist()}. Clip is empty."
)
logger.debug(
f"Successfully extracted {len(clip_np) if clip_np.ndim > 1 and clip_np.shape[0] > 0 else 0} frames for {video_url} with original shape {clip_np.shape}."
)
# Convert the NumPy array from the video decoder into a PyTorch tensor.
# This is a required step to use PyTorch functions for GPU-accelerated image processing.
frames_tensor_orig_res = torch.from_numpy(clip_np) # Shape: (T, H, W, C)
# Resize frames using utility function
resized_frames_tensor_hwc = resize_video_frames(
frames_tensor_orig_res, self.frame_height, self.frame_width
)
# Prepare tensor for RDMA using utility function
tensor_for_descriptor = prepare_tensor_for_rdma(
resized_frames_tensor_hwc, request_id
)
request.embeddings_shape = tuple(tensor_for_descriptor.shape)
descriptor = connect.Descriptor(tensor_for_descriptor)
with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
# Clear the image URL as hint that the image is passed as embeddings.
request.multimodal_input.video_url = None
logger.debug(f"Request: {request.model_dump_json()}")
# Get the response generator
response_generator = await self.pd_worker_client.round_robin(
request.model_dump_json()
)
await readable.wait_for_completion()
async for response in response_generator:
output = MyRequestOutput.model_validate_json(response.data())
yield MyRequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
).model_dump_json()
except (
FileNotFoundError,
av.FFmpegError,
ValueError,
) as e:
logger.error(
f"Error processing request {request_id} ({video_url[:100]}...): {type(e).__name__} - {e}"
)
raise # Re-raise to be handled by the service framework
except Exception as e:
logger.exception(
f"Unexpected error processing request {request_id} ({video_url[:100]}...): {e}"
)
raise
finally:
if container:
await asyncio.to_thread(container.close)
async def async_init(self, runtime: DistributedRuntime):
logger.info("Startup started.")
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector()
logger.info("Startup completed.")
@classmethod
def parse_args(cls) -> Tuple[argparse.Namespace, Config]:
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
DEFAULT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.encoder.generate"
DEFAULT_DOWNSTREAM_ENDPOINT = f"dyn://{DYN_NAMESPACE}.llm.generate"
parser = FlexibleArgumentParser(
description="vLLM based encoder for Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_ENDPOINT}'",
)
parser.add_argument(
"--downstream-endpoint",
type=str,
default=DEFAULT_DOWNSTREAM_ENDPOINT,
help=f"The endpoint string of the downstream LLM in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_DOWNSTREAM_ENDPOINT}'",
)
parser.add_argument(
"--num-frames-to-sample",
type=int,
default=8,
help="Number of frames to sample from the video. Default: 8",
)
args, config = base_parse_args(parser)
return args, config
async def graceful_shutdown(runtime):
"""
By calling `runtime.shutdown()`, the endpoints will immediately be unavailable.
However, in-flight requests will still be processed until they are finished.
After all in-flight requests are finished, the `serve_endpoint` functions will return
and the engine will be shutdown by Python's garbage collector.
"""
logging.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
@dynamo_worker()
async def worker(runtime: DistributedRuntime):
# Runtime setup
# Set up signal handler for graceful shutdown
loop = asyncio.get_running_loop()
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logging.info("Signal handlers set up for graceful shutdown")
# worker setup
args, config = VllmEncodeWorker.parse_args()
await init(runtime, args, config)
async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Config):
"""
Instantiate and serve
"""
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
args.downstream_endpoint
)
pd_worker_client = await runtime.endpoint(
f"{parsed_namespace}.{parsed_component_name}.{parsed_endpoint_name}"
).client()
handler = VllmEncodeWorker(args, config.engine_args, pd_worker_client)
await handler.async_init(runtime)
logger.info("Waiting for PD Worker Instances ...")
await pd_worker_client.wait_for_instances()
logger.info(f"Starting to serve the {args.endpoint} endpoint...")
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate, metrics_labels=[("model", config.model)]
),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
......@@ -235,11 +235,7 @@ class VllmPDWorker(VllmBaseWorker):
f"{parsed_namespace}.{parsed_component_name}.{parsed_endpoint_name}"
).client()
if "video" in self.engine_args.model.lower():
self.EMBEDDINGS_DTYPE = torch.uint8
else:
self.EMBEDDINGS_DTYPE = torch.float16
self.EMBEDDINGS_DEVICE = "cpu"
# Create and initialize a dynamo connector for this worker.
......@@ -283,14 +279,7 @@ class VllmPDWorker(VllmBaseWorker):
request.serialized_request, descriptor
)
await read_op.wait_for_completion()
if "video" in self.engine_args.model.lower():
video_numpy = embeddings.numpy()
multi_modal_data = construct_mm_data(
self.engine_args.model,
self.EMBEDDINGS_DTYPE,
video_numpy=video_numpy,
)
elif "audio" in self.engine_args.model.lower():
if "audio" in self.engine_args.model.lower():
multi_modal_data = construct_mm_data(
self.engine_args.model,
self.EMBEDDINGS_DTYPE,
......
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
source "$SCRIPT_DIR/../../common/gpu_utils.sh"
# Default values
MODEL_NAME="llava-hf/LLaVA-NeXT-Video-7B-hf"
PROMPT_TEMPLATE="USER: <video>\n<prompt> ASSISTANT:"
NUM_FRAMES_TO_SAMPLE=8
# run ingress
python -m dynamo.frontend --http-port=8000 &
# run processor
python3 components/processor.py --model $MODEL_NAME --prompt-template "$PROMPT_TEMPLATE" &
# run E/P/D workers
GPU_MEM_ARGS=$(build_gpu_mem_args vllm)
CUDA_VISIBLE_DEVICES=0 python3 components/video_encode_worker.py --model $MODEL_NAME --num-frames-to-sample $NUM_FRAMES_TO_SAMPLE &
VLLM_NIXL_SIDE_CHANNEL_PORT=20097 CUDA_VISIBLE_DEVICES=1 python3 components/worker.py --model $MODEL_NAME --worker-type prefill $GPU_MEM_ARGS &
# Wait for all background processes to complete
wait
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
source "$SCRIPT_DIR/../../common/gpu_utils.sh"
# Default values
MODEL_NAME="llava-hf/LLaVA-NeXT-Video-7B-hf"
PROMPT_TEMPLATE="USER: <video>\n<prompt> ASSISTANT:"
NUM_FRAMES_TO_SAMPLE=8
# run ingress
python -m dynamo.frontend --http-port=8000 &
# run processor
python3 components/processor.py --model $MODEL_NAME --prompt-template "$PROMPT_TEMPLATE" &
# run E/P/D workers
GPU_MEM_ARGS=$(build_gpu_mem_args vllm)
CUDA_VISIBLE_DEVICES=0 python3 components/video_encode_worker.py --model $MODEL_NAME --num-frames-to-sample $NUM_FRAMES_TO_SAMPLE &
DYN_VLLM_KV_EVENT_PORT=20081 VLLM_NIXL_SIDE_CHANNEL_PORT=20098 CUDA_VISIBLE_DEVICES=1 python3 components/worker.py --model $MODEL_NAME --worker-type prefill --enable-disagg $GPU_MEM_ARGS &
DYN_VLLM_KV_EVENT_PORT=20082 VLLM_NIXL_SIDE_CHANNEL_PORT=20099 CUDA_VISIBLE_DEVICES=2 python3 components/worker.py --model $MODEL_NAME --worker-type decode --enable-disagg $GPU_MEM_ARGS &
# Wait for all background processes to complete
wait
......@@ -27,7 +27,6 @@ class SupportedModels:
LLAVA_1_5_7B = "llava-hf/llava-1.5-7b-hf"
QWEN_2_5_VL_7B = "Qwen/Qwen2.5-VL-7B-Instruct"
LLAVA_NEXT_VIDEO_7B = "llava-hf/LLaVA-NeXT-Video-7B-hf"
QWEN_2_AUDIO_7B = "Qwen/Qwen2-Audio-7B-Instruct"
......@@ -45,7 +44,6 @@ def construct_mm_data(
model: str,
embeddings_dtype: torch.dtype,
image_embeds: Optional[torch.Tensor] = None,
video_numpy: Optional[Any] = None,
image_grid_thw: Optional[List[Any]] = None,
audio_embeds: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor | Dict[str, Any]]:
......@@ -54,11 +52,6 @@ def construct_mm_data(
audio_embeds = audio_embeds.to(torch.bfloat16)
assert audio_embeds.ndim == 2, "Audio embeddings must be 2D"
return {"audio": [audio_embeds]}
# Handle video models
if model == SupportedModels.LLAVA_NEXT_VIDEO_7B:
if video_numpy is None:
raise ValueError("No video frames provided.")
return {"video": video_numpy}
# Handle image models - validate image embeddings first
if image_embeds is None:
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import base64
import binascii
import logging
import os
from io import BytesIO
from queue import Queue
from typing import Tuple
from urllib.parse import urlparse
import av
import httpx
import numpy as np
import torch
import torch.nn.functional as F
from .http_client import get_http_client
logger = logging.getLogger(__name__)
async def load_video_content(
video_url: str,
video_content_cache: dict[str, BytesIO],
cache_queue: Queue[str],
http_timeout: float = 60.0,
) -> BytesIO:
"""
Load video content from various sources (URL, data URI, file).
Args:
video_url: The video URL or path
video_content_cache: Cache dictionary for storing downloaded content
cache_queue: Queue for managing cache eviction
http_timeout: Timeout for HTTP requests
Returns:
BytesIO stream containing video data
Raises:
ValueError: If video source is unsupported or invalid
FileNotFoundError: If local file doesn't exist
RuntimeError: If HTTP client initialization fails
"""
parsed_url = urlparse(video_url)
video_url_lower = video_url.lower()
if parsed_url.scheme in ("http", "https"):
if video_url_lower in video_content_cache:
logger.debug(f"Video content found in cache for URL: {video_url}")
cached_content = video_content_cache[video_url_lower]
cached_content.seek(0)
return cached_content
try:
video_data: BytesIO
if parsed_url.scheme == "data":
if not parsed_url.path.startswith(("video/", "application/octet-stream")):
raise ValueError("Data URL must be a video type or octet-stream")
media_type_and_data = parsed_url.path.split(",", 1)
if len(media_type_and_data) != 2:
raise ValueError("Invalid Data URL format: missing comma separator")
media_type, data_segment = media_type_and_data
if ";base64" not in media_type:
raise ValueError("Video Data URL currently must be base64 encoded")
try:
video_bytes = base64.b64decode(data_segment)
video_data = BytesIO(video_bytes)
except binascii.Error as e:
raise ValueError(f"Invalid base64 encoding for video data: {e}") from e
elif parsed_url.scheme in ("http", "https"):
http_client = get_http_client(http_timeout)
logger.debug(f"Downloading video from URL: {video_url}")
response = await http_client.get(video_url, timeout=http_timeout)
response.raise_for_status()
if not response.content:
raise ValueError(f"Empty response content from video URL: {video_url}")
video_data = BytesIO(response.content)
video_data.seek(0)
logger.debug(
f"Video downloaded from {video_url}, size: {len(response.content)} bytes."
)
elif parsed_url.scheme == "file" or not parsed_url.scheme:
file_path = parsed_url.path if parsed_url.scheme else video_url
# Ensure path is absolute or resolve relative to a known base if necessary
# For simplicity, assuming it's an accessible path.
if not os.path.exists(file_path):
raise FileNotFoundError(f"Error reading file: {file_path}")
with open(file_path, "rb") as f:
video_bytes = f.read()
video_data = BytesIO(video_bytes)
else:
raise ValueError(
f"Unsupported video source scheme: {parsed_url.scheme} for URL {video_url}"
)
if parsed_url.scheme in (
"http",
"https",
): # Cache successfully downloaded content
if cache_queue.full():
oldest_url = cache_queue.get_nowait()
if oldest_url in video_content_cache:
del video_content_cache[oldest_url]
# Store the BytesIO object directly; it will be seek(0)'d when retrieved
video_content_cache[video_url_lower] = video_data
cache_queue.put(video_url_lower)
return video_data
except httpx.HTTPStatusError as e:
logger.error(
f"HTTP error {e.response.status_code} loading video {video_url}: {e.response.text[:200]}"
)
raise ValueError(
f"Failed to download video {video_url}: HTTP {e.response.status_code}"
) from e
except httpx.RequestError as e:
logger.error(f"Request error loading video {video_url}: {e}")
raise ValueError(f"Network request failed for video {video_url}") from e
except FileNotFoundError as e:
logger.error(f"File error loading video {video_url}: {e}")
raise
except Exception as e:
logger.error(
f"Error loading video content from {video_url}: {type(e).__name__} - {e}"
)
raise ValueError(f"Failed to load video content: {e}") from e
async def open_video_container(
video_content_stream: BytesIO, video_url: str
) -> av.container.InputContainer:
"""
Open a video container from a BytesIO stream using PyAV.
Args:
video_content_stream: BytesIO stream containing video data
video_url: Original video URL for error reporting
Returns:
Opened PyAV container
Raises:
ValueError: If video format is invalid or corrupted
"""
def open_video_container_sync():
try:
return av.open(video_content_stream, mode="r")
except av.FFmpegError as ave:
logger.error(f"PyAV error opening video stream from {video_url}: {ave}")
raise ValueError(
f"Invalid video format or corrupted data from {video_url}."
) from ave
except Exception as e:
logger.error(
f"Unexpected error opening video stream from {video_url} with PyAV: {e}"
)
raise ValueError(f"Unexpected error opening video from {video_url}.") from e
return await asyncio.to_thread(open_video_container_sync)
def get_video_metadata(container: av.container.InputContainer) -> Tuple[int, float]:
"""
Extract metadata from video container.
Args:
container: Opened PyAV container
Returns:
Tuple of (total_frames, duration_in_seconds)
"""
if not container or not container.streams.video:
return 0, 0.0
stream_info = container.streams.video[0]
total_frames = stream_info.frames
# Duration can be useful for streams where total_frames is 0
if stream_info.duration and stream_info.time_base:
duration_sec = float(stream_info.duration * stream_info.time_base)
else:
duration_sec = 0.0
return total_frames, duration_sec
async def read_video_pyav(
container: av.container.InputContainer, indices: np.ndarray
) -> np.ndarray:
"""
Decode the video with PyAV decoder. Async wrapper.
Args:
container: The video container to decode from
indices: Frame indices to extract
Returns:
NumPy array of decoded frames
Raises:
ValueError: If no frames could be decoded for the given indices
"""
def blocking_decode():
container.seek(0) # Reset container for decoding
processed_indices = set(indices)
# Determine min/max index to optimize decoding loop slightly
min_idx = 0
max_idx = -1
if len(indices) > 0:
min_idx = np.min(indices)
max_idx = np.max(indices)
if (
not processed_indices
and container.streams.video
and container.streams.video[0].frames > 0
):
logger.warning(
"read_video_pyav called with empty indices for a non-empty video, attempting to read first frame."
)
try:
frame = next(container.decode(video=0))
return np.stack([frame.to_ndarray(format="rgb24")])
except StopIteration:
logger.error(
"Failed to read even the first frame despite non-empty indices check."
)
return np.array([])
decoded_frames_list = []
for i, frame in enumerate(container.decode(video=0)):
if i > max_idx and max_idx != -1: # max_idx is -1 if indices is empty
break
if i >= min_idx and i in processed_indices:
decoded_frames_list.append(frame)
if not decoded_frames_list and len(processed_indices) > 0:
actual_decoded_count = 0
try:
container.seek(0) # Reset for counting
for _ in container.decode(video=0):
actual_decoded_count += 1
except Exception: # Handle cases where re-decoding/counting fails
pass # Keep original error message
raise ValueError(
f"Could not decode any frames for the given indices: {indices.tolist()}. "
f"Video might be shorter than expected or indices out of bounds. "
f"Actual decodable frames in container (approx): {actual_decoded_count}."
)
return (
np.stack([x.to_ndarray(format="rgb24") for x in decoded_frames_list])
if decoded_frames_list
else np.array([])
)
return await asyncio.to_thread(blocking_decode)
def calculate_frame_sampling_indices(
total_frames: int,
num_frames_to_sample: int,
duration_sec: float = 0,
video_url: str = "",
) -> np.ndarray:
"""
Calculate frame indices to sample from a video.
Args:
total_frames: Total number of frames in the video
num_frames_to_sample: Number of frames to sample
duration_sec: Duration of video in seconds (for logging)
video_url: Video URL for logging purposes
Returns:
Array of frame indices to sample
Raises:
ValueError: If video has 0 frames and 0 duration
"""
if total_frames == 0 and duration_sec == 0:
logger.error(f"Video file '{video_url}' has 0 frames and 0 duration.")
raise ValueError(f"Video {video_url} has 0 frames and 0 duration.")
if total_frames == 0 and duration_sec > 0:
logger.warning(
f"Video {video_url} reports 0 frames but has duration {duration_sec:.2f}s. "
"Frame sampling may be based on requested count directly."
)
logger.debug(
f"Video {video_url} has {total_frames} frames (duration: {duration_sec:.2f}s). "
f"Sampling {num_frames_to_sample} frames."
)
indices: np.ndarray
if total_frames > 0:
if total_frames < num_frames_to_sample:
logger.warning(
f"Video frames ({total_frames}) < samples ({num_frames_to_sample}). "
f"Using all {total_frames} available frames."
)
indices = np.arange(0, total_frames).astype(int)
else:
indices = np.linspace(0, total_frames - 1, num_frames_to_sample, dtype=int)
else: # total_frames is 0 (likely a stream), sample by count.
logger.warning(
f"Video {video_url} frame count is 0. Attempting to sample {num_frames_to_sample} "
"frames by index. This might fail if stream is too short."
)
indices = np.arange(0, num_frames_to_sample).astype(int)
# Ensure indices are unique, especially after linspace for small numbers.
indices = np.unique(indices)
# Safety checks for edge cases
if len(indices) == 0 and total_frames > 0:
# If unique resulted in empty but there are frames, sample at least one
actual_samples = min(num_frames_to_sample, total_frames)
indices = np.arange(0, actual_samples).astype(int)
elif len(indices) == 0 and total_frames == 0:
# If indices is empty and total_frames is 0, let downstream handle this case
pass
logger.debug(f"Selected frame indices for {video_url}: {indices.tolist()}")
return indices
def resize_video_frames(
frames_tensor: torch.Tensor, target_height: int, target_width: int
) -> torch.Tensor:
"""
Resize video frames using PyTorch interpolation.
Args:
frames_tensor: Input tensor with shape (T, H, W, C)
target_height: Target frame height
target_width: Target frame width
Returns:
Resized tensor with shape (T, target_height, target_width, C)
"""
# Permute to (T, C, H, W) for interpolate
frames_tensor_chw = frames_tensor.permute(0, 3, 1, 2).float()
# Resize
resized_frames_tensor_chw = F.interpolate(
frames_tensor_chw,
size=(target_height, target_width),
mode="bilinear",
align_corners=False,
)
# Permute back to (T, H_new, W_new, C)
resized_frames_tensor_hwc = resized_frames_tensor_chw.permute(0, 2, 3, 1)
logger.debug(f"Resized frames to shape: {resized_frames_tensor_hwc.shape}")
return resized_frames_tensor_hwc
def prepare_tensor_for_rdma(
frames_tensor: torch.Tensor, request_id: str
) -> torch.Tensor:
"""
Prepare video frames tensor for RDMA transfer.
Args:
frames_tensor: Input frames tensor
request_id: Request ID for logging
Returns:
Tensor prepared for RDMA (CPU, uint8, contiguous)
"""
# Ensure the tensor is contiguous, on CPU and uint8 for the NIXL buffer.
tensor_for_descriptor = frames_tensor.to(
device="cpu", dtype=torch.uint8
).contiguous()
logger.debug(
f"Req {request_id}: Preparing raw frames tensor (shape: {tensor_for_descriptor.shape}, "
f"dtype: {tensor_for_descriptor.dtype}, device: {tensor_for_descriptor.device}, "
f"contiguous: {tensor_for_descriptor.is_contiguous()}) for RDMA."
)
return tensor_for_descriptor
......@@ -7,6 +7,7 @@ import logging
import os
import random
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
import pytest
......@@ -50,6 +51,10 @@ class VLLMConfig(EngineConfig):
vllm_dir = os.environ.get("VLLM_DIR") or os.path.join(
WORKSPACE_DIR, "examples/backends/vllm"
)
LOCAL_VIDEO_TEST_PATH = Path(
WORKSPACE_DIR, "lib/llm/tests/data/media/240p_10.mp4"
).resolve()
LOCAL_VIDEO_TEST_URI = LOCAL_VIDEO_TEST_PATH.as_uri()
# vLLM test configurations
......@@ -531,20 +536,18 @@ vllm_configs = {
),
],
),
# Video multimodal tests for nightly CI pipeline
# These tests validate video inference capabilities with LLaVA-NeXT-Video model
# Reference: Linear OPS-3015
# Video multimodal tests for CI using the vLLM video launch scripts.
"multimodal_video_agg": VLLMConfig(
name="multimodal_video_agg",
directory=os.path.join(WORKSPACE_DIR, "examples/multimodal"),
directory=vllm_dir,
script_name="video_agg.sh",
marks=[
pytest.mark.gpu_2,
pytest.mark.nightly,
pytest.mark.gpu_1,
pytest.mark.pre_merge,
], # TODO: profile to get max_vram and timeout
model="llava-hf/LLaVA-NeXT-Video-7B-hf",
model="Qwen/Qwen3-VL-2B-Instruct",
delayed_start=60, # Video models require longer loading time
script_args=["--model", "llava-hf/LLaVA-NeXT-Video-7B-hf"],
script_args=["--model", "Qwen/Qwen3-VL-2B-Instruct"],
timeout=600, # 10 minutes for video processing overhead
request_payloads=[
chat_payload(
......@@ -552,13 +555,11 @@ vllm_configs = {
{"type": "text", "text": "Describe the video in detail"},
{
"type": "video_url",
"video_url": {
"url": "https://storage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4"
},
"video_url": {"url": LOCAL_VIDEO_TEST_URI},
},
],
repeat_count=1,
expected_response=["rabbit"],
expected_response=["red", "static", "still"],
temperature=0.0,
max_tokens=100,
)
......@@ -566,15 +567,15 @@ vllm_configs = {
),
"multimodal_video_disagg": VLLMConfig(
name="multimodal_video_disagg",
directory=os.path.join(WORKSPACE_DIR, "examples/multimodal"),
directory=vllm_dir,
script_name="video_disagg.sh",
marks=[
pytest.mark.gpu_2,
pytest.mark.nightly,
pytest.mark.gpu_1,
pytest.mark.pre_merge,
], # TODO: profile to get max_vram and timeout
model="llava-hf/LLaVA-NeXT-Video-7B-hf",
model="Qwen/Qwen3-VL-2B-Instruct",
delayed_start=60, # Video models require longer loading time
script_args=["--model", "llava-hf/LLaVA-NeXT-Video-7B-hf"],
script_args=["--model", "Qwen/Qwen3-VL-2B-Instruct", "--single-gpu"],
timeout=600, # 10 minutes for video processing overhead
request_payloads=[
chat_payload(
......@@ -582,13 +583,11 @@ vllm_configs = {
{"type": "text", "text": "Describe the video in detail"},
{
"type": "video_url",
"video_url": {
"url": "https://storage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4"
},
"video_url": {"url": LOCAL_VIDEO_TEST_URI},
},
],
repeat_count=1,
expected_response=["rabbit"],
expected_response=["red", "static", "still"],
temperature=0.0,
max_tokens=100,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment