Commit fe83f8aa authored by ptarasiewiczNV's avatar ptarasiewiczNV Committed by GitHub
Browse files

Update NIXL Dockerfile + vLLM patch with variable TP (#9)


Co-authored-by: default avatarPiotr Tarasiewicz Nvidia <ptarasiewicznv@Piotrs-MacBook-Pro.local>
parent b9ce8dd0
......@@ -8,50 +8,12 @@ FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} AS dev
USER root
# Install utilities
RUN apt update -y && apt install -y git wget curl nvtop tmux vim
# nats
RUN wget https://github.com/nats-io/nats-server/releases/download/v2.10.24/nats-server-v2.10.24-amd64.deb && dpkg -i nats-server-v2.10.24-amd64.deb
# etcd
ENV ETCD_VERSION="v3.5.18"
RUN wget https://github.com/etcd-io/etcd/releases/download/$ETCD_VERSION/etcd-$ETCD_VERSION-linux-amd64.tar.gz -O /tmp/etcd.tar.gz && \
mkdir -p /usr/local/bin/etcd && \
tar -xvf /tmp/etcd.tar.gz -C /usr/local/bin/etcd --strip-components=1
ENV PATH=/usr/local/bin/etcd/:$PATH
### VIRTUAL ENVIRONMENT SETUP ###
# Install uv and create virtualenv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
RUN mkdir /opt/triton && \
uv venv /opt/triton/venv --python 3.12
# Activate virtual environment
ENV VIRTUAL_ENV=/opt/triton/venv
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Install patched vllm - keep this early in Dockerfile to avoid
# rebuilds from unrelated source code changes
ARG VLLM_REF="v0.7.2"
ARG VLLM_PATCH="vllm_${VLLM_REF}-triton-kv-disagg-patch.patch"
RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
bash /tmp/deps/vllm/install.sh --patch /tmp/deps/vllm/${VLLM_PATCH} --ref ${VLLM_REF} --install-cmd "uv pip install --editable" --use-precompiled --installation-dir /opt/vllm
# Install genai-perf for benchmarking
ARG GENAI_PERF_TAG="r25.01"
RUN uv pip install "git+https://github.com/triton-inference-server/perf_analyzer.git@${GENAI_PERF_TAG}#subdirectory=genai-perf"
# Install test dependencies
RUN --mount=type=bind,source=./container/deps/requirements.test.txt,target=/tmp/requirements.txt \
uv pip install --requirement /tmp/requirements.txt
### NIXL SETUP ###
ARG MOFED_VERSION=5.8-1.1.2.1
ARG MOFED_VERSION=24.10-1.1.4.0
ARG PYTHON_VERSION=3.12
ARG NSYS_URL=https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2024_4/
ARG NSYS_PKG=NsightSystems-linux-cli-public-2024.4.1.61-3431596.deb
ARG NSYS_URL=https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_1/
ARG NSYS_PKG=NsightSystems-linux-cli-public-2025.1.1.131-3554042.deb
RUN apt-get update -y && apt-get -y install curl \
git \
......@@ -78,7 +40,9 @@ RUN apt-get update -y && apt-get -y install curl \
libprotobuf-dev \
protobuf-compiler-grpc \
pybind11-dev \
python3-full \
python3-pip \
python3-numpy \
etcd-server \
net-tools \
pciutils \
......@@ -89,25 +53,18 @@ RUN apt-get update -y && apt-get -y install curl \
ibverbs-utils \
libibmad-dev
RUN apt-get update && \
apt install -y wget libglib2.0-0
RUN wget ${NSYS_URL}${NSYS_PKG} && \
dpkg -i $NSYS_PKG && \
rm $NSYS_PKG
RUN apt-get install -y linux-tools-common linux-tools-generic ethtool iproute2
RUN apt-get install -y dkms linux-headers-generic
RUN apt-get install -y meson ninja-build uuid-dev gdb
RUN uv pip install --upgrade meson
RUN uv pip install ninja pybind11
RUN apt-get update && apt install -y wget libglib2.0-0
RUN wget ${NSYS_URL}${NSYS_PKG} && dpkg -i $NSYS_PKG && rm $NSYS_PKG
RUN cd /usr/local/src && \
curl -fSsL "https://content.mellanox.com/ofed/MLNX_OFED-${MOFED_VERSION}/MLNX_OFED_LINUX-${MOFED_VERSION}-ubuntu20.04-x86_64.tgz" -o mofed.tgz && \
curl -fSsL "https://content.mellanox.com/ofed/MLNX_OFED-${MOFED_VERSION}/MLNX_OFED_LINUX-${MOFED_VERSION}-ubuntu24.04-x86_64.tgz" -o mofed.tgz && \
tar -xf /usr/local/src/mofed.tgz && \
cd MLNX_OFED_LINUX-* && \
apt-get update && \
apt-get install -y --no-install-recommends \
apt-get update && apt-get install -y --no-install-recommends \
./DEBS/libibverbs* ./DEBS/ibverbs-providers* ./DEBS/librdmacm* ./DEBS/libibumad* && \
rm -rf /var/lib/apt/lists/* /usr/local/src/*
......@@ -128,9 +85,7 @@ ARG UCX_VERSION=v1.18.0
RUN cd /usr/local/src && \
curl -fSsL "https://github.com/openucx/ucx/tarball/${UCX_VERSION}" | tar xz && \
cd openucx-ucx* && \
./autogen.sh && \
./configure \
--prefix=/usr/local/ucx \
./autogen.sh && ./configure \
--enable-shared \
--disable-static \
--disable-doxygen-doc \
......@@ -147,13 +102,20 @@ RUN cd /usr/local/src && \
make -j install-strip && \
ldconfig
ENV LD_LIBRARY_PATH=/usr/local/ucx/lib:$LD_LIBRARY_PATH
ENV CPATH=/usr/local/ucx/include:$CPATH
ENV PATH=/usr/local/ucx/bin:$PATH
ENV PKG_CONFIG_PATH=/usr/local/ucx/lib/pkgconfig:$PKG_CONFIG_PATH
ENV LD_LIBRARY_PATH=/usr/lib:$LD_LIBRARY_PATH
ENV CPATH=/usr/include:$CPATH
ENV PATH=/usr/bin:$PATH
ENV PKG_CONFIG_PATH=/usr/lib/pkgconfig:$PKG_CONFIG_PATH
SHELL ["/bin/bash", "-c"]
WORKDIR /workspace
ENV LD_LIBRARY_PATH=/usr/local/ompi/lib:$LD_LIBRARY_PATH
ENV CPATH=/usr/local/ompi/include:$CPATH
ENV PATH=/usr/local/ompi/bin:$PATH
ENV PKG_CONFIG_PATH=/usr/local/ompi/lib/pkgconfig:$PKG_CONFIG_PATH
COPY --from=nixl . /opt/nixl
RUN cd /opt/nixl && \
......@@ -161,15 +123,10 @@ RUN cd /opt/nixl && \
meson setup build/ --prefix=/usr/local/nixl && \
cd build/ && \
ninja && \
ninja install && \
mkdir -p /usr/local/nixl/include/internal && \
cp ../include/*.h /usr/local/nixl/include && \
cp ../include/internal/*.h /usr/local/nixl/include/internal && \
cp ../include/internal/*.h /usr/local/nixl/include/ && \
cp ../src/utils/serdes/serdes.h /usr/local/nixl/include
ninja install
ENV LD_LIBRARY_PATH=/usr/local/nixl/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH
ENV PYTHONPATH=/usr/local/nixl/lib/python${PYTHON_VERSION}/site-packages/:/opt/nixl/test/python/:$PYTHONPATH
ENV PYTHONPATH=/usr/local/nixl/lib/python3/dist-packages/:/opt/nixl/test/python/:$PYTHONPATH
RUN ls -l /usr/local/nixl/
RUN ls -l /usr/local/nixl/include/
......@@ -177,6 +134,44 @@ RUN ls -l /usr/local/nixl/include/internal/
RUN ls /opt/nixl
# Install utilities
RUN apt update -y && apt install -y git wget curl nvtop tmux vim
# nats
RUN wget https://github.com/nats-io/nats-server/releases/download/v2.10.24/nats-server-v2.10.24-amd64.deb && dpkg -i nats-server-v2.10.24-amd64.deb
# etcd
ENV ETCD_VERSION="v3.5.18"
RUN wget https://github.com/etcd-io/etcd/releases/download/$ETCD_VERSION/etcd-$ETCD_VERSION-linux-amd64.tar.gz -O /tmp/etcd.tar.gz && \
mkdir -p /usr/local/bin/etcd && \
tar -xvf /tmp/etcd.tar.gz -C /usr/local/bin/etcd --strip-components=1
ENV PATH=/usr/local/bin/etcd/:$PATH
### VIRTUAL ENVIRONMENT SETUP ###
# Install uv and create virtualenv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
RUN mkdir /opt/triton && \
uv venv /opt/triton/venv --python 3.12
# Activate virtual environment
ENV VIRTUAL_ENV=/opt/triton/venv
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Install patched vllm - keep this early in Dockerfile to avoid
# rebuilds from unrelated source code changes
ARG VLLM_REF="v0.7.2"
ARG VLLM_PATCH="vllm_${VLLM_REF}-triton-kv-disagg-patch.patch"
RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
bash /tmp/deps/vllm/install.sh --patch /tmp/deps/vllm/${VLLM_PATCH} --ref ${VLLM_REF} --install-cmd "uv pip install --editable" --use-precompiled --installation-dir /opt/vllm
# Install genai-perf for benchmarking
ARG GENAI_PERF_TAG="r25.01"
RUN uv pip install "git+https://github.com/triton-inference-server/perf_analyzer.git@${GENAI_PERF_TAG}#subdirectory=genai-perf"
# Install test dependencies
RUN --mount=type=bind,source=./container/deps/requirements.test.txt,target=/tmp/requirements.txt \
uv pip install --requirement /tmp/requirements.txt
# ### MISC UTILITY SETUP ###
# Finish pyright install
......
......@@ -368,10 +368,15 @@ index 00000000..350453cd
+
+ self.event_id_counter += 1
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index f507847a..ee20d50c 100644
index f507847a..abe574d1 100644
--- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py
@@ -8,18 +8,17 @@ from collections import deque
@@ -4,22 +4,22 @@ import enum
import os
import random
import time
+import copy
from collections import deque
from dataclasses import dataclass, field
from typing import Callable, Deque, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence
......@@ -393,7 +398,7 @@ index f507847a..ee20d50c 100644
logger = init_logger(__name__)
# Test-only. If configured, decode is preempted with
@@ -325,12 +324,14 @@ class Scheduler:
@@ -325,12 +325,14 @@ class Scheduler:
def __init__(
self,
......@@ -408,7 +413,7 @@ index f507847a..ee20d50c 100644
self.scheduler_config = scheduler_config
self.cache_config = cache_config
# Note for LoRA scheduling: the current policy is extremely
@@ -356,6 +357,7 @@ class Scheduler:
@@ -356,6 +358,7 @@ class Scheduler:
# Create the block space manager.
self.block_manager = BlockSpaceManagerImpl(
......@@ -416,7 +421,7 @@ index f507847a..ee20d50c 100644
block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
@@ -371,6 +373,14 @@ class Scheduler:
@@ -371,6 +374,16 @@ class Scheduler:
# Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out.
self.swapped: Deque[SequenceGroup] = deque()
......@@ -424,6 +429,8 @@ index f507847a..ee20d50c 100644
+ # Sequence groups in the REMOTE_PREFILLING state.
+ # Contain requests that are being prefilled by a remote worker.
+ self.remote_prefilling: Deque[SequenceGroup] = deque()
+ # Contain requests that are being prefilled by a local worker.
+ self.prefill_sending: Deque[SequenceGroup] = deque()
+
+ self._remote_prefill_outputs: Dict[str, int] = {}
+
......@@ -431,24 +438,25 @@ index f507847a..ee20d50c 100644
# Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests
# can and must be released after the current step.
@@ -501,7 +511,7 @@ class Scheduler:
@@ -501,7 +514,7 @@ class Scheduler:
def has_unfinished_seqs(self) -> bool:
return len(self.waiting) != 0 or len(self.running) != 0 or len(
- self.swapped) != 0
+ self.swapped) != 0 or len(self.remote_prefilling) != 0
+ self.swapped) != 0 or len(self.remote_prefilling) != 0 or len(self.prefill_sending) != 0
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)
@@ -523,6 +533,7 @@ class Scheduler:
@@ -523,6 +536,8 @@ class Scheduler:
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
enable_chunking: bool = False,
+ finished_prefills: Optional[Set[str]] = None
+ finished_prefills: Optional[Set[str]] = None,
+ finished_transfers: Optional[Set[str]] = None
) -> SchedulerRunningOutputs:
"""Schedule sequence groups that are running.
@@ -537,6 +548,8 @@ class Scheduler:
@@ -537,6 +552,8 @@ class Scheduler:
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
......@@ -457,7 +465,7 @@ index f507847a..ee20d50c 100644
Returns:
SchedulerRunningOutputs.
@@ -566,6 +579,24 @@ class Scheduler:
@@ -566,6 +583,38 @@ class Scheduler:
preempted: List[SequenceGroup] = ret.preempted
swapped_out: List[SequenceGroup] = ret.swapped_out
......@@ -468,6 +476,7 @@ index f507847a..ee20d50c 100644
+ if seq_group.request_id not in finished_prefills:
+ leftover_remote_prefilling_sequences.append(seq_group)
+ continue
+
+ else:
+ finished_prefills.remove(seq_group.request_id)
+ assert len(seq_group.seqs) == 1
......@@ -478,39 +487,63 @@ index f507847a..ee20d50c 100644
+ seq.data._stage = SequenceStage.DECODE
+ self.running.appendleft(seq_group)
+ remote_prefilling_queue.extendleft(leftover_remote_prefilling_sequences)
+
+ remote_transfers_queue = self.prefill_sending
+ leftover_remote_transfers_sequences: Deque[SequenceGroup] = deque()
+ while remote_transfers_queue:
+ seq_group = remote_transfers_queue.popleft()
+ if seq_group.request_id not in finished_transfers:
+ leftover_remote_transfers_sequences.append(seq_group)
+ else:
+ finished_transfers.remove(seq_group.request_id)
+ assert len(seq_group.seqs) == 1
+ seq = seq_group.seqs[0]
+ self.free_seq(seq)
+ remote_transfers_queue.extendleft(leftover_remote_transfers_sequences)
+
running_queue = self.running
assert len(self._async_stopped) == 0
while running_queue:
@@ -1008,7 +1039,7 @@ class Scheduler:
@@ -1008,7 +1057,17 @@ class Scheduler:
if curr_loras is not None and lora_int_id > 0:
curr_loras.add(lora_int_id)
waiting_queue.popleft()
- self._allocate_and_set_running(seq_group)
+
+ seq_group_copy = copy.deepcopy(seq_group)
+ seq_group_copy.seqs[0].seq_id = seq_group.seqs[0].seq_id + 1
+
+ logger.debug("Allocating and setting running or remote prefill for seq_group %s", seq_group.request_id)
+ logger.debug("Seq id: %s", seq_group.seqs[0].seq_id)
+ self._allocate_and_set_running_or_remote_prefill(seq_group)
+ if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode:
+ logger.debug("Seq id: %s", seq_group_copy.seqs[0].seq_id)
+ self._allocate_and_set_running_or_remote_prefill(seq_group_copy)
+ self.prefill_sending.append(seq_group_copy)
if enable_chunking and self.scheduler_config.is_multi_step:
blocks_to_copy: List[Tuple[int, int]] = []
@@ -1048,7 +1079,7 @@ class Scheduler:
@@ -1048,7 +1107,7 @@ class Scheduler:
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True, enable_chunking=enable_chunking))
- def _schedule_default(self) -> SchedulerOutputs:
+ def _schedule_default(self, finished_prefills: Optional[Set[str]] = None) -> SchedulerOutputs:
+ def _schedule_default(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs:
"""Schedule queued requests.
The current policy is designed to optimize the throughput. First,
@@ -1090,7 +1121,8 @@ class Scheduler:
@@ -1090,7 +1149,9 @@ class Scheduler:
if len(prefills.seq_groups) == 0:
running_scheduled = self._schedule_running(budget,
curr_loras,
- enable_chunking=False)
+ enable_chunking=False,
+ finished_prefills=finished_prefills)
+ finished_prefills=finished_prefills,
+ finished_transfers=finished_transfers)
# If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests.
@@ -1106,7 +1138,12 @@ class Scheduler:
@@ -1106,7 +1167,12 @@ class Scheduler:
self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests.
if len(prefills.seq_groups) > 0:
......@@ -524,30 +557,31 @@ index f507847a..ee20d50c 100644
self.running.extend(running_scheduled.decode_seq_groups_list)
@@ -1248,12 +1285,14 @@ class Scheduler:
@@ -1248,12 +1314,14 @@ class Scheduler:
len(running_scheduled.swapped_out)),
)
- def _schedule(self) -> SchedulerOutputs:
+ def _schedule(self, finished_prefills: Optional[Set[str]] = None) -> SchedulerOutputs:
+ def _schedule(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs:
"""Schedule queued requests."""
if self.scheduler_config.chunked_prefill_enabled:
+ if finished_prefills:
+ if finished_prefills or finished_transfers:
+ raise ValueError("Chunked prefill does not support remote prefills")
return self._schedule_chunked_prefill()
else:
- return self._schedule_default()
+ return self._schedule_default(finished_prefills)
+ return self._schedule_default(finished_prefills, finished_transfers)
def _can_append_slots(self, seq_group: SequenceGroup,
enable_chunking: bool) -> bool:
@@ -1287,14 +1326,15 @@ class Scheduler:
@@ -1287,14 +1355,16 @@ class Scheduler:
return no_single_seq
def schedule(
- self
+ self,
+ finished_prefills: Optional[Set[str]] = None
+ finished_prefills: Optional[Set[str]] = None,
+ finished_transfers: Optional[Set[str]] = None
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]:
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
......@@ -556,11 +590,11 @@ index f507847a..ee20d50c 100644
- scheduler_outputs: SchedulerOutputs = self._schedule()
+ scheduler_start_time = time.perf_counter()
+ scheduler_outputs: SchedulerOutputs = self._schedule(finished_prefills)
+ scheduler_outputs: SchedulerOutputs = self._schedule(finished_prefills, finished_transfers)
now = time.time()
if not self.cache_config.enable_prefix_caching:
@@ -1333,7 +1373,8 @@ class Scheduler:
@@ -1333,7 +1403,8 @@ class Scheduler:
encoder_seq_data = None
cross_block_table = None
......@@ -570,18 +604,24 @@ index f507847a..ee20d50c 100644
seq_id = seq.seq_id
seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq)
@@ -1364,6 +1405,10 @@ class Scheduler:
@@ -1364,9 +1435,16 @@ class Scheduler:
< seqs[0].data.get_len()):
do_sample = False
+ is_remote_prefill = False
+ if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill:
+ is_remote_prefill = True
+ if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode:
+ block_tables[seq_group.seqs[0].seq_id + 1] = self.block_manager.block_tables[seq.seq_id + 1].physical_block_ids
+
# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
if is_first_prefill or not self.scheduler_config.send_delta_data:
@@ -1392,6 +1437,7 @@ class Scheduler:
+ logger.debug("Assinged blocks: %s", block_tables)
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
is_prompt=is_prompt,
@@ -1392,6 +1470,7 @@ class Scheduler:
if scheduler_outputs.num_prefill_groups > 0 else None,
mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request,
......@@ -589,7 +629,7 @@ index f507847a..ee20d50c 100644
)
else:
# When SPMD mode is enabled, we only send delta data except for
@@ -1490,10 +1536,13 @@ class Scheduler:
@@ -1490,10 +1569,13 @@ class Scheduler:
self._async_stopped.clear()
......@@ -605,12 +645,80 @@ index f507847a..ee20d50c 100644
def _append_slots(self,
seq_group: SequenceGroup,
diff --git a/vllm/distributed/device_communicators/kv_rearrange.py b/vllm/distributed/device_communicators/kv_rearrange.py
new file mode 100644
index 00000000..9b938039
--- /dev/null
+++ b/vllm/distributed/device_communicators/kv_rearrange.py
@@ -0,0 +1,61 @@
+import torch
+import triton
+import triton.language as tl
+
+@triton.jit
+def rearrange_kernel(
+ t1_ptr,
+ t2_ptr,
+ N,
+ B,
+ H,
+ C,
+ d,
+ tensor_subset_size,
+ block_size,
+ token_size,
+ BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(0)
+
+ block_start = pid * BLOCK_SIZE
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
+
+ curr_n = offsets // block_size
+ curr_b = offsets // token_size % B
+ curr_h = offsets // C % H
+ curr_c = offsets % C
+
+ src_pos = offsets
+
+ tp_group = curr_h * d // H
+ dst_h = curr_h % (H // d)
+ tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c
+
+ dst_pos = tensor_subset_size * tp_group + tp_group_offset
+
+ tl.store(t2_ptr + dst_pos, tl.load(t1_ptr + src_pos))
+
+def rearrange_tensors(t1: torch.Tensor, t2: torch.Tensor, d: int):
+ N, B, H, C = t1.shape
+
+ assert t2.shape == (N, B, H, C), "Destination tensor must have same shape as source"
+ assert H % d == 0, "H must be divisible by d"
+
+ block_size = B * H * C
+ token_size = H * C
+ tensor_size = N * block_size
+ tensor_subset_size = tensor_size // d
+
+ BLOCK_SIZE = 1024
+ grid = ((N * B * H * C + BLOCK_SIZE - 1) // BLOCK_SIZE,)
+
+ rearrange_kernel[grid](
+ t1, t2,
+ N, B, H, C,
+ d,
+ tensor_subset_size,
+ block_size,
+ token_size,
+ BLOCK_SIZE=BLOCK_SIZE
+ )
\ No newline at end of file
diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py
new file mode 100644
index 00000000..bc962726
index 00000000..f1618bc4
--- /dev/null
+++ b/vllm/distributed/device_communicators/nixl.py
@@ -0,0 +1,249 @@
@@ -0,0 +1,318 @@
+import torch
+from typing import List, Tuple
+from vllm.config import VllmConfig
......@@ -618,46 +726,19 @@ index 00000000..bc962726
+import msgspec
+import time
+import uuid
+from collections import defaultdict
+from .kv_rearrange import rearrange_tensors
+
+logger = init_logger(__name__)
+
+def nixl_wrapper_init_patch(self, agent_name, nixl_config):
+ logger.info("Initializing patched NixlWrapper")
+ import nixl_bindings as nixl
+ # Read available backends and device info from nixl_config
+ # For now setting the multithreading to enabled.
+ devices = nixl.nixlAgentConfig(False)
+ init = nixl.nixlUcxInitParams()
+
+ self.name = agent_name
+ self.notifs = {}
+ self.backends = {}
+ self.agent = nixl.nixlAgent(agent_name, devices)
+ self.backends["UCX"] = self.agent.createBackend(init)
+
+ self.nixl_mems = {"DRAM": nixl.DRAM_SEG,
+ "VRAM": nixl.VRAM_SEG,
+ "cpu": nixl.DRAM_SEG,
+ "cuda": nixl.VRAM_SEG}
+ self.nixl_ops = {"WRITE": nixl.NIXL_WR_FLUSH,
+ "READ": nixl.NIXL_RD_FLUSH,
+ "WRITE_NOTIF": nixl.NIXL_WR_NOTIF,
+ "READ_NOTIF": nixl.NIXL_RD_NOTIF}
+
+ logger.info("Initializied NIXL agent: %s", agent_name)
+
+
+# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
+try:
+ from nixl_wrapper import nixl_wrapper as NixlWrapper # type: ignore
+ logger.info("NIXL is available")
+ NixlWrapper.__init__ = nixl_wrapper_init_patch
+except ImportError:
+ logger.warning("NIXL is not available")
+ NixlWrapper = None
+
+
+
+class NixlMetadata(
+ msgspec.Struct,
+ omit_defaults=True, # type: ignore[call-arg]
......@@ -679,7 +760,9 @@ index 00000000..bc962726
+
+ self.num_layers = None
+ self.num_blocks = None
+ self.num_heads = None
+ self.block_len = None
+ self.kv_caches = None
+ self.kv_caches_base_addr = {}
+ self.kv_cache_shape = {}
+
......@@ -688,33 +771,51 @@ index 00000000..bc962726
+ self.engine_id = engine_id
+ self.rank = rank
+ self.notifs = {}
+ self._tp_size = {}
+ self._block_descs = {}
+ self._xfer_side_handles = {}
+
+
+ self._transfers = defaultdict(list)
+
+
+ self._tp_size[engine_id] = vllm_config.parallel_config.tensor_parallel_size
+
+
+ @property
+ def agent_name(self):
+ return self.nixl_wrapper.name
+
+ def register_kv_caches(self, kv_caches: List[torch.Tensor]):
+ caches_data = []
+ self.num_layers = len(kv_caches)
+ _, _, block_size, num_heads, head_dim = kv_caches[0].shape
+ _, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape
+ self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size()
+ logger.debug("Per layer kv cache size: %s", kv_caches[0].shape)
+
+ self.num_layers = len(kv_caches)
+ self.num_blocks = num_blocks
+ self.num_heads = num_heads
+ self.kv_caches = kv_caches
+ kv_caches_base_addr = []
+ caches_data = []
+ blocks_data = []
+ for key_cache, value_cache in kv_caches:
+ for cache in [key_cache, value_cache]:
+ base_addr = cache.data_ptr()
+ region_len = cache.numel() * cache.element_size()
+ gpu_id = cache.get_device()
+ assert gpu_id > -1, "Tensor is not on GPU"
+ caches_data.append((base_addr, region_len, gpu_id))
+ region_len = num_blocks * self.block_len
+ caches_data.append((base_addr, region_len, self.rank))
+ for block_id in range(self.num_blocks):
+ blocks_data.append((base_addr + block_id * self.block_len, self.block_len, self.rank))
+
+ kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr()))
+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
+
+ descs = self.nixl_wrapper.get_descs(("VRAM", caches_data))
+ logger.debug("Registering descs: %s", caches_data)
+ self.nixl_wrapper.register_memory(descs)
+ self._registered_descs.append(descs)
+
+ self._block_descs[self.engine_id] = self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+ self._xfer_side_handles[self.engine_id] = self.nixl_wrapper.prep_xfer_side(self._block_descs[self.engine_id])
+
+ def get_agent_metadata(self):
+ return self.nixl_wrapper.get_agent_metadata()
+
......@@ -724,10 +825,14 @@ index 00000000..bc962726
+ for agent_name in self._remote_agents.values():
+ self.nixl_wrapper.remove_remote_agent(agent_name)
+
+ def add_remote_agent(self, engine_id, agent_metadata):
+ agent_name = self.nixl_wrapper.add_remote_agent(agent_metadata)
+ self._remote_agents[engine_id] = agent_name
+ return agent_name
+ def add_remote_agent(self, engine_id, agent_metadata, agent_tp):
+ self._tp_size[engine_id] = agent_tp
+ agent_names = []
+ for agent_meta in agent_metadata:
+ agent_name = self.nixl_wrapper.add_remote_agent(agent_meta)
+ agent_names.append(agent_name)
+ self._remote_agents[engine_id] = agent_names
+ return agent_names
+
+ def get_descs_ids(self, layer_ids, block_ids):
+ if layer_ids == "all":
......@@ -742,17 +847,29 @@ index 00000000..bc962726
+ descs_ids.append(2 * (self.num_blocks * layer_id + block_id) + 1)
+ return descs_ids
+
+ def _get_range_descs(self, engine_id, ranges, layer_ids):
+ def _get_range_descs(self, ranges, layer_ids, kv_caches_base_addr, tp_multiplier=1, rank=None, i=0):
+ if rank is None:
+ rank = self.rank
+ offset_block_len = self.block_len
+ block_len = self.block_len // tp_multiplier
+ tp_offset = i * block_len
+ else:
+ offset_block_len = self.block_len // tp_multiplier
+ block_len = self.block_len // tp_multiplier
+ tp_offset = 0
+ logger.debug("Getting range descs for layer ids: %s, ranges: %s, tp_multiplier: %s, rank: %s, i: %s", layer_ids, ranges, tp_multiplier, rank, i)
+ if layer_ids == "all":
+ layer_ids = list(range(self.num_layers))
+ blocks_data = []
+ for layer_id in layer_ids:
+ for range_start, range_end in ranges:
+ key_base_addr, value_base_addr = self.kv_caches_base_addr[engine_id][layer_id]
+ start_offset = range_start * self.block_len
+ blocks_len = (range_end - range_start + 1) * self.block_len
+ blocks_data.append((key_base_addr + start_offset, blocks_len, self.rank))
+ blocks_data.append((value_base_addr + start_offset, blocks_len, self.rank))
+ range_len = range_end - range_start + 1
+ key_base_addr, value_base_addr = kv_caches_base_addr[layer_id]
+ start_offset = range_start * offset_block_len + tp_offset * range_len
+ blocks_len = range_len * block_len
+ blocks_data.append((key_base_addr + start_offset, blocks_len, rank))
+ blocks_data.append((value_base_addr + start_offset, blocks_len, rank))
+ logger.debug("Blocks data: %s", blocks_data)
+ return self.nixl_wrapper.get_descs(("VRAM", blocks_data))
+
+ def _get_ranges(self, block_ids):
......@@ -765,9 +882,9 @@ index 00000000..bc962726
+ ranges = []
+ for i in range(len(sorted_block_ids)):
+ if i == 0 or sorted_block_ids[i] != sorted_block_ids[i-1] + 1:
+ ranges.append([sorted_block_ids[i]])
+ ranges.append([sorted_block_ids[i], sorted_block_ids[i]])
+ else:
+ ranges[-1].append(sorted_block_ids[i])
+ ranges[-1][1] = sorted_block_ids[i]
+ return ranges
+
+ def _get_same_length_ranges(self, src_ranges, dst_ranges):
......@@ -807,11 +924,24 @@ index 00000000..bc962726
+ src_idx += 1
+
+ return src_overlapping_ranges, dst_overlapping_ranges
+
+
+
+ def _get_block_descs_ids(self, layer_ids, block_ids):
+ if layer_ids == "all":
+ layer_ids = list(range(self.num_layers))
+ if block_ids == "all":
+ block_ids = list(range(self.num_blocks))
+ descs_ids = []
+ for layer_id in layer_ids:
+ for is_value in [0, 1]:
+ for block_id in block_ids:
+ descs_ids.append(layer_id * 2 * self.num_blocks + is_value * self.num_blocks + block_id)
+ return descs_ids
+
+
+
+ def transfer_mem(self, src_block_ids, dst_block_ids, dst_engine_id, notify_msg):
+
+ def transfer_mem(self, src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg, use_prepped_xfer=False):
+ start_time = time.perf_counter()
+ logger.debug("Transferring memory from %s to %s with notify message %s", self.agent_name, dst_engine_id, notify_msg)
+
......@@ -820,44 +950,62 @@ index 00000000..bc962726
+ # If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \
+ # one less block due to the missing last token.
+ dst_block_ids = dst_block_ids[:len(src_block_ids)]
+ assert len(staging_block_ids) == len(src_block_ids)
+
+ if use_prepped_xfer:
+ raise NotImplementedError("Prepped xfer is not implemented")
+ # src_block_descs_ids = self._get_block_descs_ids("all", src_block_ids)
+ # dst_block_descs_ids = self._get_block_descs_ids("all", dst_block_ids)
+
+ # src_xfer_side_handle = self._xfer_side_handles[self.engine_id]
+ # dst_xfer_side_handle = self._xfer_side_handles[dst_engine_id]
+
+ # logger.debug("Time to get block desc ids: %s ms", (time.perf_counter() - start_time) * 1000)
+
+ # handle = self.nixl_wrapper.make_prepped_xfer(src_xfer_side_handle, src_block_descs_ids,
+ # dst_xfer_side_handle, dst_block_descs_ids,
+ # notify_msg, "WRITE", no_check=True)
+ # else:
+ # Legacy path using range-based transfers
+ src_ranges = self._get_ranges(src_block_ids)
+ staging_ranges = self._get_ranges(staging_block_ids)
+ dst_ranges = self._get_ranges(dst_block_ids)
+ src_overlapping_ranges, dst_overlapping_ranges = self._get_same_length_ranges(src_ranges, dst_ranges)
+
+ logger.debug("Got %s overlapping ranges for %s blocks", len(src_overlapping_ranges), len(src_block_ids))
+ assert len(src_ranges) == 1
+ assert len(staging_ranges) == 1
+
+ logger.debug("Time to get ranges: %s ms", time.perf_counter() - start_time)
+ tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
+
+ src_range_start, src_range_end = src_ranges[0]
+ src_range_len = src_range_end - src_range_start + 1
+ staging_range_start, staging_range_end = staging_ranges[0]
+ staging_range_len = staging_range_end - staging_range_start + 1
+
+ src_descs = self._get_range_descs(self.engine_id, src_overlapping_ranges, "all")
+ dst_descs = self._get_range_descs(dst_engine_id, dst_overlapping_ranges, "all")
+ logger.debug("Rearranging tensors for cache: %s, src_ranges: %s of len %s, staging_ranges: %s of len %s", self.kv_caches[0].shape, src_ranges, src_range_len, staging_ranges, staging_range_len)
+ for kv_cache in self.kv_caches:
+ for cache in kv_cache:
+ rearrange_tensors(cache[src_range_start:src_range_start + src_range_len], cache[staging_range_start:staging_range_start + staging_range_len], tp_multiplier)
+
+ logger.debug("Time to get descs: %s ms", (time.perf_counter() - start_time) * 1000)
+ staging_overlapping_ranges, dst_overlapping_ranges = self._get_same_length_ranges(staging_ranges, dst_ranges)
+ assert len(staging_overlapping_ranges) == len(dst_overlapping_ranges)
+
+ handle = self.nixl_wrapper.initialize_xfer(src_descs, dst_descs, self._remote_agents[dst_engine_id], notify_msg, "WRITE")
+ logger.debug("Time to initialize xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer handle: %s", handle)
+ status = self.nixl_wrapper.transfer(handle)
+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer status: %s", status)
+ # TODO ptarasiewicz: remove blocking transfer mem
+ # add scheduler check for transfer done
+ while True:
+ xfer_state = self.nixl_wrapper.check_xfer_state(handle)
+ if xfer_state == "ERR":
+ raise RuntimeError("Transfer failed")
+ elif xfer_state == "DONE":
+ logger.debug("Transfer done")
+ break
+ elif xfer_state == "PROC":
+ time.sleep(0.01)
+ else:
+ raise RuntimeError("Unknown transfer state")
+ logger.debug("Time to wait for transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ self.nixl_wrapper.abort_xfer(handle)
+ logger.debug("Time to abort xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer time: %s ms", (time.perf_counter() - start_time) * 1000)
+ for i in range(tp_multiplier):
+
+ src_descs = self._get_range_descs(staging_overlapping_ranges, "all", self.kv_caches_base_addr[self.engine_id], tp_multiplier, i=i)
+ dst_descs = self._get_range_descs(dst_overlapping_ranges, "all", self.kv_caches_base_addr[dst_engine_id][self.rank * tp_multiplier + i], tp_multiplier, rank=self.rank * tp_multiplier + i)
+ logger.debug("Time to get descs: %s ms", (time.perf_counter() - start_time) * 1000)
+
+ logger.debug("Transfering to agent %s", self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i])
+ handle = self.nixl_wrapper.initialize_xfer(src_descs, dst_descs,
+ self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i],
+ notify_msg, "WRITE")
+ self._transfers[notify_msg].append(handle)
+ logger.debug("Time to initialize xfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer handle: %s", handle)
+ status = self.nixl_wrapper.transfer(handle)
+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Transfer status: %s", status)
+
+ def deserialize_descs(self, serialized_descs):
+ return self.nixl_wrapper.deserialize_descs(serialized_descs)
+
......@@ -870,6 +1018,26 @@ index 00000000..bc962726
+
+ def add_remote_kv_caches_base_addr(self, engine_id, kv_caches_base_addr):
+ self.kv_caches_base_addr[engine_id] = kv_caches_base_addr
+
+ def get_done_tranfers(self) -> List[str]:
+ done_req_ids = []
+ for req_id, handles in self._transfers.items():
+ running_reqs = []
+ for handle in handles:
+ xfer_state = self.nixl_wrapper.check_xfer_state(handle)
+ if xfer_state == "DONE":
+ # self.nixl_wrapper.abort_xfer(handle) # TODO ptarasiewicz: why abort is throwing errors?
+ continue
+ if xfer_state == "PROC":
+ running_reqs.append(handle)
+ else:
+ raise RuntimeError("Transfer failed with state %s", xfer_state)
+ if len(running_reqs) == 0:
+ done_req_ids.append(req_id)
+ else:
+ self._transfers[req_id] = running_reqs
+ return done_req_ids
\ No newline at end of file
diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py
index fe480533..61a357d0 100644
--- a/vllm/distributed/kv_transfer/kv_connector/factory.py
......@@ -2231,7 +2399,7 @@ index 321902d1..b8937ef8 100644
def ensure_model_parallel_initialized(
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index d82d9ad9..9ba1a326 100644
index d82d9ad9..254337cb 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -2,13 +2,17 @@
......@@ -2318,7 +2486,9 @@ index d82d9ad9..9ba1a326 100644
+ self._nixl_agents_names = self._initialize_nixl()
+
+ self._request_notif_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size)
+ self._request_done_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size)
+ self._finished_prefills = set()
+ self._finished_transfers = set()
+
+ @property
+ def is_nixl_initialized(self) -> bool:
......@@ -2337,8 +2507,6 @@ index d82d9ad9..9ba1a326 100644
+ engine_id = nixl_metadata.engine_id
+ agents_metadata = nixl_metadata.agent_metadata
+ kv_caches_base_addr = nixl_metadata.kv_caches_base_addr
+ if len(agents_metadata) != len(self._nixl_agents_names):
+ raise ValueError("Number of agents does not match. Make sure all engines are initialized with the same parallel sizes.")
+ return self.model_executor.collective_rpc("add_remote_nixl_metadata", args=(engine_id, agents_metadata, kv_caches_base_addr))
+
+ def _initialize_nixl(self) -> List[bytes]:
......@@ -2372,7 +2540,16 @@ index d82d9ad9..9ba1a326 100644
ParallelSampleSequenceGroup.add_request(
request_id,
self,
@@ -584,7 +634,7 @@ class LLMEngine:
@@ -574,6 +624,8 @@ class LLMEngine:
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
+ if remote_prefill_params is not None and remote_prefill_params.is_remote_decode:
+ next(self.seq_counter) # empty sequence for staging
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
if is_encoder_decoder_inputs(processed_inputs):
@@ -584,7 +636,7 @@ class LLMEngine:
encoder_inputs = None
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
......@@ -2381,7 +2558,7 @@ index d82d9ad9..9ba1a326 100644
encoder_seq = (None if encoder_inputs is None else Sequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
@@ -601,8 +651,12 @@ class LLMEngine:
@@ -601,8 +653,12 @@ class LLMEngine:
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
......@@ -2395,7 +2572,7 @@ index d82d9ad9..9ba1a326 100644
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
@@ -673,6 +727,7 @@ class LLMEngine:
@@ -673,6 +729,7 @@ class LLMEngine:
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
......@@ -2403,7 +2580,7 @@ index d82d9ad9..9ba1a326 100644
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
@@ -765,6 +820,7 @@ class LLMEngine:
@@ -765,6 +822,7 @@ class LLMEngine:
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=priority,
......@@ -2411,7 +2588,7 @@ index d82d9ad9..9ba1a326 100644
)
def _validate_token_prompt(self, prompt: PromptType,
@@ -799,6 +855,7 @@ class LLMEngine:
@@ -799,6 +857,7 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
encoder_seq: Optional[Sequence] = None,
priority: int = 0,
......@@ -2419,7 +2596,7 @@ index d82d9ad9..9ba1a326 100644
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs
@@ -829,7 +886,9 @@ class LLMEngine:
@@ -829,7 +888,9 @@ class LLMEngine:
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
......@@ -2430,7 +2607,7 @@ index d82d9ad9..9ba1a326 100644
return seq_group
@@ -995,11 +1054,11 @@ class LLMEngine:
@@ -995,11 +1056,11 @@ class LLMEngine:
# When we process only one request, no pop is required
# (since later we will process all of the rest)
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
......@@ -2444,7 +2621,7 @@ index d82d9ad9..9ba1a326 100644
# Sanity check
assert len(seq_group_metadata_list) == len(
@@ -1325,15 +1384,49 @@ class LLMEngine:
@@ -1325,15 +1386,49 @@ class LLMEngine:
# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()
......@@ -2464,7 +2641,7 @@ index d82d9ad9..9ba1a326 100644
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
- ) = self.scheduler[virtual_engine].schedule()
+ ) = self.scheduler[virtual_engine].schedule(self._finished_prefills)
+ ) = self.scheduler[virtual_engine].schedule(self._finished_prefills, self._finished_transfers)
+
+
+ # Separate remote prefill and running seq groups
......@@ -2496,7 +2673,7 @@ index d82d9ad9..9ba1a326 100644
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
@@ -1383,9 +1476,29 @@ class LLMEngine:
@@ -1383,9 +1478,31 @@ class LLMEngine:
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
......@@ -2510,9 +2687,11 @@ index d82d9ad9..9ba1a326 100644
+ req_id = scheduled_seq_group.seq_group.request_id
+ seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id
+ block_table = seq_group_metadata.block_tables[seq_id]
+ staging_block_ids = seq_group_metadata.block_tables[seq_id + 1]
+ memory_transfer_req = MemoryTransferRequest(
+ request_id=req_id,
+ src_block_ids=block_table,
+ staging_block_ids=staging_block_ids,
+ dst_block_ids=remote_prefill_params.decode_block_ids,
+ dst_engine_id=remote_prefill_params.decode_engine_id,
+ notify_msg=req_id,
......@@ -2522,13 +2701,13 @@ index d82d9ad9..9ba1a326 100644
+
+ execute_model_req.memory_transfer_requests = memory_transfer_reqs
+
+ outputs, request_notif_counter = self.model_executor.execute_model(
+ outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model(
execute_model_req=execute_model_req)
-
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
@@ -1396,7 +1509,20 @@ class LLMEngine:
@@ -1396,7 +1513,26 @@ class LLMEngine:
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
# No outputs in this case
......@@ -2539,7 +2718,7 @@ index d82d9ad9..9ba1a326 100644
+ blocks_to_swap_out=[],
+ blocks_to_copy=[])
+
+ outputs, request_notif_counter = self.model_executor.execute_model(
+ outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model(
+ execute_model_req=execute_model_req)
+
+ for req_id, notif_count in request_notif_counter.items():
......@@ -2547,10 +2726,16 @@ index d82d9ad9..9ba1a326 100644
+ if self._request_notif_counter[req_id] > -1:
+ self._finished_prefills.add(req_id)
+ del self._request_notif_counter[req_id]
+
+ for req_id, done_count in request_done_counter.items():
+ self._request_done_counter[req_id] += done_count
+ if self._request_done_counter[req_id] > -1:
+ self._finished_transfers.add(req_id)
+ del self._request_done_counter[req_id]
# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
@@ -1456,7 +1582,7 @@ class LLMEngine:
@@ -1456,7 +1592,7 @@ class LLMEngine:
# queued control plane messages, such as add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop()
......@@ -2628,7 +2813,7 @@ index 3cf1850e..6b90ece7 100644
+ kv_active_blocks: int
+ kv_total_blocks: int
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
index 85b5f31e..d33d546a 100644
index 85b5f31e..c501e4c8 100644
--- a/vllm/engine/multiprocessing/client.py
+++ b/vllm/engine/multiprocessing/client.py
@@ -8,6 +8,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
......@@ -2756,7 +2941,7 @@ index 85b5f31e..d33d546a 100644
+ kv_metrics.kv_active_blocks,
+ kv_metrics.kv_total_blocks)
+
+ logger.debug("Metircs successful.")
+ logger.debug("Metircs successful.")
+
+ except asyncio.CancelledError:
+ logger.debug("Shutting down MQLLMEngineClient check metrics loop.")
......@@ -3161,10 +3346,10 @@ index 786380c3..56a7cf89 100644
"""The output data of one completion output of a request.
diff --git a/vllm/remote_prefill.py b/vllm/remote_prefill.py
new file mode 100644
index 00000000..03f02006
index 00000000..957f55de
--- /dev/null
+++ b/vllm/remote_prefill.py
@@ -0,0 +1,53 @@
@@ -0,0 +1,54 @@
+from dataclasses import dataclass
+from typing import Callable, Optional, List, Coroutine
+
......@@ -3202,6 +3387,7 @@ index 00000000..03f02006
+ """
+ request_id: str
+ src_block_ids: List[int]
+ staging_block_ids: List[int]
+ dst_block_ids: List[int]
+ dst_engine_id: str
+ notify_msg: str
......@@ -3361,7 +3547,7 @@ index 12baecde..cbada27f 100644
prefill_meta = model_input.attn_metadata.prefill_metadata
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index 582aa460..ffb7b403 100644
index 582aa460..1b8515bf 100644
--- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
@@ -2,7 +2,7 @@
......@@ -3402,13 +3588,13 @@ index 582aa460..ffb7b403 100644
+
+ def add_remote_nixl_metadata(self, engine_id: str, agents_metadata: List[bytes], kv_caches_base_addr: List[List[Tuple[int, int]]]) -> str:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata[self.local_rank]) # TODO ptarasiewicz: rank or local_rank?
+ self.nixl_connector.add_remote_kv_caches_base_addr(engine_id, kv_caches_base_addr[self.local_rank])
+ agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata, len(agents_metadata)) # TODO ptarasiewicz: rank or local_rank?
+ self.nixl_connector.add_remote_kv_caches_base_addr(engine_id, kv_caches_base_addr)
+ return agent_name
+
+ def transfer_nixl_memory(self, src_descs: List[bytes], dst_descs: List[bytes], remote_agent_name: List[str], notify_msg: str) -> None:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
+ self.nixl_connector.transfer_mem(src_descs[self.local_rank], dst_descs[self.local_rank], remote_agent_name[self.local_rank], notify_msg) # TODO ptarasiewicz: rank or local_rank?
+ self.nixl_connector.transfer_mem(src_descs[self.local_rank], dst_descs[self.local_rank], remote_agent_name, notify_msg) # TODO ptarasiewicz: rank or local_rank?
+
+ def get_nixl_kv_caches_base_addr(self) -> List[bytes]:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
......@@ -3416,8 +3602,8 @@ index 582aa460..ffb7b403 100644
+
+ def _transfer_blocks(self, worker_input: WorkerInput) -> None:
+ if worker_input.src_block_ids is not None:
+ for src_block_ids, dst_block_ids, dst_engine_id, notify_msg in zip(worker_input.src_block_ids, worker_input.dst_block_ids, worker_input.dst_engine_id, worker_input.notify_msg):
+ self.nixl_connector.transfer_mem(src_block_ids, dst_block_ids, dst_engine_id, notify_msg)
+ for src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg in zip(worker_input.src_block_ids, worker_input.staging_block_ids, worker_input.dst_block_ids, worker_input.dst_engine_id, worker_input.notify_msg):
+ self.nixl_connector.transfer_mem(src_block_ids, staging_block_ids, dst_block_ids, dst_engine_id, notify_msg)
+
+ def shutdown_nixl(self) -> None:
+ assert self.nixl_connector is not None, "Nixl connector is not initialized"
......@@ -3435,11 +3621,12 @@ index 582aa460..ffb7b403 100644
return WorkerInput(
num_seq_groups=num_seq_groups,
@@ -375,6 +416,10 @@ class Worker(LocalOrDistributedWorkerBase):
@@ -375,6 +416,11 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
num_steps=num_steps,
+ src_block_ids=[r.src_block_ids for r in mem_transfer_reqs],
+ staging_block_ids=[r.staging_block_ids for r in mem_transfer_reqs],
+ dst_block_ids=[r.dst_block_ids for r in mem_transfer_reqs],
+ dst_engine_id=[r.dst_engine_id for r in mem_transfer_reqs],
+ notify_msg=[r.notify_msg for r in mem_transfer_reqs],
......@@ -3447,7 +3634,7 @@ index 582aa460..ffb7b403 100644
@torch.inference_mode()
diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py
index 819b81fb..d9c039eb 100644
index 819b81fb..ecb68530 100644
--- a/vllm/worker/worker_base.py
+++ b/vllm/worker/worker_base.py
@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
......@@ -3475,11 +3662,12 @@ index 819b81fb..d9c039eb 100644
@abstractmethod
def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device
@@ -216,6 +220,11 @@ class WorkerInput:
@@ -216,6 +220,12 @@ class WorkerInput:
virtual_engine: int = 0
num_steps: int = 1
+ src_block_ids: Optional[List[List[int]]] = None
+ staging_block_ids: Optional[List[List[int]]] = None
+ dst_block_ids: Optional[List[List[int]]] = None
+ dst_engine_id: Optional[List[str]] = None
+ notify_msg: Optional[List[str]] = None
......@@ -3487,29 +3675,31 @@ index 819b81fb..d9c039eb 100644
@classmethod
def from_broadcasted_tensor_dict(
cls: Type["WorkerInput"],
@@ -232,6 +241,10 @@ class WorkerInput:
@@ -232,6 +242,11 @@ class WorkerInput:
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"],
num_steps=tensor_dict.pop("num_steps"),
+ src_block_ids=tensor_dict.pop("src_block_ids"),
+ staging_block_ids=tensor_dict.pop("staging_block_ids"),
+ dst_block_ids=tensor_dict.pop("dst_block_ids"),
+ dst_engine_id=tensor_dict.pop("dst_engine_id"),
+ notify_msg=tensor_dict.pop("notify_msg"),
)
def as_broadcastable_tensor_dict(
@@ -246,6 +259,10 @@ class WorkerInput:
@@ -246,6 +261,11 @@ class WorkerInput:
"blocks_to_copy": self.blocks_to_copy,
"virtual_engine": self.virtual_engine,
"num_steps": self.num_steps,
+ "src_block_ids": self.src_block_ids,
+ "staging_block_ids": self.staging_block_ids,
+ "dst_block_ids": self.dst_block_ids,
+ "dst_engine_id": self.dst_engine_id,
+ "notify_msg": self.notify_msg,
}
return tensor_dict
@@ -316,13 +333,16 @@ class LocalOrDistributedWorkerBase(WorkerBase):
@@ -316,13 +336,16 @@ class LocalOrDistributedWorkerBase(WorkerBase):
return None
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
......@@ -3531,7 +3721,7 @@ index 819b81fb..d9c039eb 100644
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
@@ -396,49 +416,79 @@ class LocalOrDistributedWorkerBase(WorkerBase):
@@ -396,49 +419,87 @@ class LocalOrDistributedWorkerBase(WorkerBase):
self.execute_worker(worker_input)
# If there is no input, we don't need to execute the model.
......@@ -3628,7 +3818,7 @@ index 819b81fb..d9c039eb 100644
+ else:
+ for i in range(1, get_tp_group().world_size):
+ all_new_notifs.append(get_tp_group().recv_object(src=i))
+
+ request_notif_counter = defaultdict(int)
+ for notifs in all_new_notifs:
+ for req_ids in notifs.values():
......@@ -3637,12 +3827,20 @@ index 819b81fb..d9c039eb 100644
+
+ if request_notif_counter:
+ logger.debug("Request notif counter: %s", request_notif_counter)
+
+ request_done_counter = defaultdict(int)
+ for req_id in self.nixl_connector.get_done_tranfers():
+ request_done_counter[req_id] += 1
+
+ if request_done_counter:
+ logger.debug("Request done counter: %s", request_done_counter)
+
+ else:
+ request_notif_counter = {}
+ request_done_counter = {}
# output is List[SamplerOutput]
- return output
+ return output, request_notif_counter
+ return output, request_notif_counter, request_done_counter
+
+ def _transfer_blocks(self, worker_input: WorkerInput) -> None:
+ pass
......
......@@ -20,7 +20,7 @@ limitations under the License.
## Build docker
```
./container/build.sh --framework VLLM_NIXL --target dev --build-context nixl=<path to downloaded nixl repo @ fc912eb012597be67de11fa9ba0599e4e1974fa2>
./container/build.sh --framework VLLM_NIXL --target dev --build-context nixl=<path to downloaded nixl repo @ c53bb19a6a114e9093071bd1f2904f996ae1839b>
```
## Run container
......@@ -72,10 +72,11 @@ In terminal 2:
```
cd /workspace/examples/python_rs/llm/vllm_nixl
CUDA_VISIBLE_DEVICES=1 python3 worker.py \
CUDA_VISIBLE_DEVICES=1,2 python3 worker.py \
--remote-prefill \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager \
--tensor-parallel-size 2 \
--kv-transfer-config \
'{"kv_connector":"TritonNixlConnector"}'
```
......
......@@ -18,7 +18,7 @@ import asyncio
import msgspec
import uvloop
from common import find_remote_metadata, parse_vllm_args, temp_metadata_file
from common import find_remote_metadata, parse_vllm_args
from vllm.distributed.device_communicators.nixl import NixlMetadata
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.api_server import (
......@@ -69,19 +69,18 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
async with build_async_engine_client_from_engine_args(engine_args) as engine_client:
# This should be replaced with etcd
metadata = engine_client.nixl_metadata
with temp_metadata_file(metadata.engine_id, metadata):
print(f"Waiting for remote metadata for engine {metadata.engine_id}")
remote_metadata: list[NixlMetadata] = []
while not remote_metadata:
await asyncio.sleep(1)
remote_metadata = find_remote_metadata(metadata.engine_id)
print(
f"Found {len(remote_metadata)} remote metadata for engine {metadata.engine_id}"
)
for remote_metadata in remote_metadata:
await engine_client.add_remote_nixl_metadata(remote_metadata)
await endpoint.serve_endpoint(RequestHandler(engine_client).generate)
print(f"Waiting for remote metadata for engine {metadata.engine_id}")
remote_metadata: list[NixlMetadata] = []
while not remote_metadata:
await asyncio.sleep(1)
remote_metadata = find_remote_metadata(metadata.engine_id)
print(
f"Found {len(remote_metadata)} remote metadata for engine {metadata.engine_id}"
)
for remote_metadata in remote_metadata:
await engine_client.add_remote_nixl_metadata(remote_metadata)
await endpoint.serve_endpoint(RequestHandler(engine_client).generate)
if __name__ == "__main__":
......@@ -96,4 +95,12 @@ if __name__ == "__main__":
print("Pipeline parallel size is not supported yet, setting to 1")
engine_args.pipeline_parallel_size = 1
if engine_args.disable_async_output_proc is not True:
print("Async output processing is not supported yet, setting to True")
engine_args.disable_async_output_proc = True
if engine_args.enforce_eager is not True:
print("Prefill must be done eagerly, setting to True")
engine_args.enforce_eager = True
asyncio.run(worker(engine_args))
......@@ -93,6 +93,8 @@ class RequestHandler:
await self.init()
assert self.openai_serving_chat is not None
request.model = "vllm"
if self.do_remote_prefill:
remote_prefill_params = RemotePrefillParams(
is_remote_prefill=True,
......@@ -133,7 +135,7 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
with temp_metadata_file(metadata.engine_id, metadata):
await endpoint.serve_endpoint(
RequestHandler(
model_name=engine_args.model,
model_name="vllm",
engine_client=engine_client,
prefill_client=prefill_client,
do_remote_prefill=True,
......@@ -142,7 +144,7 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
else:
await endpoint.serve_endpoint(
RequestHandler(
model_name=engine_args.model,
model_name="vllm",
engine_client=engine_client,
prefill_client=prefill_client,
do_remote_prefill=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment