Commit f81ce56b authored by chenzk's avatar chenzk
Browse files

vllm kvprune:v1.0.1

parent 2b7160c6
import heapq
import logging
from enum import Enum, auto
from typing import List, Optional, Union
import torch
from vllm.kvprune.config.constants import RESERVED_BATCH
from vllm.kvprune.kv_cache.write_page_table import scatter_to_page_table
logger = logging.getLogger(__name__)
def cdiv(a, b):
return (a + b - 1) // b
def next_multiple(a, b):
return cdiv(a, b) * b
class KVAllocationStatus(Enum):
EXCEEDS_MAX_SEQUENCE_LENGTH = auto()
EXCEEDS_CURRENTLY_AVAILABLE_PAGES = auto()
EXCEEDS_MAX_NUM_BATCHES = auto()
SUCCESS = auto()
class PagedKVCache(torch.nn.Module):
"""
Global paged KV cache.
This module manages:
* A global K/V backing buffer for all layers:
``kv_cache[2, num_layers, n_pages * page_size, head_dim]``,
where the first dimension indexes K vs V.
* A per-layer page table:
``page_table[num_layers, max_num_seqs, H_kv, max_pages_per_head]``,
mapping logical (batch, kv-head, logical_page) to a physical page ID
in the global K/V buffer.
* Per-layer, per-(batch, kv-head) logical sequence lengths
``bh_seq_lens[num_layers, max_num_seqs, H_kv]`` (in tokens), and
the number of allocated pages ``bh_num_pages`` for each (layer, batch,
head).
* A page allocator implemented as a min-heap of free physical pages
per layer, plus free batch indices.
Pages are of fixed size ``page_size`` tokens.
Args:
:param num_layers:
Number of transformer layers that will use this cache.
:param max_logical_pages_per_head:
Maximum number of logical pages that can be assigned to a single
(batch, kv-head) pair.
:param num_pages:
Total number of physical pages available in the global cache per
layer. The global K/V buffers are of length
``num_pages * page_size`` along the token dimension.
:param page_size:
Number of tokens stored per page.
:param H_kv:
Number of KV heads per layer.
:param head_dim:
Head dimension for K/V.
:param max_num_batches:
Maximum number of concurrent batches / sequences supported. One
batch index is reserved for internal use (``RESERVED_BATCH``).
:param dtype:
Data type of K/V entries (e.g. ``torch.float16`` or ``torch.bfloat16``).
:param device:
Device on which to allocate the cache (string, torch.device, or
int; defaults to ``"cuda"``).
"""
def __init__(
self,
num_layers: int,
max_logical_pages_per_head: int,
num_pages: int,
page_size: int, # tokens per page
H_kv: int,
head_dim: int,
max_num_batches: int,
dtype: torch.dtype,
device: Union[str, torch.device, int] = "cuda",
):
super().__init__()
self.n_pages = num_pages
self.num_layers = num_layers
self.page_size: int = int(page_size)
self.H_kv = int(H_kv)
self.max_pages_per_head = max_logical_pages_per_head
max_num_batches += 1
self.max_num_batches = max_num_batches
self.head_dim = head_dim
cache_shape = (2, num_layers, num_pages * page_size, head_dim)
self.kv_cache = torch.empty(cache_shape, dtype=dtype, device=device)
self.page_table = torch.empty(
(num_layers, max_num_batches, H_kv, self.max_pages_per_head),
device=device,
dtype=torch.int32,
)
# Per-(batch, head) logical seq length (tokens)
self.bh_seq_lens = torch.zeros(
(num_layers, max_num_batches, H_kv), device=device, dtype=torch.int32
)
# self._bh_seq_lens_cpu_buffer = torch.zeros((num_layers, H_kv), device="cpu", dtype=torch.int32)
self.bh_num_pages = torch.zeros(
(num_layers, max_num_batches, H_kv), device=device, dtype=torch.int32
)
# Page allocator (min-heap of free physical pages)
self.free_pages: List[List[int]] = [
list(range(num_pages)) for _ in range(num_layers)
]
for free_pages in self.free_pages:
heapq.heapify(free_pages)
# batch zero is reserved
self.free_batches: List[int] = list(reversed(range(max_num_batches)))
self.free_batches.remove(RESERVED_BATCH)
# Record of physical page ids owned by a batch (for freeing)
self.pages_indices_per_batch: List[List[set[int]]] = [
[set() for _ in range(num_layers)] for _ in range(max_num_batches)
]
def new_batch(self) -> Optional[int]:
"""
Reserve a new batch slot.
A batch slot corresponds to a row in ``bh_seq_lens`` /
``bh_num_pages`` and a slice in ``page_table`` for all layers and KV
heads. This method checks whether a free batch index is available, and
whether each layer has at least ``H_kv`` free pages remaining.
If both checks pass, it returns a batch index and removes it from
``free_batches``. Otherwise, it returns ``None``.
Returns:
:return Optional[int]:
Newly reserved batch index, or ``None`` if no capacity is
available.
"""
if self.free_batches and all([self.H_kv <= len(fp) for fp in self.free_pages]):
return self.free_batches.pop()
return None
def reserve_tokens(self, batch_index: int, add_tokens: int) -> KVAllocationStatus:
"""
Ensure enough pages are allocated to handle ``add_tokens`` new tokens.
Args:
:param batch_index:
Batch index to reserve space for.
:param add_tokens:
Number of additional tokens to reserve capacity for.
All heads in this batch and all layers reserve
the same number of extra tokens.
Returns:
:return bool:
``True`` if the reservation succeeds; ``False`` otherwise .
"""
cur_bh_lens = self.bh_seq_lens[:, batch_index] # [L, H]
curr_pages = self.bh_num_pages[:, batch_index] # [L, H]
curr_cap_tokens = curr_pages * self.page_size # [L, H]
need_tokens = cur_bh_lens + add_tokens # [L, H]
if (need_tokens <= curr_cap_tokens).all():
return KVAllocationStatus.SUCCESS
missing_tokens = need_tokens - curr_cap_tokens
add_pages = cdiv(missing_tokens, self.page_size)
new_total_pages = curr_pages + add_pages
if (new_total_pages > self.max_pages_per_head).any():
return KVAllocationStatus.EXCEEDS_MAX_SEQUENCE_LENGTH
# CPU work
pages_per_layer_cpu = add_pages.sum(dim=-1).tolist()
new_phys_pages = []
for layer_index in range(self.num_layers):
if pages_per_layer_cpu[layer_index] > len(self.free_pages[layer_index]):
return KVAllocationStatus.EXCEEDS_CURRENTLY_AVAILABLE_PAGES
for layer_index in range(self.num_layers):
this_layer_pages = [
heapq.heappop(self.free_pages[layer_index])
for _ in range(pages_per_layer_cpu[layer_index])
]
self.pages_indices_per_batch[batch_index][layer_index] |= set(
this_layer_pages
)
new_phys_pages.extend(this_layer_pages)
new_phys_pages = torch.tensor(new_phys_pages, dtype=torch.int32, device="cuda")
scatter_to_page_table(
add_pages=add_pages,
new_phys_pages=new_phys_pages,
curr_pages=curr_pages,
page_table=self.page_table[:, batch_index],
max_pages_per_head=self.max_pages_per_head,
)
self.bh_num_pages[:, batch_index, :] = new_total_pages.to(
self.bh_num_pages.dtype
)
return KVAllocationStatus.SUCCESS
def reclaim_pages(
self,
batch_index: int,
future_reserve_tokens: int = 0,
):
"""
Reclaim unused pages for a single batch index. This shrinks the KV
allocation for the batch down to the minimum number of pages needed
to hold the current (plus optional future) sequence length.
Args:
:param batch_index:
Batch index whose pages should be compacted.
:param future_reserve_tokens:
Optional number of extra tokens to keep capacity for, beyond
the current sequence length. This can reduce churn when
sequences are expected to grow slightly in the near future.
Returns:
:return int:
Approximate number of bytes freed across both K and V.
"""
device = self.bh_seq_lens.device
L, B, H = self.bh_seq_lens.shape
assert 0 <= batch_index < B
seq = self.bh_seq_lens[:, batch_index, :] + future_reserve_tokens # [L, H]
alloc = self.bh_num_pages[:, batch_index, :] # [L, H]
pt = self.page_table[:, batch_index, :, :].reshape(-1) # [L, H, P]
# Compute used pages: ceil_div(seq, page_size), clamped into [0, alloc]
used_pages = cdiv(seq, self.page_size)
used_pages = torch.minimum(used_pages, alloc)
# page indices [0..P-1], broadcasted over [L, H, P]
p = torch.arange(
self.max_pages_per_head, device=device, dtype=torch.int32
).view(1, 1, self.max_pages_per_head)
# allocated: p < alloc
alloc_mask = p < alloc.unsqueeze(-1) # [L, H, P]
# to free: allocated and p in [used_pages, alloc)
free_mask = alloc_mask & (p >= used_pages.unsqueeze(-1))
free_mask_flat = free_mask.view(-1) # [L*H*P]
if not free_mask_flat.any():
return 0
idx = free_mask_flat.nonzero(as_tuple=False).squeeze(
-1
) # indices of freed slots
# Freed physical page ids
freed_pages = pt[idx]
# Compute layer index for each freed slot:
# layout is [L, H, P] → flat index = ((l * H) + h) * P + p
freed_layers = (idx // (H * self.max_pages_per_head)).to(torch.int32)
freed_pages = freed_pages.tolist()
layer_mapping = freed_layers.tolist()
self.bh_num_pages[:, batch_index, :] = used_pages
for page, layer in zip(freed_pages, layer_mapping):
self.pages_indices_per_batch[batch_index][layer].remove(page)
heapq.heappush(self.free_pages[layer], page)
approximate_bytes_freed = (
len(freed_pages)
* (self.page_size * self.head_dim * self.kv_cache.element_size())
* 2
) # multiply for two for K + V
return approximate_bytes_freed
def _free_batch_layer(self, layer_index: int, batch_index: int) -> None:
"""
Free all pages belonging to batch_index and reset its metadata.
"""
# Return pages to the global heap
for phys in self.pages_indices_per_batch[batch_index][layer_index]:
heapq.heappush(self.free_pages[layer_index], int(phys))
self.pages_indices_per_batch[batch_index][layer_index] = set()
def free_batch(self, batch_index: int) -> None:
"""
Free all resources associated with a batch index.
Args:
:param batch_index:
Batch index to release. Must have been previously allocated
via :meth:`new_batch`.
"""
for layer in range(self.num_layers):
self._free_batch_layer(layer, batch_index)
self.bh_seq_lens[:, batch_index].zero_()
self.bh_num_pages[:, batch_index].zero_()
self.free_batches.append(batch_index)
def layer_slices(self, layer: int):
"""
Return layer-local views needed by the attention module.
For a given ``layer`` index, this method returns the slices of the
global K/V cache, page table, and per-(batch, head) sequence lengths
corresponding to that layer.
Args:
:param layer:
Layer index ``l`` in ``[0, num_layers)``.
Returns:
:return Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
``(k, v, pt, bh)`` as described above.
"""
assert 0 <= layer < self.num_layers
k = self.kv_cache[0, layer]
v = self.kv_cache[1, layer]
pt = self.page_table[layer]
bh = self.bh_seq_lens[layer]
return k, v, pt, bh
import torch
import triton
import triton.language as tl
from vllm.kvprune.config.constants import (
TRITON_RESERVED_BATCH as _TRITON_RESERVED_BATCH,
)
@triton.jit
def _prefill_store_topk_kv_kernel(
key,
value, # [N_total, H, D] (D stride assumed 1)
batch_mapping, # [B] int32 (local b -> true batch)
num_tokens_to_retain, # [B] int32
indices_topk, # [B, MAX_SEL] int32 (across all heads)
# Lengths & page table:
bh_lens, # [B, H] int32 (contiguous)
page_table, # [B_total * H * N_LOGICAL_PAGES_MAX] int32 (flattened), read-only
k_cache,
v_cache, # [N_PAGES * PAGE_SIZE, D]
sk_n,
sk_h, # strides for key,value. D stride assumed 1
sv_n,
sv_h,
# Runtime ints
MAX_SEL, # num tokens that are ranked in indices for each batch (might be bigger than num_tokens_to_retain)
HKV: tl.constexpr,
N_LOGICAL_PAGES_MAX: tl.constexpr,
D: tl.constexpr,
PAGE_SIZE: tl.constexpr,
K_TILE: tl.constexpr, # how many selected tokens each program processes
TRITON_RESERVED_BATCH: tl.constexpr,
):
b_local = tl.program_id(0)
tile_id = tl.program_id(1)
offs = tl.arange(0, D)
# how many tokens we actually keep for this batch
k_total = tl.load(num_tokens_to_retain + b_local)
if k_total == 0:
return
# map to true batch row in the page table
b_true = tl.load(batch_mapping + b_local)
if b_true == TRITON_RESERVED_BATCH:
return
base = tile_id * K_TILE
# process up to K_TILE tokens
for j in tl.range(0, K_TILE):
sel_idx = base + j
if sel_idx < k_total and sel_idx < MAX_SEL:
# flattened selection: sel = token * H + head
sel = tl.load(indices_topk + b_local * MAX_SEL + sel_idx)
tok = sel // HKV
head = sel - (tok * HKV)
# atomically reserve one position in (b_local, hed)
# i.e the KV cache is scrambled when storing
len_ptr = bh_lens + b_local * HKV + head
pos = tl.atomic_add(len_ptr, 1) # old length (int32)
lp = pos // PAGE_SIZE
off = pos - lp * PAGE_SIZE
# translate logical page to physical page
pt_base = (b_true * HKV + head) * N_LOGICAL_PAGES_MAX
phys = tl.load(page_table + pt_base + lp).to(tl.int64)
# destination row and element offset
dst_row = phys * PAGE_SIZE + off
dst_off = dst_row * D + offs
# load one vector from [N_total, H, D]
k_src = key + tok * sk_n + head * sk_h + offs
v_src = value + tok * sv_n + head * sv_h + offs
tl.store(
k_cache + dst_off,
tl.load(k_src, cache_modifier=".cv", eviction_policy="evict_first"),
eviction_policy="evict_first",
)
tl.store(
v_cache + dst_off,
tl.load(v_src, cache_modifier=".cv", eviction_policy="evict_first"),
eviction_policy="evict_first",
)
def prefill_store_topk_kv(
*,
new_keys: torch.Tensor, # [N_total, H, D]
new_vals: torch.Tensor, # [N_total, H, D]
indices_topk: torch.Tensor, # [B, MAX_SEL] int32 (global flattened token*H + head)
num_tokens_to_retain: torch.Tensor, # [B] int32
page_table: torch.Tensor, # [B_total, H, N_LOGICAL_PAGES_MAX] int32
batch_mapping: torch.Tensor, # [B] int32 (local -> true batch rows)
bh_lens: torch.Tensor, # [B, H] int32 (contiguous), UPDATED atomically
k_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D]
v_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D]
PAGE_SIZE: int,
PAD_TO_PAGE_SIZE: bool = True,
cu_seqlens_k: torch.Tensor | None = None,
K_TILE: int = 16,
TRITON_RESERVED_BATCH: int = None,
):
assert new_keys.shape == new_vals.shape
N_total, H, D = new_keys.shape
B = indices_topk.shape[0]
assert page_table.shape[1] == H
assert bh_lens.shape == (B, H)
assert new_keys.device == k_cache.device == v_cache.device
assert page_table.is_contiguous(), "page table must be contiguous."
assert bh_lens.is_contiguous(), "bh_lens must be contiguous."
assert batch_mapping.is_contiguous(), "batch mapping must be contiguous."
assert k_cache.is_contiguous() and v_cache.is_contiguous()
assert new_keys.stride(-1) == 1 and new_vals.stride(-1) == 1, (
"new_keys/new_vals last dim must be contiguous."
)
assert (D & (D - 1)) == 0, "D must be a power of 2"
page_table = page_table.to(torch.int32)
bh_lens = bh_lens.to(torch.int32)
batch_mapping = batch_mapping.to(torch.int32)
indices_topk = indices_topk.to(torch.int32)
num_tokens_to_retain = num_tokens_to_retain.to(torch.int32)
# strides (elements) for [N_total, H, D]
sk_n, sk_h, _ = new_keys.stride()
sv_n, sv_h, _ = new_vals.stride()
# tile second grid dim
MAX_SEL = indices_topk.shape[-1]
N_TILES = (MAX_SEL + K_TILE - 1) // K_TILE
grid = (B, max(1, N_TILES))
if TRITON_RESERVED_BATCH is None:
TRITON_RESERVED_BATCH = _TRITON_RESERVED_BATCH
_prefill_store_topk_kv_kernel[grid](
key=new_keys,
value=new_vals,
batch_mapping=batch_mapping,
num_tokens_to_retain=num_tokens_to_retain,
indices_topk=indices_topk,
bh_lens=bh_lens,
page_table=page_table,
k_cache=k_cache,
v_cache=v_cache,
sk_n=sk_n,
sk_h=sk_h,
sv_n=sv_n,
sv_h=sv_h,
MAX_SEL=int(MAX_SEL),
HKV=H,
N_LOGICAL_PAGES_MAX=page_table.shape[2],
D=D,
PAGE_SIZE=PAGE_SIZE,
K_TILE=K_TILE,
TRITON_RESERVED_BATCH=TRITON_RESERVED_BATCH,
)
if PAD_TO_PAGE_SIZE:
assert cu_seqlens_k is not None
assert indices_topk.is_contiguous()
assert page_table.is_contiguous()
_prefill_store_topk_pad_kernel[(B, H)](
key=new_keys,
value=new_vals,
batch_mapping=batch_mapping,
num_tokens_to_retain=num_tokens_to_retain,
indices=indices_topk,
local_lens=bh_lens,
page_table_flat=page_table,
k_cache=k_cache,
v_cache=v_cache,
cu_seqlens_k=cu_seqlens_k,
sk_n=sk_n,
sk_h=sk_h,
sv_n=sv_n,
sv_h=sv_h,
MAX_SEL=int(MAX_SEL),
H=H, # type: ignore
N_LOGICAL_PAGES_MAX=page_table.shape[2], # type: ignore
D=D, # type: ignore
PAGE_SIZE=PAGE_SIZE, # type: ignore
TRITON_RESERVED_BATCH=TRITON_RESERVED_BATCH,
)
@triton.jit
def _prefill_store_topk_pad_kernel(
key, # [N_total, H, D]
value, # [N_total, H, D]
batch_mapping, # [B] int32 (local b -> true batch)
num_tokens_to_retain, # [B] int32
indices, # [B, MAX_SEL] int32 (across all heads)
local_lens, # [B, H] int32 (contiguous)
page_table_flat, # [B_total*H*N_LOGICAL_PAGES_MAX] int32
k_cache,
v_cache, # [N_PAGES*PAGE_SIZE, D]
cu_seqlens_k,
sk_n,
sk_h,
sv_n,
sv_h,
MAX_SEL,
# Constexprs
H: tl.constexpr, # number of KV heads
N_LOGICAL_PAGES_MAX: tl.constexpr,
D: tl.constexpr,
PAGE_SIZE: tl.constexpr,
TRITON_RESERVED_BATCH: tl.constexpr,
):
b_local = tl.program_id(0)
h = tl.program_id(1)
offs_d = tl.arange(0, D)
L = tl.load(local_lens + b_local * H + h)
modulo_page_size = L - (L // PAGE_SIZE) * PAGE_SIZE
if modulo_page_size == 0:
return
need = PAGE_SIZE - modulo_page_size
b_true = tl.load(batch_mapping + b_local)
if b_true == TRITON_RESERVED_BATCH:
return
pt_base = (b_true * H + h) * N_LOGICAL_PAGES_MAX
written_tokens = 0
idx = tl.load(num_tokens_to_retain + b_local)
this_batch_ctx_len = tl.load(cu_seqlens_k + b_local + 1) - tl.load(
cu_seqlens_k + b_local
)
max_additional = this_batch_ctx_len - L
while (written_tokens < need and idx < MAX_SEL) and (
written_tokens < max_additional
):
# candidate head
cand_idx = tl.load(indices + b_local * MAX_SEL + idx)
cand_h = cand_idx % H
if cand_h == h:
tok = cand_idx // H
pos = L + written_tokens
lp = pos // PAGE_SIZE
off = pos - lp * PAGE_SIZE
phys = tl.load(page_table_flat + pt_base + lp).to(tl.int32)
dst_row = phys * PAGE_SIZE + off
dst_off = dst_row.to(tl.int64) * D + offs_d
k_src = key + tok * sk_n + h * sk_h + offs_d
v_src = value + tok * sv_n + h * sv_h + offs_d
tl.store(
k_cache + dst_off,
tl.load(k_src),
)
tl.store(
v_cache + dst_off,
tl.load(v_src),
)
written_tokens += 1
idx += 1
tl.store(local_lens + b_local * H + h, L + written_tokens)
@triton.jit
def _prefill_store_all_kv_kernel(
key,
value, # [N, H, D] (D contiguous)
cu_seqlens_k, # [B + 1] int32
batch_mapping, # [B] int32 (local b -> true batch index)
bh_lens, # [B * HKV] int32 (UPDATED)
pt_flat, # [B_total * HKV * N_LOGICAL_PAGES_MAX] int32 (flattened)
k_cache,
v_cache, # [N_PAGES * PAGE_SIZE, D]
# source strides (elements)
sk_n,
sk_h,
sv_n,
sv_h,
# constexpr
HKV: tl.constexpr,
N_LOGICAL_PAGES_MAX: tl.constexpr,
D: tl.constexpr,
PAGE_SIZE: tl.constexpr,
K_TILE: tl.constexpr, # number of (token, head) pairs processed per program
):
pid_b = tl.program_id(0)
pid_blk = tl.program_id(1)
start = tl.load(cu_seqlens_k + pid_b)
end = tl.load(cu_seqlens_k + pid_b + 1)
num_toks_this_batch = end - start
if num_toks_this_batch <= 0:
return
total_elems = num_toks_this_batch * HKV
# base linear index in (token, head) grid for this program
base = pid_blk * K_TILE
offs_d = tl.arange(0, D)
# Iterate K_TILE elements in this tile
for i in tl.range(0, K_TILE):
idx = base + i
if idx < total_elems:
# map linear idx -> (t, h)
t = idx // HKV
h = idx - t * HKV
len_idx = pid_b * HKV + h
L0 = tl.load(bh_lens + len_idx)
token_idx_in_cache = L0 + t
lp = token_idx_in_cache // PAGE_SIZE # logical page
off_in_pg = token_idx_in_cache - lp * PAGE_SIZE # pos in page
# physical page
b_true = tl.load(batch_mapping + pid_b).to(tl.int32)
pt_base = (b_true * HKV + h) * N_LOGICAL_PAGES_MAX
phys = tl.load(pt_flat + pt_base + lp).to(tl.int64)
row = phys * PAGE_SIZE + off_in_pg
dst_off = row * D + offs_d
n_global = (start + t).to(tl.int64)
# Use strides for non-contiguous [N, H, D] (D stride == 1)
k_src = key + n_global * sk_n + h * sk_h + offs_d
v_src = value + n_global * sv_n + h * sv_h + offs_d
tl.store(k_cache + dst_off, tl.load(k_src))
tl.store(v_cache + dst_off, tl.load(v_src))
def prefill_store_all_kv(
*,
new_keys: torch.Tensor,
new_values: torch.Tensor, # [N, H_kv, D]
cu_seqlens_k: torch.Tensor, # [B + 1] int32
max_seqlen_k: int,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
page_table: torch.Tensor, # [B_total, H_kv, N_LOGICAL_PAGES_MAX] int32
bh_lens: torch.Tensor, # [B, H_kv] int32 (UPDATED)
batch_mapping: torch.Tensor, # [B] int32 (local->true)
PAGE_SIZE: int,
K_TILE: int = 32, # how many (token, head) pairs per program
):
assert new_keys.stride(-1) == 1 and new_values.stride(-1) == 1, (
"last dim must be contiguous"
)
assert page_table.is_contiguous(), "page table must be contiguous"
assert bh_lens.is_contiguous(), "bh_lens must be contiguous"
assert batch_mapping.is_contiguous(), "batch mapping must be contiguous"
assert k_cache.is_contiguous() and v_cache.is_contiguous()
N, HKV, D = new_keys.shape
B = batch_mapping.shape[0]
assert (D & (D - 1)) == 0, "D must be a power of 2"
sk_n, sk_h, _ = new_keys.stride()
sv_n, sv_h, _ = new_values.stride()
n_tiles = (max_seqlen_k * HKV + K_TILE - 1) // K_TILE
grid = (B, n_tiles)
_prefill_store_all_kv_kernel[grid](
new_keys,
new_values,
cu_seqlens_k,
batch_mapping,
bh_lens,
page_table,
k_cache,
v_cache,
sk_n=sk_n,
sk_h=sk_h,
sv_n=sv_n,
sv_h=sv_h,
HKV=HKV,
N_LOGICAL_PAGES_MAX=page_table.shape[-1],
D=D,
PAGE_SIZE=PAGE_SIZE,
K_TILE=K_TILE,
)
bh_lens += cu_seqlens_k.diff()[:, None]
@triton.jit
def _decode_store_kv_kernel(
key,
value,
batch_mapping, # [B] int32
bh_lens, # [B*HKV] int32
page_table, # [B_total*HKV*N_LOGICAL_PAGES_MAX]
k_cache,
v_cache, # [N_PAGES*PAGE_SIZE, D]
sk_b,
sk_h,
sv_b,
sv_h,
HKV: tl.constexpr,
N_LOGICAL_PAGES_MAX: tl.constexpr,
D: tl.constexpr,
PAGE_SIZE: tl.constexpr,
TRITON_RESERVED_BATCH: tl.constexpr,
):
pid_b = tl.program_id(0)
h = tl.program_id(1)
mapped_b = tl.load(batch_mapping + pid_b)
if mapped_b == TRITON_RESERVED_BATCH:
return
offs_d = tl.arange(0, D)
length = tl.load(bh_lens + pid_b * HKV + h)
logical_page = length // PAGE_SIZE
internal_offset = length - logical_page * PAGE_SIZE
pt_base = (mapped_b * HKV + h) * N_LOGICAL_PAGES_MAX
physical_page = tl.load(page_table + pt_base + logical_page).to(tl.int64)
dst_row = physical_page * PAGE_SIZE + internal_offset
# Source addressing using strides (D stride == 1)
k_src = key + pid_b * sk_b + h * sk_h + offs_d
v_src = value + pid_b * sv_b + h * sv_h + offs_d
dst_off = dst_row * D + offs_d
tl.store(k_cache + dst_off, tl.load(k_src))
tl.store(v_cache + dst_off, tl.load(v_src))
tl.store(bh_lens + pid_b * HKV + h, length + 1)
def decode_store_kv(
*,
key: torch.Tensor, # [B, HKV, D]
value: torch.Tensor, # [B, HKV, D]
batch_mapping: torch.Tensor, # [B] int32
bh_lens: torch.Tensor, # [B, HKV] or flattened [B*HKV] int32
page_table: torch.Tensor, # [B_total, HKV, N_LOGICAL_PAGES_MAX] int32
k_cache: torch.Tensor,
v_cache: torch.Tensor, # [N_PAGES*PAGE_SIZE, D]
PAGE_SIZE: int,
TRITON_RESERVED_BATCH: int = None,
):
assert key.shape == value.shape and key.ndim == 3, "key/value must be [B, HKV, D]"
B, HKV, D = key.shape
assert key.stride(-1) == 1 and value.stride(-1) == 1, (
"key/value last dim must be contiguous."
)
assert page_table.is_contiguous(), "page table must be contiguous."
assert bh_lens.is_contiguous(), "bh_lens must be contiguous."
assert batch_mapping.is_contiguous(), "batch mapping must be contiguous."
assert k_cache.is_contiguous() and v_cache.is_contiguous()
assert (D & (D - 1)) == 0, "D must be a power of 2"
sk_b, sk_h, _ = key.stride()
sv_b, sv_h, _ = value.stride()
grid = (
int(batch_mapping.shape[0]),
HKV,
)
_decode_store_kv_kernel[grid](
key=key,
value=value,
batch_mapping=batch_mapping,
bh_lens=bh_lens,
page_table=page_table,
k_cache=k_cache,
v_cache=v_cache,
sk_b=sk_b,
sk_h=sk_h,
sv_b=sv_b,
sv_h=sv_h,
HKV=HKV,
N_LOGICAL_PAGES_MAX=page_table.shape[2],
D=D,
PAGE_SIZE=PAGE_SIZE,
TRITON_RESERVED_BATCH=TRITON_RESERVED_BATCH
if TRITON_RESERVED_BATCH is not None
else _TRITON_RESERVED_BATCH,
)
import torch
import triton
import triton.language as tl
def scatter_to_page_table(
add_pages: torch.Tensor, # [L, H] int32
new_phys_pages: torch.Tensor, # [N]
curr_pages: torch.Tensor, # [L, H] int32
page_table: torch.Tensor, # [L, H, max_pages_per_head] int32, NOT assumed contiguous globally
max_pages_per_head: int,
):
"""
Append newly allocated physical pages into a layered page table via Triton.
For each (layer ``l``, head ``h``):
Args:
:param add_pages:
Tensor of shape ``[L, H]`` (int32) indicating how many pages to
append for each (layer, head).
:param new_phys_pages:
1D tensor of shape ``[N]`` (int32) containing physical page IDs
for all (layer, head) pairs, concatenated in row-major (L, H)
order. ``N`` must equal ``add_pages.sum()``.
:param curr_pages:
Tensor of shape ``[L, H]`` (int32) with the current logical page
counts per (layer, head) before this update.
:param page_table:
Tensor of shape ``[L, H, max_pages_per_head]`` (int32) holding
the logical to physical page mapping. The last dimension is
logically indexed as logical_page ∈ [0, max_pages_per_head).
:param max_pages_per_head:
Maximum number of logical pages permitted per (layer, head). The
kernel skips writes beyond this bound.
Returns:
None. The function updates ``page_table`` in-place.
"""
L, H = add_pages.shape
if L == 0 or H == 0:
return
add_flat = add_pages.to(torch.int32).contiguous().view(-1)
curr_flat = curr_pages.to(torch.int32).contiguous().view(-1)
cum_page_heads = torch.empty(L * H + 1, device="cuda", dtype=torch.int32)
cum_page_heads[0] = 0
torch.cumsum(add_flat, 0, out=cum_page_heads[1:])
stride_pl, stride_ph, stride_pp = page_table.stride()
grid = (L, H)
_scatter_pages_kernel_lh[grid](
add_flat,
cum_page_heads,
new_phys_pages,
curr_flat,
page_table,
stride_pl,
stride_ph,
stride_pp,
L=L,
H=H,
max_pages_per_head=max_pages_per_head,
)
@triton.jit
def _scatter_pages_kernel_lh(
add_pages, # int32 [L*H]
cum_page_heads, # int32 [L*H], base offset in flat_new_phys per (l,h)
flat_new_phys, # int32 [total_pages]
curr_pages, # int32 [L*H], existing logical pages per (l,h)
page_table_ptr, # int32* base pointer to page_table
stride_pl, # int, stride for layer dim
stride_ph, # int, stride for head dim
stride_pp, # int, stride for page dim
L: tl.constexpr,
H: tl.constexpr,
max_pages_per_head: tl.constexpr,
):
layer_idx = tl.program_id(0)
h = tl.program_id(1)
if layer_idx >= L or h >= H:
return
lh = layer_idx * H + h
ap = tl.load(add_pages + lh)
if ap <= 0:
return
base = tl.load(cum_page_heads + lh)
cp = tl.load(curr_pages + lh)
# Append ap pages: logical pages [cp .. cp+ap)
for i in tl.range(0, ap):
phys = tl.load(flat_new_phys + base + i)
lp = cp + i
if lp < max_pages_per_head:
offset = layer_idx * stride_pl + h * stride_ph + lp * stride_pp
tl.store(page_table_ptr + offset, phys)
# TODO: write reclaim kernel
@triton.jit
def reclaim_page_kernel():
pass
def reclaim_pages(
batch_index: int,
bh_seq_lens: torch.Tensor,
bh_num_pages: torch.Tensor,
page_table: torch.Tensor,
):
pass
# KV-prune 与上游 vLLM 的集成说明
本文说明:**剪枝/压缩(Compactor)功能**在「官网 vLLM 主仓库」里改动了哪些位置、是否只有少量文件、以及随 vLLM 版本升级时如何预期合并成本。
## 1. 是否「仅仅」改了少数几个脚本?
**核心运行时接线**确实集中在少数几个**非** `vllm/kvprune/` 下的文件;功能主体在 `vllm/kvprune/` 包内独立维护。
| 路径 | 作用简述 |
|------|-----------|
| `vllm/env_override.py` | 在 `import vllm` 最早阶段设置与 kvprune 相关的默认环境变量(如 v1 多进程默认、压缩默认开关、可选释放 v1 KV 等)。 |
| `vllm/__init__.py` | 对外导出 `CompressionParams`(懒加载至 `vllm.kvprune.integration.compression_params`)。 |
| `vllm/entrypoints/llm.py` | `kvprune_compression` 参数、`generate(..., compression=...)`、v1 `enforce_eager` / `num_gpu_blocks_override` 策略、懒加载 compactor、委托 `compressed_generate`。 |
| `vllm/v1/worker/gpu_worker.py` | `kvprune_v1_compressed_generate`:供 `collective_rpc` 调用的 TP 多卡压缩生成入口。 |
| `tests/conftest.py` | 测试在导入 vLLM 前覆盖部分 `VLLM_KVPRUNE_*` 默认值,避免全量测试默认走压缩路径。 |
| `vllm\vllm\envs.py` | envs.py 中对 VLLM_KVPRUNE_* 的集中注册 |
**此外(可选/示例,非引擎必需):**
- `examples/offline_inference/` 下若干 `*kvprune*` 示例脚本:演示用法,不参与核心引擎加载。
**结论:**
- **「官网 vLLM 主包」里与 kvprune 强相关的改动,主要就是上表 4 个文件 + 测试根配置**(若把测试也算进「集成面」,共 5 处常见提法)。
- **算法、Compactor、TP 内嵌 runner 等**均在 `vllm/kvprune/`(及该目录下的 `integration/`)中,与上游 diff 相对隔离。
## 2. 随 vLLM 版本更新,是否「很容易」同步剪枝压缩功能?
**相对容易的部分:**
- **集成面小**:合并冲突主要出现在上述少数文件,而不是遍布整个 executor / attention / model 层。
- **逻辑内聚**:大量代码在 `vllm/kvprune/`,可整体移植或 `git` 三方合并时以子树为主处理。
**仍需人工跟进的点(不能假设「自动无痛」):**
- **`entrypoints/llm.py` 属于高频变更文件**:上游每次大版本可能重构 `LLM` 构造参数、`generate` 签名或引擎初始化;需要**逐次解决冲突**并回归压缩路径。
- **`v1/worker/gpu_worker.py`** 同样会随 executor / RPC 接口变动;`collective_rpc` 方法名或 worker 基类若有变化,需对齐。
- **`env_override.py`** 若上游调整导入顺序或新增全局默认环境变量,需避免覆盖冲突或行为打架。
- **vLLM v1 内部 API**(如 `worker.get_model()``vllm_config` 结构)若变更,`vllm/kvprune/integration/*` 也可能要跟着改——这类改动**不在**「仅 5 个文件」里,但仍是**集成层**维护成本。
**建议同步流程(简版):**
1. 在新上游 tag 上先合并/应用 `vllm/kvprune/` 目录。
2. 再手动合并上述 4 个主包文件 + `tests/conftest.py`
3. 跑与 kvprune 相关的测试与至少一条离线 `compression` 示例。
4. 关注发行说明中 `LLM``EngineArgs``gpu_worker`、多进程默认的破坏性变更。
## 3. 与「深度改内核」方案的区别
当前设计**没有**`model_executor` 的统一注意力路径上大规模插入 kvprune 钩子(相关辅助逻辑主要在 `vllm/kvprune` 内部)。因此:
- **上游同步时**,通常不必与 FlashAttention / 每层模型代码逐文件对打;
- **代价是**:功能边界以「共享权重 + compactor 引擎 + 可选 TP RPC」为主,与「原生 KV 算子级一体化」的改动面不同。
---
*文档随仓库维护;若集成文件列表有增删,请同步更新本节表格。*
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Layers from upstream compactor (attention, linear, MoE, …).
Prefer importing concrete modules, e.g. ``from vllm.kvprune.layers.attention import ...``.
"""
__all__: list[str] = []
import torch
import torch.nn.functional as F
from torch import nn
class SiluAndMul(nn.Module):
def __init__(self):
super().__init__()
# @torch.compile
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, y = x.chunk(2, -1)
return F.silu(x) * y
from typing import Optional
import torch
from flash_attn.flash_attn_interface import flash_attn_varlen_func
from torch import nn
from vllm.kvprune.attention.fa_paged_bridge import (
flash_decode_from_paged,
flash_prefill_from_paged,
)
from vllm.kvprune.attention.sparse_decode_kernel import head_sparse_decode_attention
from vllm.kvprune.attention.sparse_varlen_kernel import (
causal_sparse_varlen_with_cache,
)
from vllm.kvprune.compression.common import extract_and_store_top_kv
from vllm.kvprune.config.engine_config import KvpruneAttentionSchedule
from vllm.kvprune.kv_cache.store_kv_cache import decode_store_kv, prefill_store_all_kv
from vllm.kvprune.utils.context import Context, get_context
from vllm.kvprune.utils.helpers import maybe_execute_in_stream
class Attention(nn.Module):
def __init__(
self,
num_heads,
head_dim,
scale,
num_kv_heads,
):
super().__init__()
self.num_heads: int = num_heads
self.head_dim = head_dim
self.scale: float = scale
self.num_kv_heads = int(num_kv_heads)
self.k_cache: Optional[torch.Tensor] = None
self.v_cache: Optional[torch.Tensor] = None
self.page_table: Optional[torch.Tensor] = None
self.bh_seq_lens: Optional[torch.Tensor] = None
self.page_size: Optional[int] = None
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scores: Optional[torch.Tensor] = None,
):
context: Context = get_context()
batch_mapping = context.batch_mapping
seq_lens = (
None
if self.bh_seq_lens is None
else self.bh_seq_lens.index_select(0, batch_mapping).contiguous()
)
sched = context.attention_schedule
use_triton_prefill_attn = (
sched == KvpruneAttentionSchedule.TRITON_PREFILL_TRITON_DECODE
)
use_fa_decode = sched == KvpruneAttentionSchedule.PDFA
if context.is_prefill:
seq_lens_copy = seq_lens.clone() if seq_lens is not None else None
if (
self.k_cache is not None
and context.do_compression
and scores is not None
):
compression_context = context.compression_context
assert scores is not None
assert compression_context is not None
maybe_execute_in_stream(
extract_and_store_top_kv,
scores=scores,
cu_seqlens_k=context.cu_seqlens_k,
max_k_len=context.max_seqlen_k,
top_k=compression_context.max_tokens_to_retain,
H=int(self.num_kv_heads),
new_keys=k,
new_vals=v,
num_tokens_to_retain=compression_context.batch_tokens_to_retain,
page_table=self.page_table,
batch_mapping=batch_mapping,
bh_lens=seq_lens,
k_cache=self.k_cache,
v_cache=self.v_cache,
PAGE_SIZE=self.page_size,
PAD_TO_PAGE_SIZE=True,
STORE_STREAM=context.STORE_STREAM,
)
elif self.k_cache is not None:
maybe_execute_in_stream(
prefill_store_all_kv,
new_keys=k,
new_values=v,
cu_seqlens_k=context.cu_seqlens_k,
max_seqlen_k=context.max_seqlen_k,
k_cache=self.k_cache,
v_cache=self.v_cache,
page_table=self.page_table,
bh_lens=seq_lens,
batch_mapping=batch_mapping,
PAGE_SIZE=self.page_size,
STORE_STREAM=context.STORE_STREAM,
)
if use_triton_prefill_attn:
if context.do_compression and context.STORE_STREAM is not None:
torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
assert seq_lens_copy is not None
o = causal_sparse_varlen_with_cache(
q,
k,
v,
self.k_cache,
self.v_cache,
seq_lens_bh=seq_lens_copy,
global_page_table=self.page_table,
batch_mapping=batch_mapping,
cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_q=context.max_seqlen_q,
max_seqlen_k_cache=context.max_bh_len,
HKV=int(self.num_kv_heads),
PAGE_SIZE=self.page_size,
sm_scale=self.scale,
)
elif context.do_compression:
if context.STORE_STREAM is not None:
torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
assert seq_lens_copy is not None
o = flash_prefill_from_paged(
q,
k,
v,
self.k_cache,
self.v_cache,
seq_lens_bh_before=seq_lens_copy,
global_page_table=self.page_table,
batch_mapping=batch_mapping,
cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_q=context.max_seqlen_q,
PAGE_SIZE=self.page_size,
HKV=int(self.num_kv_heads),
sm_scale=self.scale,
)
else:
o = flash_attn_varlen_func(
q,
k,
v,
max_seqlen_q=context.max_seqlen_q,
cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k,
cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale,
causal=True,
)
else:
assert self.k_cache is not None, "KV Cache must be initialized for decoding"
decode_store_kv(
key=k,
value=v,
batch_mapping=batch_mapping,
bh_lens=seq_lens,
page_table=self.page_table,
k_cache=self.k_cache,
v_cache=self.v_cache,
PAGE_SIZE=self.page_size,
)
if use_fa_decode:
assert seq_lens is not None
o = flash_decode_from_paged(
q,
self.k_cache,
self.v_cache,
seq_lens_bh=seq_lens,
global_page_table=self.page_table,
batch_mapping=batch_mapping,
PAGE_SIZE=self.page_size,
HKV=int(self.num_kv_heads),
sm_scale=self.scale,
)
else:
o = head_sparse_decode_attention(
q,
self.k_cache,
self.v_cache,
seq_lens,
self.page_table,
batch_mapping,
int(self.num_kv_heads),
self.page_size,
self.scale,
key_split=context.key_split,
)
# Match compactor_vllm ``Attention``: ``index_copy_`` into the global
# ``bh_seq_lens`` table. The Triton masked copy was a CUDA fast path but
# disagreed with decode_store_kv / paged attention bookkeeping in edge
# cases and could leave lengths stale → garbage logits / immediate EOS.
if self.bh_seq_lens is not None:
longbm = batch_mapping.to(
device=self.bh_seq_lens.device, dtype=torch.long
)
maybe_execute_in_stream(
self.bh_seq_lens.index_copy_,
0,
longbm,
seq_lens,
STORE_STREAM=context.STORE_STREAM if context.is_prefill else None,
)
return o
import torch
import torch.distributed as dist
import torch.nn.functional as F
from vllm.kvprune.utils.context import get_context
from vllm.kvprune.utils.tp_collectives import tensor_parallel_all_reduce
from vllm.kvprune.utils.tp_utils import (
tensor_parallel_rank_for_sharding,
tensor_parallel_world_size_for_sharding,
)
from torch import nn
class VocabParallelEmbedding(nn.Module):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
):
super().__init__()
self.tp_rank = tensor_parallel_rank_for_sharding()
self.tp_size = tensor_parallel_world_size_for_sharding()
assert num_embeddings % self.tp_size == 0
self.num_embeddings = num_embeddings
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
self.weight = nn.Parameter(
torch.empty(self.num_embeddings_per_partition, embedding_dim)
)
self.weight.weight_loader = self.weight_loader
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
shard_size = param_data.size(0)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
param_data.copy_(loaded_weight)
def forward(self, x: torch.Tensor):
if self.tp_size > 1:
mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
x = mask * (x - self.vocab_start_idx)
y = F.embedding(x, self.weight)
if self.tp_size > 1:
y = mask.unsqueeze(1) * y
tensor_parallel_all_reduce(y)
return y
class ParallelLMHead(VocabParallelEmbedding):
"""LM head with TP vocab sharding.
When embedded in a vLLM worker, logits must be gathered on the **tensor-
parallel** process group (see :func:`~vllm.distributed.communication_op.tensor_model_parallel_gather`),
not the default :func:`torch.distributed.gather` — otherwise shard order / group
mismatch yields garbage logits and decoded gibberish.
After gather, logits are truncated to ``org_vocab_size`` (HF tokenizer vocab),
matching :class:`~vllm.model_executor.layers.logits_processor.LogitsProcessor`
removal of padded vocabulary columns.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
*,
org_vocab_size: int | None = None,
):
assert not bias
super().__init__(num_embeddings, embedding_dim)
# Original (unpadded) vocab size for logits truncation; defaults to num_embeddings.
self.org_vocab_size = (
int(org_vocab_size) if org_vocab_size is not None else num_embeddings
)
def forward(self, x: torch.Tensor):
context = get_context()
if context.is_prefill:
cu = context.cu_seqlens_q
last_indices = (cu[1:] - 1).to(torch.long)
n_tok = x.shape[0]
if n_tok > 0:
last_indices = last_indices.clamp(min=0, max=n_tok - 1)
x = x[last_indices].contiguous()
logits = F.linear(x, self.weight)
if self.tp_size > 1:
logits = self._gather_logits_tp(logits)
if logits is not None and logits.shape[-1] > self.org_vocab_size:
logits = logits[..., : self.org_vocab_size]
return logits
def _gather_logits_tp(self, logits: torch.Tensor) -> torch.Tensor | None:
try:
from vllm.distributed.parallel_state import model_parallel_is_initialized
from vllm.distributed.communication_op import (
tensor_model_parallel_gather,
)
if model_parallel_is_initialized():
return tensor_model_parallel_gather(logits, dst=0, dim=-1)
except Exception:
pass
all_logits = (
[torch.empty_like(logits) for _ in range(self.tp_size)]
if self.tp_rank == 0
else None
)
dist.gather(logits, all_logits, 0)
return torch.cat(all_logits, -1) if self.tp_rank == 0 else None
import torch
from torch import nn
class RMSNorm(nn.Module):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size))
# @torch.compile
def rms_forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
orig_dtype = x.dtype
x = x.float()
var = x.pow(2).mean(dim=-1, keepdim=True)
x.mul_(torch.rsqrt(var + self.eps))
x = x.to(orig_dtype).mul_(self.weight)
return x
# @torch.compile
def add_rms_forward(
self,
x: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
orig_dtype = x.dtype
x = x.float().add_(residual.float())
residual = x.to(orig_dtype)
var = x.pow(2).mean(dim=-1, keepdim=True)
x.mul_(torch.rsqrt(var + self.eps))
x = x.to(orig_dtype).mul_(self.weight)
return x, residual
def forward(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if residual is None:
return self.rms_forward(x)
else:
return self.add_rms_forward(x, residual)
import torch
import torch.distributed as dist
import torch.nn.functional as F
from vllm.kvprune.utils.tp_collectives import tensor_parallel_all_reduce
from vllm.kvprune.utils.tp_utils import (
tensor_parallel_rank_for_sharding,
tensor_parallel_world_size_for_sharding,
)
from torch import nn
def divide(numerator, denominator):
assert numerator % denominator == 0
return numerator // denominator
class LinearBase(nn.Module):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
tp_dim: int | None = None,
):
super().__init__()
self.tp_dim = tp_dim
self.tp_rank = tensor_parallel_rank_for_sharding()
self.tp_size = tensor_parallel_world_size_for_sharding()
self.weight = nn.Parameter(torch.empty(output_size, input_size))
self.weight.weight_loader = self.weight_loader
if bias:
self.bias = nn.Parameter(torch.empty(output_size))
self.bias.weight_loader = self.weight_loader
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
class ReplicatedLinear(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
):
super().__init__(input_size, output_size, bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param.data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight, self.bias)
class ColumnParallelLinear(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
):
tp_size = tensor_parallel_world_size_for_sharding()
super().__init__(input_size, divide(output_size, tp_size), bias, 0)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
shard_size = param_data.size(self.tp_dim)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
param_data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight, self.bias)
class MergedColumnParallelLinear(ColumnParallelLinear):
def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = False,
):
self.output_sizes = output_sizes
super().__init__(input_size, sum(output_sizes), bias)
def weight_loader(
self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int
):
param_data = param.data
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
param_data.copy_(loaded_weight)
class QKVParallelLinear(ColumnParallelLinear):
def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: int | None = None,
bias: bool = False,
):
tp_size = tensor_parallel_world_size_for_sharding()
total_num_kv_heads = total_num_kv_heads or total_num_heads
self.head_size = head_size
self.num_heads = divide(total_num_heads, tp_size)
self.num_kv_heads = divide(total_num_kv_heads, tp_size)
output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size
super().__init__(hidden_size, output_size, bias)
def weight_loader(
self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str
):
param_data = param.data
assert loaded_shard_id in ["q", "k", "v"]
if loaded_shard_id == "q":
shard_size = self.num_heads * self.head_size
shard_offset = 0
elif loaded_shard_id == "k":
shard_size = self.num_kv_heads * self.head_size
shard_offset = self.num_heads * self.head_size
else:
shard_size = self.num_kv_heads * self.head_size
shard_offset = (
self.num_heads * self.head_size + self.num_kv_heads * self.head_size
)
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
param_data.copy_(loaded_weight)
class RowParallelLinear(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
):
tp_size = tensor_parallel_world_size_for_sharding()
super().__init__(divide(input_size, tp_size), output_size, bias, 1)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
shard_size = param_data.size(self.tp_dim)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
param_data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
if self.tp_size > 1:
tensor_parallel_all_reduce(y)
return y
import torch
import torch.distributed as dist
from vllm.kvprune.triton_kernels.matmul_ogs import matmul_ogs
from vllm.kvprune.utils.tp_collectives import tensor_parallel_all_reduce
from vllm.kvprune.utils.tp_utils import (
tensor_parallel_rank_for_sharding,
tensor_parallel_world_size_for_sharding,
)
from torch import nn
def divide(numerator, denominator):
assert numerator % denominator == 0
return numerator // denominator
class TritonFusedMoeLinearBase(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
num_experts: int,
bias: bool = False,
tp_dim: int | None = None,
) -> None:
super().__init__()
self.tp_dim = tp_dim
self.tp_rank = tensor_parallel_rank_for_sharding()
self.tp_size = tensor_parallel_world_size_for_sharding()
self.in_features = in_features
self.out_features = out_features
self.num_experts = num_experts
self.weight = nn.Parameter(
torch.empty((num_experts, in_features, out_features)).transpose(-1, -2)
)
self.weight.weight_loader = self.weight_loader
if bias:
self.bias = nn.Parameter(torch.empty((num_experts, out_features)))
self.bias.weight_loader = self.weight_loader
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
raise NotImplementedError
class ReplicatedTritonFusedMoeLinear(TritonFusedMoeLinearBase):
def __init__(
self,
in_features: int,
out_features: int,
num_experts: int,
bias: bool = False,
) -> None:
super().__init__(in_features, out_features, num_experts, bias)
def weight_loader(
self, param: nn.Parameter, loaded_weight: torch.Tensor, expert_idx: int
):
param.data[expert_idx].copy_(loaded_weight, non_blocking=True)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
w = self.weight.transpose(-1, -2)
assert w.is_contiguous()
return matmul_ogs(
x,
self.weight,
self.bias,
**kwargs,
)
class RowParallelTritonFusedMoeLinear(TritonFusedMoeLinearBase):
def __init__(
self,
in_features: int,
out_features: int,
num_experts: int,
bias: bool = False,
) -> None:
tp_size = (
tensor_parallel_world_size_for_sharding()
if dist.is_initialized()
else 1
)
super().__init__(
divide(in_features, tp_size), out_features, num_experts, bias, 2
)
def weight_loader(
self, param: nn.Parameter, loaded_weight: torch.Tensor, expert_idx: int
):
shard_size = param.size(2)
start_idx = self.tp_rank * shard_size
local_shard = loaded_weight[:, start_idx : start_idx + shard_size]
param.data[expert_idx].copy_(local_shard, non_blocking=True)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
w = self.weight.transpose(-1, -2)
assert w.is_contiguous()
y = matmul_ogs(
x,
w,
self.bias,
**kwargs,
)
if self.tp_size > 1:
tensor_parallel_all_reduce(y)
return y
class ColumnParallelTritonFusedMoeLinear(TritonFusedMoeLinearBase):
def __init__(
self,
in_features: int,
out_features: int,
num_experts: int,
bias: bool = False,
) -> None:
tp_size = (
tensor_parallel_world_size_for_sharding()
if dist.is_initialized()
else 1
)
super().__init__(
in_features, divide(out_features, tp_size), num_experts, bias, 1
)
def weight_loader(
self, param: nn.Parameter, loaded_weight: torch.Tensor, expert_idx: int
):
shard_size = param.size(1)
start_idx = self.tp_rank * shard_size
local_shard = loaded_weight[start_idx : start_idx + shard_size, :]
param.data[expert_idx].copy_(local_shard, non_blocking=True)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
w = self.weight.transpose(-1, -2)
assert w.is_contiguous()
y = matmul_ogs(
x,
w,
self.bias,
**kwargs,
)
return y
class MergedColumnParallelTritonFusedMoeLinear(ColumnParallelTritonFusedMoeLinear):
def __init__(
self,
in_features: int,
out_feature_list: list[int],
num_experts: int,
bias: bool = False,
):
self.out_feature_list = out_feature_list
super().__init__(in_features, sum(out_feature_list), num_experts, bias)
def weight_loader(
self,
param: nn.Parameter,
loaded_weight: torch.Tensor,
expert_idx: int,
shard_id: int,
):
param_data = param.data
shard_offset = sum(self.out_feature_list[:shard_id]) // self.tp_size
shard_size = self.out_feature_list[shard_id] // self.tp_size
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
local_weight = loaded_weight.chunk(self.tp_size, dim=self.tp_dim - 1)[
self.tp_rank
]
param_data[expert_idx].copy_(local_weight, non_blocking=True)
import math
from functools import lru_cache
from typing import Any
import torch
from torch import nn
def apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
x1, x2 = torch.chunk(x.float(), 2, dim=-1)
y1 = x1 * cos - x2 * sin
y2 = x2 * cos + x1 * sin
return torch.cat((y1, y2), dim=-1).to(x.dtype)
def rope_theta_from_hf_config(config: Any) -> float:
"""Match vLLM/HF: ``rope_theta`` may live only under ``rope_parameters`` in config.json."""
rp = getattr(config, "rope_parameters", None)
if isinstance(rp, dict) and "rope_theta" in rp:
return float(rp["rope_theta"])
return float(getattr(config, "rope_theta", 1_000_000.0))
class RotaryEmbedding(nn.Module):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
rope_scaling: tuple | None,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
inv_freq = 1.0 / (
base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim)
)
if rope_scaling is not None:
(
rope_type,
factor,
low_freq_factor,
high_freq_factor,
original_max_position_embeddings,
) = rope_scaling
assert rope_type == "llama3"
old_context_len = original_max_position_embeddings
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
wavelen = 2 * math.pi / inv_freq
inv_freq_llama = torch.where(
wavelen > low_freq_wavelen, inv_freq / factor, inv_freq
)
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
smoothed_inv_freq = (
1 - smooth_factor
) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(
wavelen > low_freq_wavelen
)
inv_freq = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1)
self.register_buffer("cos_sin_cache", cache, persistent=False)
# @torch.compile
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
cache_len = self.cos_sin_cache.shape[0]
# CUDA graph capture forbids device→CPU sync (e.g. ``.item()``) inside the
# captured region; :meth:`ModelRunner.capture_cudagraph` runs decode with
# placeholder positions. Skip the range check while capturing; eager runs
# still validate.
_capturing = (
torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()
)
if positions.numel() > 0 and not _capturing:
pmax = int(positions.max().item())
pmin = int(positions.min().item())
if pmax >= cache_len or pmin < 0:
raise ValueError(
f"RoPE positions out of range: need 0 <= pos < {cache_len}, "
f"got min={pmin}, max={pmax}. "
"Shorten the prompt or increase max_model_len (and align vLLM "
"RoPE cos_sin_cache with tie_kvprune_rope_buffers_from_vllm)."
)
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
query = apply_rotary_emb(query, cos, sin)
key = apply_rotary_emb(key, cos, sin)
return query, key
@lru_cache(1)
def get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: float,
rope_scaling: tuple | None = None,
):
rotary_emb = RotaryEmbedding(
head_size, rotary_dim, max_position, base, rope_scaling
)
return rotary_emb
import torch
from torch import nn
class Sampler(nn.Module):
def __init__(self):
super().__init__()
# @torch.compile
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
temps = temperatures.view(-1)
scaled = logits.float()
greedy_mask = temps == 0.0
sample_mask = ~greedy_mask
if sample_mask.any():
temps_sample = temps[sample_mask].unsqueeze(-1) # [B_sample, 1]
scaled_sample = scaled[sample_mask].div(temps_sample) # temperature scaling
E = torch.empty_like(scaled_sample).exponential_(1).clamp_min_(1e-10).log()
scaled_sample = scaled_sample - E
scaled = scaled.clone()
scaled[sample_mask] = scaled_sample
return scaled.argmax(dim=-1)
import torch
import triton
import triton.language as tl
@triton.jit
def _masked_index_select_kernel(
X_ptr,
IDX_ptr,
OUT_ptr,
N,
stride_xn,
stride_xh,
stride_ob,
stride_oh,
):
b = tl.program_id(0) # which output row (0..B-1)
h = tl.program_id(1)
idx = tl.load(IDX_ptr + b) # int32
valid = (idx >= 0) & (idx < N)
out_ptrs = OUT_ptr + b * stride_ob + h * stride_oh
if not valid:
tl.store(out_ptrs, 0)
else:
x_ptrs = X_ptr + idx * stride_xn + h * stride_xh
vals = tl.load(x_ptrs)
tl.store(out_ptrs, vals)
def masked_index_select_triton_dim0(
input: torch.Tensor, index: torch.Tensor
) -> torch.Tensor:
"""
X: [N, H] : contiguous in the H dimension
b_m: [B] int32/int64 on same device; out-of-range -> zeros)
Returns: [B, H]
"""
assert input.ndim == 2 and index.ndim == 1
N, H = input.shape
B = index.numel()
out = torch.empty((B, H), dtype=input.dtype, device=input.device)
_masked_index_select_kernel[(B, H)](
input,
index,
out,
N,
input.stride(0),
input.stride(1),
out.stride(0),
out.stride(1),
)
return out
@triton.jit
def _masked_index_copy_kernel(
DST_ptr,
IDX_ptr,
SRC_ptr,
N,
stride_dn,
stride_dh,
stride_sb,
stride_sh,
):
b = tl.program_id(0)
h = tl.program_id(1)
idx = tl.load(IDX_ptr + b)
valid = (idx >= 0) & (idx < N)
if valid:
src_ptrs = SRC_ptr + b * stride_sb + h * stride_sh
dst_ptrs = DST_ptr + idx * stride_dn + h * stride_dh
tl.store(dst_ptrs, tl.load(src_ptrs))
def masked_index_copy_triton_dim0(
dst: torch.Tensor, index: torch.Tensor, src: torch.Tensor
):
"""
In-place: dst.index_copy_(0, index, src) but masked:
- rows with index[b] < 0 or >= dst.shape[0] are skipped (no write).
Shapes:
dst: [N, H]
src: [B, H]
index: [B]
"""
assert dst.ndim == 2 and src.ndim == 2 and index.ndim == 1
N, H = dst.shape
B, Hs = src.shape
assert Hs == H and index.numel() == B
_masked_index_copy_kernel[(B, H)](
dst,
index,
src,
N,
dst.stride(0),
dst.stride(1),
src.stride(0),
src.stride(1),
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
from vllm.kvprune.models.llama3 import LlamaForCausalLM
from vllm.kvprune.models.qwen3 import Qwen3ForCausalLM
logger = logging.getLogger(__name__)
MODEL_REGISTRY = {
"llama": LlamaForCausalLM,
"qwen3": Qwen3ForCausalLM,
}
try:
from vllm.kvprune.models.qwen3_moe import Qwen3MoeForCausalLM
except Exception as exc:
logger.warning("Disabling qwen3_moe due to import error: %s", exc)
else:
MODEL_REGISTRY["qwen3_moe"] = Qwen3MoeForCausalLM
import os
from glob import glob
import torch
import tqdm
from safetensors import safe_open
from torch import nn
from transformers import LlamaConfig
from vllm.kvprune.compression import (
CompressionMethod,
apply_postrope_compression,
apply_prerope_compression,
)
from vllm.kvprune.layers.activation import SiluAndMul
from vllm.kvprune.layers.attention import Attention
from vllm.kvprune.layers.embed_head import ParallelLMHead, VocabParallelEmbedding
from vllm.kvprune.layers.layernorm import RMSNorm
from vllm.kvprune.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.kvprune.layers.rotary_embedding import get_rope
from vllm.kvprune.utils.context import get_context
from vllm.kvprune.utils.tp_utils import tensor_parallel_world_size_for_sharding
class LlamaAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
head_dim: int | None = None,
qkv_bias: bool = False,
rope_theta: float = 10000,
rope_scaling: dict | None = None,
) -> None:
super().__init__()
tp_size = tensor_parallel_world_size_for_sharding()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
)
if rope_scaling is not None:
rope_scaling_tuple = (
rope_scaling["rope_type"],
rope_scaling["factor"],
rope_scaling["low_freq_factor"],
rope_scaling["high_freq_factor"],
rope_scaling["original_max_position_embeddings"],
)
else:
rope_scaling_tuple = None
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=rope_theta,
rope_scaling=rope_scaling_tuple,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
self.num_kv_heads,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
context = get_context()
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
scores = None
if context.is_prefill and context.do_compression:
scores = apply_prerope_compression(q, k, v, context)
q, k = self.rotary_emb(positions, q, k)
if context.is_prefill and context.do_compression:
cc = context.compression_context
if cc is not None and cc.compression_method == CompressionMethod.CRITICALADAKV:
# 关键:注入 wo_weight 到 compression_context
wo_raw = self.o_proj.weight
hidden_size, _ = wo_raw.shape
Hq, D = self.num_heads, self.head_dim
cc.wo_weight = (
wo_raw.transpose(0, 1)
.contiguous()
.view(Hq, D, hidden_size)
.to(dtype=torch.float32)
)
scores = apply_postrope_compression(q, k, v, scores, context)
o = self.attn(q, k, v, scores)
output = self.o_proj(o.flatten(1, -1))
return output
class LlamaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
mlp_bias: bool,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=mlp_bias,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=mlp_bias,
)
assert hidden_act == "silu"
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x = self.down_proj(x)
return x
class LlamaDecoderLayer(nn.Module):
def __init__(
self,
config: LlamaConfig,
) -> None:
super().__init__()
self.self_attn = LlamaAttention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
max_position=config.max_position_embeddings,
qkv_bias=getattr(config, "attention_bias", False),
head_dim=getattr(config, "head_dim", None),
rope_theta=getattr(config, "rope_theta", 500000.0),
rope_scaling=getattr(config, "rope_scaling", None),
)
self.mlp = LlamaMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
mlp_bias=config.mlp_bias,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(positions, hidden_states)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class LlamaModel(nn.Module):
def __init__(
self,
config: LlamaConfig,
) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.hidden_size
)
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class LlamaForCausalLM(nn.Module):
packed_modules_mapping = {
"q_proj": ("qkv_proj", "q"),
"k_proj": ("qkv_proj", "k"),
"v_proj": ("qkv_proj", "v"),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(self, config: LlamaConfig) -> None:
super().__init__()
self.model = LlamaModel(config)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
org_vocab_size=config.vocab_size,
)
if config.tie_word_embeddings:
self.lm_head.weight.data = self.model.embed_tokens.weight.data
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
return self.model(input_ids, positions)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
return self.lm_head(hidden_states)
def load_model(
self,
path: str,
*,
use_tqdm: bool = False,
) -> None:
all_shards = glob(os.path.join(path, "*.safetensors"))
for file in (
tqdm.tqdm(all_shards, desc="Loading model") if use_tqdm else all_shards
):
with safe_open(file, "pt", "cpu") as f:
for weight_name in f.keys():
weight_tensor = f.get_tensor(weight_name)
is_loaded = False
# Load packed modules
for k in self.packed_modules_mapping:
if k in weight_name:
v, shard_id = self.packed_modules_mapping[k]
param_name = weight_name.replace(k, v)
param = self.get_parameter(param_name)
weight_loader = getattr(param, "weight_loader")
weight_loader(param, weight_tensor, shard_id)
is_loaded = True
break
# Load other modules
if not is_loaded:
param = self.get_parameter(weight_name)
weight_loader = getattr(
param,
"weight_loader",
lambda p, loaded_weight: p.data.copy_(loaded_weight),
)
weight_loader(param, weight_tensor)
is_loaded = True
assert is_loaded, f"Weight {weight_name} not loaded"
import os
from glob import glob
import torch
import tqdm
from safetensors import safe_open
from torch import nn
from transformers import Qwen3Config
from vllm.kvprune.compression import (
CompressionMethod,
apply_postrope_compression,
apply_prerope_compression,
)
from vllm.kvprune.layers.activation import SiluAndMul
from vllm.kvprune.layers.attention import Attention
from vllm.kvprune.layers.embed_head import ParallelLMHead, VocabParallelEmbedding
from vllm.kvprune.layers.layernorm import RMSNorm
from vllm.kvprune.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.kvprune.layers.rotary_embedding import get_rope, rope_theta_from_hf_config
from vllm.kvprune.utils.context import get_context
from vllm.kvprune.utils.tp_utils import tensor_parallel_world_size_for_sharding
class Qwen3Attention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
head_dim: int | None = None,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
rope_theta: float = 10000,
rope_scaling: tuple | None = None,
) -> None:
super().__init__()
tp_size = tensor_parallel_world_size_for_sharding()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
self.num_kv_heads,
)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
context = get_context()
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = self.q_norm(q.view(-1, self.num_heads, self.head_dim))
k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim))
scores = None
if context.is_prefill and context.do_compression:
scores = apply_prerope_compression(q, k, v, context)
v = v.view(-1, self.num_kv_heads, self.head_dim)
q, k = self.rotary_emb(positions, q, k)
if context.is_prefill and context.do_compression:
cc = context.compression_context
if cc is not None and cc.compression_method == CompressionMethod.CRITICALADAKV:
# 关键:注入 wo_weight 到 compression_context
wo_raw = self.o_proj.weight
hidden_size, _ = wo_raw.shape
Hq, D = self.num_heads, self.head_dim
cc.wo_weight = (
wo_raw.transpose(0, 1)
.contiguous()
.view(Hq, D, hidden_size)
.to(dtype=torch.float32)
)
scores = apply_postrope_compression(q, k, v, scores, context)
o = self.attn(q, k, v, scores)
output = self.o_proj(o.flatten(1, -1))
return output
class Qwen3MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
)
assert hidden_act == "silu"
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x = self.down_proj(x)
return x
class Qwen3DecoderLayer(nn.Module):
def __init__(
self,
config: Qwen3Config,
) -> None:
super().__init__()
head_dim = getattr(config, "head_dim", None)
if head_dim is None:
head_dim = config.hidden_size // config.num_attention_heads
rope_theta = rope_theta_from_hf_config(config)
rs = getattr(config, "rope_scaling", None)
rope_scaling_tuple: tuple | None = rs if isinstance(rs, tuple) else None
self.self_attn = Qwen3Attention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
max_position=config.max_position_embeddings,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, "attention_bias", False),
head_dim=head_dim,
rope_theta=rope_theta,
rope_scaling=rope_scaling_tuple,
)
self.mlp = Qwen3MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(positions, hidden_states)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class Qwen3Model(nn.Module):
def __init__(
self,
config: Qwen3Config,
) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.hidden_size
)
self.layers = nn.ModuleList(
[Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class Qwen3ForCausalLM(nn.Module):
packed_modules_mapping = {
"q_proj": ("qkv_proj", "q"),
"k_proj": ("qkv_proj", "k"),
"v_proj": ("qkv_proj", "v"),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(self, config: Qwen3Config) -> None:
super().__init__()
self.model = Qwen3Model(config)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
org_vocab_size=config.vocab_size,
)
if config.tie_word_embeddings:
self.lm_head.weight.data = self.model.embed_tokens.weight.data
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
return self.model(input_ids, positions)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
return self.lm_head(hidden_states)
def load_model(
self,
path: str,
*,
use_tqdm: bool = False,
) -> None:
all_shards = glob(os.path.join(path, "*.safetensors"))
for file in (
tqdm.tqdm(all_shards, desc="Loading model") if use_tqdm else all_shards
):
with safe_open(file, "pt", "cpu") as f:
for weight_name in f.keys():
weight_tensor = f.get_tensor(weight_name)
is_loaded = False
# Load packed modules
for k in self.packed_modules_mapping:
if k in weight_name:
v, shard_id = self.packed_modules_mapping[k]
param_name = weight_name.replace(k, v)
param = self.get_parameter(param_name)
weight_loader = getattr(param, "weight_loader")
weight_loader(param, weight_tensor, shard_id)
is_loaded = True
break
# Load other modules
if not is_loaded:
param = self.get_parameter(weight_name)
weight_loader = getattr(
param,
"weight_loader",
lambda p, loaded_weight: p.data.copy_(loaded_weight),
)
weight_loader(param, weight_tensor)
is_loaded = True
assert is_loaded, f"Weight {weight_name} not loaded"
import os
from glob import glob
import torch
import tqdm
from safetensors import safe_open
from torch import nn
from transformers import Qwen3MoeConfig
from vllm.kvprune.compression import (
CompressionMethod,
apply_postrope_compression,
apply_prerope_compression,
)
from vllm.kvprune.layers.activation import SiluAndMul
from vllm.kvprune.layers.attention import Attention
from vllm.kvprune.layers.embed_head import ParallelLMHead, VocabParallelEmbedding
from vllm.kvprune.layers.layernorm import RMSNorm
from vllm.kvprune.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.kvprune.layers.moe import (
MergedColumnParallelTritonFusedMoeLinear,
RowParallelTritonFusedMoeLinear,
)
from vllm.kvprune.layers.rotary_embedding import get_rope, rope_theta_from_hf_config
from vllm.kvprune.triton_kernels.routing import routing
from vllm.kvprune.utils.context import get_context
from vllm.kvprune.utils.tp_utils import (
tensor_parallel_rank_for_sharding,
tensor_parallel_world_size_for_sharding,
)
class Qwen3MoeAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
head_dim: int | None = None,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
rope_theta: float = 10000,
rope_scaling: tuple | None = None,
sliding_window: int | None = None,
) -> None:
super().__init__()
tp_size = tensor_parallel_world_size_for_sharding()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.sliding_window = sliding_window
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
self.num_kv_heads,
)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
context = get_context()
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = self.q_norm(q.view(-1, self.num_heads, self.head_dim))
k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim))
scores = None
if context.is_prefill and context.do_compression:
scores = apply_prerope_compression(q, k, v, context)
v = v.view(-1, self.num_kv_heads, self.head_dim)
q, k = self.rotary_emb(positions, q, k)
if context.is_prefill and context.do_compression:
cc = context.compression_context
if cc is not None and cc.compression_method == CompressionMethod.CRITICALADAKV:
# 关键:注入 wo_weight 到 compression_context
wo_raw = self.o_proj.weight
hidden_size, _ = wo_raw.shape
Hq, D = self.num_heads, self.head_dim
cc.wo_weight = (
wo_raw.transpose(0, 1)
.contiguous()
.view(Hq, D, hidden_size)
.to(dtype=torch.float32)
)
scores = apply_postrope_compression(q, k, v, scores, context)
o = self.attn(q, k, v, scores)
output = self.o_proj(o.flatten(1, -1))
return output
class Qwen3MoeMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
)
assert hidden_act == "silu"
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x = self.down_proj(x)
return x
class Qwen3MoeTritonSparseMoeBlock(nn.Module):
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
num_experts_per_tok: int,
norm_topk_prob: bool,
hidden_act: str,
) -> None:
super().__init__()
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
self.norm_topk_prob = norm_topk_prob
self.hidden_size = hidden_size
self.moe_intermediate_size = intermediate_size
self.gate = ReplicatedLinear(hidden_size, num_experts, bias=False)
self.gate_up_proj = MergedColumnParallelTritonFusedMoeLinear(
hidden_size, [intermediate_size] * 2, num_experts
)
self.down_proj = RowParallelTritonFusedMoeLinear(
intermediate_size, hidden_size, num_experts
)
self.act_fn = SiluAndMul()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
x = hidden_states
if x.numel() == 0:
return x
logits = self.gate(x)
rdata, gather_indx, scatter_indx = routing(
logits,
self.num_experts_per_tok,
simulated_ep=1, # single device, replicated experts
)
x = self.gate_up_proj(x, routing_data=rdata, gather_indx=gather_indx)
x = self.act_fn(x)
x = self.down_proj(
x, routing_data=rdata, scatter_indx=scatter_indx, gammas=rdata.gate_scal
)
return x
class Qwen3MoeBlock(Qwen3MoeTritonSparseMoeBlock):
pass
class Qwen3MoeRMSNorm(RMSNorm):
pass
class Qwen3MoeDecoderLayer(nn.Module):
def __init__(
self,
config: Qwen3MoeConfig,
layer_idx: int,
) -> None:
super().__init__()
head_dim = getattr(config, "head_dim", None)
if head_dim is None:
head_dim = config.hidden_size // config.num_attention_heads
rope_theta = rope_theta_from_hf_config(config)
rs = getattr(config, "rope_scaling", None)
rope_scaling_tuple: tuple | None = rs if isinstance(rs, tuple) else None
self.self_attn = Qwen3MoeAttention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
max_position=config.max_position_embeddings,
head_dim=head_dim,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, "attention_bias", False),
rope_theta=rope_theta,
rope_scaling=rope_scaling_tuple,
sliding_window=config.sliding_window,
)
if (layer_idx not in config.mlp_only_layers) and (
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
):
self.mlp = Qwen3MoeBlock(
num_experts=config.num_experts,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
num_experts_per_tok=config.num_experts_per_tok,
norm_topk_prob=config.norm_topk_prob,
hidden_act=config.hidden_act,
)
else:
self.mlp = Qwen3MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = Qwen3MoeRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_attention_layernorm = Qwen3MoeRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(positions, hidden_states)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Qwen3MoeModel(nn.Module):
def __init__(
self,
config: Qwen3MoeConfig,
) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.hidden_size
)
self.layers = nn.ModuleList(
[
Qwen3MoeDecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for decoder_layer in self.layers:
hidden_states = decoder_layer(
hidden_states,
position_ids,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class Qwen3MoeForCausalLM(nn.Module):
packed_modules_mapping = {
"q_proj": ("qkv_proj", "q"),
"k_proj": ("qkv_proj", "k"),
"v_proj": ("qkv_proj", "v"),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(
self,
config: Qwen3MoeConfig,
) -> None:
super().__init__()
self.model = Qwen3MoeModel(config)
self.num_experts = config.num_experts
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
org_vocab_size=config.vocab_size,
)
if config.tie_word_embeddings:
self.lm_head.weight.data = self.model.embed_tokens.weight.data
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
) -> torch.Tensor:
return self.model(input_ids, position_ids)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
return self.lm_head(hidden_states)
def load_model(
self,
path: str,
*,
use_tqdm: bool = False,
) -> None:
rank = tensor_parallel_rank_for_sharding()
device = torch.cuda.current_device() if torch.cuda.is_available() else rank
all_shards = glob(os.path.join(path, "*.safetensors"))
for file in (
tqdm.tqdm(all_shards, desc="Loading model") if use_tqdm else all_shards
):
with safe_open(file, "pt", f"cuda:{device}") as f:
for weight_name in f.keys():
weight_tensor = f.get_tensor(weight_name)
is_expert = "mlp.experts" in weight_name
is_loaded = False
# Process experts params name
if is_expert:
mlp_module_name, expert_module_name = weight_name.split(
".experts."
)
expert_idx = int(expert_module_name.split(".")[0])
proj_name = expert_module_name.replace(f"{expert_idx}.", "")
weight_name = f"{mlp_module_name}.{proj_name}"
# Load packed modules
for k in self.packed_modules_mapping:
if k in weight_name:
v, shard_id = self.packed_modules_mapping[k]
param_name = weight_name.replace(k, v)
param = self.get_parameter(param_name)
weight_loader = getattr(param, "weight_loader")
if is_expert:
weight_loader(
param, weight_tensor, expert_idx, shard_id
)
else:
weight_loader(param, weight_tensor, shard_id)
is_loaded = True
break
# Load other modules
if not is_loaded:
param = self.get_parameter(weight_name)
weight_loader = getattr(
param,
"weight_loader",
lambda p, lw: p.data.copy_(lw, non_blocking=True),
)
if is_expert:
weight_loader(param, weight_tensor, expert_idx)
else:
weight_loader(param, weight_tensor)
is_loaded = True
assert is_loaded, f"Weight {weight_name} not loaded"
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Triton kernel utilities (matmul_ogs, MoE, topk, …) plus KV-facing entrypoints.
For KV pruning attention/store, see also ``vllm.kvprune.attention`` and
``vllm.kvprune.kv_cache``.
"""
from vllm.kvprune.attention.sparse_varlen_kernel import causal_sparse_varlen_with_cache
from vllm.kvprune.kv_cache.store_kv_cache import (
decode_store_kv,
prefill_store_all_kv,
prefill_store_topk_kv,
)
__all__ = [
"causal_sparse_varlen_with_cache",
"decode_store_kv",
"prefill_store_all_kv",
"prefill_store_topk_kv",
]
import torch
from .compaction_details._masked_compaction import _masked_compaction
from .tensor import Bitmatrix
def compaction(yv, yi, bitmask, sentinel=-1):
"""
Return compacted copies of *yv* and *yi* based on a per-row bitmask.
Only the elements whose index appears among the active bits of *bitmask*
are kept; the rest are replaced by *sentinel*. Kept elements preserve
their original left-to-right order.
Parameters
----------
yv : torch.Tensor, shape (B, K)
Values tensor.
yi : torch.Tensor, shape (B, K), dtype torch.long
Integer indices (0 ≤ index < 32) associated with *yv*.
bitmask : torch.Tensor, shape (B,) **or** (B, 32)
Per-row mask of active indices. See the in-place version for details.
sentinel : int, default -1
Value written into dropped positions of the returned tensors.
Returns
-------
(yv_out, yi_out) : Tuple[torch.Tensor, torch.Tensor], each shape (B, K)
New tensors with the same dtype/device as the inputs.
"""
n_rows, n_cols = yi.shape
ret_yv = torch.empty_like(yv)
ret_yi = torch.empty_like(yi)
if isinstance(bitmask, Bitmatrix):
bitmask = bitmask.storage.data
_masked_compaction[(n_rows,)](
yv,
yi,
bitmask,
bitmask.stride(0),
bitmask.stride(1), # inputs
ret_yv,
ret_yi, # outputs
sentinel, # sentinel
K=n_cols, # constants
)
return ret_yv, ret_yi
def compaction_torch(
yv: torch.Tensor, yi: torch.Tensor, bitmask: torch.Tensor, sentinel=-1
):
"""
reference implementation of `masked_compact`
"""
B, K = yi.shape
device = yi.device
# Expand bitmask to a boolean matrix of active bits (B, 32)
w = 1 << torch.arange(32, device=device, dtype=bitmask.dtype)
bits = (bitmask.unsqueeze(-1) & w) != 0
mask = bits.flatten(start_dim=-2) # or bits.reshape(B, -1)
# For every yi element decide whether it should be kept
keep = mask.gather(1, yi.long())
# Build a stable permutation that brings all "keep" items forward
# False→0, True→1 ==> invert so kept==0, dropped==1, then argsort
order = (~keep).to(torch.int).argsort(dim=1, stable=True)
# Re‑order tensors according to above permutation
yi_sorted = yi.gather(1, order)
yv_sorted = yv.gather(1, order)
# fill relevant positions with sentinel
keep_sorted = keep.gather(1, order)
yi_sorted[~keep_sorted] = sentinel
yv_sorted[~keep_sorted] = sentinel
return yv_sorted, yi_sorted
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment