Commit 9e4a548d authored by ptarasiewiczNV's avatar ptarasiewiczNV Committed by GitHub
Browse files

feat: add openai endpoint to the vllm example (#183)


Co-authored-by: default avatarRyan Olson <rolson@nvidia.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 695754ae
......@@ -135,8 +135,16 @@ COPY runtime /workspace/runtime
RUN cd runtime/rust && \
cargo build --release --locked && cargo doc --no-deps
# Build OpenAI HTTP Service binaries
COPY llm/rust /workspace/llm/rust
COPY examples/rust /workspace/examples/rust
RUN cd examples/rust && \
cargo build --release && \
cp target/release/http /usr/local/bin/ && \
cp target/release/llmctl /usr/local/bin/
# Generate C bindings. Note that this is required for TRTLLM backend re-build
COPY llm /workspace/llm
COPY llm/rust /workspace/llm/rust
RUN cd llm/rust/ && \
cargo build --release --locked && cargo doc --no-deps
......
......@@ -78,6 +78,14 @@ COPY runtime /workspace/runtime
RUN cd runtime/rust && \
cargo build --release --locked && cargo doc --no-deps
# Build OpenAI HTTP Service binaries
COPY llm/rust /workspace/llm/rust
COPY examples/rust /workspace/examples/rust
RUN cd examples/rust && \
cargo build --release && \
cp target/release/http /usr/local/bin/ && \
cp target/release/llmctl /usr/local/bin/
# Generate C bindings for kv cache routing in vLLM
COPY llm /workspace/llm
RUN cd llm/rust/ && \
......@@ -94,6 +102,8 @@ RUN mkdir -p /opt/triton/llm_binding/wheels && mkdir /opt/triton/llm_binding/lib
RUN cp python-wheel/dist/triton_distributed_rs*cp312*.whl /opt/triton/llm_binding/wheels/.
RUN cp llm/rust/target/release/libtriton_llm_capi.so /opt/triton/llm_binding/lib/.
RUN cp -r llm/rust/libtriton-llm/include /opt/triton/llm_binding/.
# Tell vllm to use the Triton LLM C API for KV Cache Routing
ENV VLLM_KV_CAPI_PATH="/opt/triton/llm_binding/lib/libtriton_llm_capi.so"
# Install patched vllm
ARG VLLM_REF="v0.7.2"
......@@ -118,8 +128,6 @@ COPY . /workspace
# Environment setup
ENV PYTHONPATH="${PYTHONPATH}:/workspace/examples/python:/opt/tritonserver/python/openai/openai_frontend"
ENV RAPIDS_LIBUCX_PREFER_SYSTEM_LIBRARY=true
# Tell vllm to use the Triton LLM C API for KV Cache Routing
ENV VLLM_KV_CAPI_PATH="/opt/triton/llm_binding/lib/libtriton_llm_capi.so"
CMD []
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
......@@ -51,48 +51,36 @@ The example is designed to run in a containerized environment using Triton Distr
./container/run.sh --framework VLLM -it
```
## Deployment Options
## Deployment
### 1. Monolithic Deployment
#### 1. HTTP Server
Run the server and client components in separate terminal sessions:
Run the server logging (with debug level logging):
```bash
TRD_LOG=DEBUG http
```
By default the server will run on port 9992.
Add model to the server:
```bash
llmctl http add chat-models deepseek-ai/DeepSeek-R1-Distill-Llama-8B triton-init.vllm.generate
```
### 2. Workers
#### 2.1. Monolithic Deployment
In a separate terminal run the vllm worker:
**Terminal 1 - Server:**
```bash
# Launch worker
cd /workspace/examples/python_rs/llm/vllm
python3 -m monolith.worker \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--max-model-len 100 \
--enforce-eager
```
**Terminal 2 - Client:**
```bash
# Run client
cd /workspace/examples/python_rs/llm/vllm
python3 -m common.client \
--prompt "what is the capital of france?" \
--max-tokens 10 \
--temperature 0.5
```
The output should look similar to:
```
Annotated(data=' Well', event=None, comment=[], id=None)
Annotated(data=' Well,', event=None, comment=[], id=None)
Annotated(data=' Well, France', event=None, comment=[], id=None)
Annotated(data=' Well, France is', event=None, comment=[], id=None)
Annotated(data=' Well, France is a', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country located', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country located in', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country located in Western', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country located in Western Europe', event=None, comment=[], id=None)
```
### 2. Disaggregated Deployment
#### 2.2. Disaggregated Deployment
This deployment option splits the model serving across prefill and decode workers, enabling more efficient resource utilization.
......@@ -102,7 +90,6 @@ This deployment option splits the model serving across prefill and decode worker
cd /workspace/examples/python_rs/llm/vllm
VLLM_WORKER_MULTIPROC_METHOD=spawn CUDA_VISIBLE_DEVICES=0 python3 -m disaggregated.prefill_worker \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--enforce-eager \
--tensor-parallel-size 1 \
......@@ -116,7 +103,6 @@ VLLM_WORKER_MULTIPROC_METHOD=spawn CUDA_VISIBLE_DEVICES=0 python3 -m disaggregat
cd /workspace/examples/python_rs/llm/vllm
VLLM_WORKER_MULTIPROC_METHOD=spawn CUDA_VISIBLE_DEVICES=1,2 python3 -m disaggregated.decode_worker \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--enforce-eager \
--tensor-parallel-size 2 \
......@@ -124,21 +110,43 @@ VLLM_WORKER_MULTIPROC_METHOD=spawn CUDA_VISIBLE_DEVICES=1,2 python3 -m disaggreg
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}'
```
**Terminal 3 - Client:**
```bash
# Run client
cd /workspace/examples/python_rs/llm/vllm
python3 -m common.client \
--prompt "what is the capital of france?" \
--max-tokens 10 \
--temperature 0.5
```
The disaggregated deployment utilizes separate GPUs for prefill and decode operations, allowing for optimized resource allocation and improved performance. For more details on the disaggregated deployment, please refer to the [vLLM documentation](https://docs.vllm.ai/en/latest/features/disagg_prefill.html).
### 3. Client
### 3. Multi-Node Deployment
```bash
curl localhost:9992/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"messages": [
{"role": "user", "content": "What is the capital of France?"}
]
}'
```
Expected output:
```json
{
"id": "5b04e7b0-0dcd-4c45-baa0-1d03d924010c",
"choices": [{
"message": {
"role": "assistant",
"content": "The capital of France is Paris. Paris is a major city known for iconic landmarks like the Eiffel Tower and the Louvre Museum."
},
"index": 0,
"finish_reason": "stop"
}],
"created": 1739548787,
"model": "vllm",
"object": "chat.completion",
"usage": null,
"system_fingerprint": null
}
```
### 4. Multi-Node Deployment
The vLLM workers can be deployed across multiple nodes by configuring the NATS and etcd connection endpoints through environment variables. This enables distributed inference across a cluster.
......@@ -158,7 +166,7 @@ For disaggregated deployment, you will also need to pass the `kv_ip` and `kv_por
```
### 4. KV Router Deployment
### 5. KV Router Deployment
The KV Router is a component that aggregates KV Events from all the workers and maintains a prefix tree of the cached tokens. It makes decisions on which worker to route requests to based on the length of the prefix match and the load on the workers.
......@@ -237,11 +245,9 @@ python3 -m common.client \
--temperature 0.5
```
### 5. Known Issues and Limitations
### 6. Known Issues and Limitations
- vLLM is not working well with the `fork` method for multiprocessing and TP > 1. This is a known issue and a workaround is to use the `spawn` method instead. See [vLLM issue](https://github.com/vllm-project/vllm/issues/6152).
- `kv_rank` of `kv_producer` must be smaller than of `kv_consumer`.
- Instances with the same `kv_role` must have the same `--tensor-parallel-size`.
- Currently only `--pipeline-parallel-size 1` is supported for XpYd disaggregated deployment.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import vllm
from common.chat_processor import ChatProcessor
from vllm.engine.arg_utils import AsyncEngineArgs
class BaseVllmEngine:
"""
Request handler for the generate endpoint
"""
def __init__(self, engine_args: AsyncEngineArgs):
self.model_config = engine_args.create_model_config()
self.engine = vllm.AsyncLLMEngine.from_engine_args(engine_args)
self.chat_processor = ChatProcessor(self.engine, self.model_config)
async def _parse_raw_request(self, raw_request):
request = self.chat_processor.parse_raw_request(raw_request)
(
conversation,
request_prompt,
engine_prompt,
) = await self.chat_processor.preprocess(raw_request)
default_max_tokens = self.model_config.max_model_len - len(
engine_prompt["prompt_token_ids"]
)
default_sampling_params = self.model_config.get_diff_sampling_param()
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern,
default_sampling_params,
)
return request, conversation, request_prompt, engine_prompt, sampling_params
async def _stream_response(self, request, generator, request_id, conversation):
return self.chat_processor.stream_response(
request,
generator,
request_id,
conversation,
)
@abc.abstractmethod
async def generate(self, raw_request):
pass
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import AsyncIterator, List
import vllm
from vllm.config import ModelConfig
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
RequestResponseMetadata,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
class ChatProcessor:
def __init__(self, engine_client: vllm.AsyncLLMEngine, model_config: ModelConfig):
self.engine_client = engine_client
self.model_config = model_config
self.openai_serving = OpenAIServingChat(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
response_role="assistant",
chat_template=None,
chat_template_content_format="auto",
)
def parse_raw_request(self, raw_request: dict) -> ChatCompletionRequest:
return ChatCompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: dict):
request = self.parse_raw_request(raw_request)
tokenizer = await self.engine_client.get_tokenizer()
(
conversation,
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_chat(
request,
tokenizer,
request.messages,
chat_template=request.chat_template or tokenizer.chat_template,
chat_template_content_format=self.openai_serving.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tool_dicts=None,
documents=request.documents,
chat_template_kwargs=request.chat_template_kwargs,
tool_parser=self.openai_serving.tool_parser,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
return conversation[0], request_prompts[0], engine_prompts[0]
async def stream_response(
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: List,
):
tokenizer = await self.engine_client.get_tokenizer()
request_metadata = RequestResponseMetadata(request_id=request_id)
assert request.stream, "Only stream is supported"
async for raw_response in self.openai_serving.chat_completion_stream_generator(
request,
result_generator,
request_id,
request.model,
conversation,
tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
......@@ -18,16 +18,21 @@ import asyncio
import random
import uuid
import msgspec
import uvloop
import vllm
from common.base_engine import BaseVllmEngine
from common.parser import parse_vllm_args
from common.protocol import PrefillRequest, Request, Response
from common.protocol import PrefillRequest
from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionStreamResponse,
)
from vllm.logger import logger as vllm_logger
class VllmDecodeEngine:
class VllmDecodeEngine(BaseVllmEngine):
"""
Request handler for the generate endpoint
"""
......@@ -36,10 +41,10 @@ class VllmDecodeEngine:
assert (
engine_args.kv_transfer_config.is_kv_consumer
), "Decode worker must be a KV consumer"
self.engine = vllm.AsyncLLMEngine.from_engine_args(engine_args)
super().__init__(engine_args)
self.prefills: list = []
self.prefill_workers = (
self.num_prefill_workers = (
self.engine.engine.vllm_config.kv_transfer_config.kv_producers_parallel_size
)
self.kv_rank = self.engine.engine.vllm_config.kv_transfer_config.kv_rank
......@@ -47,29 +52,41 @@ class VllmDecodeEngine:
def add_prefill(self, prefill):
self.prefills.append(prefill)
@triton_endpoint(Request, Response)
async def generate(self, request):
vllm_logger.info(f"Received request: {request}")
sampling_params = vllm.SamplingParams(**request.sampling_params)
prefill_rank = random.choice(range(self.prefill_workers))
@triton_endpoint(ChatCompletionRequest, ChatCompletionStreamResponse)
async def generate(self, raw_request):
vllm_logger.debug(f"Got raw request: {raw_request}")
(
request,
conversation,
request_prompt,
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
prefill_rank = random.choice(range(self.num_prefill_workers))
request_id = f"{uuid.uuid4()}___prefill_kv_rank_{prefill_rank}___decode_kv_rank_{self.kv_rank}"
prefill_sampling_params = {**request.sampling_params}
prefill_sampling_params = {**msgspec.to_builtins(sampling_params)}
prefill_sampling_params["max_tokens"] = 1
prefill_request = PrefillRequest(
prompt=request.prompt,
prompt=request_prompt, # TODO: we should use engine prompt to avoid extra tokenization
sampling_params=prefill_sampling_params,
request_id=request_id,
)
vllm_logger.debug(f"Prefill request: {prefill_request}")
self.prefills[prefill_rank].generate(
prefill_request.model_dump_json(),
)
async for response in self.engine.generate(
request.prompt, sampling_params, request_id
vllm_logger.debug(
f"Running generate with engine_prompt: {engine_prompt}, sampling_params: {sampling_params}, request_id: {request_id}"
)
generator = self.engine.generate(engine_prompt, sampling_params, request_id)
async for response in await self._stream_response(
request, generator, request_id, conversation
):
vllm_logger.debug(f"Generated response: {response}")
yield response.outputs[0].text
yield response
@triton_worker()
......@@ -82,7 +99,7 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
await component.create_service()
decode_engine = VllmDecodeEngine(engine_args)
for i in range(decode_engine.prefill_workers):
for i in range(decode_engine.num_prefill_workers):
prefill = (
await runtime.namespace("triton-init")
.component("prefill")
......
......@@ -18,6 +18,7 @@ import asyncio
import uvloop
import vllm
from common.base_engine import BaseVllmEngine
from common.parser import parse_vllm_args
from common.protocol import PrefillRequest, PrefillResponse
from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker
......@@ -25,7 +26,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.logger import logger as vllm_logger
class VllmPrefillEngine:
class VllmPrefillEngine(BaseVllmEngine):
"""
Request handler for the generate endpoint
"""
......@@ -34,7 +35,7 @@ class VllmPrefillEngine:
assert (
engine_args.kv_transfer_config.is_kv_producer
), "Prefill worker must be a KV producer"
self.engine = vllm.AsyncLLMEngine.from_engine_args(engine_args)
super().__init__(engine_args)
self.kv_rank = self.engine.engine.vllm_config.kv_transfer_config.kv_rank
@triton_endpoint(PrefillRequest, PrefillResponse)
......
......@@ -18,32 +18,47 @@ import asyncio
import uuid
import uvloop
import vllm
from common.base_engine import BaseVllmEngine
from common.parser import parse_vllm_args
from common.protocol import Request, Response
from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionStreamResponse,
)
from vllm.logger import logger as vllm_logger
class VllmEngine:
class VllmEngine(BaseVllmEngine):
"""
Request handler for the generate endpoint
"""
def __init__(self, engine_args: AsyncEngineArgs):
self.engine = vllm.AsyncLLMEngine.from_engine_args(engine_args)
super().__init__(engine_args)
@triton_endpoint(Request, Response)
async def generate(self, request):
vllm_logger.debug(f"Received request: {request}")
sampling_params = vllm.SamplingParams(**request.sampling_params)
@triton_endpoint(ChatCompletionRequest, ChatCompletionStreamResponse)
async def generate(self, raw_request):
vllm_logger.debug(f"Got raw request: {raw_request}")
(
request,
conversation,
_,
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
request_id = str(uuid.uuid4())
async for response in self.engine.generate(
request.prompt, sampling_params, request_id
vllm_logger.debug(
f"Running generate with engine_prompt: {engine_prompt}, sampling_params: {sampling_params}, request_id: {request_id}"
)
generator = self.engine.generate(engine_prompt, sampling_params, request_id)
async for response in await self._stream_response(
request, generator, request_id, conversation
):
vllm_logger.debug(f"Generated response: {response}")
yield response.outputs[0].text
yield response
@triton_worker()
......
......@@ -64,7 +64,12 @@ def triton_endpoint(
try:
if len(args) in [1, 2]:
args = list(args)
if isinstance(args[-1], str):
args[-1] = request_model.parse_raw(args[-1])
elif isinstance(args[-1], dict):
args[-1] = request_model.parse_obj(args[-1])
else:
raise ValueError(f"Invalid request: {args[-1]}")
except ValidationError as e:
raise ValueError(f"Invalid request: {e}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment