Commit f7a60cba authored by ptarasiewiczNV's avatar ptarasiewiczNV Committed by GitHub
Browse files

feat: vLLM + NIXL example


Co-authored-by: default avatarPiotr Tarasiewicz Nvidia <ptarasiewicznv@Piotrs-MacBook-Pro.local>
Co-authored-by: default avatarnnshah1 <neelays@nvidia.com>
Co-authored-by: default avataralec-flowers <aflowers@nvidia.com>
parent ea401e3b
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
ARG BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base"
ARG BASE_IMAGE_TAG="25.01-cuda12.8-devel-ubuntu24.04"
FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} AS dev
USER root
# Install utilities
RUN apt update -y && apt install -y git wget curl nvtop tmux vim
# nats
RUN wget https://github.com/nats-io/nats-server/releases/download/v2.10.24/nats-server-v2.10.24-amd64.deb && dpkg -i nats-server-v2.10.24-amd64.deb
# etcd
ENV ETCD_VERSION="v3.5.18"
RUN wget https://github.com/etcd-io/etcd/releases/download/$ETCD_VERSION/etcd-$ETCD_VERSION-linux-amd64.tar.gz -O /tmp/etcd.tar.gz && \
mkdir -p /usr/local/bin/etcd && \
tar -xvf /tmp/etcd.tar.gz -C /usr/local/bin/etcd --strip-components=1
ENV PATH=/usr/local/bin/etcd/:$PATH
### VIRTUAL ENVIRONMENT SETUP ###
# Install uv and create virtualenv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
RUN mkdir /opt/triton && \
uv venv /opt/triton/venv --python 3.12
# Activate virtual environment
ENV VIRTUAL_ENV=/opt/triton/venv
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Install patched vllm - keep this early in Dockerfile to avoid
# rebuilds from unrelated source code changes
ARG VLLM_REF="v0.7.2"
ARG VLLM_PATCH="vllm_${VLLM_REF}-triton-kv-disagg-patch.patch"
RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
bash /tmp/deps/vllm/install.sh --patch /tmp/deps/vllm/${VLLM_PATCH} --ref ${VLLM_REF} --install-cmd "uv pip install --editable" --use-precompiled --installation-dir /opt/vllm
# Install genai-perf for benchmarking
ARG GENAI_PERF_TAG="r25.01"
RUN uv pip install "git+https://github.com/triton-inference-server/perf_analyzer.git@${GENAI_PERF_TAG}#subdirectory=genai-perf"
# Install test dependencies
RUN --mount=type=bind,source=./container/deps/requirements.test.txt,target=/tmp/requirements.txt \
uv pip install --requirement /tmp/requirements.txt
### NIXL SETUP ###
ARG MOFED_VERSION=5.8-1.1.2.1
ARG PYTHON_VERSION=3.12
ARG NSYS_URL=https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2024_4/
ARG NSYS_PKG=NsightSystems-linux-cli-public-2024.4.1.61-3431596.deb
RUN apt-get update -y && apt-get -y install curl \
git \
libnuma-dev \
numactl \
wget \
autotools-dev \
automake \
libtool \
libz-dev \
libiberty-dev \
flex \
build-essential \
cmake \
libibverbs-dev \
libgoogle-glog-dev \
libgtest-dev \
libjsoncpp-dev \
libpython3-dev \
libboost-all-dev \
libssl-dev \
libgrpc-dev \
libgrpc++-dev \
libprotobuf-dev \
protobuf-compiler-grpc \
pybind11-dev \
python3-pip \
etcd-server \
net-tools \
pciutils \
libpci-dev \
vim \
tmux \
screen \
ibverbs-utils \
libibmad-dev
RUN apt-get update && \
apt install -y wget libglib2.0-0
RUN wget ${NSYS_URL}${NSYS_PKG} && \
dpkg -i $NSYS_PKG && \
rm $NSYS_PKG
RUN apt-get install -y linux-tools-common linux-tools-generic ethtool iproute2
RUN apt-get install -y dkms linux-headers-generic
RUN apt-get install -y meson ninja-build uuid-dev gdb
RUN uv pip install --upgrade meson
RUN uv pip install ninja pybind11
RUN cd /usr/local/src && \
curl -fSsL "https://content.mellanox.com/ofed/MLNX_OFED-${MOFED_VERSION}/MLNX_OFED_LINUX-${MOFED_VERSION}-ubuntu20.04-x86_64.tgz" -o mofed.tgz && \
tar -xf /usr/local/src/mofed.tgz && \
cd MLNX_OFED_LINUX-* && \
apt-get update && \
apt-get install -y --no-install-recommends \
./DEBS/libibverbs* ./DEBS/ibverbs-providers* ./DEBS/librdmacm* ./DEBS/libibumad* && \
rm -rf /var/lib/apt/lists/* /usr/local/src/*
ENV LIBRARY_PATH=$LIBRARY_PATH:/usr/local/cuda/lib64 \
LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64
ENV LIBRARY_PATH=$LIBRARY_PATH:/usr/local/lib \
LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib
WORKDIR /workspace
RUN git clone https://github.com/NVIDIA/gdrcopy.git
RUN PREFIX=/usr/local DESTLIB=/usr/local/lib make -C /workspace/gdrcopy lib_install
RUN cp gdrcopy/src/libgdrapi.so.2.* /usr/lib/x86_64-linux-gnu/
RUN ldconfig
ARG UCX_VERSION=v1.18.0
RUN cd /usr/local/src && \
curl -fSsL "https://github.com/openucx/ucx/tarball/${UCX_VERSION}" | tar xz && \
cd openucx-ucx* && \
./autogen.sh && \
./configure \
--prefix=/usr/local/ucx \
--enable-shared \
--disable-static \
--disable-doxygen-doc \
--enable-optimizations \
--enable-cma \
--enable-devel-headers \
--with-cuda=/usr/local/cuda \
--with-verbs \
--with-dm \
--with-gdrcopy=/usr/local \
--enable-mt \
--with-mlx5-dv && \
make -j && \
make -j install-strip && \
ldconfig
ENV LD_LIBRARY_PATH=/usr/local/ucx/lib:$LD_LIBRARY_PATH
ENV CPATH=/usr/local/ucx/include:$CPATH
ENV PATH=/usr/local/ucx/bin:$PATH
ENV PKG_CONFIG_PATH=/usr/local/ucx/lib/pkgconfig:$PKG_CONFIG_PATH
SHELL ["/bin/bash", "-c"]
COPY --from=nixl . /opt/nixl
RUN cd /opt/nixl && \
mkdir build && \
meson setup build/ --prefix=/usr/local/nixl && \
cd build/ && \
ninja && \
ninja install && \
mkdir -p /usr/local/nixl/include/internal && \
cp ../include/*.h /usr/local/nixl/include && \
cp ../include/internal/*.h /usr/local/nixl/include/internal && \
cp ../include/internal/*.h /usr/local/nixl/include/ && \
cp ../src/utils/serdes/serdes.h /usr/local/nixl/include
ENV LD_LIBRARY_PATH=/usr/local/nixl/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH
ENV PYTHONPATH=/usr/local/nixl/lib/python${PYTHON_VERSION}/site-packages/:/opt/nixl/test/python/:$PYTHONPATH
RUN ls -l /usr/local/nixl/
RUN ls -l /usr/local/nixl/include/
RUN ls -l /usr/local/nixl/include/internal/
RUN ls /opt/nixl
# ### MISC UTILITY SETUP ###
# Finish pyright install
RUN pyright --help > /dev/null 2>&1
# Enable Git operations in the /workspace directory
RUN printf "[safe]\n directory=/workspace\n" > /root/.gitconfig
RUN ln -sf /bin/bash /bin/sh
### BUILDS ###
# Rust build/dev dependencies
RUN apt update -y && \
apt install -y \
build-essential \
protobuf-compiler \
cmake \
libssl-dev \
pkg-config && \
curl https://sh.rustup.rs -sSf | bash -s -- -y
ENV PATH="/root/.cargo/bin:${PATH}"
# Working directory
WORKDIR /workspace
COPY lib/runtime /workspace/lib/runtime
RUN cd lib/runtime && \
cargo build --release --locked && \
cargo doc --no-deps
# Build OpenAI HTTP Service binaries
COPY lib/llm /workspace/lib/llm
COPY examples/rust /workspace/examples/rust
RUN cd examples/rust && \
cargo build --release && \
cp target/release/http /usr/local/bin/ && \
cp target/release/llmctl /usr/local/bin/
# TODO: Build tio
# COPY applications/...
# Generate C bindings for kv cache routing in vLLM
COPY lib/bindings /workspace/lib/bindings
RUN cd lib/bindings/c && \
cargo build --release --locked && \
cargo doc --no-deps
# Build triton_distributed wheel
RUN source /opt/triton/venv/bin/activate && \
cd lib/bindings/python && \
uv build && \
uv pip install /workspace/lib/bindings/python/dist/triton_distributed*cp312*.whl
# Package the bindings
RUN mkdir -p /opt/triton/bindings/wheels && \
mkdir /opt/triton/bindings/lib && \
cp lib/bindings/python/dist/triton_distributed*cp312*.whl /opt/triton/bindings/wheels/. && \
cp lib/bindings/c/target/release/libtriton_distributed_llm_capi.so /opt/triton/bindings/lib/. && \
cp -r lib/bindings/c/include /opt/triton/bindings/.
# Tell vllm to use the Triton LLM C API for KV Cache Routing
ENV VLLM_KV_CAPI_PATH="/opt/triton/bindings/lib/libtriton_distributed_llm_capi.so"
# FIXME: Copy more specific folders in for dev/debug after directory restructure
COPY . /workspace
# FIXME: May want a modification with triton-distributed banner on entry
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
CMD []
### Lean Runtime Image Stage ###
# FIXME: Separate build and runtime images
FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} AS runtime
USER root
# Install tools for interactive convenience
RUN apt update -y && \
apt install -y curl tmux vim && \
echo "set -g mouse on" >> /root/.tmux.conf
# Set environment variables
ENV VIRTUAL_ENV=/opt/triton/venv
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
ENV RAPIDS_LIBUCX_PREFER_SYSTEM_LIBRARY=true
ENV VLLM_KV_CAPI_PATH="/opt/triton/bindings/lib/libtriton_distributed_llm_capi.so"
# Copy binaries
COPY --from=dev /usr/local/bin/http /usr/local/bin/http
COPY --from=dev /usr/local/bin/llmctl /usr/local/bin/llmctl
COPY --from=dev /usr/local/bin/etcd/etcd /usr/local/bin/etcd
COPY --from=dev /usr/bin/nats-server /usr/local/bin/nats-server
COPY --from=dev /bin/uv /usr/local/bin/uv
COPY --from=dev /bin/uvx /usr/local/bin/uvx
# Copy venv with installed packages
RUN uv python install 3.12
COPY --from=dev /opt/vllm /opt/vllm
COPY --from=dev ${VIRTUAL_ENV} ${VIRTUAL_ENV}
# Copy minimal set of files for testing. May consider separate stage for testing
# if test dependencies start to negatively impact deployment environment/size.
COPY pyproject.toml /workspace/pyproject.toml
COPY container/deps/vllm /workspace/container/deps/vllm
# Add library for KV routing
COPY --from=dev ${VLLM_KV_CAPI_PATH} ${VLLM_KV_CAPI_PATH}
# Copy minimal set of files for deployment/examples
# FIXME: Use a more consolidated path after directory restructure
COPY examples/python_rs/llm/vllm /workspace/examples/python_rs/llm/vllm
WORKDIR /workspace
# FIXME: May want a modification with triton-distributed banner on entry
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
CMD []
...@@ -43,7 +43,7 @@ PYTHON_PACKAGE_VERSION=${current_tag:-$latest_tag.dev+$commit_id} ...@@ -43,7 +43,7 @@ PYTHON_PACKAGE_VERSION=${current_tag:-$latest_tag.dev+$commit_id}
# dependencies are specified in the /container/deps folder and # dependencies are specified in the /container/deps folder and
# installed within framework specific sections of the Dockerfile. # installed within framework specific sections of the Dockerfile.
declare -A FRAMEWORKS=(["STANDARD"]=1 ["TENSORRTLLM"]=2 ["VLLM"]=3) declare -A FRAMEWORKS=(["STANDARD"]=1 ["TENSORRTLLM"]=2 ["VLLM"]=3 ["VLLM_NIXL"]=4)
DEFAULT_FRAMEWORK=STANDARD DEFAULT_FRAMEWORK=STANDARD
SOURCE_DIR=$(dirname "$(readlink -f "$0")") SOURCE_DIR=$(dirname "$(readlink -f "$0")")
...@@ -74,6 +74,9 @@ TENSORRTLLM_SKIP_CLONE=0 ...@@ -74,6 +74,9 @@ TENSORRTLLM_SKIP_CLONE=0
VLLM_BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base" VLLM_BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base"
VLLM_BASE_IMAGE_TAG="25.01-cuda12.8-devel-ubuntu24.04" VLLM_BASE_IMAGE_TAG="25.01-cuda12.8-devel-ubuntu24.04"
VLLM_NIXL_BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base"
VLLM_NIXL_BASE_IMAGE_TAG="25.01-cuda12.8-devel-ubuntu24.04"
get_options() { get_options() {
while :; do while :; do
case $1 in case $1 in
...@@ -180,6 +183,14 @@ get_options() { ...@@ -180,6 +183,14 @@ get_options() {
missing_requirement $1 missing_requirement $1
fi fi
;; ;;
--build-context)
if [ "$2" ]; then
BUILD_CONTEXT_ARG="--build-context $2"
shift
else
missing_requirement $1
fi
;;
--) --)
shift shift
break break
...@@ -274,6 +285,7 @@ show_help() { ...@@ -274,6 +285,7 @@ show_help() {
echo " [--tag tag for image]" echo " [--tag tag for image]"
echo " [--no-cache disable docker build cache]" echo " [--no-cache disable docker build cache]"
echo " [--dry-run print docker commands without running]" echo " [--dry-run print docker commands without running]"
echo " [--build-context name=path to add build context]"
exit 0 exit 0
} }
...@@ -292,6 +304,8 @@ get_options "$@" ...@@ -292,6 +304,8 @@ get_options "$@"
# Update DOCKERFILE if framework is VLLM # Update DOCKERFILE if framework is VLLM
if [[ $FRAMEWORK == "VLLM" ]]; then if [[ $FRAMEWORK == "VLLM" ]]; then
DOCKERFILE=${SOURCE_DIR}/Dockerfile.vllm DOCKERFILE=${SOURCE_DIR}/Dockerfile.vllm
elif [[ $FRAMEWORK == "VLLM_NIXL" ]]; then
DOCKERFILE=${SOURCE_DIR}/Dockerfile.vllm_nixl
fi fi
# BUILD DEV IMAGE # BUILD DEV IMAGE
...@@ -327,7 +341,7 @@ if [ -z "$RUN_PREFIX" ]; then ...@@ -327,7 +341,7 @@ if [ -z "$RUN_PREFIX" ]; then
set -x set -x
fi fi
$RUN_PREFIX docker build -f $DOCKERFILE $TARGET_STR $PLATFORM $BUILD_ARGS $CACHE_FROM $TAG $LATEST_TAG $BUILD_CONTEXT $NO_CACHE $RUN_PREFIX docker build -f $DOCKERFILE $TARGET_STR $PLATFORM $BUILD_ARGS $CACHE_FROM $TAG $LATEST_TAG $BUILD_CONTEXT_ARG $BUILD_CONTEXT $NO_CACHE
{ set +x; } 2>/dev/null { set +x; } 2>/dev/null
......
...@@ -27,4 +27,4 @@ pytestmark = pytest.mark.pre_merge ...@@ -27,4 +27,4 @@ pytestmark = pytest.mark.pre_merge
@pytest.mark.skipif(vllm is None, reason="Skipping vllm tests, vllm not installed") @pytest.mark.skipif(vllm is None, reason="Skipping vllm tests, vllm not installed")
def test_version(): def test_version():
# Verify that the image has the patched version of vllm # Verify that the image has the patched version of vllm
assert vllm.__version__.startswith("0.7.3.dev") assert vllm.__version__.startswith("0.7.3.dev") # type: ignore
...@@ -22,7 +22,7 @@ RUN_PREFIX= ...@@ -22,7 +22,7 @@ RUN_PREFIX=
# dependencies are specified in the /container/deps folder and # dependencies are specified in the /container/deps folder and
# installed within framework specific sections of the Dockerfile. # installed within framework specific sections of the Dockerfile.
declare -A FRAMEWORKS=(["STANDARD"]=1 ["TENSORRTLLM"]=2 ["VLLM"]=3) declare -A FRAMEWORKS=(["STANDARD"]=1 ["TENSORRTLLM"]=2 ["VLLM"]=3 ["VLLM_NIXL"]=4)
DEFAULT_FRAMEWORK=STANDARD DEFAULT_FRAMEWORK=STANDARD
SOURCE_DIR=$(dirname "$(readlink -f "$0")") SOURCE_DIR=$(dirname "$(readlink -f "$0")")
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -17,7 +17,6 @@ import json ...@@ -17,7 +17,6 @@ import json
import time import time
from typing import AsyncIterator, List, Optional, Protocol, Union, runtime_checkable from typing import AsyncIterator, List, Optional, Protocol, Union, runtime_checkable
from vllm import TokensPrompt
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.chat_utils import ConversationMessage from vllm.entrypoints.chat_utils import ConversationMessage
...@@ -29,6 +28,7 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -29,6 +28,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import RequestPrompt from vllm.entrypoints.openai.serving_engine import RequestPrompt
from vllm.inputs.data import TokensPrompt
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
......
...@@ -21,7 +21,9 @@ import msgspec ...@@ -21,7 +21,9 @@ import msgspec
from pydantic import BaseModel, ConfigDict, field_validator from pydantic import BaseModel, ConfigDict, field_validator
from pydantic_core import core_schema from pydantic_core import core_schema
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm import CompletionOutput, SamplingParams, TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import PromptLogprobs, RequestMetrics from vllm.sequence import PromptLogprobs, RequestMetrics
......
...@@ -55,7 +55,7 @@ class VllmPrefillEngine(BaseVllmEngine): ...@@ -55,7 +55,7 @@ class VllmPrefillEngine(BaseVllmEngine):
await self.initialize() await self.initialize()
vllm_logger.debug(f"Received prefill request: {request}") vllm_logger.debug(f"Received prefill request: {request}")
sampling_params = vllm.SamplingParams(**request.sampling_params) sampling_params = vllm.sampling_params.SamplingParams(**request.sampling_params)
if self.engine_client is None: if self.engine_client is None:
raise RuntimeError("Engine client not initialized") raise RuntimeError("Engine client not initialized")
else: else:
......
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
> **NOTE**: This example is based on an internal NVIDIA library that will soon be publicly released. The example won't work until the official release.
## Build docker
```
./container/build.sh --framework VLLM_NIXL --target dev --build-context nixl=<path to downloaded nixl repo @ fc912eb012597be67de11fa9ba0599e4e1974fa2>
```
## Run container
```
./container/run.sh --framework VLLM_NIXL --target dev -it
```
All of the commands below are run inside the same container.
## Run deployment
Add model to triton and start http server.
In terminal 0:
```
llmctl http add chat-models deepseek-ai/DeepSeek-R1-Distill-Llama-8B test-nixl.vllm.generate
TRT_LOG=DEBUG http --port 8181
```
### Monolithic deployment
In terminal 1:
```
cd /workspace/examples/python_rs/llm/vllm_nixl
CUDA_VISIBLE_DEVICES=0 python3 worker.py \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager
```
### Disaggregated deployment
In terminal 1:
```
cd /workspace/examples/python_rs/llm/vllm_nixl
CUDA_VISIBLE_DEVICES=0 python prefill_worker.py \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager \
--kv-transfer-config \
'{"kv_connector":"TritonNixlConnector"}'
```
In terminal 2:
```
cd /workspace/examples/python_rs/llm/vllm_nixl
CUDA_VISIBLE_DEVICES=1 python3 worker.py \
--remote-prefill \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager \
--kv-transfer-config \
'{"kv_connector":"TritonNixlConnector"}'
```
## Client
In another terminal:
```
curl localhost:8181/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"messages": [
{"role": "user", "content": "What is the capital of France?"}
],
"max_tokens": 10
}'
```
## Run genai-perf
`genai-perf` is a tool for profiling and benchmarking LLM servers. It is already installed in the container. For more details, please refer to the [genai-perf README](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/perf_analyzer/genai-perf/README.html).
```
genai-perf profile \
-m deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--url localhost:8181 \
--endpoint-type chat \
--streaming \
--service-kind openai \
--endpoint v1/chat/completions \
--warmup-request-count 10 \
--random-seed 123 \
--synthetic-input-tokens-stddev 0 \
--output-tokens-stddev 0 \
--tokenizer deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--synthetic-input-tokens-mean 3000 \
--output-tokens-mean 150 \
--extra-inputs min_tokens:150 \
--extra-inputs max_tokens:150 \
--profile-export-file my_profile_export.json \
--artifact-dir artifacts/ \
--concurrency 10 \
--request-count 40 \
-- -v \
--async
```
## Close deployment
Kill all python processes and clean up metadata files:
```
pkill -9 -f python
rm -r /tmp/nixl
```
## TODOs, limitations, known issues
- [ ] Add etcd for discovery
- [ ] Multi-node deployment support
- [ ] Enable chunked prefill
- [ ] Support mixed tp
- [ ] Process many remote prefill in one iteration
- [ ] Support recompute preemption
- [ ] Make sure decode does not preempt blocks before xfer finishes
- [ ] Layer wise transfer
- [ ] Non blocking send in prefill (cache manager should check xfer status)
- [ ] Test under load
- [ ] Support pp > 1
- [ ] Check why adding extra seed input is crashing vllm with remote prefill
- [ ] Unified worker for both prefill and decode
- [x] Require sending two parallel requests to start decode for the first time
- [x] Concurrency > 2 is not working
- [x] Parse cmdline args
- [x] Manual nixl example with tp1
- [x] Zero copy
- [x] Conditional remote prefill
- [x] Manual example with tp > 1
- [x] Run on triton distributed runtime
- [x] add oai http endpoint
- [x] Sample only on decode, do note return remote prefill response
- [x] Check if all transfers finished before moving to decode
- [x] Enable async output processing - could be working
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import contextmanager
import msgspec
from vllm.distributed.device_communicators.nixl import NixlMetadata
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser
METADATA_DIR = "/tmp/nixl"
def parse_vllm_args() -> AsyncEngineArgs:
parser = FlexibleArgumentParser()
parser.add_argument(
"--remote-prefill", action="store_true", help="Enable remote prefill"
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine_args.remote_prefill = args.remote_prefill
return engine_args
@contextmanager
def temp_metadata_file(engine_id, metadata: NixlMetadata):
os.makedirs(METADATA_DIR, exist_ok=True)
path = f"{METADATA_DIR}/{engine_id}.nixl_meta"
with open(path, "wb") as f:
encoded = msgspec.msgpack.encode(metadata)
print(f"Size of encoded metadata: {len(encoded)}")
f.write(encoded)
try:
yield path
finally:
if os.path.exists(path):
os.remove(path)
def find_remote_metadata(engine_id):
# find and load metadata from METADATA_DIR that do not match engine_id
remote_metadata = []
for file in os.listdir(METADATA_DIR):
if file.endswith(".nixl_meta"):
if file.split(".")[0] != engine_id:
with open(os.path.join(METADATA_DIR, file), "rb") as f:
remote_metadata.append(
msgspec.msgpack.decode(f.read(), type=NixlMetadata)
)
return remote_metadata
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import msgspec
import uvloop
from common import find_remote_metadata, parse_vllm_args, temp_metadata_file
from vllm.distributed.device_communicators.nixl import NixlMetadata
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.inputs.data import TokensPrompt
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from triton_distributed.runtime import DistributedRuntime, triton_worker
class RequestHandler:
def __init__(self, engine_client):
self.engine_client = engine_client
print("RequestHandler initialized")
async def generate(self, raw_request: str):
request: RemotePrefillRequest = msgspec.json.decode(
raw_request.encode("utf-8"), type=RemotePrefillRequest
)
sampling_params = request.sampling_params
sampling_params.max_tokens = 1
sampling_params.min_tokens = 1
remote_prefill_params = RemotePrefillParams(
is_remote_decode=True,
decode_block_ids=request.block_ids,
decode_engine_id=request.engine_id,
)
async for _ in self.engine_client.generate(
request_id=request.request_id,
prompt=TokensPrompt(prompt_token_ids=request.prompt_token_ids),
sampling_params=sampling_params,
remote_prefill_params=remote_prefill_params,
):
yield
@triton_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
component = runtime.namespace("test-nixl").component("prefill")
await component.create_service()
endpoint = component.endpoint("generate")
async with build_async_engine_client_from_engine_args(engine_args) as engine_client:
# This should be replaced with etcd
metadata = engine_client.nixl_metadata
with temp_metadata_file(metadata.engine_id, metadata):
print(f"Waiting for remote metadata for engine {metadata.engine_id}")
remote_metadata: list[NixlMetadata] = []
while not remote_metadata:
await asyncio.sleep(1)
remote_metadata = find_remote_metadata(metadata.engine_id)
print(
f"Found {len(remote_metadata)} remote metadata for engine {metadata.engine_id}"
)
for remote_metadata in remote_metadata:
await engine_client.add_remote_nixl_metadata(remote_metadata)
await endpoint.serve_endpoint(RequestHandler(engine_client).generate)
if __name__ == "__main__":
uvloop.install()
engine_args = parse_vllm_args()
if engine_args.enable_chunked_prefill is not False:
print("Chunked prefill is not supported yet, setting to False")
engine_args.enable_chunked_prefill = False
if engine_args.pipeline_parallel_size != 1:
print("Pipeline parallel size is not supported yet, setting to 1")
engine_args.pipeline_parallel_size = 1
asyncio.run(worker(engine_args))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import msgspec
from vllm.sampling_params import SamplingParams
class Request(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True,
):
"""The request data of one remote prefill output of a request.
Args:
request_id: The unique ID of the request.
prompt: The prompt string of the request.
"""
request_id: str
prompt: str
sampling_params: SamplingParams
do_remote_prefill: bool = False
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import msgspec
import uvloop
from common import parse_vllm_args, temp_metadata_file
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.multiprocessing.client import EngineClient
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionStreamResponse,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from triton_distributed.runtime import (
DistributedRuntime,
triton_endpoint,
triton_worker,
)
class RequestHandler:
def __init__(
self,
model_name: str,
engine_client: EngineClient,
prefill_client,
do_remote_prefill: bool,
):
self.model_name = model_name
self.engine_client = engine_client
self.prefill_client = prefill_client
self.openai_serving_chat = None
self.initialized = False
self.do_remote_prefill = (
do_remote_prefill # TODO: this should be decided by the algorithm
)
print("RequestHandler initialized")
async def init(self):
models = OpenAIServingModels(
engine_client=self.engine_client,
model_config=await self.engine_client.get_model_config(),
base_model_paths=[
BaseModelPath(
name=self.model_name,
model_path=self.model_name,
)
],
)
self.openai_serving_chat = OpenAIServingChat(
engine_client=self.engine_client,
model_config=await self.engine_client.get_model_config(),
models=models,
request_logger=None,
response_role="assistant",
chat_template=None,
chat_template_content_format="auto",
)
self.initialized = True
def get_remote_prefill_request_callback(self):
async def callback(request: RemotePrefillRequest):
json_request = msgspec.json.encode(request).decode("utf-8")
self.prefill_client.generate(json_request)
return callback
@triton_endpoint(ChatCompletionRequest, ChatCompletionStreamResponse)
async def generate(self, request):
if not self.initialized:
await self.init()
assert self.openai_serving_chat is not None
if self.do_remote_prefill:
remote_prefill_params = RemotePrefillParams(
is_remote_prefill=True,
remote_prefill_request_callback=self.get_remote_prefill_request_callback(),
)
else:
remote_prefill_params = None
async for raw_response in await self.openai_serving_chat.create_chat_completion(
request,
remote_prefill_params=remote_prefill_params,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
@triton_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
component = runtime.namespace("test-nixl").component("vllm")
await component.create_service()
endpoint = component.endpoint("generate")
prefill_client = (
await runtime.namespace("test-nixl")
.component("prefill")
.endpoint("generate")
.client()
)
async with build_async_engine_client_from_engine_args(engine_args) as engine_client:
# This should be replaced with etcd
if engine_args.remote_prefill:
metadata = engine_client.nixl_metadata
with temp_metadata_file(metadata.engine_id, metadata):
await endpoint.serve_endpoint(
RequestHandler(
model_name=engine_args.model,
engine_client=engine_client,
prefill_client=prefill_client,
do_remote_prefill=True,
).generate
)
else:
await endpoint.serve_endpoint(
RequestHandler(
model_name=engine_args.model,
engine_client=engine_client,
prefill_client=prefill_client,
do_remote_prefill=False,
).generate
)
if __name__ == "__main__":
uvloop.install()
engine_args = parse_vllm_args()
if engine_args.remote_prefill:
if engine_args.enable_chunked_prefill is not False:
print("Chunked prefill is not supported yet, setting to False")
engine_args.enable_chunked_prefill = False
if engine_args.preemption_mode != "swap":
print("Preemption mode is not supported yet, setting to swap")
engine_args.preemption_mode = "swap"
if engine_args.pipeline_parallel_size != 1:
print("Pipeline parallel size is not supported yet, setting to 1")
engine_args.pipeline_parallel_size = 1
asyncio.run(worker(engine_args))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment