"vscode:/vscode.git/clone" did not exist on "20ad730cfde470de79a59eae6ed20938a23ace3c"
Unverified Commit 9b87c89c authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

feat: multi-modal example with vLLM v1 and UX v2 (#2040)


Co-authored-by: default avatarkrishung5 <krish@nvidia.com>
parent 1327e3bb
...@@ -12,7 +12,7 @@ ARG RUNTIME_IMAGE="nvcr.io/nvidia/cuda" ...@@ -12,7 +12,7 @@ ARG RUNTIME_IMAGE="nvcr.io/nvidia/cuda"
ARG RUNTIME_IMAGE_TAG="12.8.1-runtime-ubuntu24.04" ARG RUNTIME_IMAGE_TAG="12.8.1-runtime-ubuntu24.04"
# Make sure to update the dependency version in pyproject.toml when updating this # Make sure to update the dependency version in pyproject.toml when updating this
ARG VLLM_REF="f4135232b9a8c4845f8961fb1cd17581c56ae2ce" ARG VLLM_REF="ba81acbdc1eec643ba815a76628ae3e4b2263b76"
ARG TORCH_BACKEND="cu128" ARG TORCH_BACKEND="cu128"
# Match 0.10.0 vLLM release # Match 0.10.0 vLLM release
...@@ -186,6 +186,7 @@ RUN if [ "$ARCH" = "arm64" ]; then \ ...@@ -186,6 +186,7 @@ RUN if [ "$ARCH" = "arm64" ]; then \
# Install vllm - keep this early in Dockerfile to avoid # Install vllm - keep this early in Dockerfile to avoid
# rebuilds from unrelated source code changes # rebuilds from unrelated source code changes
ARG VLLM_REF ARG VLLM_REF
ARG VLLM_GIT_URL
ARG DEEPGEMM_REF ARG DEEPGEMM_REF
ARG FLASHINF_REF ARG FLASHINF_REF
......
...@@ -20,7 +20,8 @@ set -euo pipefail ...@@ -20,7 +20,8 @@ set -euo pipefail
# Parse arguments # Parse arguments
EDITABLE=true EDITABLE=true
VLLM_REF="f4135232b9a8c4845f8961fb1cd17581c56ae2ce" VLLM_REF="ba81acbdc1eec643ba815a76628ae3e4b2263b76"
VLLM_GIT_URL="https://github.com/vllm-project/vllm.git"
MAX_JOBS=16 MAX_JOBS=16
INSTALLATION_DIR=/tmp INSTALLATION_DIR=/tmp
ARCH=$(uname -m) ARCH=$(uname -m)
...@@ -49,6 +50,10 @@ while [[ $# -gt 0 ]]; do ...@@ -49,6 +50,10 @@ while [[ $# -gt 0 ]]; do
VLLM_REF="$2" VLLM_REF="$2"
shift 2 shift 2
;; ;;
--vllm-git-url)
VLLM_GIT_URL="$2"
shift 2
;;
--max-jobs) --max-jobs)
MAX_JOBS="$2" MAX_JOBS="$2"
shift 2 shift 2
...@@ -113,7 +118,7 @@ uv pip install lmcache ...@@ -113,7 +118,7 @@ uv pip install lmcache
# Create vllm directory and clone # Create vllm directory and clone
mkdir -p $INSTALLATION_DIR mkdir -p $INSTALLATION_DIR
cd $INSTALLATION_DIR cd $INSTALLATION_DIR
git clone https://github.com/vllm-project/vllm.git git clone $VLLM_GIT_URL vllm
cd vllm cd vllm
git checkout $VLLM_REF git checkout $VLLM_REF
...@@ -148,7 +153,7 @@ fi ...@@ -148,7 +153,7 @@ fi
# Install ep_kernels and DeepGEMM # Install ep_kernels and DeepGEMM
echo "Installing ep_kernels and DeepGEMM" echo "Installing ep_kernels and DeepGEMM"
cd tools/ep_kernels cd tools/ep_kernels
bash install_python_libraries.sh # These libraries aren't pinned. TORCH_CUDA_ARCH_LIST="9.0;10.0" bash install_python_libraries.sh # These libraries aren't pinned.
cd ep_kernels_workspace cd ep_kernels_workspace
git clone https://github.com/deepseek-ai/DeepGEMM.git git clone https://github.com/deepseek-ai/DeepGEMM.git
cd DeepGEMM cd DeepGEMM
......
...@@ -42,7 +42,10 @@ class MetricsPublisher(StatLoggerBase): ...@@ -42,7 +42,10 @@ class MetricsPublisher(StatLoggerBase):
logger.info(f"ZMQ publisher initialized on port {port}") logger.info(f"ZMQ publisher initialized on port {port}")
def record( def record(
self, scheduler_stats: SchedulerStats, iteration_stats: Optional[IterationStats] self,
scheduler_stats: SchedulerStats,
iteration_stats: Optional[IterationStats],
engine_idx: int = 0,
): ):
# Send metrics over ZMQ # Send metrics over ZMQ
metrics_data = { metrics_data = {
......
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Multimodal Deployment Examples
This directory provides example workflows and reference implementations for deploying a multimodal model using Dynamo and vLLM v1.
## Use the Latest Release
We recommend using the latest stable release of dynamo to avoid breaking changes:
[![GitHub Release](https://img.shields.io/github/v/release/ai-dynamo/dynamo)](https://github.com/ai-dynamo/dynamo/releases/latest)
You can find the latest release [here](https://github.com/ai-dynamo/dynamo/releases/latest) and check out the corresponding branch with:
```bash
git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
```
## Multimodal Aggregated Serving
### Components
- workers: For aggregated serving, we have two workers, [VllmEncodeWorker](components/encode_worker.py) for encoding and [VllmPDWorker](components/worker.py) for prefilling and decoding.
- processor: Tokenizes the prompt and passes it to the VllmEncodeWorker.
- frontend: HTTP endpoint to handle incoming requests.
### Graph
In this graph, we have two workers, [VllmEncodeWorker](components/encode_worker.py) and [VllmPDWorker](components/worker.py).
The VllmEncodeWorker is responsible for encoding the image and passing the embeddings to the VllmPDWorker via a combination of NATS and RDMA.
The work complete event is sent via NATS, while the embeddings tensor is transferred via RDMA through the NIXL interface.
Its VllmPDWorker then prefills and decodes the prompt, just like the [LLM aggregated serving](/components/backends/vllm/README.md) example.
By separating the encode from the prefill and decode stages, we can have a more flexible deployment and scale the
VllmEncodeWorker independently from the prefill and decode workers if needed.
This figure shows the flow of the graph:
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --image_url--> encode_worker
encode_worker --> processor
encode_worker --embeddings--> pd_worker
pd_worker --> encode_worker
```
```bash
cd $DYNAMO_HOME/examples/multimodal_v1
# Serve a LLaVA 1.5 7B model:
bash launch/agg.sh --model llava-hf/llava-1.5-7b-hf
# Serve a Qwen2.5-VL model:
# bash launch/agg.sh --model Qwen/Qwen2.5-VL-7B-Instruct
# Serve a Phi3V model:
# bash launch/agg.sh --model microsoft/Phi-3.5-vision-instruct
```
### Client
In another terminal:
```bash
curl http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llava-hf/llava-1.5-7b-hf",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What is in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
}
}
]
}
],
"max_tokens": 300,
"temperature": 0.0,
"stream": false
}'
```
If serving the example Qwen model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"Qwen/Qwen2.5-VL-7B-Instruct"`. If serving the example Phi3V model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"microsoft/Phi-3.5-vision-instruct"`.
You should see a response similar to this:
```json
{"id": "c37b946e-9e58-4d54-88c8-2dbd92c47b0c", "object": "chat.completion", "created": 1747725277, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " In the image, there is a city bus parked on a street, with a street sign nearby on the right side. The bus appears to be stopped out of service. The setting is in a foggy city, giving it a slightly moody atmosphere."}, "finish_reason": "stop"}]}
```
## Multimodal Disaggregated Serving
### Components
- workers: For disaggregated serving, we have three workers, [VllmEncodeWorker](components/encode_worker.py) for encoding, [VllmDecodeWorker](components/worker.py) for decoding, and [VllmPDWorker](components/worker.py) for prefilling.
- processor: Tokenizes the prompt and passes it to the VllmEncodeWorker.
- frontend: HTTP endpoint to handle incoming requests.
### Graph
In this graph, we have three workers, [VllmEncodeWorker](components/encode_worker.py), [VllmDecodeWorker](components/worker.py), and [VllmPDWorker](components/worker.py).
For the Llava model, embeddings are only required during the prefill stage. As such, the VllmEncodeWorker is connected directly to the prefill worker.
The VllmEncodeWorker is responsible for encoding the image and passing the embeddings to the prefill worker via a combination of NATS and RDMA.
Its work complete event is sent via NATS, while the embeddings tensor is transferred via RDMA through the NIXL interface.
The prefill worker performs the prefilling step and forwards the KV cache to the decode worker for decoding.
For more details on the roles of the prefill and decode workers, refer to the [LLM disaggregated serving](/components/backends/vllm/README.md) example.
This figure shows the flow of the graph:
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --image_url--> encode_worker
encode_worker --> processor
encode_worker --embeddings--> prefill_worker
prefill_worker --> encode_worker
prefill_worker --> decode_worker
decode_worker --> prefill_worker
```
```bash
cd $DYNAMO_HOME/examples/multimodal_v1
bash launch/disagg.sh --model llava-hf/llava-1.5-7b-hf
```
### Client
In another terminal:
```bash
curl http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llava-hf/llava-1.5-7b-hf",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What is in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
}
}
]
}
],
"max_tokens": 300,
"temperature": 0.0,
"stream": false
}'
```
You should see a response similar to this:
```json
{"id": "c1774d61-3299-4aa3-bea1-a0af6c055ba8", "object": "chat.completion", "created": 1747725645, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " This image shows a passenger bus traveling down the road near power lines and trees. The bus displays a sign that says \"OUT OF SERVICE\" on its front."}, "finish_reason": "stop"}]}
```
***Note***: disaggregation is currently only confirmed to work with LLaVA. Qwen VL and PhiV are not confirmed to be supported.
## Llama 4 family Serving
The family of Llama 4 models is natively multimodal, however, different
from Llava, they do not directly consume image embedding as input
(see the [support metrics](https://docs.vllm.ai/en/latest/models/supported_models.html#text-generation_1)
from vLLM for the types of multi-modal inputs supported by the model).
Therefore, encoder worker will not be used in the following example and the
encoding will be done along side with prefill.
`meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8` will be used as an example
for the content below. And the system will be H100x8 which can hold one instance
of the model per node.
### Multimodal Aggregated Serving
#### Components
- workers: For aggregated serving, we have one worker, [VllmPDWorker](components/worker.py) for prefilling and decoding.
- processor: Tokenizes the prompt and passes it to the VllmPDWorker.
- frontend: HTTP endpoint to handle incoming requests.
#### Graph
In this graph, we have [VllmPDWorker](components/worker.py) which will encode the image, prefill and decode the prompt, just like the [LLM aggregated serving](/components/backends/vllm/README.md) example.
This figure shows the flow of the graph:
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --image_url--> pd_worker
pd_worker --> processor
```
```bash
cd $DYNAMO_HOME/examples/multimodal_v1
bash launch/agg_llama.sh
```
#### Client
In another terminal:
```bash
curl http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What is in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
}
}
]
}
],
"max_tokens": 300,
"temperature": 0.0,
"stream": false
}'
```
You should see a response similar to this:
```json
{"id": "b8f060fa95584e34b9204eaba7b105cc", "object": "chat.completion", "created": 1752706281, "model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", "choices": [{"index": 0, "message": {"role": "assistant", "content": "The image depicts a street scene with a trolley bus as the central focus. The trolley bus is positioned on the left side of the road, facing the camera, and features a white and yellow color scheme. A prominent sign on the front of the bus reads \"OUT OF SERVICE\" in orange letters.\n\n**Key Elements:**\n\n* **Trolley Bus:** The bus is the main subject of the image, showcasing its distinctive design and color.\n* **Sign:** The \"OUT OF SERVICE\" sign is clearly visible on the front of the bus, indicating its current status.\n* **Street Scene:** The surrounding environment includes trees, buildings, and power lines, creating a sense of context and atmosphere.\n* **Lighting:** The image is characterized by a misty or foggy quality, with soft lighting that adds to the overall ambiance.\n\n**Overall Impression:**\n\nThe image presents a serene and somewhat melancholic scene, with the out-of-service trolley bus serving as a focal point. The misty atmosphere and soft lighting contribute to a dreamy or nostalgic feel, inviting the viewer to reflect on the scene."}, "finish_reason": "stop"}]}
```
### Multimodal Disaggregated Serving
#### Components
- workers: For disaggregated serving, we have two workers, [VllmDecodeWorker](components/worker.py) for decoding, and [VllmPDWorker](components/worker.py) for encoding and prefilling.
- processor: Tokenizes the prompt and passes it to the VllmPDWorker.
- frontend: HTTP endpoint to handle incoming requests.
#### Graph
In this graph, we have two workers, [VllmDecodeWorker](components/worker.py), and [VllmPDWorker](components/worker.py).
The prefill worker performs the encoding and prefilling steps and forwards the KV cache to the decode worker for decoding.
For more details on the roles of the prefill and decode workers, refer to the [LLM disaggregated serving](/components/backends/vllm/README.md) example.
This figure shows the flow of the graph:
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --image_url--> prefill_worker
prefill_worker --> processor
prefill_worker --> decode_worker
decode_worker --> prefill_worker
```
```bash
cd $DYNAMO_HOME/examples/multimodal_v1
bash launch/disagg_llama.sh --head-node
# On a separate node that has finished standard dynamo setup, i.e.
# the worker node needs NATS_SERVER and ETCD_ENDPOINTS environment variables
# pointing to the head node's external IP address for distributed coordination
cd $DYNAMO_HOME/examples/multimodal_v1
bash launch/disagg_llama.sh
```
#### Client
In another terminal:
```bash
curl http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What is in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
}
}
]
}
],
"max_tokens": 300,
"temperature": 0.0,
"stream": false
}'
```
You should see a response similar to this:
```json
{"id": "6cc99123ad6948d685b8695428238d4b", "object": "chat.completion", "created": 1752708043, "model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", "choices": [{"index": 0, "message": {"role": "assistant", "content": "The image depicts a street scene with a trolley bus as the central focus. The trolley bus is positioned on the left side of the road, facing the camera, and features a white and yellow color scheme. A prominent sign on the front of the bus reads \"OUT OF SERVICE\" in orange letters.\n\n**Key Elements:**\n\n* **Trolley Bus:** The bus is the main subject of the image, showcasing its distinctive design and color.\n* **Sign:** The \"OUT OF SERVICE\" sign is clearly visible on the front of the bus, indicating its current status.\n* **Street Scene:** The surrounding environment includes trees, buildings, and power lines, creating a sense of context and atmosphere.\n* **Lighting:** The image is characterized by a misty or foggy quality, with soft lighting that adds to the overall mood.\n\n**Overall Impression:**\n\nThe image presents a serene and somewhat melancholic scene, with the out-of-service trolley bus serving as a focal point. The misty atmosphere and soft lighting contribute to a contemplative ambiance, inviting the viewer to reflect on the situation."}, "finish_reason": "stop"}]}
```
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import logging
import os
import signal
import sys
from typing import AsyncIterator, Tuple
import torch
import uvloop
from transformers import AutoImageProcessor, LlavaForConditionalGeneration
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
import connect
from utils.args import Config, base_parse_args, parse_endpoint
from utils.image_loader import ImageLoader
from utils.protocol import MyRequestOutput, vLLMMultimodalRequest
configure_dynamo_logging()
logger = logging.getLogger(__name__)
try:
import cupy as array_module
if not array_module.cuda.is_available():
raise ImportError("CUDA is not available.")
DEVICE = "cuda"
logger.info("Using cupy for array operations (GPU mode).")
except ImportError as e:
logger.warning(f"Failed to import cupy, falling back to numpy: {e}.")
import numpy as array_module
DEVICE = "cpu"
CACHE_SIZE_MAXIMUM = 8
class VllmEncodeWorker:
def __init__(self, args: argparse.Namespace, engine_args: AsyncEngineArgs) -> None:
self.downstream_endpoint = args.downstream_endpoint
self.engine_args = engine_args
self.model = self.engine_args.model
self.image_loader = ImageLoader(cache_size=CACHE_SIZE_MAXIMUM)
self.image_processor = AutoImageProcessor.from_pretrained(
self.model, trust_remote_code=True
)
# self.vision_model = load_vision_model(self.model)
self.vision_model = LlavaForConditionalGeneration.from_pretrained(
self.model, device_map="auto", torch_dtype=torch.float16
).eval()
self.min_workers = 1
def cleanup(self):
pass
async def generate(
self, request: vLLMMultimodalRequest
) -> AsyncIterator[MyRequestOutput]:
logger.debug(f"Got raw request: {request}")
if not isinstance(request, vLLMMultimodalRequest):
if isinstance(request, str):
request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received encode request: {{ id: {request.request_id} }}.")
request_id = request.request_id
# The following steps encode the requested image and provided useful embeddings.
# 1. Open the image from the provided URL.
# 2. Process the image using the image processor.
# 3. Run the image through the vision model's vision tower.
# 4. Run the results of the vision tower through the multi-modal projector.
# 5. Create a descriptor for the embeddings.
# 6. Create a write operation using the serialized request and the descriptor.
# 7. Await for the write operation to complete.
# 8. Yield the encode response.
try:
image = await self.image_loader.load_image(request.image_url)
logger.debug(f"Processing image for request: {{ id: {request_id} }}")
image_embeds = self.image_processor(images=image, return_tensors="pt")
# [gluo NOTE] The commented section is for VLM generalization support,
# will use more generic approach once utils/model.py is fixed,
# see utils/models.py for details.
# # Add a batch dimension to everything
# for item in image_embeds:
# image_embeds[item] = image_embeds[item].unsqueeze(0).to(DEVICE)
# logger.debug(f"Image embeds: {image_embeds}")
# image_grid_thw = (
# image_embeds["image_grid_thw"].tolist()
# if "image_grid_thw" in image_embeds
# else None
# )
# image_sizes = (
# image_embeds["image_sizes"].tolist()
# if "image_sizes" in image_embeds
# else [image.size]
# )
# logger.debug(
# f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}"
# )
# with torch.no_grad():
# embeddings = self.vision_model.get_multimodal_embeddings(**image_embeds)
# if isinstance(embeddings, tuple) or isinstance(embeddings, list):
# # The result multimodal_embeddings may be a list or tuple of tensors, with each
# # tensor corresponding to a multimodal data item (image or video).
# # TODO: for multi-image support, this result will contain multiple tensors.
# embeddings = embeddings[0].unsqueeze(0)
# logger.debug(
# f"Embeddings: {{ shape: {embeddings.shape}, dtype: {embeddings.dtype}, device: {embeddings.device}, ptr: {embeddings.data_ptr()}, elements: {{ count: {embeddings.numel()}, size: {embeddings.element_size()} }} }}."
# )
with torch.no_grad():
logger.debug(f"Vision model device: {self.vision_model.device}")
vision_outputs = self.vision_model.vision_tower(
image_embeds["pixel_values"].to(self.vision_model.device)
)
logger.debug("Vision model completed.")
embeddings = vision_outputs.last_hidden_state
embeddings = self.vision_model.multi_modal_projector(embeddings)
descriptor = connect.Descriptor(embeddings)
with self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.to_serialized()
# Clear the image URL as hint that the image is passed as embeddings.
request.image_url = None
logger.debug(f"Request: {request.model_dump_json()}")
# Get the response generator
response_generator = await self.pd_worker_client.round_robin(
request.model_dump_json()
)
await readable.wait_for_completion()
async for response in response_generator:
output = MyRequestOutput.model_validate_json(response.data())
yield MyRequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
).model_dump_json()
except Exception as e:
logger.error(f"Error processing request {request_id}: {e}")
raise
async def async_init(self, runtime: DistributedRuntime):
logger.info("Startup started.")
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
self.downstream_endpoint
)
self.pd_worker_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector(runtime=runtime, namespace=parsed_namespace)
await self._connector.initialize()
logger.info("Startup completed.")
@classmethod
def parse_args(cls) -> Tuple[argparse.Namespace, Config]:
DEFAULT_ENDPOINT = "dyn://dynamo.encoder.generate"
DEFAULT_DOWNSTREAM_ENDPOINT = "dyn://dynamo.llm.generate"
parser = FlexibleArgumentParser(
description="vLLM based encoder for Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_ENDPOINT}'",
)
parser.add_argument(
"--downstream-endpoint",
type=str,
default=DEFAULT_DOWNSTREAM_ENDPOINT,
help=f"The endpoint string of the downstream LLM in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_DOWNSTREAM_ENDPOINT}'",
)
args, config = base_parse_args(parser)
return args, config
async def graceful_shutdown(runtime):
"""
By calling `runtime.shutdown()`, the endpoints will immediately be unavailable.
However, in-flight requests will still be processed until they are finished.
After all in-flight requests are finished, the `serve_endpoint` functions will return
and the engine will be shutdown by Python's garbage collector.
"""
logging.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
# Runtime setup
# Set up signal handler for graceful shutdown
loop = asyncio.get_running_loop()
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logging.info("Signal handlers set up for graceful shutdown")
# worker setup
args, config = VllmEncodeWorker.parse_args()
await init(runtime, args, config)
async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Config):
"""
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
generate_endpoint = component.endpoint(config.endpoint)
handler = VllmEncodeWorker(args, config.engine_args)
await handler.async_init(runtime)
logger.info(f"Starting to serve the {args.endpoint} endpoint...")
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import json
import logging
import os
import signal
import sys
import uuid
from enum import Enum
from typing import AsyncIterator, Tuple, Union
import uvloop
from transformers import AutoTokenizer
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import FlexibleArgumentParser
from dynamo.llm import ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
# To import example local module
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
from utils.args import Config, base_parse_args, parse_endpoint
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.protocol import MultiModalRequest, MyRequestOutput, vLLMMultimodalRequest
configure_dynamo_logging()
logger = logging.getLogger(__name__)
prompt_template = "USER: <image>\n<prompt> ASSISTANT:"
class RequestType(Enum):
CHAT = "chat"
COMPLETION = "completion"
class Processor(ProcessMixIn):
"""
vLLM pre and post processing
"""
@classmethod
def parse_args(cls) -> Tuple[argparse.Namespace, Config]:
DEFAULT_ENDPOINT = "dyn://dynamo.processor.generate"
DEFAULT_DOWNSTREAM_ENDPOINT = "dyn://dynamo.encoder.generate"
parser = FlexibleArgumentParser(
description="vLLM based processor for Dynamo LLM."
)
parser.add_argument(
"--prompt-template",
type=str,
required=True,
help=(
"Different multi-modal models expect the prompt to contain different special media prompts. "
"The processor will use this argument to construct the final prompt. "
"User prompt will replace '<prompt>' in the provided template. "
"For example, if the user prompt is 'please describe the image' and the prompt template is "
"'USER: <image> <prompt> ASSISTANT:', the resulting prompt is "
"'USER: <image> please describe the image ASSISTANT:'."
),
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_ENDPOINT}'",
)
parser.add_argument(
"--downstream-endpoint",
type=str,
default=DEFAULT_DOWNSTREAM_ENDPOINT,
help=f"The endpoint string of the downstream encoder in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_DOWNSTREAM_ENDPOINT}'",
)
args, config = base_parse_args(parser)
return args, config
def __init__(self, args: argparse.Namespace, engine_args: AsyncEngineArgs):
self.prompt_template = args.prompt_template
self.downstream_endpoint = args.downstream_endpoint
self.engine_args = engine_args
self.model_config = self.engine_args.create_model_config()
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.tokenizer = self._create_tokenizer(self.engine_args)
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
self.completions_processor = CompletionsProcessor(
self.tokenizer, self.model_config
)
def cleanup(self):
pass
def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer:
"""Create a TokenizerGroup using engine arguments similar to VLLM's approach"""
model_path = engine_args.model
# Create the base tokenizer with VLLM's typical settings
base_tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
padding_side="left",
truncation_side="left",
use_fast=True, # VLLM might use the fast tokenizer for efficiency
)
return base_tokenizer
async def async_init(self, runtime: DistributedRuntime):
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
self.downstream_endpoint
)
self.encode_worker_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
# Main method to parse the request and send the request to the vllm worker.
async def _generate(
self,
raw_request: Union[CompletionRequest, ChatCompletionRequest],
image: str,
request_type: RequestType,
):
request_id = str(uuid.uuid4().hex)
logger.debug(f"Got raw request: {raw_request}")
(
request,
conversation,
prompt,
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
worker_request = vLLMMultimodalRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
image_url=image,
)
# model_dump_json() serializes the request to JSON string
# This API could accept Pydantic class, but SamplingParams
# in vLLMMultimodalRequest is not a Pydantic class and will
# cause TypeError: unsupported type SamplingParams
response_generator = await self.encode_worker_client.round_robin(
worker_request.model_dump_json()
)
output = self._generate_responses(response_generator, request_type)
# Stream the processed responses
async for response in await self._stream_response(
request, output, request_id, conversation
):
yield response
# This method is used to process the responses from the engine generator.
async def _generate_responses(
self,
response_generator: AsyncIterator[RequestOutput],
request_type: RequestType,
):
async for resp in response_generator:
# Deserialize the response from the engine
# Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data())
# OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object
request_output = RequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
)
if request_type == RequestType.CHAT:
# For chat requests, yield the request_output directly.
yield request_output
else:
raise NotImplementedError(
f"Request type {request_type} not implemented"
)
# The generate endpoint will be used by the frontend to handle incoming requests.
async def generate(self, raw_request: MultiModalRequest):
logger.debug(f"Got raw request: {raw_request}")
if not isinstance(raw_request, MultiModalRequest):
# If the request is not MultiModalRequest, convert it to MultiModalRequest
raw_request = MultiModalRequest.model_validate(raw_request)
# Ensure the configured template includes the placeholder
template = self.prompt_template
if "<prompt>" not in template:
raise ValueError("prompt_template must contain '<prompt>' placeholder")
# Safely extract user text
try:
user_text = raw_request.messages[0].content[0].text
except (IndexError, AttributeError) as e:
raise ValueError(f"Invalid message structure: {e}")
prompt = template.replace("<prompt>", user_text)
msg = {
"role": "user",
"content": prompt,
}
chat_request = ChatCompletionRequest(
model=raw_request.model,
messages=[msg],
stream=raw_request.stream,
max_tokens=raw_request.max_tokens,
temperature=raw_request.temperature,
request_id=str(uuid.uuid4()),
)
image_url = None
for message in raw_request.messages:
for item in message.content:
if item.type == "image_url":
image_url = item.image_url.url
if image_url is None:
raise ValueError("Image URL is required")
async for response in self._generate(chat_request, image_url, RequestType.CHAT):
logger.debug(
f"Generated response type {type(response)}, content: {response}"
)
# reconstructing back the OpenAI chat response as dynamo egress expects it
if response.startswith("data: [DONE]"):
break
response = json.loads(response.lstrip("data: "))
yield response
async def graceful_shutdown(runtime):
"""
By calling `runtime.shutdown()`, the endpoints will immediately be unavailable.
However, in-flight requests will still be processed until they are finished.
After all in-flight requests are finished, the `serve_endpoint` functions will return
and the engine will be shutdown by Python's garbage collector.
"""
logging.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
# Runtime setup
# Set up signal handler for graceful shutdown
loop = asyncio.get_running_loop()
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logging.info("Signal handlers set up for graceful shutdown")
# worker setup
args, config = Processor.parse_args()
await init(runtime, args, config)
async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Config):
"""
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
generate_endpoint = component.endpoint(config.endpoint)
handler = Processor(args, config.engine_args)
await handler.async_init(runtime)
# Register the endpoint as entrypoint to a model
await register_llm(
ModelType.Chat, # Custom processor is used and this type bypasses SDK processor
generate_endpoint,
config.model,
config.served_model_name,
kv_cache_block_size=config.engine_args.block_size,
)
logger.info(f"Starting to serve the {args.endpoint} endpoint...")
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from vllm.config import VllmConfig
from vllm.v1.metrics.loggers import StatLoggerBase
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from dynamo.llm import (
ForwardPassMetrics,
KvStats,
SpecDecodeStats,
WorkerMetricsPublisher,
WorkerStats,
)
from dynamo.runtime import Component
class NullStatLogger(StatLoggerBase):
def __init__(self):
pass
def record(
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
engine_idx: int = 0,
):
pass
def log_engine_initialized(self):
pass
class DynamoStatLoggerPublisher(StatLoggerBase):
"""Stat logger publisher. Wrapper for the WorkerMetricsPublisher to match the StatLoggerBase interface."""
def __init__(self, component: Component, dp_rank: int) -> None:
self.inner = WorkerMetricsPublisher()
self.inner.create_endpoint(component)
self.dp_rank = dp_rank
self.num_gpu_block = 1
self.request_total_slots = 1
# TODO: Remove this and pass as metadata through etcd
def set_num_gpu_block(self, num_blocks):
self.num_gpu_block = num_blocks
# TODO: Remove this and pass as metadata through etcd
def set_num_request_total_slots(self, request_total_slots):
self.request_total_slots = request_total_slots
def record(
self,
scheduler_stats: SchedulerStats,
iteration_stats: Optional[IterationStats],
engine_idx: int = 0,
):
# request_total_slots and kv_total_blocks are properties of model + gpu
# we should only publish them once, not every metric update
# they should be part of some runtime metadata tied to MDC or put in etcd ?
hit_rate = 0
if scheduler_stats.prefix_cache_stats.queries > 0:
hit_rate = (
scheduler_stats.prefix_cache_stats.hits
/ scheduler_stats.prefix_cache_stats.queries
)
worker_stats = WorkerStats(
request_active_slots=scheduler_stats.num_running_reqs,
request_total_slots=self.request_total_slots,
num_requests_waiting=scheduler_stats.num_waiting_reqs,
data_parallel_rank=self.dp_rank,
)
kv_stats = KvStats(
kv_active_blocks=int(self.num_gpu_block * scheduler_stats.kv_cache_usage),
kv_total_blocks=self.num_gpu_block,
gpu_cache_usage_perc=scheduler_stats.kv_cache_usage,
gpu_prefix_cache_hit_rate=hit_rate, # TODO: This is a point in time update, not cumulative. Will be problematic on router side if we try to use it.
)
spec_dec_stats = scheduler_stats.spec_decoding_stats
if spec_dec_stats:
spec_dec_stats = SpecDecodeStats(
num_spec_tokens=spec_dec_stats.num_spec_tokens,
num_drafts=spec_dec_stats.num_drafts,
num_draft_tokens=spec_dec_stats.num_draft_tokens,
num_accepted_tokens=spec_dec_stats.num_accepted_tokens,
num_accepted_tokens_per_pos=spec_dec_stats.num_accepted_tokens_per_pos,
)
metrics = ForwardPassMetrics(
worker_stats=worker_stats,
kv_stats=kv_stats,
spec_decode_stats=spec_dec_stats,
)
self.inner.publish(metrics)
def init_publish(self):
worker_stats = WorkerStats(
request_active_slots=0,
request_total_slots=self.request_total_slots,
num_requests_waiting=0,
data_parallel_rank=self.dp_rank,
)
kv_stats = KvStats(
kv_active_blocks=0,
kv_total_blocks=self.num_gpu_block,
gpu_cache_usage_perc=0,
gpu_prefix_cache_hit_rate=0,
)
metrics = ForwardPassMetrics(
worker_stats=worker_stats,
kv_stats=kv_stats,
spec_decode_stats=None,
)
self.inner.publish(metrics)
def log_engine_initialized(self) -> None:
pass
class StatLoggerFactory:
"""Factory for creating stat logger publishers. Required by vLLM."""
def __init__(self, component: Component, dp_rank: int = 0) -> None:
self.component = component
self.created_logger: Optional[DynamoStatLoggerPublisher] = None
self.dp_rank = dp_rank
def create_stat_logger(self, dp_rank: int) -> StatLoggerBase:
if self.dp_rank != dp_rank:
return NullStatLogger()
logger = DynamoStatLoggerPublisher(self.component, dp_rank)
self.created_logger = logger
return logger
def __call__(self, vllm_config: VllmConfig, dp_rank: int) -> StatLoggerBase:
return self.create_stat_logger(dp_rank=dp_rank)
# TODO Remove once we publish metadata to etcd
def set_num_gpu_blocks_all(self, num_blocks):
if self.created_logger:
self.created_logger.set_num_gpu_block(num_blocks)
def set_request_total_slots_all(self, request_total_slots):
if self.created_logger:
self.created_logger.set_num_request_total_slots(request_total_slots)
def init_publish(self):
if self.created_logger:
self.created_logger.init_publish()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import copy
import logging
import os
import signal
import sys
from typing import Tuple
import torch
import uvloop
from transformers import AutoImageProcessor
from vllm.distributed.kv_events import ZmqEventPublisher
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs.data import TokensPrompt
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
from vllm.v1.engine.async_llm import AsyncLLM
from dynamo.llm import ZmqKvEventPublisher, ZmqKvEventPublisherConfig
from dynamo.runtime import Component, DistributedRuntime, Endpoint, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
import connect
from publisher import StatLoggerFactory
from utils.args import (
Config,
base_parse_args,
configure_ports_with_etcd,
overwrite_args,
parse_endpoint,
)
from utils.image_loader import ImageLoader
from utils.protocol import MyRequestOutput, vLLMMultimodalRequest
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class VllmBaseWorker:
@classmethod
def parse_args(cls) -> Tuple[argparse.Namespace, Config]:
parser = FlexibleArgumentParser(
description="vLLM based encoder for Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
help="Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default value will vary based on the worker type, see --worker-type for details.",
)
parser.add_argument(
"--downstream-endpoint",
type=str,
help="The endpoint string of the downstream LLM in 'dyn://namespace.component.endpoint' format. Default value will vary based on the worker type, see --worker-type for details.",
)
parser.add_argument(
"--worker-type",
type=str,
choices=["prefill", "decode", "encode_prefill"],
required=True,
help="Specify the type of worker. Must be one of: 'prefill', 'decode', 'encode_prefill'",
)
parser.add_argument(
"--enable-disagg",
action="store_true",
help="Enable disaggregated mode, where prefill and decode are handled by separate workers."
" If not set, the '*prefill' worker type will handle both prefill and decode.",
)
# use endpoint_overwrite to set the default endpoint based on worker type
def endpoint_overwrite(args):
# default endpoint for this worker
if args.worker_type == "prefill":
args.endpoint = args.endpoint or "dyn://dynamo.llm.generate"
elif args.worker_type == "decode":
args.endpoint = args.endpoint or "dyn://dynamo.decoder.generate"
elif args.worker_type == "encode_prefill":
args.endpoint = args.endpoint or "dyn://dynamo.encoder.generate"
# set downstream endpoint for disaggregated workers
if args.enable_disagg:
args.downstream_endpoint = (
args.downstream_endpoint or "dyn://dynamo.decoder.generate"
)
return args
args, config = base_parse_args(parser, endpoint_overwrite)
return args, config
def __init__(
self,
args: argparse.Namespace,
engine_args: AsyncEngineArgs,
component: Component,
endpoint: Endpoint,
):
self.enable_disagg = args.enable_disagg
self.endpoint = args.endpoint
self.downstream_endpoint = args.downstream_endpoint
self.engine_args = engine_args
self.setup_vllm_engine(component, endpoint)
async def async_init(self, runtime: DistributedRuntime):
pass
def setup_vllm_engine(self, component: Component, endpoint: Endpoint):
"""Initialize the vLLM engine.
This method sets up the vLLM engine client, and configures the dynamo-aware KV
event publisher and metrics stats logger based on component and endpoint.
"""
os.environ["VLLM_NO_USAGE_STATS"] = "1" # Avoid internal HTTP requests
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
# Load default sampling params from `generation_config.json`
self.default_sampling_params = (
self.engine_args.create_model_config().get_diff_sampling_param()
)
# Taken from build_async_engine_client_from_engine_args()
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = self.engine_args.create_engine_config(usage_context=usage_context)
# Create vLLM engine with metrics logger and KV event publisher attached
self.stats_logger = StatLoggerFactory(
component, self.engine_args.data_parallel_rank or 0
)
self.engine_client = AsyncLLM.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
stat_loggers=[self.stats_logger],
disable_log_requests=self.engine_args.disable_log_requests,
disable_log_stats=self.engine_args.disable_log_stats,
)
# TODO Hack to get data, move this to registering in ETCD
self.stats_logger.set_num_gpu_blocks_all(
vllm_config.cache_config.num_gpu_blocks
)
self.stats_logger.set_request_total_slots_all(
vllm_config.scheduler_config.max_num_seqs
)
self.stats_logger.init_publish()
# TODO: We start off with a valid endpoint, then we increment it by dp_rank
# May no longer be valid. Lets remove the increment behavior from vLLM and here
zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
self.engine_args.kv_events_config.endpoint,
data_parallel_rank=self.engine_args.data_parallel_rank or 0,
).replace("*", "127.0.0.1")
zmq_config = ZmqKvEventPublisherConfig(
worker_id=endpoint.lease_id(),
kv_block_size=vllm_config.cache_config.block_size,
zmq_endpoint=zmq_endpoint,
)
self.kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
logger.info(f"Reading Events from {zmq_endpoint}")
logger.info(f"VllmWorker for {self.engine_args.model} has been initialized")
async def generate(self, request: vLLMMultimodalRequest):
raise NotImplementedError(
"This method should be implemented in subclasses to handle the generation logic."
)
async def clear_kv_blocks(self, request=None):
try:
await self.engine_client.reset_prefix_cache()
yield {"status": "success", "message": "KV cache cleared"}
except Exception as e:
yield {"status": "error", "message": str(e)}
def cleanup(self):
"""Override in subclasses if cleanup is needed."""
pass
class VllmDecodeWorker(VllmBaseWorker):
async def generate(self, request: vLLMMultimodalRequest):
logger.debug(f"Got raw request: {request}")
if not isinstance(request, vLLMMultimodalRequest):
if isinstance(request, str):
request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received decode request: {{ id: {request.request_id} }}.")
# Decode worker doesn't process embeddings, so we pass None or empty tensor
gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=request.engine_prompt["prompt_token_ids"],
),
sampling_params=request.sampling_params,
request_id=request.request_id,
)
async for response in gen:
logger.debug(f"Response kv_transfer_params: {response.kv_transfer_params}")
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
metrics=response.metrics,
kv_transfer_params=response.kv_transfer_params,
).model_dump_json()
class VllmPDWorker(VllmBaseWorker):
async def async_init(self, runtime: DistributedRuntime):
logger.info("Startup started.")
if self.enable_disagg:
(
parsed_namespace,
parsed_component_name,
parsed_endpoint_name,
) = parse_endpoint(self.downstream_endpoint)
self.decode_worker_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
EMBEDDINGS_DTYPE = torch.float16
EMBEDDINGS_DEVICE = "cpu"
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
parsed_namespace, _, _ = parse_endpoint(self.endpoint)
self._connector = connect.Connector(runtime=runtime, namespace=parsed_namespace)
await self._connector.initialize()
# embeddings_shape, self.embeddings_dtype = get_vision_embeddings_info(
# self.engine_args.model, self.engine_args.num_patches
# )
# [gluo NOTE] Hardcoded for now, will use more generic approach once utils/model.py
# is fixed, see utils/models.py for details.
embeddings_shape = (1, 577, 4096)
logger.debug(f"Embeddings shape: {embeddings_shape}")
self.embedding_size = embeddings_shape[1]
embeddings = torch.empty(
embeddings_shape, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE
)
descriptor = connect.Descriptor(embeddings)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
# descriptor.register_memory(self._connector)
self._embeddings_descriptor = (embeddings, descriptor)
self.image_loader = ImageLoader()
self.image_processor = AutoImageProcessor.from_pretrained(
self.engine_args.model, trust_remote_code=True
)
logger.info("VllmPDWorker has been initialized")
async def generate(self, request: vLLMMultimodalRequest):
logger.debug(f"Got raw request: {request}")
if type(request) is not vLLMMultimodalRequest:
if type(request) is str:
request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")
if request.image_url is None:
# Process embeddings using the connector
embeddings, descriptor = self._embeddings_descriptor
if descriptor is None:
raise RuntimeError(
"Descriptor is None in PD worker - cannot process embeddings"
)
read_op = await self._connector.begin_read(
request.serialized_request, descriptor
)
await read_op.wait_for_completion()
logger.debug(f"in PD worker, image features: {embeddings}")
multi_modal_data = embeddings
else:
# Use PIL image instead of image embeddings
multi_modal_data = await self.image_loader.load_image(request.image_url)
# multi_modal_data = self.image_processor(images=image, return_tensors="pt")["pixel_values"].to(dtype=torch.float16)
# image input is expected to be (image_num, channel, height, width)
# logger.info(f"Image features shape: {multi_modal_data.shape}")
# multi_modal_data = multi_modal_data.unsqueeze(0)
# Remove the image features from the request as they are not required
request.image_url = None
request.serialized_request = None
pd_request = copy.deepcopy(request)
# Do prefill and remote decode if enable_disagg is true
if self.enable_disagg:
extra_args = pd_request.sampling_params.extra_args or {}
extra_args["kv_transfer_params"] = {
"do_remote_decode": True,
}
pd_request.sampling_params.extra_args = extra_args
pd_request.sampling_params.max_tokens = 1
pd_request.sampling_params.min_tokens = 1
logger.debug("Prefill request: %s", pd_request)
gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=pd_request.engine_prompt["prompt_token_ids"],
multi_modal_data={"image": multi_modal_data},
),
sampling_params=pd_request.sampling_params,
request_id=pd_request.request_id,
)
if self.enable_disagg:
decode_request = copy.deepcopy(request)
async for prefill_response in gen:
# Update the prompt token id in the decode request to the one
# in response, which has image templated filled in. So that
# the decode worker will fetch correct amount of KV blocks.
decode_request.engine_prompt[
"prompt_token_ids"
] = prefill_response.prompt_token_ids
logger.debug(
f"Prefill response kv_transfer_params: {prefill_response.kv_transfer_params}"
)
extra_args = decode_request.sampling_params.extra_args or {}
extra_args["kv_transfer_params"] = prefill_response.kv_transfer_params
extra_args.pop("serialized_request", None)
decode_request.sampling_params.extra_args = extra_args
logger.debug("Decode request: %s", decode_request)
async for decode_response in await self.decode_worker_client.round_robin(
decode_request.model_dump_json()
):
output = MyRequestOutput.model_validate_json(decode_response.data())
yield MyRequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
kv_transfer_params=output.kv_transfer_params,
).model_dump_json()
else:
async for response in gen:
logger.debug(
f"Response kv_transfer_params: {response.kv_transfer_params}"
)
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
metrics=response.metrics,
kv_transfer_params=response.kv_transfer_params,
).model_dump_json()
async def graceful_shutdown(runtime):
"""
By calling `runtime.shutdown()`, the endpoints will immediately be unavailable.
However, in-flight requests will still be processed until they are finished.
After all in-flight requests are finished, the `serve_endpoint` functions will return
and the engine will be shutdown by Python's garbage collector.
"""
logging.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
# Runtime setup
# Set up signal handler for graceful shutdown
loop = asyncio.get_running_loop()
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logging.info("Signal handlers set up for graceful shutdown")
# worker setup
args, config = VllmBaseWorker.parse_args()
# vLLM config overwrites
etcd_client = runtime.etcd_client()
await configure_ports_with_etcd(config, etcd_client)
overwrite_args(config)
await init(runtime, args, config)
async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Config):
"""
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
generate_endpoint = component.endpoint(config.endpoint)
clear_endpoint = component.endpoint("clear_kv_blocks")
if args.worker_type in ["prefill", "encode_prefill"]:
handler: VllmBaseWorker = VllmPDWorker(
args, config.engine_args, component, generate_endpoint
)
elif args.worker_type == "decode":
handler = VllmDecodeWorker(
args, config.engine_args, component, generate_endpoint
)
await handler.async_init(runtime)
logger.info(f"Starting to serve the {args.endpoint} endpoint...")
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate),
clear_endpoint.serve_endpoint(handler.clear_kv_blocks),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Dynamo Connect
Dynamo connect provides a Pythonic interface to the NIXL base RDMA subsystem via a set of Python classes.
The primary goal of this library to simplify the integration of NIXL based RDMA into inference applications.
All operations using the Connect library begin with the [`Connector`](#connector) class and the type of operation required.
There are four types of supported operations:
- **Register local readable memory**:
Register local memory buffer(s) with the RDMA subsystem to enable a remote worker to read from.
- **Register local writable memory**:
Register local memory buffer(s) with the RDMA subsystem to enable a remote worker to write to.
- **Read from registered, remote memory**:
Read remote memory buffer(s), registered by a remote worker to be readable, into local memory buffer(s).
- **Write to registered, remote memory**:
Write local memory buffer(s) to remote memory buffer(s) registered by a remote worker to writable.
By connecting correctly paired operations, high-throughput GPU Direct RDMA data transfers can be completed.
Given the list above, the correct pairing of operations would be 1 & 3 or 2 & 4.
Where one side is a "(read|write)-able operation" and the other is its correctly paired "(read|write) operation".
Specifically, a read operation must be paired with a readable operation, and a write operation must be paired with a writable operation.
## Examples
### Generic Example
In the diagram below, Local creates a [`WritableOperation`](#writableoperation) intended to receive data from Remote.
Local then sends metadata about the requested RDMA operation to Remote.
Remote then uses the metadata to create a [`WriteOperation`](#writeoperation) which will perform the GPU Direct RDMA memory transfer from Remote's GPU memory to Local's GPU memory.
```mermaid
---
title: Write Operation Between Two Workers
---
flowchart LR
c1[Remote] --"3: .begin_write()"--- WriteOperation
WriteOperation e1@=="4: GPU Direct RDMA"==> WritableOperation
WritableOperation --"1: .create_writable()"--- c2[Local]
c2 e2@--"2: RDMA Metadata via HTTP"--> c1
e1@{ animate: true; }
e2@{ animate: true; }
```
### Multimodal Example
In the case of the [Dynamo Multimodal Disaggregated Example](../README.md):
1. The HTTP frontend accepts a text prompt and a URL to an image.
2. The prompt and URL are then enqueued with the Processor before being dispatched to the first available Decode Worker.
3. Decode Worker then requests a Prefill Worker to provide key-value data for the LLM powering the Decode Worker.
4. Prefill Worker then requests that the image be processed and provided as embeddings by the Encode Worker.
5. Encode Worker acquires the image, processes it, performs inference on the image using a specialized vision model, and finally provides the embeddings to Prefill Worker.
6. Prefill Worker receives the embeddings from Encode Worker and generates a key-value cache (KV$) update for Decode Worker's LLM and writes the update directly to the GPU memory reserved for the data.
7. Finally, Decode Worker performs the requested inference.
```mermaid
---
title: Multimodal Disaggregated Workflow
---
flowchart LR
p0[HTTP Frontend] i0@--"text prompt"-->p1[Processor]
p0 i1@--"url"-->p1
p1 i2@--"prompt"-->dw[Decode Worker]
p1 i3@--"url"-->dw
dw i4@--"prompt"-->pw[Prefill Worker]
dw i5@--"url"-->pw
pw i6@--"url"-->ew[Encode Worker]
ew o0@=="image embeddings"==>pw
pw o1@=="kv_cache updates"==>dw
dw o2@--"inference results"-->p0
i0@{ animate: true; }
i1@{ animate: true; }
i2@{ animate: true; }
i3@{ animate: true; }
i4@{ animate: true; }
i5@{ animate: true; }
i6@{ animate: true; }
o0@{ animate: true; }
o1@{ animate: true; }
o2@{ animate: true; }
```
_Note: In this example, it is the data transfer between the Prefill Worker and the Encode Worker that utilizes the Dynamo Connect library. The KV Cache transfer between Decode Worker and Prefill Worker utilizes the NIXL base RDMA subsystem directly without using the Dynamo Connect library._
#### Code Examples
See [prefill_worker](../components/prefill_worker.py#L199) or [decode_worker](../components/decode_worker.py#L239),
for how they coordinate directly with the Encode Worker by creating a [`WritableOperation`](#writableoperation),
sending the operation's metadata via Dynamo's round-robin dispatcher, and awaiting the operation for completion before making use of the transferred data.
See [encode_worker](../components/encode_worker.py#L190),
for how the resulting embeddings are registered with the RDMA subsystem by creating a [`Descriptor`](#descriptor),
a [`WriteOperation`](#writeoperation) is created using the metadata provided by the requesting worker,
and the worker awaits for the data transfer to complete for yielding a response.
## Python Classes
### Connector
Core class for managing the connection between workers in a distributed environment.
Use this class to create readable and writable operations, or read and write data to remote workers.
This class is responsible for interfacing with the NIXL-based RDMA subsystem and providing a "Pythonic" interface
with which to utilize GPU Direct RDMA accelerated data transfers between models hosted by different workers in a Dynamo pipeline.
The connector provides two methods of moving data between workers:
- Preparing local memory to be written to by a remote worker.
- Preparing local memory to be read by a remote worker.
In both cases, local memory is registered with the NIXL-based RDMA subsystem via the [`Descriptor`](#descriptor) class and provided to the connector.
The connector then configures the RDMA subsystem to expose the memory for the requested operation and returns an operation control object.
The operation control object, either a [`ReadableOperation`](#readableoperation) or a [`WritableOperation`](#writableoperation),
provides RDMA metadata via its [`.to_serialized()`](#to_serialized) method as well as functionality to know when the operation has been completed or cancel the operation prior to completion.
The RDMA metadata must be provided to the remote worker expected to complete the operation.
The metadata contains required information (identifiers, keys, etc.) which enables the remote worker to interact with the provided memory.
#### Methods
##### `begin_read`
> Creates a [`ReadOperation`](#readoperation) for transferring data from a remote worker.
>
> To create the operation, the serialized request from a remote worker's [`ReadableOperation`](#readableoperation)
> along with a matching set of local memory descriptors which reference memory intended to receive data from the remote worker
> must be provided.
> The serialized request must be transferred from the remote to the local worker via a secondary channel, most likely HTTP or TCP+NATS.
>
> Once created, the operation will begin reading immediately.
> Disposal of the object reference will instruct the RDMA subsystem to cancel the read operation,
> therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
##### `begin_write`
> Creates a write operation for transferring data to a remote worker.
>
> To create the operation, the serialized request from a remote worker's [`WritableOperation`](#writableoperation)
> along with a matching set of local memory descriptors which reference memory to be transferred to the remote worker
> must be provided.
> The serialized request must be transferred from the remote to the local worker via a secondary channel, most likely HTTP or TCP+NATS.
>
> Once created, the operation will begin writing immediately.
> Disposal of the object reference will instruct the RDMA subsystem to cancel the write operation,
> therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
##### `create_readable`
> Creates a [`ReadableOperation`](#readableoperation) for transferring data to a remote worker.
>
> To create the operation, a set of local memory descriptors must be provided that reference memory intended to be transferred to
> a remote worker.
> Once created, the memory referenced by the provided descriptors becomes immediately readable by a remote worker with the necessary metadata.
> The metadata required to access the memory referenced by the provided descriptors is accessible via the operations `.to_serialized()` method.
> Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
>
> Disposable of the operation's object reference will instruct the RDMA subsystem to cancel the operation,
> therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
##### `create_writable`
> Creates a [`WritableOperation`](#writableoperation) for transferring data from a remote worker.
>
> To create the operation, a set of local memory descriptors must be provided which reference memory intended to receive data from
> a remote worker.
> Once created, the memory referenced by the provided descriptors becomes immediately writable by a remote worker with the necessary metadata.
> The metadata required to access the memory referenced by the provided descriptors is accessible via the operations `.to_serialized()` method.
> Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
>
> Disposable of the operation's object reference will instruct the RDMA subsystem to cancel the operation,
> therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
### Descriptor
Memory descriptor that ensures memory is registered with the NIXL base RDMA subsystem.
Memory must be registered with the RDMA subsystem to enable interaction with the memory.
Descriptor objects are administrative and do not copy, move, or otherwise modify the registered memory.
There are four ways to create a descriptor:
1. From a `torch.Tensor` object. Device information will be derived from the provided object.
2. From a `tuple` containing either a NumPy or CuPy `ndarray` and information describing where the memory resides (Host/CPU vs GPU).
3. From a Python `bytes` object. Memory is assumed to reside in CPU addressable host memory.
4. From a `tuple` comprised of the address of the memory, its size in bytes, and device information.
An optional reference to a Python object can be provided to avoid garbage collection issues.
### Device
Device describes the device, or kind of memory, a given allocation resides in.
Usually host (`"cpu"`) or GPU (`"cuda"`) memory.
When a system contains multiple GPU devices, specific GPU devices can be identified by including their ordinal index number.
For example, to reference the second GPU in a system `"cuda:1"` can be used.
By default, when `"cuda"` is provided, it is assumed to be `"cuda:0"` or the first GPU enumerated by the system.
### ReadOperation
An operation which transfers data from a remote worker to the local worker.
To create the operation, RDMA metadata ([`SerializedRequest`](#serializedrequest)) from a remote worker's [`ReadableOperation`](#readableoperation)
along with a matching set of local [`Descriptor`](#descriptor) objects which reference memory intended to receive data from the remote worker must be provided.
The RDMA metadata must be transferred from the remote to the local worker via a secondary channel, most likely HTTP or TCP+NATS.
Once created, the operation will begin reading immediately.
Disposal of the object reference will instruct the RDMA subsystem to cancel the read operation,
therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
#### Methods
##### `cancel`
> Instructs the RDMA subsystem to cancel the operation.
> Completed operations cannot be cancelled.
##### `wait_for_completion`
> Blocks the caller until the memory from the remote worker has been transferred to the provided buffers.
### ReadableOperation
An operation which enables a remote worker to read data from the local worker.
To create the operation, a set of local [`Descriptor`](#descriptor) objects must be provided that reference memory intended to be transferred to a remote worker.
Once created, the memory referenced by the provided descriptors becomes immediately readable by a remote worker with the necessary metadata.
The metadata required to access the memory referenced by the provided descriptors is accessible via the operations `.to_serialized()` method.
Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
Disposal of the operation's object reference will instruct the RDMA subsystem to cancel the operation,
therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
#### Methods
##### `to_serialized`
> Generates and returns the RDMA metadata ([`SerializedRequest`](#serializedrequest)) required for a remote worker to read from the operation.
> Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
##### `wait_for_completion`
> Blocks the caller until the operation has received a completion signal from a remote worker.
### WriteOperation
An operation which transfers data from the local worker to a remote worker.
To create the operation, RDMA metadata ([`SerializedRequest`](#serializedrequest)) from a remote worker's [`WritableOperation`](#writableoperation)
along with a matching set of local [`Descriptor`](#descriptor) objects which reference memory to be transferred to the remote worker must be provided.
The RDMA metadata must be transferred from the remote to the local worker via a secondary channel, most likely HTTP or TCP+NATS.
Once created, the operation will begin writing immediately.
Disposal of the object reference will instruct the RDMA subsystem to cancel the write operation,
therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
#### Methods
##### `cancel`
> Instructs the RDMA subsystem to cancel the operation.
> Completed operations cannot be cancelled.
##### `wait_for_completion`
> Blocks the caller until all provided buffers have been transferred to the remote worker.
### WritableOperation
An operation which enables a remote worker to write data to the local worker.
To create the operation, a set of local [`Descriptor`](#descriptor) objects must be provided which reference memory intended to receive data from a remote worker.
Once created, the memory referenced by the provided descriptors becomes immediately writable by a remote worker with the necessary metadata.
The metadata required to access the memory referenced by the provided descriptors is accessible via the operations `.to_serialized()` method.
Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
Disposal of the operation's object reference will instruct the RDMA subsystem to cancel the operation,
therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
#### Methods
##### `to_serialized`
> Generates and returns the RDMA metadata ([`SerializedRequest`](#serializedrequest)) required for a remote worker to write to the operation.
> Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
##### `wait_for_completion`
> Blocks the caller until the operation has received a completion signal from a remote worker.
### SerializedRequest
A Pydantic type intended to provide JSON serialized RDMA metadata about a [`ReadableOperation`](#readableoperation) or [`WritableOperation`](#writableoperation) object.
Use the [`.to_serialized()`](#to_serialized) method on either of the above types to generate a `SerializedRequest` object for an operation.
## References
- [NVIDIA Dynamo](https://developer.nvidia.com/dynamo) @ [GitHub](https://github.com/ai-dynamo/dynamo)
- [NVIDIA Inference Transfer Library (NIXL)](https://developer.nvidia.com/blog/introducing-nvidia-dynamo-a-low-latency-distributed-inference-framework-for-scaling-reasoning-ai-models/#nvidia_inference_transfer_library_nixl_low-latency_hardware-agnostic_communication%C2%A0) @ [GitHub](https://github.com/ai-dynamo/nixl)
- [Dynamo Multimodal Example](https://github.com/ai-dynamo/dynamo/tree/main/examples/multimodal)
- [NVIDIA GPU Direct](https://developer.nvidia.com/gpudirect)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import asyncio
import logging
import socket
import uuid
import zlib
from abc import ABC, abstractmethod
from enum import IntEnum
from functools import cached_property
from typing import Any, List, Optional
import nixl._api as nixl_api
import nixl._bindings as nixl_bindings
import torch
from pydantic import BaseModel, ConfigDict, field_validator
from dynamo.runtime import DistributedRuntime
logger = logging.getLogger(__name__)
try:
import cupy as array_module
from cupy_backends.cuda.api.runtime import CUDARuntimeError
logger.info("Utilizing cupy to enable GPU acceleration.")
except ImportError:
try:
import numpy as array_module
logger.warning("Failed to load cupy for GPU acceleration, utilizing numpy to provide CPU based operations.")
except ImportError as e:
raise ImportError("Numpy or cupy must be installed to use this module.") from e
class AbstractOperation(ABC):
"""
Abstract base class for awaitable NIXL based RDMA operations.
"""
def __init__(
self,
connector: Connector,
operation_kind: OperationKind,
local_descriptors: Descriptor | list[Descriptor],
remote_descriptors: Optional[Descriptor | list[Descriptor]],
notification_key: Optional[str],
) -> None:
if not isinstance(connector, Connector):
raise TypeError("Argument `connector` must be `dynamo.connect.Connector`.")
if operation_kind is not OperationKind.READ and operation_kind is not OperationKind.WRITE:
raise ValueError("Argument `operation_kind` must be either `READ` or `WRITE`.")
if not (
isinstance(local_descriptors, (Descriptor, list))
or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors))
):
raise TypeError("Argument `local_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.")
if (
remote_descriptors is not None
and not (
isinstance(remote_descriptors, Descriptor)
or (isinstance(remote_descriptors, list) and all(isinstance(d, Descriptor) for d in remote_descriptors))
)
):
raise TypeError("Argument `remote_descriptors` must be dynamo.connect.Descriptor`, `list[dynamo.connect.Descriptor]`, or `None`.")
if isinstance(local_descriptors, list) and len(local_descriptors) == 0:
raise ValueError("Argument `local_descriptors` must not be an empty list.")
if (
remote_descriptors is not None
and isinstance(remote_descriptors, list)
and len(remote_descriptors) == 0
):
raise ValueError("Argument `remote_descriptors` must not be an empty list.")
notification_key = str(uuid.uuid4()) if notification_key is None else notification_key
if not isinstance(notification_key, str):
raise TypeError("Argument `notification_key` must be `str` or `None`.")
if len(notification_key) == 0:
raise ValueError("Argument `notification_key` must not be an empty string.")
self._notification_key: str = notification_key
self._connector: Connector = connector
self._operation_kind: OperationKind = operation_kind
self._local_descriptors: Descriptor | list[Descriptor] = local_descriptors
self._local_dlist: Optional[list[tuple[int, int, int]]] = None
self._local_memtype: DeviceKind = DeviceKind.UNSPECIFIED
self._remote_descriptors: Optional[Descriptor | list[Descriptor]] = None if remote_descriptors is None else remote_descriptors
self._remote_dlist: Optional[list[tuple[int, int, int]]] = None
self._remote_memtype: DeviceKind = DeviceKind.UNSPECIFIED
# Register local descriptors with NIXL.
# Note: Only local descriptors should be registered with NIXL,
if isinstance(local_descriptors, list):
for d in local_descriptors:
d.register_memory(self._connector)
else:
local_descriptors.register_memory(self._connector)
# Record local descriptors.
memtype, dtlist = self._create_dlist(local_descriptors)
self._local_dlist = dtlist
self._local_memtype = memtype
# Record remote descriptors when provided.
if remote_descriptors is not None:
memtype, dtlist = self._create_dlist(remote_descriptors)
self._remote_dlist = dtlist
self._remote_memtype = memtype
def __del__(self) -> None:
self._release()
def __enter__(self) -> AbstractOperation:
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self._release()
def _release(self) -> None:
"""
Private method to release resources. Only to be called by `self`.
"""
pass
@property
def connector(self) -> Connector:
"""
Gets the local associated with this operation.
"""
return self._connector
@property
def operation_kind(self) -> OperationKind:
"""
Gets the kind of operation.
"""
return self._operation_kind
@abstractmethod
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
raise NotImplementedError("Abstract method not implemented by derived class.")
# Private Methods
def _create_dlist(
self,
descriptors: Descriptor | list[Descriptor],
) -> tuple[DeviceKind, list[tuple[int, int, int]]]:
"""
Helper function to create a list of tuples (ptr, size, device) from descriptors.
"""
dlist: list[tuple[int, int, int]] = []
memtype: DeviceKind = DeviceKind.UNSPECIFIED
if isinstance(descriptors, list):
memtype = descriptors[0].device.kind
for desc in descriptors:
if memtype != desc.device.kind:
raise ValueError("All local descriptors must have the same memory type.")
dlist.append((desc.ptr, desc.size, desc.device.id))
else:
memtype = descriptors.device.kind
dlist.append((descriptors.ptr, descriptors.size, descriptors.device.id))
return (memtype, dlist)
class ActiveOperation(AbstractOperation):
"""
Abstract class for active operations that initiates a NIXL based RDMA transfer based `SerializedRequest`
provided by the remote worker's corresponding `PassiveOperation`.
"""
def __init__(
self,
remote: Remote,
operation_kind: OperationKind,
local_descriptors: Descriptor | list[Descriptor],
remote_descriptors: Descriptor | list[Descriptor],
notification_key: str,
) -> None:
if not isinstance(remote, Remote) or remote._connector is None:
raise TypeError("Argument `remote` must be valid `dynamo.connect.Remote`.")
if not isinstance(operation_kind, OperationKind):
raise TypeError("Argument `operation_kind` must `dynamo.connect.OperationKind`.")
if operation_kind is not OperationKind.READ and operation_kind is not OperationKind.WRITE:
raise ValueError("Argument `operation_kind` must be either `READ` or `WRITE`.")
if not (
isinstance(local_descriptors, Descriptor)
or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors))
):
raise TypeError("Argument `local_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.")
if not (
isinstance(remote_descriptors, Descriptor)
or (isinstance(remote_descriptors, list) and all(isinstance(d, Descriptor) for d in remote_descriptors))
):
raise TypeError("Argument `remote_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.")
# Unpack single descriptors from lists if they are provided as single descriptors.
if isinstance(local_descriptors, list) and len(local_descriptors) == 1:
local_descriptors = local_descriptors[0]
if isinstance(remote_descriptors, list) and len(remote_descriptors) == 1:
remote_descriptors = remote_descriptors[0]
if (isinstance(local_descriptors, list) and isinstance(remote_descriptors, list) and len(local_descriptors) != len(remote_descriptors)):
raise ValueError("When `local_descriptors` and `remote_descriptors` are lists, they must have the same length.")
elif isinstance(local_descriptors, list) != isinstance(remote_descriptors, list):
raise ValueError("Both `local_descriptors` and `remote_descriptors` must be either lists or single descriptors.")
if not isinstance(notification_key, str):
raise TypeError("Argument `notification_key` must be `str`.")
if len(notification_key) == 0:
raise ValueError("Argument `notification_key` must not be an empty string.")
self._remote = remote
self._status = OperationStatus.UNINTIALIZED
super().__init__(remote.connector, operation_kind, local_descriptors, remote_descriptors, notification_key)
# Quick check to ensure remote descriptors are not None to make static analysis happy.
if self._local_dlist is None or self._remote_dlist is None:
raise RuntimeError("NIXL descriptor list(s) not bound to operation.")
self._local_xfer_descs: Optional[nixl_bindings.nixlXferDList] = None
self._remote_xfer_descs: Optional[nixl_bindings.nixlXferDList] = None
self._xfer_hndl: Optional[nixl_api.nixl_xfer_handle] = None
self._local_xfer_descs = self._connector._nixl.get_xfer_descs(
descs=self._local_dlist,
mem_type=str(self._local_memtype),
)
logger.debug(f"Created local NIXL xfer descs: {self._local_xfer_descs}")
self._remote_xfer_descs = self._connector._nixl.get_xfer_descs(
descs=self._remote_dlist,
mem_type=str(self._remote_memtype),
)
logger.debug(f"Created remote NIXL xfer descs: {self._remote_xfer_descs}")
self._xfer_hndl = self._connector._nixl.initialize_xfer(
operation=str(operation_kind),
local_descs=self._local_xfer_descs,
remote_descs=self._remote_xfer_descs,
remote_agent=self._remote.name,
notif_msg=self._notification_key.encode("utf-8"),
)
logger.debug(f"Created NIXL transfer handle: {self._xfer_hndl}")
def __del__(self) -> None:
super().__del__()
self._release()
def __enter__(self) -> ActiveOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
match self.status:
case OperationStatus.IN_PROGRESS | OperationStatus.INITIALIZED:
self._status = OperationStatus.CANCELLED
self._release()
def __repr__(self) -> str:
return str(
f"{self.__class__.__name__}("
f"operation_kind={self._operation_kind}, "
f"local_descriptors={self._local_descriptors}, "
f"remote_descriptors={self._remote_descriptors}, "
f"notification_key='{self._notification_key}', "
f"remote='{self._remote.name}', "
f"status='{self._status}'"
f")"
)
def _release(self) -> None:
"""
Private method to release resources.
"""
error: Optional[Exception] = None
if self._xfer_hndl is not None:
try:
logger.debug(f"NIXL transfer handle {self._xfer_hndl} released.")
self._connector._nixl.release_xfer_handle(self._xfer_hndl)
except Exception as e:
logger.error(f"Failed to release resources: {e}")
error = e
finally:
self._xfer_hndl = None
try:
super()._release()
except Exception as e:
logger.error(f"Failed to release WaitableOperation resources: {e}")
if error is not None:
e.__cause__ = error
error = e
if error is not None:
raise error
def _cancel_(self) -> None:
if self._xfer_hndl is None:
return
if self.status == OperationStatus.ERRORED:
raise RuntimeError("Operation is errored, unable to cancel the operation.")
logger.info(f"Cancellation requested for operation {{ kind={self._operation_kind}, remote='{self._remote.name}', status={self._status} }}.")
# NIXL will cancel the transfer if it is in progress when the handle is released.
self._connector._nixl.release_xfer_handle(self._xfer_hndl)
self._status = OperationStatus.CANCELLED
self._xfer_hndl = None
async def _wait_for_completion_(self) -> None:
# Loop until the operation is no longer in progress (or "initalized"),
# yielding control to the event loop to allow other operations to run.
iteration_count = 0
while True:
if iteration_count % 10 == 0:
logger.debug(f"Waiting for operation {{ kind={self._operation_kind}, remote='{self._remote.name}', duration={iteration_count / 10}s }}.")
match self.status:
# "in progress" or "initialized" means the operation is ongoing.
case OperationStatus.INITIALIZED:
await asyncio.sleep(0.1)
case OperationStatus.IN_PROGRESS:
await asyncio.sleep(0.1)
# Any other state indicates completion or error.
case _:
return
iteration_count += 1
@abstractmethod
def cancel(self) -> None:
"""
Cancels the operation.
No affect if the operation has already completed or errored, or has been cancelled.
"""
raise NotImplementedError("Abstract method not implemented by derived class.")
@property
def remote(self) -> Remote:
"""
Gets the remote worker associated with this operation.
"""
return self._remote
@property
def status(self) -> OperationStatus:
"""
Gets the status of the operation.
"""
# Early return if the operation is already complete, errored, or cancelled.
match self._status:
case OperationStatus.COMPLETE | OperationStatus.ERRORED | OperationStatus.CANCELLED:
return self._status
if self._xfer_hndl is None:
raise RuntimeError("NIXL transfer handle is invalid.")
old_status = self._status
if self._status == OperationStatus.UNINTIALIZED:
state = self._connector._nixl.transfer(self._xfer_hndl, self._notification_key.encode("utf-8"))
logger.debug(f"NIXL reported transfer state: {state}")
if state == "ERR":
self._status = OperationStatus.ERRORED
elif state == "DONE":
self._status = OperationStatus.COMPLETE
else:
self._status = OperationStatus.INITIALIZED
else:
state = self._connector._nixl.check_xfer_state(self._xfer_hndl)
logger.debug(f"NIXL reported transfer state: {state}")
if state == "ERR":
self._status = OperationStatus.ERRORED
elif state == "DONE":
self._status = OperationStatus.COMPLETE
else:
self._status = OperationStatus.IN_PROGRESS
if self._status != old_status:
logger.debug(f"{self.__class__.__name__} {{ remote: '{self._remote.name}' status: '{old_status}' => '{self._status}' }}.")
return self._status
class Connector:
"""
Core class for managing the connection between workers in a distributed environment.
Use this class to create readable and writable operations, or read and write data to remote workers.
"""
def __init__(
self,
namespace: Optional[str] = None,
runtime: Optional[DistributedRuntime] = None,
worker_id: Optional[str] = None,
) -> None:
"""
Creates a new Connector instance.
Parameters
----------
namespace : Optional[str], optional
Dynamo namespace of the component, defaults to "dynamo" when `None`.
runtime : Optional[DistributedRuntime], optional
Reference the dynamo runtime used by the compenent, attempts to use the current runtime when `None`.
worker_id : Optional[str], optional
Unique identifier of the worker, defaults to a new UUID when `None`.
Raises
------
TypeError
When `namespace` is provied and not of type 'str'.
TypeError
When `runtime` iis provied and not of type `dynamo.runtime.DistributedRuntime`.
TypeError
When `worker_id` is provied and not of type `uuid.UUID`.
"""
namespace = "dynamo" if namespace is None else namespace
if not isinstance(namespace, str):
raise TypeError("Argument `namespace` must be `str` or `None`.")
if not isinstance(runtime, DistributedRuntime) or runtime is None:
raise TypeError("Argument `runtime` must be `dynamo.runtime.DistributedRuntime` or `None`.")
worker_id = worker_id if worker_id is not None else str(uuid.uuid4())
if not isinstance(worker_id, str) or len(worker_id) == 0:
raise TypeError("Argument `worker_id` must be a non-empty `str` or `None`.")
self._worker_id = worker_id
self._is_initialized = False
self._runtime = runtime
self._namespace = namespace
self._nixl = nixl_api.nixl_agent(self._worker_id)
self._hostname = socket.gethostname()
self._agent_metadata: Optional[bytes] = None
logger.debug(f"Created {self.__repr__()}.")
def __repr__(self) -> str:
return str(
f"{self.__class__.__name__}("
f"worker_id='{self._worker_id}', "
f"namespace={self._namespace}, "
f"hostname={self._hostname}, "
f"metadata=<{0 if self._agent_metadata is None else len(self._agent_metadata)} bytes>"
")"
)
def __str__(self) -> str:
return self._worker_id
@cached_property
def is_cuda_available(self) -> bool:
# Note: cuda.is_avalailable initializes cuda
# and can't be called when forking subprocesses
# care should be taken to only call it within
# subprocesses or use 'spawn'
try:
return array_module.cuda is not None and array_module.cuda.is_available()
except CUDARuntimeError:
return False
@property
def metadata(self) -> bytes:
"""
Get the metadata of the worker.
"""
return self._nixl.get_agent_metadata()
@property
def name(self) -> str | None:
"""
Get the name of the worker.
"""
return self._worker_id
@property
def namespace(self) -> str:
"""
Get the namespace of the local.
"""
return self._namespace
@property
def runtime(self) -> DistributedRuntime:
"""
Get the runtime of the local.
"""
if self._runtime is None:
raise RuntimeError("Runtime is not set. This Connector was not initialized with a runtime.")
return self._runtime
async def begin_read(
self,
remote_request: SerializedRequest,
local_descriptors: Descriptor | list[Descriptor],
) -> ReadOperation:
"""
Creates a read operation for fulfilling a remote readable operation.
Parameters
----------
remote_request : SerializedRequest
Serialized request from a remote worker that has created a readable operation.
local_descriptors : Descriptor | list[Descriptor]
Local descriptor(s) to receive data from the remote worker described by `remote_request`.
Returns
-------
ReadOperation
Awaitable read operation that can be used to transfer data from a remote worker.
Raises
------
TypeError
When `remote_request` is not of type `SerializedRequest`.
TypeError
When `local_descriptors` is not of type `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.
"""
if remote_request is None or not isinstance(remote_request, SerializedRequest):
raise TypeError("Argument `remote_request` must be `SerializedRequest`.")
if not (
isinstance(local_descriptors, Descriptor)
or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors))
):
raise TypeError("Argument `local_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.")
if remote_request.operation_kind != OperationKind.READ.value:
raise RuntimeError("Cannot create a `dynamo.connect.ReadOperation` to read from a remote `dynamo.connect.WritableOperation`.")
if not self._is_initialized:
raise RuntimeError("Connector not initialized. Call `initialize()` before calling this method.")
op = ReadOperation(self, remote_request, local_descriptors)
return op
async def begin_write(
self,
local_descriptors: Descriptor | list[Descriptor],
remote_request: SerializedRequest,
) -> WriteOperation:
"""
Creates a write operation for transferring data to a remote worker.
Parameters
----------
remote_request : SerializedRequest
Serialized request from a remote worker that has created a readable operation.
local_descriptors : Descriptor | list[Descriptor]
Local descriptors of one or more data objects to be transferred to the remote worker.
"""
if remote_request is None or not isinstance(remote_request, SerializedRequest):
raise TypeError("Argument `remote_request` must be `SerializedRequest`.")
if not (
isinstance(local_descriptors, Descriptor)
or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors))
):
raise TypeError("Argument `local_descriptors` must be `Descriptor` or `list[Descriptor]`.")
if remote_request.operation_kind != OperationKind.WRITE:
raise RuntimeError("Cannot create a `WriteOperation` to write to a remote `ReadableOperation`.")
if not isinstance(remote_request.nixl_metadata, str):
raise TypeError("Argument `remote_request.nixl_metadata` must be `str`.")
if not self._is_initialized:
raise RuntimeError("Connector not initialized. Call `initialize()` before calling this method.")
op = WriteOperation(self, local_descriptors, remote_request)
return op
def create_readable(
self,
local_descriptors: Descriptor | list[Descriptor],
) -> ReadableOperation:
"""
Creates a readable operation for transferring data from a remote worker.
Returns
-------
ReadableOperation
A readable operation that can be used to transfer data from a remote worker.
"""
if not self._is_initialized:
raise RuntimeError("Connector not initialized. Call `initialize()` before calling this method.")
op = ReadableOperation(self, local_descriptors)
return op
def create_writable(
self,
local_descriptors: Descriptor | list[Descriptor],
) -> WritableOperation:
"""
Creates a writable operation for transferring data to a remote worker.
Returns
-------
WritableOperation
A writable operation that can be used to transfer data to a remote worker.
"""
if not self._is_initialized:
raise RuntimeError("Connector not initialized. Call `initialize()` before calling this method.")
op = WritableOperation(self, local_descriptors)
return op
async def initialize(self) -> None:
# Only initialize the connector once.
if self._is_initialized:
return
self._is_initialized = True
# This method is a no-op for now, in the future it may be used to initialize the connector.
logger.debug(f"Initialized Connector {{ name: '{self._worker_id}', namespace '{self._namespace}' }} completed.")
class Descriptor:
"""
Memory descriptor that ensures memory is registered w/ NIXL, used for transferring data between workers.
"""
def __init__(
self,
data: torch.Tensor | tuple[array_module.ndarray, Device|str] | bytes | tuple[int, int, Device|str, Any],
) -> None:
"""
Memory descriptor for transferring data between workers.
Parameters
----------
data : torch.Tensor | tuple[ndarray, Device|str] | bytes | tuple[int, int, Device|str, Any]
The data to be transferred.
When `torch.Tensor` is provided, the attributes of the tensor will be used to create the descriptor.
When `tuple[ndarray, Device]` is provided, the tuple must contain:
- `ndarray`: The CuPy or NumPy array to be transferred.
- `Device`: Either a `dynamo.connect.Device` or a string representing the device type (e.g., "cuda" or "cpu").
When `bytes` is provided, the pointer and size derived from the bytes object and memory type will be assumed to be CPU.
When `tuple[int, int, Device|str, Any]` is provided, the tuple must contain the following elements:
- `int`: Pointer to the data in memory.
- `int`: Size of the data in bytes.
- `Device`: Either a `dynamo.connect.Device` or a string representing the device type (e.g., "cuda" or "cpu").
- `Any`: Optional reference to the data (e.g., the original tensor or bytes object).
This is useful for keeping a reference to the data in memory, but it is not required.
Raises
------
ValueError
When `data` is `None`.
TypeError
When `data` is not a valid type (i.e., not `torch.Tensor`, `bytes`, or a valid tuple).
TypeError
When `data` is a tuple but the elements are not of the expected types (i.e., [`ndarray`, `Device|str`] OR [`int`, `int`, `Device|str`, `Any`]).
"""
TYPE_ERROR_MESSAGE = "Argument `data` must be `torch.Tensor`, `tuple[ndarray, Device|str]`, `bytes`, or `tuple[int, int, Device|str, Any]`."
if data is None:
raise ValueError("Argument `data` cannot be `None`.")
if not (isinstance(data, torch.Tensor) or isinstance(data, bytes) or isinstance(data, tuple)):
raise TypeError(TYPE_ERROR_MESSAGE)
self._data_device: Device = Device("cpu")
self._data_ptr: int = 0
self._data_ref: Optional[Any] = None
self._data_size: int = 0
# Member fields for managing NIXL memory registration.
# Note: ONLY local descriptors should be registered with NIXL,
# remote descriptors do not have a valid memory address and registration will fault.
self._connector: Optional[Connector] = None
self._nixl_hndl: Optional[nixl_bindings.nixlRegDList] = None
# Initially `None` cached serialized descriptor reference, populated when `to_serialized()` is called.
self._serialized: Optional[SerializedDescriptor] = None
# Data is `torch.Tensor`.
if isinstance(data, torch.Tensor):
self._data_ptr = data.data_ptr()
self._data_size = data.numel() * data.element_size()
if data.is_cuda:
self._data_device = Device((DeviceKind.CUDA, data.get_device()))
self._data_ref = data
logger.debug(f"Created {self.__repr__()} from `torch.Tensor`.")
# Data is `tuple[ndarray, Device]`.
elif (
isinstance(data, tuple)
and len(data) == 2
and isinstance(data[0], array_module.ndarray)
and (isinstance(data[1], Device) or isinstance(data[1], str))
):
if hasattr(data[0], "__array_interface__"):
self._data_ptr = data[0].__array_interface__["data"][0]
elif hasattr(data[0], "__cuda_array_interface__"):
self._data_ptr = data[0].__cuda_array_interface__["data"][0]
else:
raise TypeError("Argument `data[0]` must be a `ndarray` with a valid array interface.")
self._data_size = data[0].nbytes
self._data_device = data[1] if isinstance(data[1], Device) else Device(data[1])
self._data_ref = data[0]
logger.debug(f"Created {self.__repr__()} from `tuple[ndarray, Device|str]`.")
# Data is `bytes`.
elif isinstance(data, bytes):
self._data_ptr = id(data)
self._data_size = len(data)
self._data_ref = data
logger.debug(f"Created {self.__repr__()} from `bytes`.")
# Data is `tuple[int, int, Device, dtype, tuple, Any]`.
elif isinstance(data, tuple) and len(data) >= 2 and isinstance(data[0], int) and isinstance(data[1], int):
if len(data) >= 3 and not (isinstance(data[2], Device) or isinstance(data[2], str)):
raise TypeError("Argument `data` must be a `tuple[int, int, Device|str, Any]`.")
self._data_ptr = data[0]
self._data_size = data[1]
if len(data) >= 3:
self._data_device = data[2] if isinstance(data[2], Device) else Device(data[2])
self._data_ref = data[3] if len(data) >=4 else None
logger.debug(f"Created {self.__repr__()} from `tuple[int, int, Device|str, Any]`.")
else:
raise TypeError(TYPE_ERROR_MESSAGE)
def __del__(self) -> None:
if self._nixl_hndl is not None and self._connector is not None:
# Unregister the memory with NIXL.
self._connector._nixl.deregister_memory(self._nixl_hndl)
self._nixl_hndl = None
if self._data_ref is not None:
# Release the reference to the data.
del self._data_ref
logger.debug(f"Deleted {self.__repr__()}.")
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self})"
def __str__(self) -> str:
return f"ptr={hex(self._data_ptr)}, size={self._data_size}, device={self._data_device}"
@property
def device(self) -> Device:
"""
Gets the device the of the descriptor.
"""
return self._data_device
@property
def ptr(self) -> int:
"""
Gets the pointer of the descriptor.
"""
return self._data_ptr
@property
def size(self) -> int:
"""
Gets the size of the descriptor.
"""
return self._data_size
@staticmethod
def from_serialized(
serialized: SerializedDescriptor,
) -> Descriptor:
"""
Deserializes a `SerializedDescriptor` into a `Descriptor` object.
Parameters
----------
serialized : SerializedDescriptor
The serialized descriptor to deserialize.
Returns
-------
Descriptor
The deserialized descriptor.
"""
if not isinstance(serialized, SerializedDescriptor):
raise TypeError("Argument `serialized` must be `SerializedDescriptor`.")
return serialized.to_descriptor()
def register_memory(
self,
connector: Connector,
) -> None:
"""
Registers the memory of the descriptor with NIXL.
"""
if not isinstance(connector, Connector):
raise TypeError("Argument `connector` must be `dynamo.connect.Connector`.")
if self._data_ptr == 0:
raise ValueError("Cannot register memory with a null pointer.")
if not (self._nixl_hndl is None and self._connector is None):
return
# Register the memory with NIXL.
self._connector = connector
if isinstance(self._data_ref, torch.Tensor):
self._nixl_hndl = connector._nixl.register_memory(self._data_ref)
else:
mem_type = str(self._data_device.kind)
reg_list = [(self._data_ptr, self._data_size, self._data_device.id, mem_type)]
self._nixl_hndl = connector._nixl.register_memory(reg_list, mem_type)
logger.debug(f"Registered {self.__repr__()} with NIXL.")
def to_serialized(self) -> SerializedDescriptor:
"""
Serializes the descriptor into a `SerializedDescriptor` object.
"""
if self._serialized is None:
self._serialized = SerializedDescriptor(
device=f"{self._data_device}",
ptr=self._data_ptr,
size=self._data_size,
)
return self._serialized
class Device:
"""
Represents a device in the system.
"""
def __init__(
self,
metadata: str | tuple[DeviceKind, int],
) -> None:
if metadata is None:
raise ValueError("Argument `metadata` cannot be `None`.")
if isinstance(metadata, tuple) and len(metadata) == 2 and isinstance(metadata[0], DeviceKind) and isinstance(metadata[1], int):
kind, device_id = metadata
elif isinstance(metadata, str):
metadata = metadata.strip().lower()
if metadata.startswith("cuda") or metadata.startswith("gpu"):
kind = DeviceKind.CUDA
device_id = 0 if metadata.find(":") == -1 else int(metadata.split(":")[1])
elif metadata.startswith("cpu") or metadata.startswith("host"):
kind = DeviceKind.HOST
device_id = 0
else:
raise ValueError("Argument `metadata` must be in the format 'cuda:<device_id>' or 'cpu'.")
else:
raise TypeError("Argument `metadata` must be a `tuple[MemoryKind, int]` or a `str`.")
self._device_id = device_id
self._kind = kind
def __repr__(self) -> str:
return f"{self.__class__.__name__}(kind={self._kind}, id={self._device_id})"
def __str__(self) -> str:
return f"{self._kind}:{self._device_id}" if self._kind is DeviceKind.CUDA else f"{self._kind}"
@property
def id(self) -> int:
"""
Gets the device ID of the device.
"""
return self._device_id
@property
def kind(self) -> DeviceKind:
"""
Gets the memory kind of the device.
"""
return self._kind
class DeviceKind(IntEnum):
"""
Type of memory a descriptor has been allocated to.
"""
UNSPECIFIED = 0
HOST = 1
CUDA = 2
def __str__(self) -> str:
if self == DeviceKind.HOST:
return "cpu"
elif self == DeviceKind.CUDA:
return "cuda"
else:
return "<invalid>"
class OperationKind(IntEnum):
"""
Kind of an operation.
"""
UNSPECIFIED = 0
READ = 1
WRITE = 2
def __str__(self) -> str:
if self == OperationKind.READ:
return "READ"
elif self == OperationKind.WRITE:
return "WRITE"
else:
return "<invalid>"
class OperationStatus(IntEnum):
"""
Status of an operation.
"""
UNINTIALIZED = 0
INITIALIZED = 1
IN_PROGRESS = 2
COMPLETE = 3
CANCELLED = 4
ERRORED = 5
def __str__(self) -> str:
if self == OperationStatus.INITIALIZED:
return "INIT"
elif self == OperationStatus.IN_PROGRESS:
return "PROC"
elif self == OperationStatus.COMPLETE:
return "DONE"
elif self == OperationStatus.ERRORED:
return "ERR"
elif self == OperationStatus.CANCELLED:
return "STOP"
else:
return "<invalid>"
class PassiveOperation(AbstractOperation):
"""
Abstract class for common functionality of passive operations.
"""
def __init__(
self,
connector: Connector,
operation_kind: OperationKind,
local_descriptors: Descriptor | list[Descriptor],
) -> None:
if operation_kind is not OperationKind.READ and operation_kind is not OperationKind.WRITE:
raise ValueError("Argument `operation_kind` must be either `READ` or `WRITE`.")
self._status = OperationStatus.UNINTIALIZED
super().__init__(connector, operation_kind, local_descriptors, None, None)
self._serialized_request: Optional[SerializedRequest] = None
self._status = OperationStatus.INITIALIZED
def __del__(self) -> None:
super().__del__()
def __enter__(self) -> AbstractOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
super().__exit__(exc_type, exc_value, traceback)
def __repr__(self) -> str:
return str(
f"{self.__class__.__name__}("
f"operation_kind={self._operation_kind}, "
f"local_descriptors={self._local_descriptors}, "
f"notification_key='{self._notification_key}', "
f"status='{self._status}'"
f")"
)
async def _wait_for_completion_(self) -> None:
# Loop until the operation is no longer in progress (or "initalized"),
# yielding control to the event loop to allow other operations to run.
while True:
match self.status:
# "in progress" or "initialized" means the operation is ongoing.
case OperationStatus.INITIALIZED:
await asyncio.sleep(0.1)
case OperationStatus.IN_PROGRESS:
await asyncio.sleep(0.1)
# Any other state indicates completion or error.
case _:
return
@property
def status(self) -> OperationStatus:
"""
Gets the status of the operation.
"""
# Early return if the operation is already complete, errored, or cancelled.
match self._status:
case OperationStatus.COMPLETE | OperationStatus.ERRORED | OperationStatus.CANCELLED:
return self._status
old_status = self._status
# Query NIXL for any notifications.
notifications = self._connector._nixl.update_notifs()
if isinstance(notifications, dict):
remote_state = OperationStatus.IN_PROGRESS
logger.debug(f"NIXL reported notifications: {len(notifications)}.")
for key, values in notifications.items():
if not isinstance(values, list):
raise TypeError(f"Expected `dict[str, list[bytes]]` from NIXL notification query; got {type(notifications)}.")
for value in values:
if not isinstance(value, bytes):
continue
notification_key = value.decode("utf-8")
# Once we've found the notification key, we know the operation is complete.
if notification_key == self._notification_key:
remote_state = OperationStatus.COMPLETE
break
if remote_state == OperationStatus.COMPLETE:
self._status = remote_state
logger.debug(f"{self.__class__.__name__} {{ remote: '{self._connector.name}' status: '{old_status}' => '{self._status}' }}.")
return self._status
def to_serialized(self) -> SerializedRequest:
"""
Gets the request descriptor for the operation.
"""
if self._serialized_request is None:
# When we've not yet cached the serialized request, we need to generate one before returning it.
# Handle both cases: multiple and single descriptors.
if isinstance(self._local_descriptors, list):
descriptors = [desc.to_serialized() for desc in self._local_descriptors]
else:
descriptors = [self._local_descriptors.to_serialized()]
original_len = len(self._connector.metadata)
nixl_metadata = self._connector.metadata
nixl_metadata = zlib.compress(nixl_metadata, level=6)
compressed_len = len(nixl_metadata)
logger.debug(f"Compressed NIXL metadata from {original_len} bytes to {compressed_len} bytes.")
if compressed_len > original_len:
logger.warning(f"Compressed NIXL metadata is larger than original ({compressed_len} > {original_len}).")
self._serialized_request = SerializedRequest(
descriptors=descriptors,
nixl_metadata=nixl_metadata.hex(),
notification_key=self._notification_key,
operation_kind=int(self._operation_kind),
)
return self._serialized_request
@abstractmethod
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
raise NotImplementedError("Abstract method not implemented by derived class.")
class ReadOperation(ActiveOperation):
"""
Operation that initiates an RDMA read operation to transfer data from a remote worker's `ReadableOperation`,
as described by `remote_request`, to local buffers.
"""
def __init__(
self,
connector: Connector,
remote_request: SerializedRequest,
local_descriptors: Descriptor | list[Descriptor],
) -> None:
"""
Creates a new instance of `ReadOperation`, registers `local_descriptors` with NIXL,
and begins an RDMA read operation which will transfer data described by `remote_request`
to `local_descriptors`.
Parameters
----------
connector : Connector
Connector instance to use for the operation.
remote_request : SerializedRequest
Serialized request from the remote worker.
local_descriptors : Descriptor | list[Descriptor]
Local descriptor(s) to to receive the data from the remote worker.
"""
if not isinstance(connector, Connector):
raise TypeError("Argument `connector` must be `dynamo.connect.Connector`.")
if not isinstance(remote_request, SerializedRequest):
raise TypeError("Argument `remote_request` must be `dynamo.connect.RequestDescriptor`.")
if remote_request.operation_kind != OperationKind.READ.value:
raise ValueError("Argument `remote_request` must be of kind `READ`.")
remote = Remote(connector, remote_request.nixl_metadata)
remote_descriptors = remote_request.to_descriptors()
if not (
isinstance(local_descriptors, Descriptor)
or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors))
):
raise TypeError("Argument `local_descriptors` must be `dynamo.connect.Descriptor`, `list[dynamo.connect.Descriptor]`.")
super().__init__(remote, OperationKind.READ, local_descriptors, remote_descriptors, remote_request.notification_key)
logger.debug(f"Created {self.__repr__()}")
def __del__(self) -> None:
super().__del__()
logger.debug(f"Deleted {self.__repr__()}")
def __enter__(self) -> ReadOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
super().__exit__(exc_type, exc_value, traceback)
def __repr__(self) -> str:
return super().__repr__()
def cancel(self) -> None:
"""
Cancels the operation.
No affect if the operation has already completed or errored, or been cancelled.
"""
super()._cancel_()
def results(self) -> list[Descriptor]:
"""
Gets the results of the operation.
Returns a single descriptor if only one was requested, or a list of descriptors if multiple were requested.
"""
if self._status != OperationStatus.COMPLETE:
raise RuntimeError("Operation has not completed yet, cannot get results.")
return self._local_descriptors if isinstance(self._local_descriptors, list) else [self._local_descriptors]
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await super()._wait_for_completion_()
class ReadableOperation(PassiveOperation):
"""
Operation that can be awaited until a remote worker has completed a `ReadOperation`.
"""
def __init__(
self,
connector: Connector,
local_descriptors: Descriptor | list[Descriptor],
) -> None:
super().__init__(connector, OperationKind.READ, local_descriptors)
logger.debug(f"Created {self.__repr__()}")
def __del__(self) -> None:
super().__del__()
logger.debug(f"Deleted {self.__repr__()}")
def __enter__(self) -> ReadableOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
super().__exit__(exc_type, exc_value, traceback)
def __repr__(self) -> str:
return super().__repr__()
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await super()._wait_for_completion_()
class Remote:
"""
Identifies a remote NIXL enabled worker relative to a local NIXL enabled worker.
"""
def __init__(
self,
connector: Connector,
nixl_metadata: bytes | str,
) -> None:
if not isinstance(connector, Connector):
raise TypeError("Argument `local` must be `dynamo.connect.Connector`.")
if not (isinstance(nixl_metadata, bytes) or isinstance(nixl_metadata, str)):
raise TypeError("Argument `nixl_metadata` must be `bytes` or `str`.")
if len(nixl_metadata) == 0:
raise ValueError("Argument `nixl_metadata` cannot be empty.")
self._connector = connector
# When `nixl_metadata` is a string, it is assumed to have come from a remote worker
# via a `SerializedRequest` object and therefore can assumed be a hex-encoded, compressed
# representation of the NIXL metadata.
if isinstance(nixl_metadata, str):
# Decode the hex-encoded string into bytes.
nixl_metadata = bytes.fromhex(nixl_metadata)
# Decompress the NIXL metadata.
nixl_metadata = zlib.decompress(nixl_metadata)
self._name = connector._nixl.add_remote_agent(nixl_metadata)
if isinstance(self._name, bytes):
self._name = self._name.decode("utf-8")
logger.debug(f"Created {self.__repr__()}.")
def __del__(self) -> None:
self._release()
def __enter__(self) -> Remote:
"""
Context manager entry method. Returns the current instance.
"""
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
"""
Context manager exit method. Cleans up the instance.
"""
self._release()
def __repr__(self) -> str:
return f"Remote(name={self._name}, connector={self._connector.name})"
def __str__(self) -> str:
return self._name
def _release(self) -> None:
"""
Private method for releasing NIXL resources. Not intended for public use.
"""
# We have to unregister the remote agent from NIXL because we cannot know if the remote worker has updated its descriptors or not, and
# NIXL will return an error if we attempt to register a remote agent with the same name but different descriptors (aka conn_info).
self._connector._nixl.remove_remote_agent(self._name)
logger.debug(f"dynamo.connect.{self.__class__.__name__}: Unregistered NIXL remote {{ name: \"{self._name}\" }}.")
@property
def connector(self) -> Connector:
"""
Gets the local connector associated with this remote worker.
"""
return self._connector
@property
def name(self) -> str:
"""
Gets the name of the remote worker.
"""
return self._name
class SerializedDescriptor(BaseModel):
"""
Pydantic serialization type for memory descriptors.
"""
model_config = ConfigDict(
extra="forbid",
frozen=True,
arbitrary_types_allowed=True,
)
device: str = "cpu"
ptr: int = 0
size: int = 0
def to_descriptor(self) -> Descriptor:
"""
Deserialize the serialized descriptor into a `Descriptor` object.
"""
return Descriptor(data=(self.ptr, self.size, self.device, None))
@field_validator("device")
def validate_memtype(cls, v: str) -> str:
if not isinstance(v, str):
raise TypeError("Argument `device` must be `str`.")
v = v.strip().lower()
if not (v.startswith("cuda") or v == "cpu"):
raise ValueError("Argument `device` must be one of 'cpu' or 'cuda:<device_id>'.")
return v
@field_validator("ptr")
def validate_ptr(cls, v: int) -> int:
if v == 0:
raise ValueError("Argument `ptr` cannot be zero (aka `null` or `None`).")
return v
@field_validator("size")
def validate_size(cls, v: int) -> int:
if v < 0:
raise ValueError("Argument `size` must be an integer greater than or equal to zero.")
return v
class SerializedRequest(BaseModel):
"""
Pydantic serialization type for describing the passive side of a transfer.
"""
model_config = ConfigDict(
extra="forbid",
frozen=True,
arbitrary_types_allowed=True,
)
descriptors: List[SerializedDescriptor] = []
nixl_metadata: str = ""
notification_key: str = ""
operation_kind: int = 0
def to_descriptors(self) -> Descriptor | list[Descriptor]:
"""
Deserializes the request descriptor into a `dynamo.connect.Descriptor` or list of `dynamo.connect.Descriptor` objects.
"""
if len(self.descriptors) == 0:
raise ValueError("Request descriptor must contain at least one serialized descriptor.")
if len(self.descriptors) == 1:
return self.descriptors[0].to_descriptor()
return [item.to_descriptor() for item in self.descriptors]
@field_validator("operation_kind")
def validate_operation_kind(cls, v: int) -> int:
if v < 1 or v > 3:
raise TypeError("Argument `operation_kind` must be an integer value of `dynamo.connect.OperationKind`.")
return v
class WritableOperation(PassiveOperation):
"""
Operation which can be awaited until written to by a `WriteOperation` from a remote worker.
"""
def __init__(
self,
connector: Connector,
local_descriptors: Descriptor | list[Descriptor],
) -> None:
"""
Creates a new instance of `WritableOperation`, registers the operation and descriptors w/ NIXL,
and enables an RDMA write operation to occur.
Parameters
----------
connector : Connector
Connector instance to use for the operation.
local_descriptors : Descriptor | list[Descriptor]
Descriptors to receive data from a remote worker.
Raises
TypeError
When `local` is not a `dynamo.connect.Connector`.
TypeError
When `local_descriptors` is not a `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.
"""
super().__init__(connector, OperationKind.WRITE, local_descriptors)
logger.debug(f"Created {self.__repr__()}")
def __del__(self) -> None:
super().__del__()
logger.debug(f"Deleted {self.__repr__()}")
def __enter__(self) -> WritableOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
super().__exit__(exc_type, exc_value, traceback)
def __repr__(self) -> str:
return super().__repr__()
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await super()._wait_for_completion_()
class WriteOperation(ActiveOperation):
"""
Awaitable write operation which initiates an RDMA write operation to a remote worker
which provided a `SerializedRequest` object from a `WritableOperation`.
"""
def __init__(
self,
connector: Connector,
local_descriptors: Descriptor | list[Descriptor],
remote_request: SerializedRequest,
) -> None:
"""
Creates a new instance of `WriteOperation`, registers `local_descriptors` with NIXL,
and begins an RDMA write operation which will transfer from `local_descriptors` to
remote target(s) described by `remote_request`
Parameters
----------
connector : Connector
Connector instance to use for the operation.
local_descriptors : Descriptor | list[Descriptor]
Local descriptor(s) to send from, to the remote worker.
remote_request : SerializedRequest
Serialized request from the remote worker that describes the target(s) to send to.
Raises
TypeError
When `connector` is not a `dynamo.connect.Connector`.
TypeError
When `remote_request` is not a `dynamo.connect.RequestDescriptor`.
ValueError
When `remote_request` is not of kind `WRITE`.
ValueError
When `remote_request.nixl_metadata` is not a non-empty `str`.
TypeError
When `local_descriptors` is not a `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.
"""
if not isinstance(connector, Connector):
raise TypeError("Argument `connector` must be `dynamo.connect.Connector`.")
if not isinstance(remote_request, SerializedRequest):
raise TypeError("Argument `remote_request` must be `dynamo.connect.RequestDescriptor`.")
if remote_request.operation_kind != OperationKind.WRITE.value:
raise ValueError("Argument `remote_request` must be of kind `WRITE`.")
remote = Remote(connector, remote_request.nixl_metadata)
remote_descriptors = remote_request.to_descriptors()
super().__init__(remote, OperationKind.WRITE, local_descriptors, remote_descriptors, remote_request.notification_key)
logger.debug(f"Created {self.__repr__()}")
def __del__(self) -> None:
super().__del__()
logger.debug(f"Deleted {self.__repr__()}")
def __enter__(self) -> WriteOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
super().__exit__(exc_type, exc_value, traceback)
def __repr__(self) -> str:
return super().__repr__()
def cancel(self) -> None:
"""
Cancels the operation.
No affect if the operation has already completed or errored, or has been cancelled.
"""
super()._cancel_()
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await super()._wait_for_completion_()
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
# Default values
MODEL_NAME="llava-hf/llava-1.5-7b-hf"
PROMPT_TEMPLATE="USER: <image>\n<prompt> ASSISTANT:"
PROVIDED_PROMPT_TEMPLATE=""
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL_NAME=$2
shift 2
;;
--prompt-template)
PROVIDED_PROMPT_TEMPLATE=$2
shift 2
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --model <model_name> Specify the model to use (default: $MODEL_NAME)"
echo " --prompt-template <template> Specify the multi-modal prompt template to use. LLaVA 1.5 7B, Qwen2.5-VL, and Phi3V models have predefined templates."
echo " -h, --help Show this help message"
exit 0
;;
*)
echo "Unknown option: $1"
echo "Use --help for usage information"
exit 1
;;
esac
done
# Set PROMPT_TEMPLATE based on the MODEL_NAME
if [[ -n "$PROVIDED_PROMPT_TEMPLATE" ]]; then
PROMPT_TEMPLATE="$PROVIDED_PROMPT_TEMPLATE"
elif [[ "$MODEL_NAME" == "llava-hf/llava-1.5-7b-hf" ]]; then
PROMPT_TEMPLATE="USER: <image>\n<prompt> ASSISTANT:"
elif [[ "$MODEL_NAME" == "microsoft/Phi-3.5-vision-instruct" ]]; then
PROMPT_TEMPLATE="<|user|>\n<|image_1|>\n<prompt><|end|>\n<|assistant|>\n"
elif [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then
PROMPT_TEMPLATE="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|><prompt><|im_end|>\n<|im_start|>assistant\n"
else
echo "No multi-modal prompt template is defined for the model: $MODEL_NAME"
echo "Please provide a prompt template using --prompt-template option."
echo "Example: --prompt-template 'USER: <image>\n<prompt> ASSISTANT:'"
exit 1
fi
# run ingress
python -m dynamo.frontend &
# run processor
python3 components/processor.py --model $MODEL_NAME --prompt-template "$PROMPT_TEMPLATE" &
# run E/P/D workers
CUDA_VISIBLE_DEVICES=0 python3 components/encode_worker.py --model $MODEL_NAME &
CUDA_VISIBLE_DEVICES=1 python3 components/worker.py --model $MODEL_NAME --worker-type prefill &
# Wait for all background processes to complete
wait
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -ex
trap 'echo Cleaning up...; kill 0' EXIT
MODEL_NAME="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
# run ingress
python -m dynamo.frontend &
# run processor
python3 components/processor.py --model $MODEL_NAME --prompt-template "<|image|>\n<prompt>" &
# LLama 4 doesn't support image embedding input, so the prefill worker will also
# handle image encoding.
# run EP/D workers
python3 components/worker.py --model $MODEL_NAME --worker-type encode_prefill --tensor-parallel-size=8 --max-model-len=208960 &
# Wait for all background processes to complete
wait
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
# Default values
MODEL_NAME="llava-hf/llava-1.5-7b-hf"
PROMPT_TEMPLATE="USER: <image>\n<prompt> ASSISTANT:"
PROVIDED_PROMPT_TEMPLATE=""
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL_NAME=$2
shift 2
;;
--prompt-template)
PROVIDED_PROMPT_TEMPLATE=$2
shift 2
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --model <model_name> Specify the model to use (default: $MODEL_NAME)"
echo " --prompt-template <template> Specify the multi-modal prompt template to use. LLaVA 1.5 7B, Qwen2.5-VL, and Phi3V models have predefined templates."
echo " -h, --help Show this help message"
exit 0
;;
*)
echo "Unknown option: $1"
echo "Use --help for usage information"
exit 1
;;
esac
done
# Set PROMPT_TEMPLATE based on the MODEL_NAME
if [[ -n "$PROVIDED_PROMPT_TEMPLATE" ]]; then
PROMPT_TEMPLATE="$PROVIDED_PROMPT_TEMPLATE"
elif [[ "$MODEL_NAME" == "llava-hf/llava-1.5-7b-hf" ]]; then
PROMPT_TEMPLATE="USER: <image>\n<prompt> ASSISTANT:"
elif [[ "$MODEL_NAME" == "microsoft/Phi-3.5-vision-instruct" ]]; then
PROMPT_TEMPLATE="<|user|>\n<|image_1|>\n<prompt><|end|>\n<|assistant|>\n"
elif [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then
PROMPT_TEMPLATE="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|><prompt><|im_end|>\n<|im_start|>assistant\n"
else
echo "No multi-modal prompt template is defined for the model: $MODEL_NAME"
echo "Please provide a prompt template using --prompt-template option."
echo "Example: --prompt-template 'USER: <image>\n<prompt> ASSISTANT:'"
exit 1
fi
# run ingress
python -m dynamo.frontend &
# run processor
python3 components/processor.py --model $MODEL_NAME --prompt-template "$PROMPT_TEMPLATE" &
# run E/P/D workers
CUDA_VISIBLE_DEVICES=0 python3 components/encode_worker.py --model $MODEL_NAME &
CUDA_VISIBLE_DEVICES=1 python3 components/worker.py --model $MODEL_NAME --worker-type prefill --enable-disagg &
CUDA_VISIBLE_DEVICES=2 python3 components/worker.py --model $MODEL_NAME --worker-type decode --enable-disagg &
# Wait for all background processes to complete
wait
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -ex
# Default values
HEAD_NODE=0
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--head-node)
HEAD_NODE=1
shift 1
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --head-node Run as head node. Head node will run the HTTP server, processor and prefill worker."
echo " -h, --help Show this help message"
exit 0
;;
*)
echo "Unknown option: $1"
echo "Use --help for usage information"
exit 1
;;
esac
done
trap 'echo Cleaning up...; kill 0' EXIT
MODEL_NAME="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
if [[ $HEAD_NODE -eq 1 ]]; then
# run ingress
python -m dynamo.frontend &
# run processor
python3 components/processor.py --model $MODEL_NAME --prompt-template "<|image|>\n<prompt>" &
# LLama 4 doesn't support image embedding input, so the prefill worker will also
# handle image encoding.
# run EP/D workers
python3 components/worker.py --model $MODEL_NAME --worker-type encode_prefill --enable-disagg --tensor-parallel-size=8 --max-model-len=208960 &
else
# run decode worker on non-head node
python3 components/worker.py --model $MODEL_NAME --worker-type decode --enable-disagg --tensor-parallel-size=8 --max-model-len=208960 &
fi
# Wait for all background processes to complete
wait
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import json
import logging
import os
import socket
import sys
import time
from typing import Callable, List, Optional, Tuple
from vllm.config import KVTransferConfig
from vllm.distributed.kv_events import KVEventsConfig
from vllm.engine.arg_utils import AsyncEngineArgs
logger = logging.getLogger(__name__)
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
class Config:
"""Command line parameters or defaults"""
# dynamo specific
namespace: str
component: str
endpoint: str
kv_port: Optional[int] = None
side_channel_port: Optional[int] = None
# mirror vLLM
model: str
served_model_name: Optional[str]
# rest vLLM args
engine_args: AsyncEngineArgs
def parse_endpoint(endpoint: str) -> List[str]:
endpoint_str = endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
logger.error(
f"Invalid endpoint format: '{endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
sys.exit(1)
return endpoint_parts
def base_parse_args(
parser: argparse.ArgumentParser, endpoint_overwrite: Optional[Callable] = None
) -> Tuple[argparse.Namespace, Config]:
"""
Basic parsing logic for any dynamo vLLM deployment. The caller will use
'parser' and 'endpoint_overwrite' to apply use case specific customization.
Args:
parser (argparse.ArgumentParser): The argument parser which has use case
specific arguments added.
endpoint_overwrite (Callable): A user provided function to overwrite the endpoints
the given the parsed arguments. This function should return the overwritten args.
A typical selector will check the worker type and return specific endpoints.
Returns:
Tuple[argparse.Namespace, Config]: A tuple containing the parsed arguments
and a Config object with the relevant settings.
"""
if not any(arg.dest == "endpoint" for arg in parser._actions):
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
config = Config()
config.model = args.model
if args.served_model_name:
assert (
len(args.served_model_name) <= 1
), "We do not support multiple model names."
config.served_model_name = args.served_model_name[0]
else:
# This becomes an `Option` on the Rust side
config.served_model_name = None
if endpoint_overwrite is not None:
args = endpoint_overwrite(args)
endpoint = args.endpoint
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
endpoint
)
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
config.engine_args = engine_args
if config.engine_args.block_size is None:
config.engine_args.block_size = 16
logger.debug(
f"Setting reasonable default of {config.engine_args.block_size} for block_size"
)
return args, config
async def allocate_and_reserve_port(
namespace,
etcd_client,
worker_id: str,
reason: str,
max_attempts: int = 100,
) -> int:
"""
Get an OS-assigned port and atomically reserve it in ETCD.
Retries until successful or max_attempts reached.
Args:
max_attempts: Maximum number of ports to try (default: 100)
Raises:
RuntimeError: If unable to reserve a port within max_attempts
OSError: If unable to create sockets (system resource issues)
"""
node_name = socket.gethostname()
for attempt in range(1, max_attempts + 1):
# Hold socket open just long enough to reserve in ETCD
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("", 0))
port = sock.getsockname()[1]
# Reserve in ETCD while holding the socket
key = f"dyn://{namespace}/ports/{node_name}/{port}"
value = {
"worker_id": worker_id,
"reason": reason,
"reserved_at": time.time(),
"pid": os.getpid(),
}
try:
await etcd_client.kv_create(
key=key,
value=json.dumps(value).encode(),
lease_id=etcd_client.primary_lease_id(),
)
logger.debug(f"Reserved OS-assigned port {port} for {worker_id}")
return port
except Exception as e:
logger.debug(
f"Port {port} on {node_name} was already reserved (attempt {attempt}): {e}"
)
if attempt < max_attempts:
await asyncio.sleep(0.01)
raise RuntimeError(
f"Failed to allocate and reserve a port after {max_attempts} attempts"
)
async def configure_ports_with_etcd(config: Config, etcd_client):
"""Configure all settings that require ETCD, including port allocation and vLLM overrides."""
# First, allocate ports
dp_rank = config.engine_args.data_parallel_rank or 0
worker_id = f"vllm-{config.component}-dp{dp_rank}"
# Allocate KV events port
kv_port = await allocate_and_reserve_port(
namespace=config.namespace,
etcd_client=etcd_client,
worker_id=f"{worker_id}",
reason="zmq_kv_event_port",
)
# Allocate side channel port
side_channel_port = await allocate_and_reserve_port(
namespace=config.namespace,
etcd_client=etcd_client,
worker_id=f"{worker_id}",
reason="nixl_side_channel_port",
)
# Update config with allocated ports
config.kv_port = kv_port
config.side_channel_port = side_channel_port
def overwrite_args(config):
"""Set vLLM defaults for Dynamo."""
assert (
config.kv_port is not None
), "Must set the kv_port, use configure_ports_with_etcd"
assert (
config.side_channel_port is not None
), "Must set the side_channel_port, use configure_ports_with_etcd"
dp_rank = config.engine_args.data_parallel_rank or 0
defaults = {
"task": "generate",
"skip_tokenizer_init": False,
"disable_log_requests": True,
"enable_prefix_caching": True,
# KV routing relies on logging KV metrics
"disable_log_stats": False,
# Always setting up kv transfer for disagg
"kv_transfer_config": KVTransferConfig(
kv_connector="NixlConnector", kv_role="kv_both"
),
"kv_events_config": KVEventsConfig(
enable_kv_cache_events=True,
publisher="zmq",
endpoint=f"tcp://*:{config.kv_port - dp_rank}", # vLLM will iterate dp_rank for us, so we need to subtract it out TODO: fix in vLLM
),
}
set_side_channel_host_and_port(config)
logger.debug("Setting Dynamo defaults for vLLM")
for key, value in defaults.items():
if hasattr(config.engine_args, key):
setattr(config.engine_args, key, value)
logger.debug(f" engine_args.{key} = {value}")
else:
raise ValueError(f"{key} not found in AsyncEngineArgs from vLLM.")
def set_side_channel_host_and_port(config: Config, hostname: Optional[str] = None):
"""vLLM V1 NixlConnector creates a side channel to exchange metadata with other NIXL connectors.
This sets the port number for the side channel.
"""
if hostname is None:
hostname = socket.gethostname()
# Test if hostname is usable by attempting to bind to it
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as test_socket:
test_socket.bind((hostname, 0))
except (socket.error, socket.gaierror):
# If hostname is not usable, fall back to localhost
logger.warning(
f"Hostname '{hostname}' is not usable, falling back to '127.0.0.1'"
)
hostname = "127.0.0.1"
os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = hostname
os.environ["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(config.side_channel_port)
logger.debug(f"Set NIXL side channel to {hostname}:{config.side_channel_port}")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
from typing import AsyncIterator, List, Optional, Protocol, Union, runtime_checkable
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.chat_utils import ConversationMessage
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
RequestResponseMetadata,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import RequestPrompt
from vllm.inputs.data import TokensPrompt
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
@runtime_checkable
class ProcessMixInRequired(Protocol):
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
default_sampling_params: SamplingParams
class ProcessMixIn(ProcessMixInRequired):
"""
Mixin for pre and post processing for vLLM
"""
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
default_sampling_params: SamplingParams
def __init__(self):
pass
def _get_processor(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
# Determine the processor type based on the request structure
return (
self.chat_processor
if isinstance(raw_request, ChatCompletionRequest)
else self.completions_processor
)
async def _parse_raw_request(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
processor = self._get_processor(raw_request)
if processor is None:
raise RuntimeError("Processor has not been initialized")
request = processor.parse_raw_request(raw_request)
preprocess_result = await processor.preprocess(raw_request)
default_max_tokens = self.model_config.max_model_len - len(
preprocess_result.engine_prompt["prompt_token_ids"]
)
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern,
self.default_sampling_params,
)
return (
request,
preprocess_result.conversation,
preprocess_result.request_prompt,
preprocess_result.engine_prompt,
sampling_params,
)
async def _stream_response(self, request, generator, request_id, conversation):
processor = self._get_processor(request)
if processor is None:
raise RuntimeError("processor has not been initialized")
return processor.stream_response(
request,
generator,
request_id,
conversation,
)
class PreprocessResult:
def __init__(
self,
conversation: Optional[ConversationMessage],
request_prompt: RequestPrompt,
engine_prompt: TokensPrompt,
):
self.conversation = conversation
self.request_prompt = request_prompt
self.engine_prompt = engine_prompt
class ChatProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingChat(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
response_role="assistant",
chat_template=None,
chat_template_content_format="auto",
)
def parse_raw_request(
self, raw_request: ChatCompletionRequest
) -> ChatCompletionRequest:
return ChatCompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: ChatCompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
# TODO: Revisit this later when adding multi-modal support for the frontend.
# If no chat template is provided and tokenizer doesn't have one,
# use a simple format that just concatenates messages
if not request.chat_template and not self.tokenizer.chat_template:
chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}User: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}Assistant: {{ message['content'] }}\n{% endif %}{% endfor %}Assistant:"
else:
chat_template = request.chat_template or self.tokenizer.chat_template
(
conversation,
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_chat(
request,
self.tokenizer,
request.messages,
chat_template=chat_template,
chat_template_content_format=self.openai_serving.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tool_dicts=None,
documents=request.documents,
chat_template_kwargs=request.chat_template_kwargs,
tool_parser=self.openai_serving.tool_parser,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
return PreprocessResult(conversation[0], request_prompts[0], engine_prompts[0])
async def stream_response(
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: List,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if request.stream:
# Handle streaming response
num_output_text_so_far = 0
async for raw_response in self.openai_serving.chat_completion_stream_generator(
request,
result_generator,
request_id,
request.model,
conversation,
self.tokenizer,
request_metadata,
enable_force_include_usage=False,
):
if raw_response.startswith("data: [DONE]"):
yield raw_response
break
# Parse the response
response = json.loads(raw_response.lstrip("data: "))
# Process delta content to extract only new text
if "choices" in response and len(response["choices"]) > 0:
if "delta" in response["choices"][0]:
content = response["choices"][0]["delta"].get("content", "")
if content:
# Extract only the new part from the full content
new_content = content[num_output_text_so_far:]
response["choices"][0]["delta"]["content"] = new_content
num_output_text_so_far = len(content)
# Yield the processed response
yield f"data: {json.dumps(response)}\n\n"
else:
# Handle non-streaming response
# Collect all chunks into a single response
full_response = None
num_output_text_so_far = 0
async for raw_response in self.openai_serving.chat_completion_stream_generator(
request,
result_generator,
request_id,
request.model,
conversation,
self.tokenizer,
request_metadata,
enable_force_include_usage=False,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
if full_response is None:
# Initialize the full response structure
full_response = {
"id": response.get("id", ""),
"object": "chat.completion",
"created": int(time.time()),
"model": request.model,
"choices": [
{
"index": response.get("index", 0),
"message": {"role": "assistant", "content": ""},
"finish_reason": None,
}
],
}
# Concatenate content if it exists. Each delta contains the full text so far.
if "choices" in response and len(response["choices"]) > 0:
if "delta" in response["choices"][0]:
content = response["choices"][0]["delta"].get("content", "")
if content:
# Extract only the new part from the full content
new_content = content[num_output_text_so_far:]
full_response["choices"][0]["message"][
"content"
] += new_content
num_output_text_so_far = len(content)
# Update finish reason if present
if "finish_reason" in response["choices"][0]:
full_response["choices"][0]["finish_reason"] = response[
"choices"
][0]["finish_reason"]
if full_response is not None:
yield json.dumps(full_response)
class CompletionsProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingCompletion(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
)
def parse_raw_request(self, raw_request: CompletionRequest) -> CompletionRequest:
return CompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: CompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
(
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_completion(
request,
self.tokenizer,
input_or_inputs=request.prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
return PreprocessResult(None, request_prompts[0], engine_prompts[0])
async def stream_response(
self,
request: CompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: Optional[List[ConversationMessage]] = None,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if not request.stream:
raise ValueError("Only streaming responses are supported")
async for raw_response in self.openai_serving.completion_stream_generator(
request,
result_generator,
request_id,
int(time.time()), # created_time
request.model,
1, # num_prompts
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import base64
import binascii
import logging
from io import BytesIO
from urllib.parse import urlparse
import httpx
from PIL import Image
logger = logging.getLogger(__name__)
class ImageLoader:
CACHE_SIZE_MAXIMUM = 8
def __init__(self, cache_size: int = CACHE_SIZE_MAXIMUM):
self._http_timeout = 30.0
self._http_client = httpx.AsyncClient(timeout=self._http_timeout)
self._image_cache: dict[str, Image.Image] = {}
self._cache_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=cache_size)
async def load_image(self, image_url: str) -> Image.Image:
parsed_url = urlparse(image_url)
# For HTTP(S) URLs, check cache first
if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
if image_url_lower in self._image_cache:
logger.debug(f"Image found in cache for URL: {image_url}")
return self._image_cache[image_url_lower]
try:
if parsed_url.scheme == "data":
# Parse data URL format: data:[<media type>][;base64],<data>
if not parsed_url.path.startswith("image/"):
raise ValueError("Data URL must be an image type")
# Split the path into media type and data
media_type, data = parsed_url.path.split(",", 1)
if ";base64" not in media_type:
raise ValueError("Data URL must be base64 encoded")
try:
image_bytes = base64.b64decode(data)
image_data = BytesIO(image_bytes)
except binascii.Error as e:
raise ValueError(f"Invalid base64 encoding: {e}")
elif parsed_url.scheme in ("http", "https"):
if not self._http_client:
raise RuntimeError("HTTP client not initialized")
response = await self._http_client.get(image_url)
response.raise_for_status()
if not response.content:
raise ValueError("Empty response content from image URL")
image_data = BytesIO(response.content)
else:
raise ValueError(f"Invalid image source scheme: {parsed_url.scheme}")
# PIL is sync, so offload to a thread to avoid blocking the event loop
image = await asyncio.to_thread(Image.open, image_data)
# Validate image format and convert to RGB
if image.format not in ("JPEG", "PNG", "WEBP"):
raise ValueError(f"Unsupported image format: {image.format}")
image_converted = image.convert("RGB")
# Cache HTTP(S) URLs
if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
# Cache the image for future use, and evict the oldest image if the cache is full
if self._cache_queue.full():
oldest_image_url = await self._cache_queue.get()
del self._image_cache[oldest_image_url]
self._image_cache[image_url_lower] = image_converted
await self._cache_queue.put(image_url_lower)
return image_converted
except httpx.HTTPError as e:
logger.error(f"HTTP error loading image: {e}")
raise
except Exception as e:
logger.error(f"Error loading image: {e}")
raise ValueError(f"Failed to load image: {e}")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, Tuple
import torch
from transformers import AutoConfig
from utils.protocol import EncodeResponse
from vllm import AsyncEngineArgs
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.worker import Worker
# from transformers import AutoImageProcessor, LlavaForConditionalGeneration
# from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
logger = logging.getLogger(__name__)
# [gluo NOTE] in vLLM v1, Worker() usage below will results in NotImplementedError,
# must find another way to properly load the vision model given the model name (model_id).
def load_vision_model(model_id: str) -> torch.nn.Module:
"""
Load a vision model from a HuggingFace model ID.
"""
engine_args = AsyncEngineArgs(model=model_id, trust_remote_code=True)
engine_config = engine_args.create_engine_config()
distributed_init_method = get_distributed_init_method(get_ip(), get_open_port())
worker = Worker(
vllm_config=engine_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
is_driver_worker=True,
)
# Initialize the worker.
worker.init_device()
worker.load_model()
return worker.model_runner.model
# model = LlavaForConditionalGeneration.from_pretrained(
# model_id, device_map="auto", torch_dtype=torch.float16
# ).eval()
# model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
# model_id, torch_dtype="auto", device_map="auto"
# ).eval()
# return model
def get_vision_embeddings_info(
model_id: str, num_patches: int
) -> Tuple[Tuple[int, int, int], torch.dtype]:
"""Calculate vision embeddings size and dtype using model config
Returns a tuple of (batch_size, num_patches, hidden_dim), dtype.
"""
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
assert num_patches > 0, "Number of patches must be positive"
if not hasattr(config, "torch_dtype"):
raise ValueError("Model config missing required 'torch_dtype' attribute")
if not hasattr(config, "hidden_size"):
logger.warning(
"Model config missing required 'hidden_size' attribute, using 4096"
)
hidden_size = 4096
else:
hidden_size = config.hidden_size
return (1, num_patches, hidden_size), config.torch_dtype
def construct_mm_data(
model: str,
encode_output: EncodeResponse,
image_embeds: torch.Tensor,
embeddings_dtype: torch.dtype,
) -> Dict[str, torch.Tensor | Dict[str, Any]]:
"""Construct multimodal data for a vLLM request for models that require additional parameters alongside the embeddings"""
image_embeds = image_embeds.to(embeddings_dtype)
if "Qwen2" in model:
return {
"image": {
"image_embeds": image_embeds.squeeze(0),
"image_grid_thw": torch.tensor(encode_output.image_grid_thw).squeeze(0),
}
}
elif "MiniCPM-V" in model:
return {
"image": {
"image_embeds": image_embeds,
"image_sizes": encode_output.image_sizes,
}
}
else:
return {"image": image_embeds}
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, List, Literal, Optional, Union
import connect
import msgspec
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic_core import core_schema
from typing_extensions import NotRequired
from vllm.inputs.data import TokensPrompt
from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import PromptLogprobs, RequestMetrics
class Request(BaseModel):
prompt: str
sampling_params: dict
class Tokens(BaseModel):
tokens: list[int]
class PrefillRequest(Request):
request_id: str
class Response(BaseModel):
text: str
class PrefillResponse(BaseModel):
prefilled: bool
# Hack to override the type of multi_modal_data in TokensPrompt
# as pydantic doesn't understand generic types
# TokensPrompt is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/inputs/data.py#L38
# multi_modal_data is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L103
# ModalityData is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L80
class PatchedTokensPrompt(TokensPrompt):
multi_modal_data: NotRequired[Optional[Any]] # type: ignore
# Monkey-patch the SamplingParams type to add a dummy core schema so pydantic can validate it
# Sampling params is a mspspec struct
# SamplingParams is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/sampling_params.py#L88
SamplingParams.__get_pydantic_core_schema__ = classmethod(
lambda cls, source, handler: core_schema.any_schema()
)
class vLLMGenerateRequest(BaseModel):
"""
Serializable class of all the fields vLLM engine requires for inference
"""
engine_prompt: PatchedTokensPrompt
sampling_params: SamplingParams
request_id: str
prefix_hit_rate: Optional[float] = 0.0
@field_validator("sampling_params", mode="before")
@classmethod
def parse_sampling_params(cls, v: Any) -> SamplingParams:
if isinstance(v, str):
v = json.loads(v)
if isinstance(v, dict):
return SamplingParams(**v)
return v
model_config = ConfigDict(
arbitrary_types_allowed=True,
json_encoders={SamplingParams: lambda v: msgspec.json.encode(v)},
)
class TextContent(BaseModel):
type: Literal["text"]
text: str
class ImageURLDetail(BaseModel):
url: str
class ImageContent(BaseModel):
type: Literal["image_url"]
image_url: ImageURLDetail
MessageContent = Union[TextContent, ImageContent]
class ChatMessage(BaseModel):
role: Literal["user", "system", "assistant"]
content: List[MessageContent]
class MultiModalRequest(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
model: str
messages: List[ChatMessage]
max_tokens: Optional[int] = None
temperature: Optional[float] = None
stream: Optional[bool] = True
class vLLMMultimodalRequest(vLLMGenerateRequest):
model_config = ConfigDict(arbitrary_types_allowed=True)
image_url: Optional[str] = None
# image_features: Optional[List[List[List[float]]]] = None # Remove once have NIXL support
serialized_request: Optional[connect.SerializedRequest] = None
class EncodeRequest(BaseModel):
"""
Serializable class of all the fields vLLM engine requires for inference
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
image_url: str
request_id: str
serialized_request: Optional[connect.SerializedRequest] = None
class EncodeResponse(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str
image_grid_thw: Optional[List[Any]] = None
image_sizes: Optional[List[Any]] = None
serialized_request: Optional[connect.SerializedRequest] = None
image_features: List[List[List[float]]] # Remove once have NIXL support
class MyRequestOutput(BaseModel):
"""
RequestOutput from vLLM is not serializable by default
https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/outputs.py#L85
This class is used to serialize the RequestOutput and any recursively defined types
We can do this because PromptLogprobs, RequestMetrics, and CompletionOutput are all serializable dataclasses
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str
prompt: Optional[str] = None
prompt_token_ids: Optional[List[int]] = None
prompt_logprobs: Optional[PromptLogprobs] = None
outputs: List[CompletionOutput]
finished: bool
metrics: Optional[RequestMetrics] = None
kv_transfer_params: Optional[dict[str, Any]] = None
# lora_request: Optional[LoRARequest] = None
# encoder_prompt: Optional[str] = None
# encoder_prompt_token_ids: Optional[List[int]] = None
# num_cached_tokens: Optional[int] = None
# multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
...@@ -26,6 +26,7 @@ from dynamo._core import Backend as Backend ...@@ -26,6 +26,7 @@ from dynamo._core import Backend as Backend
from dynamo._core import Client as Client from dynamo._core import Client as Client
from dynamo._core import Component as Component from dynamo._core import Component as Component
from dynamo._core import DistributedRuntime as DistributedRuntime from dynamo._core import DistributedRuntime as DistributedRuntime
from dynamo._core import Endpoint as Endpoint
from dynamo._core import EtcdKvCache as EtcdKvCache from dynamo._core import EtcdKvCache as EtcdKvCache
from dynamo._core import ModelDeploymentCard as ModelDeploymentCard from dynamo._core import ModelDeploymentCard as ModelDeploymentCard
from dynamo._core import OAIChatPreprocessor as OAIChatPreprocessor from dynamo._core import OAIChatPreprocessor as OAIChatPreprocessor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment