"deploy/compoundai/api-server/api/database/database.go" did not exist on "14ce7e03bedde2b7632fa526a01a187f4b374996"
Commit f7a60cba authored by ptarasiewiczNV's avatar ptarasiewiczNV Committed by GitHub
Browse files

feat: vLLM + NIXL example


Co-authored-by: default avatarPiotr Tarasiewicz Nvidia <ptarasiewicznv@Piotrs-MacBook-Pro.local>
Co-authored-by: default avatarnnshah1 <neelays@nvidia.com>
Co-authored-by: default avataralec-flowers <aflowers@nvidia.com>
parent ea401e3b
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
ARG BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base"
ARG BASE_IMAGE_TAG="25.01-cuda12.8-devel-ubuntu24.04"
FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} AS dev
USER root
# Install utilities
RUN apt update -y && apt install -y git wget curl nvtop tmux vim
# nats
RUN wget https://github.com/nats-io/nats-server/releases/download/v2.10.24/nats-server-v2.10.24-amd64.deb && dpkg -i nats-server-v2.10.24-amd64.deb
# etcd
ENV ETCD_VERSION="v3.5.18"
RUN wget https://github.com/etcd-io/etcd/releases/download/$ETCD_VERSION/etcd-$ETCD_VERSION-linux-amd64.tar.gz -O /tmp/etcd.tar.gz && \
mkdir -p /usr/local/bin/etcd && \
tar -xvf /tmp/etcd.tar.gz -C /usr/local/bin/etcd --strip-components=1
ENV PATH=/usr/local/bin/etcd/:$PATH
### VIRTUAL ENVIRONMENT SETUP ###
# Install uv and create virtualenv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
RUN mkdir /opt/triton && \
uv venv /opt/triton/venv --python 3.12
# Activate virtual environment
ENV VIRTUAL_ENV=/opt/triton/venv
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Install patched vllm - keep this early in Dockerfile to avoid
# rebuilds from unrelated source code changes
ARG VLLM_REF="v0.7.2"
ARG VLLM_PATCH="vllm_${VLLM_REF}-triton-kv-disagg-patch.patch"
RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
bash /tmp/deps/vllm/install.sh --patch /tmp/deps/vllm/${VLLM_PATCH} --ref ${VLLM_REF} --install-cmd "uv pip install --editable" --use-precompiled --installation-dir /opt/vllm
# Install genai-perf for benchmarking
ARG GENAI_PERF_TAG="r25.01"
RUN uv pip install "git+https://github.com/triton-inference-server/perf_analyzer.git@${GENAI_PERF_TAG}#subdirectory=genai-perf"
# Install test dependencies
RUN --mount=type=bind,source=./container/deps/requirements.test.txt,target=/tmp/requirements.txt \
uv pip install --requirement /tmp/requirements.txt
### NIXL SETUP ###
ARG MOFED_VERSION=5.8-1.1.2.1
ARG PYTHON_VERSION=3.12
ARG NSYS_URL=https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2024_4/
ARG NSYS_PKG=NsightSystems-linux-cli-public-2024.4.1.61-3431596.deb
RUN apt-get update -y && apt-get -y install curl \
git \
libnuma-dev \
numactl \
wget \
autotools-dev \
automake \
libtool \
libz-dev \
libiberty-dev \
flex \
build-essential \
cmake \
libibverbs-dev \
libgoogle-glog-dev \
libgtest-dev \
libjsoncpp-dev \
libpython3-dev \
libboost-all-dev \
libssl-dev \
libgrpc-dev \
libgrpc++-dev \
libprotobuf-dev \
protobuf-compiler-grpc \
pybind11-dev \
python3-pip \
etcd-server \
net-tools \
pciutils \
libpci-dev \
vim \
tmux \
screen \
ibverbs-utils \
libibmad-dev
RUN apt-get update && \
apt install -y wget libglib2.0-0
RUN wget ${NSYS_URL}${NSYS_PKG} && \
dpkg -i $NSYS_PKG && \
rm $NSYS_PKG
RUN apt-get install -y linux-tools-common linux-tools-generic ethtool iproute2
RUN apt-get install -y dkms linux-headers-generic
RUN apt-get install -y meson ninja-build uuid-dev gdb
RUN uv pip install --upgrade meson
RUN uv pip install ninja pybind11
RUN cd /usr/local/src && \
curl -fSsL "https://content.mellanox.com/ofed/MLNX_OFED-${MOFED_VERSION}/MLNX_OFED_LINUX-${MOFED_VERSION}-ubuntu20.04-x86_64.tgz" -o mofed.tgz && \
tar -xf /usr/local/src/mofed.tgz && \
cd MLNX_OFED_LINUX-* && \
apt-get update && \
apt-get install -y --no-install-recommends \
./DEBS/libibverbs* ./DEBS/ibverbs-providers* ./DEBS/librdmacm* ./DEBS/libibumad* && \
rm -rf /var/lib/apt/lists/* /usr/local/src/*
ENV LIBRARY_PATH=$LIBRARY_PATH:/usr/local/cuda/lib64 \
LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64
ENV LIBRARY_PATH=$LIBRARY_PATH:/usr/local/lib \
LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib
WORKDIR /workspace
RUN git clone https://github.com/NVIDIA/gdrcopy.git
RUN PREFIX=/usr/local DESTLIB=/usr/local/lib make -C /workspace/gdrcopy lib_install
RUN cp gdrcopy/src/libgdrapi.so.2.* /usr/lib/x86_64-linux-gnu/
RUN ldconfig
ARG UCX_VERSION=v1.18.0
RUN cd /usr/local/src && \
curl -fSsL "https://github.com/openucx/ucx/tarball/${UCX_VERSION}" | tar xz && \
cd openucx-ucx* && \
./autogen.sh && \
./configure \
--prefix=/usr/local/ucx \
--enable-shared \
--disable-static \
--disable-doxygen-doc \
--enable-optimizations \
--enable-cma \
--enable-devel-headers \
--with-cuda=/usr/local/cuda \
--with-verbs \
--with-dm \
--with-gdrcopy=/usr/local \
--enable-mt \
--with-mlx5-dv && \
make -j && \
make -j install-strip && \
ldconfig
ENV LD_LIBRARY_PATH=/usr/local/ucx/lib:$LD_LIBRARY_PATH
ENV CPATH=/usr/local/ucx/include:$CPATH
ENV PATH=/usr/local/ucx/bin:$PATH
ENV PKG_CONFIG_PATH=/usr/local/ucx/lib/pkgconfig:$PKG_CONFIG_PATH
SHELL ["/bin/bash", "-c"]
COPY --from=nixl . /opt/nixl
RUN cd /opt/nixl && \
mkdir build && \
meson setup build/ --prefix=/usr/local/nixl && \
cd build/ && \
ninja && \
ninja install && \
mkdir -p /usr/local/nixl/include/internal && \
cp ../include/*.h /usr/local/nixl/include && \
cp ../include/internal/*.h /usr/local/nixl/include/internal && \
cp ../include/internal/*.h /usr/local/nixl/include/ && \
cp ../src/utils/serdes/serdes.h /usr/local/nixl/include
ENV LD_LIBRARY_PATH=/usr/local/nixl/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH
ENV PYTHONPATH=/usr/local/nixl/lib/python${PYTHON_VERSION}/site-packages/:/opt/nixl/test/python/:$PYTHONPATH
RUN ls -l /usr/local/nixl/
RUN ls -l /usr/local/nixl/include/
RUN ls -l /usr/local/nixl/include/internal/
RUN ls /opt/nixl
# ### MISC UTILITY SETUP ###
# Finish pyright install
RUN pyright --help > /dev/null 2>&1
# Enable Git operations in the /workspace directory
RUN printf "[safe]\n directory=/workspace\n" > /root/.gitconfig
RUN ln -sf /bin/bash /bin/sh
### BUILDS ###
# Rust build/dev dependencies
RUN apt update -y && \
apt install -y \
build-essential \
protobuf-compiler \
cmake \
libssl-dev \
pkg-config && \
curl https://sh.rustup.rs -sSf | bash -s -- -y
ENV PATH="/root/.cargo/bin:${PATH}"
# Working directory
WORKDIR /workspace
COPY lib/runtime /workspace/lib/runtime
RUN cd lib/runtime && \
cargo build --release --locked && \
cargo doc --no-deps
# Build OpenAI HTTP Service binaries
COPY lib/llm /workspace/lib/llm
COPY examples/rust /workspace/examples/rust
RUN cd examples/rust && \
cargo build --release && \
cp target/release/http /usr/local/bin/ && \
cp target/release/llmctl /usr/local/bin/
# TODO: Build tio
# COPY applications/...
# Generate C bindings for kv cache routing in vLLM
COPY lib/bindings /workspace/lib/bindings
RUN cd lib/bindings/c && \
cargo build --release --locked && \
cargo doc --no-deps
# Build triton_distributed wheel
RUN source /opt/triton/venv/bin/activate && \
cd lib/bindings/python && \
uv build && \
uv pip install /workspace/lib/bindings/python/dist/triton_distributed*cp312*.whl
# Package the bindings
RUN mkdir -p /opt/triton/bindings/wheels && \
mkdir /opt/triton/bindings/lib && \
cp lib/bindings/python/dist/triton_distributed*cp312*.whl /opt/triton/bindings/wheels/. && \
cp lib/bindings/c/target/release/libtriton_distributed_llm_capi.so /opt/triton/bindings/lib/. && \
cp -r lib/bindings/c/include /opt/triton/bindings/.
# Tell vllm to use the Triton LLM C API for KV Cache Routing
ENV VLLM_KV_CAPI_PATH="/opt/triton/bindings/lib/libtriton_distributed_llm_capi.so"
# FIXME: Copy more specific folders in for dev/debug after directory restructure
COPY . /workspace
# FIXME: May want a modification with triton-distributed banner on entry
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
CMD []
### Lean Runtime Image Stage ###
# FIXME: Separate build and runtime images
FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} AS runtime
USER root
# Install tools for interactive convenience
RUN apt update -y && \
apt install -y curl tmux vim && \
echo "set -g mouse on" >> /root/.tmux.conf
# Set environment variables
ENV VIRTUAL_ENV=/opt/triton/venv
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
ENV RAPIDS_LIBUCX_PREFER_SYSTEM_LIBRARY=true
ENV VLLM_KV_CAPI_PATH="/opt/triton/bindings/lib/libtriton_distributed_llm_capi.so"
# Copy binaries
COPY --from=dev /usr/local/bin/http /usr/local/bin/http
COPY --from=dev /usr/local/bin/llmctl /usr/local/bin/llmctl
COPY --from=dev /usr/local/bin/etcd/etcd /usr/local/bin/etcd
COPY --from=dev /usr/bin/nats-server /usr/local/bin/nats-server
COPY --from=dev /bin/uv /usr/local/bin/uv
COPY --from=dev /bin/uvx /usr/local/bin/uvx
# Copy venv with installed packages
RUN uv python install 3.12
COPY --from=dev /opt/vllm /opt/vllm
COPY --from=dev ${VIRTUAL_ENV} ${VIRTUAL_ENV}
# Copy minimal set of files for testing. May consider separate stage for testing
# if test dependencies start to negatively impact deployment environment/size.
COPY pyproject.toml /workspace/pyproject.toml
COPY container/deps/vllm /workspace/container/deps/vllm
# Add library for KV routing
COPY --from=dev ${VLLM_KV_CAPI_PATH} ${VLLM_KV_CAPI_PATH}
# Copy minimal set of files for deployment/examples
# FIXME: Use a more consolidated path after directory restructure
COPY examples/python_rs/llm/vllm /workspace/examples/python_rs/llm/vllm
WORKDIR /workspace
# FIXME: May want a modification with triton-distributed banner on entry
ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"]
CMD []
...@@ -43,7 +43,7 @@ PYTHON_PACKAGE_VERSION=${current_tag:-$latest_tag.dev+$commit_id} ...@@ -43,7 +43,7 @@ PYTHON_PACKAGE_VERSION=${current_tag:-$latest_tag.dev+$commit_id}
# dependencies are specified in the /container/deps folder and # dependencies are specified in the /container/deps folder and
# installed within framework specific sections of the Dockerfile. # installed within framework specific sections of the Dockerfile.
declare -A FRAMEWORKS=(["STANDARD"]=1 ["TENSORRTLLM"]=2 ["VLLM"]=3) declare -A FRAMEWORKS=(["STANDARD"]=1 ["TENSORRTLLM"]=2 ["VLLM"]=3 ["VLLM_NIXL"]=4)
DEFAULT_FRAMEWORK=STANDARD DEFAULT_FRAMEWORK=STANDARD
SOURCE_DIR=$(dirname "$(readlink -f "$0")") SOURCE_DIR=$(dirname "$(readlink -f "$0")")
...@@ -74,6 +74,9 @@ TENSORRTLLM_SKIP_CLONE=0 ...@@ -74,6 +74,9 @@ TENSORRTLLM_SKIP_CLONE=0
VLLM_BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base" VLLM_BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base"
VLLM_BASE_IMAGE_TAG="25.01-cuda12.8-devel-ubuntu24.04" VLLM_BASE_IMAGE_TAG="25.01-cuda12.8-devel-ubuntu24.04"
VLLM_NIXL_BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base"
VLLM_NIXL_BASE_IMAGE_TAG="25.01-cuda12.8-devel-ubuntu24.04"
get_options() { get_options() {
while :; do while :; do
case $1 in case $1 in
...@@ -180,6 +183,14 @@ get_options() { ...@@ -180,6 +183,14 @@ get_options() {
missing_requirement $1 missing_requirement $1
fi fi
;; ;;
--build-context)
if [ "$2" ]; then
BUILD_CONTEXT_ARG="--build-context $2"
shift
else
missing_requirement $1
fi
;;
--) --)
shift shift
break break
...@@ -274,6 +285,7 @@ show_help() { ...@@ -274,6 +285,7 @@ show_help() {
echo " [--tag tag for image]" echo " [--tag tag for image]"
echo " [--no-cache disable docker build cache]" echo " [--no-cache disable docker build cache]"
echo " [--dry-run print docker commands without running]" echo " [--dry-run print docker commands without running]"
echo " [--build-context name=path to add build context]"
exit 0 exit 0
} }
...@@ -292,6 +304,8 @@ get_options "$@" ...@@ -292,6 +304,8 @@ get_options "$@"
# Update DOCKERFILE if framework is VLLM # Update DOCKERFILE if framework is VLLM
if [[ $FRAMEWORK == "VLLM" ]]; then if [[ $FRAMEWORK == "VLLM" ]]; then
DOCKERFILE=${SOURCE_DIR}/Dockerfile.vllm DOCKERFILE=${SOURCE_DIR}/Dockerfile.vllm
elif [[ $FRAMEWORK == "VLLM_NIXL" ]]; then
DOCKERFILE=${SOURCE_DIR}/Dockerfile.vllm_nixl
fi fi
# BUILD DEV IMAGE # BUILD DEV IMAGE
...@@ -327,7 +341,7 @@ if [ -z "$RUN_PREFIX" ]; then ...@@ -327,7 +341,7 @@ if [ -z "$RUN_PREFIX" ]; then
set -x set -x
fi fi
$RUN_PREFIX docker build -f $DOCKERFILE $TARGET_STR $PLATFORM $BUILD_ARGS $CACHE_FROM $TAG $LATEST_TAG $BUILD_CONTEXT $NO_CACHE $RUN_PREFIX docker build -f $DOCKERFILE $TARGET_STR $PLATFORM $BUILD_ARGS $CACHE_FROM $TAG $LATEST_TAG $BUILD_CONTEXT_ARG $BUILD_CONTEXT $NO_CACHE
{ set +x; } 2>/dev/null { set +x; } 2>/dev/null
......
...@@ -27,4 +27,4 @@ pytestmark = pytest.mark.pre_merge ...@@ -27,4 +27,4 @@ pytestmark = pytest.mark.pre_merge
@pytest.mark.skipif(vllm is None, reason="Skipping vllm tests, vllm not installed") @pytest.mark.skipif(vllm is None, reason="Skipping vllm tests, vllm not installed")
def test_version(): def test_version():
# Verify that the image has the patched version of vllm # Verify that the image has the patched version of vllm
assert vllm.__version__.startswith("0.7.3.dev") assert vllm.__version__.startswith("0.7.3.dev") # type: ignore
diff --git a/vllm/config.py b/vllm/config.py diff --git a/vllm/config.py b/vllm/config.py
index 9ba49757..7e871521 100644 index 9ba49757..a2f88854 100644
--- a/vllm/config.py --- a/vllm/config.py
+++ b/vllm/config.py +++ b/vllm/config.py
@@ -2629,7 +2629,7 @@ class KVTransferConfig(BaseModel): @@ -2629,7 +2629,7 @@ class KVTransferConfig(BaseModel):
...@@ -26,7 +26,13 @@ index 9ba49757..7e871521 100644 ...@@ -26,7 +26,13 @@ index 9ba49757..7e871521 100644
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,
@@ -2685,6 +2693,7 @@ class KVTransferConfig(BaseModel): @@ -2680,11 +2688,12 @@ class KVTransferConfig(BaseModel):
f"Supported roles are `kv_producer`, `kv_consumer`, "
f"and `kv_both`")
- if self.kv_connector is not None and self.kv_role is None:
+ if self.kv_connector is not None and self.kv_connector != "TritonNixlConnector" and self.kv_role is None:
raise ValueError("Please specify kv_disagg_role when kv_connector "
"is set, supported roles are `kv_producer`, " "is set, supported roles are `kv_producer`, "
"`kv_consumer`, and `kv_both`") "`kv_consumer`, and `kv_both`")
...@@ -34,7 +40,16 @@ index 9ba49757..7e871521 100644 ...@@ -34,7 +40,16 @@ index 9ba49757..7e871521 100644
@property @property
def is_kv_transfer_instance(self) -> bool: def is_kv_transfer_instance(self) -> bool:
return self.kv_connector is not None and \ return self.kv_connector is not None and \
@@ -2706,6 +2715,18 @@ class KVTransferConfig(BaseModel): @@ -2694,6 +2703,8 @@ class KVTransferConfig(BaseModel):
def need_kv_parallel_group(self) -> bool:
# for those database-based connector, vLLM does not need to create
# parallel group, and in that case the kv parallel size will be 1.
+ if self.kv_connector == "TritonNixlConnector":
+ return False
return self.kv_connector is not None and self.kv_parallel_size > 1
@property
@@ -2706,6 +2717,18 @@ class KVTransferConfig(BaseModel):
return self.kv_connector is not None and \ return self.kv_connector is not None and \
self.kv_role in ["kv_consumer", "kv_both"] self.kv_role in ["kv_consumer", "kv_both"]
...@@ -118,6 +133,18 @@ index 359b5b26..d52ee050 100644 ...@@ -118,6 +133,18 @@ index 359b5b26..d52ee050 100644
self._swap_mapping: Dict[int, int] = {} self._swap_mapping: Dict[int, int] = {}
self._null_block: Optional[Block] = None self._null_block: Optional[Block] = None
diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py
index c388366b..c1883736 100644
--- a/vllm/core/block/naive_block.py
+++ b/vllm/core/block/naive_block.py
@@ -135,6 +135,7 @@ class NaiveBlockAllocator(BlockAllocator):
raise BlockAllocator.NoFreeBlocksError()
block_id = self._free_block_indices.popleft()
+ # TODO: figure out why sometime block_id is None
self._refcounter.incr(block_id)
return block_id
diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py
index 1ca9e49d..b1591c0c 100644 index 1ca9e49d..b1591c0c 100644
--- a/vllm/core/block/prefix_caching_block.py --- a/vllm/core/block/prefix_caching_block.py
...@@ -341,19 +368,32 @@ index 00000000..350453cd ...@@ -341,19 +368,32 @@ index 00000000..350453cd
+ +
+ self.event_id_counter += 1 + self.event_id_counter += 1
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index f507847a..6af77646 100644 index f507847a..ee20d50c 100644
--- a/vllm/core/scheduler.py --- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py
@@ -10,7 +10,7 @@ from typing import Callable, Deque, Dict, Iterable, List, Optional @@ -8,18 +8,17 @@ from collections import deque
from dataclasses import dataclass, field
from typing import Callable, Deque, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union -from typing import Set, Tuple, Union
+from typing import Set, Tuple, Union, Any
-from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig -from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
+from vllm.config import ModelConfig, CacheConfig, LoRAConfig, SchedulerConfig +from vllm.config import ModelConfig, CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@@ -325,12 +325,14 @@ class Scheduler: from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceGroupMetadataDelta,
- SequenceStatus)
+ SequenceStatus, SequenceStage)
from vllm.utils import Device, PyObjectCache
-
logger = init_logger(__name__)
# Test-only. If configured, decode is preempted with
@@ -325,12 +324,14 @@ class Scheduler:
def __init__( def __init__(
self, self,
...@@ -368,7 +408,7 @@ index f507847a..6af77646 100644 ...@@ -368,7 +408,7 @@ index f507847a..6af77646 100644
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.cache_config = cache_config self.cache_config = cache_config
# Note for LoRA scheduling: the current policy is extremely # Note for LoRA scheduling: the current policy is extremely
@@ -356,6 +358,7 @@ class Scheduler: @@ -356,6 +357,7 @@ class Scheduler:
# Create the block space manager. # Create the block space manager.
self.block_manager = BlockSpaceManagerImpl( self.block_manager = BlockSpaceManagerImpl(
...@@ -376,6 +416,450 @@ index f507847a..6af77646 100644 ...@@ -376,6 +416,450 @@ index f507847a..6af77646 100644
block_size=self.cache_config.block_size, block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks, num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks, num_cpu_blocks=num_cpu_blocks,
@@ -371,6 +373,14 @@ class Scheduler:
# Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out.
self.swapped: Deque[SequenceGroup] = deque()
+
+ # Sequence groups in the REMOTE_PREFILLING state.
+ # Contain requests that are being prefilled by a remote worker.
+ self.remote_prefilling: Deque[SequenceGroup] = deque()
+
+ self._remote_prefill_outputs: Dict[str, int] = {}
+
+
# Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests
# can and must be released after the current step.
@@ -501,7 +511,7 @@ class Scheduler:
def has_unfinished_seqs(self) -> bool:
return len(self.waiting) != 0 or len(self.running) != 0 or len(
- self.swapped) != 0
+ self.swapped) != 0 or len(self.remote_prefilling) != 0
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)
@@ -523,6 +533,7 @@ class Scheduler:
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
enable_chunking: bool = False,
+ finished_prefills: Optional[Set[str]] = None
) -> SchedulerRunningOutputs:
"""Schedule sequence groups that are running.
@@ -537,6 +548,8 @@ class Scheduler:
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
+ finished_remote_prefill_request_ids: Set of request ids of remote
+ prefills that have finished.
Returns:
SchedulerRunningOutputs.
@@ -566,6 +579,24 @@ class Scheduler:
preempted: List[SequenceGroup] = ret.preempted
swapped_out: List[SequenceGroup] = ret.swapped_out
+ remote_prefilling_queue = self.remote_prefilling
+ leftover_remote_prefilling_sequences: Deque[SequenceGroup] = deque()
+ while remote_prefilling_queue:
+ seq_group = remote_prefilling_queue.popleft()
+ if seq_group.request_id not in finished_prefills:
+ leftover_remote_prefilling_sequences.append(seq_group)
+ continue
+ else:
+ finished_prefills.remove(seq_group.request_id)
+ assert len(seq_group.seqs) == 1
+ seq = seq_group.seqs[0]
+ # we computed all but the last token in prefill, we need to decode the first token on decode
+ seq_group.update_num_computed_tokens(seq.get_len() - 1)
+ seq.status = SequenceStatus.RUNNING
+ seq.data._stage = SequenceStage.DECODE
+ self.running.appendleft(seq_group)
+ remote_prefilling_queue.extendleft(leftover_remote_prefilling_sequences)
+
running_queue = self.running
assert len(self._async_stopped) == 0
while running_queue:
@@ -1008,7 +1039,7 @@ class Scheduler:
if curr_loras is not None and lora_int_id > 0:
curr_loras.add(lora_int_id)
waiting_queue.popleft()
- self._allocate_and_set_running(seq_group)
+ self._allocate_and_set_running_or_remote_prefill(seq_group)
if enable_chunking and self.scheduler_config.is_multi_step:
blocks_to_copy: List[Tuple[int, int]] = []
@@ -1048,7 +1079,7 @@ class Scheduler:
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True, enable_chunking=enable_chunking))
- def _schedule_default(self) -> SchedulerOutputs:
+ def _schedule_default(self, finished_prefills: Optional[Set[str]] = None) -> SchedulerOutputs:
"""Schedule queued requests.
The current policy is designed to optimize the throughput. First,
@@ -1090,7 +1121,8 @@ class Scheduler:
if len(prefills.seq_groups) == 0:
running_scheduled = self._schedule_running(budget,
curr_loras,
- enable_chunking=False)
+ enable_chunking=False,
+ finished_prefills=finished_prefills)
# If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests.
@@ -1106,7 +1138,12 @@ class Scheduler:
self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests.
if len(prefills.seq_groups) > 0:
- self.running.extend([s.seq_group for s in prefills.seq_groups])
+ for s in prefills.seq_groups:
+ seq_group = s.seq_group
+ if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill:
+ self.remote_prefilling.append(seq_group)
+ else:
+ self.running.append(seq_group)
self.running.extend(running_scheduled.decode_seq_groups_list)
@@ -1248,12 +1285,14 @@ class Scheduler:
len(running_scheduled.swapped_out)),
)
- def _schedule(self) -> SchedulerOutputs:
+ def _schedule(self, finished_prefills: Optional[Set[str]] = None) -> SchedulerOutputs:
"""Schedule queued requests."""
if self.scheduler_config.chunked_prefill_enabled:
+ if finished_prefills:
+ raise ValueError("Chunked prefill does not support remote prefills")
return self._schedule_chunked_prefill()
else:
- return self._schedule_default()
+ return self._schedule_default(finished_prefills)
def _can_append_slots(self, seq_group: SequenceGroup,
enable_chunking: bool) -> bool:
@@ -1287,14 +1326,15 @@ class Scheduler:
return no_single_seq
def schedule(
- self
+ self,
+ finished_prefills: Optional[Set[str]] = None
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]:
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting.
- scheduler_start_time = time.perf_counter()
- scheduler_outputs: SchedulerOutputs = self._schedule()
+ scheduler_start_time = time.perf_counter()
+ scheduler_outputs: SchedulerOutputs = self._schedule(finished_prefills)
now = time.time()
if not self.cache_config.enable_prefix_caching:
@@ -1333,7 +1373,8 @@ class Scheduler:
encoder_seq_data = None
cross_block_table = None
- for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
+ running_or_remote_prefilling_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + seq_group.get_seqs(status=SequenceStatus.REMOTE_PREFILLING)
+ for seq in running_or_remote_prefilling_seqs:
seq_id = seq.seq_id
seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq)
@@ -1364,6 +1405,10 @@ class Scheduler:
< seqs[0].data.get_len()):
do_sample = False
+ is_remote_prefill = False
+ if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill:
+ is_remote_prefill = True
+
# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
if is_first_prefill or not self.scheduler_config.send_delta_data:
@@ -1392,6 +1437,7 @@ class Scheduler:
if scheduler_outputs.num_prefill_groups > 0 else None,
mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request,
+ do_remote_prefill=is_remote_prefill,
)
else:
# When SPMD mode is enabled, we only send delta data except for
@@ -1490,10 +1536,13 @@ class Scheduler:
self._async_stopped.clear()
- def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
+ def _allocate_and_set_running_or_remote_prefill(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group)
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
- seq.status = SequenceStatus.RUNNING
+ if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill:
+ seq.status = SequenceStatus.REMOTE_PREFILLING
+ else:
+ seq.status = SequenceStatus.RUNNING
def _append_slots(self,
seq_group: SequenceGroup,
diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py
new file mode 100644
index 00000000..bc962726
--- /dev/null
+++ b/vllm/distributed/device_communicators/nixl.py
@@ -0,0 +1,249 @@
+import torch
+from typing import List, Tuple
+from vllm.config import VllmConfig
+from vllm.logger import init_logger
+import msgspec
+import time
+import uuid
+from nixl_wrapper import nixl_wrapper as NixlWrapper
+
+logger = init_logger(__name__)
+
+
+def nixl_wrapper_init_patch(self, agent_name, nixl_config):
+ logger.info("Initializing patched NixlWrapper")
+ import nixl_bindings as nixl
+ # Read available backends and device info from nixl_config
+ # For now setting the multithreading to enabled.
+ devices = nixl.nixlAgentConfig(False)
+ init = nixl.nixlUcxInitParams()
+
+ self.name = agent_name
+ self.notifs = {}
+ self.backends = {}
+ self.agent = nixl.nixlAgent(agent_name, devices)
+ self.backends["UCX"] = self.agent.createBackend(init)
+
+ self.nixl_mems = {"DRAM": nixl.DRAM_SEG,
+ "VRAM": nixl.VRAM_SEG,
+ "cpu": nixl.DRAM_SEG,
+ "cuda": nixl.VRAM_SEG}
+ self.nixl_ops = {"WRITE": nixl.NIXL_WR_FLUSH,
+ "READ": nixl.NIXL_RD_FLUSH,
+ "WRITE_NOTIF": nixl.NIXL_WR_NOTIF,
+ "READ_NOTIF": nixl.NIXL_RD_NOTIF}
+
+ print("Initializied NIXL agent:", agent_name)
+
+NixlWrapper.__init__ = nixl_wrapper_init_patch
+
+
+
+class NixlMetadata(
+ msgspec.Struct,
+ omit_defaults=True, # type: ignore[call-arg]
+ # required for @cached_property.
+ dict=True):
+ engine_id: str
+ agent_metadata: List[bytes]
+ kv_caches_base_addr: List[List[Tuple[int, int]]] # base address for each rank for each layer for keys and values
+
+
+class TritonNixlConnector:
+ def __init__(self, vllm_config: VllmConfig, engine_id: str, rank: int):
+ self.vllm_config = vllm_config
+ self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
+
+ self.num_layers = None
+ self.num_blocks = None
+ self.block_len = None
+ self.kv_caches_base_addr = {}
+ self.kv_cache_shape = {}
+
+ self._registered_descs = []
+ self._remote_agents = {}
+ self.engine_id = engine_id
+ self.rank = rank
+ self.notifs = {}
+
+ @property
+ def agent_name(self):
+ return self.nixl_wrapper.name
+
+ def register_kv_caches(self, kv_caches: List[torch.Tensor]):
+ caches_data = []
+ self.num_layers = len(kv_caches)
+ _, _, block_size, num_heads, head_dim = kv_caches[0].shape
+ self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size()
+ logger.debug("Per layer kv cache size: %s", kv_caches[0].shape)
+
+ kv_caches_base_addr = []
+ for key_cache, value_cache in kv_caches:
+ for cache in [key_cache, value_cache]:
+ base_addr = cache.data_ptr()
+ region_len = cache.numel() * cache.element_size()
+ gpu_id = cache.get_device()
+ assert gpu_id > -1, "Tensor is not on GPU"
+ caches_data.append((base_addr, region_len, gpu_id))
+ kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr()))
+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
+
+ descs = self.nixl_wrapper.get_descs(("VRAM", caches_data))
+ self.nixl_wrapper.register_memory(descs)
+ self._registered_descs.append(descs)
+
+ def get_agent_metadata(self):
+ return self.nixl_wrapper.get_agent_metadata()
+
+ def shutdown(self):
+ for descs_list in self._registered_descs:
+ self.nixl_wrapper.deregister_memory(descs_list)
+ for agent_name in self._remote_agents.values():
+ self.nixl_wrapper.remove_remote_agent(agent_name)
+
+ def add_remote_agent(self, engine_id, agent_metadata):
+ agent_name = self.nixl_wrapper.add_remote_agent(agent_metadata)
+ self._remote_agents[engine_id] = agent_name
+ return agent_name
+
+ def get_descs_ids(self, layer_ids, block_ids):
+ if layer_ids == "all":
+ layer_ids = list(range(self.num_layers))
+ if block_ids == "all":
+ block_ids = list(range(self.num_blocks))
+ descs_ids = []
+ for layer_id in layer_ids:
+ for block_id in block_ids:
+ assert block_id < self.num_blocks, f"Block id {block_id} is greater than the number of blocks {self.num_blocks}"
+ descs_ids.append(2 * (self.num_blocks * layer_id + block_id))
+ descs_ids.append(2 * (self.num_blocks * layer_id + block_id) + 1)
+ return descs_ids
+
+ def _get_range_descs(self, engine_id, ranges, layer_ids):
+ if layer_ids == "all":
+ layer_ids = list(range(self.num_layers))
+ blocks_data = []
+ for layer_id in layer_ids:
+ for range_start, range_end in ranges:
+ key_base_addr, value_base_addr = self.kv_caches_base_addr[engine_id][layer_id]
+ start_offset = range_start * self.block_len
+ blocks_len = (range_end - range_start + 1) * self.block_len
+ blocks_data.append((key_base_addr + start_offset, blocks_len, self.rank))
+ blocks_data.append((value_base_addr + start_offset, blocks_len, self.rank))
+ return self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+
+ def _get_ranges(self, block_ids):
+ # This function should return a list of ranges of block ids that are contiguous
+ # For example, if block_ids is [0, 1, 2, 4, 5, 6], the function should return [[0, 2], [4, 6]]
+ # The ranges are sorted by the starting block id
+ # The function should also make sure that the block ids are contiguous
+ # If the block ids are not contiguous, the function should raise an error
+ sorted_block_ids = sorted(block_ids)
+ ranges = []
+ for i in range(len(sorted_block_ids)):
+ if i == 0 or sorted_block_ids[i] != sorted_block_ids[i-1] + 1:
+ ranges.append([sorted_block_ids[i]])
+ else:
+ ranges[-1].append(sorted_block_ids[i])
+ return ranges
+
+ def _get_same_length_ranges(self, src_ranges, dst_ranges):
+ # This function should return a list of ranges for both src and dst so that corresponding ranges are the same length
+ # For example, if src_ranges is [[0, 2] [4, 8]] and dst_ranges is [[1, 3], [5, 7], [9, 10]]
+ # The function should return ([[0, 2], [4, 6], [7, 8]], [[1, 3], [5, 7], [9, 10]])
+ src_overlapping_ranges, dst_overlapping_ranges = [], []
+
+ src_idx, dst_idx = 0, 0
+ while src_idx < len(src_ranges) and dst_idx < len(dst_ranges):
+ src_range = src_ranges[src_idx]
+ dst_range = dst_ranges[dst_idx]
+
+ # Calculate the length of each range
+ src_len = src_range[-1] - src_range[0] + 1
+ dst_len = dst_range[-1] - dst_range[0] + 1
+
+ # If ranges have the same length, add them directly
+ if src_len == dst_len:
+ src_overlapping_ranges.append([src_range[0], src_range[-1]])
+ dst_overlapping_ranges.append([dst_range[0], dst_range[-1]])
+ src_idx += 1
+ dst_idx += 1
+ # If source range is longer, split it
+ elif src_len > dst_len:
+ src_overlapping_ranges.append([src_range[0], src_range[0] + dst_len - 1])
+ dst_overlapping_ranges.append([dst_range[0], dst_range[-1]])
+ # Update source range for next iteration
+ src_ranges[src_idx] = [src_range[0] + dst_len, src_range[-1]]
+ dst_idx += 1
+ # If destination range is longer, split it
+ else: # src_len < dst_len
+ src_overlapping_ranges.append([src_range[0], src_range[-1]])
+ dst_overlapping_ranges.append([dst_range[0], dst_range[0] + src_len - 1])
+ # Update destination range for next iteration
+ dst_ranges[dst_idx] = [dst_range[0] + src_len, dst_range[-1]]
+ src_idx += 1
+
+ return src_overlapping_ranges, dst_overlapping_ranges
+
+
+
+ def transfer_mem(self, src_block_ids, dst_block_ids, dst_engine_id, notify_msg):
+
+ start_time = time.perf_counter()
+ logger.debug("Transferring memory from %s to %s with notify message %s", self.agent_name, dst_engine_id, notify_msg)
+
+ # hongkuanz: we send isl[:-1] tokens to the prefill where the kv for the last
+ # isl[-1] token is calculated in the first iteration in decode.
+ # If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \
+ # one less block due to the missing last token.
+ dst_block_ids = dst_block_ids[:len(src_block_ids)]
+
+ src_ranges = self._get_ranges(src_block_ids)
+ dst_ranges = self._get_ranges(dst_block_ids)
+ src_overlapping_ranges, dst_overlapping_ranges = self._get_same_length_ranges(src_ranges, dst_ranges)
+
+ logger.debug("Got %s overlapping ranges for %s blocks", len(src_overlapping_ranges), len(src_block_ids))
+
+ logger.debug("Time to get ranges: %s ms", time.perf_counter() - start_time)
+
+ src_descs = self._get_range_descs(self.engine_id, src_overlapping_ranges, "all")
+ dst_descs = self._get_range_descs(dst_engine_id, dst_overlapping_ranges, "all")
+
+ logger.debug("Time to get descs: %s ms", (time.perf_counter() - start_time) * 1000)
+
+ handle = self.nixl_wrapper.initialize_xfer(src_descs, dst_descs, self._remote_agents[dst_engine_id], notify_msg, "WRITE")
+ logger.debug("Time to initialize xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer handle: %s", handle)
+ status = self.nixl_wrapper.transfer(handle)
+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer status: %s", status)
+ # TODO ptarasiewicz: remove blocking transfer mem
+ # add scheduler check for transfer done
+ while True:
+ xfer_state = self.nixl_wrapper.check_xfer_state(handle)
+ if xfer_state == "ERR":
+ raise RuntimeError("Transfer failed")
+ elif xfer_state == "DONE":
+ logger.debug("Transfer done")
+ break
+ elif xfer_state == "PROC":
+ time.sleep(0.01)
+ else:
+ raise RuntimeError("Unknown transfer state")
+ logger.debug("Time to wait for transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ self.nixl_wrapper.abort_xfer(handle)
+ logger.debug("Time to abort xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer time: %s ms", (time.perf_counter() - start_time) * 1000)
+
+ def deserialize_descs(self, serialized_descs):
+ return self.nixl_wrapper.deserialize_descs(serialized_descs)
+
+ def get_notifs(self):
+ self.notifs = self.nixl_wrapper.agent.getNotifs(self.notifs)
+ return self.notifs
+
+ def get_new_notifs(self):
+ return self.nixl_wrapper.agent.getNotifs({})
+
+ def add_remote_kv_caches_base_addr(self, engine_id, kv_caches_base_addr):
+ self.kv_caches_base_addr[engine_id] = kv_caches_base_addr
diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py
index fe480533..61a357d0 100644 index fe480533..61a357d0 100644
--- a/vllm/distributed/kv_transfer/kv_connector/factory.py --- a/vllm/distributed/kv_transfer/kv_connector/factory.py
...@@ -1737,10 +2221,75 @@ index 321902d1..b8937ef8 100644 ...@@ -1737,10 +2221,75 @@ index 321902d1..b8937ef8 100644
def ensure_model_parallel_initialized( def ensure_model_parallel_initialized(
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index d82d9ad9..542ccfe8 100644 index d82d9ad9..9ba1a326 100644
--- a/vllm/engine/llm_engine.py --- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py
@@ -348,7 +348,7 @@ class LLMEngine: @@ -2,13 +2,17 @@
import copy
import time
+import pickle
+import uuid
from collections import Counter as collectionsCounter
from collections import deque
+from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
+from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
- List, Mapping, NamedTuple, Optional)
+ List, Mapping, NamedTuple, Optional, Tuple)
from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast, overload
@@ -60,6 +64,9 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
from vllm.version import __version__ as VLLM_VERSION
+from vllm.remote_prefill import RemotePrefillRequest, RemotePrefillParams, MemoryTransferRequest
+from vllm.distributed.device_communicators.nixl import NixlMetadata
+
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
@@ -90,7 +97,7 @@ class OutputData(NamedTuple):
# outputs from multiple steps.
is_first_step_output: Optional[bool]
skip: List[int]
-
+ remote_prefill_requests: Optional[List[RemotePrefillRequest]]
class SchedulerContext:
@@ -104,11 +111,14 @@ class SchedulerContext:
self.multi_step_stream_outputs: bool = multi_step_stream_outputs
+ self.remote_prefill_requests: List[RemotePrefillRequest] = []
+
def append_output(self, outputs: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduler_outputs: SchedulerOutputs, is_async: bool,
is_last_step: bool,
- is_first_step_output: Optional[bool]):
+ is_first_step_output: Optional[bool],
+ remote_prefill_requests: Optional[List[RemotePrefillRequest]] = None):
self.output_queue.append(
OutputData(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
@@ -116,7 +126,9 @@ class SchedulerContext:
is_async=is_async,
is_last_step=is_last_step,
is_first_step_output=is_first_step_output,
- skip=[]))
+ skip=[],
+ remote_prefill_requests=remote_prefill_requests))
+
class LLMEngine:
@@ -348,7 +360,7 @@ class LLMEngine:
# GPU and CPU blocks, which are profiled in the distributed executor. # GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = [ self.scheduler = [
Scheduler( Scheduler(
...@@ -1749,19 +2298,315 @@ index d82d9ad9..542ccfe8 100644 ...@@ -1749,19 +2298,315 @@ index d82d9ad9..542ccfe8 100644
self.parallel_config.pipeline_parallel_size, self.parallel_config.pipeline_parallel_size,
self.async_callbacks[v_id] self.async_callbacks[v_id]
if self.model_config.use_async_output_proc else None) if self.model_config.use_async_output_proc else None)
@@ -405,6 +417,39 @@ class LLMEngine:
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
+ self.engine_id = str(uuid.uuid4())
+ self._nixl_agents_names: Optional[List[str]] = None
+ if self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "TritonNixlConnector":
+ self._nixl_agents_names = self._initialize_nixl()
+
+ self._request_notif_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size)
+ self._finished_prefills = set()
+
+ @property
+ def is_nixl_initialized(self) -> bool:
+ return self._nixl_agents_names is not None
+
+ def get_nixl_metadata(self) -> NixlMetadata:
+ if not self.is_nixl_initialized:
+ raise RuntimeError("Nixl is not initialized")
+ agent_metadata = self.model_executor.collective_rpc("get_nixl_agent_metadata")
+ kv_caches_base_addr = self.model_executor.collective_rpc("get_nixl_kv_caches_base_addr")
+ return NixlMetadata(engine_id=self.engine_id, agent_metadata=agent_metadata, kv_caches_base_addr=kv_caches_base_addr)
+
+ def add_remote_nixl_metadata(self, nixl_metadata: NixlMetadata) -> List[str]:
+ if not self.is_nixl_initialized:
+ raise RuntimeError("Nixl is not initialized")
+ engine_id = nixl_metadata.engine_id
+ agents_metadata = nixl_metadata.agent_metadata
+ kv_caches_base_addr = nixl_metadata.kv_caches_base_addr
+ if len(agents_metadata) != len(self._nixl_agents_names):
+ raise ValueError("Number of agents does not match. Make sure all engines are initialized with the same parallel sizes.")
+ return self.model_executor.collective_rpc("add_remote_nixl_metadata", args=(engine_id, agents_metadata, kv_caches_base_addr))
+
+ def _initialize_nixl(self) -> List[bytes]:
+ agents_names = self.model_executor.collective_rpc("initialize_nixl", args=(self.engine_id,))
+ return agents_names
+
def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
@@ -500,6 +545,8 @@ class LLMEngine:
# Shutdown model executor when engine is garbage collected
# Use getattr since __init__ can fail before the field is set
if model_executor := getattr(self, "model_executor", None):
+ if self._nixl_agents_names:
+ model_executor.collective_rpc("shutdown_nixl")
model_executor.shutdown()
def get_tokenizer_group(
@@ -552,11 +599,14 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
+ remote_prefill_params: Optional[RemotePrefillParams] = None,
) -> Optional[SequenceGroup]:
"""Add a processed request to the engine's request pool.
return the created sequence group.
"""
if isinstance(params, SamplingParams) and params.n > 1:
+ if remote_prefill_params is not None:
+ raise ValueError("Remote prefill params are not supported for multi-step sampling")
ParallelSampleSequenceGroup.add_request(
request_id,
self,
@@ -584,7 +634,7 @@ class LLMEngine:
encoder_inputs = None
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
- lora_request, prompt_adapter_request)
+ lora_request, prompt_adapter_request, remote_prefill_params)
encoder_seq = (None if encoder_inputs is None else Sequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
@@ -601,8 +651,12 @@ class LLMEngine:
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
- priority=priority)
+ priority=priority,
+ remote_prefill_params=remote_prefill_params,
+ )
elif isinstance(params, PoolingParams):
+ if remote_prefill_params is not None:
+ raise ValueError("Remote prefill params are not supported for pooling")
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
@@ -673,6 +727,7 @@ class LLMEngine:
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
+ remote_prefill_params: Optional[RemotePrefillParams] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
@@ -765,6 +820,7 @@ class LLMEngine:
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=priority,
+ remote_prefill_params=remote_prefill_params,
)
def _validate_token_prompt(self, prompt: PromptType,
@@ -799,6 +855,7 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
encoder_seq: Optional[Sequence] = None,
priority: int = 0,
+ remote_prefill_params: Optional[RemotePrefillParams] = None,
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs
@@ -829,7 +886,9 @@ class LLMEngine:
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
- priority=priority)
+ priority=priority,
+ remote_prefill_params=remote_prefill_params
+ )
return seq_group
@@ -995,11 +1054,11 @@ class LLMEngine:
# When we process only one request, no pop is required
# (since later we will process all of the rest)
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
- is_last_step, is_first_step_output, skip) = ctx.output_queue[0]
+ is_last_step, is_first_step_output, skip, remote_prefill_requests) = ctx.output_queue[0]
else:
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step, is_first_step_output,
- skip) = ctx.output_queue.popleft()
+ skip, remote_prefill_requests) = ctx.output_queue.popleft()
# Sanity check
assert len(seq_group_metadata_list) == len(
@@ -1325,15 +1384,49 @@ class LLMEngine:
# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()
+ ctx.remote_prefill_requests.clear()
# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
+ remote_prefill_seq_group_metadata_list: List[SequenceGroupMetadata] = []
+ running_seq_group_metadata_list: List[SequenceGroupMetadata] = []
+ remote_prefill_scheduled_seq_groups: List[ScheduledSequenceGroup] = []
+ running_scheduled_seq_groups: List[ScheduledSequenceGroup] = []
+
if not self._has_remaining_steps(seq_group_metadata_list):
- # Schedule iteration
+
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
- ) = self.scheduler[virtual_engine].schedule()
+ ) = self.scheduler[virtual_engine].schedule(self._finished_prefills)
+
+
+ # Separate remote prefill and running seq groups
+ for seq_group_metadata, scheduled_seq_group in zip(seq_group_metadata_list, scheduler_outputs.scheduled_seq_groups):
+ if seq_group_metadata.do_remote_prefill:
+ remote_prefill_seq_group_metadata_list.append(seq_group_metadata)
+ remote_prefill_scheduled_seq_groups.append(scheduled_seq_group)
+ else:
+ running_seq_group_metadata_list.append(seq_group_metadata)
+ running_scheduled_seq_groups.append(scheduled_seq_group)
+
+ seq_group_metadata_list = running_seq_group_metadata_list
+ scheduler_outputs.scheduled_seq_groups = running_scheduled_seq_groups
+
+ # Send remote prefill requests before model execution
+ for seq_group_metadata, scheduled_seq_group in zip(remote_prefill_seq_group_metadata_list, remote_prefill_scheduled_seq_groups):
+ assert len(scheduled_seq_group.seq_group.seqs) == 1
+ assert self._nixl_agents_names
+ seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id
+ block_table = seq_group_metadata.block_tables[seq_id]
+ remote_prefill_request = RemotePrefillRequest(
+ request_id=seq_group_metadata.request_id,
+ prompt_token_ids=scheduled_seq_group.seq_group.seqs[0].inputs.prompt_token_ids[:-1], # last one will be decoded on decode for sampling anyway
+ sampling_params=scheduled_seq_group.seq_group.sampling_params,
+ block_ids=block_table,
+ engine_id=self.engine_id,
+ )
+ scheduled_seq_group.seq_group.remote_prefill_params.remote_prefill_request_callback(remote_prefill_request)
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
@@ -1383,9 +1476,29 @@ class LLMEngine:
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
- outputs = self.model_executor.execute_model(
+ # After model execution, we need to transfer the memory from the prefill to the decode
+ memory_transfer_reqs = []
+ for scheduled_seq_group, seq_group_metadata in zip(scheduler_outputs.scheduled_seq_groups, seq_group_metadata_list):
+ remote_prefill_params = scheduled_seq_group.seq_group.remote_prefill_params
+ if remote_prefill_params is not None and remote_prefill_params.is_remote_decode:
+ assert len(scheduled_seq_group.seq_group.seqs) == 1
+ req_id = scheduled_seq_group.seq_group.request_id
+ seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id
+ block_table = seq_group_metadata.block_tables[seq_id]
+ memory_transfer_req = MemoryTransferRequest(
+ request_id=req_id,
+ src_block_ids=block_table,
+ dst_block_ids=remote_prefill_params.decode_block_ids,
+ dst_engine_id=remote_prefill_params.decode_engine_id,
+ notify_msg=req_id,
+ )
+
+ memory_transfer_reqs.append(memory_transfer_req)
+
+ execute_model_req.memory_transfer_requests = memory_transfer_reqs
+
+ outputs, request_notif_counter = self.model_executor.execute_model(
execute_model_req=execute_model_req)
-
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
@@ -1396,7 +1509,20 @@ class LLMEngine:
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
# No outputs in this case
- outputs = []
+ execute_model_req = ExecuteModelRequest(
+ seq_group_metadata_list=[],
+ blocks_to_swap_in=[],
+ blocks_to_swap_out=[],
+ blocks_to_copy=[])
+
+ outputs, request_notif_counter = self.model_executor.execute_model(
+ execute_model_req=execute_model_req)
+
+ for req_id, notif_count in request_notif_counter.items():
+ self._request_notif_counter[req_id] += notif_count
+ if self._request_notif_counter[req_id] > -1:
+ self._finished_prefills.add(req_id)
+ del self._request_notif_counter[req_id]
# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
@@ -1456,7 +1582,7 @@ class LLMEngine:
# queued control plane messages, such as add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop()
-
+
return ctx.request_outputs
def _has_remaining_steps(
diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py
index 3cf1850e..38acca0e 100644 index 3cf1850e..6b90ece7 100644
--- a/vllm/engine/multiprocessing/__init__.py --- a/vllm/engine/multiprocessing/__init__.py
+++ b/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py
@@ -21,6 +21,7 @@ IPC_INPUT_EXT = "_input_socket" @@ -14,13 +14,17 @@ from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.utils import deprecate_kwargs
-
+from vllm.remote_prefill import RemotePrefillParams
+from vllm.distributed.device_communicators.nixl import NixlMetadata
VLLM_RPC_SUCCESS_STR = "SUCCESS"
IPC_INPUT_EXT = "_input_socket"
IPC_OUTPUT_EXT = "_output_socket" IPC_OUTPUT_EXT = "_output_socket"
IPC_HEALTH_EXT = "_health_socket" IPC_HEALTH_EXT = "_health_socket"
IPC_DATA_EXT = "_data_socket" IPC_DATA_EXT = "_data_socket"
+IPC_REMOTE_PREFILL_REQUEST_EXT = "_remote_prefill_request_socket"
+IPC_REMOTE_NIXL_METADATA_EXT = "_remote_nixl_metadata_socket"
+IPC_METRICS_EXT = "_metrics_socket" +IPC_METRICS_EXT = "_metrics_socket"
class MQEngineDeadError(RuntimeError): class MQEngineDeadError(RuntimeError):
@@ -157,3 +158,10 @@ def ENGINE_DEAD_ERROR( @@ -36,6 +40,7 @@ class RPCProcessRequest:
trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
priority: int = 0
+ remote_prefill_params: Optional[RemotePrefillParams] = None
@overload
def __init__(
@@ -78,6 +83,7 @@ class RPCProcessRequest:
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
+ remote_prefill_params: Optional[RemotePrefillParams] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
@@ -95,7 +101,7 @@ class RPCProcessRequest:
self.trace_headers = trace_headers
self.prompt_adapter_request = prompt_adapter_request
self.priority = priority
-
+ self.remote_prefill_params = remote_prefill_params
@dataclass
class RPCError:
@@ -116,7 +122,7 @@ class RPCStartupRequest(Enum):
@dataclass
class RPCStartupResponse:
tracing_enabled: bool
-
+ nixl_metadata: Optional[bytes] = None
class RPCUProfileRequest(Enum):
START_PROFILE = 1
@@ -157,3 +163,10 @@ def ENGINE_DEAD_ERROR(
return MQEngineDeadError( return MQEngineDeadError(
"Engine loop is not running. Inspect the stacktrace to " "Engine loop is not running. Inspect the stacktrace to "
f"find the original error: {repr(error)}.") f"find the original error: {repr(error)}.")
...@@ -1773,17 +2618,27 @@ index 3cf1850e..38acca0e 100644 ...@@ -1773,17 +2618,27 @@ index 3cf1850e..38acca0e 100644
+ kv_active_blocks: int + kv_active_blocks: int
+ kv_total_blocks: int + kv_total_blocks: int
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
index 85b5f31e..6a7ea3ae 100644 index 85b5f31e..d33d546a 100644
--- a/vllm/engine/multiprocessing/client.py --- a/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py
@@ -25,14 +25,15 @@ from vllm.engine.async_llm_engine import ( @@ -8,6 +8,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
Optional, Union, cast, overload)
import cloudpickle
+import msgspec
import psutil
import zmq
import zmq.asyncio
@@ -25,14 +26,16 @@ from vllm.engine.async_llm_engine import (
build_guided_decoding_logits_processor_async) build_guided_decoding_logits_processor_async)
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT,
- IPC_OUTPUT_EXT, RPC_REQUEST_T, - IPC_OUTPUT_EXT, RPC_REQUEST_T,
+ IPC_OUTPUT_EXT, IPC_METRICS_EXT, - VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
+ IPC_OUTPUT_EXT, IPC_REMOTE_PREFILL_REQUEST_EXT,
+ RPC_REQUEST_T, + RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + VLLM_RPC_SUCCESS_STR, IPC_REMOTE_NIXL_METADATA_EXT, RPCAbortRequest,
+ IPC_METRICS_EXT,
RPCAdapterLoadedResponse, RPCError, RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest, RPCLoadAdapterRequest,
RPCProcessRequest, RPCProcessRequest,
...@@ -1794,7 +2649,24 @@ index 85b5f31e..6a7ea3ae 100644 ...@@ -1794,7 +2649,24 @@ index 85b5f31e..6a7ea3ae 100644
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
# yapf: enable # yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT from vllm.envs import VLLM_RPC_TIMEOUT
@@ -115,6 +116,10 @@ class MQLLMEngineClient(EngineClient): @@ -46,6 +49,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import deprecate_kwargs
+from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest, RemotePrefillRequestCallback
+from vllm.distributed.device_communicators.nixl import NixlMetadata
logger = init_logger(__name__)
@@ -91,6 +96,7 @@ class MQLLMEngineClient(EngineClient):
self._errored_with: Optional[BaseException] = None
# Get the configs.
+ self.vllm_config = engine_config
self.model_config = engine_config.model_config
self.decoding_config = engine_config.decoding_config
@@ -115,6 +121,10 @@ class MQLLMEngineClient(EngineClient):
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL) self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
...@@ -1805,7 +2677,7 @@ index 85b5f31e..6a7ea3ae 100644 ...@@ -1805,7 +2677,7 @@ index 85b5f31e..6a7ea3ae 100644
# IPC path for the data socket. # IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
@@ -129,6 +134,12 @@ class MQLLMEngineClient(EngineClient): @@ -129,8 +139,27 @@ class MQLLMEngineClient(EngineClient):
# Loop to check health of the LLMEngine periodically. # Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready. # Started after the MQLLMEngine is ready.
self.health_loop: Optional[asyncio.Task] = None self.health_loop: Optional[asyncio.Task] = None
...@@ -1817,11 +2689,36 @@ index 85b5f31e..6a7ea3ae 100644 ...@@ -1817,11 +2689,36 @@ index 85b5f31e..6a7ea3ae 100644
+ +
self._engine_process = psutil.Process(engine_pid) self._engine_process = psutil.Process(engine_pid)
+ self.nixl_metadata: Optional[NixlMetadata] = None
+ self.remote_prefill_request_socket: Socket = self.context.socket(zmq.constants.PULL)
+ self.remote_nixl_metadata_socket: Socket = self.context.socket(zmq.constants.PUSH)
+ self.remote_prefill_requests_callback: Dict[str, RemotePrefillRequestCallback] = {}
+ if self.using_nixl_connector:
+ self.remote_prefill_request_socket.connect(f"{ipc_path}{IPC_REMOTE_PREFILL_REQUEST_EXT}")
+ self.remote_nixl_metadata_socket.connect(f"{ipc_path}{IPC_REMOTE_NIXL_METADATA_EXT}")
+
+
+ @property
+ def using_nixl_connector(self) -> bool:
+ return self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "TritonNixlConnector"
+
@staticmethod @staticmethod
@@ -180,6 +191,46 @@ class MQLLMEngineClient(EngineClient): def is_unsupported_config(engine_args: AsyncEngineArgs):
# Pipeline parallel not yet supported
@@ -180,6 +209,56 @@ class MQLLMEngineClient(EngineClient):
except Exception as e: except Exception as e:
self._set_errored(e) self._set_errored(e)
+ async def run_remote_prefill_request_handler_loop(self):
+ try:
+ while True:
+ if await self.remote_prefill_request_socket.poll(timeout=VLLM_RPC_TIMEOUT):
+ frames = await self.remote_prefill_request_socket.recv(copy=False)
+ remote_prefill_request = msgspec.msgpack.decode(frames.buffer, type=RemotePrefillRequest)
+ await self.remote_prefill_requests_callback[remote_prefill_request.request_id](remote_prefill_request)
+ except asyncio.CancelledError:
+ logger.debug("Shutting down MQLLMEngineClient remote prefill request handler loop.")
+
+ async def run_metrics_loop(self, timeout: int): + async def run_metrics_loop(self, timeout: int):
+ """Background loop that continually checks to ensure the engine process + """Background loop that continually checks to ensure the engine process
+ is still alive. + is still alive.
...@@ -1865,11 +2762,25 @@ index 85b5f31e..6a7ea3ae 100644 ...@@ -1865,11 +2762,25 @@ index 85b5f31e..6a7ea3ae 100644
async def run_output_handler_loop(self): async def run_output_handler_loop(self):
"""Get RequestOutputs from Engine and stream to Request Queues""" """Get RequestOutputs from Engine and stream to Request Queues"""
@@ -284,6 +335,12 @@ class MQLLMEngineClient(EngineClient): @@ -278,12 +357,26 @@ class MQLLMEngineClient(EngineClient):
# Wait until server is ready.
response = await self._wait_for_server_rpc(socket)
+ if response.nixl_metadata is not None:
+ assert self.using_nixl_connector
+ self.nixl_metadata = msgspec.msgpack.decode(response.nixl_metadata, type=NixlMetadata)
+
self.tracing_flag = response.tracing_enabled
# Start health_loop.
if self.health_loop is None: if self.health_loop is None:
self.health_loop = asyncio.create_task( self.health_loop = asyncio.create_task(
self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT)) self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT))
+ +
+ if self.using_nixl_connector:
+ self.remote_prefill_loop = asyncio.create_task(
+ self.run_remote_prefill_request_handler_loop())
+
+ # Start metrics_loop. + # Start metrics_loop.
+ if self.metrics_loop is None: + if self.metrics_loop is None:
+ self.metrics_loop = asyncio.create_task( + self.metrics_loop = asyncio.create_task(
...@@ -1878,7 +2789,7 @@ index 85b5f31e..6a7ea3ae 100644 ...@@ -1878,7 +2789,7 @@ index 85b5f31e..6a7ea3ae 100644
def close(self): def close(self):
"""Destroy the ZeroMQ Context.""" """Destroy the ZeroMQ Context."""
@@ -293,6 +350,8 @@ class MQLLMEngineClient(EngineClient): @@ -293,6 +386,8 @@ class MQLLMEngineClient(EngineClient):
# Cancel background tasks. # Cancel background tasks.
if self.health_loop is not None: if self.health_loop is not None:
self.health_loop.cancel() self.health_loop.cancel()
...@@ -1887,7 +2798,70 @@ index 85b5f31e..6a7ea3ae 100644 ...@@ -1887,7 +2798,70 @@ index 85b5f31e..6a7ea3ae 100644
if self.output_loop is not None: if self.output_loop is not None:
self.output_loop.cancel() self.output_loop.cancel()
@@ -705,3 +764,6 @@ class MQLLMEngineClient(EngineClient): @@ -415,6 +510,9 @@ class MQLLMEngineClient(EngineClient):
"""
if self._errored_with is not None:
raise self._errored_with
+
+ async def add_remote_nixl_metadata(self, nixl_metadata: NixlMetadata):
+ await self.remote_nixl_metadata_socket.send(msgspec.msgpack.encode(nixl_metadata), copy=False)
@property
def is_running(self) -> bool:
@@ -473,6 +571,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
+ remote_prefill_params: Optional[RemotePrefillParams] = None,
*,
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[RequestOutput, None]:
@@ -502,7 +601,8 @@ class MQLLMEngineClient(EngineClient):
return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers,
- prompt_adapter_request, priority)
+ prompt_adapter_request, priority,
+ remote_prefill_params)
@overload
def encode(
@@ -586,6 +686,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
+ remote_prefill_params: Optional[RemotePrefillParams] = None,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
PoolingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
@@ -630,6 +731,12 @@ class MQLLMEngineClient(EngineClient):
else:
lp_bytes = None
+ if remote_prefill_params is not None:
+ self.remote_prefill_requests_callback[request_id] = remote_prefill_params.remote_prefill_request_callback
+ remote_prefill_params.remote_prefill_request_callback = None
+ else:
+ remote_prefill_request_callback = None
+
request_bytes = pickle.dumps(
RPCProcessRequest(
prompt=prompt,
@@ -639,11 +746,11 @@ class MQLLMEngineClient(EngineClient):
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
+ remote_prefill_params=remote_prefill_params,
))
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
- parts = (request_bytes,
- lp_bytes) if lp_bytes else (request_bytes, )
+ parts = (request_bytes, lp_bytes) if lp_bytes else (request_bytes,)
await self.input_socket.send_multipart(parts, copy=False)
# 4) Stream the RequestOutputs from the output queue. Note
@@ -705,3 +812,6 @@ class MQLLMEngineClient(EngineClient):
# Raise on error, otherwise happily return None # Raise on error, otherwise happily return None
if isinstance(request_output, BaseException): if isinstance(request_output, BaseException):
raise request_output raise request_output
...@@ -1895,28 +2869,48 @@ index 85b5f31e..6a7ea3ae 100644 ...@@ -1895,28 +2869,48 @@ index 85b5f31e..6a7ea3ae 100644
+ def set_metrics_publisher(self, metrics_publisher): + def set_metrics_publisher(self, metrics_publisher):
+ self.metrics_publisher = metrics_publisher + self.metrics_publisher = metrics_publisher
diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py
index a0dd7958..dc6ea25d 100644 index a0dd7958..dbd9d58d 100644
--- a/vllm/engine/multiprocessing/engine.py --- a/vllm/engine/multiprocessing/engine.py
+++ b/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py
@@ -14,24 +14,56 @@ from vllm.engine.llm_engine import LLMEngine @@ -3,35 +3,73 @@
import pickle
import signal
from contextlib import contextmanager
-from typing import Iterator, List, Optional, Union
+from typing import Iterator, List, Optional, Union, Dict
import cloudpickle
+import time
import zmq
-
+import msgspec
from vllm import AsyncEngineArgs, SamplingParams
from vllm.engine.llm_engine import LLMEngine
# yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT,
- IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, - IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
+ IPC_OUTPUT_EXT, IPC_METRICS_EXT, - VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
+ REQUEST_OUTPUTS_T, + REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + VLLM_RPC_SUCCESS_STR, IPC_REMOTE_PREFILL_REQUEST_EXT,
+ RPCAbortRequest,
+ IPC_OUTPUT_EXT, IPC_METRICS_EXT,
RPCAdapterLoadedResponse, RPCError, RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest, RPCLoadAdapterRequest,
RPCProcessRequest, RPCProcessRequest,
RPCResetPrefixCacheRequest, RPCResetPrefixCacheRequest,
RPCStartupRequest, RPCStartupResponse, RPCStartupRequest, RPCStartupResponse,
- RPCUProfileRequest) - RPCUProfileRequest)
+ RPCUProfileRequest, KvMetrics) + RPCUProfileRequest, IPC_REMOTE_NIXL_METADATA_EXT,
+ KvMetrics)
# yapf: enable # yapf: enable
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
+from vllm.remote_prefill import RemotePrefillRequest
+from vllm.distributed.device_communicators.nixl import NixlMetadata
+
+from vllm.engine.metrics_types import StatLoggerBase, Stats, SupportsMetricsInfo +from vllm.engine.metrics_types import StatLoggerBase, Stats, SupportsMetricsInfo
+from dataclasses import dataclass, field +from dataclasses import dataclass, field
...@@ -1957,7 +2951,7 @@ index a0dd7958..dc6ea25d 100644 ...@@ -1957,7 +2951,7 @@ index a0dd7958..dc6ea25d 100644
class MQLLMEngine: class MQLLMEngine:
"""A multiprocessing wrapper for :class:`LLMEngine`. """A multiprocessing wrapper for :class:`LLMEngine`.
@@ -94,12 +126,24 @@ class MQLLMEngine: @@ -94,12 +132,31 @@ class MQLLMEngine:
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
...@@ -1971,6 +2965,13 @@ index a0dd7958..dc6ea25d 100644 ...@@ -1971,6 +2965,13 @@ index a0dd7958..dc6ea25d 100644
# Error state. # Error state.
self._errored_with: Optional[BaseException] = None self._errored_with: Optional[BaseException] = None
+ self.remote_prefill_request_socket = self.ctx.socket(zmq.constants.PUSH)
+ self.remote_nixl_metadata_socket = self.ctx.socket(zmq.constants.PULL)
+ if self.engine.is_nixl_initialized:
+ self.remote_prefill_request_socket.bind(f"{ipc_path}{IPC_REMOTE_PREFILL_REQUEST_EXT}")
+ self.remote_nixl_metadata_socket.bind(f"{ipc_path}{IPC_REMOTE_NIXL_METADATA_EXT}")
+
+
+ # Attach logger for continuous metrics publishing + # Attach logger for continuous metrics publishing
+ self.stat_logger = KvStatLogger( + self.stat_logger = KvStatLogger(
+ self.engine.scheduler_config.max_num_seqs, + self.engine.scheduler_config.max_num_seqs,
...@@ -1982,6 +2983,99 @@ index a0dd7958..dc6ea25d 100644 ...@@ -1982,6 +2983,99 @@ index a0dd7958..dc6ea25d 100644
@property @property
def dead_error(self) -> BaseException: def dead_error(self) -> BaseException:
if self._errored_with is not None: if self._errored_with is not None:
@@ -171,8 +228,17 @@ class MQLLMEngine:
# Handle the query from the Client.
if request == RPCStartupRequest.IS_SERVER_READY:
tracing_enabled = self.engine.is_tracing_enabled()
- response = RPCStartupResponse(
- tracing_enabled=tracing_enabled)
+
+ # Send nixl metadata to the client
+ if self.engine.is_nixl_initialized:
+ nixl_metadata = self.engine.get_nixl_metadata()
+ encoded_nixl_metadata = msgspec.msgpack.encode(nixl_metadata)
+ response = RPCStartupResponse(
+ tracing_enabled=tracing_enabled,
+ nixl_metadata=encoded_nixl_metadata)
+ else:
+ response = RPCStartupResponse(
+ tracing_enabled=tracing_enabled)
except Exception as e:
response = e
@@ -185,6 +251,7 @@ class MQLLMEngine:
while True:
if not self.engine.has_unfinished_requests():
+ logger.debug("No unfinished requests")
# Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
# When there's no work, check on engine health and send
@@ -220,6 +287,13 @@ class MQLLMEngine:
def handle_new_input(self):
"""Handle new input from the socket"""
try:
+ if self.engine.is_nixl_initialized:
+ while self.remote_nixl_metadata_socket.poll(timeout=0) != 0:
+ frames = self.remote_nixl_metadata_socket.recv(copy=False)
+ nixl_metadata = msgspec.msgpack.decode(frames.buffer, type=NixlMetadata)
+ logger.debug("Adding remote nixl metadata for engine: %s", nixl_metadata.engine_id)
+ self.engine.add_remote_nixl_metadata(nixl_metadata)
+
while self.input_socket.poll(timeout=0) != 0:
frames = self.input_socket.recv_multipart(copy=False)
request = pickle.loads(frames[0].buffer)
@@ -262,6 +336,11 @@ class MQLLMEngine:
self._send_outputs(rpc_err)
try:
+ if request.remote_prefill_params is not None and request.remote_prefill_params.is_remote_prefill:
+ def remote_prefill_request_callback(request: RemotePrefillRequest):
+ logger.debug("Sending remote prefill request: %s", request.request_id)
+ self.remote_prefill_request_socket.send(msgspec.msgpack.encode(request), copy=False)
+ request.remote_prefill_params.remote_prefill_request_callback = remote_prefill_request_callback
self.engine.add_request(
request_id=request_id,
prompt=request.prompt,
@@ -269,7 +348,9 @@ class MQLLMEngine:
lora_request=request.lora_request,
trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request,
- priority=request.priority)
+ priority=request.priority,
+ remote_prefill_params=request.remote_prefill_params,
+ )
if self.log_requests:
logger.info("Added request %s.", request.request_id)
diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py
index 107220d5..c716f75f 100644
--- a/vllm/entrypoints/openai/serving_chat.py
+++ b/vllm/entrypoints/openai/serving_chat.py
@@ -34,6 +34,7 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
+from vllm.remote_prefill import RemotePrefillParams
logger = init_logger(__name__)
@@ -112,6 +113,7 @@ class OpenAIServingChat(OpenAIServing):
self,
request: ChatCompletionRequest,
raw_request: Optional[Request] = None,
+ remote_prefill_params: Optional[RemotePrefillParams] = None,
) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
ErrorResponse]:
"""
@@ -243,6 +245,7 @@ class OpenAIServingChat(OpenAIServing):
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=request.priority,
+ remote_prefill_params=remote_prefill_params,
)
generators.append(generator)
diff --git a/vllm/envs.py b/vllm/envs.py diff --git a/vllm/envs.py b/vllm/envs.py
index 745b068b..0ae63d9b 100644 index 745b068b..0ae63d9b 100644
--- a/vllm/envs.py --- a/vllm/envs.py
...@@ -2032,3 +3126,516 @@ index 773f5abe..3eefd266 100644 ...@@ -2032,3 +3126,516 @@ index 773f5abe..3eefd266 100644
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
diff --git a/vllm/outputs.py b/vllm/outputs.py
index 786380c3..56a7cf89 100644
--- a/vllm/outputs.py
+++ b/vllm/outputs.py
@@ -6,16 +6,16 @@ from typing import Dict, Generic, List, MutableSequence, Optional
from typing import Sequence as GenericSequence
from typing import Union
+import msgspec
import torch
from typing_extensions import TypeVar, deprecated
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalPlaceholderDict
-from vllm.sampling_params import RequestOutputKind
+from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup, SequenceGroupBase, SequenceStatus)
-
@dataclass
class CompletionOutput:
"""The output data of one completion output of a request.
diff --git a/vllm/remote_prefill.py b/vllm/remote_prefill.py
new file mode 100644
index 00000000..03f02006
--- /dev/null
+++ b/vllm/remote_prefill.py
@@ -0,0 +1,53 @@
+from dataclasses import dataclass
+from typing import Callable, Optional, List, Coroutine
+
+import msgspec
+
+from vllm.sampling_params import SamplingParams
+
+
+class RemotePrefillRequest(
+ msgspec.Struct,
+ omit_defaults=True, # type: ignore[call-arg]
+ # required for @cached_property.
+ dict=True):
+ """The request data of one remote prefill output of a request.
+
+ Args:
+ request_id: The unique ID of the request.
+ prompt: The prompt string of the request.
+ """
+ request_id: str
+ prompt_token_ids: List[int]
+ sampling_params: SamplingParams
+ block_ids: List[int]
+ engine_id: str
+
+
+class MemoryTransferRequest(
+ msgspec.Struct,
+ array_like=True, # type: ignore[call-arg]
+ omit_defaults=True): # type: ignore[call-arg]
+ """The request data of one memory transfer output of a request.
+
+ Args:
+ request_id: The unique ID of the request.
+ """
+ request_id: str
+ src_block_ids: List[int]
+ dst_block_ids: List[int]
+ dst_engine_id: str
+ notify_msg: str
+
+
+RemotePrefillRequestCallback = Callable[[RemotePrefillRequest], None]
+
+
+@dataclass
+class RemotePrefillParams:
+ """Remote prefill parameters for text generation."""
+ is_remote_prefill: bool = False
+ is_remote_decode: bool = False
+ decode_block_ids: Optional[List[int]] = None
+ decode_engine_id: Optional[str] = None
+ remote_prefill_request_callback: Optional[RemotePrefillRequestCallback] = None
\ No newline at end of file
diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py
index 97f9e212..1bb97b00 100644
--- a/vllm/sampling_params.py
+++ b/vllm/sampling_params.py
@@ -83,7 +83,7 @@ class RequestOutputKind(Enum):
DELTA = 1
# Do not return intermediate RequestOuputs
FINAL_ONLY = 2
-
+
class SamplingParams(
msgspec.Struct,
diff --git a/vllm/sequence.py b/vllm/sequence.py
index 534b9e60..18675d2f 100644
--- a/vllm/sequence.py
+++ b/vllm/sequence.py
@@ -20,6 +20,7 @@ from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
+from vllm.remote_prefill import RemotePrefillParams, MemoryTransferRequest
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
@@ -59,13 +60,14 @@ class SequenceStatus(enum.IntEnum):
"""Status of a sequence."""
WAITING = 0
RUNNING = 1
- SWAPPED = 2
- # Note: anything after SWAPPED (2) will be considered
+ REMOTE_PREFILLING = 2
+ SWAPPED = 3
+ # Note: anything after SWAPPED (3) will be considered
# as a finished status.
- FINISHED_STOPPED = 3
- FINISHED_LENGTH_CAPPED = 4
- FINISHED_ABORTED = 5
- FINISHED_IGNORED = 6
+ FINISHED_STOPPED = 4
+ FINISHED_LENGTH_CAPPED = 5
+ FINISHED_ABORTED = 6
+ FINISHED_IGNORED = 7
@staticmethod
def is_finished(status: "SequenceStatus") -> bool:
@@ -409,6 +411,7 @@ class Sequence:
eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ remote_prefill_params: Optional[RemotePrefillParams] = None,
) -> None:
self.seq_id = seq_id
self.inputs = SingletonInputsAdapter(inputs)
@@ -416,7 +419,7 @@ class Sequence:
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request
-
+ self.remote_prefill_params = remote_prefill_params
self.data = SequenceData.from_seqs(self.prompt_token_ids)
self.output_logprobs: SampleLogprobs = []
self.output_text = ""
@@ -639,6 +642,7 @@ class SequenceGroup:
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request.
priority: User-defined priority of the request.
+ remote_prefill_params: Remote prefill parameters.
"""
def __init__(
@@ -654,6 +658,7 @@ class SequenceGroup:
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
+ remote_prefill_params: Optional[RemotePrefillParams] = None,
) -> None:
self.request_id = request_id
self.seqs = seqs
@@ -678,7 +683,7 @@ class SequenceGroup:
self.encoder_seq = encoder_seq
self.trace_headers = trace_headers
self.priority = priority
-
+ self.remote_prefill_params = remote_prefill_params
self.cached_request_output = None
@property
@@ -927,6 +932,9 @@ class SequenceGroupMetadata(
query tokens for prefill, we don't need sampling.
token_chunk_size: The number of tokens to be processed (per sequence).
None if chunking is not required.
+ do_remote_prefill: True if remote prefill is required.
+ do_remote_decode: True if remote decode is required.
+ decode_memory_desc: The memory descriptor for the decoder blocks.
lora_request: LoRA request.
computed_block_nums: The block numbers that are already computed,
used in prefix caching.
@@ -966,6 +974,9 @@ class SequenceGroupMetadata(
cross_block_table: Optional[List[int]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
token_chunk_size: Optional[int] = None
+ do_remote_prefill: bool = False
+ do_remote_decode: bool = False
+ decode_memory_desc: Optional[bytes] = None
### Stateful fields that are lazily defined. ###
# The number of speculative tokens adopted in this request.
@@ -1310,6 +1321,8 @@ class ExecuteModelRequest(
last_sampled_token_ids: Optional[torch.Tensor] = None
# Async callback
async_callback: Optional[Callable] = None
+ # The memory transfer requests.
+ memory_transfer_requests: Optional[List[MemoryTransferRequest]] = None
@property
def is_first_multi_step(self) -> bool:
diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py
index 12baecde..cbada27f 100644
--- a/vllm/worker/model_runner.py
+++ b/vllm/worker/model_runner.py
@@ -1824,6 +1824,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if self.vllm_config.kv_transfer_config is None:
return False
+
+ if self.vllm_config.kv_transfer_config.kv_connector == "TritonNixlConnector":
+ return False
prefill_meta = model_input.attn_metadata.prefill_metadata
@@ -1849,6 +1852,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if self.vllm_config.kv_transfer_config is None:
return False
+
+ if self.vllm_config.kv_transfer_config.kv_connector == "TritonNixlConnector":
+ return False
prefill_meta = model_input.attn_metadata.prefill_metadata
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index 582aa460..ffb7b403 100644
--- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
@@ -2,7 +2,7 @@
"""A GPU worker class."""
import gc
import os
-from typing import Dict, List, Optional, Set, Tuple, Type, Union
+from typing import Dict, List, Optional, Set, Tuple, Type, Union, TYPE_CHECKING, Any
import torch
import torch.distributed
@@ -31,6 +31,8 @@ from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.pooling_model_runner import PoolingModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
+from vllm.distributed.device_communicators.nixl import TritonNixlConnector
+
logger = init_logger(__name__)
@@ -306,6 +308,43 @@ class Worker(LocalOrDistributedWorkerBase):
self._init_cache_engine()
self._warm_up_model()
+ def initialize_nixl(self, engine_id: str) -> List[bytes]:
+
+ # TODO ptarasiewicz nixl can also support DRAM
+ assert self.device_config.device_type == "cuda", "Currently only CUDA is supported for Nixl connector"
+
+ self.nixl_connector = TritonNixlConnector(self.vllm_config, engine_id, self.local_rank) # TODO ptarasiewicz: rank or local_rank?
+ assert len(self.cache_engine) == 1, "Only one cache engine is supported for now"
+ self.nixl_connector.register_kv_caches(self.cache_engine[0].gpu_cache)
+ return self.nixl_connector.agent_name
+
+ def get_nixl_agent_metadata(self) -> bytes:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ return self.nixl_connector.get_agent_metadata()
+
+ def add_remote_nixl_metadata(self, engine_id: str, agents_metadata: List[bytes], kv_caches_base_addr: List[List[Tuple[int, int]]]) -> str:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata[self.local_rank]) # TODO ptarasiewicz: rank or local_rank?
+ self.nixl_connector.add_remote_kv_caches_base_addr(engine_id, kv_caches_base_addr[self.local_rank])
+ return agent_name
+
+ def transfer_nixl_memory(self, src_descs: List[bytes], dst_descs: List[bytes], remote_agent_name: List[str], notify_msg: str) -> None:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ self.nixl_connector.transfer_mem(src_descs[self.local_rank], dst_descs[self.local_rank], remote_agent_name[self.local_rank], notify_msg) # TODO ptarasiewicz: rank or local_rank?
+
+ def get_nixl_kv_caches_base_addr(self) -> List[bytes]:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ return self.nixl_connector.kv_caches_base_addr[self.nixl_connector.engine_id]
+
+ def _transfer_blocks(self, worker_input: WorkerInput) -> None:
+ if worker_input.src_block_ids is not None:
+ for src_block_ids, dst_block_ids, dst_engine_id, notify_msg in zip(worker_input.src_block_ids, worker_input.dst_block_ids, worker_input.dst_engine_id, worker_input.notify_msg):
+ self.nixl_connector.transfer_mem(src_block_ids, dst_block_ids, dst_engine_id, notify_msg)
+
+ def shutdown_nixl(self) -> None:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ self.nixl_connector.shutdown()
+
def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = [
@@ -367,6 +406,8 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device=self.device,
dtype=torch.int64).view(-1, 2)
+
+ mem_transfer_reqs = execute_model_req.memory_transfer_requests or []
return WorkerInput(
num_seq_groups=num_seq_groups,
@@ -375,6 +416,10 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
num_steps=num_steps,
+ src_block_ids=[r.src_block_ids for r in mem_transfer_reqs],
+ dst_block_ids=[r.dst_block_ids for r in mem_transfer_reqs],
+ dst_engine_id=[r.dst_engine_id for r in mem_transfer_reqs],
+ notify_msg=[r.notify_msg for r in mem_transfer_reqs],
)
@torch.inference_mode()
diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py
index 819b81fb..d9c039eb 100644
--- a/vllm/worker/worker_base.py
+++ b/vllm/worker/worker_base.py
@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
import cloudpickle
import torch
import torch.nn as nn
+from collections import defaultdict
from vllm.config import (ObservabilityConfig, VllmConfig,
set_current_vllm_config)
@@ -23,6 +24,7 @@ from vllm.utils import (enable_trace_function_call_for_thread,
from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase,
ModelRunnerInputBase)
+from vllm.distributed.device_communicators.nixl import TritonNixlConnector
logger = init_logger(__name__)
@@ -53,6 +55,8 @@ class WorkerBase(ABC):
from vllm.platforms import current_platform
self.current_platform = current_platform
+ self.nixl_connector: Optional[TritonNixlConnector] = None
+
@abstractmethod
def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device
@@ -216,6 +220,11 @@ class WorkerInput:
virtual_engine: int = 0
num_steps: int = 1
+ src_block_ids: Optional[List[List[int]]] = None
+ dst_block_ids: Optional[List[List[int]]] = None
+ dst_engine_id: Optional[List[str]] = None
+ notify_msg: Optional[List[str]] = None
+
@classmethod
def from_broadcasted_tensor_dict(
cls: Type["WorkerInput"],
@@ -232,6 +241,10 @@ class WorkerInput:
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"],
num_steps=tensor_dict.pop("num_steps"),
+ src_block_ids=tensor_dict.pop("src_block_ids"),
+ dst_block_ids=tensor_dict.pop("dst_block_ids"),
+ dst_engine_id=tensor_dict.pop("dst_engine_id"),
+ notify_msg=tensor_dict.pop("notify_msg"),
)
def as_broadcastable_tensor_dict(
@@ -246,6 +259,10 @@ class WorkerInput:
"blocks_to_copy": self.blocks_to_copy,
"virtual_engine": self.virtual_engine,
"num_steps": self.num_steps,
+ "src_block_ids": self.src_block_ids,
+ "dst_block_ids": self.dst_block_ids,
+ "dst_engine_id": self.dst_engine_id,
+ "notify_msg": self.notify_msg,
}
return tensor_dict
@@ -316,13 +333,16 @@ class LocalOrDistributedWorkerBase(WorkerBase):
return None
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
- model_input = (
- self.model_runner.make_model_input_from_broadcasted_tensor_dict(
- broadcast_data))
+ if worker_input.num_seq_groups > 0:
+ model_input = (
+ self.model_runner.make_model_input_from_broadcasted_tensor_dict(
+ broadcast_data))
- kwargs = extract_previous_hidden_states(broadcast_data)
+ kwargs = extract_previous_hidden_states(broadcast_data)
- return model_input, worker_input, kwargs
+ return model_input, worker_input, kwargs
+ else:
+ return None, worker_input, {}
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
@@ -396,49 +416,79 @@ class LocalOrDistributedWorkerBase(WorkerBase):
self.execute_worker(worker_input)
# If there is no input, we don't need to execute the model.
- if worker_input.num_seq_groups == 0:
- return []
-
- intermediate_tensors = None
- orig_model_execute_time = 0.0
- if not get_pp_group().is_first_rank:
- intermediate_tensors = IntermediateTensors(
- get_pp_group().recv_tensor_dict(
- all_gather_group=get_tp_group()))
+ if worker_input.num_seq_groups > 0:
+
+ intermediate_tensors = None
+ orig_model_execute_time = 0.0
+ if not get_pp_group().is_first_rank:
+ intermediate_tensors = IntermediateTensors(
+ get_pp_group().recv_tensor_dict(
+ all_gather_group=get_tp_group()))
+ if (self.observability_config is not None
+ and self.observability_config.collect_model_execute_time):
+ orig_model_execute_time = intermediate_tensors.tensors.get(
+ "model_execute_time", torch.tensor(0)).item()
+
+ output = self.model_runner.execute_model(
+ model_input=model_input,
+ kv_caches=self.kv_cache[worker_input.virtual_engine]
+ if self.kv_cache is not None else None,
+ intermediate_tensors=intermediate_tensors,
+ num_steps=num_steps,
+ **kwargs,
+ )
+
+ model_execute_time = time.perf_counter() - start_time
+ if not get_pp_group().is_last_rank:
+ # output is IntermediateTensors
+ assert isinstance(output, IntermediateTensors)
+ if (self.observability_config is not None
+ and self.observability_config.collect_model_execute_time):
+ output.tensors["model_execute_time"] = torch.tensor(
+ model_execute_time + orig_model_execute_time)
+ get_pp_group().send_tensor_dict(output.tensors,
+ all_gather_group=get_tp_group())
+ return [None]
if (self.observability_config is not None
- and self.observability_config.collect_model_execute_time):
- orig_model_execute_time = intermediate_tensors.tensors.get(
- "model_execute_time", torch.tensor(0)).item()
+ and self.observability_config.collect_model_execute_time
+ and output is not None):
+ for o in output:
+ o.model_execute_time = (orig_model_execute_time +
+ model_execute_time)
- output = self.model_runner.execute_model(
- model_input=model_input,
- kv_caches=self.kv_cache[worker_input.virtual_engine]
- if self.kv_cache is not None else None,
- intermediate_tensors=intermediate_tensors,
- num_steps=num_steps,
- **kwargs,
- )
+ self._transfer_blocks(worker_input)
- model_execute_time = time.perf_counter() - start_time
- if not get_pp_group().is_last_rank:
- # output is IntermediateTensors
- assert isinstance(output, IntermediateTensors)
- if (self.observability_config is not None
- and self.observability_config.collect_model_execute_time):
- output.tensors["model_execute_time"] = torch.tensor(
- model_execute_time + orig_model_execute_time)
- get_pp_group().send_tensor_dict(output.tensors,
- all_gather_group=get_tp_group())
- return [None]
- if (self.observability_config is not None
- and self.observability_config.collect_model_execute_time
- and output is not None):
- for o in output:
- o.model_execute_time = (orig_model_execute_time +
- model_execute_time)
+ else:
+ output = []
+
+ # collect kv transfer notifications from non driver workers
+
+ if self.nixl_connector is not None:
+ new_notifs = self.nixl_connector.get_new_notifs()
+ rank = get_tp_group().rank
+ all_new_notifs = [new_notifs]
+ if rank > 0:
+ get_tp_group().send_object(new_notifs, dst=0)
+ else:
+ for i in range(1, get_tp_group().world_size):
+ all_new_notifs.append(get_tp_group().recv_object(src=i))
+
+ request_notif_counter = defaultdict(int)
+ for notifs in all_new_notifs:
+ for req_ids in notifs.values():
+ for req_id in req_ids:
+ request_notif_counter[req_id] += 1
+
+ if request_notif_counter:
+ logger.debug("Request notif counter: %s", request_notif_counter)
+ else:
+ request_notif_counter = {}
# output is List[SamplerOutput]
- return output
+ return output, request_notif_counter
+
+ def _transfer_blocks(self, worker_input: WorkerInput) -> None:
+ pass
def _execute_model_spmd(
self,
...@@ -22,7 +22,7 @@ RUN_PREFIX= ...@@ -22,7 +22,7 @@ RUN_PREFIX=
# dependencies are specified in the /container/deps folder and # dependencies are specified in the /container/deps folder and
# installed within framework specific sections of the Dockerfile. # installed within framework specific sections of the Dockerfile.
declare -A FRAMEWORKS=(["STANDARD"]=1 ["TENSORRTLLM"]=2 ["VLLM"]=3) declare -A FRAMEWORKS=(["STANDARD"]=1 ["TENSORRTLLM"]=2 ["VLLM"]=3 ["VLLM_NIXL"]=4)
DEFAULT_FRAMEWORK=STANDARD DEFAULT_FRAMEWORK=STANDARD
SOURCE_DIR=$(dirname "$(readlink -f "$0")") SOURCE_DIR=$(dirname "$(readlink -f "$0")")
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -17,7 +17,6 @@ import json ...@@ -17,7 +17,6 @@ import json
import time import time
from typing import AsyncIterator, List, Optional, Protocol, Union, runtime_checkable from typing import AsyncIterator, List, Optional, Protocol, Union, runtime_checkable
from vllm import TokensPrompt
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.chat_utils import ConversationMessage from vllm.entrypoints.chat_utils import ConversationMessage
...@@ -29,6 +28,7 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -29,6 +28,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import RequestPrompt from vllm.entrypoints.openai.serving_engine import RequestPrompt
from vllm.inputs.data import TokensPrompt
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
......
...@@ -21,7 +21,9 @@ import msgspec ...@@ -21,7 +21,9 @@ import msgspec
from pydantic import BaseModel, ConfigDict, field_validator from pydantic import BaseModel, ConfigDict, field_validator
from pydantic_core import core_schema from pydantic_core import core_schema
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm import CompletionOutput, SamplingParams, TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import PromptLogprobs, RequestMetrics from vllm.sequence import PromptLogprobs, RequestMetrics
......
...@@ -55,7 +55,7 @@ class VllmPrefillEngine(BaseVllmEngine): ...@@ -55,7 +55,7 @@ class VllmPrefillEngine(BaseVllmEngine):
await self.initialize() await self.initialize()
vllm_logger.debug(f"Received prefill request: {request}") vllm_logger.debug(f"Received prefill request: {request}")
sampling_params = vllm.SamplingParams(**request.sampling_params) sampling_params = vllm.sampling_params.SamplingParams(**request.sampling_params)
if self.engine_client is None: if self.engine_client is None:
raise RuntimeError("Engine client not initialized") raise RuntimeError("Engine client not initialized")
else: else:
......
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
> **NOTE**: This example is based on an internal NVIDIA library that will soon be publicly released. The example won't work until the official release.
## Build docker
```
./container/build.sh --framework VLLM_NIXL --target dev --build-context nixl=<path to downloaded nixl repo @ fc912eb012597be67de11fa9ba0599e4e1974fa2>
```
## Run container
```
./container/run.sh --framework VLLM_NIXL --target dev -it
```
All of the commands below are run inside the same container.
## Run deployment
Add model to triton and start http server.
In terminal 0:
```
llmctl http add chat-models deepseek-ai/DeepSeek-R1-Distill-Llama-8B test-nixl.vllm.generate
TRT_LOG=DEBUG http --port 8181
```
### Monolithic deployment
In terminal 1:
```
cd /workspace/examples/python_rs/llm/vllm_nixl
CUDA_VISIBLE_DEVICES=0 python3 worker.py \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager
```
### Disaggregated deployment
In terminal 1:
```
cd /workspace/examples/python_rs/llm/vllm_nixl
CUDA_VISIBLE_DEVICES=0 python prefill_worker.py \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager \
--kv-transfer-config \
'{"kv_connector":"TritonNixlConnector"}'
```
In terminal 2:
```
cd /workspace/examples/python_rs/llm/vllm_nixl
CUDA_VISIBLE_DEVICES=1 python3 worker.py \
--remote-prefill \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager \
--kv-transfer-config \
'{"kv_connector":"TritonNixlConnector"}'
```
## Client
In another terminal:
```
curl localhost:8181/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"messages": [
{"role": "user", "content": "What is the capital of France?"}
],
"max_tokens": 10
}'
```
## Run genai-perf
`genai-perf` is a tool for profiling and benchmarking LLM servers. It is already installed in the container. For more details, please refer to the [genai-perf README](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/perf_analyzer/genai-perf/README.html).
```
genai-perf profile \
-m deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--url localhost:8181 \
--endpoint-type chat \
--streaming \
--service-kind openai \
--endpoint v1/chat/completions \
--warmup-request-count 10 \
--random-seed 123 \
--synthetic-input-tokens-stddev 0 \
--output-tokens-stddev 0 \
--tokenizer deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--synthetic-input-tokens-mean 3000 \
--output-tokens-mean 150 \
--extra-inputs min_tokens:150 \
--extra-inputs max_tokens:150 \
--profile-export-file my_profile_export.json \
--artifact-dir artifacts/ \
--concurrency 10 \
--request-count 40 \
-- -v \
--async
```
## Close deployment
Kill all python processes and clean up metadata files:
```
pkill -9 -f python
rm -r /tmp/nixl
```
## TODOs, limitations, known issues
- [ ] Add etcd for discovery
- [ ] Multi-node deployment support
- [ ] Enable chunked prefill
- [ ] Support mixed tp
- [ ] Process many remote prefill in one iteration
- [ ] Support recompute preemption
- [ ] Make sure decode does not preempt blocks before xfer finishes
- [ ] Layer wise transfer
- [ ] Non blocking send in prefill (cache manager should check xfer status)
- [ ] Test under load
- [ ] Support pp > 1
- [ ] Check why adding extra seed input is crashing vllm with remote prefill
- [ ] Unified worker for both prefill and decode
- [x] Require sending two parallel requests to start decode for the first time
- [x] Concurrency > 2 is not working
- [x] Parse cmdline args
- [x] Manual nixl example with tp1
- [x] Zero copy
- [x] Conditional remote prefill
- [x] Manual example with tp > 1
- [x] Run on triton distributed runtime
- [x] add oai http endpoint
- [x] Sample only on decode, do note return remote prefill response
- [x] Check if all transfers finished before moving to decode
- [x] Enable async output processing - could be working
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import contextmanager
import msgspec
from vllm.distributed.device_communicators.nixl import NixlMetadata
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser
METADATA_DIR = "/tmp/nixl"
def parse_vllm_args() -> AsyncEngineArgs:
parser = FlexibleArgumentParser()
parser.add_argument(
"--remote-prefill", action="store_true", help="Enable remote prefill"
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine_args.remote_prefill = args.remote_prefill
return engine_args
@contextmanager
def temp_metadata_file(engine_id, metadata: NixlMetadata):
os.makedirs(METADATA_DIR, exist_ok=True)
path = f"{METADATA_DIR}/{engine_id}.nixl_meta"
with open(path, "wb") as f:
encoded = msgspec.msgpack.encode(metadata)
print(f"Size of encoded metadata: {len(encoded)}")
f.write(encoded)
try:
yield path
finally:
if os.path.exists(path):
os.remove(path)
def find_remote_metadata(engine_id):
# find and load metadata from METADATA_DIR that do not match engine_id
remote_metadata = []
for file in os.listdir(METADATA_DIR):
if file.endswith(".nixl_meta"):
if file.split(".")[0] != engine_id:
with open(os.path.join(METADATA_DIR, file), "rb") as f:
remote_metadata.append(
msgspec.msgpack.decode(f.read(), type=NixlMetadata)
)
return remote_metadata
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import msgspec
import uvloop
from common import find_remote_metadata, parse_vllm_args, temp_metadata_file
from vllm.distributed.device_communicators.nixl import NixlMetadata
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.inputs.data import TokensPrompt
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from triton_distributed.runtime import DistributedRuntime, triton_worker
class RequestHandler:
def __init__(self, engine_client):
self.engine_client = engine_client
print("RequestHandler initialized")
async def generate(self, raw_request: str):
request: RemotePrefillRequest = msgspec.json.decode(
raw_request.encode("utf-8"), type=RemotePrefillRequest
)
sampling_params = request.sampling_params
sampling_params.max_tokens = 1
sampling_params.min_tokens = 1
remote_prefill_params = RemotePrefillParams(
is_remote_decode=True,
decode_block_ids=request.block_ids,
decode_engine_id=request.engine_id,
)
async for _ in self.engine_client.generate(
request_id=request.request_id,
prompt=TokensPrompt(prompt_token_ids=request.prompt_token_ids),
sampling_params=sampling_params,
remote_prefill_params=remote_prefill_params,
):
yield
@triton_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
component = runtime.namespace("test-nixl").component("prefill")
await component.create_service()
endpoint = component.endpoint("generate")
async with build_async_engine_client_from_engine_args(engine_args) as engine_client:
# This should be replaced with etcd
metadata = engine_client.nixl_metadata
with temp_metadata_file(metadata.engine_id, metadata):
print(f"Waiting for remote metadata for engine {metadata.engine_id}")
remote_metadata: list[NixlMetadata] = []
while not remote_metadata:
await asyncio.sleep(1)
remote_metadata = find_remote_metadata(metadata.engine_id)
print(
f"Found {len(remote_metadata)} remote metadata for engine {metadata.engine_id}"
)
for remote_metadata in remote_metadata:
await engine_client.add_remote_nixl_metadata(remote_metadata)
await endpoint.serve_endpoint(RequestHandler(engine_client).generate)
if __name__ == "__main__":
uvloop.install()
engine_args = parse_vllm_args()
if engine_args.enable_chunked_prefill is not False:
print("Chunked prefill is not supported yet, setting to False")
engine_args.enable_chunked_prefill = False
if engine_args.pipeline_parallel_size != 1:
print("Pipeline parallel size is not supported yet, setting to 1")
engine_args.pipeline_parallel_size = 1
asyncio.run(worker(engine_args))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import msgspec
from vllm.sampling_params import SamplingParams
class Request(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True,
):
"""The request data of one remote prefill output of a request.
Args:
request_id: The unique ID of the request.
prompt: The prompt string of the request.
"""
request_id: str
prompt: str
sampling_params: SamplingParams
do_remote_prefill: bool = False
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import msgspec
import uvloop
from common import parse_vllm_args, temp_metadata_file
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.multiprocessing.client import EngineClient
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionStreamResponse,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from triton_distributed.runtime import (
DistributedRuntime,
triton_endpoint,
triton_worker,
)
class RequestHandler:
def __init__(
self,
model_name: str,
engine_client: EngineClient,
prefill_client,
do_remote_prefill: bool,
):
self.model_name = model_name
self.engine_client = engine_client
self.prefill_client = prefill_client
self.openai_serving_chat = None
self.initialized = False
self.do_remote_prefill = (
do_remote_prefill # TODO: this should be decided by the algorithm
)
print("RequestHandler initialized")
async def init(self):
models = OpenAIServingModels(
engine_client=self.engine_client,
model_config=await self.engine_client.get_model_config(),
base_model_paths=[
BaseModelPath(
name=self.model_name,
model_path=self.model_name,
)
],
)
self.openai_serving_chat = OpenAIServingChat(
engine_client=self.engine_client,
model_config=await self.engine_client.get_model_config(),
models=models,
request_logger=None,
response_role="assistant",
chat_template=None,
chat_template_content_format="auto",
)
self.initialized = True
def get_remote_prefill_request_callback(self):
async def callback(request: RemotePrefillRequest):
json_request = msgspec.json.encode(request).decode("utf-8")
self.prefill_client.generate(json_request)
return callback
@triton_endpoint(ChatCompletionRequest, ChatCompletionStreamResponse)
async def generate(self, request):
if not self.initialized:
await self.init()
assert self.openai_serving_chat is not None
if self.do_remote_prefill:
remote_prefill_params = RemotePrefillParams(
is_remote_prefill=True,
remote_prefill_request_callback=self.get_remote_prefill_request_callback(),
)
else:
remote_prefill_params = None
async for raw_response in await self.openai_serving_chat.create_chat_completion(
request,
remote_prefill_params=remote_prefill_params,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
@triton_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
component = runtime.namespace("test-nixl").component("vllm")
await component.create_service()
endpoint = component.endpoint("generate")
prefill_client = (
await runtime.namespace("test-nixl")
.component("prefill")
.endpoint("generate")
.client()
)
async with build_async_engine_client_from_engine_args(engine_args) as engine_client:
# This should be replaced with etcd
if engine_args.remote_prefill:
metadata = engine_client.nixl_metadata
with temp_metadata_file(metadata.engine_id, metadata):
await endpoint.serve_endpoint(
RequestHandler(
model_name=engine_args.model,
engine_client=engine_client,
prefill_client=prefill_client,
do_remote_prefill=True,
).generate
)
else:
await endpoint.serve_endpoint(
RequestHandler(
model_name=engine_args.model,
engine_client=engine_client,
prefill_client=prefill_client,
do_remote_prefill=False,
).generate
)
if __name__ == "__main__":
uvloop.install()
engine_args = parse_vllm_args()
if engine_args.remote_prefill:
if engine_args.enable_chunked_prefill is not False:
print("Chunked prefill is not supported yet, setting to False")
engine_args.enable_chunked_prefill = False
if engine_args.preemption_mode != "swap":
print("Preemption mode is not supported yet, setting to swap")
engine_args.preemption_mode = "swap"
if engine_args.pipeline_parallel_size != 1:
print("Pipeline parallel size is not supported yet, setting to 1")
engine_args.pipeline_parallel_size = 1
asyncio.run(worker(engine_args))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment