"launch/README.md" did not exist on "e97493eb0065285c2775bfb5fcee7cd821f08842"
Commit a0e1da03 authored by Ryan McCormick's avatar Ryan McCormick Committed by GitHub
Browse files

build: Add vllm wheel patch scripts for disaggregated serving




---------
Co-authored-by: default avatarPiotr Marcinkiewicz <piotrm@nvidia.com>
parent 587addb9
......@@ -15,6 +15,7 @@
ARG BASE_IMAGE="nvcr.io/nvidia/tritonserver"
ARG BASE_IMAGE_TAG="24.12-py3"
ARG VLLM_WHEEL
FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} AS triton-distributed
......@@ -56,6 +57,19 @@ RUN --mount=type=bind,source=./container/deps/requirements.tensorrtllm.txt,targe
RUN --mount=type=bind,source=./container/deps/requirements.vllm.txt,target=/tmp/requirements.txt \
if [[ "$FRAMEWORK" == "VLLM" ]] ; then pip install --timeout=2000 -r /tmp/requirements.txt ; fi
# NOTE: python3-distro is a system package that causes conflicts with user
# packages installed by `pip`. Removing the system package allows user packages
# to correctly manage dependencies with the `distro` pip user package instead.
RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \
if [[ "$FRAMEWORK" == "VLLM" ]]; then \
apt install -y zip && \
apt remove -y python3-distro && \
pip install distro && \
cp -r /tmp/deps/vllm /tmp/local_vllm && \
cd /tmp/local_vllm && \
bash ./prepare_wheel.sh --install --debug --force ; \
fi
RUN --mount=type=bind,source=./container/deps/requirements.standard.txt,target=/tmp/requirements.txt \
if [[ "$FRAMEWORK" == "STANDARD" ]] ; then pip install --timeout=2000 -r /tmp/requirements.txt ; fi
......
......@@ -44,6 +44,8 @@ TENSORRTLLM_BASE_IMAGE_TAG=${TENSORRTLLM_BASE_VERSION}-trtllm-python-py3
# IMPORTANT NOTE: Ensure the commit matches the TRTLLM backend version used in the base image above
TENSORRTLLM_BACKEND_COMMIT=v0.16.0
# vllm installation is done later in the Dockerfile so it will overwrite the
# vllm version installed in the base image.
VLLM_BASE_VERSION=24.12
VLLM_BASE_IMAGE=nvcr.io/nvidia/tritonserver
VLLM_BASE_IMAGE_TAG=${VLLM_BASE_VERSION}-vllm-python-py3
......
......@@ -17,5 +17,5 @@
--extra-index-url https://flashinfer.ai/whl/cu121/torch2.4
flashinfer
ucx-py-cu12
# TODO update to branch / fork
vllm==0.6.3post1
# vLLM is installed by patching script
# vllm==0.6.3post1
<!--
SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
Apply this patch to Python source code from vLLM release [v0.6.3.post1](https://github.com/vllm-project/vllm/releases/tag/v0.6.3.post1). Copy files in ``data_plane`` folder into vLLM folder ``vllm/distributed``.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# A script to download a Python wheel, patch it, copy additional files,
# repackage it, and optionally install the new wheel.
import logging
import math
import socket
import threading
import typing
import uuid
import torch
import torch.distributed
import tritonserver
import zmq
from triton_distributed.icp.data_plane import (
set_icp_data_type,
set_icp_memory_type,
set_icp_shape,
set_icp_tensor_size,
set_icp_tensor_uri,
)
from triton_distributed.icp.protos.icp_pb2 import ModelInferRequest
from triton_distributed.icp.ucp_data_plane import DataPlaneError, UcpDataPlane
logger = logging.getLogger(__name__)
class VllmUcpDataPlane:
def __init__(
self,
hostname: typing.Optional[str] = None,
port: int = 0,
keep_endpoints_open: bool = False,
) -> None:
self._data_plane = UcpDataPlane(hostname, port, keep_endpoints_open)
@property
def hostname(self) -> str:
return self._data_plane.hostname
@property
def port(self) -> int:
return self._data_plane.port
def connect(self) -> None:
self._data_plane.connect()
def close(self) -> None:
self._data_plane.close()
def put_input_tensor(
self,
tensor: torch.Tensor,
tensor_id: typing.Optional[uuid.UUID] = None,
):
logger.debug(
f"Putting input tensor with id {tensor_id} on {self.hostname}:{self.port}"
)
if tensor.dtype == torch.bfloat16:
tensor = tensor.view(torch.float16)
triton_tensor = tritonserver.Tensor.from_dlpack(tensor)
self._data_plane.put_input_tensor(triton_tensor, tensor_id)
def put_output_tensor(
self,
tensor: torch.Tensor,
tensor_id: typing.Optional[uuid.UUID] = None,
):
logger.debug(
f"Putting input tensor with id {tensor_id} on {self.hostname}:{self.port}"
)
if tensor.dtype == torch.bfloat16:
tensor = tensor.view(torch.float16)
triton_tensor = tritonserver.Tensor.from_dlpack(tensor)
self._data_plane.put_output_tensor(triton_tensor, tensor_id)
def get_tensor(
self,
tensor_uri: str,
shape: typing.Sequence[int],
dtype: torch.dtype,
device_id: int,
) -> torch.Tensor:
logger.debug("Getting tensor from %s", tensor_uri)
result = ModelInferRequest.InferInputTensor()
triton_dtype = {
torch.float32: tritonserver.DataType.FP32,
torch.float16: tritonserver.DataType.FP16,
torch.bfloat16: tritonserver.DataType.FP16,
torch.uint8: tritonserver.DataType.UINT8,
}.get(dtype)
if triton_dtype is None:
raise DataPlaneError(f"Unsupported dtype {dtype}")
tensor_size = math.prod(shape) * dtype.itemsize
set_icp_data_type(result, triton_dtype)
set_icp_shape(result, shape)
set_icp_tensor_uri(result, tensor_uri)
set_icp_memory_type(result, tritonserver.MemoryType.GPU)
set_icp_tensor_size(result, tensor_size)
triton_tensor = self._data_plane.get_tensor(
remote_tensor=result,
requested_memory_type=tritonserver.MemoryType.GPU,
requested_memory_type_id=device_id,
)
tensor = torch.utils.dlpack.from_dlpack(triton_tensor)
if dtype == torch.bfloat16:
tensor = tensor.view(torch.bfloat16)
logger.debug("Got tensor from %s", tensor_uri)
return tensor
class VllmNcclDataPlane:
def __init__(
self,
hostname: str = "",
port: int = 0,
# FIXME: world_size and rank both unused
world_size: int = -1,
rank: int = -1,
) -> None:
if not torch.distributed.is_initialized():
raise RuntimeError("NCCL backend not initialized")
if not hostname:
hostname = socket.gethostname()
if port == 0:
port = 13337 + torch.distributed.get_rank()
self._hostname = hostname
self._port = port
self._rank = torch.distributed.get_rank()
self._world_size: int = world_size
self._current_device = torch.cuda.current_device()
# FIXME: Use stricter type for req value in tuple
self.store: typing.Dict[str, typing.Tuple[torch.Tensor, int, typing.Any]] = {}
self.context = zmq.Context()
self.rep_socket = self.context.socket(zmq.REP)
logger.info(f"Rank {self._rank} binding to {self._hostname}:{self._port}")
self.rep_socket.bind(f"tcp://{self._hostname}:{self._port}")
self._listener_thread = threading.Thread(
target=self.listen_for_requests, daemon=True
)
self._listener_thread.start()
# FIXME: Use stricter ZMQ socket type hint
self.req_sockets: typing.Dict[str, typing.Any] = {}
logger.info(f"Rank {self._rank} connected to the server")
@property
def world_size(self):
return self._world_size
@property
def rank(self):
return self._rank
def put_input_tensor(
self,
tensor: torch.Tensor,
rank: int,
tensor_id: str,
remote_address: typing.Optional[str] = None,
):
return self._put_tensor(tensor, rank, tensor_id, remote_address)
def put_output_tensor(
self,
tensor: torch.Tensor,
rank: int,
tensor_id: str,
remote_address: typing.Optional[str] = None,
):
return self._put_tensor(tensor, rank, tensor_id, remote_address)
def get_tensor(
self,
rank: int,
tensor_id: str,
remote_address: str,
) -> torch.Tensor:
return self._get_tensor(rank, tensor_id, remote_address)
def _put_tensor(
self,
tensor: torch.Tensor,
rank: int,
tensor_id: str,
remote_address: typing.Optional[str] = None,
):
logger.debug(
f"Rank {self._rank} storing tensor with id {tensor_id} of shape {tensor.shape} and dtype {tensor.dtype}"
)
if remote_address is None:
self.store[tensor_id] = (tensor, rank, None)
else:
tensor_shape = "_".join(str(dim) for dim in tensor.shape)
if remote_address not in self.req_sockets:
self.req_sockets[remote_address] = self.context.socket(zmq.REQ)
self.req_sockets[remote_address].connect(f"tcp://{remote_address}")
req_socket = self.req_sockets[remote_address]
req_socket.connect(f"tcp://{remote_address}")
req_socket.send_string(f"PUT {self._rank} {tensor_shape} {tensor_id}")
ret = req_socket.recv_string()
assert ret == "OK"
torch.distributed.isend(tensor, dst=rank)
def _get_tensor(
self,
rank: int,
tensor_id: str,
remote_address: str,
) -> torch.Tensor:
logger.debug(f"Rank {self._rank} receiving tensor from rank {rank}")
if tensor_id in self.store:
tensor, _, req = self.store.pop(tensor_id)
req.wait() # TODO ptarasiewicz we should run other request instead of wait
logger.debug(f"Rank {self._rank} received tensor from rank {rank}")
return tensor
raise NotImplementedError("Getting tensor from remote rank not implemented")
def _receive_tensor(
self,
tensor_id: str,
rank: int,
shape: typing.Sequence[int],
):
tensor = torch.empty(
shape, dtype=torch.uint8, device=f"cuda:{self._current_device}"
)
req = torch.distributed.irecv(tensor, src=rank)
self.store[tensor_id] = (tensor, rank, req)
def listen_for_requests(self):
while True:
cmd, _rank, _shape, tensor_id = self.rep_socket.recv_string().split()
logger.debug(f"Rank {self._rank} received request for tensor {tensor_id}")
self.rep_socket.send_string("OK")
if cmd == "GET":
raise NotImplementedError(
"Getting tensor from remote rank not implemented"
)
elif cmd == "PUT":
rank = int(_rank)
shape = [int(dim) for dim in _shape.split("_")]
self._receive_tensor(tensor_id, rank, shape)
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# A script to download a Python wheel, patch it, copy additional files,
# repackage it, and optionally install the new wheel.
import typing
if typing.TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata # type: ignore
from vllm.attention.backends.abstract import AttentionMetadata # type: ignore
import hashlib
import uuid
from concurrent.futures import ThreadPoolExecutor
import torch
import torch.distributed
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.flashinfer import FlashInferBackend, FlashInferMetadata
from vllm.distributed.data_plane import VllmNcclDataPlane, VllmUcpDataPlane
from vllm.distributed.parallel_state import get_store, get_tp_group # type: ignore
from vllm.logger import init_logger
logger = init_logger(__name__)
_kv_cache_handler = None
def get_kv_cache_handler():
global _kv_cache_handler
if _kv_cache_handler is None:
_kv_cache_handler = KVCacheHandler()
return _kv_cache_handler
class KVCacheHandler:
def __init__(self):
if _kv_cache_handler is not None:
raise ValueError("KVCacheHandler is a singleton")
self._data_plane_backend = envs.VLLM_DATA_PLANE_BACKEND
if self._data_plane_backend == "nccl":
self._data_plane = VllmNcclDataPlane()
self._store = get_store()
logger.info("Store set up")
self._store.set(
f"worker_{envs.VLLM_WORKER_ID}_rank_{get_tp_group().local_rank}",
f"{self._data_plane._hostname}:{self._data_plane._port}",
)
elif self._data_plane_backend == "ucx":
self._data_plane = VllmUcpDataPlane(keep_endpoints_open=True)
self._data_plane.connect()
rank = torch.distributed.get_rank()
is_master = envs.VLLM_WORKER_ID == 0 and rank == 0
self._store = torch.distributed.TCPStore(
envs.VLLM_TORCH_HOST, envs.VLLM_TORCH_PORT, is_master=is_master
)
self._store.set(
f"worker_{envs.VLLM_WORKER_ID}_rank_{rank}",
f"{self._data_plane.hostname}:{self._data_plane.port}",
)
else:
raise ValueError(f"Unknown data plane backend {self._data_plane_backend}")
self._local_store = {}
self.transport_thread = ThreadPoolExecutor(max_workers=1)
logger.info("KVCacheHandler initialized")
def send(
self,
model: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: typing.List[torch.Tensor],
):
with torch.cuda.nvtx.range("KV send"):
self._send(
model,
model_input,
kv_caches,
)
def _send(
self,
model: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: typing.List[torch.Tensor],
):
seq_lens = model_input.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model.start_layer
end_layer = model.end_layer
request_ids = list(model_input.request_ids_to_seq_ids.keys())
_, _, block_size, num_heads, _ = kv_caches[0].shape
attention_backend = _get_attention_backend(model_input.attn_metadata)
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
request_id = request_ids[idx]
keys, values = [], []
logger.debug(
f"seq_len {slen}, start_pos {start_pos}, end_pos {end_pos}, slot_mapping_flat {slot_mapping_flat.shape}"
)
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
if len(model_input.attn_metadata.block_tables[idx]) > 0:
block_inds = torch.range(
0,
block_size - 1,
device=current_slot_mapping.device,
dtype=current_slot_mapping.dtype,
)
additional_inds = (
torch.cat(
[
block_inds + block_size * i
for i in model_input.attn_metadata.block_tables[idx]
]
)
.to(current_slot_mapping.device)
.to(current_slot_mapping.dtype)
)
logger.debug(f"additional_inds: {additional_inds.shape}")
logger.debug(f"current_slot_mapping: {current_slot_mapping.shape}")
current_slot_mapping = torch.cat(
[additional_inds, current_slot_mapping]
)
logger.debug(f"new current_slot_mapping: {current_slot_mapping.shape}")
current_slot_mapping_quotient = current_slot_mapping // block_size
current_slot_mapping_remainder = current_slot_mapping % block_size
logger.debug("kv_caches shape: %s", kv_caches[0].shape)
for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]
if attention_backend == "flash_attn":
key_cache = kv_cache[
0, current_slot_mapping_quotient, current_slot_mapping_remainder
]
value_cache = kv_cache[
1, current_slot_mapping_quotient, current_slot_mapping_remainder
]
elif attention_backend == "flash_infer":
key_cache = kv_cache[
current_slot_mapping_quotient, 0, current_slot_mapping_remainder
]
value_cache = kv_cache[
current_slot_mapping_quotient, 1, current_slot_mapping_remainder
]
else:
raise ValueError(f"Unknown attention backend {attention_backend}")
keys.append(key_cache)
values.append(value_cache)
keys = torch.stack(keys, dim=0) # type: ignore
values = torch.stack(values, dim=0) # type: ignore
tp_multipler = envs.VLLM_GENERATE_TP_SIZE // envs.VLLM_CONTEXT_TP_SIZE
first_rank = envs.VLLM_CONTEXT_WORKERS * envs.VLLM_CONTEXT_TP_SIZE
for i in range(tp_multipler):
num_heads_per_generate_rank = num_heads // tp_multipler
first_head = i * num_heads_per_generate_rank
partial_keys = keys[
:, :, first_head : first_head + num_heads_per_generate_rank, :
].clone() # type: ignore
partial_values = values[
:, :, first_head : first_head + num_heads_per_generate_rank, :
].clone() # type: ignore
target_local_rank = get_tp_group().local_rank * tp_multipler + i
target_rank = target_local_rank + first_rank
# torch.cuda.synchronize()
self._send_tensors(
request_id,
target_rank,
target_local_rank,
partial_keys,
partial_values,
)
logger.debug("Tensors sent")
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
def recv(
self,
model: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: typing.List[torch.Tensor],
):
with torch.cuda.nvtx.range("KV recv"):
self._recv(
model.start_layer,
model.end_layer,
[
model.layers[i].self_attn.attn.kv_cache_dtype
for i in range(model.start_layer, model.end_layer)
],
[
model.layers[i].self_attn.attn._k_scale
for i in range(model.start_layer, model.end_layer)
],
[
model.layers[i].self_attn.attn._v_scale
for i in range(model.start_layer, model.end_layer)
],
model_input,
kv_caches,
)
def _recv(
self,
start_layer: int,
end_layer: int,
kv_cache_dtypes: typing.List[torch.dtype],
k_scales: typing.List[float],
v_scales: typing.List[float],
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: typing.List[torch.Tensor],
):
seq_lens = model_input.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
request_ids = list(model_input.request_ids_to_seq_ids.keys())
_, _, block_size, num_heads, head_dim = kv_caches[0].shape
attention_backend = _get_attention_backend(model_input.attn_metadata)
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
request_id = request_ids[idx]
num_tokens = slen
base_request_id, context_worker_id = request_id.split("___")
context_worker_id = int(context_worker_id)
keys, values = self._recv_tensors(
base_request_id,
context_worker_id,
num_tokens,
end_layer - start_layer,
num_heads,
head_dim,
)
logger.debug(f"Received tensors for request_id {request_id}")
if kv_caches[0].dtype == torch.uint8:
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer("fp8_e4m3")
keys = keys.view(torch_dtype).to(torch.bfloat16)
values = values.view(torch_dtype).to(torch.bfloat16)
logger.debug("Converted caches to torch.blfoat16")
for i in range(start_layer, end_layer):
kv_cache = kv_caches[i - start_layer]
key, value = keys[i], values[i]
if attention_backend == "flash_attn":
key_cache, value_cache = kv_cache[0], kv_cache[1]
elif attention_backend == "flash_infer":
key_cache, value_cache = kv_cache[:, 0], kv_cache[:, 1]
else:
raise ValueError(f"Unknown attention backend {attention_backend}")
ops.reshape_and_cache_flash( # type: ignore
key,
value,
key_cache,
value_cache,
slot_mapping_flat[start_pos:end_pos],
kv_cache_dtypes[i],
k_scales[i],
v_scales[i],
)
logger.debug(f"KV receive DONE for rank {torch.distributed.get_rank()}")
def _send_tensors(self, request_id, target_rank, target_local_rank, keys, values):
logger.debug(
f"Sending tensors for request_id {request_id} to rank {target_rank}"
)
logger.debug(f"Tensor shapes: keys {keys.shape}, values {values.shape}")
logger.debug(f"Tensor dtypes: keys {keys.dtype}, values {values.dtype}")
if self._data_plane_backend == "nccl":
self._send_tensors_nccl(
request_id, target_rank, target_local_rank, keys, values
)
elif self._data_plane_backend == "ucx":
self._send_tensors_ucx(request_id, target_local_rank, keys, values)
def _send_tensors_nccl(
self, request_id, target_rank, target_local_rank, keys, values
):
generate_worker_id = envs.VLLM_CONTEXT_WORKERS
target_addr = self._store.get(
f"worker_{generate_worker_id}_rank_{target_local_rank}"
).decode()
self._data_plane.put_output_tensor(
keys,
rank=target_rank,
tensor_id=f"{request_id}_keys_rank{target_local_rank}",
remote_address=target_addr,
)
self._data_plane.put_output_tensor(
values,
rank=target_rank,
tensor_id=f"{request_id}_values_rank{target_local_rank}",
remote_address=target_addr,
)
def _send_tensors_ucx(self, request_id, target_local_rank, keys, values):
torch.cuda.synchronize()
self._data_plane.put_output_tensor(
keys,
_create_id_from_str(f"{request_id}_keys_rank{target_local_rank}"),
)
self._data_plane.put_output_tensor(
values,
_create_id_from_str(f"{request_id}_values_rank{target_local_rank}"),
)
def _recv_tensors(
self, request_id, context_worker_id, num_tokens, num_layers, num_heads, head_dim
):
logger.debug(
f"Receiving tensors for request_id {request_id} from worker {context_worker_id}"
)
if self._data_plane_backend == "nccl":
return self._recv_tensors_nccl(
request_id,
context_worker_id,
num_tokens,
num_layers,
num_heads,
head_dim,
)
elif self._data_plane_backend == "ucx":
return self._recv_tensors_ucx(
request_id,
context_worker_id,
num_tokens,
num_layers,
num_heads,
head_dim,
)
def _recv_tensors_nccl(
self, request_id, context_worker_id, num_tokens, num_layers, num_heads, head_dim
):
tp_rank = get_tp_group().local_rank
tp_multipler = envs.VLLM_GENERATE_TP_SIZE // envs.VLLM_CONTEXT_TP_SIZE
source_tp_rank = tp_rank // tp_multipler
source_rank = context_worker_id * envs.VLLM_CONTEXT_TP_SIZE + source_tp_rank
worker_key = f"worker_{context_worker_id}_rank_{source_tp_rank}"
source_addr = self._local_store.get(worker_key)
if source_addr is None:
logger.info(
"Fetching source address for worker %d by key %s",
context_worker_id,
worker_key,
)
source_addr = self._store.get(worker_key).decode()
self._local_store[worker_key] = source_addr
keys = self._data_plane.get_tensor(
rank=source_rank,
tensor_id=f"{request_id}_keys_rank{tp_rank}",
remote_address=source_addr,
)
values = self._data_plane.get_tensor(
rank=source_rank,
tensor_id=f"{request_id}_values_rank{tp_rank}",
remote_address=source_addr,
)
return keys, values
def _recv_tensors_ucx(
self, request_id, context_worker_id, num_tokens, num_layers, num_heads, head_dim
):
local_rank = get_tp_group().local_rank
tp_rank = get_tp_group().local_rank
tp_multipler = envs.VLLM_GENERATE_TP_SIZE // envs.VLLM_CONTEXT_TP_SIZE
source_tp_rank = tp_rank // tp_multipler
source_addr = self._store.get(
f"worker_{context_worker_id}_rank_{source_tp_rank}"
).decode()
keys_id = _create_id_from_str(f"{request_id}_keys_rank{local_rank}")
keys_uri = f"ucp://{source_addr}/{keys_id}"
keys = self._data_plane.get_tensor(
keys_uri,
(num_layers, num_tokens, num_heads, head_dim),
torch.uint8,
device_id=local_rank,
)
values_id = _create_id_from_str(f"{request_id}_values_rank{local_rank}")
values_uri = f"ucp://{source_addr}/{values_id}"
values = self._data_plane.get_tensor(
values_uri,
(num_layers, num_tokens, num_heads, head_dim),
torch.uint8,
device_id=local_rank,
)
return keys, values
def _get_attention_backend(attn_metadata: "AttentionMetadata") -> str:
if isinstance(attn_metadata, FlashAttentionMetadata):
return "flash_attn"
elif isinstance(attn_metadata, FlashInferMetadata):
return "flash_infer"
else:
raise ValueError(
f"Unknown attention metadata type {type(attn_metadata)}. Only FlashAttentionMetadata and FlashInferMetadata are supported."
)
def _create_id_from_str(str_id: str) -> uuid.UUID:
# Create a hash of the str string using SHA-1
hashed_key = hashlib.sha1(str_id.encode("utf-8"))
# Generate a UUID from the hash, ensuring it's in the correct format
hash_hex = hashed_key.hexdigest()[:32] # Get first 32 characters
uuid_generated = uuid.UUID(hash_hex)
return uuid_generated
#!/bin/bash -e
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# A script to download a Python wheel, patch it, copy additional files,
# repackage it, and optionally install the new wheel.
###############################################################################
# CONFIGURATION & DEFAULTS
###############################################################################
DEFAULT_WHEEL_URL="https://files.pythonhosted.org/packages/4a/4c/ee65ba33467a4c0de350ce29fbae39b9d0e7fcd887cc756fa993654d1228/vllm-0.6.3.post1-cp38-abi3-manylinux1_x86_64.whl"
DEFAULT_PATCH_FILE="vllm_patch_063post1.patch"
DEFAULT_DATA_PLANE_DIR="data_plane"
DEFAULT_WHEEL_DIR="wheel"
DEFAULT_OUTPUT_WHEEL="vllm-dist-0.6.3.post1-cp38-abi3-manylinux1_x86_64.whl"
# Optionally set a default SHA256 checksum for the downloaded wheel
# DEFAULT_CHECKSUM="SOME_SHA256_HERE"
###############################################################################
# HELPER FUNCTIONS
###############################################################################
usage() {
cat << EOF
Usage: $0 [OPTIONS]
Options:
-u, --url <URL> Wheel URL to download (default: $DEFAULT_WHEEL_URL)
-p, --patch <FILE> Patch file path (default: $DEFAULT_PATCH_FILE)
-d, --data-plane <DIR> Directory with additional data-plane files (default: $DEFAULT_DATA_PLANE_DIR)
-w, --wheel-dir <DIR> Extract destination directory (default: $DEFAULT_WHEEL_DIR)
-o, --output-wheel <FILE> Name/path for the repackaged wheel (default: $DEFAULT_OUTPUT_WHEEL)
-f, --force Force overwriting existing directories/files without prompts
-i, --install Install the new wheel after repackaging
-D, --debug Enable debug mode (verbose output)
-h, --help Show this help and exit
Example:
$0 \\
--url "https://example.com/path/to/vllm.whl" \\
--patch my_patch.patch \\
--data-plane custom_data/ \\
--wheel-dir extracted_wheel \\
--output-wheel my_vllm_dist.whl \\
--install \\
--force \\
--debug
EOF
}
info_log() {
# Always prints, showing a timestamp.
echo "[$(date +'%Y-%m-%d %H:%M:%S')] INFO: $*"
}
debug_log() {
# Prints only if DEBUG=true
if [[ "$DEBUG" == true ]]; then
echo "[$(date +'%Y-%m-%d %H:%M:%S')] DEBUG: $*"
fi
}
error_exit() {
echo "ERROR: $*" >&2
exit 1
}
###############################################################################
# PARSE ARGUMENTS
###############################################################################
FORCE_OVERWRITE=false
INSTALL_WHEEL=false
DEBUG=false
WHEEL_URL="$DEFAULT_WHEEL_URL"
PATCH_FILE="$DEFAULT_PATCH_FILE"
DATA_PLANE_DIR="$DEFAULT_DATA_PLANE_DIR"
WHEEL_DIR="$DEFAULT_WHEEL_DIR"
OUTPUT_WHEEL="$DEFAULT_OUTPUT_WHEEL"
while [[ $# -gt 0 ]]; do
case "$1" in
-u|--url)
WHEEL_URL="$2"
shift 2
;;
-p|--patch)
PATCH_FILE="$2"
shift 2
;;
-d|--data-plane)
DATA_PLANE_DIR="$2"
shift 2
;;
-w|--wheel-dir)
WHEEL_DIR="$2"
shift 2
;;
-o|--output-wheel)
OUTPUT_WHEEL="$2"
shift 2
;;
-f|--force)
FORCE_OVERWRITE=true
shift
;;
-i|--install)
INSTALL_WHEEL=true
shift
;;
-D|--debug)
DEBUG=true
shift
;;
-h|--help)
usage
exit 0
;;
*)
error_exit "Unknown option: $1"
;;
esac
done
###############################################################################
# MAIN SCRIPT
###############################################################################
# Enable debug mode if requested
if [[ "$DEBUG" == true ]]; then
set -x
fi
info_log "Starting wheel patching script..."
# Function to check if a command exists
command_exists() {
command -v "$1" >/dev/null 2>&1 || { echo >&2 "I require $1 but it's not installed. Aborting."; exit 1; }
}
# Check for required commands
command_exists pip
command_exists unzip
command_exists zip
command_exists patch
# ---------------------------------------------------------------------------
# 1. Check for existing wheel file or directory
# ---------------------------------------------------------------------------
WHEEL_FILENAME=$(basename "$WHEEL_URL")
if [[ -f "$WHEEL_FILENAME" && "$FORCE_OVERWRITE" != true ]]; then
info_log "File '$WHEEL_FILENAME' already exists. Reusing existing file."
info_log "If you want to redownload, remove '$WHEEL_FILENAME' or use --force."
else
info_log "Downloading wheel from $WHEEL_URL..."
rm -f "$WHEEL_FILENAME" 2>/dev/null || true # Remove existing file if forcing
wget -O "$WHEEL_FILENAME" "$WHEEL_URL"
fi
# ---------------------------------------------------------------------------
# 2. Optional: Verify checksum (commented out by default)
# ---------------------------------------------------------------------------
# if [[ -n "$DEFAULT_CHECKSUM" ]]; then
# info_log "Verifying SHA256 checksum..."
# echo "${DEFAULT_CHECKSUM} ${WHEEL_FILENAME}" | sha256sum --check - || error_exit "Checksum mismatch!"
# fi
# ---------------------------------------------------------------------------
# 3. Create/clean wheel extraction directory
# ---------------------------------------------------------------------------
if [[ -d "$WHEEL_DIR" ]]; then
if [[ "$FORCE_OVERWRITE" == true ]]; then
info_log "Removing existing directory '$WHEEL_DIR' due to --force..."
rm -rf "$WHEEL_DIR"
else
error_exit "Directory '$WHEEL_DIR' already exists. Use --force to overwrite."
fi
fi
info_log "Creating directory '$WHEEL_DIR'..."
mkdir -p "$WHEEL_DIR"
# ---------------------------------------------------------------------------
# 4. Unzip the wheel into the specified directory
# ---------------------------------------------------------------------------
info_log "Unzipping wheel into directory '$WHEEL_DIR'..."
unzip -q "$WHEEL_FILENAME" -d "$WHEEL_DIR"
# ---------------------------------------------------------------------------
# 5. Check/Apply patch
# ---------------------------------------------------------------------------
if [[ ! -f "$PATCH_FILE" ]]; then
error_exit "Patch file '$PATCH_FILE' not found."
fi
PATCH_TARGET_DIR="$WHEEL_DIR/vllm"
if [[ ! -d "$PATCH_TARGET_DIR" ]]; then
error_exit "Could not find directory '$PATCH_TARGET_DIR' in unzipped wheel."
fi
info_log "Applying patch '$PATCH_FILE' to '$PATCH_TARGET_DIR'..."
debug_log "Executing: (cd \"$PATCH_TARGET_DIR\" && patch -p1 < \"../../$PATCH_FILE\")"
(
cd "$PATCH_TARGET_DIR"
patch -p1 < "../../$PATCH_FILE"
)
# ---------------------------------------------------------------------------
# 6. Copy data plane files
# ---------------------------------------------------------------------------
if [[ ! -d "$DATA_PLANE_DIR" ]]; then
error_exit "Data plane directory '$DATA_PLANE_DIR' not found."
fi
info_log "Copying files from '$DATA_PLANE_DIR' to '$PATCH_TARGET_DIR/distributed'..."
mkdir -p "$PATCH_TARGET_DIR/distributed"
cp -r "$DATA_PLANE_DIR/"* "$PATCH_TARGET_DIR/distributed/"
# ---------------------------------------------------------------------------
# 7. Re-package into a new wheel
# ---------------------------------------------------------------------------
if [[ -f "$OUTPUT_WHEEL" && "$FORCE_OVERWRITE" != true ]]; then
error_exit "Output wheel '$OUTPUT_WHEEL' already exists. Use --force to overwrite."
fi
info_log "Creating new wheel file '$OUTPUT_WHEEL'..."
debug_log "Executing: (cd \"$WHEEL_DIR\" && zip -rq \"../$OUTPUT_WHEEL\" .)"
(
cd "$WHEEL_DIR"
zip -rq "../$OUTPUT_WHEEL" .
)
# ---------------------------------------------------------------------------
# 8. Optional: Install the new wheel
# ---------------------------------------------------------------------------
if [[ "$INSTALL_WHEEL" == true ]]; then
# Check if pip is installed
if ! command -v pip >/dev/null 2>&1; then
error_exit "pip is not installed or not found in PATH."
fi
info_log "Installing newly created wheel '$OUTPUT_WHEEL'..."
pip install --force-reinstall --upgrade --break-system-packages "$OUTPUT_WHEEL"
fi
info_log "Patch and repackage completed successfully!"
info_log "New wheel: $OUTPUT_WHEEL"
if [[ "$INSTALL_WHEEL" == true ]]; then
info_log "Wheel has been installed."
fi
import pytest
try:
import vllm
except ImportError:
vllm = None # type: ignore
pytestmark = pytest.mark.pre_merge
# TODO: Consider `pytest.mark.vllm` and running tests based on environment
@pytest.mark.skipif(vllm is None, reason="Skipping vllm tests, vllm not installed")
def test_version():
# Verify that the image has the patched version of vllm
assert vllm.__version__ == "0.6.3.post2.dev16+gf61960ce"
@pytest.mark.skipif(vllm is None, reason="Skipping vllm tests, vllm not installed")
def test_patch_imports():
# Verify patched files have no glaring syntax or import issues
pass
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
diff -Naur v0.6.3.post1_vllm/_version.py patched_vllm/_version.py
--- v0.6.3.post1_vllm/_version.py 2025-01-09 03:03:32.439278263 -0800
+++ patched_vllm/_version.py 2025-01-09 01:49:43.785300620 -0800
@@ -12,5 +12,5 @@
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE
-__version__ = version = '0.6.3.post1'
-__version_tuple__ = version_tuple = (0, 6, 3)
+__version__ = version = '0.6.3.post2.dev16+gf61960ce'
+__version_tuple__ = version_tuple = (0, 6, 3, 'dev16', 'gf61960ce')
diff -Naur v0.6.3.post1_vllm/config.py patched_vllm/config.py
--- v0.6.3.post1_vllm/config.py 2025-01-09 03:03:32.439278263 -0800
+++ patched_vllm/config.py 2025-01-09 01:49:43.785300620 -0800
@@ -1046,7 +1046,7 @@
"sequences. Please increase max_num_batched_tokens or "
"decrease max_model_len.")
- if self.max_num_batched_tokens < self.max_num_seqs:
+ if self.max_num_seqs is not None and self.max_num_batched_tokens < self.max_num_seqs:
raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
"be greater than or equal to max_num_seqs "
diff -Naur v0.6.3.post1_vllm/core/scheduler.py patched_vllm/core/scheduler.py
--- v0.6.3.post1_vllm/core/scheduler.py 2025-01-09 03:03:32.291290245 -0800
+++ patched_vllm/core/scheduler.py 2025-01-09 01:49:43.785300620 -0800
@@ -17,6 +17,7 @@
SequenceGroupMetadata, SequenceGroupMetadataDelta,
SequenceStatus)
from vllm.utils import Device, PyObjectCache
+import vllm.envs as envs
logger = init_logger(__name__)
@@ -883,12 +884,17 @@
assert len(waiting_seqs) == 1, (
"Waiting sequence group should have only one prompt "
"sequence.")
- num_new_tokens = self._get_num_new_tokens(seq_group,
- SequenceStatus.WAITING,
- enable_chunking, budget)
- if not enable_chunking:
- num_prompt_tokens = waiting_seqs[0].get_len()
- assert num_new_tokens == num_prompt_tokens
+ real_num_new_tokens = self._get_num_new_tokens(seq_group,
+ SequenceStatus.WAITING,
+ enable_chunking, budget)
+ if envs.VLLM_DISAGG_STAGE == "GENERATE":
+ num_new_tokens = 1
+ assert not enable_chunking
+ else:
+ num_new_tokens = real_num_new_tokens
+ if not enable_chunking:
+ num_prompt_tokens = waiting_seqs[0].get_len()
+ assert num_new_tokens == num_prompt_tokens
prompt_limit = self._get_prompt_limit(seq_group)
if num_new_tokens > prompt_limit:
@@ -967,7 +973,7 @@
seq_groups.append(
ScheduledSequenceGroup(seq_group=seq_group,
- token_chunk_size=num_new_tokens))
+ token_chunk_size=real_num_new_tokens))
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
budget.add_num_seqs(seq_group.request_id, num_new_seqs)
diff -Naur v0.6.3.post1_vllm/distributed/parallel_state.py patched_vllm/distributed/parallel_state.py
--- v0.6.3.post1_vllm/distributed/parallel_state.py 2025-01-09 03:03:32.291290245 -0800
+++ patched_vllm/distributed/parallel_state.py 2025-01-09 01:49:43.785300620 -0800
@@ -889,6 +889,14 @@
get_pipeline_model_parallel_group = get_pp_group
+_STORE: Optional[Any] = None
+
+def get_store() -> Any:
+ assert _STORE is not None, ("store is not initialized")
+ return _STORE
+
+
+
@contextmanager
def graph_capture():
"""
@@ -926,20 +934,51 @@
local_rank: int = -1,
backend: str = "nccl",
):
+
+ # TODO ptarasiewicz this is a hack to get the stage from the environment
+ logger.info("="*50)
+ logger.info("Patching init_distributed_environment")
+ stage = envs.VLLM_DISAGG_STAGE
+ logger.info(f"Stage: {stage}")
+ store = None
+ if stage is not None and envs.VLLM_DATA_PLANE_BACKEND == "nccl":
+ context_workers = envs.VLLM_CONTEXT_WORKERS
+ context_tp_size = envs.VLLM_CONTEXT_TP_SIZE
+ generate_workers = envs.VLLM_GENERATE_WORKERS
+ generate_tp_size = envs.VLLM_GENERATE_TP_SIZE
+ world_size = context_workers * context_tp_size + generate_workers * generate_tp_size
+ if stage == "PREFILL":
+ worker_id = envs.VLLM_WORKER_ID
+ rank += worker_id * context_tp_size
+ if stage == "GENERATE":
+ rank += context_workers * context_tp_size # TODO ptarasiewicz this only works for 1 generate worker
+ # distributed_init_method = f"tcp://{envs.VLLM_TORCH_HOST}:{envs.VLLM_TORCH_PORT}"
+ distributed_init_method = None
+ store = torch.distributed.TCPStore(envs.VLLM_TORCH_HOST, envs.VLLM_TORCH_PORT, world_size=world_size, is_master = rank == 0)
+ logger.info(f"world_size: {world_size}, rank: {rank}, distributed_init_method: {distributed_init_method}, local_rank: {local_rank}, backend: {backend}")
+
logger.debug(
"world_size=%d rank=%d local_rank=%d "
"distributed_init_method=%s backend=%s", world_size, rank, local_rank,
distributed_init_method, backend)
if not torch.distributed.is_initialized():
- assert distributed_init_method is not None, (
- "distributed_init_method must be provided when initializing "
+ assert distributed_init_method is not None or store is not None, (
+ "distributed_init_method or store must be provided when initializing "
"distributed environment")
# this backend is used for WORLD
- torch.distributed.init_process_group(
- backend=backend,
- init_method=distributed_init_method,
- world_size=world_size,
- rank=rank)
+ if store is None:
+ torch.distributed.init_process_group(
+ backend=backend,
+ init_method=distributed_init_method,
+ world_size=world_size,
+ rank=rank)
+ else:
+ torch.distributed.init_process_group(
+ backend=backend,
+ # init_method=distributed_init_method,
+ world_size=world_size,
+ rank=rank,
+ store=store)
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
@@ -958,6 +997,10 @@
assert _WORLD.world_size == torch.distributed.get_world_size(), (
"world group already initialized with a different world size")
+ global _STORE
+ if store is not None and _STORE is None:
+ _STORE = store
+
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
@@ -986,32 +1029,60 @@
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
+ logger.debug("="*50)
+ logger.debug("Patching initialize_model_parallel")
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
- if (world_size !=
- tensor_model_parallel_size * pipeline_model_parallel_size):
- raise RuntimeError(
- f"world_size ({world_size}) is not equal to "
- f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
- f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
+
+ # TODO ptarasiewicz this assertion does not work with disagg
+ # if (world_size !=
+ # tensor_model_parallel_size * pipeline_model_parallel_size):
+ # raise RuntimeError(
+ # f"world_size ({world_size}) is not equal to "
+ # f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
+ # f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
# Build the tensor model-parallel groups.
num_tensor_model_parallel_groups: int = (world_size //
- tensor_model_parallel_size)
+ tensor_model_parallel_size)
+
+ stage = envs.VLLM_DISAGG_STAGE
+
global _TP
assert _TP is None, ("tensor model parallel group is already initialized")
group_ranks = []
for i in range(num_tensor_model_parallel_groups):
- ranks = list(
- range(i * tensor_model_parallel_size,
- (i + 1) * tensor_model_parallel_size))
- group_ranks.append(ranks)
+
+ # TODO ptarasiewicz this is a hack to adjust the ranks for the stage
+ if stage is not None and envs.VLLM_DATA_PLANE_BACKEND == "nccl":
+ ranks = []
+ context_workers = envs.VLLM_CONTEXT_WORKERS
+ context_tp_size = envs.VLLM_CONTEXT_TP_SIZE
+ generate_workers = envs.VLLM_GENERATE_WORKERS
+ generate_tp_size = envs.VLLM_GENERATE_TP_SIZE
+ for context_id in range(context_workers):
+ ranks.append(
+ [context_id * context_tp_size + i for i in range(context_tp_size)]
+ )
+ for generate_id in range(generate_workers):
+ ranks.append(
+ [context_workers * context_tp_size + generate_id * generate_tp_size + i for i in range(generate_tp_size)]
+ )
+ group_ranks.extend(ranks)
+ break
+ else:
+ ranks = list(
+ range(i * tensor_model_parallel_size,
+ (i + 1) * tensor_model_parallel_size))
+ group_ranks.append(ranks)
# message queue broadcaster is only used in tensor model parallel group
+ logger.debug("initializing tensor model parallel group")
+ logger.debug(f"group_ranks {group_ranks}")
_TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
@@ -1020,15 +1091,32 @@
# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups: int = (world_size //
- pipeline_model_parallel_size)
+ pipeline_model_parallel_size)
+
global _PP
assert _PP is None, (
"pipeline model parallel group is already initialized")
group_ranks = []
for i in range(num_pipeline_model_parallel_groups):
- ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
- group_ranks.append(ranks)
+
+
+
+ # TODO ptarasiewicz this is a hack to adjust the ranks for the stage
+ if stage is not None and envs.VLLM_DATA_PLANE_BACKEND == "nccl":
+ context_workers = envs.VLLM_CONTEXT_WORKERS
+ context_tp_size = envs.VLLM_CONTEXT_TP_SIZE
+ generate_workers = envs.VLLM_GENERATE_WORKERS
+ generate_tp_size = envs.VLLM_GENERATE_TP_SIZE
+ world_size = context_workers * context_tp_size + generate_workers * generate_tp_size
+ ranks = [[i] for i in range(world_size)]
+ group_ranks.extend(ranks)
+ break
+ else:
+ ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
+ group_ranks.append(ranks)
# pipeline parallel does not need custom allreduce
+ logger.debug("initializing pipeline model parallel group")
+ logger.debug(f"group_ranks {group_ranks}")
_PP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
diff -Naur v0.6.3.post1_vllm/engine/async_llm_engine.py patched_vllm/engine/async_llm_engine.py
--- v0.6.3.post1_vllm/engine/async_llm_engine.py 2025-01-09 03:03:32.443277939 -0800
+++ patched_vllm/engine/async_llm_engine.py 2025-01-09 01:49:43.785300620 -0800
@@ -371,7 +371,7 @@
# multi_step_model_runner does the first-step output append.
is_first_step_output: bool = False if not seq_group_metadata_list \
else seq_group_metadata_list[0].state.num_steps == 1
-
+
ctx.append_output(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
diff -Naur v0.6.3.post1_vllm/engine/llm_engine.py patched_vllm/engine/llm_engine.py
--- v0.6.3.post1_vllm/engine/llm_engine.py 2025-01-09 03:03:32.443277939 -0800
+++ patched_vllm/engine/llm_engine.py 2025-01-09 01:49:43.785300620 -0800
@@ -479,8 +479,28 @@
The workers will determine the number of blocks in both the GPU cache
and the swap CPU cache.
"""
- num_gpu_blocks, num_cpu_blocks = (
- self.model_executor.determine_num_available_blocks())
+
+
+ max_num_seqs = None
+ if self.scheduler_config.max_num_seqs is not None:
+ num_gpu_blocks, num_cpu_blocks = (
+ self.model_executor.determine_num_available_blocks())
+ else:
+ max_num_seqs = 1
+ max_concurrency = None
+ max_iter_count_left = 5
+ while True:
+ logger.info("Profiling with %d sequences", max_num_seqs)
+ num_gpu_blocks, num_cpu_blocks = (
+ self.model_executor.determine_num_available_blocks(max_num_seqs))
+ max_concurrency = (num_gpu_blocks * self.cache_config.block_size / self.model_config.max_model_len)
+ logger.info("Maximum concurrency for %d sequences and %s tokens per request: %.2fx",
+ max_num_seqs, self.model_config.max_model_len, max_concurrency)
+ max_iter_count_left -= 1
+ if max_iter_count_left < 1 or (max_concurrency - max_num_seqs) ** 2 < 2:
+ break
+ max_num_seqs = int(max_concurrency)
+
if self.cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
@@ -494,6 +514,10 @@
self.cache_config.num_cpu_blocks = num_cpu_blocks
self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
+
+ if max_num_seqs is not None and self.scheduler_config.max_num_seqs is None:
+ self.scheduler_config.max_num_seqs = max_num_seqs
+ logger.info("Setting max_num_seqs to %d", max_num_seqs)
@classmethod
def _get_executor_cls(cls,
diff -Naur v0.6.3.post1_vllm/envs.py patched_vllm/envs.py
--- v0.6.3.post1_vllm/envs.py 2025-01-09 03:03:32.439278263 -0800
+++ patched_vllm/envs.py 2025-01-09 01:49:43.789300297 -0800
@@ -66,6 +66,15 @@
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_TORCH_COMPILE_LEVEL: int = 0
VLLM_DISABLED_KERNELS: List[str] = []
+ VLLM_DISAGG_STAGE: Optional[str] = None
+ VLLM_CONTEXT_TP_SIZE: int = 0
+ VLLM_CONTEXT_WORKERS: int = 0
+ VLLM_GENERATE_TP_SIZE: int = 0
+ VLLM_GENERATE_WORKERS: int = 0
+ VLLM_TORCH_HOST: str = "localhost"
+ VLLM_TORCH_PORT: int = 36183
+ VLLM_WORKER_ID: int = 0
+ VLLM_DATA_PLANE_BACKEND: str = "nccl"
def get_default_cache_root():
@@ -433,6 +442,35 @@
"VLLM_DISABLED_KERNELS":
lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[
"VLLM_DISABLED_KERNELS"].split(","),
+
+ # Stage of the disaggregated model
+ "VLLM_DISAGG_STAGE":
+ lambda: os.getenv("VLLM_DISAGG_STAGE", None),
+
+ # Disaggregation global config
+ "VLLM_CONTEXT_TP_SIZE":
+ lambda: int(os.getenv("VLLM_CONTEXT_TP_SIZE", "0")),
+
+ "VLLM_CONTEXT_WORKERS":
+ lambda: int(os.getenv("VLLM_CONTEXT_WORKERS", "0")),
+
+ "VLLM_GENERATE_TP_SIZE":
+ lambda: int(os.getenv("VLLM_GENERATE_TP_SIZE", "0")),
+
+ "VLLM_GENERATE_WORKERS":
+ lambda: int(os.getenv("VLLM_GENERATE_WORKERS", "0")),
+
+ "VLLM_TORCH_HOST":
+ lambda: os.getenv("VLLM_TORCH_HOST", "localhost"),
+
+ "VLLM_TORCH_PORT":
+ lambda: int(os.getenv("VLLM_TORCH_PORT", "36183")),
+
+ "VLLM_WORKER_ID":
+ lambda: int(os.getenv("VLLM_WORKER_ID", "0")),
+
+ "VLLM_DATA_PLANE_BACKEND":
+ lambda: os.getenv("VLLM_DATA_PLANE_BACKEND", "nccl"),
}
# end-env-vars-definition
diff -Naur v0.6.3.post1_vllm/executor/distributed_gpu_executor.py patched_vllm/executor/distributed_gpu_executor.py
--- v0.6.3.post1_vllm/executor/distributed_gpu_executor.py 2025-01-09 03:03:32.443277939 -0800
+++ patched_vllm/executor/distributed_gpu_executor.py 2025-01-09 01:49:43.789300297 -0800
@@ -25,7 +25,7 @@
super().__init__(*args, **kwargs)
- def determine_num_available_blocks(self) -> Tuple[int, int]:
+ def determine_num_available_blocks(self, max_num_seqs: Optional[int] = None) -> Tuple[int, int]:
"""Determine the number of available KV blocks.
This invokes `determine_num_available_blocks` on each worker and takes
@@ -36,7 +36,7 @@
- tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
- num_blocks = self._run_workers("determine_num_available_blocks", )
+ num_blocks = self._run_workers("determine_num_available_blocks", max_num_seqs=max_num_seqs)
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
diff -Naur v0.6.3.post1_vllm/executor/gpu_executor.py patched_vllm/executor/gpu_executor.py
--- v0.6.3.post1_vllm/executor/gpu_executor.py 2025-01-09 03:03:32.443277939 -0800
+++ patched_vllm/executor/gpu_executor.py 2025-01-09 01:49:43.789300297 -0800
@@ -107,11 +107,11 @@
rank=rank,
distributed_init_method=distributed_init_method))
- def determine_num_available_blocks(self) -> Tuple[int, int]:
+ def determine_num_available_blocks(self, max_num_seqs: Optional[int] = None) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
- return self.driver_worker.determine_num_available_blocks()
+ return self.driver_worker.determine_num_available_blocks(max_num_seqs=max_num_seqs)
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
"""Initialize the KV cache by invoking the underlying worker.
diff -Naur v0.6.3.post1_vllm/worker/model_runner.py patched_vllm/worker/model_runner.py
--- v0.6.3.post1_vllm/worker/model_runner.py 2025-01-09 03:03:32.559268548 -0800
+++ patched_vllm/worker/model_runner.py 2025-01-09 01:49:43.993283816 -0800
@@ -1,7 +1,9 @@
+import time
import dataclasses
import gc
import inspect
import itertools
+import os
import time
import warnings
import weakref
@@ -25,6 +27,7 @@
PromptAdapterConfig, SchedulerConfig)
from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group
+from vllm.distributed.kv_cache import get_kv_cache_handler
from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
@@ -46,7 +49,7 @@
from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager)
from vllm.sampling_params import SamplingParams
-from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
+from vllm.sequence import IntermediateTensors, SequenceGroupMetadata, Logprob, SequenceOutput, CompletionSequenceGroupOutput
from vllm.transformers_utils.config import uses_mrope
from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d,
flatten_2d_lists, is_hip, is_pin_memory_available,
@@ -58,6 +61,7 @@
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict, dump_input_when_exception)
+
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
@@ -120,6 +124,7 @@
"virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids,
+ "seq_lens": self.seq_lens,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
@@ -158,6 +163,7 @@
"virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids,
+ "seq_lens": self.seq_lens,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
@@ -959,6 +965,7 @@
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
+ self._kv_cache_handler = None
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
@@ -978,8 +985,7 @@
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
- self.max_batchsize_to_capture = _get_max_graph_batch_size(
- self.scheduler_config.max_num_seqs)
+ self.max_batchsize_to_capture = None
self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
{} for _ in range(self.parallel_config.pipeline_parallel_size)
@@ -995,9 +1001,7 @@
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
- self.graph_block_tables = np.zeros(
- (self.max_batchsize_to_capture, self.get_max_block_per_batch()),
- dtype=np.int32)
+ self.graph_block_tables = None
# Attention-free but stateful models like Mamba need a placeholder attn
# backend, as the attention metadata is needed to manage internal state.
@@ -1196,11 +1200,18 @@
return builder.build() # type: ignore
@torch.inference_mode()
- def profile_run(self) -> None:
+ def profile_run(self, max_num_seqs: Optional[int] = None) -> None:
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
- max_num_seqs = self.scheduler_config.max_num_seqs
+ if max_num_seqs is None:
+ max_num_seqs = self.scheduler_config.max_num_seqs
+
+ self.max_batchsize_to_capture = _get_max_graph_batch_size(
+ max_num_seqs)
+ self.graph_block_tables = np.zeros(
+ (self.max_batchsize_to_capture, self.get_max_block_per_batch()),
+ dtype=np.int32)
# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
@@ -1522,6 +1533,11 @@
# This usually takes < 10 seconds.
logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
+ def init_kv_cache_handler(self) -> None:
+ if envs.VLLM_DISAGG_STAGE is not None:
+ self._kv_cache_handler = get_kv_cache_handler()
+ torch.distributed.barrier()
+
def _update_inputs_to_capture_for_enc_dec_model(self,
capture_inputs: Dict[str,
Any]):
@@ -1552,6 +1568,13 @@
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
ModelInputForGPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
+ _first_tokens = {}
+
+ def set_first_token(self, request_id: str, first_token: int) -> None:
+ self._first_tokens[request_id] = first_token
+
+ def pop_first_token(self, request_id: str) -> int:
+ return self._first_tokens.pop(request_id)
def make_model_input_from_broadcasted_tensor_dict(
self,
@@ -1610,6 +1633,8 @@
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
+ logger.debug(f"execute model input_ids: {model_input.input_tokens.shape}")
+ # logger.info(f"model input {model_input}")
if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in ModelRunner")
@@ -1654,21 +1679,98 @@
model_forward_end = torch.cuda.Event(enable_timing=True)
model_forward_start.record()
+ # Check if we need to run inference or not
+ # We do not run inference if we are in generating stage of disaggregated serving
+ disagg_stage = envs.VLLM_DISAGG_STAGE
+ assert disagg_stage in ["PREFILL", "GENERATE", None], f"Invalid disagg_stage: {disagg_stage}, must be one of ['PREFILL', 'GENERATE', None]"
+ is_profile_run = (
+ (kv_caches is None) or (kv_caches[0] is None) or (kv_caches[0].numel() == 0)
+ )
+ is_generate = disagg_stage == "GENERATE" and prefill_meta is None
+ run_inference = any(
+ [
+ disagg_stage == "PREFILL",
+ is_profile_run,
+ is_generate,
+ disagg_stage is None,
+ ]
+ )
+
+ logger.debug(
+ f"Running inference: {run_inference}, disagg_stage: {disagg_stage}, "
+ f"is_profile_run: {is_profile_run}, is_generate: {is_generate}"
+ )
+ logger.debug(f"Batch size: {model_input.input_tokens.shape} Prefill: {prefill_meta is not None}, Generate: {prefill_meta is None}")
+
+ if not run_inference:
+ if self.is_driver_worker:
+ logger.debug(f"Pulling KV cache for seq lens: {model_input.seq_lens}")
+ # We are in the GENERATE stage of disaggregated serving
+ # instead of running inference, we just need to pull the KV cache
+ # and hidden states from the context stage
+ # TODO ptarasiewicz check why without torch.cuda.synchronize() thorughput is lower
+ logger.debug("PULLING KV CACHE")
+ # torch.cuda.synchronize()
+ # start_time = time.perf_counter_ns()
+ self._kv_cache_handler.recv(
+ self.model.model,
+ model_input,
+ kv_caches,
+ )
+ # torch.cuda.synchronize()
+ # end_time = time.perf_counter_ns()
+ # logger.info(f"KV CACHE PULL TIME: {(end_time - start_time) / 1e6} ms")
+
+ if not self.is_driver_worker:
+ return []
+
+ fist_token = self.pop_first_token(list(model_input.request_ids_to_seq_ids.keys())[0])
+ mocked_output = SamplerOutput(
+ outputs=[
+ CompletionSequenceGroupOutput(
+ samples=[
+ SequenceOutput(
+ parent_seq_id=seq_group.seq_ids[0],
+ output_token=fist_token,
+ logprobs={fist_token: Logprob(float('inf'))},
+ )
+ ],
+ prompt_logprobs=None,
+ )
+ for seq_group in model_input.sampling_metadata.seq_groups
+ ],
+ )
+ logger.debug(f"MOCKED OUTPUT {mocked_output}")
+ return [mocked_output]
+
+ logger.debug("RUNNING INFERENCE")
with set_forward_context(model_input.attn_metadata):
- hidden_or_intermediate_states = model_executable(
- input_ids=model_input.input_tokens,
- positions=model_input.input_positions,
- kv_caches=kv_caches,
- attn_metadata=model_input.attn_metadata,
- intermediate_tensors=intermediate_tensors,
- **MultiModalInputs.as_kwargs(multi_modal_kwargs,
- device=self.device),
- **seqlen_agnostic_kwargs)
+ with torch.cuda.nvtx.range("model_executable"):
+ hidden_or_intermediate_states = model_executable(
+ input_ids=model_input.input_tokens,
+ positions=model_input.input_positions,
+ kv_caches=kv_caches,
+ attn_metadata=model_input.attn_metadata,
+ intermediate_tensors=intermediate_tensors,
+ **MultiModalInputs.as_kwargs(multi_modal_kwargs,
+ device=self.device),
+ **seqlen_agnostic_kwargs)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.record()
+ if disagg_stage == "PREFILL" and not is_profile_run:
+ # TODO ptarasiewicz check why without torch.cuda.synchronize() thorughput is lower
+ logger.debug("Pushing KV cache")
+ torch.cuda.synchronize()
+ self._kv_cache_handler.send(
+ self.model.model,
+ model_input,
+ kv_caches,
+ )
+ logger.debug("finished sending")
+
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
if (self.is_driver_worker
@@ -1687,6 +1789,31 @@
hidden_or_intermediate_states.tensors["model_forward_time"] = (
torch.tensor(model_forward_time + orig_model_forward_time))
return hidden_or_intermediate_states
+
+ if not run_inference:
+
+ if not self.is_driver_worker:
+ return []
+
+ fist_token = self.pop_first_token(list(model_input.request_ids_to_seq_ids.keys())[0])
+ mocked_output = SamplerOutput(
+ outputs=[
+ CompletionSequenceGroupOutput(
+ samples=[
+ SequenceOutput(
+ parent_seq_id=seq_group.seq_ids[0],
+ output_token=fist_token,
+ logprobs={fist_token: Logprob(float('inf'))},
+ )
+ ],
+ prompt_logprobs=None,
+ )
+ for seq_group in model_input.sampling_metadata.seq_groups
+ ],
+ )
+ # print("OUTPUT", output)
+ print("MOCKED OUTPUT", mocked_output)
+ return [mocked_output]
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
@@ -1702,6 +1829,7 @@
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
+
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time
and output is not None):
diff -Naur v0.6.3.post1_vllm/worker/worker.py patched_vllm/worker/worker.py
--- v0.6.3.post1_vllm/worker/worker.py 2025-01-09 03:03:32.559268548 -0800
+++ patched_vllm/worker/worker.py 2025-01-09 01:49:43.993283816 -0800
@@ -202,7 +202,7 @@
tensorizer_config=tensorizer_config, )
@torch.inference_mode()
- def determine_num_available_blocks(self) -> Tuple[int, int]:
+ def determine_num_available_blocks(self, max_num_seqs: Optional[int] = None) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
@@ -214,13 +214,14 @@
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
+
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
- self.model_runner.profile_run()
+ self.model_runner.profile_run(max_num_seqs)
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
@@ -242,16 +243,18 @@
else:
num_gpu_blocks = int(
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
- peak_memory) // cache_block_size)
+ peak_memory) // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
- cache_block_size)
+ cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
if self.model_runner.lora_manager:
self.model_runner.remove_all_loras()
gc.collect()
torch.cuda.empty_cache()
- return num_gpu_blocks, num_cpu_blocks
+ return num_gpu_blocks, num_cpu_blocks
+
+
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
@@ -285,6 +288,7 @@
def _warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
self.model_runner.capture_model(self.gpu_cache)
+ self.model_runner.init_kv_cache_handler()
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment