Unverified Commit 75503dae authored by Indrajit Bhosale's avatar Indrajit Bhosale Committed by GitHub
Browse files

feat: Video support with Dynamo (#1443)

Added support in examples for multimodal video with aggregated and disaggregated architecture using LLaVA-NeXT-Video-7B
parent 6f8c68c1
......@@ -265,3 +265,182 @@ curl localhost:8000/v1/chat/completions \
If serving the example Qwen model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"Qwen/Qwen2.5-VL-7B-Instruct"`. If serving the example Phi3V model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"microsoft/Phi-3.5-vision-instruct"`.
For more details on managing deployments, testing, and troubleshooting, please refer to the [Operator Deployment Guide](../../docs/guides/dynamo_deploy/operator_deployment.md).
## Multimodal Aggregated Video Serving
This example demonstrates deploying an aggregated multimodal model that can process video inputs.
### Dependency
Video example relies on `av` package for video preprocessing inside the encode_worker.
Please install `av` inside the dynamo container to enable video example.
`pip install av`
### Components
- workers: For video serving, we have two workers, [video_encode_worker](components/video_encode_worker.py) for decoding video into frames, and [video_decode_worker](components/video_decode_worker.py) for prefilling and decoding.
- processor: Tokenizes the prompt and passes it to the decode worker.
- frontend: HTTP endpoint to handle incoming requests.
### Graph
In this graph, we have two workers, `video_encode_worker` and `video_decode_worker`.
The `video_encode_worker` is responsible for decoding the video into a series of frames. Unlike the image pipeline which generates embeddings, this pipeline passes the raw frames directly to the `video_decode_worker`. This transfer is done efficiently using RDMA.
The `video_decode_worker` then receives these frames, and performs prefill and decode steps with the model. Separating the video processing from the language model inference allows for flexible scaling.
This figure shows the flow of the graph:
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --> video_decode_worker
video_decode_worker --> processor
video_decode_worker --video_url--> video_encode_worker
video_encode_worker --frames--> video_decode_worker
```
```bash
cd $DYNAMO_HOME/examples/multimodal
# Serve a LLaVA-NeXT-Video-7B model:
dynamo serve graphs.agg_video:Frontend -f ./configs/agg_video.yaml
```
### Client
In another terminal:
```bash
curl -X 'POST' 'http://localhost:8000/v1/chat/completions' -H 'Content-Type: application/json' -d '{
"model": "llava-hf/LLaVA-NeXT-Video-7B-hf",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Describe the video in detail"
},
{
"type": "video_url",
"video_url": {
"url": "https://storage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4"
}
}
]
}
],
"max_tokens": 300,
"stream": false
}' | jq
```
You should see a response describing the video's content similar to
```json
{
"id": "b5714626-5889-4bb7-8c51-f3bca65b4683",
"object": "chat.completion",
"created": 1749772533,
"model": "llava-hf/LLaVA-NeXT-Video-7B-hf",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": " Sure! The video features a group of anthropomorphic animals who appear human-like. They're out in a meadow, which is a large, open area covered in grasses, and have given human qualities like speaking and a desire to go on adventures. The animals are seen play-fighting with each other clearly seen glancing at the camera when they sense it, blinking, and Roman the second can be directly heard by the camera reciting the line, \"When the challenge becomes insane, the behavior becomes erratic.\" A white rabbit is the first in shot and he winks the left eye and flips the right ear before shaking with the mouse and squirrel friends on a blurry rock ledge under the sky. At some point, the rabbit turns towards the camera and starts playing with the thing, and there's a distant mountain in the background. Furthermore, a little animal from a tree in the background flies with two rocks, and it's joined by the rest of the group of friends. That outro is an elder turtle in the Ramden musical style saturated with a horn-like thing pattern."
},
"finish_reason": "stop"
}
]
}
```
## Multimodal Disaggregated Video Serving
This example demonstrates deploying a disaggregated multimodal model that can process video inputs.
### Dependency
Video example relies on `av` package for video preprocessing inside the encode_worker.
Please install `av` inside the dynamo container to enable video example.
`pip install av`
### Components
- workers: For disaggregated video serving, we have three workers, [video_encode_worker](components/video_encode_worker.py) for decoding video into frames, [video_decode_worker](components/video_decode_worker.py) for decoding, and [video_prefill_worker](components/video_prefill_worker.py) for prefilling.
- processor: Tokenizes the prompt and passes it to the decode worker.
- frontend: HTTP endpoint to handle incoming requests.
### Graph
In this graph, we have three workers, `video_encode_worker`, `video_decode_worker`, and `video_prefill_worker`.
For the LLaVA-NeXT-Video-7B model, frames are only required during the prefill stage. As such, the `video_encode_worker` is connected directly to the `video_prefill_worker`.
The `video_encode_worker` is responsible for decoding the video into a series of frames and passing them to the `video_prefill_worker` via RDMA.
The `video_prefill_worker` performs the prefilling step and forwards the KV cache to the `video_decode_worker` for decoding.
This figure shows the flow of the graph:
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --> video_decode_worker
video_decode_worker --> processor
video_decode_worker --> video_prefill_worker
video_prefill_worker --> video_decode_worker
video_prefill_worker --video_url--> video_encode_worker
video_encode_worker --frames--> video_prefill_worker
```
```bash
cd $DYNAMO_HOME/examples/multimodal
# Serve a LLaVA-NeXT-Video-7B model:
dynamo serve graphs.disagg_video:Frontend -f ./configs/disagg_video.yaml
```
### Client
In another terminal:
```bash
curl -X 'POST' 'http://localhost:8000/v1/chat/completions' -H 'Content-Type: application/json' -d '{
"model": "llava-hf/LLaVA-NeXT-Video-7B-hf",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Describe the video in detail"
},
{
"type": "video_url",
"video_url": {
"url": "https://storage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4"
}
}
]
}
],
"max_tokens": 300,
"stream": false
}' | jq
```
You should see a response describing the video's content similar to
```json
{
"id": "d1d641b1-4daf-48d3-9d06-6a60743b5a42",
"object": "chat.completion",
"created": 1749775300,
"model": "llava-hf/LLaVA-NeXT-Video-7B-hf",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": " The video features two animals in a lush, green outdoor environment. On the ground, there is a white rabbit with big brown eyes, a playful expression, and two antlers. The rabbit is accompanied by a uniquely colored bird with orange pupils, possibly a squirrel or a hamster, sitting on its head. These two animals seem to have embarked on an unlikely journey, flying together in the sky. The backdrop showcases rolling green hills and trees under the pleasant weather. The sky is clear, indicating a beautiful day. The colors and contrast suggest the landscape is during spring or summer, signifying the rabbit and bird could also be engaging in outdoor activities during those seasons. Overall, it's a charming scene depicting an unlikely yet harmonious pair, enjoying a surprise adventure in nature."
},
"finish_reason": "stop"
}
]
}
```
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
import signal
from typing import Optional, Union
import connect
import torch
from components.disagg_router import PyDisaggregatedRouter
from components.video_encode_worker import VllmEncodeWorker
from components.video_prefill_worker import VllmPrefillWorker
from transformers import AutoProcessor
from utils.logging import check_required_workers
from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue
from utils.protocol import EncodeRequest, MyRequestOutput, vLLMMultimodalRequest
from utils.vllm import parse_vllm_args
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.inputs.data import TokensPrompt
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from vllm.sampling_params import RequestOutputKind
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
logger = logging.getLogger(__name__)
# Constants for the shape and dtype of the INCOMING FRAMES tensor from EncodeWorker.
# IMPORTANT ASSUMPTION: EncodeWorker must provide frames of this fixed shape and dtype.
INCOMING_FRAMES_DTYPE = torch.uint8
INCOMING_FRAMES_DEVICE = "cpu"
@service(
dynamo={
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class VllmDecodeWorker:
# For disaggregated serving, we need to link the prefill worker to the vllm worker
prefill_worker = depends(VllmPrefillWorker)
# For aggregated serving, we need to link the encode worker to the vllm worker.
encode_worker = depends(VllmEncodeWorker)
def _expand_video_tokens_in_prompt(
self,
original_tokens: list[int],
num_frames_to_expand_to: int,
image_token_id: int, # This should be the ID from hf_processor.tokenizer
add_dummy_tokens: bool,
dummy_token_id: int = 0,
num_dummy_tokens_per_frame: int = 0,
) -> list[int]:
"""
Expands the first occurrence of image_token_id in original_tokens
to num_frames_to_expand_to occurrences. Optionally adds dummy tokens.
"""
expanded_prompt_list = []
token_expanded_successfully = False
for token_id_val in original_tokens:
if token_id_val == image_token_id and not token_expanded_successfully:
for _ in range(num_frames_to_expand_to):
expanded_prompt_list.append(image_token_id)
if add_dummy_tokens:
dummy_tokens_to_add = [
dummy_token_id
] * num_dummy_tokens_per_frame
expanded_prompt_list.extend(dummy_tokens_to_add)
token_expanded_successfully = True
else:
expanded_prompt_list.append(token_id_val)
if not token_expanded_successfully:
# If the specific video token ID isn't found (e.g. prompt had no video placeholder),
# it implies the original prompt didn't intend for video.
# This might be an issue if video data is expected.
logger.warning(
f"Image token ID {image_token_id} for expansion not found in prompt tokenized by hf_processor. Prompt: {original_tokens}. This might be okay if no video was intended in this specific prompt structure."
)
return list(original_tokens) # Return original if no video token to expand
return expanded_prompt_list
def __init__(self):
self.client = None
self.min_workers = 1
self.disaggregated_router: Optional[PyDisaggregatedRouter] = None
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self.model_path = self.engine_args.model
self.num_sampled_frames = getattr(self.engine_args, "num_sampled_frames", 8)
self.frame_height = getattr(self.engine_args, "frame_height", 336)
self.frame_width = getattr(self.engine_args, "frame_width", 336)
self.frame_channels = getattr(self.engine_args, "frame_channels", 3)
self.dummy_token_id = getattr(self.engine_args, "dummy_token_id", 0)
self.video_token_id = getattr(self.engine_args, "video_token_id", 32000)
self.dummy_tokens_per_frame = getattr(
self.engine_args, "dummy_tokens_per_frame", 144
)
self.do_remote_prefill = self.engine_args.remote_prefill
self.model_name = (
self.engine_args.served_model_name
if self.engine_args.served_model_name is not None
else "vllm"
)
self._prefill_queue_nats_server = os.getenv(
"NATS_SERVER", "nats://localhost:4222"
)
self._prefill_queue_stream_name = self.model_name
logger.info(
f"Prefill queue: {self._prefill_queue_nats_server}:{self._prefill_queue_stream_name}"
)
if self.engine_args.remote_prefill:
if self.engine_args.enable_chunked_prefill is not False:
logger.info("Chunked prefill is not supported yet, setting to False")
self.engine_args.enable_chunked_prefill = False
if self.engine_args.preemption_mode != "swap":
logger.info("Preemption mode is not supported yet, setting to swap")
self.engine_args.preemption_mode = "swap"
if self.engine_args.pipeline_parallel_size != 1:
logger.info("Pipeline parallel size is not supported yet, setting to 1")
self.engine_args.pipeline_parallel_size = 1
if self.engine_args.router == "kv":
raise NotImplementedError(
"Multimodal requests are not supported for kv router mode"
)
signal.signal(signal.SIGTERM, self.shutdown_vllm_engine)
signal.signal(signal.SIGINT, self.shutdown_vllm_engine)
@async_on_start
async def async_init(self):
self._engine_context = build_async_engine_client_from_engine_args(
self.engine_args
)
if self._engine_context is not None:
self.engine_client = await self._engine_context.__aenter__()
else:
raise RuntimeError("Failed to initialize engine client")
if self.engine_args.router == "kv":
raise NotImplementedError(
"Multimodal requests are not supported for kv router mode"
)
# Load the Hugging Face processor
try:
self.hf_processor = AutoProcessor.from_pretrained(
self.model_path, trust_remote_code=True
)
logger.info(f"Successfully loaded AutoProcessor from: {self.model_path}")
if (
not hasattr(self.hf_processor, "tokenizer")
or self.hf_processor.tokenizer is None
):
logger.warning(
f"Loaded AutoProcessor from {self.model_path} but it does not have a 'tokenizer' attribute or it is None."
)
except Exception as e:
logger.error(
f"Failed to load AutoProcessor from {self.model_path}: {e}",
exc_info=True,
)
# Depending on the desired behavior, you might want to raise the error
# or allow the worker to start without a processor if it's optional for some paths.
# For this change, processor is critical.
raise RuntimeError(f"Failed to initialize AutoProcessor: {e}") from e
runtime = dynamo_context["runtime"]
# Common setup for interacting with EncodeWorker (NIXL, client)
# This is needed for aggregated mode OR for local prefill in disaggregated mode.
enc_comp_ns, enc_comp_name = VllmEncodeWorker.dynamo_address() # type: ignore
self.encode_worker_client = (
await runtime.namespace(enc_comp_ns)
.component(enc_comp_name)
.endpoint("encode")
.client()
)
# Initialize the connector for RDMA transfers within the specified namespace.
self._connector = connect.Connector(runtime=runtime, namespace=enc_comp_ns)
await self._connector.initialize()
# NIXL buffer for receiving raw video frames.
incoming_frames_shape = (
self.num_sampled_frames,
self.frame_height,
self.frame_width,
self.frame_channels,
)
raw_frames_tensor = torch.empty(
incoming_frames_shape,
dtype=INCOMING_FRAMES_DTYPE,
device=INCOMING_FRAMES_DEVICE,
)
# Create a descriptor for the tensor to make it available for remote access.
descriptor = connect.Descriptor(raw_frames_tensor)
# Register the memory with the connector, making it discoverable.
descriptor.register_memory(self._connector)
self._frames_descriptor = (raw_frames_tensor, descriptor)
await check_required_workers(self.encode_worker_client, self.min_workers)
if self.do_remote_prefill: # Disaggregated mode specific setup
metadata = self.engine_client.nixl_metadata
metadata_store = NixlMetadataStore("dynamo", runtime)
await metadata_store.put(metadata.engine_id, metadata)
if self.engine_args.conditional_disagg:
self.disaggregated_router = PyDisaggregatedRouter(
runtime,
self.model_name,
max_local_prefill_length=self.engine_args.max_local_prefill_length,
max_prefill_queue_size=self.engine_args.max_prefill_queue_size,
)
await self.disaggregated_router.async_init()
else:
self.disaggregated_router = (
None # Always remote prefill if not conditional_disagg
)
# embedding_size is used for dummy token calculation in remote prefill case.
# For LLaVA-NeXT-Video, the model architecture processes each frame into a 12x12 grid
# of visual tokens, resulting in 144 tokens per frame. This is a fixed architectural
# constant. For more details on the vision tower architecture, see the LLaVA-1.5 paper
# which LLaVA-NeXT is based on: https://arxiv.org/abs/2310.03744
self.embedding_size = 144
logger.info(
f"Disaggregated mode: Using LLaVA-NeXT-Video embedding size: {self.embedding_size}"
)
else: # Aggregated mode specific setup
self.disaggregated_router = (
None # No disaggregated router in aggregated mode
)
logger.info(
"Aggregated mode: VllmDecodeWorker will handle multimodal data directly via NIXL."
)
logger.info("Initialization complete.")
def shutdown_vllm_engine(self, signum, frame):
"""Shutdown the background loop"""
logger.info(f"Received signal {signum}, shutting down")
loop = asyncio.get_event_loop()
try:
self.engine_client.close()
logger.info("Shutdown complete.")
except Exception as e:
logger.error(f"Error during shutdown: {e}")
finally:
loop.stop()
def get_remote_prefill_request_callback(self):
async def callback(request: RemotePrefillRequest):
try:
async with PrefillQueue.get_instance(
nats_server=self._prefill_queue_nats_server,
stream_name=self._prefill_queue_stream_name,
) as prefill_queue:
await prefill_queue.enqueue_prefill_request(request)
logger.info(
f"DecodeWorker {request.request_id}: Successfully enqueued remote prefill request."
)
except Exception as e:
logger.error(
f"DecodeWorker {request.request_id}: Failed to enqueue remote prefill request: {e}",
exc_info=True,
)
return callback
@endpoint()
async def generate(self, request: vLLMMultimodalRequest):
request_id = request.request_id
video_url = request.video_url # Video path for EncodeWorker
# TODO: Fix existing tokenizer <video> not found error and remove this.
user_text_prompt = request.engine_prompt.get(
"text_prompt", "Describe the video."
)
logger.info(
f"Received multimodal request {{ id: {request_id} }} with user text: '{user_text_prompt}'."
)
# Constants for token manipulation
# For LLaVA-NeXT-Video models, the video token ID is 32000, not 32001
# 32001 is for image tokens in LLaVA-NeXT-Video, 32000 is for video tokens
VIDEO_TOKEN_ID_FOR_EXPANSION = 32000
DUMMY_TOKEN_ID = 0
# Variables to be set based on processing path
prompt_argument_for_vllm: Union[str, TokensPrompt]
current_received_multimodal_data_tensor: Optional[torch.Tensor] = None
current_remote_prefill_params: Optional[RemotePrefillParams] = None
multi_modal_data_for_engine: Optional[dict] = None
if self.do_remote_prefill:
logger.info(f"Disaggregated mode: request {{ id: {request_id} }}.")
# Tokenize the prompt string to get base IDs for router length check and potential remote prefill manipulation
base_prompt_ids_for_router = request.engine_prompt["prompt_token_ids"]
if (
isinstance(base_prompt_ids_for_router, list)
and len(base_prompt_ids_for_router) > 0
and isinstance(base_prompt_ids_for_router[0], list)
and len(base_prompt_ids_for_router) == 1
):
base_prompt_ids_for_router = base_prompt_ids_for_router[0]
should_prefill_remotely_decision = True
if self.disaggregated_router:
async with PrefillQueue.get_instance(
nats_server=self._prefill_queue_nats_server,
stream_name=self._prefill_queue_stream_name,
) as prefill_queue:
prefill_queue_size = await prefill_queue.get_queue_size()
should_prefill_remotely_decision = (
await self.disaggregated_router.prefill_remote(
len(base_prompt_ids_for_router),
request.prefix_hit_rate,
prefill_queue_size,
)
)
if should_prefill_remotely_decision:
logger.info(
f"Disaggregated: Prefilling REMOTELY for request {{ id: {request_id} }} (orig prompt len {len(base_prompt_ids_for_router)})"
)
current_remote_prefill_params = RemotePrefillParams(
is_remote_prefill=True,
remote_prefill_request_callback=self.get_remote_prefill_request_callback(),
multimodal_data_source={"video_url": video_url},
)
num_dummies = self.embedding_size - 1
# For remote prefill, expand the *single* video token from base_prompt_ids and add dummies
expanded_and_dummied_ids = self._expand_video_tokens_in_prompt(
base_prompt_ids_for_router, # Use the tokenized output of chat_template
self.num_sampled_frames,
VIDEO_TOKEN_ID_FOR_EXPANSION,
add_dummy_tokens=True,
dummy_token_id=DUMMY_TOKEN_ID,
num_dummy_tokens_per_frame=num_dummies,
)
prompt_argument_for_vllm = TokensPrompt(
prompt_token_ids=expanded_and_dummied_ids, multi_modal_data=None
)
multi_modal_data_for_engine = None # Handled by prefill worker
else: # Local prefill in disaggregated mode
logger.info(
f"Disaggregated: Prefilling LOCALLY for request {{ id: {request_id} }} (orig prompt len {len(base_prompt_ids_for_router)})"
)
raw_frames_tensor_from_nixl, desc = self._frames_descriptor
# Create a writable operation handle for the remote EncodeWorker.
# This allows the EncodeWorker to write directly into this worker's `raw_frames_tensor_from_nixl`.
with self._connector.create_writable(desc) as writable:
enc_req = EncodeRequest(
request_id=request_id,
video_url=video_url,
# Serialize the writable handle to send it to the EncodeWorker.
serialized_request=writable.to_serialized(),
)
async for _ in await self.encode_worker_client.round_robin(
enc_req.model_dump_json()
):
pass
# Wait for the remote write from the EncodeWorker to complete.
await writable.wait_for_completion()
current_received_multimodal_data_tensor = raw_frames_tensor_from_nixl
# The vLLM engine's processor for raw visual data expects a CPU-based NumPy array.
# Therefore, we must first move the tensor from the GPU to the CPU memory
# before converting it to a NumPy array.
# See vLLM's official example for raw image inputs: https://github.com/vllm-project/vllm/blob/main/examples/llava_example.py
video_numpy = current_received_multimodal_data_tensor.numpy()
multi_modal_data_for_engine = {"video": video_numpy}
prompt_argument_for_vllm = request.engine_prompt[
"prompt_token_ids"
] # Pass raw string to vLLM
current_remote_prefill_params = None
else: # AGGREGATED MODE
logger.info(
f"Aggregated mode: request {{ id: {request_id} }}. Fetching frames directly."
)
raw_frames_tensor_from_nixl, desc = self._frames_descriptor
# Create a writable operation handle for the remote EncodeWorker.
# This allows the EncodeWorker to write directly into this worker's `raw_frames_tensor_from_nixl`.
with self._connector.create_writable(desc) as writable:
enc_req = EncodeRequest(
request_id=request_id,
video_url=video_url,
# Serialize the writable handle to send it to the EncodeWorker.
serialized_request=writable.to_serialized(),
)
async for _ in await self.encode_worker_client.round_robin(
enc_req.model_dump_json()
):
pass
# Wait for the remote write from the EncodeWorker to complete.
await writable.wait_for_completion()
current_received_multimodal_data_tensor = raw_frames_tensor_from_nixl
# The vLLM engine's processor for raw visual data expects a CPU-based NumPy array.
# Therefore, we must first move the tensor from the GPU to the CPU memory
# before converting it to a NumPy array.
# See vLLM's official example for raw image inputs: https://github.com/vllm-project/vllm/blob/main/examples/llava_example.py
video_numpy = current_received_multimodal_data_tensor.numpy()
multi_modal_data_for_engine = {"video": video_numpy}
prompt_argument_for_vllm = request.engine_prompt[
"prompt_token_ids"
] # Pass raw string to vLLM
current_remote_prefill_params = None
request.sampling_params.output_kind = RequestOutputKind.DELTA
# Prepare the first argument for vLLM engine's generate call
final_vllm_input: Union[str, dict]
if isinstance(prompt_argument_for_vllm, dict):
# This handles the remote prefill path where we have a TokensPrompt,
# which is a dict-like object.
final_vllm_input = prompt_argument_for_vllm
elif isinstance(prompt_argument_for_vllm, list):
# This handles the local prefill (aggregated or disaggregated) path
# where we have a list of token IDs and raw video data.
final_vllm_input = {
"prompt_token_ids": prompt_argument_for_vllm,
"multi_modal_data": multi_modal_data_for_engine,
}
else:
logger.error(
f"Unexpected type for prompt_argument_for_vllm: {type(prompt_argument_for_vllm)}"
)
raise TypeError("Invalid type for vLLM prompt argument.")
async for response in self.engine_client.generate(
final_vllm_input, # This is now the prompts argument (dict)
sampling_params=request.sampling_params,
request_id=request.request_id,
remote_prefill_params=current_remote_prefill_params,
):
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
).model_dump_json()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import base64
import binascii
import json
import logging
import os
from io import BytesIO
from queue import Queue
from typing import AsyncIterator, Optional
from urllib.parse import urlparse
import av
import connect
import httpx
import numpy as np
import torch
import torch.nn.functional as F
from utils.protocol import EncodeRequest
from utils.vllm import parse_vllm_args
from dynamo.sdk import async_on_start, endpoint, service
logger = logging.getLogger(__name__)
CACHE_SIZE_MAXIMUM = 8
@service(
dynamo={
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class VllmEncodeWorker:
def __init__(self) -> None:
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self.num_frames_to_sample = getattr(self.engine_args, "num_sampled_frames", 8)
self.frame_height = getattr(self.engine_args, "frame_height", 336)
self.frame_width = getattr(self.engine_args, "frame_width", 336)
self.frame_channels = getattr(self.engine_args, "frame_channels", 3)
self.dummy_token_id = getattr(self.engine_args, "dummy_token_id", 0)
self.video_token_id = getattr(self.engine_args, "video_token_id", 32000)
self.dummy_tokens_per_frame = getattr(
self.engine_args, "dummy_tokens_per_frame", 144
)
self._video_content_cache: dict[str, BytesIO] = {}
self._cache_queue: Queue[str] = Queue(maxsize=CACHE_SIZE_MAXIMUM)
self._http_client: Optional[httpx.AsyncClient] = None
self._http_timeout = 60.0
async def _read_video_pyav(
self, container: av.container.InputContainer, indices: np.ndarray
) -> np.ndarray:
"""
Decode the video with PyAV decoder. Async wrapper.
"""
def blocking_decode():
container.seek(0) # Reset container for decoding
processed_indices = set(indices)
# Determine min/max index to optimize decoding loop slightly
min_idx = 0
max_idx = -1
if len(indices) > 0:
min_idx = np.min(indices)
max_idx = np.max(indices)
if (
not processed_indices
and container.streams.video
and container.streams.video[0].frames > 0
):
logger.warning(
"_read_video_pyav called with empty indices for a non-empty video, attempting to read first frame."
)
try:
frame = next(container.decode(video=0))
return np.stack([frame.to_ndarray(format="rgb24")])
except StopIteration:
logger.error(
"Failed to read even the first frame despite non-empty indices check."
)
return np.array([])
decoded_frames_list = []
for i, frame in enumerate(container.decode(video=0)):
if i > max_idx and max_idx != -1: # max_idx is -1 if indices is empty
break
if i >= min_idx and i in processed_indices:
decoded_frames_list.append(frame)
if not decoded_frames_list and len(processed_indices) > 0:
actual_decoded_count = 0
try:
container.seek(0) # Reset for counting
for _ in container.decode(video=0):
actual_decoded_count += 1
except Exception: # Handle cases where re-decoding/counting fails
pass # Keep original error message
raise ValueError(
f"Could not decode any frames for the given indices: {indices.tolist()}. "
f"Video might be shorter than expected or indices out of bounds. "
f"Actual decodable frames in container (approx): {actual_decoded_count}."
)
return (
np.stack([x.to_ndarray(format="rgb24") for x in decoded_frames_list])
if decoded_frames_list
else np.array([])
)
return await asyncio.to_thread(blocking_decode)
async def _load_video_content(self, video_url: str) -> BytesIO:
parsed_url = urlparse(video_url)
video_url_lower = video_url.lower()
if parsed_url.scheme in ("http", "https"):
if video_url_lower in self._video_content_cache:
logger.info(f"Video content found in cache for URL: {video_url}")
cached_content = self._video_content_cache[video_url_lower]
cached_content.seek(0)
return cached_content
try:
video_data: BytesIO
if parsed_url.scheme == "data":
if not parsed_url.path.startswith(
("video/", "application/octet-stream")
):
raise ValueError("Data URL must be a video type or octet-stream")
media_type_and_data = parsed_url.path.split(",", 1)
if len(media_type_and_data) != 2:
raise ValueError("Invalid Data URL format: missing comma separator")
media_type, data_segment = media_type_and_data
if ";base64" not in media_type:
raise ValueError("Video Data URL currently must be base64 encoded")
try:
video_bytes = base64.b64decode(data_segment)
video_data = BytesIO(video_bytes)
except binascii.Error as e:
raise ValueError(
f"Invalid base64 encoding for video data: {e}"
) from e
elif parsed_url.scheme in ("http", "https"):
if not self._http_client:
await self._init_http_client()
if not self._http_client: # Double check after initialization
raise RuntimeError("Failed to initialize HTTP client")
logger.info(f"Downloading video from URL: {video_url}")
response = await self._http_client.get(
video_url, timeout=self._http_timeout
)
response.raise_for_status()
if not response.content:
raise ValueError(
f"Empty response content from video URL: {video_url}"
)
video_data = BytesIO(response.content)
video_data.seek(0)
logger.info(
f"Video downloaded from {video_url}, size: {len(response.content)} bytes."
)
elif parsed_url.scheme == "file" or not parsed_url.scheme:
file_path = parsed_url.path if parsed_url.scheme else video_url
# Ensure path is absolute or resolve relative to a known base if necessary
# For simplicity, assuming it's an accessible path.
if not os.path.exists(file_path):
raise FileNotFoundError(f"Error reading file: {file_path}")
with open(file_path, "rb") as f:
video_bytes = f.read()
video_data = BytesIO(video_bytes)
else:
raise ValueError(
f"Unsupported video source scheme: {parsed_url.scheme} for URL {video_url}"
)
if parsed_url.scheme in (
"http",
"https",
): # Cache successfully downloaded content
if self._cache_queue.full():
oldest_url = self._cache_queue.get_nowait()
if oldest_url in self._video_content_cache:
del self._video_content_cache[oldest_url]
# Store the BytesIO object directly; it will be seek(0)'d when retrieved
self._video_content_cache[video_url_lower] = video_data
self._cache_queue.put(video_url_lower)
return video_data
except httpx.HTTPStatusError as e:
logger.error(
f"HTTP error {e.response.status_code} loading video {video_url}: {e.response.text[:200]}"
)
raise ValueError(
f"Failed to download video {video_url}: HTTP {e.response.status_code}"
) from e
except httpx.RequestError as e:
logger.error(f"Request error loading video {video_url}: {e}")
raise ValueError(f"Network request failed for video {video_url}") from e
except FileNotFoundError as e:
logger.error(f"File error loading video {video_url}: {e}")
raise
except Exception as e:
logger.error(
f"Error loading video content from {video_url}: {type(e).__name__} - {e}"
)
raise ValueError(f"Failed to load video content: {e}") from e
@endpoint()
async def encode(self, request: EncodeRequest) -> AsyncIterator[str]:
request_id = request.request_id
video_url = getattr(request, "video_url", None)
if not video_url:
logger.error(f"Request {request_id}: 'video_url' not provided.")
raise ValueError("'video_url' is required for encoding.")
if request.serialized_request is None:
logger.error(
f"Request serialized_request is None for request: {{ id: {request_id} }}."
)
raise ValueError("'serialized_request' is required for encoding.")
logger.info(
f"Received encode request: {{ id: {request_id}, video_url: '{video_url[:100]}...' }}"
)
container: Optional[av.container.InputContainer] = None
try:
video_content_stream = await self._load_video_content(video_url)
def open_video_container_sync():
try:
return av.open(video_content_stream, mode="r")
except av.FFmpegError as ave:
logger.error(
f"PyAV error opening video stream from {video_url}: {ave}"
)
raise ValueError(
f"Invalid video format or corrupted data from {video_url}."
) from ave
except Exception as e:
logger.error(
f"Unexpected error opening video stream from {video_url} with PyAV: {e}"
)
raise ValueError(
f"Unexpected error opening video from {video_url}."
) from e
container = await asyncio.to_thread(open_video_container_sync)
if not container or not container.streams.video:
logger.error(f"No video stream found in {video_url}.")
raise ValueError(f"No video stream in {video_url}.")
stream_info = container.streams.video[0]
total_frames = stream_info.frames
# Duration can be useful for streams where total_frames is 0
if stream_info.duration and stream_info.time_base:
duration_sec = float(stream_info.duration * stream_info.time_base)
else:
duration_sec = 0
if total_frames == 0 and duration_sec == 0:
logger.error(f"Video file '{video_url}' has 0 frames and 0 duration.")
raise ValueError(f"Video {video_url} has 0 frames and 0 duration.")
if total_frames == 0 and duration_sec > 0:
logger.warning(
f"Video {video_url} reports 0 frames but has duration {duration_sec:.2f}s. Frame sampling may be based on requested count directly."
)
logger.debug(
f"Video {video_url} has {total_frames} frames (duration: {duration_sec:.2f}s). Sampling {self.num_frames_to_sample} frames."
)
indices: np.ndarray
if total_frames > 0:
if total_frames < self.num_frames_to_sample:
logger.warning(
f"Video frames ({total_frames}) < samples ({self.num_frames_to_sample}). Using all {total_frames} available frames."
)
indices = np.arange(0, total_frames).astype(int)
else:
indices = np.linspace(
0, total_frames - 1, self.num_frames_to_sample, dtype=int
)
else: # total_frames is 0 (likely a stream), sample by count.
logger.warning(
f"Video {video_url} frame count is 0. Attempting to sample {self.num_frames_to_sample} frames by index. This might fail if stream is too short."
)
indices = np.arange(0, self.num_frames_to_sample).astype(int)
# Ensure indices are unique, especially after linspace for small numbers.
indices = np.unique(indices)
if (
len(indices) == 0 and total_frames > 0
): # Safety for linspace oddities with few frames
# If unique resulted in empty but there are frames, sample at least one or up to num_frames_to_sample
actual_samples = min(self.num_frames_to_sample, total_frames)
indices = np.arange(0, actual_samples).astype(int)
elif len(indices) == 0 and total_frames == 0:
# If indices is empty and total_frames is 0, this means num_frames_to_sample might be 0 or indices logic failed.
# This case implies we might not be able to sample any frames.
# _read_video_pyav handles empty indices with non-empty video by trying to read the first frame.
# If indices is empty here due to num_frames_to_sample=0, _read_video_pyav will return empty.
pass # Let _read_video_pyav handle this.
logger.info(f"Selected frame indices for {video_url}: {indices.tolist()}")
if not container:
raise ValueError(f"Container is None for {video_url}")
clip_np: np.ndarray = await self._read_video_pyav(container, indices)
if clip_np.size == 0:
raise ValueError(
f"Failed to extract any video frames from {video_url} for indices {indices.tolist()}. Clip is empty."
)
logger.info(
f"Successfully extracted {len(clip_np) if clip_np.ndim > 1 and clip_np.shape[0] > 0 else 0} frames for {video_url} with original shape {clip_np.shape}."
)
# Convert the NumPy array from the video decoder into a PyTorch tensor.
# This is a required step to use PyTorch functions for GPU-accelerated image processing.
frames_tensor_orig_res = torch.from_numpy(clip_np) # Shape: (T, H, W, C)
# Permute to (T, C, H, W) for interpolate
frames_tensor_chw = frames_tensor_orig_res.permute(
0, 3, 1, 2
).float() # Ensure float for interpolate
# Resize
resized_frames_tensor_chw = F.interpolate(
frames_tensor_chw,
size=(self.frame_height, self.frame_width),
mode="bilinear",
align_corners=False,
)
# Permute back to (T, H_new, W_new, C)
resized_frames_tensor_hwc = resized_frames_tensor_chw.permute(0, 2, 3, 1)
logger.debug(f"Resized frames to shape: {resized_frames_tensor_hwc.shape}")
# Ensure the tensor is contiguous, on CUDA and uint8 for the NIXL buffer.
tensor_for_descriptor: torch.Tensor = resized_frames_tensor_hwc.to(
device="cpu", dtype=torch.uint8
).contiguous()
logger.info(
f"Req {request_id}: Preparing raw frames tensor (shape: {tensor_for_descriptor.shape}, "
f"dtype: {tensor_for_descriptor.dtype}, device: {tensor_for_descriptor.device}, "
f"contiguous: {tensor_for_descriptor.is_contiguous()}) for RDMA."
)
# Create a descriptor for the tensor to be sent via the connector.
descriptor = connect.Descriptor(tensor_for_descriptor)
logger.info(f"Req {request_id}: Beginning connector write operation.")
# Pass the remote worker's SerializedRequest (representing its WritableOperation) to begin_write.
# This initiates the data transfer to the memory buffer on the other worker.
write_op = await self._connector.begin_write(
descriptor, request.serialized_request
)
# Wait for the RDMA/transfer operation to complete.
await write_op.wait_for_completion()
logger.info(f"Req {request_id}: Connector write operation completed.")
# Yield a simplified response, assuming EncodeResponse in protocol is adapted
final_response_data = {
"request_id": request.request_id,
}
yield json.dumps(final_response_data)
logger.info(f"Encode request {request_id} processed successfully.")
except (
FileNotFoundError,
av.FFmpegError,
ValueError,
) as e:
logger.error(
f"Error processing request {request_id} ({video_url[:100]}...): {type(e).__name__} - {e}"
)
raise # Re-raise to be handled by the service framework
except Exception as e:
logger.exception(
f"Unexpected error processing request {request_id} ({video_url[:100]}...): {e}"
)
raise
finally:
if container:
await asyncio.to_thread(container.close)
async def _init_http_client(self):
if (
not self._http_client or self._http_client.is_closed
): # Check if closed as well
self._http_client = httpx.AsyncClient(timeout=self._http_timeout)
logger.info("HTTP client (re)initialized.")
@async_on_start
async def async_init(self):
logger.info(f"{self.__class__.__name__} async_init started.")
# Initialize the connector for RDMA transfers.
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Dynamo connector initialized.")
await self._init_http_client()
logger.info(
f"{self.__class__.__name__} async_init completed. Ready to encode video frames."
)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from components.video_processor import Processor
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from utils.protocol import MultiModalRequest
from dynamo.sdk import DYNAMO_IMAGE, api, depends, service
logger = logging.getLogger(__name__)
@service(
dynamo={
"namespace": "dynamo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
image=DYNAMO_IMAGE,
app=FastAPI(title="Multimodal Example"),
)
class Frontend:
processor = depends(Processor)
@api(name="v1/chat/completions")
async def generate(self, request: MultiModalRequest):
async def content_generator():
async for response in self.processor.generate(request.model_dump_json()):
try:
s = json.loads(response)
yield s
except json.JSONDecodeError as e:
raise RuntimeError(f"Failed to parse JSON response: {e}") from e
return StreamingResponse(content_generator(), media_type="text/event-stream")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
import signal
from typing import Optional
import connect
import torch
from components.video_encode_worker import VllmEncodeWorker
from pydantic import BaseModel
from utils.logging import check_required_workers
from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue
from utils.protocol import EncodeRequest
from utils.vllm import parse_vllm_args
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.inputs.data import TokensPrompt
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
logger = logging.getLogger(__name__)
# Constants for the shape and dtype of the INCOMING FRAMES tensor.
# Other constants taken from yaml as they are model dependent.
INCOMING_FRAMES_DTYPE = torch.uint8
INCOMING_FRAMES_DEVICE = "cpu"
class RequestType(BaseModel):
text: str
@service(
dynamo={
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class VllmPrefillWorker:
encode_worker = depends(VllmEncodeWorker)
def __init__(self):
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self.model_path = self.engine_args.model # Store model path for AutoProcessor
self.num_sampled_frames = getattr(self.engine_args, "num_sampled_frames", 8)
self.frame_height = getattr(self.engine_args, "frame_height", 336)
self.frame_width = getattr(self.engine_args, "frame_width", 336)
self.frame_channels = getattr(self.engine_args, "frame_channels", 3)
self.dummy_token_id = getattr(self.engine_args, "dummy_token_id", 0)
self.video_token_id = getattr(self.engine_args, "video_token_id", 32000)
self.dummy_tokens_per_frame = getattr(
self.engine_args, "dummy_tokens_per_frame", 144
)
self._loaded_metadata = set()
self.initialized = False
self.min_workers = 1
# IMPORTANT: PrefillWorker MUST remove dummy tokens before passing to vLLM
# Only the actual video tokens (32000) should remain as placeholders for multimodal embeddings
if self.engine_args.enable_chunked_prefill is not False:
logger.info("Chunked prefill is not supported yet, setting to False")
self.engine_args.enable_chunked_prefill = False
if self.engine_args.pipeline_parallel_size != 1:
logger.info("Pipeline parallel size is not supported yet, setting to 1")
self.engine_args.pipeline_parallel_size = 1
if self.engine_args.disable_async_output_proc is not True:
logger.info("Async output processing is not supported yet, setting to True")
self.engine_args.disable_async_output_proc = True
if self.engine_args.enforce_eager is not True:
logger.info("Prefill must be done eagerly, setting to True")
self.engine_args.enforce_eager = True
if self.engine_args.enable_prefix_caching is not False:
logger.info(
"Prefix caching is not supported yet in prefill worker, setting to False"
)
self.engine_args.enable_prefix_caching = False
signal.signal(signal.SIGTERM, self.shutdown_vllm_engine)
signal.signal(signal.SIGINT, self.shutdown_vllm_engine)
@async_on_start
async def async_init(self):
self._engine_context = build_async_engine_client_from_engine_args(
self.engine_args
)
if self._engine_context is not None:
self.engine_client = await self._engine_context.__aenter__()
else:
raise RuntimeError("Failed to initialize engine client")
# NOTE: PrefillWorker no longer needs AutoProcessor since it uses the original
# tokenized prompt from DecodeWorker instead of creating its own.
logger.info(
"PrefillWorker: Skipping AutoProcessor initialization - using original tokens from DecodeWorker"
)
runtime = dynamo_context["runtime"]
enc_comp_ns, enc_comp_name = VllmEncodeWorker.dynamo_address() # type: ignore
self.encode_worker_client = (
await runtime.namespace(enc_comp_ns)
.component(enc_comp_name)
.endpoint("encode")
.client()
)
# Initialize the connector for RDMA transfers within the specified namespace.
self._connector = connect.Connector(runtime=runtime, namespace=enc_comp_ns)
await self._connector.initialize()
incoming_frames_shape = (
self.num_sampled_frames,
self.frame_height,
self.frame_width,
self.frame_channels,
)
# Pre-allocate a tensor on the CPU to receive frame data.
frames_tensor = torch.empty(
incoming_frames_shape,
dtype=INCOMING_FRAMES_DTYPE,
device=INCOMING_FRAMES_DEVICE,
)
# Create a descriptor for the tensor to make it available for remote access.
descriptor = connect.Descriptor(frames_tensor)
# Register the memory with the connector, making it discoverable.
descriptor.register_memory(self._connector)
self._frames_descriptor = (frames_tensor, descriptor)
await check_required_workers(self.encode_worker_client, self.min_workers)
metadata = self.engine_client.nixl_metadata
self._metadata_store = NixlMetadataStore("dynamo", runtime)
await self._metadata_store.put(metadata.engine_id, metadata)
logger.info("PrefillWorker: Creating prefill_queue_handler task.")
task = asyncio.create_task(self.prefill_queue_handler())
def prefill_queue_handler_cb(fut):
try:
fut.result()
logger.info(
"PrefillWorker: prefill_queue_handler task exited successfully."
)
except asyncio.CancelledError:
logger.info("PrefillWorker: prefill_queue_handler task was cancelled.")
except Exception as e:
logger.error(
f"PrefillWorker: prefill_queue_handler task failed with exception: {e!r}",
exc_info=True,
)
task.add_done_callback(prefill_queue_handler_cb)
logger.info("PrefillWorker: async_init complete.")
def shutdown_vllm_engine(self, signum, frame):
"""Shutdown the background loop"""
logger.info(f"Shutdown started, signal {signum} received.")
loop = asyncio.get_event_loop()
try:
self.engine_client.close()
except Exception as e:
logger.error(f"Error during shutdown: {e}")
finally:
loop.stop()
logger.info("Shutdown complete.")
async def prefill_queue_handler(self):
logger.info("PrefillWorker: Prefill queue handler task started.")
prefill_queue_nats_server = os.getenv("NATS_SERVER", "nats://localhost:4222")
prefill_queue_stream_name = (
self.engine_args.served_model_name
if self.engine_args.served_model_name is not None
else "vllm"
)
logger.info(
f"PrefillWorker: Connecting to prefill queue: {prefill_queue_nats_server}, stream: '{prefill_queue_stream_name}'"
)
self.initialized = True
try:
async with PrefillQueue.get_instance(
nats_server=prefill_queue_nats_server,
stream_name=prefill_queue_stream_name,
) as prefill_queue:
logger.info(
f"PrefillWorker: Entering dequeue loop for stream '{prefill_queue_stream_name}'."
)
while True:
prefill_request: Optional[RemotePrefillRequest] = None
try:
prefill_request = await prefill_queue.dequeue_prefill_request()
except Exception as e:
logger.error(
f"PrefillWorker: Exception during dequeue_prefill_request: {e}",
exc_info=True,
)
await asyncio.sleep(5)
continue
if prefill_request is not None:
logger.info(
f"PrefillWorker: Dequeued prefill request: {prefill_request.request_id}"
)
try:
async for _ in self.generate(prefill_request):
pass
logger.info(
f"PrefillWorker: Successfully processed prefill request {prefill_request.request_id}."
)
except Exception as e:
logger.error(
f"PrefillWorker: Error processing prefill request {prefill_request.request_id} in self.generate: {e}",
exc_info=True,
)
else:
await asyncio.sleep(0.1)
except Exception as e:
logger.error(
f"PrefillWorker: Prefill queue handler CRASHED: {e}", exc_info=True
)
async def generate(self, request: RemotePrefillRequest):
video_url = request.multimodal_data_source.get("video_url")
if video_url is None:
raise ValueError(
"No video_url provided in multimodal_data_source for prefill request"
)
request_id = request.request_id
engine_id = request.engine_id
logger.info(
f"PrefillWorker {request_id}: Received prefill request for video_url: {video_url}."
)
raw_frames_tensor, descriptor = self._frames_descriptor
logger.debug(
f"PrefillWorker {request_id}: Requesting frames from EncodeWorker for {video_url}"
)
# Create a writable operation handle for the remote EncodeWorker.
# This allows the EncodeWorker to write directly into this worker's `frames_tensor`.
with self._connector.create_writable(descriptor) as writable:
encode_generator = await self.encode_worker_client.round_robin(
EncodeRequest(
request_id=request_id,
video_url=video_url,
# Serialize the writable handle to send it to the EncodeWorker.
serialized_request=writable.to_serialized(),
).model_dump_json()
)
async for _ in encode_generator:
pass
# Wait for the remote write from the EncodeWorker to complete.
await writable.wait_for_completion()
logger.debug(
f"PrefillWorker {request_id}: Frames received from EncodeWorker, shape: {raw_frames_tensor.shape}"
)
if not request.prompt_token_ids:
logger.error(
f"PrefillWorker {request_id}: No prompt_token_ids provided in request!"
)
raise ValueError(
"PrefillWorker requires prompt_token_ids from DecodeWorker"
)
# Constants for token manipulation
DUMMY_TOKEN_ID = self.dummy_token_id
VIDEO_TOKEN_ID = self.video_token_id
# Step 1: Find all video token positions
video_token_positions = [
i
for i, token in enumerate(request.prompt_token_ids)
if token == VIDEO_TOKEN_ID
]
logger.debug(
f"PrefillWorker {request_id}: Found {len(video_token_positions)} video tokens at positions: {video_token_positions}"
)
# Step 2: Process tokens from end to start to avoid position shifting
processed_tokens = list(request.prompt_token_ids)
for pos in reversed(video_token_positions):
# Calculate range of tokens to remove (dummy tokens after this video token)
start_idx = pos + 1
end_idx = start_idx + self.dummy_tokens_per_frame
# Check if we have enough tokens to remove
if end_idx > len(processed_tokens):
logger.warning(
f"PrefillWorker {request_id}: Not enough tokens to remove at position {pos}"
)
continue
# Remove the dummy tokens
processed_tokens = processed_tokens[:start_idx] + processed_tokens[end_idx:]
# Step 3: Verify we have exactly one video token left
final_video_count = sum(
1 for token in processed_tokens if token == VIDEO_TOKEN_ID
)
if final_video_count != 1:
logger.error(
f"PrefillWorker {request_id}: Wrong number of video tokens! Expected 1, got {final_video_count}"
)
# Step 4: Check for any remaining dummy tokens (should be none)
remaining_dummies = sum(
1 for token in processed_tokens if token == DUMMY_TOKEN_ID
)
if remaining_dummies > 0:
logger.warning(
f"PrefillWorker {request_id}: Found {remaining_dummies} remaining dummy tokens!"
)
# Create the input for vLLM
prefill_vllm_input = TokensPrompt(
prompt_token_ids=processed_tokens,
multi_modal_data={"video": raw_frames_tensor.numpy()},
)
sampling_params = request.sampling_params
sampling_params.max_tokens = 1
sampling_params.min_tokens = 1
remote_prefill_params = RemotePrefillParams(
is_remote_decode=True,
decode_block_ids=request.block_ids,
decode_engine_id=engine_id,
decode_computed_block_ids=request.computed_block_ids,
)
if engine_id not in self._loaded_metadata:
remote_metadata = await self._metadata_store.get(request.engine_id)
await self.engine_client.add_remote_nixl_metadata(remote_metadata)
logger.info(
f"Loaded nixl metadata from engine {engine_id} into "
f"engine {self.engine_client.nixl_metadata.engine_id}"
)
self._loaded_metadata.add(engine_id)
logger.debug(
f"PrefillWorker {request_id}: Calling engine_client.generate for prefill."
)
async for _ in self.engine_client.generate(
prefill_vllm_input,
sampling_params=sampling_params,
request_id=request_id,
remote_prefill_params=remote_prefill_params,
):
yield
logger.info(f"PrefillWorker {request_id}: Finished processing prefill request.")
@endpoint()
async def mock(self, req: RequestType):
yield f"mock_response: {req}"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import uuid
from enum import Enum
from typing import AsyncIterator, Tuple, Union
from components.video_decode_worker import VllmDecodeWorker
from transformers import AutoTokenizer
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.logging import check_required_workers
from utils.protocol import MultiModalRequest, MyRequestOutput, vLLMMultimodalRequest
from utils.vllm import parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from dynamo.runtime import EtcdKvCache
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
logger = logging.getLogger(__name__)
class RequestType(Enum):
CHAT = "chat"
COMPLETION = "completion"
@service(
dynamo={
"namespace": "dynamo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
class Processor(ProcessMixIn):
"""
vLLM pre and post processing
"""
worker = depends(VllmDecodeWorker)
def __init__(self):
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self.model_config = self.engine_args.create_model_config()
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.tokenizer = self._create_tokenizer(self.engine_args)
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
self.completions_processor = CompletionsProcessor(
self.tokenizer, self.model_config
)
self.min_workers = 1
def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer:
"""Create a TokenizerGroup using engine arguments similar to VLLM's approach"""
model_path = engine_args.model
# Create the base tokenizer with VLLM's typical settings
base_tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
padding_side="left",
truncation_side="left",
use_fast=True, # VLLM might use the fast tokenizer for efficiency
)
return base_tokenizer
@async_on_start
async def async_init(self):
runtime = dynamo_context["runtime"]
comp_ns, comp_name = VllmDecodeWorker.dynamo_address() # type: ignore
self.worker_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("generate")
.client()
)
await check_required_workers(self.worker_client, self.min_workers)
self.etcd_kv_cache = await EtcdKvCache.create(
runtime.etcd_client(),
"/dynamo/processor/",
{"router": self.engine_args.router},
)
# Main method to parse the request and send the request to the vllm worker.
async def _generate(
self,
raw_request: Union[CompletionRequest, ChatCompletionRequest],
image: str,
request_type: RequestType,
):
request_id = str(uuid.uuid4())
logger.debug(f"Got raw request: {raw_request}")
(
request,
conversation,
prompt,
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
worker_request = vLLMMultimodalRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
video_url=image,
)
router_mode = (await self.etcd_kv_cache.get("router")).decode()
if router_mode == "kv":
# The current KV router does not support multimodal requests because
# it performs cache lookup based solely on prompt tokens. At this stage,
# multimodal data (e.g., image features) is not yet available, so the router
# cannot select the optimal worker using both prompt and image inputs.
raise NotImplementedError(
"Multimodal requests are not supported for kv router mode"
)
if router_mode == "random":
response_generator = await self.worker_client.generate(
worker_request.model_dump_json()
)
elif router_mode == "round-robin":
response_generator = await self.worker_client.round_robin(
worker_request.model_dump_json()
)
else:
raise NotImplementedError(f"Router mode {router_mode} not implemented")
output = self._generate_responses(response_generator, request_type)
async for response in await self._stream_response(
request, output, request_id, conversation
):
yield response
# This method is used to process the responses from the engine generator.
async def _generate_responses(
self,
response_generator: AsyncIterator[RequestOutput],
request_type: RequestType,
) -> AsyncIterator[Union[RequestOutput, Tuple[int, RequestOutput]]]:
prompt_idx = 0
async for resp in response_generator:
# Deserialize the response from the engine
# Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data())
# OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object
request_output = RequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
)
if request_type == RequestType.CHAT:
# For chat requests, yield the request_output directly.
yield request_output
elif request_type == RequestType.COMPLETION:
# Completion requests can have multiple prompts and stream generator requires the prompt index
yield (prompt_idx, request_output)
else:
raise NotImplementedError(
f"Request type {request_type} not implemented"
)
# The generate endpoint will be used by the frontend to handle incoming requests.
@endpoint()
async def generate(self, raw_request: MultiModalRequest):
# TODO: Make the template consumed from the config file
msg = {
"role": "user",
"content": "USER: <video>\nQuestion:"
+ raw_request.messages[0].content[0].text
+ " Answer:",
}
chat_request = ChatCompletionRequest(
model=raw_request.model,
messages=[msg],
stream=raw_request.stream,
max_tokens=raw_request.max_tokens,
request_id=str(uuid.uuid4()),
)
video_url = None
for message in raw_request.messages:
for item in message.content:
if item.type == "video_url":
video_url = item.video_url.url
if video_url is None:
raise ValueError("Video URL is required")
async for response in self._generate(chat_request, video_url, RequestType.CHAT):
yield json.dumps(response)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: llava-hf/LLaVA-NeXT-Video-7B-hf
block-size: 64
max-model-len: 4096
num-sampled-frames: 8
frame-height: 336
frame-width: 336
frame-channels: 3
dummy-token-id: 0
video-token-id: 32000
dummy-tokens-per-frame: 144
Processor:
router: round-robin
common-configs: [model, block-size, max-model-len]
VllmDecodeWorker:
enforce-eager: true
max-num-batched-tokens: 16384
enable-prefix-caching: true
router: random
tensor-parallel-size: 1
ServiceArgs:
workers: 1
resources:
gpu: 1
common-configs: [model, block-size, max-model-len, num-sampled-frames, frame-height, frame-width, frame-channels, dummy-token-id, video-token-id, dummy-tokens-per-frame]
VllmEncodeWorker:
tensor-parallel-size: 1
router: random
ServiceArgs:
workers: 1
resources:
gpu: 1
common-configs: [model, num-sampled-frames, frame-height, frame-width, frame-channels, dummy-token-id, video-token-id, dummy-tokens-per-frame]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: llava-hf/LLaVA-NeXT-Video-7B-hf
block-size: 64
max-model-len: 8192
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
num-sampled-frames: 8
frame-height: 336
frame-width: 336
frame-channels: 3
dummy-token-id: 0
video-token-id: 32000
dummy-tokens-per-frame: 144
Processor:
router: round-robin
common-configs: [model, block-size]
VllmDecodeWorker:
remote-prefill: true
conditional-disagg: false
max-local-prefill-length: 50
max-prefill-queue-size: 2
ServiceArgs:
workers: 1
resources:
gpu: 1
common-configs: [model, block-size, max-model-len, kv-transfer-config, num-sampled-frames, frame-height, frame-width, frame-channels, dummy-token-id, video-token-id, dummy-tokens-per-frame]
VllmPrefillWorker:
max-num-batched-tokens: 16384
ServiceArgs:
workers: 1
resources:
gpu: 1
common-configs: [model, block-size, max-model-len, kv-transfer-config, num-sampled-frames, frame-height, frame-width, frame-channels, dummy-token-id, video-token-id, dummy-tokens-per-frame]
VllmEncodeWorker:
tensor-parallel-size: 1
router: random
ServiceArgs:
workers: 1
resources:
gpu: 1
common-configs: [model, num-sampled-frames, frame-height, frame-width, frame-channels, dummy-token-id, video-token-id, dummy-tokens-per-frame]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from components.video_decode_worker import VllmDecodeWorker
from components.video_encode_worker import VllmEncodeWorker
from components.video_frontend import Frontend
from components.video_processor import Processor
Frontend.link(Processor).link(VllmDecodeWorker).link(VllmEncodeWorker)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from components.video_decode_worker import VllmDecodeWorker
from components.video_encode_worker import VllmEncodeWorker
from components.video_frontend import Frontend
from components.video_prefill_worker import VllmPrefillWorker
from components.video_processor import Processor
Frontend.link(Processor).link(VllmDecodeWorker).link(VllmPrefillWorker).link(
VllmEncodeWorker
)
......@@ -106,7 +106,16 @@ class ImageContent(BaseModel):
image_url: ImageURLDetail
MessageContent = Union[TextContent, ImageContent]
class VideoURLDetail(BaseModel):
url: str
class VideoContent(BaseModel):
type: Literal["video_url"]
video_url: VideoURLDetail
MessageContent = Union[TextContent, ImageContent, VideoContent]
class ChatMessage(BaseModel):
......@@ -125,16 +134,19 @@ class MultiModalRequest(BaseModel):
class vLLMMultimodalRequest(vLLMGenerateRequest):
model_config = ConfigDict(arbitrary_types_allowed=True)
image_url: str
image_url: Optional[str] = None
video_url: Optional[str] = None
class EncodeRequest(BaseModel):
"""
Serializable class of all the fields vLLM engine requires for inference
Serializable class for encoding requests for both images and videos
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
image_url: str
image_url: Optional[str] = None
video_url: Optional[str] = None
num_frames: Optional[int] = None
request_id: str
serialized_request: Optional[connect.SerializedRequest] = None
......@@ -144,6 +156,7 @@ class EncodeResponse(BaseModel):
request_id: str
image_grid_thw: Optional[List[Any]] = None
image_sizes: Optional[List[Any]] = None
raw_frames: Optional[List[List[List[List[int]]]]] = None
class MyRequestOutput(BaseModel):
......
......@@ -69,6 +69,48 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
default="<prompt>",
help="Prompt template to use for the model",
)
parser.add_argument(
"--num-sampled-frames",
type=int,
default=8,
help="Number of frames to sample from the video",
)
parser.add_argument(
"--frame-height",
type=int,
default=336,
help="Height of the video frames",
)
parser.add_argument(
"--frame-width",
type=int,
default=336,
help="Width of the video frames",
)
parser.add_argument(
"--frame-channels",
type=int,
default=3,
help="Number of channels in the video frames",
)
parser.add_argument(
"--dummy-token-id",
type=int,
default=0,
help="Dummy token ID",
)
parser.add_argument(
"--video-token-id",
type=int,
default=32000,
help="Video token ID",
)
parser.add_argument(
"--dummy-tokens-per-frame",
type=int,
default=144,
help="Number of dummy tokens per frame",
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args(vllm_args)
engine_args = AsyncEngineArgs.from_cli_args(args)
......@@ -78,6 +120,13 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
engine_args.max_local_prefill_length = args.max_local_prefill_length
engine_args.max_prefill_queue_size = args.max_prefill_queue_size
engine_args.prompt_template = args.prompt_template
engine_args.num_sampled_frames = args.num_sampled_frames
engine_args.frame_height = args.frame_height
engine_args.frame_width = args.frame_width
engine_args.frame_channels = args.frame_channels
engine_args.dummy_token_id = args.dummy_token_id
engine_args.video_token_id = args.video_token_id
engine_args.dummy_tokens_per_frame = args.dummy_tokens_per_frame
engine_args.num_patches = args.num_patches
engine_args.image_token_id = args.image_token_id
return engine_args
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment