"detection/git@developer.sourcefind.cn:OpenDAS/dcnv3.git" did not exist on "00af501a6ff6b817118a3ca78989b089a7309d55"
Commit cab65e1a authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: onboard nixl based vllm example to dynamo serve (#120)

parent 8435b993
......@@ -63,6 +63,8 @@ class ServiceConfig(dict):
if isinstance(value, bool):
if value:
args.append(f"--{arg_key}")
elif isinstance(value, dict):
args.extend([f"--{arg_key}", json.dumps(value)])
else:
args.extend([f"--{arg_key}", str(value)])
......
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
## Overview
Pipeline Architecture:
```
Users/Clients (HTTP)
┌─────────────┐
│ Frontend │ HTTP API endpoint (/generate)
└─────────────┘
│ dynamo/runtime
┌─────────────┐
│ Middle │
└─────────────┘
│ dynamo/runtime
┌─────────────┐
│ Backend │
└─────────────┘
```
## Unified serve
1. Launch all three services using a single command -
```bash
cd /workspace/deploy/examples
dynamo hello_world.hello_world:Frontend
```
2. Send request to frontend using curl -
```bash
curl -X 'POST' \
'http://localhost:3000/generate' \
-H 'accept: text/event-stream' \
-H 'Content-Type: application/json' \
-d '{
"text": "test"
}'
```
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic import BaseModel
from dynamo.sdk import api, depends, dynamo_endpoint, service
"""
Pipeline Architecture:
Users/Clients (HTTP)
┌─────────────┐
│ Frontend │ HTTP API endpoint (/generate)
└─────────────┘
│ dynamo/runtime
┌─────────────┐
│ Middle │
└─────────────┘
│ dynamo/runtime
┌─────────────┐
│ Backend │
└─────────────┘
"""
class RequestType(BaseModel):
text: str
class ResponseType(BaseModel):
text: str
@service(
resources={"cpu": "2"},
traffic={"timeout": 30},
dynamo={
"enabled": True,
"namespace": "inference",
},
workers=3,
)
class Backend:
def __init__(self) -> None:
print("Starting backend")
@dynamo_endpoint()
async def generate(self, req: RequestType):
"""Generate tokens."""
req_text = req.text
print(f"Backend received: {req_text}")
text = f"{req_text}-back"
for token in text.split():
yield f"Backend: {token}"
@service(
resources={"cpu": "2"},
traffic={"timeout": 30},
dynamo={"enabled": True, "namespace": "inference"},
)
class Middle:
backend = depends(Backend)
def __init__(self) -> None:
print("Starting middle")
@dynamo_endpoint()
async def generate(self, req: RequestType):
"""Forward requests to backend."""
req_text = req.text
print(f"Middle received: {req_text}")
text = f"{req_text}-mid"
next_request = RequestType(text=text).model_dump_json()
async for response in self.backend.generate(next_request):
print(f"Middle received response: {response}")
yield f"Middle: {response}"
@service(resources={"cpu": "1"}, traffic={"timeout": 60}) # Regular HTTP API
class Frontend:
middle = depends(Middle)
def __init__(self) -> None:
print("Starting frontend")
@api
async def generate(self, text):
"""Stream results from the pipeline."""
print(f"Frontend received: {text}")
print(f"Frontend received type: {type(text)}")
txt = RequestType(text=text)
print(f"Frontend sending: {type(txt)}")
async for response in self.middle.generate(txt.model_dump_json()):
yield f"Frontend: {response}"
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
## Prerequisites
Start required services (etcd and NATS):
Option A: Using [Docker Compose](/deploy/docker-compose.yml) (Recommended)
```bash
docker compose -f deploy/docker-compose.yml up -d
```
Option B: Manual Setup
- [NATS.io](https://docs.nats.io/running-a-nats-service/introduction/installation) server with [Jetstream](https://docs.nats.io/nats-concepts/jetstream)
- example: `nats-server -js --trace`
- [etcd](https://etcd.io) server
- follow instructions in [etcd installation](https://etcd.io/docs/v3.5/install/) to start an `etcd-server` locally
## Build docker
```
./container/build.sh
```
## Run container
```
./container/run.sh -it
```
## Run deployment
This figure shows an overview of the major components to deploy:
```
+----------------+
+------| prefill worker |-------+
notify | | | |
finished | +----------------+ | pull
v v
+------+ +-----------+ +------------------+ push +---------------+
| HTTP |----->| processor |----->| decode/monolith |------------>| prefill queue |
| |<-----| |<-----| worker | | |
+------+ +-----------+ +------------------+ +---------------+
| ^ |
query best | | return | publish kv events
worker | | worker_id v
| | +------------------+
| +---------| kv-router |
+------------->| |
+------------------+
```
### Disaggregated vLLM deployment
Serve following components:
- processor: Processor routes the requests to the (decode) workers. Three scheduling strategies are supported: random and kv.
- kv router: The KV Router is a component that aggregates KV Events from all the workers and maintains
a prefix tree of the cached tokens. It makes decisions on which worker to route requests
to based on the length of the prefix match and the load on the workers.
- decode worker: runs on gpu = 0
- prefill worker: runs on gpu = 1
```bash
cd /workspace/deploy/examples/vllm
dynamo serve disaggregated.processor:Processor \
--Processor.model=deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--Processor.tokenizer=deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--Processor.block-size=64 \
--Processor.max-model-len=16384 \
--Processor.router=kv \
--Router.min-workers=1 \
--Router.block-size=64 \
--Router.model-name=deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--VllmWorker.remote-prefill=true \
--VllmWorker.model=deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--VllmWorker.enforce-eager=true \
--VllmWorker.tensor-parallel-size=1 \
--VllmWorker.kv-transfer-config='{"kv_connector": "DynamoNixlConnector"}' \
--VllmWorker.block-size=64 \
--VllmWorker.max-num-batched-tokens=16384 \
--VllmWorker.max-model-len=16384 \
--VllmWorker.router=kv \
--VllmWorker.enable-prefix-caching=true \
--PrefillWorker.model=deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--PrefillWorker.enforce-eager=true \
--PrefillWorker.block-size=64 \
--PrefillWorker.max-model-len=16384 \
--PrefillWorker.max-num-batched-tokens=16384 \
--PrefillWorker.kv-transfer-config='{"kv_connector": "DynamoNixlConnector"}' \
--PrefillWorker.cuda-visible-device-offset=1
```
Add model to dynamo and start http server.
```
llmctl http add chat-models deepseek-ai/DeepSeek-R1-Distill-Llama-8B dynamo-init.Processor.chat_completions
TRT_LOG=DEBUG http --port 8181
```
### Client
In another terminal:
```bash
# this test request has around 200 tokens isl
curl localhost:8181/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"messages": [
{
"role": "user",
"content": "In the heart of Eldoria, an ancient land of boundless magic and mysterious creatures, lies the long-forgotten city of Aeloria. Once a beacon of knowledge and power, Aeloria was buried beneath the shifting sands of time, lost to the world for centuries. You are an intrepid explorer, known for your unparalleled curiosity and courage, who has stumbled upon an ancient map hinting at ests that Aeloria holds a secret so profound that it has the potential to reshape the very fabric of reality. Your journey will take you through treacherous deserts, enchanted forests, and across perilous mountain ranges. Your Task: Character Background: Develop a detailed background for your character. Describe their motivations for seeking out Aeloria, their skills and weaknesses, and any personal connections to the ancient city or its legends. Are they driven by a quest for knowledge, a search for lost familt clue is hidden."
}
],
"stream":false,
"max_tokens": 30
}'
```
### Close deployment
Kill all python processes and clean up metadata files:
```
pkill -9 -f python
```
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from vllm.logger import logger as vllm_logger
class PyDisaggregatedRouter:
def __init__(
self,
runtime,
served_model_name,
max_local_prefill_length=1000,
):
self.runtime = runtime
self.served_model_name = served_model_name
self.max_local_prefill_length = max_local_prefill_length
def prefill_remote(self, prompt_length: int, prefix_hit_rate: float):
absolute_prefill_length = int(prompt_length * (1 - prefix_hit_rate))
vllm_logger.info(
f"Remote prefill: {absolute_prefill_length > self.max_local_prefill_length} (prefill length: {absolute_prefill_length}/{prompt_length})"
)
return absolute_prefill_length > self.max_local_prefill_length
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import random
from argparse import Namespace
from typing import AsyncIterator
from utils.protocol import Tokens
from vllm.logger import logger as vllm_logger
from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores
from dynamo.sdk import async_onstart, dynamo_context, dynamo_endpoint, service
from dynamo.sdk.lib.config import ServiceConfig
WorkerId = str
def parse_args(service_name, prefix) -> Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--min-workers",
type=int,
default=1,
help="Minimum number of workers required before proceeding",
)
parser.add_argument(
"--model-name",
type=str,
default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
help="Model that is being served",
)
# TODO: Read block size
parser.add_argument(
"--block-size",
type=int,
default=64,
help="KV block size",
)
parser.add_argument(
"--custom-router",
type=bool,
default=False,
help="Whether to use custom router or not",
)
config = ServiceConfig.get_instance()
config_args = config.as_args(service_name, prefix=prefix)
args = parser.parse_args(config_args)
return args
@service(
dynamo={
"enabled": True,
"namespace": "dynamo-init",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
class Router:
"""
Request handler for the generate endpoint
"""
def __init__(self):
vllm_logger.info("Initializing Custom Router")
self.args = parse_args(self.__class__.__name__, "")
print("[ROUTER] args = ", self.args)
@async_onstart
async def async_init(self):
self.runtime = dynamo_context["runtime"]
self.workers_client = (
await self.runtime.namespace("dynamo-init")
.component("VllmWorker")
.endpoint("generate")
.client()
)
while len(self.workers_client.endpoint_ids()) < self.args.min_workers:
# TODO: replace print w/ vllm_logger.info
print(
f"Waiting for more workers to be ready.\n"
f" Current: {len(self.workers_client.endpoint_ids())},"
f" Required: {self.args.min_workers}"
)
await asyncio.sleep(2)
kv_listener = self.runtime.namespace("dynamo-init").component("VllmWorker")
await kv_listener.create_service()
self.indexer = KvIndexer(kv_listener, self.args.block_size)
self.metrics_aggregator = KvMetricsAggregator(kv_listener)
print("KV Router initialized")
def _cost_function(
self,
scores: OverlapScores | None,
metrics: AggregatedMetrics | None,
token_length: int,
):
worker_scores = {}
if scores:
for worker_id, score in scores.scores.items():
# score is number of matching blocks we multiply by block_size to get tokens
# and compare to token_length. The larger the cache hit the better
worker_scores[worker_id] = (
score * self.indexer.block_size() / token_length
)
worker_metrics = {}
# pull metrics for each worker
max_waiting = 0.0
if metrics:
print("[ROUTER] metrics.endpoint ", metrics.endpoints)
for endpoint in metrics.endpoints:
worker_id = endpoint.worker_id
worker_metrics[worker_id] = {
"gpu_cache_usage_perc": endpoint.gpu_cache_usage_perc
if hasattr(endpoint, "gpu_cache_usage_perc")
else 0.0,
"num_requests_waiting": endpoint.num_requests_waiting
if hasattr(endpoint, "num_requests_waiting")
else 0.0,
"gpu_prefix_cache_hit_rate": endpoint.gpu_prefix_cache_hit_rate
if hasattr(endpoint, "gpu_prefix_cache_hit_rate")
else 0.0,
}
max_waiting = max(
max_waiting, worker_metrics[worker_id]["num_requests_waiting"]
)
# Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
# and we want all workers to be considered in the logit calculation
worker_ids = self.workers_client.endpoint_ids()
worker_logits = {}
for worker_id in worker_ids:
# Use default values if worker not in scores or metrics
score = worker_scores.get(worker_id, 0.0)
metrics_dict = worker_metrics.get(
worker_id,
{
"gpu_cache_usage_perc": 0.0,
"num_requests_waiting": 0.0,
"gpu_prefix_cache_hit_rate": 0.0,
},
)
normalized_waiting = (
metrics_dict["num_requests_waiting"] / max_waiting
if max_waiting > 0
else 0.0
)
# Have 1 metric that weights towards cache hit
# 2 metrics that penalize overloaded worker and queuing
worker_logits[worker_id] = (
2 * score - metrics_dict["gpu_cache_usage_perc"] - normalized_waiting
)
vllm_logger.info(
f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = 2.0 * {score:.3f} - {metrics_dict['gpu_cache_usage_perc']:.3f} - {normalized_waiting:.3f}"
)
if not worker_logits or all(logit == 0 for logit in worker_logits.values()):
return ""
# Select the worker with the highest logit
if worker_logits:
max_logit = max(worker_logits.values())
best_workers = [
wid for wid, logit in worker_logits.items() if logit == max_logit
]
best_worker_id = random.choice(best_workers)
else:
best_worker_id = ""
# Log the metrics for the selected worker
if best_worker_id:
vllm_logger.info(
f"Selected worker: {best_worker_id}, logit: {worker_logits[best_worker_id]:.3f}"
)
vllm_logger.info(
f"Score: {scores.scores.get(best_worker_id, 0.0) if scores else 0.0:.3f}"
)
metrics_dict = worker_metrics.get(best_worker_id, {})
vllm_logger.info(
f"GPU Cache Hit Rate: {metrics_dict.get('gpu_prefix_cache_hit_rate', 0.0):.3f}"
)
vllm_logger.info(
f"GPU Cache Usage: {metrics_dict.get('gpu_cache_usage_perc', 0.0):.3f}"
)
vllm_logger.info(
f"Requests Waiting: {metrics_dict.get('num_requests_waiting', 0.0) / max_waiting if max_waiting > 0 else 0.0:.3f}"
)
return best_worker_id, worker_scores.get(best_worker_id, 0.0)
@dynamo_endpoint()
async def generate(self, request: Tokens) -> AsyncIterator[WorkerId]:
lora_id = 0
worker_id = ""
try:
scores = await self.indexer.find_matches_for_request(
request.tokens, lora_id
)
except Exception as e:
scores = {}
vllm_logger.exception(f"Error finding matches: {e}")
token_length = len(request.tokens)
metrics = await self.metrics_aggregator.get_metrics()
schedule_result = self._cost_function(scores, metrics, token_length)
if schedule_result == "":
worker_id = ""
prefix_hit_rate = 0.0
else:
worker_id, prefix_hit_rate = schedule_result
vllm_logger.info(
f"Scheduling to worker_id: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
)
yield f"{worker_id}_{prefix_hit_rate}"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from pydantic import BaseModel
from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue
from utils.vllm import parse_vllm_args
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.inputs.data import TokensPrompt
from vllm.logger import logger as vllm_logger
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from dynamo.sdk import (
async_onstart,
dynamo_context,
dynamo_endpoint,
server_context,
service,
)
class RequestType(BaseModel):
text: str
os.environ["VLLM_LOG_LEVEL"] = "DEBUG"
@service(
dynamo={
"enabled": True,
"namespace": "dynamo-init",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class PrefillWorker:
def __init__(self):
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
gpu_idx = (
self.engine_args.cuda_visible_device_offset
+ server_context.worker_index
- 1
)
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_idx}"
self._loaded_metadata = set()
self.initialized = False
if self.engine_args.enable_chunked_prefill is not False:
print("Chunked prefill is not supported yet, setting to False")
self.engine_args.enable_chunked_prefill = False
if self.engine_args.pipeline_parallel_size != 1:
print("Pipeline parallel size is not supported yet, setting to 1")
self.engine_args.pipeline_parallel_size = 1
if self.engine_args.disable_async_output_proc is not True:
print("Async output processing is not supported yet, setting to True")
self.engine_args.disable_async_output_proc = True
if self.engine_args.enforce_eager is not True:
print("Prefill must be done eagerly, setting to True")
self.engine_args.enforce_eager = True
print("PrefillWorker initialized")
@async_onstart
async def async_init(self):
self._engine_context = build_async_engine_client_from_engine_args(
self.engine_args
)
if self._engine_context is not None:
self.engine_client = await self._engine_context.__aenter__()
else:
raise RuntimeError("Failed to initialize engine client")
runtime = dynamo_context["runtime"]
metadata = self.engine_client.nixl_metadata
self._metadata_store = NixlMetadataStore("dynamo-init", runtime)
await self._metadata_store.put(metadata.engine_id, metadata)
task = asyncio.create_task(self.prefill_queue_handler())
task.add_done_callback(lambda _: print("prefill queue handler created"))
async def prefill_queue_handler(self):
print("[DEBUG] prefill queue handler entered")
prefill_queue_nats_server = os.getenv("NATS_SERVER", "nats://localhost:4222")
prefill_queue_stream_name = (
self.engine_args.served_model_name
if self.engine_args.served_model_name is not None
else "vllm"
)
print(f"Prefill queue: {prefill_queue_nats_server}:{prefill_queue_stream_name}")
self.initialized = True
# TODO: integrate prefill_queue to a dynamo endpoint
async with PrefillQueue.get_instance(
nats_server=prefill_queue_nats_server,
stream_name=prefill_queue_stream_name,
) as prefill_queue:
print("prefill queue handler started")
while True:
# TODO: this might add a small overhead to pull prefill from nats
# need to test and check how much overhead it is
prefill_request = await prefill_queue.dequeue_prefill_request()
if prefill_request is not None:
vllm_logger.info(f"Dequeued prefill request: {prefill_request}")
async for _ in self.generate(prefill_request):
pass
async def generate(self, request: RemotePrefillRequest):
sampling_params = request.sampling_params
sampling_params.max_tokens = 1
sampling_params.min_tokens = 1
remote_prefill_params = RemotePrefillParams(
is_remote_decode=True,
decode_block_ids=request.block_ids,
decode_engine_id=request.engine_id,
)
# TODO check if metadata has changed
# and reload - currently only loading once
if request.engine_id not in self._loaded_metadata:
remote_metadata = await self._metadata_store.get(request.engine_id)
await self.engine_client.add_remote_nixl_metadata(remote_metadata)
print(
f"Loaded nixl metadata from engine {request.engine_id} into "
f"engine {self.engine_client.nixl_metadata.engine_id}"
)
self._loaded_metadata.add(request.engine_id)
async for _ in self.engine_client.generate(
request_id=request.request_id,
prompt=TokensPrompt(prompt_token_ids=request.prompt_token_ids),
sampling_params=sampling_params,
remote_prefill_params=remote_prefill_params,
):
yield
@dynamo_endpoint()
async def mock(self, req: RequestType):
yield f"mock_response: {req}"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import uuid
from enum import Enum
from typing import AsyncIterator, Tuple, Union
from disaggregated.kv_router import Router
from disaggregated.worker import VllmWorker
from transformers import AutoTokenizer
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.protocol import MyRequestOutput, Tokens, vLLMGenerateRequest
from utils.vllm import parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest
from vllm.logger import logger as vllm_logger
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from dynamo.sdk import depends, dynamo_context, dynamo_endpoint, service
os.environ["VLLM_LOG_LEVEL"] = "DEBUG"
class RequestType(Enum):
CHAT = "chat"
COMPLETION = "completion"
@service(
dynamo={
"enabled": True,
"namespace": "dynamo-init",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
class Processor(ProcessMixIn):
"""
vLLM pre and post processing
"""
worker = depends(VllmWorker)
router = depends(Router)
def __init__(self):
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self.model_config = self.engine_args.create_model_config()
print(f"[Processor] self.engine_args: {self.engine_args}")
self.tokenizer = self._create_tokenizer(self.engine_args)
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
self.completions_processor = CompletionsProcessor(
self.tokenizer, self.model_config
)
self.router_mode = self.engine_args.router
def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer:
"""Create a TokenizerGroup using engine arguments similar to VLLM's approach"""
model_path = engine_args.model
# Create the base tokenizer with VLLM's typical settings
base_tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
padding_side="left",
truncation_side="left",
use_fast=True, # VLLM might use the fast tokenizer for efficiency
)
return base_tokenizer
async def _generate(
self,
raw_request: Union[CompletionRequest, ChatCompletionRequest],
request_type: RequestType,
):
request_id = str(uuid.uuid4())
vllm_logger.debug(f"Got raw request: {raw_request}")
(
request,
conversation,
prompt,
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
runtime = dynamo_context["runtime"]
comp_ns, comp_name = VllmWorker.dynamo_address() # type: ignore
print(f"[Processor] comp_ns: {comp_ns}, comp_name: {comp_name}")
worker_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("generate")
.client()
)
if self.router_mode == "kv":
async for route_response in self.router.generate(
Tokens(tokens=engine_prompt["prompt_token_ids"]).model_dump_json()
):
worker_id, prefix_hit_rate = route_response.split("_")
prefix_hit_rate = float(prefix_hit_rate)
vllm_logger.info(
f"Worker ID: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
)
break
if worker_id == "":
engine_generator = await worker_client.generate(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
prefix_hit_rate=prefix_hit_rate,
).model_dump_json()
)
else:
engine_generator = await worker_client.direct(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
prefix_hit_rate=prefix_hit_rate,
).model_dump_json(),
int(worker_id),
)
elif self.router_mode == "random":
engine_generator = await worker_client.generate(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json()
)
# TODO: add round-robin mode
# elif self.router_mode == "round-robin":
# engine_generator = await self.worker.round_robin(
# vLLMGenerateRequest(
# engine_prompt=engine_prompt,
# sampling_params=sampling_params,
# request_id=request_id,
# ).model_dump_json()
# )
output = self._generate_responses(engine_generator, request_type)
async for response in await self._stream_response(
request, output, request_id, conversation
):
yield response
async def _generate_responses(
self, engine_generator: AsyncIterator[RequestOutput], request_type: RequestType
) -> AsyncIterator[Union[RequestOutput, Tuple[int, RequestOutput]]]:
prompt_idx = 0
async for resp in engine_generator:
# Deserialize the response from the engine
# Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data())
# OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object
request_output = RequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
)
if request_type == RequestType.CHAT:
# For chat requests, yield the request_output directly.
yield request_output
elif request_type == RequestType.COMPLETION:
# Completion requests can have multiple prompts and stream generator requires the prompt index
yield (prompt_idx, request_output)
else:
raise NotImplementedError(
f"Request type {request_type} not implemented"
)
@dynamo_endpoint()
async def chat_completions(self, raw_request: ChatCompletionRequest):
async for response in self._generate(raw_request, RequestType.CHAT):
yield response
# @dynamo_endpoint()
# async def completions(self, raw_request: CompletionRequest):
# async for response in self._generate(raw_request, RequestType.COMPLETION):
# yield response
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from disaggregated.disagg_router import PyDisaggregatedRouter
from disaggregated.prefill_worker import PrefillWorker
from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue
from utils.protocol import MyRequestOutput, vLLMGenerateRequest
from utils.vllm import parse_vllm_args
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.logger import logger as vllm_logger
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from vllm.sampling_params import RequestOutputKind
from dynamo.llm import KvMetricsPublisher
from dynamo.sdk import (
async_onstart,
depends,
dynamo_context,
dynamo_endpoint,
server_context,
service,
)
os.environ["VLLM_LOG_LEVEL"] = "DEBUG"
@service(
dynamo={
"enabled": True,
"namespace": "dynamo-init",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class VllmWorker:
prefill_worker = depends(PrefillWorker)
def __init__(self):
self.client = None
self.disaggregated_router: PyDisaggregatedRouter = None # type: ignore
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self.do_remote_prefill = self.engine_args.remote_prefill
self.model_name = (
self.engine_args.served_model_name
if self.engine_args.served_model_name is not None
else "vllm"
)
self._prefill_queue_nats_server = os.getenv(
"NATS_SERVER", "nats://localhost:4222"
)
self._prefill_queue_stream_name = self.model_name
vllm_logger.info(
f"Prefill queue: {self._prefill_queue_nats_server}:{self._prefill_queue_stream_name}"
)
if self.engine_args.remote_prefill:
if self.engine_args.enable_chunked_prefill is not False:
print("Chunked prefill is not supported yet, setting to False")
self.engine_args.enable_chunked_prefill = False
if self.engine_args.preemption_mode != "swap":
print("Preemption mode is not supported yet, setting to swap")
self.engine_args.preemption_mode = "swap"
if self.engine_args.pipeline_parallel_size != 1:
print("Pipeline parallel size is not supported yet, setting to 1")
self.engine_args.pipeline_parallel_size = 1
if self.engine_args.router == "kv":
VLLM_WORKER_ID = dynamo_context["endpoints"][0].lease_id()
os.environ["VLLM_WORKER_ID"] = str(VLLM_WORKER_ID)
os.environ["VLLM_KV_NAMESPACE"] = "dynamo-init"
os.environ["VLLM_KV_COMPONENT"] = class_name
vllm_logger.info(f"Generate endpoint ID: {VLLM_WORKER_ID}")
# note: worker_index is 1-based, but CUDA_VISIBLE_DEVICES is 0-based
gpu_idx = (
self.engine_args.cuda_visible_device_offset
+ server_context.worker_index
- 1
)
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_idx}"
self.metrics_publisher = KvMetricsPublisher()
@async_onstart
async def async_init(self):
self._engine_context = build_async_engine_client_from_engine_args(
self.engine_args
)
if self._engine_context is not None:
self.engine_client = await self._engine_context.__aenter__()
else:
raise RuntimeError("Failed to initialize engine client")
if self.engine_args.router == "kv":
assert self.engine_client is not None, "engine_client was not initialized"
self.engine_client.set_metrics_publisher(self.metrics_publisher)
# Initially send dummy metrics to kick start,
# vLLM will not update stat until forward pass is triggered
self.metrics_publisher.publish(
0, # request_active_slots
1024, # request_total_slots
0, # kv_active_blocks
1024, # kv_total_blocks
0, # num_requests_waiting
0.0, # gpu_cache_usage_perc
0.0, # gpu_prefix_cache_hit_rate
)
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback(
lambda _: print("metrics publisher endpoint created")
)
runtime = dynamo_context["runtime"]
if self.engine_args.remote_prefill:
metadata = self.engine_client.nixl_metadata
metadata_store = NixlMetadataStore("dynamo-init", runtime)
await metadata_store.put(metadata.engine_id, metadata)
if self.engine_args.conditional_disagg:
self.disaggregated_router = PyDisaggregatedRouter(
runtime,
self.model_name,
max_local_prefill_length=self.engine_args.max_local_prefill_length,
)
else:
self.disaggregated_router = None
print("VllmWorker has been initialized")
async def create_metrics_publisher_endpoint(self):
component = dynamo_context["component"]
await self.metrics_publisher.create_endpoint(component)
def get_remote_prefill_request_callback(self):
# TODO: integrate prefill_queue to dynamo endpoint
async def callback(request: RemotePrefillRequest):
async with PrefillQueue.get_instance(
nats_server=self._prefill_queue_nats_server,
stream_name=self._prefill_queue_stream_name,
) as prefill_queue:
await prefill_queue.enqueue_prefill_request(request)
return callback
@dynamo_endpoint()
async def generate(self, request: vLLMGenerateRequest):
# TODO: consider prefix hit when deciding prefill locally or remotely
if self.disaggregated_router is not None:
disagg_router_decision = self.disaggregated_router.prefill_remote(
len(request.engine_prompt["prompt_token_ids"]), request.prefix_hit_rate
)
else:
# always prefill remotely if no disaggregated router is provided
disagg_router_decision = True
if self.do_remote_prefill and disagg_router_decision:
remote_prefill_params = RemotePrefillParams(
is_remote_prefill=True,
remote_prefill_request_callback=self.get_remote_prefill_request_callback(),
)
vllm_logger.info(
f"Prefilling remotely for request {request.request_id} with length {len(request.engine_prompt['prompt_token_ids'])}"
)
else:
remote_prefill_params = None
vllm_logger.info(
f"Prefilling locally for request {request.request_id} with length {len(request.engine_prompt['prompt_token_ids'])}"
)
# rust HTTP requires Delta streaming
request.sampling_params.output_kind = RequestOutputKind.DELTA
async for response in self.engine_client.generate(
prompt=request.engine_prompt,
sampling_params=request.sampling_params,
request_id=request.request_id,
remote_prefill_params=remote_prefill_params,
):
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
).model_dump_json()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
from typing import AsyncIterator, List, Optional, Protocol, Union, runtime_checkable
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.chat_utils import ConversationMessage
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
RequestResponseMetadata,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import RequestPrompt
from vllm.inputs.data import TokensPrompt
from vllm.transformers_utils.tokenizer import AnyTokenizer
@runtime_checkable
class ProcessMixInRequired(Protocol):
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
class ProcessMixIn(ProcessMixInRequired):
"""
Mixin for pre and post processing for vLLM
Requires engine_args, engine_client, processor, model_config to be initialized
"""
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
def __init__(self):
pass
def _get_processor(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
# Determine the processor type based on the request structure
return (
self.chat_processor
if isinstance(raw_request, ChatCompletionRequest)
else self.completions_processor
)
async def _parse_raw_request(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
processor = self._get_processor(raw_request)
if processor is None:
raise RuntimeError("Processor has not been initialized")
request = processor.parse_raw_request(raw_request)
preprocess_result = await processor.preprocess(raw_request)
default_max_tokens = self.model_config.max_model_len - len(
preprocess_result.engine_prompt["prompt_token_ids"]
)
default_sampling_params = self.model_config.get_diff_sampling_param()
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern,
default_sampling_params,
)
return (
request,
preprocess_result.conversation,
preprocess_result.request_prompt,
preprocess_result.engine_prompt,
sampling_params,
)
async def _stream_response(self, request, generator, request_id, conversation):
processor = self._get_processor(request)
if processor is None:
raise RuntimeError("processor has not been initialized")
return processor.stream_response(
request,
generator,
request_id,
conversation,
)
class PreprocessResult:
def __init__(
self,
conversation: Optional[ConversationMessage],
request_prompt: RequestPrompt,
engine_prompt: TokensPrompt,
):
self.conversation = conversation
self.request_prompt = request_prompt
self.engine_prompt = engine_prompt
class ChatProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingChat(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
response_role="assistant",
chat_template=None,
chat_template_content_format="auto",
)
def parse_raw_request(
self, raw_request: ChatCompletionRequest
) -> ChatCompletionRequest:
return ChatCompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: ChatCompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
(
conversation,
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_chat(
request,
self.tokenizer,
request.messages,
chat_template=request.chat_template or self.tokenizer.chat_template,
chat_template_content_format=self.openai_serving.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tool_dicts=None,
documents=request.documents,
chat_template_kwargs=request.chat_template_kwargs,
tool_parser=self.openai_serving.tool_parser,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
return PreprocessResult(conversation[0], request_prompts[0], engine_prompts[0])
async def stream_response(
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: List,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if not request.stream:
raise ValueError("Only streaming responses are supported")
async for raw_response in self.openai_serving.chat_completion_stream_generator(
request,
result_generator,
request_id,
request.model,
conversation,
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
class CompletionsProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingCompletion(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
)
def parse_raw_request(self, raw_request: CompletionRequest) -> CompletionRequest:
return CompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: CompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
(
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_completion(
request,
self.tokenizer,
input_or_inputs=request.prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
return PreprocessResult(None, request_prompts[0], engine_prompts[0])
async def stream_response(
self,
request: CompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: Optional[List[ConversationMessage]] = None,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if not request.stream:
raise ValueError("Only streaming responses are supported")
async for raw_response in self.openai_serving.completion_stream_generator(
request,
result_generator,
request_id,
int(time.time()), # created_time
request.model,
1, # num_prompts
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from contextlib import asynccontextmanager
from typing import ClassVar, Optional
from nats.aio.client import Client as NATS
from nats.errors import Error as NatsError
from nats.js.client import JetStreamContext
from nats.js.errors import NotFoundError
class NATSQueue:
_instance: ClassVar[Optional["NATSQueue"]] = None
_lock: ClassVar[asyncio.Lock] = asyncio.Lock()
def __init__(
self,
stream_name: str = "default",
nats_server: str = "nats://localhost:4222",
dequeue_timeout: float = 1,
):
self.nats_url = nats_server
self._nc: Optional[NATS] = None
self._js: Optional[JetStreamContext] = None
# TODO: check if this is needed
# Sanitize stream_name to remove path separators
self._stream_name = stream_name.replace("/", "_").replace("\\", "_")
self._subject = f"{self._stream_name}.*"
self.dequeue_timeout = dequeue_timeout
self._subscriber: Optional[JetStreamContext.PullSubscription] = None
@classmethod
@asynccontextmanager
async def get_instance(
cls,
*,
stream_name: str = "default",
nats_server: str = "nats://localhost:4222",
dequeue_timeout: float = 1,
):
"""Get or create a singleton instance of NATSq"""
# TODO: check if this _lock is needed with GIL
async with cls._lock:
if cls._instance is None:
cls._instance = cls(
stream_name=stream_name,
nats_server=nats_server,
dequeue_timeout=dequeue_timeout,
)
await cls._instance.connect()
try:
yield cls._instance
except Exception:
if cls._instance:
await cls._instance.close()
cls._instance = None
raise
# TODO: check to see if this can be replaced by something like get_instance().close()
@classmethod
async def shutdown(cls):
"""Explicitly close the singleton instance if it exists"""
async with cls._lock:
if cls._instance:
await cls._instance.close()
cls._instance = None
async def connect(self):
"""Establish connection and create stream if needed"""
try:
if self._nc is None:
self._nc = NATS()
await self._nc.connect(self.nats_url)
self._js = self._nc.jetstream()
# Check if stream exists, if not create it
try:
await self._js.stream_info(self._stream_name)
except NotFoundError:
await self._js.add_stream(
name=self._stream_name, subjects=[self._subject]
)
# Create persistent subscriber
self._subscriber = await self._js.pull_subscribe(
f"{self._stream_name}.queue", durable="worker-group"
)
except NatsError as e:
await self.close()
raise ConnectionError(f"Failed to connect to NATS: {e}")
async def ensure_connection(self):
"""Ensure we have an active connection"""
if self._nc is None or self._nc.is_closed:
await self.connect()
async def close(self):
"""Close the connection when done"""
if self._nc:
await self._nc.close()
self._nc = None
self._js = None
self._subscriber = None
# TODO: is enqueue/dequeue_object a better name for a general queue?
async def enqueue_task(self, task_data: bytes) -> None:
"""
Enqueue a task using msgspec-encoded data
"""
await self.ensure_connection()
try:
await self._js.publish(f"{self._stream_name}.queue", task_data) # type: ignore
except NatsError as e:
raise RuntimeError(f"Failed to enqueue task: {e}")
async def dequeue_task(self) -> Optional[bytes]:
"""Dequeue and return a task as raw bytes, to be decoded with msgspec"""
await self.ensure_connection()
try:
msgs = await self._subscriber.fetch(1, timeout=self.dequeue_timeout) # type: ignore
if msgs:
msg = msgs[0]
await msg.ack()
return msg.data
return None
except asyncio.TimeoutError:
return None
except NatsError as e:
raise RuntimeError(f"Failed to dequeue task: {e}")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import contextmanager
import msgspec
from vllm.distributed.device_communicators.nixl import NixlMetadata
from dynamo.runtime import DistributedRuntime
METADATA_DIR = "/tmp/nixl"
@contextmanager
def temp_metadata_file(engine_id, metadata: NixlMetadata):
os.makedirs(METADATA_DIR, exist_ok=True)
path = f"{METADATA_DIR}/{engine_id}.nixl_meta"
with open(path, "wb") as f:
encoded = msgspec.msgpack.encode(metadata)
print(f"Size of encoded metadata: {len(encoded)}")
f.write(encoded)
try:
yield path
finally:
if os.path.exists(path):
os.remove(path)
def find_remote_metadata(engine_id):
# find and load metadata from METADATA_DIR that do not match engine_id
remote_metadata = []
for file in os.listdir(METADATA_DIR):
if file.endswith(".nixl_meta"):
if file.split(".")[0] != engine_id:
with open(os.path.join(METADATA_DIR, file), "rb") as f:
remote_metadata.append(
msgspec.msgpack.decode(f.read(), type=NixlMetadata)
)
return remote_metadata
class NixlMetadataStore:
NIXL_METADATA_KEY = "nixl_metadata"
def __init__(self, namespace: str, runtime: DistributedRuntime) -> None:
self._namespace = namespace
# TODO Remove metadata from etcd on delete
self._stored: set[str] = set()
self._cached: dict[str, NixlMetadata] = {}
self._client = runtime.etcd_client()
self._key_prefix = f"{self._namespace}/{NixlMetadataStore.NIXL_METADATA_KEY}"
async def put(self, engine_id, metadata: NixlMetadata):
serialized_metadata = msgspec.msgpack.encode(metadata)
key = "/".join([self._key_prefix, engine_id])
await self._client.kv_put(key, serialized_metadata, None)
self._stored.add(engine_id)
async def get(self, engine_id) -> NixlMetadata:
try:
if engine_id in self._cached:
return self._cached[engine_id]
key = "/".join([self._key_prefix, engine_id])
key_values = await self._client.kv_get_prefix(key)
deserialized_metadata = None
for item in key_values:
deserialized_metadata = msgspec.msgpack.decode(
item["value"], type=NixlMetadata
)
break
if deserialized_metadata is None:
raise Exception("metadata not found in etcd")
self._cached[engine_id] = deserialized_metadata
# TODO watch for changes and update cache
# self._client.add_watch_callback(
# key,
# self._watch_callback,
# )
except Exception as e:
raise Exception("Error retrieving metadata for engine {engine_id}") from e
return deserialized_metadata
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import msgspec
from utils.nats_queue import NATSQueue
from vllm.remote_prefill import RemotePrefillRequest
class PrefillQueue(NATSQueue):
"""
A wrapper of NATSQueue for PrefillRequest.
The stream name is forced to be "prefill_queue".
"""
def __init__(
self,
stream_name="prefill_queue",
nats_server: str = "nats://localhost:4222",
dequeue_timeout: float = 1,
):
super().__init__(
stream_name=stream_name,
nats_server=nats_server,
dequeue_timeout=dequeue_timeout,
)
async def enqueue_prefill_request(
self, prefill_request: RemotePrefillRequest
) -> None:
encoded_request = msgspec.json.encode(prefill_request)
await self.enqueue_task(encoded_request)
async def dequeue_prefill_request(self) -> Optional[RemotePrefillRequest]:
encoded_request = await self.dequeue_task()
if encoded_request is not None:
prefill_request = msgspec.json.decode(
encoded_request, type=RemotePrefillRequest
)
return prefill_request
else:
return None
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, List, Optional
import msgspec
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic_core import core_schema
from typing_extensions import NotRequired
from vllm.inputs.data import TokensPrompt
from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import PromptLogprobs, RequestMetrics
class Request(BaseModel):
prompt: str
sampling_params: dict
class Tokens(BaseModel):
tokens: list[int]
class PrefillRequest(Request):
request_id: str
class Response(BaseModel):
text: str
class PrefillResponse(BaseModel):
prefilled: bool
# Hack to override the type of multi_modal_data in TokensPrompt
# as pydantic doesn't understand generic types
# TokensPrompt is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/inputs/data.py#L38
# multi_modal_data is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L103
# ModalityData is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L80
class PatchedTokensPrompt(TokensPrompt):
multi_modal_data: NotRequired[Optional[Any]] # type: ignore
# Monkey-patch the SamplingParams type to add a dummy core schema so pydantic can validate it
# Sampling params is a mspspec struct
# SamplingParams is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/sampling_params.py#L88
SamplingParams.__get_pydantic_core_schema__ = classmethod(
lambda cls, source, handler: core_schema.any_schema()
)
class vLLMGenerateRequest(BaseModel):
"""
Serializable class of all the fields vLLM engine requires for inference
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
engine_prompt: PatchedTokensPrompt
sampling_params: SamplingParams
request_id: str
prefix_hit_rate: Optional[float] = 0.0
@field_validator("sampling_params", mode="before")
@classmethod
def parse_sampling_params(cls, v: Any) -> SamplingParams:
if isinstance(v, str):
v = json.loads(v)
if isinstance(v, dict):
return SamplingParams(**v)
return v
model_config = ConfigDict(
json_encoders={SamplingParams: lambda v: msgspec.json.encode(v)}
)
class MyRequestOutput(BaseModel):
"""
RequestOutput from vLLM is not serializable by default
https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/outputs.py#L85
This class is used to serialize the RequestOutput and any recursively defined types
We can do this because PromptLogprobs, RequestMetrics, and CompletionOutput are all serializable dataclasses
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str
prompt: Optional[str] = None
prompt_token_ids: Optional[List[int]] = None
prompt_logprobs: Optional[PromptLogprobs] = None
outputs: List[CompletionOutput]
finished: bool
metrics: Optional[RequestMetrics] = None
# lora_request: Optional[LoRARequest] = None
# encoder_prompt: Optional[str] = None
# encoder_prompt_token_ids: Optional[List[int]] = None
# num_cached_tokens: Optional[int] = None
# multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: rename to avoid ambiguity with vllm package
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser
from dynamo.sdk.lib.config import ServiceConfig
def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
config = ServiceConfig.get_instance()
print(f"[DEBUG] config: {config}")
vllm_args = config.as_args(service_name, prefix=prefix)
print(f"[DEBUG] service_name: {service_name}, vllm_args: {vllm_args}")
parser = FlexibleArgumentParser()
parser.add_argument(
"--router",
type=str,
choices=["random", "round-robin", "kv"],
default="random",
help="Router type to use for scheduling requests to workers",
)
parser.add_argument(
"--remote-prefill", action="store_true", help="Enable remote prefill"
)
parser.add_argument(
"--conditional-disagg",
action="store_true",
help="Use disaggregated router to decide whether to prefill locally or remotely",
)
parser.add_argument(
"--max-local-prefill-length",
type=int,
default=1000,
help="Maximum length of local prefill",
)
parser.add_argument(
"--cuda-visible-device-offset",
type=int,
default=0,
help="Offset of CUDA_VISIBLE_DEVICE",
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args(vllm_args)
engine_args = AsyncEngineArgs.from_cli_args(args)
engine_args.router = args.router
engine_args.remote_prefill = args.remote_prefill
engine_args.conditional_disagg = args.conditional_disagg
engine_args.max_local_prefill_length = args.max_local_prefill_length
engine_args.cuda_visible_device_offset = args.cuda_visible_device_offset
return engine_args
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment