Unverified Commit 78a11074 authored by Kris Hung's avatar Kris Hung Committed by GitHub
Browse files

feat: Add vLLM multimodal video support (#2738)


Signed-off-by: default avatarkrishung5 <krish@nvidia.com>
parent 46175f5d
......@@ -326,3 +326,179 @@ You should see a response similar to this:
```json
{"id": "6cc99123ad6948d685b8695428238d4b", "object": "chat.completion", "created": 1752708043, "model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", "choices": [{"index": 0, "message": {"role": "assistant", "content": "The image depicts a street scene with a trolley bus as the central focus. The trolley bus is positioned on the left side of the road, facing the camera, and features a white and yellow color scheme. A prominent sign on the front of the bus reads \"OUT OF SERVICE\" in orange letters.\n\n**Key Elements:**\n\n* **Trolley Bus:** The bus is the main subject of the image, showcasing its distinctive design and color.\n* **Sign:** The \"OUT OF SERVICE\" sign is clearly visible on the front of the bus, indicating its current status.\n* **Street Scene:** The surrounding environment includes trees, buildings, and power lines, creating a sense of context and atmosphere.\n* **Lighting:** The image is characterized by a misty or foggy quality, with soft lighting that adds to the overall mood.\n\n**Overall Impression:**\n\nThe image presents a serene and somewhat melancholic scene, with the out-of-service trolley bus serving as a focal point. The misty atmosphere and soft lighting contribute to a contemplative ambiance, inviting the viewer to reflect on the situation."}, "finish_reason": "stop"}]}
```
## Multimodal Aggregated Video Serving
This example demonstrates deploying an aggregated multimodal model that can process video inputs.
### Components
- workers: For video serving, we use the [VideoEncodeWorker](components/video_encode_worker.py) for decoding video into frames, and send the frames to [VllmPDWorker](components/worker.py) for prefilling and decoding.
- processor: Tokenizes the prompt and passes it to the VideoEncodeWorker.
- frontend: HTTP endpoint to handle incoming requests.
### Graph
In this graph, we have two workers, [VideoEncodeWorker](components/video_encode_worker.py) and [VllmPDWorker](components/worker.py).
The VideoEncodeWorker is responsible for decoding the video into a series of frames. Unlike the image pipeline which generates embeddings,
this pipeline passes the raw frames directly to the VllmPDWorker via a combination of NATS and RDMA.
Its VllmPDWorker then prefills and decodes the prompt, just like the [LLM aggregated serving](/components/backends/vllm/README.md) example.
By separating the video processing from the prefill and decode stages, we can have a more flexible deployment and scale the
VideoEncodeWorker independently from the prefill and decode workers if needed.
This figure shows the flow of the graph:
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --video_url--> video_encode_worker
video_encode_worker --> processor
video_encode_worker --frames--> pd_worker
pd_worker --> video_encode_worker
```
```bash
cd $DYNAMO_HOME/examples/multimodal
bash launch/video_agg.sh
```
### Client
In another terminal:
```bash
curl http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llava-hf/LLaVA-NeXT-Video-7B-hf",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Describe the video in detail"
},
{
"type": "video_url",
"video_url": {
"url": "https://storage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4"
}
}
]
}
],
"max_tokens": 300,
"stream": false
}' | jq
```
You should see a response describing the video's content similar to
```json
{
"id": "7587e7d152014bae8e5c4e25f9fda0ed",
"choices": [
{
"index": 0,
"message": {
"content": " The video takes us away to a lively world of wildlife and natural beauty, featuring a white rabbit in a vibrant forest setting. At the beginning of the clip, the white rabbit is seen standing on a rock, facing towards the right side of the frame, with bushes and trees in the backdrop. The rabbit appears to be alert, given its ears are up and its ears perked in the air. As the clip progresses, the movement of the rabbit brings it around a tree, where its legs are partially hidden by the dense vegetation. It then sits down and grooms its fur, a behavior that suggests it is comfortable in its surroundings. \n\nThe scene then switches to a close-up shot of the rabbit, giving us a better view of its features and expressions. In this camera angle, the rabbit appears more dynamic and alert, with its breathing more visible, signaling its health and well-being. The camera pans out, and we see the rabbit heading towards the left side of the screen, possibly curious or hunting for food, with its ears perked up again. The lush greenery of the forest unfolds in the background, adding to the feeling of a wild and thriving environment.\n\n\nThe rabbit, upturned slightly while walking, finds a pile of dirt and rocks and sits there, fully clothed, perhaps taking a break from its exploration. There's a mention of a blue bird that appears to perch atop a log, adding a touch of whimsy to the scene. Lastly, the rabbit is observed relaxing on the rocks, resting comfortably, and looking off to the right side—a moment of tranquility in a bustling ecosystem. Throughout the clip, the rabbit's outfit remains the same, allowing for a clear focus on its behavior and characteristics while fitting in its habitat.",
"role": "assistant",
"reasoning_content": null
},
"finish_reason": "stop"
}
],
"created": 1756251832,
"model": "llava-hf/LLaVA-NeXT-Video-7B-hf",
"object": "chat.completion",
"usage": null
}
```
## Multimodal Disaggregated Video Serving
This example demonstrates deploying a disaggregated multimodal model that can process video inputs.
### Components
- workers: For disaggregated video serving, we have three workers, [VideoEncodeWorker](components/video_encode_worker.py) for decoding video into frames,
[VllmDecodeWorker](components/worker.py) for decoding, and [VllmPDWorker](components/worker.py) for prefilling.
- processor: Tokenizes the prompt and passes it to the VideoEncodeWorker.
- frontend: HTTP endpoint to handle incoming requests.
### Graph
In this graph, we have three workers, [VideoEncodeWorker](components/video_encode_worker.py), [VllmDecodeWorker](components/worker.py), and [VllmPDWorker](components/worker.py).
For the LLaVA-NeXT-Video-7B model, frames are only required during the prefill stage. As such, the VideoEncodeWorker is connected directly to the prefill worker.
The VideoEncodeWorker is responsible for decoding the video into a series of frames and passing them to the prefill worker via RDMA.
The prefill worker performs the prefilling step and forwards the KV cache to the decode worker for decoding.
For more details on the roles of the prefill and decode workers, refer to the [LLM disaggregated serving](/components/backends/vllm/README.md) example.
This figure shows the flow of the graph:
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --video_url--> video_encode_worker
video_encode_worker --> processor
video_encode_worker --frames--> prefill_worker
prefill_worker --> video_encode_worker
prefill_worker --> decode_worker
decode_worker --> prefill_worker
```
```bash
cd $DYNAMO_HOME/examples/multimodal
bash launch/video_disagg.sh
```
### Client
In another terminal:
```bash
curl http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llava-hf/LLaVA-NeXT-Video-7B-hf",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Describe the video in detail"
},
{
"type": "video_url",
"video_url": {
"url": "https://storage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4"
}
}
]
}
],
"max_tokens": 300,
"stream": false
}' | jq
```
You should see a response describing the video's content similar to
```json
{
"id": "7587e7d152014bae8e5c4e25f9fda0ed",
"choices": [
{
"index": 0,
"message": {
"content": " The video takes us away to a lively world of wildlife and natural beauty, featuring a white rabbit in a vibrant forest setting. At the beginning of the clip, the white rabbit is seen standing on a rock, facing towards the right side of the frame, with bushes and trees in the backdrop. The rabbit appears to be alert, given its ears are up and its ears perked in the air. As the clip progresses, the movement of the rabbit brings it around a tree, where its legs are partially hidden by the dense vegetation. It then sits down and grooms its fur, a behavior that suggests it is comfortable in its surroundings. \n\nThe scene then switches to a close-up shot of the rabbit, giving us a better view of its features and expressions. In this camera angle, the rabbit appears more dynamic and alert, with its breathing more visible, signaling its health and well-being. The camera pans out, and we see the rabbit heading towards the left side of the screen, possibly curious or hunting for food, with its ears perked up again. The lush greenery of the forest unfolds in the background, adding to the feeling of a wild and thriving environment.\n\n\nThe rabbit, upturned slightly while walking, finds a pile of dirt and rocks and sits there, fully clothed, perhaps taking a break from its exploration. There's a mention of a blue bird that appears to perch atop a log, adding a touch of whimsy to the scene. Lastly, the rabbit is observed relaxing on the rocks, resting comfortably, and looking off to the right side—a moment of tranquility in a bustling ecosystem. Throughout the clip, the rabbit's outfit remains the same, allowing for a clear focus on its behavior and characteristics while fitting in its habitat.",
"role": "assistant",
"reasoning_content": null
},
"finish_reason": "stop"
}
],
"created": 1756251832,
"model": "llava-hf/LLaVA-NeXT-Video-7B-hf",
"object": "chat.completion",
"usage": null
}
```
......@@ -106,7 +106,12 @@ class VllmEncodeWorker:
# 8. Yield the encode response.
try:
image = await self.image_loader.load_image(request.image_url)
if not request.multimodal_input.image_url:
raise ValueError("image_url is required for the encode worker.")
image = await self.image_loader.load_image(
request.multimodal_input.image_url
)
logger.debug(f"Processing image for request: {{ id: {request_id} }}")
image_embeds = self.image_processor(images=image, return_tensors="pt")
......@@ -135,7 +140,7 @@ class VllmEncodeWorker:
with self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
# Clear the image URL as hint that the image is passed as embeddings.
request.image_url = None
request.multimodal_input.image_url = None
logger.debug(f"Request: {request.model_dump_json()}")
......
......@@ -40,13 +40,16 @@ from dynamo.runtime.logging import configure_dynamo_logging
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
from utils.args import Config, base_parse_args, parse_endpoint
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.protocol import MultiModalRequest, MyRequestOutput, vLLMMultimodalRequest
from utils.protocol import (
MultiModalInput,
MultiModalRequest,
MyRequestOutput,
vLLMMultimodalRequest,
)
configure_dynamo_logging()
logger = logging.getLogger(__name__)
prompt_template = "USER: <image>\n<prompt> ASSISTANT:"
class RequestType(Enum):
CHAT = "chat"
......@@ -134,7 +137,7 @@ class Processor(ProcessMixIn):
async def _generate(
self,
raw_request: Union[CompletionRequest, ChatCompletionRequest],
image: str,
multimodal_input: MultiModalInput,
request_type: RequestType,
):
request_id = str(uuid.uuid4().hex)
......@@ -151,7 +154,7 @@ class Processor(ProcessMixIn):
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
image_url=image,
multimodal_input=multimodal_input,
)
# model_dump_json() serializes the request to JSON string
......@@ -233,16 +236,23 @@ class Processor(ProcessMixIn):
temperature=raw_request.temperature,
request_id=str(uuid.uuid4()),
)
image_url = None
multimodal_input = MultiModalInput()
for message in raw_request.messages:
for item in message.content:
if item.type == "image_url":
image_url = item.image_url.url
if image_url is None:
raise ValueError("Image URL is required")
multimodal_input.image_url = item.image_url.url
elif item.type == "video_url":
if multimodal_input.image_url is not None:
raise ValueError("Cannot provide both image and video URLs")
multimodal_input.video_url = item.video_url.url
if multimodal_input.image_url is None and multimodal_input.video_url is None:
raise ValueError("Either image URL or video URL is required")
async for response in self._generate(chat_request, image_url, RequestType.CHAT):
async for response in self._generate(
chat_request, multimodal_input, RequestType.CHAT
):
logger.debug(
f"Generated response type {type(response)}, content: {response}"
)
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import logging
import os
import signal
import sys
from io import BytesIO
from queue import Queue
from typing import AsyncIterator, Optional, Tuple
import av
import numpy as np
import torch
import uvloop
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser
import dynamo.nixl_connect as connect
from dynamo.runtime import Client, DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
from utils.args import Config, base_parse_args, parse_endpoint
from utils.protocol import MyRequestOutput, vLLMMultimodalRequest
from utils.video_utils import (
calculate_frame_sampling_indices,
get_video_metadata,
load_video_content,
open_video_container,
prepare_tensor_for_rdma,
read_video_pyav,
resize_video_frames,
)
configure_dynamo_logging()
logger = logging.getLogger(__name__)
try:
import cupy as array_module
if not array_module.cuda.is_available():
raise ImportError("CUDA is not available.")
DEVICE = "cuda"
logger.info("Using cupy for array operations (GPU mode).")
except ImportError as e:
logger.warning(f"Failed to import cupy, falling back to numpy: {e}.")
import numpy as array_module
DEVICE = "cpu"
CACHE_SIZE_MAXIMUM = 8
class VllmEncodeWorker:
def __init__(
self,
args: argparse.Namespace,
engine_args: AsyncEngineArgs,
pd_worker_client: Client,
) -> None:
self.pd_worker_client = pd_worker_client
self.engine_args = engine_args
self.model = self.engine_args.model
self.min_workers = 1
# Video processing parameters
self.num_frames_to_sample = args.num_frames_to_sample
self.frame_height = 336
self.frame_width = 336
self.frame_channels = 3
self._video_content_cache: dict[str, BytesIO] = {}
self._cache_queue: Queue[str] = Queue(maxsize=CACHE_SIZE_MAXIMUM)
self._http_timeout = 60.0
def cleanup(self):
pass
async def generate(
self, request: vLLMMultimodalRequest
) -> AsyncIterator[MyRequestOutput]:
logger.debug(f"Got raw request: {request}")
if not isinstance(request, vLLMMultimodalRequest):
if isinstance(request, str):
request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received encode request: {{ id: {request.request_id} }}.")
request_id = request.request_id
video_url = request.multimodal_input.video_url
if video_url is None:
raise ValueError("Video URL is required.")
container: Optional[av.container.InputContainer] = None
try:
video_content_stream = await load_video_content(
video_url,
self._video_content_cache,
self._cache_queue,
self._http_timeout,
)
# Open video container using utility function
container = await open_video_container(video_content_stream, video_url)
if not container or not container.streams.video:
logger.error(f"No video stream found in {video_url}.")
raise ValueError(f"No video stream in {video_url}.")
# Get video metadata using utility function
total_frames, duration_sec = get_video_metadata(container)
# Calculate frame sampling indices using utility function
indices = calculate_frame_sampling_indices(
total_frames, self.num_frames_to_sample, duration_sec, video_url
)
if not container:
raise ValueError(f"Container is None for {video_url}")
# Decode video frames
clip_np: np.ndarray = await read_video_pyav(container, indices)
if clip_np.size == 0:
raise ValueError(
f"Failed to extract any video frames from {video_url} for indices {indices.tolist()}. Clip is empty."
)
logger.debug(
f"Successfully extracted {len(clip_np) if clip_np.ndim > 1 and clip_np.shape[0] > 0 else 0} frames for {video_url} with original shape {clip_np.shape}."
)
# Convert the NumPy array from the video decoder into a PyTorch tensor.
# This is a required step to use PyTorch functions for GPU-accelerated image processing.
frames_tensor_orig_res = torch.from_numpy(clip_np) # Shape: (T, H, W, C)
# Resize frames using utility function
resized_frames_tensor_hwc = resize_video_frames(
frames_tensor_orig_res, self.frame_height, self.frame_width
)
# Prepare tensor for RDMA using utility function
tensor_for_descriptor = prepare_tensor_for_rdma(
resized_frames_tensor_hwc, request_id
)
request.embeddings_shape = tuple(tensor_for_descriptor.shape)
descriptor = connect.Descriptor(tensor_for_descriptor)
with self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
# Clear the image URL as hint that the image is passed as embeddings.
request.multimodal_input.video_url = None
logger.debug(f"Request: {request.model_dump_json()}")
# Get the response generator
response_generator = await self.pd_worker_client.round_robin(
request.model_dump_json()
)
await readable.wait_for_completion()
async for response in response_generator:
output = MyRequestOutput.model_validate_json(response.data())
yield MyRequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
).model_dump_json()
except (
FileNotFoundError,
av.FFmpegError,
ValueError,
) as e:
logger.error(
f"Error processing request {request_id} ({video_url[:100]}...): {type(e).__name__} - {e}"
)
raise # Re-raise to be handled by the service framework
except Exception as e:
logger.exception(
f"Unexpected error processing request {request_id} ({video_url[:100]}...): {e}"
)
raise
finally:
if container:
await asyncio.to_thread(container.close)
async def async_init(self, runtime: DistributedRuntime):
logger.info("Startup started.")
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Startup completed.")
@classmethod
def parse_args(cls) -> Tuple[argparse.Namespace, Config]:
DEFAULT_ENDPOINT = "dyn://dynamo.encoder.generate"
DEFAULT_DOWNSTREAM_ENDPOINT = "dyn://dynamo.llm.generate"
parser = FlexibleArgumentParser(
description="vLLM based encoder for Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_ENDPOINT}'",
)
parser.add_argument(
"--downstream-endpoint",
type=str,
default=DEFAULT_DOWNSTREAM_ENDPOINT,
help=f"The endpoint string of the downstream LLM in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_DOWNSTREAM_ENDPOINT}'",
)
parser.add_argument(
"--num-frames-to-sample",
type=int,
default=8,
help="Number of frames to sample from the video. Default: 8",
)
args, config = base_parse_args(parser)
return args, config
async def graceful_shutdown(runtime):
"""
By calling `runtime.shutdown()`, the endpoints will immediately be unavailable.
However, in-flight requests will still be processed until they are finished.
After all in-flight requests are finished, the `serve_endpoint` functions will return
and the engine will be shutdown by Python's garbage collector.
"""
logging.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
# Runtime setup
# Set up signal handler for graceful shutdown
loop = asyncio.get_running_loop()
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logging.info("Signal handlers set up for graceful shutdown")
# worker setup
args, config = VllmEncodeWorker.parse_args()
await init(runtime, args, config)
async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Config):
"""
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
generate_endpoint = component.endpoint(config.endpoint)
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
args.downstream_endpoint
)
pd_worker_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
handler = VllmEncodeWorker(args, config.engine_args, pd_worker_client)
await handler.async_init(runtime)
logger.info("Waiting for PD Worker Instances ...")
await pd_worker_client.wait_for_instances()
logger.info(f"Starting to serve the {args.endpoint} endpoint...")
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
......@@ -245,8 +245,13 @@ class VllmPDWorker(VllmBaseWorker):
.client()
)
self.EMBEDDINGS_DTYPE = torch.float16
if "video" in self.engine_args.model.lower():
self.EMBEDDINGS_DTYPE = torch.uint8
else:
self.EMBEDDINGS_DTYPE = torch.float16
self.EMBEDDINGS_DEVICE = "cpu"
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
parsed_namespace, _, _ = parse_endpoint(self.endpoint)
......@@ -277,7 +282,10 @@ class VllmPDWorker(VllmBaseWorker):
)
descriptor = connect.Descriptor(embeddings)
if request.image_url is None:
if (
request.multimodal_input.image_url is None
and request.multimodal_input.video_url is None
):
if descriptor is None:
raise RuntimeError(
"Descriptor is None in PD worker - cannot process embeddings"
......@@ -287,20 +295,31 @@ class VllmPDWorker(VllmBaseWorker):
request.serialized_request, descriptor
)
await read_op.wait_for_completion()
multi_modal_data = construct_mm_data(
self.engine_args.model,
embeddings,
self.EMBEDDINGS_DTYPE,
request.image_grid_thw,
)
if "video" in self.engine_args.model.lower():
video_numpy = embeddings.numpy()
multi_modal_data = construct_mm_data(
self.engine_args.model,
self.EMBEDDINGS_DTYPE,
video_numpy=video_numpy,
)
else:
multi_modal_data = construct_mm_data(
self.engine_args.model,
self.EMBEDDINGS_DTYPE,
image_embeds=embeddings,
image_grid_thw=request.image_grid_thw,
)
else:
# Use PIL image instead of image embeddings
multi_modal_data = {
"image": await self.image_loader.load_image(request.image_url)
"image": await self.image_loader.load_image(
request.multimodal_input.image_url
)
}
# Remove the image features from the request as they are not required
request.image_url = None
request.multimodal_input.image_url = None
request.multimodal_input.video_url = None
request.serialized_request = None
pd_request = copy.deepcopy(request)
......
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
# Default values
MODEL_NAME="llava-hf/LLaVA-NeXT-Video-7B-hf"
PROMPT_TEMPLATE="USER: <video>\n<prompt> ASSISTANT:"
NUM_FRAMES_TO_SAMPLE=8
# run ingress
python -m dynamo.frontend &
# run processor
python3 components/processor.py --model $MODEL_NAME --prompt-template "$PROMPT_TEMPLATE" &
# run E/P/D workers
CUDA_VISIBLE_DEVICES=0 python3 components/video_encode_worker.py --model $MODEL_NAME --num-frames-to-sample $NUM_FRAMES_TO_SAMPLE &
CUDA_VISIBLE_DEVICES=1 python3 components/worker.py --model $MODEL_NAME --worker-type prefill &
# Wait for all background processes to complete
wait
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
# Default values
MODEL_NAME="llava-hf/LLaVA-NeXT-Video-7B-hf"
PROMPT_TEMPLATE="USER: <video>\n<prompt> ASSISTANT:"
NUM_FRAMES_TO_SAMPLE=8
# run ingress
python -m dynamo.frontend &
# run processor
python3 components/processor.py --model $MODEL_NAME --prompt-template "$PROMPT_TEMPLATE" &
# run E/P/D workers
CUDA_VISIBLE_DEVICES=0 python3 components/video_encode_worker.py --model $MODEL_NAME --num-frames-to-sample $NUM_FRAMES_TO_SAMPLE &
CUDA_VISIBLE_DEVICES=1 python3 components/worker.py --model $MODEL_NAME --worker-type prefill --enable-disagg &
CUDA_VISIBLE_DEVICES=2 python3 components/worker.py --model $MODEL_NAME --worker-type decode --enable-disagg &
# Wait for all background processes to complete
wait
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
import httpx
logger = logging.getLogger(__name__)
# Global HTTP client instance
_global_http_client: Optional[httpx.AsyncClient] = None
def get_http_client(timeout: float = 60.0) -> httpx.AsyncClient:
"""
Get or create a shared HTTP client instance.
Args:
timeout: Timeout for HTTP requests
Returns:
Shared HTTP client instance
"""
global _global_http_client
if _global_http_client is None or _global_http_client.is_closed:
_global_http_client = httpx.AsyncClient(
timeout=timeout,
follow_redirects=True,
limits=httpx.Limits(max_keepalive_connections=20, max_connections=100),
)
logger.info(f"Shared HTTP client initialized with timeout={timeout}s")
return _global_http_client
......@@ -23,17 +23,18 @@ from urllib.parse import urlparse
import httpx
from PIL import Image
from .http_client import get_http_client
logger = logging.getLogger(__name__)
class ImageLoader:
CACHE_SIZE_MAXIMUM = 8
def __init__(self, cache_size: int = CACHE_SIZE_MAXIMUM):
self._http_timeout = 30.0
self._http_client = httpx.AsyncClient(
timeout=self._http_timeout, follow_redirects=True
)
def __init__(
self, cache_size: int = CACHE_SIZE_MAXIMUM, http_timeout: float = 30.0
):
self._http_timeout = http_timeout
self._image_cache: dict[str, Image.Image] = {}
self._cache_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=cache_size)
......@@ -64,10 +65,9 @@ class ImageLoader:
except binascii.Error as e:
raise ValueError(f"Invalid base64 encoding: {e}")
elif parsed_url.scheme in ("http", "https"):
if not self._http_client:
raise RuntimeError("HTTP client not initialized")
http_client = get_http_client(self._http_timeout)
response = await self._http_client.get(image_url)
response = await http_client.get(image_url)
response.raise_for_status()
if not response.content:
......
......@@ -14,10 +14,10 @@
# limitations under the License.
import logging
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional
import torch
from transformers import AutoConfig, AutoModel
from transformers import AutoModel
logger = logging.getLogger(__name__)
......@@ -40,52 +40,47 @@ def load_vision_model(model_id: str) -> torch.nn.Module:
return model
def get_vision_embeddings_info(
model_id: str,
) -> Tuple[Tuple[int, int, int], torch.dtype]:
"""Calculate vision embeddings size and dtype using model config
Returns a tuple of (batch_size, seq_len, hidden_dim), dtype.
"""
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
if model_id == SupportedModels.LLAVA_1_5_7B:
seq_len = 577
elif model_id == SupportedModels.QWEN_2_5_VL_7B:
seq_len = 345
else:
seq_len = 0
if not hasattr(config, "torch_dtype"):
raise ValueError("Model config missing required 'torch_dtype' attribute")
if not hasattr(config, "hidden_size"):
logger.warning(
"Model config missing required 'hidden_size' attribute, using 4096"
)
hidden_size = 4096
else:
hidden_size = config.hidden_size
return (1, seq_len, hidden_size), config.torch_dtype
def construct_mm_data(
model: str,
image_embeds: torch.Tensor,
embeddings_dtype: torch.dtype,
image_grid_thw: Optional[List[Any]],
image_embeds: Optional[torch.Tensor] = None,
video_numpy: Optional[Any] = None,
image_grid_thw: Optional[List[Any]] = None,
) -> Dict[str, torch.Tensor | Dict[str, Any]]:
"""Construct multimodal data for a vLLM request for models that require additional parameters alongside the embeddings"""
# Handle video models
if model == SupportedModels.LLAVA_NEXT_VIDEO_7B:
if video_numpy is None:
raise ValueError("No video frames provided.")
return {"video": video_numpy}
# Handle image models - validate image embeddings first
if image_embeds is None:
raise ValueError("No image embeddings provided.")
image_embeds = image_embeds.to(embeddings_dtype)
# Model-specific image handling
if model == SupportedModels.QWEN_2_5_VL_7B:
if image_grid_thw is not None and len(image_grid_thw) > 0:
grid_thw_tensor = torch.tensor(image_grid_thw)
else:
raise ValueError("No image grid provided.")
return {
"image": {
"image_embeds": image_embeds.squeeze(0),
"image_grid_thw": grid_thw_tensor,
}
}
return _construct_qwen_image_data(image_embeds, image_grid_thw)
else:
# Default image handling for other models (e.g., LLAVA_1_5_7B)
return {"image": image_embeds}
def _construct_qwen_image_data(
image_embeds: torch.Tensor, image_grid_thw: Optional[List[Any]]
) -> Dict[str, Dict[str, torch.Tensor]]:
"""Construct image data specifically for Qwen models."""
if image_grid_thw is None or len(image_grid_thw) == 0:
raise ValueError("No image grid provided for Qwen model.")
grid_thw_tensor = torch.tensor(image_grid_thw)
return {
"image": {
"image_embeds": image_embeds.squeeze(0),
"image_grid_thw": grid_thw_tensor,
}
}
......@@ -18,7 +18,7 @@ import json
from typing import Any, List, Literal, Optional, Tuple, Union
import msgspec
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic_core import core_schema
from typing_extensions import NotRequired
from vllm.inputs.data import TokensPrompt
......@@ -107,7 +107,16 @@ class ImageContent(BaseModel):
image_url: ImageURLDetail
MessageContent = Union[TextContent, ImageContent]
class VideoURLDetail(BaseModel):
url: str
class VideoContent(BaseModel):
type: Literal["video_url"]
video_url: VideoURLDetail
MessageContent = Union[TextContent, ImageContent, VideoContent]
class ChatMessage(BaseModel):
......@@ -124,22 +133,18 @@ class MultiModalRequest(BaseModel):
stream: Optional[bool] = True
class vLLMMultimodalRequest(vLLMGenerateRequest):
model_config = ConfigDict(arbitrary_types_allowed=True)
class MultiModalInput(BaseModel):
image_url: Optional[str] = None
image_grid_thw: Optional[List[Any]] = None
embeddings_shape: Optional[Tuple[int, int, int]] = None
serialized_request: Optional[connect.RdmaMetadata] = None
video_url: Optional[str] = None
class EncodeRequest(BaseModel):
"""
Serializable class of all the fields vLLM engine requires for inference
"""
class vLLMMultimodalRequest(vLLMGenerateRequest):
model_config = ConfigDict(arbitrary_types_allowed=True)
image_url: str
request_id: str
multimodal_input: Optional[MultiModalInput] = Field(default_factory=MultiModalInput)
image_grid_thw: Optional[List[Any]] = None
embeddings_shape: Optional[
Union[Tuple[int, int, int], Tuple[int, int, int, int]]
] = None
serialized_request: Optional[connect.RdmaMetadata] = None
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import base64
import binascii
import logging
import os
from io import BytesIO
from queue import Queue
from typing import Tuple
from urllib.parse import urlparse
import av
import httpx
import numpy as np
import torch
import torch.nn.functional as F
from .http_client import get_http_client
logger = logging.getLogger(__name__)
async def load_video_content(
video_url: str,
video_content_cache: dict[str, BytesIO],
cache_queue: Queue[str],
http_timeout: float = 60.0,
) -> BytesIO:
"""
Load video content from various sources (URL, data URI, file).
Args:
video_url: The video URL or path
video_content_cache: Cache dictionary for storing downloaded content
cache_queue: Queue for managing cache eviction
http_timeout: Timeout for HTTP requests
Returns:
BytesIO stream containing video data
Raises:
ValueError: If video source is unsupported or invalid
FileNotFoundError: If local file doesn't exist
RuntimeError: If HTTP client initialization fails
"""
parsed_url = urlparse(video_url)
video_url_lower = video_url.lower()
if parsed_url.scheme in ("http", "https"):
if video_url_lower in video_content_cache:
logger.debug(f"Video content found in cache for URL: {video_url}")
cached_content = video_content_cache[video_url_lower]
cached_content.seek(0)
return cached_content
try:
video_data: BytesIO
if parsed_url.scheme == "data":
if not parsed_url.path.startswith(("video/", "application/octet-stream")):
raise ValueError("Data URL must be a video type or octet-stream")
media_type_and_data = parsed_url.path.split(",", 1)
if len(media_type_and_data) != 2:
raise ValueError("Invalid Data URL format: missing comma separator")
media_type, data_segment = media_type_and_data
if ";base64" not in media_type:
raise ValueError("Video Data URL currently must be base64 encoded")
try:
video_bytes = base64.b64decode(data_segment)
video_data = BytesIO(video_bytes)
except binascii.Error as e:
raise ValueError(f"Invalid base64 encoding for video data: {e}") from e
elif parsed_url.scheme in ("http", "https"):
http_client = get_http_client(http_timeout)
logger.debug(f"Downloading video from URL: {video_url}")
response = await http_client.get(video_url, timeout=http_timeout)
response.raise_for_status()
if not response.content:
raise ValueError(f"Empty response content from video URL: {video_url}")
video_data = BytesIO(response.content)
video_data.seek(0)
logger.debug(
f"Video downloaded from {video_url}, size: {len(response.content)} bytes."
)
elif parsed_url.scheme == "file" or not parsed_url.scheme:
file_path = parsed_url.path if parsed_url.scheme else video_url
# Ensure path is absolute or resolve relative to a known base if necessary
# For simplicity, assuming it's an accessible path.
if not os.path.exists(file_path):
raise FileNotFoundError(f"Error reading file: {file_path}")
with open(file_path, "rb") as f:
video_bytes = f.read()
video_data = BytesIO(video_bytes)
else:
raise ValueError(
f"Unsupported video source scheme: {parsed_url.scheme} for URL {video_url}"
)
if parsed_url.scheme in (
"http",
"https",
): # Cache successfully downloaded content
if cache_queue.full():
oldest_url = cache_queue.get_nowait()
if oldest_url in video_content_cache:
del video_content_cache[oldest_url]
# Store the BytesIO object directly; it will be seek(0)'d when retrieved
video_content_cache[video_url_lower] = video_data
cache_queue.put(video_url_lower)
return video_data
except httpx.HTTPStatusError as e:
logger.error(
f"HTTP error {e.response.status_code} loading video {video_url}: {e.response.text[:200]}"
)
raise ValueError(
f"Failed to download video {video_url}: HTTP {e.response.status_code}"
) from e
except httpx.RequestError as e:
logger.error(f"Request error loading video {video_url}: {e}")
raise ValueError(f"Network request failed for video {video_url}") from e
except FileNotFoundError as e:
logger.error(f"File error loading video {video_url}: {e}")
raise
except Exception as e:
logger.error(
f"Error loading video content from {video_url}: {type(e).__name__} - {e}"
)
raise ValueError(f"Failed to load video content: {e}") from e
async def open_video_container(
video_content_stream: BytesIO, video_url: str
) -> av.container.InputContainer:
"""
Open a video container from a BytesIO stream using PyAV.
Args:
video_content_stream: BytesIO stream containing video data
video_url: Original video URL for error reporting
Returns:
Opened PyAV container
Raises:
ValueError: If video format is invalid or corrupted
"""
def open_video_container_sync():
try:
return av.open(video_content_stream, mode="r")
except av.FFmpegError as ave:
logger.error(f"PyAV error opening video stream from {video_url}: {ave}")
raise ValueError(
f"Invalid video format or corrupted data from {video_url}."
) from ave
except Exception as e:
logger.error(
f"Unexpected error opening video stream from {video_url} with PyAV: {e}"
)
raise ValueError(f"Unexpected error opening video from {video_url}.") from e
return await asyncio.to_thread(open_video_container_sync)
def get_video_metadata(container: av.container.InputContainer) -> Tuple[int, float]:
"""
Extract metadata from video container.
Args:
container: Opened PyAV container
Returns:
Tuple of (total_frames, duration_in_seconds)
"""
if not container or not container.streams.video:
return 0, 0.0
stream_info = container.streams.video[0]
total_frames = stream_info.frames
# Duration can be useful for streams where total_frames is 0
if stream_info.duration and stream_info.time_base:
duration_sec = float(stream_info.duration * stream_info.time_base)
else:
duration_sec = 0.0
return total_frames, duration_sec
async def read_video_pyav(
container: av.container.InputContainer, indices: np.ndarray
) -> np.ndarray:
"""
Decode the video with PyAV decoder. Async wrapper.
Args:
container: The video container to decode from
indices: Frame indices to extract
Returns:
NumPy array of decoded frames
Raises:
ValueError: If no frames could be decoded for the given indices
"""
def blocking_decode():
container.seek(0) # Reset container for decoding
processed_indices = set(indices)
# Determine min/max index to optimize decoding loop slightly
min_idx = 0
max_idx = -1
if len(indices) > 0:
min_idx = np.min(indices)
max_idx = np.max(indices)
if (
not processed_indices
and container.streams.video
and container.streams.video[0].frames > 0
):
logger.warning(
"read_video_pyav called with empty indices for a non-empty video, attempting to read first frame."
)
try:
frame = next(container.decode(video=0))
return np.stack([frame.to_ndarray(format="rgb24")])
except StopIteration:
logger.error(
"Failed to read even the first frame despite non-empty indices check."
)
return np.array([])
decoded_frames_list = []
for i, frame in enumerate(container.decode(video=0)):
if i > max_idx and max_idx != -1: # max_idx is -1 if indices is empty
break
if i >= min_idx and i in processed_indices:
decoded_frames_list.append(frame)
if not decoded_frames_list and len(processed_indices) > 0:
actual_decoded_count = 0
try:
container.seek(0) # Reset for counting
for _ in container.decode(video=0):
actual_decoded_count += 1
except Exception: # Handle cases where re-decoding/counting fails
pass # Keep original error message
raise ValueError(
f"Could not decode any frames for the given indices: {indices.tolist()}. "
f"Video might be shorter than expected or indices out of bounds. "
f"Actual decodable frames in container (approx): {actual_decoded_count}."
)
return (
np.stack([x.to_ndarray(format="rgb24") for x in decoded_frames_list])
if decoded_frames_list
else np.array([])
)
return await asyncio.to_thread(blocking_decode)
def calculate_frame_sampling_indices(
total_frames: int,
num_frames_to_sample: int,
duration_sec: float = 0,
video_url: str = "",
) -> np.ndarray:
"""
Calculate frame indices to sample from a video.
Args:
total_frames: Total number of frames in the video
num_frames_to_sample: Number of frames to sample
duration_sec: Duration of video in seconds (for logging)
video_url: Video URL for logging purposes
Returns:
Array of frame indices to sample
Raises:
ValueError: If video has 0 frames and 0 duration
"""
if total_frames == 0 and duration_sec == 0:
logger.error(f"Video file '{video_url}' has 0 frames and 0 duration.")
raise ValueError(f"Video {video_url} has 0 frames and 0 duration.")
if total_frames == 0 and duration_sec > 0:
logger.warning(
f"Video {video_url} reports 0 frames but has duration {duration_sec:.2f}s. "
"Frame sampling may be based on requested count directly."
)
logger.debug(
f"Video {video_url} has {total_frames} frames (duration: {duration_sec:.2f}s). "
f"Sampling {num_frames_to_sample} frames."
)
indices: np.ndarray
if total_frames > 0:
if total_frames < num_frames_to_sample:
logger.warning(
f"Video frames ({total_frames}) < samples ({num_frames_to_sample}). "
f"Using all {total_frames} available frames."
)
indices = np.arange(0, total_frames).astype(int)
else:
indices = np.linspace(0, total_frames - 1, num_frames_to_sample, dtype=int)
else: # total_frames is 0 (likely a stream), sample by count.
logger.warning(
f"Video {video_url} frame count is 0. Attempting to sample {num_frames_to_sample} "
"frames by index. This might fail if stream is too short."
)
indices = np.arange(0, num_frames_to_sample).astype(int)
# Ensure indices are unique, especially after linspace for small numbers.
indices = np.unique(indices)
# Safety checks for edge cases
if len(indices) == 0 and total_frames > 0:
# If unique resulted in empty but there are frames, sample at least one
actual_samples = min(num_frames_to_sample, total_frames)
indices = np.arange(0, actual_samples).astype(int)
elif len(indices) == 0 and total_frames == 0:
# If indices is empty and total_frames is 0, let downstream handle this case
pass
logger.debug(f"Selected frame indices for {video_url}: {indices.tolist()}")
return indices
def resize_video_frames(
frames_tensor: torch.Tensor, target_height: int, target_width: int
) -> torch.Tensor:
"""
Resize video frames using PyTorch interpolation.
Args:
frames_tensor: Input tensor with shape (T, H, W, C)
target_height: Target frame height
target_width: Target frame width
Returns:
Resized tensor with shape (T, target_height, target_width, C)
"""
# Permute to (T, C, H, W) for interpolate
frames_tensor_chw = frames_tensor.permute(0, 3, 1, 2).float()
# Resize
resized_frames_tensor_chw = F.interpolate(
frames_tensor_chw,
size=(target_height, target_width),
mode="bilinear",
align_corners=False,
)
# Permute back to (T, H_new, W_new, C)
resized_frames_tensor_hwc = resized_frames_tensor_chw.permute(0, 2, 3, 1)
logger.debug(f"Resized frames to shape: {resized_frames_tensor_hwc.shape}")
return resized_frames_tensor_hwc
def prepare_tensor_for_rdma(
frames_tensor: torch.Tensor, request_id: str
) -> torch.Tensor:
"""
Prepare video frames tensor for RDMA transfer.
Args:
frames_tensor: Input frames tensor
request_id: Request ID for logging
Returns:
Tensor prepared for RDMA (CPU, uint8, contiguous)
"""
# Ensure the tensor is contiguous, on CPU and uint8 for the NIXL buffer.
tensor_for_descriptor = frames_tensor.to(
device="cpu", dtype=torch.uint8
).contiguous()
logger.debug(
f"Req {request_id}: Preparing raw frames tensor (shape: {tensor_for_descriptor.shape}, "
f"dtype: {tensor_for_descriptor.dtype}, device: {tensor_for_descriptor.device}, "
f"contiguous: {tensor_for_descriptor.is_contiguous()}) for RDMA."
)
return tensor_for_descriptor
......@@ -212,6 +212,19 @@ pub struct ImageUrl {
pub detail: Option<ImageDetail>,
}
#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)]
#[builder(name = "VideoUrlArgs")]
#[builder(pattern = "mutable")]
#[builder(setter(into, strip_option), default)]
#[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))]
pub struct VideoUrl {
/// Either a URL of the video or the base64 encoded video data.
pub url: String,
/// Specifies the detail level of the video processing.
pub detail: Option<ImageDetail>,
}
#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)]
#[builder(name = "ChatCompletionRequestMessageContentPartImageArgs")]
#[builder(pattern = "mutable")]
......@@ -222,6 +235,16 @@ pub struct ChatCompletionRequestMessageContentPartImage {
pub image_url: ImageUrl,
}
#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)]
#[builder(name = "ChatCompletionRequestMessageContentPartVideoArgs")]
#[builder(pattern = "mutable")]
#[builder(setter(into, strip_option), default)]
#[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))]
pub struct ChatCompletionRequestMessageContentPartVideo {
pub video_url: VideoUrl,
}
#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum InputAudioFormat {
......@@ -255,6 +278,7 @@ pub struct ChatCompletionRequestMessageContentPartAudio {
pub enum ChatCompletionRequestUserMessageContentPart {
Text(ChatCompletionRequestMessageContentPartText),
ImageUrl(ChatCompletionRequestMessageContentPartImage),
VideoUrl(ChatCompletionRequestMessageContentPartVideo),
InputAudio(ChatCompletionRequestMessageContentPartAudio),
}
......
......@@ -15,7 +15,7 @@ use serde::{Deserialize, Serialize};
use crate::error::OpenAIError;
use super::{ImageDetail, ImageUrl};
use super::{ImageDetail, ImageUrl, VideoUrl};
#[derive(Clone, Serialize, Debug, Deserialize, PartialEq, Default)]
#[serde(rename_all = "lowercase")]
......@@ -114,6 +114,7 @@ pub enum MessageContent {
Text(MessageContentTextObject),
ImageFile(MessageContentImageFileObject),
ImageUrl(MessageContentImageUrlObject),
VideoUrl(MessageContentVideoUrlObject),
Refusal(MessageContentRefusalObject),
}
......@@ -198,6 +199,12 @@ pub struct MessageContentImageUrlObject {
pub image_url: ImageUrl,
}
/// References a video URL in the content of a message.
#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)]
pub struct MessageContentVideoUrlObject {
pub video_url: VideoUrl,
}
#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)]
pub struct MessageRequestContentTextObject {
/// Text content to be sent to the model
......@@ -220,6 +227,7 @@ pub enum MessageContentInput {
Text(MessageRequestContentTextObject),
ImageFile(MessageContentImageFileObject),
ImageUrl(MessageContentImageUrlObject),
VideoUrl(MessageContentVideoUrlObject),
}
#[derive(Clone, Serialize, Default, Debug, Deserialize, Builder, PartialEq)]
#[builder(name = "CreateMessageRequestArgs")]
......@@ -289,6 +297,7 @@ pub struct MessageDelta {
pub enum MessageDeltaContent {
ImageFile(MessageDeltaContentImageFileObject),
ImageUrl(MessageDeltaContentImageUrlObject),
VideoUrl(MessageDeltaContentVideoUrlObject),
Text(MessageDeltaContentTextObject),
Refusal(MessageDeltaContentRefusalObject),
}
......@@ -363,3 +372,11 @@ pub struct MessageDeltaContentImageUrlObject {
pub image_url: Option<ImageUrl>,
}
#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)]
pub struct MessageDeltaContentVideoUrlObject {
/// The index of the content part in the message.
pub index: u32,
pub video_url: Option<VideoUrl>,
}
......@@ -92,7 +92,7 @@ pub enum InputContent {
InputItemContentList(Vec<ContentType>),
}
/// Parts of a message: text, image, file, or audio.
/// Parts of a message: text, image, video, file, or audio.
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentType {
......@@ -100,6 +100,8 @@ pub enum ContentType {
InputText(InputText),
/// An image input to the model.
InputImage(InputImage),
/// A video input to the model.
InputVideo(InputVideo),
/// A file input to the model.
InputFile(InputFile),
}
......@@ -129,6 +131,26 @@ pub struct InputImage {
image_url: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default, Builder)]
#[builder(
name = "InputVideoArgs",
pattern = "mutable",
setter(into, strip_option),
default
)]
#[builder(build_fn(error = "OpenAIError"))]
pub struct InputVideo {
/// The detail level of the video to be sent to the model.
detail: ImageDetail,
/// The ID of the file to be sent to the model.
#[serde(skip_serializing_if = "Option::is_none")]
file_id: Option<String>,
/// The URL of the video to be sent to the model. A fully qualified URL or base64 encoded video
/// in a data URL.
#[serde(skip_serializing_if = "Option::is_none")]
video_url: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default, Builder)]
#[builder(
name = "InputFileArgs",
......
......@@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
def create_payload_for_config(config: "VLLMConfig") -> Payload:
"""Create a payload using the model from the vLLM config"""
if "multimodal" in config.name:
if config.name in ["multimodal_agg_llava", "multimodal_agg_qwen"]:
# Special handling for multimodal models
return Payload(
payload_chat={
......@@ -50,6 +50,32 @@ def create_payload_for_config(config: "VLLMConfig") -> Payload:
expected_log=[],
expected_response=["bus"],
)
elif config.name == "multimodal_video_agg":
# Special handling for multimodal models
return Payload(
payload_chat={
"model": config.model,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "Describe the video in detail"},
{
"type": "video_url",
"video_url": {
"url": "https://storage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4"
},
},
],
}
],
"max_tokens": 300,
"stream": False,
},
repeat_count=1,
expected_log=[],
expected_response=["rabbit"],
)
else:
# Use base implementation for standard text models
return base_create_payload(config)
......@@ -194,6 +220,20 @@ vllm_configs = {
args=["--model", "Qwen/Qwen2.5-VL-7B-Instruct"],
timeout=360,
),
"multimodal_video_agg": VLLMConfig(
name="multimodal_video_agg",
directory="/workspace/examples/multimodal",
script_name="video_agg.sh",
marks=[pytest.mark.gpu_2, pytest.mark.vllm],
endpoints=["v1/chat/completions"],
response_handlers=[
chat_completions_response_handler,
],
model="llava-hf/LLaVA-NeXT-Video-7B-hf",
delayed_start=0,
args=["--model", "llava-hf/LLaVA-NeXT-Video-7B-hf"],
timeout=360,
),
# TODO: Enable this test case when we have 4 GPUs runners.
# "multimodal_disagg": VLLMConfig(
# name="multimodal_disagg",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment