Commit f38aa469 authored by ptarasiewiczNV's avatar ptarasiewiczNV Committed by GitHub
Browse files

feat: vLLM v0.7.2 patch supporting XpYd and heterogeneous P/D parallel configs (#167)


Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
Co-authored-by: default avatarNeelay Shah <neelays@nvidia.com>
parent dddebc0d
...@@ -36,3 +36,4 @@ ...@@ -36,3 +36,4 @@
**/.github **/.github
**/*backup*/ **/*backup*/
.dockerignore .dockerignore
**/target/*
\ No newline at end of file
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
exclude: ^src/grpc_generated exclude: ^(src/grpc_generated|.*\.patch$)
repos: repos:
- repo: https://github.com/timothycrosley/isort - repo: https://github.com/timothycrosley/isort
rev: 5.12.0 rev: 5.12.0
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
ARG BASE_IMAGE="nvcr.io/nvidia/tritonserver" ARG BASE_IMAGE="nvcr.io/nvidia/tritonserver"
ARG BASE_IMAGE_TAG="25.01-py3" ARG BASE_IMAGE_TAG="25.01-py3"
ARG VLLM_WHEEL
FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} AS triton-distributed FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} AS triton-distributed
...@@ -76,21 +75,6 @@ RUN --mount=type=bind,source=./container/deps/requirements.tensorrtllm.txt,targe ...@@ -76,21 +75,6 @@ RUN --mount=type=bind,source=./container/deps/requirements.tensorrtllm.txt,targe
--mount=type=bind,source=./container/deps/clone_tensorrtllm.sh,target=/tmp/clone_tensorrtllm.sh \ --mount=type=bind,source=./container/deps/clone_tensorrtllm.sh,target=/tmp/clone_tensorrtllm.sh \
if [[ "$FRAMEWORK" == "TENSORRTLLM" ]] ; then pip install --timeout=2000 -r /tmp/requirements.txt; /tmp/clone_tensorrtllm.sh --tensorrtllm-backend-repo-tag ${TENSORRTLLM_BACKEND_REPO_TAG} --tensorrtllm-backend-rebuild ${TENSORRTLLM_BACKEND_REBUILD} ; fi if [[ "$FRAMEWORK" == "TENSORRTLLM" ]] ; then pip install --timeout=2000 -r /tmp/requirements.txt; /tmp/clone_tensorrtllm.sh --tensorrtllm-backend-repo-tag ${TENSORRTLLM_BACKEND_REPO_TAG} --tensorrtllm-backend-rebuild ${TENSORRTLLM_BACKEND_REBUILD} ; fi
RUN --mount=type=bind,source=./container/deps/requirements.vllm.txt,target=/tmp/requirements.txt \
if [[ "$FRAMEWORK" == "VLLM" ]] ; then pip install --timeout=2000 -r /tmp/requirements.txt ; fi
# NOTE: python3-distro is a system package that causes conflicts with user
# packages installed by `pip`. Removing the system package allows user packages
# to correctly manage dependencies with the `distro` pip user package instead.
RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
if [[ "$FRAMEWORK" == "VLLM" ]]; then \
apt install -y zip && \
apt remove -y python3-distro && \
pip install distro && \
cp -r /tmp/deps/vllm /tmp/local_vllm && \
cd /tmp/local_vllm && \
bash ./prepare_wheel.sh --install --debug --force ; \
fi
RUN --mount=type=bind,source=./container/deps/requirements.standard.txt,target=/tmp/requirements.txt \ RUN --mount=type=bind,source=./container/deps/requirements.standard.txt,target=/tmp/requirements.txt \
if [[ "$FRAMEWORK" == "STANDARD" ]] ; then pip install --timeout=2000 -r /tmp/requirements.txt ; fi if [[ "$FRAMEWORK" == "STANDARD" ]] ; then pip install --timeout=2000 -r /tmp/requirements.txt ; fi
...@@ -102,23 +86,6 @@ ENV LD_LIBRARY_PATH=${FRAMEWORK_LD_LIBRARY_PATH}:${LD_LIBRARY_PATH} ...@@ -102,23 +86,6 @@ ENV LD_LIBRARY_PATH=${FRAMEWORK_LD_LIBRARY_PATH}:${LD_LIBRARY_PATH}
ENV TENSORRTLLM_BACKEND_REPO_TAG=$TENSORRTLLM_BACKEND_REPO_TAG ENV TENSORRTLLM_BACKEND_REPO_TAG=$TENSORRTLLM_BACKEND_REPO_TAG
ENV TRTLLM_USE_MPI_KVCACHE=${TENSORRTLLM_FRAMEWORK:+"1"} ENV TRTLLM_USE_MPI_KVCACHE=${TENSORRTLLM_FRAMEWORK:+"1"}
# TODO set VLLM Version
# ENV VLLM_VERSION
ARG VLLM_FRAMEWORK
# DEFAULT VLLM VARIABLES
ENV VLLM_ATTENTION_BACKEND=${VLLM_FRAMEWORK:+FLASHINFER}
ENV VLLM_WORKER_MULTIPROC_METHOD=${VLLM_FRAMEWORK:+spawn}
ENV VLLM_TORCH_HOST=${VLLM_FRAMEWORK:+localhost}
ENV VLLM_TORCH_PORT=${VLLM_FRAMEWORK:+36183}
ENV VLLM_DATA_PLANE_BACKEND=${VLLM_FRAMEWORK:+nccl}
ENV VLLM_BASELINE_WORKERS=${VLLM_FRAMEWORK:+0}
ENV VLLM_CONTEXT_WORKERS=${VLLM_FRAMEWORK:+1}
ENV VLLM_GENERATE_WORKERS=${VLLM_FRAMEWORK:+1}
ENV VLLM_BASELINE_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV VLLM_CONTEXT_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV VLLM_GENERATE_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV PYTHONUNBUFFERED=1
# Install NATS - pointing toward NATS github instead of binaries.nats.dev due to server instability # Install NATS - pointing toward NATS github instead of binaries.nats.dev due to server instability
RUN wget https://github.com/nats-io/nats-server/releases/download/v2.10.24/nats-server-v2.10.24-amd64.deb && dpkg -i nats-server-v2.10.24-amd64.deb RUN wget https://github.com/nats-io/nats-server/releases/download/v2.10.24/nats-server-v2.10.24-amd64.deb && dpkg -i nats-server-v2.10.24-amd64.deb
...@@ -161,6 +128,16 @@ RUN mkdir /opt/triton && \ ...@@ -161,6 +128,16 @@ RUN mkdir /opt/triton && \
uv build && \ uv build && \
uv pip install dist/triton_distributed_rs*cp312*.whl uv pip install dist/triton_distributed_rs*cp312*.whl
# Install patched vllm
ARG VLLM_REF="v0.7.2"
ARG VLLM_PATCH="vllm_v0.7.2.patch"
RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
if [[ "$FRAMEWORK" == "VLLM" ]]; then \
source /opt/triton/venv/bin/activate && \
bash /tmp/deps/vllm/install.sh --patch /tmp/deps/vllm/${VLLM_PATCH} --ref ${VLLM_REF} --install-cmd "uv pip install --editable" --use-precompiled --installation-dir /opt/vllm ; \
fi
# Install triton_distributed_rs wheel globally in container for tests that # Install triton_distributed_rs wheel globally in container for tests that
# currently run without virtual environment activated. # currently run without virtual environment activated.
# TODO: In future, we may use a virtualenv for everything and remove this. # TODO: In future, we may use a virtualenv for everything and remove this.
......
...@@ -71,7 +71,7 @@ TENSORRTLLM_BACKEND_REBUILD=0 ...@@ -71,7 +71,7 @@ TENSORRTLLM_BACKEND_REBUILD=0
# vllm version installed in the base image. # vllm version installed in the base image.
VLLM_BASE_VERSION=25.01 VLLM_BASE_VERSION=25.01
VLLM_BASE_IMAGE=nvcr.io/nvidia/tritonserver VLLM_BASE_IMAGE=nvcr.io/nvidia/tritonserver
VLLM_BASE_IMAGE_TAG=${VLLM_BASE_VERSION}-vllm-python-py3 VLLM_BASE_IMAGE_TAG=${VLLM_BASE_VERSION}-py3
get_options() { get_options() {
while :; do while :; do
......
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Necessary for vLLM engine.
--extra-index-url https://flashinfer.ai/whl/cu121/torch2.4
flashinfer<0.2.0
# Necessary for vLLM engine.
ninja==1.11.1.3
ucx-py-cu12
# vLLM is installed by patching script
# vllm==0.6.3post1
...@@ -15,4 +15,4 @@ See the License for the specific language governing permissions and ...@@ -15,4 +15,4 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
--> -->
Apply this patch to Python source code from vLLM release [v0.6.3.post1](https://github.com/vllm-project/vllm/releases/tag/v0.6.3.post1). Copy files in ``data_plane`` folder into vLLM folder ``vllm/distributed``. Apply this patch to Python source code from vLLM release [v0.7.2](https://github.com/vllm-project/vllm/releases/tag/v0.7.2).
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# A script to download a Python wheel, patch it, copy additional files,
# repackage it, and optionally install the new wheel.
import logging
import math
import socket
import threading
import typing
import uuid
import torch
import torch.distributed
import tritonserver
import zmq
from triton_distributed.icp.data_plane import (
set_icp_data_type,
set_icp_memory_type,
set_icp_shape,
set_icp_tensor_size,
set_icp_tensor_uri,
)
from triton_distributed.icp.protos.icp_pb2 import ModelInferRequest
from triton_distributed.icp.ucp_data_plane import DataPlaneError, UcpDataPlane
logger = logging.getLogger(__name__)
class VllmUcpDataPlane:
def __init__(
self,
hostname: typing.Optional[str] = None,
port: int = 0,
keep_endpoints_open: bool = False,
) -> None:
self._data_plane = UcpDataPlane(hostname, port, keep_endpoints_open)
@property
def hostname(self) -> str:
return self._data_plane.hostname
@property
def port(self) -> int:
return self._data_plane.port
def connect(self) -> None:
self._data_plane.connect()
def close(self) -> None:
self._data_plane.close()
def put_input_tensor(
self,
tensor: torch.Tensor,
tensor_id: typing.Optional[uuid.UUID] = None,
):
logger.debug(
f"Putting input tensor with id {tensor_id} on {self.hostname}:{self.port}"
)
if tensor.dtype == torch.bfloat16:
tensor = tensor.view(torch.float16)
triton_tensor = tritonserver.Tensor.from_dlpack(tensor)
self._data_plane.put_input_tensor(triton_tensor, tensor_id)
def put_output_tensor(
self,
tensor: torch.Tensor,
tensor_id: typing.Optional[uuid.UUID] = None,
):
logger.debug(
f"Putting input tensor with id {tensor_id} on {self.hostname}:{self.port}"
)
if tensor.dtype == torch.bfloat16:
tensor = tensor.view(torch.float16)
triton_tensor = tritonserver.Tensor.from_dlpack(tensor)
self._data_plane.put_output_tensor(triton_tensor, tensor_id)
def get_tensor(
self,
tensor_uri: str,
shape: typing.Sequence[int],
dtype: torch.dtype,
device_id: int,
) -> torch.Tensor:
logger.debug("Getting tensor from %s", tensor_uri)
result = ModelInferRequest.InferInputTensor()
triton_dtype = {
torch.float32: tritonserver.DataType.FP32,
torch.float16: tritonserver.DataType.FP16,
torch.bfloat16: tritonserver.DataType.FP16,
torch.uint8: tritonserver.DataType.UINT8,
}.get(dtype)
if triton_dtype is None:
raise DataPlaneError(f"Unsupported dtype {dtype}")
tensor_size = math.prod(shape) * dtype.itemsize
set_icp_data_type(result, triton_dtype)
set_icp_shape(result, shape)
set_icp_tensor_uri(result, tensor_uri)
set_icp_memory_type(result, tritonserver.MemoryType.GPU)
set_icp_tensor_size(result, tensor_size)
triton_tensor = self._data_plane.get_tensor(
remote_tensor=result,
requested_memory_type=tritonserver.MemoryType.GPU,
requested_memory_type_id=device_id,
)
tensor = torch.utils.dlpack.from_dlpack(triton_tensor)
if dtype == torch.bfloat16:
tensor = tensor.view(torch.bfloat16)
logger.debug("Got tensor from %s", tensor_uri)
return tensor
class VllmNcclDataPlane:
def __init__(
self,
hostname: str = "",
port: int = 0,
# FIXME: world_size and rank both unused
world_size: int = -1,
rank: int = -1,
) -> None:
if not torch.distributed.is_initialized():
raise RuntimeError("NCCL backend not initialized")
if not hostname:
hostname = socket.gethostname()
if port == 0:
port = 13337 + torch.distributed.get_rank()
self._hostname = hostname
self._port = port
self._rank = torch.distributed.get_rank()
self._world_size: int = world_size
self._current_device = torch.cuda.current_device()
# FIXME: Use stricter type for req value in tuple
self.store: typing.Dict[str, typing.Tuple[torch.Tensor, int, typing.Any]] = {}
self.context = zmq.Context()
self.rep_socket = self.context.socket(zmq.REP)
logger.info(f"Rank {self._rank} binding to {self._hostname}:{self._port}")
self.rep_socket.bind(f"tcp://{self._hostname}:{self._port}")
self._listener_thread = threading.Thread(
target=self.listen_for_requests, daemon=True
)
self._listener_thread.start()
# FIXME: Use stricter ZMQ socket type hint
self.req_sockets: typing.Dict[str, typing.Any] = {}
logger.info(f"Rank {self._rank} connected to the server")
@property
def world_size(self):
return self._world_size
@property
def rank(self):
return self._rank
def put_input_tensor(
self,
tensor: torch.Tensor,
rank: int,
tensor_id: str,
remote_address: typing.Optional[str] = None,
):
return self._put_tensor(tensor, rank, tensor_id, remote_address)
def put_output_tensor(
self,
tensor: torch.Tensor,
rank: int,
tensor_id: str,
remote_address: typing.Optional[str] = None,
):
return self._put_tensor(tensor, rank, tensor_id, remote_address)
def get_tensor(
self,
rank: int,
tensor_id: str,
remote_address: str,
) -> torch.Tensor:
return self._get_tensor(rank, tensor_id, remote_address)
def _put_tensor(
self,
tensor: torch.Tensor,
rank: int,
tensor_id: str,
remote_address: typing.Optional[str] = None,
):
logger.debug(
f"Rank {self._rank} storing tensor with id {tensor_id} of shape {tensor.shape} and dtype {tensor.dtype}"
)
if remote_address is None:
self.store[tensor_id] = (tensor, rank, None)
else:
tensor_shape = "_".join(str(dim) for dim in tensor.shape)
if remote_address not in self.req_sockets:
self.req_sockets[remote_address] = self.context.socket(zmq.REQ)
self.req_sockets[remote_address].connect(f"tcp://{remote_address}")
req_socket = self.req_sockets[remote_address]
req_socket.connect(f"tcp://{remote_address}")
req_socket.send_string(f"PUT {self._rank} {tensor_shape} {tensor_id}")
ret = req_socket.recv_string()
assert ret == "OK"
torch.distributed.isend(tensor, dst=rank)
def _get_tensor(
self,
rank: int,
tensor_id: str,
remote_address: str,
) -> torch.Tensor:
logger.debug(f"Rank {self._rank} receiving tensor from rank {rank}")
if tensor_id in self.store:
tensor, _, req = self.store.pop(tensor_id)
req.wait() # TODO ptarasiewicz we should run other request instead of wait
logger.debug(f"Rank {self._rank} received tensor from rank {rank}")
return tensor
raise NotImplementedError("Getting tensor from remote rank not implemented")
def _receive_tensor(
self,
tensor_id: str,
rank: int,
shape: typing.Sequence[int],
):
tensor = torch.empty(
shape, dtype=torch.uint8, device=f"cuda:{self._current_device}"
)
req = torch.distributed.irecv(tensor, src=rank)
self.store[tensor_id] = (tensor, rank, req)
def listen_for_requests(self):
while True:
cmd, _rank, _shape, tensor_id = self.rep_socket.recv_string().split()
logger.debug(f"Rank {self._rank} received request for tensor {tensor_id}")
self.rep_socket.send_string("OK")
if cmd == "GET":
raise NotImplementedError(
"Getting tensor from remote rank not implemented"
)
elif cmd == "PUT":
rank = int(_rank)
shape = [int(dim) for dim in _shape.split("_")]
self._receive_tensor(tensor_id, rank, shape)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# A script to download a Python wheel, patch it, copy additional files,
# repackage it, and optionally install the new wheel.
# FIXME: Address type checking with divergent interfaces for Ucp/Nccl data planes
# type: ignore
import typing
if typing.TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata # type: ignore
from vllm.attention.backends.abstract import AttentionMetadata # type: ignore
import hashlib
import uuid
from concurrent.futures import ThreadPoolExecutor
import torch
import torch.distributed
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.flashinfer import FlashInferBackend, FlashInferMetadata
from vllm.distributed.data_plane import VllmNcclDataPlane, VllmUcpDataPlane
from vllm.distributed.parallel_state import get_store, get_tp_group # type: ignore
from vllm.logger import init_logger
logger = init_logger(__name__)
_kv_cache_handler = None
def get_kv_cache_handler():
global _kv_cache_handler
if _kv_cache_handler is None:
_kv_cache_handler = KVCacheHandler()
return _kv_cache_handler
class KVCacheHandler:
def __init__(self):
if _kv_cache_handler is not None:
raise ValueError("KVCacheHandler is a singleton")
self._data_plane_backend = envs.VLLM_DATA_PLANE_BACKEND
if self._data_plane_backend == "nccl":
self._data_plane = VllmNcclDataPlane()
self._store = get_store()
logger.info("Store set up")
self._store.set(
f"worker_{envs.VLLM_WORKER_ID}_rank_{get_tp_group().local_rank}",
f"{self._data_plane._hostname}:{self._data_plane._port}",
)
elif self._data_plane_backend == "ucx":
self._data_plane = VllmUcpDataPlane(keep_endpoints_open=True)
self._data_plane.connect()
rank = torch.distributed.get_rank()
is_master = envs.VLLM_WORKER_ID == 0 and rank == 0
self._store = torch.distributed.TCPStore(
envs.VLLM_TORCH_HOST, envs.VLLM_TORCH_PORT, is_master=is_master
)
self._store.set(
f"worker_{envs.VLLM_WORKER_ID}_rank_{rank}",
f"{self._data_plane.hostname}:{self._data_plane.port}",
)
else:
raise ValueError(f"Unknown data plane backend {self._data_plane_backend}")
self._local_store = {}
self.transport_thread = ThreadPoolExecutor(max_workers=1)
logger.info("KVCacheHandler initialized")
def send(
self,
model: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: typing.List[torch.Tensor],
):
with torch.cuda.nvtx.range("KV send"):
self._send(
model,
model_input,
kv_caches,
)
def _send(
self,
model: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: typing.List[torch.Tensor],
):
seq_lens = model_input.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model.start_layer
end_layer = model.end_layer
request_ids = list(model_input.request_ids_to_seq_ids.keys())
_, _, block_size, num_heads, _ = kv_caches[0].shape
attention_backend = _get_attention_backend(model_input.attn_metadata)
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
request_id = request_ids[idx]
keys, values = [], []
logger.debug(
f"seq_len {slen}, start_pos {start_pos}, end_pos {end_pos}, slot_mapping_flat {slot_mapping_flat.shape}"
)
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
if len(model_input.attn_metadata.block_tables[idx]) > 0:
block_inds = torch.range(
0,
block_size - 1,
device=current_slot_mapping.device,
dtype=current_slot_mapping.dtype,
)
additional_inds = (
torch.cat(
[
block_inds + block_size * i
for i in model_input.attn_metadata.block_tables[idx]
]
)
.to(current_slot_mapping.device)
.to(current_slot_mapping.dtype)
)
logger.debug(f"additional_inds: {additional_inds.shape}")
logger.debug(f"current_slot_mapping: {current_slot_mapping.shape}")
current_slot_mapping = torch.cat(
[additional_inds, current_slot_mapping]
)
logger.debug(f"new current_slot_mapping: {current_slot_mapping.shape}")
current_slot_mapping_quotient = current_slot_mapping // block_size
current_slot_mapping_remainder = current_slot_mapping % block_size
logger.debug("kv_caches shape: %s", kv_caches[0].shape)
for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]
if attention_backend == "flash_attn":
key_cache = kv_cache[
0, current_slot_mapping_quotient, current_slot_mapping_remainder
]
value_cache = kv_cache[
1, current_slot_mapping_quotient, current_slot_mapping_remainder
]
elif attention_backend == "flash_infer":
key_cache = kv_cache[
current_slot_mapping_quotient, 0, current_slot_mapping_remainder
]
value_cache = kv_cache[
current_slot_mapping_quotient, 1, current_slot_mapping_remainder
]
else:
raise ValueError(f"Unknown attention backend {attention_backend}")
keys.append(key_cache)
values.append(value_cache)
keys = torch.stack(keys, dim=0) # type: ignore
values = torch.stack(values, dim=0) # type: ignore
tp_multipler = envs.VLLM_GENERATE_TP_SIZE // envs.VLLM_CONTEXT_TP_SIZE
first_rank = envs.VLLM_CONTEXT_WORKERS * envs.VLLM_CONTEXT_TP_SIZE
for i in range(tp_multipler):
num_heads_per_generate_rank = num_heads // tp_multipler
first_head = i * num_heads_per_generate_rank
partial_keys = keys[
:, :, first_head : first_head + num_heads_per_generate_rank, :
].clone() # type: ignore
partial_values = values[
:, :, first_head : first_head + num_heads_per_generate_rank, :
].clone() # type: ignore
target_local_rank = get_tp_group().local_rank * tp_multipler + i
target_rank = target_local_rank + first_rank
# torch.cuda.synchronize()
self._send_tensors(
request_id,
target_rank,
target_local_rank,
partial_keys,
partial_values,
)
logger.debug("Tensors sent")
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
def recv(
self,
model: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: typing.List[torch.Tensor],
):
with torch.cuda.nvtx.range("KV recv"):
self._recv(
model.start_layer,
model.end_layer,
[
model.layers[i].self_attn.attn.kv_cache_dtype
for i in range(model.start_layer, model.end_layer)
],
[
model.layers[i].self_attn.attn._k_scale
for i in range(model.start_layer, model.end_layer)
],
[
model.layers[i].self_attn.attn._v_scale
for i in range(model.start_layer, model.end_layer)
],
model_input,
kv_caches,
)
def _recv(
self,
start_layer: int,
end_layer: int,
kv_cache_dtypes: typing.List[torch.dtype],
k_scales: typing.List[float],
v_scales: typing.List[float],
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: typing.List[torch.Tensor],
):
seq_lens = model_input.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
request_ids = list(model_input.request_ids_to_seq_ids.keys())
_, _, block_size, num_heads, head_dim = kv_caches[0].shape
attention_backend = _get_attention_backend(model_input.attn_metadata)
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
request_id = request_ids[idx]
num_tokens = slen
base_request_id, context_worker_id = request_id.split("___")
context_worker_id = int(context_worker_id)
keys, values = self._recv_tensors(
base_request_id,
context_worker_id,
num_tokens,
end_layer - start_layer,
num_heads,
head_dim,
)
logger.debug(f"Received tensors for request_id {request_id}")
if kv_caches[0].dtype == torch.uint8:
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer("fp8_e4m3")
keys = keys.view(torch_dtype).to(torch.bfloat16)
values = values.view(torch_dtype).to(torch.bfloat16)
logger.debug("Converted caches to torch.blfoat16")
for i in range(start_layer, end_layer):
kv_cache = kv_caches[i - start_layer]
key, value = keys[i], values[i]
if attention_backend == "flash_attn":
key_cache, value_cache = kv_cache[0], kv_cache[1]
elif attention_backend == "flash_infer":
key_cache, value_cache = kv_cache[:, 0], kv_cache[:, 1]
else:
raise ValueError(f"Unknown attention backend {attention_backend}")
ops.reshape_and_cache_flash( # type: ignore
key,
value,
key_cache,
value_cache,
slot_mapping_flat[start_pos:end_pos],
kv_cache_dtypes[i],
k_scales[i],
v_scales[i],
)
logger.debug(f"KV receive DONE for rank {torch.distributed.get_rank()}")
def _send_tensors(self, request_id, target_rank, target_local_rank, keys, values):
logger.debug(
f"Sending tensors for request_id {request_id} to rank {target_rank}"
)
logger.debug(f"Tensor shapes: keys {keys.shape}, values {values.shape}")
logger.debug(f"Tensor dtypes: keys {keys.dtype}, values {values.dtype}")
if self._data_plane_backend == "nccl":
self._send_tensors_nccl(
request_id, target_rank, target_local_rank, keys, values
)
elif self._data_plane_backend == "ucx":
self._send_tensors_ucx(request_id, target_local_rank, keys, values)
def _send_tensors_nccl(
self, request_id, target_rank, target_local_rank, keys, values
):
generate_worker_id = envs.VLLM_CONTEXT_WORKERS
target_addr = self._store.get(
f"worker_{generate_worker_id}_rank_{target_local_rank}"
).decode()
self._data_plane.put_output_tensor(
keys,
rank=target_rank,
tensor_id=f"{request_id}_keys_rank{target_local_rank}",
remote_address=target_addr,
)
self._data_plane.put_output_tensor(
values,
rank=target_rank,
tensor_id=f"{request_id}_values_rank{target_local_rank}",
remote_address=target_addr,
)
def _send_tensors_ucx(self, request_id, target_local_rank, keys, values):
torch.cuda.synchronize()
self._data_plane.put_output_tensor(
keys,
_create_id_from_str(f"{request_id}_keys_rank{target_local_rank}"),
)
self._data_plane.put_output_tensor(
values,
_create_id_from_str(f"{request_id}_values_rank{target_local_rank}"),
)
def _recv_tensors(
self, request_id, context_worker_id, num_tokens, num_layers, num_heads, head_dim
):
logger.debug(
f"Receiving tensors for request_id {request_id} from worker {context_worker_id}"
)
if self._data_plane_backend == "nccl":
return self._recv_tensors_nccl(
request_id,
context_worker_id,
num_tokens,
num_layers,
num_heads,
head_dim,
)
elif self._data_plane_backend == "ucx":
return self._recv_tensors_ucx(
request_id,
context_worker_id,
num_tokens,
num_layers,
num_heads,
head_dim,
)
def _recv_tensors_nccl(
self, request_id, context_worker_id, num_tokens, num_layers, num_heads, head_dim
):
tp_rank = get_tp_group().local_rank
tp_multipler = envs.VLLM_GENERATE_TP_SIZE // envs.VLLM_CONTEXT_TP_SIZE
source_tp_rank = tp_rank // tp_multipler
source_rank = context_worker_id * envs.VLLM_CONTEXT_TP_SIZE + source_tp_rank
worker_key = f"worker_{context_worker_id}_rank_{source_tp_rank}"
source_addr = self._local_store.get(worker_key)
if source_addr is None:
logger.info(
"Fetching source address for worker %d by key %s",
context_worker_id,
worker_key,
)
source_addr = self._store.get(worker_key).decode()
self._local_store[worker_key] = source_addr
keys = self._data_plane.get_tensor(
rank=source_rank,
tensor_id=f"{request_id}_keys_rank{tp_rank}",
remote_address=source_addr,
)
values = self._data_plane.get_tensor(
rank=source_rank,
tensor_id=f"{request_id}_values_rank{tp_rank}",
remote_address=source_addr,
)
return keys, values
def _recv_tensors_ucx(
self, request_id, context_worker_id, num_tokens, num_layers, num_heads, head_dim
):
local_rank = get_tp_group().local_rank
tp_rank = get_tp_group().local_rank
tp_multipler = envs.VLLM_GENERATE_TP_SIZE // envs.VLLM_CONTEXT_TP_SIZE
source_tp_rank = tp_rank // tp_multipler
source_addr = self._store.get(
f"worker_{context_worker_id}_rank_{source_tp_rank}"
).decode()
keys_id = _create_id_from_str(f"{request_id}_keys_rank{local_rank}")
keys_uri = f"ucp://{source_addr}/{keys_id}"
keys = self._data_plane.get_tensor(
keys_uri,
(num_layers, num_tokens, num_heads, head_dim),
torch.uint8,
device_id=local_rank,
)
values_id = _create_id_from_str(f"{request_id}_values_rank{local_rank}")
values_uri = f"ucp://{source_addr}/{values_id}"
values = self._data_plane.get_tensor(
values_uri,
(num_layers, num_tokens, num_heads, head_dim),
torch.uint8,
device_id=local_rank,
)
return keys, values
def _get_attention_backend(attn_metadata: "AttentionMetadata") -> str:
if isinstance(attn_metadata, FlashAttentionMetadata):
return "flash_attn"
elif isinstance(attn_metadata, FlashInferMetadata):
return "flash_infer"
else:
raise ValueError(
f"Unknown attention metadata type {type(attn_metadata)}. Only FlashAttentionMetadata and FlashInferMetadata are supported."
)
def _create_id_from_str(str_id: str) -> uuid.UUID:
# Create a hash of the str string using SHA-1
hashed_key = hashlib.sha1(str_id.encode("utf-8"))
# Generate a UUID from the hash, ensuring it's in the correct format
hash_hex = hashed_key.hexdigest()[:32] # Get first 32 characters
uuid_generated = uuid.UUID(hash_hex)
return uuid_generated
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
# Print usage information
print_usage() {
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --patch PATH Apply a patch file during installation"
echo " --ref REF Specify the vLLM git reference (branch/tag/commit) to install"
echo " --install-cmd CMD Specify the installation command (default: 'pip install')"
echo " --use-precompiled Use precompiled kernels during installation"
echo " --installation-dir DIR Specify the installation directory (default: 'vllm')"
echo " --help Show this help message"
}
# Default values
INSTALL_CMD="pip install"
VLLM_REF="main"
PATCH_PATH=""
USE_PRECOMPILED=false
# Parse arguments
while [[ $# -gt 0 ]]; do
case $1 in
--patch)
PATCH_PATH="$2"
shift 2
;;
--ref)
VLLM_REF="$2"
shift 2
;;
--install-cmd)
INSTALL_CMD="$2"
shift 2
;;
--use-precompiled)
USE_PRECOMPILED=true
shift
;;
--installation-dir)
INSTALLATION_DIR="$2"
shift 2
;;
--help)
print_usage
exit 0
;;
*)
echo "Unknown argument: $1"
print_usage
exit 1
;;
esac
done
# Create temp directory and clean it up on exit
# Convert patch path to absolute path if it's relative
if [[ ! "$PATCH_PATH" = /* ]]; then
PATCH_PATH="$(pwd)/${PATCH_PATH}"
fi
# Clone vLLM repository
echo "Cloning vLLM repository at ref: $VLLM_REF"
git clone https://github.com/vllm-project/vllm.git "$INSTALLATION_DIR"
cd "$INSTALLATION_DIR"
git checkout "$VLLM_REF"
# Apply patch if provided
if [ -n "$PATCH_PATH" ]; then
echo "Applying patch from: $PATCH_PATH"
git apply "$PATCH_PATH"
fi
# Install using specified command
echo "Installing using: $INSTALL_CMD"
if [ "$USE_PRECOMPILED" = true ]; then
echo "Using precompiled kernels"
export VLLM_USE_PRECOMPILED=1
fi
$INSTALL_CMD .
echo "Installation complete!"
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
# Function to print usage
print_usage() {
echo "Usage: $0 --original-ref <original_tag_or_branch> --fork-repo <fork_repo_url> --fork-ref <fork_tag_or_branch> --output <patch_output_path>"
echo
echo "Arguments:"
echo " --original-ref The tag or branch name from the original vllm-project/vllm repo"
echo " --fork-repo The URL of the forked repository"
echo " --fork-ref The tag or branch name from the forked repository"
echo " --output Path where the generated patch file should be saved"
echo
echo "Example:"
echo " $0 --original-ref v0.2.0 --fork-repo https://github.com/user/vllm.git --fork-ref feature-branch --output ./my-patch.diff"
exit 1
}
# Parse named arguments
while [[ $# -gt 0 ]]; do
case $1 in
--original-ref)
ORIGINAL_REF="$2"
shift 2
;;
--fork-repo)
FORK_REPO="$2"
shift 2
;;
--fork-ref)
FORK_REF="$2"
shift 2
;;
--output)
PATCH_OUTPUT="$2"
shift 2
;;
*)
print_usage
;;
esac
done
# Check if all required arguments are provided
if [ -z "$ORIGINAL_REF" ] || [ -z "$FORK_REPO" ] || [ -z "$FORK_REF" ] || [ -z "$PATCH_OUTPUT" ]; then
print_usage
fi
# Convert patch output path to absolute path if it's relative
if [[ ! "$PATCH_OUTPUT" = /* ]]; then
PATCH_OUTPUT="$(pwd)/${PATCH_OUTPUT}"
fi
TEMP_DIR=$(mktemp -d)
# Clean up temp directory on script exit
trap 'rm -rf "$TEMP_DIR"' EXIT
# Clone original vLLM to a temp directory
git clone https://github.com/vllm-project/vllm.git "$TEMP_DIR/original_vllm"
cd "$TEMP_DIR/original_vllm"
git remote add fork "$FORK_REPO"
git fetch fork "$FORK_REF"
git diff "$ORIGINAL_REF" fork/"$FORK_REF" > "$PATCH_OUTPUT"
echo "Patch created successfully: $PATCH_OUTPUT"
\ No newline at end of file
#!/bin/bash -e
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# A script to download a Python wheel, patch it, copy additional files,
# repackage it, and optionally install the new wheel.
###############################################################################
# CONFIGURATION & DEFAULTS
###############################################################################
DEFAULT_WHEEL_URL="https://files.pythonhosted.org/packages/4a/4c/ee65ba33467a4c0de350ce29fbae39b9d0e7fcd887cc756fa993654d1228/vllm-0.6.3.post1-cp38-abi3-manylinux1_x86_64.whl"
DEFAULT_PATCH_FILE="vllm_patch_063post1.patch"
DEFAULT_DATA_PLANE_DIR="data_plane"
DEFAULT_WHEEL_DIR="wheel"
DEFAULT_OUTPUT_WHEEL="vllm-dist-0.6.3.post1-cp38-abi3-manylinux1_x86_64.whl"
# Optionally set a default SHA256 checksum for the downloaded wheel
# DEFAULT_CHECKSUM="SOME_SHA256_HERE"
###############################################################################
# HELPER FUNCTIONS
###############################################################################
usage() {
cat << EOF
Usage: $0 [OPTIONS]
Options:
-u, --url <URL> Wheel URL to download (default: $DEFAULT_WHEEL_URL)
-p, --patch <FILE> Patch file path (default: $DEFAULT_PATCH_FILE)
-d, --data-plane <DIR> Directory with additional data-plane files (default: $DEFAULT_DATA_PLANE_DIR)
-w, --wheel-dir <DIR> Extract destination directory (default: $DEFAULT_WHEEL_DIR)
-o, --output-wheel <FILE> Name/path for the repackaged wheel (default: $DEFAULT_OUTPUT_WHEEL)
-f, --force Force overwriting existing directories/files without prompts
-i, --install Install the new wheel after repackaging
-D, --debug Enable debug mode (verbose output)
-h, --help Show this help and exit
Example:
$0 \\
--url "https://example.com/path/to/vllm.whl" \\
--patch my_patch.patch \\
--data-plane custom_data/ \\
--wheel-dir extracted_wheel \\
--output-wheel my_vllm_dist.whl \\
--install \\
--force \\
--debug
EOF
}
info_log() {
# Always prints, showing a timestamp.
echo "[$(date +'%Y-%m-%d %H:%M:%S')] INFO: $*"
}
debug_log() {
# Prints only if DEBUG=true
if [[ "$DEBUG" == true ]]; then
echo "[$(date +'%Y-%m-%d %H:%M:%S')] DEBUG: $*"
fi
}
error_exit() {
echo "ERROR: $*" >&2
exit 1
}
###############################################################################
# PARSE ARGUMENTS
###############################################################################
FORCE_OVERWRITE=false
INSTALL_WHEEL=false
DEBUG=false
WHEEL_URL="$DEFAULT_WHEEL_URL"
PATCH_FILE="$DEFAULT_PATCH_FILE"
DATA_PLANE_DIR="$DEFAULT_DATA_PLANE_DIR"
WHEEL_DIR="$DEFAULT_WHEEL_DIR"
OUTPUT_WHEEL="$DEFAULT_OUTPUT_WHEEL"
while [[ $# -gt 0 ]]; do
case "$1" in
-u|--url)
WHEEL_URL="$2"
shift 2
;;
-p|--patch)
PATCH_FILE="$2"
shift 2
;;
-d|--data-plane)
DATA_PLANE_DIR="$2"
shift 2
;;
-w|--wheel-dir)
WHEEL_DIR="$2"
shift 2
;;
-o|--output-wheel)
OUTPUT_WHEEL="$2"
shift 2
;;
-f|--force)
FORCE_OVERWRITE=true
shift
;;
-i|--install)
INSTALL_WHEEL=true
shift
;;
-D|--debug)
DEBUG=true
shift
;;
-h|--help)
usage
exit 0
;;
*)
error_exit "Unknown option: $1"
;;
esac
done
###############################################################################
# MAIN SCRIPT
###############################################################################
# Enable debug mode if requested
if [[ "$DEBUG" == true ]]; then
set -x
fi
info_log "Starting wheel patching script..."
# Function to check if a command exists
command_exists() {
command -v "$1" >/dev/null 2>&1 || { echo >&2 "I require $1 but it's not installed. Aborting."; exit 1; }
}
# Check for required commands
command_exists pip
command_exists unzip
command_exists zip
command_exists patch
# ---------------------------------------------------------------------------
# 1. Check for existing wheel file or directory
# ---------------------------------------------------------------------------
WHEEL_FILENAME=$(basename "$WHEEL_URL")
if [[ -f "$WHEEL_FILENAME" && "$FORCE_OVERWRITE" != true ]]; then
info_log "File '$WHEEL_FILENAME' already exists. Reusing existing file."
info_log "If you want to redownload, remove '$WHEEL_FILENAME' or use --force."
else
info_log "Downloading wheel from $WHEEL_URL..."
rm -f "$WHEEL_FILENAME" 2>/dev/null || true # Remove existing file if forcing
wget -O "$WHEEL_FILENAME" "$WHEEL_URL"
fi
# ---------------------------------------------------------------------------
# 2. Optional: Verify checksum (commented out by default)
# ---------------------------------------------------------------------------
# if [[ -n "$DEFAULT_CHECKSUM" ]]; then
# info_log "Verifying SHA256 checksum..."
# echo "${DEFAULT_CHECKSUM} ${WHEEL_FILENAME}" | sha256sum --check - || error_exit "Checksum mismatch!"
# fi
# ---------------------------------------------------------------------------
# 3. Create/clean wheel extraction directory
# ---------------------------------------------------------------------------
if [[ -d "$WHEEL_DIR" ]]; then
if [[ "$FORCE_OVERWRITE" == true ]]; then
info_log "Removing existing directory '$WHEEL_DIR' due to --force..."
rm -rf "$WHEEL_DIR"
else
error_exit "Directory '$WHEEL_DIR' already exists. Use --force to overwrite."
fi
fi
info_log "Creating directory '$WHEEL_DIR'..."
mkdir -p "$WHEEL_DIR"
# ---------------------------------------------------------------------------
# 4. Unzip the wheel into the specified directory
# ---------------------------------------------------------------------------
info_log "Unzipping wheel into directory '$WHEEL_DIR'..."
unzip -q "$WHEEL_FILENAME" -d "$WHEEL_DIR"
# ---------------------------------------------------------------------------
# 5. Check/Apply patch
# ---------------------------------------------------------------------------
if [[ ! -f "$PATCH_FILE" ]]; then
error_exit "Patch file '$PATCH_FILE' not found."
fi
PATCH_TARGET_DIR="$WHEEL_DIR/vllm"
if [[ ! -d "$PATCH_TARGET_DIR" ]]; then
error_exit "Could not find directory '$PATCH_TARGET_DIR' in unzipped wheel."
fi
info_log "Applying patch '$PATCH_FILE' to '$PATCH_TARGET_DIR'..."
debug_log "Executing: (cd \"$PATCH_TARGET_DIR\" && patch -p1 < \"../../$PATCH_FILE\")"
(
cd "$PATCH_TARGET_DIR"
patch -p1 < "../../$PATCH_FILE"
)
# ---------------------------------------------------------------------------
# 6. Copy data plane files
# ---------------------------------------------------------------------------
if [[ ! -d "$DATA_PLANE_DIR" ]]; then
error_exit "Data plane directory '$DATA_PLANE_DIR' not found."
fi
info_log "Copying files from '$DATA_PLANE_DIR' to '$PATCH_TARGET_DIR/distributed'..."
mkdir -p "$PATCH_TARGET_DIR/distributed"
cp -r "$DATA_PLANE_DIR/"* "$PATCH_TARGET_DIR/distributed/"
# ---------------------------------------------------------------------------
# 7. Re-package into a new wheel
# ---------------------------------------------------------------------------
if [[ -f "$OUTPUT_WHEEL" && "$FORCE_OVERWRITE" != true ]]; then
error_exit "Output wheel '$OUTPUT_WHEEL' already exists. Use --force to overwrite."
fi
info_log "Creating new wheel file '$OUTPUT_WHEEL'..."
debug_log "Executing: (cd \"$WHEEL_DIR\" && zip -rq \"../$OUTPUT_WHEEL\" .)"
(
cd "$WHEEL_DIR"
zip -rq "../$OUTPUT_WHEEL" .
)
# ---------------------------------------------------------------------------
# 8. Optional: Install the new wheel
# ---------------------------------------------------------------------------
if [[ "$INSTALL_WHEEL" == true ]]; then
# Check if pip is installed
if ! command -v pip >/dev/null 2>&1; then
error_exit "pip is not installed or not found in PATH."
fi
info_log "Installing newly created wheel '$OUTPUT_WHEEL'..."
pip install --force-reinstall --upgrade --break-system-packages "$OUTPUT_WHEEL"
fi
info_log "Patch and repackage completed successfully!"
info_log "New wheel: $OUTPUT_WHEEL"
if [[ "$INSTALL_WHEEL" == true ]]; then
info_log "Wheel has been installed."
fi
...@@ -27,14 +27,4 @@ pytestmark = pytest.mark.pre_merge ...@@ -27,14 +27,4 @@ pytestmark = pytest.mark.pre_merge
@pytest.mark.skipif(vllm is None, reason="Skipping vllm tests, vllm not installed") @pytest.mark.skipif(vllm is None, reason="Skipping vllm tests, vllm not installed")
def test_version(): def test_version():
# Verify that the image has the patched version of vllm # Verify that the image has the patched version of vllm
assert vllm.__version__ == "0.6.3.post2.dev16+gf61960ce" assert vllm.__version__.startswith("0.7.3.dev")
@pytest.mark.skipif(vllm is None, reason="Skipping vllm tests, vllm not installed")
def test_patch_imports():
# Verify patched files have no glaring syntax or import issues
import vllm.distributed.data_plane as d
import vllm.distributed.kv_cache as k
# Placeholder to avoid unused import errors or removal by linters
assert d, k
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
diff -Naur v0.6.3.post1_vllm/_version.py patched_vllm/_version.py
--- v0.6.3.post1_vllm/_version.py 2025-01-09 03:03:32.439278263 -0800
+++ patched_vllm/_version.py 2025-01-09 01:49:43.785300620 -0800
@@ -12,5 +12,5 @@
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE
-__version__ = version = '0.6.3.post1'
-__version_tuple__ = version_tuple = (0, 6, 3)
+__version__ = version = '0.6.3.post2.dev16+gf61960ce'
+__version_tuple__ = version_tuple = (0, 6, 3, 'dev16', 'gf61960ce')
diff -Naur v0.6.3.post1_vllm/config.py patched_vllm/config.py
--- v0.6.3.post1_vllm/config.py 2025-01-09 03:03:32.439278263 -0800
+++ patched_vllm/config.py 2025-01-09 01:49:43.785300620 -0800
@@ -1046,7 +1046,7 @@
"sequences. Please increase max_num_batched_tokens or "
"decrease max_model_len.")
- if self.max_num_batched_tokens < self.max_num_seqs:
+ if self.max_num_seqs is not None and self.max_num_batched_tokens < self.max_num_seqs:
raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
"be greater than or equal to max_num_seqs "
diff -Naur v0.6.3.post1_vllm/core/scheduler.py patched_vllm/core/scheduler.py
--- v0.6.3.post1_vllm/core/scheduler.py 2025-01-09 03:03:32.291290245 -0800
+++ patched_vllm/core/scheduler.py 2025-01-09 01:49:43.785300620 -0800
@@ -17,6 +17,7 @@
SequenceGroupMetadata, SequenceGroupMetadataDelta,
SequenceStatus)
from vllm.utils import Device, PyObjectCache
+import vllm.envs as envs
logger = init_logger(__name__)
@@ -883,12 +884,17 @@
assert len(waiting_seqs) == 1, (
"Waiting sequence group should have only one prompt "
"sequence.")
- num_new_tokens = self._get_num_new_tokens(seq_group,
- SequenceStatus.WAITING,
- enable_chunking, budget)
- if not enable_chunking:
- num_prompt_tokens = waiting_seqs[0].get_len()
- assert num_new_tokens == num_prompt_tokens
+ real_num_new_tokens = self._get_num_new_tokens(seq_group,
+ SequenceStatus.WAITING,
+ enable_chunking, budget)
+ if envs.VLLM_DISAGG_STAGE == "GENERATE":
+ num_new_tokens = 1
+ assert not enable_chunking
+ else:
+ num_new_tokens = real_num_new_tokens
+ if not enable_chunking:
+ num_prompt_tokens = waiting_seqs[0].get_len()
+ assert num_new_tokens == num_prompt_tokens
prompt_limit = self._get_prompt_limit(seq_group)
if num_new_tokens > prompt_limit:
@@ -967,7 +973,7 @@
seq_groups.append(
ScheduledSequenceGroup(seq_group=seq_group,
- token_chunk_size=num_new_tokens))
+ token_chunk_size=real_num_new_tokens))
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
budget.add_num_seqs(seq_group.request_id, num_new_seqs)
diff -Naur v0.6.3.post1_vllm/distributed/parallel_state.py patched_vllm/distributed/parallel_state.py
--- v0.6.3.post1_vllm/distributed/parallel_state.py 2025-01-09 03:03:32.291290245 -0800
+++ patched_vllm/distributed/parallel_state.py 2025-01-09 01:49:43.785300620 -0800
@@ -889,6 +889,14 @@
get_pipeline_model_parallel_group = get_pp_group
+_STORE: Optional[Any] = None
+
+def get_store() -> Any:
+ assert _STORE is not None, ("store is not initialized")
+ return _STORE
+
+
+
@contextmanager
def graph_capture():
"""
@@ -926,20 +934,51 @@
local_rank: int = -1,
backend: str = "nccl",
):
+
+ # TODO ptarasiewicz this is a hack to get the stage from the environment
+ logger.info("="*50)
+ logger.info("Patching init_distributed_environment")
+ stage = envs.VLLM_DISAGG_STAGE
+ logger.info(f"Stage: {stage}")
+ store = None
+ if stage is not None and envs.VLLM_DATA_PLANE_BACKEND == "nccl":
+ context_workers = envs.VLLM_CONTEXT_WORKERS
+ context_tp_size = envs.VLLM_CONTEXT_TP_SIZE
+ generate_workers = envs.VLLM_GENERATE_WORKERS
+ generate_tp_size = envs.VLLM_GENERATE_TP_SIZE
+ world_size = context_workers * context_tp_size + generate_workers * generate_tp_size
+ if stage == "PREFILL":
+ worker_id = envs.VLLM_WORKER_ID
+ rank += worker_id * context_tp_size
+ if stage == "GENERATE":
+ rank += context_workers * context_tp_size # TODO ptarasiewicz this only works for 1 generate worker
+ # distributed_init_method = f"tcp://{envs.VLLM_TORCH_HOST}:{envs.VLLM_TORCH_PORT}"
+ distributed_init_method = None
+ store = torch.distributed.TCPStore(envs.VLLM_TORCH_HOST, envs.VLLM_TORCH_PORT, world_size=world_size, is_master = rank == 0)
+ logger.info(f"world_size: {world_size}, rank: {rank}, distributed_init_method: {distributed_init_method}, local_rank: {local_rank}, backend: {backend}")
+
logger.debug(
"world_size=%d rank=%d local_rank=%d "
"distributed_init_method=%s backend=%s", world_size, rank, local_rank,
distributed_init_method, backend)
if not torch.distributed.is_initialized():
- assert distributed_init_method is not None, (
- "distributed_init_method must be provided when initializing "
+ assert distributed_init_method is not None or store is not None, (
+ "distributed_init_method or store must be provided when initializing "
"distributed environment")
# this backend is used for WORLD
- torch.distributed.init_process_group(
- backend=backend,
- init_method=distributed_init_method,
- world_size=world_size,
- rank=rank)
+ if store is None:
+ torch.distributed.init_process_group(
+ backend=backend,
+ init_method=distributed_init_method,
+ world_size=world_size,
+ rank=rank)
+ else:
+ torch.distributed.init_process_group(
+ backend=backend,
+ # init_method=distributed_init_method,
+ world_size=world_size,
+ rank=rank,
+ store=store)
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
@@ -958,6 +997,10 @@
assert _WORLD.world_size == torch.distributed.get_world_size(), (
"world group already initialized with a different world size")
+ global _STORE
+ if store is not None and _STORE is None:
+ _STORE = store
+
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
@@ -986,32 +1029,60 @@
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
+ logger.debug("="*50)
+ logger.debug("Patching initialize_model_parallel")
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
- if (world_size !=
- tensor_model_parallel_size * pipeline_model_parallel_size):
- raise RuntimeError(
- f"world_size ({world_size}) is not equal to "
- f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
- f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
+
+ # TODO ptarasiewicz this assertion does not work with disagg
+ # if (world_size !=
+ # tensor_model_parallel_size * pipeline_model_parallel_size):
+ # raise RuntimeError(
+ # f"world_size ({world_size}) is not equal to "
+ # f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
+ # f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
# Build the tensor model-parallel groups.
num_tensor_model_parallel_groups: int = (world_size //
- tensor_model_parallel_size)
+ tensor_model_parallel_size)
+
+ stage = envs.VLLM_DISAGG_STAGE
+
global _TP
assert _TP is None, ("tensor model parallel group is already initialized")
group_ranks = []
for i in range(num_tensor_model_parallel_groups):
- ranks = list(
- range(i * tensor_model_parallel_size,
- (i + 1) * tensor_model_parallel_size))
- group_ranks.append(ranks)
+
+ # TODO ptarasiewicz this is a hack to adjust the ranks for the stage
+ if stage is not None and envs.VLLM_DATA_PLANE_BACKEND == "nccl":
+ ranks = []
+ context_workers = envs.VLLM_CONTEXT_WORKERS
+ context_tp_size = envs.VLLM_CONTEXT_TP_SIZE
+ generate_workers = envs.VLLM_GENERATE_WORKERS
+ generate_tp_size = envs.VLLM_GENERATE_TP_SIZE
+ for context_id in range(context_workers):
+ ranks.append(
+ [context_id * context_tp_size + i for i in range(context_tp_size)]
+ )
+ for generate_id in range(generate_workers):
+ ranks.append(
+ [context_workers * context_tp_size + generate_id * generate_tp_size + i for i in range(generate_tp_size)]
+ )
+ group_ranks.extend(ranks)
+ break
+ else:
+ ranks = list(
+ range(i * tensor_model_parallel_size,
+ (i + 1) * tensor_model_parallel_size))
+ group_ranks.append(ranks)
# message queue broadcaster is only used in tensor model parallel group
+ logger.debug("initializing tensor model parallel group")
+ logger.debug(f"group_ranks {group_ranks}")
_TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
@@ -1020,15 +1091,32 @@
# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups: int = (world_size //
- pipeline_model_parallel_size)
+ pipeline_model_parallel_size)
+
global _PP
assert _PP is None, (
"pipeline model parallel group is already initialized")
group_ranks = []
for i in range(num_pipeline_model_parallel_groups):
- ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
- group_ranks.append(ranks)
+
+
+
+ # TODO ptarasiewicz this is a hack to adjust the ranks for the stage
+ if stage is not None and envs.VLLM_DATA_PLANE_BACKEND == "nccl":
+ context_workers = envs.VLLM_CONTEXT_WORKERS
+ context_tp_size = envs.VLLM_CONTEXT_TP_SIZE
+ generate_workers = envs.VLLM_GENERATE_WORKERS
+ generate_tp_size = envs.VLLM_GENERATE_TP_SIZE
+ world_size = context_workers * context_tp_size + generate_workers * generate_tp_size
+ ranks = [[i] for i in range(world_size)]
+ group_ranks.extend(ranks)
+ break
+ else:
+ ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
+ group_ranks.append(ranks)
# pipeline parallel does not need custom allreduce
+ logger.debug("initializing pipeline model parallel group")
+ logger.debug(f"group_ranks {group_ranks}")
_PP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
diff -Naur v0.6.3.post1_vllm/engine/async_llm_engine.py patched_vllm/engine/async_llm_engine.py
--- v0.6.3.post1_vllm/engine/async_llm_engine.py 2025-01-09 03:03:32.443277939 -0800
+++ patched_vllm/engine/async_llm_engine.py 2025-01-09 01:49:43.785300620 -0800
@@ -371,7 +371,7 @@
# multi_step_model_runner does the first-step output append.
is_first_step_output: bool = False if not seq_group_metadata_list \
else seq_group_metadata_list[0].state.num_steps == 1
-
+
ctx.append_output(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
diff -Naur v0.6.3.post1_vllm/engine/llm_engine.py patched_vllm/engine/llm_engine.py
--- v0.6.3.post1_vllm/engine/llm_engine.py 2025-01-09 03:03:32.443277939 -0800
+++ patched_vllm/engine/llm_engine.py 2025-01-09 01:49:43.785300620 -0800
@@ -479,8 +479,28 @@
The workers will determine the number of blocks in both the GPU cache
and the swap CPU cache.
"""
- num_gpu_blocks, num_cpu_blocks = (
- self.model_executor.determine_num_available_blocks())
+
+
+ max_num_seqs = None
+ if self.scheduler_config.max_num_seqs is not None:
+ num_gpu_blocks, num_cpu_blocks = (
+ self.model_executor.determine_num_available_blocks())
+ else:
+ max_num_seqs = 1
+ max_concurrency = None
+ max_iter_count_left = 5
+ while True:
+ logger.info("Profiling with %d sequences", max_num_seqs)
+ num_gpu_blocks, num_cpu_blocks = (
+ self.model_executor.determine_num_available_blocks(max_num_seqs))
+ max_concurrency = (num_gpu_blocks * self.cache_config.block_size / self.model_config.max_model_len)
+ logger.info("Maximum concurrency for %d sequences and %s tokens per request: %.2fx",
+ max_num_seqs, self.model_config.max_model_len, max_concurrency)
+ max_iter_count_left -= 1
+ if max_iter_count_left < 1 or (max_concurrency - max_num_seqs) ** 2 < 2:
+ break
+ max_num_seqs = int(max_concurrency)
+
if self.cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
@@ -494,6 +514,10 @@
self.cache_config.num_cpu_blocks = num_cpu_blocks
self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
+
+ if max_num_seqs is not None and self.scheduler_config.max_num_seqs is None:
+ self.scheduler_config.max_num_seqs = max_num_seqs
+ logger.info("Setting max_num_seqs to %d", max_num_seqs)
@classmethod
def _get_executor_cls(cls,
diff -Naur v0.6.3.post1_vllm/envs.py patched_vllm/envs.py
--- v0.6.3.post1_vllm/envs.py 2025-01-09 03:03:32.439278263 -0800
+++ patched_vllm/envs.py 2025-01-09 01:49:43.789300297 -0800
@@ -66,6 +66,15 @@
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_TORCH_COMPILE_LEVEL: int = 0
VLLM_DISABLED_KERNELS: List[str] = []
+ VLLM_DISAGG_STAGE: Optional[str] = None
+ VLLM_CONTEXT_TP_SIZE: int = 0
+ VLLM_CONTEXT_WORKERS: int = 0
+ VLLM_GENERATE_TP_SIZE: int = 0
+ VLLM_GENERATE_WORKERS: int = 0
+ VLLM_TORCH_HOST: str = "localhost"
+ VLLM_TORCH_PORT: int = 36183
+ VLLM_WORKER_ID: int = 0
+ VLLM_DATA_PLANE_BACKEND: str = "nccl"
def get_default_cache_root():
@@ -433,6 +442,35 @@
"VLLM_DISABLED_KERNELS":
lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[
"VLLM_DISABLED_KERNELS"].split(","),
+
+ # Stage of the disaggregated model
+ "VLLM_DISAGG_STAGE":
+ lambda: os.getenv("VLLM_DISAGG_STAGE", None),
+
+ # Disaggregation global config
+ "VLLM_CONTEXT_TP_SIZE":
+ lambda: int(os.getenv("VLLM_CONTEXT_TP_SIZE", "0")),
+
+ "VLLM_CONTEXT_WORKERS":
+ lambda: int(os.getenv("VLLM_CONTEXT_WORKERS", "0")),
+
+ "VLLM_GENERATE_TP_SIZE":
+ lambda: int(os.getenv("VLLM_GENERATE_TP_SIZE", "0")),
+
+ "VLLM_GENERATE_WORKERS":
+ lambda: int(os.getenv("VLLM_GENERATE_WORKERS", "0")),
+
+ "VLLM_TORCH_HOST":
+ lambda: os.getenv("VLLM_TORCH_HOST", "localhost"),
+
+ "VLLM_TORCH_PORT":
+ lambda: int(os.getenv("VLLM_TORCH_PORT", "36183")),
+
+ "VLLM_WORKER_ID":
+ lambda: int(os.getenv("VLLM_WORKER_ID", "0")),
+
+ "VLLM_DATA_PLANE_BACKEND":
+ lambda: os.getenv("VLLM_DATA_PLANE_BACKEND", "nccl"),
}
# end-env-vars-definition
diff -Naur v0.6.3.post1_vllm/executor/distributed_gpu_executor.py patched_vllm/executor/distributed_gpu_executor.py
--- v0.6.3.post1_vllm/executor/distributed_gpu_executor.py 2025-01-09 03:03:32.443277939 -0800
+++ patched_vllm/executor/distributed_gpu_executor.py 2025-01-09 01:49:43.789300297 -0800
@@ -25,7 +25,7 @@
super().__init__(*args, **kwargs)
- def determine_num_available_blocks(self) -> Tuple[int, int]:
+ def determine_num_available_blocks(self, max_num_seqs: Optional[int] = None) -> Tuple[int, int]:
"""Determine the number of available KV blocks.
This invokes `determine_num_available_blocks` on each worker and takes
@@ -36,7 +36,7 @@
- tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
- num_blocks = self._run_workers("determine_num_available_blocks", )
+ num_blocks = self._run_workers("determine_num_available_blocks", max_num_seqs=max_num_seqs)
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
diff -Naur v0.6.3.post1_vllm/executor/gpu_executor.py patched_vllm/executor/gpu_executor.py
--- v0.6.3.post1_vllm/executor/gpu_executor.py 2025-01-09 03:03:32.443277939 -0800
+++ patched_vllm/executor/gpu_executor.py 2025-01-09 01:49:43.789300297 -0800
@@ -107,11 +107,11 @@
rank=rank,
distributed_init_method=distributed_init_method))
- def determine_num_available_blocks(self) -> Tuple[int, int]:
+ def determine_num_available_blocks(self, max_num_seqs: Optional[int] = None) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
- return self.driver_worker.determine_num_available_blocks()
+ return self.driver_worker.determine_num_available_blocks(max_num_seqs=max_num_seqs)
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
"""Initialize the KV cache by invoking the underlying worker.
diff -Naur v0.6.3.post1_vllm/worker/model_runner.py patched_vllm/worker/model_runner.py
--- v0.6.3.post1_vllm/worker/model_runner.py 2025-01-09 03:03:32.559268548 -0800
+++ patched_vllm/worker/model_runner.py 2025-01-09 01:49:43.993283816 -0800
@@ -1,7 +1,9 @@
+import time
import dataclasses
import gc
import inspect
import itertools
+import os
import time
import warnings
import weakref
@@ -25,6 +27,7 @@
PromptAdapterConfig, SchedulerConfig)
from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group
+from vllm.distributed.kv_cache import get_kv_cache_handler
from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
@@ -46,7 +49,7 @@
from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager)
from vllm.sampling_params import SamplingParams
-from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
+from vllm.sequence import IntermediateTensors, SequenceGroupMetadata, Logprob, SequenceOutput, CompletionSequenceGroupOutput
from vllm.transformers_utils.config import uses_mrope
from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d,
flatten_2d_lists, is_hip, is_pin_memory_available,
@@ -58,6 +61,7 @@
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict, dump_input_when_exception)
+
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
@@ -120,6 +124,7 @@
"virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids,
+ "seq_lens": self.seq_lens,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
@@ -158,6 +163,7 @@
"virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids,
+ "seq_lens": self.seq_lens,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
@@ -959,6 +965,7 @@
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
+ self._kv_cache_handler = None
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
@@ -978,8 +985,7 @@
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
- self.max_batchsize_to_capture = _get_max_graph_batch_size(
- self.scheduler_config.max_num_seqs)
+ self.max_batchsize_to_capture = None
self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
{} for _ in range(self.parallel_config.pipeline_parallel_size)
@@ -995,9 +1001,7 @@
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
- self.graph_block_tables = np.zeros(
- (self.max_batchsize_to_capture, self.get_max_block_per_batch()),
- dtype=np.int32)
+ self.graph_block_tables = None
# Attention-free but stateful models like Mamba need a placeholder attn
# backend, as the attention metadata is needed to manage internal state.
@@ -1196,11 +1200,18 @@
return builder.build() # type: ignore
@torch.inference_mode()
- def profile_run(self) -> None:
+ def profile_run(self, max_num_seqs: Optional[int] = None) -> None:
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
- max_num_seqs = self.scheduler_config.max_num_seqs
+ if max_num_seqs is None:
+ max_num_seqs = self.scheduler_config.max_num_seqs
+
+ self.max_batchsize_to_capture = _get_max_graph_batch_size(
+ max_num_seqs)
+ self.graph_block_tables = np.zeros(
+ (self.max_batchsize_to_capture, self.get_max_block_per_batch()),
+ dtype=np.int32)
# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
@@ -1522,6 +1533,11 @@
# This usually takes < 10 seconds.
logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
+ def init_kv_cache_handler(self) -> None:
+ if envs.VLLM_DISAGG_STAGE is not None:
+ self._kv_cache_handler = get_kv_cache_handler()
+ # torch.distributed.barrier() # TODO ptarasiewicz check why this is raising NCCL errors
+
def _update_inputs_to_capture_for_enc_dec_model(self,
capture_inputs: Dict[str,
Any]):
@@ -1552,6 +1568,13 @@
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
ModelInputForGPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
+ _first_tokens = {}
+
+ def set_first_token(self, request_id: str, first_token: int) -> None:
+ self._first_tokens[request_id] = first_token
+
+ def pop_first_token(self, request_id: str) -> int:
+ return self._first_tokens.pop(request_id)
def make_model_input_from_broadcasted_tensor_dict(
self,
@@ -1610,6 +1633,8 @@
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
+ logger.debug(f"execute model input_ids: {model_input.input_tokens.shape}")
+ # logger.info(f"model input {model_input}")
if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in ModelRunner")
@@ -1654,21 +1679,98 @@
model_forward_end = torch.cuda.Event(enable_timing=True)
model_forward_start.record()
+ # Check if we need to run inference or not
+ # We do not run inference if we are in generating stage of disaggregated serving
+ disagg_stage = envs.VLLM_DISAGG_STAGE
+ assert disagg_stage in ["PREFILL", "GENERATE", None], f"Invalid disagg_stage: {disagg_stage}, must be one of ['PREFILL', 'GENERATE', None]"
+ is_profile_run = (
+ (kv_caches is None) or (kv_caches[0] is None) or (kv_caches[0].numel() == 0)
+ )
+ is_generate = disagg_stage == "GENERATE" and prefill_meta is None
+ run_inference = any(
+ [
+ disagg_stage == "PREFILL",
+ is_profile_run,
+ is_generate,
+ disagg_stage is None,
+ ]
+ )
+
+ logger.debug(
+ f"Running inference: {run_inference}, disagg_stage: {disagg_stage}, "
+ f"is_profile_run: {is_profile_run}, is_generate: {is_generate}"
+ )
+ logger.debug(f"Batch size: {model_input.input_tokens.shape} Prefill: {prefill_meta is not None}, Generate: {prefill_meta is None}")
+
+ if not run_inference:
+ if self.is_driver_worker:
+ logger.debug(f"Pulling KV cache for seq lens: {model_input.seq_lens}")
+ # We are in the GENERATE stage of disaggregated serving
+ # instead of running inference, we just need to pull the KV cache
+ # and hidden states from the context stage
+ # TODO ptarasiewicz check why without torch.cuda.synchronize() thorughput is lower
+ logger.debug("PULLING KV CACHE")
+ # torch.cuda.synchronize()
+ # start_time = time.perf_counter_ns()
+ self._kv_cache_handler.recv(
+ self.model.model,
+ model_input,
+ kv_caches,
+ )
+ # torch.cuda.synchronize()
+ # end_time = time.perf_counter_ns()
+ # logger.info(f"KV CACHE PULL TIME: {(end_time - start_time) / 1e6} ms")
+
+ if not self.is_driver_worker:
+ return []
+
+ fist_token = self.pop_first_token(list(model_input.request_ids_to_seq_ids.keys())[0])
+ mocked_output = SamplerOutput(
+ outputs=[
+ CompletionSequenceGroupOutput(
+ samples=[
+ SequenceOutput(
+ parent_seq_id=seq_group.seq_ids[0],
+ output_token=fist_token,
+ logprobs={fist_token: Logprob(float('inf'))},
+ )
+ ],
+ prompt_logprobs=None,
+ )
+ for seq_group in model_input.sampling_metadata.seq_groups
+ ],
+ )
+ logger.debug(f"MOCKED OUTPUT {mocked_output}")
+ return [mocked_output]
+
+ logger.debug("RUNNING INFERENCE")
with set_forward_context(model_input.attn_metadata):
- hidden_or_intermediate_states = model_executable(
- input_ids=model_input.input_tokens,
- positions=model_input.input_positions,
- kv_caches=kv_caches,
- attn_metadata=model_input.attn_metadata,
- intermediate_tensors=intermediate_tensors,
- **MultiModalInputs.as_kwargs(multi_modal_kwargs,
- device=self.device),
- **seqlen_agnostic_kwargs)
+ with torch.cuda.nvtx.range("model_executable"):
+ hidden_or_intermediate_states = model_executable(
+ input_ids=model_input.input_tokens,
+ positions=model_input.input_positions,
+ kv_caches=kv_caches,
+ attn_metadata=model_input.attn_metadata,
+ intermediate_tensors=intermediate_tensors,
+ **MultiModalInputs.as_kwargs(multi_modal_kwargs,
+ device=self.device),
+ **seqlen_agnostic_kwargs)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.record()
+ if disagg_stage == "PREFILL" and not is_profile_run:
+ # TODO ptarasiewicz check why without torch.cuda.synchronize() thorughput is lower
+ logger.debug("Pushing KV cache")
+ torch.cuda.synchronize()
+ self._kv_cache_handler.send(
+ self.model.model,
+ model_input,
+ kv_caches,
+ )
+ logger.debug("finished sending")
+
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
if (self.is_driver_worker
@@ -1687,6 +1789,31 @@
hidden_or_intermediate_states.tensors["model_forward_time"] = (
torch.tensor(model_forward_time + orig_model_forward_time))
return hidden_or_intermediate_states
+
+ if not run_inference:
+
+ if not self.is_driver_worker:
+ return []
+
+ fist_token = self.pop_first_token(list(model_input.request_ids_to_seq_ids.keys())[0])
+ mocked_output = SamplerOutput(
+ outputs=[
+ CompletionSequenceGroupOutput(
+ samples=[
+ SequenceOutput(
+ parent_seq_id=seq_group.seq_ids[0],
+ output_token=fist_token,
+ logprobs={fist_token: Logprob(float('inf'))},
+ )
+ ],
+ prompt_logprobs=None,
+ )
+ for seq_group in model_input.sampling_metadata.seq_groups
+ ],
+ )
+ # print("OUTPUT", output)
+ print("MOCKED OUTPUT", mocked_output)
+ return [mocked_output]
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
@@ -1702,6 +1829,7 @@
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
+
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time
and output is not None):
diff -Naur v0.6.3.post1_vllm/worker/worker.py patched_vllm/worker/worker.py
--- v0.6.3.post1_vllm/worker/worker.py 2025-01-09 03:03:32.559268548 -0800
+++ patched_vllm/worker/worker.py 2025-01-09 01:49:43.993283816 -0800
@@ -202,7 +202,7 @@
tensorizer_config=tensorizer_config, )
@torch.inference_mode()
- def determine_num_available_blocks(self) -> Tuple[int, int]:
+ def determine_num_available_blocks(self, max_num_seqs: Optional[int] = None) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
@@ -214,13 +214,14 @@
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
+
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
- self.model_runner.profile_run()
+ self.model_runner.profile_run(max_num_seqs)
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
@@ -242,16 +243,18 @@
else:
num_gpu_blocks = int(
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
- peak_memory) // cache_block_size)
+ peak_memory) // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
- cache_block_size)
+ cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
if self.model_runner.lora_manager:
self.model_runner.remove_all_loras()
gc.collect()
torch.cuda.empty_cache()
- return num_gpu_blocks, num_cpu_blocks
+ return num_gpu_blocks, num_cpu_blocks
+
+
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
@@ -285,6 +288,7 @@
def _warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
self.model_runner.capture_model(self.gpu_cache)
+ self.model_runner.init_kv_cache_handler()
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
diff --git a/vllm/config.py b/vllm/config.py
index 9ba49757..7e871521 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -2629,7 +2629,7 @@ class KVTransferConfig(BaseModel):
kv_buffer_size: float = 1e9
# Whether this vLLM instance produces, consumes KV cache, or both. Choices
- # are 'kv_producer', 'kv_consumer', and 'both'.
+ # are 'kv_producer', 'kv_consumer', and 'kv_both'.
kv_role: Optional[str] = None
# The rank of this vLLM instance in the KV cache transfer. Typical value:
@@ -2647,6 +2647,14 @@ class KVTransferConfig(BaseModel):
# The KV connector port, used to build distributed connection
kv_port: int = 14579
+
+ # This does not need to be set by the user. It is set by the connector.
+ kv_producers_parallel_size: Optional[int] = None
+ kv_producers_tensor_parallel_size: Optional[int] = None
+ kv_producers_pipeline_parallel_size: Optional[int] = None
+ kv_consumers_tensor_parallel_size: Optional[int] = None
+ kv_consumers_pipeline_parallel_size: Optional[int] = None
+
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
@@ -2685,6 +2693,7 @@ class KVTransferConfig(BaseModel):
"is set, supported roles are `kv_producer`, "
"`kv_consumer`, and `kv_both`")
+
@property
def is_kv_transfer_instance(self) -> bool:
return self.kv_connector is not None and \
@@ -2706,6 +2715,18 @@ class KVTransferConfig(BaseModel):
return self.kv_connector is not None and \
self.kv_role in ["kv_consumer", "kv_both"]
+ @property
+ def tensor_parallel_multiplier(self) -> int:
+ return self.kv_consumers_tensor_parallel_size // self.kv_producers_tensor_parallel_size
+
+ @property
+ def kv_consumers_parallel_size(self) -> int:
+ return self.kv_parallel_size - self.kv_producers_parallel_size
+
+ @property
+ def kv_world_size(self) -> int:
+ return self.kv_producers_parallel_size + self.kv_consumers_parallel_size * self.tensor_parallel_multiplier
+
class CompilationLevel:
# constants for the levels of the compilation process
diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py
index fe480533..b768e03c 100644
--- a/vllm/distributed/kv_transfer/kv_connector/factory.py
+++ b/vllm/distributed/kv_transfer/kv_connector/factory.py
@@ -27,13 +27,13 @@ class KVConnectorFactory:
@classmethod
def create_connector(cls, rank: int, local_rank: int,
- config: "VllmConfig") -> KVConnectorBase:
+ config: "VllmConfig", world_group) -> KVConnectorBase:
connector_name = config.kv_transfer_config.kv_connector
if connector_name not in cls._registry:
raise ValueError(f"Unsupported connector type: {connector_name}")
connector_cls = cls._registry[connector_name]()
- return connector_cls(rank, local_rank, config)
+ return connector_cls(rank, local_rank, config, world_group)
# Register various connectors here.
diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
index 2033e976..71cd0567 100644
--- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
@@ -8,13 +8,15 @@ MooncakePipe.
But the logic can be extended to support other pipe and lookup buffer.
"""
+import re
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
from vllm import _custom_ops as ops
-from vllm.config import VllmConfig
+from vllm.config import VllmConfig, KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
+from vllm.distributed.utils import StatelessProcessGroup
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
SimpleBuffer)
from vllm.logger import init_logger
@@ -33,10 +35,10 @@ class SimpleConnector(KVConnectorBase):
rank: int,
local_rank: int,
config: VllmConfig,
+ world_group,
):
self.config = config.kv_transfer_config
- self.tp_size = config.parallel_config.tensor_parallel_size
if self.config.kv_connector == "PyNcclConnector":
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
@@ -71,20 +73,31 @@ class SimpleConnector(KVConnectorBase):
self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
+ self._broadcast_and_enhance_kv_config(rank, config, world_group)
+
+ self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config)
+ self.tp_size = config.parallel_config.tensor_parallel_size
+
# 2 pipes for every rank in the world
- port_offset_base = 2 * rank
+ if self.config.is_kv_producer:
+ port_offset_base = 2 * rank + 1
+ else:
+ port_offset_base = 2 * (rank // self.config.tensor_parallel_multiplier) + 1
+ self.local_kv_rank = rank % self.config.tensor_parallel_multiplier
# In disaggregated prefill, the prefill vLLM only uses send pipe
# and the decode vLLM only uses recv pipe
if self.config.is_kv_producer:
if self.config.kv_connector == "PyNcclConnector":
self.producer_data_pipe = PyNcclPipe(
+ kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.producer_signal_pipe = PyNcclPipe(
+ kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
@@ -108,11 +121,13 @@ class SimpleConnector(KVConnectorBase):
# its recv pipe to the send pipe of KV producder
if self.config.kv_connector == "PyNcclConnector":
self.consumer_data_pipe = PyNcclPipe(
+ kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.consumer_signal_pipe = PyNcclPipe(
+ kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
@@ -131,21 +146,25 @@ class SimpleConnector(KVConnectorBase):
self.config.kv_buffer_size,
)
- def select(self, input_tokens: Optional[torch.Tensor],
+ def select(self, source_rank: int, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
+ logger.info("Selecting KV caches and hidden states for source rank %d", source_rank)
+
assert self.consumer_buffer is not None, "Please initialize the "\
"consumer buffer before calling select."
- return self.consumer_buffer.drop_select(input_tokens, roi)
+ return self.consumer_buffer.drop_select(source_rank, self.local_kv_rank, input_tokens, roi)
- def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
+ def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
+ logger.info("Inserting KV caches and hidden states for kv_group_rank %d, target rank %d", kv_group_rank, target_rank)
+
assert self.producer_buffer is not None, "Please initialize the "\
"producer buffer before calling insert."
- self.producer_buffer.insert(input_tokens, roi, key, value, hidden)
+ self.producer_buffer.insert(kv_group_rank, target_rank, input_tokens, roi, key, value, hidden)
def send_kv_caches_and_hidden_states(
self,
@@ -161,6 +180,7 @@ class SimpleConnector(KVConnectorBase):
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
+ request_ids = list(model_input.request_ids_to_seq_ids.keys())
model_config = model_executable.model.config
num_heads = int(model_config.num_key_value_heads / self.tp_size)
@@ -175,27 +195,36 @@ class SimpleConnector(KVConnectorBase):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]
+ current_request_id = request_ids[idx]
+ _, decode_kv_rank = self.parse_request_id(current_request_id)
+ starting_kv_group_rank = self._get_kv_group_rank(decode_kv_rank, 0, self.config)
+
+ for target_rank in range(self.config.tensor_parallel_multiplier):
+
+ keys, values = [], []
- keys, values = [], []
+ for layer_id in range(start_layer, end_layer):
+ kv_cache = kv_caches[layer_id - start_layer]
- for layer_id in range(start_layer, end_layer):
- kv_cache = kv_caches[layer_id - start_layer]
+ key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
+ value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
- key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
- value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
+ current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
- current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
+ num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier
+ head_start = target_rank * num_heads_per_rank
+ head_end = head_start + num_heads_per_rank
- keys.append(key_cache[current_slot_mapping].unsqueeze(0))
- values.append(value_cache[current_slot_mapping].unsqueeze(0))
+ keys.append(key_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
+ values.append(value_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
- keys = torch.cat(keys, dim=0)
- values = torch.cat(values, dim=0)
+ keys = torch.cat(keys, dim=0)
+ values = torch.cat(values, dim=0)
- self.insert(current_tokens,
- torch.ones_like(current_tokens,
- dtype=bool), keys, values,
- hidden_or_intermediate_states[start_pos:end_pos])
+ self.insert(starting_kv_group_rank, target_rank, current_tokens,
+ torch.ones_like(current_tokens,
+ dtype=bool), keys, values,
+ hidden_or_intermediate_states[start_pos:end_pos])
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
@@ -215,6 +244,7 @@ class SimpleConnector(KVConnectorBase):
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
+ request_ids = list(model_input.request_ids_to_seq_ids.keys())
hidden_or_intermediate_states_for_one_req = []
@@ -229,13 +259,15 @@ class SimpleConnector(KVConnectorBase):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]
+ current_request_id = request_ids[idx]
+ prefill_rank, _ = self.parse_request_id(current_request_id)
num_tokens = slen
# collecting data for rebuilding the input
input_tokens_list.append(current_tokens)
start_pos_list.append(start_pos)
- ret = self.select(current_tokens,
+ ret = self.select(prefill_rank, current_tokens,
torch.ones_like(current_tokens, dtype=bool))
if ret[0] is None:
# didn't find any match.
@@ -312,3 +344,77 @@ class SimpleConnector(KVConnectorBase):
# MooncakePipe reuses data_pipe for signal_pipe, so we only have to
# close the data_pipe.
pass
+
+ @staticmethod
+ def parse_request_id(request_id):
+ # Regular expression to match the ranks
+ pattern = r"___prefill_kv_rank_(\d+)___decode_kv_rank_(\d+)"
+
+ # Use re.search to find the pattern in the request_id
+ match = re.search(pattern, request_id)
+
+ if match:
+ # Extract the ranks
+ prefill_rank = int(match.group(1))
+ decode_rank = int(match.group(2))
+
+ return prefill_rank, decode_rank
+ else:
+ return None, None
+
+
+
+ def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
+ if kv_rank < config.kv_producers_parallel_size:
+ return kv_rank
+
+ kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
+ return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier
+
+ def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group):
+ if rank == 0:
+ if self.config.kv_connector == "PyNcclConnector":
+ config_group = StatelessProcessGroup.create(
+ host=self.config.kv_ip,
+ port=self.config.kv_port,
+ rank=self.config.kv_rank,
+ world_size=self.config.kv_parallel_size,
+ )
+ parallel_configs = config_group.all_gather_obj({
+ "kv_role": self.config.kv_role,
+ "tensor_parallel_size": config.parallel_config.tensor_parallel_size,
+ "pipeline_parallel_size": config.parallel_config.pipeline_parallel_size,
+ })
+ logger.debug("parallel_configs: %s", parallel_configs)
+ kv_config_enhanced = {
+ "kv_producers_tensor_parallel_size": None,
+ "kv_consumers_tensor_parallel_size": None,
+ "kv_producers_pipeline_parallel_size": None,
+ "kv_consumers_pipeline_parallel_size": None,
+ "kv_producers_parallel_size": 0,
+ }
+ for parallel_config in parallel_configs:
+ kv_role = parallel_config["kv_role"]
+ assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances"
+
+ if kv_role == "kv_producer":
+ kv_config_enhanced["kv_producers_parallel_size"] += 1
+ if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None:
+ kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"]
+ kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"]
+ else:
+ assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size"
+ assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size"
+ world_group.broadcast_object(kv_config_enhanced)
+
+ else:
+ raise NotImplementedError("MooncakeConnector is not supported in Triton Distributed vllm patch")
+ else:
+ kv_config_enhanced = world_group.broadcast_object()
+ logger.info("kv_config_enhanced: %s", kv_config_enhanced)
+
+ self.config.kv_producers_tensor_parallel_size = kv_config_enhanced["kv_producers_tensor_parallel_size"]
+ self.config.kv_consumers_tensor_parallel_size = kv_config_enhanced["kv_consumers_tensor_parallel_size"]
+ self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"]
+ self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"]
+ self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"]
\ No newline at end of file
diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
index 5e1b6235..b4506877 100644
--- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
+++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
@@ -12,7 +12,8 @@
import threading
import time
from collections import deque
-from typing import Deque, List, Optional, Union
+from concurrent.futures import ThreadPoolExecutor
+from typing import Deque, List, Optional, Union, Dict
import torch
@@ -46,7 +47,7 @@ class SimpleBuffer(KVLookupBufferBase):
self.buffer_lock = threading.Lock()
self.signal_pipe = signal_pipe
self.data_pipe = data_pipe
- self.request_handling_thread: Optional[threading.Thread] = None
+ self.request_handling_thread: Optional[ThreadPoolExecutor] = None
self.normal_signal = torch.tensor([0], device="cpu")
self.end_signal = None
@@ -57,10 +58,16 @@ class SimpleBuffer(KVLookupBufferBase):
# tokens_roi_sender: tokens and roi of the producer (in the buffer)
# tokens_roi_recver: tokens and roi of the consumer (query)
- tokens_sender = tokens_roi_sender[0]
- tokens_recver = tokens_roi_recver[0]
- roi_sender = tokens_roi_sender[1]
- roi_recver = tokens_roi_recver[1]
+ target_rank_sender = tokens_roi_sender[0]
+ target_rank_recver = tokens_roi_recver[0]
+
+ if target_rank_sender.item() != target_rank_recver.item():
+ return 0
+
+ tokens_sender = tokens_roi_sender[1]
+ tokens_recver = tokens_roi_recver[1]
+ roi_sender = tokens_roi_sender[2]
+ roi_recver = tokens_roi_recver[2]
if tokens_recver is None:
# consumer sends an empty request
@@ -80,14 +87,14 @@ class SimpleBuffer(KVLookupBufferBase):
return 0
- def _send_tensor_and_dec_size(self,
- tensor: Optional[torch.Tensor]) -> None:
+ def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor],
+ target_rank: int) -> None:
assert tensor is not None, "Use self.data_pipe.send(None) instead"
self.buffer_size -= tensor.element_size() * tensor.numel()
if tensor.dtype == torch.bool:
tensor = tensor.float()
- self.data_pipe.send_tensor(tensor)
+ self.data_pipe.send_tensor(tensor, target_rank)
def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]):
@@ -100,7 +107,7 @@ class SimpleBuffer(KVLookupBufferBase):
raise AssertionError(f"Unknown data type {type(data)}")
- def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor,
+ def _add_to_buffer(self, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor):
@@ -115,7 +122,7 @@ class SimpleBuffer(KVLookupBufferBase):
if isinstance(hidden, torch.Tensor):
hidden = hidden.clone()
- buffer_item = [input_tokens, roi, key, value, hidden]
+ buffer_item = [torch.tensor(target_rank), input_tokens, roi, key, value, hidden]
with self.buffer_lock:
for data in buffer_item:
@@ -125,53 +132,54 @@ class SimpleBuffer(KVLookupBufferBase):
def _is_end_signal(self, signal):
return signal is None
- def drop_select_handler(self):
+ def drop_select_handler(self, rank: int):
try:
- while True:
- signal = self.signal_pipe.recv_tensor()
- if self._is_end_signal(signal):
- logger.info("Received end signal!")
- break
-
- input_tokens = self.data_pipe.recv_tensor()
-
- roi = self.data_pipe.recv_tensor()
- assert roi is not None, "Please provide the roi when sending "\
- "drop-select request"
- roi = (roi > 0.5)
- tokens_roi_recver = [input_tokens, roi]
-
- matched_length = 0
-
- # perform input tokens and roi matching
- # FIXME: this matching is O(n), ideally it should be O(1)
- # but this buffer size won't (and shouldn't) be too large so
- # the fix is not urgent.
- with self.buffer_lock:
-
- for _ in range(len(self.buffer)):
-
- temp_length = self._matches(self.buffer[0],
- tokens_roi_recver)
- if temp_length > 0:
- matched_length = temp_length
- break
- # rotate the element we just accessed to the end
- self.buffer.rotate(-1)
-
- if matched_length > 0:
- # need to clone the tensor
- # in case the tensor is freed before sending finishes
- matched_item = self.buffer.popleft()
- for tensor in matched_item:
- self._send_tensor_and_dec_size(tensor)
-
- else:
- # no match, just send None
- for _ in range(5):
- self.data_pipe.send_tensor(None)
+ signal = self.signal_pipe.recv_tensor(rank)
+ if self._is_end_signal(signal):
+ logger.info("Received end signal!")
+ return
+ target_kv_rank = self.data_pipe.recv_tensor(rank)
+ # assert target_rank.item() == rank, "Target rank does not match"\
+ # "the rank of the drop-select handler"
+ input_tokens = self.data_pipe.recv_tensor(rank)
+ roi = self.data_pipe.recv_tensor(rank)
+ assert roi is not None, "Please provide the roi when sending "\
+ "drop-select request"
+ roi = (roi > 0.5)
+ tokens_roi_recver = [target_kv_rank, input_tokens, roi]
+
+ matched_length = 0
+
+ # perform input tokens and roi matching
+ # FIXME: this matching is O(n), ideally it should be O(1)
+ # but this buffer size won't (and shouldn't) be too large so
+ # the fix is not urgent.
+ with self.buffer_lock:
+
+ for _ in range(len(self.buffer)):
+
+ temp_length = self._matches(self.buffer[0],
+ tokens_roi_recver)
+ if temp_length > 0:
+ matched_length = temp_length
+ break
+ # rotate the element we just accessed to the end
+ self.buffer.rotate(-1)
+
+ if matched_length > 0:
+ # need to clone the tensor
+ # in case the tensor is freed before sending finishes
+ matched_item = self.buffer.popleft()
+ target_rank = matched_item[0].item()
+ for tensor in matched_item[1:]:
+ self._send_tensor_and_dec_size(tensor, rank)
+
+ else:
+ # no match, just send None
+ for _ in range(5):
+ self.data_pipe.send_tensor(None, rank)
except RuntimeError as e:
if 'Connection closed by peer' not in str(e):
@@ -180,10 +188,10 @@ class SimpleBuffer(KVLookupBufferBase):
logger.debug("Closing drop_select_handler")
def drop_select(
- self, input_tokens: Optional[torch.Tensor],
+ self, rank: int, kv_rank: int, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
- assert self.request_handling_thread is None, \
+ assert not self.request_handling_thread, \
"drop_select should be called by the KV cache consumer "\
"(e.g. the decode vLLM instance)"
@@ -192,26 +200,28 @@ class SimpleBuffer(KVLookupBufferBase):
if isinstance(roi, torch.Tensor):
roi = roi.clone().float()
- self.signal_pipe.send_tensor(self.normal_signal)
- self.data_pipe.send_tensor(input_tokens)
- self.data_pipe.send_tensor(roi)
+ self.signal_pipe.send_tensor(self.normal_signal, rank)
+
+ self.data_pipe.send_tensor(torch.tensor(kv_rank), rank)
+ self.data_pipe.send_tensor(input_tokens, rank)
+ self.data_pipe.send_tensor(roi, rank)
- input_tokens = self.data_pipe.recv_tensor()
- roi = self.data_pipe.recv_tensor()
+ input_tokens = self.data_pipe.recv_tensor(rank)
+ roi = self.data_pipe.recv_tensor(rank)
if roi is not None:
# convert from float tensor to bool tensor
# as PyNccl does not support sending bool tensor
roi = (roi > 0.5)
- key = self.data_pipe.recv_tensor()
- value = self.data_pipe.recv_tensor()
- hidden = self.data_pipe.recv_tensor()
+ key = self.data_pipe.recv_tensor(rank)
+ value = self.data_pipe.recv_tensor(rank)
+ hidden = self.data_pipe.recv_tensor(rank)
return [input_tokens, roi, key, value, hidden]
def full_handler(self):
time.sleep(0.001)
- def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
+ def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
@@ -222,20 +232,19 @@ class SimpleBuffer(KVLookupBufferBase):
while self.buffer_size > self.buffer_size_threshold:
self.full_handler()
- self._add_to_buffer(input_tokens, roi, key, value, hidden)
+ self._add_to_buffer(target_rank, input_tokens, roi, key, value, hidden)
# when calling the insert, the current process is a sender
# need to launch the request handler and start listening to request.
+ target_rank_global = target_rank + kv_group_rank
if self.request_handling_thread is None:
- self.request_handling_thread = threading.Thread(
- target=self.drop_select_handler)
- self.request_handling_thread.start()
+ self.request_handling_thread = ThreadPoolExecutor(max_workers=1)
+ self.request_handling_thread.submit(self.drop_select_handler, target_rank_global)
def close(self):
- if hasattr(self, "request_handling_thread"
- ) and self.request_handling_thread is not None:
- self.request_handling_thread.join()
+ if hasattr(self, "request_handling_thread") and self.request_handling_thread:
+ self.request_handling_thread.shutdown()
else:
# TODO: have a explicit close signal and have a explicit way to
diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py
index 40589fb3..da2829cf 100644
--- a/vllm/distributed/kv_transfer/kv_pipe/base.py
+++ b/vllm/distributed/kv_transfer/kv_pipe/base.py
@@ -23,7 +23,7 @@ class KVPipeBase(ABC):
"""
@abstractmethod
- def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
+ def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int = 0) -> None:
"""Send a tensor, or None, via the pipe.
Need to support sending None -- important for error handling.
@@ -41,7 +41,7 @@ class KVPipeBase(ABC):
raise NotImplementedError
@abstractmethod
- def recv_tensor(self) -> Optional[torch.Tensor]:
+ def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]:
"""Receive a tensor (can be None) from the pipeline.
Returns:
diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
index 7aa53d07..db10f8a0 100644
--- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
+++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
@@ -45,33 +45,33 @@ class PyNcclPipe(KVPipeBase):
METADATA_DTYPE = torch.int64
def __init__(self,
+ kv_group_rank: int,
local_rank: int,
config: KVTransferConfig,
device: Optional[str] = None,
port_offset: int = 0):
self.config = config
self.local_rank = local_rank
- self.kv_rank = self.config.kv_rank
+ self.kv_group_rank = kv_group_rank
self.kv_parallel_size = self.config.kv_parallel_size
+ self.kv_world_size = self.config.kv_world_size
if device is None:
self.device = self._select_device(self.config.kv_buffer_device)
else:
self.device = self._select_device(device)
# build distributed connection and send/recv implementation
+ logger.info("Creating process group for kv transfer with rank %d and world size %d, ip: %s, port: %d", self.kv_group_rank, self.kv_world_size, self.config.kv_ip, self.config.kv_port + port_offset)
self.group = StatelessProcessGroup.create(
host=self.config.kv_ip,
port=self.config.kv_port + port_offset,
- rank=self.kv_rank,
- world_size=self.kv_parallel_size,
+ rank=self.kv_group_rank,
+ world_size=self.kv_world_size,
)
# add a barrier to make sure the connection is initiated properly
self.group.barrier()
impl = self._get_device_send_recv_impl(self.group)
self.device_send_func, self.device_recv_func = impl
- # set target rank
- self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size
- self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size
# transportation-related variables
self.transport_thread: Optional[ThreadPoolExecutor] = None
@@ -145,16 +145,16 @@ class PyNcclPipe(KVPipeBase):
dtype=metadata["dtype"],
device=self.device)
- def _send_metadata(self, metadata: Metadata):
+ def _send_metadata(self, metadata: Metadata, target_rank: int):
"""
Send the metadata dictionary to the target rank.
Parameters:
- metadata: A dictionary with keys "dtype" and "shape".
"""
- self.group.send_obj(metadata, self.target_rank_for_send)
+ self.group.send_obj(metadata, target_rank)
- def _recv_metadata(self) -> Metadata:
+ def _recv_metadata(self, src_rank: int) -> Metadata:
"""
Receive the metadata dictionary from the target rank.
@@ -162,9 +162,9 @@ class PyNcclPipe(KVPipeBase):
- metadata: A dictionary with keys "dtype" and "shape" describing
the tensor.
"""
- return self.group.recv_obj(self.target_rank_for_recv)
+ return self.group.recv_obj(src_rank)
- def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:
+ def _send_impl(self, tensor: Optional[torch.Tensor], target_rank: int) -> None:
"""
The actual implementation of sending the tensor and its metadata to the
target rank.
@@ -174,12 +174,12 @@ class PyNcclPipe(KVPipeBase):
being sent.
"""
metadata = self._make_metadata(tensor)
- self._send_metadata(metadata)
+ self._send_metadata(metadata, target_rank)
if tensor is not None:
self.device_send_func(tensor.to(self.device),
- self.target_rank_for_send)
+ target_rank)
- def _recv_impl(self) -> Optional[torch.Tensor]:
+ def _recv_impl(self, src_rank: int) -> Optional[torch.Tensor]:
"""
The actual implementation of receiving a tensor and its metadata from
the target rank.
@@ -187,21 +187,22 @@ class PyNcclPipe(KVPipeBase):
Returns:
- buffer: The received tensor, or None if no tensor is received.
"""
- metadata = self._recv_metadata()
+ metadata = self._recv_metadata(src_rank)
if metadata["dtype"] is None:
return None
buffer = self._prepare_recv_buffer(metadata)
- self.device_recv_func(buffer, self.target_rank_for_recv)
+ self.device_recv_func(buffer, src_rank)
return buffer
def send_tensor_wrapper(self, tensor: Optional[torch.Tensor],
- tensor_size: int) -> None:
+ tensor_size: int,
+ target_rank: int) -> None:
"""
Wrapper for _send_impl to handle exceptions and update buffer size.
"""
try:
- self._send_impl(tensor)
+ self._send_impl(tensor, target_rank)
with self.buffer_size_lock:
self.buffer_size -= tensor_size
@@ -220,7 +221,7 @@ class PyNcclPipe(KVPipeBase):
logger.debug("KV cache transfer pipe is full. Waiting...")
time.sleep(0.05)
- def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
+ def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int) -> None:
"""
Sends a tensor and its metadata to the destination rank in a
non-blocking way.
@@ -228,6 +229,7 @@ class PyNcclPipe(KVPipeBase):
Parameters:
- tensor: The tensor to send, or None if no tensor is being sent.
"""
+ logger.debug("Rank %d sending tensor of shape %s dtype %s to rank %d", self.kv_group_rank, tensor.shape if tensor is not None else "None", tensor.dtype if tensor is not None else "None", target_rank)
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
@@ -242,19 +244,23 @@ class PyNcclPipe(KVPipeBase):
self.buffer_size += tensor_size
self.transport_thread.submit(self.send_tensor_wrapper, tensor,
- tensor_size)
+ tensor_size,
+ target_rank)
- def recv_tensor(self) -> Optional[torch.Tensor]:
+ def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]:
"""
Receives a tensor and its metadata from the source rank. Blocking call.
Returns:
- tensor: The received tensor, or None if no tensor is received.
"""
+
+ logger.debug("Rank %d receiving tensor from rank %d", self.kv_group_rank, src_rank)
+
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
- future = self.transport_thread.submit(self._recv_impl)
+ future = self.transport_thread.submit(self._recv_impl, src_rank)
try:
tensor = future.result()
diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py
index 1e80e0bd..cd90206f 100644
--- a/vllm/distributed/kv_transfer/kv_transfer_agent.py
+++ b/vllm/distributed/kv_transfer/kv_transfer_agent.py
@@ -35,6 +35,7 @@ class KVTransferAgent:
rank: int,
local_rank: int,
config: "VllmConfig",
+ world_group,
):
self.config = config
@@ -47,7 +48,7 @@ class KVTransferAgent:
"TransferAgent should only be used when kv_connector is set."
self.connector = KVConnectorFactory.create_connector(
- rank, local_rank, config)
+ rank, local_rank, config, world_group)
def send_kv_caches_and_hidden_states(
self,
diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py
index 321902d1..b8937ef8 100644
--- a/vllm/distributed/parallel_state.py
+++ b/vllm/distributed/parallel_state.py
@@ -1085,7 +1085,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
_KV_TRANSFER = kv_transfer.KVTransferAgent(
rank=get_world_group().rank,
local_rank=get_world_group().local_rank,
- config=vllm_config)
+ config=vllm_config,
+ world_group=get_world_group())
def ensure_model_parallel_initialized(
...@@ -17,6 +17,9 @@ limitations under the License. ...@@ -17,6 +17,9 @@ limitations under the License.
# Disaggregated Serving with VLLM # Disaggregated Serving with VLLM
> **Warning**
> This example is currently not tested and might not work as expected. For working disaggregated serving examples, please see the [vLLM example](/examples/python_rs/llm/vllm/).
This example demonstrates **disaggregated serving** [^1] using Triton Distributed together with vLLM engines. Disaggregated serving decouples the prefill (prompt encoding) and the decode (token generation) stages of large language model (LLM) inference into separate processes. This separation allows you to independently scale, optimize, and distribute resources for each stage. This example demonstrates **disaggregated serving** [^1] using Triton Distributed together with vLLM engines. Disaggregated serving decouples the prefill (prompt encoding) and the decode (token generation) stages of large language model (LLM) inference into separate processes. This separation allows you to independently scale, optimize, and distribute resources for each stage.
In this example, you will deploy: In this example, you will deploy:
......
...@@ -42,17 +42,13 @@ The example is designed to run in a containerized environment using Triton Distr ...@@ -42,17 +42,13 @@ The example is designed to run in a containerized environment using Triton Distr
```bash ```bash
# Build image # Build image
./container/build.sh ./container/build.sh --framework VLLM
``` ```
## Launching the Environment ## Launching the Environment
``` ```
# Run image interactively # Run image interactively
./container/run.sh -it ./container/run.sh --framework VLLM -it
# Add vllm into the python virtual environment
source /opt/triton/venv/bin/activate
uv pip install vllm==0.7.2
``` ```
## Deployment Options ## Deployment Options
...@@ -113,11 +109,12 @@ source /opt/triton/venv/bin/activate ...@@ -113,11 +109,12 @@ source /opt/triton/venv/bin/activate
# Launch prefill worker # Launch prefill worker
cd /workspace/examples/python_rs/llm/vllm cd /workspace/examples/python_rs/llm/vllm
CUDA_VISIBLE_DEVICES=0 python3 -m disaggregated.prefill_worker \ VLLM_WORKER_MULTIPROC_METHOD=spawn CUDA_VISIBLE_DEVICES=0 python3 -m disaggregated.prefill_worker \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--max-model-len 100 \ --max-model-len 100 \
--gpu-memory-utilization 0.8 \ --gpu-memory-utilization 0.8 \
--enforce-eager \ --enforce-eager \
--tensor-parallel-size 1 \
--kv-transfer-config \ --kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}'
``` ```
...@@ -129,11 +126,12 @@ source /opt/triton/venv/bin/activate ...@@ -129,11 +126,12 @@ source /opt/triton/venv/bin/activate
# Launch decode worker # Launch decode worker
cd /workspace/examples/python_rs/llm/vllm cd /workspace/examples/python_rs/llm/vllm
CUDA_VISIBLE_DEVICES=1 python3 -m disaggregated.decode_worker \ VLLM_WORKER_MULTIPROC_METHOD=spawn CUDA_VISIBLE_DEVICES=1,2 python3 -m disaggregated.decode_worker \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--max-model-len 100 \ --max-model-len 100 \
--gpu-memory-utilization 0.8 \ --gpu-memory-utilization 0.8 \
--enforce-eager \ --enforce-eager \
--tensor-parallel-size 2 \
--kv-transfer-config \ --kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}'
``` ```
...@@ -174,6 +172,11 @@ For disaggregated deployment, you will also need to pass the `kv_ip` and `kv_por ...@@ -174,6 +172,11 @@ For disaggregated deployment, you will also need to pass the `kv_ip` and `kv_por
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":<rank>,"kv_parallel_size":2,"kv_ip":<master_node_ip>,"kv_port":<kv_port>}' '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":<rank>,"kv_parallel_size":2,"kv_ip":<master_node_ip>,"kv_port":<kv_port>}'
``` ```
### 4. Known Issues and Limitations
- vLLM is not working well with the `fork` method for multiprocessing and TP > 1. This is a known issue and a workaround is to use the `spawn` method instead. See [vLLM issue](https://github.com/vllm-project/vllm/issues/6152).
- `kv_rank` of `kv_producer` must be smaller than of `kv_consumer`.
- Instances with the same `kv_role` must have the same `--tensor-parallel-size`.
- Currently only `--pipeline-parallel-size 1` is supported for XpYd disaggregated deployment.
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
# limitations under the License. # limitations under the License.
import argparse
import asyncio import asyncio
import uvloop import uvloop
from triton_distributed_rs import DistributedRuntime, triton_worker from triton_distributed_rs import DistributedRuntime, triton_worker
from vllm.utils import FlexibleArgumentParser
from .protocol import Request from .protocol import Request
...@@ -40,14 +40,23 @@ async def worker( ...@@ -40,14 +40,23 @@ async def worker(
print(client.endpoint_ids()) print(client.endpoint_ids())
# issue request # issue request
stream = await client.generate( tasks = []
for _ in range(1):
tasks.append(
client.generate(
Request( Request(
prompt=prompt, prompt=prompt,
sampling_params={"temperature": temperature, "max_tokens": max_tokens}, sampling_params={
"temperature": temperature,
"max_tokens": max_tokens,
},
).model_dump_json() ).model_dump_json()
) )
)
streams = await asyncio.gather(*tasks)
# process response # process response
for stream in streams:
async for resp in stream: async for resp in stream:
print(resp) print(resp)
...@@ -55,7 +64,7 @@ async def worker( ...@@ -55,7 +64,7 @@ async def worker(
if __name__ == "__main__": if __name__ == "__main__":
uvloop.install() uvloop.install()
parser = FlexibleArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, default="what is the capital of france?") parser.add_argument("--prompt", type=str, default="what is the capital of france?")
parser.add_argument("--max-tokens", type=int, default=10) parser.add_argument("--max-tokens", type=int, default=10)
parser.add_argument("--temperature", type=float, default=0.5) parser.add_argument("--temperature", type=float, default=0.5)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import asyncio import asyncio
import random
import uuid import uuid
import uvloop import uvloop
...@@ -31,18 +32,27 @@ class VllmDecodeEngine: ...@@ -31,18 +32,27 @@ class VllmDecodeEngine:
Request handler for the generate endpoint Request handler for the generate endpoint
""" """
def __init__(self, engine_args: AsyncEngineArgs, prefill): def __init__(self, engine_args: AsyncEngineArgs):
assert ( assert (
engine_args.kv_transfer_config.is_kv_consumer engine_args.kv_transfer_config.is_kv_consumer
), "Decode worker must be a KV consumer" ), "Decode worker must be a KV consumer"
self.engine = vllm.AsyncLLMEngine.from_engine_args(engine_args) self.engine = vllm.AsyncLLMEngine.from_engine_args(engine_args)
self.prefill = prefill self.prefills: list = []
self.prefill_workers = (
self.engine.engine.vllm_config.kv_transfer_config.kv_producers_parallel_size
)
self.kv_rank = self.engine.engine.vllm_config.kv_transfer_config.kv_rank
def add_prefill(self, prefill):
self.prefills.append(prefill)
@triton_endpoint(Request, Response) @triton_endpoint(Request, Response)
async def generate(self, request): async def generate(self, request):
vllm_logger.info(f"Received request: {request}") vllm_logger.info(f"Received request: {request}")
sampling_params = vllm.SamplingParams(**request.sampling_params) sampling_params = vllm.SamplingParams(**request.sampling_params)
request_id = str(uuid.uuid4()) prefill_rank = random.choice(range(self.prefill_workers))
request_id = f"{uuid.uuid4()}___prefill_kv_rank_{prefill_rank}___decode_kv_rank_{self.kv_rank}"
prefill_sampling_params = {**request.sampling_params} prefill_sampling_params = {**request.sampling_params}
prefill_sampling_params["max_tokens"] = 1 prefill_sampling_params["max_tokens"] = 1
...@@ -51,13 +61,9 @@ class VllmDecodeEngine: ...@@ -51,13 +61,9 @@ class VllmDecodeEngine:
sampling_params=prefill_sampling_params, sampling_params=prefill_sampling_params,
request_id=request_id, request_id=request_id,
) )
prefill_generator = await self.prefill.generate( self.prefills[prefill_rank].generate(
prefill_request.model_dump_json() prefill_request.model_dump_json(),
) )
prefill_response = [resp async for resp in prefill_generator]
assert len(prefill_response) == 1, "Prefill response should be a single boolean"
prefill_response = prefill_response[0]
vllm_logger.debug(f"Prefill response: {prefill_response}")
async for response in self.engine.generate( async for response in self.engine.generate(
request.prompt, sampling_params, request_id request.prompt, sampling_params, request_id
...@@ -75,15 +81,18 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): ...@@ -75,15 +81,18 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
component = runtime.namespace("triton-init").component("vllm") component = runtime.namespace("triton-init").component("vllm")
await component.create_service() await component.create_service()
decode_engine = VllmDecodeEngine(engine_args)
for i in range(decode_engine.prefill_workers):
prefill = ( prefill = (
await runtime.namespace("triton-init") await runtime.namespace("triton-init")
.component("prefill") .component("prefill")
.endpoint("generate") .endpoint(f"generate_kv_rank_{i}")
.client() .client()
) )
decode_engine.add_prefill(prefill)
endpoint = component.endpoint("generate") endpoint = component.endpoint("generate")
await endpoint.serve_endpoint(VllmDecodeEngine(engine_args, prefill).generate) await endpoint.serve_endpoint(decode_engine.generate)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -35,6 +35,7 @@ class VllmPrefillEngine: ...@@ -35,6 +35,7 @@ class VllmPrefillEngine:
engine_args.kv_transfer_config.is_kv_producer engine_args.kv_transfer_config.is_kv_producer
), "Prefill worker must be a KV producer" ), "Prefill worker must be a KV producer"
self.engine = vllm.AsyncLLMEngine.from_engine_args(engine_args) self.engine = vllm.AsyncLLMEngine.from_engine_args(engine_args)
self.kv_rank = self.engine.engine.vllm_config.kv_transfer_config.kv_rank
@triton_endpoint(PrefillRequest, PrefillResponse) @triton_endpoint(PrefillRequest, PrefillResponse)
async def generate(self, request): async def generate(self, request):
...@@ -56,8 +57,9 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): ...@@ -56,8 +57,9 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
component = runtime.namespace("triton-init").component("prefill") component = runtime.namespace("triton-init").component("prefill")
await component.create_service() await component.create_service()
endpoint = component.endpoint("generate") prefill_engine = VllmPrefillEngine(engine_args)
await endpoint.serve_endpoint(VllmPrefillEngine(engine_args).generate) endpoint = component.endpoint(f"generate_kv_rank_{prefill_engine.kv_rank}")
await endpoint.serve_endpoint(prefill_engine.generate)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment