Commit a99300bd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc1' into v0.10.2rc1-dev

parents cc3e01c7 5438967f
...@@ -352,7 +352,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): ...@@ -352,7 +352,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
with num_lookahead_slots. with num_lookahead_slots.
Args: Args:
sequence_group (SequenceGroup): The sequence group to swap in. seq_group (SequenceGroup): The sequence group to swap in.
num_lookahead_slots (int): Number of lookahead slots used in num_lookahead_slots (int): Number of lookahead slots used in
speculative decoding, default to 0. speculative decoding, default to 0.
...@@ -405,8 +405,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): ...@@ -405,8 +405,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
Args: Args:
seq_group (SequenceGroup): The sequence group to swap out. seq_group (SequenceGroup): The sequence group to swap out.
num_lookahead_slots (int): Number of lookahead slots used in
speculative decoding, default to 0.
Returns: Returns:
bool: Whether it's possible to swap out current sequence group. bool: Whether it's possible to swap out current sequence group.
...@@ -420,7 +418,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): ...@@ -420,7 +418,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
swapping out the given sequence_group with num_lookahead_slots. swapping out the given sequence_group with num_lookahead_slots.
Args: Args:
sequence_group (SequenceGroup): The sequence group to swap out. seq_group (SequenceGroup): The sequence group to swap out.
Returns: Returns:
List[Tuple[int, int]]: The mapping of swapping block from List[Tuple[int, int]]: The mapping of swapping block from
...@@ -473,7 +471,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): ...@@ -473,7 +471,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
on to the 'device'. on to the 'device'.
Args: Args:
sequence_group (SequenceGroup): The sequence group to swap in/out. seq_group (SequenceGroup): The sequence group to swap in/out.
device (Device): device to swap the 'seq_group' on. device (Device): device to swap the 'seq_group' on.
status (SequenceStatus): The status of sequence which is needed status (SequenceStatus): The status of sequence which is needed
for action. RUNNING for swap out and SWAPPED for swap in for action. RUNNING for swap out and SWAPPED for swap in
......
...@@ -76,7 +76,7 @@ class LRUEvictor(Evictor): ...@@ -76,7 +76,7 @@ class LRUEvictor(Evictor):
that's recorded in the Block. If there are multiple blocks with that's recorded in the Block. If there are multiple blocks with
the same last_accessed time, then the one with the largest num_hashed_tokens the same last_accessed time, then the one with the largest num_hashed_tokens
will be evicted. If two blocks each have the lowest last_accessed time and will be evicted. If two blocks each have the lowest last_accessed time and
highest num_hashed_tokens value, then one will be chose arbitrarily highest num_hashed_tokens value, then one will be chosen arbitrarily
""" """
# CLEANUP_THRESHOLD determines the maximum allowable size of the priority # CLEANUP_THRESHOLD determines the maximum allowable size of the priority
......
...@@ -657,7 +657,7 @@ class Scheduler: ...@@ -657,7 +657,7 @@ class Scheduler:
`budget.num_batched_tokens` has not enough capacity to schedule `budget.num_batched_tokens` has not enough capacity to schedule
all tokens. all tokens.
partial_prefill_metadata: information about the partial prefills partial_prefill_metadata: information about the partial prefills
that are currently running that are currently running
Returns: Returns:
SchedulerRunningOutputs. SchedulerRunningOutputs.
...@@ -1591,7 +1591,6 @@ class Scheduler: ...@@ -1591,7 +1591,6 @@ class Scheduler:
encoder_seq_data=encoder_seq_data, encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table, cross_block_table=cross_block_table,
state=seq_group.state, state=seq_group.state,
token_type_ids=seq_group.token_type_ids,
# `multi_modal_data` will only be present for the 1st comm # `multi_modal_data` will only be present for the 1st comm
# between engine and worker. # between engine and worker.
# the subsequent comms can still use delta, but # the subsequent comms can still use delta, but
......
...@@ -152,8 +152,13 @@ class CuMemAllocator: ...@@ -152,8 +152,13 @@ class CuMemAllocator:
self.pointer_to_data: dict[int, AllocationData] = {} self.pointer_to_data: dict[int, AllocationData] = {}
self.current_tag: str = CuMemAllocator.default_tag self.current_tag: str = CuMemAllocator.default_tag
self.allocator_and_pools: dict[str, Any] = {} self.allocator_and_pools: dict[str, Any] = {}
# Creating strong references to the two callbacks here to prevent
# these ephemeral bound-method objects being garbage collected.
# See discussions in https://github.com/vllm-project/vllm/pull/22724
self.python_malloc_callback = self._python_malloc_callback
self.python_free_callback = self._python_free_callback
def python_malloc_callback(self, allocation_handle: HandleType) -> None: def _python_malloc_callback(self, allocation_handle: HandleType) -> None:
""" """
Internal method to store the allocation data Internal method to store the allocation data
when memory is allocated in the memory pool.""" when memory is allocated in the memory pool."""
...@@ -162,7 +167,7 @@ class CuMemAllocator: ...@@ -162,7 +167,7 @@ class CuMemAllocator:
allocation_handle, self.current_tag) allocation_handle, self.current_tag)
return return
def python_free_callback(self, ptr: int) -> HandleType: def _python_free_callback(self, ptr: int) -> HandleType:
""" """
Internal method to look up the allocation data Internal method to look up the allocation data
when memory is freed in the memory pool.""" when memory is freed in the memory pool."""
...@@ -212,9 +217,9 @@ class CuMemAllocator: ...@@ -212,9 +217,9 @@ class CuMemAllocator:
def wake_up(self, tags: Optional[list[str]] = None) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
""" """
Wake up the allocator from sleep mode. Wake up the allocator from sleep mode.
All data that is previously offloaded will be loaded back to GPU All data that is previously offloaded will be loaded back to GPU
memory, and the rest of the data will have empty memory. memory, and the rest of the data will have empty memory.
:param tags: The tags of the memory allocation that will be loaded :param tags: The tags of the memory allocation that will be loaded
back to GPU memory. If None, all memory allocation will be loaded back to GPU memory. If None, all memory allocation will be loaded
back to GPU memory. back to GPU memory.
......
...@@ -23,6 +23,39 @@ from vllm.utils import (cuda_device_count_stateless, ...@@ -23,6 +23,39 @@ from vllm.utils import (cuda_device_count_stateless,
logger = init_logger(__name__) logger = init_logger(__name__)
MiB = 1024 * 1024
# Max size for each world size in case symmetric memory is available
# For different SM architectures
CUSTOM_ALL_REDUCE_MAX_SIZES = {
"9.0": {
2: 64 * MiB, # 64 MB
4: 32 * MiB, # 32 MB
6: MiB // 2, # 512 KB
8: MiB // 4, # 256 KB
},
"10.0": {
2: 2 * MiB, # 2 MB
4: 2 * MiB, # 2 MB
6: 2 * MiB, # 2 MB
8: 2 * MiB, # 2 MB
}
}
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
"9.0": {
2: 64 * MiB, # 64 MB
4: 32 * MiB, # 32 MB
6: 64 * MiB, # 64 MB
8: 64 * MiB, # 64 MB
},
"10.0": {
2: 8 * MiB, # 8 MB
4: 32 * MiB, # 32 MB
6: 128 * MiB, # 128 MB
8: 128 * MiB, # 128 MB
}
}
def producer(batch_src: Sequence[int], def producer(batch_src: Sequence[int],
producer_queue, producer_queue,
......
...@@ -255,7 +255,7 @@ class DeviceCommunicatorBase: ...@@ -255,7 +255,7 @@ class DeviceCommunicatorBase:
if module.__class__.__name__ == "FusedMoE" if module.__class__.__name__ == "FusedMoE"
] ]
for module in moe_modules: for module in moe_modules:
module.quant_method.init_prepare_finalize() module.quant_method.init_prepare_finalize(module)
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self, hidden_states: torch.Tensor,
......
...@@ -44,6 +44,8 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -44,6 +44,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
PyNcclCommunicator) PyNcclCommunicator)
from vllm.distributed.device_communicators.quick_all_reduce import ( from vllm.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce) QuickAllReduce)
from vllm.distributed.device_communicators.symm_mem import (
SymmMemCommunicator)
self.pynccl_comm: Optional[PyNcclCommunicator] = None self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1: if use_pynccl and self.world_size > 1:
...@@ -54,6 +56,7 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -54,6 +56,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.ca_comm: Optional[CustomAllreduce] = None self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None self.qr_comm: Optional[QuickAllReduce] = None
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
if use_custom_allreduce and self.world_size > 1: if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation. # Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce( self.ca_comm = CustomAllreduce(
...@@ -69,6 +72,12 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -69,6 +72,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
# currently be an MI300 series. # currently be an MI300 series.
self.qr_comm = QuickAllReduce(group=self.cpu_group, self.qr_comm = QuickAllReduce(group=self.cpu_group,
device=self.device) device=self.device)
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group,
device=self.device,
)
if self.use_all2all: if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive": if all2all_backend == "naive":
...@@ -105,6 +114,12 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -105,6 +114,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
out = ca_comm.custom_all_reduce(input_) out = ca_comm.custom_all_reduce(input_)
assert out is not None assert out is not None
return out return out
symm_mem_comm = self.symm_mem_comm
if symm_mem_comm is not None and \
symm_mem_comm.should_use_symm_mem(input_):
out = symm_mem_comm.all_reduce(input_)
assert out is not None
return out
pynccl_comm = self.pynccl_comm pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_) out = pynccl_comm.all_reduce(input_)
...@@ -137,7 +152,7 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -137,7 +152,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
dtype=input_tensor.dtype, dtype=input_tensor.dtype,
device=input_tensor.device) device=input_tensor.device)
pynccl_comm.reduce_scatter(output, input_) pynccl_comm.reduce_scatter(output, input_tensor)
# Reshape before returning # Reshape before returning
return output.movedim(0, dim).contiguous() return output.movedim(0, dim).contiguous()
...@@ -171,9 +186,9 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -171,9 +186,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
device=input_tensor.device) device=input_tensor.device)
if sizes is not None: if sizes is not None:
pynccl_comm.reduce_scatterv(output, input_, sizes=sizes) pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes)
else: else:
pynccl_comm.reduce_scatter(output, input_) pynccl_comm.reduce_scatter(output, input_tensor)
# Reshape before returning # Reshape before returning
return output.movedim(0, dim).contiguous() return output.movedim(0, dim).contiguous()
......
...@@ -11,8 +11,8 @@ from torch.distributed import ProcessGroup ...@@ -11,8 +11,8 @@ from torch.distributed import ProcessGroup
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed.device_communicators.custom_all_reduce_utils import ( from vllm.distributed.device_communicators.all_reduce_utils import (
gpu_p2p_access_check) CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check)
from vllm.distributed.parallel_state import in_the_same_node_as from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -114,7 +114,13 @@ class CustomAllreduce: ...@@ -114,7 +114,13 @@ class CustomAllreduce:
# now `device` is a `torch.device` object # now `device` is a `torch.device` object
assert isinstance(device, torch.device) assert isinstance(device, torch.device)
self.device = device self.device = device
device_capability = current_platform.get_device_capability(
).as_version_str()
if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES):
max_size = min(
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size],
max_size)
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices: if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(","))) device_ids = list(map(int, cuda_visible_devices.split(",")))
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.distributed.device_communicators.all_reduce_utils import (
SYMM_MEM_ALL_REDUCE_MAX_SIZES)
from vllm.logger import init_logger
from vllm.platforms import current_platform
try:
import torch.distributed._symmetric_memory as torch_symm_mem
symm_mem_available = True
except ImportError:
symm_mem_available = False
logger = init_logger(__name__)
class SymmMemCommunicator:
_WORLD_SIZES_MULTIMEM = {
"9.0": [4, 6, 8],
"10.0": [6, 8],
}
def __init__(self, group: ProcessGroup, device: Union[int, str,
torch.device]):
self.disabled = True
if not symm_mem_available:
return
if not current_platform.is_cuda():
logger.warning("SymmMemCommunicator: symmetric "
"memory is not available.")
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
torch.cuda.set_device(device)
self.dtype = torch.bfloat16
self.device = device
self.group = group
self.world_size = dist.get_world_size(self.group)
self.device_capability = current_platform.get_device_capability(
).as_version_str()
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
logger.warning(
"SymmMemCommunicator: Device capability %s not supported, "
"communicator is not available.",
self.device_capability,
)
return
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[
self.device_capability]:
logger.warning(
"SymmMemCommunicator: World size %d not supported, "
"communicator is not available.",
self.world_size,
)
return
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
self.world_size]
self.buffer = torch_symm_mem.empty(
self.max_size // self.dtype.itemsize,
device=self.device,
dtype=self.dtype,
)
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
if handle.multicast_ptr == 0:
logger.warning("SymmMemCommunicator: symmetric memory "
"multicast operations are not supported.")
return
self.disabled = False
def should_use_symm_mem(self, inp: torch.Tensor):
if self.disabled:
return False
if inp.dtype != self.dtype:
return False
inp_size = inp.numel() * inp.element_size()
if inp_size % 4 != 0:
return False
return inp_size < self.max_size
def all_reduce(
self,
inp: torch.Tensor,
*,
out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
if not self.should_use_symm_mem(inp):
return None
if out is None:
out = torch.empty_like(inp)
self.buffer[:inp.numel()].copy_(inp.view(-1))
if self.world_size in self._WORLD_SIZES_MULTIMEM[
self.device_capability]:
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)
else:
torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)
out.copy_(self.buffer[:inp.numel()].view(out.shape))
return out
...@@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup ...@@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_COMMONS
from .base_device_communicator import DeviceCommunicatorBase from .base_device_communicator import DeviceCommunicatorBase
...@@ -18,16 +19,17 @@ USE_RAY = parallel_config = get_current_vllm_config( ...@@ -18,16 +19,17 @@ USE_RAY = parallel_config = get_current_vllm_config(
logger = init_logger(__name__) logger = init_logger(__name__)
if current_platform.is_tpu(): if not USE_TPU_COMMONS:
import torch_xla logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
import torch_xla.core.xla_model as xm if current_platform.is_tpu():
import torch_xla.runtime as xr import torch_xla
from torch_xla._internal import pjrt import torch_xla.core.xla_model as xm
from torch_xla.distributed.xla_multiprocessing import ( import torch_xla.runtime as xr
create_optimized_replica_groups) from torch_xla._internal import pjrt
from torch_xla.distributed.xla_multiprocessing import (
if USE_RAY: create_optimized_replica_groups)
from vllm.executor import ray_utils if USE_RAY:
from vllm.executor import ray_utils
class TpuCommunicator(DeviceCommunicatorBase): class TpuCommunicator(DeviceCommunicatorBase):
...@@ -94,10 +96,7 @@ class TpuCommunicator(DeviceCommunicatorBase): ...@@ -94,10 +96,7 @@ class TpuCommunicator(DeviceCommunicatorBase):
return xm.all_gather(input_, dim=dim) return xm.all_gather(input_, dim=dim)
try: if USE_TPU_COMMONS:
from tpu_commons.distributed.device_communicators import ( from tpu_commons.distributed.device_communicators import (
TpuCommunicator as TpuCommonsCommunicator) TpuCommunicator as TpuCommonsCommunicator)
TpuCommunicator = TpuCommonsCommunicator # type: ignore TpuCommunicator = TpuCommonsCommunicator # type: ignore
except ImportError:
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
pass
...@@ -7,8 +7,13 @@ import torch ...@@ -7,8 +7,13 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm.logger import init_logger
from .base_device_communicator import DeviceCommunicatorBase from .base_device_communicator import DeviceCommunicatorBase
logger = init_logger(__name__)
class XpuCommunicator(DeviceCommunicatorBase): class XpuCommunicator(DeviceCommunicatorBase):
...@@ -18,6 +23,12 @@ class XpuCommunicator(DeviceCommunicatorBase): ...@@ -18,6 +23,12 @@ class XpuCommunicator(DeviceCommunicatorBase):
device_group: Optional[ProcessGroup] = None, device_group: Optional[ProcessGroup] = None,
unique_name: str = ""): unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name) super().__init__(cpu_group, device, device_group, unique_name)
if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
def all_reduce(self, input_) -> torch.Tensor: def all_reduce(self, input_) -> torch.Tensor:
dist.all_reduce(input_, group=self.device_group) dist.all_reduce(input_, group=self.device_group)
......
...@@ -244,7 +244,7 @@ class EplbState: ...@@ -244,7 +244,7 @@ class EplbState:
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
expert_load_window_size = parallel_config.eplb_window_size expert_load_window_size = parallel_config.eplb_config.window_size
expert_load_window = torch.zeros( expert_load_window = torch.zeros(
(expert_load_window_size, model.num_moe_layers, (expert_load_window_size, model.num_moe_layers,
model.num_physical_experts), model.num_physical_experts),
...@@ -253,7 +253,7 @@ class EplbState: ...@@ -253,7 +253,7 @@ class EplbState:
) )
# Set the initial progress of rearrangement to 3/4 # Set the initial progress of rearrangement to 3/4
eplb_step_interval = parallel_config.eplb_step_interval eplb_step_interval = parallel_config.eplb_config.step_interval
expert_rearrangement_step = max( expert_rearrangement_step = max(
0, eplb_step_interval - eplb_step_interval // 4) 0, eplb_step_interval - eplb_step_interval // 4)
...@@ -409,12 +409,14 @@ class EplbState: ...@@ -409,12 +409,14 @@ class EplbState:
self.expert_rearrangement_step = 0 self.expert_rearrangement_step = 0
self.rearrange(model) self.rearrange(model)
def rearrange(self, def rearrange(
model: MixtureOfExperts, self,
is_profile: bool = False, model: MixtureOfExperts,
execute_shuffle: bool = True, is_profile: bool = False,
global_expert_load: Optional[torch.Tensor] = None, execute_shuffle: bool = True,
rank_mapping: Optional[dict[int, int]] = None) -> None: global_expert_load: Optional[torch.Tensor] = None,
rank_mapping: Optional[dict[int,
int]] = None) -> Optional[torch.Tensor]:
""" """
Rearrange the experts according to the current load. Rearrange the experts according to the current load.
""" """
...@@ -548,6 +550,7 @@ class EplbState: ...@@ -548,6 +550,7 @@ class EplbState:
" (profile) " if is_profile else " ", " (profile) " if is_profile else " ",
time_end - time_start, time_end - time_start,
) )
return None
@staticmethod @staticmethod
def recv_state() -> tuple[torch.Tensor, torch.Tensor]: def recv_state() -> tuple[torch.Tensor, torch.Tensor]:
...@@ -613,4 +616,4 @@ def _node_count_with_rank_mapping( ...@@ -613,4 +616,4 @@ def _node_count_with_rank_mapping(
if is_same_node and node_assignment[other_rank] == 0: if is_same_node and node_assignment[other_rank] == 0:
node_assignment[other_rank] = next_node_id node_assignment[other_rank] = next_node_id
return next_node_id return next_node_id
\ No newline at end of file
...@@ -40,16 +40,21 @@ class KVCacheEvent( ...@@ -40,16 +40,21 @@ class KVCacheEvent(
"""Base class for all KV cache-related events""" """Base class for all KV cache-related events"""
MEDIUM_GPU = "GPU"
class BlockStored(KVCacheEvent): class BlockStored(KVCacheEvent):
block_hashes: list[int] block_hashes: list[int]
parent_block_hash: Optional[int] parent_block_hash: Optional[int]
token_ids: list[int] token_ids: list[int]
block_size: int block_size: int
lora_id: Optional[int] lora_id: Optional[int]
medium: Optional[str]
class BlockRemoved(KVCacheEvent): class BlockRemoved(KVCacheEvent):
block_hashes: list[int] block_hashes: list[int]
medium: Optional[str]
class AllBlocksCleared(KVCacheEvent): class AllBlocksCleared(KVCacheEvent):
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Distributed KV cache transfer # Distributed KV cache transfer
This folder implements distributed KV cache transfer across vLLM instances. This folder implements distributed KV cache transfer across vLLM instances.
Currently the main usecase is for disaggregated prefilling. Currently the main use case is for disaggregated prefilling.
## Abstractions ## Abstractions
...@@ -14,7 +14,7 @@ The KV cache transfer contains three layer of abstractions: ...@@ -14,7 +14,7 @@ The KV cache transfer contains three layer of abstractions:
Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer. Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer.
NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed NOTE: KV pipe layer is bypassable: you can skip this layer if your distributed
communication service already supports key-value-based lookup (like redis or communication service already supports key-value-based lookup (like redis or
RDMA database). RDMA database).
......
...@@ -19,6 +19,8 @@ The class provides the following primitives: ...@@ -19,6 +19,8 @@ The class provides the following primitives:
Returns whether KV cache should be freed now or will be Returns whether KV cache should be freed now or will be
freed asynchronously and optionally returns KV transfer freed asynchronously and optionally returns KV transfer
params. params.
take_events() - returns new KV events that were collected
by the connector since the last call.
Worker-side: runs in each worker, loads/saves KV cache to/from Worker-side: runs in each worker, loads/saves KV cache to/from
the Connector based on the metadata. the Connector based on the metadata.
...@@ -34,6 +36,7 @@ The class provides the following primitives: ...@@ -34,6 +36,7 @@ The class provides the following primitives:
import enum import enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
import torch import torch
...@@ -45,6 +48,7 @@ from vllm.v1.outputs import KVConnectorOutput ...@@ -45,6 +48,7 @@ from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request from vllm.v1.request import Request
...@@ -131,8 +135,8 @@ class KVConnectorBase_V1(ABC): ...@@ -131,8 +135,8 @@ class KVConnectorBase_V1(ABC):
Initialize with the KV caches. Useful for pre-registering the Initialize with the KV caches. Useful for pre-registering the
KV Caches in the KVConnector (e.g. for NIXL). KV Caches in the KVConnector (e.g. for NIXL).
Args: kv_caches: Args:
dictionary of layer names, kv cache kv_caches: dictionary of layer names, kv cache
""" """
return return
...@@ -313,6 +317,15 @@ class KVConnectorBase_V1(ABC): ...@@ -313,6 +317,15 @@ class KVConnectorBase_V1(ABC):
""" """
return False, None return False, None
def take_events(self) -> Iterable["KVCacheEvent"]:
"""
Take the KV cache events from the connector.
Yields:
New KV cache events since the last call.
"""
return ()
@classmethod @classmethod
def get_required_kvcache_layout( def get_required_kvcache_layout(
cls, vllm_config: "VllmConfig") -> Optional[str]: cls, vllm_config: "VllmConfig") -> Optional[str]:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy import copy
from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
import torch import torch
from vllm.config import KVTransferConfig, VllmConfig from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.factory import ( from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory) KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
...@@ -208,6 +210,10 @@ class MultiConnector(KVConnectorBase_V1): ...@@ -208,6 +210,10 @@ class MultiConnector(KVConnectorBase_V1):
return async_saves > 0, kv_txfer_params return async_saves > 0, kv_txfer_params
def take_events(self) -> Iterable[KVCacheEvent]:
for c in self._connectors:
yield from c.take_events()
@classmethod @classmethod
def get_required_kvcache_layout( def get_required_kvcache_layout(
cls, vllm_config: "VllmConfig") -> Optional[str]: cls, vllm_config: "VllmConfig") -> Optional[str]:
......
...@@ -686,9 +686,6 @@ class NixlConnectorWorker: ...@@ -686,9 +686,6 @@ class NixlConnectorWorker:
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl.""" """Register the KV Cache data in nixl."""
_, first_kv_cache = next(iter(kv_caches.items()))
kv_elem_size = first_kv_cache.element_size()
if self.use_host_buffer: if self.use_host_buffer:
self.initialize_host_xfer_buffer(kv_caches=kv_caches) self.initialize_host_xfer_buffer(kv_caches=kv_caches)
assert len(self.host_xfer_buffers) == len(kv_caches), ( assert len(self.host_xfer_buffers) == len(kv_caches), (
...@@ -701,66 +698,16 @@ class NixlConnectorWorker: ...@@ -701,66 +698,16 @@ class NixlConnectorWorker:
"host_xfer_buffer should not be initialized when " "host_xfer_buffer should not be initialized when "
f"kv_buffer_device is {self.kv_buffer_device}") f"kv_buffer_device is {self.kv_buffer_device}")
# TODO(tms): Find a more robust way to detect and handle MLA
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
# KV memory layout is HND, as opposed to the default NHD. Note that it
# will only affects the strides. For MLA instead, we make require no
# such thing and resort to the standard layout.
use_mla = len(first_kv_cache.shape) == 3
if self.device_type == "tpu":
assert not use_mla, f"{self.kv_buffer_device} does not support MLA."
assert self._use_pallas_v1, f"attn backend: {self.backend_name}"
# tpu (v1) kv shape per layer:
# (num_blocks, block_size, num_kv_heads * 2, head_size)
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, n_kv_heads_x_2, head_dim = block_shape
self.slot_size_bytes = kv_elem_size * n_kv_heads_x_2 * head_dim
elif self.device_type == "cuda":
assert use_mla == self.use_mla
# TODO (NickLucche) not compatible with hybrid allocator.
# Enforce check once it goes live, as a single kv layout
# is expected for xfers.
if use_mla:
# MLA case.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 2 # [block_size, latent_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, kv_latent_dim = block_shape
self.slot_size_bytes = kv_elem_size * kv_latent_dim
else:
# [2 (k and v), num_blocks, ...]
if self._use_flashinfer:
# FlashInfer swaps 2<->num_blocks dimensions.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 4 # [2, block_size, kv_heads, head_dim]
else:
self.num_blocks = first_kv_cache.shape[1]
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, n_kv_heads, head_dim = block_shape[-3:]
# head size in bytes.
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
assert block_size == self.block_size
else:
raise RuntimeError(
f"{self.device_type} ({self.backend_name}) is not supported.")
# TODO(tms): self.block_len needs to be per-layer for sliding window,
# hybrid attn, etc
# block size in bytes
self.block_len = kv_elem_size * math.prod(block_shape)
logger.info( logger.info(
"Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, " "Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, "
"use_host_buffer: %s, num_blocks: %s, block_shape: %s, " "use_host_buffer: %s", self.use_mla, self.kv_buffer_device,
"per_layer_kv_cache_shape: %s", use_mla, self.kv_buffer_device, self.use_host_buffer)
self.use_host_buffer, self.num_blocks, block_shape,
first_kv_cache.shape)
self.dst_num_blocks[self.engine_id] = self.num_blocks
self.device_kv_caches = kv_caches
kv_caches_base_addr = []
caches_data = [] caches_data = []
# With hybrid allocator, layers can share a kv cache tensor
seen_base_addresses = []
xfer_buffers = (self.host_xfer_buffers
if self.use_host_buffer else kv_caches)
# Note(tms): I modified this from the original region setup code. # Note(tms): I modified this from the original region setup code.
# K and V are now in different regions. Advantage is that we can # K and V are now in different regions. Advantage is that we can
...@@ -770,42 +717,35 @@ class NixlConnectorWorker: ...@@ -770,42 +717,35 @@ class NixlConnectorWorker:
# (roughly 8KB vs 5KB). # (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are transferred in the same tensor # Conversely for FlashInfer, K and V are transferred in the same tensor
# to better exploit the memory layout (ie num_blocks is the first dim). # to better exploit the memory layout (ie num_blocks is the first dim).
for cache_or_caches in xfer_buffers.values(): split_k_and_v = not (self.use_mla or self._use_pallas_v1
# Normalize to always be a list of caches or self._use_flashinfer)
cache_list = [cache_or_caches] if use_mla \ tensor_size_bytes = None
or self._use_pallas_v1 or self._use_flashinfer \ for layer_name, cache_or_caches in xfer_buffers.items():
else cache_or_caches cache_list = cache_or_caches if split_k_and_v else [
cache_or_caches
]
for cache in cache_list: for cache in cache_list:
base_addr = cache.data_ptr() base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len if base_addr in seen_base_addresses:
# NOTE: use tp_rank for device_id since multi-node TP continue
# is rarely used.
caches_data.append((base_addr, region_len, self.tp_rank, "")) seen_base_addresses.append(base_addr)
kv_caches_base_addr.append(base_addr) curr_tensor_size_bytes = cache.numel() * cache.element_size()
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
if tensor_size_bytes is None:
tensor_size_bytes = curr_tensor_size_bytes
self.num_blocks = cache.shape[0]
assert tensor_size_bytes == curr_tensor_size_bytes, \
"All kv cache tensors must have the same size"
caches_data.append(
(base_addr, tensor_size_bytes, self.tp_rank, ""))
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
self.num_regions = len(caches_data) self.num_regions = len(caches_data)
self.num_layers = len(xfer_buffers.keys()) self.num_layers = len(xfer_buffers.keys())
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
if self.vllm_config.model_config.hf_config.model_type == "llama4":
from transformers import Llama4TextConfig
assert isinstance(self.vllm_config.model_config.hf_text_config,
Llama4TextConfig)
llama4_config = self.vllm_config.model_config.hf_text_config
no_rope_layers = llama4_config.no_rope_layers
chunk_size = llama4_config.attention_chunk_size
chunk_block_size = math.ceil(chunk_size / self.block_size)
for layer_idx in range(self.num_layers):
# no_rope_layers[layer_idx] == 0 means NoPE (global)
# Any other value means RoPE (local chunked)
is_local_attention = no_rope_layers[layer_idx] != 0
block_window = chunk_block_size if is_local_attention else None
self.block_window_per_layer.append(block_window)
logger.debug("Llama 4 block window per layer mapping: %s",
self.block_window_per_layer)
assert len(self.block_window_per_layer) == self.num_layers
descs = self.nixl_wrapper.get_reg_descs(caches_data, descs = self.nixl_wrapper.get_reg_descs(caches_data,
self.nixl_memory_type) self.nixl_memory_type)
logger.debug("Registering descs: %s", caches_data) logger.debug("Registering descs: %s", caches_data)
...@@ -813,9 +753,20 @@ class NixlConnectorWorker: ...@@ -813,9 +753,20 @@ class NixlConnectorWorker:
logger.debug("Done registering descs") logger.debug("Done registering descs")
self._registered_descs.append(descs) self._registered_descs.append(descs)
assert tensor_size_bytes is not None
assert self.num_blocks != 0
assert tensor_size_bytes % self.num_blocks == 0
self.block_len = tensor_size_bytes // self.num_blocks
self.slot_size_bytes = self.block_len // self.block_size
if self._use_flashinfer:
assert self.slot_size_bytes % 2 == 0
self.slot_size_bytes /= 2
self.device_kv_caches = kv_caches
self.dst_num_blocks[self.engine_id] = self.num_blocks
# Register local/src descr for NIXL xfer. # Register local/src descr for NIXL xfer.
blocks_data = [] blocks_data = []
for base_addr in self.kv_caches_base_addr[self.engine_id]: for base_addr in seen_base_addresses:
# NOTE With heter-TP, more blocks are prepared than what are # NOTE With heter-TP, more blocks are prepared than what are
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
# could create fewer, but then _get_block_descs_ids needs to # could create fewer, but then _get_block_descs_ids needs to
...@@ -836,6 +787,26 @@ class NixlConnectorWorker: ...@@ -836,6 +787,26 @@ class NixlConnectorWorker:
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs) "NIXL_INIT_AGENT", descs)
# TODO(mgoin): Hybrid memory allocator is currently diabled for
# models with local attention (Llama 4). Can remove this once enabled.
if self.vllm_config.model_config.hf_config.model_type == "llama4":
from transformers import Llama4TextConfig
assert isinstance(self.vllm_config.model_config.hf_text_config,
Llama4TextConfig)
llama4_config = self.vllm_config.model_config.hf_text_config
no_rope_layers = llama4_config.no_rope_layers
chunk_size = llama4_config.attention_chunk_size
chunk_block_size = math.ceil(chunk_size / self.block_size)
for layer_idx in range(self.num_layers):
# no_rope_layers[layer_idx] == 0 means NoPE (global)
# Any other value means RoPE (local chunked)
is_local_attention = no_rope_layers[layer_idx] != 0
block_window = chunk_block_size if is_local_attention else None
self.block_window_per_layer.append(block_window)
logger.debug("Llama 4 block window per layer mapping: %s",
self.block_window_per_layer)
assert len(self.block_window_per_layer) == self.num_layers
# After KV Caches registered, listen for new connections. # After KV Caches registered, listen for new connections.
metadata = NixlAgentMetadata( metadata = NixlAgentMetadata(
engine_id=self.engine_id, engine_id=self.engine_id,
......
...@@ -30,27 +30,19 @@ logger = init_logger(__name__) ...@@ -30,27 +30,19 @@ logger = init_logger(__name__)
class ReqMeta: class ReqMeta:
# Request Id # Request Id
request_id: str request_id: str
# Request tokens # Request block ids
token_ids: torch.Tensor block_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids # Request num tokens
slot_mapping: torch.Tensor num_tokens: int
@staticmethod @staticmethod
def make_meta(request_id: str, token_ids: list[int], block_ids: list[int], def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
block_size: int) -> "ReqMeta": block_size: int) -> "ReqMeta":
valid_num_tokens = len(token_ids)
token_ids_tensor = torch.tensor(token_ids)
block_ids_tensor = torch.tensor(block_ids) block_ids_tensor = torch.tensor(block_ids)
num_blocks = block_ids_tensor.shape[0]
block_offsets = torch.arange(0, block_size)
slot_mapping = block_offsets.reshape((1, block_size)) + \
block_ids_tensor.reshape((num_blocks, 1)) * block_size
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
return ReqMeta( return ReqMeta(
request_id=request_id, request_id=request_id,
token_ids=token_ids_tensor, block_ids=block_ids_tensor,
slot_mapping=slot_mapping, num_tokens=len(token_ids),
) )
...@@ -123,63 +115,58 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -123,63 +115,58 @@ class P2pNcclConnector(KVConnectorBase_V1):
return return
def inject_kv_into_layer( def inject_kv_into_layer(
dst_kv_cache_layer: torch.Tensor, layer: torch.Tensor,
src_kv_cache: torch.Tensor, kv_cache: torch.Tensor,
slot_mapping: torch.Tensor, block_ids: torch.Tensor,
request_id: str, request_id: str,
) -> None: ) -> None:
"""Inject the KV cache into the layer. """
Inject KV cache data into a given attention layer tensor.
This function updates `layer` in-place with values from `kv_cache`,
handling different backend layouts:
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
indexed along the first dimension.
- FlashAttention: KV tensors are indexed along the second
dimension.
If the number of provided block IDs does not match the number of KV
blocks, only the overlapping portion is updated, and a warning is
logged.
Args: Args:
dst_kv_cache_layer (torch.Tensor): the destination KV cache layer (torch.Tensor): The attention layer KV tensor to update.
layer. In shape [2, num_pages, page_size, xxx] if not kv_cache (torch.Tensor): The KV cache tensor to inject.
using MLA, [num_pages, page_size, xxx] otherwise. block_ids (torch.Tensor): Indices of the blocks to update.
src_kv_cache (torch.Tensor): the source KV cache. In shape request_id (str): Request identifier used for logging.
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
otherwise. Returns:
slot_mapping (torch.Tensor): the slot mapping. In shape None. The function modifies `layer` in-place.
[num_tokens].
request_id (str): request id for log
""" """
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape if (isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values())
if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()): or layer.shape[1] == 2): # MLA or FlashInfer
num_pages = dst_kv_cache_layer_shape[0] num_block = kv_cache.shape[0]
page_size = dst_kv_cache_layer_shape[1] self.check_tensors_except_dim(layer, kv_cache, 0)
dst_kv_cache_layer = dst_kv_cache_layer.reshape( if len(block_ids) == num_block:
num_pages * page_size, -1) layer[block_ids, ...] = kv_cache
self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
0)
num_token = src_kv_cache.shape[0]
if len(slot_mapping) == num_token:
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
else: else:
dst_kv_cache_layer[slot_mapping[:num_token], layer[block_ids[:num_block], ...] = kv_cache
...] = src_kv_cache
logger.warning( logger.warning(
"🚧src_kv_cache does not match, num_slot:%d, " "🚧kv_cache does not match, block_ids:%d, "
"num_token:%d, request_id:%s", len(slot_mapping), "num_block:%d, request_id:%s", len(block_ids),
num_token, request_id) num_block, request_id)
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) elif layer.shape[0] == 2: # FlashAttention
else: num_block = kv_cache.shape[1]
num_pages = dst_kv_cache_layer_shape[1] self.check_tensors_except_dim(layer, kv_cache, 1)
page_size = dst_kv_cache_layer_shape[2] if len(block_ids) == num_block:
dst_kv_cache_layer = dst_kv_cache_layer.reshape( layer[:, block_ids, ...] = kv_cache
2, num_pages * page_size, -1)
self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
1)
num_token = src_kv_cache.shape[1]
if len(slot_mapping) == num_token:
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
else: else:
dst_kv_cache_layer[:, slot_mapping[:num_token], layer[:, block_ids[:num_block], ...] = kv_cache
...] = src_kv_cache
logger.warning( logger.warning(
"🚧src_kv_cache does not match, num_slot:%d, " "🚧kv_cache does not match, block_ids:%d, "
"num_token:%d, request_id:%s", len(slot_mapping), "num_block:%d, request_id:%s", len(block_ids),
num_token, request_id) num_block, request_id)
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
# Get the metadata # Get the metadata
metadata: KVConnectorMetadata = \ metadata: KVConnectorMetadata = \
...@@ -201,19 +188,17 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -201,19 +188,17 @@ class P2pNcclConnector(KVConnectorBase_V1):
if kv_cache is None: if kv_cache is None:
continue continue
kv_cache_layer = kv_cache[ \ layer = kv_cache[forward_context.virtual_engine]
forward_context.virtual_engine]
kv_cache = self.p2p_nccl_engine.recv_tensor( kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name) request.request_id + "#" + layer_name)
if kv_cache is None: if kv_cache is None:
logger.warning("🚧src_kv_cache is None, %s", logger.warning("🚧kv_cache is None, %s", request.request_id)
request.request_id)
continue continue
inject_kv_into_layer(kv_cache_layer, kv_cache, inject_kv_into_layer(layer, kv_cache, request.block_ids,
request.slot_mapping, request.request_id) request.request_id)
tensor = self.p2p_nccl_engine.recv_store.pop(request.request_id + "#" + layer_name, None) tensor = self.p2p_nccl_engine.recv_store.pop(request.request_id + "#" + layer_name, None)
if tensor is not None: if tensor is not None:
del tensor del tensor
...@@ -248,16 +233,46 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -248,16 +233,46 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert self.p2p_nccl_engine is not None assert self.p2p_nccl_engine is not None
def extract_kv_from_layer(
layer: torch.Tensor,
block_ids: torch.Tensor,
) -> torch.Tensor:
"""
Extract KV cache slices from a given attention layer tensor.
This function handles multiple backend layouts:
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
indexed along the first dimension.
- FlashAttention: KV tensors are indexed along the second
dimension.
Args:
layer (torch.Tensor): The KV cache from the attention layer.
block_ids (torch.Tensor): Indices of blocks to extract.
Returns:
torch.Tensor: A tensor containing the extracted KV slices.
Returns None if the layout is unsupported.
"""
if (isinstance(attn_metadata, MLACommonMetadata)
or layer.shape[1] == 2): # MLA or FlashInfer
return layer[block_ids, ...]
if layer.shape[0] == 2: # FlashAttention
return layer[:, block_ids, ...]
return None
connector_metadata = self._get_connector_metadata() connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, P2pNcclConnectorMetadata) assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
for request in connector_metadata.requests: for request in connector_metadata.requests:
request_id = request.request_id request_id = request.request_id
ip, port = self.parse_request_id(request_id, True) ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank) remote_address = ip + ":" + str(port + self._rank)
self.p2p_nccl_engine.send_tensor(
request_id + "#" + layer_name, kv_layer, remote_address, kv_cache = extract_kv_from_layer(kv_layer, request.block_ids)
request.slot_mapping, self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
isinstance(attn_metadata, MLACommonMetadata)) kv_cache, remote_address)
def wait_for_save(self): def wait_for_save(self):
if self.is_producer: if self.is_producer:
......
...@@ -62,8 +62,6 @@ class SendQueueItem: ...@@ -62,8 +62,6 @@ class SendQueueItem:
tensor_id: str tensor_id: str
remote_address: str remote_address: str
tensor: torch.Tensor tensor: torch.Tensor
slot_mapping: torch.Tensor
is_mla: bool
class P2pNcclEngine: class P2pNcclEngine:
...@@ -202,8 +200,6 @@ class P2pNcclEngine: ...@@ -202,8 +200,6 @@ class P2pNcclEngine:
tensor_id: str, tensor_id: str,
tensor: torch.Tensor, tensor: torch.Tensor,
remote_address: typing.Optional[str] = None, remote_address: typing.Optional[str] = None,
slot_mapping: torch.Tensor = None,
is_mla: bool = False,
) -> bool: ) -> bool:
if remote_address is None: if remote_address is None:
with self.recv_store_cv: with self.recv_store_cv:
...@@ -213,9 +209,7 @@ class P2pNcclEngine: ...@@ -213,9 +209,7 @@ class P2pNcclEngine:
item = SendQueueItem(tensor_id=tensor_id, item = SendQueueItem(tensor_id=tensor_id,
remote_address=remote_address, remote_address=remote_address,
tensor=tensor, tensor=tensor)
slot_mapping=slot_mapping,
is_mla=is_mla)
if self.send_type == "PUT": if self.send_type == "PUT":
return self.send_sync(item) return self.send_sync(item)
...@@ -433,9 +427,7 @@ class P2pNcclEngine: ...@@ -433,9 +427,7 @@ class P2pNcclEngine:
if item.remote_address not in self.socks: if item.remote_address not in self.socks:
self.create_connect(item.remote_address) self.create_connect(item.remote_address)
with self.send_stream: tensor = item.tensor
tensor = self.extract_kv_from_layer(item.is_mla, item.tensor,
item.slot_mapping)
sock = self.socks[item.remote_address] sock = self.socks[item.remote_address]
comm, rank = self.comms[item.remote_address] comm, rank = self.comms[item.remote_address]
...@@ -548,21 +540,3 @@ class P2pNcclEngine: ...@@ -548,21 +540,3 @@ class P2pNcclEngine:
self._send_thread.join() self._send_thread.join()
if self._ping_thread is not None: if self._ping_thread is not None:
self._ping_thread.join() self._ping_thread.join()
@staticmethod
def extract_kv_from_layer(
is_mla: bool,
layer: torch.Tensor,
slot_mapping: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if is_mla:
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
...]
...@@ -99,8 +99,9 @@ class TensorMemoryPool: ...@@ -99,8 +99,9 @@ class TensorMemoryPool:
addr=self.base_address) addr=self.base_address)
self.free_lists[self.max_block_size][ self.free_lists[self.max_block_size][
initial_block.addr] = initial_block initial_block.addr] = initial_block
logger.debug("TensorMemoryPool, base_address:", self.base_address,
self.base_address % self.max_block_size) logger.debug("TensorMemoryPool, base_address:%d, max_block_size:%d",
self.base_address, self.max_block_size)
def allocate(self, size: int) -> int: def allocate(self, size: int) -> int:
"""Allocates a memory block of at least the requested size. """Allocates a memory block of at least the requested size.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment