Unverified Commit 9b87c89c authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

feat: multi-modal example with vLLM v1 and UX v2 (#2040)


Co-authored-by: default avatarkrishung5 <krish@nvidia.com>
parent 1327e3bb
...@@ -12,7 +12,7 @@ ARG RUNTIME_IMAGE="nvcr.io/nvidia/cuda" ...@@ -12,7 +12,7 @@ ARG RUNTIME_IMAGE="nvcr.io/nvidia/cuda"
ARG RUNTIME_IMAGE_TAG="12.8.1-runtime-ubuntu24.04" ARG RUNTIME_IMAGE_TAG="12.8.1-runtime-ubuntu24.04"
# Make sure to update the dependency version in pyproject.toml when updating this # Make sure to update the dependency version in pyproject.toml when updating this
ARG VLLM_REF="f4135232b9a8c4845f8961fb1cd17581c56ae2ce" ARG VLLM_REF="ba81acbdc1eec643ba815a76628ae3e4b2263b76"
ARG TORCH_BACKEND="cu128" ARG TORCH_BACKEND="cu128"
# Match 0.10.0 vLLM release # Match 0.10.0 vLLM release
...@@ -186,6 +186,7 @@ RUN if [ "$ARCH" = "arm64" ]; then \ ...@@ -186,6 +186,7 @@ RUN if [ "$ARCH" = "arm64" ]; then \
# Install vllm - keep this early in Dockerfile to avoid # Install vllm - keep this early in Dockerfile to avoid
# rebuilds from unrelated source code changes # rebuilds from unrelated source code changes
ARG VLLM_REF ARG VLLM_REF
ARG VLLM_GIT_URL
ARG DEEPGEMM_REF ARG DEEPGEMM_REF
ARG FLASHINF_REF ARG FLASHINF_REF
......
...@@ -20,7 +20,8 @@ set -euo pipefail ...@@ -20,7 +20,8 @@ set -euo pipefail
# Parse arguments # Parse arguments
EDITABLE=true EDITABLE=true
VLLM_REF="f4135232b9a8c4845f8961fb1cd17581c56ae2ce" VLLM_REF="ba81acbdc1eec643ba815a76628ae3e4b2263b76"
VLLM_GIT_URL="https://github.com/vllm-project/vllm.git"
MAX_JOBS=16 MAX_JOBS=16
INSTALLATION_DIR=/tmp INSTALLATION_DIR=/tmp
ARCH=$(uname -m) ARCH=$(uname -m)
...@@ -49,6 +50,10 @@ while [[ $# -gt 0 ]]; do ...@@ -49,6 +50,10 @@ while [[ $# -gt 0 ]]; do
VLLM_REF="$2" VLLM_REF="$2"
shift 2 shift 2
;; ;;
--vllm-git-url)
VLLM_GIT_URL="$2"
shift 2
;;
--max-jobs) --max-jobs)
MAX_JOBS="$2" MAX_JOBS="$2"
shift 2 shift 2
...@@ -113,7 +118,7 @@ uv pip install lmcache ...@@ -113,7 +118,7 @@ uv pip install lmcache
# Create vllm directory and clone # Create vllm directory and clone
mkdir -p $INSTALLATION_DIR mkdir -p $INSTALLATION_DIR
cd $INSTALLATION_DIR cd $INSTALLATION_DIR
git clone https://github.com/vllm-project/vllm.git git clone $VLLM_GIT_URL vllm
cd vllm cd vllm
git checkout $VLLM_REF git checkout $VLLM_REF
...@@ -148,7 +153,7 @@ fi ...@@ -148,7 +153,7 @@ fi
# Install ep_kernels and DeepGEMM # Install ep_kernels and DeepGEMM
echo "Installing ep_kernels and DeepGEMM" echo "Installing ep_kernels and DeepGEMM"
cd tools/ep_kernels cd tools/ep_kernels
bash install_python_libraries.sh # These libraries aren't pinned. TORCH_CUDA_ARCH_LIST="9.0;10.0" bash install_python_libraries.sh # These libraries aren't pinned.
cd ep_kernels_workspace cd ep_kernels_workspace
git clone https://github.com/deepseek-ai/DeepGEMM.git git clone https://github.com/deepseek-ai/DeepGEMM.git
cd DeepGEMM cd DeepGEMM
......
...@@ -42,7 +42,10 @@ class MetricsPublisher(StatLoggerBase): ...@@ -42,7 +42,10 @@ class MetricsPublisher(StatLoggerBase):
logger.info(f"ZMQ publisher initialized on port {port}") logger.info(f"ZMQ publisher initialized on port {port}")
def record( def record(
self, scheduler_stats: SchedulerStats, iteration_stats: Optional[IterationStats] self,
scheduler_stats: SchedulerStats,
iteration_stats: Optional[IterationStats],
engine_idx: int = 0,
): ):
# Send metrics over ZMQ # Send metrics over ZMQ
metrics_data = { metrics_data = {
......
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Multimodal Deployment Examples
This directory provides example workflows and reference implementations for deploying a multimodal model using Dynamo and vLLM v1.
## Use the Latest Release
We recommend using the latest stable release of dynamo to avoid breaking changes:
[![GitHub Release](https://img.shields.io/github/v/release/ai-dynamo/dynamo)](https://github.com/ai-dynamo/dynamo/releases/latest)
You can find the latest release [here](https://github.com/ai-dynamo/dynamo/releases/latest) and check out the corresponding branch with:
```bash
git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
```
## Multimodal Aggregated Serving
### Components
- workers: For aggregated serving, we have two workers, [VllmEncodeWorker](components/encode_worker.py) for encoding and [VllmPDWorker](components/worker.py) for prefilling and decoding.
- processor: Tokenizes the prompt and passes it to the VllmEncodeWorker.
- frontend: HTTP endpoint to handle incoming requests.
### Graph
In this graph, we have two workers, [VllmEncodeWorker](components/encode_worker.py) and [VllmPDWorker](components/worker.py).
The VllmEncodeWorker is responsible for encoding the image and passing the embeddings to the VllmPDWorker via a combination of NATS and RDMA.
The work complete event is sent via NATS, while the embeddings tensor is transferred via RDMA through the NIXL interface.
Its VllmPDWorker then prefills and decodes the prompt, just like the [LLM aggregated serving](/components/backends/vllm/README.md) example.
By separating the encode from the prefill and decode stages, we can have a more flexible deployment and scale the
VllmEncodeWorker independently from the prefill and decode workers if needed.
This figure shows the flow of the graph:
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --image_url--> encode_worker
encode_worker --> processor
encode_worker --embeddings--> pd_worker
pd_worker --> encode_worker
```
```bash
cd $DYNAMO_HOME/examples/multimodal_v1
# Serve a LLaVA 1.5 7B model:
bash launch/agg.sh --model llava-hf/llava-1.5-7b-hf
# Serve a Qwen2.5-VL model:
# bash launch/agg.sh --model Qwen/Qwen2.5-VL-7B-Instruct
# Serve a Phi3V model:
# bash launch/agg.sh --model microsoft/Phi-3.5-vision-instruct
```
### Client
In another terminal:
```bash
curl http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llava-hf/llava-1.5-7b-hf",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What is in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
}
}
]
}
],
"max_tokens": 300,
"temperature": 0.0,
"stream": false
}'
```
If serving the example Qwen model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"Qwen/Qwen2.5-VL-7B-Instruct"`. If serving the example Phi3V model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"microsoft/Phi-3.5-vision-instruct"`.
You should see a response similar to this:
```json
{"id": "c37b946e-9e58-4d54-88c8-2dbd92c47b0c", "object": "chat.completion", "created": 1747725277, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " In the image, there is a city bus parked on a street, with a street sign nearby on the right side. The bus appears to be stopped out of service. The setting is in a foggy city, giving it a slightly moody atmosphere."}, "finish_reason": "stop"}]}
```
## Multimodal Disaggregated Serving
### Components
- workers: For disaggregated serving, we have three workers, [VllmEncodeWorker](components/encode_worker.py) for encoding, [VllmDecodeWorker](components/worker.py) for decoding, and [VllmPDWorker](components/worker.py) for prefilling.
- processor: Tokenizes the prompt and passes it to the VllmEncodeWorker.
- frontend: HTTP endpoint to handle incoming requests.
### Graph
In this graph, we have three workers, [VllmEncodeWorker](components/encode_worker.py), [VllmDecodeWorker](components/worker.py), and [VllmPDWorker](components/worker.py).
For the Llava model, embeddings are only required during the prefill stage. As such, the VllmEncodeWorker is connected directly to the prefill worker.
The VllmEncodeWorker is responsible for encoding the image and passing the embeddings to the prefill worker via a combination of NATS and RDMA.
Its work complete event is sent via NATS, while the embeddings tensor is transferred via RDMA through the NIXL interface.
The prefill worker performs the prefilling step and forwards the KV cache to the decode worker for decoding.
For more details on the roles of the prefill and decode workers, refer to the [LLM disaggregated serving](/components/backends/vllm/README.md) example.
This figure shows the flow of the graph:
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --image_url--> encode_worker
encode_worker --> processor
encode_worker --embeddings--> prefill_worker
prefill_worker --> encode_worker
prefill_worker --> decode_worker
decode_worker --> prefill_worker
```
```bash
cd $DYNAMO_HOME/examples/multimodal_v1
bash launch/disagg.sh --model llava-hf/llava-1.5-7b-hf
```
### Client
In another terminal:
```bash
curl http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llava-hf/llava-1.5-7b-hf",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What is in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
}
}
]
}
],
"max_tokens": 300,
"temperature": 0.0,
"stream": false
}'
```
You should see a response similar to this:
```json
{"id": "c1774d61-3299-4aa3-bea1-a0af6c055ba8", "object": "chat.completion", "created": 1747725645, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " This image shows a passenger bus traveling down the road near power lines and trees. The bus displays a sign that says \"OUT OF SERVICE\" on its front."}, "finish_reason": "stop"}]}
```
***Note***: disaggregation is currently only confirmed to work with LLaVA. Qwen VL and PhiV are not confirmed to be supported.
## Llama 4 family Serving
The family of Llama 4 models is natively multimodal, however, different
from Llava, they do not directly consume image embedding as input
(see the [support metrics](https://docs.vllm.ai/en/latest/models/supported_models.html#text-generation_1)
from vLLM for the types of multi-modal inputs supported by the model).
Therefore, encoder worker will not be used in the following example and the
encoding will be done along side with prefill.
`meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8` will be used as an example
for the content below. And the system will be H100x8 which can hold one instance
of the model per node.
### Multimodal Aggregated Serving
#### Components
- workers: For aggregated serving, we have one worker, [VllmPDWorker](components/worker.py) for prefilling and decoding.
- processor: Tokenizes the prompt and passes it to the VllmPDWorker.
- frontend: HTTP endpoint to handle incoming requests.
#### Graph
In this graph, we have [VllmPDWorker](components/worker.py) which will encode the image, prefill and decode the prompt, just like the [LLM aggregated serving](/components/backends/vllm/README.md) example.
This figure shows the flow of the graph:
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --image_url--> pd_worker
pd_worker --> processor
```
```bash
cd $DYNAMO_HOME/examples/multimodal_v1
bash launch/agg_llama.sh
```
#### Client
In another terminal:
```bash
curl http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What is in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
}
}
]
}
],
"max_tokens": 300,
"temperature": 0.0,
"stream": false
}'
```
You should see a response similar to this:
```json
{"id": "b8f060fa95584e34b9204eaba7b105cc", "object": "chat.completion", "created": 1752706281, "model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", "choices": [{"index": 0, "message": {"role": "assistant", "content": "The image depicts a street scene with a trolley bus as the central focus. The trolley bus is positioned on the left side of the road, facing the camera, and features a white and yellow color scheme. A prominent sign on the front of the bus reads \"OUT OF SERVICE\" in orange letters.\n\n**Key Elements:**\n\n* **Trolley Bus:** The bus is the main subject of the image, showcasing its distinctive design and color.\n* **Sign:** The \"OUT OF SERVICE\" sign is clearly visible on the front of the bus, indicating its current status.\n* **Street Scene:** The surrounding environment includes trees, buildings, and power lines, creating a sense of context and atmosphere.\n* **Lighting:** The image is characterized by a misty or foggy quality, with soft lighting that adds to the overall ambiance.\n\n**Overall Impression:**\n\nThe image presents a serene and somewhat melancholic scene, with the out-of-service trolley bus serving as a focal point. The misty atmosphere and soft lighting contribute to a dreamy or nostalgic feel, inviting the viewer to reflect on the scene."}, "finish_reason": "stop"}]}
```
### Multimodal Disaggregated Serving
#### Components
- workers: For disaggregated serving, we have two workers, [VllmDecodeWorker](components/worker.py) for decoding, and [VllmPDWorker](components/worker.py) for encoding and prefilling.
- processor: Tokenizes the prompt and passes it to the VllmPDWorker.
- frontend: HTTP endpoint to handle incoming requests.
#### Graph
In this graph, we have two workers, [VllmDecodeWorker](components/worker.py), and [VllmPDWorker](components/worker.py).
The prefill worker performs the encoding and prefilling steps and forwards the KV cache to the decode worker for decoding.
For more details on the roles of the prefill and decode workers, refer to the [LLM disaggregated serving](/components/backends/vllm/README.md) example.
This figure shows the flow of the graph:
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --image_url--> prefill_worker
prefill_worker --> processor
prefill_worker --> decode_worker
decode_worker --> prefill_worker
```
```bash
cd $DYNAMO_HOME/examples/multimodal_v1
bash launch/disagg_llama.sh --head-node
# On a separate node that has finished standard dynamo setup, i.e.
# the worker node needs NATS_SERVER and ETCD_ENDPOINTS environment variables
# pointing to the head node's external IP address for distributed coordination
cd $DYNAMO_HOME/examples/multimodal_v1
bash launch/disagg_llama.sh
```
#### Client
In another terminal:
```bash
curl http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What is in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
}
}
]
}
],
"max_tokens": 300,
"temperature": 0.0,
"stream": false
}'
```
You should see a response similar to this:
```json
{"id": "6cc99123ad6948d685b8695428238d4b", "object": "chat.completion", "created": 1752708043, "model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", "choices": [{"index": 0, "message": {"role": "assistant", "content": "The image depicts a street scene with a trolley bus as the central focus. The trolley bus is positioned on the left side of the road, facing the camera, and features a white and yellow color scheme. A prominent sign on the front of the bus reads \"OUT OF SERVICE\" in orange letters.\n\n**Key Elements:**\n\n* **Trolley Bus:** The bus is the main subject of the image, showcasing its distinctive design and color.\n* **Sign:** The \"OUT OF SERVICE\" sign is clearly visible on the front of the bus, indicating its current status.\n* **Street Scene:** The surrounding environment includes trees, buildings, and power lines, creating a sense of context and atmosphere.\n* **Lighting:** The image is characterized by a misty or foggy quality, with soft lighting that adds to the overall mood.\n\n**Overall Impression:**\n\nThe image presents a serene and somewhat melancholic scene, with the out-of-service trolley bus serving as a focal point. The misty atmosphere and soft lighting contribute to a contemplative ambiance, inviting the viewer to reflect on the situation."}, "finish_reason": "stop"}]}
```
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import logging
import os
import signal
import sys
from typing import AsyncIterator, Tuple
import torch
import uvloop
from transformers import AutoImageProcessor, LlavaForConditionalGeneration
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
import connect
from utils.args import Config, base_parse_args, parse_endpoint
from utils.image_loader import ImageLoader
from utils.protocol import MyRequestOutput, vLLMMultimodalRequest
configure_dynamo_logging()
logger = logging.getLogger(__name__)
try:
import cupy as array_module
if not array_module.cuda.is_available():
raise ImportError("CUDA is not available.")
DEVICE = "cuda"
logger.info("Using cupy for array operations (GPU mode).")
except ImportError as e:
logger.warning(f"Failed to import cupy, falling back to numpy: {e}.")
import numpy as array_module
DEVICE = "cpu"
CACHE_SIZE_MAXIMUM = 8
class VllmEncodeWorker:
def __init__(self, args: argparse.Namespace, engine_args: AsyncEngineArgs) -> None:
self.downstream_endpoint = args.downstream_endpoint
self.engine_args = engine_args
self.model = self.engine_args.model
self.image_loader = ImageLoader(cache_size=CACHE_SIZE_MAXIMUM)
self.image_processor = AutoImageProcessor.from_pretrained(
self.model, trust_remote_code=True
)
# self.vision_model = load_vision_model(self.model)
self.vision_model = LlavaForConditionalGeneration.from_pretrained(
self.model, device_map="auto", torch_dtype=torch.float16
).eval()
self.min_workers = 1
def cleanup(self):
pass
async def generate(
self, request: vLLMMultimodalRequest
) -> AsyncIterator[MyRequestOutput]:
logger.debug(f"Got raw request: {request}")
if not isinstance(request, vLLMMultimodalRequest):
if isinstance(request, str):
request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received encode request: {{ id: {request.request_id} }}.")
request_id = request.request_id
# The following steps encode the requested image and provided useful embeddings.
# 1. Open the image from the provided URL.
# 2. Process the image using the image processor.
# 3. Run the image through the vision model's vision tower.
# 4. Run the results of the vision tower through the multi-modal projector.
# 5. Create a descriptor for the embeddings.
# 6. Create a write operation using the serialized request and the descriptor.
# 7. Await for the write operation to complete.
# 8. Yield the encode response.
try:
image = await self.image_loader.load_image(request.image_url)
logger.debug(f"Processing image for request: {{ id: {request_id} }}")
image_embeds = self.image_processor(images=image, return_tensors="pt")
# [gluo NOTE] The commented section is for VLM generalization support,
# will use more generic approach once utils/model.py is fixed,
# see utils/models.py for details.
# # Add a batch dimension to everything
# for item in image_embeds:
# image_embeds[item] = image_embeds[item].unsqueeze(0).to(DEVICE)
# logger.debug(f"Image embeds: {image_embeds}")
# image_grid_thw = (
# image_embeds["image_grid_thw"].tolist()
# if "image_grid_thw" in image_embeds
# else None
# )
# image_sizes = (
# image_embeds["image_sizes"].tolist()
# if "image_sizes" in image_embeds
# else [image.size]
# )
# logger.debug(
# f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}"
# )
# with torch.no_grad():
# embeddings = self.vision_model.get_multimodal_embeddings(**image_embeds)
# if isinstance(embeddings, tuple) or isinstance(embeddings, list):
# # The result multimodal_embeddings may be a list or tuple of tensors, with each
# # tensor corresponding to a multimodal data item (image or video).
# # TODO: for multi-image support, this result will contain multiple tensors.
# embeddings = embeddings[0].unsqueeze(0)
# logger.debug(
# f"Embeddings: {{ shape: {embeddings.shape}, dtype: {embeddings.dtype}, device: {embeddings.device}, ptr: {embeddings.data_ptr()}, elements: {{ count: {embeddings.numel()}, size: {embeddings.element_size()} }} }}."
# )
with torch.no_grad():
logger.debug(f"Vision model device: {self.vision_model.device}")
vision_outputs = self.vision_model.vision_tower(
image_embeds["pixel_values"].to(self.vision_model.device)
)
logger.debug("Vision model completed.")
embeddings = vision_outputs.last_hidden_state
embeddings = self.vision_model.multi_modal_projector(embeddings)
descriptor = connect.Descriptor(embeddings)
with self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.to_serialized()
# Clear the image URL as hint that the image is passed as embeddings.
request.image_url = None
logger.debug(f"Request: {request.model_dump_json()}")
# Get the response generator
response_generator = await self.pd_worker_client.round_robin(
request.model_dump_json()
)
await readable.wait_for_completion()
async for response in response_generator:
output = MyRequestOutput.model_validate_json(response.data())
yield MyRequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
).model_dump_json()
except Exception as e:
logger.error(f"Error processing request {request_id}: {e}")
raise
async def async_init(self, runtime: DistributedRuntime):
logger.info("Startup started.")
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
self.downstream_endpoint
)
self.pd_worker_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector(runtime=runtime, namespace=parsed_namespace)
await self._connector.initialize()
logger.info("Startup completed.")
@classmethod
def parse_args(cls) -> Tuple[argparse.Namespace, Config]:
DEFAULT_ENDPOINT = "dyn://dynamo.encoder.generate"
DEFAULT_DOWNSTREAM_ENDPOINT = "dyn://dynamo.llm.generate"
parser = FlexibleArgumentParser(
description="vLLM based encoder for Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_ENDPOINT}'",
)
parser.add_argument(
"--downstream-endpoint",
type=str,
default=DEFAULT_DOWNSTREAM_ENDPOINT,
help=f"The endpoint string of the downstream LLM in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_DOWNSTREAM_ENDPOINT}'",
)
args, config = base_parse_args(parser)
return args, config
async def graceful_shutdown(runtime):
"""
By calling `runtime.shutdown()`, the endpoints will immediately be unavailable.
However, in-flight requests will still be processed until they are finished.
After all in-flight requests are finished, the `serve_endpoint` functions will return
and the engine will be shutdown by Python's garbage collector.
"""
logging.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
# Runtime setup
# Set up signal handler for graceful shutdown
loop = asyncio.get_running_loop()
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logging.info("Signal handlers set up for graceful shutdown")
# worker setup
args, config = VllmEncodeWorker.parse_args()
await init(runtime, args, config)
async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Config):
"""
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
generate_endpoint = component.endpoint(config.endpoint)
handler = VllmEncodeWorker(args, config.engine_args)
await handler.async_init(runtime)
logger.info(f"Starting to serve the {args.endpoint} endpoint...")
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import json
import logging
import os
import signal
import sys
import uuid
from enum import Enum
from typing import AsyncIterator, Tuple, Union
import uvloop
from transformers import AutoTokenizer
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import FlexibleArgumentParser
from dynamo.llm import ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
# To import example local module
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
from utils.args import Config, base_parse_args, parse_endpoint
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.protocol import MultiModalRequest, MyRequestOutput, vLLMMultimodalRequest
configure_dynamo_logging()
logger = logging.getLogger(__name__)
prompt_template = "USER: <image>\n<prompt> ASSISTANT:"
class RequestType(Enum):
CHAT = "chat"
COMPLETION = "completion"
class Processor(ProcessMixIn):
"""
vLLM pre and post processing
"""
@classmethod
def parse_args(cls) -> Tuple[argparse.Namespace, Config]:
DEFAULT_ENDPOINT = "dyn://dynamo.processor.generate"
DEFAULT_DOWNSTREAM_ENDPOINT = "dyn://dynamo.encoder.generate"
parser = FlexibleArgumentParser(
description="vLLM based processor for Dynamo LLM."
)
parser.add_argument(
"--prompt-template",
type=str,
required=True,
help=(
"Different multi-modal models expect the prompt to contain different special media prompts. "
"The processor will use this argument to construct the final prompt. "
"User prompt will replace '<prompt>' in the provided template. "
"For example, if the user prompt is 'please describe the image' and the prompt template is "
"'USER: <image> <prompt> ASSISTANT:', the resulting prompt is "
"'USER: <image> please describe the image ASSISTANT:'."
),
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_ENDPOINT}'",
)
parser.add_argument(
"--downstream-endpoint",
type=str,
default=DEFAULT_DOWNSTREAM_ENDPOINT,
help=f"The endpoint string of the downstream encoder in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_DOWNSTREAM_ENDPOINT}'",
)
args, config = base_parse_args(parser)
return args, config
def __init__(self, args: argparse.Namespace, engine_args: AsyncEngineArgs):
self.prompt_template = args.prompt_template
self.downstream_endpoint = args.downstream_endpoint
self.engine_args = engine_args
self.model_config = self.engine_args.create_model_config()
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.tokenizer = self._create_tokenizer(self.engine_args)
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
self.completions_processor = CompletionsProcessor(
self.tokenizer, self.model_config
)
def cleanup(self):
pass
def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer:
"""Create a TokenizerGroup using engine arguments similar to VLLM's approach"""
model_path = engine_args.model
# Create the base tokenizer with VLLM's typical settings
base_tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
padding_side="left",
truncation_side="left",
use_fast=True, # VLLM might use the fast tokenizer for efficiency
)
return base_tokenizer
async def async_init(self, runtime: DistributedRuntime):
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
self.downstream_endpoint
)
self.encode_worker_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
# Main method to parse the request and send the request to the vllm worker.
async def _generate(
self,
raw_request: Union[CompletionRequest, ChatCompletionRequest],
image: str,
request_type: RequestType,
):
request_id = str(uuid.uuid4().hex)
logger.debug(f"Got raw request: {raw_request}")
(
request,
conversation,
prompt,
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
worker_request = vLLMMultimodalRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
image_url=image,
)
# model_dump_json() serializes the request to JSON string
# This API could accept Pydantic class, but SamplingParams
# in vLLMMultimodalRequest is not a Pydantic class and will
# cause TypeError: unsupported type SamplingParams
response_generator = await self.encode_worker_client.round_robin(
worker_request.model_dump_json()
)
output = self._generate_responses(response_generator, request_type)
# Stream the processed responses
async for response in await self._stream_response(
request, output, request_id, conversation
):
yield response
# This method is used to process the responses from the engine generator.
async def _generate_responses(
self,
response_generator: AsyncIterator[RequestOutput],
request_type: RequestType,
):
async for resp in response_generator:
# Deserialize the response from the engine
# Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data())
# OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object
request_output = RequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
)
if request_type == RequestType.CHAT:
# For chat requests, yield the request_output directly.
yield request_output
else:
raise NotImplementedError(
f"Request type {request_type} not implemented"
)
# The generate endpoint will be used by the frontend to handle incoming requests.
async def generate(self, raw_request: MultiModalRequest):
logger.debug(f"Got raw request: {raw_request}")
if not isinstance(raw_request, MultiModalRequest):
# If the request is not MultiModalRequest, convert it to MultiModalRequest
raw_request = MultiModalRequest.model_validate(raw_request)
# Ensure the configured template includes the placeholder
template = self.prompt_template
if "<prompt>" not in template:
raise ValueError("prompt_template must contain '<prompt>' placeholder")
# Safely extract user text
try:
user_text = raw_request.messages[0].content[0].text
except (IndexError, AttributeError) as e:
raise ValueError(f"Invalid message structure: {e}")
prompt = template.replace("<prompt>", user_text)
msg = {
"role": "user",
"content": prompt,
}
chat_request = ChatCompletionRequest(
model=raw_request.model,
messages=[msg],
stream=raw_request.stream,
max_tokens=raw_request.max_tokens,
temperature=raw_request.temperature,
request_id=str(uuid.uuid4()),
)
image_url = None
for message in raw_request.messages:
for item in message.content:
if item.type == "image_url":
image_url = item.image_url.url
if image_url is None:
raise ValueError("Image URL is required")
async for response in self._generate(chat_request, image_url, RequestType.CHAT):
logger.debug(
f"Generated response type {type(response)}, content: {response}"
)
# reconstructing back the OpenAI chat response as dynamo egress expects it
if response.startswith("data: [DONE]"):
break
response = json.loads(response.lstrip("data: "))
yield response
async def graceful_shutdown(runtime):
"""
By calling `runtime.shutdown()`, the endpoints will immediately be unavailable.
However, in-flight requests will still be processed until they are finished.
After all in-flight requests are finished, the `serve_endpoint` functions will return
and the engine will be shutdown by Python's garbage collector.
"""
logging.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
# Runtime setup
# Set up signal handler for graceful shutdown
loop = asyncio.get_running_loop()
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logging.info("Signal handlers set up for graceful shutdown")
# worker setup
args, config = Processor.parse_args()
await init(runtime, args, config)
async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Config):
"""
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
generate_endpoint = component.endpoint(config.endpoint)
handler = Processor(args, config.engine_args)
await handler.async_init(runtime)
# Register the endpoint as entrypoint to a model
await register_llm(
ModelType.Chat, # Custom processor is used and this type bypasses SDK processor
generate_endpoint,
config.model,
config.served_model_name,
kv_cache_block_size=config.engine_args.block_size,
)
logger.info(f"Starting to serve the {args.endpoint} endpoint...")
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from vllm.config import VllmConfig
from vllm.v1.metrics.loggers import StatLoggerBase
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from dynamo.llm import (
ForwardPassMetrics,
KvStats,
SpecDecodeStats,
WorkerMetricsPublisher,
WorkerStats,
)
from dynamo.runtime import Component
class NullStatLogger(StatLoggerBase):
def __init__(self):
pass
def record(
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
engine_idx: int = 0,
):
pass
def log_engine_initialized(self):
pass
class DynamoStatLoggerPublisher(StatLoggerBase):
"""Stat logger publisher. Wrapper for the WorkerMetricsPublisher to match the StatLoggerBase interface."""
def __init__(self, component: Component, dp_rank: int) -> None:
self.inner = WorkerMetricsPublisher()
self.inner.create_endpoint(component)
self.dp_rank = dp_rank
self.num_gpu_block = 1
self.request_total_slots = 1
# TODO: Remove this and pass as metadata through etcd
def set_num_gpu_block(self, num_blocks):
self.num_gpu_block = num_blocks
# TODO: Remove this and pass as metadata through etcd
def set_num_request_total_slots(self, request_total_slots):
self.request_total_slots = request_total_slots
def record(
self,
scheduler_stats: SchedulerStats,
iteration_stats: Optional[IterationStats],
engine_idx: int = 0,
):
# request_total_slots and kv_total_blocks are properties of model + gpu
# we should only publish them once, not every metric update
# they should be part of some runtime metadata tied to MDC or put in etcd ?
hit_rate = 0
if scheduler_stats.prefix_cache_stats.queries > 0:
hit_rate = (
scheduler_stats.prefix_cache_stats.hits
/ scheduler_stats.prefix_cache_stats.queries
)
worker_stats = WorkerStats(
request_active_slots=scheduler_stats.num_running_reqs,
request_total_slots=self.request_total_slots,
num_requests_waiting=scheduler_stats.num_waiting_reqs,
data_parallel_rank=self.dp_rank,
)
kv_stats = KvStats(
kv_active_blocks=int(self.num_gpu_block * scheduler_stats.kv_cache_usage),
kv_total_blocks=self.num_gpu_block,
gpu_cache_usage_perc=scheduler_stats.kv_cache_usage,
gpu_prefix_cache_hit_rate=hit_rate, # TODO: This is a point in time update, not cumulative. Will be problematic on router side if we try to use it.
)
spec_dec_stats = scheduler_stats.spec_decoding_stats
if spec_dec_stats:
spec_dec_stats = SpecDecodeStats(
num_spec_tokens=spec_dec_stats.num_spec_tokens,
num_drafts=spec_dec_stats.num_drafts,
num_draft_tokens=spec_dec_stats.num_draft_tokens,
num_accepted_tokens=spec_dec_stats.num_accepted_tokens,
num_accepted_tokens_per_pos=spec_dec_stats.num_accepted_tokens_per_pos,
)
metrics = ForwardPassMetrics(
worker_stats=worker_stats,
kv_stats=kv_stats,
spec_decode_stats=spec_dec_stats,
)
self.inner.publish(metrics)
def init_publish(self):
worker_stats = WorkerStats(
request_active_slots=0,
request_total_slots=self.request_total_slots,
num_requests_waiting=0,
data_parallel_rank=self.dp_rank,
)
kv_stats = KvStats(
kv_active_blocks=0,
kv_total_blocks=self.num_gpu_block,
gpu_cache_usage_perc=0,
gpu_prefix_cache_hit_rate=0,
)
metrics = ForwardPassMetrics(
worker_stats=worker_stats,
kv_stats=kv_stats,
spec_decode_stats=None,
)
self.inner.publish(metrics)
def log_engine_initialized(self) -> None:
pass
class StatLoggerFactory:
"""Factory for creating stat logger publishers. Required by vLLM."""
def __init__(self, component: Component, dp_rank: int = 0) -> None:
self.component = component
self.created_logger: Optional[DynamoStatLoggerPublisher] = None
self.dp_rank = dp_rank
def create_stat_logger(self, dp_rank: int) -> StatLoggerBase:
if self.dp_rank != dp_rank:
return NullStatLogger()
logger = DynamoStatLoggerPublisher(self.component, dp_rank)
self.created_logger = logger
return logger
def __call__(self, vllm_config: VllmConfig, dp_rank: int) -> StatLoggerBase:
return self.create_stat_logger(dp_rank=dp_rank)
# TODO Remove once we publish metadata to etcd
def set_num_gpu_blocks_all(self, num_blocks):
if self.created_logger:
self.created_logger.set_num_gpu_block(num_blocks)
def set_request_total_slots_all(self, request_total_slots):
if self.created_logger:
self.created_logger.set_num_request_total_slots(request_total_slots)
def init_publish(self):
if self.created_logger:
self.created_logger.init_publish()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import copy
import logging
import os
import signal
import sys
from typing import Tuple
import torch
import uvloop
from transformers import AutoImageProcessor
from vllm.distributed.kv_events import ZmqEventPublisher
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs.data import TokensPrompt
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
from vllm.v1.engine.async_llm import AsyncLLM
from dynamo.llm import ZmqKvEventPublisher, ZmqKvEventPublisherConfig
from dynamo.runtime import Component, DistributedRuntime, Endpoint, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
import connect
from publisher import StatLoggerFactory
from utils.args import (
Config,
base_parse_args,
configure_ports_with_etcd,
overwrite_args,
parse_endpoint,
)
from utils.image_loader import ImageLoader
from utils.protocol import MyRequestOutput, vLLMMultimodalRequest
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class VllmBaseWorker:
@classmethod
def parse_args(cls) -> Tuple[argparse.Namespace, Config]:
parser = FlexibleArgumentParser(
description="vLLM based encoder for Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
help="Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default value will vary based on the worker type, see --worker-type for details.",
)
parser.add_argument(
"--downstream-endpoint",
type=str,
help="The endpoint string of the downstream LLM in 'dyn://namespace.component.endpoint' format. Default value will vary based on the worker type, see --worker-type for details.",
)
parser.add_argument(
"--worker-type",
type=str,
choices=["prefill", "decode", "encode_prefill"],
required=True,
help="Specify the type of worker. Must be one of: 'prefill', 'decode', 'encode_prefill'",
)
parser.add_argument(
"--enable-disagg",
action="store_true",
help="Enable disaggregated mode, where prefill and decode are handled by separate workers."
" If not set, the '*prefill' worker type will handle both prefill and decode.",
)
# use endpoint_overwrite to set the default endpoint based on worker type
def endpoint_overwrite(args):
# default endpoint for this worker
if args.worker_type == "prefill":
args.endpoint = args.endpoint or "dyn://dynamo.llm.generate"
elif args.worker_type == "decode":
args.endpoint = args.endpoint or "dyn://dynamo.decoder.generate"
elif args.worker_type == "encode_prefill":
args.endpoint = args.endpoint or "dyn://dynamo.encoder.generate"
# set downstream endpoint for disaggregated workers
if args.enable_disagg:
args.downstream_endpoint = (
args.downstream_endpoint or "dyn://dynamo.decoder.generate"
)
return args
args, config = base_parse_args(parser, endpoint_overwrite)
return args, config
def __init__(
self,
args: argparse.Namespace,
engine_args: AsyncEngineArgs,
component: Component,
endpoint: Endpoint,
):
self.enable_disagg = args.enable_disagg
self.endpoint = args.endpoint
self.downstream_endpoint = args.downstream_endpoint
self.engine_args = engine_args
self.setup_vllm_engine(component, endpoint)
async def async_init(self, runtime: DistributedRuntime):
pass
def setup_vllm_engine(self, component: Component, endpoint: Endpoint):
"""Initialize the vLLM engine.
This method sets up the vLLM engine client, and configures the dynamo-aware KV
event publisher and metrics stats logger based on component and endpoint.
"""
os.environ["VLLM_NO_USAGE_STATS"] = "1" # Avoid internal HTTP requests
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
# Load default sampling params from `generation_config.json`
self.default_sampling_params = (
self.engine_args.create_model_config().get_diff_sampling_param()
)
# Taken from build_async_engine_client_from_engine_args()
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = self.engine_args.create_engine_config(usage_context=usage_context)
# Create vLLM engine with metrics logger and KV event publisher attached
self.stats_logger = StatLoggerFactory(
component, self.engine_args.data_parallel_rank or 0
)
self.engine_client = AsyncLLM.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
stat_loggers=[self.stats_logger],
disable_log_requests=self.engine_args.disable_log_requests,
disable_log_stats=self.engine_args.disable_log_stats,
)
# TODO Hack to get data, move this to registering in ETCD
self.stats_logger.set_num_gpu_blocks_all(
vllm_config.cache_config.num_gpu_blocks
)
self.stats_logger.set_request_total_slots_all(
vllm_config.scheduler_config.max_num_seqs
)
self.stats_logger.init_publish()
# TODO: We start off with a valid endpoint, then we increment it by dp_rank
# May no longer be valid. Lets remove the increment behavior from vLLM and here
zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
self.engine_args.kv_events_config.endpoint,
data_parallel_rank=self.engine_args.data_parallel_rank or 0,
).replace("*", "127.0.0.1")
zmq_config = ZmqKvEventPublisherConfig(
worker_id=endpoint.lease_id(),
kv_block_size=vllm_config.cache_config.block_size,
zmq_endpoint=zmq_endpoint,
)
self.kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
logger.info(f"Reading Events from {zmq_endpoint}")
logger.info(f"VllmWorker for {self.engine_args.model} has been initialized")
async def generate(self, request: vLLMMultimodalRequest):
raise NotImplementedError(
"This method should be implemented in subclasses to handle the generation logic."
)
async def clear_kv_blocks(self, request=None):
try:
await self.engine_client.reset_prefix_cache()
yield {"status": "success", "message": "KV cache cleared"}
except Exception as e:
yield {"status": "error", "message": str(e)}
def cleanup(self):
"""Override in subclasses if cleanup is needed."""
pass
class VllmDecodeWorker(VllmBaseWorker):
async def generate(self, request: vLLMMultimodalRequest):
logger.debug(f"Got raw request: {request}")
if not isinstance(request, vLLMMultimodalRequest):
if isinstance(request, str):
request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received decode request: {{ id: {request.request_id} }}.")
# Decode worker doesn't process embeddings, so we pass None or empty tensor
gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=request.engine_prompt["prompt_token_ids"],
),
sampling_params=request.sampling_params,
request_id=request.request_id,
)
async for response in gen:
logger.debug(f"Response kv_transfer_params: {response.kv_transfer_params}")
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
metrics=response.metrics,
kv_transfer_params=response.kv_transfer_params,
).model_dump_json()
class VllmPDWorker(VllmBaseWorker):
async def async_init(self, runtime: DistributedRuntime):
logger.info("Startup started.")
if self.enable_disagg:
(
parsed_namespace,
parsed_component_name,
parsed_endpoint_name,
) = parse_endpoint(self.downstream_endpoint)
self.decode_worker_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
EMBEDDINGS_DTYPE = torch.float16
EMBEDDINGS_DEVICE = "cpu"
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
parsed_namespace, _, _ = parse_endpoint(self.endpoint)
self._connector = connect.Connector(runtime=runtime, namespace=parsed_namespace)
await self._connector.initialize()
# embeddings_shape, self.embeddings_dtype = get_vision_embeddings_info(
# self.engine_args.model, self.engine_args.num_patches
# )
# [gluo NOTE] Hardcoded for now, will use more generic approach once utils/model.py
# is fixed, see utils/models.py for details.
embeddings_shape = (1, 577, 4096)
logger.debug(f"Embeddings shape: {embeddings_shape}")
self.embedding_size = embeddings_shape[1]
embeddings = torch.empty(
embeddings_shape, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE
)
descriptor = connect.Descriptor(embeddings)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
# descriptor.register_memory(self._connector)
self._embeddings_descriptor = (embeddings, descriptor)
self.image_loader = ImageLoader()
self.image_processor = AutoImageProcessor.from_pretrained(
self.engine_args.model, trust_remote_code=True
)
logger.info("VllmPDWorker has been initialized")
async def generate(self, request: vLLMMultimodalRequest):
logger.debug(f"Got raw request: {request}")
if type(request) is not vLLMMultimodalRequest:
if type(request) is str:
request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")
if request.image_url is None:
# Process embeddings using the connector
embeddings, descriptor = self._embeddings_descriptor
if descriptor is None:
raise RuntimeError(
"Descriptor is None in PD worker - cannot process embeddings"
)
read_op = await self._connector.begin_read(
request.serialized_request, descriptor
)
await read_op.wait_for_completion()
logger.debug(f"in PD worker, image features: {embeddings}")
multi_modal_data = embeddings
else:
# Use PIL image instead of image embeddings
multi_modal_data = await self.image_loader.load_image(request.image_url)
# multi_modal_data = self.image_processor(images=image, return_tensors="pt")["pixel_values"].to(dtype=torch.float16)
# image input is expected to be (image_num, channel, height, width)
# logger.info(f"Image features shape: {multi_modal_data.shape}")
# multi_modal_data = multi_modal_data.unsqueeze(0)
# Remove the image features from the request as they are not required
request.image_url = None
request.serialized_request = None
pd_request = copy.deepcopy(request)
# Do prefill and remote decode if enable_disagg is true
if self.enable_disagg:
extra_args = pd_request.sampling_params.extra_args or {}
extra_args["kv_transfer_params"] = {
"do_remote_decode": True,
}
pd_request.sampling_params.extra_args = extra_args
pd_request.sampling_params.max_tokens = 1
pd_request.sampling_params.min_tokens = 1
logger.debug("Prefill request: %s", pd_request)
gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=pd_request.engine_prompt["prompt_token_ids"],
multi_modal_data={"image": multi_modal_data},
),
sampling_params=pd_request.sampling_params,
request_id=pd_request.request_id,
)
if self.enable_disagg:
decode_request = copy.deepcopy(request)
async for prefill_response in gen:
# Update the prompt token id in the decode request to the one
# in response, which has image templated filled in. So that
# the decode worker will fetch correct amount of KV blocks.
decode_request.engine_prompt[
"prompt_token_ids"
] = prefill_response.prompt_token_ids
logger.debug(
f"Prefill response kv_transfer_params: {prefill_response.kv_transfer_params}"
)
extra_args = decode_request.sampling_params.extra_args or {}
extra_args["kv_transfer_params"] = prefill_response.kv_transfer_params
extra_args.pop("serialized_request", None)
decode_request.sampling_params.extra_args = extra_args
logger.debug("Decode request: %s", decode_request)
async for decode_response in await self.decode_worker_client.round_robin(
decode_request.model_dump_json()
):
output = MyRequestOutput.model_validate_json(decode_response.data())
yield MyRequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
kv_transfer_params=output.kv_transfer_params,
).model_dump_json()
else:
async for response in gen:
logger.debug(
f"Response kv_transfer_params: {response.kv_transfer_params}"
)
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
metrics=response.metrics,
kv_transfer_params=response.kv_transfer_params,
).model_dump_json()
async def graceful_shutdown(runtime):
"""
By calling `runtime.shutdown()`, the endpoints will immediately be unavailable.
However, in-flight requests will still be processed until they are finished.
After all in-flight requests are finished, the `serve_endpoint` functions will return
and the engine will be shutdown by Python's garbage collector.
"""
logging.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
# Runtime setup
# Set up signal handler for graceful shutdown
loop = asyncio.get_running_loop()
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logging.info("Signal handlers set up for graceful shutdown")
# worker setup
args, config = VllmBaseWorker.parse_args()
# vLLM config overwrites
etcd_client = runtime.etcd_client()
await configure_ports_with_etcd(config, etcd_client)
overwrite_args(config)
await init(runtime, args, config)
async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Config):
"""
Instantiate and serve
"""
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
generate_endpoint = component.endpoint(config.endpoint)
clear_endpoint = component.endpoint("clear_kv_blocks")
if args.worker_type in ["prefill", "encode_prefill"]:
handler: VllmBaseWorker = VllmPDWorker(
args, config.engine_args, component, generate_endpoint
)
elif args.worker_type == "decode":
handler = VllmDecodeWorker(
args, config.engine_args, component, generate_endpoint
)
await handler.async_init(runtime)
logger.info(f"Starting to serve the {args.endpoint} endpoint...")
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate),
clear_endpoint.serve_endpoint(handler.clear_kv_blocks),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Dynamo Connect
Dynamo connect provides a Pythonic interface to the NIXL base RDMA subsystem via a set of Python classes.
The primary goal of this library to simplify the integration of NIXL based RDMA into inference applications.
All operations using the Connect library begin with the [`Connector`](#connector) class and the type of operation required.
There are four types of supported operations:
- **Register local readable memory**:
Register local memory buffer(s) with the RDMA subsystem to enable a remote worker to read from.
- **Register local writable memory**:
Register local memory buffer(s) with the RDMA subsystem to enable a remote worker to write to.
- **Read from registered, remote memory**:
Read remote memory buffer(s), registered by a remote worker to be readable, into local memory buffer(s).
- **Write to registered, remote memory**:
Write local memory buffer(s) to remote memory buffer(s) registered by a remote worker to writable.
By connecting correctly paired operations, high-throughput GPU Direct RDMA data transfers can be completed.
Given the list above, the correct pairing of operations would be 1 & 3 or 2 & 4.
Where one side is a "(read|write)-able operation" and the other is its correctly paired "(read|write) operation".
Specifically, a read operation must be paired with a readable operation, and a write operation must be paired with a writable operation.
## Examples
### Generic Example
In the diagram below, Local creates a [`WritableOperation`](#writableoperation) intended to receive data from Remote.
Local then sends metadata about the requested RDMA operation to Remote.
Remote then uses the metadata to create a [`WriteOperation`](#writeoperation) which will perform the GPU Direct RDMA memory transfer from Remote's GPU memory to Local's GPU memory.
```mermaid
---
title: Write Operation Between Two Workers
---
flowchart LR
c1[Remote] --"3: .begin_write()"--- WriteOperation
WriteOperation e1@=="4: GPU Direct RDMA"==> WritableOperation
WritableOperation --"1: .create_writable()"--- c2[Local]
c2 e2@--"2: RDMA Metadata via HTTP"--> c1
e1@{ animate: true; }
e2@{ animate: true; }
```
### Multimodal Example
In the case of the [Dynamo Multimodal Disaggregated Example](../README.md):
1. The HTTP frontend accepts a text prompt and a URL to an image.
2. The prompt and URL are then enqueued with the Processor before being dispatched to the first available Decode Worker.
3. Decode Worker then requests a Prefill Worker to provide key-value data for the LLM powering the Decode Worker.
4. Prefill Worker then requests that the image be processed and provided as embeddings by the Encode Worker.
5. Encode Worker acquires the image, processes it, performs inference on the image using a specialized vision model, and finally provides the embeddings to Prefill Worker.
6. Prefill Worker receives the embeddings from Encode Worker and generates a key-value cache (KV$) update for Decode Worker's LLM and writes the update directly to the GPU memory reserved for the data.
7. Finally, Decode Worker performs the requested inference.
```mermaid
---
title: Multimodal Disaggregated Workflow
---
flowchart LR
p0[HTTP Frontend] i0@--"text prompt"-->p1[Processor]
p0 i1@--"url"-->p1
p1 i2@--"prompt"-->dw[Decode Worker]
p1 i3@--"url"-->dw
dw i4@--"prompt"-->pw[Prefill Worker]
dw i5@--"url"-->pw
pw i6@--"url"-->ew[Encode Worker]
ew o0@=="image embeddings"==>pw
pw o1@=="kv_cache updates"==>dw
dw o2@--"inference results"-->p0
i0@{ animate: true; }
i1@{ animate: true; }
i2@{ animate: true; }
i3@{ animate: true; }
i4@{ animate: true; }
i5@{ animate: true; }
i6@{ animate: true; }
o0@{ animate: true; }
o1@{ animate: true; }
o2@{ animate: true; }
```
_Note: In this example, it is the data transfer between the Prefill Worker and the Encode Worker that utilizes the Dynamo Connect library. The KV Cache transfer between Decode Worker and Prefill Worker utilizes the NIXL base RDMA subsystem directly without using the Dynamo Connect library._
#### Code Examples
See [prefill_worker](../components/prefill_worker.py#L199) or [decode_worker](../components/decode_worker.py#L239),
for how they coordinate directly with the Encode Worker by creating a [`WritableOperation`](#writableoperation),
sending the operation's metadata via Dynamo's round-robin dispatcher, and awaiting the operation for completion before making use of the transferred data.
See [encode_worker](../components/encode_worker.py#L190),
for how the resulting embeddings are registered with the RDMA subsystem by creating a [`Descriptor`](#descriptor),
a [`WriteOperation`](#writeoperation) is created using the metadata provided by the requesting worker,
and the worker awaits for the data transfer to complete for yielding a response.
## Python Classes
### Connector
Core class for managing the connection between workers in a distributed environment.
Use this class to create readable and writable operations, or read and write data to remote workers.
This class is responsible for interfacing with the NIXL-based RDMA subsystem and providing a "Pythonic" interface
with which to utilize GPU Direct RDMA accelerated data transfers between models hosted by different workers in a Dynamo pipeline.
The connector provides two methods of moving data between workers:
- Preparing local memory to be written to by a remote worker.
- Preparing local memory to be read by a remote worker.
In both cases, local memory is registered with the NIXL-based RDMA subsystem via the [`Descriptor`](#descriptor) class and provided to the connector.
The connector then configures the RDMA subsystem to expose the memory for the requested operation and returns an operation control object.
The operation control object, either a [`ReadableOperation`](#readableoperation) or a [`WritableOperation`](#writableoperation),
provides RDMA metadata via its [`.to_serialized()`](#to_serialized) method as well as functionality to know when the operation has been completed or cancel the operation prior to completion.
The RDMA metadata must be provided to the remote worker expected to complete the operation.
The metadata contains required information (identifiers, keys, etc.) which enables the remote worker to interact with the provided memory.
#### Methods
##### `begin_read`
> Creates a [`ReadOperation`](#readoperation) for transferring data from a remote worker.
>
> To create the operation, the serialized request from a remote worker's [`ReadableOperation`](#readableoperation)
> along with a matching set of local memory descriptors which reference memory intended to receive data from the remote worker
> must be provided.
> The serialized request must be transferred from the remote to the local worker via a secondary channel, most likely HTTP or TCP+NATS.
>
> Once created, the operation will begin reading immediately.
> Disposal of the object reference will instruct the RDMA subsystem to cancel the read operation,
> therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
##### `begin_write`
> Creates a write operation for transferring data to a remote worker.
>
> To create the operation, the serialized request from a remote worker's [`WritableOperation`](#writableoperation)
> along with a matching set of local memory descriptors which reference memory to be transferred to the remote worker
> must be provided.
> The serialized request must be transferred from the remote to the local worker via a secondary channel, most likely HTTP or TCP+NATS.
>
> Once created, the operation will begin writing immediately.
> Disposal of the object reference will instruct the RDMA subsystem to cancel the write operation,
> therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
##### `create_readable`
> Creates a [`ReadableOperation`](#readableoperation) for transferring data to a remote worker.
>
> To create the operation, a set of local memory descriptors must be provided that reference memory intended to be transferred to
> a remote worker.
> Once created, the memory referenced by the provided descriptors becomes immediately readable by a remote worker with the necessary metadata.
> The metadata required to access the memory referenced by the provided descriptors is accessible via the operations `.to_serialized()` method.
> Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
>
> Disposable of the operation's object reference will instruct the RDMA subsystem to cancel the operation,
> therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
##### `create_writable`
> Creates a [`WritableOperation`](#writableoperation) for transferring data from a remote worker.
>
> To create the operation, a set of local memory descriptors must be provided which reference memory intended to receive data from
> a remote worker.
> Once created, the memory referenced by the provided descriptors becomes immediately writable by a remote worker with the necessary metadata.
> The metadata required to access the memory referenced by the provided descriptors is accessible via the operations `.to_serialized()` method.
> Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
>
> Disposable of the operation's object reference will instruct the RDMA subsystem to cancel the operation,
> therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
### Descriptor
Memory descriptor that ensures memory is registered with the NIXL base RDMA subsystem.
Memory must be registered with the RDMA subsystem to enable interaction with the memory.
Descriptor objects are administrative and do not copy, move, or otherwise modify the registered memory.
There are four ways to create a descriptor:
1. From a `torch.Tensor` object. Device information will be derived from the provided object.
2. From a `tuple` containing either a NumPy or CuPy `ndarray` and information describing where the memory resides (Host/CPU vs GPU).
3. From a Python `bytes` object. Memory is assumed to reside in CPU addressable host memory.
4. From a `tuple` comprised of the address of the memory, its size in bytes, and device information.
An optional reference to a Python object can be provided to avoid garbage collection issues.
### Device
Device describes the device, or kind of memory, a given allocation resides in.
Usually host (`"cpu"`) or GPU (`"cuda"`) memory.
When a system contains multiple GPU devices, specific GPU devices can be identified by including their ordinal index number.
For example, to reference the second GPU in a system `"cuda:1"` can be used.
By default, when `"cuda"` is provided, it is assumed to be `"cuda:0"` or the first GPU enumerated by the system.
### ReadOperation
An operation which transfers data from a remote worker to the local worker.
To create the operation, RDMA metadata ([`SerializedRequest`](#serializedrequest)) from a remote worker's [`ReadableOperation`](#readableoperation)
along with a matching set of local [`Descriptor`](#descriptor) objects which reference memory intended to receive data from the remote worker must be provided.
The RDMA metadata must be transferred from the remote to the local worker via a secondary channel, most likely HTTP or TCP+NATS.
Once created, the operation will begin reading immediately.
Disposal of the object reference will instruct the RDMA subsystem to cancel the read operation,
therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
#### Methods
##### `cancel`
> Instructs the RDMA subsystem to cancel the operation.
> Completed operations cannot be cancelled.
##### `wait_for_completion`
> Blocks the caller until the memory from the remote worker has been transferred to the provided buffers.
### ReadableOperation
An operation which enables a remote worker to read data from the local worker.
To create the operation, a set of local [`Descriptor`](#descriptor) objects must be provided that reference memory intended to be transferred to a remote worker.
Once created, the memory referenced by the provided descriptors becomes immediately readable by a remote worker with the necessary metadata.
The metadata required to access the memory referenced by the provided descriptors is accessible via the operations `.to_serialized()` method.
Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
Disposal of the operation's object reference will instruct the RDMA subsystem to cancel the operation,
therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
#### Methods
##### `to_serialized`
> Generates and returns the RDMA metadata ([`SerializedRequest`](#serializedrequest)) required for a remote worker to read from the operation.
> Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
##### `wait_for_completion`
> Blocks the caller until the operation has received a completion signal from a remote worker.
### WriteOperation
An operation which transfers data from the local worker to a remote worker.
To create the operation, RDMA metadata ([`SerializedRequest`](#serializedrequest)) from a remote worker's [`WritableOperation`](#writableoperation)
along with a matching set of local [`Descriptor`](#descriptor) objects which reference memory to be transferred to the remote worker must be provided.
The RDMA metadata must be transferred from the remote to the local worker via a secondary channel, most likely HTTP or TCP+NATS.
Once created, the operation will begin writing immediately.
Disposal of the object reference will instruct the RDMA subsystem to cancel the write operation,
therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
#### Methods
##### `cancel`
> Instructs the RDMA subsystem to cancel the operation.
> Completed operations cannot be cancelled.
##### `wait_for_completion`
> Blocks the caller until all provided buffers have been transferred to the remote worker.
### WritableOperation
An operation which enables a remote worker to write data to the local worker.
To create the operation, a set of local [`Descriptor`](#descriptor) objects must be provided which reference memory intended to receive data from a remote worker.
Once created, the memory referenced by the provided descriptors becomes immediately writable by a remote worker with the necessary metadata.
The metadata required to access the memory referenced by the provided descriptors is accessible via the operations `.to_serialized()` method.
Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
Disposal of the operation's object reference will instruct the RDMA subsystem to cancel the operation,
therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended.
#### Methods
##### `to_serialized`
> Generates and returns the RDMA metadata ([`SerializedRequest`](#serializedrequest)) required for a remote worker to write to the operation.
> Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS.
##### `wait_for_completion`
> Blocks the caller until the operation has received a completion signal from a remote worker.
### SerializedRequest
A Pydantic type intended to provide JSON serialized RDMA metadata about a [`ReadableOperation`](#readableoperation) or [`WritableOperation`](#writableoperation) object.
Use the [`.to_serialized()`](#to_serialized) method on either of the above types to generate a `SerializedRequest` object for an operation.
## References
- [NVIDIA Dynamo](https://developer.nvidia.com/dynamo) @ [GitHub](https://github.com/ai-dynamo/dynamo)
- [NVIDIA Inference Transfer Library (NIXL)](https://developer.nvidia.com/blog/introducing-nvidia-dynamo-a-low-latency-distributed-inference-framework-for-scaling-reasoning-ai-models/#nvidia_inference_transfer_library_nixl_low-latency_hardware-agnostic_communication%C2%A0) @ [GitHub](https://github.com/ai-dynamo/nixl)
- [Dynamo Multimodal Example](https://github.com/ai-dynamo/dynamo/tree/main/examples/multimodal)
- [NVIDIA GPU Direct](https://developer.nvidia.com/gpudirect)
This diff is collapsed.
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
# Default values
MODEL_NAME="llava-hf/llava-1.5-7b-hf"
PROMPT_TEMPLATE="USER: <image>\n<prompt> ASSISTANT:"
PROVIDED_PROMPT_TEMPLATE=""
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL_NAME=$2
shift 2
;;
--prompt-template)
PROVIDED_PROMPT_TEMPLATE=$2
shift 2
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --model <model_name> Specify the model to use (default: $MODEL_NAME)"
echo " --prompt-template <template> Specify the multi-modal prompt template to use. LLaVA 1.5 7B, Qwen2.5-VL, and Phi3V models have predefined templates."
echo " -h, --help Show this help message"
exit 0
;;
*)
echo "Unknown option: $1"
echo "Use --help for usage information"
exit 1
;;
esac
done
# Set PROMPT_TEMPLATE based on the MODEL_NAME
if [[ -n "$PROVIDED_PROMPT_TEMPLATE" ]]; then
PROMPT_TEMPLATE="$PROVIDED_PROMPT_TEMPLATE"
elif [[ "$MODEL_NAME" == "llava-hf/llava-1.5-7b-hf" ]]; then
PROMPT_TEMPLATE="USER: <image>\n<prompt> ASSISTANT:"
elif [[ "$MODEL_NAME" == "microsoft/Phi-3.5-vision-instruct" ]]; then
PROMPT_TEMPLATE="<|user|>\n<|image_1|>\n<prompt><|end|>\n<|assistant|>\n"
elif [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then
PROMPT_TEMPLATE="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|><prompt><|im_end|>\n<|im_start|>assistant\n"
else
echo "No multi-modal prompt template is defined for the model: $MODEL_NAME"
echo "Please provide a prompt template using --prompt-template option."
echo "Example: --prompt-template 'USER: <image>\n<prompt> ASSISTANT:'"
exit 1
fi
# run ingress
python -m dynamo.frontend &
# run processor
python3 components/processor.py --model $MODEL_NAME --prompt-template "$PROMPT_TEMPLATE" &
# run E/P/D workers
CUDA_VISIBLE_DEVICES=0 python3 components/encode_worker.py --model $MODEL_NAME &
CUDA_VISIBLE_DEVICES=1 python3 components/worker.py --model $MODEL_NAME --worker-type prefill &
# Wait for all background processes to complete
wait
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -ex
trap 'echo Cleaning up...; kill 0' EXIT
MODEL_NAME="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
# run ingress
python -m dynamo.frontend &
# run processor
python3 components/processor.py --model $MODEL_NAME --prompt-template "<|image|>\n<prompt>" &
# LLama 4 doesn't support image embedding input, so the prefill worker will also
# handle image encoding.
# run EP/D workers
python3 components/worker.py --model $MODEL_NAME --worker-type encode_prefill --tensor-parallel-size=8 --max-model-len=208960 &
# Wait for all background processes to complete
wait
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
# Default values
MODEL_NAME="llava-hf/llava-1.5-7b-hf"
PROMPT_TEMPLATE="USER: <image>\n<prompt> ASSISTANT:"
PROVIDED_PROMPT_TEMPLATE=""
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL_NAME=$2
shift 2
;;
--prompt-template)
PROVIDED_PROMPT_TEMPLATE=$2
shift 2
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --model <model_name> Specify the model to use (default: $MODEL_NAME)"
echo " --prompt-template <template> Specify the multi-modal prompt template to use. LLaVA 1.5 7B, Qwen2.5-VL, and Phi3V models have predefined templates."
echo " -h, --help Show this help message"
exit 0
;;
*)
echo "Unknown option: $1"
echo "Use --help for usage information"
exit 1
;;
esac
done
# Set PROMPT_TEMPLATE based on the MODEL_NAME
if [[ -n "$PROVIDED_PROMPT_TEMPLATE" ]]; then
PROMPT_TEMPLATE="$PROVIDED_PROMPT_TEMPLATE"
elif [[ "$MODEL_NAME" == "llava-hf/llava-1.5-7b-hf" ]]; then
PROMPT_TEMPLATE="USER: <image>\n<prompt> ASSISTANT:"
elif [[ "$MODEL_NAME" == "microsoft/Phi-3.5-vision-instruct" ]]; then
PROMPT_TEMPLATE="<|user|>\n<|image_1|>\n<prompt><|end|>\n<|assistant|>\n"
elif [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then
PROMPT_TEMPLATE="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|><prompt><|im_end|>\n<|im_start|>assistant\n"
else
echo "No multi-modal prompt template is defined for the model: $MODEL_NAME"
echo "Please provide a prompt template using --prompt-template option."
echo "Example: --prompt-template 'USER: <image>\n<prompt> ASSISTANT:'"
exit 1
fi
# run ingress
python -m dynamo.frontend &
# run processor
python3 components/processor.py --model $MODEL_NAME --prompt-template "$PROMPT_TEMPLATE" &
# run E/P/D workers
CUDA_VISIBLE_DEVICES=0 python3 components/encode_worker.py --model $MODEL_NAME &
CUDA_VISIBLE_DEVICES=1 python3 components/worker.py --model $MODEL_NAME --worker-type prefill --enable-disagg &
CUDA_VISIBLE_DEVICES=2 python3 components/worker.py --model $MODEL_NAME --worker-type decode --enable-disagg &
# Wait for all background processes to complete
wait
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -ex
# Default values
HEAD_NODE=0
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--head-node)
HEAD_NODE=1
shift 1
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --head-node Run as head node. Head node will run the HTTP server, processor and prefill worker."
echo " -h, --help Show this help message"
exit 0
;;
*)
echo "Unknown option: $1"
echo "Use --help for usage information"
exit 1
;;
esac
done
trap 'echo Cleaning up...; kill 0' EXIT
MODEL_NAME="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
if [[ $HEAD_NODE -eq 1 ]]; then
# run ingress
python -m dynamo.frontend &
# run processor
python3 components/processor.py --model $MODEL_NAME --prompt-template "<|image|>\n<prompt>" &
# LLama 4 doesn't support image embedding input, so the prefill worker will also
# handle image encoding.
# run EP/D workers
python3 components/worker.py --model $MODEL_NAME --worker-type encode_prefill --enable-disagg --tensor-parallel-size=8 --max-model-len=208960 &
else
# run decode worker on non-head node
python3 components/worker.py --model $MODEL_NAME --worker-type decode --enable-disagg --tensor-parallel-size=8 --max-model-len=208960 &
fi
# Wait for all background processes to complete
wait
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import json
import logging
import os
import socket
import sys
import time
from typing import Callable, List, Optional, Tuple
from vllm.config import KVTransferConfig
from vllm.distributed.kv_events import KVEventsConfig
from vllm.engine.arg_utils import AsyncEngineArgs
logger = logging.getLogger(__name__)
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
class Config:
"""Command line parameters or defaults"""
# dynamo specific
namespace: str
component: str
endpoint: str
kv_port: Optional[int] = None
side_channel_port: Optional[int] = None
# mirror vLLM
model: str
served_model_name: Optional[str]
# rest vLLM args
engine_args: AsyncEngineArgs
def parse_endpoint(endpoint: str) -> List[str]:
endpoint_str = endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
logger.error(
f"Invalid endpoint format: '{endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
sys.exit(1)
return endpoint_parts
def base_parse_args(
parser: argparse.ArgumentParser, endpoint_overwrite: Optional[Callable] = None
) -> Tuple[argparse.Namespace, Config]:
"""
Basic parsing logic for any dynamo vLLM deployment. The caller will use
'parser' and 'endpoint_overwrite' to apply use case specific customization.
Args:
parser (argparse.ArgumentParser): The argument parser which has use case
specific arguments added.
endpoint_overwrite (Callable): A user provided function to overwrite the endpoints
the given the parsed arguments. This function should return the overwritten args.
A typical selector will check the worker type and return specific endpoints.
Returns:
Tuple[argparse.Namespace, Config]: A tuple containing the parsed arguments
and a Config object with the relevant settings.
"""
if not any(arg.dest == "endpoint" for arg in parser._actions):
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
config = Config()
config.model = args.model
if args.served_model_name:
assert (
len(args.served_model_name) <= 1
), "We do not support multiple model names."
config.served_model_name = args.served_model_name[0]
else:
# This becomes an `Option` on the Rust side
config.served_model_name = None
if endpoint_overwrite is not None:
args = endpoint_overwrite(args)
endpoint = args.endpoint
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
endpoint
)
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
config.engine_args = engine_args
if config.engine_args.block_size is None:
config.engine_args.block_size = 16
logger.debug(
f"Setting reasonable default of {config.engine_args.block_size} for block_size"
)
return args, config
async def allocate_and_reserve_port(
namespace,
etcd_client,
worker_id: str,
reason: str,
max_attempts: int = 100,
) -> int:
"""
Get an OS-assigned port and atomically reserve it in ETCD.
Retries until successful or max_attempts reached.
Args:
max_attempts: Maximum number of ports to try (default: 100)
Raises:
RuntimeError: If unable to reserve a port within max_attempts
OSError: If unable to create sockets (system resource issues)
"""
node_name = socket.gethostname()
for attempt in range(1, max_attempts + 1):
# Hold socket open just long enough to reserve in ETCD
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("", 0))
port = sock.getsockname()[1]
# Reserve in ETCD while holding the socket
key = f"dyn://{namespace}/ports/{node_name}/{port}"
value = {
"worker_id": worker_id,
"reason": reason,
"reserved_at": time.time(),
"pid": os.getpid(),
}
try:
await etcd_client.kv_create(
key=key,
value=json.dumps(value).encode(),
lease_id=etcd_client.primary_lease_id(),
)
logger.debug(f"Reserved OS-assigned port {port} for {worker_id}")
return port
except Exception as e:
logger.debug(
f"Port {port} on {node_name} was already reserved (attempt {attempt}): {e}"
)
if attempt < max_attempts:
await asyncio.sleep(0.01)
raise RuntimeError(
f"Failed to allocate and reserve a port after {max_attempts} attempts"
)
async def configure_ports_with_etcd(config: Config, etcd_client):
"""Configure all settings that require ETCD, including port allocation and vLLM overrides."""
# First, allocate ports
dp_rank = config.engine_args.data_parallel_rank or 0
worker_id = f"vllm-{config.component}-dp{dp_rank}"
# Allocate KV events port
kv_port = await allocate_and_reserve_port(
namespace=config.namespace,
etcd_client=etcd_client,
worker_id=f"{worker_id}",
reason="zmq_kv_event_port",
)
# Allocate side channel port
side_channel_port = await allocate_and_reserve_port(
namespace=config.namespace,
etcd_client=etcd_client,
worker_id=f"{worker_id}",
reason="nixl_side_channel_port",
)
# Update config with allocated ports
config.kv_port = kv_port
config.side_channel_port = side_channel_port
def overwrite_args(config):
"""Set vLLM defaults for Dynamo."""
assert (
config.kv_port is not None
), "Must set the kv_port, use configure_ports_with_etcd"
assert (
config.side_channel_port is not None
), "Must set the side_channel_port, use configure_ports_with_etcd"
dp_rank = config.engine_args.data_parallel_rank or 0
defaults = {
"task": "generate",
"skip_tokenizer_init": False,
"disable_log_requests": True,
"enable_prefix_caching": True,
# KV routing relies on logging KV metrics
"disable_log_stats": False,
# Always setting up kv transfer for disagg
"kv_transfer_config": KVTransferConfig(
kv_connector="NixlConnector", kv_role="kv_both"
),
"kv_events_config": KVEventsConfig(
enable_kv_cache_events=True,
publisher="zmq",
endpoint=f"tcp://*:{config.kv_port - dp_rank}", # vLLM will iterate dp_rank for us, so we need to subtract it out TODO: fix in vLLM
),
}
set_side_channel_host_and_port(config)
logger.debug("Setting Dynamo defaults for vLLM")
for key, value in defaults.items():
if hasattr(config.engine_args, key):
setattr(config.engine_args, key, value)
logger.debug(f" engine_args.{key} = {value}")
else:
raise ValueError(f"{key} not found in AsyncEngineArgs from vLLM.")
def set_side_channel_host_and_port(config: Config, hostname: Optional[str] = None):
"""vLLM V1 NixlConnector creates a side channel to exchange metadata with other NIXL connectors.
This sets the port number for the side channel.
"""
if hostname is None:
hostname = socket.gethostname()
# Test if hostname is usable by attempting to bind to it
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as test_socket:
test_socket.bind((hostname, 0))
except (socket.error, socket.gaierror):
# If hostname is not usable, fall back to localhost
logger.warning(
f"Hostname '{hostname}' is not usable, falling back to '127.0.0.1'"
)
hostname = "127.0.0.1"
os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = hostname
os.environ["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(config.side_channel_port)
logger.debug(f"Set NIXL side channel to {hostname}:{config.side_channel_port}")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
from typing import AsyncIterator, List, Optional, Protocol, Union, runtime_checkable
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.chat_utils import ConversationMessage
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
RequestResponseMetadata,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import RequestPrompt
from vllm.inputs.data import TokensPrompt
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
@runtime_checkable
class ProcessMixInRequired(Protocol):
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
default_sampling_params: SamplingParams
class ProcessMixIn(ProcessMixInRequired):
"""
Mixin for pre and post processing for vLLM
"""
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
default_sampling_params: SamplingParams
def __init__(self):
pass
def _get_processor(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
# Determine the processor type based on the request structure
return (
self.chat_processor
if isinstance(raw_request, ChatCompletionRequest)
else self.completions_processor
)
async def _parse_raw_request(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
processor = self._get_processor(raw_request)
if processor is None:
raise RuntimeError("Processor has not been initialized")
request = processor.parse_raw_request(raw_request)
preprocess_result = await processor.preprocess(raw_request)
default_max_tokens = self.model_config.max_model_len - len(
preprocess_result.engine_prompt["prompt_token_ids"]
)
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern,
self.default_sampling_params,
)
return (
request,
preprocess_result.conversation,
preprocess_result.request_prompt,
preprocess_result.engine_prompt,
sampling_params,
)
async def _stream_response(self, request, generator, request_id, conversation):
processor = self._get_processor(request)
if processor is None:
raise RuntimeError("processor has not been initialized")
return processor.stream_response(
request,
generator,
request_id,
conversation,
)
class PreprocessResult:
def __init__(
self,
conversation: Optional[ConversationMessage],
request_prompt: RequestPrompt,
engine_prompt: TokensPrompt,
):
self.conversation = conversation
self.request_prompt = request_prompt
self.engine_prompt = engine_prompt
class ChatProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingChat(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
response_role="assistant",
chat_template=None,
chat_template_content_format="auto",
)
def parse_raw_request(
self, raw_request: ChatCompletionRequest
) -> ChatCompletionRequest:
return ChatCompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: ChatCompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
# TODO: Revisit this later when adding multi-modal support for the frontend.
# If no chat template is provided and tokenizer doesn't have one,
# use a simple format that just concatenates messages
if not request.chat_template and not self.tokenizer.chat_template:
chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}User: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}Assistant: {{ message['content'] }}\n{% endif %}{% endfor %}Assistant:"
else:
chat_template = request.chat_template or self.tokenizer.chat_template
(
conversation,
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_chat(
request,
self.tokenizer,
request.messages,
chat_template=chat_template,
chat_template_content_format=self.openai_serving.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tool_dicts=None,
documents=request.documents,
chat_template_kwargs=request.chat_template_kwargs,
tool_parser=self.openai_serving.tool_parser,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
return PreprocessResult(conversation[0], request_prompts[0], engine_prompts[0])
async def stream_response(
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: List,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if request.stream:
# Handle streaming response
num_output_text_so_far = 0
async for raw_response in self.openai_serving.chat_completion_stream_generator(
request,
result_generator,
request_id,
request.model,
conversation,
self.tokenizer,
request_metadata,
enable_force_include_usage=False,
):
if raw_response.startswith("data: [DONE]"):
yield raw_response
break
# Parse the response
response = json.loads(raw_response.lstrip("data: "))
# Process delta content to extract only new text
if "choices" in response and len(response["choices"]) > 0:
if "delta" in response["choices"][0]:
content = response["choices"][0]["delta"].get("content", "")
if content:
# Extract only the new part from the full content
new_content = content[num_output_text_so_far:]
response["choices"][0]["delta"]["content"] = new_content
num_output_text_so_far = len(content)
# Yield the processed response
yield f"data: {json.dumps(response)}\n\n"
else:
# Handle non-streaming response
# Collect all chunks into a single response
full_response = None
num_output_text_so_far = 0
async for raw_response in self.openai_serving.chat_completion_stream_generator(
request,
result_generator,
request_id,
request.model,
conversation,
self.tokenizer,
request_metadata,
enable_force_include_usage=False,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
if full_response is None:
# Initialize the full response structure
full_response = {
"id": response.get("id", ""),
"object": "chat.completion",
"created": int(time.time()),
"model": request.model,
"choices": [
{
"index": response.get("index", 0),
"message": {"role": "assistant", "content": ""},
"finish_reason": None,
}
],
}
# Concatenate content if it exists. Each delta contains the full text so far.
if "choices" in response and len(response["choices"]) > 0:
if "delta" in response["choices"][0]:
content = response["choices"][0]["delta"].get("content", "")
if content:
# Extract only the new part from the full content
new_content = content[num_output_text_so_far:]
full_response["choices"][0]["message"][
"content"
] += new_content
num_output_text_so_far = len(content)
# Update finish reason if present
if "finish_reason" in response["choices"][0]:
full_response["choices"][0]["finish_reason"] = response[
"choices"
][0]["finish_reason"]
if full_response is not None:
yield json.dumps(full_response)
class CompletionsProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingCompletion(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
)
def parse_raw_request(self, raw_request: CompletionRequest) -> CompletionRequest:
return CompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: CompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
(
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_completion(
request,
self.tokenizer,
input_or_inputs=request.prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
return PreprocessResult(None, request_prompts[0], engine_prompts[0])
async def stream_response(
self,
request: CompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: Optional[List[ConversationMessage]] = None,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if not request.stream:
raise ValueError("Only streaming responses are supported")
async for raw_response in self.openai_serving.completion_stream_generator(
request,
result_generator,
request_id,
int(time.time()), # created_time
request.model,
1, # num_prompts
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import base64
import binascii
import logging
from io import BytesIO
from urllib.parse import urlparse
import httpx
from PIL import Image
logger = logging.getLogger(__name__)
class ImageLoader:
CACHE_SIZE_MAXIMUM = 8
def __init__(self, cache_size: int = CACHE_SIZE_MAXIMUM):
self._http_timeout = 30.0
self._http_client = httpx.AsyncClient(timeout=self._http_timeout)
self._image_cache: dict[str, Image.Image] = {}
self._cache_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=cache_size)
async def load_image(self, image_url: str) -> Image.Image:
parsed_url = urlparse(image_url)
# For HTTP(S) URLs, check cache first
if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
if image_url_lower in self._image_cache:
logger.debug(f"Image found in cache for URL: {image_url}")
return self._image_cache[image_url_lower]
try:
if parsed_url.scheme == "data":
# Parse data URL format: data:[<media type>][;base64],<data>
if not parsed_url.path.startswith("image/"):
raise ValueError("Data URL must be an image type")
# Split the path into media type and data
media_type, data = parsed_url.path.split(",", 1)
if ";base64" not in media_type:
raise ValueError("Data URL must be base64 encoded")
try:
image_bytes = base64.b64decode(data)
image_data = BytesIO(image_bytes)
except binascii.Error as e:
raise ValueError(f"Invalid base64 encoding: {e}")
elif parsed_url.scheme in ("http", "https"):
if not self._http_client:
raise RuntimeError("HTTP client not initialized")
response = await self._http_client.get(image_url)
response.raise_for_status()
if not response.content:
raise ValueError("Empty response content from image URL")
image_data = BytesIO(response.content)
else:
raise ValueError(f"Invalid image source scheme: {parsed_url.scheme}")
# PIL is sync, so offload to a thread to avoid blocking the event loop
image = await asyncio.to_thread(Image.open, image_data)
# Validate image format and convert to RGB
if image.format not in ("JPEG", "PNG", "WEBP"):
raise ValueError(f"Unsupported image format: {image.format}")
image_converted = image.convert("RGB")
# Cache HTTP(S) URLs
if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
# Cache the image for future use, and evict the oldest image if the cache is full
if self._cache_queue.full():
oldest_image_url = await self._cache_queue.get()
del self._image_cache[oldest_image_url]
self._image_cache[image_url_lower] = image_converted
await self._cache_queue.put(image_url_lower)
return image_converted
except httpx.HTTPError as e:
logger.error(f"HTTP error loading image: {e}")
raise
except Exception as e:
logger.error(f"Error loading image: {e}")
raise ValueError(f"Failed to load image: {e}")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, Tuple
import torch
from transformers import AutoConfig
from utils.protocol import EncodeResponse
from vllm import AsyncEngineArgs
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.worker import Worker
# from transformers import AutoImageProcessor, LlavaForConditionalGeneration
# from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
logger = logging.getLogger(__name__)
# [gluo NOTE] in vLLM v1, Worker() usage below will results in NotImplementedError,
# must find another way to properly load the vision model given the model name (model_id).
def load_vision_model(model_id: str) -> torch.nn.Module:
"""
Load a vision model from a HuggingFace model ID.
"""
engine_args = AsyncEngineArgs(model=model_id, trust_remote_code=True)
engine_config = engine_args.create_engine_config()
distributed_init_method = get_distributed_init_method(get_ip(), get_open_port())
worker = Worker(
vllm_config=engine_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
is_driver_worker=True,
)
# Initialize the worker.
worker.init_device()
worker.load_model()
return worker.model_runner.model
# model = LlavaForConditionalGeneration.from_pretrained(
# model_id, device_map="auto", torch_dtype=torch.float16
# ).eval()
# model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
# model_id, torch_dtype="auto", device_map="auto"
# ).eval()
# return model
def get_vision_embeddings_info(
model_id: str, num_patches: int
) -> Tuple[Tuple[int, int, int], torch.dtype]:
"""Calculate vision embeddings size and dtype using model config
Returns a tuple of (batch_size, num_patches, hidden_dim), dtype.
"""
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
assert num_patches > 0, "Number of patches must be positive"
if not hasattr(config, "torch_dtype"):
raise ValueError("Model config missing required 'torch_dtype' attribute")
if not hasattr(config, "hidden_size"):
logger.warning(
"Model config missing required 'hidden_size' attribute, using 4096"
)
hidden_size = 4096
else:
hidden_size = config.hidden_size
return (1, num_patches, hidden_size), config.torch_dtype
def construct_mm_data(
model: str,
encode_output: EncodeResponse,
image_embeds: torch.Tensor,
embeddings_dtype: torch.dtype,
) -> Dict[str, torch.Tensor | Dict[str, Any]]:
"""Construct multimodal data for a vLLM request for models that require additional parameters alongside the embeddings"""
image_embeds = image_embeds.to(embeddings_dtype)
if "Qwen2" in model:
return {
"image": {
"image_embeds": image_embeds.squeeze(0),
"image_grid_thw": torch.tensor(encode_output.image_grid_thw).squeeze(0),
}
}
elif "MiniCPM-V" in model:
return {
"image": {
"image_embeds": image_embeds,
"image_sizes": encode_output.image_sizes,
}
}
else:
return {"image": image_embeds}
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, List, Literal, Optional, Union
import connect
import msgspec
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic_core import core_schema
from typing_extensions import NotRequired
from vllm.inputs.data import TokensPrompt
from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import PromptLogprobs, RequestMetrics
class Request(BaseModel):
prompt: str
sampling_params: dict
class Tokens(BaseModel):
tokens: list[int]
class PrefillRequest(Request):
request_id: str
class Response(BaseModel):
text: str
class PrefillResponse(BaseModel):
prefilled: bool
# Hack to override the type of multi_modal_data in TokensPrompt
# as pydantic doesn't understand generic types
# TokensPrompt is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/inputs/data.py#L38
# multi_modal_data is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L103
# ModalityData is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L80
class PatchedTokensPrompt(TokensPrompt):
multi_modal_data: NotRequired[Optional[Any]] # type: ignore
# Monkey-patch the SamplingParams type to add a dummy core schema so pydantic can validate it
# Sampling params is a mspspec struct
# SamplingParams is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/sampling_params.py#L88
SamplingParams.__get_pydantic_core_schema__ = classmethod(
lambda cls, source, handler: core_schema.any_schema()
)
class vLLMGenerateRequest(BaseModel):
"""
Serializable class of all the fields vLLM engine requires for inference
"""
engine_prompt: PatchedTokensPrompt
sampling_params: SamplingParams
request_id: str
prefix_hit_rate: Optional[float] = 0.0
@field_validator("sampling_params", mode="before")
@classmethod
def parse_sampling_params(cls, v: Any) -> SamplingParams:
if isinstance(v, str):
v = json.loads(v)
if isinstance(v, dict):
return SamplingParams(**v)
return v
model_config = ConfigDict(
arbitrary_types_allowed=True,
json_encoders={SamplingParams: lambda v: msgspec.json.encode(v)},
)
class TextContent(BaseModel):
type: Literal["text"]
text: str
class ImageURLDetail(BaseModel):
url: str
class ImageContent(BaseModel):
type: Literal["image_url"]
image_url: ImageURLDetail
MessageContent = Union[TextContent, ImageContent]
class ChatMessage(BaseModel):
role: Literal["user", "system", "assistant"]
content: List[MessageContent]
class MultiModalRequest(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
model: str
messages: List[ChatMessage]
max_tokens: Optional[int] = None
temperature: Optional[float] = None
stream: Optional[bool] = True
class vLLMMultimodalRequest(vLLMGenerateRequest):
model_config = ConfigDict(arbitrary_types_allowed=True)
image_url: Optional[str] = None
# image_features: Optional[List[List[List[float]]]] = None # Remove once have NIXL support
serialized_request: Optional[connect.SerializedRequest] = None
class EncodeRequest(BaseModel):
"""
Serializable class of all the fields vLLM engine requires for inference
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
image_url: str
request_id: str
serialized_request: Optional[connect.SerializedRequest] = None
class EncodeResponse(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str
image_grid_thw: Optional[List[Any]] = None
image_sizes: Optional[List[Any]] = None
serialized_request: Optional[connect.SerializedRequest] = None
image_features: List[List[List[float]]] # Remove once have NIXL support
class MyRequestOutput(BaseModel):
"""
RequestOutput from vLLM is not serializable by default
https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/outputs.py#L85
This class is used to serialize the RequestOutput and any recursively defined types
We can do this because PromptLogprobs, RequestMetrics, and CompletionOutput are all serializable dataclasses
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str
prompt: Optional[str] = None
prompt_token_ids: Optional[List[int]] = None
prompt_logprobs: Optional[PromptLogprobs] = None
outputs: List[CompletionOutput]
finished: bool
metrics: Optional[RequestMetrics] = None
kv_transfer_params: Optional[dict[str, Any]] = None
# lora_request: Optional[LoRARequest] = None
# encoder_prompt: Optional[str] = None
# encoder_prompt_token_ids: Optional[List[int]] = None
# num_cached_tokens: Optional[int] = None
# multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
...@@ -26,6 +26,7 @@ from dynamo._core import Backend as Backend ...@@ -26,6 +26,7 @@ from dynamo._core import Backend as Backend
from dynamo._core import Client as Client from dynamo._core import Client as Client
from dynamo._core import Component as Component from dynamo._core import Component as Component
from dynamo._core import DistributedRuntime as DistributedRuntime from dynamo._core import DistributedRuntime as DistributedRuntime
from dynamo._core import Endpoint as Endpoint
from dynamo._core import EtcdKvCache as EtcdKvCache from dynamo._core import EtcdKvCache as EtcdKvCache
from dynamo._core import ModelDeploymentCard as ModelDeploymentCard from dynamo._core import ModelDeploymentCard as ModelDeploymentCard
from dynamo._core import OAIChatPreprocessor as OAIChatPreprocessor from dynamo._core import OAIChatPreprocessor as OAIChatPreprocessor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment