Commit d29c39ca authored by chenzk's avatar chenzk
Browse files

vllm kvprune wo:v1.1.0

parent f81ce56b
"""Distributed helpers for kvprune when embedded in vLLM (use TP process group)."""
from __future__ import annotations
import torch
import torch.distributed as dist
def broadcast_from_tp_rank0(
tensor: torch.Tensor, *, use_tp_group: bool
) -> None:
"""Broadcast ``tensor`` from group-local rank 0.
When ``use_tp_group`` is False (standalone compactor subprocesses), uses the
default process group (world == tensor parallel size).
When True (embedded in a vLLM worker), uses vLLM's tensor-parallel group so
collectives do not accidentally involve DP/PP ranks if the default group is global.
"""
if not use_tp_group:
dist.broadcast(tensor, src=0)
return
from vllm.distributed.parallel_state import get_tp_group
get_tp_group().broadcast(tensor, src=0)
def barrier_sync(*, use_tp_group: bool) -> None:
"""Barrier across either the default group or the TP group (see :func:`broadcast_from_tp_rank0`)."""
if not use_tp_group:
dist.barrier()
return
from vllm.distributed.parallel_state import get_tp_group
get_tp_group().barrier()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Bridge vLLM paged KV layout to compactor Triton kernels.
vLLM FlashAttention KV cache is shaped
[num_blocks, block_size, num_kv_heads, head_dim].
Compactor kernels expect a flat buffer [CACHE_SIZE, head_dim] and a page table
global_page_table[batch, kv_head, logical_page] -> physical_page_id
where each physical page holds ``block_size`` consecutive rows belonging to that
KV head only.
When num_kv_heads == 1 (MQA), a vLLM block maps 1:1 to compactor rows:
row_index = physical_block_id * block_size + offset_in_block.
When ``num_kv_heads > 1``, we permute to head-major
``[num_kv_heads, num_blocks, block_size, head_dim]`` and flatten to
``[num_kv_heads * num_blocks * block_size, head_dim]`` so each KV head occupies
a disjoint row range in the flat buffer. The page table is built so each
logical compression page maps to ``global_row // PAGE_SIZE`` in that layout
(see ``build_page_table_head_major``).
"""
from __future__ import annotations
import torch
def _cdiv(n: int, d: int) -> int:
return (n + d - 1) // d
def flatten_kv_cache_head_major(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""View ``[nb, bs, H, D]`` caches as ``[H*nb*bs, D]`` in head-major order."""
if key_cache.shape != value_cache.shape:
raise ValueError("key_cache and value_cache must match")
nb, bs, hkv, d = key_cache.shape
k_hm = key_cache.permute(2, 0, 1, 3).contiguous()
v_hm = value_cache.permute(2, 0, 1, 3).contiguous()
k_flat = k_hm.reshape(hkv * nb * bs, d)
v_flat = v_hm.reshape(hkv * nb * bs, d)
return k_flat, v_flat
def write_head_major_flat_to_interleaved(
k_flat: torch.Tensor,
v_flat: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
) -> None:
"""Copy ``[H*nb*bs, D]`` head-major flats back to ``[nb, bs, H, D]``."""
nb, bs, hkv, d = key_cache.shape
k_hm = k_flat.view(hkv, nb, bs, d)
v_hm = v_flat.view(hkv, nb, bs, d)
key_cache.copy_(k_hm.permute(1, 2, 0, 3))
value_cache.copy_(v_hm.permute(1, 2, 0, 3))
def build_page_table_head_major(
block_table: torch.Tensor,
num_kv_heads: int,
num_blocks: int,
block_size: int,
page_size: int,
max_batches: int,
) -> torch.Tensor:
"""Build ``[max_batches, H, max_chain]`` page table for head-major flat KV.
Chains physical page ids in ``block_table`` order for each (batch, head).
Each entry is ``global_row // page_size`` where ``global_row`` indexes rows
in the head-major flat buffer (see ``flatten_kv_cache_head_major``).
"""
bsz, max_blocks = block_table.shape
if bsz > max_batches:
raise ValueError("batch size exceeds max_batches for page table")
num_pages_per_block = _cdiv(block_size, page_size)
max_chain = max_blocks * num_pages_per_block
out = torch.zeros(
(max_batches, num_kv_heads, max_chain),
dtype=torch.int32,
device=block_table.device,
)
bt = block_table.to(torch.int64)
for b in range(bsz):
for h in range(num_kv_heads):
lp_idx = 0
for blk_i in range(max_blocks):
bid = int(bt[b, blk_i].item())
if bid < 0:
continue
if bid >= num_blocks:
raise ValueError(
f"block_table[{b},{blk_i}]={bid} out of range "
f"num_blocks={num_blocks}"
)
base_row = h * (num_blocks * block_size) + bid * block_size
for p in range(num_pages_per_block):
start_row = base_row + p * page_size
if start_row >= base_row + block_size:
break
phys = start_row // page_size
out[b, h, lp_idx] = int(phys)
lp_idx += 1
return out
def flatten_kv_cache_plane(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""View (num_blocks, block_size, HKV, D) caches as [num_blocks*block_size*HKV, D].
This matches compactor row indexing only when HKV == 1 (see module doc).
"""
if num_kv_heads != 1:
raise ValueError(
"flatten_kv_cache_plane requires num_kv_heads==1 for compactor layout"
)
if key_cache.shape != value_cache.shape:
raise ValueError("key_cache and value_cache must match")
# [num_blocks, block_size, 1, D] -> [num_blocks * block_size, D]
nb, bs, hkv, d = key_cache.shape
if hkv != 1:
raise ValueError("expected num_kv_heads==1")
k_flat = key_cache.reshape(nb * bs, d)
v_flat = value_cache.reshape(nb * bs, d)
if not k_flat.is_contiguous():
k_flat = k_flat.contiguous()
if not v_flat.is_contiguous():
v_flat = v_flat.contiguous()
return k_flat, v_flat
def block_table_to_global_page_table(
block_table: torch.Tensor,
num_kv_heads: int,
max_batches: int,
) -> torch.Tensor:
"""Build [max_batches, HKV, num_logical_pages] int32 page table.
For MQA, every KV head reuses the same physical block ids as vLLM's table.
"""
# block_table: [num_reqs_padded, max_num_blocks]
bsz, max_lp = block_table.shape
if bsz > max_batches:
raise ValueError("batch size exceeds max_batches for page table")
out = torch.zeros(
(max_batches, num_kv_heads, max_lp),
dtype=torch.int32,
device=block_table.device,
)
bt = block_table.to(torch.int32)[:bsz]
if num_kv_heads == 1:
out[:bsz, 0, :max_lp] = bt
else:
for h in range(num_kv_heads):
out[:bsz, h, :max_lp] = bt
return out
def build_batch_mapping(num_reqs: int, device: torch.device) -> torch.Tensor:
"""Local batch index -> global batch row (identity)."""
return torch.arange(num_reqs, dtype=torch.int32, device=device)
from dataclasses import dataclass, field
from enum import Enum, auto
from itertools import count
from typing import List
from vllm.kvprune.compression.compression_config import SequenceCompressionParams
from vllm.kvprune.config.sampling_params import SamplingParams
class SequenceStatus(Enum):
WAITING = auto()
RUNNING = auto()
FINISHED = auto()
@dataclass
class Sequence:
"""
Represents a single user request / sequence being generated.
"""
_counter = count()
prompt_token_ids: List[int]
completion_token_ids: List[int] = field(default_factory=list)
sampling_params: SamplingParams = field(default_factory=SamplingParams)
compression_params: SequenceCompressionParams = field(
default_factory=SequenceCompressionParams
)
status: SequenceStatus = SequenceStatus.WAITING
seq_id: int = field(default_factory=lambda: next(Sequence._counter), init=False)
num_tokens_processed: int = 0
@property
def num_prompt_tokens(self) -> int:
return len(self.prompt_token_ids)
@property
def num_generated_tokens(self) -> int:
return len(self.completion_token_ids)
def add_new_token(self, token_id: int) -> None:
if len(self.completion_token_ids) == 0:
self.num_tokens_processed += self.num_prompt_tokens
self.completion_token_ids.append(token_id)
self.num_tokens_processed += 1
def tokens_to_retain_per_layer(self, num_kv_heads: int) -> int:
n = int(
self.compression_params.compression_ratio
* self.num_prompt_tokens
* num_kv_heads
)
return max(1, n)
def __getstate__(self):
return dict(
prompt_token_ids=list(self.prompt_token_ids),
completion_token_ids=list(self.completion_token_ids),
sampling_params=self.sampling_params,
compression_params=self.compression_params,
status=self.status,
seq_id=self.seq_id,
num_tokens_processed=self.num_tokens_processed,
)
def __setstate__(self, state):
self.prompt_token_ids = list(state["prompt_token_ids"])
self.completion_token_ids = list(state["completion_token_ids"])
self.sampling_params = state["sampling_params"]
self.compression_params = state["compression_params"]
self.status = state["status"]
self.seq_id = state["seq_id"]
self.num_tokens_processed = state["num_tokens_processed"]
@property
def prompt_len(self) -> int:
return len(self.prompt_token_ids)
@property
def completion_len(self) -> int:
return len(self.completion_token_ids)
"""Tensor-parallel collectives for kvprune (match vLLM TP process group when embedded)."""
from __future__ import annotations
import torch.distributed as dist
def tensor_parallel_all_reduce(tensor: torch.Tensor) -> torch.Tensor:
"""All-reduce across tensor-parallel ranks (in-place on ``tensor`` when possible).
When vLLM :mod:`vllm.distributed.parallel_state` is initialized (e.g. kvprune
runs inside a vLLM GPU worker), uses the same TP NCCL group as the main model
(:func:`~vllm.distributed.communication_op.tensor_model_parallel_all_reduce`).
vLLM's TP :meth:`~vllm.distributed.parallel_state.GroupCoordinator.all_reduce`
is **out-of-place** and returns a new tensor. Call sites such as
:class:`~vllm.kvprune.layers.linear.RowParallelLinear` historically invoked
``tensor_parallel_all_reduce(y)`` without using the return value, which left
``y`` as the **unreduced** per-rank partial output under TP>1 — wrong activations,
wrong logits, and garbage tokens. We copy the reduced result back into ``tensor``
so existing call sites remain correct.
Standalone kvprune subprocesses only have the default process group (world ==
``tensor_parallel_size``); in that case we fall back to :func:`torch.distributed.all_reduce`
on the default group.
"""
if not dist.is_initialized() or dist.get_world_size() <= 1:
return tensor
try:
from vllm.distributed.parallel_state import model_parallel_is_initialized
if model_parallel_is_initialized():
from vllm.distributed.communication_op import (
tensor_model_parallel_all_reduce as vllm_tp_all_reduce,
)
reduced = vllm_tp_all_reduce(tensor)
if reduced is not tensor:
# vLLM TP all_reduce is out-of-place: `reduced` holds the cross-rank sum.
# Call sites ignore the return value and expect `tensor` to be updated — we
# MUST materialize the reduced values here or TP>1 keeps per-rank partials
# (RowParallel / VocabParallel outputs stay wrong without this copy).
tensor.copy_(reduced)
return tensor
except Exception:
pass
dist.all_reduce(tensor)
return tensor
"""Tensor-parallel helpers for kvprune when embedded in a vLLM worker."""
from __future__ import annotations
import torch.distributed as dist
def tensor_parallel_rank_for_sharding() -> int:
"""Rank within the tensor-parallel group (matches vLLM weight shards when embedded).
Falls back to :func:`torch.distributed.get_rank` when vLLM parallel state is
unavailable (standalone kvprune with only the default process group).
"""
try:
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
return int(get_tensor_model_parallel_rank())
except Exception:
if dist.is_initialized():
return int(dist.get_rank())
return 0
def tensor_parallel_world_size_for_sharding() -> int:
"""World size of the tensor-parallel group."""
try:
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size,
)
return int(get_tensor_model_parallel_world_size())
except Exception:
if dist.is_initialized():
return int(dist.get_world_size())
return 1
def kv_heads_shard_divisor() -> int:
"""Return world size used to shard KV heads (TP group when vLLM is loaded)."""
return tensor_parallel_world_size_for_sharding()
from __future__ import annotations
import inspect
from typing import Any, Callable, Mapping
import torch
def _filter_kwargs_for_callable(
fn: Callable[..., Any], kwargs: Mapping[str, Any]
) -> dict[str, Any]:
try:
params = inspect.signature(fn).parameters
except (TypeError, ValueError):
return dict(kwargs)
return {k: v for k, v in kwargs.items() if k in params}
def autotune(*, configs, key, **kwargs):
"""
Compatibility wrapper around `triton.autotune`.
Some Triton builds (e.g., custom vendor builds) may not support newer
keyword arguments like `cache_results`. This wrapper filters unsupported
kwargs based on the runtime `triton.autotune` signature.
"""
import triton
filtered = _filter_kwargs_for_callable(triton.autotune, kwargs)
return triton.autotune(configs=configs, key=key, **filtered)
def maybe_set_allocator(alloc_fn: Callable[[int, int, int | None], Any]) -> bool:
"""
Call `triton.set_allocator(alloc_fn)` if present; otherwise no-op.
Returns True if the allocator was set.
"""
import triton
setter = getattr(triton, "set_allocator", None)
if setter is None:
return False
setter(alloc_fn)
return True
def cuda_capability_geq(major: int, minor: int = 0, device: int | None = None) -> bool:
"""
Host-side CUDA capability check that works even when `tl.target_info` is absent.
"""
if not torch.cuda.is_available():
return False
if device is None:
try:
device = torch.cuda.current_device()
except Exception:
device = 0
cap = torch.cuda.get_device_capability(device)
return cap >= (major, minor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment