"...git@developer.sourcefind.cn:tsoc/superbenchmark.git" did not exist on "682ed06aeeaf7783afb719e153a68537c0a1d9dd"
Commit 5bcdb734 authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

refactor: rename vllm_nixl to vllm and make default (#100)

parent a7c35dcf
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from enum import Enum
import bentoml
from common.protocol import Tokens
from dynamo.sdk import async_onstart, dynamo_context, dynamo_endpoint, service
with bentoml.importing():
from dynamo.runtime import KvRouter
WorkerId = str
class RoutingStrategy(Enum):
PREFIX = "prefix"
ROUND_ROBIN = "round_robin"
RANDOM = "random"
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
class Router:
"""
Request handler for the generate endpoint
"""
def __init__(self):
self.model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
self.routing_strategy = RoutingStrategy.PREFIX
self.runtime = dynamo_context["runtime"]
self.min_workers = 1
self.kv_block_size = 64
@async_onstart
async def init_engine(self):
workers_client = (
await self.runtime.namespace("dynamo")
.component("VllmEngine")
.endpoint("generate")
.client()
)
wait_task = workers_client.wait_for_endpoints()
await asyncio.sleep(1)
while not wait_task.done():
print("Waiting for workers to be ready...")
await asyncio.sleep(5)
wait_task.result()
while len(workers_client.endpoint_ids()) < self.min_workers:
print(
f"Waiting for more workers... Current: {len(workers_client.endpoint_ids())}, Required: {self.min_workers}"
)
await asyncio.sleep(5)
kv_listener = self.runtime.namespace("dynamo").component(self.model_name)
await kv_listener.create_service()
self.router = KvRouter(self.runtime, kv_listener, self.kv_block_size)
@dynamo_endpoint()
async def generate(self, request: Tokens):
lora_id = 0
worker_id = ""
if self.routing_strategy == RoutingStrategy.PREFIX:
try:
worker_id = await self.router.schedule(request.tokens, lora_id)
except Exception as e:
if "No worker found" in str(e):
worker_id = ""
else:
print(f"Error during worker selection: {e}")
print(f"Scheduling to worker_id: {worker_id}")
yield worker_id
else:
# TODO: Do we implement round_robin and random here?
# or just skip this router and directly enable in preprocess?
raise NotImplementedError(
f"Routing strategy {self.routing_strategy} not implemented"
)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from typing import Optional
import bentoml
with bentoml.importing():
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.logger import logger as vllm_logger
from vllm.sampling_params import RequestOutputKind
from common.base_engine import BaseVllmEngine
from common.protocol import MyRequestOutput, vLLMGenerateRequest
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from dynamo.llm import KvMetricsPublisher
from dynamo.sdk import (
async_onstart,
dynamo_context,
dynamo_endpoint,
server_context,
service,
)
lease_id = None
## TODO: metrics_publisher.create_endpoint(worker_component),
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class VllmEngine(BaseVllmEngine):
"""
vLLM Inference Engine
"""
def __init__(self):
model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
self.engine_args = AsyncEngineArgs(
model=model,
gpu_memory_utilization=0.8,
enable_prefix_caching=True,
block_size=64,
max_model_len=16384,
)
VLLM_WORKER_ID = dynamo_context["endpoints"][0].lease_id()
os.environ["VLLM_WORKER_ID"] = str(VLLM_WORKER_ID)
os.environ["VLLM_KV_NAMESPACE"] = "dynamo"
os.environ["VLLM_KV_COMPONENT"] = "vllm"
vllm_logger.info(f"Generate endpoint ID: {VLLM_WORKER_ID}")
os.environ["CUDA_VISIBLE_DEVICES"] = f"{server_context.worker_index - 1}"
self.metrics_publisher = KvMetricsPublisher()
self.engine_client: Optional[MQLLMEngineClient] = None
super().__init__(self.engine_args)
async def create_metrics_publisher_endpoint(self):
component = dynamo_context["component"]
await self.metrics_publisher.create_endpoint(component)
@async_onstart
async def init_engine(self):
if self.engine_client is None:
await super().initialize()
print("vLLM worker initialized")
assert self.engine_client is not None, "engine_client was not initialized"
self.engine_client.set_metrics_publisher(self.metrics_publisher)
self.metrics_publisher.publish(0, 1024, 0, 1024)
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback(lambda _: print("metrics publisher endpoint created"))
@dynamo_endpoint()
async def generate(self, request: vLLMGenerateRequest):
sampling_params = request.sampling_params
# rust HTTP requires Delta streaming
sampling_params.output_kind = RequestOutputKind.DELTA
async for response in self.engine_client.generate( # type: ignore
request.engine_prompt, sampling_params, request.request_id
):
# MyRequestOutput takes care of serializing the response as
# vLLM's RequestOutput is not serializable by default
resp = MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
).model_dump_json()
yield resp
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
> **NOTE**: This example is based on an internal NVIDIA library that will soon be publicly released. The example won't work until the official release.
## Prerequisites
Start required services (etcd and NATS):
Option A: Using [Docker Compose](/deploy/docker-compose.yml) (Recommended)
```bash
docker compose -f deploy/docker-compose.yml up -d
```
Option B: Manual Setup
- [NATS.io](https://docs.nats.io/running-a-nats-service/introduction/installation) server with [Jetstream](https://docs.nats.io/nats-concepts/jetstream)
- example: `nats-server -js --trace`
- [etcd](https://etcd.io) server
- follow instructions in [etcd installation](https://etcd.io/docs/v3.5/install/) to start an `etcd-server` locally
## Build docker
```
./container/build.sh --framework VLLM_NIXL --target dev --build-context nixl=<path to downloaded nixl repo @ c53bb19a6a114e9093071bd1f2904f996ae1839b>
```
## Run container
```
./container/run.sh --framework VLLM_NIXL --target dev -it
```
All of the commands below are run inside the same container.
## Run deployment
This figure shows an overview of the major components to deploy:
```
+----------------+
+------| prefill worker |-------+
notify | | (optional) | |
finished | +----------------+ | pull
v v
+------+ +-----------+ +------------------+ push +---------------+
| HTTP |----->| processor |----->| decode/monolith |------------>| prefill queue |
| |<-----| |<-----| worker | (if disagg) | (optional) |
+------+ +-----------+ +------------------+ +---------------+
| ^ |
query best | | return | publish kv events
worker | | worker_id v
| | +------------------+
| +---------| kv-router |
+------------->| (optional) |
+------------------+
```
Add model to dynamo and start http server.
```
llmctl http add chat-models deepseek-ai/DeepSeek-R1-Distill-Llama-8B dynamo-init.process.chat/completions
TRT_LOG=DEBUG http --port 8181
```
### Processor
Processor routes the requests to the (decode) workers. Three scheduling strategies are supported: 1. random, 2. round-robin, 3. kv (see [Kv Router](#kv-router)).
```
# Processor must take the same args as the (decoder) worker
# This is temporary until we communicate the ModelDeploymentCard over etcd
RUST_LOG=info python3 processor.py \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--tokenizer deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--block-size 64 \
--max-model-len 16384 \
--router <random/round-robin/kv>
```
Alternatively, the processor can be bypassed by directly hitting the worker endpoints:
```
llmctl http add chat-models deepseek-ai/DeepSeek-R1-Distill-Llama-8B dynamo-init.vllm.generate
# monolithic
CUDA_VISIBLE_DEVICES=0 python3 routerless/worker.py \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager
# disaggregated
CUDA_VISIBLE_DEVICES=0 python routerless/prefill_worker.py \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager \
--kv-transfer-config '{"kv_connector":"DynamoNixlConnector"}'
CUDA_VISIBLE_DEVICES=1 python3 routerless/worker.py \
--remote-prefill \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager \
--kv-transfer-config '{"kv_connector":"DynamoNixlConnector"}'
```
### Kv Router
The KV Router is a component that aggregates KV Events from all the workers and maintains
a prefix tree of the cached tokens. It makes decisions on which worker to route requests
to based on the length of the prefix match and the load on the workers.
There are three steps needed to enable the kv router:
1. Use `--router kv` in the processor.
2. Use `--router kv` and `--enable-prefix-caching` in all the (decode) workers.
3. Launch the kv router in a separate terminal.
```
RUST_LOG=info python3 kv_router.py \
--model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--block-size 64 \
--min-workers 1
```
where `--min-workers` is the number of (decode) workers.
You can choose only the prefix strategy for now:
- `prefix`: Route requests to the worker that has the longest prefix match.
### Disaggregated Router
The disaggregated router determines whether a request should be send to a
remote prefill engine or a local prefill engine for prefilling based on the
prefill length. If kv router is enabled, the disaggregated router will use
the absolute prefill length (actual prefill length - prefix hit length) to make
the decision.
When prefilling locally, the vllm scheduler will prioritize
prefill request and pause any ongoing decode requests.
To enable the disaggregated router, add the following commands in the decode workers:
```
python worker.py \
...
--conditional-disagg \
--max-local-prefill-length <length>
```
### Worker
#### Monolithic
Only kv router is supported for monolithic deployment.
```
CUDA_VISIBLE_DEVICES=0 python3 worker.py \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager \
--block-size 64 \
--max-model-len 16384 \
<optional kv router args: --router kv --enable-prefix-caching>
```
#### Disaggregated
Kv router and disaggregated router are supported and can be turned on/off individually.
```
# start prefill worker in one terminal
# Note: prefix caching is not supported in the prefill for now
CUDA_VISIBLE_DEVICES=0 python3 prefill_worker.py \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager \
--kv-transfer-config '{"kv_connector":"DynamoNixlConnector"}' \
--block-size 64 \
--max-num-batched-tokens 16384 \
--max-model-len 16384
# start decode worker in another terminal
CUDA_VISIBLE_DEVICES=1 python3 worker.py \
--remote-prefill \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager \
--tensor-parallel-size 1 \
--kv-transfer-config '{"kv_connector":"DynamoNixlConnector"}' \
--block-size 64 \
--max-num-batched-tokens 16384 \
--max-model-len 16384 \
<optional kv router args: --router kv --enable-prefix-caching>
<optional disaggregated router args: --conditional-disagg --max-local-prefill-length <length>>
```
### Multi-Node Deployment
For multi-node deployment, etcd, nats, processor, and kv router
are only required on the head node. The only components that need
to be deployed on all nodes are the workers.
Set the following environment variables on each node before running the workers:
```bash
export NATS_SERVER="nats://<nats-server-host>:<nats-server-port>"
export ETCD_ENDPOINTS="http://<etcd-server-host>:<etcd-server-port>"
```
### Common Issues
If torch GLOO backend is complaining about file name too long, set
```
export GLOO_SOCKET_IFNAME=lo
```
## Client
In another terminal:
```
# this test request has around 200 tokens isl
curl localhost:8181/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"messages": [
{
"role": "user",
"content": "In the heart of Eldoria, an ancient land of boundless magic and mysterious creatures, lies the long-forgotten city of Aeloria. Once a beacon of knowledge and power, Aeloria was buried beneath the shifting sands of time, lost to the world for centuries. You are an intrepid explorer, known for your unparalleled curiosity and courage, who has stumbled upon an ancient map hinting at ests that Aeloria holds a secret so profound that it has the potential to reshape the very fabric of reality. Your journey will take you through treacherous deserts, enchanted forests, and across perilous mountain ranges. Your Task: Character Background: Develop a detailed background for your character. Describe their motivations for seeking out Aeloria, their skills and weaknesses, and any personal connections to the ancient city or its legends. Are they driven by a quest for knowledge, a search for lost familt clue is hidden."
}
],
"stream":false,
"max_tokens": 30
}'
```
## Run genai-perf
`genai-perf` is a tool for profiling and benchmarking LLM servers. It is already installed in the container. For more details, please refer to the [genai-perf README](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/perf_analyzer/genai-perf/README.html).
```
genai-perf profile \
-m deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--url localhost:8181 \
--endpoint-type chat \
--streaming \
--service-kind openai \
--endpoint v1/chat/completions \
--warmup-request-count 10 \
--random-seed 123 \
--synthetic-input-tokens-stddev 0 \
--output-tokens-stddev 0 \
--tokenizer deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--synthetic-input-tokens-mean 3000 \
--output-tokens-mean 150 \
--extra-inputs min_tokens:150 \
--extra-inputs max_tokens:150 \
--profile-export-file my_profile_export.json \
--artifact-dir artifacts/ \
--concurrency 10 \
--request-count 40 \
-- -v \
--async
```
## Close deployment
Kill all python processes and clean up metadata files:
```
pkill -9 -f python
```
## TODOs, limitations, known issues
- [ ] Add etcd for discovery
- [ ] Multi-node deployment support
- [ ] Enable chunked prefill
- [ ] Process many remote prefill in one iteration
- [ ] Support recompute preemption
- [ ] Make sure decode does not preempt blocks before xfer finishes
- [ ] Layer wise transfer
- [ ] Non blocking send in prefill (cache manager should check xfer status)
- [ ] Test under load
- [ ] Support pp > 1
- [ ] Check why adding extra seed input is crashing vllm with remote prefill
- [ ] Unified worker for both prefill and decode
- [x] Support mixed tp
- [x] Require sending two parallel requests to start decode for the first time
- [x] Concurrency > 2 is not working
- [x] Parse cmdline args
- [x] Manual nixl example with tp1
- [x] Zero copy
- [x] Conditional remote prefill
- [x] Manual example with tp > 1
- [x] Run on dynamo distributed runtime
- [x] add oai http endpoint
- [x] Sample only on decode, do note return remote prefill response
- [x] Check if all transfers finished before moving to decode
- [x] Enable async output processing - could be working
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
from typing import AsyncIterator, List, Optional, Protocol, Union, runtime_checkable
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.chat_utils import ConversationMessage
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
RequestResponseMetadata,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import RequestPrompt
from vllm.inputs.data import TokensPrompt
from vllm.transformers_utils.tokenizer import AnyTokenizer
@runtime_checkable
class ProcessMixInRequired(Protocol):
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
class ProcessMixIn(ProcessMixInRequired):
"""
Mixin for pre and post processing for vLLM
Requires engine_args, engine_client, processor, model_config to be initialized
"""
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
def __init__(self):
pass
def _get_processor(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
# Determine the processor type based on the request structure
return (
self.chat_processor
if isinstance(raw_request, ChatCompletionRequest)
else self.completions_processor
)
async def _parse_raw_request(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
processor = self._get_processor(raw_request)
if processor is None:
raise RuntimeError("Processor has not been initialized")
request = processor.parse_raw_request(raw_request)
preprocess_result = await processor.preprocess(raw_request)
default_max_tokens = self.model_config.max_model_len - len(
preprocess_result.engine_prompt["prompt_token_ids"]
)
default_sampling_params = self.model_config.get_diff_sampling_param()
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern,
default_sampling_params,
)
return (
request,
preprocess_result.conversation,
preprocess_result.request_prompt,
preprocess_result.engine_prompt,
sampling_params,
)
async def _stream_response(self, request, generator, request_id, conversation):
processor = self._get_processor(request)
if processor is None:
raise RuntimeError("processor has not been initialized")
return processor.stream_response(
request,
generator,
request_id,
conversation,
)
class PreprocessResult:
def __init__(
self,
conversation: Optional[ConversationMessage],
request_prompt: RequestPrompt,
engine_prompt: TokensPrompt,
):
self.conversation = conversation
self.request_prompt = request_prompt
self.engine_prompt = engine_prompt
class ChatProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingChat(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
response_role="assistant",
chat_template=None,
chat_template_content_format="auto",
)
def parse_raw_request(
self, raw_request: ChatCompletionRequest
) -> ChatCompletionRequest:
return ChatCompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: ChatCompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
(
conversation,
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_chat(
request,
self.tokenizer,
request.messages,
chat_template=request.chat_template or self.tokenizer.chat_template,
chat_template_content_format=self.openai_serving.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tool_dicts=None,
documents=request.documents,
chat_template_kwargs=request.chat_template_kwargs,
tool_parser=self.openai_serving.tool_parser,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
return PreprocessResult(conversation[0], request_prompts[0], engine_prompts[0])
async def stream_response(
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: List,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if not request.stream:
raise ValueError("Only streaming responses are supported")
async for raw_response in self.openai_serving.chat_completion_stream_generator(
request,
result_generator,
request_id,
request.model,
conversation,
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
class CompletionsProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingCompletion(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
)
def parse_raw_request(self, raw_request: CompletionRequest) -> CompletionRequest:
return CompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: CompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
(
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_completion(
request,
self.tokenizer,
input_or_inputs=request.prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
return PreprocessResult(None, request_prompts[0], engine_prompts[0])
async def stream_response(
self,
request: CompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: Optional[List[ConversationMessage]] = None,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if not request.stream:
raise ValueError("Only streaming responses are supported")
async for raw_response in self.openai_serving.completion_stream_generator(
request,
result_generator,
request_id,
int(time.time()), # created_time
request.model,
1, # num_prompts
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from contextlib import asynccontextmanager
from typing import ClassVar, Optional
from nats.aio.client import Client as NATS
from nats.errors import Error as NatsError
from nats.js.client import JetStreamContext
from nats.js.errors import NotFoundError
class NATSQueue:
_instance: ClassVar[Optional["NATSQueue"]] = None
_lock: ClassVar[asyncio.Lock] = asyncio.Lock()
def __init__(
self,
stream_name: str = "default",
nats_server: str = "nats://localhost:4222",
dequeue_timeout: float = 1,
):
self.nats_url = nats_server
self._nc: Optional[NATS] = None
self._js: Optional[JetStreamContext] = None
# TODO: check if this is needed
# Sanitize stream_name to remove path separators
self._stream_name = stream_name.replace("/", "_").replace("\\", "_")
self._subject = f"{self._stream_name}.*"
self.dequeue_timeout = dequeue_timeout
self._subscriber: Optional[JetStreamContext.PullSubscription] = None
@classmethod
@asynccontextmanager
async def get_instance(
cls,
*,
stream_name: str = "default",
nats_server: str = "nats://localhost:4222",
dequeue_timeout: float = 1,
):
"""Get or create a singleton instance of NATSq"""
# TODO: check if this _lock is needed with GIL
async with cls._lock:
if cls._instance is None:
cls._instance = cls(
stream_name=stream_name,
nats_server=nats_server,
dequeue_timeout=dequeue_timeout,
)
await cls._instance.connect()
try:
yield cls._instance
except Exception:
if cls._instance:
await cls._instance.close()
cls._instance = None
raise
# TODO: check to see if this can be replaced by something like get_instance().close()
@classmethod
async def shutdown(cls):
"""Explicitly close the singleton instance if it exists"""
async with cls._lock:
if cls._instance:
await cls._instance.close()
cls._instance = None
async def connect(self):
"""Establish connection and create stream if needed"""
try:
if self._nc is None:
self._nc = NATS()
await self._nc.connect(self.nats_url)
self._js = self._nc.jetstream()
# Check if stream exists, if not create it
try:
await self._js.stream_info(self._stream_name)
except NotFoundError:
await self._js.add_stream(
name=self._stream_name, subjects=[self._subject]
)
# Create persistent subscriber
self._subscriber = await self._js.pull_subscribe(
f"{self._stream_name}.queue", durable="worker-group"
)
except NatsError as e:
await self.close()
raise ConnectionError(f"Failed to connect to NATS: {e}")
async def ensure_connection(self):
"""Ensure we have an active connection"""
if self._nc is None or self._nc.is_closed:
await self.connect()
async def close(self):
"""Close the connection when done"""
if self._nc:
await self._nc.close()
self._nc = None
self._js = None
self._subscriber = None
# TODO: is enqueue/dequeue_object a better name for a general queue?
async def enqueue_task(self, task_data: bytes) -> None:
"""
Enqueue a task using msgspec-encoded data
"""
await self.ensure_connection()
try:
await self._js.publish(f"{self._stream_name}.queue", task_data) # type: ignore
except NatsError as e:
raise RuntimeError(f"Failed to enqueue task: {e}")
async def dequeue_task(self) -> Optional[bytes]:
"""Dequeue and return a task as raw bytes, to be decoded with msgspec"""
await self.ensure_connection()
try:
msgs = await self._subscriber.fetch(1, timeout=self.dequeue_timeout) # type: ignore
if msgs:
msg = msgs[0]
await msg.ack()
return msg.data
return None
except asyncio.TimeoutError:
return None
except NatsError as e:
raise RuntimeError(f"Failed to dequeue task: {e}")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import contextmanager
import msgspec
from vllm.distributed.device_communicators.nixl import NixlMetadata
from dynamo.runtime import DistributedRuntime
METADATA_DIR = "/tmp/nixl"
@contextmanager
def temp_metadata_file(engine_id, metadata: NixlMetadata):
os.makedirs(METADATA_DIR, exist_ok=True)
path = f"{METADATA_DIR}/{engine_id}.nixl_meta"
with open(path, "wb") as f:
encoded = msgspec.msgpack.encode(metadata)
print(f"Size of encoded metadata: {len(encoded)}")
f.write(encoded)
try:
yield path
finally:
if os.path.exists(path):
os.remove(path)
def find_remote_metadata(engine_id):
# find and load metadata from METADATA_DIR that do not match engine_id
remote_metadata = []
for file in os.listdir(METADATA_DIR):
if file.endswith(".nixl_meta"):
if file.split(".")[0] != engine_id:
with open(os.path.join(METADATA_DIR, file), "rb") as f:
remote_metadata.append(
msgspec.msgpack.decode(f.read(), type=NixlMetadata)
)
return remote_metadata
class NixlMetadataStore:
NIXL_METADATA_KEY = "nixl_metadata"
def __init__(self, namespace: str, runtime: DistributedRuntime) -> None:
self._namespace = namespace
# TODO Remove metadata from etcd on delete
self._stored: set[str] = set()
self._cached: dict[str, NixlMetadata] = {}
self._client = runtime.etcd_client()
self._key_prefix = f"{self._namespace}/{NixlMetadataStore.NIXL_METADATA_KEY}"
async def put(self, engine_id, metadata: NixlMetadata):
serialized_metadata = msgspec.msgpack.encode(metadata)
key = "/".join([self._key_prefix, engine_id])
await self._client.kv_put(key, serialized_metadata, None)
self._stored.add(engine_id)
async def get(self, engine_id) -> NixlMetadata:
try:
if engine_id in self._cached:
return self._cached[engine_id]
key = "/".join([self._key_prefix, engine_id])
key_values = await self._client.kv_get_prefix(key)
deserialized_metadata = None
for item in key_values:
deserialized_metadata = msgspec.msgpack.decode(
item["value"], type=NixlMetadata
)
break
if deserialized_metadata is None:
raise Exception("metadata not found in etcd")
self._cached[engine_id] = deserialized_metadata
# TODO watch for changes and update cache
# self._client.add_watch_callback(
# key,
# self._watch_callback,
# )
except Exception as e:
raise Exception("Error retrieving metadata for engine {engine_id}") from e
return deserialized_metadata
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import msgspec
from utils.nats_queue import NATSQueue
from vllm.remote_prefill import RemotePrefillRequest
class PrefillQueue(NATSQueue):
"""
A wrapper of NATSQueue for PrefillRequest.
The stream name is forced to be "prefill_queue".
"""
def __init__(
self,
stream_name="prefill_queue",
nats_server: str = "nats://localhost:4222",
dequeue_timeout: float = 1,
):
super().__init__(
stream_name=stream_name,
nats_server=nats_server,
dequeue_timeout=dequeue_timeout,
)
async def enqueue_prefill_request(
self, prefill_request: RemotePrefillRequest
) -> None:
encoded_request = msgspec.json.encode(prefill_request)
await self.enqueue_task(encoded_request)
async def dequeue_prefill_request(self) -> Optional[RemotePrefillRequest]:
encoded_request = await self.dequeue_task()
if encoded_request is not None:
prefill_request = msgspec.json.decode(
encoded_request, type=RemotePrefillRequest
)
return prefill_request
else:
return None
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: rename to avoid ambiguity with vllm package
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser
def parse_vllm_args() -> AsyncEngineArgs:
parser = FlexibleArgumentParser()
parser.add_argument(
"--router",
type=str,
choices=["random", "round-robin", "kv"],
default="random",
help="Router type to use for scheduling requests to workers",
)
parser.add_argument(
"--remote-prefill", action="store_true", help="Enable remote prefill"
)
parser.add_argument(
"--conditional-disagg",
action="store_true",
help="Use disaggregated router to decide whether to prefill locally or remotely",
)
parser.add_argument(
"--max-local-prefill-length",
type=int,
default=1000,
help="Maximum length of local prefill",
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine_args.router = args.router
engine_args.remote_prefill = args.remote_prefill
engine_args.conditional_disagg = args.conditional_disagg
engine_args.max_local_prefill_length = args.max_local_prefill_length
return engine_args
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment