Commit 45bc426c authored by ptarasiewiczNV's avatar ptarasiewiczNV Committed by GitHub
Browse files

feat: vLLM DistributedRuntime Monolith and Disagg Workers Example


Co-authored-by: default avatarNeelay Shah <neelays@nvidia.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 39441642
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# vLLM Integration with Triton Distributed
This example demonstrates how to use Triton Distributed to serve large language models with the vLLM engine, enabling efficient model serving with both monolithic and disaggregated deployment options.
## Prerequisites
1. Follow the setup instructions in the Python bindings [README](/runtime/rust/python-wheel/README.md) to prepare your environment
2. Install vLLM:
```bash
uv pip install vllm==0.7.2
```
3. Start required services (etcd and NATS):
Option A: Using [Docker Compose](/runtime/rust/docker-compose.yml) (Recommended)
```bash
docker-compose up -d
```
Option B: Manual Setup
- [NATS.io](https://docs.nats.io/running-a-nats-service/introduction/installation) server with [Jetstream](https://docs.nats.io/nats-concepts/jetstream)
- example: `nats-server -js --trace`
- [etcd](https://etcd.io) server
- follow instructions in [etcd installation](https://etcd.io/docs/v3.5/install/) to start an `etcd-server` locally
## Deployment Options
### 1. Monolithic Deployment
Run the server and client components in separate terminal sessions:
**Terminal 1 - Server:**
```bash
python3 -m monolith.worker \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--max-model-len 100 \
--enforce-eager
```
**Terminal 2 - Client:**
```bash
python3 -m common.client \
--prompt "what is the capital of france?" \
--max-tokens 10 \
--temperature 0.5
```
The output should look similar to:
```
Annotated(data=' Well', event=None, comment=[], id=None)
Annotated(data=' Well,', event=None, comment=[], id=None)
Annotated(data=' Well, France', event=None, comment=[], id=None)
Annotated(data=' Well, France is', event=None, comment=[], id=None)
Annotated(data=' Well, France is a', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country located', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country located in', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country located in Western', event=None, comment=[], id=None)
Annotated(data=' Well, France is a country located in Western Europe', event=None, comment=[], id=None)
```
### 2. Disaggregated Deployment
This deployment option splits the model serving across prefill and decode workers, enabling more efficient resource utilization.
**Terminal 1 - Prefill Worker:**
```bash
CUDA_VISIBLE_DEVICES=0 python3 -m disaggregated.prefill_worker \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--enforce-eager \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}'
```
**Terminal 2 - Decode Worker:**
```bash
CUDA_VISIBLE_DEVICES=1 python3 -m disaggregated.decode_worker \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--enforce-eager \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}'
```
**Terminal 3 - Client:**
```bash
python3 -m common.client \
--prompt "what is the capital of france?" \
--max-tokens 10 \
--temperature 0.5
```
The disaggregated deployment utilizes separate GPUs for prefill and decode operations, allowing for optimized resource allocation and improved performance. For more details on the disaggregated deployment, please refer to the [vLLM documentation](https://docs.vllm.ai/en/latest/features/disagg_prefill.html).
### 3. Multi-Node Deployment
The vLLM workers can be deployed across multiple nodes by configuring the NATS and etcd connection endpoints through environment variables. This enables distributed inference across a cluster.
Set the following environment variables on each node before running the workers:
```bash
export NATS_SERVER="nats://<nats-server-host>:<nats-server-port>"
export ETCD_ENDPOINTS="http://<etcd-server-host1>:<etcd-server-port>,http://<etcd-server-host2>:<etcd-server-port>",...
```
For disaggregated deployment, you will also need to pass the `kv_ip` and `kv_port` to the workers in the `kv_transfer_config` argument:
```bash
...
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":<rank>,"kv_parallel_size":2,"kv_ip":<master_node_ip>,"kv_port":<kv_port>}'
```
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import uvloop
from triton_distributed_rs import DistributedRuntime, triton_worker
from vllm.utils import FlexibleArgumentParser
from .protocol import Request
@triton_worker()
async def worker(
runtime: DistributedRuntime, prompt: str, max_tokens: int, temperature: float
):
"""
Instantiate a `backend` client and call the `generate` endpoint
"""
# get endpoint
endpoint = runtime.namespace("triton-init").component("vllm").endpoint("generate")
# create client
client = await endpoint.client()
# list the endpoints
print(client.endpoint_ids())
# issue request
stream = await client.generate(
Request(
prompt=prompt,
sampling_params={"temperature": temperature, "max_tokens": max_tokens},
).model_dump_json()
)
# process response
async for resp in stream:
print(resp)
if __name__ == "__main__":
uvloop.install()
parser = FlexibleArgumentParser()
parser.add_argument("--prompt", type=str, default="what is the capital of france?")
parser.add_argument("--max-tokens", type=int, default=10)
parser.add_argument("--temperature", type=float, default=0.5)
args = parser.parse_args()
asyncio.run(worker(args.prompt, args.max_tokens, args.temperature))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser
def parse_vllm_args() -> AsyncEngineArgs:
parser = FlexibleArgumentParser()
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
return AsyncEngineArgs.from_cli_args(args)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic import BaseModel
class Request(BaseModel):
prompt: str
sampling_params: dict
class PrefillRequest(Request):
request_id: str
class Response(BaseModel):
text: str
class PrefillResponse(BaseModel):
prefilled: bool
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import uuid
import uvloop
import vllm
from common.parser import parse_vllm_args
from common.protocol import PrefillRequest, Request, Response
from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.logger import logger as vllm_logger
class VllmDecodeEngine:
"""
Request handler for the generate endpoint
"""
def __init__(self, engine_args: AsyncEngineArgs, prefill):
assert (
engine_args.kv_transfer_config.is_kv_consumer
), "Decode worker must be a KV consumer"
self.engine = vllm.AsyncLLMEngine.from_engine_args(engine_args)
self.prefill = prefill
@triton_endpoint(Request, Response)
async def generate(self, request):
vllm_logger.info(f"Received request: {request}")
sampling_params = vllm.SamplingParams(**request.sampling_params)
request_id = str(uuid.uuid4())
prefill_sampling_params = {**request.sampling_params}
prefill_sampling_params["max_tokens"] = 1
prefill_request = PrefillRequest(
prompt=request.prompt,
sampling_params=prefill_sampling_params,
request_id=request_id,
)
prefill_generator = await self.prefill.generate(
prefill_request.model_dump_json()
)
prefill_response = [resp async for resp in prefill_generator]
assert len(prefill_response) == 1, "Prefill response should be a single boolean"
prefill_response = prefill_response[0]
vllm_logger.debug(f"Prefill response: {prefill_response}")
async for response in self.engine.generate(
request.prompt, sampling_params, request_id
):
vllm_logger.debug(f"Generated response: {response}")
yield response.outputs[0].text
@triton_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
component = runtime.namespace("triton-init").component("vllm")
await component.create_service()
prefill = (
await runtime.namespace("triton-init")
.component("prefill")
.endpoint("generate")
.client()
)
endpoint = component.endpoint("generate")
await endpoint.serve_endpoint(VllmDecodeEngine(engine_args, prefill).generate)
if __name__ == "__main__":
uvloop.install()
engine_args = parse_vllm_args()
asyncio.run(worker(engine_args))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import uvloop
import vllm
from common.parser import parse_vllm_args
from common.protocol import PrefillRequest, PrefillResponse
from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.logger import logger as vllm_logger
class VllmPrefillEngine:
"""
Request handler for the generate endpoint
"""
def __init__(self, engine_args: AsyncEngineArgs):
assert (
engine_args.kv_transfer_config.is_kv_producer
), "Prefill worker must be a KV producer"
self.engine = vllm.AsyncLLMEngine.from_engine_args(engine_args)
@triton_endpoint(PrefillRequest, PrefillResponse)
async def generate(self, request):
vllm_logger.info(f"Received prefill request: {request}")
sampling_params = vllm.SamplingParams(**request.sampling_params)
async for response in self.engine.generate(
request.prompt, sampling_params, request.request_id
):
vllm_logger.debug(f"Generated response: {response}")
yield True
@triton_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
component = runtime.namespace("triton-init").component("prefill")
await component.create_service()
endpoint = component.endpoint("generate")
await endpoint.serve_endpoint(VllmPrefillEngine(engine_args).generate)
if __name__ == "__main__":
uvloop.install()
engine_args = parse_vllm_args()
asyncio.run(worker(engine_args))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import uuid
import uvloop
import vllm
from common.parser import parse_vllm_args
from common.protocol import Request, Response
from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.logger import logger as vllm_logger
class VllmEngine:
"""
Request handler for the generate endpoint
"""
def __init__(self, engine_args: AsyncEngineArgs):
self.engine = vllm.AsyncLLMEngine.from_engine_args(engine_args)
@triton_endpoint(Request, Response)
async def generate(self, request):
vllm_logger.debug(f"Received request: {request}")
sampling_params = vllm.SamplingParams(**request.sampling_params)
request_id = str(uuid.uuid4())
async for response in self.engine.generate(
request.prompt, sampling_params, request_id
):
vllm_logger.debug(f"Generated response: {response}")
yield response.outputs[0].text
@triton_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
component = runtime.namespace("triton-init").component("vllm")
await component.create_service()
endpoint = component.endpoint("generate")
await endpoint.serve_endpoint(VllmEngine(engine_args).generate)
if __name__ == "__main__":
uvloop.install()
engine_args = parse_vllm_args()
asyncio.run(worker(engine_args))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment