Unverified Commit 617d55c0 authored by Ryan McCormick's avatar Ryan McCormick Committed by GitHub
Browse files

chore: remove deprecated examples/multimodal directory (#8141)


Signed-off-by: default avatarRyan McCormick <rmccormick@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 326a702d
......@@ -103,7 +103,6 @@ xpu:
multimodal:
- changed-files:
- any-glob-to-any-file:
- examples/multimodal/**
- components/src/dynamo/common/memory/multimodal_embedding_cache_manager.py
- components/src/dynamo/common/multimodal/**
- components/src/dynamo/vllm/omni/**
......
......@@ -16,7 +16,6 @@ Cargo.toml @ai-dynamo/dynamo-rust-codeowners
# Examples
/examples/ @ai-dynamo/Devops @ai-dynamo/dynamo-rust-codeowners @ai-dynamo/python-codeowners @ai-dynamo/dynamo-deploy-codeowners
/examples/multimodal/ @ai-dynamo/python-codeowners @ai-dynamo/Devops
# Dynamo deploy
/deploy/ @ai-dynamo/dynamo-deploy-codeowners
......
......@@ -256,7 +256,6 @@ To quickly setup both: `docker compose -f deploy/docker-compose.yml up -d`
[kv-routing]: docs/components/router/README.md
[planner]: docs/components/planner/planner-guide.md
[kvbm]: docs/components/kvbm/README.md
[mm]: examples/multimodal/
[migration]: docs/fault-tolerance/request-migration.md
[lora]: examples/backends/vllm/deploy/lora/README.md
[tools]: docs/agents/tool-calling.md
......@@ -165,5 +165,5 @@ and the worker awaits for the data transfer to complete for yielding a response.
- [NVIDIA Dynamo](https://developer.nvidia.com/dynamo) @ [GitHub](https://github.com/ai-dynamo/dynamo)
- [NVIDIA Inference Transfer Library (NIXL)](https://developer.nvidia.com/blog/introducing-nvidia-dynamo-a-low-latency-distributed-inference-framework-for-scaling-reasoning-ai-models/#nvidia_inference_transfer_library_nixl_low-latency_hardware-agnostic_communication%C2%A0) @ [GitHub](https://github.com/ai-dynamo/nixl)
- [Dynamo Multimodal Example](https://github.com/ai-dynamo/dynamo/tree/main/examples/multimodal)
- [Dynamo Multimodal Example](https://github.com/ai-dynamo/dynamo/tree/main/examples/backends/vllm/launch)
- [NVIDIA GPU Direct](https://developer.nvidia.com/gpudirect)
......@@ -55,7 +55,6 @@ Reference implementations for deploying multimodal models:
- [vLLM multimodal examples](https://github.com/ai-dynamo/dynamo/tree/main/examples/backends/vllm/launch) (image, video)
- [TRT-LLM multimodal examples](https://github.com/ai-dynamo/dynamo/tree/main/examples/backends/trtllm/launch)
- [SGLang multimodal examples](https://github.com/ai-dynamo/dynamo/tree/main/examples/backends/sglang/launch)
- [Experimental multimodal examples](https://github.com/ai-dynamo/dynamo/tree/main/examples/multimodal/launch) (audio)
## Backend Documentation
......
......@@ -11,7 +11,7 @@ This deployment pattern enables dynamic LoRA adapter loading from S3-compatible
- Kubernetes cluster with GPU support
- Helm 3.x installed
- `kubectl` configured to access your cluster
- Dynamo Kubernetes Platform installed ([Installation Guide](../../../../docs/kubernetes/installation-guide.md))
- Dynamo Kubernetes Platform installed ([Installation Guide](../../../../../../docs/kubernetes/installation-guide.md))
- HuggingFace token for downloading base and LoRA adapters
## Files in This Directory
......@@ -364,7 +364,7 @@ kubectl delete secret hf-token-secret -n ${NAMESPACE}
## Further Reading
- [Multimodal LoRA Launch Guide](../../launch/lora/README.md) - Local launch with shell scripts
- [LLM LoRA Deployment](../../../backends/vllm/deploy/lora/README.md) - Text-only LoRA deployment pattern
- [Dynamo Kubernetes Guide](../../../../docs/kubernetes/README.md) - Platform setup
- [Installation Guide](../../../../docs/kubernetes/installation-guide.md) - Platform installation
- [Multimodal LoRA Launch Guide](../../../launch/lora/multimodal/README.md) - Local launch with shell scripts
- [LLM LoRA Deployment](../README.md) - Text-only LoRA deployment pattern
- [Dynamo Kubernetes Guide](../../../../../../docs/kubernetes/README.md) - Platform setup
- [Installation Guide](../../../../../../docs/kubernetes/installation-guide.md) - Platform installation
......@@ -12,7 +12,7 @@
# List: curl http://<worker>:9090/v1/loras
# Unload: curl -X DELETE http://<worker>:9090/v1/loras/my-adapter
#
# Matches the pattern in: examples/multimodal/launch/lora/lora_agg.sh
# Matches the pattern in: examples/backends/vllm/launch/lora/multimodal/lora_agg.sh
apiVersion: nvidia.com/v1alpha1
kind: DynamoGraphDeployment
......
......@@ -14,7 +14,7 @@ Serve vision-language models (VLMs) with dynamically loadable LoRA adapters usin
### 1. Launch the server
```bash
cd examples/multimodal/launch/lora
cd examples/backends/vllm/launch/lora/multimodal
./lora_agg.sh
```
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
import asyncio
import logging
import os
import signal
import sys
from typing import AsyncIterator, Tuple
import uvloop
from transformers import AutoImageProcessor
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser
import dynamo.nixl_connect as connect
from dynamo.runtime import Client, DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
from utils.args import Config, base_parse_args, parse_endpoint
from utils.encode_utils import encode_image_embeddings, get_encoder_components
from utils.image_loader import ImageLoader
from utils.model import load_vision_model
from utils.protocol import MyRequestOutput, vLLMMultimodalRequest
configure_dynamo_logging()
logger = logging.getLogger(__name__)
try:
import cupy as array_module
if not array_module.cuda.is_available():
raise ImportError("CUDA is not available.")
DEVICE = "cuda"
logger.info("Using cupy for array operations (GPU mode).")
except ImportError as e:
logger.warning(f"Failed to import cupy, falling back to numpy: {e}.")
import numpy as array_module
DEVICE = "cpu"
CACHE_SIZE_MAXIMUM = 8
class VllmEncodeWorker:
def __init__(
self,
args: argparse.Namespace,
engine_args: AsyncEngineArgs,
pd_worker_client: Client,
) -> None:
self.pd_worker_client = pd_worker_client
self.engine_args = engine_args
self.model = self.engine_args.model
self.image_loader = ImageLoader(cache_size=CACHE_SIZE_MAXIMUM)
self.image_processor = AutoImageProcessor.from_pretrained(
self.model, trust_remote_code=True
)
self.vision_model = load_vision_model(self.model)
self.min_workers = 1
# Get encoder components for the model
self.vision_encoder, self.projector = get_encoder_components(
self.model, self.vision_model
)
def cleanup(self):
pass
async def generate(
self, request: vLLMMultimodalRequest
) -> AsyncIterator[MyRequestOutput]:
logger.debug(f"Got raw request: {request}")
if not isinstance(request, vLLMMultimodalRequest):
if isinstance(request, str):
request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received encode request: {{ id: {request.request_id} }}.")
request_id = request.request_id
# The following steps encode the requested image and provided useful embeddings.
# 1. Open the image from the provided URL.
# 2. Process the image using the image processor.
# 3. Run the image through the vision model's vision tower.
# 4. Run the results of the vision tower through the multi-modal projector.
# 5. Create a descriptor for the embeddings.
# 6. Create a write operation using the serialized request and the descriptor.
# 7. Await for the write operation to complete.
# 8. Yield the encode response.
try:
if not request.multimodal_input.image_url:
raise ValueError("image_url is required for the encode worker.")
image = await self.image_loader.load_image(
request.multimodal_input.image_url
)
logger.debug(f"Processing image for request: {{ id: {request_id} }}")
image_embeds = self.image_processor(images=image, return_tensors="pt")
# Encode the image embeddings using model-specific encoder
embeddings = encode_image_embeddings(
model_name=self.model,
image_embeds=image_embeds,
vision_encoder=self.vision_encoder,
projector=self.projector,
)
image_grid_thw = (
image_embeds["image_grid_thw"].tolist()
if "image_grid_thw" in image_embeds
else None
)
logger.debug(
f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}"
)
request.image_grid_thw = image_grid_thw
request.embeddings_shape = tuple(embeddings.shape)
descriptor = connect.Descriptor(embeddings)
with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
# Clear the image URL as hint that the image is passed as embeddings.
request.multimodal_input.image_url = None
logger.debug(f"Request: {request.model_dump_json()}")
# Get the response generator
response_generator = await self.pd_worker_client.round_robin(
request.model_dump_json()
)
await readable.wait_for_completion()
async for response in response_generator:
output = MyRequestOutput.model_validate_json(response.data())
yield MyRequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
).model_dump_json()
except Exception as e:
logger.error(f"Error processing request {request_id}: {e}")
raise
async def async_init(self, runtime: DistributedRuntime):
logger.info("Startup started.")
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector()
logger.info("Startup completed.")
@classmethod
def parse_args(cls) -> Tuple[argparse.Namespace, Config]:
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
DEFAULT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.encoder.generate"
DEFAULT_DOWNSTREAM_ENDPOINT = f"dyn://{DYN_NAMESPACE}.llm.generate"
parser = FlexibleArgumentParser(
description="vLLM based encoder for Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_ENDPOINT}'",
)
parser.add_argument(
"--downstream-endpoint",
type=str,
default=DEFAULT_DOWNSTREAM_ENDPOINT,
help=f"The endpoint string of the downstream LLM in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_DOWNSTREAM_ENDPOINT}'",
)
args, config = base_parse_args(parser)
return args, config
async def graceful_shutdown(runtime):
"""
By calling `runtime.shutdown()`, the endpoints will immediately be unavailable.
However, in-flight requests will still be processed until they are finished.
After all in-flight requests are finished, the `serve_endpoint` functions will return
and the engine will be shutdown by Python's garbage collector.
"""
logging.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
@dynamo_worker()
async def worker(runtime: DistributedRuntime):
# Runtime setup
# Set up signal handler for graceful shutdown
loop = asyncio.get_running_loop()
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logging.info("Signal handlers set up for graceful shutdown")
# worker setup
args, config = VllmEncodeWorker.parse_args()
await init(runtime, args, config)
async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Config):
"""
Instantiate and serve
"""
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
args.downstream_endpoint
)
pd_worker_client = await runtime.endpoint(
f"{parsed_namespace}.{parsed_component_name}.{parsed_endpoint_name}"
).client()
handler = VllmEncodeWorker(args, config.engine_args, pd_worker_client)
await handler.async_init(runtime)
logger.info("Waiting for PD Worker Instances ...")
await pd_worker_client.wait_for_instances()
logger.info(f"Starting to serve the {args.endpoint} endpoint...")
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate, metrics_labels=[("model", config.model)]
),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
import asyncio
import json
import logging
import os
import signal
import sys
import uuid
from enum import Enum
from typing import AsyncIterator, Tuple, Union
import uvloop
from transformers import AutoTokenizer
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.completion.protocol import CompletionRequest
from vllm.outputs import RequestOutput
from vllm.tokenizers import TokenizerLike as AnyTokenizer
from vllm.utils.argparse_utils import FlexibleArgumentParser
from dynamo.llm import ModelInput, ModelType, register_model
from dynamo.runtime import Client, DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
# To import example local module
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
from utils.args import Config, base_parse_args, parse_endpoint
from utils.chat_message_utils import extract_user_text
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.protocol import (
MultiModalInput,
MultiModalRequest,
MyRequestOutput,
vLLMMultimodalRequest,
)
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class RequestType(Enum):
CHAT = "chat"
COMPLETION = "completion"
class Processor(ProcessMixIn):
"""
vLLM pre and post processing
"""
@classmethod
def parse_args(cls) -> Tuple[argparse.Namespace, Config]:
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
DEFAULT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.processor.generate"
DEFAULT_DOWNSTREAM_ENDPOINT = f"dyn://{DYN_NAMESPACE}.encoder.generate"
parser = FlexibleArgumentParser(
description="vLLM based processor for Dynamo LLM."
)
parser.add_argument(
"--prompt-template",
type=str,
required=True,
help=(
"Different multi-modal models expect the prompt to contain different special media prompts. "
"The processor will use this argument to construct the final prompt. "
"User prompt will replace '<prompt>' in the provided template. "
"For example, if the user prompt is 'please describe the image' and the prompt template is "
"'USER: <image> <prompt> ASSISTANT:', the resulting prompt is "
"'USER: <image> please describe the image ASSISTANT:'."
),
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_ENDPOINT}'",
)
parser.add_argument(
"--downstream-endpoint",
type=str,
default=DEFAULT_DOWNSTREAM_ENDPOINT,
help=f"The endpoint string of the downstream encoder in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_DOWNSTREAM_ENDPOINT}'",
)
args, config = base_parse_args(parser)
return args, config
def __init__(
self,
args: argparse.Namespace,
engine_args: AsyncEngineArgs,
encode_worker_client: Client,
):
self.encode_worker_client = encode_worker_client
self.prompt_template = args.prompt_template
self.engine_args = engine_args
self.model_config = self.engine_args.create_model_config()
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.tokenizer = self._create_tokenizer(self.engine_args)
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
self.completions_processor = CompletionsProcessor(
self.tokenizer, self.model_config
)
def cleanup(self):
pass
def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer:
"""Create a TokenizerGroup using engine arguments similar to VLLM's approach"""
model_path = engine_args.model
# Create the base tokenizer with VLLM's typical settings
base_tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
padding_side="left",
truncation_side="left",
use_fast=True, # VLLM might use the fast tokenizer for efficiency
)
return base_tokenizer
# Main method to parse the request and send the request to the vllm worker.
async def _generate(
self,
raw_request: Union[CompletionRequest, ChatCompletionRequest],
multimodal_input: MultiModalInput,
request_type: RequestType,
):
request_id = str(uuid.uuid4().hex)
logger.debug(f"Got raw request: {raw_request}")
(
request,
conversation,
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
worker_request = vLLMMultimodalRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
model=raw_request.model,
multimodal_input=multimodal_input,
)
# model_dump_json() serializes the request to JSON string
# This API could accept Pydantic class, but SamplingParams
# in vLLMMultimodalRequest is not a Pydantic class and will
# cause TypeError: unsupported type SamplingParams
response_generator = await self.encode_worker_client.round_robin(
worker_request.model_dump_json()
)
output = self._generate_responses(response_generator, request_type)
# Stream the processed responses
async for response in await self._stream_response(
request, output, request_id, conversation
):
yield response
# This method is used to process the responses from the engine generator.
async def _generate_responses(
self,
response_generator: AsyncIterator[RequestOutput],
request_type: RequestType,
):
async for resp in response_generator:
# Deserialize the response from the engine
# Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data())
# OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object
request_output = RequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
)
if request_type == RequestType.CHAT:
# For chat requests, yield the request_output directly.
yield request_output
else:
raise NotImplementedError(
f"Request type {request_type} not implemented"
)
# The generate endpoint will be used by the frontend to handle incoming requests.
async def generate(self, raw_request: MultiModalRequest):
logger.debug(f"Got raw request: {raw_request}")
if not isinstance(raw_request, MultiModalRequest):
# If the request is not MultiModalRequest, convert it to MultiModalRequest
raw_request = MultiModalRequest.model_validate(raw_request)
# Ensure the configured template includes the placeholder
template = self.prompt_template
if "<prompt>" not in template:
raise ValueError("prompt_template must contain '<prompt>' placeholder")
user_text = extract_user_text(raw_request.messages)
prompt = template.replace("<prompt>", user_text)
msg = {
"role": "user",
"content": prompt,
}
# Set stream=True - the http frontend will handle aggregation of
# streamed chunks into a single http response, or stream them
# back as SSE responses based on the stream flag in the request.
chat_request = ChatCompletionRequest(
model=raw_request.model,
messages=[msg],
stream=True,
stream_options=raw_request.stream_options,
max_tokens=raw_request.max_tokens,
temperature=raw_request.temperature,
request_id=str(uuid.uuid4()),
)
multimodal_input = MultiModalInput()
for message in raw_request.messages:
for item in message.content:
if item.type == "image_url":
raise ValueError(
"Image requests should use the standard `python -m dynamo.frontend` "
"+ `python -m dynamo.vllm --enable-multimodal` flow instead of the "
"legacy multimodal example processor."
)
elif item.type == "video_url":
raise ValueError(
"Video requests should use the standard `python -m dynamo.frontend` "
"+ `python -m dynamo.vllm --enable-multimodal` flow instead of the "
"legacy multimodal example processor."
)
elif item.type == "audio_url":
if (
multimodal_input.image_url is not None
or multimodal_input.video_url is not None
):
raise ValueError("Cannot mix image, video and audio URLs")
multimodal_input.audio_url = item.audio_url.url
if (
multimodal_input.image_url is None
and multimodal_input.video_url is None
and multimodal_input.audio_url is None
):
raise ValueError(
"Audio requests are the only multimodal mode supported by the "
"legacy example processor."
)
async for response in self._generate(
chat_request, multimodal_input, RequestType.CHAT
):
logger.debug(
f"Generated response type {type(response)}, content: {response}"
)
# reconstructing back the OpenAI chat response as dynamo egress expects it
if response.startswith("data: [DONE]"):
break
response = json.loads(response.lstrip("data: "))
yield response
async def graceful_shutdown(runtime):
"""
By calling `runtime.shutdown()`, the endpoints will immediately be unavailable.
However, in-flight requests will still be processed until they are finished.
After all in-flight requests are finished, the `serve_endpoint` functions will return
and the engine will be shutdown by Python's garbage collector.
"""
logging.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
@dynamo_worker()
async def worker(runtime: DistributedRuntime):
# Runtime setup
# Set up signal handler for graceful shutdown
loop = asyncio.get_running_loop()
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logging.info("Signal handlers set up for graceful shutdown")
# worker setup
args, config = Processor.parse_args()
await init(runtime, args, config)
async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Config):
"""
Instantiate and serve
"""
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
args.downstream_endpoint
)
encode_worker_client = await runtime.endpoint(
f"{parsed_namespace}.{parsed_component_name}.{parsed_endpoint_name}"
).client()
handler = Processor(args, config.engine_args, encode_worker_client)
logger.info("Waiting for Encoder Worker Instances ...")
await encode_worker_client.wait_for_instances()
# Register the endpoint as entrypoint to a model
await register_model(
ModelInput.Text, # Custom processor is used and this type bypasses SDK processor
ModelType.Chat,
generate_endpoint,
config.model,
config.served_model_name,
kv_cache_block_size=config.engine_args.block_size,
)
logger.info(f"Starting to serve the {args.endpoint} endpoint...")
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate, metrics_labels=[("model", config.model)]
),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
from typing import List, Optional, Tuple
from vllm.config import VllmConfig
from vllm.v1.metrics.loggers import StatLoggerBase
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from dynamo.llm import WorkerMetricsPublisher
from dynamo.runtime import Endpoint
class NullStatLogger(StatLoggerBase):
def __init__(self):
pass
def record(
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
engine_idx: int = 0,
*args,
**kwargs,
):
pass
def log_engine_initialized(self):
pass
class DynamoStatLoggerPublisher(StatLoggerBase):
"""Stat logger publisher. Wrapper for the WorkerMetricsPublisher to match the StatLoggerBase interface."""
def __init__(
self,
endpoint: Endpoint,
dp_rank: int,
) -> None:
self.inner = WorkerMetricsPublisher()
self._endpoint = endpoint
self.dp_rank = dp_rank
self.num_gpu_block = 1
# Schedule async endpoint creation
self._endpoint_task = asyncio.create_task(self._create_endpoint())
async def _create_endpoint(self) -> None:
"""Create the NATS endpoint asynchronously."""
try:
await self.inner.create_endpoint(self._endpoint)
logging.debug("Multimodal metrics publisher endpoint created")
except Exception:
logging.exception("Failed to create multimodal metrics publisher endpoint")
raise
# TODO: Remove this and pass as metadata through etcd
def set_num_gpu_block(self, num_blocks):
self.num_gpu_block = num_blocks
def record(
self,
scheduler_stats: SchedulerStats,
iteration_stats: Optional[IterationStats],
engine_idx: int = 0,
*args,
**kwargs,
):
active_decode_blocks = int(self.num_gpu_block * scheduler_stats.kv_cache_usage)
self.inner.publish(self.dp_rank, active_decode_blocks)
def init_publish(self):
self.inner.publish(self.dp_rank, 0)
def log_engine_initialized(self) -> None:
pass
class StatLoggerFactory:
"""Factory for creating stat logger publishers. Required by vLLM."""
def __init__(
self,
endpoint: Endpoint,
dp_rank: int = 0,
metrics_labels: Optional[List[Tuple[str, str]]] = None,
) -> None:
self.endpoint = endpoint
self.created_logger: Optional[DynamoStatLoggerPublisher] = None
self.dp_rank = dp_rank
self.metrics_labels = metrics_labels or []
def create_stat_logger(self, dp_rank: int) -> StatLoggerBase:
if self.dp_rank != dp_rank:
return NullStatLogger()
logger = DynamoStatLoggerPublisher(self.endpoint, dp_rank)
self.created_logger = logger
return logger
def __call__(self, vllm_config: VllmConfig, dp_rank: int) -> StatLoggerBase:
return self.create_stat_logger(dp_rank=dp_rank)
# TODO Remove once we publish metadata to etcd
def set_num_gpu_blocks_all(self, num_blocks):
if self.created_logger:
self.created_logger.set_num_gpu_block(num_blocks)
def init_publish(self):
if self.created_logger:
self.created_logger.init_publish()
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import os
if "PYTHONHASHSEED" not in os.environ:
os.environ["PYTHONHASHSEED"] = "0"
import argparse
import asyncio
import copy
import logging
import signal
import sys
from typing import Tuple
import torch
import uvloop
from vllm.distributed.kv_events import ZmqEventPublisher
from vllm.inputs.data import TokensPrompt
from vllm.usage.usage_lib import UsageContext
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.engine.async_llm import AsyncLLM
import dynamo.nixl_connect as connect
from dynamo.llm import KvEventPublisher
from dynamo.runtime import DistributedRuntime, Endpoint, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
from publisher import StatLoggerFactory
from utils.args import (
Config,
base_parse_args,
configure_ports,
overwrite_args,
parse_endpoint,
)
from utils.image_loader import ImageLoader
from utils.model import construct_mm_data
from utils.protocol import MyRequestOutput, vLLMMultimodalRequest
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class VllmBaseWorker:
@classmethod
def parse_args(cls) -> Tuple[argparse.Namespace, Config]:
parser = FlexibleArgumentParser(
description="vLLM based encoder for Dynamo LLM."
)
parser.add_argument(
"--endpoint",
type=str,
help="Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default value will vary based on the worker type, see --worker-type for details.",
)
parser.add_argument(
"--downstream-endpoint",
type=str,
help="The endpoint string of the downstream LLM in 'dyn://namespace.component.endpoint' format. Default value will vary based on the worker type, see --worker-type for details.",
)
parser.add_argument(
"--worker-type",
type=str,
choices=["prefill", "decode", "encode_prefill"],
required=True,
help="Specify the type of worker. Must be one of: 'prefill', 'decode', 'encode_prefill'",
)
parser.add_argument(
"--enable-disagg",
action="store_true",
help="Enable disaggregated mode, where prefill and decode are handled by separate workers."
" If not set, the '*prefill' worker type will handle both prefill and decode.",
)
# use endpoint_overwrite to set the default endpoint based on worker type
def endpoint_overwrite(args):
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
# default endpoint for this worker
if args.worker_type == "prefill":
args.endpoint = args.endpoint or f"dyn://{DYN_NAMESPACE}.llm.generate"
elif args.worker_type == "decode":
args.endpoint = (
args.endpoint or f"dyn://{DYN_NAMESPACE}.decoder.generate"
)
elif args.worker_type == "encode_prefill":
args.endpoint = (
args.endpoint or f"dyn://{DYN_NAMESPACE}.encoder.generate"
)
# set downstream endpoint for disaggregated workers
if args.enable_disagg:
args.downstream_endpoint = (
args.downstream_endpoint
or f"dyn://{DYN_NAMESPACE}.decoder.generate"
)
return args
args, config = base_parse_args(parser, endpoint_overwrite)
return args, config
def __init__(
self,
args: argparse.Namespace,
endpoint: Endpoint,
config: Config,
):
self.enable_disagg = args.enable_disagg
self.endpoint = args.endpoint
self.downstream_endpoint = args.downstream_endpoint
self.engine_args = config.engine_args
self.config = config
self.setup_vllm_engine(endpoint)
async def async_init(self, runtime: DistributedRuntime):
pass
def setup_vllm_engine(self, endpoint: Endpoint):
"""Initialize the vLLM engine.
This method sets up the vLLM engine client, and configures the dynamo-aware KV
event publisher and metrics stats logger based on endpoint.
"""
os.environ["VLLM_NO_USAGE_STATS"] = "1" # Avoid internal HTTP requests
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
# Load default sampling params from `generation_config.json`
self.default_sampling_params = (
self.engine_args.create_model_config().get_diff_sampling_param()
)
# Taken from build_async_engine_client_from_engine_args()
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = self.engine_args.create_engine_config(usage_context=usage_context)
# Create vLLM engine with metrics logger and KV event publisher attached
self.stats_logger = StatLoggerFactory(
endpoint=endpoint,
dp_rank=self.engine_args.data_parallel_rank or 0,
)
self.engine_client = AsyncLLM.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
stat_loggers=[self.stats_logger],
enable_log_requests=self.engine_args.enable_log_requests,
disable_log_stats=self.engine_args.disable_log_stats,
)
# TODO Hack to get data, move this to registering in ETCD
self.stats_logger.set_num_gpu_blocks_all(
vllm_config.cache_config.num_gpu_blocks
)
self.stats_logger.init_publish()
# TODO: We start off with a valid endpoint, then we increment it by dp_rank
# May no longer be valid. Lets remove the increment behavior from vLLM and here
zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
self.engine_args.kv_events_config.endpoint,
data_parallel_rank=self.engine_args.data_parallel_rank or 0,
).replace("*", "127.0.0.1")
self.kv_publisher = KvEventPublisher(
endpoint=endpoint,
kv_block_size=vllm_config.cache_config.block_size,
zmq_endpoint=zmq_endpoint,
)
logger.info(f"Reading Events from {zmq_endpoint}")
logger.info(f"VllmWorker for {self.engine_args.model} has been initialized")
async def generate(self, request: vLLMMultimodalRequest):
raise NotImplementedError(
"This method should be implemented in subclasses to handle the generation logic."
)
async def clear_kv_blocks(self, request=None):
try:
await self.engine_client.reset_prefix_cache()
yield {"status": "success", "message": "KV cache cleared"}
except Exception as e:
yield {"status": "error", "message": str(e)}
def cleanup(self):
"""Override in subclasses if cleanup is needed."""
pass
class VllmDecodeWorker(VllmBaseWorker):
async def generate(self, request: vLLMMultimodalRequest):
logger.debug(f"Got raw request: {request}")
if not isinstance(request, vLLMMultimodalRequest):
if isinstance(request, str):
request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received decode request: {{ id: {request.request_id} }}.")
# Decode worker doesn't process embeddings, so we pass None or empty tensor
gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=request.engine_prompt["prompt_token_ids"],
),
sampling_params=request.sampling_params,
request_id=request.request_id,
)
async for response in gen:
logger.debug(f"Response kv_transfer_params: {response.kv_transfer_params}")
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
metrics=response.metrics,
kv_transfer_params=response.kv_transfer_params,
).model_dump_json()
class VllmPDWorker(VllmBaseWorker):
async def async_init(self, runtime: DistributedRuntime):
logger.info("Startup started.")
if self.enable_disagg:
(
parsed_namespace,
parsed_component_name,
parsed_endpoint_name,
) = parse_endpoint(self.downstream_endpoint)
self.decode_worker_client = await runtime.endpoint(
f"{parsed_namespace}.{parsed_component_name}.{parsed_endpoint_name}"
).client()
self.EMBEDDINGS_DTYPE = torch.float16
self.EMBEDDINGS_DEVICE = "cpu"
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
parsed_namespace, _, _ = parse_endpoint(self.endpoint)
self._connector = connect.Connector()
self.image_loader = ImageLoader()
logger.info("VllmPDWorker has been initialized")
async def generate(self, request: vLLMMultimodalRequest):
logger.debug(f"Got raw request: {request}")
if type(request) is not vLLMMultimodalRequest:
if type(request) is str:
request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")
if (
request.multimodal_input.image_url is None
and request.multimodal_input.video_url is None
and request.multimodal_input.audio_url is None
):
# Process embeddings using the connector
# Create a descriptor based on the embedding shape.
embeddings = torch.empty(
request.embeddings_shape,
dtype=self.EMBEDDINGS_DTYPE,
device=self.EMBEDDINGS_DEVICE,
)
descriptor = connect.Descriptor(embeddings)
if descriptor is None:
raise RuntimeError(
"Descriptor is None in PD worker - cannot process embeddings"
)
read_op = await self._connector.begin_read(
request.serialized_request, descriptor
)
await read_op.wait_for_completion()
if "audio" in self.engine_args.model.lower():
multi_modal_data = construct_mm_data(
self.engine_args.model,
self.EMBEDDINGS_DTYPE,
audio_embeds=embeddings,
)
else:
multi_modal_data = construct_mm_data(
self.engine_args.model,
self.EMBEDDINGS_DTYPE,
image_embeds=embeddings,
image_grid_thw=request.image_grid_thw,
)
else:
# Use PIL image instead of image embeddings
multi_modal_data = {
"image": await self.image_loader.load_image(
request.multimodal_input.image_url
)
}
# Remove the image features from the request as they are not required
request.multimodal_input.image_url = None
request.multimodal_input.video_url = None
request.multimodal_input.audio_url = None
request.serialized_request = None
pd_request = copy.deepcopy(request)
# Do prefill and remote decode if enable_disagg is true
if self.enable_disagg:
extra_args = pd_request.sampling_params.extra_args or {}
extra_args["kv_transfer_params"] = {
"do_remote_decode": True,
}
pd_request.sampling_params.extra_args = extra_args
pd_request.sampling_params.max_tokens = 1
pd_request.sampling_params.min_tokens = 1
logger.debug("Prefill request: %s", pd_request)
gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=pd_request.engine_prompt["prompt_token_ids"],
multi_modal_data=multi_modal_data,
),
sampling_params=pd_request.sampling_params,
request_id=pd_request.request_id,
)
if self.enable_disagg:
decode_request = copy.deepcopy(request)
async for prefill_response in gen:
# Update the prompt token id in the decode request to the one
# in response, which has image templated filled in. So that
# the decode worker will fetch correct amount of KV blocks.
decode_request.engine_prompt[
"prompt_token_ids"
] = prefill_response.prompt_token_ids
logger.debug(
f"Prefill response kv_transfer_params: {prefill_response.kv_transfer_params}"
)
extra_args = decode_request.sampling_params.extra_args or {}
extra_args["kv_transfer_params"] = prefill_response.kv_transfer_params
extra_args.pop("serialized_request", None)
decode_request.sampling_params.extra_args = extra_args
logger.debug("Decode request: %s", decode_request)
async for (
decode_response
) in await self.decode_worker_client.round_robin(
decode_request.model_dump_json()
):
output = MyRequestOutput.model_validate_json(decode_response.data())
yield MyRequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
kv_transfer_params=output.kv_transfer_params,
).model_dump_json()
else:
async for response in gen:
logger.debug(
f"Response kv_transfer_params: {response.kv_transfer_params}"
)
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
metrics=response.metrics,
kv_transfer_params=response.kv_transfer_params,
).model_dump_json()
async def graceful_shutdown(runtime):
"""
By calling `runtime.shutdown()`, the endpoints will immediately be unavailable.
However, in-flight requests will still be processed until they are finished.
After all in-flight requests are finished, the `serve_endpoint` functions will return
and the engine will be shutdown by Python's garbage collector.
"""
logging.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
@dynamo_worker()
async def worker(runtime: DistributedRuntime):
# Runtime setup
# Set up signal handler for graceful shutdown
loop = asyncio.get_running_loop()
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logging.info("Signal handlers set up for graceful shutdown")
# worker setup
args, config = VllmBaseWorker.parse_args()
# vLLM config overwrites
configure_ports(config)
overwrite_args(config)
await init(runtime, args, config)
async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Config):
"""
Instantiate and serve
"""
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
clear_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.clear_kv_blocks"
)
if args.worker_type in ["prefill", "encode_prefill"]:
handler: VllmBaseWorker = VllmPDWorker(args, generate_endpoint, config)
elif args.worker_type == "decode":
handler = VllmDecodeWorker(args, generate_endpoint, config)
await handler.async_init(runtime)
logger.info(f"Starting to serve the {args.endpoint} endpoint...")
metrics_labels = [("model", config.model)]
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate, metrics_labels=metrics_labels
),
clear_endpoint.serve_endpoint(
handler.clear_kv_blocks, metrics_labels=metrics_labels
),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
apiVersion: nvidia.com/v1alpha1
kind: DynamoGraphDeployment
metadata:
name: agg-llava
spec:
backendFramework: vllm
services:
Frontend:
componentType: frontend
replicas: 1
extraPodSpec:
mainContainer:
image: my-registry/vllm-runtime:my-tag
EncodeWorker:
envFromSecret: hf-token-secret
componentType: worker
replicas: 1
resources:
limits:
gpu: "1"
extraPodSpec:
mainContainer:
image: my-registry/vllm-runtime:my-tag
workingDir: /workspace/examples/multimodal
command:
- /bin/sh
- -c
args:
- python3 components/encode_worker.py --model llava-hf/llava-1.5-7b-hf
VLMWorker:
envFromSecret: hf-token-secret
componentType: worker
replicas: 1
resources:
limits:
gpu: "1"
extraPodSpec:
mainContainer:
image: my-registry/vllm-runtime:my-tag
workingDir: /workspace/examples/multimodal
command:
- /bin/sh
- -c
args:
- python3 components/worker.py --model llava-hf/llava-1.5-7b-hf --worker-type prefill
Processor:
envFromSecret: hf-token-secret
componentType: worker
replicas: 1
resources:
limits:
gpu: "1"
extraPodSpec:
mainContainer:
image: my-registry/vllm-runtime:my-tag
workingDir: /workspace/examples/multimodal
command:
- /bin/sh
- -c
args:
- 'python3 components/processor.py --model llava-hf/llava-1.5-7b-hf --prompt-template "USER: <image>\n<prompt> ASSISTANT:"'
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
apiVersion: nvidia.com/v1alpha1
kind: DynamoGraphDeployment
metadata:
name: agg-qwen
spec:
backendFramework: vllm
services:
Frontend:
componentType: frontend
replicas: 1
extraPodSpec:
mainContainer:
image: my-registry/vllm-runtime:my-tag
EncodeWorker:
envFromSecret: hf-token-secret
componentType: worker
replicas: 1
resources:
limits:
gpu: "1"
extraPodSpec:
mainContainer:
image: my-registry/vllm-runtime:my-tag
workingDir: /workspace/examples/multimodal
command:
- /bin/sh
- -c
args:
- python3 components/encode_worker.py --model Qwen/Qwen2.5-VL-7B-Instruct
VLMWorker:
envFromSecret: hf-token-secret
componentType: worker
replicas: 1
resources:
limits:
gpu: "1"
extraPodSpec:
mainContainer:
image: my-registry/vllm-runtime:my-tag
workingDir: /workspace/examples/multimodal
command:
- /bin/sh
- -c
args:
- python3 components/worker.py --model Qwen/Qwen2.5-VL-7B-Instruct --worker-type prefill
Processor:
envFromSecret: hf-token-secret
componentType: worker
replicas: 1
resources:
limits:
gpu: "1"
extraPodSpec:
mainContainer:
image: my-registry/vllm-runtime:my-tag
workingDir: /workspace/examples/multimodal
command:
- /bin/sh
- -c
args:
- 'python3 components/processor.py --model Qwen/Qwen2.5-VL-7B-Instruct --prompt-template "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|><prompt><|im_end|>\n<|im_start|>assistant\n"'
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
import ipaddress
import logging
import os
import socket
import sys
from typing import Callable, List, Optional, Tuple
from vllm.config import KVTransferConfig
from vllm.distributed.kv_events import KVEventsConfig
from vllm.engine.arg_utils import AsyncEngineArgs
logger = logging.getLogger(__name__)
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
DEFAULT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.backend.generate"
class Config:
"""Command line parameters or defaults"""
# dynamo specific
namespace: str
component: str
endpoint: str
kv_port: Optional[int] = None
# mirror vLLM
model: str
served_model_name: Optional[str]
# rest vLLM args
engine_args: AsyncEngineArgs
def parse_endpoint(endpoint: str) -> List[str]:
endpoint_str = endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
logger.error(
f"Invalid endpoint format: '{endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
sys.exit(1)
return endpoint_parts
def base_parse_args(
parser: argparse.ArgumentParser, endpoint_overwrite: Optional[Callable] = None
) -> Tuple[argparse.Namespace, Config]:
"""
Basic parsing logic for any dynamo vLLM deployment. The caller will use
'parser' and 'endpoint_overwrite' to apply use case specific customization.
Args:
parser (argparse.ArgumentParser): The argument parser which has use case
specific arguments added.
endpoint_overwrite (Callable): A user provided function to overwrite the endpoints
the given the parsed arguments. This function should return the overwritten args.
A typical selector will check the worker type and return specific endpoints.
Returns:
Tuple[argparse.Namespace, Config]: A tuple containing the parsed arguments
and a Config object with the relevant settings.
"""
if not any(arg.dest == "endpoint" for arg in parser._actions):
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
config = Config()
config.model = args.model
if args.served_model_name:
assert (
len(args.served_model_name) <= 1
), "We do not support multiple model names."
config.served_model_name = args.served_model_name[0]
else:
# This becomes an `Option` on the Rust side
config.served_model_name = None
if endpoint_overwrite is not None:
args = endpoint_overwrite(args)
endpoint = args.endpoint
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
endpoint
)
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
config.engine_args = engine_args
if config.engine_args.block_size is None:
config.engine_args.block_size = 16
logger.debug(
f"Setting reasonable default of {config.engine_args.block_size} for block_size"
)
return args, config
def get_kv_port() -> int:
"""Get KV events port from environment or default."""
return int(os.getenv("DYN_VLLM_KV_EVENT_PORT", "20080"))
def ensure_side_channel_host():
"""Ensure the NIXL side-channel host is available without overriding user settings.
Uses hostname resolution with UDP connect fallback. Supports IPv4 and IPv6.
Raises RuntimeError if no routable IP can be determined.
"""
existing_host = os.getenv("VLLM_NIXL_SIDE_CHANNEL_HOST")
if existing_host:
logger.info("Using existing VLLM_NIXL_SIDE_CHANNEL_HOST=%s", existing_host)
return
def is_routable(ip_str: str) -> bool:
try:
addr = ipaddress.ip_address(ip_str)
return not (
addr.is_loopback
or addr.is_link_local
or addr.is_unspecified
or addr.is_multicast
)
except ValueError:
return False
# Strategy 1: hostname resolution (AF_UNSPEC for IPv4+IPv6)
host_ip = None
detection_method = None
try:
host_name = socket.gethostname()
infos = socket.getaddrinfo(
host_name, None, socket.AF_UNSPEC, socket.SOCK_STREAM
)
for family, socktype, _, _, sockaddr in infos:
candidate = sockaddr[0]
try:
with socket.socket(family, socktype) as s:
s.bind((candidate, 0))
if is_routable(candidate):
host_ip = candidate
detection_method = "hostname resolution"
break
except OSError:
continue
except OSError as exc:
logger.debug("Hostname resolution failed: %s", exc)
# Strategy 2: UDP connect trick (IPv4 then IPv6)
if not host_ip:
for family, target, label in [
(socket.AF_INET, ("8.8.8.8", 80), "outbound interface detection (IPv4)"),
(
socket.AF_INET6,
("2001:4860:4860::8888", 80),
"outbound interface detection (IPv6)",
),
]:
try:
with socket.socket(family, socket.SOCK_DGRAM) as s:
s.connect(target)
candidate = s.getsockname()[0]
if is_routable(candidate):
host_ip = candidate
detection_method = label
break
except OSError:
continue
if not host_ip:
raise RuntimeError(
"Unable to determine a routable host IP for NIXL side-channel. "
"Please set the VLLM_NIXL_SIDE_CHANNEL_HOST environment variable to "
"the IP address that peer nodes can reach this host on."
)
os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = host_ip
logger.info(
"Set VLLM_NIXL_SIDE_CHANNEL_HOST=%s (detected via %s)",
host_ip,
detection_method,
)
def configure_ports(config: Config):
"""Configure port settings from dedicated environment overrides."""
# Always set kv_port as it's used by overwrite_args regardless of prefix caching
config.kv_port = get_kv_port()
ensure_side_channel_host()
def overwrite_args(config):
"""Set vLLM defaults for Dynamo."""
if config.engine_args.enable_prefix_caching:
assert config.kv_port is not None, "Must set the kv_port, use configure_ports"
dp_rank = config.engine_args.data_parallel_rank or 0
defaults = {
# vLLM 0.13+ renamed 'task' to 'runner'
"runner": "generate",
"skip_tokenizer_init": False,
"enable_log_requests": False,
"enable_prefix_caching": True,
# KV routing relies on logging KV metrics
"disable_log_stats": False,
# Enable multimodal embeddings input
"enable_mm_embeds": True,
# Always setting up kv transfer for disagg
"kv_transfer_config": KVTransferConfig(
kv_connector="NixlConnector", kv_role="kv_both"
),
"kv_events_config": KVEventsConfig(
enable_kv_cache_events=True,
publisher="zmq",
endpoint=f"tcp://*:{config.kv_port - dp_rank}", # vLLM will iterate dp_rank for us, so we need to subtract it out TODO: fix in vLLM
),
}
logger.debug("Setting Dynamo defaults for vLLM")
for key, value in defaults.items():
if hasattr(config.engine_args, key):
setattr(config.engine_args, key, value)
logger.debug(f" engine_args.{key} = {value}")
else:
logger.debug(
f" Skipping engine_args.{key} (not available in this vLLM version)"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment