Commit 19844fc0 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: kv aware router + disagg router + prefill queue (#11)


Signed-off-by: default avatarHongkuan Zhou <tedzhouhk@gmail.com>
Co-authored-by: default avatarhongkuan <hongkuanz@nvidia.com>
Co-authored-by: default avatarPiotr Tarasiewicz <ptarasiewicz@nvidia.com>
Co-authored-by: default avatarPiotr Tarasiewicz Nvidia <ptarasiewicznv@Piotrs-MacBook-Pro.local>
Co-authored-by: default avataralec-flowers <aflowers@nvidia.com>
Co-authored-by: default avatarNeelay Shah <neelays@nvidia.com>
parent 7567620f
...@@ -17,6 +17,22 @@ limitations under the License. ...@@ -17,6 +17,22 @@ limitations under the License.
> **NOTE**: This example is based on an internal NVIDIA library that will soon be publicly released. The example won't work until the official release. > **NOTE**: This example is based on an internal NVIDIA library that will soon be publicly released. The example won't work until the official release.
## Prerequisites
Start required services (etcd and NATS):
Option A: Using [Docker Compose](/deploy/docker-compose.yml) (Recommended)
```bash
docker compose -f deploy/docker-compose.yml up -d
```
Option B: Manual Setup
- [NATS.io](https://docs.nats.io/running-a-nats-service/introduction/installation) server with [Jetstream](https://docs.nats.io/nats-concepts/jetstream)
- example: `nats-server -js --trace`
- [etcd](https://etcd.io) server
- follow instructions in [etcd installation](https://etcd.io/docs/v3.5/install/) to start an `etcd-server` locally
## Build docker ## Build docker
``` ```
...@@ -35,65 +51,204 @@ All of the commands below are run inside the same container. ...@@ -35,65 +51,204 @@ All of the commands below are run inside the same container.
Add model to dynamo and start http server. Add model to dynamo and start http server.
In terminal 0:
``` ```
llmctl http add chat-models deepseek-ai/DeepSeek-R1-Distill-Llama-8B test-nixl.vllm.generate
TRT_LOG=DEBUG http --port 8181 TRT_LOG=DEBUG http --port 8181
``` ```
### Router-less Deployment
Router-less deployment without kv router and disaggregated router.
### Monolithic deployment For router-less deployment, the client should directly hit the vllm.generate endpoint,
```
llmctl http add chat-models deepseek-ai/DeepSeek-R1-Distill-Llama-8B dynamo-init.vllm.generate
```
In terminal 1: #### Monolithic
``` ```
cd /workspace/examples/python_rs/llm/vllm_nixl cd /workspace/examples/python_rs/llm/vllm_nixl
CUDA_VISIBLE_DEVICES=0 python3 worker.py \ CUDA_VISIBLE_DEVICES=0 python3 routerless/worker.py \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager --enforce-eager
``` ```
#### Disaggregated
### Disaggregated deployment In disaggregated router-less deployment, the decode worker will directly send requests to a random prefill worker. All the requests will be sent to prefill worker(s) for remote prefill.
In terminal 1: In terminal 1:
``` ```
cd /workspace/examples/python_rs/llm/vllm_nixl cd /workspace/examples/python_rs/llm/vllm_nixl
CUDA_VISIBLE_DEVICES=0 python prefill_worker.py \ CUDA_VISIBLE_DEVICES=0 python routerless/prefill_worker.py \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager \ --enforce-eager \
--block-size 64 \
--kv-transfer-config \ --kv-transfer-config \
'{"kv_connector":"DynamoNixlConnector"}' '{"kv_connector":"DynamoNixlConnector"}'
``` ```
In terminal 2: In terminal 2:
``` ```
cd /workspace/examples/python_rs/llm/vllm_nixl cd /workspace/examples/python_rs/llm/vllm_nixl
CUDA_VISIBLE_DEVICES=1,2 python3 worker.py \ CUDA_VISIBLE_DEVICES=1,2 python3 routerless/worker.py \
--remote-prefill \ --remote-prefill \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager \ --enforce-eager \
--block-size 64 \
--tensor-parallel-size 2 \ --tensor-parallel-size 2 \
--kv-transfer-config \ --kv-transfer-config \
'{"kv_connector":"DynamoNixlConnector"}' '{"kv_connector":"DynamoNixlConnector"}'
``` ```
### Router-based Deployment
Router-based deployment use kv router to schedule the request to the best decode worker and disaggregated router to decide whether to prefill locally or remotely. The remote prefill requests will be sent to a global prefill queue to balance the prefill load.
For router deployment, the client should hit the endpoint of the processor,
```
llmctl http add chat-models deepseek-ai/DeepSeek-R1-Distill-Llama-8B dynamo-init.process.chat/completions
```
To launch disaggregated vllm deployment, there are four major components:
1. Processor
2. KV Router
3. Disaggregated Router
4. Prefill and Decode Workers
#### Processor
```
# Processor must take the same args as the worker
# This is temporary until we communicate the ModelDeploymentCard over etcd
# Currently only block-size=64 is supported
cd /workspace/examples/python_rs/llm/vllm_nixl
RUST_LOG=info python3 router/processor.py \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--tokenizer deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enable-prefix-caching \
--block-size 64 \
--max-model-len 16384
```
#### KV Router
The KV Router is a component that aggregates KV Events from all the workers and maintains a prefix tree of the cached tokens. It makes decisions on which worker to route requests to based on the length of the prefix match and the load on the workers.
To launch the KV Router, run the following command:
```
RUST_LOG=info python3 router/kv_router.py \
--routing-strategy prefix \
--model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--min-workers 1
```
There is also a custom router that uses a cost function defined in python to make routing decisions. To launch the custom router, run the following command:
```
RUST_LOG=info python3 router/kv_router.py \
--routing-strategy prefix \
--model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--custom-router \
--min-workers 1
```
You can choose only the prefix strategy for now:
- `prefix`: Route requests to the worker that has the longest prefix match.
#### Disaggregated Router
The disaggregated router determines whether a request should be send to a
remote prefill engine or a local prefill engine for prefilling based on the
prefill length. When prefilling locally, the vllm scheduler will prioritize
prefill request and pause any ongoing decode requests.
There are two types of disaggregated router implementations:
* Rust native: provide a simple heuristic to route to prefill engine
if prefill length (including prefix catch hit) is greater than a threshold.
This threshold can by dynamically adjusted at runtime through etcd.
To check the current threshold (this will print out all kv pairs in etcd):
```
curl -s -L http://localhost:2379/v3/kv/range -X POST -d '{"key":"AA==", "range_end":"AA=="}' | jq -r '.kvs[] | "KEY: \(.key | @base64d)\nVALUE: \(.value | @base64d)\n---"'
```
To update the threshold:
```
ETCDCTL_API=3 etcdctl --endpoints=http://localhost:2379 put 'public/components/disagg_router/models/chat/<vllm.served_model_name(default to "vllm")>' '{"max_local_prefill_length": <new_threshold>}'
```
* Python customized: provide a python implementation that can be easily customized.
However, it does not support dynamic threshold adjustment through etcd.
It is recommended to use the custom disaggregated router together with the custom
kv router as the rust kv router does not report kv cache hit ratio.
To use the python disaggregated router, add the following commands when launching
the decode worker:
```
python worker.py \
--custom-disagg-router \
--max-local-prefill-length <length> \
--max-remote-prefill-cache-hit-ratio <ratio>
```
#### Workers
```
# start prefill worker in Terminal 1
# Note: prefix caching is not supported in the prefill for now
cd /workspace/examples/python_rs/llm/vllm_nixl
CUDA_VISIBLE_DEVICES=0 python3 router/prefill_worker.py \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager \
--kv-transfer-config '{"kv_connector":"DynamoNixlConnector"}' \
--block-size 64 \
--max-num-batched-tokens 16384 \
--max-model-len 16384
# start decode worker in Terminal 2
cd /workspace/examples/python_rs/llm/vllm_nixl
CUDA_VISIBLE_DEVICES=1 python3 router/worker.py \
--remote-prefill \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager \
--tensor-parallel-size 1 \
--kv-transfer-config '{"kv_connector":"DynamoNixlConnector"}' \
--enable-prefix-caching \
--block-size 64 \
--max-num-batched-tokens 16384 \
--max-model-len 16384
```
Alternatively, we also provide a script to launch all workers in one go (with the python customized router):
```
# this TODO: change to dynamo-deploy functionality
./start_single_node.sh
# Usage [--model <model>] [--p_tensor_parallel_size <size>] [--d_tensor_parallel_size <size>] [--max_model_len <len>] [--max_num_batched_tokens <tokens>] [--max_num_seqs <seqs>] [--gpu_memory_utilization <utilization>] [--enable_chunked_prefill <True/False>] [--num_p <p>] [--num_d <d>]
```
### Common Issues
If torch GLOO backend is complaining about file name too long, set
```
export GLOO_SOCKET_IFNAME=lo
```
## Client ## Client
In another terminal: In another terminal:
``` ```
curl localhost:8181/v1/chat/completions \ # this test request has around 200 tokens isl
-H "Content-Type: application/json" \ curl localhost:8181/v1/chat/completions -H "Content-Type: application/json" -d '{
-d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", "model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"messages": [ "messages": [
{"role": "user", "content": "What is the capital of France?"} {
"role": "user",
"content": "In the heart of Eldoria, an ancient land of boundless magic and mysterious creatures, lies the long-forgotten city of Aeloria. Once a beacon of knowledge and power, Aeloria was buried beneath the shifting sands of time, lost to the world for centuries. You are an intrepid explorer, known for your unparalleled curiosity and courage, who has stumbled upon an ancient map hinting at ests that Aeloria holds a secret so profound that it has the potential to reshape the very fabric of reality. Your journey will take you through treacherous deserts, enchanted forests, and across perilous mountain ranges. Your Task: Character Background: Develop a detailed background for your character. Describe their motivations for seeking out Aeloria, their skills and weaknesses, and any personal connections to the ancient city or its legends. Are they driven by a quest for knowledge, a search for lost familt clue is hidden."
}
], ],
"max_tokens": 10 "stream":false,
"max_tokens": 30
}' }'
``` ```
...@@ -132,7 +287,6 @@ Kill all python processes and clean up metadata files: ...@@ -132,7 +287,6 @@ Kill all python processes and clean up metadata files:
``` ```
pkill -9 -f python pkill -9 -f python
rm -r /tmp/nixl
``` ```
## TODOs, limitations, known issues ## TODOs, limitations, known issues
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dynamo.llm import DisaggregatedRouter
class PyDisaggregatedRouter:
def __init__(
self,
runtime,
served_model_name,
custom_disagg_router=False,
max_local_prefill_length=1000,
max_remote_prefill_cache_hit_ratio=0.5,
):
self.runtime = runtime
self.served_model_name = served_model_name
self.max_local_prefill_length = max_local_prefill_length
self.max_remote_prefill_cache_hit_ratio = max_remote_prefill_cache_hit_ratio
self.custom_disagg_router = custom_disagg_router
if not self.custom_disagg_router:
# TODO: add max_remote_prefill_cache_hit_ratio to rust router
self.disagg_router = DisaggregatedRouter(
runtime,
served_model_name,
max_local_prefill_length,
)
def prefill_remote(self, prompt_length, cache_hit_length=0):
if self.custom_disagg_router:
# TODO: add max_remote_prefill_cache_hit_ratio to python router
return prompt_length > self.max_local_prefill_length
else:
return self.disagg_router.prefill_remote(prompt_length, cache_hit_length)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from argparse import Namespace
from enum import Enum
from typing import AsyncIterator
import uvloop
from utils.protocol import Tokens
from vllm.logger import logger as vllm_logger
from dynamo.llm import KvIndexer, KvMetricsAggregator, KvRouter
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
WorkerId = str
class RoutingStrategy(Enum):
PREFIX = "prefix"
ROUND_ROBIN = "round_robin"
RANDOM = "random"
class Router:
"""
Request handler for the generate endpoint
"""
def __init__(
self,
router: KvRouter,
routing_strategy: RoutingStrategy = RoutingStrategy.PREFIX,
):
vllm_logger.info(
f"Initializing KV Router with strategy: {routing_strategy.value}"
)
self.router = router
self.routing_strategy = routing_strategy
@dynamo_endpoint(Tokens, WorkerId)
async def generate(self, request) -> AsyncIterator[WorkerId]:
lora_id = 0
worker_id = None
if self.routing_strategy == RoutingStrategy.PREFIX:
try:
worker_id = await self.router.schedule(request.tokens, lora_id)
# [NOTE][TODO] Now that the scheduler may return more error messages,
# now we are catching all exceptions and logging them. Should have
# catch specific router exceptions once we have dedicated types.
except Exception as e:
vllm_logger.info(f"{e}")
worker_id = ""
vllm_logger.exception(f"Error during worker selection: {e}")
vllm_logger.info(f"Scheduling to worker_id: {worker_id}")
yield str(worker_id)
else:
# TODO: Do we implement round_robin and random here?
# or just skip this router and directly enable in preprocess?
raise NotImplementedError(
f"Routing strategy {self.routing_strategy} not implemented"
)
class CustomRouter:
"""
Request handler for the generate endpoint
"""
def __init__(
self,
indexer: KvIndexer,
metrics_aggregator: KvMetricsAggregator,
):
self.indexer = indexer
self.metrics_aggregator = metrics_aggregator
def _cost_function(self, scores, metrics):
# naive cost function for demonstration purposes
current_best = ("", 0)
for worker_id, score in scores.scores.items():
if score > current_best[1]:
current_best = (worker_id, score)
for endpoint in metrics.endpoints:
if endpoint.worker_id == current_best[0]:
print(f"Metrics of endpoint: {endpoint.worker_id}")
print(
f"request slot usage: {endpoint.request_active_slots} / {endpoint.request_total_slots}"
)
print(
f"KV block usage: {endpoint.kv_active_blocks} / {endpoint.kv_total_blocks}"
)
return current_best[0]
@dynamo_endpoint(Tokens, WorkerId)
async def generate(self, request) -> AsyncIterator[WorkerId]:
lora_id = 0
worker_id = ""
try:
scores = await self.indexer.find_matches_for_request(
request.tokens, lora_id
)
metrics = await self.metrics_aggregator.get_metrics()
worker_id = self._cost_function(scores, metrics)
# [NOTE][TODO] Now that the scheduler may return more error messages,
# now we are catching all exceptions and logging them. Should have
# catch specific router exceptions once we have dedicated types.
except Exception as e:
vllm_logger.info(f"{e}")
worker_id = ""
vllm_logger.exception(f"Error during worker selection: {e}")
vllm_logger.info(f"Scheduling to worker_id: {worker_id}")
yield str(worker_id)
@dynamo_worker()
async def worker(runtime: DistributedRuntime, args: Namespace):
"""
Set up the worker clients.
Serve the dynamo-init.router.generate endpoint.
"""
workers_client = (
await runtime.namespace("dynamo-init")
.component("vllm")
.endpoint("generate")
.client()
)
wait_task = workers_client.wait_for_endpoints()
await asyncio.sleep(1)
while not wait_task.done():
vllm_logger.info("Waiting for workers to be ready...")
await asyncio.sleep(5)
wait_task.result()
while len(workers_client.endpoint_ids()) < args.min_workers:
vllm_logger.info(
f"Waiting for more workers... Current: {len(workers_client.endpoint_ids())}, Required: {args.min_workers}"
)
await asyncio.sleep(5)
vllm_logger.info(
f"Required number of workers ({args.min_workers}) are ready:\n"
+ "\n".join(f"id: {id}" for id in workers_client.endpoint_ids())
)
kv_listener = runtime.namespace("dynamo-init").component("vllm")
await kv_listener.create_service()
router_component = runtime.namespace("dynamo-init").component("router")
await router_component.create_service()
endpoint = router_component.endpoint("generate")
if args.custom_router:
indexer = KvIndexer(kv_listener)
metrics_aggregator = KvMetricsAggregator(kv_listener)
await endpoint.serve_endpoint(
CustomRouter(indexer, metrics_aggregator).generate
)
else:
router = KvRouter(runtime, kv_listener)
await endpoint.serve_endpoint(Router(router, args.routing_strategy).generate)
if __name__ == "__main__":
uvloop.install()
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--routing-strategy",
type=RoutingStrategy,
default=RoutingStrategy.PREFIX,
choices=list(RoutingStrategy),
help="Routing strategy to use",
)
parser.add_argument(
"--min-workers",
type=int,
default=1,
help="Minimum number of workers required before proceeding",
)
parser.add_argument(
"--model-name",
type=str,
default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
help="Model that is being served",
)
parser.add_argument(
"--custom-router",
type=bool,
default=False,
help="Whether to use custom router or not",
)
args = parser.parse_args()
asyncio.run(worker(args))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
import uvloop
from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue
from utils.vllm import parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.inputs.data import TokensPrompt
from vllm.logger import logger as vllm_logger
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from dynamo.runtime import DistributedRuntime, dynamo_worker
class RequestHandler:
def __init__(self, engine_client, metadata_store):
self.engine_client = engine_client
self._metadata_store = metadata_store
self._loaded_metadata = set()
print("RequestHandler initialized")
async def generate(self, request: RemotePrefillRequest):
sampling_params = request.sampling_params
sampling_params.max_tokens = 1
sampling_params.min_tokens = 1
remote_prefill_params = RemotePrefillParams(
is_remote_decode=True,
decode_block_ids=request.block_ids,
decode_engine_id=request.engine_id,
)
# TODO check if metadata has changed
# and reload - currently only loading once
if request.engine_id not in self._loaded_metadata:
remote_metadata = await self._metadata_store.get(request.engine_id)
await self.engine_client.add_remote_nixl_metadata(remote_metadata)
print(
f"Loaded nixl metadata from engine {request.engine_id} into engine {self.engine_client.nixl_metadata.engine_id}"
)
self._loaded_metadata.add(request.engine_id)
async for _ in self.engine_client.generate(
request_id=request.request_id,
prompt=TokensPrompt(prompt_token_ids=request.prompt_token_ids),
sampling_params=sampling_params,
remote_prefill_params=remote_prefill_params,
):
yield
@dynamo_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
# TODO: we don't need it now, but will need it after the queue is integrated to the runtime
component = runtime.namespace("dynamo-init").component("prefill")
await component.create_service()
async with build_async_engine_client_from_engine_args(engine_args) as engine_client:
metadata = engine_client.nixl_metadata
metadata_store = NixlMetadataStore("dynamo-init", runtime)
await metadata_store.put(metadata.engine_id, metadata)
# TODO: move this to prefill_queue.py
prefill_queue_nats_server = os.getenv("NATS_SERVER", "nats://localhost:4222")
prefill_queue_stream_name = (
engine_args.served_model_name
if engine_args.served_model_name is not None
else "vllm"
)
vllm_logger.info(
f"Prefill queue: {prefill_queue_nats_server}:{prefill_queue_stream_name}"
)
request_handler = RequestHandler(engine_client, metadata_store)
# TODO: integrate prefill_queue to an triton_distributed endpoint
async with PrefillQueue.get_instance(
nats_server=prefill_queue_nats_server,
stream_name=prefill_queue_stream_name,
) as prefill_queue:
while True:
# TODO: this might add a small overhead to pull prefill from nats
# need to test and check how much overhead it is
prefill_request = await prefill_queue.dequeue_prefill_request()
if prefill_request is not None:
vllm_logger.info(f"Dequeued prefill request: {prefill_request}")
async for _ in request_handler.generate(prefill_request):
pass
if __name__ == "__main__":
uvloop.install()
engine_args = parse_vllm_args()
if engine_args.enable_chunked_prefill is not False:
print("Chunked prefill is not supported yet, setting to False")
engine_args.enable_chunked_prefill = False
if engine_args.pipeline_parallel_size != 1:
print("Pipeline parallel size is not supported yet, setting to 1")
engine_args.pipeline_parallel_size = 1
if engine_args.disable_async_output_proc is not True:
print("Async output processing is not supported yet, setting to True")
engine_args.disable_async_output_proc = True
if engine_args.enforce_eager is not True:
print("Prefill must be done eagerly, setting to True")
engine_args.enforce_eager = True
asyncio.run(worker(engine_args))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import uuid
from enum import Enum
from typing import AsyncIterator, Tuple, Union
import uvloop
from transformers import AutoTokenizer
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.protocol import MyRequestOutput, Tokens, vLLMGenerateRequest
from utils.vllm import parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionStreamResponse,
CompletionRequest,
CompletionStreamResponse,
)
from vllm.logger import logger as vllm_logger
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from dynamo.runtime import Client, DistributedRuntime, dynamo_endpoint, dynamo_worker
class RequestType(Enum):
CHAT = "chat"
COMPLETION = "completion"
class Processor(ProcessMixIn):
"""
vLLM pre and post processing
"""
def __init__(
self,
engine_args: AsyncEngineArgs,
router_client: Client,
workers_client: Client,
):
self.engine_args = engine_args
self.model_config = self.engine_args.create_model_config()
self.tokenizer = self._create_tokenizer(engine_args)
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
self.completions_processor = CompletionsProcessor(
self.tokenizer, self.model_config
)
self.router_client = router_client
self.workers_client = workers_client
def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer:
"""Create a TokenizerGroup using engine arguments similar to VLLM's approach"""
model_path = engine_args.model
# Create the base tokenizer with VLLM's typical settings
base_tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
padding_side="left",
truncation_side="left",
use_fast=True, # VLLM might use the fast tokenizer for efficiency
)
return base_tokenizer
async def _generate(
self,
raw_request: Union[CompletionRequest, ChatCompletionRequest],
request_type: RequestType,
):
request_id = str(uuid.uuid4())
vllm_logger.debug(f"Got raw request: {raw_request}")
(
request,
conversation,
prompt,
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
worker_id_generator: AsyncIterator = await self.router_client.generate(
Tokens(tokens=engine_prompt["prompt_token_ids"]).model_dump_json()
)
worker_id = (
await worker_id_generator.__anext__()
) # only one worker id is returned
worker_id = worker_id.data()
vllm_logger.info(f"Worker ID: {worker_id}")
if worker_id == "":
engine_generator = await self.workers_client.random(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json()
)
else:
engine_generator = await self.workers_client.direct(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json(),
int(worker_id),
)
output = self._generate_responses(engine_generator, request_type)
async for response in await self._stream_response(
request, output, request_id, conversation
):
yield response
async def _generate_responses(
self, engine_generator: AsyncIterator[RequestOutput], request_type: RequestType
) -> AsyncIterator[Union[RequestOutput, Tuple[int, RequestOutput]]]:
prompt_idx = 0
async for resp in engine_generator:
# Deserialize the response from the engine
# Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data())
# OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object
request_output = RequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
)
if request_type == RequestType.CHAT:
# For chat requests, yield the request_output directly.
yield request_output
elif request_type == RequestType.COMPLETION:
# Completion requests can have multiple prompts and stream generator requires the prompt index
yield (prompt_idx, request_output)
else:
raise NotImplementedError(
f"Request type {request_type} not implemented"
)
@dynamo_endpoint(ChatCompletionRequest, ChatCompletionStreamResponse)
async def generate_chat(self, raw_request):
async for response in self._generate(raw_request, RequestType.CHAT):
yield response
@dynamo_endpoint(CompletionRequest, CompletionStreamResponse)
async def generate_completions(self, raw_request):
async for response in self._generate(raw_request, RequestType.COMPLETION):
yield response
@dynamo_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
"""
Set up clients to the router and workers.
Serve the dynamo-init.process.chat/completions endpoint.
"""
workers_client = (
await runtime.namespace("dynamo-init")
.component("vllm")
.endpoint("generate")
.client()
)
router_client = (
await runtime.namespace("dynamo-init")
.component("router")
.endpoint("generate")
.client()
)
preprocess_component = runtime.namespace("dynamo-init").component("process")
await preprocess_component.create_service()
chat_endpoint = preprocess_component.endpoint("chat/completions")
completions_endpoint = preprocess_component.endpoint("completions")
processor = Processor(engine_args, router_client, workers_client)
await asyncio.gather(
chat_endpoint.serve_endpoint(processor.generate_chat),
completions_endpoint.serve_endpoint(processor.generate_completions),
)
if __name__ == "__main__":
uvloop.install()
engine_args = parse_vllm_args()
asyncio.run(worker(engine_args))
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# default values
model=deepseek-ai/DeepSeek-R1-Distill-Llama-8B
p_tensor_parallel_size=1
d_tensor_parallel_size=1
max_model_len=16384
max_num_batched_tokens=16384
max_num_seqs=1024
gpu_memory_utilization=0.9
enable_chunked_prefill=False
block_size=64
num_p=2
num_d=2
total_rank=$((num_p + num_d))
curr_kv_rank=0
# Function to display usage
usage() {
echo "Usage: $0 [--model <model>] [--p_tensor_parallel_size <size>] [--d_tensor_parallel_size <size>] [--max_model_len <len>] [--max_num_batched_tokens <tokens>] [--max_num_seqs <seqs>] [--gpu_memory_utilization <utilization>] [--enable_chunked_prefill <True/False>] [--num_p <p>] [--num_d <d>]"
exit 1
}
# Parse the command-line arguments
while [[ $# -gt 0 ]]; do
case "$1" in
--model)
model="$2"
shift 2
;;
--p_tensor_parallel_size)
p_tensor_parallel_size="$2"
shift 2
;;
--d_tensor_parallel_size)
d_tensor_parallel_size="$2"
shift 2
;;
--max_model_len)
max_model_len="$2"
shift 2
;;
--max_num_batched_tokens)
max_num_batched_tokens="$2"
shift 2
;;
--max_num_seqs)
max_num_seqs="$2"
shift 2
;;
--gpu_memory_utilization)
gpu_memory_utilization="$2"
shift 2
;;
--enable_chunked_prefill)
enable_chunked_prefill="$2"
shift 2
;;
--num_p)
num_p="$2"
shift 2
;;
--num_d)
num_d="$2"
shift 2
;;
--total_rank)
total_rank="$2"
shift 2
;;
--curr_kv_rank)
curr_kv_rank="$2"
shift 2
;;
--block_size)
block_size="$2"
shift 2
;;
*)
usage
;;
esac
done
# rank here is GPU rank
curr_rank=0
echo "total rank: "${total_rank}
for (( i=1; i<=num_d; i++ )); do
cuda_devices=$(seq $curr_rank $(($curr_rank + $d_tensor_parallel_size - 1)))
cuda_devices=$(echo $cuda_devices | tr ' ' ',')
echo "starting gpu rank "${cuda_devices}" (decode)"
CUDA_VISIBLE_DEVICES=${cuda_devices} python3 worker.py \
--remote-prefill \
--model ${model} \
--max-model-len ${max_model_len} \
--max-num-batched-tokens ${max_num_batched_tokens} \
--enable-chunked-prefill ${enable_chunked_prefill} \
--gpu-memory-utilization ${gpu_memory_utilization} \
--enforce-eager \
--enable-prefix-caching \
--tensor-parallel-size ${d_tensor_parallel_size} \
--block-size ${block_size} \
--kv-transfer-config '{"kv_connector":"dynamoNixlConnector"}' & disown
curr_rank=$((curr_rank + d_tensor_parallel_size))
curr_kv_rank=$((curr_kv_rank + 1))
done
for (( i=1; i<=num_p; i++ )); do
cuda_devices=$(seq $curr_rank $(($curr_rank + $p_tensor_parallel_size - 1)))
cuda_devices=$(echo $cuda_devices | tr ' ' ',')
echo "starting gpu rank "${cuda_devices}" (prefill)"
CUDA_VISIBLE_DEVICES=${cuda_devices} python3 prefill_worker.py \
--model ${model} \
--max-model-len ${max_model_len} \
--max-num-batched-tokens ${max_num_batched_tokens} \
--enable-chunked-prefill ${enable_chunked_prefill} \
--gpu-memory-utilization ${gpu_memory_utilization} \
--enforce-eager \
--tensor-parallel-size ${p_tensor_parallel_size} \
--block-size ${block_size} \
--kv-transfer-config '{"kv_connector":"dynamoNixlConnector"}' & disown
curr_rank=$((curr_rank + p_tensor_parallel_size))
curr_kv_rank=$((curr_kv_rank + 1))
done
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
import uvloop
from disagg_router import PyDisaggregatedRouter
from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue
from utils.protocol import MyRequestOutput, vLLMGenerateRequest
from utils.vllm import parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.multiprocessing.client import EngineClient
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.logger import logger as vllm_logger
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from vllm.sampling_params import RequestOutputKind
from dynamo.llm import KvMetricsPublisher
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
class RequestHandler:
def __init__(
self,
model_name: str,
engine_client: EngineClient,
prefill_client,
do_remote_prefill: bool,
disaggregated_router: PyDisaggregatedRouter = None,
):
self.model_name = model_name
self.client = engine_client
self.prefill_client = prefill_client
self.openai_serving_chat = None
self.initialized = False
self.do_remote_prefill = (
do_remote_prefill # remote prefill is still controlled by the router
)
self.disaggregated_router = disaggregated_router
if do_remote_prefill:
assert (
disaggregated_router is not None
), "Disaggregated router is required for remote prefill"
self._prefill_queue_nats_server = os.getenv(
"NATS_SERVER", "nats://localhost:4222"
)
self._prefill_queue_stream_name = model_name
vllm_logger.info(
f"Prefill queue: {self._prefill_queue_nats_server}:{self._prefill_queue_stream_name}"
)
print("RequestHandler initialized")
def get_remote_prefill_request_callback(self):
# TODO: integrate prefill_queue to an triton_distributed endpoint
async def callback(request: RemotePrefillRequest):
async with PrefillQueue.get_instance(
nats_server=self._prefill_queue_nats_server,
stream_name=self._prefill_queue_stream_name,
) as prefill_queue:
await prefill_queue.enqueue_prefill_request(request)
return callback
@dynamo_endpoint(vLLMGenerateRequest, MyRequestOutput)
async def generate(self, request):
# TODO: consider prefix hit when deciding prefill locally or remotely
if self.disaggregated_router is not None:
disagg_router_decision = self.disaggregated_router.prefill_remote(
len(request.engine_prompt["prompt_token_ids"]), 0
)
else:
# always prefill remotely if no disaggregated router is provided
disagg_router_decision = True
if self.do_remote_prefill and disagg_router_decision:
remote_prefill_params = RemotePrefillParams(
is_remote_prefill=True,
remote_prefill_request_callback=self.get_remote_prefill_request_callback(),
)
vllm_logger.info(
f"Prefilling remotely for request {request.request_id} with length {len(request.engine_prompt['prompt_token_ids'])}"
)
else:
remote_prefill_params = None
vllm_logger.info(
f"Prefilling locally for request {request.request_id} with length {len(request.engine_prompt['prompt_token_ids'])}"
)
# rust HTTP requires Delta streaming
request.sampling_params.output_kind = RequestOutputKind.DELTA
async for response in self.client.generate(
prompt=request.engine_prompt,
sampling_params=request.sampling_params,
request_id=request.request_id,
remote_prefill_params=remote_prefill_params,
):
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
).model_dump_json()
@dynamo_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
component = runtime.namespace("dynamo-init").component("vllm")
await component.create_service()
endpoint = component.endpoint("generate")
prefill_client = (
await runtime.namespace("dynamo-init")
.component("prefill")
.endpoint("generate")
.client()
)
# TODO: do we need these env vars?
VLLM_WORKER_ID = endpoint.lease_id()
os.environ["VLLM_WORKER_ID"] = str(VLLM_WORKER_ID)
vllm_logger.info(f"Generate endpoint ID: {VLLM_WORKER_ID}")
VLLM_KV_NAMESPACE = "dynamo-init"
os.environ["VLLM_KV_NAMESPACE"] = str(VLLM_KV_NAMESPACE)
VLLM_KV_COMPONENT = "vllm"
os.environ["VLLM_KV_COMPONENT"] = str(VLLM_KV_COMPONENT)
metrics_publisher = KvMetricsPublisher()
async with build_async_engine_client_from_engine_args(engine_args) as engine_client:
served_model_name = (
engine_args.served_model_name
if engine_args.served_model_name is not None
else "vllm"
)
disaggregated_router = PyDisaggregatedRouter(
runtime,
served_model_name,
custom_disagg_router=engine_args.custom_disagg_router,
max_local_prefill_length=engine_args.max_local_prefill_length,
max_remote_prefill_cache_hit_ratio=engine_args.max_remote_prefill_cache_hit_ratio,
)
engine_client.set_metrics_publisher(metrics_publisher)
# Initially send dummy metrics to kick start,
# vLLM will not update stat until forward pass is triggered
metrics_publisher.publish(
0,
1024,
0,
1024,
)
metadata = engine_client.nixl_metadata
metadata_store = NixlMetadataStore("dynamo-init", runtime)
await metadata_store.put(metadata.engine_id, metadata)
await asyncio.gather(
endpoint.serve_endpoint(
RequestHandler(
model_name=served_model_name,
engine_client=engine_client,
prefill_client=prefill_client,
do_remote_prefill=True,
disaggregated_router=disaggregated_router,
).generate
),
metrics_publisher.create_endpoint(component),
)
if __name__ == "__main__":
uvloop.install()
engine_args = parse_vllm_args()
if engine_args.enable_chunked_prefill is not False:
print("Chunked prefill is not supported yet, setting to False")
engine_args.enable_chunked_prefill = False
if engine_args.preemption_mode != "swap":
print("Preemption mode is not supported yet, setting to swap")
engine_args.preemption_mode = "swap"
if engine_args.pipeline_parallel_size != 1:
print("Pipeline parallel size is not supported yet, setting to 1")
engine_args.pipeline_parallel_size = 1
asyncio.run(worker(engine_args))
...@@ -18,7 +18,8 @@ import asyncio ...@@ -18,7 +18,8 @@ import asyncio
import msgspec import msgspec
import uvloop import uvloop
from common import NixlMetadataStore, parse_vllm_args from utils.nixl import NixlMetadataStore
from utils.vllm import parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.api_server import ( from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args, build_async_engine_client_from_engine_args,
...@@ -73,14 +74,14 @@ class RequestHandler: ...@@ -73,14 +74,14 @@ class RequestHandler:
@dynamo_worker() @dynamo_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
component = runtime.namespace("test-nixl").component("prefill") component = runtime.namespace("dynamo-init").component("prefill")
await component.create_service() await component.create_service()
endpoint = component.endpoint("generate") endpoint = component.endpoint("generate")
async with build_async_engine_client_from_engine_args(engine_args) as engine_client: async with build_async_engine_client_from_engine_args(engine_args) as engine_client:
metadata = engine_client.nixl_metadata metadata = engine_client.nixl_metadata
metadata_store = NixlMetadataStore("test-nixl", runtime) metadata_store = NixlMetadataStore("dynamo-init", runtime)
await metadata_store.put(metadata.engine_id, metadata) await metadata_store.put(metadata.engine_id, metadata)
await endpoint.serve_endpoint( await endpoint.serve_endpoint(
......
...@@ -19,7 +19,8 @@ import json ...@@ -19,7 +19,8 @@ import json
import msgspec import msgspec
import uvloop import uvloop
from common import NixlMetadataStore, parse_vllm_args from utils.nixl import NixlMetadataStore
from utils.vllm import parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.multiprocessing.client import EngineClient from vllm.engine.multiprocessing.client import EngineClient
from vllm.entrypoints.openai.api_server import ( from vllm.entrypoints.openai.api_server import (
...@@ -111,13 +112,13 @@ class RequestHandler: ...@@ -111,13 +112,13 @@ class RequestHandler:
@dynamo_worker() @dynamo_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
component = runtime.namespace("test-nixl").component("vllm") component = runtime.namespace("dynamo-init").component("vllm")
await component.create_service() await component.create_service()
endpoint = component.endpoint("generate") endpoint = component.endpoint("generate")
prefill_client = ( prefill_client = (
await runtime.namespace("test-nixl") await runtime.namespace("dynamo-init")
.component("prefill") .component("prefill")
.endpoint("generate") .endpoint("generate")
.client() .client()
...@@ -128,7 +129,7 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): ...@@ -128,7 +129,7 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
if engine_args.remote_prefill: if engine_args.remote_prefill:
metadata = engine_client.nixl_metadata metadata = engine_client.nixl_metadata
metadata_store = NixlMetadataStore("test-nixl", runtime) metadata_store = NixlMetadataStore("dynamo-init", runtime)
await metadata_store.put(metadata.engine_id, metadata) await metadata_store.put(metadata.engine_id, metadata)
await endpoint.serve_endpoint( await endpoint.serve_endpoint(
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
from typing import AsyncIterator, List, Optional, Protocol, Union, runtime_checkable
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.chat_utils import ConversationMessage
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
RequestResponseMetadata,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import RequestPrompt
from vllm.inputs.data import TokensPrompt
from vllm.transformers_utils.tokenizer import AnyTokenizer
@runtime_checkable
class ProcessMixInRequired(Protocol):
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
class ProcessMixIn(ProcessMixInRequired):
"""
Mixin for pre and post processing for vLLM
Requires engine_args, engine_client, processor, model_config to be initialized
"""
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
def __init__(self):
pass
def _get_processor(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
# Determine the processor type based on the request structure
return (
self.chat_processor
if isinstance(raw_request, ChatCompletionRequest)
else self.completions_processor
)
async def _parse_raw_request(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
processor = self._get_processor(raw_request)
if processor is None:
raise RuntimeError("Processor has not been initialized")
request = processor.parse_raw_request(raw_request)
preprocess_result = await processor.preprocess(raw_request)
default_max_tokens = self.model_config.max_model_len - len(
preprocess_result.engine_prompt["prompt_token_ids"]
)
default_sampling_params = self.model_config.get_diff_sampling_param()
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern,
default_sampling_params,
)
return (
request,
preprocess_result.conversation,
preprocess_result.request_prompt,
preprocess_result.engine_prompt,
sampling_params,
)
async def _stream_response(self, request, generator, request_id, conversation):
processor = self._get_processor(request)
if processor is None:
raise RuntimeError("processor has not been initialized")
return processor.stream_response(
request,
generator,
request_id,
conversation,
)
class PreprocessResult:
def __init__(
self,
conversation: Optional[ConversationMessage],
request_prompt: RequestPrompt,
engine_prompt: TokensPrompt,
):
self.conversation = conversation
self.request_prompt = request_prompt
self.engine_prompt = engine_prompt
class ChatProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingChat(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
response_role="assistant",
chat_template=None,
chat_template_content_format="auto",
)
def parse_raw_request(
self, raw_request: ChatCompletionRequest
) -> ChatCompletionRequest:
return ChatCompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: ChatCompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
(
conversation,
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_chat(
request,
self.tokenizer,
request.messages,
chat_template=request.chat_template or self.tokenizer.chat_template,
chat_template_content_format=self.openai_serving.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tool_dicts=None,
documents=request.documents,
chat_template_kwargs=request.chat_template_kwargs,
tool_parser=self.openai_serving.tool_parser,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
return PreprocessResult(conversation[0], request_prompts[0], engine_prompts[0])
async def stream_response(
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: List,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if not request.stream:
raise ValueError("Only streaming responses are supported")
async for raw_response in self.openai_serving.chat_completion_stream_generator(
request,
result_generator,
request_id,
request.model,
conversation,
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
class CompletionsProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingCompletion(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
)
def parse_raw_request(self, raw_request: CompletionRequest) -> CompletionRequest:
return CompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: CompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
(
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_completion(
request,
self.tokenizer,
input_or_inputs=request.prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
return PreprocessResult(None, request_prompts[0], engine_prompts[0])
async def stream_response(
self,
request: CompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: Optional[List[ConversationMessage]] = None,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if not request.stream:
raise ValueError("Only streaming responses are supported")
async for raw_response in self.openai_serving.completion_stream_generator(
request,
result_generator,
request_id,
int(time.time()), # created_time
request.model,
1, # num_prompts
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from contextlib import asynccontextmanager
from typing import ClassVar, Optional
from nats.aio.client import Client as NATS
from nats.errors import Error as NatsError
from nats.js.client import JetStreamContext
from nats.js.errors import NotFoundError
class NATSQueue:
_instance: ClassVar[Optional["NATSQueue"]] = None
_lock: ClassVar[asyncio.Lock] = asyncio.Lock()
def __init__(
self,
stream_name: str = "default",
nats_server: str = "nats://localhost:4222",
dequeue_timeout: float = 1,
):
self.nats_url = nats_server
self._nc: Optional[NATS] = None
self._js: Optional[JetStreamContext] = None
# TODO: check if this is needed
# Sanitize stream_name to remove path separators
self._stream_name = stream_name.replace("/", "_").replace("\\", "_")
self._subject = f"{self._stream_name}.*"
self.dequeue_timeout = dequeue_timeout
self._subscriber: Optional[JetStreamContext.PullSubscription] = None
@classmethod
@asynccontextmanager
async def get_instance(
cls,
*,
stream_name: str = "default",
nats_server: str = "nats://localhost:4222",
dequeue_timeout: float = 1,
):
"""Get or create a singleton instance of NATSq"""
# TODO: check if this _lock is needed with GIL
async with cls._lock:
if cls._instance is None:
cls._instance = cls(
stream_name=stream_name,
nats_server=nats_server,
dequeue_timeout=dequeue_timeout,
)
await cls._instance.connect()
try:
yield cls._instance
except Exception:
if cls._instance:
await cls._instance.close()
cls._instance = None
raise
# TODO: check to see if this can be replaced by something like get_instance().close()
@classmethod
async def shutdown(cls):
"""Explicitly close the singleton instance if it exists"""
async with cls._lock:
if cls._instance:
await cls._instance.close()
cls._instance = None
async def connect(self):
"""Establish connection and create stream if needed"""
try:
if self._nc is None:
self._nc = NATS()
await self._nc.connect(self.nats_url)
self._js = self._nc.jetstream()
# Check if stream exists, if not create it
try:
await self._js.stream_info(self._stream_name)
except NotFoundError:
await self._js.add_stream(
name=self._stream_name, subjects=[self._subject]
)
# Create persistent subscriber
self._subscriber = await self._js.pull_subscribe(
f"{self._stream_name}.queue", durable="worker-group"
)
except NatsError as e:
await self.close()
raise ConnectionError(f"Failed to connect to NATS: {e}")
async def ensure_connection(self):
"""Ensure we have an active connection"""
if self._nc is None or self._nc.is_closed:
await self.connect()
async def close(self):
"""Close the connection when done"""
if self._nc:
await self._nc.close()
self._nc = None
self._js = None
self._subscriber = None
# TODO: is enqueue/dequeue_object a better name for a general queue?
async def enqueue_task(self, task_data: bytes) -> None:
"""
Enqueue a task using msgspec-encoded data
"""
await self.ensure_connection()
try:
await self._js.publish(f"{self._stream_name}.queue", task_data) # type: ignore
except NatsError as e:
raise RuntimeError(f"Failed to enqueue task: {e}")
async def dequeue_task(self) -> Optional[bytes]:
"""Dequeue and return a task as raw bytes, to be decoded with msgspec"""
await self.ensure_connection()
try:
msgs = await self._subscriber.fetch(1, timeout=self.dequeue_timeout) # type: ignore
if msgs:
msg = msgs[0]
await msg.ack()
return msg.data
return None
except asyncio.TimeoutError:
return None
except NatsError as e:
raise RuntimeError(f"Failed to dequeue task: {e}")
...@@ -19,26 +19,12 @@ from contextlib import contextmanager ...@@ -19,26 +19,12 @@ from contextlib import contextmanager
import msgspec import msgspec
from vllm.distributed.device_communicators.nixl import NixlMetadata from vllm.distributed.device_communicators.nixl import NixlMetadata
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
METADATA_DIR = "/tmp/nixl" METADATA_DIR = "/tmp/nixl"
def parse_vllm_args() -> AsyncEngineArgs:
parser = FlexibleArgumentParser()
parser.add_argument(
"--remote-prefill", action="store_true", help="Enable remote prefill"
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine_args.remote_prefill = args.remote_prefill
return engine_args
@contextmanager @contextmanager
def temp_metadata_file(engine_id, metadata: NixlMetadata): def temp_metadata_file(engine_id, metadata: NixlMetadata):
os.makedirs(METADATA_DIR, exist_ok=True) os.makedirs(METADATA_DIR, exist_ok=True)
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import msgspec
from utils.nats_queue import NATSQueue
from vllm.remote_prefill import RemotePrefillRequest
class PrefillQueue(NATSQueue):
"""
A wrapper of NATSQueue for PrefillRequest.
The stream name is forced to be "prefill_queue".
"""
def __init__(
self,
stream_name="prefill_queue",
nats_server: str = "nats://localhost:4222",
dequeue_timeout: float = 1,
):
super().__init__(
stream_name=stream_name,
nats_server=nats_server,
dequeue_timeout=dequeue_timeout,
)
async def enqueue_prefill_request(
self, prefill_request: RemotePrefillRequest
) -> None:
encoded_request = msgspec.json.encode(prefill_request)
await self.enqueue_task(encoded_request)
async def dequeue_prefill_request(self) -> Optional[RemotePrefillRequest]:
encoded_request = await self.dequeue_task()
if encoded_request is not None:
prefill_request = msgspec.json.decode(
encoded_request, type=RemotePrefillRequest
)
return prefill_request
else:
return None
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, List, Optional
import msgspec
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic_core import core_schema
from typing_extensions import NotRequired
from vllm.inputs.data import TokensPrompt
from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import PromptLogprobs, RequestMetrics
class Request(BaseModel):
prompt: str
sampling_params: dict
class Tokens(BaseModel):
tokens: list[int]
class PrefillRequest(Request):
request_id: str
class Response(BaseModel):
text: str
class PrefillResponse(BaseModel):
prefilled: bool
# Hack to override the type of multi_modal_data in TokensPrompt
# as pydantic doesn't understand generic types
# TokensPrompt is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/inputs/data.py#L38
# multi_modal_data is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L103
# ModalityData is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L80
class PatchedTokensPrompt(TokensPrompt):
multi_modal_data: NotRequired[Optional[Any]] # type: ignore
# Monkey-patch the SamplingParams type to add a dummy core schema so pydantic can validate it
# Sampling params is a mspspec struct
# SamplingParams is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/sampling_params.py#L88
SamplingParams.__get_pydantic_core_schema__ = classmethod(
lambda cls, source, handler: core_schema.any_schema()
)
class vLLMGenerateRequest(BaseModel):
"""
Serializable class of all the fields vLLM engine requires for inference
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
engine_prompt: PatchedTokensPrompt
sampling_params: SamplingParams
request_id: str
@field_validator("sampling_params", mode="before")
@classmethod
def parse_sampling_params(cls, v: Any) -> SamplingParams:
if isinstance(v, str):
v = json.loads(v)
if isinstance(v, dict):
return SamplingParams(**v)
return v
model_config = ConfigDict(
json_encoders={SamplingParams: lambda v: msgspec.json.encode(v)}
)
class MyRequestOutput(BaseModel):
"""
RequestOutput from vLLM is not serializable by default
https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/outputs.py#L85
This class is used to serialize the RequestOutput and any recursively defined types
We can do this because PromptLogprobs, RequestMetrics, and CompletionOutput are all serializable dataclasses
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str
prompt: Optional[str] = None
prompt_token_ids: Optional[List[int]] = None
prompt_logprobs: Optional[PromptLogprobs] = None
outputs: List[CompletionOutput]
finished: bool
metrics: Optional[RequestMetrics] = None
# lora_request: Optional[LoRARequest] = None
# encoder_prompt: Optional[str] = None
# encoder_prompt_token_ids: Optional[List[int]] = None
# num_cached_tokens: Optional[int] = None
# multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: rename to avoid ambiguity with vllm package
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser
def parse_vllm_args() -> AsyncEngineArgs:
parser = FlexibleArgumentParser()
parser.add_argument(
"--remote-prefill", action="store_true", help="Enable remote prefill"
)
parser.add_argument(
"--conditional-disagg",
action="store_true",
help="Use disaggregated router to decide whether to prefill locally or remotely",
)
parser.add_argument(
"--custom-disagg-router",
action="store_true",
help="Use custom python implementation of disaggregated router instead of the default rust one",
)
parser.add_argument(
"--max-local-prefill-length",
type=int,
default=1000,
help="Maximum length of local prefill",
)
parser.add_argument(
"--max-remote-prefill-cache-hit-ratio",
type=float,
default=0.5,
help="Maximum cache hit ratio for remote prefill "
"(only applicable to custom python implementation of disaggregated router)",
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine_args.remote_prefill = args.remote_prefill
engine_args.conditional_disagg = args.conditional_disagg
engine_args.custom_disagg_router = args.custom_disagg_router
engine_args.max_local_prefill_length = args.max_local_prefill_length
engine_args.max_remote_prefill_cache_hit_ratio = (
args.max_remote_prefill_cache_hit_ratio
)
return engine_args
...@@ -83,14 +83,14 @@ each time. ...@@ -83,14 +83,14 @@ each time.
# Performance # Performance
The performance impacts of synchrononizing the Python and Rust async runtimes The performance impacts of synchronizing the Python and Rust async runtimes
is a critical consideration when optimizing the performance of a highly is a critical consideration when optimizing the performance of a highly
concurrent and parallel distributed system. concurrent and parallel distributed system.
The Python GIL is a global critical section and is ultimately the death of The Python GIL is a global critical section and is ultimately the death of
parallelism. To compound that, when Rust async futures become ready, parallelism. To compound that, when Rust async futures become ready,
accessing the GIL on those async event loop needs to be considered carefully. accessing the GIL on those async event loop needs to be considered carefully.
Under high load, accessing the GIL or performing CPU intenstive tasks on Under high load, accessing the GIL or performing CPU intensive tasks on
on the event loop threads can starve out other async tasks for CPU resources. on the event loop threads can starve out other async tasks for CPU resources.
However, performing a `tokio::task::spawn_blocking` is not without overheads However, performing a `tokio::task::spawn_blocking` is not without overheads
as well. as well.
......
...@@ -66,6 +66,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -66,6 +66,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<EtcdClient>()?; m.add_class::<EtcdClient>()?;
m.add_class::<AsyncResponseStream>()?; m.add_class::<AsyncResponseStream>()?;
m.add_class::<llm::kv::KvRouter>()?; m.add_class::<llm::kv::KvRouter>()?;
m.add_class::<llm::disagg_router::DisaggregatedRouter>()?;
m.add_class::<llm::kv::KvMetricsPublisher>()?; m.add_class::<llm::kv::KvMetricsPublisher>()?;
m.add_class::<llm::model_card::ModelDeploymentCard>()?; m.add_class::<llm::model_card::ModelDeploymentCard>()?;
m.add_class::<llm::preprocessor::OAIChatPreprocessor>()?; m.add_class::<llm::preprocessor::OAIChatPreprocessor>()?;
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
use super::*; use super::*;
pub mod backend; pub mod backend;
pub mod disagg_router;
pub mod kv; pub mod kv;
pub mod model_card; pub mod model_card;
pub mod preprocessor; pub mod preprocessor;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
use pyo3::exceptions::PyRuntimeError;
use std::sync::Arc;
use tokio::runtime::Runtime;
#[pyclass]
pub struct DisaggregatedRouter {
inner: Arc<dynamo_llm::disagg_router::DisaggregatedRouter>,
}
#[pymethods]
impl DisaggregatedRouter {
#[new]
#[pyo3(signature = (drt, model_name, default_max_local_prefill_length))]
fn new(
drt: PyObject,
model_name: String,
default_max_local_prefill_length: i32,
) -> PyResult<Self> {
let drt_arc = Python::with_gil(|py| {
let drt_ref = drt.extract::<DistributedRuntime>(py)?;
Ok::<_, PyErr>(Arc::new(drt_ref.inner))
})?;
// Create the runtime directly with the correct import
let runtime = Runtime::new().map_err(|e| {
PyRuntimeError::new_err(format!("Failed to create tokio runtime: {}", e))
})?;
let router = runtime.block_on(async {
dynamo_llm::disagg_router::DisaggregatedRouter::new_with_etcd_and_default(
drt_arc,
model_name,
default_max_local_prefill_length,
)
.await
.map_err(|e| {
PyRuntimeError::new_err(format!("Failed to create DisaggregatedRouter: {}", e))
})
})?;
Ok(DisaggregatedRouter {
inner: Arc::new(router),
})
}
fn prefill_remote(&self, prefill_length: i32, prefix_hit_length: i32) -> bool {
self.inner.prefill_remote(prefill_length, prefix_hit_length)
}
fn get_model_name(&self) -> &str {
self.inner.get_model_name()
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment