Unverified Commit 0590ec3f authored by Kuntai Du's avatar Kuntai Du Committed by GitHub
Browse files

[Core] Implement disagg prefill by StatelessProcessGroup (#10502)



This PR provides initial support for single-node disaggregated prefill in 1P1D scenario.
Signed-off-by: default avatarKuntaiDu <kuntai@uchicago.edu>
Co-authored-by: default avatarApostaC <yihua98@uchicago.edu>
Co-authored-by: default avatarYaoJiayi <120040070@link.cuhk.edu.cn>
parent c11f1721
"""
Simple KV Cache Connector for Distributed Machine Learning Inference
The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe.
But the logic can be extended to support other pipe and lookup buffer.
"""
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
from vllm import _custom_ops as ops
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
SimpleBuffer)
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
logger = init_logger(__name__)
class SimpleConnector(KVConnectorBase):
def __init__(
self,
rank: int,
local_rank: int,
config: VllmConfig,
):
self.config = config.kv_transfer_config
logger.info("Initializing PyNcclConfig under kv_transfer_config %s",
self.config)
self.lookup_buffer_size = self.config.kv_buffer_size
self.producer_buffer: Optional[SimpleBuffer] = None
self.consumer_buffer: Optional[SimpleBuffer] = None
# 2 pipes for every rank in the world
port_offset_base = 2 * rank
# In disaggregated prefill, the prefill vLLM only uses send pipe
# and the decode vLLM only uses recv pipe
if self.config.is_kv_producer:
self.producer_data_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.producer_signal_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
device="cpu",
)
self.producer_buffer = SimpleBuffer(self.producer_signal_pipe,
self.producer_data_pipe,
self.config.kv_buffer_size)
else:
# the current vLLM instance is KV consumer, so it needs to connect
# its recv pipe to the send pipe of KV producder
self.consumer_data_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.consumer_signal_pipe = PyNcclPipe(
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
device="cpu",
)
self.consumer_buffer = SimpleBuffer(
self.consumer_signal_pipe,
self.consumer_data_pipe,
self.config.kv_buffer_size,
)
def select(self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
assert self.consumer_buffer is not None, "Please initialize the "\
"consumer buffer before calling select."
return self.consumer_buffer.drop_select(input_tokens, roi)
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
assert self.producer_buffer is not None, "Please initialize the "\
"producer buffer before calling insert."
self.producer_buffer.insert(input_tokens, roi, key, value, hidden)
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
# query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance
# FIXME(Kuntai): This assume that all requests are prefill.
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]
keys, values = [], []
for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]
_, _, num_heads, head_size = kv_cache[0].shape
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
values.append(value_cache[current_slot_mapping].unsqueeze(0))
keys = torch.cat(keys, dim=0)
values = torch.cat(values, dim=0)
self.insert(current_tokens,
torch.ones_like(current_tokens,
dtype=bool), keys, values,
hidden_or_intermediate_states[start_pos:end_pos])
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
# When bypass_model_exec is set to False, it means that at least for one
# request its corresponding KV cache or hidden state is missing.
# In this case we need to do prefilling to recompute missing KV cache
# and hidden states.
bypass_model_exec = True
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
hidden_or_intermediate_states_for_one_req = []
input_tokens_list = []
num_computed_tokens_list = []
start_pos_list = []
# enumerate different requests
# FIXME(Kuntai): This impl assumes that all requests are prefill.
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]
num_tokens = slen
# collecting data for rebuilding the input
input_tokens_list.append(current_tokens)
start_pos_list.append(start_pos)
ret = self.select(current_tokens,
torch.ones_like(current_tokens, dtype=bool))
if ret[0] is None:
# didn't find any match.
bypass_model_exec = False
num_computed_tokens_list.append(0)
continue
roi: torch.Tensor = ret[1]
keys: torch.Tensor = ret[2]
values: torch.Tensor = ret[3]
hidden: torch.Tensor = ret[4]
num_computed_tokens = roi.shape[0]
num_computed_tokens_list.append(num_computed_tokens)
# check if both KV cache and the hidden states are received
# If not, need to redo the forwarding to compute missing states
if not all([(num_computed_tokens == num_tokens), hidden is not None
]):
bypass_model_exec = False
# update the end position based on how many tokens are cached.
end_pos = start_pos + num_computed_tokens
# put received KV caches into paged memory
for i in range(model_executable.model.start_layer,
model_executable.model.end_layer):
kv_cache = kv_caches[i - model_executable.model.start_layer]
layer = model_executable.model.layers[i]
key_cache, value_cache = kv_cache[0], kv_cache[1]
ops.reshape_and_cache_flash(
keys[i - model_executable.model.start_layer].to(
key_cache.device),
values[i - model_executable.model.start_layer].to(
value_cache.device),
key_cache,
value_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)
hidden_or_intermediate_states_for_one_req.append(hidden)
if not bypass_model_exec:
# Some of the KV cache is not retrieved
# Here we will fall back to normal model forwarding
# But optionally you can adjust model_input so that you only do
# prefilling on those tokens that are missing KV caches.
logger.debug(
"[rank%d]: Failed to receive all KVs and hidden "
"states, redo model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = None
else:
logger.debug(
"[rank%d]: Successfully received all KVs and hidden "
"states, skip model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = torch.cat(
hidden_or_intermediate_states_for_one_req, dim=0)
return hidden_or_intermediate_states, bypass_model_exec, model_input
def close(self):
self.producer_data_pipe.close()
self.producer_signal_pipe.close()
self.consumer_data_pipe.close()
self.consumer_signal_pipe.close()
"""
This file contains a new class `KVLookupBufferBase` that allows developers to
think of KV cache operations as inserting new KV cache entries (`insert`)
into the lookup buffer and querying existing KV caches (`drop_select`)
from the lookup buffer.
All distributed communications are abstracted behind this class.
"""
from abc import ABC, abstractmethod
from typing import List, Optional
import torch
class KVLookupBufferBase(ABC):
"""
Abstract base class for a lookup buffer.
This class provides an abstraction for a key-value (KV) cache lookup buffer.
The key of the lookup buffer:
- input_tokens: token IDs of the request
- roi: a binary mask on top of input_tokens.
- Purpose of roi: Since KV cache may only be available for a subset of
tokens in the input (for example, when vLLM is connected to an external
KV cache service), roi specifies the subset of tokens that the KV cache
is associated with.
- NOTE: roi can be further extended to describe which part of KV the
current process is holding (each process may only hold a part of KV
due to TP and PP). This is not implemented for now.
The value of the lookup buffer:
- key: the key tensor in the KV cache
- value: the value tensor in the KV cache
- hidden: the final hidden state generated by model forwarding. This allows
vLLM to bypass further model forwarding by transmitting the hidden state.
"""
@abstractmethod
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
"""Insert into the lookup buffer.
The functionality is similar to the following python statement
```
buffer[input_tokens, roi] = [key, value, hidden]
```
FIXME: in the future, we should only have two arguments, key and value,
where key is a tensor dict and value is a tensor dict.
FIXME: we should transmit both sampler outputs and the hidden states.
Args:
input_tokens (torch.Tensor): token IDs.
roi (torch.Tensor): A binary mask on top of the input tokens
key (torch.Tensor): The key tensor in the KV cache.
value (torch.Tensor): The value tensor in the KV cache.
hidden (torch.Tensor): The final hidden state tensor generated
during model forwarding to bypass model
forwarding.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def drop_select(
self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
"""Select and *drop* KV cache entries from the lookup buffer.
The functionality is similar to the following python statements
```
ret = buffer.pop(input_tokens, roi)
return ret
```
If `input_tokens` and `roi` is `None`, it means selecting any of the
KV caches in the buffer, return, and remove it from the buffer, useful
when offloading KV cache to KV cache storage service.
Args:
input_tokens (torch.Tensor): token IDs.
roi (torch.Tensor): A binary mask on top of the input tokens
Returns:
List[Optional[torch.Tensor]]: A list of tensors. Can be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def close(self) -> None:
"""Close the buffer and release resources.
This method is responsible for cleaning up resources related to the
lookup buffer when it is no longer needed.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
"""
Implements a distributed key-value (KV) cache transfer mechanism.
Key Features:
- Distributed KV cache transmission using PyNccl pipes.
- Non-blocking `insert`, blocking `drop_select`.
- Use CPU signal pipe to avoid racing condition
- Handles buffer size constraints and provide backpressure mechanism to
stop the prefill instance when the decode instance is slow.
"""
import threading
import time
from collections import deque
from typing import Deque, List, Optional, Union
import torch
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import (
KVLookupBufferBase)
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.logger import init_logger
logger = init_logger(__name__)
class SimpleBuffer(KVLookupBufferBase):
def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase,
buffer_size_thresh: float):
"""
signal_pipe: on CPU
NOTE: on-device recv will block all threads in the process, making the
KV cache producer unable to listen to new request while transmitting
KV cache. Luckily CPU recv only blocks the current thread so we use
CPU recv to listen to new request.
data_pipe: on device (e.g. GPU)
"""
self.buffer: Deque[List[torch.Tensor]] = deque()
self.buffer_size = 0
self.buffer_size_threshold = buffer_size_thresh
self.buffer_lock = threading.Lock()
self.signal_pipe = signal_pipe
self.data_pipe = data_pipe
self.request_handling_thread: Optional[threading.Thread] = None
self.normal_signal = torch.tensor([0], device="cpu")
self.end_signal = None
def _matches(self, tokens_roi_sender: List[torch.Tensor],
tokens_roi_recver: List[torch.Tensor]):
# tokens_roi_sender: tokens and roi of the producer (in the buffer)
# tokens_roi_recver: tokens and roi of the consumer (query)
tokens_sender = tokens_roi_sender[0]
tokens_recver = tokens_roi_recver[0]
roi_sender = tokens_roi_sender[1]
roi_recver = tokens_roi_recver[1]
if tokens_recver is None:
# consumer sends an empty request
# semantics: DROP SELECT * LIMIT 1
# so any of the data in the buffer can be drop-selected
return True
# Assuming that roi is a binary mask on tokens
tokens_sender = tokens_sender[roi_sender]
tokens_recver = tokens_recver[roi_recver]
# simple common prefix matching
min_length = min(len(tokens_sender), len(tokens_recver))
if torch.allclose(tokens_sender[:min_length],
tokens_recver[:min_length]):
return min_length
return 0
def _send_tensor_and_dec_size(self,
tensor: Optional[torch.Tensor]) -> None:
assert tensor is not None, "Use self.data_pipe.send(None) instead"
self.buffer_size -= tensor.element_size() * tensor.numel()
if tensor.dtype == torch.bool:
tensor = tensor.float()
self.data_pipe.send_tensor(tensor)
def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]):
if isinstance(data, torch.Tensor):
return data.element_size() * data.numel()
if not data:
# cannot perform `not data` on a tensor
# so this check needs to go after the check above
return 0
raise AssertionError(f"Unknown data type {type(data)}")
def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor):
if isinstance(input_tokens, torch.Tensor):
input_tokens = input_tokens.clone()
if isinstance(roi, torch.Tensor):
roi = roi.clone()
if isinstance(key, torch.Tensor):
key = key.clone()
if isinstance(value, torch.Tensor):
value = value.clone()
if isinstance(hidden, torch.Tensor):
hidden = hidden.clone()
buffer_item = [input_tokens, roi, key, value, hidden]
with self.buffer_lock:
for data in buffer_item:
self.buffer_size += self._get_element_size(data)
self.buffer.append(buffer_item)
def _is_end_signal(self, signal):
return signal is None
def drop_select_handler(self):
try:
while True:
signal = self.signal_pipe.recv_tensor()
if self._is_end_signal(signal):
logger.info("Received end signal!")
break
input_tokens = self.data_pipe.recv_tensor()
roi = self.data_pipe.recv_tensor()
assert roi is not None, "Please provide the roi when sending "\
"drop-select request"
roi = (roi > 0.5)
tokens_roi_recver = [input_tokens, roi]
matched_length = 0
# perform input tokens and roi matching
# FIXME: this matching is O(n), ideally it should be O(1)
# but this buffer size won't (and shouldn't) be too large so
# the fix is not urgent.
with self.buffer_lock:
for _ in range(len(self.buffer)):
temp_length = self._matches(self.buffer[0],
tokens_roi_recver)
if temp_length > 0:
matched_length = temp_length
break
# rotate the element we just accessed to the end
self.buffer.rotate(-1)
if matched_length > 0:
# need to clone the tensor
# in case the tensor is freed before sending finishes
matched_item = self.buffer.popleft()
for tensor in matched_item:
self._send_tensor_and_dec_size(tensor)
else:
# no match, just send None
for _ in range(5):
self.data_pipe.send_tensor(None)
except RuntimeError as e:
if 'Connection closed by peer' not in str(e):
raise e
logger.debug("Closing drop_select_handler")
def drop_select(
self, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
assert self.request_handling_thread is None, \
"drop_select should be called by the KV cache consumer "\
"(e.g. the decode vLLM instance)"
if isinstance(input_tokens, torch.Tensor):
input_tokens = input_tokens.clone()
if isinstance(roi, torch.Tensor):
roi = roi.clone().float()
self.signal_pipe.send_tensor(self.normal_signal)
self.data_pipe.send_tensor(input_tokens)
self.data_pipe.send_tensor(roi)
input_tokens = self.data_pipe.recv_tensor()
roi = self.data_pipe.recv_tensor()
if roi is not None:
# convert from float tensor to bool tensor
# as PyNccl does not support sending bool tensor
roi = (roi > 0.5)
key = self.data_pipe.recv_tensor()
value = self.data_pipe.recv_tensor()
hidden = self.data_pipe.recv_tensor()
return [input_tokens, roi, key, value, hidden]
def full_handler(self):
time.sleep(0.001)
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
if self.buffer_size > self.buffer_size_threshold:
# log outside the while loop to avoid this message being logged
# repeatedly.
logger.debug("KV transfer buffer is full. Handling...")
while self.buffer_size > self.buffer_size_threshold:
self.full_handler()
self._add_to_buffer(input_tokens, roi, key, value, hidden)
# when calling the insert, the current process is a sender
# need to launch the request handler and start listening to request.
if self.request_handling_thread is None:
self.request_handling_thread = threading.Thread(
target=self.drop_select_handler)
self.request_handling_thread.start()
def close(self):
if hasattr(self, "request_handling_thread"
) and self.request_handling_thread is not None:
self.request_handling_thread.join()
else:
# TODO: have a explicit close signal and have a explicit way to
# check if it's requester
self.signal_pipe.send_tensor(self.end_signal)
"""
This file defines an interface `KVPipeBase`
that provides an abstraction for sending and receiving tensors, or None, via
distributed communications.
All classes instantiated from this interface are assumed to be a FIFO pipe.
If your distributed communication platform already supports key-value lookup,
you can bypass this interface and directly start from `kv_lookup_buffer`.
"""
from abc import ABC, abstractmethod
from typing import Optional
import torch
class KVPipeBase(ABC):
"""
This class provides an interface for sending and receiving tensors, or
None, by distributed communications.
"""
@abstractmethod
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
"""Send a tensor, or None, via the pipe.
Need to support sending None -- important for error handling.
TODO: add a `key` argument so that we can use traditional
key-value database as the distributed communication mechanism behind
the pipe.
Args:
tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def recv_tensor(self) -> Optional[torch.Tensor]:
"""Receive a tensor (can be None) from the pipeline.
Returns:
Optional[torch.Tensor]: The tensor received from the pipeline. Can
be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def close(self) -> None:
"""Close the pipeline and release resources.
This method is responsible for closing the communication pipeline
and releasing any resources associated with it.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
"""
This module implements a PyNccl pipe for sending and receiving
Optional[torch.Tensor] between distributed ranks with advanced
communication features.
Key Features:
- Supports sending and receiving tensors with metadata
- Handles both CUDA and CPU device communications
- Implements a non-blocking tensor transfer mechanism
- Manages buffer size and provides backpressure control
- Supports distributed process groups with configurable parameters
"""
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Dict, Optional, Tuple
import torch
from vllm.config import KVTransferConfig
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
logger = init_logger(__name__)
class BrokenPipeException(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)
Metadata = Dict[str, Optional[torch.Tensor]]
class PyNcclPipe(KVPipeBase):
METADATA_LENGTH = 16
MAX_TENSOR_DIMENSIONS = 14
METADATA_DTYPE = torch.int64
def __init__(self,
local_rank: int,
config: KVTransferConfig,
device: Optional[str] = None,
port_offset: int = 0):
self.config = config
self.local_rank = local_rank
self.kv_rank = self.config.kv_rank
self.kv_parallel_size = self.config.kv_parallel_size
if device is None:
self.device = self._select_device(self.config.kv_buffer_device)
else:
self.device = self._select_device(device)
# build distributed connection and send/recv implementation
self.group = StatelessProcessGroup.create(
host=self.config.kv_ip,
port=self.config.kv_port + port_offset,
rank=self.kv_rank,
world_size=self.kv_parallel_size,
)
# add a barrier to make sure the connection is initiated properly
self.group.barrier()
impl = self._get_device_send_recv_impl(self.group)
self.device_send_func, self.device_recv_func = impl
# set target rank
self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size
self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size
# transportation-related variables
self.transport_thread: Optional[ThreadPoolExecutor] = None
self.buffer_size = 0
self.buffer_size_lock = threading.Lock()
self.buffer_size_thresh = self.config.kv_buffer_size
def _get_device_send_recv_impl(
self, group: StatelessProcessGroup
) -> Tuple[Callable[[torch.Tensor, int], None], Callable[
[torch.Tensor, int], None]]:
send: Callable[[torch.Tensor, int], None]
recv: Callable[[torch.Tensor, int], None]
if self.device.type == "cuda":
# use PyNCCL for send / recv
comm = PyNcclCommunicator(group, device=self.local_rank)
comm.disabled = False
send, recv = comm.send, comm.recv # type: ignore
else:
# This send / recv implementation here is NOT intended to transfer
# KV caches (and should NOT be repurposed to transfer KV caches).
# Currently it is only used to transmit control-plane messages
# for PyNcclBuffer.
send = group.send_obj
def my_recv(x, src):
x[...] = group.recv_obj(src)
recv = my_recv
return send, recv
def _select_device(self, device: str):
logger.info("Selecting device: %s", device)
if device == "cuda":
return torch.device(f"cuda:{self.local_rank}")
else:
return torch.device("cpu")
def _make_metadata(self, tensor: Optional[torch.Tensor]) -> Metadata:
"""
Create the metadata as a dictionary based on the input tensor.
Parameters:
- tensor: The input tensor or None if no tensor is provided.
Returns:
- metadata: A dictionary with the following keys:
- "dtype": The data type of the tensor or None.
- "shape": The shape of the tensor or None.
"""
if tensor is None:
return {"dtype": None, "shape": None}
else:
return {"dtype": tensor.dtype, "shape": tensor.shape}
def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor:
"""
Create a buffer to receive the tensor based on the provided metadata.
Parameters:
- metadata: A dictionary with keys "dtype" and "shape", describing
the tensor's data type and shape.
Returns:
- buffer: A tensor of the specified type and shape, allocated on
self.device.
"""
return torch.empty(metadata["shape"],
dtype=metadata["dtype"],
device=self.device)
def _send_metadata(self, metadata: Metadata):
"""
Send the metadata dictionary to the target rank.
Parameters:
- metadata: A dictionary with keys "dtype" and "shape".
"""
self.group.send_obj(metadata, self.target_rank_for_send)
def _recv_metadata(self) -> Metadata:
"""
Receive the metadata dictionary from the target rank.
Returns:
- metadata: A dictionary with keys "dtype" and "shape" describing
the tensor.
"""
return self.group.recv_obj(self.target_rank_for_recv)
def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:
"""
The actual implementation of sending the tensor and its metadata to the
target rank.
Parameters:
- tensor: The input tensor to be sent, or None if no tensor is
being sent.
"""
metadata = self._make_metadata(tensor)
self._send_metadata(metadata)
if tensor is not None:
self.device_send_func(tensor.to(self.device),
self.target_rank_for_send)
def _recv_impl(self) -> Optional[torch.Tensor]:
"""
The actual implementation of receiving a tensor and its metadata from
the target rank.
Returns:
- buffer: The received tensor, or None if no tensor is received.
"""
metadata = self._recv_metadata()
if metadata["dtype"] is None:
return None
buffer = self._prepare_recv_buffer(metadata)
self.device_recv_func(buffer, self.target_rank_for_recv)
return buffer
def send_tensor_wrapper(self, tensor: Optional[torch.Tensor],
tensor_size: int) -> None:
"""
Wrapper for _send_impl to handle exceptions and update buffer size.
"""
try:
self._send_impl(tensor)
with self.buffer_size_lock:
self.buffer_size -= tensor_size
except Exception as e:
logger.error("[rank%d]: Exception when trying to send %s, msg: %s",
torch.distributed.get_rank(), str(tensor), str(e))
import traceback
traceback.print_exc()
def block_if_full(self):
"""
Block the current thread if the buffer size is larger than the
threshold.
"""
while self.buffer_size > self.buffer_size_thresh:
logger.debug("KV cache transfer pipe is full. Waiting...")
time.sleep(0.05)
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
"""
Sends a tensor and its metadata to the destination rank in a
non-blocking way.
Parameters:
- tensor: The tensor to send, or None if no tensor is being sent.
"""
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
if tensor is not None:
tensor_size = tensor.element_size() * tensor.numel()
else:
tensor_size = 0
self.block_if_full()
with self.buffer_size_lock:
self.buffer_size += tensor_size
self.transport_thread.submit(self.send_tensor_wrapper, tensor,
tensor_size)
def recv_tensor(self) -> Optional[torch.Tensor]:
"""
Receives a tensor and its metadata from the source rank. Blocking call.
Returns:
- tensor: The received tensor, or None if no tensor is received.
"""
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
future = self.transport_thread.submit(self._recv_impl)
try:
tensor = future.result()
except Exception as e:
logger.error("Encountering exception in KV receiving thread")
logger.error("%s", e)
logger.error("My device: %s", self.device)
import traceback
traceback.print_exc()
raise e
return tensor
def close(self):
"""
Close the pipe and release associated resources.
"""
if hasattr(self,
"transport_thread") and self.transport_thread is not None:
self.transport_thread.shutdown()
"""A centralized entrypoint to perform distributed KV cache transfer.
This implementation is a shim wrapper on two APIs exposed by `kv_connector`:
1. `send_kv_caches_and_hidden_states`
2. `recv_kv_caches_and_hidden_states
"""
from typing import TYPE_CHECKING, List, Tuple, Union
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from vllm.config import VllmConfig
import torch
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
logger = init_logger(__name__)
class KVTransferAgent:
"""
A class designated for distributed KV transfer
Target use cases:
1. Disaggregated prefill
2. Remote KV cache storage
"""
def __init__(
self,
rank: int,
local_rank: int,
config: "VllmConfig",
):
self.config = config
if config.kv_transfer_config is None:
raise ValueError("KVTransferConfig is not set in the VllmConfig,"
" cannot initialize KVConnector.")
assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\
"TransferAgent should only be used when kv_connector is set."
self.connector = KVConnectorFactory.create_connector(
rank, local_rank, config)
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
self.connector.send_kv_caches_and_hidden_states(
model_executable, model_input, kv_caches,
hidden_or_intermediate_states)
def close(self) -> None:
self.connector.close()
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
return self.connector.recv_kv_caches_and_hidden_states(
model_executable, model_input, kv_caches)
...@@ -27,18 +27,23 @@ from collections import namedtuple ...@@ -27,18 +27,23 @@ from collections import namedtuple
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union)
from unittest.mock import patch from unittest.mock import patch
import torch import torch
import torch.distributed import torch.distributed
from torch.distributed import Backend, ProcessGroup from torch.distributed import Backend, ProcessGroup
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, supports_custom_op from vllm.utils import direct_register_custom_op, supports_custom_op
if TYPE_CHECKING:
from vllm.config import VllmConfig
@dataclass @dataclass
class GraphCaptureContext: class GraphCaptureContext:
...@@ -904,6 +909,14 @@ def get_pp_group() -> GroupCoordinator: ...@@ -904,6 +909,14 @@ def get_pp_group() -> GroupCoordinator:
# kept for backward compatibility # kept for backward compatibility
get_pipeline_model_parallel_group = get_pp_group get_pipeline_model_parallel_group = get_pp_group
_KV_TRANSFER: Optional[kv_transfer.KVTransferAgent] = None
def get_kv_transfer_group() -> kv_transfer.KVTransferAgent:
assert _KV_TRANSFER is not None, (
"disaggregated KV cache transfer parallel group is not initialized")
return _KV_TRANSFER
@contextmanager @contextmanager
def graph_capture(): def graph_capture():
...@@ -1052,6 +1065,26 @@ def initialize_model_parallel( ...@@ -1052,6 +1065,26 @@ def initialize_model_parallel(
group_name="pp") group_name="pp")
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
"""
Initialize KV cache transfer parallel group.
"""
global _KV_TRANSFER
if vllm_config.kv_transfer_config is None:
return
if all([
vllm_config.kv_transfer_config.need_kv_parallel_group,
_KV_TRANSFER is None
]):
_KV_TRANSFER = kv_transfer.KVTransferAgent(
rank=get_world_group().rank,
local_rank=get_world_group().local_rank,
config=vllm_config)
def ensure_model_parallel_initialized( def ensure_model_parallel_initialized(
tensor_model_parallel_size: int, tensor_model_parallel_size: int,
pipeline_model_parallel_size: int, pipeline_model_parallel_size: int,
......
...@@ -9,10 +9,10 @@ import torch ...@@ -9,10 +9,10 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat, from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
DecodingConfig, DeviceConfig, HfOverrides, LoadConfig, DecodingConfig, DeviceConfig, HfOverrides,
LoadFormat, LoRAConfig, ModelConfig, KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig, ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig, PoolerConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TaskOption, TokenizerPoolConfig, SpeculativeConfig, TaskOption, TokenizerPoolConfig,
VllmConfig) VllmConfig)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
...@@ -108,6 +108,7 @@ class EngineArgs: ...@@ -108,6 +108,7 @@ class EngineArgs:
# notice. # notice.
distributed_executor_backend: Optional[Union[str, distributed_executor_backend: Optional[Union[str,
Type[ExecutorBase]]] = None Type[ExecutorBase]]] = None
# number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = 1 pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1 tensor_parallel_size: int = 1
max_parallel_loading_workers: Optional[int] = None max_parallel_loading_workers: Optional[int] = None
...@@ -194,6 +195,8 @@ class EngineArgs: ...@@ -194,6 +195,8 @@ class EngineArgs:
compilation_config: Optional[CompilationConfig] = None compilation_config: Optional[CompilationConfig] = None
worker_cls: str = "auto" worker_cls: str = "auto"
kv_transfer_config: Optional[KVTransferConfig] = None
def __post_init__(self): def __post_init__(self):
if not self.tokenizer: if not self.tokenizer:
self.tokenizer = self.model self.tokenizer = self.model
...@@ -908,6 +911,12 @@ class EngineArgs: ...@@ -908,6 +911,12 @@ class EngineArgs:
'compilers, using -O without space is also ' 'compilers, using -O without space is also '
'supported. -O3 is equivalent to -O 3.') 'supported. -O3 is equivalent to -O 3.')
parser.add_argument('--kv-transfer-config',
type=KVTransferConfig.from_cli,
default=None,
help='The configurations for distributed KV cache '
'transfer. Should be a JSON string.')
parser.add_argument( parser.add_argument(
'--worker-cls', '--worker-cls',
type=str, type=str,
...@@ -1201,6 +1210,7 @@ class EngineArgs: ...@@ -1201,6 +1210,7 @@ class EngineArgs:
observability_config=observability_config, observability_config=observability_config,
prompt_adapter_config=prompt_adapter_config, prompt_adapter_config=prompt_adapter_config,
compilation_config=self.compilation_config, compilation_config=self.compilation_config,
kv_transfer_config=self.kv_transfer_config,
) )
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
......
...@@ -21,7 +21,7 @@ from vllm.attention.backends.utils import CommonAttentionState ...@@ -21,7 +21,7 @@ from vllm.attention.backends.utils import CommonAttentionState
from vllm.compilation.compile_context import set_compile_context from vllm.compilation.compile_context import set_compile_context
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group from vllm.distributed import get_kv_transfer_group, get_pp_group
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.inputs import INPUT_REGISTRY, InputRegistry
...@@ -1666,6 +1666,24 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1666,6 +1666,24 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
else: else:
model_executable = self.model model_executable = self.model
# Receive KV cache in distributed KV cache transfer setting
# In disagg prefill setting, it will also recv hidden states and bypass
# model forwarding
# In KV cache database setting, it will change the model input so that
# we can skip prefilling on tokens that successfully received KV caches
# NOTE: The receive operation is blocking
bypass_model_exec = False
if self.need_recv_kv(model_input, kv_caches):
hidden_or_intermediate_states, bypass_model_exec, model_input = \
get_kv_transfer_group().recv_kv_caches_and_hidden_states(
# model is used to know which layer the current worker
# is working on, so that we can receive KV for only those
# layers.
model_executable,
model_input,
kv_caches=kv_caches
)
multi_modal_kwargs = model_input.multi_modal_kwargs or {} multi_modal_kwargs = model_input.multi_modal_kwargs or {}
seqlen_agnostic_kwargs = { seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids, "finished_requests_ids": model_input.finished_requests_ids,
...@@ -1677,7 +1695,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1677,7 +1695,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_end = torch.cuda.Event(enable_timing=True)
model_forward_start.record() model_forward_start.record()
with set_forward_context(model_input.attn_metadata, self.vllm_config): if not bypass_model_exec:
with set_forward_context(model_input.attn_metadata,
self.vllm_config):
hidden_or_intermediate_states = model_executable( hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
...@@ -1692,6 +1712,19 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1692,6 +1712,19 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
and self.observability_config.collect_model_forward_time): and self.observability_config.collect_model_forward_time):
model_forward_end.record() model_forward_end.record()
# Sending KV cache in distributed KV cache transfer setting
# NOTE: the send operation is non-blocking
if self.need_send_kv(model_input, kv_caches):
get_kv_transfer_group().send_kv_caches_and_hidden_states(
# model_executable is used to know which layer the current
# worker is working on, so that we can send KV for only those
# layers.
model_executable,
model_input,
kv_caches,
hidden_or_intermediate_states,
)
# Compute the logits in the last pipeline stage. # Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
if (self.is_driver_worker if (self.is_driver_worker
...@@ -1759,6 +1792,56 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1759,6 +1792,56 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
return [output] return [output]
def need_recv_kv(self, model_input, kv_caches) -> bool:
"""Check if we need to receive kv-cache from the other worker.
We need to receive KV when
1. current vLLM instance is KV cache consumer/decode vLLM instance
2. this batch is not a profiling run
3. this batch is a prefill run
Args:
model_input: input to the model executable
kv_caches: vLLM's paged memory
"""
prefill_meta = model_input.attn_metadata.prefill_metadata
# check if the current run is profiling
is_profile_run = (kv_caches[0].numel() == 0)
# check if the current run is prefill
is_prefill_run = prefill_meta is not None
if self.vllm_config.kv_transfer_config is None:
return False
return self.vllm_config.kv_transfer_config.is_kv_consumer and (
not is_profile_run) and is_prefill_run
def need_send_kv(self, model_input, kv_caches) -> bool:
"""Check if we need to send kv-cache to the other worker.
We need to send KV when
1. current vLLM instance is KV cache producer/prefill vLLM instance
2. this batch is not a profiling run
3. this batch is a prefill run
Args:
model_input: input to the model executable
kv_caches: vLLM's paged memory
"""
prefill_meta = model_input.attn_metadata.prefill_metadata
# check if the current run is profiling
is_profile_run = (kv_caches[0].numel() == 0)
# check if the current run is prefill
is_prefill_run = prefill_meta is not None
if self.vllm_config.kv_transfer_config is None:
return False
return self.vllm_config.kv_transfer_config.is_kv_producer and (
not is_profile_run) and is_prefill_run
# NOTE: this is nn.Module so the profiler can properly capture/group # NOTE: this is nn.Module so the profiler can properly capture/group
# kernels calls made within the graph # kernels calls made within the graph
......
...@@ -8,8 +8,9 @@ import torch ...@@ -8,8 +8,9 @@ import torch
import torch.distributed import torch.distributed
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_kv_transfer_initialized,
ensure_model_parallel_initialized,
init_distributed_environment, init_distributed_environment,
set_custom_all_reduce) set_custom_all_reduce)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -144,7 +145,7 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -144,7 +145,7 @@ class Worker(LocalOrDistributedWorkerBase):
raise RuntimeError( raise RuntimeError(
f"Not support device type: {self.device_config.device}") f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment. # Initialize the distributed environment.
init_worker_distributed_environment(self.parallel_config, self.rank, init_worker_distributed_environment(self.vllm_config, self.rank,
self.distributed_init_method, self.distributed_init_method,
self.local_rank) self.local_rank)
# Set random seed. # Set random seed.
...@@ -457,20 +458,22 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -457,20 +458,22 @@ class Worker(LocalOrDistributedWorkerBase):
def init_worker_distributed_environment( def init_worker_distributed_environment(
parallel_config: ParallelConfig, vllm_config: VllmConfig,
rank: int, rank: int,
distributed_init_method: Optional[str] = None, distributed_init_method: Optional[str] = None,
local_rank: int = -1, local_rank: int = -1,
) -> None: ) -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
parallel_config = vllm_config.parallel_config
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank, init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank) distributed_init_method, local_rank)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size)
ensure_kv_transfer_initialized(vllm_config)
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype. # Check if the GPU supports the dtype.
......
...@@ -43,6 +43,7 @@ class WorkerBase(ABC): ...@@ -43,6 +43,7 @@ class WorkerBase(ABC):
self.speculative_config = vllm_config.speculative_config self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
self.kv_transfer_config = vllm_config.kv_transfer_config
@abstractmethod @abstractmethod
def init_device(self) -> None: def init_device(self) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment