"lib/bindings/vscode:/vscode.git/clone" did not exist on "d0a63635849ab1c29f4b3cbe419a19730a575da1"
Unverified Commit 353146e2 authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

feat: add vLLM v1 multi-modal example. Add llama4 Maverick example (#1990)


Signed-off-by: default avatarGuanLuo <41310872+GuanLuo@users.noreply.github.com>
Co-authored-by: default avatarkrishung5 <krish@nvidia.com>
parent 1f07dab7
......@@ -167,7 +167,11 @@ RUN uv pip install /workspace/wheels/nixl/*.whl
# Install vllm - keep this early in Dockerfile to avoid
# rebuilds from unrelated source code changes
ARG VLLM_REF="059d4cd"
# [gluo NOTE] currently using a fork of vllm until the fix
# for multi-modal disaggregated serving is merged upstream.
# see https://github.com/vllm-project/vllm/pull/21074
ARG VLLM_REPO=https://github.com/GuanLuo/vllm.git
ARG VLLM_REF="eaadf838ebe93e29a38a6fc1bab5a9801abe7d2c"
ARG MAX_JOBS=16
ENV MAX_JOBS=$MAX_JOBS
ENV CUDA_HOME=/usr/local/cuda
......@@ -177,7 +181,7 @@ RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
uv pip install pip cuda-python && \
mkdir /opt/vllm && \
cd /opt/vllm && \
git clone https://github.com/vllm-project/vllm.git && \
git clone $VLLM_REPO && \
cd vllm && \
git checkout $VLLM_REF && \
uv pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 && \
......@@ -198,7 +202,7 @@ RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
uv pip install pip cuda-python && \
mkdir /opt/vllm && \
cd /opt/vllm && \
git clone https://github.com/vllm-project/vllm.git && \
git clone $VLLM_REPO && \
cd vllm && \
git checkout $VLLM_REF && \
VLLM_USE_PRECOMPILED=1 uv pip install -e . && \
......
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Multimodal Deployment Examples
This directory provides example workflows and reference implementations for deploying a multimodal model using Dynamo and vLLM v1.
## Use the Latest Release
We recommend using the latest stable release of dynamo to avoid breaking changes:
[![GitHub Release](https://img.shields.io/github/v/release/ai-dynamo/dynamo)](https://github.com/ai-dynamo/dynamo/releases/latest)
You can find the latest release [here](https://github.com/ai-dynamo/dynamo/releases/latest) and check out the corresponding branch with:
```bash
git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
```
## Multimodal Aggregated Serving
### Components
- workers: For aggregated serving, we have two workers, [VllmEncodeWorker](components/encode_worker.py) for encoding and [VllmPDWorker](components/worker.py) for prefilling and decoding.
- processor: Tokenizes the prompt and passes it to the VllmEncodeWorker.
- frontend: HTTP endpoint to handle incoming requests.
### Graph
In this graph, we have two workers, [VllmEncodeWorker](components/encode_worker.py) and [VllmPDWorker](components/worker.py).
The VllmEncodeWorker is responsible for encoding the image and passing the embeddings to the VllmPDWorker via a combination of NATS and RDMA.
The work complete event is sent via NATS, while the embeddings tensor is transferred via RDMA through the NIXL interface.
Its VllmPDWorker then prefills and decodes the prompt, just like the [LLM aggregated serving](../llm/README.md) example.
By separating the encode from the prefill and decode stages, we can have a more flexible deployment and scale the
VllmEncodeWorker independently from the prefill and decode workers if needed.
This figure shows the flow of the graph:
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --image_url--> encode_worker
encode_worker --> processor
encode_worker --embeddings--> pd_worker
pd_worker --> encode_worker
```
```bash
cd $DYNAMO_HOME/examples/multimodal_v1
# Serve a LLaVA 1.5 7B model:
dynamo serve graphs.agg:Frontend -f ./configs/agg-llava.yaml
# Serve a Qwen2.5-VL model:
# dynamo serve graphs.agg:Frontend -f ./configs/agg-qwen.yaml
# Serve a Phi3V model:
# dynamo serve graphs.agg:Frontend -f ./configs/agg-phi3v.yaml
```
### Client
In another terminal:
```bash
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llava-hf/llava-1.5-7b-hf",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What is in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
}
}
]
}
],
"max_tokens": 300,
"temperature": 0.0,
"stream": false
}'
```
If serving the example Qwen model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"Qwen/Qwen2.5-VL-7B-Instruct"`. If serving the example Phi3V model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"microsoft/Phi-3.5-vision-instruct"`.
You should see a response similar to this:
```json
{"id": "c37b946e-9e58-4d54-88c8-2dbd92c47b0c", "object": "chat.completion", "created": 1747725277, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " In the image, there is a city bus parked on a street, with a street sign nearby on the right side. The bus appears to be stopped out of service. The setting is in a foggy city, giving it a slightly moody atmosphere."}, "finish_reason": "stop"}]}
```
## Multimodal Disaggregated Serving
### Components
- workers: For disaggregated serving, we have three workers, [VllmEncodeWorker](components/encode_worker.py) for encoding, [VllmDecodeWorker](components/worker.py) for decoding, and [VllmPDWorker](components/worker.py) for prefilling.
- processor: Tokenizes the prompt and passes it to the VllmEncodeWorker.
- frontend: HTTP endpoint to handle incoming requests.
### Graph
In this graph, we have three workers, [VllmEncodeWorker](components/encode_worker.py), [VllmDecodeWorker](components/worker.py), and [VllmPDWorker](components/worker.py).
For the Llava model, embeddings are only required during the prefill stage. As such, the VllmEncodeWorker is connected directly to the prefill worker.
The VllmEncodeWorker is responsible for encoding the image and passing the embeddings to the prefill worker via a combination of NATS and RDMA.
Its work complete event is sent via NATS, while the embeddings tensor is transferred via RDMA through the NIXL interface.
The prefill worker performs the prefilling step and forwards the KV cache to the decode worker for decoding.
For more details on the roles of the prefill and decode workers, refer to the [LLM disaggregated serving](../llm/README.md) example.
This figure shows the flow of the graph:
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --image_url--> encode_worker
encode_worker --> processor
encode_worker --embeddings--> prefill_worker
prefill_worker --> encode_worker
prefill_worker --> decode_worker
decode_worker --> prefill_worker
```
```bash
cd $DYNAMO_HOME/examples/multimodal_v1
dynamo serve graphs.disagg:Frontend -f configs/disagg.yaml
```
### Client
In another terminal:
```bash
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llava-hf/llava-1.5-7b-hf",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What is in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
}
}
]
}
],
"max_tokens": 300,
"temperature": 0.0,
"stream": false
}'
```
You should see a response similar to this:
```json
{"id": "c1774d61-3299-4aa3-bea1-a0af6c055ba8", "object": "chat.completion", "created": 1747725645, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " This image shows a passenger bus traveling down the road near power lines and trees. The bus displays a sign that says \"OUT OF SERVICE\" on its front."}, "finish_reason": "stop"}]}
```
***Note***: disaggregation is currently only confirmed to work with LLaVA. Qwen VL and PhiV are not confirmed to be supported.
## Llama 4 family Serving
The family of Llama 4 models is natively multimodal, however, different
from Llava, they do not directly consume image embedding as input
(see the [support metrics](https://docs.vllm.ai/en/latest/models/supported_models.html#text-generation_1)
from vLLM for the types of multi-modal inputs supported by the model).
Therefore, encoder worker will not be used in the following example and the
encoding will be done along side with prefill.
`meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8` will be used as an example
for the content below. And the system will be H100x8 which can hold one instance
of the model per node.
### Multimodal Aggregated Serving
#### Components
- workers: For aggregated serving, we have one worker, [VllmPDWorker](components/worker.py) for prefilling and decoding.
- processor: Tokenizes the prompt and passes it to the VllmPDWorker.
- frontend: HTTP endpoint to handle incoming requests.
#### Graph
In this graph, we have [VllmPDWorker](components/worker.py) which will encode the image, prefill and decode the prompt, just like the [LLM aggregated serving](../llm/README.md) example.
This figure shows the flow of the graph:
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --image_url--> pd_worker
pd_worker --> processor
```
```bash
cd $DYNAMO_HOME/examples/multimodal_v1
export CONFIG_FILE=configs/llama.yaml
# start components individually as the model is too large that addition
# node will be needed to scale up number of workers. And graph deployment
# doesn't work well in multi-node case.
dynamo serve components.web:Frontend --service-name Frontend -f $CONFIG_FILE &
dynamo serve components.direct_processor:Processor --service-name Processor -f $CONFIG_FILE &
dynamo serve components.worker:VllmPDWorker --service-name VllmPDWorker -f $CONFIG_FILE &
```
#### Client
In another terminal:
```bash
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What is in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
}
}
]
}
],
"max_tokens": 300,
"temperature": 0.0,
"stream": false
}'
```
You should see a response similar to this:
```json
{"id": "b8f060fa95584e34b9204eaba7b105cc", "object": "chat.completion", "created": 1752706281, "model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", "choices": [{"index": 0, "message": {"role": "assistant", "content": "The image depicts a street scene with a trolley bus as the central focus. The trolley bus is positioned on the left side of the road, facing the camera, and features a white and yellow color scheme. A prominent sign on the front of the bus reads \"OUT OF SERVICE\" in orange letters.\n\n**Key Elements:**\n\n* **Trolley Bus:** The bus is the main subject of the image, showcasing its distinctive design and color.\n* **Sign:** The \"OUT OF SERVICE\" sign is clearly visible on the front of the bus, indicating its current status.\n* **Street Scene:** The surrounding environment includes trees, buildings, and power lines, creating a sense of context and atmosphere.\n* **Lighting:** The image is characterized by a misty or foggy quality, with soft lighting that adds to the overall ambiance.\n\n**Overall Impression:**\n\nThe image presents a serene and somewhat melancholic scene, with the out-of-service trolley bus serving as a focal point. The misty atmosphere and soft lighting contribute to a dreamy or nostalgic feel, inviting the viewer to reflect on the scene."}, "finish_reason": "stop"}]}
```
### Multimodal Disaggregated Serving
#### Components
- workers: For disaggregated serving, we have two workers, [VllmDecodeWorker](components/worker.py) for decoding, and [VllmPDWorker](components/worker.py) for encoding and prefilling.
- processor: Tokenizes the prompt and passes it to the VllmPDWorker.
- frontend: HTTP endpoint to handle incoming requests.
#### Graph
In this graph, we have two workers, [VllmDecodeWorker](components/worker.py), and [VllmPDWorker](components/worker.py).
The prefill worker performs the encoding and prefilling steps and forwards the KV cache to the decode worker for decoding.
For more details on the roles of the prefill and decode workers, refer to the [LLM disaggregated serving](../llm/README.md) example.
This figure shows the flow of the graph:
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --image_url--> prefill_worker
prefill_worker --> processor
prefill_worker --> decode_worker
decode_worker --> prefill_worker
```
```bash
cd $DYNAMO_HOME/examples/multimodal_v1
export CONFIG_FILE=configs/llama.yaml
# start components individually as the model is too large that addition
# node will be needed to scale up number of workers. And graph deployment
# doesn't work well in multi-node case.
dynamo serve components.web:Frontend --service-name Frontend -f $CONFIG_FILE &
dynamo serve components.direct_processor:Processor --service-name Processor -f $CONFIG_FILE &
dynamo serve components.worker:VllmPDWorker --service-name VllmPDWorker --VllmPDWorker.enable_disagg true -f $CONFIG_FILE &
# On a separate node with standard dynamo setup
# (i.e. nats and etcd environment variables are set)
dynamo serve components.worker:VllmDecodeWorker --service-name VllmDecodeWorker -f $CONFIG_FILE &
```
#### Client
In another terminal:
```bash
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What is in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
}
}
]
}
],
"max_tokens": 300,
"temperature": 0.0,
"stream": false
}'
```
You should see a response similar to this:
```json
{"id": "6cc99123ad6948d685b8695428238d4b", "object": "chat.completion", "created": 1752708043, "model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", "choices": [{"index": 0, "message": {"role": "assistant", "content": "The image depicts a street scene with a trolley bus as the central focus. The trolley bus is positioned on the left side of the road, facing the camera, and features a white and yellow color scheme. A prominent sign on the front of the bus reads \"OUT OF SERVICE\" in orange letters.\n\n**Key Elements:**\n\n* **Trolley Bus:** The bus is the main subject of the image, showcasing its distinctive design and color.\n* **Sign:** The \"OUT OF SERVICE\" sign is clearly visible on the front of the bus, indicating its current status.\n* **Street Scene:** The surrounding environment includes trees, buildings, and power lines, creating a sense of context and atmosphere.\n* **Lighting:** The image is characterized by a misty or foggy quality, with soft lighting that adds to the overall mood.\n\n**Overall Impression:**\n\nThe image presents a serene and somewhat melancholic scene, with the out-of-service trolley bus serving as a focal point. The misty atmosphere and soft lighting contribute to a contemplative ambiance, inviting the viewer to reflect on the situation."}, "finish_reason": "stop"}]}
```
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
> [!WARNING]
> The content of this README is derived from `examples/multimodal` and have not been validated in `examples/multimodal_v1`.
> If needed, should validate the content and port to `README.md`
## Deployment with Dynamo Operator
These multimodal examples can be deployed to a Kubernetes cluster using [Dynamo Cloud](../../docs/guides/dynamo_deploy/dynamo_cloud.md) and the Dynamo CLI.
### Prerequisites
You must have first followed the instructions in [deploy/cloud/helm/README.md](../../deploy/cloud/helm/README.md) to install Dynamo Cloud on your Kubernetes cluster.
**Note**: The `KUBE_NS` variable in the following steps must match the Kubernetes namespace where you installed Dynamo Cloud. You must also expose the `dynamo-store` service externally. This will be the endpoint the CLI uses to interface with Dynamo Cloud.
### Deployment Steps
For detailed deployment instructions, please refer to the [Operator Deployment Guide](../../docs/guides/dynamo_deploy/operator_deployment.md). The following are the specific commands for the multimodal examples:
```bash
# Set your project root directory
export PROJECT_ROOT=$(pwd)
# Configure environment variables (see operator_deployment.md for details)
export KUBE_NS=dynamo-cloud
export DYNAMO_CLOUD=http://localhost:8080 # If using port-forward
# OR
# export DYNAMO_CLOUD=https://dynamo-cloud.nvidia.com # If using Ingress/VirtualService
# Build the Dynamo base image (see operator_deployment.md for details)
export DYNAMO_IMAGE=<your-registry>/<your-image-name>:<your-tag>
# TODO: Apply Dynamo graph deployment for the example
```
**Note**: To avoid rate limiting from unauthenticated requests to HuggingFace (HF), you can provide your `HF_TOKEN` as a secret in your deployment. See the [operator deployment guide](../../docs/guides/dynamo_deploy/operator_deployment.md#referencing-secrets-in-your-deployment) for instructions on referencing secrets like `HF_TOKEN` in your deployment configuration.
**Note**: Optionally add `--Planner.no-operation=false` at the end of the deployment command to enable the planner component to take scaling actions on your deployment.
### Testing the Deployment
Once the deployment is complete, you can test it. If you have ingress available for your deployment, you can directly call the url returned
in `dynamo deployment get ${DEPLOYMENT_NAME}` and skip the steps to find and forward the frontend pod.
```bash
# Find your frontend pod
export FRONTEND_POD=$(kubectl get pods -n ${KUBE_NS} | grep "${DEPLOYMENT_NAME}-frontend" | sort -k1 | tail -n1 | awk '{print $1}')
# Forward the pod's port to localhost
kubectl port-forward pod/$FRONTEND_POD 8000:8000 -n ${KUBE_NS}
# Test the API endpoint
curl localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llava-hf/llava-1.5-7b-hf",
"messages": [
{
"role": "user",
"content": [
{ "type": "text", "text": "What is in this image?" },
{ "type": "image_url", "image_url": { "url": "http://images.cocodataset.org/test2017/000000155781.jpg" } }
]
}
],
"max_tokens": 300,
"temperature": 0.0,
"stream": false
}'
```
If serving the example Qwen model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"Qwen/Qwen2.5-VL-7B-Instruct"`. If serving the example Phi3V model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"microsoft/Phi-3.5-vision-instruct"`.
For more details on managing deployments, testing, and troubleshooting, please refer to the [Operator Deployment Guide](../../docs/guides/dynamo_deploy/operator_deployment.md).
## Multimodal Aggregated Video Serving
This example demonstrates deploying an aggregated multimodal model that can process video inputs.
### Components
- workers: For video serving, we have two workers, [video_encode_worker](components/video_encode_worker.py) for decoding video into frames, and [video_decode_worker](components/video_decode_worker.py) for prefilling and decoding.
- processor: Tokenizes the prompt and passes it to the decode worker.
- frontend: HTTP endpoint to handle incoming requests.
### Graph
In this graph, we have two workers, `video_encode_worker` and `video_decode_worker`.
The `video_encode_worker` is responsible for decoding the video into a series of frames. Unlike the image pipeline which generates embeddings, this pipeline passes the raw frames directly to the `video_decode_worker`. This transfer is done efficiently using RDMA.
The `video_decode_worker` then receives these frames, and performs prefill and decode steps with the model. Separating the video processing from the language model inference allows for flexible scaling.
This figure shows the flow of the graph:
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --> video_decode_worker
video_decode_worker --> processor
video_decode_worker --video_url--> video_encode_worker
video_encode_worker --frames--> video_decode_worker
```
```bash
cd $DYNAMO_HOME/examples/multimodal
# Serve a LLaVA-NeXT-Video-7B model:
dynamo serve graphs.agg_video:Frontend -f ./configs/agg_video.yaml
```
### Client
In another terminal:
```bash
curl -X 'POST' 'http://localhost:8000/v1/chat/completions' -H 'Content-Type: application/json' -d '{
"model": "llava-hf/LLaVA-NeXT-Video-7B-hf",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Describe the video in detail"
},
{
"type": "video_url",
"video_url": {
"url": "https://storage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4"
}
}
]
}
],
"max_tokens": 300,
"stream": false
}' | jq
```
You should see a response describing the video's content similar to
```json
{
"id": "b5714626-5889-4bb7-8c51-f3bca65b4683",
"object": "chat.completion",
"created": 1749772533,
"model": "llava-hf/LLaVA-NeXT-Video-7B-hf",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": " Sure! The video features a group of anthropomorphic animals who appear human-like. They're out in a meadow, which is a large, open area covered in grasses, and have given human qualities like speaking and a desire to go on adventures. The animals are seen play-fighting with each other clearly seen glancing at the camera when they sense it, blinking, and Roman the second can be directly heard by the camera reciting the line, \"When the challenge becomes insane, the behavior becomes erratic.\" A white rabbit is the first in shot and he winks the left eye and flips the right ear before shaking with the mouse and squirrel friends on a blurry rock ledge under the sky. At some point, the rabbit turns towards the camera and starts playing with the thing, and there's a distant mountain in the background. Furthermore, a little animal from a tree in the background flies with two rocks, and it's joined by the rest of the group of friends. That outro is an elder turtle in the Ramden musical style saturated with a horn-like thing pattern."
},
"finish_reason": "stop"
}
]
}
```
## Multimodal Disaggregated Video Serving
This example demonstrates deploying a disaggregated multimodal model that can process video inputs.
### Dependency
Video example relies on `av` package for video preprocessing inside the encode_worker.
Please install `av` inside the dynamo container to enable video example.
`pip install av`
### Components
- workers: For disaggregated video serving, we have three workers, [video_encode_worker](components/video_encode_worker.py) for decoding video into frames, [video_decode_worker](components/video_decode_worker.py) for decoding, and [video_prefill_worker](components/video_prefill_worker.py) for prefilling.
- processor: Tokenizes the prompt and passes it to the decode worker.
- frontend: HTTP endpoint to handle incoming requests.
### Graph
In this graph, we have three workers, `video_encode_worker`, `video_decode_worker`, and `video_prefill_worker`.
For the LLaVA-NeXT-Video-7B model, frames are only required during the prefill stage. As such, the `video_encode_worker` is connected directly to the `video_prefill_worker`.
The `video_encode_worker` is responsible for decoding the video into a series of frames and passing them to the `video_prefill_worker` via RDMA.
The `video_prefill_worker` performs the prefilling step and forwards the KV cache to the `video_decode_worker` for decoding.
This figure shows the flow of the graph:
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --> video_decode_worker
video_decode_worker --> processor
video_decode_worker --> video_prefill_worker
video_prefill_worker --> video_decode_worker
video_prefill_worker --video_url--> video_encode_worker
video_encode_worker --frames--> video_prefill_worker
```
```bash
cd $DYNAMO_HOME/examples/multimodal
# Serve a LLaVA-NeXT-Video-7B model:
dynamo serve graphs.disagg_video:Frontend -f ./configs/disagg_video.yaml
```
### Client
In another terminal:
```bash
curl -X 'POST' 'http://localhost:8000/v1/chat/completions' -H 'Content-Type: application/json' -d '{
"model": "llava-hf/LLaVA-NeXT-Video-7B-hf",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Describe the video in detail"
},
{
"type": "video_url",
"video_url": {
"url": "https://storage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4"
}
}
]
}
],
"max_tokens": 300,
"stream": false
}' | jq
```
You should see a response describing the video's content similar to
```json
{
"id": "d1d641b1-4daf-48d3-9d06-6a60743b5a42",
"object": "chat.completion",
"created": 1749775300,
"model": "llava-hf/LLaVA-NeXT-Video-7B-hf",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": " The video features two animals in a lush, green outdoor environment. On the ground, there is a white rabbit with big brown eyes, a playful expression, and two antlers. The rabbit is accompanied by a uniquely colored bird with orange pupils, possibly a squirrel or a hamster, sitting on its head. These two animals seem to have embarked on an unlikely journey, flying together in the sky. The backdrop showcases rolling green hills and trees under the pleasant weather. The sky is clear, indicating a beautiful day. The colors and contrast suggest the landscape is during spring or summer, signifying the rabbit and bird could also be engaging in outdoor activities during those seasons. Overall, it's a charming scene depicting an unlikely yet harmonious pair, enjoying a surprise adventure in nature."
},
"finish_reason": "stop"
}
]
}
```
## Deploying Multimodal Examples on Kubernetes
This guide will help you quickly deploy and clean up the multimodal example services in Kubernetes.
### Prerequisites
- **Dynamo Cloud** is already deployed in your target Kubernetes namespace.
- You have `kubectl` access to your cluster and the correct namespace set in `$NAMESPACE`.
### Create a secret with huggingface token
```bash
export HF_TOKEN="huggingfacehub token with read permission to models"
kubectl create secret generic hf-token-secret --from-literal=HF_TOKEN=$HF_TOKEN -n $KUBE_NS || true
```
---
Choose the example you want to deploy or delete. The YAML files are located in `examples/multimodal/deploy/k8s/`.
### Deploy the Multimodal Example
```bash
kubectl apply -f examples/multimodal/deploy/k8s/<Example yaml file> -n $NAMESPACE
```
### Uninstall the Multimodal Example
```bash
kubectl delete -f examples/multimodal/deploy/k8s/<Example yaml file> -n $NAMESPACE
```
### Using a different dynamo container
To customize the container image used in your deployment, you will need to update the manifest before applying it.
You can use [`yq`](https://github.com/mikefarah/yq?tab=readme-ov-file#install), a portable command-line YAML processor.
Please follow the [installation instructions](https://github.com/mikefarah/yq?tab=readme-ov-file#install) for your platform if you do not already have `yq` installed. After installing `yq`, you can generate and apply your manifest as follows:
```bash
export DYNAMO_IMAGE=my-registry/my-image:tag
yq '.spec.services.[].extraPodSpec.mainContainer.image = env(DYNAMO_IMAGE)' $EXAMPLE_FILE > my_example_manifest.yaml
# install the dynamo example
kubectl apply -f my_example_manifest.yaml -n $NAMESPACE
# uninstall the dynamo example
kubectl delete -f my_example_manifest.yaml -n $NAMESPACE
```
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import uuid
from enum import Enum
from typing import AsyncIterator, Tuple, Union
from components.worker import VllmPDWorker
from transformers import AutoTokenizer
from utils.args import parse_vllm_args
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.logging import check_required_workers
from utils.protocol import MultiModalRequest, MyRequestOutput, vLLMMultimodalRequest
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
logger = logging.getLogger(__name__)
class RequestType(Enum):
CHAT = "chat"
COMPLETION = "completion"
@service(
dynamo={
"namespace": "dynamo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
class Processor(ProcessMixIn):
"""
vLLM pre and post processing
"""
pd_worker = depends(VllmPDWorker)
def __init__(self):
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self.model_config = self.engine_args.create_model_config()
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.tokenizer = self._create_tokenizer(self.engine_args)
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
self.completions_processor = CompletionsProcessor(
self.tokenizer, self.model_config
)
self.min_workers = 1
def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer:
"""Create a TokenizerGroup using engine arguments similar to VLLM's approach"""
model_path = engine_args.model
# Create the base tokenizer with VLLM's typical settings
base_tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
padding_side="left",
truncation_side="left",
use_fast=True, # VLLM might use the fast tokenizer for efficiency
)
return base_tokenizer
@async_on_start
async def async_init(self):
runtime = dynamo_context["runtime"]
comp_ns, comp_name = VllmPDWorker.dynamo_address() # type: ignore
self.encode_worker_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("generate")
.client()
)
await check_required_workers(self.encode_worker_client, self.min_workers)
# Main method to parse the request and send the request to the vllm worker.
async def _generate(
self,
raw_request: Union[CompletionRequest, ChatCompletionRequest],
image: str,
request_type: RequestType,
):
request_id = str(uuid.uuid4().hex)
logger.debug(f"Got raw request: {raw_request}")
(
request,
conversation,
prompt,
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
worker_request = vLLMMultimodalRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
image_url=image,
)
response_generator = await self.encode_worker_client.round_robin(
worker_request.model_dump_json()
)
output = self._generate_responses(response_generator, request_type)
# Stream the processed responses
async for response in await self._stream_response(
request, output, request_id, conversation
):
yield response
# This method is used to process the responses from the engine generator.
async def _generate_responses(
self,
response_generator: AsyncIterator[RequestOutput],
request_type: RequestType,
) -> AsyncIterator[Union[RequestOutput, Tuple[int, RequestOutput]]]:
async for resp in response_generator:
# Deserialize the response from the engine
# Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data())
# OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object
request_output = RequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
)
if request_type == RequestType.CHAT:
# For chat requests, yield the request_output directly.
yield request_output
else:
raise NotImplementedError(
f"Request type {request_type} not implemented"
)
# The generate endpoint will be used by the frontend to handle incoming requests.
@endpoint()
async def generate(self, raw_request: MultiModalRequest):
# Ensure the configured template includes the placeholder
template = self.engine_args.prompt_template
if "<prompt>" not in template:
raise ValueError("prompt_template must contain '<prompt>' placeholder")
# Safely extract user text
try:
user_text = raw_request.messages[0].content[0].text
except (IndexError, AttributeError) as e:
raise ValueError(f"Invalid message structure: {e}")
prompt = template.replace("<prompt>", user_text)
msg = {
"role": "user",
"content": prompt,
}
chat_request = ChatCompletionRequest(
model=raw_request.model,
messages=[msg],
stream=raw_request.stream,
max_tokens=raw_request.max_tokens,
temperature=raw_request.temperature,
request_id=str(uuid.uuid4()),
)
image_url = None
for message in raw_request.messages:
for item in message.content:
if item.type == "image_url":
image_url = item.image_url.url
if image_url is None:
raise ValueError("Image URL is required")
async for response in self._generate(chat_request, image_url, RequestType.CHAT):
yield json.dumps(response)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import AsyncIterator
import connect
import torch
from components.worker import VllmPDWorker
from transformers import AutoImageProcessor, LlavaForConditionalGeneration
from utils.args import parse_vllm_args
from utils.image_loader import ImageLoader
from utils.logging import check_required_workers
from utils.protocol import MyRequestOutput, vLLMMultimodalRequest
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
logger = logging.getLogger(__name__)
try:
import cupy as array_module
if not array_module.cuda.is_available():
raise ImportError("CUDA is not available.")
DEVICE = "cuda"
logger.info("Using cupy for array operations (GPU mode).")
except ImportError as e:
logger.warning(f"Failed to import cupy, falling back to numpy: {e}.")
import numpy as array_module
DEVICE = "cpu"
CACHE_SIZE_MAXIMUM = 8
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class VllmEncodeWorker:
decode_worker = depends(VllmPDWorker)
def __init__(self) -> None:
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self.MODEL_ID = self.engine_args.model
self.image_loader = ImageLoader(cache_size=CACHE_SIZE_MAXIMUM)
self.image_processor = AutoImageProcessor.from_pretrained(
self.MODEL_ID, trust_remote_code=True
)
# self.vision_model = load_vision_model(self.MODEL_ID)
self.vision_model = LlavaForConditionalGeneration.from_pretrained(
self.MODEL_ID, device_map="auto", torch_dtype=torch.float16
).eval()
self.min_workers = 1
@endpoint()
async def encode(
self, request: vLLMMultimodalRequest
) -> AsyncIterator[MyRequestOutput]:
logger.debug(f"Received encode request: {{ id: {request.request_id} }}.")
request_id = request.request_id
# The following steps encode the requested image and provided useful embeddings.
# 1. Open the image from the provided URL.
# 2. Process the image using the image processor.
# 3. Run the image through the vision model's vision tower.
# 4. Run the results of the vision tower through the multi-modal projector.
# 5. Create a descriptor for the embeddings.
# 6. Create a write operation using the serialized request and the descriptor.
# 7. Await for the write operation to complete.
# 8. Yield the encode response.
try:
image = await self.image_loader.load_image(request.image_url)
logger.debug(f"Processing image for request: {{ id: {request_id} }}")
image_embeds = self.image_processor(images=image, return_tensors="pt")
# # Add a batch dimension to everything
# for item in image_embeds:
# image_embeds[item] = image_embeds[item].unsqueeze(0).to(DEVICE)
# logger.debug(f"Image embeds: {image_embeds}")
# image_grid_thw = (
# image_embeds["image_grid_thw"].tolist()
# if "image_grid_thw" in image_embeds
# else None
# )
# image_sizes = (
# image_embeds["image_sizes"].tolist()
# if "image_sizes" in image_embeds
# else [image.size]
# )
# logger.debug(
# f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}"
# )
# with torch.no_grad():
# embeddings = self.vision_model.get_multimodal_embeddings(**image_embeds)
# if isinstance(embeddings, tuple) or isinstance(embeddings, list):
# # The result multimodal_embeddings may be a list or tuple of tensors, with each
# # tensor corresponding to a multimodal data item (image or video).
# # TODO: for multi-image support, this result will contain multiple tensors.
# embeddings = embeddings[0].unsqueeze(0)
# logger.debug(
# f"Embeddings: {{ shape: {embeddings.shape}, dtype: {embeddings.dtype}, device: {embeddings.device}, ptr: {embeddings.data_ptr()}, elements: {{ count: {embeddings.numel()}, size: {embeddings.element_size()} }} }}."
# )
# yield EncodeResponse(
# request_id=request.request_id,
# image_grid_thw=image_grid_thw,
# image_sizes=image_sizes,
# ).model_dump_json()
with torch.no_grad():
logger.debug(f"Vision model device: {self.vision_model.device}")
vision_outputs = self.vision_model.vision_tower(
image_embeds["pixel_values"].to(self.vision_model.device)
)
logger.debug("Vision model completed.")
embeddings = vision_outputs.last_hidden_state
embeddings = self.vision_model.multi_modal_projector(embeddings)
descriptor = connect.Descriptor(embeddings)
with self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.to_serialized()
# Clear the image URL as hint that the image is passed as embeddings.
request.image_url = None
logger.debug(f"Request: {request.model_dump_json()}")
# Get the response generator
response_generator = await self.pd_worker_client.round_robin(
request.model_dump_json()
)
await readable.wait_for_completion()
async for response in response_generator:
output = MyRequestOutput.model_validate_json(response.data())
yield MyRequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
).model_dump_json()
except Exception as e:
logger.error(f"Error processing request {request_id}: {e}")
raise
@async_on_start
async def async_init(self):
logger.info("Startup started.")
runtime = dynamo_context["runtime"]
comp_ns, comp_name = VllmPDWorker.dynamo_address() # type: ignore
self.pd_worker_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("generate")
.client()
)
await check_required_workers(self.pd_worker_client, self.min_workers)
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector(runtime=runtime, namespace=comp_ns)
await self._connector.initialize()
logger.info("Startup completed.")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import uuid
from enum import Enum
from typing import AsyncIterator, Tuple, Union
from components.encode_worker import VllmEncodeWorker
from transformers import AutoTokenizer
from utils.args import parse_vllm_args
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.logging import check_required_workers
from utils.protocol import MultiModalRequest, MyRequestOutput, vLLMMultimodalRequest
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
logger = logging.getLogger(__name__)
class RequestType(Enum):
CHAT = "chat"
COMPLETION = "completion"
@service(
dynamo={
"namespace": "dynamo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
class Processor(ProcessMixIn):
"""
vLLM pre and post processing
"""
encode_worker = depends(VllmEncodeWorker)
def __init__(self):
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self.model_config = self.engine_args.create_model_config()
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.tokenizer = self._create_tokenizer(self.engine_args)
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
self.completions_processor = CompletionsProcessor(
self.tokenizer, self.model_config
)
self.min_workers = 1
def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer:
"""Create a TokenizerGroup using engine arguments similar to VLLM's approach"""
model_path = engine_args.model
# Create the base tokenizer with VLLM's typical settings
base_tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
padding_side="left",
truncation_side="left",
use_fast=True, # VLLM might use the fast tokenizer for efficiency
)
return base_tokenizer
@async_on_start
async def async_init(self):
runtime = dynamo_context["runtime"]
comp_ns, comp_name = VllmEncodeWorker.dynamo_address() # type: ignore
self.encode_worker_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("encode")
.client()
)
await check_required_workers(self.encode_worker_client, self.min_workers)
# self.etcd_kv_cache = await EtcdKvCache.create(
# runtime.etcd_client(),
# "/dynamo/processor/",
# {"router": self.engine_args.router},
# )
# Main method to parse the request and send the request to the vllm worker.
async def _generate(
self,
raw_request: Union[CompletionRequest, ChatCompletionRequest],
image: str,
request_type: RequestType,
):
request_id = str(uuid.uuid4().hex)
logger.debug(f"Got raw request: {raw_request}")
(
request,
conversation,
prompt,
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
worker_request = vLLMMultimodalRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
image_url=image,
)
response_generator = await self.encode_worker_client.round_robin(
worker_request.model_dump_json()
)
output = self._generate_responses(response_generator, request_type)
# Stream the processed responses
async for response in await self._stream_response(
request, output, request_id, conversation
):
yield response
# This method is used to process the responses from the engine generator.
async def _generate_responses(
self,
response_generator: AsyncIterator[RequestOutput],
request_type: RequestType,
) -> AsyncIterator[Union[RequestOutput, Tuple[int, RequestOutput]]]:
async for resp in response_generator:
# Deserialize the response from the engine
# Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data())
# OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object
request_output = RequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
)
if request_type == RequestType.CHAT:
# For chat requests, yield the request_output directly.
yield request_output
else:
raise NotImplementedError(
f"Request type {request_type} not implemented"
)
# The generate endpoint will be used by the frontend to handle incoming requests.
@endpoint()
async def generate(self, raw_request: MultiModalRequest):
# Ensure the configured template includes the placeholder
template = self.engine_args.prompt_template
if "<prompt>" not in template:
raise ValueError("prompt_template must contain '<prompt>' placeholder")
# Safely extract user text
try:
user_text = raw_request.messages[0].content[0].text
except (IndexError, AttributeError) as e:
raise ValueError(f"Invalid message structure: {e}")
prompt = template.replace("<prompt>", user_text)
msg = {
"role": "user",
"content": prompt,
}
chat_request = ChatCompletionRequest(
model=raw_request.model,
messages=[msg],
stream=raw_request.stream,
max_tokens=raw_request.max_tokens,
temperature=raw_request.temperature,
request_id=str(uuid.uuid4()),
)
image_url = None
for message in raw_request.messages:
for item in message.content:
if item.type == "image_url":
image_url = item.image_url.url
if image_url is None:
raise ValueError("Image URL is required")
async for response in self._generate(chat_request, image_url, RequestType.CHAT):
yield json.dumps(response)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from components.processor import Processor
from fastapi import FastAPI
from fastapi.responses import JSONResponse, StreamingResponse
from utils.args import parse_vllm_args
from utils.protocol import MultiModalRequest
from dynamo.sdk import DYNAMO_IMAGE, api, depends, service
logger = logging.getLogger(__name__)
@service(
dynamo={
"namespace": "dynamo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
image=DYNAMO_IMAGE,
app=FastAPI(title="Multimodal Example"),
)
class Frontend:
processor = depends(Processor)
def __init__(self):
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
@api(name="v1/chat/completions")
async def generate(self, request: MultiModalRequest):
if self.engine_args.model != request.model:
return JSONResponse(
{"error": f"Model '{request.model}' not found"},
status_code=404,
)
async def content_generator():
async for response in self.processor.generate(request.model_dump_json()):
try:
s = json.loads(response)
yield s
except json.JSONDecodeError as e:
raise RuntimeError(f"Failed to parse JSON response: {e}")
return StreamingResponse(content_generator(), media_type="text/event-stream")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import copy
import logging
import os
import signal
import socket
from typing import Optional
import connect
import torch
from transformers import AutoImageProcessor
from utils.args import parse_vllm_args
from utils.image_loader import ImageLoader
from utils.logging import check_required_workers
from utils.protocol import MyRequestOutput, vLLMMultimodalRequest
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.inputs.data import TokensPrompt
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
logger = logging.getLogger(__name__)
class VllmBaseWorker:
def __init__(self):
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
model_config = self.engine_args.create_model_config()
self.default_sampling_params = model_config.get_diff_sampling_param()
self.enable_disagg = self.engine_args.enable_disagg
self.min_workers = 1
signal.signal(signal.SIGTERM, self.shutdown_vllm_engine)
signal.signal(signal.SIGINT, self.shutdown_vllm_engine)
self.set_side_channel_host_and_port()
async def async_init(self):
self._engine_context = build_async_engine_client_from_engine_args(
self.engine_args
)
if self._engine_context is not None:
self.engine_client = await self._engine_context.__aenter__()
else:
raise RuntimeError("Failed to initialize engine client")
logger.info("VllmWorker has been initialized")
def shutdown_vllm_engine(self, signum, frame):
"""Shutdown the background loop"""
logger.info(f"Received signal {signum}, shutting down")
loop = asyncio.get_event_loop()
try:
self.engine_client.close()
logger.info("VllmWorker shutdown complete")
except Exception as e:
logger.error(f"Error during shutdown: {e}")
finally:
loop.stop()
def set_side_channel_host_and_port(
self, hostname: Optional[str] = None, port: Optional[int] = None
):
"""vLLM V1 NixlConnector creates a side channel to exchange metadata with other NIXL connectors.
This sets the port number for the side channel.
"""
if hostname is None:
hostname = socket.gethostname()
if port is None:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0)) # Bind to a free port provided by the host.
port = s.getsockname()[1] # Get the port number assigned.
logger.debug("Setting VLLM_NIXL_SIDE_CHANNEL_HOST to %s", hostname)
os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = hostname
logger.debug("Setting VLLM_NIXL_SIDE_CHANNEL_PORT to %s", port)
os.environ["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(port)
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class VllmDecodeWorker(VllmBaseWorker):
@async_on_start
async def async_init(self):
await super().async_init()
logger.info("VllmDecodeWorker has been initialized")
@endpoint()
async def generate(self, request: vLLMMultimodalRequest):
logger.debug(
f"Received generate request in DecodeWorker: {{ id: {request.request_id} }}."
)
# Decode worker doesn't process embeddings, so we pass None or empty tensor
gen = self.engine_client.generate(
# prompt=request.engine_prompt,
prompt=TokensPrompt(
prompt_token_ids=request.engine_prompt["prompt_token_ids"],
# multi_modal_data={"image": None}
),
sampling_params=request.sampling_params,
request_id=request.request_id,
)
async for response in gen:
logger.debug(f"Response kv_transfer_params: {response.kv_transfer_params}")
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
metrics=response.metrics,
kv_transfer_params=response.kv_transfer_params,
).model_dump_json()
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class VllmPDWorker(VllmBaseWorker):
decode_worker = depends(VllmDecodeWorker)
@async_on_start
async def async_init(self):
await super().async_init()
if self.enable_disagg:
runtime = dynamo_context["runtime"]
comp_ns, comp_name = VllmDecodeWorker.dynamo_address() # type: ignore
self.decode_worker_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("generate")
.client()
)
await check_required_workers(self.decode_worker_client, self.min_workers)
EMBEDDINGS_DTYPE = torch.float16
EMBEDDINGS_DEVICE = "cpu"
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector()
await self._connector.initialize()
# embeddings_shape, self.embeddings_dtype = get_vision_embeddings_info(
# self.engine_args.model, self.engine_args.num_patches
# )
embeddings_shape = (1, 577, 4096)
logger.debug(f"Embeddings shape: {embeddings_shape}")
self.embedding_size = embeddings_shape[1]
embeddings = torch.empty(
embeddings_shape, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE
)
descriptor = connect.Descriptor(embeddings)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
# descriptor.register_memory(self._connector)
self._embeddings_descriptor = (embeddings, descriptor)
self.image_loader = ImageLoader()
self.image_processor = AutoImageProcessor.from_pretrained(
self.engine_args.model, trust_remote_code=True
)
logger.info("VllmPDWorker has been initialized")
@endpoint()
async def generate(self, request: vLLMMultimodalRequest):
logger.debug(
f"Received generate request in PDWorker: {{ id: {request.request_id} }}."
)
if request.image_url is None:
# Process embeddings using the connector
embeddings, descriptor = self._embeddings_descriptor
if descriptor is None:
logger.error("in PD worker, descriptor is None")
read_op = await self._connector.begin_read(
request.serialized_request, descriptor
)
await read_op.wait_for_completion()
logger.debug(f"in PD worker, image features: {embeddings}")
multi_modal_data = embeddings
else:
# Use PIL image instead of image embeddings
multi_modal_data = await self.image_loader.load_image(request.image_url)
# multi_modal_data = self.image_processor(images=image, return_tensors="pt")["pixel_values"].to(dtype=torch.float16)
# image input is expected to be (image_num, channel, height, width)
# logger.info(f"Image features shape: {multi_modal_data.shape}")
# multi_modal_data = multi_modal_data.unsqueeze(0)
# Remove the image features from the request as they are not required
request.image_url = None
request.serialized_request = None
pd_request = copy.deepcopy(request)
# Do prefill and remote decode if enable_disagg is true
if self.enable_disagg:
extra_args = pd_request.sampling_params.extra_args or {}
extra_args["kv_transfer_params"] = {
"do_remote_decode": True,
}
pd_request.sampling_params.extra_args = extra_args
pd_request.sampling_params.max_tokens = 1
pd_request.sampling_params.min_tokens = 1
logger.debug("Prefill request: %s", pd_request)
gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=pd_request.engine_prompt["prompt_token_ids"],
multi_modal_data={"image": multi_modal_data},
),
sampling_params=pd_request.sampling_params,
request_id=pd_request.request_id,
)
if self.enable_disagg:
decode_request = copy.deepcopy(request)
async for prefill_response in gen:
# Update the prompt token id in the decode request to the one
# in response, which has image templated filled in. So that
# the decode worker will fetch correct amount of KV blocks.
decode_request.engine_prompt[
"prompt_token_ids"
] = prefill_response.prompt_token_ids
# logger.debug(f"Prefill response: {prefill_response}")
# request_output = MyRequestOutput.model_validate_json(prefill_response.model_dump_json())
logger.debug(
f"Prefill response kv_transfer_params: {prefill_response.kv_transfer_params}"
)
extra_args = decode_request.sampling_params.extra_args or {}
extra_args["kv_transfer_params"] = prefill_response.kv_transfer_params
extra_args.pop("serialized_request", None)
decode_request.sampling_params.extra_args = extra_args
logger.debug("Decode request: %s", decode_request)
async for decode_response in await self.decode_worker_client.round_robin(
decode_request.model_dump_json()
):
output = MyRequestOutput.model_validate_json(decode_response.data())
yield MyRequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
kv_transfer_params=output.kv_transfer_params,
).model_dump_json()
else:
async for response in gen:
logger.debug(
f"Response kv_transfer_params: {response.kv_transfer_params}"
)
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
metrics=response.metrics,
kv_transfer_params=response.kv_transfer_params,
).model_dump_json()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: llava-hf/llava-1.5-7b-hf
block-size: 64
max-model-len: 4096
image-token-id: 32000
num-patches: 576
Frontend:
common-configs: [model]
Processor:
router: round-robin
prompt-template: "USER: <image>\n<prompt> ASSISTANT:"
common-configs: [model, block-size, max-model-len]
VllmEncodeWorker:
tensor-parallel-size: 1
router: random
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model]
VllmPDWorker:
enforce-eager: true
max-num-batched-tokens: 16384
enable-prefix-caching: true
enable_disagg: false
router: random
tensor-parallel-size: 1
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model, block-size, max-model-len, image-token-id, num-patches]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: microsoft/Phi-3.5-vision-instruct
block-size: 64
max-model-len: 4096
trust-remote-code: true
Frontend:
common-configs: [model]
Processor:
router: round-robin
prompt-template: "<|user|>\n<|image_1|>\n<prompt><|end|>\n<|assistant|>\n"
common-configs: [model, block-size, max-model-len, trust-remote-code]
VllmDecodeWorker:
enforce-eager: true
max-num-batched-tokens: 16384
max-num-seqs: 2
mm-processor-kwargs:
num_crops: 16
enable-prefix-caching: true
image-token-id: 32000
num-patches: 757
router: random
tensor-parallel-size: 1
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model, block-size, max-model-len, trust-remote-code]
VllmEncodeWorker:
tensor-parallel-size: 1
router: random
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: Qwen/Qwen2.5-VL-7B-Instruct
block-size: 64
max-model-len: 4096
Frontend:
common-configs: [model]
Processor:
router: round-robin
prompt-template: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|><prompt><|im_end|>\n<|im_start|>assistant\n"
common-configs: [model, block-size, max-model-len]
VllmDecodeWorker:
enforce-eager: true
max-num-batched-tokens: 16384
max-num-seqs: 5
mm-processor-kwargs:
min_pixels: 784
max_pixels: 1003520
fps: 1
enable-prefix-caching: true
image-token-id: 151655
num-patches: 345
router: random
tensor-parallel-size: 1
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model, block-size, max-model-len]
VllmEncodeWorker:
tensor-parallel-size: 1
router: random
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: llava-hf/llava-1.5-7b-hf
kv-transfer-config: '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
block-size: 64
# max-model-len: 4096
image-token-id: 32000
num-patches: 576
Frontend:
common-configs: [model]
Processor:
router: round-robin
prompt-template: "USER: <image>\n<prompt> ASSISTANT:"
common-configs: [model]
VllmEncodeWorker:
tensor-parallel-size: 1
router: random
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model]
VllmPDWorker:
enforce-eager: true
# max-num-batched-tokens: 16384
enable-prefix-caching: true
enable_disagg: true
router: random
tensor-parallel-size: 1
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model, kv-transfer-config, image-token-id, num-patches, block-size]
VllmDecodeWorker:
enforce-eager: true
# max-num-batched-tokens: 16384
enable-prefix-caching: true
router: random
tensor-parallel-size: 1
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model, kv-transfer-config, image-token-id, num-patches, block-size]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8
kv-transfer-config: '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
max-model-len: 208960
Frontend:
common-configs: [model]
Processor:
router: round-robin
prompt-template: "<|image|>\n<prompt>"
common-configs: [model, block-size, max-model-len]
VllmDecodeWorker:
router: random
tensor_parallel_size: 8
ServiceArgs:
workers: 1
resources:
gpu: '8'
common-configs: [model, kv-transfer-config, block-size, max-model-len]
VllmPDWorker:
enable_disagg: false
router: random
tensor_parallel_size: 8
ServiceArgs:
workers: 1
resources:
gpu: '8'
common-configs: [model, kv-transfer-config, block-size, max-model-len]
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Dynamo Connect
Dynamo connect provides a Pythonic interface to the NIXL base RDMA subsystem via a set of Python classes.
The primary goal of this library to simplify the integration of NIXL based RDMA into inference applications.
All operations using the Connect library begin with the [`Connector`](#connector) class and the type of operation required.
There are four types of supported operations:
- **Register local readable memory**:
Register local memory buffer(s) with the RDMA subsystem to enable a remote worker to read from.
- **Register local writable memory**:
Register local memory buffer(s) with the RDMA subsystem to enable a remote worker to write to.
- **Read from registered, remote memory**:
Read remote memory buffer(s), registered by a remote worker to be readable, into local memory buffer(s).
- **Write to registered, remote memory**:
Write local memory buffer(s) to remote memory buffer(s) registered by a remote worker to writable.
By connecting correctly paired operations, high-throughput GPU Direct RDMA data transfers can be completed.
Given the list above, the correct pairing of operations would be 1 & 3 or 2 & 4.
Where one side is a "(read|write)-able operation" and the other is its correctly paired "(read|write) operation".
Specifically, a read operation must be paired with a readable operation, and a write operation must be paired with a writable operation.
## Examples
### Generic Example
In the diagram below, Local creates a [`WritableOperation`](#writableoperation) intended to receive data from Remote.
Local then sends metadata about the requuested RDMA operation to Remote.
Remote then uses the metadata to create a [`WriteOperation`](#writeoperation) which will perform the GPU Direct RDMA memory transfer from Remote's GPU memory to Local's GPU memory.
```mermaid
---
title: Write Operation Between Two Workers
---
flowchart LR
c1[Remote] --"3: .begin_write()"--- WriteOperation
WriteOperation e1@=="4: GPU Direct RDMA"==> WritableOperation
WritableOperation --"1: .create_writable()"--- c2[Local]
c2 e2@--"2: RDMA Metadata via HTTP"--> c1
e1@{ animate: true; }
e2@{ animate: true; }
```
### Multimodal Example
In the case of the [Dynamo Multimodal Disaggregated Example](../README.md):
1. The HTTP frontend accepts a text prompt and a URL to an image.
2. The prompt and URL are then enqueued with the Processor before being dispatched to the first available Decode Worker.
3. Decode Worker then requests a Prefill Worker to provide key-value data for the LLM powering the Decode Worker.
4. Prefill Worker then requests that the image be processed and provided as embeddings by the Encode Worker.
5. Encode Worker acquires the image, processes it, performs inference on the image using a specialized vision model, and finally provides the embeddings to Prefill Worker.
6. Prefill Worker receives the embeddings from Encode Worker and generates a key-value cache (KV$) update for Decode Worker's LLM and writes the update directly to the GPU memory reserved for the data.
7. Finally, Decode Worker performs the requested inference.
```mermaid
---
title: Multimodal Disaggregated Workflow
---
flowchart LR
p0[HTTP Frontend] i0@--"text prompt"-->p1[Processor]
p0 i1@--"url"-->p1
p1 i2@--"prompt"-->dw[Decode Worker]
p1 i3@--"url"-->dw
dw i4@--"prompt"-->pw[Prefill Worker]
dw i5@--"url"-->pw
pw i6@--"url"-->ew[Encode Worker]
ew o0@=="image embeddings"==>pw
pw o1@=="kv_cache updates"==>dw
dw o2@--"inference results"-->p0
i0@{ animate: true; }
i1@{ animate: true; }
i2@{ animate: true; }
i3@{ animate: true; }
i4@{ animate: true; }
i5@{ animate: true; }
i6@{ animate: true; }
o0@{ animate: true; }
o1@{ animate: true; }
o2@{ animate: true; }
```
_Note: In this example, it is the data transfer between the Prefill Worker and the Encode Worker that utilizes the Dynamo Connect library. The KV Cache transfer between Decode Worker and Prefill Worker utilizes the NIXL base RDMA subsystem directly without using the Dynamo Connect library._
#### Code Examples
See [prefill_worker](../components/prefill_worker.py#L199) or [decode_worker](../components/decode_worker.py#L239),
for how they coordinate directly with the Encode Worker by creating a [`WritableOperation`](#writableoperation),
sending the operation's metadata via Dynamo's round-robin dispatcher, and awaiting the operation for completion before making use of the transferred data.
See [encode_worker](../components/encode_worker.py#L190),
for how the resulting embeddings are registered with the RDMA subsystem by creating a [`Descriptor`](#descriptor),
a [`WriteOperation`](#writeoperation) is created using the metadata provided by the requesting worker,
and the worker awaits for the data transfer to complete for yielding a response.
## Python Classes
### Connector
Core class for managing the connection between workers in a distributed environment.
Use this class to create readable and writable operations, or read and write data to remote workers.
This class is responsible for interfacing with the NIXL-based RDMA subsystem and providing a "Pythonic" interface
with which to utilize GPU Direct RDMA accelerated data transfers between models hosted by different workers in a Dynamo pipeline.
The connector provides two methods of moving data between workers:
- Preparing local memory to be written to by a remote worker.
- Preparing local memory to be read by a remote worker.
In both cases, local memory is registered with the NIXL-based RDMA subsystem via the [`Descriptor`](#descriptor) class and provided to the connector.
The connector then configures the RDMA subsystem to expose the memory for the requested operation and returns an operation control object.
The operation control object, either a [`ReadableOperation`](#readableoperation) or a [`WritableOperation`](#writableoperation),
provides RDMA metadata via its [`.to_serialized()`](#to_serialized) method as well as functionality to know when the operation has been completed or cancel the operation prior to completion.
The RDMA metadata must be provided to the remote worker expected to complete the operation.
The metadata contains required information (identifiers, keys, etc.) which enables the remote worker to interact with the provided memory.
#### Methods
##### `begin_read`
> Creates a [`ReadOperation`](#readoperation) for transferring data from a remote worker.
>
> To create the operation, the serialized request from a remote worker's [`ReadableOperation`](#readableoperation)
> along with a matching set of local memory descriptors which reference memory intended to receive data from the remote worker
> must be provided.
> The serialized request must be transferred from the remote to the local worker via a secondary channel, most likely HTTP or TCP+NATS.
>
> Once created, the operation will begin reading immediately.
> Disposal of the object reference will instruct the RDMA subsystem to cancel the read operation,
> therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
##### `begin_write`
> Creates a write operation for transferring data to a remote worker.
>
> To create the operation, the serialized request from a remote worker's [`WritableOperation`](#writableoperation)
> along with a matching set of local memory descriptors which reference memory to be transferred to the remote worker
> must be provided.
> The serialized request must be transferred from the remote to the local worker via a secondary channel, most likely HTTP or TCP+NATS.
>
> Once created, the operation will begin writing immediately.
> Disposal of the object reference will instruct the RDMA subsystem to cancel the write operation,
> therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
##### `create_readable`
> Creates a [`ReadableOperation`](#readableoperation) for transferring data to a remote worker.
>
> To create the operation, a set of local memory descriptors must be provided that reference memory intended to be transferred to
> a remote worker.
> Once created, the memory referenced by the provided descriptors becomes immediately readable by a remote worker with the necessary metadata.
> The metadata required to access the memory referenced by the provided descriptors is accessible via the operations `.to_serialized()` method.
> Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
>
> Disposable of the operation's object reference will instruct the RDMA subsystem to cancel the operation,
> therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
##### `create_writable`
> Creates a [`WritableOperation`](#writableoperation) for transferring data from a remote worker.
>
> To create the operation, a set of local memory descriptors must be provided which reference memory intended to receive data from
> a remote worker.
> Once created, the memory referenced by the provided descriptors becomes immediately writable by a remote worker with the necessary metadata.
> The metadata required to access the memory referenced by the provided descriptors is accessible via the operations `.to_serialized()` method.
> Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
>
> Disposable of the operation's object reference will instruct the RDMA subsystem to cancel the operation,
> therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
### Descriptor
Memory descriptor that ensures memory is registered with the NIXL base RDMA subsystem.
Memory must be registered with the RDMA subsystem to enable interaction with the memory.
Descriptor objects are administrative and do not copy, move, or otherwise modify the registered memory.
There are four ways to create a descriptor:
1. From a `torch.Tensor` object. Device information will be derived from the provided object.
2. From a `tuple` containing either a NumPy or CuPy `ndarray` and information desribing where the memory resides (Host/CPU vs GPU).
3. From a Python `bytes` object. Memory is assumed to reside in CPU addressable host memory.
4. From a `tuple` comprised of the address of the memory, its size in bytes, and device information.
An optional reference to a Python object can be provided to avoid garbage collection issues.
### Device
Device describes the device, or kind of memory, a given allocation resides in.
Usually host (`"cpu"`) or GPU (`"cuda"`) memory.
When a system contains multiple GPU devices, specific GPU devices can be identified by including their ordinal index number.
For example, to reference the second GPU in a system `"cuda:1"` can be used.
By default, when `"cuda"` is provided, it is assumed to be `"cuda:0"` or the first GPU enumerated by the system.
### ReadOperation
An operation which transfers data from a remote worker to the local worker.
To create the operation, RDMA metadata ([`SerializedRequest`](#serializedrequest)) from a remote worker's [`ReadableOperation`](#readableoperation)
along with a matching set of local [`Descriptor`](#descriptor) objects which reference memory intended to receive data from the remote worker must be provided.
The RDMA metadata must be transferred from the remote to the local worker via a secondary channel, most likely HTTP or TCP+NATS.
Once created, the operation will begin reading immediately.
Disposal of the object reference will instruct the RDMA subsystem to cancel the read operation,
therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
#### Methods
##### `cancel`
> Instructs the RDMA subsystem to cancel the operation.
> Completed operations cannot be cancelled.
##### `wait_for_completion`
> Blocks the caller until the memory from the remote worker has been transferred to the provided buffers.
### ReadableOperation
An operation which enables a remote worker to read data from the local worker.
To create the operation, a set of local [`Descriptor`](#descriptor) objects must be provided that reference memory intended to be transferred to a remote worker.
Once created, the memory referenced by the provided descriptors becomes immediately readable by a remote worker with the necessary metadata.
The metadata required to access the memory referenced by the provided descriptors is accessible via the operations `.to_serialized()` method.
Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
Disposal of the operation's object reference will instruct the RDMA subsystem to cancel the operation,
therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
#### Methods
##### `to_serialized`
> Generates and returns the RDMA metadata ([`SerializedRequest`](#serializedrequest)) required for a remote worker to read from the operation.
> Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
##### `wait_for_completion`
> Blocks the caller until the operation has received a completion signal from a remote worker.
### WriteOperation
An operation which transfers data from the local worker to a remote worker.
To create the operation, RDMA metadata ([`SerializedRequest`](#serializedrequest)) from a remote worker's [`WritableOperation`](#writableoperation)
along with a matching set of local [`Descriptor`](#descriptor) objects which reference memory to be transferred to the remote worker must be provided.
The RDMA metadata must be transferred from the remote to the local worker via a secondary channel, most likely HTTP or TCP+NATS.
Once created, the operation will begin writing immediately.
Disposal of the object reference will instruct the RDMA subsystem to cancel the write operation,
therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
#### Methods
##### `cancel`
> Instructs the RDMA subsystem to cancel the operation.
> Completed operations cannot be cancelled.
##### `wait_for_completion`
> Blocks the caller until all provided buffers have been transferred to the remote worker.
### WritableOperation
An operation which enables a remote worker to write data to the local worker.
To create the operation, a set of local [`Descriptor`](#descriptor) objects must be provided which reference memory intended to receive data from a remote worker.
Once created, the memory referenced by the provided descriptors becomes immediately writable by a remote worker with the necessary metadata.
The metadata required to access the memory referenced by the provided descriptors is accessible via the operations `.to_serialized()` method.
Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
Disposal of the operation's object reference will instruct the RDMA subsystem to cancel the operation,
therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
#### Methods
##### `to_serialized`
> Generates and returns the RDMA metadata ([`SerializedRequest`](#serializedrequest)) required for a remote worker to write to the operation.
> Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
##### `wait_for_completion`
> Blocks the caller until the operation has received a completion signal from a remote worker.
### SerializedRequest
A Pydantic type intended to provide JSON serialized RDMA metadata about a [`ReadableOperation`](#readableoperation) or [`WritableOperation`](#writableoperation) object.
Use the [`.to_serialized()`](#to_serialized) method on either of the above types to generate a `SerializedRequest` object for an operation.
## References
- [NVIDIA Dynamo](https://developer.nvidia.com/dynamo) @ [GitHub](https://github.com/ai-dynamo/dynamo)
- [NVIDIA Inference Transfer Library (NIXL)](https://developer.nvidia.com/blog/introducing-nvidia-dynamo-a-low-latency-distributed-inference-framework-for-scaling-reasoning-ai-models/#nvidia_inference_transfer_library_nixl_low-latency_hardware-agnostic_communication%C2%A0) @ [GitHub](https://github.com/ai-dynamo/nixl)
- [Dynamo Multimodal Example](https://github.com/ai-dynamo/dynamo/tree/main/examples/multimodal)
- [NVIDIA GPU Direct](https://developer.nvidia.com/gpudirect)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import asyncio
import logging
import socket
import uuid
import zlib
from abc import ABC, abstractmethod
from enum import IntEnum
from functools import cached_property
from typing import Any, List, Optional
import nixl._api as nixl_api
import nixl._bindings as nixl_bindings
import torch
from pydantic import BaseModel, ConfigDict, field_validator
from dynamo.runtime import DistributedRuntime
from dynamo.sdk import dynamo_context
logger = logging.getLogger(__name__)
try:
import cupy as array_module
from cupy_backends.cuda.api.runtime import CUDARuntimeError
logger.info("Utilizing cupy to enable GPU acceleration.")
except ImportError:
try:
import numpy as array_module
logger.warning("Failed to load cupy for GPU acceleration, utilizing numpy to provide CPU based operations.")
except ImportError as e:
raise ImportError("Numpy or cupy must be installed to use this module.") from e
class AbstractOperation(ABC):
"""
Abstract base class for awaitable NIXL based RDMA operations.
"""
def __init__(
self,
connector: Connector,
operation_kind: OperationKind,
local_descriptors: Descriptor | list[Descriptor],
remote_descriptors: Optional[Descriptor | list[Descriptor]],
notification_key: Optional[str],
) -> None:
if not isinstance(connector, Connector):
raise TypeError("Argument `connector` must be `dynamo.connect.Connector`.")
if operation_kind is not OperationKind.READ and operation_kind is not OperationKind.WRITE:
raise ValueError("Argument `operation_kind` must be either `READ` or `WRITE`.")
if not (
isinstance(local_descriptors, (Descriptor, list))
or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors))
):
raise TypeError("Argument `local_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.")
if (
remote_descriptors is not None
and not (
isinstance(remote_descriptors, Descriptor)
or (isinstance(remote_descriptors, list) and all(isinstance(d, Descriptor) for d in remote_descriptors))
)
):
raise TypeError("Argument `remote_descriptors` must be dynamo.connect.Descriptor`, `list[dynamo.connect.Descriptor]`, or `None`.")
if isinstance(local_descriptors, list) and len(local_descriptors) == 0:
raise ValueError("Argument `local_descriptors` must not be an empty list.")
if (
remote_descriptors is not None
and isinstance(remote_descriptors, list)
and len(remote_descriptors) == 0
):
raise ValueError("Argument `remote_descriptors` must not be an empty list.")
notification_key = str(uuid.uuid4()) if notification_key is None else notification_key
if not isinstance(notification_key, str):
raise TypeError("Argument `notification_key` must be `str` or `None`.")
if len(notification_key) == 0:
raise ValueError("Argument `notification_key` must not be an empty string.")
self._notification_key: str = "" if notification_key is None else notification_key
self._connector: Connector = connector
self._operation_kind: OperationKind = operation_kind
self._local_descriptors: Descriptor | list[Descriptor] = local_descriptors
self._local_dlist: Optional[list[tuple[int, int, int]]] = None
self._local_memtype: DeviceKind = DeviceKind.UNSPECIFIED
self._remote_descriptors: Optional[Descriptor | list[Descriptor]] = None if remote_descriptors is None else remote_descriptors
self._remote_dlist: Optional[list[tuple[int, int, int]]] = None
self._remote_memtype: DeviceKind = DeviceKind.UNSPECIFIED
# Register local descriptors with NIXL.
# Note: Only local descriptors should be registered with NIXL,
if isinstance(local_descriptors, list):
for d in local_descriptors:
d.register_memory(self._connector)
else:
local_descriptors.register_memory(self._connector)
# Record local descriptors.
memtype, dtlist = self._create_dlist(local_descriptors)
self._local_dlist = dtlist
self._local_memtype = memtype
# Record remote descriptors when provided.
if remote_descriptors is not None:
memtype, dtlist = self._create_dlist(remote_descriptors)
self._remote_dlist = dtlist
self._remote_memtype = memtype
def __del__(self) -> None:
self._release()
def __enter__(self) -> AbstractOperation:
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self._release()
def _release(self) -> None:
"""
Private method to release resources. Only to be called by `self`.
"""
pass
@property
def connector(self) -> Connector:
"""
Gets the local associated with this operation.
"""
return self._connector
@property
def operation_kind(self) -> OperationKind:
"""
Gets the kind of operation.
"""
return self._operation_kind
@abstractmethod
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
raise NotImplementedError("Abstract method not implemented by derived class.")
# Private Methods
def _create_dlist(
self,
descriptors: Descriptor | list[Descriptor],
) -> tuple[DeviceKind, list[tuple[int, int, int]]]:
"""
Helper function to create a list of tuples (ptr, size, device) from descriptors.
"""
dlist: list[tuple[int, int, int]] = []
memtype: DeviceKind = DeviceKind.UNSPECIFIED
if isinstance(descriptors, list):
memtype = descriptors[0].device.kind
for desc in descriptors:
if memtype != desc.device.kind:
raise ValueError("All local descriptors must have the same memory type.")
dlist.append((desc.ptr, desc.size, desc.device.id))
else:
memtype = descriptors.device.kind
dlist.append((descriptors.ptr, descriptors.size, descriptors.device.id))
return (memtype, dlist)
class ActiveOperation(AbstractOperation):
"""
Abstract class for active operations that initiates a NIXL based RDMA transfer based `SerializedRequest`
provided by the remote worker's corresponding `PassiveOperation`.
"""
def __init__(
self,
remote: Remote,
operation_kind: OperationKind,
local_descriptors: Descriptor | list[Descriptor],
remote_descriptors: Descriptor | list[Descriptor],
notification_key: str,
) -> None:
if not isinstance(remote, Remote) or remote._connector is None:
raise TypeError("Argument `remote` must be valid `dynamo.connect.Remote`.")
if not isinstance(operation_kind, OperationKind):
raise TypeError("Argument `operation_kind` must `dynamo.connect.OperationKind`.")
if operation_kind is not OperationKind.READ and operation_kind is not OperationKind.WRITE:
raise ValueError("Argument `operation_kind` must be either `READ` or `WRITE`.")
if not (
isinstance(local_descriptors, Descriptor)
or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors))
):
raise TypeError("Argument `local_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.")
if not (
isinstance(remote_descriptors, Descriptor)
or (isinstance(remote_descriptors, list) and all(isinstance(d, Descriptor) for d in remote_descriptors))
):
raise TypeError("Argument `remote_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.")
# Unpack single descriptors from lists if they are provided as single descriptors.
if isinstance(local_descriptors, list) and len(local_descriptors) == 1:
local_descriptors = local_descriptors[0]
if isinstance(remote_descriptors, list) and len(remote_descriptors) == 1:
remote_descriptors = remote_descriptors[0]
if (isinstance(local_descriptors, list) and isinstance(remote_descriptors, list) and len(local_descriptors) != len(remote_descriptors)):
raise ValueError("When `local_descriptors` and `remote_descriptors` are lists, they must have the same length.")
elif isinstance(local_descriptors, list) != isinstance(remote_descriptors, list):
raise ValueError("Both `local_descriptors` and `remote_descriptors` must be either lists or single descriptors.")
if not isinstance(notification_key, str):
raise TypeError("Argument `notification_key` must be `str`.")
if len(notification_key) == 0:
raise ValueError("Argument `notification_key` must not be an empty string.")
self._remote = remote
self._status = OperationStatus.UNINTIALIZED
super().__init__(remote.connector, operation_kind, local_descriptors, remote_descriptors, notification_key)
# Quick check to ensure remote descriptors are not None to make static analysis happy.
if self._local_dlist is None or self._remote_dlist is None:
raise RuntimeError("NIXL descriptor list(s) not bound to operation.")
self._local_xfer_descs: Optional[nixl_bindings.nixlXferDList] = None
self._remote_xfer_descs: Optional[nixl_bindings.nixlXferDList] = None
self._xfer_hndl: Optional[nixl_api.nixl_xfer_handle] = None
self._local_xfer_descs = self._connector._nixl.get_xfer_descs(
descs=self._local_dlist,
mem_type=str(self._local_memtype),
)
logger.debug(f"Created local NIXL xfer descs: {self._local_xfer_descs}")
self._remote_xfer_descs = self._connector._nixl.get_xfer_descs(
descs=self._remote_dlist,
mem_type=str(self._remote_memtype),
)
logger.debug(f"Created remote NIXL xfer descs: {self._remote_xfer_descs}")
self._xfer_hndl = self._connector._nixl.initialize_xfer(
operation=str(operation_kind),
local_descs=self._local_xfer_descs,
remote_descs=self._remote_xfer_descs,
remote_agent=self._remote.name,
notif_msg=self._notification_key.encode("utf-8"),
)
logger.debug(f"Created NIXL transfer handle: {self._xfer_hndl}")
def __del__(self) -> None:
super().__del__()
self._release()
def __enter__(self) -> ActiveOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
match self.status:
case OperationStatus.IN_PROGRESS | OperationStatus.INITIALIZED:
self._status = OperationStatus.CANCELLED
self._release()
def __repr__(self) -> str:
return str(
f"{self.__class__.__name__}("
f"operation_kind={self._operation_kind}, "
f"local_descriptors={self._local_descriptors}, "
f"remote_descriptors={self._remote_descriptors}, "
f"notification_key='{self._notification_key}', "
f"remote='{self._remote.name}', "
f"status='{self._status}'"
f")"
)
def _release(self) -> None:
"""
Private method to release resources.
"""
error: Optional[Exception] = None
if self._xfer_hndl is not None:
try:
logger.debug(f"NIXL transfer handle {self._xfer_hndl} released.")
self._connector._nixl.release_xfer_handle(self._xfer_hndl)
except Exception as e:
logger.error(f"Failed to release resources: {e}")
error = e
finally:
self._xfer_hndl = None
try:
super()._release()
except Exception as e:
logger.error(f"Failed to release WaitableOperation resources: {e}")
if error is not None:
e.__cause__ = error
error = e
if error is not None:
raise error
def _cancel_(self) -> None:
if self._xfer_hndl is None:
return
if self.status == OperationStatus.ERRORED:
raise RuntimeError("Operation is errored, unable to cancel the operation.")
logger.info(f"Cancellation requested for operation {{ kind={self._operation_kind}, remote='{self._remote.name}', status={self._status} }}.")
# NIXL will cancel the transfer if it is in progress when the handle is released.
self._connector._nixl.release_xfer_handle(self._xfer_hndl)
self._status = OperationStatus.CANCELLED
self._xfer_hndl = None
async def _wait_for_completion_(self) -> None:
# Loop until the operation is no longer in progress (or "initalized"),
# yielding control to the event loop to allow other operations to run.
iteration_count = 0
while True:
if iteration_count & 10 == 0:
logger.debug(f"Waiting for operation {{ kind={self._operation_kind}, remote='{self._remote.name}', duration={iteration_count / 10}s }}.")
match self.status:
# "in progress" or "initialized" means the operation is ongoing.
case OperationStatus.INITIALIZED:
await asyncio.sleep(0.1)
case OperationStatus.IN_PROGRESS:
await asyncio.sleep(0.1)
# Any other state indicates completion or error.
case _:
return
@abstractmethod
def cancel(self) -> None:
"""
Cancels the operation.
No affect if the operation has already completed or errored, or has been cancelled.
"""
raise NotImplementedError("Abstract method not implemented by derived class.")
@property
def remote(self) -> Remote:
"""
Gets the remote worker associated with this operation.
"""
return self._remote
@property
def status(self) -> OperationStatus:
"""
Gets the status of the operation.
"""
# Early return if the operation is already complete, errored, or cancelled.
match self._status:
case OperationStatus.COMPLETE | OperationStatus.ERRORED | OperationStatus.CANCELLED:
return self._status
if self._xfer_hndl is None:
raise RuntimeError("NIXL transfer handle is invalid.")
old_status = self._status
if self._status == OperationStatus.UNINTIALIZED:
state = self._connector._nixl.transfer(self._xfer_hndl, self._notification_key.encode("utf-8"))
logger.debug(f"NIXL reported transfer state: {state}")
if state == "ERR":
self._status = OperationStatus.ERRORED
elif state == "DONE":
self._status = OperationStatus.COMPLETE
else:
self._status = OperationStatus.INITIALIZED
else:
state = self._connector._nixl.check_xfer_state(self._xfer_hndl)
logger.debug(f"NIXL reported transfer state: {state}")
if state == "ERR":
self._status = OperationStatus.ERRORED
elif state == "DONE":
self._status = OperationStatus.COMPLETE
else:
self._status = OperationStatus.IN_PROGRESS
if self._status != old_status:
logger.debug(f"{self.__class__.__name__} {{ remote: '{self._remote.name}' status: '{old_status}' => '{self._status}' }}.")
return self._status
class Connector:
"""
Core class for managing the connection between workers in a distributed environment.
Use this class to create readable and writable operations, or read and write data to remote workers.
"""
def __init__(
self,
namespace: Optional[str] = None,
runtime: Optional[DistributedRuntime] = None,
worker_id: Optional[str] = None,
) -> None:
"""
Creates a new Connector instance.
Parameters
----------
namespace : Optional[str], optional
Dynamo namespace of the component, defaults to "dynamo" when `None`.
runtime : Optional[DistributedRuntime], optional
Reference the dynamo runtime used by the compenent, attempts to use the current runtime when `None`.
worker_id : Optional[str], optional
Unique identifier of the worker, defaults to a new UUID when `None`.
Raises
------
TypeError
When `namespace` is provied and not of type 'str'.
TypeError
When `runtime` iis provied and not of type `dynamo.runtime.DistributedRuntime`.
TypeError
When `worker_id` is provied and not of type `uuid.UUID`.
"""
namespace = "dynamo" if namespace is None else namespace
if not isinstance(namespace, str):
raise TypeError("Argument `namespace` must be `str` or `None`.")
if dynamo_context is not None and "runtime" in dynamo_context:
runtime = dynamo_context["runtime"] if runtime is None else runtime
if not isinstance(runtime, DistributedRuntime) or runtime is None:
raise TypeError("Argument `runtime` must be `dynamo.runtime.DistributedRuntime` or `None`.")
worker_id = worker_id if worker_id is not None else str(uuid.uuid4())
if not isinstance(worker_id, str) or len(worker_id) == 0:
raise TypeError("Argument `worker_id` must be a non-empty `str` or `None`.")
self._worker_id = worker_id
self._is_initialized = False
self._runtime = runtime
self._namespace = namespace
self._nixl = nixl_api.nixl_agent(self._worker_id)
self._hostname = socket.gethostname()
self._agent_metadata: Optional[bytes] = None
logger.debug(f"Created {self.__repr__()}.")
def __repr__(self) -> str:
return str(
f"{self.__class__.__name__}("
f"worker_id='{self._worker_id}', "
f"namespace={self._namespace}, "
f"hostname={self._hostname}, "
f"metadata=<{0 if self._agent_metadata is None else len(self._agent_metadata)} bytes>"
")"
)
def __str__(self) -> str:
return self._worker_id
@cached_property
def is_cuda_available(self) -> bool:
# Note: cuda.is_avalailable initializes cuda
# and can't be called when forking subprocesses
# care should be taken to only call it within
# subprocesses or use 'spawn'
try:
return array_module.cuda is not None and array_module.cuda.is_available()
except CUDARuntimeError:
return False
@property
def metadata(self) -> bytes:
"""
Get the metadata of the worker.
"""
return self._nixl.get_agent_metadata()
@property
def name(self) -> str | None:
"""
Get the name of the worker.
"""
return self._worker_id
@property
def namespace(self) -> str:
"""
Get the namespace of the local.
"""
return self._namespace
@property
def runtime(self) -> DistributedRuntime:
"""
Get the runtime of the local.
"""
if self._runtime is None:
raise RuntimeError("Runtime is not set. This Connector was not initialized with a runtime.")
return self._runtime
async def begin_read(
self,
remote_request: SerializedRequest,
local_descriptors: Descriptor | list[Descriptor],
) -> ReadOperation:
"""
Creates a read operation for fulfilling a remote readable operation.
Parameters
----------
remote_request : SerializedRequest
Serialized request from a remote worker that has created a readable operation.
local_descriptors : Descriptor | list[Descriptor]
Local descriptor(s) to receive data from the remote worker described by `remote_request`.
Returns
-------
ReadOperation
Awaitable read operation that can be used to transfer data from a remote worker.
Raises
------
TypeError
When `remote_request` is not of type `SerializedRequest`.
TypeError
When `local_descriptors` is not of type `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.
"""
if remote_request is None or not isinstance(remote_request, SerializedRequest):
raise TypeError("Argument `remote_request` must be `SerializedRequest`.")
if not (
isinstance(local_descriptors, Descriptor)
or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors))
):
raise TypeError("Argument `local_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.")
if remote_request.operation_kind != OperationKind.READ.value:
raise RuntimeError("Cannot create a `dynamo.connect.ReadOperation` to read from a remote `dynamo.connect.WritableOperation`.")
if not self._is_initialized:
raise RuntimeError("Connector not initialized. Call `initialize()` before calling this method.")
op = ReadOperation(self, remote_request, local_descriptors)
return op
async def begin_write(
self,
local_descriptors: Descriptor | list[Descriptor],
remote_request: SerializedRequest,
) -> WriteOperation:
"""
Creates a write operation for transferring data to a remote worker.
Parameters
----------
remote_request : SerializedRequest
Serialized request from a remote worker that has created a readable operation.
local_descriptors : Descriptor | list[Descriptor]
Local descriptors of one or more data objects to be transferred to the remote worker.
"""
if remote_request is None or not isinstance(remote_request, SerializedRequest):
raise TypeError("Argument `remote_request` must be `SerializedRequest`.")
if not (
isinstance(local_descriptors, Descriptor)
or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors))
):
raise TypeError("Argument `local_descriptors` must be `Descriptor` or `list[Descriptor]`.")
if remote_request.operation_kind != OperationKind.WRITE:
raise RuntimeError("Cannot create a `WriteOperation` to write to a remote `ReadableOperation`.")
if not isinstance(remote_request.nixl_metadata, str):
raise TypeError("Argument `remote_request.nixl_metadata` must be `str`.")
if not self._is_initialized:
raise RuntimeError("Connector not initialized. Call `initialize()` before calling this method.")
op = WriteOperation(self, local_descriptors, remote_request)
return op
def create_readable(
self,
local_descriptors: Descriptor | list[Descriptor],
) -> ReadableOperation:
"""
Creates a readable operation for transferring data from a remote worker.
Returns
-------
ReadableOperation
A readable operation that can be used to transfer data from a remote worker.
"""
if not self._is_initialized:
raise RuntimeError("Connector not initialized. Call `initialize()` before calling this method.")
op = ReadableOperation(self, local_descriptors)
return op
def create_writable(
self,
local_descriptors: Descriptor | list[Descriptor],
) -> WritableOperation:
"""
Creates a writable operation for transferring data to a remote worker.
Returns
-------
WritableOperation
A writable operation that can be used to transfer data to a remote worker.
"""
if not self._is_initialized:
raise RuntimeError("Connector not initialized. Call `initialize()` before calling this method.")
op = WritableOperation(self, local_descriptors)
return op
async def initialize(self) -> None:
# Only initialize the connector once.
if self._is_initialized:
return
self._is_initialized = True
# This method is a no-op for now, in the future it may be used to initialize the connector.
logger.debug(f"Initialized Connector {{ name: '{self._worker_id}', namespace '{self._namespace}' }} completed.")
class Descriptor:
"""
Memory descriptor that ensures memory is registered w/ NIXL, used for transferring data between workers.
"""
def __init__(
self,
data: torch.Tensor | tuple[array_module.ndarray, Device|str] | bytes | tuple[int, int, Device|str, Any],
) -> None:
"""
Memory descriptor for transferring data between workers.
Parameters
----------
data : torch.Tensor | tuple[ndarray, Device|str] | bytes | tuple[int, int, Device|str, Any]
The data to be transferred.
When `torch.Tensor` is provided, the attributes of the tensor will be used to create the descriptor.
When `tuple[ndarray, Device]` is provided, the tuple must contain:
- `ndarray`: The CuPy or NumPy array to be transferred.
- `Device`: Either a `dynamo.connect.Device` or a string representing the device type (e.g., "cuda" or "cpu").
When `bytes` is provided, the pointer and size derived from the bytes object and memory type will be assumed to be CPU.
When `tuple[int, int, Device|str, Any]` is provided, the tuple must contain the following elements:
- `int`: Pointer to the data in memory.
- `int`: Size of the data in bytes.
- `Device`: Either a `dynamo.connect.Device` or a string representing the device type (e.g., "cuda" or "cpu").
- `Any`: Optional reference to the data (e.g., the original tensor or bytes object).
This is useful for keeping a reference to the data in memory, but it is not required.
Raises
------
ValueError
When `data` is `None`.
TypeError
When `data` is not a valid type (i.e., not `torch.Tensor`, `bytes`, or a valid tuple).
TypeError
When `data` is a tuple but the elements are not of the expected types (i.e., [`ndarray`, `Device|str`] OR [`int`, `int`, `Device|str`, `Any`]).
"""
TYPE_ERROR_MESSAGE = "Argument `data` must be `torch.Tensor`, `tuple[ndarray, Device|str]`, `bytes`, or `tuple[int, int, Device|str, Any]`."
if data is None:
raise ValueError("Argument `data` cannot be `None`.")
if not (isinstance(data, torch.Tensor) or isinstance(data, bytes) or isinstance(data, tuple)):
raise TypeError(TYPE_ERROR_MESSAGE)
self._data_device: Device = Device("cpu")
self._data_ptr: int = 0
self._data_ref: Optional[Any] = None
self._data_size: int = 0
# Member fields for managing NIXL memory registration.
# Note: ONLY local descriptors should be registered with NIXL,
# remote descriptors do not have a valid memory address and registration will fault.
self._connector: Optional[Connector] = None
self._nixl_hndl: Optional[nixl_bindings.nixlRegDList] = None
# Initially `None` cached serialized descriptor reference, populated when `to_serialized()` is called.
self._serialized: Optional[SerializedDescriptor] = None
# Data is `torch.Tensor`.
if isinstance(data, torch.Tensor):
self._data_ptr = data.data_ptr()
self._data_size = data.numel() * data.element_size()
if data.is_cuda:
self._data_device = Device((DeviceKind.CUDA, data.get_device()))
self._data_ref = data
logger.debug(f"Created {self.__repr__()} from `torch.Tensor`.")
# Data is `tuple[ndarray, Device]`.
elif (
isinstance(data, tuple)
and len(data) == 2
and isinstance(data[0], array_module.ndarray)
and (isinstance(data[1], Device) or isinstance(data[1], str))
):
if hasattr(data[0], "__array_interface__"):
self._data_ptr = data[0].__array_interface__["data"][0]
elif hasattr(data[0], "__cuda_array_interface__"):
self._data_ptr = data[0].__cuda_array_interface__["data"][0]
else:
raise TypeError("Argument `data[0]` must be a `ndarray` with a valid array interface.")
self._data_size = data[0].nbytes
self._data_device = data[1] if isinstance(data[1], Device) else Device(data[1])
self._data_ref = data[0]
logger.debug(f"Created {self.__repr__()} from `tuple[ndarray, Device|str]`.")
# Data is `bytes`.
elif isinstance(data, bytes):
self._data_ptr = id(data)
self._data_size = len(data)
self._data_ref = data
logger.debug(f"Created {self.__repr__()} from `bytes`.")
# Data is `tuple[int, int, Device, dtype, tuple, Any]`.
elif isinstance(data, tuple) and len(data) >= 2 and isinstance(data[0], int) and isinstance(data[1], int):
if len(data) >= 3 and not (isinstance(data[2], Device) or isinstance(data[2], str)):
raise TypeError("Argument `data` must be a `tuple[int, int, Device|str, Any]`.")
self._data_ptr = data[0]
self._data_size = data[1]
if len(data) >= 3:
self._data_device = data[2] if isinstance(data[2], Device) else Device(data[2])
self._data_ref = data[3] if len(data) >=4 else None
logger.debug(f"Created {self.__repr__()} from `tuple[int, int, Device|str, Any]`.")
else:
raise TypeError(TYPE_ERROR_MESSAGE)
def __del__(self) -> None:
if self._nixl_hndl is not None and self._connector is not None:
# Unregister the memory with NIXL.
self._connector._nixl.deregister_memory(self._nixl_hndl)
self._nixl_hndl = None
if self._data_ref is not None:
# Release the reference to the data.
del self._data_ref
logger.debug(f"Deleted {self.__repr__()}.")
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self})"
def __str__(self) -> str:
return f"ptr={hex(self._data_ptr)}, size={self._data_size}, device={self._data_device}"
@property
def device(self) -> Device:
"""
Gets the device the of the descriptor.
"""
return self._data_device
@property
def ptr(self) -> int:
"""
Gets the pointer of the descriptor.
"""
return self._data_ptr
@property
def size(self) -> int:
"""
Gets the size of the descriptor.
"""
return self._data_size
@staticmethod
def from_serialized(
serialized: SerializedDescriptor,
) -> Descriptor:
"""
Deserializes a `SerializedDescriptor` into a `Descriptor` object.
Parameters
----------
serialized : SerializedDescriptor
The serialized descriptor to deserialize.
Returns
-------
Descriptor
The deserialized descriptor.
"""
if not isinstance(serialized, SerializedDescriptor):
raise TypeError("Argument `serialized` must be `SerializedDescriptor`.")
return serialized.to_descriptor()
def register_memory(
self,
connector: Connector,
) -> None:
"""
Registers the memory of the descriptor with NIXL.
"""
if not isinstance(connector, Connector):
raise TypeError("Argument `connector` must be `dynamo.connect.Connector`.")
if self._data_ptr == 0:
raise ValueError("Cannot register memory with a null pointer.")
if not (self._nixl_hndl is None and self._connector is None):
return
# Register the memory with NIXL.
self._connector = connector
if isinstance(self._data_ref, torch.Tensor):
self._nixl_hndl = connector._nixl.register_memory(self._data_ref)
else:
mem_type = str(self._data_device.kind)
reg_list = [(self._data_ptr, self._data_size, self._data_device.id, mem_type)]
self._nixl_hndl = connector._nixl.register_memory(reg_list, mem_type)
logger.debug(f"Registered {self.__repr__()} with NIXL.")
def to_serialized(self) -> SerializedDescriptor:
"""
Serializes the descriptor into a `SerializedDescriptor` object.
"""
if self._serialized is None:
self._serialized = SerializedDescriptor(
device=f"{self._data_device}",
ptr=self._data_ptr,
size=self._data_size,
)
return self._serialized
class Device:
"""
Represents a device in the system.
"""
def __init__(
self,
metadata: str | tuple[DeviceKind, int],
) -> None:
if metadata is None:
raise ValueError("Argument `metadata` cannot be `None`.")
if isinstance(metadata, tuple) and len(metadata) == 2 and isinstance(metadata[0], DeviceKind) and isinstance(metadata[1], int):
kind, device_id = metadata
elif isinstance(metadata, str):
metadata = metadata.strip().lower()
if metadata.startswith("cuda") or metadata.startswith("gpu"):
kind = DeviceKind.CUDA
device_id = 0 if metadata.find(":") == -1 else int(metadata.split(":")[1])
elif metadata.startswith("cpu") or metadata.startswith("host"):
kind = DeviceKind.HOST
device_id = 0
else:
raise ValueError("Argument `metadata` must be in the format 'cuda:<device_id>' or 'cpu'.")
else:
raise TypeError("Argument `metadata` must be a `tuple[MemoryKind, int]` or a `str`.")
self._device_id = device_id
self._kind = kind
def __repr__(self) -> str:
return f"{self.__class__.__name__}(kind={self._kind}, id={self._device_id})"
def __str__(self) -> str:
return f"{self._kind}:{self._device_id}" if self._kind is DeviceKind.CUDA else f"{self._kind}"
@property
def id(self) -> int:
"""
Gets the device ID of the device.
"""
return self._device_id
@property
def kind(self) -> DeviceKind:
"""
Gets the memory kind of the device.
"""
return self._kind
class DeviceKind(IntEnum):
"""
Type of memory a descriptor has been allocated to.
"""
UNSPECIFIED = 0
HOST = 1
CUDA = 2
def __str__(self) -> str:
if self == DeviceKind.HOST:
return "cpu"
elif self == DeviceKind.CUDA:
return "cuda"
else:
return "<invalid>"
class OperationKind(IntEnum):
"""
Kind of an operation.
"""
UNSPECIFIED = 0
READ = 1
WRITE = 2
def __str__(self) -> str:
if self == OperationKind.READ:
return "READ"
elif self == OperationKind.WRITE:
return "WRITE"
else:
return "<invalid>"
class OperationStatus(IntEnum):
"""
Status of an operation.
"""
UNINTIALIZED = 0
INITIALIZED = 1
IN_PROGRESS = 2
COMPLETE = 3
CANCELLED = 4
ERRORED = 5
def __str__(self) -> str:
if self == OperationStatus.INITIALIZED:
return "INIT"
elif self == OperationStatus.IN_PROGRESS:
return "PROC"
elif self == OperationStatus.COMPLETE:
return "DONE"
elif self == OperationStatus.ERRORED:
return "ERR"
elif self == OperationStatus.CANCELLED:
return "STOP"
else:
return "<invalid>"
class PassiveOperation(AbstractOperation):
"""
Abstract class for common functionality of passive operations.
"""
def __init__(
self,
connector: Connector,
operation_kind: OperationKind,
local_descriptors: Descriptor | list[Descriptor],
) -> None:
if operation_kind is not OperationKind.READ and operation_kind is not OperationKind.WRITE:
raise ValueError("Argument `operation_kind` must be either `READ` or `WRITE`.")
self._status = OperationStatus.UNINTIALIZED
super().__init__(connector, operation_kind, local_descriptors, None, None)
self._serialized_request: Optional[SerializedRequest] = None
self._status = OperationStatus.INITIALIZED
def __del__(self) -> None:
super().__del__()
def __enter__(self) -> AbstractOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
super().__exit__(exc_type, exc_value, traceback)
def __repr__(self) -> str:
return str(
f"{self.__class__.__name__}("
f"operation_kind={self._operation_kind}, "
f"local_descriptors={self._local_descriptors}, "
f"notification_key='{self._notification_key}', "
f"status='{self._status}'"
f")"
)
async def _wait_for_completion_(self) -> None:
# Loop until the operation is no longer in progress (or "initalized"),
# yielding control to the event loop to allow other operations to run.
while True:
match self.status:
# "in progress" or "initialized" means the operation is ongoing.
case OperationStatus.INITIALIZED:
await asyncio.sleep(0.1)
case OperationStatus.IN_PROGRESS:
await asyncio.sleep(0.1)
# Any other state indicates completion or error.
case _:
return
@property
def status(self) -> OperationStatus:
"""
Gets the status of the operation.
"""
# Early return if the operation is already complete, errored, or cancelled.
match self._status:
case OperationStatus.COMPLETE | OperationStatus.ERRORED | OperationStatus.CANCELLED:
return self._status
old_status = self._status
# Query NIXL for any notifications.
notifications = self._connector._nixl.update_notifs()
if isinstance(notifications, dict):
remote_state = OperationStatus.IN_PROGRESS
logger.debug(f"NIXL reported notifications: {len(notifications)}.")
for key, values in notifications.items():
if not isinstance(values, list):
raise TypeError(f"Expected `dict[str, list[bytes]]` from NIXL notification query; got {type(notifications)}.")
for value in values:
if not isinstance(value, bytes):
continue
notification_key = value.decode("utf-8")
# Once we've found the notification key, we know the operation is complete.
if notification_key == self._notification_key:
remote_state = OperationStatus.COMPLETE
break
if remote_state == OperationStatus.COMPLETE:
self._status = remote_state
logger.debug(f"{self.__class__.__name__} {{ remote: '{self._connector.name}' status: '{old_status}' => '{self._status}' }}.")
return self._status
def to_serialized(self) -> SerializedRequest:
"""
Gets the request descriptor for the operation.
"""
if self._serialized_request is None:
# When we've not yet cached the serialized request, we need to generate one before returning it.
# Handle both cases: multiple and single descriptors.
if isinstance(self._local_descriptors, list):
descriptors = [desc.to_serialized() for desc in self._local_descriptors]
else:
descriptors = [self._local_descriptors.to_serialized()]
original_len = len(self._connector.metadata)
nixl_metadata = self._connector.metadata
nixl_metadata = zlib.compress(nixl_metadata, level=6)
compressed_len = len(nixl_metadata)
logger.debug(f"Compressed NIXL metadata from {original_len} bytes to {compressed_len} bytes.")
if compressed_len > original_len:
logger.warning(f"Compressed NIXL metadata is larger than original ({compressed_len} > {original_len}).")
self._serialized_request = SerializedRequest(
descriptors=descriptors,
nixl_metadata=nixl_metadata.hex(),
notification_key=self._notification_key,
operation_kind=int(self._operation_kind),
)
return self._serialized_request
@abstractmethod
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
raise NotImplementedError("Abstract method not implemented by derived class.")
class ReadOperation(ActiveOperation):
"""
Operation that initiates an RDMA read operation to transfer data from a remote worker's `ReadableOperation`,
as described by `remote_request`, to local buffers.
"""
def __init__(
self,
connector: Connector,
remote_request: SerializedRequest,
local_descriptors: Descriptor | list[Descriptor],
) -> None:
"""
Creates a new instance of `ReadOperation`, registers `local_descriptors` with NIXL,
and begins an RDMA read operation which will transfer data described by `remote_request`
to `local_descriptors`.
Parameters
----------
connector : Connector
Connector instance to use for the operation.
remote_request : SerializedRequest
Serialized request from the remote worker.
local_descriptors : Descriptor | list[Descriptor]
Local descriptor(s) to to receive the data from the remote worker.
"""
if not isinstance(connector, Connector):
raise TypeError("Argument `connector` must be `dynamo.connect.Connector`.")
if not isinstance(remote_request, SerializedRequest):
raise TypeError("Argument `remote_request` must be `dynamo.connect.RequestDescriptor`.")
if remote_request.operation_kind != OperationKind.READ.value:
raise ValueError("Argument `remote_request` must be of kind `READ`.")
remote = Remote(connector, remote_request.nixl_metadata)
remote_descriptors = remote_request.to_descriptors()
if not (
isinstance(local_descriptors, Descriptor)
or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors))
):
raise TypeError("Argument `local_descriptors` must be `dynamo.connect.Descriptor`, `list[dynamo.connect.Descriptor]`.")
super().__init__(remote, OperationKind.READ, local_descriptors, remote_descriptors, remote_request.notification_key)
logger.debug(f"Created {self.__repr__()}")
def __del__(self) -> None:
super().__del__()
logger.debug(f"Deleted {self.__repr__()}")
def __enter__(self) -> ReadOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
super().__exit__(exc_type, exc_value, traceback)
def __repr__(self) -> str:
return super().__repr__()
def cancel(self) -> None:
"""
Cancels the operation.
No affect if the operation has already completed or errored, or been cancelled.
"""
super()._cancel_()
def results(self) -> list[Descriptor]:
"""
Gets the results of the operation.
Returns a single descriptor if only one was requested, or a list of descriptors if multiple were requested.
"""
if self._status != OperationStatus.COMPLETE:
raise RuntimeError("Operation has not completed yet, cannot get results.")
return self._local_descriptors if isinstance(self._local_descriptors, list) else [self._local_descriptors]
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await super()._wait_for_completion_()
class ReadableOperation(PassiveOperation):
"""
Operation that can be awaited until a remote worker has completed a `ReadOperation`.
"""
def __init__(
self,
connector: Connector,
local_descriptors: Descriptor | list[Descriptor],
) -> None:
super().__init__(connector, OperationKind.READ, local_descriptors)
logger.debug(f"Created {self.__repr__()}")
def __del__(self) -> None:
super().__del__()
logger.debug(f"Deleted {self.__repr__()}")
def __enter__(self) -> ReadableOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
super().__exit__(exc_type, exc_value, traceback)
def __repr__(self) -> str:
return super().__repr__()
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await super()._wait_for_completion_()
class Remote:
"""
Identifies a remote NIXL enabled worker relative to a local NIXL enabled worker.
"""
def __init__(
self,
connector: Connector,
nixl_metadata: bytes | str,
) -> None:
if not isinstance(connector, Connector):
raise TypeError("Argument `local` must be `dynamo.connect.Connector`.")
if not (isinstance(nixl_metadata, bytes) or isinstance(nixl_metadata, str)):
raise TypeError("Argument `nixl_metadata` must be `bytes` or `str`.")
if len(nixl_metadata) == 0:
raise ValueError("Argument `nixl_metadata` cannot be empty.")
self._connector = connector
# When `nixl_metadata` is a string, it is assumed to have come from a remote worker
# via a `SerializedRequest` object and therefore can assumed be a hex-encoded, compressed
# representation of the NIXL metadata.
if isinstance(nixl_metadata, str):
# Decode the hex-encoded string into bytes.
nixl_metadata = bytes.fromhex(nixl_metadata)
# Decompress the NIXL metadata.
nixl_metadata = zlib.decompress(nixl_metadata)
self._name = connector._nixl.add_remote_agent(nixl_metadata)
if isinstance(self._name, bytes):
self._name = self._name.decode("utf-8")
logger.debug(f"Created {self.__repr__()}.")
def __del__(self) -> None:
self._release()
def __enter__(self) -> Remote:
"""
Context manager entry method. Returns the current instance.
"""
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
"""
Context manager exit method. Cleans up the instance.
"""
self._release()
def __repr__(self) -> str:
return f"Remote(name={self._name}, connector={self._connector.name})"
def __str__(self) -> str:
return self._name
def _release(self) -> None:
"""
Private method for releasing NIXL resources. Not intended for public use.
"""
# We have to unregister the remote agent from NIXL because we cannot know if the remote worker has updated its descriptors or not, and
# NIXL will return an error if we attempt to register a remote agent with the same name but different descriptors (aka conn_info).
self._connector._nixl.remove_remote_agent(self._name)
logger.debug(f"dynamo.connect.{self.__class__.__name__}: Unregistered NIXL remote {{ name: \"{self._name}\" }}.")
@property
def connector(self) -> Connector:
"""
Gets the local connector associated with this remote worker.
"""
return self._connector
@property
def name(self) -> str:
"""
Gets the name of the remote worker.
"""
return self._name
class SerializedDescriptor(BaseModel):
"""
Pydantic serialization type for memory descriptors.
"""
model_config = ConfigDict(
extra="forbid",
frozen=True,
arbitrary_types_allowed=True,
)
device: str = "cpu"
ptr: int = 0
size: int = 0
def to_descriptor(self) -> Descriptor:
"""
Deserialize the serialized descriptor into a `Descriptor` object.
"""
return Descriptor(data=(self.ptr, self.size, self.device, None))
@field_validator("device")
def validate_memtype(cls, v: str) -> str:
if not isinstance(v, str):
raise TypeError("Argument `device` must be `str`.")
v = v.strip().lower()
if not (v.startswith("cuda") or v == "cpu"):
raise ValueError("Argument `device` must be one of 'cpu' or 'cuda:<device_id>'.")
return v
@field_validator("ptr")
def validate_ptr(cls, v: int) -> int:
if v == 0:
raise ValueError("Argument `ptr` cannot be zero (aka `null` or `None`).")
return v
@field_validator("size")
def validate_size(cls, v: int) -> int:
if v < 0:
raise ValueError("Argument `size` must be an integer greater than or equal to zero.")
return v
class SerializedRequest(BaseModel):
"""
Pydantic serialization type for describing the passive side of a transfer.
"""
model_config = ConfigDict(
extra="forbid",
frozen=True,
arbitrary_types_allowed=True,
)
descriptors: List[SerializedDescriptor] = []
nixl_metadata: str = ""
notification_key: str = ""
operation_kind: int = 0
def to_descriptors(self) -> Descriptor | list[Descriptor]:
"""
Deserializes the request descriptor into a `dynamo.connect.Descriptor` or list of `dynamo.connect.Descriptor` objects.
"""
if len(self.descriptors) == 0:
raise ValueError("Request descriptor must contain at least one serialized descriptor.")
if len(self.descriptors) == 1:
return self.descriptors[0].to_descriptor()
return [item.to_descriptor() for item in self.descriptors]
@field_validator("operation_kind")
def validate_operation_kind(cls, v: int) -> int:
if v < 1 or v > 3:
raise TypeError("Argument `operation_kind` must be an integer value of `dynamo.connect.OperationKind`.")
return v
class WritableOperation(PassiveOperation):
"""
Operation which can be awaited until written to by a `WriteOperation` from a remote worker.
"""
def __init__(
self,
connector: Connector,
local_descriptors: Descriptor | list[Descriptor],
) -> None:
"""
Creates a new instance of `WritableOperation`, registers the operation and descriptors w/ NIXL,
and enables an RDMA write operation to occur.
Parameters
----------
connector : Connector
Connector instance to use for the operation.
local_descriptors : Descriptor | list[Descriptor]
Descriptors to receive data from a remote worker.
Raises
TypeError
When `local` is not a `dynamo.connect.Connector`.
TypeError
When `local_descriptors` is not a `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.
"""
super().__init__(connector, OperationKind.WRITE, local_descriptors)
logger.debug(f"Created {self.__repr__()}")
def __del__(self) -> None:
super().__del__()
logger.debug(f"Deleted {self.__repr__()}")
def __enter__(self) -> WritableOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
super().__exit__(exc_type, exc_value, traceback)
def __repr__(self) -> str:
return super().__repr__()
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await super()._wait_for_completion_()
class WriteOperation(ActiveOperation):
"""
Awaitable write operation which initiates an RDMA write operation to a remote worker
which provided a `SerializedRequest` object from a `WritableOperation`.
"""
def __init__(
self,
connector: Connector,
local_descriptors: Descriptor | list[Descriptor],
remote_request: SerializedRequest,
) -> None:
"""
Creates a new instance of `WriteOperation`, registers `local_descriptors` with NIXL,
and begins an RDMA write operation which will transfer from `local_descriptors` to
remote target(s) described by `remote_request`
Parameters
----------
connector : Connector
Connector instance to use for the operation.
local_descriptors : Descriptor | list[Descriptor]
Local descriptor(s) to send from, to the remote worker.
remote_request : SerializedRequest
Serialized request from the remote worker that describes the target(s) to send to.
Raises
TypeError
When `connector` is not a `dynamo.connect.Connector`.
TypeError
When `remote_request` is not a `dynamo.connect.RequestDescriptor`.
ValueError
When `remote_request` is not of kind `WRITE`.
ValueError
When `remote_request.nixl_metadata` is not a non-empty `str`.
TypeError
When `local_descriptors` is not a `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.
"""
if not isinstance(connector, Connector):
raise TypeError("Argument `connector` must be `dynamo.connect.Connector`.")
if not isinstance(remote_request, SerializedRequest):
raise TypeError("Argument `remote_request` must be `dynamo.connect.RequestDescriptor`.")
if remote_request.operation_kind != OperationKind.WRITE.value:
raise ValueError("Argument `remote_request` must be of kind `WRITE`.")
remote = Remote(connector, remote_request.nixl_metadata)
remote_descriptors = remote_request.to_descriptors()
super().__init__(remote, OperationKind.WRITE, local_descriptors, remote_descriptors, remote_request.notification_key)
logger.debug(f"Created {self.__repr__()}")
def __del__(self) -> None:
super().__del__()
logger.debug(f"Deleted {self.__repr__()}")
def __enter__(self) -> WriteOperation:
super().__enter__()
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
super().__exit__(exc_type, exc_value, traceback)
def __repr__(self) -> str:
return super().__repr__()
def cancel(self) -> None:
"""
Cancels the operation.
No affect if the operation has already completed or errored, or has been cancelled.
"""
super()._cancel_()
async def wait_for_completion(self) -> None:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await super()._wait_for_completion_()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from components.encode_worker import VllmEncodeWorker
from components.processor import Processor
from components.web import Frontend
from components.worker import VllmPDWorker
Frontend.link(Processor).link(VllmEncodeWorker).link(VllmPDWorker)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from components.encode_worker import VllmEncodeWorker
from components.processor import Processor
from components.web import Frontend
from components.worker import VllmDecodeWorker, VllmPDWorker
Frontend.link(Processor).link(VllmEncodeWorker).link(VllmPDWorker).link(
VllmDecodeWorker
)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: rename to avoid ambiguity with vllm package
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser
from dynamo.sdk.lib.config import ServiceConfig
def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
config = ServiceConfig.get_instance()
vllm_args = config.as_args(service_name, prefix=prefix)
parser = FlexibleArgumentParser()
parser.add_argument(
"--enable-disagg", action="store_true", help="Enable disaggregation"
)
parser.add_argument(
"--image-token-id",
type=int,
default=32000,
help="Image token ID used to represent image patches in the token sequence",
)
parser.add_argument(
"--num-patches",
type=int,
default=576,
help="Number of patches the input image is divided into (must be positive)",
)
parser.add_argument(
"--prompt-template",
type=str,
default="<prompt>",
help="Prompt template to use for the model",
)
parser.add_argument(
"--router",
type=str,
choices=["random", "round-robin", "kv"],
default="random",
help="Router type to use for scheduling requests to workers",
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args(vllm_args)
engine_args = AsyncEngineArgs.from_cli_args(args)
engine_args.enable_disagg = args.enable_disagg
engine_args.image_token_id = args.image_token_id
engine_args.num_patches = args.num_patches
engine_args.prompt_template = args.prompt_template
engine_args.router = args.router
return engine_args
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
from typing import AsyncIterator, List, Optional, Protocol, Union, runtime_checkable
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.chat_utils import ConversationMessage
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
RequestResponseMetadata,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import RequestPrompt
from vllm.inputs.data import TokensPrompt
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
@runtime_checkable
class ProcessMixInRequired(Protocol):
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
default_sampling_params: SamplingParams
class ProcessMixIn(ProcessMixInRequired):
"""
Mixin for pre and post processing for vLLM
Requires engine_args, engine_client, processor, model_config to be initialized
"""
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
default_sampling_params: SamplingParams
def __init__(self):
pass
def _get_processor(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
# Determine the processor type based on the request structure
return (
self.chat_processor
if isinstance(raw_request, ChatCompletionRequest)
else self.completions_processor
)
async def _parse_raw_request(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
processor = self._get_processor(raw_request)
if processor is None:
raise RuntimeError("Processor has not been initialized")
request = processor.parse_raw_request(raw_request)
preprocess_result = await processor.preprocess(raw_request)
default_max_tokens = self.model_config.max_model_len - len(
preprocess_result.engine_prompt["prompt_token_ids"]
)
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern,
self.default_sampling_params,
)
return (
request,
preprocess_result.conversation,
preprocess_result.request_prompt,
preprocess_result.engine_prompt,
sampling_params,
)
async def _stream_response(self, request, generator, request_id, conversation):
processor = self._get_processor(request)
if processor is None:
raise RuntimeError("processor has not been initialized")
return processor.stream_response(
request,
generator,
request_id,
conversation,
)
class PreprocessResult:
def __init__(
self,
conversation: Optional[ConversationMessage],
request_prompt: RequestPrompt,
engine_prompt: TokensPrompt,
):
self.conversation = conversation
self.request_prompt = request_prompt
self.engine_prompt = engine_prompt
class ChatProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingChat(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
response_role="assistant",
chat_template=None,
chat_template_content_format="auto",
)
def parse_raw_request(
self, raw_request: ChatCompletionRequest
) -> ChatCompletionRequest:
return ChatCompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: ChatCompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
# TODO: Revisit this later when adding multi-modal support for the frontend.
# If no chat template is provided and tokenizer doesn't have one,
# use a simple format that just concatenates messages
if not request.chat_template and not self.tokenizer.chat_template:
chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}User: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}Assistant: {{ message['content'] }}\n{% endif %}{% endfor %}Assistant:"
else:
chat_template = request.chat_template or self.tokenizer.chat_template
(
conversation,
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_chat(
request,
self.tokenizer,
request.messages,
chat_template=chat_template,
chat_template_content_format=self.openai_serving.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tool_dicts=None,
documents=request.documents,
chat_template_kwargs=request.chat_template_kwargs,
tool_parser=self.openai_serving.tool_parser,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
return PreprocessResult(conversation[0], request_prompts[0], engine_prompts[0])
async def stream_response(
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: List,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if request.stream:
# Handle streaming response
num_output_text_so_far = 0
async for raw_response in self.openai_serving.chat_completion_stream_generator(
request,
result_generator,
request_id,
request.model,
conversation,
self.tokenizer,
request_metadata,
enable_force_include_usage=False,
):
if raw_response.startswith("data: [DONE]"):
yield raw_response
break
# Parse the response
response = json.loads(raw_response.lstrip("data: "))
# Process delta content to extract only new text
if "choices" in response and len(response["choices"]) > 0:
if "delta" in response["choices"][0]:
content = response["choices"][0]["delta"].get("content", "")
if content:
# Extract only the new part from the full content
new_content = content[num_output_text_so_far:]
response["choices"][0]["delta"]["content"] = new_content
num_output_text_so_far = len(content)
# Yield the processed response
yield f"data: {json.dumps(response)}\n\n"
else:
# Handle non-streaming response
# Collect all chunks into a single response
full_response = None
num_output_text_so_far = 0
async for raw_response in self.openai_serving.chat_completion_stream_generator(
request,
result_generator,
request_id,
request.model,
conversation,
self.tokenizer,
request_metadata,
enable_force_include_usage=False,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
if full_response is None:
# Initialize the full response structure
full_response = {
"id": response.get("id", ""),
"object": "chat.completion",
"created": int(time.time()),
"model": request.model,
"choices": [
{
"index": response.get("index", 0),
"message": {"role": "assistant", "content": ""},
"finish_reason": None,
}
],
}
# Concatenate content if it exists. Each delta contains the full text so far.
if "choices" in response and len(response["choices"]) > 0:
if "delta" in response["choices"][0]:
content = response["choices"][0]["delta"].get("content", "")
if content:
# Extract only the new part from the full content
new_content = content[num_output_text_so_far:]
full_response["choices"][0]["message"][
"content"
] += new_content
num_output_text_so_far = len(content)
# Update finish reason if present
if "finish_reason" in response["choices"][0]:
full_response["choices"][0]["finish_reason"] = response[
"choices"
][0]["finish_reason"]
if full_response is not None:
yield json.dumps(full_response)
class CompletionsProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingCompletion(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
)
def parse_raw_request(self, raw_request: CompletionRequest) -> CompletionRequest:
return CompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: CompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
(
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_completion(
request,
self.tokenizer,
input_or_inputs=request.prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
return PreprocessResult(None, request_prompts[0], engine_prompts[0])
async def stream_response(
self,
request: CompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: Optional[List[ConversationMessage]] = None,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if not request.stream:
raise ValueError("Only streaming responses are supported")
async for raw_response in self.openai_serving.completion_stream_generator(
request,
result_generator,
request_id,
int(time.time()), # created_time
request.model,
1, # num_prompts
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import base64
import binascii
import logging
from io import BytesIO
from urllib.parse import urlparse
import httpx
from PIL import Image
logger = logging.getLogger(__name__)
class ImageLoader:
CACHE_SIZE_MAXIMUM = 8
def __init__(self, cache_size: int = CACHE_SIZE_MAXIMUM):
self._http_timeout = 30.0
self._http_client = httpx.AsyncClient(timeout=self._http_timeout)
self._image_cache: dict[str, Image.Image] = {}
self._cache_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=cache_size)
async def load_image(self, image_url: str) -> Image.Image:
parsed_url = urlparse(image_url)
# For HTTP(S) URLs, check cache first
if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
if image_url_lower in self._image_cache:
logger.debug(f"Image found in cache for URL: {image_url}")
return self._image_cache[image_url_lower]
try:
if parsed_url.scheme == "data":
# Parse data URL format: data:[<media type>][;base64],<data>
if not parsed_url.path.startswith("image/"):
raise ValueError("Data URL must be an image type")
# Split the path into media type and data
media_type, data = parsed_url.path.split(",", 1)
if ";base64" not in media_type:
raise ValueError("Data URL must be base64 encoded")
try:
image_bytes = base64.b64decode(data)
image_data = BytesIO(image_bytes)
except binascii.Error as e:
raise ValueError(f"Invalid base64 encoding: {e}")
elif parsed_url.scheme in ("http", "https"):
if not self._http_client:
raise RuntimeError("HTTP client not initialized")
response = await self._http_client.get(image_url)
response.raise_for_status()
if not response.content:
raise ValueError("Empty response content from image URL")
image_data = BytesIO(response.content)
else:
raise ValueError(f"Invalid image source scheme: {parsed_url.scheme}")
# PIL is sync, so offload to a thread to avoid blocking the event loop
image = await asyncio.to_thread(Image.open, image_data)
# Validate image format and convert to RGB
if image.format not in ("JPEG", "PNG", "WEBP"):
raise ValueError(f"Unsupported image format: {image.format}")
image_converted = image.convert("RGB")
# Cache HTTP(S) URLs
if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
# Cache the image for future use, and evict the oldest image if the cache is full
if self._cache_queue.full():
oldest_image_url = await self._cache_queue.get()
del self._image_cache[oldest_image_url]
self._image_cache[image_url_lower] = image_converted
await self._cache_queue.put(image_url_lower)
return image_converted
except httpx.HTTPError as e:
logger.error(f"HTTP error loading image: {e}")
raise
except Exception as e:
logger.error(f"Error loading image: {e}")
raise ValueError(f"Failed to load image: {e}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment