"vscode:/vscode.git/clone" did not exist on "46ed649cb3bdf9fd8526036d291ae4b95cc1ce58"
Commit a0e1da03 authored by Ryan McCormick's avatar Ryan McCormick Committed by GitHub
Browse files

build: Add vllm wheel patch scripts for disaggregated serving




---------
Co-authored-by: default avatarPiotr Marcinkiewicz <piotrm@nvidia.com>
parent 587addb9
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
ARG BASE_IMAGE="nvcr.io/nvidia/tritonserver" ARG BASE_IMAGE="nvcr.io/nvidia/tritonserver"
ARG BASE_IMAGE_TAG="24.12-py3" ARG BASE_IMAGE_TAG="24.12-py3"
ARG VLLM_WHEEL
FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} AS triton-distributed FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} AS triton-distributed
...@@ -56,6 +57,19 @@ RUN --mount=type=bind,source=./container/deps/requirements.tensorrtllm.txt,targe ...@@ -56,6 +57,19 @@ RUN --mount=type=bind,source=./container/deps/requirements.tensorrtllm.txt,targe
RUN --mount=type=bind,source=./container/deps/requirements.vllm.txt,target=/tmp/requirements.txt \ RUN --mount=type=bind,source=./container/deps/requirements.vllm.txt,target=/tmp/requirements.txt \
if [[ "$FRAMEWORK" == "VLLM" ]] ; then pip install --timeout=2000 -r /tmp/requirements.txt ; fi if [[ "$FRAMEWORK" == "VLLM" ]] ; then pip install --timeout=2000 -r /tmp/requirements.txt ; fi
# NOTE: python3-distro is a system package that causes conflicts with user
# packages installed by `pip`. Removing the system package allows user packages
# to correctly manage dependencies with the `distro` pip user package instead.
RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
if [[ "$FRAMEWORK" == "VLLM" ]]; then \
apt install -y zip && \
apt remove -y python3-distro && \
pip install distro && \
cp -r /tmp/deps/vllm /tmp/local_vllm && \
cd /tmp/local_vllm && \
bash ./prepare_wheel.sh --install --debug --force ; \
fi
RUN --mount=type=bind,source=./container/deps/requirements.standard.txt,target=/tmp/requirements.txt \ RUN --mount=type=bind,source=./container/deps/requirements.standard.txt,target=/tmp/requirements.txt \
if [[ "$FRAMEWORK" == "STANDARD" ]] ; then pip install --timeout=2000 -r /tmp/requirements.txt ; fi if [[ "$FRAMEWORK" == "STANDARD" ]] ; then pip install --timeout=2000 -r /tmp/requirements.txt ; fi
......
...@@ -44,6 +44,8 @@ TENSORRTLLM_BASE_IMAGE_TAG=${TENSORRTLLM_BASE_VERSION}-trtllm-python-py3 ...@@ -44,6 +44,8 @@ TENSORRTLLM_BASE_IMAGE_TAG=${TENSORRTLLM_BASE_VERSION}-trtllm-python-py3
# IMPORTANT NOTE: Ensure the commit matches the TRTLLM backend version used in the base image above # IMPORTANT NOTE: Ensure the commit matches the TRTLLM backend version used in the base image above
TENSORRTLLM_BACKEND_COMMIT=v0.16.0 TENSORRTLLM_BACKEND_COMMIT=v0.16.0
# vllm installation is done later in the Dockerfile so it will overwrite the
# vllm version installed in the base image.
VLLM_BASE_VERSION=24.12 VLLM_BASE_VERSION=24.12
VLLM_BASE_IMAGE=nvcr.io/nvidia/tritonserver VLLM_BASE_IMAGE=nvcr.io/nvidia/tritonserver
VLLM_BASE_IMAGE_TAG=${VLLM_BASE_VERSION}-vllm-python-py3 VLLM_BASE_IMAGE_TAG=${VLLM_BASE_VERSION}-vllm-python-py3
......
...@@ -17,5 +17,5 @@ ...@@ -17,5 +17,5 @@
--extra-index-url https://flashinfer.ai/whl/cu121/torch2.4 --extra-index-url https://flashinfer.ai/whl/cu121/torch2.4
flashinfer flashinfer
ucx-py-cu12 ucx-py-cu12
# TODO update to branch / fork # vLLM is installed by patching script
vllm==0.6.3post1 # vllm==0.6.3post1
<!--
SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
Apply this patch to Python source code from vLLM release [v0.6.3.post1](https://github.com/vllm-project/vllm/releases/tag/v0.6.3.post1). Copy files in ``data_plane`` folder into vLLM folder ``vllm/distributed``.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# A script to download a Python wheel, patch it, copy additional files,
# repackage it, and optionally install the new wheel.
import logging
import math
import socket
import threading
import typing
import uuid
import torch
import torch.distributed
import tritonserver
import zmq
from triton_distributed.icp.data_plane import (
set_icp_data_type,
set_icp_memory_type,
set_icp_shape,
set_icp_tensor_size,
set_icp_tensor_uri,
)
from triton_distributed.icp.protos.icp_pb2 import ModelInferRequest
from triton_distributed.icp.ucp_data_plane import DataPlaneError, UcpDataPlane
logger = logging.getLogger(__name__)
class VllmUcpDataPlane:
def __init__(
self,
hostname: typing.Optional[str] = None,
port: int = 0,
keep_endpoints_open: bool = False,
) -> None:
self._data_plane = UcpDataPlane(hostname, port, keep_endpoints_open)
@property
def hostname(self) -> str:
return self._data_plane.hostname
@property
def port(self) -> int:
return self._data_plane.port
def connect(self) -> None:
self._data_plane.connect()
def close(self) -> None:
self._data_plane.close()
def put_input_tensor(
self,
tensor: torch.Tensor,
tensor_id: typing.Optional[uuid.UUID] = None,
):
logger.debug(
f"Putting input tensor with id {tensor_id} on {self.hostname}:{self.port}"
)
if tensor.dtype == torch.bfloat16:
tensor = tensor.view(torch.float16)
triton_tensor = tritonserver.Tensor.from_dlpack(tensor)
self._data_plane.put_input_tensor(triton_tensor, tensor_id)
def put_output_tensor(
self,
tensor: torch.Tensor,
tensor_id: typing.Optional[uuid.UUID] = None,
):
logger.debug(
f"Putting input tensor with id {tensor_id} on {self.hostname}:{self.port}"
)
if tensor.dtype == torch.bfloat16:
tensor = tensor.view(torch.float16)
triton_tensor = tritonserver.Tensor.from_dlpack(tensor)
self._data_plane.put_output_tensor(triton_tensor, tensor_id)
def get_tensor(
self,
tensor_uri: str,
shape: typing.Sequence[int],
dtype: torch.dtype,
device_id: int,
) -> torch.Tensor:
logger.debug("Getting tensor from %s", tensor_uri)
result = ModelInferRequest.InferInputTensor()
triton_dtype = {
torch.float32: tritonserver.DataType.FP32,
torch.float16: tritonserver.DataType.FP16,
torch.bfloat16: tritonserver.DataType.FP16,
torch.uint8: tritonserver.DataType.UINT8,
}.get(dtype)
if triton_dtype is None:
raise DataPlaneError(f"Unsupported dtype {dtype}")
tensor_size = math.prod(shape) * dtype.itemsize
set_icp_data_type(result, triton_dtype)
set_icp_shape(result, shape)
set_icp_tensor_uri(result, tensor_uri)
set_icp_memory_type(result, tritonserver.MemoryType.GPU)
set_icp_tensor_size(result, tensor_size)
triton_tensor = self._data_plane.get_tensor(
remote_tensor=result,
requested_memory_type=tritonserver.MemoryType.GPU,
requested_memory_type_id=device_id,
)
tensor = torch.utils.dlpack.from_dlpack(triton_tensor)
if dtype == torch.bfloat16:
tensor = tensor.view(torch.bfloat16)
logger.debug("Got tensor from %s", tensor_uri)
return tensor
class VllmNcclDataPlane:
def __init__(
self,
hostname: str = "",
port: int = 0,
# FIXME: world_size and rank both unused
world_size: int = -1,
rank: int = -1,
) -> None:
if not torch.distributed.is_initialized():
raise RuntimeError("NCCL backend not initialized")
if not hostname:
hostname = socket.gethostname()
if port == 0:
port = 13337 + torch.distributed.get_rank()
self._hostname = hostname
self._port = port
self._rank = torch.distributed.get_rank()
self._world_size: int = world_size
self._current_device = torch.cuda.current_device()
# FIXME: Use stricter type for req value in tuple
self.store: typing.Dict[str, typing.Tuple[torch.Tensor, int, typing.Any]] = {}
self.context = zmq.Context()
self.rep_socket = self.context.socket(zmq.REP)
logger.info(f"Rank {self._rank} binding to {self._hostname}:{self._port}")
self.rep_socket.bind(f"tcp://{self._hostname}:{self._port}")
self._listener_thread = threading.Thread(
target=self.listen_for_requests, daemon=True
)
self._listener_thread.start()
# FIXME: Use stricter ZMQ socket type hint
self.req_sockets: typing.Dict[str, typing.Any] = {}
logger.info(f"Rank {self._rank} connected to the server")
@property
def world_size(self):
return self._world_size
@property
def rank(self):
return self._rank
def put_input_tensor(
self,
tensor: torch.Tensor,
rank: int,
tensor_id: str,
remote_address: typing.Optional[str] = None,
):
return self._put_tensor(tensor, rank, tensor_id, remote_address)
def put_output_tensor(
self,
tensor: torch.Tensor,
rank: int,
tensor_id: str,
remote_address: typing.Optional[str] = None,
):
return self._put_tensor(tensor, rank, tensor_id, remote_address)
def get_tensor(
self,
rank: int,
tensor_id: str,
remote_address: str,
) -> torch.Tensor:
return self._get_tensor(rank, tensor_id, remote_address)
def _put_tensor(
self,
tensor: torch.Tensor,
rank: int,
tensor_id: str,
remote_address: typing.Optional[str] = None,
):
logger.debug(
f"Rank {self._rank} storing tensor with id {tensor_id} of shape {tensor.shape} and dtype {tensor.dtype}"
)
if remote_address is None:
self.store[tensor_id] = (tensor, rank, None)
else:
tensor_shape = "_".join(str(dim) for dim in tensor.shape)
if remote_address not in self.req_sockets:
self.req_sockets[remote_address] = self.context.socket(zmq.REQ)
self.req_sockets[remote_address].connect(f"tcp://{remote_address}")
req_socket = self.req_sockets[remote_address]
req_socket.connect(f"tcp://{remote_address}")
req_socket.send_string(f"PUT {self._rank} {tensor_shape} {tensor_id}")
ret = req_socket.recv_string()
assert ret == "OK"
torch.distributed.isend(tensor, dst=rank)
def _get_tensor(
self,
rank: int,
tensor_id: str,
remote_address: str,
) -> torch.Tensor:
logger.debug(f"Rank {self._rank} receiving tensor from rank {rank}")
if tensor_id in self.store:
tensor, _, req = self.store.pop(tensor_id)
req.wait() # TODO ptarasiewicz we should run other request instead of wait
logger.debug(f"Rank {self._rank} received tensor from rank {rank}")
return tensor
raise NotImplementedError("Getting tensor from remote rank not implemented")
def _receive_tensor(
self,
tensor_id: str,
rank: int,
shape: typing.Sequence[int],
):
tensor = torch.empty(
shape, dtype=torch.uint8, device=f"cuda:{self._current_device}"
)
req = torch.distributed.irecv(tensor, src=rank)
self.store[tensor_id] = (tensor, rank, req)
def listen_for_requests(self):
while True:
cmd, _rank, _shape, tensor_id = self.rep_socket.recv_string().split()
logger.debug(f"Rank {self._rank} received request for tensor {tensor_id}")
self.rep_socket.send_string("OK")
if cmd == "GET":
raise NotImplementedError(
"Getting tensor from remote rank not implemented"
)
elif cmd == "PUT":
rank = int(_rank)
shape = [int(dim) for dim in _shape.split("_")]
self._receive_tensor(tensor_id, rank, shape)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# A script to download a Python wheel, patch it, copy additional files,
# repackage it, and optionally install the new wheel.
import typing
if typing.TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata # type: ignore
from vllm.attention.backends.abstract import AttentionMetadata # type: ignore
import hashlib
import uuid
from concurrent.futures import ThreadPoolExecutor
import torch
import torch.distributed
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.flashinfer import FlashInferBackend, FlashInferMetadata
from vllm.distributed.data_plane import VllmNcclDataPlane, VllmUcpDataPlane
from vllm.distributed.parallel_state import get_store, get_tp_group # type: ignore
from vllm.logger import init_logger
logger = init_logger(__name__)
_kv_cache_handler = None
def get_kv_cache_handler():
global _kv_cache_handler
if _kv_cache_handler is None:
_kv_cache_handler = KVCacheHandler()
return _kv_cache_handler
class KVCacheHandler:
def __init__(self):
if _kv_cache_handler is not None:
raise ValueError("KVCacheHandler is a singleton")
self._data_plane_backend = envs.VLLM_DATA_PLANE_BACKEND
if self._data_plane_backend == "nccl":
self._data_plane = VllmNcclDataPlane()
self._store = get_store()
logger.info("Store set up")
self._store.set(
f"worker_{envs.VLLM_WORKER_ID}_rank_{get_tp_group().local_rank}",
f"{self._data_plane._hostname}:{self._data_plane._port}",
)
elif self._data_plane_backend == "ucx":
self._data_plane = VllmUcpDataPlane(keep_endpoints_open=True)
self._data_plane.connect()
rank = torch.distributed.get_rank()
is_master = envs.VLLM_WORKER_ID == 0 and rank == 0
self._store = torch.distributed.TCPStore(
envs.VLLM_TORCH_HOST, envs.VLLM_TORCH_PORT, is_master=is_master
)
self._store.set(
f"worker_{envs.VLLM_WORKER_ID}_rank_{rank}",
f"{self._data_plane.hostname}:{self._data_plane.port}",
)
else:
raise ValueError(f"Unknown data plane backend {self._data_plane_backend}")
self._local_store = {}
self.transport_thread = ThreadPoolExecutor(max_workers=1)
logger.info("KVCacheHandler initialized")
def send(
self,
model: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: typing.List[torch.Tensor],
):
with torch.cuda.nvtx.range("KV send"):
self._send(
model,
model_input,
kv_caches,
)
def _send(
self,
model: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: typing.List[torch.Tensor],
):
seq_lens = model_input.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model.start_layer
end_layer = model.end_layer
request_ids = list(model_input.request_ids_to_seq_ids.keys())
_, _, block_size, num_heads, _ = kv_caches[0].shape
attention_backend = _get_attention_backend(model_input.attn_metadata)
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
request_id = request_ids[idx]
keys, values = [], []
logger.debug(
f"seq_len {slen}, start_pos {start_pos}, end_pos {end_pos}, slot_mapping_flat {slot_mapping_flat.shape}"
)
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
if len(model_input.attn_metadata.block_tables[idx]) > 0:
block_inds = torch.range(
0,
block_size - 1,
device=current_slot_mapping.device,
dtype=current_slot_mapping.dtype,
)
additional_inds = (
torch.cat(
[
block_inds + block_size * i
for i in model_input.attn_metadata.block_tables[idx]
]
)
.to(current_slot_mapping.device)
.to(current_slot_mapping.dtype)
)
logger.debug(f"additional_inds: {additional_inds.shape}")
logger.debug(f"current_slot_mapping: {current_slot_mapping.shape}")
current_slot_mapping = torch.cat(
[additional_inds, current_slot_mapping]
)
logger.debug(f"new current_slot_mapping: {current_slot_mapping.shape}")
current_slot_mapping_quotient = current_slot_mapping // block_size
current_slot_mapping_remainder = current_slot_mapping % block_size
logger.debug("kv_caches shape: %s", kv_caches[0].shape)
for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]
if attention_backend == "flash_attn":
key_cache = kv_cache[
0, current_slot_mapping_quotient, current_slot_mapping_remainder
]
value_cache = kv_cache[
1, current_slot_mapping_quotient, current_slot_mapping_remainder
]
elif attention_backend == "flash_infer":
key_cache = kv_cache[
current_slot_mapping_quotient, 0, current_slot_mapping_remainder
]
value_cache = kv_cache[
current_slot_mapping_quotient, 1, current_slot_mapping_remainder
]
else:
raise ValueError(f"Unknown attention backend {attention_backend}")
keys.append(key_cache)
values.append(value_cache)
keys = torch.stack(keys, dim=0) # type: ignore
values = torch.stack(values, dim=0) # type: ignore
tp_multipler = envs.VLLM_GENERATE_TP_SIZE // envs.VLLM_CONTEXT_TP_SIZE
first_rank = envs.VLLM_CONTEXT_WORKERS * envs.VLLM_CONTEXT_TP_SIZE
for i in range(tp_multipler):
num_heads_per_generate_rank = num_heads // tp_multipler
first_head = i * num_heads_per_generate_rank
partial_keys = keys[
:, :, first_head : first_head + num_heads_per_generate_rank, :
].clone() # type: ignore
partial_values = values[
:, :, first_head : first_head + num_heads_per_generate_rank, :
].clone() # type: ignore
target_local_rank = get_tp_group().local_rank * tp_multipler + i
target_rank = target_local_rank + first_rank
# torch.cuda.synchronize()
self._send_tensors(
request_id,
target_rank,
target_local_rank,
partial_keys,
partial_values,
)
logger.debug("Tensors sent")
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
def recv(
self,
model: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: typing.List[torch.Tensor],
):
with torch.cuda.nvtx.range("KV recv"):
self._recv(
model.start_layer,
model.end_layer,
[
model.layers[i].self_attn.attn.kv_cache_dtype
for i in range(model.start_layer, model.end_layer)
],
[
model.layers[i].self_attn.attn._k_scale
for i in range(model.start_layer, model.end_layer)
],
[
model.layers[i].self_attn.attn._v_scale
for i in range(model.start_layer, model.end_layer)
],
model_input,
kv_caches,
)
def _recv(
self,
start_layer: int,
end_layer: int,
kv_cache_dtypes: typing.List[torch.dtype],
k_scales: typing.List[float],
v_scales: typing.List[float],
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: typing.List[torch.Tensor],
):
seq_lens = model_input.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
request_ids = list(model_input.request_ids_to_seq_ids.keys())
_, _, block_size, num_heads, head_dim = kv_caches[0].shape
attention_backend = _get_attention_backend(model_input.attn_metadata)
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
request_id = request_ids[idx]
num_tokens = slen
base_request_id, context_worker_id = request_id.split("___")
context_worker_id = int(context_worker_id)
keys, values = self._recv_tensors(
base_request_id,
context_worker_id,
num_tokens,
end_layer - start_layer,
num_heads,
head_dim,
)
logger.debug(f"Received tensors for request_id {request_id}")
if kv_caches[0].dtype == torch.uint8:
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer("fp8_e4m3")
keys = keys.view(torch_dtype).to(torch.bfloat16)
values = values.view(torch_dtype).to(torch.bfloat16)
logger.debug("Converted caches to torch.blfoat16")
for i in range(start_layer, end_layer):
kv_cache = kv_caches[i - start_layer]
key, value = keys[i], values[i]
if attention_backend == "flash_attn":
key_cache, value_cache = kv_cache[0], kv_cache[1]
elif attention_backend == "flash_infer":
key_cache, value_cache = kv_cache[:, 0], kv_cache[:, 1]
else:
raise ValueError(f"Unknown attention backend {attention_backend}")
ops.reshape_and_cache_flash( # type: ignore
key,
value,
key_cache,
value_cache,
slot_mapping_flat[start_pos:end_pos],
kv_cache_dtypes[i],
k_scales[i],
v_scales[i],
)
logger.debug(f"KV receive DONE for rank {torch.distributed.get_rank()}")
def _send_tensors(self, request_id, target_rank, target_local_rank, keys, values):
logger.debug(
f"Sending tensors for request_id {request_id} to rank {target_rank}"
)
logger.debug(f"Tensor shapes: keys {keys.shape}, values {values.shape}")
logger.debug(f"Tensor dtypes: keys {keys.dtype}, values {values.dtype}")
if self._data_plane_backend == "nccl":
self._send_tensors_nccl(
request_id, target_rank, target_local_rank, keys, values
)
elif self._data_plane_backend == "ucx":
self._send_tensors_ucx(request_id, target_local_rank, keys, values)
def _send_tensors_nccl(
self, request_id, target_rank, target_local_rank, keys, values
):
generate_worker_id = envs.VLLM_CONTEXT_WORKERS
target_addr = self._store.get(
f"worker_{generate_worker_id}_rank_{target_local_rank}"
).decode()
self._data_plane.put_output_tensor(
keys,
rank=target_rank,
tensor_id=f"{request_id}_keys_rank{target_local_rank}",
remote_address=target_addr,
)
self._data_plane.put_output_tensor(
values,
rank=target_rank,
tensor_id=f"{request_id}_values_rank{target_local_rank}",
remote_address=target_addr,
)
def _send_tensors_ucx(self, request_id, target_local_rank, keys, values):
torch.cuda.synchronize()
self._data_plane.put_output_tensor(
keys,
_create_id_from_str(f"{request_id}_keys_rank{target_local_rank}"),
)
self._data_plane.put_output_tensor(
values,
_create_id_from_str(f"{request_id}_values_rank{target_local_rank}"),
)
def _recv_tensors(
self, request_id, context_worker_id, num_tokens, num_layers, num_heads, head_dim
):
logger.debug(
f"Receiving tensors for request_id {request_id} from worker {context_worker_id}"
)
if self._data_plane_backend == "nccl":
return self._recv_tensors_nccl(
request_id,
context_worker_id,
num_tokens,
num_layers,
num_heads,
head_dim,
)
elif self._data_plane_backend == "ucx":
return self._recv_tensors_ucx(
request_id,
context_worker_id,
num_tokens,
num_layers,
num_heads,
head_dim,
)
def _recv_tensors_nccl(
self, request_id, context_worker_id, num_tokens, num_layers, num_heads, head_dim
):
tp_rank = get_tp_group().local_rank
tp_multipler = envs.VLLM_GENERATE_TP_SIZE // envs.VLLM_CONTEXT_TP_SIZE
source_tp_rank = tp_rank // tp_multipler
source_rank = context_worker_id * envs.VLLM_CONTEXT_TP_SIZE + source_tp_rank
worker_key = f"worker_{context_worker_id}_rank_{source_tp_rank}"
source_addr = self._local_store.get(worker_key)
if source_addr is None:
logger.info(
"Fetching source address for worker %d by key %s",
context_worker_id,
worker_key,
)
source_addr = self._store.get(worker_key).decode()
self._local_store[worker_key] = source_addr
keys = self._data_plane.get_tensor(
rank=source_rank,
tensor_id=f"{request_id}_keys_rank{tp_rank}",
remote_address=source_addr,
)
values = self._data_plane.get_tensor(
rank=source_rank,
tensor_id=f"{request_id}_values_rank{tp_rank}",
remote_address=source_addr,
)
return keys, values
def _recv_tensors_ucx(
self, request_id, context_worker_id, num_tokens, num_layers, num_heads, head_dim
):
local_rank = get_tp_group().local_rank
tp_rank = get_tp_group().local_rank
tp_multipler = envs.VLLM_GENERATE_TP_SIZE // envs.VLLM_CONTEXT_TP_SIZE
source_tp_rank = tp_rank // tp_multipler
source_addr = self._store.get(
f"worker_{context_worker_id}_rank_{source_tp_rank}"
).decode()
keys_id = _create_id_from_str(f"{request_id}_keys_rank{local_rank}")
keys_uri = f"ucp://{source_addr}/{keys_id}"
keys = self._data_plane.get_tensor(
keys_uri,
(num_layers, num_tokens, num_heads, head_dim),
torch.uint8,
device_id=local_rank,
)
values_id = _create_id_from_str(f"{request_id}_values_rank{local_rank}")
values_uri = f"ucp://{source_addr}/{values_id}"
values = self._data_plane.get_tensor(
values_uri,
(num_layers, num_tokens, num_heads, head_dim),
torch.uint8,
device_id=local_rank,
)
return keys, values
def _get_attention_backend(attn_metadata: "AttentionMetadata") -> str:
if isinstance(attn_metadata, FlashAttentionMetadata):
return "flash_attn"
elif isinstance(attn_metadata, FlashInferMetadata):
return "flash_infer"
else:
raise ValueError(
f"Unknown attention metadata type {type(attn_metadata)}. Only FlashAttentionMetadata and FlashInferMetadata are supported."
)
def _create_id_from_str(str_id: str) -> uuid.UUID:
# Create a hash of the str string using SHA-1
hashed_key = hashlib.sha1(str_id.encode("utf-8"))
# Generate a UUID from the hash, ensuring it's in the correct format
hash_hex = hashed_key.hexdigest()[:32] # Get first 32 characters
uuid_generated = uuid.UUID(hash_hex)
return uuid_generated
#!/bin/bash -e
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# A script to download a Python wheel, patch it, copy additional files,
# repackage it, and optionally install the new wheel.
###############################################################################
# CONFIGURATION & DEFAULTS
###############################################################################
DEFAULT_WHEEL_URL="https://files.pythonhosted.org/packages/4a/4c/ee65ba33467a4c0de350ce29fbae39b9d0e7fcd887cc756fa993654d1228/vllm-0.6.3.post1-cp38-abi3-manylinux1_x86_64.whl"
DEFAULT_PATCH_FILE="vllm_patch_063post1.patch"
DEFAULT_DATA_PLANE_DIR="data_plane"
DEFAULT_WHEEL_DIR="wheel"
DEFAULT_OUTPUT_WHEEL="vllm-dist-0.6.3.post1-cp38-abi3-manylinux1_x86_64.whl"
# Optionally set a default SHA256 checksum for the downloaded wheel
# DEFAULT_CHECKSUM="SOME_SHA256_HERE"
###############################################################################
# HELPER FUNCTIONS
###############################################################################
usage() {
cat << EOF
Usage: $0 [OPTIONS]
Options:
-u, --url <URL> Wheel URL to download (default: $DEFAULT_WHEEL_URL)
-p, --patch <FILE> Patch file path (default: $DEFAULT_PATCH_FILE)
-d, --data-plane <DIR> Directory with additional data-plane files (default: $DEFAULT_DATA_PLANE_DIR)
-w, --wheel-dir <DIR> Extract destination directory (default: $DEFAULT_WHEEL_DIR)
-o, --output-wheel <FILE> Name/path for the repackaged wheel (default: $DEFAULT_OUTPUT_WHEEL)
-f, --force Force overwriting existing directories/files without prompts
-i, --install Install the new wheel after repackaging
-D, --debug Enable debug mode (verbose output)
-h, --help Show this help and exit
Example:
$0 \\
--url "https://example.com/path/to/vllm.whl" \\
--patch my_patch.patch \\
--data-plane custom_data/ \\
--wheel-dir extracted_wheel \\
--output-wheel my_vllm_dist.whl \\
--install \\
--force \\
--debug
EOF
}
info_log() {
# Always prints, showing a timestamp.
echo "[$(date +'%Y-%m-%d %H:%M:%S')] INFO: $*"
}
debug_log() {
# Prints only if DEBUG=true
if [[ "$DEBUG" == true ]]; then
echo "[$(date +'%Y-%m-%d %H:%M:%S')] DEBUG: $*"
fi
}
error_exit() {
echo "ERROR: $*" >&2
exit 1
}
###############################################################################
# PARSE ARGUMENTS
###############################################################################
FORCE_OVERWRITE=false
INSTALL_WHEEL=false
DEBUG=false
WHEEL_URL="$DEFAULT_WHEEL_URL"
PATCH_FILE="$DEFAULT_PATCH_FILE"
DATA_PLANE_DIR="$DEFAULT_DATA_PLANE_DIR"
WHEEL_DIR="$DEFAULT_WHEEL_DIR"
OUTPUT_WHEEL="$DEFAULT_OUTPUT_WHEEL"
while [[ $# -gt 0 ]]; do
case "$1" in
-u|--url)
WHEEL_URL="$2"
shift 2
;;
-p|--patch)
PATCH_FILE="$2"
shift 2
;;
-d|--data-plane)
DATA_PLANE_DIR="$2"
shift 2
;;
-w|--wheel-dir)
WHEEL_DIR="$2"
shift 2
;;
-o|--output-wheel)
OUTPUT_WHEEL="$2"
shift 2
;;
-f|--force)
FORCE_OVERWRITE=true
shift
;;
-i|--install)
INSTALL_WHEEL=true
shift
;;
-D|--debug)
DEBUG=true
shift
;;
-h|--help)
usage
exit 0
;;
*)
error_exit "Unknown option: $1"
;;
esac
done
###############################################################################
# MAIN SCRIPT
###############################################################################
# Enable debug mode if requested
if [[ "$DEBUG" == true ]]; then
set -x
fi
info_log "Starting wheel patching script..."
# Function to check if a command exists
command_exists() {
command -v "$1" >/dev/null 2>&1 || { echo >&2 "I require $1 but it's not installed. Aborting."; exit 1; }
}
# Check for required commands
command_exists pip
command_exists unzip
command_exists zip
command_exists patch
# ---------------------------------------------------------------------------
# 1. Check for existing wheel file or directory
# ---------------------------------------------------------------------------
WHEEL_FILENAME=$(basename "$WHEEL_URL")
if [[ -f "$WHEEL_FILENAME" && "$FORCE_OVERWRITE" != true ]]; then
info_log "File '$WHEEL_FILENAME' already exists. Reusing existing file."
info_log "If you want to redownload, remove '$WHEEL_FILENAME' or use --force."
else
info_log "Downloading wheel from $WHEEL_URL..."
rm -f "$WHEEL_FILENAME" 2>/dev/null || true # Remove existing file if forcing
wget -O "$WHEEL_FILENAME" "$WHEEL_URL"
fi
# ---------------------------------------------------------------------------
# 2. Optional: Verify checksum (commented out by default)
# ---------------------------------------------------------------------------
# if [[ -n "$DEFAULT_CHECKSUM" ]]; then
# info_log "Verifying SHA256 checksum..."
# echo "${DEFAULT_CHECKSUM} ${WHEEL_FILENAME}" | sha256sum --check - || error_exit "Checksum mismatch!"
# fi
# ---------------------------------------------------------------------------
# 3. Create/clean wheel extraction directory
# ---------------------------------------------------------------------------
if [[ -d "$WHEEL_DIR" ]]; then
if [[ "$FORCE_OVERWRITE" == true ]]; then
info_log "Removing existing directory '$WHEEL_DIR' due to --force..."
rm -rf "$WHEEL_DIR"
else
error_exit "Directory '$WHEEL_DIR' already exists. Use --force to overwrite."
fi
fi
info_log "Creating directory '$WHEEL_DIR'..."
mkdir -p "$WHEEL_DIR"
# ---------------------------------------------------------------------------
# 4. Unzip the wheel into the specified directory
# ---------------------------------------------------------------------------
info_log "Unzipping wheel into directory '$WHEEL_DIR'..."
unzip -q "$WHEEL_FILENAME" -d "$WHEEL_DIR"
# ---------------------------------------------------------------------------
# 5. Check/Apply patch
# ---------------------------------------------------------------------------
if [[ ! -f "$PATCH_FILE" ]]; then
error_exit "Patch file '$PATCH_FILE' not found."
fi
PATCH_TARGET_DIR="$WHEEL_DIR/vllm"
if [[ ! -d "$PATCH_TARGET_DIR" ]]; then
error_exit "Could not find directory '$PATCH_TARGET_DIR' in unzipped wheel."
fi
info_log "Applying patch '$PATCH_FILE' to '$PATCH_TARGET_DIR'..."
debug_log "Executing: (cd \"$PATCH_TARGET_DIR\" && patch -p1 < \"../../$PATCH_FILE\")"
(
cd "$PATCH_TARGET_DIR"
patch -p1 < "../../$PATCH_FILE"
)
# ---------------------------------------------------------------------------
# 6. Copy data plane files
# ---------------------------------------------------------------------------
if [[ ! -d "$DATA_PLANE_DIR" ]]; then
error_exit "Data plane directory '$DATA_PLANE_DIR' not found."
fi
info_log "Copying files from '$DATA_PLANE_DIR' to '$PATCH_TARGET_DIR/distributed'..."
mkdir -p "$PATCH_TARGET_DIR/distributed"
cp -r "$DATA_PLANE_DIR/"* "$PATCH_TARGET_DIR/distributed/"
# ---------------------------------------------------------------------------
# 7. Re-package into a new wheel
# ---------------------------------------------------------------------------
if [[ -f "$OUTPUT_WHEEL" && "$FORCE_OVERWRITE" != true ]]; then
error_exit "Output wheel '$OUTPUT_WHEEL' already exists. Use --force to overwrite."
fi
info_log "Creating new wheel file '$OUTPUT_WHEEL'..."
debug_log "Executing: (cd \"$WHEEL_DIR\" && zip -rq \"../$OUTPUT_WHEEL\" .)"
(
cd "$WHEEL_DIR"
zip -rq "../$OUTPUT_WHEEL" .
)
# ---------------------------------------------------------------------------
# 8. Optional: Install the new wheel
# ---------------------------------------------------------------------------
if [[ "$INSTALL_WHEEL" == true ]]; then
# Check if pip is installed
if ! command -v pip >/dev/null 2>&1; then
error_exit "pip is not installed or not found in PATH."
fi
info_log "Installing newly created wheel '$OUTPUT_WHEEL'..."
pip install --force-reinstall --upgrade --break-system-packages "$OUTPUT_WHEEL"
fi
info_log "Patch and repackage completed successfully!"
info_log "New wheel: $OUTPUT_WHEEL"
if [[ "$INSTALL_WHEEL" == true ]]; then
info_log "Wheel has been installed."
fi
import pytest
try:
import vllm
except ImportError:
vllm = None # type: ignore
pytestmark = pytest.mark.pre_merge
# TODO: Consider `pytest.mark.vllm` and running tests based on environment
@pytest.mark.skipif(vllm is None, reason="Skipping vllm tests, vllm not installed")
def test_version():
# Verify that the image has the patched version of vllm
assert vllm.__version__ == "0.6.3.post2.dev16+gf61960ce"
@pytest.mark.skipif(vllm is None, reason="Skipping vllm tests, vllm not installed")
def test_patch_imports():
# Verify patched files have no glaring syntax or import issues
pass
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment