Commit fe83f8aa authored by ptarasiewiczNV's avatar ptarasiewiczNV Committed by GitHub
Browse files

Update NIXL Dockerfile + vLLM patch with variable TP (#9)


Co-authored-by: default avatarPiotr Tarasiewicz Nvidia <ptarasiewicznv@Piotrs-MacBook-Pro.local>
parent b9ce8dd0
...@@ -8,50 +8,12 @@ FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} AS dev ...@@ -8,50 +8,12 @@ FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} AS dev
USER root USER root
# Install utilities
RUN apt update -y && apt install -y git wget curl nvtop tmux vim
# nats
RUN wget https://github.com/nats-io/nats-server/releases/download/v2.10.24/nats-server-v2.10.24-amd64.deb && dpkg -i nats-server-v2.10.24-amd64.deb
# etcd
ENV ETCD_VERSION="v3.5.18"
RUN wget https://github.com/etcd-io/etcd/releases/download/$ETCD_VERSION/etcd-$ETCD_VERSION-linux-amd64.tar.gz -O /tmp/etcd.tar.gz && \
mkdir -p /usr/local/bin/etcd && \
tar -xvf /tmp/etcd.tar.gz -C /usr/local/bin/etcd --strip-components=1
ENV PATH=/usr/local/bin/etcd/:$PATH
### VIRTUAL ENVIRONMENT SETUP ###
# Install uv and create virtualenv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
RUN mkdir /opt/triton && \
uv venv /opt/triton/venv --python 3.12
# Activate virtual environment
ENV VIRTUAL_ENV=/opt/triton/venv
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Install patched vllm - keep this early in Dockerfile to avoid
# rebuilds from unrelated source code changes
ARG VLLM_REF="v0.7.2"
ARG VLLM_PATCH="vllm_${VLLM_REF}-triton-kv-disagg-patch.patch"
RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
bash /tmp/deps/vllm/install.sh --patch /tmp/deps/vllm/${VLLM_PATCH} --ref ${VLLM_REF} --install-cmd "uv pip install --editable" --use-precompiled --installation-dir /opt/vllm
# Install genai-perf for benchmarking
ARG GENAI_PERF_TAG="r25.01"
RUN uv pip install "git+https://github.com/triton-inference-server/perf_analyzer.git@${GENAI_PERF_TAG}#subdirectory=genai-perf"
# Install test dependencies
RUN --mount=type=bind,source=./container/deps/requirements.test.txt,target=/tmp/requirements.txt \
uv pip install --requirement /tmp/requirements.txt
### NIXL SETUP ### ### NIXL SETUP ###
ARG MOFED_VERSION=5.8-1.1.2.1 ARG MOFED_VERSION=24.10-1.1.4.0
ARG PYTHON_VERSION=3.12 ARG PYTHON_VERSION=3.12
ARG NSYS_URL=https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2024_4/ ARG NSYS_URL=https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_1/
ARG NSYS_PKG=NsightSystems-linux-cli-public-2024.4.1.61-3431596.deb ARG NSYS_PKG=NsightSystems-linux-cli-public-2025.1.1.131-3554042.deb
RUN apt-get update -y && apt-get -y install curl \ RUN apt-get update -y && apt-get -y install curl \
git \ git \
...@@ -78,7 +40,9 @@ RUN apt-get update -y && apt-get -y install curl \ ...@@ -78,7 +40,9 @@ RUN apt-get update -y && apt-get -y install curl \
libprotobuf-dev \ libprotobuf-dev \
protobuf-compiler-grpc \ protobuf-compiler-grpc \
pybind11-dev \ pybind11-dev \
python3-full \
python3-pip \ python3-pip \
python3-numpy \
etcd-server \ etcd-server \
net-tools \ net-tools \
pciutils \ pciutils \
...@@ -89,25 +53,18 @@ RUN apt-get update -y && apt-get -y install curl \ ...@@ -89,25 +53,18 @@ RUN apt-get update -y && apt-get -y install curl \
ibverbs-utils \ ibverbs-utils \
libibmad-dev libibmad-dev
RUN apt-get update && \
apt install -y wget libglib2.0-0
RUN wget ${NSYS_URL}${NSYS_PKG} && \
dpkg -i $NSYS_PKG && \
rm $NSYS_PKG
RUN apt-get install -y linux-tools-common linux-tools-generic ethtool iproute2 RUN apt-get install -y linux-tools-common linux-tools-generic ethtool iproute2
RUN apt-get install -y dkms linux-headers-generic RUN apt-get install -y dkms linux-headers-generic
RUN apt-get install -y meson ninja-build uuid-dev gdb RUN apt-get install -y meson ninja-build uuid-dev gdb
RUN uv pip install --upgrade meson RUN apt-get update && apt install -y wget libglib2.0-0
RUN uv pip install ninja pybind11 RUN wget ${NSYS_URL}${NSYS_PKG} && dpkg -i $NSYS_PKG && rm $NSYS_PKG
RUN cd /usr/local/src && \ RUN cd /usr/local/src && \
curl -fSsL "https://content.mellanox.com/ofed/MLNX_OFED-${MOFED_VERSION}/MLNX_OFED_LINUX-${MOFED_VERSION}-ubuntu20.04-x86_64.tgz" -o mofed.tgz && \ curl -fSsL "https://content.mellanox.com/ofed/MLNX_OFED-${MOFED_VERSION}/MLNX_OFED_LINUX-${MOFED_VERSION}-ubuntu24.04-x86_64.tgz" -o mofed.tgz && \
tar -xf /usr/local/src/mofed.tgz && \ tar -xf /usr/local/src/mofed.tgz && \
cd MLNX_OFED_LINUX-* && \ cd MLNX_OFED_LINUX-* && \
apt-get update && \ apt-get update && apt-get install -y --no-install-recommends \
apt-get install -y --no-install-recommends \
./DEBS/libibverbs* ./DEBS/ibverbs-providers* ./DEBS/librdmacm* ./DEBS/libibumad* && \ ./DEBS/libibverbs* ./DEBS/ibverbs-providers* ./DEBS/librdmacm* ./DEBS/libibumad* && \
rm -rf /var/lib/apt/lists/* /usr/local/src/* rm -rf /var/lib/apt/lists/* /usr/local/src/*
...@@ -128,9 +85,7 @@ ARG UCX_VERSION=v1.18.0 ...@@ -128,9 +85,7 @@ ARG UCX_VERSION=v1.18.0
RUN cd /usr/local/src && \ RUN cd /usr/local/src && \
curl -fSsL "https://github.com/openucx/ucx/tarball/${UCX_VERSION}" | tar xz && \ curl -fSsL "https://github.com/openucx/ucx/tarball/${UCX_VERSION}" | tar xz && \
cd openucx-ucx* && \ cd openucx-ucx* && \
./autogen.sh && \ ./autogen.sh && ./configure \
./configure \
--prefix=/usr/local/ucx \
--enable-shared \ --enable-shared \
--disable-static \ --disable-static \
--disable-doxygen-doc \ --disable-doxygen-doc \
...@@ -147,13 +102,20 @@ RUN cd /usr/local/src && \ ...@@ -147,13 +102,20 @@ RUN cd /usr/local/src && \
make -j install-strip && \ make -j install-strip && \
ldconfig ldconfig
ENV LD_LIBRARY_PATH=/usr/lib:$LD_LIBRARY_PATH
ENV LD_LIBRARY_PATH=/usr/local/ucx/lib:$LD_LIBRARY_PATH ENV CPATH=/usr/include:$CPATH
ENV CPATH=/usr/local/ucx/include:$CPATH ENV PATH=/usr/bin:$PATH
ENV PATH=/usr/local/ucx/bin:$PATH ENV PKG_CONFIG_PATH=/usr/lib/pkgconfig:$PKG_CONFIG_PATH
ENV PKG_CONFIG_PATH=/usr/local/ucx/lib/pkgconfig:$PKG_CONFIG_PATH
SHELL ["/bin/bash", "-c"] SHELL ["/bin/bash", "-c"]
WORKDIR /workspace
ENV LD_LIBRARY_PATH=/usr/local/ompi/lib:$LD_LIBRARY_PATH
ENV CPATH=/usr/local/ompi/include:$CPATH
ENV PATH=/usr/local/ompi/bin:$PATH
ENV PKG_CONFIG_PATH=/usr/local/ompi/lib/pkgconfig:$PKG_CONFIG_PATH
COPY --from=nixl . /opt/nixl COPY --from=nixl . /opt/nixl
RUN cd /opt/nixl && \ RUN cd /opt/nixl && \
...@@ -161,15 +123,10 @@ RUN cd /opt/nixl && \ ...@@ -161,15 +123,10 @@ RUN cd /opt/nixl && \
meson setup build/ --prefix=/usr/local/nixl && \ meson setup build/ --prefix=/usr/local/nixl && \
cd build/ && \ cd build/ && \
ninja && \ ninja && \
ninja install && \ ninja install
mkdir -p /usr/local/nixl/include/internal && \
cp ../include/*.h /usr/local/nixl/include && \
cp ../include/internal/*.h /usr/local/nixl/include/internal && \
cp ../include/internal/*.h /usr/local/nixl/include/ && \
cp ../src/utils/serdes/serdes.h /usr/local/nixl/include
ENV LD_LIBRARY_PATH=/usr/local/nixl/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH ENV LD_LIBRARY_PATH=/usr/local/nixl/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH
ENV PYTHONPATH=/usr/local/nixl/lib/python${PYTHON_VERSION}/site-packages/:/opt/nixl/test/python/:$PYTHONPATH ENV PYTHONPATH=/usr/local/nixl/lib/python3/dist-packages/:/opt/nixl/test/python/:$PYTHONPATH
RUN ls -l /usr/local/nixl/ RUN ls -l /usr/local/nixl/
RUN ls -l /usr/local/nixl/include/ RUN ls -l /usr/local/nixl/include/
...@@ -177,6 +134,44 @@ RUN ls -l /usr/local/nixl/include/internal/ ...@@ -177,6 +134,44 @@ RUN ls -l /usr/local/nixl/include/internal/
RUN ls /opt/nixl RUN ls /opt/nixl
# Install utilities
RUN apt update -y && apt install -y git wget curl nvtop tmux vim
# nats
RUN wget https://github.com/nats-io/nats-server/releases/download/v2.10.24/nats-server-v2.10.24-amd64.deb && dpkg -i nats-server-v2.10.24-amd64.deb
# etcd
ENV ETCD_VERSION="v3.5.18"
RUN wget https://github.com/etcd-io/etcd/releases/download/$ETCD_VERSION/etcd-$ETCD_VERSION-linux-amd64.tar.gz -O /tmp/etcd.tar.gz && \
mkdir -p /usr/local/bin/etcd && \
tar -xvf /tmp/etcd.tar.gz -C /usr/local/bin/etcd --strip-components=1
ENV PATH=/usr/local/bin/etcd/:$PATH
### VIRTUAL ENVIRONMENT SETUP ###
# Install uv and create virtualenv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
RUN mkdir /opt/triton && \
uv venv /opt/triton/venv --python 3.12
# Activate virtual environment
ENV VIRTUAL_ENV=/opt/triton/venv
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Install patched vllm - keep this early in Dockerfile to avoid
# rebuilds from unrelated source code changes
ARG VLLM_REF="v0.7.2"
ARG VLLM_PATCH="vllm_${VLLM_REF}-triton-kv-disagg-patch.patch"
RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
bash /tmp/deps/vllm/install.sh --patch /tmp/deps/vllm/${VLLM_PATCH} --ref ${VLLM_REF} --install-cmd "uv pip install --editable" --use-precompiled --installation-dir /opt/vllm
# Install genai-perf for benchmarking
ARG GENAI_PERF_TAG="r25.01"
RUN uv pip install "git+https://github.com/triton-inference-server/perf_analyzer.git@${GENAI_PERF_TAG}#subdirectory=genai-perf"
# Install test dependencies
RUN --mount=type=bind,source=./container/deps/requirements.test.txt,target=/tmp/requirements.txt \
uv pip install --requirement /tmp/requirements.txt
# ### MISC UTILITY SETUP ### # ### MISC UTILITY SETUP ###
# Finish pyright install # Finish pyright install
......
...@@ -368,10 +368,15 @@ index 00000000..350453cd ...@@ -368,10 +368,15 @@ index 00000000..350453cd
+ +
+ self.event_id_counter += 1 + self.event_id_counter += 1
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index f507847a..ee20d50c 100644 index f507847a..abe574d1 100644
--- a/vllm/core/scheduler.py --- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py
@@ -8,18 +8,17 @@ from collections import deque @@ -4,22 +4,22 @@ import enum
import os
import random
import time
+import copy
from collections import deque
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Callable, Deque, Dict, Iterable, List, Optional from typing import Callable, Deque, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
...@@ -393,7 +398,7 @@ index f507847a..ee20d50c 100644 ...@@ -393,7 +398,7 @@ index f507847a..ee20d50c 100644
logger = init_logger(__name__) logger = init_logger(__name__)
# Test-only. If configured, decode is preempted with # Test-only. If configured, decode is preempted with
@@ -325,12 +324,14 @@ class Scheduler: @@ -325,12 +325,14 @@ class Scheduler:
def __init__( def __init__(
self, self,
...@@ -408,7 +413,7 @@ index f507847a..ee20d50c 100644 ...@@ -408,7 +413,7 @@ index f507847a..ee20d50c 100644
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.cache_config = cache_config self.cache_config = cache_config
# Note for LoRA scheduling: the current policy is extremely # Note for LoRA scheduling: the current policy is extremely
@@ -356,6 +357,7 @@ class Scheduler: @@ -356,6 +358,7 @@ class Scheduler:
# Create the block space manager. # Create the block space manager.
self.block_manager = BlockSpaceManagerImpl( self.block_manager = BlockSpaceManagerImpl(
...@@ -416,7 +421,7 @@ index f507847a..ee20d50c 100644 ...@@ -416,7 +421,7 @@ index f507847a..ee20d50c 100644
block_size=self.cache_config.block_size, block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks, num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks, num_cpu_blocks=num_cpu_blocks,
@@ -371,6 +373,14 @@ class Scheduler: @@ -371,6 +374,16 @@ class Scheduler:
# Sequence groups in the SWAPPED state. # Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out. # Contain decode requests that are swapped out.
self.swapped: Deque[SequenceGroup] = deque() self.swapped: Deque[SequenceGroup] = deque()
...@@ -424,6 +429,8 @@ index f507847a..ee20d50c 100644 ...@@ -424,6 +429,8 @@ index f507847a..ee20d50c 100644
+ # Sequence groups in the REMOTE_PREFILLING state. + # Sequence groups in the REMOTE_PREFILLING state.
+ # Contain requests that are being prefilled by a remote worker. + # Contain requests that are being prefilled by a remote worker.
+ self.remote_prefilling: Deque[SequenceGroup] = deque() + self.remote_prefilling: Deque[SequenceGroup] = deque()
+ # Contain requests that are being prefilled by a local worker.
+ self.prefill_sending: Deque[SequenceGroup] = deque()
+ +
+ self._remote_prefill_outputs: Dict[str, int] = {} + self._remote_prefill_outputs: Dict[str, int] = {}
+ +
...@@ -431,24 +438,25 @@ index f507847a..ee20d50c 100644 ...@@ -431,24 +438,25 @@ index f507847a..ee20d50c 100644
# Sequence groups finished requests ids since last step iteration. # Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests # It lets the model know that any state associated with these requests
# can and must be released after the current step. # can and must be released after the current step.
@@ -501,7 +511,7 @@ class Scheduler: @@ -501,7 +514,7 @@ class Scheduler:
def has_unfinished_seqs(self) -> bool: def has_unfinished_seqs(self) -> bool:
return len(self.waiting) != 0 or len(self.running) != 0 or len( return len(self.waiting) != 0 or len(self.running) != 0 or len(
- self.swapped) != 0 - self.swapped) != 0
+ self.swapped) != 0 or len(self.remote_prefilling) != 0 + self.swapped) != 0 or len(self.remote_prefilling) != 0 or len(self.prefill_sending) != 0
def get_prefix_cache_hit_rate(self, device: Device) -> float: def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device) return self.block_manager.get_prefix_cache_hit_rate(device)
@@ -523,6 +533,7 @@ class Scheduler: @@ -523,6 +536,8 @@ class Scheduler:
budget: SchedulingBudget, budget: SchedulingBudget,
curr_loras: Optional[Set[int]], curr_loras: Optional[Set[int]],
enable_chunking: bool = False, enable_chunking: bool = False,
+ finished_prefills: Optional[Set[str]] = None + finished_prefills: Optional[Set[str]] = None,
+ finished_transfers: Optional[Set[str]] = None
) -> SchedulerRunningOutputs: ) -> SchedulerRunningOutputs:
"""Schedule sequence groups that are running. """Schedule sequence groups that are running.
@@ -537,6 +548,8 @@ class Scheduler: @@ -537,6 +552,8 @@ class Scheduler:
chunked number of tokens are scheduled if chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule `budget.num_batched_tokens` has not enough capacity to schedule
all tokens. all tokens.
...@@ -457,7 +465,7 @@ index f507847a..ee20d50c 100644 ...@@ -457,7 +465,7 @@ index f507847a..ee20d50c 100644
Returns: Returns:
SchedulerRunningOutputs. SchedulerRunningOutputs.
@@ -566,6 +579,24 @@ class Scheduler: @@ -566,6 +583,38 @@ class Scheduler:
preempted: List[SequenceGroup] = ret.preempted preempted: List[SequenceGroup] = ret.preempted
swapped_out: List[SequenceGroup] = ret.swapped_out swapped_out: List[SequenceGroup] = ret.swapped_out
...@@ -468,6 +476,7 @@ index f507847a..ee20d50c 100644 ...@@ -468,6 +476,7 @@ index f507847a..ee20d50c 100644
+ if seq_group.request_id not in finished_prefills: + if seq_group.request_id not in finished_prefills:
+ leftover_remote_prefilling_sequences.append(seq_group) + leftover_remote_prefilling_sequences.append(seq_group)
+ continue + continue
+
+ else: + else:
+ finished_prefills.remove(seq_group.request_id) + finished_prefills.remove(seq_group.request_id)
+ assert len(seq_group.seqs) == 1 + assert len(seq_group.seqs) == 1
...@@ -478,39 +487,63 @@ index f507847a..ee20d50c 100644 ...@@ -478,39 +487,63 @@ index f507847a..ee20d50c 100644
+ seq.data._stage = SequenceStage.DECODE + seq.data._stage = SequenceStage.DECODE
+ self.running.appendleft(seq_group) + self.running.appendleft(seq_group)
+ remote_prefilling_queue.extendleft(leftover_remote_prefilling_sequences) + remote_prefilling_queue.extendleft(leftover_remote_prefilling_sequences)
+
+ remote_transfers_queue = self.prefill_sending
+ leftover_remote_transfers_sequences: Deque[SequenceGroup] = deque()
+ while remote_transfers_queue:
+ seq_group = remote_transfers_queue.popleft()
+ if seq_group.request_id not in finished_transfers:
+ leftover_remote_transfers_sequences.append(seq_group)
+ else:
+ finished_transfers.remove(seq_group.request_id)
+ assert len(seq_group.seqs) == 1
+ seq = seq_group.seqs[0]
+ self.free_seq(seq)
+ remote_transfers_queue.extendleft(leftover_remote_transfers_sequences)
+ +
running_queue = self.running running_queue = self.running
assert len(self._async_stopped) == 0 assert len(self._async_stopped) == 0
while running_queue: while running_queue:
@@ -1008,7 +1039,7 @@ class Scheduler: @@ -1008,7 +1057,17 @@ class Scheduler:
if curr_loras is not None and lora_int_id > 0: if curr_loras is not None and lora_int_id > 0:
curr_loras.add(lora_int_id) curr_loras.add(lora_int_id)
waiting_queue.popleft() waiting_queue.popleft()
- self._allocate_and_set_running(seq_group) - self._allocate_and_set_running(seq_group)
+
+ seq_group_copy = copy.deepcopy(seq_group)
+ seq_group_copy.seqs[0].seq_id = seq_group.seqs[0].seq_id + 1
+
+ logger.debug("Allocating and setting running or remote prefill for seq_group %s", seq_group.request_id)
+ logger.debug("Seq id: %s", seq_group.seqs[0].seq_id)
+ self._allocate_and_set_running_or_remote_prefill(seq_group) + self._allocate_and_set_running_or_remote_prefill(seq_group)
+ if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode:
+ logger.debug("Seq id: %s", seq_group_copy.seqs[0].seq_id)
+ self._allocate_and_set_running_or_remote_prefill(seq_group_copy)
+ self.prefill_sending.append(seq_group_copy)
if enable_chunking and self.scheduler_config.is_multi_step: if enable_chunking and self.scheduler_config.is_multi_step:
blocks_to_copy: List[Tuple[int, int]] = [] blocks_to_copy: List[Tuple[int, int]] = []
@@ -1048,7 +1079,7 @@ class Scheduler: @@ -1048,7 +1107,7 @@ class Scheduler:
num_lookahead_slots=self._get_num_lookahead_slots( num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True, enable_chunking=enable_chunking)) is_prefill=True, enable_chunking=enable_chunking))
- def _schedule_default(self) -> SchedulerOutputs: - def _schedule_default(self) -> SchedulerOutputs:
+ def _schedule_default(self, finished_prefills: Optional[Set[str]] = None) -> SchedulerOutputs: + def _schedule_default(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs:
"""Schedule queued requests. """Schedule queued requests.
The current policy is designed to optimize the throughput. First, The current policy is designed to optimize the throughput. First,
@@ -1090,7 +1121,8 @@ class Scheduler: @@ -1090,7 +1149,9 @@ class Scheduler:
if len(prefills.seq_groups) == 0: if len(prefills.seq_groups) == 0:
running_scheduled = self._schedule_running(budget, running_scheduled = self._schedule_running(budget,
curr_loras, curr_loras,
- enable_chunking=False) - enable_chunking=False)
+ enable_chunking=False, + enable_chunking=False,
+ finished_prefills=finished_prefills) + finished_prefills=finished_prefills,
+ finished_transfers=finished_transfers)
# If any sequence group is preempted, do not swap in any sequence # If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests. # group. because it means there's no slot for new running requests.
@@ -1106,7 +1138,12 @@ class Scheduler: @@ -1106,7 +1167,12 @@ class Scheduler:
self.waiting.extendleft(running_scheduled.preempted) self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests. # Update new running requests.
if len(prefills.seq_groups) > 0: if len(prefills.seq_groups) > 0:
...@@ -524,30 +557,31 @@ index f507847a..ee20d50c 100644 ...@@ -524,30 +557,31 @@ index f507847a..ee20d50c 100644
self.running.extend(running_scheduled.decode_seq_groups_list) self.running.extend(running_scheduled.decode_seq_groups_list)
@@ -1248,12 +1285,14 @@ class Scheduler: @@ -1248,12 +1314,14 @@ class Scheduler:
len(running_scheduled.swapped_out)), len(running_scheduled.swapped_out)),
) )
- def _schedule(self) -> SchedulerOutputs: - def _schedule(self) -> SchedulerOutputs:
+ def _schedule(self, finished_prefills: Optional[Set[str]] = None) -> SchedulerOutputs: + def _schedule(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs:
"""Schedule queued requests.""" """Schedule queued requests."""
if self.scheduler_config.chunked_prefill_enabled: if self.scheduler_config.chunked_prefill_enabled:
+ if finished_prefills: + if finished_prefills or finished_transfers:
+ raise ValueError("Chunked prefill does not support remote prefills") + raise ValueError("Chunked prefill does not support remote prefills")
return self._schedule_chunked_prefill() return self._schedule_chunked_prefill()
else: else:
- return self._schedule_default() - return self._schedule_default()
+ return self._schedule_default(finished_prefills) + return self._schedule_default(finished_prefills, finished_transfers)
def _can_append_slots(self, seq_group: SequenceGroup, def _can_append_slots(self, seq_group: SequenceGroup,
enable_chunking: bool) -> bool: enable_chunking: bool) -> bool:
@@ -1287,14 +1326,15 @@ class Scheduler: @@ -1287,14 +1355,16 @@ class Scheduler:
return no_single_seq return no_single_seq
def schedule( def schedule(
- self - self
+ self, + self,
+ finished_prefills: Optional[Set[str]] = None + finished_prefills: Optional[Set[str]] = None,
+ finished_transfers: Optional[Set[str]] = None
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]:
# Schedule sequence groups. # Schedule sequence groups.
# This function call changes the internal states of the scheduler # This function call changes the internal states of the scheduler
...@@ -556,11 +590,11 @@ index f507847a..ee20d50c 100644 ...@@ -556,11 +590,11 @@ index f507847a..ee20d50c 100644
- scheduler_outputs: SchedulerOutputs = self._schedule() - scheduler_outputs: SchedulerOutputs = self._schedule()
+ scheduler_start_time = time.perf_counter() + scheduler_start_time = time.perf_counter()
+ scheduler_outputs: SchedulerOutputs = self._schedule(finished_prefills) + scheduler_outputs: SchedulerOutputs = self._schedule(finished_prefills, finished_transfers)
now = time.time() now = time.time()
if not self.cache_config.enable_prefix_caching: if not self.cache_config.enable_prefix_caching:
@@ -1333,7 +1373,8 @@ class Scheduler: @@ -1333,7 +1403,8 @@ class Scheduler:
encoder_seq_data = None encoder_seq_data = None
cross_block_table = None cross_block_table = None
...@@ -570,18 +604,24 @@ index f507847a..ee20d50c 100644 ...@@ -570,18 +604,24 @@ index f507847a..ee20d50c 100644
seq_id = seq.seq_id seq_id = seq.seq_id
seq_data[seq_id] = seq.data seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq) block_tables[seq_id] = self.block_manager.get_block_table(seq)
@@ -1364,6 +1405,10 @@ class Scheduler: @@ -1364,9 +1435,16 @@ class Scheduler:
< seqs[0].data.get_len()): < seqs[0].data.get_len()):
do_sample = False do_sample = False
+ is_remote_prefill = False + is_remote_prefill = False
+ if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill: + if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill:
+ is_remote_prefill = True + is_remote_prefill = True
+ if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode:
+ block_tables[seq_group.seqs[0].seq_id + 1] = self.block_manager.block_tables[seq.seq_id + 1].physical_block_ids
+ +
# It assumes the scheduled_seq_groups is ordered by # It assumes the scheduled_seq_groups is ordered by
# prefill < decoding. # prefill < decoding.
if is_first_prefill or not self.scheduler_config.send_delta_data: if is_first_prefill or not self.scheduler_config.send_delta_data:
@@ -1392,6 +1437,7 @@ class Scheduler: + logger.debug("Assinged blocks: %s", block_tables)
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
is_prompt=is_prompt,
@@ -1392,6 +1470,7 @@ class Scheduler:
if scheduler_outputs.num_prefill_groups > 0 else None, if scheduler_outputs.num_prefill_groups > 0 else None,
mm_processor_kwargs=seq_group.mm_processor_kwargs, mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request, prompt_adapter_request=seq_group.prompt_adapter_request,
...@@ -589,7 +629,7 @@ index f507847a..ee20d50c 100644 ...@@ -589,7 +629,7 @@ index f507847a..ee20d50c 100644
) )
else: else:
# When SPMD mode is enabled, we only send delta data except for # When SPMD mode is enabled, we only send delta data except for
@@ -1490,10 +1536,13 @@ class Scheduler: @@ -1490,10 +1569,13 @@ class Scheduler:
self._async_stopped.clear() self._async_stopped.clear()
...@@ -605,12 +645,80 @@ index f507847a..ee20d50c 100644 ...@@ -605,12 +645,80 @@ index f507847a..ee20d50c 100644
def _append_slots(self, def _append_slots(self,
seq_group: SequenceGroup, seq_group: SequenceGroup,
diff --git a/vllm/distributed/device_communicators/kv_rearrange.py b/vllm/distributed/device_communicators/kv_rearrange.py
new file mode 100644
index 00000000..9b938039
--- /dev/null
+++ b/vllm/distributed/device_communicators/kv_rearrange.py
@@ -0,0 +1,61 @@
+import torch
+import triton
+import triton.language as tl
+
+@triton.jit
+def rearrange_kernel(
+ t1_ptr,
+ t2_ptr,
+ N,
+ B,
+ H,
+ C,
+ d,
+ tensor_subset_size,
+ block_size,
+ token_size,
+ BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+
+ block_start = pid * BLOCK_SIZE
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
+
+ curr_n = offsets // block_size
+ curr_b = offsets // token_size % B
+ curr_h = offsets // C % H
+ curr_c = offsets % C
+
+ src_pos = offsets
+
+ tp_group = curr_h * d // H
+ dst_h = curr_h % (H // d)
+ tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c
+
+ dst_pos = tensor_subset_size * tp_group + tp_group_offset
+
+ tl.store(t2_ptr + dst_pos, tl.load(t1_ptr + src_pos))
+
+def rearrange_tensors(t1: torch.Tensor, t2: torch.Tensor, d: int):
+ N, B, H, C = t1.shape
+
+ assert t2.shape == (N, B, H, C), "Destination tensor must have same shape as source"
+ assert H % d == 0, "H must be divisible by d"
+
+ block_size = B * H * C
+ token_size = H * C
+ tensor_size = N * block_size
+ tensor_subset_size = tensor_size // d
+
+ BLOCK_SIZE = 1024
+ grid = ((N * B * H * C + BLOCK_SIZE - 1) // BLOCK_SIZE,)
+
+ rearrange_kernel[grid](
+ t1, t2,
+ N, B, H, C,
+ d,
+ tensor_subset_size,
+ block_size,
+ token_size,
+ BLOCK_SIZE=BLOCK_SIZE
+ )
\ No newline at end of file
diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py
new file mode 100644 new file mode 100644
index 00000000..bc962726 index 00000000..f1618bc4
--- /dev/null --- /dev/null
+++ b/vllm/distributed/device_communicators/nixl.py +++ b/vllm/distributed/device_communicators/nixl.py
@@ -0,0 +1,249 @@ @@ -0,0 +1,318 @@
+import torch +import torch
+from typing import List, Tuple +from typing import List, Tuple
+from vllm.config import VllmConfig +from vllm.config import VllmConfig
...@@ -618,46 +726,19 @@ index 00000000..bc962726 ...@@ -618,46 +726,19 @@ index 00000000..bc962726
+import msgspec +import msgspec
+import time +import time
+import uuid +import uuid
+from collections import defaultdict
+from .kv_rearrange import rearrange_tensors
+ +
+logger = init_logger(__name__) +logger = init_logger(__name__)
+ +
+def nixl_wrapper_init_patch(self, agent_name, nixl_config):
+ logger.info("Initializing patched NixlWrapper")
+ import nixl_bindings as nixl
+ # Read available backends and device info from nixl_config
+ # For now setting the multithreading to enabled.
+ devices = nixl.nixlAgentConfig(False)
+ init = nixl.nixlUcxInitParams()
+
+ self.name = agent_name
+ self.notifs = {}
+ self.backends = {}
+ self.agent = nixl.nixlAgent(agent_name, devices)
+ self.backends["UCX"] = self.agent.createBackend(init)
+
+ self.nixl_mems = {"DRAM": nixl.DRAM_SEG,
+ "VRAM": nixl.VRAM_SEG,
+ "cpu": nixl.DRAM_SEG,
+ "cuda": nixl.VRAM_SEG}
+ self.nixl_ops = {"WRITE": nixl.NIXL_WR_FLUSH,
+ "READ": nixl.NIXL_RD_FLUSH,
+ "WRITE_NOTIF": nixl.NIXL_WR_NOTIF,
+ "READ_NOTIF": nixl.NIXL_RD_NOTIF}
+
+ logger.info("Initializied NIXL agent: %s", agent_name)
+
+
+# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used +# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
+try: +try:
+ from nixl_wrapper import nixl_wrapper as NixlWrapper # type: ignore + from nixl_wrapper import nixl_wrapper as NixlWrapper # type: ignore
+ logger.info("NIXL is available") + logger.info("NIXL is available")
+ NixlWrapper.__init__ = nixl_wrapper_init_patch
+except ImportError: +except ImportError:
+ logger.warning("NIXL is not available") + logger.warning("NIXL is not available")
+ NixlWrapper = None + NixlWrapper = None
+ +
+
+
+class NixlMetadata( +class NixlMetadata(
+ msgspec.Struct, + msgspec.Struct,
+ omit_defaults=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg]
...@@ -679,7 +760,9 @@ index 00000000..bc962726 ...@@ -679,7 +760,9 @@ index 00000000..bc962726
+ +
+ self.num_layers = None + self.num_layers = None
+ self.num_blocks = None + self.num_blocks = None
+ self.num_heads = None
+ self.block_len = None + self.block_len = None
+ self.kv_caches = None
+ self.kv_caches_base_addr = {} + self.kv_caches_base_addr = {}
+ self.kv_cache_shape = {} + self.kv_cache_shape = {}
+ +
...@@ -688,33 +771,51 @@ index 00000000..bc962726 ...@@ -688,33 +771,51 @@ index 00000000..bc962726
+ self.engine_id = engine_id + self.engine_id = engine_id
+ self.rank = rank + self.rank = rank
+ self.notifs = {} + self.notifs = {}
+ self._tp_size = {}
+ self._block_descs = {}
+ self._xfer_side_handles = {}
+
+
+ self._transfers = defaultdict(list)
+
+
+ self._tp_size[engine_id] = vllm_config.parallel_config.tensor_parallel_size
+
+ +
+ @property + @property
+ def agent_name(self): + def agent_name(self):
+ return self.nixl_wrapper.name + return self.nixl_wrapper.name
+ +
+ def register_kv_caches(self, kv_caches: List[torch.Tensor]): + def register_kv_caches(self, kv_caches: List[torch.Tensor]):
+ caches_data = [] + _, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape
+ self.num_layers = len(kv_caches)
+ _, _, block_size, num_heads, head_dim = kv_caches[0].shape
+ self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size() + self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size()
+ logger.debug("Per layer kv cache size: %s", kv_caches[0].shape) + logger.debug("Per layer kv cache size: %s", kv_caches[0].shape)
+ + self.num_layers = len(kv_caches)
+ self.num_blocks = num_blocks
+ self.num_heads = num_heads
+ self.kv_caches = kv_caches
+ kv_caches_base_addr = [] + kv_caches_base_addr = []
+ caches_data = []
+ blocks_data = []
+ for key_cache, value_cache in kv_caches: + for key_cache, value_cache in kv_caches:
+ for cache in [key_cache, value_cache]: + for cache in [key_cache, value_cache]:
+ base_addr = cache.data_ptr() + base_addr = cache.data_ptr()
+ region_len = cache.numel() * cache.element_size() + region_len = num_blocks * self.block_len
+ gpu_id = cache.get_device() + caches_data.append((base_addr, region_len, self.rank))
+ assert gpu_id > -1, "Tensor is not on GPU" + for block_id in range(self.num_blocks):
+ caches_data.append((base_addr, region_len, gpu_id)) + blocks_data.append((base_addr + block_id * self.block_len, self.block_len, self.rank))
+
+ kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr())) + kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr()))
+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
+ +
+ descs = self.nixl_wrapper.get_descs(("VRAM", caches_data)) + descs = self.nixl_wrapper.get_descs(("VRAM", caches_data))
+ logger.debug("Registering descs: %s", caches_data)
+ self.nixl_wrapper.register_memory(descs) + self.nixl_wrapper.register_memory(descs)
+ self._registered_descs.append(descs) + self._registered_descs.append(descs)
+ +
+ self._block_descs[self.engine_id] = self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+ self._xfer_side_handles[self.engine_id] = self.nixl_wrapper.prep_xfer_side(self._block_descs[self.engine_id])
+
+ def get_agent_metadata(self): + def get_agent_metadata(self):
+ return self.nixl_wrapper.get_agent_metadata() + return self.nixl_wrapper.get_agent_metadata()
+ +
...@@ -724,10 +825,14 @@ index 00000000..bc962726 ...@@ -724,10 +825,14 @@ index 00000000..bc962726
+ for agent_name in self._remote_agents.values(): + for agent_name in self._remote_agents.values():
+ self.nixl_wrapper.remove_remote_agent(agent_name) + self.nixl_wrapper.remove_remote_agent(agent_name)
+ +
+ def add_remote_agent(self, engine_id, agent_metadata): + def add_remote_agent(self, engine_id, agent_metadata, agent_tp):
+ agent_name = self.nixl_wrapper.add_remote_agent(agent_metadata) + self._tp_size[engine_id] = agent_tp
+ self._remote_agents[engine_id] = agent_name + agent_names = []
+ return agent_name + for agent_meta in agent_metadata:
+ agent_name = self.nixl_wrapper.add_remote_agent(agent_meta)
+ agent_names.append(agent_name)
+ self._remote_agents[engine_id] = agent_names
+ return agent_names
+ +
+ def get_descs_ids(self, layer_ids, block_ids): + def get_descs_ids(self, layer_ids, block_ids):
+ if layer_ids == "all": + if layer_ids == "all":
...@@ -742,17 +847,29 @@ index 00000000..bc962726 ...@@ -742,17 +847,29 @@ index 00000000..bc962726
+ descs_ids.append(2 * (self.num_blocks * layer_id + block_id) + 1) + descs_ids.append(2 * (self.num_blocks * layer_id + block_id) + 1)
+ return descs_ids + return descs_ids
+ +
+ def _get_range_descs(self, engine_id, ranges, layer_ids): + def _get_range_descs(self, ranges, layer_ids, kv_caches_base_addr, tp_multiplier=1, rank=None, i=0):
+ if rank is None:
+ rank = self.rank
+ offset_block_len = self.block_len
+ block_len = self.block_len // tp_multiplier
+ tp_offset = i * block_len
+ else:
+ offset_block_len = self.block_len // tp_multiplier
+ block_len = self.block_len // tp_multiplier
+ tp_offset = 0
+ logger.debug("Getting range descs for layer ids: %s, ranges: %s, tp_multiplier: %s, rank: %s, i: %s", layer_ids, ranges, tp_multiplier, rank, i)
+ if layer_ids == "all": + if layer_ids == "all":
+ layer_ids = list(range(self.num_layers)) + layer_ids = list(range(self.num_layers))
+ blocks_data = [] + blocks_data = []
+ for layer_id in layer_ids: + for layer_id in layer_ids:
+ for range_start, range_end in ranges: + for range_start, range_end in ranges:
+ key_base_addr, value_base_addr = self.kv_caches_base_addr[engine_id][layer_id] + range_len = range_end - range_start + 1
+ start_offset = range_start * self.block_len + key_base_addr, value_base_addr = kv_caches_base_addr[layer_id]
+ blocks_len = (range_end - range_start + 1) * self.block_len + start_offset = range_start * offset_block_len + tp_offset * range_len
+ blocks_data.append((key_base_addr + start_offset, blocks_len, self.rank)) + blocks_len = range_len * block_len
+ blocks_data.append((value_base_addr + start_offset, blocks_len, self.rank)) + blocks_data.append((key_base_addr + start_offset, blocks_len, rank))
+ blocks_data.append((value_base_addr + start_offset, blocks_len, rank))
+ logger.debug("Blocks data: %s", blocks_data)
+ return self.nixl_wrapper.get_descs(("VRAM", blocks_data)) + return self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+ +
+ def _get_ranges(self, block_ids): + def _get_ranges(self, block_ids):
...@@ -765,9 +882,9 @@ index 00000000..bc962726 ...@@ -765,9 +882,9 @@ index 00000000..bc962726
+ ranges = [] + ranges = []
+ for i in range(len(sorted_block_ids)): + for i in range(len(sorted_block_ids)):
+ if i == 0 or sorted_block_ids[i] != sorted_block_ids[i-1] + 1: + if i == 0 or sorted_block_ids[i] != sorted_block_ids[i-1] + 1:
+ ranges.append([sorted_block_ids[i]]) + ranges.append([sorted_block_ids[i], sorted_block_ids[i]])
+ else: + else:
+ ranges[-1].append(sorted_block_ids[i]) + ranges[-1][1] = sorted_block_ids[i]
+ return ranges + return ranges
+ +
+ def _get_same_length_ranges(self, src_ranges, dst_ranges): + def _get_same_length_ranges(self, src_ranges, dst_ranges):
...@@ -807,11 +924,24 @@ index 00000000..bc962726 ...@@ -807,11 +924,24 @@ index 00000000..bc962726
+ src_idx += 1 + src_idx += 1
+ +
+ return src_overlapping_ranges, dst_overlapping_ranges + return src_overlapping_ranges, dst_overlapping_ranges
+
+
+
+ def _get_block_descs_ids(self, layer_ids, block_ids):
+ if layer_ids == "all":
+ layer_ids = list(range(self.num_layers))
+ if block_ids == "all":
+ block_ids = list(range(self.num_blocks))
+ descs_ids = []
+ for layer_id in layer_ids:
+ for is_value in [0, 1]:
+ for block_id in block_ids:
+ descs_ids.append(layer_id * 2 * self.num_blocks + is_value * self.num_blocks + block_id)
+ return descs_ids
+ +
+ +
+ +
+ def transfer_mem(self, src_block_ids, dst_block_ids, dst_engine_id, notify_msg): + def transfer_mem(self, src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg, use_prepped_xfer=False):
+
+ start_time = time.perf_counter() + start_time = time.perf_counter()
+ logger.debug("Transferring memory from %s to %s with notify message %s", self.agent_name, dst_engine_id, notify_msg) + logger.debug("Transferring memory from %s to %s with notify message %s", self.agent_name, dst_engine_id, notify_msg)
+ +
...@@ -820,44 +950,62 @@ index 00000000..bc962726 ...@@ -820,44 +950,62 @@ index 00000000..bc962726
+ # If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \ + # If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \
+ # one less block due to the missing last token. + # one less block due to the missing last token.
+ dst_block_ids = dst_block_ids[:len(src_block_ids)] + dst_block_ids = dst_block_ids[:len(src_block_ids)]
+ assert len(staging_block_ids) == len(src_block_ids)
+ +
+ if use_prepped_xfer:
+ raise NotImplementedError("Prepped xfer is not implemented")
+ # src_block_descs_ids = self._get_block_descs_ids("all", src_block_ids)
+ # dst_block_descs_ids = self._get_block_descs_ids("all", dst_block_ids)
+
+ # src_xfer_side_handle = self._xfer_side_handles[self.engine_id]
+ # dst_xfer_side_handle = self._xfer_side_handles[dst_engine_id]
+
+ # logger.debug("Time to get block desc ids: %s ms", (time.perf_counter() - start_time) * 1000)
+
+ # handle = self.nixl_wrapper.make_prepped_xfer(src_xfer_side_handle, src_block_descs_ids,
+ # dst_xfer_side_handle, dst_block_descs_ids,
+ # notify_msg, "WRITE", no_check=True)
+ # else:
+ # Legacy path using range-based transfers
+ src_ranges = self._get_ranges(src_block_ids) + src_ranges = self._get_ranges(src_block_ids)
+ staging_ranges = self._get_ranges(staging_block_ids)
+ dst_ranges = self._get_ranges(dst_block_ids) + dst_ranges = self._get_ranges(dst_block_ids)
+ src_overlapping_ranges, dst_overlapping_ranges = self._get_same_length_ranges(src_ranges, dst_ranges)
+ +
+ logger.debug("Got %s overlapping ranges for %s blocks", len(src_overlapping_ranges), len(src_block_ids)) + assert len(src_ranges) == 1
+ assert len(staging_ranges) == 1
+ +
+ logger.debug("Time to get ranges: %s ms", time.perf_counter() - start_time) + tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
+
+ src_range_start, src_range_end = src_ranges[0]
+ src_range_len = src_range_end - src_range_start + 1
+ staging_range_start, staging_range_end = staging_ranges[0]
+ staging_range_len = staging_range_end - staging_range_start + 1
+ +
+ src_descs = self._get_range_descs(self.engine_id, src_overlapping_ranges, "all") + logger.debug("Rearranging tensors for cache: %s, src_ranges: %s of len %s, staging_ranges: %s of len %s", self.kv_caches[0].shape, src_ranges, src_range_len, staging_ranges, staging_range_len)
+ dst_descs = self._get_range_descs(dst_engine_id, dst_overlapping_ranges, "all") + for kv_cache in self.kv_caches:
+ for cache in kv_cache:
+ rearrange_tensors(cache[src_range_start:src_range_start + src_range_len], cache[staging_range_start:staging_range_start + staging_range_len], tp_multiplier)
+ +
+ logger.debug("Time to get descs: %s ms", (time.perf_counter() - start_time) * 1000) + staging_overlapping_ranges, dst_overlapping_ranges = self._get_same_length_ranges(staging_ranges, dst_ranges)
+ assert len(staging_overlapping_ranges) == len(dst_overlapping_ranges)
+ +
+ handle = self.nixl_wrapper.initialize_xfer(src_descs, dst_descs, self._remote_agents[dst_engine_id], notify_msg, "WRITE") + for i in range(tp_multiplier):
+ logger.debug("Time to initialize xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer handle: %s", handle)
+ status = self.nixl_wrapper.transfer(handle)
+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer status: %s", status)
+ # TODO ptarasiewicz: remove blocking transfer mem
+ # add scheduler check for transfer done
+ while True:
+ xfer_state = self.nixl_wrapper.check_xfer_state(handle)
+ if xfer_state == "ERR":
+ raise RuntimeError("Transfer failed")
+ elif xfer_state == "DONE":
+ logger.debug("Transfer done")
+ break
+ elif xfer_state == "PROC":
+ time.sleep(0.01)
+ else:
+ raise RuntimeError("Unknown transfer state")
+ logger.debug("Time to wait for transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ self.nixl_wrapper.abort_xfer(handle)
+ logger.debug("Time to abort xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer time: %s ms", (time.perf_counter() - start_time) * 1000)
+ +
+ src_descs = self._get_range_descs(staging_overlapping_ranges, "all", self.kv_caches_base_addr[self.engine_id], tp_multiplier, i=i)
+ dst_descs = self._get_range_descs(dst_overlapping_ranges, "all", self.kv_caches_base_addr[dst_engine_id][self.rank * tp_multiplier + i], tp_multiplier, rank=self.rank * tp_multiplier + i)
+ logger.debug("Time to get descs: %s ms", (time.perf_counter() - start_time) * 1000)
+
+ logger.debug("Transfering to agent %s", self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i])
+ handle = self.nixl_wrapper.initialize_xfer(src_descs, dst_descs,
+ self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i],
+ notify_msg, "WRITE")
+ self._transfers[notify_msg].append(handle)
+ logger.debug("Time to initialize xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer handle: %s", handle)
+ status = self.nixl_wrapper.transfer(handle)
+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer status: %s", status)
+
+ def deserialize_descs(self, serialized_descs): + def deserialize_descs(self, serialized_descs):
+ return self.nixl_wrapper.deserialize_descs(serialized_descs) + return self.nixl_wrapper.deserialize_descs(serialized_descs)
+ +
...@@ -870,6 +1018,26 @@ index 00000000..bc962726 ...@@ -870,6 +1018,26 @@ index 00000000..bc962726
+ +
+ def add_remote_kv_caches_base_addr(self, engine_id, kv_caches_base_addr): + def add_remote_kv_caches_base_addr(self, engine_id, kv_caches_base_addr):
+ self.kv_caches_base_addr[engine_id] = kv_caches_base_addr + self.kv_caches_base_addr[engine_id] = kv_caches_base_addr
+
+ def get_done_tranfers(self) -> List[str]:
+ done_req_ids = []
+ for req_id, handles in self._transfers.items():
+ running_reqs = []
+ for handle in handles:
+ xfer_state = self.nixl_wrapper.check_xfer_state(handle)
+ if xfer_state == "DONE":
+ # self.nixl_wrapper.abort_xfer(handle) # TODO ptarasiewicz: why abort is throwing errors?
+ continue
+ if xfer_state == "PROC":
+ running_reqs.append(handle)
+ else:
+ raise RuntimeError("Transfer failed with state %s", xfer_state)
+ if len(running_reqs) == 0:
+ done_req_ids.append(req_id)
+ else:
+ self._transfers[req_id] = running_reqs
+ return done_req_ids
\ No newline at end of file
diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py
index fe480533..61a357d0 100644 index fe480533..61a357d0 100644
--- a/vllm/distributed/kv_transfer/kv_connector/factory.py --- a/vllm/distributed/kv_transfer/kv_connector/factory.py
...@@ -2231,7 +2399,7 @@ index 321902d1..b8937ef8 100644 ...@@ -2231,7 +2399,7 @@ index 321902d1..b8937ef8 100644
def ensure_model_parallel_initialized( def ensure_model_parallel_initialized(
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index d82d9ad9..9ba1a326 100644 index d82d9ad9..254337cb 100644
--- a/vllm/engine/llm_engine.py --- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py
@@ -2,13 +2,17 @@ @@ -2,13 +2,17 @@
...@@ -2318,7 +2486,9 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2318,7 +2486,9 @@ index d82d9ad9..9ba1a326 100644
+ self._nixl_agents_names = self._initialize_nixl() + self._nixl_agents_names = self._initialize_nixl()
+ +
+ self._request_notif_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size) + self._request_notif_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size)
+ self._request_done_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size)
+ self._finished_prefills = set() + self._finished_prefills = set()
+ self._finished_transfers = set()
+ +
+ @property + @property
+ def is_nixl_initialized(self) -> bool: + def is_nixl_initialized(self) -> bool:
...@@ -2337,8 +2507,6 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2337,8 +2507,6 @@ index d82d9ad9..9ba1a326 100644
+ engine_id = nixl_metadata.engine_id + engine_id = nixl_metadata.engine_id
+ agents_metadata = nixl_metadata.agent_metadata + agents_metadata = nixl_metadata.agent_metadata
+ kv_caches_base_addr = nixl_metadata.kv_caches_base_addr + kv_caches_base_addr = nixl_metadata.kv_caches_base_addr
+ if len(agents_metadata) != len(self._nixl_agents_names):
+ raise ValueError("Number of agents does not match. Make sure all engines are initialized with the same parallel sizes.")
+ return self.model_executor.collective_rpc("add_remote_nixl_metadata", args=(engine_id, agents_metadata, kv_caches_base_addr)) + return self.model_executor.collective_rpc("add_remote_nixl_metadata", args=(engine_id, agents_metadata, kv_caches_base_addr))
+ +
+ def _initialize_nixl(self) -> List[bytes]: + def _initialize_nixl(self) -> List[bytes]:
...@@ -2372,7 +2540,16 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2372,7 +2540,16 @@ index d82d9ad9..9ba1a326 100644
ParallelSampleSequenceGroup.add_request( ParallelSampleSequenceGroup.add_request(
request_id, request_id,
self, self,
@@ -584,7 +634,7 @@ class LLMEngine: @@ -574,6 +624,8 @@ class LLMEngine:
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
+ if remote_prefill_params is not None and remote_prefill_params.is_remote_decode:
+ next(self.seq_counter) # empty sequence for staging
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
if is_encoder_decoder_inputs(processed_inputs):
@@ -584,7 +636,7 @@ class LLMEngine:
encoder_inputs = None encoder_inputs = None
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id, seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
...@@ -2381,7 +2558,7 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2381,7 +2558,7 @@ index d82d9ad9..9ba1a326 100644
encoder_seq = (None if encoder_inputs is None else Sequence( encoder_seq = (None if encoder_inputs is None else Sequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request, seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
@@ -601,8 +651,12 @@ class LLMEngine: @@ -601,8 +653,12 @@ class LLMEngine:
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq, encoder_seq=encoder_seq,
...@@ -2395,7 +2572,7 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2395,7 +2572,7 @@ index d82d9ad9..9ba1a326 100644
seq_group = self._create_sequence_group_with_pooling( seq_group = self._create_sequence_group_with_pooling(
request_id, request_id,
seq, seq,
@@ -673,6 +727,7 @@ class LLMEngine: @@ -673,6 +729,7 @@ class LLMEngine:
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
...@@ -2403,7 +2580,7 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2403,7 +2580,7 @@ index d82d9ad9..9ba1a326 100644
*, *,
inputs: Optional[PromptType] = None, # DEPRECATED inputs: Optional[PromptType] = None, # DEPRECATED
) -> None: ) -> None:
@@ -765,6 +820,7 @@ class LLMEngine: @@ -765,6 +822,7 @@ class LLMEngine:
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=priority, priority=priority,
...@@ -2411,7 +2588,7 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2411,7 +2588,7 @@ index d82d9ad9..9ba1a326 100644
) )
def _validate_token_prompt(self, prompt: PromptType, def _validate_token_prompt(self, prompt: PromptType,
@@ -799,6 +855,7 @@ class LLMEngine: @@ -799,6 +857,7 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
encoder_seq: Optional[Sequence] = None, encoder_seq: Optional[Sequence] = None,
priority: int = 0, priority: int = 0,
...@@ -2419,7 +2596,7 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2419,7 +2596,7 @@ index d82d9ad9..9ba1a326 100644
) -> SequenceGroup: ) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams.""" """Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs max_logprobs = self.get_model_config().max_logprobs
@@ -829,7 +886,9 @@ class LLMEngine: @@ -829,7 +888,9 @@ class LLMEngine:
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq, encoder_seq=encoder_seq,
...@@ -2430,7 +2607,7 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2430,7 +2607,7 @@ index d82d9ad9..9ba1a326 100644
return seq_group return seq_group
@@ -995,11 +1054,11 @@ class LLMEngine: @@ -995,11 +1056,11 @@ class LLMEngine:
# When we process only one request, no pop is required # When we process only one request, no pop is required
# (since later we will process all of the rest) # (since later we will process all of the rest)
(outputs, seq_group_metadata_list, scheduler_outputs, is_async, (outputs, seq_group_metadata_list, scheduler_outputs, is_async,
...@@ -2444,7 +2621,7 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2444,7 +2621,7 @@ index d82d9ad9..9ba1a326 100644
# Sanity check # Sanity check
assert len(seq_group_metadata_list) == len( assert len(seq_group_metadata_list) == len(
@@ -1325,15 +1384,49 @@ class LLMEngine: @@ -1325,15 +1386,49 @@ class LLMEngine:
# Clear outputs for each new scheduler iteration # Clear outputs for each new scheduler iteration
ctx.request_outputs.clear() ctx.request_outputs.clear()
...@@ -2464,7 +2641,7 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2464,7 +2641,7 @@ index d82d9ad9..9ba1a326 100644
(seq_group_metadata_list, scheduler_outputs, (seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc allow_async_output_proc
- ) = self.scheduler[virtual_engine].schedule() - ) = self.scheduler[virtual_engine].schedule()
+ ) = self.scheduler[virtual_engine].schedule(self._finished_prefills) + ) = self.scheduler[virtual_engine].schedule(self._finished_prefills, self._finished_transfers)
+ +
+ +
+ # Separate remote prefill and running seq groups + # Separate remote prefill and running seq groups
...@@ -2496,7 +2673,7 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2496,7 +2673,7 @@ index d82d9ad9..9ba1a326 100644
ctx.seq_group_metadata_list = seq_group_metadata_list ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs ctx.scheduler_outputs = scheduler_outputs
@@ -1383,9 +1476,29 @@ class LLMEngine: @@ -1383,9 +1478,31 @@ class LLMEngine:
execute_model_req.async_callback = self.async_callbacks[ execute_model_req.async_callback = self.async_callbacks[
virtual_engine] virtual_engine]
...@@ -2510,9 +2687,11 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2510,9 +2687,11 @@ index d82d9ad9..9ba1a326 100644
+ req_id = scheduled_seq_group.seq_group.request_id + req_id = scheduled_seq_group.seq_group.request_id
+ seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id + seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id
+ block_table = seq_group_metadata.block_tables[seq_id] + block_table = seq_group_metadata.block_tables[seq_id]
+ staging_block_ids = seq_group_metadata.block_tables[seq_id + 1]
+ memory_transfer_req = MemoryTransferRequest( + memory_transfer_req = MemoryTransferRequest(
+ request_id=req_id, + request_id=req_id,
+ src_block_ids=block_table, + src_block_ids=block_table,
+ staging_block_ids=staging_block_ids,
+ dst_block_ids=remote_prefill_params.decode_block_ids, + dst_block_ids=remote_prefill_params.decode_block_ids,
+ dst_engine_id=remote_prefill_params.decode_engine_id, + dst_engine_id=remote_prefill_params.decode_engine_id,
+ notify_msg=req_id, + notify_msg=req_id,
...@@ -2522,13 +2701,13 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2522,13 +2701,13 @@ index d82d9ad9..9ba1a326 100644
+ +
+ execute_model_req.memory_transfer_requests = memory_transfer_reqs + execute_model_req.memory_transfer_requests = memory_transfer_reqs
+ +
+ outputs, request_notif_counter = self.model_executor.execute_model( + outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model(
execute_model_req=execute_model_req) execute_model_req=execute_model_req)
- -
# We need to do this here so that last step's sampled_token_ids can # We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP. # be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
@@ -1396,7 +1509,20 @@ class LLMEngine: @@ -1396,7 +1513,26 @@ class LLMEngine:
if len(ctx.output_queue) > 0: if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx) self._process_model_outputs(ctx=ctx)
# No outputs in this case # No outputs in this case
...@@ -2539,7 +2718,7 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2539,7 +2718,7 @@ index d82d9ad9..9ba1a326 100644
+ blocks_to_swap_out=[], + blocks_to_swap_out=[],
+ blocks_to_copy=[]) + blocks_to_copy=[])
+ +
+ outputs, request_notif_counter = self.model_executor.execute_model( + outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model(
+ execute_model_req=execute_model_req) + execute_model_req=execute_model_req)
+ +
+ for req_id, notif_count in request_notif_counter.items(): + for req_id, notif_count in request_notif_counter.items():
...@@ -2547,10 +2726,16 @@ index d82d9ad9..9ba1a326 100644 ...@@ -2547,10 +2726,16 @@ index d82d9ad9..9ba1a326 100644
+ if self._request_notif_counter[req_id] > -1: + if self._request_notif_counter[req_id] > -1:
+ self._finished_prefills.add(req_id) + self._finished_prefills.add(req_id)
+ del self._request_notif_counter[req_id] + del self._request_notif_counter[req_id]
+
+ for req_id, done_count in request_done_counter.items():
+ self._request_done_counter[req_id] += done_count
+ if self._request_done_counter[req_id] > -1:
+ self._finished_transfers.add(req_id)
+ del self._request_done_counter[req_id]
# Finish the current step for all the sequence groups. # Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
@@ -1456,7 +1582,7 @@ class LLMEngine: @@ -1456,7 +1592,7 @@ class LLMEngine:
# queued control plane messages, such as add/remove lora adapters. # queued control plane messages, such as add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.") logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop() self.model_executor.stop_remote_worker_execution_loop()
...@@ -2628,7 +2813,7 @@ index 3cf1850e..6b90ece7 100644 ...@@ -2628,7 +2813,7 @@ index 3cf1850e..6b90ece7 100644
+ kv_active_blocks: int + kv_active_blocks: int
+ kv_total_blocks: int + kv_total_blocks: int
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
index 85b5f31e..d33d546a 100644 index 85b5f31e..c501e4c8 100644
--- a/vllm/engine/multiprocessing/client.py --- a/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py
@@ -8,6 +8,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, @@ -8,6 +8,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
...@@ -2756,7 +2941,7 @@ index 85b5f31e..d33d546a 100644 ...@@ -2756,7 +2941,7 @@ index 85b5f31e..d33d546a 100644
+ kv_metrics.kv_active_blocks, + kv_metrics.kv_active_blocks,
+ kv_metrics.kv_total_blocks) + kv_metrics.kv_total_blocks)
+ +
+ logger.debug("Metircs successful.") + logger.debug("Metircs successful.")
+ +
+ except asyncio.CancelledError: + except asyncio.CancelledError:
+ logger.debug("Shutting down MQLLMEngineClient check metrics loop.") + logger.debug("Shutting down MQLLMEngineClient check metrics loop.")
...@@ -3161,10 +3346,10 @@ index 786380c3..56a7cf89 100644 ...@@ -3161,10 +3346,10 @@ index 786380c3..56a7cf89 100644
"""The output data of one completion output of a request. """The output data of one completion output of a request.
diff --git a/vllm/remote_prefill.py b/vllm/remote_prefill.py diff --git a/vllm/remote_prefill.py b/vllm/remote_prefill.py
new file mode 100644 new file mode 100644
index 00000000..03f02006 index 00000000..957f55de
--- /dev/null --- /dev/null
+++ b/vllm/remote_prefill.py +++ b/vllm/remote_prefill.py
@@ -0,0 +1,53 @@ @@ -0,0 +1,54 @@
+from dataclasses import dataclass +from dataclasses import dataclass
+from typing import Callable, Optional, List, Coroutine +from typing import Callable, Optional, List, Coroutine
+ +
...@@ -3202,6 +3387,7 @@ index 00000000..03f02006 ...@@ -3202,6 +3387,7 @@ index 00000000..03f02006
+ """ + """
+ request_id: str + request_id: str
+ src_block_ids: List[int] + src_block_ids: List[int]
+ staging_block_ids: List[int]
+ dst_block_ids: List[int] + dst_block_ids: List[int]
+ dst_engine_id: str + dst_engine_id: str
+ notify_msg: str + notify_msg: str
...@@ -3361,7 +3547,7 @@ index 12baecde..cbada27f 100644 ...@@ -3361,7 +3547,7 @@ index 12baecde..cbada27f 100644
prefill_meta = model_input.attn_metadata.prefill_metadata prefill_meta = model_input.attn_metadata.prefill_metadata
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index 582aa460..ffb7b403 100644 index 582aa460..1b8515bf 100644
--- a/vllm/worker/worker.py --- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py +++ b/vllm/worker/worker.py
@@ -2,7 +2,7 @@ @@ -2,7 +2,7 @@
...@@ -3402,13 +3588,13 @@ index 582aa460..ffb7b403 100644 ...@@ -3402,13 +3588,13 @@ index 582aa460..ffb7b403 100644
+ +
+ def add_remote_nixl_metadata(self, engine_id: str, agents_metadata: List[bytes], kv_caches_base_addr: List[List[Tuple[int, int]]]) -> str: + def add_remote_nixl_metadata(self, engine_id: str, agents_metadata: List[bytes], kv_caches_base_addr: List[List[Tuple[int, int]]]) -> str:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized" + assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata[self.local_rank]) # TODO ptarasiewicz: rank or local_rank? + agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata, len(agents_metadata)) # TODO ptarasiewicz: rank or local_rank?
+ self.nixl_connector.add_remote_kv_caches_base_addr(engine_id, kv_caches_base_addr[self.local_rank]) + self.nixl_connector.add_remote_kv_caches_base_addr(engine_id, kv_caches_base_addr)
+ return agent_name + return agent_name
+ +
+ def transfer_nixl_memory(self, src_descs: List[bytes], dst_descs: List[bytes], remote_agent_name: List[str], notify_msg: str) -> None: + def transfer_nixl_memory(self, src_descs: List[bytes], dst_descs: List[bytes], remote_agent_name: List[str], notify_msg: str) -> None:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized" + assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ self.nixl_connector.transfer_mem(src_descs[self.local_rank], dst_descs[self.local_rank], remote_agent_name[self.local_rank], notify_msg) # TODO ptarasiewicz: rank or local_rank? + self.nixl_connector.transfer_mem(src_descs[self.local_rank], dst_descs[self.local_rank], remote_agent_name, notify_msg) # TODO ptarasiewicz: rank or local_rank?
+ +
+ def get_nixl_kv_caches_base_addr(self) -> List[bytes]: + def get_nixl_kv_caches_base_addr(self) -> List[bytes]:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized" + assert self.nixl_connector is not None, "Nixl connector is not initialized"
...@@ -3416,8 +3602,8 @@ index 582aa460..ffb7b403 100644 ...@@ -3416,8 +3602,8 @@ index 582aa460..ffb7b403 100644
+ +
+ def _transfer_blocks(self, worker_input: WorkerInput) -> None: + def _transfer_blocks(self, worker_input: WorkerInput) -> None:
+ if worker_input.src_block_ids is not None: + if worker_input.src_block_ids is not None:
+ for src_block_ids, dst_block_ids, dst_engine_id, notify_msg in zip(worker_input.src_block_ids, worker_input.dst_block_ids, worker_input.dst_engine_id, worker_input.notify_msg): + for src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg in zip(worker_input.src_block_ids, worker_input.staging_block_ids, worker_input.dst_block_ids, worker_input.dst_engine_id, worker_input.notify_msg):
+ self.nixl_connector.transfer_mem(src_block_ids, dst_block_ids, dst_engine_id, notify_msg) + self.nixl_connector.transfer_mem(src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg)
+ +
+ def shutdown_nixl(self) -> None: + def shutdown_nixl(self) -> None:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized" + assert self.nixl_connector is not None, "Nixl connector is not initialized"
...@@ -3435,11 +3621,12 @@ index 582aa460..ffb7b403 100644 ...@@ -3435,11 +3621,12 @@ index 582aa460..ffb7b403 100644
return WorkerInput( return WorkerInput(
num_seq_groups=num_seq_groups, num_seq_groups=num_seq_groups,
@@ -375,6 +416,10 @@ class Worker(LocalOrDistributedWorkerBase): @@ -375,6 +416,11 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine, virtual_engine=virtual_engine,
num_steps=num_steps, num_steps=num_steps,
+ src_block_ids=[r.src_block_ids for r in mem_transfer_reqs], + src_block_ids=[r.src_block_ids for r in mem_transfer_reqs],
+ staging_block_ids=[r.staging_block_ids for r in mem_transfer_reqs],
+ dst_block_ids=[r.dst_block_ids for r in mem_transfer_reqs], + dst_block_ids=[r.dst_block_ids for r in mem_transfer_reqs],
+ dst_engine_id=[r.dst_engine_id for r in mem_transfer_reqs], + dst_engine_id=[r.dst_engine_id for r in mem_transfer_reqs],
+ notify_msg=[r.notify_msg for r in mem_transfer_reqs], + notify_msg=[r.notify_msg for r in mem_transfer_reqs],
...@@ -3447,7 +3634,7 @@ index 582aa460..ffb7b403 100644 ...@@ -3447,7 +3634,7 @@ index 582aa460..ffb7b403 100644
@torch.inference_mode() @torch.inference_mode()
diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py
index 819b81fb..d9c039eb 100644 index 819b81fb..ecb68530 100644
--- a/vllm/worker/worker_base.py --- a/vllm/worker/worker_base.py
+++ b/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py
@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
...@@ -3475,11 +3662,12 @@ index 819b81fb..d9c039eb 100644 ...@@ -3475,11 +3662,12 @@ index 819b81fb..d9c039eb 100644
@abstractmethod @abstractmethod
def init_device(self) -> None: def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device """Initialize device state, such as loading the model or other on-device
@@ -216,6 +220,11 @@ class WorkerInput: @@ -216,6 +220,12 @@ class WorkerInput:
virtual_engine: int = 0 virtual_engine: int = 0
num_steps: int = 1 num_steps: int = 1
+ src_block_ids: Optional[List[List[int]]] = None + src_block_ids: Optional[List[List[int]]] = None
+ staging_block_ids: Optional[List[List[int]]] = None
+ dst_block_ids: Optional[List[List[int]]] = None + dst_block_ids: Optional[List[List[int]]] = None
+ dst_engine_id: Optional[List[str]] = None + dst_engine_id: Optional[List[str]] = None
+ notify_msg: Optional[List[str]] = None + notify_msg: Optional[List[str]] = None
...@@ -3487,29 +3675,31 @@ index 819b81fb..d9c039eb 100644 ...@@ -3487,29 +3675,31 @@ index 819b81fb..d9c039eb 100644
@classmethod @classmethod
def from_broadcasted_tensor_dict( def from_broadcasted_tensor_dict(
cls: Type["WorkerInput"], cls: Type["WorkerInput"],
@@ -232,6 +241,10 @@ class WorkerInput: @@ -232,6 +242,11 @@ class WorkerInput:
blocks_to_copy=tensor_dict.pop("blocks_to_copy"), blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"], virtual_engine=tensor_dict["virtual_engine"],
num_steps=tensor_dict.pop("num_steps"), num_steps=tensor_dict.pop("num_steps"),
+ src_block_ids=tensor_dict.pop("src_block_ids"), + src_block_ids=tensor_dict.pop("src_block_ids"),
+ staging_block_ids=tensor_dict.pop("staging_block_ids"),
+ dst_block_ids=tensor_dict.pop("dst_block_ids"), + dst_block_ids=tensor_dict.pop("dst_block_ids"),
+ dst_engine_id=tensor_dict.pop("dst_engine_id"), + dst_engine_id=tensor_dict.pop("dst_engine_id"),
+ notify_msg=tensor_dict.pop("notify_msg"), + notify_msg=tensor_dict.pop("notify_msg"),
) )
def as_broadcastable_tensor_dict( def as_broadcastable_tensor_dict(
@@ -246,6 +259,10 @@ class WorkerInput: @@ -246,6 +261,11 @@ class WorkerInput:
"blocks_to_copy": self.blocks_to_copy, "blocks_to_copy": self.blocks_to_copy,
"virtual_engine": self.virtual_engine, "virtual_engine": self.virtual_engine,
"num_steps": self.num_steps, "num_steps": self.num_steps,
+ "src_block_ids": self.src_block_ids, + "src_block_ids": self.src_block_ids,
+ "staging_block_ids": self.staging_block_ids,
+ "dst_block_ids": self.dst_block_ids, + "dst_block_ids": self.dst_block_ids,
+ "dst_engine_id": self.dst_engine_id, + "dst_engine_id": self.dst_engine_id,
+ "notify_msg": self.notify_msg, + "notify_msg": self.notify_msg,
} }
return tensor_dict return tensor_dict
@@ -316,13 +333,16 @@ class LocalOrDistributedWorkerBase(WorkerBase): @@ -316,13 +336,16 @@ class LocalOrDistributedWorkerBase(WorkerBase):
return None return None
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data) worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
...@@ -3531,7 +3721,7 @@ index 819b81fb..d9c039eb 100644 ...@@ -3531,7 +3721,7 @@ index 819b81fb..d9c039eb 100644
def _get_driver_input_and_broadcast( def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest self, execute_model_req: ExecuteModelRequest
@@ -396,49 +416,79 @@ class LocalOrDistributedWorkerBase(WorkerBase): @@ -396,49 +419,87 @@ class LocalOrDistributedWorkerBase(WorkerBase):
self.execute_worker(worker_input) self.execute_worker(worker_input)
# If there is no input, we don't need to execute the model. # If there is no input, we don't need to execute the model.
...@@ -3628,7 +3818,7 @@ index 819b81fb..d9c039eb 100644 ...@@ -3628,7 +3818,7 @@ index 819b81fb..d9c039eb 100644
+ else: + else:
+ for i in range(1, get_tp_group().world_size): + for i in range(1, get_tp_group().world_size):
+ all_new_notifs.append(get_tp_group().recv_object(src=i)) + all_new_notifs.append(get_tp_group().recv_object(src=i))
+
+ request_notif_counter = defaultdict(int) + request_notif_counter = defaultdict(int)
+ for notifs in all_new_notifs: + for notifs in all_new_notifs:
+ for req_ids in notifs.values(): + for req_ids in notifs.values():
...@@ -3637,12 +3827,20 @@ index 819b81fb..d9c039eb 100644 ...@@ -3637,12 +3827,20 @@ index 819b81fb..d9c039eb 100644
+ +
+ if request_notif_counter: + if request_notif_counter:
+ logger.debug("Request notif counter: %s", request_notif_counter) + logger.debug("Request notif counter: %s", request_notif_counter)
+
+ request_done_counter = defaultdict(int)
+ for req_id in self.nixl_connector.get_done_tranfers():
+ request_done_counter[req_id] += 1
+
+ if request_done_counter:
+ logger.debug("Request done counter: %s", request_done_counter)
+
+ else: + else:
+ request_notif_counter = {} + request_notif_counter = {}
+ request_done_counter = {}
# output is List[SamplerOutput] # output is List[SamplerOutput]
- return output - return output
+ return output, request_notif_counter + return output, request_notif_counter, request_done_counter
+ +
+ def _transfer_blocks(self, worker_input: WorkerInput) -> None: + def _transfer_blocks(self, worker_input: WorkerInput) -> None:
+ pass + pass
......
...@@ -20,7 +20,7 @@ limitations under the License. ...@@ -20,7 +20,7 @@ limitations under the License.
## Build docker ## Build docker
``` ```
./container/build.sh --framework VLLM_NIXL --target dev --build-context nixl=<path to downloaded nixl repo @ fc912eb012597be67de11fa9ba0599e4e1974fa2> ./container/build.sh --framework VLLM_NIXL --target dev --build-context nixl=<path to downloaded nixl repo @ c53bb19a6a114e9093071bd1f2904f996ae1839b>
``` ```
## Run container ## Run container
...@@ -72,10 +72,11 @@ In terminal 2: ...@@ -72,10 +72,11 @@ In terminal 2:
``` ```
cd /workspace/examples/python_rs/llm/vllm_nixl cd /workspace/examples/python_rs/llm/vllm_nixl
CUDA_VISIBLE_DEVICES=1 python3 worker.py \ CUDA_VISIBLE_DEVICES=1,2 python3 worker.py \
--remote-prefill \ --remote-prefill \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager \ --enforce-eager \
--tensor-parallel-size 2 \
--kv-transfer-config \ --kv-transfer-config \
'{"kv_connector":"TritonNixlConnector"}' '{"kv_connector":"TritonNixlConnector"}'
``` ```
......
...@@ -18,7 +18,7 @@ import asyncio ...@@ -18,7 +18,7 @@ import asyncio
import msgspec import msgspec
import uvloop import uvloop
from common import find_remote_metadata, parse_vllm_args, temp_metadata_file from common import find_remote_metadata, parse_vllm_args
from vllm.distributed.device_communicators.nixl import NixlMetadata from vllm.distributed.device_communicators.nixl import NixlMetadata
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.api_server import ( from vllm.entrypoints.openai.api_server import (
...@@ -69,19 +69,18 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): ...@@ -69,19 +69,18 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
async with build_async_engine_client_from_engine_args(engine_args) as engine_client: async with build_async_engine_client_from_engine_args(engine_args) as engine_client:
# This should be replaced with etcd # This should be replaced with etcd
metadata = engine_client.nixl_metadata metadata = engine_client.nixl_metadata
with temp_metadata_file(metadata.engine_id, metadata): print(f"Waiting for remote metadata for engine {metadata.engine_id}")
print(f"Waiting for remote metadata for engine {metadata.engine_id}") remote_metadata: list[NixlMetadata] = []
remote_metadata: list[NixlMetadata] = [] while not remote_metadata:
while not remote_metadata: await asyncio.sleep(1)
await asyncio.sleep(1) remote_metadata = find_remote_metadata(metadata.engine_id)
remote_metadata = find_remote_metadata(metadata.engine_id)
print(
print( f"Found {len(remote_metadata)} remote metadata for engine {metadata.engine_id}"
f"Found {len(remote_metadata)} remote metadata for engine {metadata.engine_id}" )
) for remote_metadata in remote_metadata:
for remote_metadata in remote_metadata: await engine_client.add_remote_nixl_metadata(remote_metadata)
await engine_client.add_remote_nixl_metadata(remote_metadata) await endpoint.serve_endpoint(RequestHandler(engine_client).generate)
await endpoint.serve_endpoint(RequestHandler(engine_client).generate)
if __name__ == "__main__": if __name__ == "__main__":
...@@ -96,4 +95,12 @@ if __name__ == "__main__": ...@@ -96,4 +95,12 @@ if __name__ == "__main__":
print("Pipeline parallel size is not supported yet, setting to 1") print("Pipeline parallel size is not supported yet, setting to 1")
engine_args.pipeline_parallel_size = 1 engine_args.pipeline_parallel_size = 1
if engine_args.disable_async_output_proc is not True:
print("Async output processing is not supported yet, setting to True")
engine_args.disable_async_output_proc = True
if engine_args.enforce_eager is not True:
print("Prefill must be done eagerly, setting to True")
engine_args.enforce_eager = True
asyncio.run(worker(engine_args)) asyncio.run(worker(engine_args))
...@@ -93,6 +93,8 @@ class RequestHandler: ...@@ -93,6 +93,8 @@ class RequestHandler:
await self.init() await self.init()
assert self.openai_serving_chat is not None assert self.openai_serving_chat is not None
request.model = "vllm"
if self.do_remote_prefill: if self.do_remote_prefill:
remote_prefill_params = RemotePrefillParams( remote_prefill_params = RemotePrefillParams(
is_remote_prefill=True, is_remote_prefill=True,
...@@ -133,7 +135,7 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): ...@@ -133,7 +135,7 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
with temp_metadata_file(metadata.engine_id, metadata): with temp_metadata_file(metadata.engine_id, metadata):
await endpoint.serve_endpoint( await endpoint.serve_endpoint(
RequestHandler( RequestHandler(
model_name=engine_args.model, model_name="vllm",
engine_client=engine_client, engine_client=engine_client,
prefill_client=prefill_client, prefill_client=prefill_client,
do_remote_prefill=True, do_remote_prefill=True,
...@@ -142,7 +144,7 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): ...@@ -142,7 +144,7 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
else: else:
await endpoint.serve_endpoint( await endpoint.serve_endpoint(
RequestHandler( RequestHandler(
model_name=engine_args.model, model_name="vllm",
engine_client=engine_client, engine_client=engine_client,
prefill_client=prefill_client, prefill_client=prefill_client,
do_remote_prefill=False, do_remote_prefill=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment