"examples/backends/trtllm/vscode:/vscode.git/clone" did not exist on "c8770464abcb5665343c0355e80abb6ab060bb2a"
Unverified Commit 4dc529a1 authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

chore: remove vLLM v0 multimodal example (#2099)

parent 384e449d
This diff is collapsed.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
import signal
from typing import Optional
import connect
import torch
from components.disagg_router import PyDisaggregatedRouter
from components.encode_worker import VllmEncodeWorker
from components.prefill_worker import VllmPrefillWorker
from utils.logging import check_required_workers
from utils.model import construct_mm_data, get_vision_embeddings_info
from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue
from utils.protocol import (
EncodeRequest,
EncodeResponse,
MyRequestOutput,
vLLMMultimodalRequest,
)
from utils.vllm import parse_vllm_args
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.inputs.data import TokensPrompt
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from vllm.sampling_params import RequestOutputKind
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
logger = logging.getLogger(__name__)
@service(
dynamo={
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class VllmDecodeWorker:
# For disaggregated serving, we need to link the prefill worker to the vllm worker
prefill_worker = depends(VllmPrefillWorker)
# For aggregated serving, we need to link the encode worker to the vllm worker.
encode_worker = depends(VllmEncodeWorker)
def __init__(self):
self.client = None
self.min_workers = 1
self.disaggregated_router: Optional[PyDisaggregatedRouter] = None
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self.do_remote_prefill = self.engine_args.remote_prefill
self.model_name = (
self.engine_args.served_model_name
if self.engine_args.served_model_name is not None
else "vllm"
)
self._prefill_queue_nats_server = os.getenv(
"NATS_SERVER", "nats://localhost:4222"
)
self._prefill_queue_stream_name = self.model_name
logger.info(
f"Prefill queue: {self._prefill_queue_nats_server}:{self._prefill_queue_stream_name}"
)
if self.engine_args.remote_prefill:
if self.engine_args.enable_chunked_prefill is not False:
logger.info("Chunked prefill is not supported yet, setting to False")
self.engine_args.enable_chunked_prefill = False
if self.engine_args.preemption_mode != "swap":
logger.info("Preemption mode is not supported yet, setting to swap")
self.engine_args.preemption_mode = "swap"
if self.engine_args.pipeline_parallel_size != 1:
logger.info("Pipeline parallel size is not supported yet, setting to 1")
self.engine_args.pipeline_parallel_size = 1
if self.engine_args.router == "kv":
raise NotImplementedError(
"Multimodal requests are not supported for kv router mode"
)
signal.signal(signal.SIGTERM, self.shutdown_vllm_engine)
signal.signal(signal.SIGINT, self.shutdown_vllm_engine)
@async_on_start
async def async_init(self):
self._engine_context = build_async_engine_client_from_engine_args(
self.engine_args
)
if self._engine_context is not None:
self.engine_client = await self._engine_context.__aenter__()
else:
raise RuntimeError("Failed to initialize engine client")
if self.engine_args.router == "kv":
raise NotImplementedError(
"Multimodal requests are not supported for kv router mode"
)
runtime = dynamo_context["runtime"]
embeddings_shape, self.embeddings_dtype = get_vision_embeddings_info(
self.engine_args.model, self.engine_args.num_patches
)
logger.debug(f"Embeddings shape: {embeddings_shape}")
self.embedding_size = embeddings_shape[1]
if self.do_remote_prefill:
metadata = self.engine_client.nixl_metadata
metadata_store = NixlMetadataStore("dynamo", runtime)
await metadata_store.put(metadata.engine_id, metadata)
if self.engine_args.conditional_disagg:
self.disaggregated_router = PyDisaggregatedRouter(
runtime,
self.model_name,
max_local_prefill_length=self.engine_args.max_local_prefill_length,
max_prefill_queue_size=self.engine_args.max_prefill_queue_size,
)
await self.disaggregated_router.async_init()
else:
self.disaggregated_router = None
else:
EMBEDDINGS_DTYPE = torch.float16
EMBEDDINGS_DEVICE = "cuda"
enc_comp_ns, enc_comp_name = VllmEncodeWorker.dynamo_address() # type: ignore
self.encode_worker_client = (
await runtime.namespace(enc_comp_ns)
.component(enc_comp_name)
.endpoint("encode")
.client()
)
self._connector = connect.Connector(runtime=runtime, namespace=enc_comp_ns)
await self._connector.initialize()
# Create a longer-lived buffer for receiving the image embeddings.
embeddings = torch.empty(
embeddings_shape, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE
)
descriptor = connect.Descriptor(embeddings)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
descriptor.register_memory(self._connector)
self._embeddings_descriptor = (embeddings, descriptor)
await check_required_workers(self.encode_worker_client, self.min_workers)
self.disaggregated_router = None
configuration = "Disaggregated" if self.do_remote_prefill else "Aggregated"
logger.info("Initialization complete { configuration: %s }.", configuration)
def shutdown_vllm_engine(self, signum, frame):
"""Shutdown the background loop"""
logger.info(f"Received signal {signum}, shutting down")
loop = asyncio.get_event_loop()
try:
self.engine_client.close()
logger.info("Shutdown complete.")
except Exception as e:
logger.error(f"Error during shutdown: {e}")
finally:
loop.stop()
def get_remote_prefill_request_callback(self):
async def callback(request: RemotePrefillRequest):
async with PrefillQueue.get_instance(
nats_server=self._prefill_queue_nats_server,
stream_name=self._prefill_queue_stream_name,
) as prefill_queue:
await prefill_queue.enqueue_prefill_request(request)
return callback
@endpoint()
async def generate(self, request: vLLMMultimodalRequest):
request_id = request.request_id
logger.info(f"Received multimodal request {{ id: {request_id} }}.")
if self.do_remote_prefill:
(
prompt_ids,
multi_modal_data,
remote_prefill_params,
) = await self.remote_prefill(request)
else:
(
prompt_ids,
multi_modal_data,
remote_prefill_params,
) = await self.local_prefill(request)
logger.debug(f"Prompt ids: {prompt_ids}")
logger.debug(f"Multi modal data: {multi_modal_data}")
logger.debug(f"Remote prefill params: {remote_prefill_params}")
# rust HTTP requires Delta streaming
request.sampling_params.output_kind = RequestOutputKind.DELTA
async for response in self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=prompt_ids,
multi_modal_data=multi_modal_data,
),
sampling_params=request.sampling_params,
request_id=request.request_id,
remote_prefill_params=remote_prefill_params,
):
logger.debug(
f"Yielding response {{ id: {response.request_id}, prompt: '{response.prompt}' }}"
)
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
).model_dump_json()
async def local_prefill(self, request: vLLMMultimodalRequest) -> tuple:
"""
Handles local prefill in aggregated serving mode.
Interacts with the encode worker to obtain image embeddings and returns
the original prompt tokens with multi-modal data for local processing.
Args:
request: The multimodal request containing image URL and prompt data
Returns:
Tuple of (prompt_ids, multi_modal_data, remote_prefill_params)
"""
logger.debug(
f"Aggregated: request {{ id: {request.request_id} }}"
" no prefill worker available, embeddings directly from encode worker."
)
# Extract the pre-allocated, reusable image embeddings tensor and its descriptor.
# Doing this avoids unnessesary memory de/registration with NIXL.
embeddings, descriptor = self._embeddings_descriptor
with self._connector.create_writable(descriptor) as writable:
# Extract serialized metadata about the operation from the writable operation,
# and use it to create a new EncodeRequest.
encode_request = EncodeRequest(
request_id=request.request_id,
image_url=request.image_url,
serialized_request=writable.to_serialized(),
)
logger.debug(f"Encode request: {encode_request.model_dump_json()}")
encode_generator = await self.encode_worker_client.round_robin(
encode_request.model_dump_json()
)
async for encode_response in encode_generator:
encode_output = EncodeResponse.model_validate_json(
encode_response.data()
)
logger.info(f"Received response: {{ id: {encode_output.request_id} }}")
# Wait for the write operation to complete.
# This will block until the write operation is complete.
# This await should be a no-op since we've already received a response from the encode worker.
await writable.wait_for_completion()
# At this point, the `embeddings` tensor is filled with the image embeddings from the remote encode worker.
remote_prefill_params = None
logger.debug(
f"Prefilling locally for request {{ id: {request.request_id} }} with length {len(request.engine_prompt['prompt_token_ids'])}"
)
prompt_ids = request.engine_prompt["prompt_token_ids"]
logger.debug(
"Aggregated: embedding data from encode worker provided via multi-modal data to decode model."
)
# When using disaggregated serving, the encode worker will have provided the key-value cache updates via the encode worker.
multi_modal_data = construct_mm_data(
self.engine_args.model, encode_output, embeddings, self.embeddings_dtype
)
return prompt_ids, multi_modal_data, remote_prefill_params
async def remote_prefill(self, request: vLLMMultimodalRequest) -> tuple:
"""
Handles remote prefill in disaggregated serving mode.
Creates remote prefill parameters and inserts dummy tokens for proper
memory allocation. No direct encode worker interaction is required.
Args:
request: The multimodal request containing image URL and prompt data
Returns:
Tuple of (prompt_ids, multi_modal_data, remote_prefill_params)
"""
logger.debug(
f"Disaggregated: request {{ id: {request.request_id} }}"
" prefill worker will populate the decode model's key-value cache ahead of time;"
" no direct encode worker interaction required."
)
if self.disaggregated_router is not None:
async with PrefillQueue.get_instance(
nats_server=self._prefill_queue_nats_server,
stream_name=self._prefill_queue_stream_name,
) as prefill_queue:
prefill_queue_size = await prefill_queue.get_queue_size()
disagg_router_decision = await self.disaggregated_router.prefill_remote(
len(request.engine_prompt["prompt_token_ids"]),
request.prefix_hit_rate,
prefill_queue_size,
)
else:
# always prefill remotely if no disaggregated router is provided
disagg_router_decision = True
if self.do_remote_prefill and disagg_router_decision:
logger.debug(
f"Prefilling remotely for request {{ id: {request.request_id} }} with length {len(request.engine_prompt['prompt_token_ids'])}"
)
remote_prefill_params = RemotePrefillParams(
is_remote_prefill=True,
remote_prefill_request_callback=self.get_remote_prefill_request_callback(),
# Pass the image url as part of the RemotePrefillParams, which will be passed to the prefill worker via RemotePrefillRequest
multimodal_data_source={
"image_url": request.image_url,
},
)
else:
remote_prefill_params = None
logger.debug(
f"Prefilling locally for request {{ id: {request.request_id} }} with length {len(request.engine_prompt['prompt_token_ids'])}"
)
# The decode worker will pre-allocate the memory based on the prompt token length for the prefill worker to transfer the kv cache.
# As a workaround, here we manually insert some placeholder dummy tokens based on the embedding size
# so that decode worker can pre-allocate the memory with the correct size.
# The structure of the prompt will be like: "\nUSER: <image> <dummy_tokens>\n<user_prompt>\nASSISTANT:".
# Since the "<image>" token is included in the prompt, only need to insert embedding_size dummy tokens after the image token.
DUMMY_TOKEN_ID = 0
# Find the index of the image token in the prompt token ids
image_token_index = request.engine_prompt["prompt_token_ids"].index(
self.engine_args.image_token_id
)
dummy_token_index = image_token_index + 1
prompt_ids = (
request.engine_prompt["prompt_token_ids"][:dummy_token_index]
+ [DUMMY_TOKEN_ID] * self.embedding_size
+ request.engine_prompt["prompt_token_ids"][dummy_token_index:]
)
logger.debug(
"Disaggregated: no embedding data required as prefill will have provided key-value cache updates via encode worker."
)
# When using aggregated serving, the encode worker will have provided the key-value cache updates via the prefill worker.
multi_modal_data = None
return prompt_ids, multi_modal_data, remote_prefill_params
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dynamo.runtime import EtcdKvCache
from dynamo.sdk import dynamo_context
logger = logging.getLogger(__name__)
class PyDisaggregatedRouter:
def __init__(
self,
runtime,
served_model_name,
max_local_prefill_length=1000,
max_prefill_queue_size=2,
):
self.runtime = runtime
self.served_model_name = served_model_name
self.max_local_prefill_length = max_local_prefill_length
self.max_prefill_queue_size = max_prefill_queue_size
async def async_init(self):
runtime = dynamo_context["runtime"]
self.etcd_kv_cache = await EtcdKvCache.create(
runtime.etcd_client(),
"/dynamo/disagg_router/",
{
"max_local_prefill_length": str(self.max_local_prefill_length),
"max_prefill_queue_size": str(self.max_prefill_queue_size),
},
)
async def prefill_remote(
self, prompt_length: int, prefix_hit_rate: float, queue_size: int
):
max_local_prefill_length = int(
await self.etcd_kv_cache.get("max_local_prefill_length")
)
max_prefill_queue_size = int(
await self.etcd_kv_cache.get("max_prefill_queue_size")
)
absolute_prefill_length = int(prompt_length * (1 - prefix_hit_rate))
# TODO: consider size of each request in the queue when making the decision
decision = (
absolute_prefill_length > max_local_prefill_length
and queue_size < max_prefill_queue_size
)
logger.info(
f"Remote prefill: {decision} (prefill length: {absolute_prefill_length}/{max_local_prefill_length}, prefill queue size: {queue_size}/{max_prefill_queue_size})"
)
return decision
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import base64
import binascii
import logging
from io import BytesIO
from queue import Queue
from typing import AsyncIterator, Optional
from urllib.parse import urlparse
import connect
import httpx
import torch
from PIL import Image
from transformers import AutoImageProcessor
from utils.model import load_vision_model
from utils.protocol import EncodeRequest, EncodeResponse
from utils.vllm import parse_vllm_args
from dynamo.sdk import async_on_start, endpoint, service
logger = logging.getLogger(__name__)
try:
import cupy as array_module
if not array_module.cuda.is_available():
raise ImportError("CUDA is not available.")
DEVICE = "cuda"
logger.info("Using cupy for array operations (GPU mode).")
except ImportError as e:
logger.warning(f"Failed to import cupy, falling back to numpy: {e}.")
import numpy as array_module
DEVICE = "cpu"
CACHE_SIZE_MAXIMUM = 8
@service(
dynamo={
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class VllmEncodeWorker:
def __init__(self) -> None:
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self.MODEL_ID = self.engine_args.model
self.image_processor = AutoImageProcessor.from_pretrained(
self.MODEL_ID, trust_remote_code=True
)
self.vision_model = load_vision_model(self.MODEL_ID)
self._image_cache: dict[str, Image.Image] = {}
self._cache_queue: Queue[str] = Queue(maxsize=CACHE_SIZE_MAXIMUM)
self._http_client: Optional[httpx.AsyncClient] = None
self._http_timeout = 30.0
async def load_image(self, image_url: str) -> Image.Image:
parsed_url = urlparse(image_url)
# For HTTP(S) URLs, check cache first
if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
if image_url_lower in self._image_cache:
logger.debug(f"Image found in cache for URL: {image_url}")
return self._image_cache[image_url_lower]
try:
if parsed_url.scheme == "data":
# Parse data URL format: data:[<media type>][;base64],<data>
if not parsed_url.path.startswith("image/"):
raise ValueError("Data URL must be an image type")
# Split the path into media type and data
media_type, data = parsed_url.path.split(",", 1)
if ";base64" not in media_type:
raise ValueError("Data URL must be base64 encoded")
try:
image_bytes = base64.b64decode(data)
image_data = BytesIO(image_bytes)
except binascii.Error as e:
raise ValueError(f"Invalid base64 encoding: {e}")
elif parsed_url.scheme in ("http", "https"):
if not self._http_client:
raise RuntimeError("HTTP client not initialized")
response = await self._http_client.get(image_url)
response.raise_for_status()
if not response.content:
raise ValueError("Empty response content from image URL")
image_data = BytesIO(response.content)
else:
raise ValueError(f"Invalid image source scheme: {parsed_url.scheme}")
# PIL is sync, so offload to a thread to avoid blocking the event loop
image = await asyncio.to_thread(Image.open, image_data)
# Validate image format and convert to RGB
if image.format not in ("JPEG", "PNG", "WEBP"):
raise ValueError(f"Unsupported image format: {image.format}")
image_converted = image.convert("RGB")
# Cache HTTP(S) URLs
if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
# Cache the image for future use, and evict the oldest image if the cache is full
if self._cache_queue.full():
oldest_image_url = self._cache_queue.get()
del self._image_cache[oldest_image_url]
self._image_cache[image_url_lower] = image_converted
self._cache_queue.put(image_url_lower)
return image
except httpx.HTTPError as e:
logger.error(f"HTTP error loading image: {e}")
raise
except Exception as e:
logger.error(f"Error loading image: {e}")
raise ValueError(f"Failed to load image: {e}")
@endpoint()
async def encode(self, request: EncodeRequest) -> AsyncIterator[EncodeResponse]:
logger.debug(f"Received encode request: {{ id: {request.request_id} }}.")
request_id = request.request_id
# The following steps encode the requested image and provided useful embeddings.
# 1. Open the image from the provided URL.
# 2. Process the image using the image processor.
# 3. Run the image through the vision model's vision tower.
# 4. Run the results of the vision tower through the multi-modal projector.
# 5. Create a descriptor for the embeddings.
# 6. Create a write operation using the serialized request and the descriptor.
# 7. Await for the write operation to complete.
# 8. Yield the encode response.
try:
image = await self.load_image(request.image_url)
logger.debug(f"Processing image for request: {{ id: {request_id} }}")
image_embeds = self.image_processor(images=image, return_tensors="pt")
# Add a batch dimension to everything
for item in image_embeds:
image_embeds[item] = image_embeds[item].unsqueeze(0).to(DEVICE)
logger.debug(f"Image embeds: {image_embeds}")
image_grid_thw = (
image_embeds["image_grid_thw"].tolist()
if "image_grid_thw" in image_embeds
else None
)
image_sizes = (
image_embeds["image_sizes"].tolist()
if "image_sizes" in image_embeds
else [image.size]
)
logger.debug(
f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}"
)
with torch.no_grad():
embeddings = self.vision_model.get_multimodal_embeddings(**image_embeds)
if isinstance(embeddings, tuple) or isinstance(embeddings, list):
# The result multimodal_embeddings may be a list or tuple of tensors, with each
# tensor corresponding to a multimodal data item (image or video).
# TODO: for multi-image support, this result will contain multiple tensors.
embeddings = embeddings[0].unsqueeze(0)
logger.debug(
f"Embeddings: {{ shape: {embeddings.shape}, dtype: {embeddings.dtype}, device: {embeddings.device}, ptr: {embeddings.data_ptr()}, elements: {{ count: {embeddings.numel()}, size: {embeddings.element_size()} }} }}."
)
if request.serialized_request is None:
logger.error(
f"Request serialized_request is None for request: {{ id: {request_id} }}."
)
# Create a descriptor for the embeddings, this will register the memory with the connector (and the NIXL runtime).
descriptor = connect.Descriptor(embeddings)
# Create a write operation using the serialized request and the descriptor.
# This will begin the RDMA transfer of the embeddings to the remote worker.
write_op = await self._connector.begin_write(
descriptor,
request.serialized_request,
)
# Await for the write operation to complete.
# This will block until the data has been written to the remote worker or an error occurs.
await write_op.wait_for_completion()
yield EncodeResponse(
request_id=request.request_id,
image_grid_thw=image_grid_thw,
image_sizes=image_sizes,
).model_dump_json()
except Exception as e:
logger.error(f"Error processing request {request_id}: {e}")
raise
@async_on_start
async def async_init(self):
logger.info("Startup started.")
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector()
await self._connector.initialize()
# Initialize HTTP client with default limits
self._http_client = httpx.AsyncClient(timeout=self._http_timeout)
logger.info("Startup completed.")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
import signal
import sys
import connect
import torch
from components.encode_worker import VllmEncodeWorker
from pydantic import BaseModel
from utils.logging import check_required_workers
from utils.model import construct_mm_data, get_vision_embeddings_info
from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue
from utils.protocol import EncodeRequest, EncodeResponse
from utils.vllm import parse_vllm_args
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.inputs.data import TokensPrompt
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
logger = logging.getLogger(__name__)
EMBEDDINGS_DEVICE = "cuda"
class RequestType(BaseModel):
text: str
@service(
dynamo={
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class VllmPrefillWorker:
encode_worker = depends(VllmEncodeWorker)
def __init__(self):
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self._loaded_metadata = set()
self.initialized = False
self.min_workers = 1
if self.engine_args.enable_chunked_prefill is not False:
logger.info("Chunked prefill is not supported yet, setting to False")
self.engine_args.enable_chunked_prefill = False
if self.engine_args.pipeline_parallel_size != 1:
logger.info("Pipeline parallel size is not supported yet, setting to 1")
self.engine_args.pipeline_parallel_size = 1
if self.engine_args.disable_async_output_proc is not True:
logger.info("Async output processing is not supported yet, setting to True")
self.engine_args.disable_async_output_proc = True
if self.engine_args.enforce_eager is not True:
logger.info("Prefill must be done eagerly, setting to True")
self.engine_args.enforce_eager = True
if self.engine_args.enable_prefix_caching is not False:
logger.info(
"Prefix caching is not supported yet in prefill worker, setting to False"
)
self.engine_args.enable_prefix_caching = False
signal.signal(signal.SIGTERM, self.shutdown_vllm_engine)
signal.signal(signal.SIGINT, self.shutdown_vllm_engine)
@async_on_start
async def async_init(self):
self._engine_context = build_async_engine_client_from_engine_args(
self.engine_args
)
if self._engine_context is not None:
self.engine_client = await self._engine_context.__aenter__()
else:
raise RuntimeError("Failed to initialize engine client")
runtime = dynamo_context["runtime"]
enc_comp_ns, enc_comp_name = VllmEncodeWorker.dynamo_address() # type: ignore
self.encode_worker_client = (
await runtime.namespace(enc_comp_ns)
.component(enc_comp_name)
.endpoint("encode")
.client()
)
self._connector = connect.Connector(runtime=runtime, namespace=enc_comp_ns)
await self._connector.initialize()
# Create a longer-lived buffer for receiving the image embeddings.
embeddings_shape, self.embeddings_dtype = get_vision_embeddings_info(
self.engine_args.model, self.engine_args.num_patches
)
embeddings = torch.empty(
embeddings_shape,
dtype=self.embeddings_dtype,
device=EMBEDDINGS_DEVICE,
)
descriptor = connect.Descriptor(embeddings)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsystem will take care of this automatically).
descriptor.register_memory(self._connector)
self._embeddings_descriptor = (embeddings, descriptor)
await check_required_workers(self.encode_worker_client, self.min_workers)
metadata = self.engine_client.nixl_metadata
self._metadata_store = NixlMetadataStore("dynamo", runtime)
await self._metadata_store.put(metadata.engine_id, metadata)
task = asyncio.create_task(self.prefill_queue_handler())
def prefill_queue_handler_cb(fut):
try:
fut.result()
logger.info("prefill queue handler exited successfully")
except Exception as e:
logger.error(f"[ERROR] prefill queue handler failed: {e!r}")
sys.exit(1)
task.add_done_callback(prefill_queue_handler_cb)
logger.info("Initialization complete.")
def shutdown_vllm_engine(self, signum, frame):
"""Shutdown the background loop"""
logger.info(f"Shutdown started, signal {signum} received.")
loop = asyncio.get_event_loop()
try:
self.engine_client.close()
except Exception as e:
logger.error(f"Error during shutdown: {e}")
finally:
loop.stop()
logger.info("Shutdown complete.")
async def prefill_queue_handler(self):
logger.info("Prefill queue handler entered")
prefill_queue_nats_server = os.getenv("NATS_SERVER", "nats://localhost:4222")
prefill_queue_stream_name = (
self.engine_args.served_model_name
if self.engine_args.served_model_name is not None
else "vllm"
)
logger.info(
f"Prefill queue: {prefill_queue_nats_server}:{prefill_queue_stream_name}"
)
self.initialized = True
# TODO: integrate prefill_queue to a dynamo endpoint
async with PrefillQueue.get_instance(
nats_server=prefill_queue_nats_server,
stream_name=prefill_queue_stream_name,
) as prefill_queue:
logger.info("prefill queue handler started")
while True:
# TODO: this might add a small overhead to pull prefill from nats
# need to test and check how much overhead it is
prefill_request = await prefill_queue.dequeue_prefill_request()
if prefill_request is not None:
logger.info(
f"Dequeued prefill request: {prefill_request.request_id}"
)
async for _ in self.generate(prefill_request):
pass
async def generate(self, request: RemotePrefillRequest):
if request.multimodal_data_source["image_url"] is None:
raise ValueError("No image url provided for prefill request")
request_id = request.request_id
engine_id = request.engine_id
image_url = request.multimodal_data_source["image_url"]
logger.info(
f"Received prefill request {{ id: {request_id}, engine_id: {engine_id} }}."
)
# Extract the pre-allocated, reusable image embeddings tensor and its descriptor.
# Doing this avoids unnecessary memory de/registration with NIXL.
embeddings, descriptor = self._embeddings_descriptor
# Create a new writable operation from the descriptor.
with self._connector.create_writable(descriptor) as writable:
# Extract serialized metadata about the operation from the writable operation,
# and use it to create a new EncodeRequest.
encode_generator = await self.encode_worker_client.round_robin(
EncodeRequest(
request_id=request_id,
image_url=image_url,
serialized_request=writable.to_serialized(),
).model_dump_json()
)
async for encode_response in encode_generator:
encode_output = EncodeResponse.model_validate_json(
encode_response.data(),
)
logger.debug(
f"Received response: {{ id: {encode_output.request_id} }}."
)
# Wait for the write operation to complete.
# This will block until the write operation is complete.
# This await should be a no-op since we've already received a response from the encode worker.
await writable.wait_for_completion()
# At this point, the `embeddings` tensor is filled with the image embeddings from the remote encode worker.
sampling_params = request.sampling_params
sampling_params.max_tokens = 1
sampling_params.min_tokens = 1
remote_prefill_params = RemotePrefillParams(
is_remote_decode=True,
decode_block_ids=request.block_ids,
decode_engine_id=engine_id,
decode_computed_block_ids=request.computed_block_ids,
)
# TODO check if metadata has changed
# and reload - currently only loading once
if engine_id not in self._loaded_metadata:
remote_metadata = await self._metadata_store.get(request.engine_id)
await self.engine_client.add_remote_nixl_metadata(remote_metadata)
logger.info(
f"Loaded nixl metadata from engine {engine_id} into "
f"engine {self.engine_client.nixl_metadata.engine_id}"
)
self._loaded_metadata.add(engine_id)
# To make sure the decode worker can pre-allocate the memory with the correct size for the prefill worker to transfer the kv cache,
# some placeholder dummy tokens are inserted based on the embedding size in the worker.py.
# TODO: make this more flexible/model-dependent
embedding_size = embeddings.shape[1]
padding_size = embedding_size
image_token_index = request.prompt_token_ids.index(
self.engine_args.image_token_id
)
dummy_token_index = image_token_index + 1
prompt_token_ids = (
request.prompt_token_ids[:dummy_token_index]
+ request.prompt_token_ids[dummy_token_index + padding_size :]
)
async for _ in self.engine_client.generate(
request_id=request_id,
prompt=TokensPrompt(
prompt_token_ids=prompt_token_ids,
multi_modal_data=construct_mm_data(
self.engine_args.model,
encode_output,
embeddings,
self.embeddings_dtype,
),
),
sampling_params=sampling_params,
remote_prefill_params=remote_prefill_params,
):
yield
@endpoint()
async def mock(self, req: RequestType):
yield f"mock_response: {req}"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import uuid
from enum import Enum
from typing import AsyncIterator, Tuple, Union
from components.decode_worker import VllmDecodeWorker
from transformers import AutoTokenizer
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.logging import check_required_workers
from utils.protocol import MultiModalRequest, MyRequestOutput, vLLMMultimodalRequest
from utils.vllm import parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from dynamo.runtime import EtcdKvCache
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
logger = logging.getLogger(__name__)
class RequestType(Enum):
CHAT = "chat"
COMPLETION = "completion"
@service(
dynamo={
"namespace": "dynamo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
class Processor(ProcessMixIn):
"""
vLLM pre and post processing
"""
worker = depends(VllmDecodeWorker)
def __init__(self):
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self.model_config = self.engine_args.create_model_config()
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.tokenizer = self._create_tokenizer(self.engine_args)
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
self.completions_processor = CompletionsProcessor(
self.tokenizer, self.model_config
)
self.min_workers = 1
def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer:
"""Create a TokenizerGroup using engine arguments similar to VLLM's approach"""
model_path = engine_args.model
# Create the base tokenizer with VLLM's typical settings
base_tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
padding_side="left",
truncation_side="left",
use_fast=True, # VLLM might use the fast tokenizer for efficiency
)
return base_tokenizer
@async_on_start
async def async_init(self):
runtime = dynamo_context["runtime"]
comp_ns, comp_name = VllmDecodeWorker.dynamo_address() # type: ignore
self.worker_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("generate")
.client()
)
await check_required_workers(self.worker_client, self.min_workers)
self.etcd_kv_cache = await EtcdKvCache.create(
runtime.etcd_client(),
"/dynamo/processor/",
{"router": self.engine_args.router},
)
# Main method to parse the request and send the request to the vllm worker.
async def _generate(
self,
raw_request: Union[CompletionRequest, ChatCompletionRequest],
image: str,
request_type: RequestType,
):
request_id = str(uuid.uuid4())
logger.debug(f"Got raw request: {raw_request}")
(
request,
conversation,
prompt,
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
worker_request = vLLMMultimodalRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
image_url=image,
)
router_mode = (await self.etcd_kv_cache.get("router")).decode()
if router_mode == "kv":
# The current KV router does not support multimodal requests because
# it performs cache lookup based solely on prompt tokens. At this stage,
# multimodal data (e.g., image features) is not yet available, so the router
# cannot select the optimal worker using both prompt and image inputs.
raise NotImplementedError(
"Multimodal requests are not supported for kv router mode"
)
if router_mode == "random":
response_generator = await self.worker_client.generate(
worker_request.model_dump_json()
)
elif router_mode == "round-robin":
response_generator = await self.worker_client.round_robin(
worker_request.model_dump_json()
)
else:
raise NotImplementedError(f"Router mode {router_mode} not implemented")
output = self._generate_responses(response_generator, request_type)
async for response in await self._stream_response(
request, output, request_id, conversation
):
yield response
# This method is used to process the responses from the engine generator.
async def _generate_responses(
self,
response_generator: AsyncIterator[RequestOutput],
request_type: RequestType,
) -> AsyncIterator[Union[RequestOutput, Tuple[int, RequestOutput]]]:
prompt_idx = 0
async for resp in response_generator:
# Deserialize the response from the engine
# Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data())
# OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object
request_output = RequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
)
if request_type == RequestType.CHAT:
# For chat requests, yield the request_output directly.
yield request_output
elif request_type == RequestType.COMPLETION:
# Completion requests can have multiple prompts and stream generator requires the prompt index
yield (prompt_idx, request_output)
else:
raise NotImplementedError(
f"Request type {request_type} not implemented"
)
# The generate endpoint will be used by the frontend to handle incoming requests.
@endpoint()
async def generate(self, raw_request: MultiModalRequest):
# Ensure the configured template includes the placeholder
template = self.engine_args.prompt_template
if "<prompt>" not in template:
raise ValueError("prompt_template must contain '<prompt>' placeholder")
# Safely extract user text
try:
user_text = raw_request.messages[0].content[0].text
except (IndexError, AttributeError) as e:
raise ValueError(f"Invalid message structure: {e}")
prompt = template.replace("<prompt>", user_text)
msg = {
"role": "user",
"content": prompt,
}
chat_request = ChatCompletionRequest(
model=raw_request.model,
messages=[msg],
stream=raw_request.stream,
max_tokens=raw_request.max_tokens,
temperature=raw_request.temperature,
request_id=str(uuid.uuid4()),
)
image_url = None
for message in raw_request.messages:
for item in message.content:
if item.type == "image_url":
image_url = item.image_url.url
if image_url is None:
raise ValueError("Image URL is required")
async for response in self._generate(chat_request, image_url, RequestType.CHAT):
yield json.dumps(response)
This diff is collapsed.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import base64
import binascii
import json
import logging
import os
from io import BytesIO
from queue import Queue
from typing import AsyncIterator, Optional
from urllib.parse import urlparse
import av
import connect
import httpx
import numpy as np
import torch
import torch.nn.functional as F
from utils.protocol import EncodeRequest
from utils.vllm import parse_vllm_args
from dynamo.sdk import async_on_start, endpoint, service
logger = logging.getLogger(__name__)
CACHE_SIZE_MAXIMUM = 8
@service(
dynamo={
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class VllmEncodeWorker:
def __init__(self) -> None:
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self.num_frames_to_sample = getattr(self.engine_args, "num_sampled_frames", 8)
self.frame_height = getattr(self.engine_args, "frame_height", 336)
self.frame_width = getattr(self.engine_args, "frame_width", 336)
self.frame_channels = getattr(self.engine_args, "frame_channels", 3)
self.dummy_token_id = getattr(self.engine_args, "dummy_token_id", 0)
self.video_token_id = getattr(self.engine_args, "video_token_id", 32000)
self.dummy_tokens_per_frame = getattr(
self.engine_args, "dummy_tokens_per_frame", 144
)
self._video_content_cache: dict[str, BytesIO] = {}
self._cache_queue: Queue[str] = Queue(maxsize=CACHE_SIZE_MAXIMUM)
self._http_client: Optional[httpx.AsyncClient] = None
self._http_timeout = 60.0
async def _read_video_pyav(
self, container: av.container.InputContainer, indices: np.ndarray
) -> np.ndarray:
"""
Decode the video with PyAV decoder. Async wrapper.
"""
def blocking_decode():
container.seek(0) # Reset container for decoding
processed_indices = set(indices)
# Determine min/max index to optimize decoding loop slightly
min_idx = 0
max_idx = -1
if len(indices) > 0:
min_idx = np.min(indices)
max_idx = np.max(indices)
if (
not processed_indices
and container.streams.video
and container.streams.video[0].frames > 0
):
logger.warning(
"_read_video_pyav called with empty indices for a non-empty video, attempting to read first frame."
)
try:
frame = next(container.decode(video=0))
return np.stack([frame.to_ndarray(format="rgb24")])
except StopIteration:
logger.error(
"Failed to read even the first frame despite non-empty indices check."
)
return np.array([])
decoded_frames_list = []
for i, frame in enumerate(container.decode(video=0)):
if i > max_idx and max_idx != -1: # max_idx is -1 if indices is empty
break
if i >= min_idx and i in processed_indices:
decoded_frames_list.append(frame)
if not decoded_frames_list and len(processed_indices) > 0:
actual_decoded_count = 0
try:
container.seek(0) # Reset for counting
for _ in container.decode(video=0):
actual_decoded_count += 1
except Exception: # Handle cases where re-decoding/counting fails
pass # Keep original error message
raise ValueError(
f"Could not decode any frames for the given indices: {indices.tolist()}. "
f"Video might be shorter than expected or indices out of bounds. "
f"Actual decodable frames in container (approx): {actual_decoded_count}."
)
return (
np.stack([x.to_ndarray(format="rgb24") for x in decoded_frames_list])
if decoded_frames_list
else np.array([])
)
return await asyncio.to_thread(blocking_decode)
async def _load_video_content(self, video_url: str) -> BytesIO:
parsed_url = urlparse(video_url)
video_url_lower = video_url.lower()
if parsed_url.scheme in ("http", "https"):
if video_url_lower in self._video_content_cache:
logger.info(f"Video content found in cache for URL: {video_url}")
cached_content = self._video_content_cache[video_url_lower]
cached_content.seek(0)
return cached_content
try:
video_data: BytesIO
if parsed_url.scheme == "data":
if not parsed_url.path.startswith(
("video/", "application/octet-stream")
):
raise ValueError("Data URL must be a video type or octet-stream")
media_type_and_data = parsed_url.path.split(",", 1)
if len(media_type_and_data) != 2:
raise ValueError("Invalid Data URL format: missing comma separator")
media_type, data_segment = media_type_and_data
if ";base64" not in media_type:
raise ValueError("Video Data URL currently must be base64 encoded")
try:
video_bytes = base64.b64decode(data_segment)
video_data = BytesIO(video_bytes)
except binascii.Error as e:
raise ValueError(
f"Invalid base64 encoding for video data: {e}"
) from e
elif parsed_url.scheme in ("http", "https"):
if not self._http_client:
await self._init_http_client()
if not self._http_client: # Double check after initialization
raise RuntimeError("Failed to initialize HTTP client")
logger.info(f"Downloading video from URL: {video_url}")
response = await self._http_client.get(
video_url, timeout=self._http_timeout
)
response.raise_for_status()
if not response.content:
raise ValueError(
f"Empty response content from video URL: {video_url}"
)
video_data = BytesIO(response.content)
video_data.seek(0)
logger.info(
f"Video downloaded from {video_url}, size: {len(response.content)} bytes."
)
elif parsed_url.scheme == "file" or not parsed_url.scheme:
file_path = parsed_url.path if parsed_url.scheme else video_url
# Ensure path is absolute or resolve relative to a known base if necessary
# For simplicity, assuming it's an accessible path.
if not os.path.exists(file_path):
raise FileNotFoundError(f"Error reading file: {file_path}")
with open(file_path, "rb") as f:
video_bytes = f.read()
video_data = BytesIO(video_bytes)
else:
raise ValueError(
f"Unsupported video source scheme: {parsed_url.scheme} for URL {video_url}"
)
if parsed_url.scheme in (
"http",
"https",
): # Cache successfully downloaded content
if self._cache_queue.full():
oldest_url = self._cache_queue.get_nowait()
if oldest_url in self._video_content_cache:
del self._video_content_cache[oldest_url]
# Store the BytesIO object directly; it will be seek(0)'d when retrieved
self._video_content_cache[video_url_lower] = video_data
self._cache_queue.put(video_url_lower)
return video_data
except httpx.HTTPStatusError as e:
logger.error(
f"HTTP error {e.response.status_code} loading video {video_url}: {e.response.text[:200]}"
)
raise ValueError(
f"Failed to download video {video_url}: HTTP {e.response.status_code}"
) from e
except httpx.RequestError as e:
logger.error(f"Request error loading video {video_url}: {e}")
raise ValueError(f"Network request failed for video {video_url}") from e
except FileNotFoundError as e:
logger.error(f"File error loading video {video_url}: {e}")
raise
except Exception as e:
logger.error(
f"Error loading video content from {video_url}: {type(e).__name__} - {e}"
)
raise ValueError(f"Failed to load video content: {e}") from e
@endpoint()
async def encode(self, request: EncodeRequest) -> AsyncIterator[str]:
request_id = request.request_id
video_url = getattr(request, "video_url", None)
if not video_url:
logger.error(f"Request {request_id}: 'video_url' not provided.")
raise ValueError("'video_url' is required for encoding.")
if request.serialized_request is None:
logger.error(
f"Request serialized_request is None for request: {{ id: {request_id} }}."
)
raise ValueError("'serialized_request' is required for encoding.")
logger.info(
f"Received encode request: {{ id: {request_id}, video_url: '{video_url[:100]}...' }}"
)
container: Optional[av.container.InputContainer] = None
try:
video_content_stream = await self._load_video_content(video_url)
def open_video_container_sync():
try:
return av.open(video_content_stream, mode="r")
except av.FFmpegError as ave:
logger.error(
f"PyAV error opening video stream from {video_url}: {ave}"
)
raise ValueError(
f"Invalid video format or corrupted data from {video_url}."
) from ave
except Exception as e:
logger.error(
f"Unexpected error opening video stream from {video_url} with PyAV: {e}"
)
raise ValueError(
f"Unexpected error opening video from {video_url}."
) from e
container = await asyncio.to_thread(open_video_container_sync)
if not container or not container.streams.video:
logger.error(f"No video stream found in {video_url}.")
raise ValueError(f"No video stream in {video_url}.")
stream_info = container.streams.video[0]
total_frames = stream_info.frames
# Duration can be useful for streams where total_frames is 0
if stream_info.duration and stream_info.time_base:
duration_sec = float(stream_info.duration * stream_info.time_base)
else:
duration_sec = 0
if total_frames == 0 and duration_sec == 0:
logger.error(f"Video file '{video_url}' has 0 frames and 0 duration.")
raise ValueError(f"Video {video_url} has 0 frames and 0 duration.")
if total_frames == 0 and duration_sec > 0:
logger.warning(
f"Video {video_url} reports 0 frames but has duration {duration_sec:.2f}s. Frame sampling may be based on requested count directly."
)
logger.debug(
f"Video {video_url} has {total_frames} frames (duration: {duration_sec:.2f}s). Sampling {self.num_frames_to_sample} frames."
)
indices: np.ndarray
if total_frames > 0:
if total_frames < self.num_frames_to_sample:
logger.warning(
f"Video frames ({total_frames}) < samples ({self.num_frames_to_sample}). Using all {total_frames} available frames."
)
indices = np.arange(0, total_frames).astype(int)
else:
indices = np.linspace(
0, total_frames - 1, self.num_frames_to_sample, dtype=int
)
else: # total_frames is 0 (likely a stream), sample by count.
logger.warning(
f"Video {video_url} frame count is 0. Attempting to sample {self.num_frames_to_sample} frames by index. This might fail if stream is too short."
)
indices = np.arange(0, self.num_frames_to_sample).astype(int)
# Ensure indices are unique, especially after linspace for small numbers.
indices = np.unique(indices)
if (
len(indices) == 0 and total_frames > 0
): # Safety for linspace oddities with few frames
# If unique resulted in empty but there are frames, sample at least one or up to num_frames_to_sample
actual_samples = min(self.num_frames_to_sample, total_frames)
indices = np.arange(0, actual_samples).astype(int)
elif len(indices) == 0 and total_frames == 0:
# If indices is empty and total_frames is 0, this means num_frames_to_sample might be 0 or indices logic failed.
# This case implies we might not be able to sample any frames.
# _read_video_pyav handles empty indices with non-empty video by trying to read the first frame.
# If indices is empty here due to num_frames_to_sample=0, _read_video_pyav will return empty.
pass # Let _read_video_pyav handle this.
logger.info(f"Selected frame indices for {video_url}: {indices.tolist()}")
if not container:
raise ValueError(f"Container is None for {video_url}")
clip_np: np.ndarray = await self._read_video_pyav(container, indices)
if clip_np.size == 0:
raise ValueError(
f"Failed to extract any video frames from {video_url} for indices {indices.tolist()}. Clip is empty."
)
logger.info(
f"Successfully extracted {len(clip_np) if clip_np.ndim > 1 and clip_np.shape[0] > 0 else 0} frames for {video_url} with original shape {clip_np.shape}."
)
# Convert the NumPy array from the video decoder into a PyTorch tensor.
# This is a required step to use PyTorch functions for GPU-accelerated image processing.
frames_tensor_orig_res = torch.from_numpy(clip_np) # Shape: (T, H, W, C)
# Permute to (T, C, H, W) for interpolate
frames_tensor_chw = frames_tensor_orig_res.permute(
0, 3, 1, 2
).float() # Ensure float for interpolate
# Resize
resized_frames_tensor_chw = F.interpolate(
frames_tensor_chw,
size=(self.frame_height, self.frame_width),
mode="bilinear",
align_corners=False,
)
# Permute back to (T, H_new, W_new, C)
resized_frames_tensor_hwc = resized_frames_tensor_chw.permute(0, 2, 3, 1)
logger.debug(f"Resized frames to shape: {resized_frames_tensor_hwc.shape}")
# Ensure the tensor is contiguous, on CUDA and uint8 for the NIXL buffer.
tensor_for_descriptor: torch.Tensor = resized_frames_tensor_hwc.to(
device="cpu", dtype=torch.uint8
).contiguous()
logger.info(
f"Req {request_id}: Preparing raw frames tensor (shape: {tensor_for_descriptor.shape}, "
f"dtype: {tensor_for_descriptor.dtype}, device: {tensor_for_descriptor.device}, "
f"contiguous: {tensor_for_descriptor.is_contiguous()}) for RDMA."
)
# Create a descriptor for the tensor to be sent via the connector.
descriptor = connect.Descriptor(tensor_for_descriptor)
logger.info(f"Req {request_id}: Beginning connector write operation.")
# Pass the remote worker's SerializedRequest (representing its WritableOperation) to begin_write.
# This initiates the data transfer to the memory buffer on the other worker.
write_op = await self._connector.begin_write(
descriptor, request.serialized_request
)
# Wait for the RDMA/transfer operation to complete.
await write_op.wait_for_completion()
logger.info(f"Req {request_id}: Connector write operation completed.")
# Yield a simplified response, assuming EncodeResponse in protocol is adapted
final_response_data = {
"request_id": request.request_id,
}
yield json.dumps(final_response_data)
logger.info(f"Encode request {request_id} processed successfully.")
except (
FileNotFoundError,
av.FFmpegError,
ValueError,
) as e:
logger.error(
f"Error processing request {request_id} ({video_url[:100]}...): {type(e).__name__} - {e}"
)
raise # Re-raise to be handled by the service framework
except Exception as e:
logger.exception(
f"Unexpected error processing request {request_id} ({video_url[:100]}...): {e}"
)
raise
finally:
if container:
await asyncio.to_thread(container.close)
async def _init_http_client(self):
if (
not self._http_client or self._http_client.is_closed
): # Check if closed as well
self._http_client = httpx.AsyncClient(timeout=self._http_timeout)
logger.info("HTTP client (re)initialized.")
@async_on_start
async def async_init(self):
logger.info(f"{self.__class__.__name__} async_init started.")
# Initialize the connector for RDMA transfers.
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Dynamo connector initialized.")
await self._init_http_client()
logger.info(
f"{self.__class__.__name__} async_init completed. Ready to encode video frames."
)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from components.video_processor import Processor
from fastapi import FastAPI
from fastapi.responses import JSONResponse, StreamingResponse
from utils.protocol import MultiModalRequest
from utils.vllm import parse_vllm_args
from dynamo.sdk import DYNAMO_IMAGE, api, depends, service
logger = logging.getLogger(__name__)
@service(
dynamo={
"namespace": "dynamo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
image=DYNAMO_IMAGE,
app=FastAPI(title="Multimodal Example"),
)
class Frontend:
processor = depends(Processor)
def __init__(self):
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
@api(name="v1/chat/completions")
async def generate(self, request: MultiModalRequest):
if self.engine_args.model != request.model:
return JSONResponse(
{"error": f"Model '{request.model}' not found"},
status_code=404,
)
async def content_generator():
async for response in self.processor.generate(request.model_dump_json()):
try:
s = json.loads(response)
yield s
except json.JSONDecodeError as e:
raise RuntimeError(f"Failed to parse JSON response: {e}") from e
return StreamingResponse(content_generator(), media_type="text/event-stream")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
import signal
from typing import Optional
import connect
import torch
from components.video_encode_worker import VllmEncodeWorker
from pydantic import BaseModel
from utils.logging import check_required_workers
from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue
from utils.protocol import EncodeRequest
from utils.vllm import parse_vllm_args
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.inputs.data import TokensPrompt
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
logger = logging.getLogger(__name__)
# Constants for the shape and dtype of the INCOMING FRAMES tensor.
# Other constants taken from yaml as they are model dependent.
INCOMING_FRAMES_DTYPE = torch.uint8
INCOMING_FRAMES_DEVICE = "cpu"
class RequestType(BaseModel):
text: str
@service(
dynamo={
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class VllmPrefillWorker:
encode_worker = depends(VllmEncodeWorker)
def __init__(self):
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self.model_path = self.engine_args.model # Store model path for AutoProcessor
self.num_sampled_frames = getattr(self.engine_args, "num_sampled_frames", 8)
self.frame_height = getattr(self.engine_args, "frame_height", 336)
self.frame_width = getattr(self.engine_args, "frame_width", 336)
self.frame_channels = getattr(self.engine_args, "frame_channels", 3)
self.dummy_token_id = getattr(self.engine_args, "dummy_token_id", 0)
self.video_token_id = getattr(self.engine_args, "video_token_id", 32000)
self.dummy_tokens_per_frame = getattr(
self.engine_args, "dummy_tokens_per_frame", 144
)
self._loaded_metadata = set()
self.initialized = False
self.min_workers = 1
# IMPORTANT: PrefillWorker MUST remove dummy tokens before passing to vLLM
# Only the actual video tokens (32000) should remain as placeholders for multimodal embeddings
if self.engine_args.enable_chunked_prefill is not False:
logger.info("Chunked prefill is not supported yet, setting to False")
self.engine_args.enable_chunked_prefill = False
if self.engine_args.pipeline_parallel_size != 1:
logger.info("Pipeline parallel size is not supported yet, setting to 1")
self.engine_args.pipeline_parallel_size = 1
if self.engine_args.disable_async_output_proc is not True:
logger.info("Async output processing is not supported yet, setting to True")
self.engine_args.disable_async_output_proc = True
if self.engine_args.enforce_eager is not True:
logger.info("Prefill must be done eagerly, setting to True")
self.engine_args.enforce_eager = True
if self.engine_args.enable_prefix_caching is not False:
logger.info(
"Prefix caching is not supported yet in prefill worker, setting to False"
)
self.engine_args.enable_prefix_caching = False
signal.signal(signal.SIGTERM, self.shutdown_vllm_engine)
signal.signal(signal.SIGINT, self.shutdown_vllm_engine)
@async_on_start
async def async_init(self):
self._engine_context = build_async_engine_client_from_engine_args(
self.engine_args
)
if self._engine_context is not None:
self.engine_client = await self._engine_context.__aenter__()
else:
raise RuntimeError("Failed to initialize engine client")
# NOTE: PrefillWorker no longer needs AutoProcessor since it uses the original
# tokenized prompt from DecodeWorker instead of creating its own.
logger.info(
"PrefillWorker: Skipping AutoProcessor initialization - using original tokens from DecodeWorker"
)
runtime = dynamo_context["runtime"]
enc_comp_ns, enc_comp_name = VllmEncodeWorker.dynamo_address() # type: ignore
self.encode_worker_client = (
await runtime.namespace(enc_comp_ns)
.component(enc_comp_name)
.endpoint("encode")
.client()
)
# Initialize the connector for RDMA transfers within the specified namespace.
self._connector = connect.Connector(runtime=runtime, namespace=enc_comp_ns)
await self._connector.initialize()
incoming_frames_shape = (
self.num_sampled_frames,
self.frame_height,
self.frame_width,
self.frame_channels,
)
# Pre-allocate a tensor on the CPU to receive frame data.
frames_tensor = torch.empty(
incoming_frames_shape,
dtype=INCOMING_FRAMES_DTYPE,
device=INCOMING_FRAMES_DEVICE,
)
# Create a descriptor for the tensor to make it available for remote access.
descriptor = connect.Descriptor(frames_tensor)
# Register the memory with the connector, making it discoverable.
descriptor.register_memory(self._connector)
self._frames_descriptor = (frames_tensor, descriptor)
await check_required_workers(self.encode_worker_client, self.min_workers)
metadata = self.engine_client.nixl_metadata
self._metadata_store = NixlMetadataStore("dynamo", runtime)
await self._metadata_store.put(metadata.engine_id, metadata)
logger.info("PrefillWorker: Creating prefill_queue_handler task.")
task = asyncio.create_task(self.prefill_queue_handler())
def prefill_queue_handler_cb(fut):
try:
fut.result()
logger.info(
"PrefillWorker: prefill_queue_handler task exited successfully."
)
except asyncio.CancelledError:
logger.info("PrefillWorker: prefill_queue_handler task was cancelled.")
except Exception as e:
logger.error(
f"PrefillWorker: prefill_queue_handler task failed with exception: {e!r}",
exc_info=True,
)
task.add_done_callback(prefill_queue_handler_cb)
logger.info("PrefillWorker: async_init complete.")
def shutdown_vllm_engine(self, signum, frame):
"""Shutdown the background loop"""
logger.info(f"Shutdown started, signal {signum} received.")
loop = asyncio.get_event_loop()
try:
self.engine_client.close()
except Exception as e:
logger.error(f"Error during shutdown: {e}")
finally:
loop.stop()
logger.info("Shutdown complete.")
async def prefill_queue_handler(self):
logger.info("PrefillWorker: Prefill queue handler task started.")
prefill_queue_nats_server = os.getenv("NATS_SERVER", "nats://localhost:4222")
prefill_queue_stream_name = (
self.engine_args.served_model_name
if self.engine_args.served_model_name is not None
else "vllm"
)
logger.info(
f"PrefillWorker: Connecting to prefill queue: {prefill_queue_nats_server}, stream: '{prefill_queue_stream_name}'"
)
self.initialized = True
try:
async with PrefillQueue.get_instance(
nats_server=prefill_queue_nats_server,
stream_name=prefill_queue_stream_name,
) as prefill_queue:
logger.info(
f"PrefillWorker: Entering dequeue loop for stream '{prefill_queue_stream_name}'."
)
while True:
prefill_request: Optional[RemotePrefillRequest] = None
try:
prefill_request = await prefill_queue.dequeue_prefill_request()
except Exception as e:
logger.error(
f"PrefillWorker: Exception during dequeue_prefill_request: {e}",
exc_info=True,
)
await asyncio.sleep(5)
continue
if prefill_request is not None:
logger.info(
f"PrefillWorker: Dequeued prefill request: {prefill_request.request_id}"
)
try:
async for _ in self.generate(prefill_request):
pass
logger.info(
f"PrefillWorker: Successfully processed prefill request {prefill_request.request_id}."
)
except Exception as e:
logger.error(
f"PrefillWorker: Error processing prefill request {prefill_request.request_id} in self.generate: {e}",
exc_info=True,
)
else:
await asyncio.sleep(0.1)
except Exception as e:
logger.error(
f"PrefillWorker: Prefill queue handler CRASHED: {e}", exc_info=True
)
async def generate(self, request: RemotePrefillRequest):
video_url = request.multimodal_data_source.get("video_url")
if video_url is None:
raise ValueError(
"No video_url provided in multimodal_data_source for prefill request"
)
request_id = request.request_id
engine_id = request.engine_id
logger.info(
f"PrefillWorker {request_id}: Received prefill request for video_url: {video_url}."
)
raw_frames_tensor, descriptor = self._frames_descriptor
logger.debug(
f"PrefillWorker {request_id}: Requesting frames from EncodeWorker for {video_url}"
)
# Create a writable operation handle for the remote EncodeWorker.
# This allows the EncodeWorker to write directly into this worker's `frames_tensor`.
with self._connector.create_writable(descriptor) as writable:
encode_generator = await self.encode_worker_client.round_robin(
EncodeRequest(
request_id=request_id,
video_url=video_url,
# Serialize the writable handle to send it to the EncodeWorker.
serialized_request=writable.to_serialized(),
).model_dump_json()
)
async for _ in encode_generator:
pass
# Wait for the remote write from the EncodeWorker to complete.
await writable.wait_for_completion()
logger.debug(
f"PrefillWorker {request_id}: Frames received from EncodeWorker, shape: {raw_frames_tensor.shape}"
)
if not request.prompt_token_ids:
logger.error(
f"PrefillWorker {request_id}: No prompt_token_ids provided in request!"
)
raise ValueError(
"PrefillWorker requires prompt_token_ids from DecodeWorker"
)
# Constants for token manipulation
DUMMY_TOKEN_ID = self.dummy_token_id
VIDEO_TOKEN_ID = self.video_token_id
# Step 1: Find all video token positions
video_token_positions = [
i
for i, token in enumerate(request.prompt_token_ids)
if token == VIDEO_TOKEN_ID
]
logger.debug(
f"PrefillWorker {request_id}: Found {len(video_token_positions)} video tokens at positions: {video_token_positions}"
)
# Step 2: Process tokens from end to start to avoid position shifting
processed_tokens = list(request.prompt_token_ids)
for pos in reversed(video_token_positions):
# Calculate range of tokens to remove (dummy tokens after this video token)
start_idx = pos + 1
end_idx = start_idx + self.dummy_tokens_per_frame
# Check if we have enough tokens to remove
if end_idx > len(processed_tokens):
logger.warning(
f"PrefillWorker {request_id}: Not enough tokens to remove at position {pos}"
)
continue
# Remove the dummy tokens
processed_tokens = processed_tokens[:start_idx] + processed_tokens[end_idx:]
# Step 3: Verify we have exactly one video token left
final_video_count = sum(
1 for token in processed_tokens if token == VIDEO_TOKEN_ID
)
if final_video_count != 1:
logger.error(
f"PrefillWorker {request_id}: Wrong number of video tokens! Expected 1, got {final_video_count}"
)
# Step 4: Check for any remaining dummy tokens (should be none)
remaining_dummies = sum(
1 for token in processed_tokens if token == DUMMY_TOKEN_ID
)
if remaining_dummies > 0:
logger.warning(
f"PrefillWorker {request_id}: Found {remaining_dummies} remaining dummy tokens!"
)
# Create the input for vLLM
prefill_vllm_input = TokensPrompt(
prompt_token_ids=processed_tokens,
multi_modal_data={"video": raw_frames_tensor.numpy()},
)
sampling_params = request.sampling_params
sampling_params.max_tokens = 1
sampling_params.min_tokens = 1
remote_prefill_params = RemotePrefillParams(
is_remote_decode=True,
decode_block_ids=request.block_ids,
decode_engine_id=engine_id,
decode_computed_block_ids=request.computed_block_ids,
)
if engine_id not in self._loaded_metadata:
remote_metadata = await self._metadata_store.get(request.engine_id)
await self.engine_client.add_remote_nixl_metadata(remote_metadata)
logger.info(
f"Loaded nixl metadata from engine {engine_id} into "
f"engine {self.engine_client.nixl_metadata.engine_id}"
)
self._loaded_metadata.add(engine_id)
logger.debug(
f"PrefillWorker {request_id}: Calling engine_client.generate for prefill."
)
async for _ in self.engine_client.generate(
prefill_vllm_input,
sampling_params=sampling_params,
request_id=request_id,
remote_prefill_params=remote_prefill_params,
):
yield
logger.info(f"PrefillWorker {request_id}: Finished processing prefill request.")
@endpoint()
async def mock(self, req: RequestType):
yield f"mock_response: {req}"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import uuid
from enum import Enum
from typing import AsyncIterator, Tuple, Union
from components.video_decode_worker import VllmDecodeWorker
from transformers import AutoTokenizer
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.logging import check_required_workers
from utils.protocol import MultiModalRequest, MyRequestOutput, vLLMMultimodalRequest
from utils.vllm import parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from dynamo.runtime import EtcdKvCache
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
logger = logging.getLogger(__name__)
class RequestType(Enum):
CHAT = "chat"
COMPLETION = "completion"
@service(
dynamo={
"namespace": "dynamo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
class Processor(ProcessMixIn):
"""
vLLM pre and post processing
"""
worker = depends(VllmDecodeWorker)
def __init__(self):
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self.model_config = self.engine_args.create_model_config()
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.tokenizer = self._create_tokenizer(self.engine_args)
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
self.completions_processor = CompletionsProcessor(
self.tokenizer, self.model_config
)
self.min_workers = 1
def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer:
"""Create a TokenizerGroup using engine arguments similar to VLLM's approach"""
model_path = engine_args.model
# Create the base tokenizer with VLLM's typical settings
base_tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
padding_side="left",
truncation_side="left",
use_fast=True, # VLLM might use the fast tokenizer for efficiency
)
return base_tokenizer
@async_on_start
async def async_init(self):
runtime = dynamo_context["runtime"]
comp_ns, comp_name = VllmDecodeWorker.dynamo_address() # type: ignore
self.worker_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("generate")
.client()
)
await check_required_workers(self.worker_client, self.min_workers)
self.etcd_kv_cache = await EtcdKvCache.create(
runtime.etcd_client(),
"/dynamo/processor/",
{"router": self.engine_args.router},
)
# Main method to parse the request and send the request to the vllm worker.
async def _generate(
self,
raw_request: Union[CompletionRequest, ChatCompletionRequest],
image: str,
request_type: RequestType,
):
request_id = str(uuid.uuid4())
logger.debug(f"Got raw request: {raw_request}")
(
request,
conversation,
prompt,
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
worker_request = vLLMMultimodalRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
video_url=image,
)
router_mode = (await self.etcd_kv_cache.get("router")).decode()
if router_mode == "kv":
# The current KV router does not support multimodal requests because
# it performs cache lookup based solely on prompt tokens. At this stage,
# multimodal data (e.g., image features) is not yet available, so the router
# cannot select the optimal worker using both prompt and image inputs.
raise NotImplementedError(
"Multimodal requests are not supported for kv router mode"
)
if router_mode == "random":
response_generator = await self.worker_client.generate(
worker_request.model_dump_json()
)
elif router_mode == "round-robin":
response_generator = await self.worker_client.round_robin(
worker_request.model_dump_json()
)
else:
raise NotImplementedError(f"Router mode {router_mode} not implemented")
output = self._generate_responses(response_generator, request_type)
async for response in await self._stream_response(
request, output, request_id, conversation
):
yield response
# This method is used to process the responses from the engine generator.
async def _generate_responses(
self,
response_generator: AsyncIterator[RequestOutput],
request_type: RequestType,
) -> AsyncIterator[Union[RequestOutput, Tuple[int, RequestOutput]]]:
prompt_idx = 0
async for resp in response_generator:
# Deserialize the response from the engine
# Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data())
# OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object
request_output = RequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
)
if request_type == RequestType.CHAT:
# For chat requests, yield the request_output directly.
yield request_output
elif request_type == RequestType.COMPLETION:
# Completion requests can have multiple prompts and stream generator requires the prompt index
yield (prompt_idx, request_output)
else:
raise NotImplementedError(
f"Request type {request_type} not implemented"
)
# The generate endpoint will be used by the frontend to handle incoming requests.
@endpoint()
async def generate(self, raw_request: MultiModalRequest):
# TODO: Make the template consumed from the config file
msg = {
"role": "user",
"content": "USER: <video>\nQuestion:"
+ raw_request.messages[0].content[0].text
+ " Answer:",
}
chat_request = ChatCompletionRequest(
model=raw_request.model,
messages=[msg],
stream=raw_request.stream,
max_tokens=raw_request.max_tokens,
request_id=str(uuid.uuid4()),
)
video_url = None
for message in raw_request.messages:
for item in message.content:
if item.type == "video_url":
video_url = item.video_url.url
if video_url is None:
raise ValueError("Video URL is required")
async for response in self._generate(chat_request, video_url, RequestType.CHAT):
yield json.dumps(response)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from components.processor import Processor
from fastapi import FastAPI
from fastapi.responses import JSONResponse, StreamingResponse
from utils.protocol import MultiModalRequest
from utils.vllm import parse_vllm_args
from dynamo.sdk import DYNAMO_IMAGE, api, depends, service
logger = logging.getLogger(__name__)
@service(
dynamo={
"namespace": "dynamo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
image=DYNAMO_IMAGE,
app=FastAPI(title="Multimodal Example"),
)
class Frontend:
processor = depends(Processor)
def __init__(self):
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
@api(name="v1/chat/completions")
async def generate(self, request: MultiModalRequest):
if self.engine_args.model != request.model:
return JSONResponse(
{"error": f"Model '{request.model}' not found"},
status_code=404,
)
async def content_generator():
async for response in self.processor.generate(request.model_dump_json()):
try:
s = json.loads(response)
yield s
except json.JSONDecodeError as e:
raise RuntimeError(f"Failed to parse JSON response: {e}")
return StreamingResponse(content_generator(), media_type="text/event-stream")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: llava-hf/llava-1.5-7b-hf
block-size: 64
max-model-len: 4096
Frontend:
common-configs: [model]
Processor:
router: round-robin
prompt-template: "USER: <image>\n<prompt> ASSISTANT:"
common-configs: [model, block-size, max-model-len]
VllmDecodeWorker:
enforce-eager: true
max-num-batched-tokens: 16384
enable-prefix-caching: true
image-token-id: 32000
num-patches: 576
router: random
tensor-parallel-size: 1
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model, block-size, max-model-len]
VllmEncodeWorker:
tensor-parallel-size: 1
router: random
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: microsoft/Phi-3.5-vision-instruct
block-size: 64
max-model-len: 4096
trust-remote-code: true
Frontend:
common-configs: [model]
Processor:
router: round-robin
prompt-template: "<|user|>\n<|image_1|>\n<prompt><|end|>\n<|assistant|>\n"
common-configs: [model, block-size, max-model-len, trust-remote-code]
VllmDecodeWorker:
enforce-eager: true
max-num-batched-tokens: 16384
max-num-seqs: 2
mm-processor-kwargs:
num_crops: 16
enable-prefix-caching: true
image-token-id: 32000
num-patches: 757
router: random
tensor-parallel-size: 1
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model, block-size, max-model-len, trust-remote-code]
VllmEncodeWorker:
tensor-parallel-size: 1
router: random
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: Qwen/Qwen2.5-VL-7B-Instruct
block-size: 64
max-model-len: 4096
Frontend:
common-configs: [model]
Processor:
router: round-robin
prompt-template: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|><prompt><|im_end|>\n<|im_start|>assistant\n"
common-configs: [model, block-size, max-model-len]
VllmDecodeWorker:
enforce-eager: true
max-num-batched-tokens: 16384
max-num-seqs: 5
mm-processor-kwargs:
min_pixels: 784
max_pixels: 1003520
fps: 1
enable-prefix-caching: true
image-token-id: 151655
num-patches: 345
router: random
tensor-parallel-size: 1
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model, block-size, max-model-len]
VllmEncodeWorker:
tensor-parallel-size: 1
router: random
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: llava-hf/LLaVA-NeXT-Video-7B-hf
block-size: 64
max-model-len: 4096
num-sampled-frames: 8
frame-height: 336
frame-width: 336
frame-channels: 3
dummy-token-id: 0
video-token-id: 32000
dummy-tokens-per-frame: 144
Frontend:
common-configs: [model]
Processor:
router: round-robin
common-configs: [model, block-size, max-model-len]
VllmDecodeWorker:
enforce-eager: true
max-num-batched-tokens: 16384
enable-prefix-caching: true
router: random
tensor-parallel-size: 1
ServiceArgs:
workers: 1
resources:
gpu: 1
common-configs: [model, block-size, max-model-len, num-sampled-frames, frame-height, frame-width, frame-channels, dummy-token-id, video-token-id, dummy-tokens-per-frame]
VllmEncodeWorker:
tensor-parallel-size: 1
router: random
ServiceArgs:
workers: 1
resources:
gpu: 1
common-configs: [model, num-sampled-frames, frame-height, frame-width, frame-channels, dummy-token-id, video-token-id, dummy-tokens-per-frame]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: llava-hf/llava-1.5-7b-hf
block-size: 64
max-model-len: 4096
image-token-id: 32000
num-patches: 576
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
Frontend:
common-configs: [model]
Processor:
router: round-robin
prompt-template: "USER: <image>\n<prompt> ASSISTANT:"
common-configs: [model, block-size]
VllmDecodeWorker:
remote-prefill: true
conditional-disagg: true
max-local-prefill-length: 10
max-prefill-queue-size: 2
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model, block-size, image-token-id, max-model-len, num-patches, kv-transfer-config]
VllmPrefillWorker:
max-num-batched-tokens: 16384
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model, block-size, image-token-id, max-model-len, num-patches, kv-transfer-config]
VllmEncodeWorker:
tensor-parallel-size: 1
router: random
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: llava-hf/LLaVA-NeXT-Video-7B-hf
block-size: 64
max-model-len: 8192
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
num-sampled-frames: 8
frame-height: 336
frame-width: 336
frame-channels: 3
dummy-token-id: 0
video-token-id: 32000
dummy-tokens-per-frame: 144
Frontend:
common-configs: [model]
Processor:
router: round-robin
common-configs: [model, block-size]
VllmDecodeWorker:
remote-prefill: true
conditional-disagg: false
max-local-prefill-length: 50
max-prefill-queue-size: 2
ServiceArgs:
workers: 1
resources:
gpu: 1
common-configs: [model, block-size, max-model-len, kv-transfer-config, num-sampled-frames, frame-height, frame-width, frame-channels, dummy-token-id, video-token-id, dummy-tokens-per-frame]
VllmPrefillWorker:
max-num-batched-tokens: 16384
ServiceArgs:
workers: 1
resources:
gpu: 1
common-configs: [model, block-size, max-model-len, kv-transfer-config, num-sampled-frames, frame-height, frame-width, frame-channels, dummy-token-id, video-token-id, dummy-tokens-per-frame]
VllmEncodeWorker:
tensor-parallel-size: 1
router: random
ServiceArgs:
workers: 1
resources:
gpu: 1
common-configs: [model, num-sampled-frames, frame-height, frame-width, frame-channels, dummy-token-id, video-token-id, dummy-tokens-per-frame]
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Dynamo Connect
Dynamo connect provides a Pythonic interface to the NIXL base RDMA subsystem via a set of Python classes.
The primary goal of this library to simplify the integration of NIXL based RDMA into inference applications.
All operations using the Connect library begin with the [`Connector`](#connector) class and the type of operation required.
There are four types of supported operations:
- **Register local readable memory**:
Register local memory buffer(s) with the RDMA subsystem to enable a remote worker to read from.
- **Register local writable memory**:
Register local memory buffer(s) with the RDMA subsystem to enable a remote worker to write to.
- **Read from registered, remote memory**:
Read remote memory buffer(s), registered by a remote worker to be readable, into local memory buffer(s).
- **Write to registered, remote memory**:
Write local memory buffer(s) to remote memory buffer(s) registered by a remote worker to writable.
By connecting correctly paired operations, high-throughput GPU Direct RDMA data transfers can be completed.
Given the list above, the correct pairing of operations would be 1 & 3 or 2 & 4.
Where one side is a "(read|write)-able operation" and the other is its correctly paired "(read|write) operation".
Specifically, a read operation must be paired with a readable operation, and a write operation must be paired with a writable operation.
## Examples
### Generic Example
In the diagram below, Local creates a [`WritableOperation`](#writableoperation) intended to receive data from Remote.
Local then sends metadata about the requuested RDMA operation to Remote.
Remote then uses the metadata to create a [`WriteOperation`](#writeoperation) which will perform the GPU Direct RDMA memory transfer from Remote's GPU memory to Local's GPU memory.
```mermaid
---
title: Write Operation Between Two Workers
---
flowchart LR
c1[Remote] --"3: .begin_write()"--- WriteOperation
WriteOperation e1@=="4: GPU Direct RDMA"==> WritableOperation
WritableOperation --"1: .create_writable()"--- c2[Local]
c2 e2@--"2: RDMA Metadata via HTTP"--> c1
e1@{ animate: true; }
e2@{ animate: true; }
```
### Multimodal Example
In the case of the [Dynamo Multimodal Disaggregated Example](../README.md):
1. The HTTP frontend accepts a text prompt and a URL to an image.
2. The prompt and URL are then enqueued with the Processor before being dispatched to the first available Decode Worker.
3. Decode Worker then requests a Prefill Worker to provide key-value data for the LLM powering the Decode Worker.
4. Prefill Worker then requests that the image be processed and provided as embeddings by the Encode Worker.
5. Encode Worker acquires the image, processes it, performs inference on the image using a specialized vision model, and finally provides the embeddings to Prefill Worker.
6. Prefill Worker receives the embeddings from Encode Worker and generates a key-value cache (KV$) update for Decode Worker's LLM and writes the update directly to the GPU memory reserved for the data.
7. Finally, Decode Worker performs the requested inference.
```mermaid
---
title: Multimodal Disaggregated Workflow
---
flowchart LR
p0[HTTP Frontend] i0@--"text prompt"-->p1[Processor]
p0 i1@--"url"-->p1
p1 i2@--"prompt"-->dw[Decode Worker]
p1 i3@--"url"-->dw
dw i4@--"prompt"-->pw[Prefill Worker]
dw i5@--"url"-->pw
pw i6@--"url"-->ew[Encode Worker]
ew o0@=="image embeddings"==>pw
pw o1@=="kv_cache updates"==>dw
dw o2@--"inference results"-->p0
i0@{ animate: true; }
i1@{ animate: true; }
i2@{ animate: true; }
i3@{ animate: true; }
i4@{ animate: true; }
i5@{ animate: true; }
i6@{ animate: true; }
o0@{ animate: true; }
o1@{ animate: true; }
o2@{ animate: true; }
```
_Note: In this example, it is the data transfer between the Prefill Worker and the Encode Worker that utilizes the Dynamo Connect library. The KV Cache transfer between Decode Worker and Prefill Worker utilizes the NIXL base RDMA subsystem directly without using the Dynamo Connect library._
#### Code Examples
See [prefill_worker](../components/prefill_worker.py#L199) or [decode_worker](../components/decode_worker.py#L239),
for how they coordinate directly with the Encode Worker by creating a [`WritableOperation`](#writableoperation),
sending the operation's metadata via Dynamo's round-robin dispatcher, and awaiting the operation for completion before making use of the transferred data.
See [encode_worker](../components/encode_worker.py#L190),
for how the resulting embeddings are registered with the RDMA subsystem by creating a [`Descriptor`](#descriptor),
a [`WriteOperation`](#writeoperation) is created using the metadata provided by the requesting worker,
and the worker awaits for the data transfer to complete for yielding a response.
## Python Classes
### Connector
Core class for managing the connection between workers in a distributed environment.
Use this class to create readable and writable operations, or read and write data to remote workers.
This class is responsible for interfacing with the NIXL-based RDMA subsystem and providing a "Pythonic" interface
with which to utilize GPU Direct RDMA accelerated data transfers between models hosted by different workers in a Dynamo pipeline.
The connector provides two methods of moving data between workers:
- Preparing local memory to be written to by a remote worker.
- Preparing local memory to be read by a remote worker.
In both cases, local memory is registered with the NIXL-based RDMA subsystem via the [`Descriptor`](#descriptor) class and provided to the connector.
The connector then configures the RDMA subsystem to expose the memory for the requested operation and returns an operation control object.
The operation control object, either a [`ReadableOperation`](#readableoperation) or a [`WritableOperation`](#writableoperation),
provides RDMA metadata via its [`.to_serialized()`](#to_serialized) method as well as functionality to know when the operation has been completed or cancel the operation prior to completion.
The RDMA metadata must be provided to the remote worker expected to complete the operation.
The metadata contains required information (identifiers, keys, etc.) which enables the remote worker to interact with the provided memory.
#### Methods
##### `begin_read`
> Creates a [`ReadOperation`](#readoperation) for transferring data from a remote worker.
>
> To create the operation, the serialized request from a remote worker's [`ReadableOperation`](#readableoperation)
> along with a matching set of local memory descriptors which reference memory intended to receive data from the remote worker
> must be provided.
> The serialized request must be transferred from the remote to the local worker via a secondary channel, most likely HTTP or TCP+NATS.
>
> Once created, the operation will begin reading immediately.
> Disposal of the object reference will instruct the RDMA subsystem to cancel the read operation,
> therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
##### `begin_write`
> Creates a write operation for transferring data to a remote worker.
>
> To create the operation, the serialized request from a remote worker's [`WritableOperation`](#writableoperation)
> along with a matching set of local memory descriptors which reference memory to be transferred to the remote worker
> must be provided.
> The serialized request must be transferred from the remote to the local worker via a secondary channel, most likely HTTP or TCP+NATS.
>
> Once created, the operation will begin writing immediately.
> Disposal of the object reference will instruct the RDMA subsystem to cancel the write operation,
> therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
##### `create_readable`
> Creates a [`ReadableOperation`](#readableoperation) for transferring data to a remote worker.
>
> To create the operation, a set of local memory descriptors must be provided that reference memory intended to be transferred to
> a remote worker.
> Once created, the memory referenced by the provided descriptors becomes immediately readable by a remote worker with the necessary metadata.
> The metadata required to access the memory referenced by the provided descriptors is accessible via the operations `.to_serialized()` method.
> Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
>
> Disposable of the operation's object reference will instruct the RDMA subsystem to cancel the operation,
> therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
##### `create_writable`
> Creates a [`WritableOperation`](#writableoperation) for transferring data from a remote worker.
>
> To create the operation, a set of local memory descriptors must be provided which reference memory intended to receive data from
> a remote worker.
> Once created, the memory referenced by the provided descriptors becomes immediately writable by a remote worker with the necessary metadata.
> The metadata required to access the memory referenced by the provided descriptors is accessible via the operations `.to_serialized()` method.
> Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
>
> Disposable of the operation's object reference will instruct the RDMA subsystem to cancel the operation,
> therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
### Descriptor
Memory descriptor that ensures memory is registered with the NIXL base RDMA subsystem.
Memory must be registered with the RDMA subsystem to enable interaction with the memory.
Descriptor objects are administrative and do not copy, move, or otherwise modify the registered memory.
There are four ways to create a descriptor:
1. From a `torch.Tensor` object. Device information will be derived from the provided object.
2. From a `tuple` containing either a NumPy or CuPy `ndarray` and information desribing where the memory resides (Host/CPU vs GPU).
3. From a Python `bytes` object. Memory is assumed to reside in CPU addressable host memory.
4. From a `tuple` comprised of the address of the memory, its size in bytes, and device information.
An optional reference to a Python object can be provided to avoid garbage collection issues.
### Device
Device describes the device, or kind of memory, a given allocation resides in.
Usually host (`"cpu"`) or GPU (`"cuda"`) memory.
When a system contains multiple GPU devices, specific GPU devices can be identified by including their ordinal index number.
For example, to reference the second GPU in a system `"cuda:1"` can be used.
By default, when `"cuda"` is provided, it is assumed to be `"cuda:0"` or the first GPU enumerated by the system.
### ReadOperation
An operation which transfers data from a remote worker to the local worker.
To create the operation, RDMA metadata ([`SerializedRequest`](#serializedrequest)) from a remote worker's [`ReadableOperation`](#readableoperation)
along with a matching set of local [`Descriptor`](#descriptor) objects which reference memory intended to receive data from the remote worker must be provided.
The RDMA metadata must be transferred from the remote to the local worker via a secondary channel, most likely HTTP or TCP+NATS.
Once created, the operation will begin reading immediately.
Disposal of the object reference will instruct the RDMA subsystem to cancel the read operation,
therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
#### Methods
##### `cancel`
> Instructs the RDMA subsystem to cancel the operation.
> Completed operations cannot be cancelled.
##### `wait_for_completion`
> Blocks the caller until the memory from the remote worker has been transferred to the provided buffers.
### ReadableOperation
An operation which enables a remote worker to read data from the local worker.
To create the operation, a set of local [`Descriptor`](#descriptor) objects must be provided that reference memory intended to be transferred to a remote worker.
Once created, the memory referenced by the provided descriptors becomes immediately readable by a remote worker with the necessary metadata.
The metadata required to access the memory referenced by the provided descriptors is accessible via the operations `.to_serialized()` method.
Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
Disposal of the operation's object reference will instruct the RDMA subsystem to cancel the operation,
therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
#### Methods
##### `to_serialized`
> Generates and returns the RDMA metadata ([`SerializedRequest`](#serializedrequest)) required for a remote worker to read from the operation.
> Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
##### `wait_for_completion`
> Blocks the caller until the operation has received a completion signal from a remote worker.
### WriteOperation
An operation which transfers data from the local worker to a remote worker.
To create the operation, RDMA metadata ([`SerializedRequest`](#serializedrequest)) from a remote worker's [`WritableOperation`](#writableoperation)
along with a matching set of local [`Descriptor`](#descriptor) objects which reference memory to be transferred to the remote worker must be provided.
The RDMA metadata must be transferred from the remote to the local worker via a secondary channel, most likely HTTP or TCP+NATS.
Once created, the operation will begin writing immediately.
Disposal of the object reference will instruct the RDMA subsystem to cancel the write operation,
therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
#### Methods
##### `cancel`
> Instructs the RDMA subsystem to cancel the operation.
> Completed operations cannot be cancelled.
##### `wait_for_completion`
> Blocks the caller until all provided buffers have been transferred to the remote worker.
### WritableOperation
An operation which enables a remote worker to write data to the local worker.
To create the operation, a set of local [`Descriptor`](#descriptor) objects must be provided which reference memory intended to receive data from a remote worker.
Once created, the memory referenced by the provided descriptors becomes immediately writable by a remote worker with the necessary metadata.
The metadata required to access the memory referenced by the provided descriptors is accessible via the operations `.to_serialized()` method.
Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
Disposal of the operation's object reference will instruct the RDMA subsystem to cancel the operation,
therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
#### Methods
##### `to_serialized`
> Generates and returns the RDMA metadata ([`SerializedRequest`](#serializedrequest)) required for a remote worker to write to the operation.
> Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
##### `wait_for_completion`
> Blocks the caller until the operation has received a completion signal from a remote worker.
### SerializedRequest
A Pydantic type intended to provide JSON serialized RDMA metadata about a [`ReadableOperation`](#readableoperation) or [`WritableOperation`](#writableoperation) object.
Use the [`.to_serialized()`](#to_serialized) method on either of the above types to generate a `SerializedRequest` object for an operation.
## References
- [NVIDIA Dynamo](https://developer.nvidia.com/dynamo) @ [GitHub](https://github.com/ai-dynamo/dynamo)
- [NVIDIA Inference Transfer Library (NIXL)](https://developer.nvidia.com/blog/introducing-nvidia-dynamo-a-low-latency-distributed-inference-framework-for-scaling-reasoning-ai-models/#nvidia_inference_transfer_library_nixl_low-latency_hardware-agnostic_communication%C2%A0) @ [GitHub](https://github.com/ai-dynamo/nixl)
- [Dynamo Multimodal Example](https://github.com/ai-dynamo/dynamo/tree/main/examples/multimodal)
- [NVIDIA GPU Direct](https://developer.nvidia.com/gpudirect)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment