Commit f38aa469 authored by ptarasiewiczNV's avatar ptarasiewiczNV Committed by GitHub
Browse files

feat: vLLM v0.7.2 patch supporting XpYd and heterogeneous P/D parallel configs (#167)


Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
Co-authored-by: default avatarNeelay Shah <neelays@nvidia.com>
parent dddebc0d
......@@ -35,4 +35,5 @@
**/.git
**/.github
**/*backup*/
.dockerignore
\ No newline at end of file
.dockerignore
**/target/*
\ No newline at end of file
......@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
exclude: ^src/grpc_generated
exclude: ^(src/grpc_generated|.*\.patch$)
repos:
- repo: https://github.com/timothycrosley/isort
rev: 5.12.0
......
......@@ -15,7 +15,6 @@
ARG BASE_IMAGE="nvcr.io/nvidia/tritonserver"
ARG BASE_IMAGE_TAG="25.01-py3"
ARG VLLM_WHEEL
FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} AS triton-distributed
......@@ -76,21 +75,6 @@ RUN --mount=type=bind,source=./container/deps/requirements.tensorrtllm.txt,targe
--mount=type=bind,source=./container/deps/clone_tensorrtllm.sh,target=/tmp/clone_tensorrtllm.sh \
if [[ "$FRAMEWORK" == "TENSORRTLLM" ]] ; then pip install --timeout=2000 -r /tmp/requirements.txt; /tmp/clone_tensorrtllm.sh --tensorrtllm-backend-repo-tag ${TENSORRTLLM_BACKEND_REPO_TAG} --tensorrtllm-backend-rebuild ${TENSORRTLLM_BACKEND_REBUILD} ; fi
RUN --mount=type=bind,source=./container/deps/requirements.vllm.txt,target=/tmp/requirements.txt \
if [[ "$FRAMEWORK" == "VLLM" ]] ; then pip install --timeout=2000 -r /tmp/requirements.txt ; fi
# NOTE: python3-distro is a system package that causes conflicts with user
# packages installed by `pip`. Removing the system package allows user packages
# to correctly manage dependencies with the `distro` pip user package instead.
RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
if [[ "$FRAMEWORK" == "VLLM" ]]; then \
apt install -y zip && \
apt remove -y python3-distro && \
pip install distro && \
cp -r /tmp/deps/vllm /tmp/local_vllm && \
cd /tmp/local_vllm && \
bash ./prepare_wheel.sh --install --debug --force ; \
fi
RUN --mount=type=bind,source=./container/deps/requirements.standard.txt,target=/tmp/requirements.txt \
if [[ "$FRAMEWORK" == "STANDARD" ]] ; then pip install --timeout=2000 -r /tmp/requirements.txt ; fi
......@@ -102,23 +86,6 @@ ENV LD_LIBRARY_PATH=${FRAMEWORK_LD_LIBRARY_PATH}:${LD_LIBRARY_PATH}
ENV TENSORRTLLM_BACKEND_REPO_TAG=$TENSORRTLLM_BACKEND_REPO_TAG
ENV TRTLLM_USE_MPI_KVCACHE=${TENSORRTLLM_FRAMEWORK:+"1"}
# TODO set VLLM Version
# ENV VLLM_VERSION
ARG VLLM_FRAMEWORK
# DEFAULT VLLM VARIABLES
ENV VLLM_ATTENTION_BACKEND=${VLLM_FRAMEWORK:+FLASHINFER}
ENV VLLM_WORKER_MULTIPROC_METHOD=${VLLM_FRAMEWORK:+spawn}
ENV VLLM_TORCH_HOST=${VLLM_FRAMEWORK:+localhost}
ENV VLLM_TORCH_PORT=${VLLM_FRAMEWORK:+36183}
ENV VLLM_DATA_PLANE_BACKEND=${VLLM_FRAMEWORK:+nccl}
ENV VLLM_BASELINE_WORKERS=${VLLM_FRAMEWORK:+0}
ENV VLLM_CONTEXT_WORKERS=${VLLM_FRAMEWORK:+1}
ENV VLLM_GENERATE_WORKERS=${VLLM_FRAMEWORK:+1}
ENV VLLM_BASELINE_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV VLLM_CONTEXT_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV VLLM_GENERATE_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV PYTHONUNBUFFERED=1
# Install NATS - pointing toward NATS github instead of binaries.nats.dev due to server instability
RUN wget https://github.com/nats-io/nats-server/releases/download/v2.10.24/nats-server-v2.10.24-amd64.deb && dpkg -i nats-server-v2.10.24-amd64.deb
......@@ -161,6 +128,16 @@ RUN mkdir /opt/triton && \
uv build && \
uv pip install dist/triton_distributed_rs*cp312*.whl
# Install patched vllm
ARG VLLM_REF="v0.7.2"
ARG VLLM_PATCH="vllm_v0.7.2.patch"
RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
if [[ "$FRAMEWORK" == "VLLM" ]]; then \
source /opt/triton/venv/bin/activate && \
bash /tmp/deps/vllm/install.sh --patch /tmp/deps/vllm/${VLLM_PATCH} --ref ${VLLM_REF} --install-cmd "uv pip install --editable" --use-precompiled --installation-dir /opt/vllm ; \
fi
# Install triton_distributed_rs wheel globally in container for tests that
# currently run without virtual environment activated.
# TODO: In future, we may use a virtualenv for everything and remove this.
......
......@@ -71,7 +71,7 @@ TENSORRTLLM_BACKEND_REBUILD=0
# vllm version installed in the base image.
VLLM_BASE_VERSION=25.01
VLLM_BASE_IMAGE=nvcr.io/nvidia/tritonserver
VLLM_BASE_IMAGE_TAG=${VLLM_BASE_VERSION}-vllm-python-py3
VLLM_BASE_IMAGE_TAG=${VLLM_BASE_VERSION}-py3
get_options() {
while :; do
......
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Necessary for vLLM engine.
--extra-index-url https://flashinfer.ai/whl/cu121/torch2.4
flashinfer<0.2.0
# Necessary for vLLM engine.
ninja==1.11.1.3
ucx-py-cu12
# vLLM is installed by patching script
# vllm==0.6.3post1
......@@ -15,4 +15,4 @@ See the License for the specific language governing permissions and
limitations under the License.
-->
Apply this patch to Python source code from vLLM release [v0.6.3.post1](https://github.com/vllm-project/vllm/releases/tag/v0.6.3.post1). Copy files in ``data_plane`` folder into vLLM folder ``vllm/distributed``.
Apply this patch to Python source code from vLLM release [v0.7.2](https://github.com/vllm-project/vllm/releases/tag/v0.7.2).
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# A script to download a Python wheel, patch it, copy additional files,
# repackage it, and optionally install the new wheel.
import logging
import math
import socket
import threading
import typing
import uuid
import torch
import torch.distributed
import tritonserver
import zmq
from triton_distributed.icp.data_plane import (
set_icp_data_type,
set_icp_memory_type,
set_icp_shape,
set_icp_tensor_size,
set_icp_tensor_uri,
)
from triton_distributed.icp.protos.icp_pb2 import ModelInferRequest
from triton_distributed.icp.ucp_data_plane import DataPlaneError, UcpDataPlane
logger = logging.getLogger(__name__)
class VllmUcpDataPlane:
def __init__(
self,
hostname: typing.Optional[str] = None,
port: int = 0,
keep_endpoints_open: bool = False,
) -> None:
self._data_plane = UcpDataPlane(hostname, port, keep_endpoints_open)
@property
def hostname(self) -> str:
return self._data_plane.hostname
@property
def port(self) -> int:
return self._data_plane.port
def connect(self) -> None:
self._data_plane.connect()
def close(self) -> None:
self._data_plane.close()
def put_input_tensor(
self,
tensor: torch.Tensor,
tensor_id: typing.Optional[uuid.UUID] = None,
):
logger.debug(
f"Putting input tensor with id {tensor_id} on {self.hostname}:{self.port}"
)
if tensor.dtype == torch.bfloat16:
tensor = tensor.view(torch.float16)
triton_tensor = tritonserver.Tensor.from_dlpack(tensor)
self._data_plane.put_input_tensor(triton_tensor, tensor_id)
def put_output_tensor(
self,
tensor: torch.Tensor,
tensor_id: typing.Optional[uuid.UUID] = None,
):
logger.debug(
f"Putting input tensor with id {tensor_id} on {self.hostname}:{self.port}"
)
if tensor.dtype == torch.bfloat16:
tensor = tensor.view(torch.float16)
triton_tensor = tritonserver.Tensor.from_dlpack(tensor)
self._data_plane.put_output_tensor(triton_tensor, tensor_id)
def get_tensor(
self,
tensor_uri: str,
shape: typing.Sequence[int],
dtype: torch.dtype,
device_id: int,
) -> torch.Tensor:
logger.debug("Getting tensor from %s", tensor_uri)
result = ModelInferRequest.InferInputTensor()
triton_dtype = {
torch.float32: tritonserver.DataType.FP32,
torch.float16: tritonserver.DataType.FP16,
torch.bfloat16: tritonserver.DataType.FP16,
torch.uint8: tritonserver.DataType.UINT8,
}.get(dtype)
if triton_dtype is None:
raise DataPlaneError(f"Unsupported dtype {dtype}")
tensor_size = math.prod(shape) * dtype.itemsize
set_icp_data_type(result, triton_dtype)
set_icp_shape(result, shape)
set_icp_tensor_uri(result, tensor_uri)
set_icp_memory_type(result, tritonserver.MemoryType.GPU)
set_icp_tensor_size(result, tensor_size)
triton_tensor = self._data_plane.get_tensor(
remote_tensor=result,
requested_memory_type=tritonserver.MemoryType.GPU,
requested_memory_type_id=device_id,
)
tensor = torch.utils.dlpack.from_dlpack(triton_tensor)
if dtype == torch.bfloat16:
tensor = tensor.view(torch.bfloat16)
logger.debug("Got tensor from %s", tensor_uri)
return tensor
class VllmNcclDataPlane:
def __init__(
self,
hostname: str = "",
port: int = 0,
# FIXME: world_size and rank both unused
world_size: int = -1,
rank: int = -1,
) -> None:
if not torch.distributed.is_initialized():
raise RuntimeError("NCCL backend not initialized")
if not hostname:
hostname = socket.gethostname()
if port == 0:
port = 13337 + torch.distributed.get_rank()
self._hostname = hostname
self._port = port
self._rank = torch.distributed.get_rank()
self._world_size: int = world_size
self._current_device = torch.cuda.current_device()
# FIXME: Use stricter type for req value in tuple
self.store: typing.Dict[str, typing.Tuple[torch.Tensor, int, typing.Any]] = {}
self.context = zmq.Context()
self.rep_socket = self.context.socket(zmq.REP)
logger.info(f"Rank {self._rank} binding to {self._hostname}:{self._port}")
self.rep_socket.bind(f"tcp://{self._hostname}:{self._port}")
self._listener_thread = threading.Thread(
target=self.listen_for_requests, daemon=True
)
self._listener_thread.start()
# FIXME: Use stricter ZMQ socket type hint
self.req_sockets: typing.Dict[str, typing.Any] = {}
logger.info(f"Rank {self._rank} connected to the server")
@property
def world_size(self):
return self._world_size
@property
def rank(self):
return self._rank
def put_input_tensor(
self,
tensor: torch.Tensor,
rank: int,
tensor_id: str,
remote_address: typing.Optional[str] = None,
):
return self._put_tensor(tensor, rank, tensor_id, remote_address)
def put_output_tensor(
self,
tensor: torch.Tensor,
rank: int,
tensor_id: str,
remote_address: typing.Optional[str] = None,
):
return self._put_tensor(tensor, rank, tensor_id, remote_address)
def get_tensor(
self,
rank: int,
tensor_id: str,
remote_address: str,
) -> torch.Tensor:
return self._get_tensor(rank, tensor_id, remote_address)
def _put_tensor(
self,
tensor: torch.Tensor,
rank: int,
tensor_id: str,
remote_address: typing.Optional[str] = None,
):
logger.debug(
f"Rank {self._rank} storing tensor with id {tensor_id} of shape {tensor.shape} and dtype {tensor.dtype}"
)
if remote_address is None:
self.store[tensor_id] = (tensor, rank, None)
else:
tensor_shape = "_".join(str(dim) for dim in tensor.shape)
if remote_address not in self.req_sockets:
self.req_sockets[remote_address] = self.context.socket(zmq.REQ)
self.req_sockets[remote_address].connect(f"tcp://{remote_address}")
req_socket = self.req_sockets[remote_address]
req_socket.connect(f"tcp://{remote_address}")
req_socket.send_string(f"PUT {self._rank} {tensor_shape} {tensor_id}")
ret = req_socket.recv_string()
assert ret == "OK"
torch.distributed.isend(tensor, dst=rank)
def _get_tensor(
self,
rank: int,
tensor_id: str,
remote_address: str,
) -> torch.Tensor:
logger.debug(f"Rank {self._rank} receiving tensor from rank {rank}")
if tensor_id in self.store:
tensor, _, req = self.store.pop(tensor_id)
req.wait() # TODO ptarasiewicz we should run other request instead of wait
logger.debug(f"Rank {self._rank} received tensor from rank {rank}")
return tensor
raise NotImplementedError("Getting tensor from remote rank not implemented")
def _receive_tensor(
self,
tensor_id: str,
rank: int,
shape: typing.Sequence[int],
):
tensor = torch.empty(
shape, dtype=torch.uint8, device=f"cuda:{self._current_device}"
)
req = torch.distributed.irecv(tensor, src=rank)
self.store[tensor_id] = (tensor, rank, req)
def listen_for_requests(self):
while True:
cmd, _rank, _shape, tensor_id = self.rep_socket.recv_string().split()
logger.debug(f"Rank {self._rank} received request for tensor {tensor_id}")
self.rep_socket.send_string("OK")
if cmd == "GET":
raise NotImplementedError(
"Getting tensor from remote rank not implemented"
)
elif cmd == "PUT":
rank = int(_rank)
shape = [int(dim) for dim in _shape.split("_")]
self._receive_tensor(tensor_id, rank, shape)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# A script to download a Python wheel, patch it, copy additional files,
# repackage it, and optionally install the new wheel.
# FIXME: Address type checking with divergent interfaces for Ucp/Nccl data planes
# type: ignore
import typing
if typing.TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata # type: ignore
from vllm.attention.backends.abstract import AttentionMetadata # type: ignore
import hashlib
import uuid
from concurrent.futures import ThreadPoolExecutor
import torch
import torch.distributed
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.flashinfer import FlashInferBackend, FlashInferMetadata
from vllm.distributed.data_plane import VllmNcclDataPlane, VllmUcpDataPlane
from vllm.distributed.parallel_state import get_store, get_tp_group # type: ignore
from vllm.logger import init_logger
logger = init_logger(__name__)
_kv_cache_handler = None
def get_kv_cache_handler():
global _kv_cache_handler
if _kv_cache_handler is None:
_kv_cache_handler = KVCacheHandler()
return _kv_cache_handler
class KVCacheHandler:
def __init__(self):
if _kv_cache_handler is not None:
raise ValueError("KVCacheHandler is a singleton")
self._data_plane_backend = envs.VLLM_DATA_PLANE_BACKEND
if self._data_plane_backend == "nccl":
self._data_plane = VllmNcclDataPlane()
self._store = get_store()
logger.info("Store set up")
self._store.set(
f"worker_{envs.VLLM_WORKER_ID}_rank_{get_tp_group().local_rank}",
f"{self._data_plane._hostname}:{self._data_plane._port}",
)
elif self._data_plane_backend == "ucx":
self._data_plane = VllmUcpDataPlane(keep_endpoints_open=True)
self._data_plane.connect()
rank = torch.distributed.get_rank()
is_master = envs.VLLM_WORKER_ID == 0 and rank == 0
self._store = torch.distributed.TCPStore(
envs.VLLM_TORCH_HOST, envs.VLLM_TORCH_PORT, is_master=is_master
)
self._store.set(
f"worker_{envs.VLLM_WORKER_ID}_rank_{rank}",
f"{self._data_plane.hostname}:{self._data_plane.port}",
)
else:
raise ValueError(f"Unknown data plane backend {self._data_plane_backend}")
self._local_store = {}
self.transport_thread = ThreadPoolExecutor(max_workers=1)
logger.info("KVCacheHandler initialized")
def send(
self,
model: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: typing.List[torch.Tensor],
):
with torch.cuda.nvtx.range("KV send"):
self._send(
model,
model_input,
kv_caches,
)
def _send(
self,
model: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: typing.List[torch.Tensor],
):
seq_lens = model_input.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model.start_layer
end_layer = model.end_layer
request_ids = list(model_input.request_ids_to_seq_ids.keys())
_, _, block_size, num_heads, _ = kv_caches[0].shape
attention_backend = _get_attention_backend(model_input.attn_metadata)
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
request_id = request_ids[idx]
keys, values = [], []
logger.debug(
f"seq_len {slen}, start_pos {start_pos}, end_pos {end_pos}, slot_mapping_flat {slot_mapping_flat.shape}"
)
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
if len(model_input.attn_metadata.block_tables[idx]) > 0:
block_inds = torch.range(
0,
block_size - 1,
device=current_slot_mapping.device,
dtype=current_slot_mapping.dtype,
)
additional_inds = (
torch.cat(
[
block_inds + block_size * i
for i in model_input.attn_metadata.block_tables[idx]
]
)
.to(current_slot_mapping.device)
.to(current_slot_mapping.dtype)
)
logger.debug(f"additional_inds: {additional_inds.shape}")
logger.debug(f"current_slot_mapping: {current_slot_mapping.shape}")
current_slot_mapping = torch.cat(
[additional_inds, current_slot_mapping]
)
logger.debug(f"new current_slot_mapping: {current_slot_mapping.shape}")
current_slot_mapping_quotient = current_slot_mapping // block_size
current_slot_mapping_remainder = current_slot_mapping % block_size
logger.debug("kv_caches shape: %s", kv_caches[0].shape)
for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]
if attention_backend == "flash_attn":
key_cache = kv_cache[
0, current_slot_mapping_quotient, current_slot_mapping_remainder
]
value_cache = kv_cache[
1, current_slot_mapping_quotient, current_slot_mapping_remainder
]
elif attention_backend == "flash_infer":
key_cache = kv_cache[
current_slot_mapping_quotient, 0, current_slot_mapping_remainder
]
value_cache = kv_cache[
current_slot_mapping_quotient, 1, current_slot_mapping_remainder
]
else:
raise ValueError(f"Unknown attention backend {attention_backend}")
keys.append(key_cache)
values.append(value_cache)
keys = torch.stack(keys, dim=0) # type: ignore
values = torch.stack(values, dim=0) # type: ignore
tp_multipler = envs.VLLM_GENERATE_TP_SIZE // envs.VLLM_CONTEXT_TP_SIZE
first_rank = envs.VLLM_CONTEXT_WORKERS * envs.VLLM_CONTEXT_TP_SIZE
for i in range(tp_multipler):
num_heads_per_generate_rank = num_heads // tp_multipler
first_head = i * num_heads_per_generate_rank
partial_keys = keys[
:, :, first_head : first_head + num_heads_per_generate_rank, :
].clone() # type: ignore
partial_values = values[
:, :, first_head : first_head + num_heads_per_generate_rank, :
].clone() # type: ignore
target_local_rank = get_tp_group().local_rank * tp_multipler + i
target_rank = target_local_rank + first_rank
# torch.cuda.synchronize()
self._send_tensors(
request_id,
target_rank,
target_local_rank,
partial_keys,
partial_values,
)
logger.debug("Tensors sent")
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
def recv(
self,
model: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: typing.List[torch.Tensor],
):
with torch.cuda.nvtx.range("KV recv"):
self._recv(
model.start_layer,
model.end_layer,
[
model.layers[i].self_attn.attn.kv_cache_dtype
for i in range(model.start_layer, model.end_layer)
],
[
model.layers[i].self_attn.attn._k_scale
for i in range(model.start_layer, model.end_layer)
],
[
model.layers[i].self_attn.attn._v_scale
for i in range(model.start_layer, model.end_layer)
],
model_input,
kv_caches,
)
def _recv(
self,
start_layer: int,
end_layer: int,
kv_cache_dtypes: typing.List[torch.dtype],
k_scales: typing.List[float],
v_scales: typing.List[float],
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: typing.List[torch.Tensor],
):
seq_lens = model_input.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
request_ids = list(model_input.request_ids_to_seq_ids.keys())
_, _, block_size, num_heads, head_dim = kv_caches[0].shape
attention_backend = _get_attention_backend(model_input.attn_metadata)
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
request_id = request_ids[idx]
num_tokens = slen
base_request_id, context_worker_id = request_id.split("___")
context_worker_id = int(context_worker_id)
keys, values = self._recv_tensors(
base_request_id,
context_worker_id,
num_tokens,
end_layer - start_layer,
num_heads,
head_dim,
)
logger.debug(f"Received tensors for request_id {request_id}")
if kv_caches[0].dtype == torch.uint8:
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer("fp8_e4m3")
keys = keys.view(torch_dtype).to(torch.bfloat16)
values = values.view(torch_dtype).to(torch.bfloat16)
logger.debug("Converted caches to torch.blfoat16")
for i in range(start_layer, end_layer):
kv_cache = kv_caches[i - start_layer]
key, value = keys[i], values[i]
if attention_backend == "flash_attn":
key_cache, value_cache = kv_cache[0], kv_cache[1]
elif attention_backend == "flash_infer":
key_cache, value_cache = kv_cache[:, 0], kv_cache[:, 1]
else:
raise ValueError(f"Unknown attention backend {attention_backend}")
ops.reshape_and_cache_flash( # type: ignore
key,
value,
key_cache,
value_cache,
slot_mapping_flat[start_pos:end_pos],
kv_cache_dtypes[i],
k_scales[i],
v_scales[i],
)
logger.debug(f"KV receive DONE for rank {torch.distributed.get_rank()}")
def _send_tensors(self, request_id, target_rank, target_local_rank, keys, values):
logger.debug(
f"Sending tensors for request_id {request_id} to rank {target_rank}"
)
logger.debug(f"Tensor shapes: keys {keys.shape}, values {values.shape}")
logger.debug(f"Tensor dtypes: keys {keys.dtype}, values {values.dtype}")
if self._data_plane_backend == "nccl":
self._send_tensors_nccl(
request_id, target_rank, target_local_rank, keys, values
)
elif self._data_plane_backend == "ucx":
self._send_tensors_ucx(request_id, target_local_rank, keys, values)
def _send_tensors_nccl(
self, request_id, target_rank, target_local_rank, keys, values
):
generate_worker_id = envs.VLLM_CONTEXT_WORKERS
target_addr = self._store.get(
f"worker_{generate_worker_id}_rank_{target_local_rank}"
).decode()
self._data_plane.put_output_tensor(
keys,
rank=target_rank,
tensor_id=f"{request_id}_keys_rank{target_local_rank}",
remote_address=target_addr,
)
self._data_plane.put_output_tensor(
values,
rank=target_rank,
tensor_id=f"{request_id}_values_rank{target_local_rank}",
remote_address=target_addr,
)
def _send_tensors_ucx(self, request_id, target_local_rank, keys, values):
torch.cuda.synchronize()
self._data_plane.put_output_tensor(
keys,
_create_id_from_str(f"{request_id}_keys_rank{target_local_rank}"),
)
self._data_plane.put_output_tensor(
values,
_create_id_from_str(f"{request_id}_values_rank{target_local_rank}"),
)
def _recv_tensors(
self, request_id, context_worker_id, num_tokens, num_layers, num_heads, head_dim
):
logger.debug(
f"Receiving tensors for request_id {request_id} from worker {context_worker_id}"
)
if self._data_plane_backend == "nccl":
return self._recv_tensors_nccl(
request_id,
context_worker_id,
num_tokens,
num_layers,
num_heads,
head_dim,
)
elif self._data_plane_backend == "ucx":
return self._recv_tensors_ucx(
request_id,
context_worker_id,
num_tokens,
num_layers,
num_heads,
head_dim,
)
def _recv_tensors_nccl(
self, request_id, context_worker_id, num_tokens, num_layers, num_heads, head_dim
):
tp_rank = get_tp_group().local_rank
tp_multipler = envs.VLLM_GENERATE_TP_SIZE // envs.VLLM_CONTEXT_TP_SIZE
source_tp_rank = tp_rank // tp_multipler
source_rank = context_worker_id * envs.VLLM_CONTEXT_TP_SIZE + source_tp_rank
worker_key = f"worker_{context_worker_id}_rank_{source_tp_rank}"
source_addr = self._local_store.get(worker_key)
if source_addr is None:
logger.info(
"Fetching source address for worker %d by key %s",
context_worker_id,
worker_key,
)
source_addr = self._store.get(worker_key).decode()
self._local_store[worker_key] = source_addr
keys = self._data_plane.get_tensor(
rank=source_rank,
tensor_id=f"{request_id}_keys_rank{tp_rank}",
remote_address=source_addr,
)
values = self._data_plane.get_tensor(
rank=source_rank,
tensor_id=f"{request_id}_values_rank{tp_rank}",
remote_address=source_addr,
)
return keys, values
def _recv_tensors_ucx(
self, request_id, context_worker_id, num_tokens, num_layers, num_heads, head_dim
):
local_rank = get_tp_group().local_rank
tp_rank = get_tp_group().local_rank
tp_multipler = envs.VLLM_GENERATE_TP_SIZE // envs.VLLM_CONTEXT_TP_SIZE
source_tp_rank = tp_rank // tp_multipler
source_addr = self._store.get(
f"worker_{context_worker_id}_rank_{source_tp_rank}"
).decode()
keys_id = _create_id_from_str(f"{request_id}_keys_rank{local_rank}")
keys_uri = f"ucp://{source_addr}/{keys_id}"
keys = self._data_plane.get_tensor(
keys_uri,
(num_layers, num_tokens, num_heads, head_dim),
torch.uint8,
device_id=local_rank,
)
values_id = _create_id_from_str(f"{request_id}_values_rank{local_rank}")
values_uri = f"ucp://{source_addr}/{values_id}"
values = self._data_plane.get_tensor(
values_uri,
(num_layers, num_tokens, num_heads, head_dim),
torch.uint8,
device_id=local_rank,
)
return keys, values
def _get_attention_backend(attn_metadata: "AttentionMetadata") -> str:
if isinstance(attn_metadata, FlashAttentionMetadata):
return "flash_attn"
elif isinstance(attn_metadata, FlashInferMetadata):
return "flash_infer"
else:
raise ValueError(
f"Unknown attention metadata type {type(attn_metadata)}. Only FlashAttentionMetadata and FlashInferMetadata are supported."
)
def _create_id_from_str(str_id: str) -> uuid.UUID:
# Create a hash of the str string using SHA-1
hashed_key = hashlib.sha1(str_id.encode("utf-8"))
# Generate a UUID from the hash, ensuring it's in the correct format
hash_hex = hashed_key.hexdigest()[:32] # Get first 32 characters
uuid_generated = uuid.UUID(hash_hex)
return uuid_generated
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
# Print usage information
print_usage() {
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --patch PATH Apply a patch file during installation"
echo " --ref REF Specify the vLLM git reference (branch/tag/commit) to install"
echo " --install-cmd CMD Specify the installation command (default: 'pip install')"
echo " --use-precompiled Use precompiled kernels during installation"
echo " --installation-dir DIR Specify the installation directory (default: 'vllm')"
echo " --help Show this help message"
}
# Default values
INSTALL_CMD="pip install"
VLLM_REF="main"
PATCH_PATH=""
USE_PRECOMPILED=false
# Parse arguments
while [[ $# -gt 0 ]]; do
case $1 in
--patch)
PATCH_PATH="$2"
shift 2
;;
--ref)
VLLM_REF="$2"
shift 2
;;
--install-cmd)
INSTALL_CMD="$2"
shift 2
;;
--use-precompiled)
USE_PRECOMPILED=true
shift
;;
--installation-dir)
INSTALLATION_DIR="$2"
shift 2
;;
--help)
print_usage
exit 0
;;
*)
echo "Unknown argument: $1"
print_usage
exit 1
;;
esac
done
# Create temp directory and clean it up on exit
# Convert patch path to absolute path if it's relative
if [[ ! "$PATCH_PATH" = /* ]]; then
PATCH_PATH="$(pwd)/${PATCH_PATH}"
fi
# Clone vLLM repository
echo "Cloning vLLM repository at ref: $VLLM_REF"
git clone https://github.com/vllm-project/vllm.git "$INSTALLATION_DIR"
cd "$INSTALLATION_DIR"
git checkout "$VLLM_REF"
# Apply patch if provided
if [ -n "$PATCH_PATH" ]; then
echo "Applying patch from: $PATCH_PATH"
git apply "$PATCH_PATH"
fi
# Install using specified command
echo "Installing using: $INSTALL_CMD"
if [ "$USE_PRECOMPILED" = true ]; then
echo "Using precompiled kernels"
export VLLM_USE_PRECOMPILED=1
fi
$INSTALL_CMD .
echo "Installation complete!"
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
# Function to print usage
print_usage() {
echo "Usage: $0 --original-ref <original_tag_or_branch> --fork-repo <fork_repo_url> --fork-ref <fork_tag_or_branch> --output <patch_output_path>"
echo
echo "Arguments:"
echo " --original-ref The tag or branch name from the original vllm-project/vllm repo"
echo " --fork-repo The URL of the forked repository"
echo " --fork-ref The tag or branch name from the forked repository"
echo " --output Path where the generated patch file should be saved"
echo
echo "Example:"
echo " $0 --original-ref v0.2.0 --fork-repo https://github.com/user/vllm.git --fork-ref feature-branch --output ./my-patch.diff"
exit 1
}
# Parse named arguments
while [[ $# -gt 0 ]]; do
case $1 in
--original-ref)
ORIGINAL_REF="$2"
shift 2
;;
--fork-repo)
FORK_REPO="$2"
shift 2
;;
--fork-ref)
FORK_REF="$2"
shift 2
;;
--output)
PATCH_OUTPUT="$2"
shift 2
;;
*)
print_usage
;;
esac
done
# Check if all required arguments are provided
if [ -z "$ORIGINAL_REF" ] || [ -z "$FORK_REPO" ] || [ -z "$FORK_REF" ] || [ -z "$PATCH_OUTPUT" ]; then
print_usage
fi
# Convert patch output path to absolute path if it's relative
if [[ ! "$PATCH_OUTPUT" = /* ]]; then
PATCH_OUTPUT="$(pwd)/${PATCH_OUTPUT}"
fi
TEMP_DIR=$(mktemp -d)
# Clean up temp directory on script exit
trap 'rm -rf "$TEMP_DIR"' EXIT
# Clone original vLLM to a temp directory
git clone https://github.com/vllm-project/vllm.git "$TEMP_DIR/original_vllm"
cd "$TEMP_DIR/original_vllm"
git remote add fork "$FORK_REPO"
git fetch fork "$FORK_REF"
git diff "$ORIGINAL_REF" fork/"$FORK_REF" > "$PATCH_OUTPUT"
echo "Patch created successfully: $PATCH_OUTPUT"
\ No newline at end of file
#!/bin/bash -e
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# A script to download a Python wheel, patch it, copy additional files,
# repackage it, and optionally install the new wheel.
###############################################################################
# CONFIGURATION & DEFAULTS
###############################################################################
DEFAULT_WHEEL_URL="https://files.pythonhosted.org/packages/4a/4c/ee65ba33467a4c0de350ce29fbae39b9d0e7fcd887cc756fa993654d1228/vllm-0.6.3.post1-cp38-abi3-manylinux1_x86_64.whl"
DEFAULT_PATCH_FILE="vllm_patch_063post1.patch"
DEFAULT_DATA_PLANE_DIR="data_plane"
DEFAULT_WHEEL_DIR="wheel"
DEFAULT_OUTPUT_WHEEL="vllm-dist-0.6.3.post1-cp38-abi3-manylinux1_x86_64.whl"
# Optionally set a default SHA256 checksum for the downloaded wheel
# DEFAULT_CHECKSUM="SOME_SHA256_HERE"
###############################################################################
# HELPER FUNCTIONS
###############################################################################
usage() {
cat << EOF
Usage: $0 [OPTIONS]
Options:
-u, --url <URL> Wheel URL to download (default: $DEFAULT_WHEEL_URL)
-p, --patch <FILE> Patch file path (default: $DEFAULT_PATCH_FILE)
-d, --data-plane <DIR> Directory with additional data-plane files (default: $DEFAULT_DATA_PLANE_DIR)
-w, --wheel-dir <DIR> Extract destination directory (default: $DEFAULT_WHEEL_DIR)
-o, --output-wheel <FILE> Name/path for the repackaged wheel (default: $DEFAULT_OUTPUT_WHEEL)
-f, --force Force overwriting existing directories/files without prompts
-i, --install Install the new wheel after repackaging
-D, --debug Enable debug mode (verbose output)
-h, --help Show this help and exit
Example:
$0 \\
--url "https://example.com/path/to/vllm.whl" \\
--patch my_patch.patch \\
--data-plane custom_data/ \\
--wheel-dir extracted_wheel \\
--output-wheel my_vllm_dist.whl \\
--install \\
--force \\
--debug
EOF
}
info_log() {
# Always prints, showing a timestamp.
echo "[$(date +'%Y-%m-%d %H:%M:%S')] INFO: $*"
}
debug_log() {
# Prints only if DEBUG=true
if [[ "$DEBUG" == true ]]; then
echo "[$(date +'%Y-%m-%d %H:%M:%S')] DEBUG: $*"
fi
}
error_exit() {
echo "ERROR: $*" >&2
exit 1
}
###############################################################################
# PARSE ARGUMENTS
###############################################################################
FORCE_OVERWRITE=false
INSTALL_WHEEL=false
DEBUG=false
WHEEL_URL="$DEFAULT_WHEEL_URL"
PATCH_FILE="$DEFAULT_PATCH_FILE"
DATA_PLANE_DIR="$DEFAULT_DATA_PLANE_DIR"
WHEEL_DIR="$DEFAULT_WHEEL_DIR"
OUTPUT_WHEEL="$DEFAULT_OUTPUT_WHEEL"
while [[ $# -gt 0 ]]; do
case "$1" in
-u|--url)
WHEEL_URL="$2"
shift 2
;;
-p|--patch)
PATCH_FILE="$2"
shift 2
;;
-d|--data-plane)
DATA_PLANE_DIR="$2"
shift 2
;;
-w|--wheel-dir)
WHEEL_DIR="$2"
shift 2
;;
-o|--output-wheel)
OUTPUT_WHEEL="$2"
shift 2
;;
-f|--force)
FORCE_OVERWRITE=true
shift
;;
-i|--install)
INSTALL_WHEEL=true
shift
;;
-D|--debug)
DEBUG=true
shift
;;
-h|--help)
usage
exit 0
;;
*)
error_exit "Unknown option: $1"
;;
esac
done
###############################################################################
# MAIN SCRIPT
###############################################################################
# Enable debug mode if requested
if [[ "$DEBUG" == true ]]; then
set -x
fi
info_log "Starting wheel patching script..."
# Function to check if a command exists
command_exists() {
command -v "$1" >/dev/null 2>&1 || { echo >&2 "I require $1 but it's not installed. Aborting."; exit 1; }
}
# Check for required commands
command_exists pip
command_exists unzip
command_exists zip
command_exists patch
# ---------------------------------------------------------------------------
# 1. Check for existing wheel file or directory
# ---------------------------------------------------------------------------
WHEEL_FILENAME=$(basename "$WHEEL_URL")
if [[ -f "$WHEEL_FILENAME" && "$FORCE_OVERWRITE" != true ]]; then
info_log "File '$WHEEL_FILENAME' already exists. Reusing existing file."
info_log "If you want to redownload, remove '$WHEEL_FILENAME' or use --force."
else
info_log "Downloading wheel from $WHEEL_URL..."
rm -f "$WHEEL_FILENAME" 2>/dev/null || true # Remove existing file if forcing
wget -O "$WHEEL_FILENAME" "$WHEEL_URL"
fi
# ---------------------------------------------------------------------------
# 2. Optional: Verify checksum (commented out by default)
# ---------------------------------------------------------------------------
# if [[ -n "$DEFAULT_CHECKSUM" ]]; then
# info_log "Verifying SHA256 checksum..."
# echo "${DEFAULT_CHECKSUM} ${WHEEL_FILENAME}" | sha256sum --check - || error_exit "Checksum mismatch!"
# fi
# ---------------------------------------------------------------------------
# 3. Create/clean wheel extraction directory
# ---------------------------------------------------------------------------
if [[ -d "$WHEEL_DIR" ]]; then
if [[ "$FORCE_OVERWRITE" == true ]]; then
info_log "Removing existing directory '$WHEEL_DIR' due to --force..."
rm -rf "$WHEEL_DIR"
else
error_exit "Directory '$WHEEL_DIR' already exists. Use --force to overwrite."
fi
fi
info_log "Creating directory '$WHEEL_DIR'..."
mkdir -p "$WHEEL_DIR"
# ---------------------------------------------------------------------------
# 4. Unzip the wheel into the specified directory
# ---------------------------------------------------------------------------
info_log "Unzipping wheel into directory '$WHEEL_DIR'..."
unzip -q "$WHEEL_FILENAME" -d "$WHEEL_DIR"
# ---------------------------------------------------------------------------
# 5. Check/Apply patch
# ---------------------------------------------------------------------------
if [[ ! -f "$PATCH_FILE" ]]; then
error_exit "Patch file '$PATCH_FILE' not found."
fi
PATCH_TARGET_DIR="$WHEEL_DIR/vllm"
if [[ ! -d "$PATCH_TARGET_DIR" ]]; then
error_exit "Could not find directory '$PATCH_TARGET_DIR' in unzipped wheel."
fi
info_log "Applying patch '$PATCH_FILE' to '$PATCH_TARGET_DIR'..."
debug_log "Executing: (cd \"$PATCH_TARGET_DIR\" && patch -p1 < \"../../$PATCH_FILE\")"
(
cd "$PATCH_TARGET_DIR"
patch -p1 < "../../$PATCH_FILE"
)
# ---------------------------------------------------------------------------
# 6. Copy data plane files
# ---------------------------------------------------------------------------
if [[ ! -d "$DATA_PLANE_DIR" ]]; then
error_exit "Data plane directory '$DATA_PLANE_DIR' not found."
fi
info_log "Copying files from '$DATA_PLANE_DIR' to '$PATCH_TARGET_DIR/distributed'..."
mkdir -p "$PATCH_TARGET_DIR/distributed"
cp -r "$DATA_PLANE_DIR/"* "$PATCH_TARGET_DIR/distributed/"
# ---------------------------------------------------------------------------
# 7. Re-package into a new wheel
# ---------------------------------------------------------------------------
if [[ -f "$OUTPUT_WHEEL" && "$FORCE_OVERWRITE" != true ]]; then
error_exit "Output wheel '$OUTPUT_WHEEL' already exists. Use --force to overwrite."
fi
info_log "Creating new wheel file '$OUTPUT_WHEEL'..."
debug_log "Executing: (cd \"$WHEEL_DIR\" && zip -rq \"../$OUTPUT_WHEEL\" .)"
(
cd "$WHEEL_DIR"
zip -rq "../$OUTPUT_WHEEL" .
)
# ---------------------------------------------------------------------------
# 8. Optional: Install the new wheel
# ---------------------------------------------------------------------------
if [[ "$INSTALL_WHEEL" == true ]]; then
# Check if pip is installed
if ! command -v pip >/dev/null 2>&1; then
error_exit "pip is not installed or not found in PATH."
fi
info_log "Installing newly created wheel '$OUTPUT_WHEEL'..."
pip install --force-reinstall --upgrade --break-system-packages "$OUTPUT_WHEEL"
fi
info_log "Patch and repackage completed successfully!"
info_log "New wheel: $OUTPUT_WHEEL"
if [[ "$INSTALL_WHEEL" == true ]]; then
info_log "Wheel has been installed."
fi
......@@ -27,14 +27,4 @@ pytestmark = pytest.mark.pre_merge
@pytest.mark.skipif(vllm is None, reason="Skipping vllm tests, vllm not installed")
def test_version():
# Verify that the image has the patched version of vllm
assert vllm.__version__ == "0.6.3.post2.dev16+gf61960ce"
@pytest.mark.skipif(vllm is None, reason="Skipping vllm tests, vllm not installed")
def test_patch_imports():
# Verify patched files have no glaring syntax or import issues
import vllm.distributed.data_plane as d
import vllm.distributed.kv_cache as k
# Placeholder to avoid unused import errors or removal by linters
assert d, k
assert vllm.__version__.startswith("0.7.3.dev")
This diff is collapsed.
This diff is collapsed.
......@@ -17,6 +17,9 @@ limitations under the License.
# Disaggregated Serving with VLLM
> **Warning**
> This example is currently not tested and might not work as expected. For working disaggregated serving examples, please see the [vLLM example](/examples/python_rs/llm/vllm/).
This example demonstrates **disaggregated serving** [^1] using Triton Distributed together with vLLM engines. Disaggregated serving decouples the prefill (prompt encoding) and the decode (token generation) stages of large language model (LLM) inference into separate processes. This separation allows you to independently scale, optimize, and distribute resources for each stage.
In this example, you will deploy:
......
......@@ -42,17 +42,13 @@ The example is designed to run in a containerized environment using Triton Distr
```bash
# Build image
./container/build.sh
./container/build.sh --framework VLLM
```
## Launching the Environment
```
# Run image interactively
./container/run.sh -it
# Add vllm into the python virtual environment
source /opt/triton/venv/bin/activate
uv pip install vllm==0.7.2
./container/run.sh --framework VLLM -it
```
## Deployment Options
......@@ -113,11 +109,12 @@ source /opt/triton/venv/bin/activate
# Launch prefill worker
cd /workspace/examples/python_rs/llm/vllm
CUDA_VISIBLE_DEVICES=0 python3 -m disaggregated.prefill_worker \
VLLM_WORKER_MULTIPROC_METHOD=spawn CUDA_VISIBLE_DEVICES=0 python3 -m disaggregated.prefill_worker \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--enforce-eager \
--tensor-parallel-size 1 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}'
```
......@@ -129,11 +126,12 @@ source /opt/triton/venv/bin/activate
# Launch decode worker
cd /workspace/examples/python_rs/llm/vllm
CUDA_VISIBLE_DEVICES=1 python3 -m disaggregated.decode_worker \
VLLM_WORKER_MULTIPROC_METHOD=spawn CUDA_VISIBLE_DEVICES=1,2 python3 -m disaggregated.decode_worker \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--enforce-eager \
--tensor-parallel-size 2 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}'
```
......@@ -174,6 +172,11 @@ For disaggregated deployment, you will also need to pass the `kv_ip` and `kv_por
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":<rank>,"kv_parallel_size":2,"kv_ip":<master_node_ip>,"kv_port":<kv_port>}'
```
### 4. Known Issues and Limitations
- vLLM is not working well with the `fork` method for multiprocessing and TP > 1. This is a known issue and a workaround is to use the `spawn` method instead. See [vLLM issue](https://github.com/vllm-project/vllm/issues/6152).
- `kv_rank` of `kv_producer` must be smaller than of `kv_consumer`.
- Instances with the same `kv_role` must have the same `--tensor-parallel-size`.
- Currently only `--pipeline-parallel-size 1` is supported for XpYd disaggregated deployment.
......@@ -14,11 +14,11 @@
# limitations under the License.
import argparse
import asyncio
import uvloop
from triton_distributed_rs import DistributedRuntime, triton_worker
from vllm.utils import FlexibleArgumentParser
from .protocol import Request
......@@ -40,22 +40,31 @@ async def worker(
print(client.endpoint_ids())
# issue request
stream = await client.generate(
Request(
prompt=prompt,
sampling_params={"temperature": temperature, "max_tokens": max_tokens},
).model_dump_json()
)
tasks = []
for _ in range(1):
tasks.append(
client.generate(
Request(
prompt=prompt,
sampling_params={
"temperature": temperature,
"max_tokens": max_tokens,
},
).model_dump_json()
)
)
streams = await asyncio.gather(*tasks)
# process response
async for resp in stream:
print(resp)
for stream in streams:
async for resp in stream:
print(resp)
if __name__ == "__main__":
uvloop.install()
parser = FlexibleArgumentParser()
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, default="what is the capital of france?")
parser.add_argument("--max-tokens", type=int, default=10)
parser.add_argument("--temperature", type=float, default=0.5)
......
......@@ -15,6 +15,7 @@
import asyncio
import random
import uuid
import uvloop
......@@ -31,18 +32,27 @@ class VllmDecodeEngine:
Request handler for the generate endpoint
"""
def __init__(self, engine_args: AsyncEngineArgs, prefill):
def __init__(self, engine_args: AsyncEngineArgs):
assert (
engine_args.kv_transfer_config.is_kv_consumer
), "Decode worker must be a KV consumer"
self.engine = vllm.AsyncLLMEngine.from_engine_args(engine_args)
self.prefill = prefill
self.prefills: list = []
self.prefill_workers = (
self.engine.engine.vllm_config.kv_transfer_config.kv_producers_parallel_size
)
self.kv_rank = self.engine.engine.vllm_config.kv_transfer_config.kv_rank
def add_prefill(self, prefill):
self.prefills.append(prefill)
@triton_endpoint(Request, Response)
async def generate(self, request):
vllm_logger.info(f"Received request: {request}")
sampling_params = vllm.SamplingParams(**request.sampling_params)
request_id = str(uuid.uuid4())
prefill_rank = random.choice(range(self.prefill_workers))
request_id = f"{uuid.uuid4()}___prefill_kv_rank_{prefill_rank}___decode_kv_rank_{self.kv_rank}"
prefill_sampling_params = {**request.sampling_params}
prefill_sampling_params["max_tokens"] = 1
......@@ -51,13 +61,9 @@ class VllmDecodeEngine:
sampling_params=prefill_sampling_params,
request_id=request_id,
)
prefill_generator = await self.prefill.generate(
prefill_request.model_dump_json()
self.prefills[prefill_rank].generate(
prefill_request.model_dump_json(),
)
prefill_response = [resp async for resp in prefill_generator]
assert len(prefill_response) == 1, "Prefill response should be a single boolean"
prefill_response = prefill_response[0]
vllm_logger.debug(f"Prefill response: {prefill_response}")
async for response in self.engine.generate(
request.prompt, sampling_params, request_id
......@@ -75,15 +81,18 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
component = runtime.namespace("triton-init").component("vllm")
await component.create_service()
prefill = (
await runtime.namespace("triton-init")
.component("prefill")
.endpoint("generate")
.client()
)
decode_engine = VllmDecodeEngine(engine_args)
for i in range(decode_engine.prefill_workers):
prefill = (
await runtime.namespace("triton-init")
.component("prefill")
.endpoint(f"generate_kv_rank_{i}")
.client()
)
decode_engine.add_prefill(prefill)
endpoint = component.endpoint("generate")
await endpoint.serve_endpoint(VllmDecodeEngine(engine_args, prefill).generate)
await endpoint.serve_endpoint(decode_engine.generate)
if __name__ == "__main__":
......
......@@ -35,6 +35,7 @@ class VllmPrefillEngine:
engine_args.kv_transfer_config.is_kv_producer
), "Prefill worker must be a KV producer"
self.engine = vllm.AsyncLLMEngine.from_engine_args(engine_args)
self.kv_rank = self.engine.engine.vllm_config.kv_transfer_config.kv_rank
@triton_endpoint(PrefillRequest, PrefillResponse)
async def generate(self, request):
......@@ -56,8 +57,9 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
component = runtime.namespace("triton-init").component("prefill")
await component.create_service()
endpoint = component.endpoint("generate")
await endpoint.serve_endpoint(VllmPrefillEngine(engine_args).generate)
prefill_engine = VllmPrefillEngine(engine_args)
endpoint = component.endpoint(f"generate_kv_rank_{prefill_engine.kv_rank}")
await endpoint.serve_endpoint(prefill_engine.generate)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment