Commit fe83f8aa authored by ptarasiewiczNV's avatar ptarasiewiczNV Committed by GitHub
Browse files

Update NIXL Dockerfile + vLLM patch with variable TP (#9)


Co-authored-by: default avatarPiotr Tarasiewicz Nvidia <ptarasiewicznv@Piotrs-MacBook-Pro.local>
parent b9ce8dd0
...@@ -8,50 +8,12 @@ FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} AS dev ...@@ -8,50 +8,12 @@ FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} AS dev
USER root USER root
# Install utilities
RUN apt update -y && apt install -y git wget curl nvtop tmux vim
# nats
RUN wget https://github.com/nats-io/nats-server/releases/download/v2.10.24/nats-server-v2.10.24-amd64.deb && dpkg -i nats-server-v2.10.24-amd64.deb
# etcd
ENV ETCD_VERSION="v3.5.18"
RUN wget https://github.com/etcd-io/etcd/releases/download/$ETCD_VERSION/etcd-$ETCD_VERSION-linux-amd64.tar.gz -O /tmp/etcd.tar.gz && \
mkdir -p /usr/local/bin/etcd && \
tar -xvf /tmp/etcd.tar.gz -C /usr/local/bin/etcd --strip-components=1
ENV PATH=/usr/local/bin/etcd/:$PATH
### VIRTUAL ENVIRONMENT SETUP ###
# Install uv and create virtualenv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
RUN mkdir /opt/triton && \
uv venv /opt/triton/venv --python 3.12
# Activate virtual environment
ENV VIRTUAL_ENV=/opt/triton/venv
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Install patched vllm - keep this early in Dockerfile to avoid
# rebuilds from unrelated source code changes
ARG VLLM_REF="v0.7.2"
ARG VLLM_PATCH="vllm_${VLLM_REF}-triton-kv-disagg-patch.patch"
RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
bash /tmp/deps/vllm/install.sh --patch /tmp/deps/vllm/${VLLM_PATCH} --ref ${VLLM_REF} --install-cmd "uv pip install --editable" --use-precompiled --installation-dir /opt/vllm
# Install genai-perf for benchmarking
ARG GENAI_PERF_TAG="r25.01"
RUN uv pip install "git+https://github.com/triton-inference-server/perf_analyzer.git@${GENAI_PERF_TAG}#subdirectory=genai-perf"
# Install test dependencies
RUN --mount=type=bind,source=./container/deps/requirements.test.txt,target=/tmp/requirements.txt \
uv pip install --requirement /tmp/requirements.txt
### NIXL SETUP ### ### NIXL SETUP ###
ARG MOFED_VERSION=5.8-1.1.2.1 ARG MOFED_VERSION=24.10-1.1.4.0
ARG PYTHON_VERSION=3.12 ARG PYTHON_VERSION=3.12
ARG NSYS_URL=https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2024_4/ ARG NSYS_URL=https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_1/
ARG NSYS_PKG=NsightSystems-linux-cli-public-2024.4.1.61-3431596.deb ARG NSYS_PKG=NsightSystems-linux-cli-public-2025.1.1.131-3554042.deb
RUN apt-get update -y && apt-get -y install curl \ RUN apt-get update -y && apt-get -y install curl \
git \ git \
...@@ -78,7 +40,9 @@ RUN apt-get update -y && apt-get -y install curl \ ...@@ -78,7 +40,9 @@ RUN apt-get update -y && apt-get -y install curl \
libprotobuf-dev \ libprotobuf-dev \
protobuf-compiler-grpc \ protobuf-compiler-grpc \
pybind11-dev \ pybind11-dev \
python3-full \
python3-pip \ python3-pip \
python3-numpy \
etcd-server \ etcd-server \
net-tools \ net-tools \
pciutils \ pciutils \
...@@ -89,25 +53,18 @@ RUN apt-get update -y && apt-get -y install curl \ ...@@ -89,25 +53,18 @@ RUN apt-get update -y && apt-get -y install curl \
ibverbs-utils \ ibverbs-utils \
libibmad-dev libibmad-dev
RUN apt-get update && \
apt install -y wget libglib2.0-0
RUN wget ${NSYS_URL}${NSYS_PKG} && \
dpkg -i $NSYS_PKG && \
rm $NSYS_PKG
RUN apt-get install -y linux-tools-common linux-tools-generic ethtool iproute2 RUN apt-get install -y linux-tools-common linux-tools-generic ethtool iproute2
RUN apt-get install -y dkms linux-headers-generic RUN apt-get install -y dkms linux-headers-generic
RUN apt-get install -y meson ninja-build uuid-dev gdb RUN apt-get install -y meson ninja-build uuid-dev gdb
RUN uv pip install --upgrade meson RUN apt-get update && apt install -y wget libglib2.0-0
RUN uv pip install ninja pybind11 RUN wget ${NSYS_URL}${NSYS_PKG} && dpkg -i $NSYS_PKG && rm $NSYS_PKG
RUN cd /usr/local/src && \ RUN cd /usr/local/src && \
curl -fSsL "https://content.mellanox.com/ofed/MLNX_OFED-${MOFED_VERSION}/MLNX_OFED_LINUX-${MOFED_VERSION}-ubuntu20.04-x86_64.tgz" -o mofed.tgz && \ curl -fSsL "https://content.mellanox.com/ofed/MLNX_OFED-${MOFED_VERSION}/MLNX_OFED_LINUX-${MOFED_VERSION}-ubuntu24.04-x86_64.tgz" -o mofed.tgz && \
tar -xf /usr/local/src/mofed.tgz && \ tar -xf /usr/local/src/mofed.tgz && \
cd MLNX_OFED_LINUX-* && \ cd MLNX_OFED_LINUX-* && \
apt-get update && \ apt-get update && apt-get install -y --no-install-recommends \
apt-get install -y --no-install-recommends \
./DEBS/libibverbs* ./DEBS/ibverbs-providers* ./DEBS/librdmacm* ./DEBS/libibumad* && \ ./DEBS/libibverbs* ./DEBS/ibverbs-providers* ./DEBS/librdmacm* ./DEBS/libibumad* && \
rm -rf /var/lib/apt/lists/* /usr/local/src/* rm -rf /var/lib/apt/lists/* /usr/local/src/*
...@@ -128,9 +85,7 @@ ARG UCX_VERSION=v1.18.0 ...@@ -128,9 +85,7 @@ ARG UCX_VERSION=v1.18.0
RUN cd /usr/local/src && \ RUN cd /usr/local/src && \
curl -fSsL "https://github.com/openucx/ucx/tarball/${UCX_VERSION}" | tar xz && \ curl -fSsL "https://github.com/openucx/ucx/tarball/${UCX_VERSION}" | tar xz && \
cd openucx-ucx* && \ cd openucx-ucx* && \
./autogen.sh && \ ./autogen.sh && ./configure \
./configure \
--prefix=/usr/local/ucx \
--enable-shared \ --enable-shared \
--disable-static \ --disable-static \
--disable-doxygen-doc \ --disable-doxygen-doc \
...@@ -147,13 +102,20 @@ RUN cd /usr/local/src && \ ...@@ -147,13 +102,20 @@ RUN cd /usr/local/src && \
make -j install-strip && \ make -j install-strip && \
ldconfig ldconfig
ENV LD_LIBRARY_PATH=/usr/lib:$LD_LIBRARY_PATH
ENV LD_LIBRARY_PATH=/usr/local/ucx/lib:$LD_LIBRARY_PATH ENV CPATH=/usr/include:$CPATH
ENV CPATH=/usr/local/ucx/include:$CPATH ENV PATH=/usr/bin:$PATH
ENV PATH=/usr/local/ucx/bin:$PATH ENV PKG_CONFIG_PATH=/usr/lib/pkgconfig:$PKG_CONFIG_PATH
ENV PKG_CONFIG_PATH=/usr/local/ucx/lib/pkgconfig:$PKG_CONFIG_PATH
SHELL ["/bin/bash", "-c"] SHELL ["/bin/bash", "-c"]
WORKDIR /workspace
ENV LD_LIBRARY_PATH=/usr/local/ompi/lib:$LD_LIBRARY_PATH
ENV CPATH=/usr/local/ompi/include:$CPATH
ENV PATH=/usr/local/ompi/bin:$PATH
ENV PKG_CONFIG_PATH=/usr/local/ompi/lib/pkgconfig:$PKG_CONFIG_PATH
COPY --from=nixl . /opt/nixl COPY --from=nixl . /opt/nixl
RUN cd /opt/nixl && \ RUN cd /opt/nixl && \
...@@ -161,15 +123,10 @@ RUN cd /opt/nixl && \ ...@@ -161,15 +123,10 @@ RUN cd /opt/nixl && \
meson setup build/ --prefix=/usr/local/nixl && \ meson setup build/ --prefix=/usr/local/nixl && \
cd build/ && \ cd build/ && \
ninja && \ ninja && \
ninja install && \ ninja install
mkdir -p /usr/local/nixl/include/internal && \
cp ../include/*.h /usr/local/nixl/include && \
cp ../include/internal/*.h /usr/local/nixl/include/internal && \
cp ../include/internal/*.h /usr/local/nixl/include/ && \
cp ../src/utils/serdes/serdes.h /usr/local/nixl/include
ENV LD_LIBRARY_PATH=/usr/local/nixl/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH ENV LD_LIBRARY_PATH=/usr/local/nixl/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH
ENV PYTHONPATH=/usr/local/nixl/lib/python${PYTHON_VERSION}/site-packages/:/opt/nixl/test/python/:$PYTHONPATH ENV PYTHONPATH=/usr/local/nixl/lib/python3/dist-packages/:/opt/nixl/test/python/:$PYTHONPATH
RUN ls -l /usr/local/nixl/ RUN ls -l /usr/local/nixl/
RUN ls -l /usr/local/nixl/include/ RUN ls -l /usr/local/nixl/include/
...@@ -177,6 +134,44 @@ RUN ls -l /usr/local/nixl/include/internal/ ...@@ -177,6 +134,44 @@ RUN ls -l /usr/local/nixl/include/internal/
RUN ls /opt/nixl RUN ls /opt/nixl
# Install utilities
RUN apt update -y && apt install -y git wget curl nvtop tmux vim
# nats
RUN wget https://github.com/nats-io/nats-server/releases/download/v2.10.24/nats-server-v2.10.24-amd64.deb && dpkg -i nats-server-v2.10.24-amd64.deb
# etcd
ENV ETCD_VERSION="v3.5.18"
RUN wget https://github.com/etcd-io/etcd/releases/download/$ETCD_VERSION/etcd-$ETCD_VERSION-linux-amd64.tar.gz -O /tmp/etcd.tar.gz && \
mkdir -p /usr/local/bin/etcd && \
tar -xvf /tmp/etcd.tar.gz -C /usr/local/bin/etcd --strip-components=1
ENV PATH=/usr/local/bin/etcd/:$PATH
### VIRTUAL ENVIRONMENT SETUP ###
# Install uv and create virtualenv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
RUN mkdir /opt/triton && \
uv venv /opt/triton/venv --python 3.12
# Activate virtual environment
ENV VIRTUAL_ENV=/opt/triton/venv
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Install patched vllm - keep this early in Dockerfile to avoid
# rebuilds from unrelated source code changes
ARG VLLM_REF="v0.7.2"
ARG VLLM_PATCH="vllm_${VLLM_REF}-triton-kv-disagg-patch.patch"
RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
bash /tmp/deps/vllm/install.sh --patch /tmp/deps/vllm/${VLLM_PATCH} --ref ${VLLM_REF} --install-cmd "uv pip install --editable" --use-precompiled --installation-dir /opt/vllm
# Install genai-perf for benchmarking
ARG GENAI_PERF_TAG="r25.01"
RUN uv pip install "git+https://github.com/triton-inference-server/perf_analyzer.git@${GENAI_PERF_TAG}#subdirectory=genai-perf"
# Install test dependencies
RUN --mount=type=bind,source=./container/deps/requirements.test.txt,target=/tmp/requirements.txt \
uv pip install --requirement /tmp/requirements.txt
# ### MISC UTILITY SETUP ### # ### MISC UTILITY SETUP ###
# Finish pyright install # Finish pyright install
......
...@@ -20,7 +20,7 @@ limitations under the License. ...@@ -20,7 +20,7 @@ limitations under the License.
## Build docker ## Build docker
``` ```
./container/build.sh --framework VLLM_NIXL --target dev --build-context nixl=<path to downloaded nixl repo @ fc912eb012597be67de11fa9ba0599e4e1974fa2> ./container/build.sh --framework VLLM_NIXL --target dev --build-context nixl=<path to downloaded nixl repo @ c53bb19a6a114e9093071bd1f2904f996ae1839b>
``` ```
## Run container ## Run container
...@@ -72,10 +72,11 @@ In terminal 2: ...@@ -72,10 +72,11 @@ In terminal 2:
``` ```
cd /workspace/examples/python_rs/llm/vllm_nixl cd /workspace/examples/python_rs/llm/vllm_nixl
CUDA_VISIBLE_DEVICES=1 python3 worker.py \ CUDA_VISIBLE_DEVICES=1,2 python3 worker.py \
--remote-prefill \ --remote-prefill \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager \ --enforce-eager \
--tensor-parallel-size 2 \
--kv-transfer-config \ --kv-transfer-config \
'{"kv_connector":"TritonNixlConnector"}' '{"kv_connector":"TritonNixlConnector"}'
``` ```
......
...@@ -18,7 +18,7 @@ import asyncio ...@@ -18,7 +18,7 @@ import asyncio
import msgspec import msgspec
import uvloop import uvloop
from common import find_remote_metadata, parse_vllm_args, temp_metadata_file from common import find_remote_metadata, parse_vllm_args
from vllm.distributed.device_communicators.nixl import NixlMetadata from vllm.distributed.device_communicators.nixl import NixlMetadata
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.api_server import ( from vllm.entrypoints.openai.api_server import (
...@@ -69,19 +69,18 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): ...@@ -69,19 +69,18 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
async with build_async_engine_client_from_engine_args(engine_args) as engine_client: async with build_async_engine_client_from_engine_args(engine_args) as engine_client:
# This should be replaced with etcd # This should be replaced with etcd
metadata = engine_client.nixl_metadata metadata = engine_client.nixl_metadata
with temp_metadata_file(metadata.engine_id, metadata): print(f"Waiting for remote metadata for engine {metadata.engine_id}")
print(f"Waiting for remote metadata for engine {metadata.engine_id}") remote_metadata: list[NixlMetadata] = []
remote_metadata: list[NixlMetadata] = [] while not remote_metadata:
while not remote_metadata: await asyncio.sleep(1)
await asyncio.sleep(1) remote_metadata = find_remote_metadata(metadata.engine_id)
remote_metadata = find_remote_metadata(metadata.engine_id)
print(
print( f"Found {len(remote_metadata)} remote metadata for engine {metadata.engine_id}"
f"Found {len(remote_metadata)} remote metadata for engine {metadata.engine_id}" )
) for remote_metadata in remote_metadata:
for remote_metadata in remote_metadata: await engine_client.add_remote_nixl_metadata(remote_metadata)
await engine_client.add_remote_nixl_metadata(remote_metadata) await endpoint.serve_endpoint(RequestHandler(engine_client).generate)
await endpoint.serve_endpoint(RequestHandler(engine_client).generate)
if __name__ == "__main__": if __name__ == "__main__":
...@@ -96,4 +95,12 @@ if __name__ == "__main__": ...@@ -96,4 +95,12 @@ if __name__ == "__main__":
print("Pipeline parallel size is not supported yet, setting to 1") print("Pipeline parallel size is not supported yet, setting to 1")
engine_args.pipeline_parallel_size = 1 engine_args.pipeline_parallel_size = 1
if engine_args.disable_async_output_proc is not True:
print("Async output processing is not supported yet, setting to True")
engine_args.disable_async_output_proc = True
if engine_args.enforce_eager is not True:
print("Prefill must be done eagerly, setting to True")
engine_args.enforce_eager = True
asyncio.run(worker(engine_args)) asyncio.run(worker(engine_args))
...@@ -93,6 +93,8 @@ class RequestHandler: ...@@ -93,6 +93,8 @@ class RequestHandler:
await self.init() await self.init()
assert self.openai_serving_chat is not None assert self.openai_serving_chat is not None
request.model = "vllm"
if self.do_remote_prefill: if self.do_remote_prefill:
remote_prefill_params = RemotePrefillParams( remote_prefill_params = RemotePrefillParams(
is_remote_prefill=True, is_remote_prefill=True,
...@@ -133,7 +135,7 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): ...@@ -133,7 +135,7 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
with temp_metadata_file(metadata.engine_id, metadata): with temp_metadata_file(metadata.engine_id, metadata):
await endpoint.serve_endpoint( await endpoint.serve_endpoint(
RequestHandler( RequestHandler(
model_name=engine_args.model, model_name="vllm",
engine_client=engine_client, engine_client=engine_client,
prefill_client=prefill_client, prefill_client=prefill_client,
do_remote_prefill=True, do_remote_prefill=True,
...@@ -142,7 +144,7 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): ...@@ -142,7 +144,7 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
else: else:
await endpoint.serve_endpoint( await endpoint.serve_endpoint(
RequestHandler( RequestHandler(
model_name=engine_args.model, model_name="vllm",
engine_client=engine_client, engine_client=engine_client,
prefill_client=prefill_client, prefill_client=prefill_client,
do_remote_prefill=False, do_remote_prefill=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment