Commit d29c39ca authored by chenzk's avatar chenzk
Browse files

vllm kvprune wo:v1.1.0

parent f81ce56b
import torch
from .base import Layout
class BlackwellMXValueLayout(Layout):
name: str = "BLACKWELL_VALUE"
def __init__(self, shape) -> None:
super().__init__(shape)
self.shape = shape
def swizzle_data(self, data):
# permutation needed to make `data` row major
to_row_major = sorted(range(data.ndim), key=lambda d: (data.stride(d), d))[::-1]
# permutation needed to retrieve original order
inv = [0] * data.ndim
for i, d in enumerate(to_row_major):
inv[d] = i
# leading dimension must be padded to be aligned to 128
align_dim = lambda x: (x + 128 - 1) // 128 * 128
major_dim = data.stride().index(1)
pad = align_dim(data.shape[major_dim]) - data.shape[major_dim]
data = torch.nn.functional.pad(data.permute(to_row_major), (0, pad)).permute(
inv
)
return data
def unswizzle_data(self, data: torch.Tensor):
# Trim padding along all dims back to the original shape recorded at init.
assert data.ndim == len(self.shape), (
"Rank mismatch between data and recorded shape"
)
sizes = [min(data.size(i), self.shape[i]) for i in range(data.ndim)]
return data[tuple(slice(0, s) for s in sizes)]
def swizzle_block_shape(self, block_shape):
return block_shape
import triton
import triton.language as tl
from .base import Layout
NON_K_PRESHUFFLE_BLOCK_SIZE = 32
class CDNA4MXScaleLayout(Layout):
name: str = "CDNA4_SCALE"
def __init__(self, shape) -> None:
super().__init__(shape)
def swizzle_data(self, data):
block_shape = data.shape
SCALE_K = block_shape[-2]
N = block_shape[-1]
data = data.transpose(-1, -2)
data = data.view(
-1, N // NON_K_PRESHUFFLE_BLOCK_SIZE, 2, 16, SCALE_K // 8, 2, 4, 1
)
data = data.permute(0, 1, 4, 6, 3, 5, 2, 7).contiguous()
if len(block_shape) == 3:
E = block_shape[0]
data = data.reshape(E, N // 32, SCALE_K * 32)
else:
assert len(block_shape) == 2
data = data.reshape(N // 32, SCALE_K * 32)
return data.transpose(-1, -2)
def unswizzle_data(self, data):
raise NotImplementedError()
def swizzle_block_shape(self, block_shape):
SCALE_K = block_shape[-2]
N = block_shape[-1]
return block_shape[:-2] + [N // 32, SCALE_K * 32]
@triton.jit
def unswizzle_mx_scale_cdna4(
x,
BLOCK_N: tl.constexpr,
MX_SCALE_BLOCK_K: tl.constexpr,
N_PRESHUFFLE_FACTOR: tl.constexpr = NON_K_PRESHUFFLE_BLOCK_SIZE,
):
x = x.reshape(BLOCK_N // N_PRESHUFFLE_FACTOR, MX_SCALE_BLOCK_K // 8, 4, 16, 2, 2, 1)
x = x.permute(0, 5, 3, 1, 4, 2, 6)
x = x.reshape(BLOCK_N, MX_SCALE_BLOCK_K)
return x
import torch
import triton
import triton.language as tl
from .base import Layout
class HopperMXScaleLayout(Layout):
name: str = "HOPPER_SCALE"
def __init__(self, shape, mx_axis, num_warps=8) -> None:
assert num_warps & (num_warps - 1) == 0, "warps_n must be a power of 2"
super().__init__(shape)
self.mx_axis = mx_axis
self.num_warps = num_warps
*self.leading_shape, _, _ = shape
def _maybe_mT(self, data):
if self.mx_axis == len(self.leading_shape):
return data.contiguous().mT
return data
def swizzle_data(self, data):
data = self._maybe_mT(data).contiguous()
*batch, M, K = data.shape
SWIZZLE_ALIGN_M = 2 * self.num_warps * 2 * 8
SWIZZLE_ALIGN_K = 2
pad_m = (SWIZZLE_ALIGN_M - (M % SWIZZLE_ALIGN_M)) % SWIZZLE_ALIGN_M
pad_k = (SWIZZLE_ALIGN_K - (K % SWIZZLE_ALIGN_K)) % SWIZZLE_ALIGN_K
data = torch.nn.functional.pad(data, (0, pad_k, 0, pad_m))
*batch, M, K = data.shape
assert data.is_contiguous()
assert M % (2 * self.num_warps * 2 * 8) == 0 and K % 2 == 0, (
f"Input tensor must have a subtile of shape (..., {2 * self.num_warps * 2 * 8}, 2)"
)
b = len(batch)
data = data.reshape(
*batch,
M // (2 * self.num_warps * 2 * 8),
2,
self.num_warps,
2,
8,
K // 2,
2,
)
perm = [0, 2, 5, 1, 4, 6, 3]
perm = list(range(b)) + [b + p for p in perm]
data = data.permute(*perm)
data = data.flatten(-5, -1)
data = data.flatten(-3, -2)
assert data.shape[-2] == M // 32
assert data.shape[-1] == K * 32
data = self._maybe_mT(data)
return data
def unswizzle_data(self, data):
data = self._maybe_mT(data)
*batch, M, K = data.shape
b = len(batch)
data = data.reshape(
*batch, M // self.num_warps, self.num_warps, K // 64, 2, 8, 2, 2
)
perm = [0, 3, 1, 6, 4, 2, 5]
perm = list(range(b)) + [b + p for p in perm]
data = data.permute(*perm)
data = data.reshape(*batch, M * 32, K // 32)
data = self._maybe_mT(data)
return data
def swizzle_block_shape(self, block_shape):
return block_shape
@triton.jit
def unswizzle_mxfp4_scale_hopper(x, mx_axis: tl.constexpr, num_warps: tl.constexpr):
"""
Triton inverse of swizzle_mxfp4_scale_hopper
"""
tl.static_assert(len(x.shape) == 2, "NYI")
# implementation assumes mxfp data is packed along the last dimension
x = x.trans() if mx_axis == 0 else x
M: tl.constexpr = x.shape[0]
K: tl.constexpr = x.shape[1]
tl.static_assert(M % num_warps == 0, f"M must be divisible by {num_warps}. Got {M}")
tl.static_assert(K % 64 == 0, f"K must be divisible by 64. Got {K}")
x = x.reshape(M // num_warps, num_warps, K // 64, 2, 8, 2, 2)
x = x.trans(0, 3, 1, 6, 4, 2, 5)
x = x.reshape(M * 32, K // 32)
# implementation assumed mxfp data is packed along the last dimension
x = x.trans() if mx_axis == 0 else x
return x
import torch
import triton
import triton.language as tl
from .base import Layout
def right_shift_unsigned(x, shift):
return (x >> shift) & ((1 << (32 - shift)) - 1)
# -----------------------------------------------------------------------
# Interleave the bits of four consecutive fp4 values (i.e. 16-bits) as:
# 1000000111000000 (first fp4)
# 1000000111000000 (second fp4)
# 1000000111000000 (third fp4)
# 0110110000000000 (fourth fp4)
# This is done so that dequantization can be done in 14 SASS instructions
# -----------------------------------------------------------------------
def _compress_fp4(x):
x = x.to(torch.int32)
return ((x & 0x8) << 12) | ((x & 0x7) << 6)
def _compress_fourth(x):
x = x.to(torch.int32)
return ((x & 0x8) << 11) | ((x & 0x6) << 9) | ((x & 0x1) << 13)
def _pack_bits(x: torch.Tensor, mx_axis: int):
x = x.contiguous()
assert x.shape[-1] % 4 == 0, (
"Input tensor must have a last dimension divisible by 4"
)
x = x.reshape(x.shape[:-1] + (x.shape[-1] // 4, 4))
first = _compress_fp4(x[..., 0]) | (_compress_fp4(x[..., 0] >> 4) << 16)
second = _compress_fp4(x[..., 1]) | (_compress_fp4(x[..., 1] >> 4) << 16)
third = _compress_fp4(x[..., 2]) | (_compress_fp4(x[..., 2] >> 4) << 16)
fourth = _compress_fourth(x[..., 3]) | (_compress_fourth(x[..., 3] >> 4) << 16)
x = (
first
| right_shift_unsigned(second, 3)
| right_shift_unsigned(third, 6)
| fourth
)
assert x.is_contiguous()
x = x.view(torch.uint8)
return x
# -----------------------------------------------------------------------
# inverse operation of _pack_bits
# -----------------------------------------------------------------------
def _bf16_to_fp4e2m1(x):
# 0bAxxxxxxBCDxxxxxx (int16) -> 0b0000ABCD (uint8)
assert x.dtype == torch.int16
s = (right_shift_unsigned(x, 15) & 0x1) << 3
em = right_shift_unsigned(x, 6) & 0x7
return (s | em).to(torch.uint8)
def _bf16x2_to_fp4e2m1x2(x):
# 0bAxxxxxxBCDxxxxxx_0bExxxxxxFGHxxxxxx (int32) -> 0bABCD_EFGH (uint8)
assert x.dtype == torch.int32
lo = (x & 0xFFFF).to(torch.int16)
hi = (right_shift_unsigned(x, 16) & 0xFFFF).to(torch.int16)
ret_lo = _bf16_to_fp4e2m1(lo)
ret_hi = _bf16_to_fp4e2m1(hi)
return ret_lo | (ret_hi << 4)
def _unpack_bits(x, mx_axis: int):
x = x.view(torch.int32)
m = 0b10000001110000001000000111000000
a = (x << 1) & 0b10000000000000001000000000000000
b = right_shift_unsigned(x, 3) & 0b00000001100000000000000110000000
c = right_shift_unsigned(x, 7) & 0b00000000010000000000000001000000
unpacked = [x & m, (x << 3) & m, (x << 6) & m, (a | b) | c]
x = torch.stack(unpacked, dim=-1)
x = x.flatten(-2, -1)
x = _bf16x2_to_fp4e2m1x2(x)
return x
# -----------------------------------------------------------------------
class HopperMXValueLayout(Layout):
name: str = "HOPPER_VALUE"
def __init__(self, shape, mx_axis, mma_version=3):
super().__init__(shape)
assert mx_axis in range(len(shape))
self.mx_axis = mx_axis
self.mma_version = mma_version
(
*self.leading_shape,
self.K,
self.N,
) = shape
def _maybe_mT(self, data):
if self.mx_axis == len(self.leading_shape):
return data.mT
return data
def swizzle_data(self, data):
"""
Given a uint8 tensor of shape (*, M, K), returns a tensor of shape
(*, M // 4, K * 4) such that:
1) Groups contiguously all the elements owned by the same thread of 4
mma tiles along the K axis. The following animation shows a similar
grouping for 2 tiles along M and 2 tiles along K rather than 4 along K
as done here:
https://neuralmagic.com/wp-content/uploads/2024/10/animation_4.gif
2) Moves the elements belonging to thread 4-7 to be contiguous with those
from thread 0-3. This is done to get a full cache line when loading them
from HBM.
mx_axis selects the lhs or rhs of the matmul.
WARNING: Assumes that the matmul will be done in bf16 or fp16!
Implementing it for fp8 is as easy as making the tile size (8, 8)
"""
batch = data.ndim - 2
assert batch >= 0
assert self.mma_version in (2, 3)
data = self._maybe_mT(data)
init_shape = data.shape
# We are loading 8 bf16 elements per thread to use ld.global.v4
# Every u8 represents 2 mxfp4 elements
u8_kwidth = 8 // 2 if self.mma_version == 2 else 1
# Pack the 4 // u8_kwidth subtiles of an mma into a u4x8
contig = (1, u8_kwidth)
scott_trick = (2, 1)
threads = (4, 4)
warp_tile = (2, 2)
k_tile = (1, 4 // u8_kwidth)
sizes = list(data.shape[:-2])
pads = []
# [rest, K, tile, threads] per dimension
for i, (a, b, c, s, d) in enumerate(
zip(k_tile, warp_tile, threads, scott_trick, contig)
):
pack = a * b * c * s * d
size = data.shape[batch + i]
pad = (pack - size % pack) % pack
pads += [(0, pad)]
sizes.append((size + pad) // pack)
sizes += [a, b, c, s, d]
pads = tuple(x for t in pads[::-1] for x in t)
data = torch.nn.functional.pad(data, pads)
init_shape = data.shape
# 0: rest[0]
# 1: k_tile[0]
# 2: warp_tile[0]
# 3: threads[0]
# 4: scott_trick[0]
# 5: contig[0]
# 6: rest[1]
# 7: k_tile[1]
# 8: warp_tile[1]
# 9: threads[1]
# 10: scott_trick[1]
# 11: contig[1]
data = data.view(*sizes)
# Want [rest[0], threads[0], rest[1], scott_trick[0], scott_trick[0], threads[1], contig[1], contig[0], k_tile[1], k_tile[0], warp_tile[1], warp_tile[0]]
perm = [0, 3, 6, 10, 4, 9, 7, 1, 8, 2, 5, 11]
perm = list(range(batch)) + [batch + p for p in perm]
data = data.permute(*perm).contiguous()
# These are views
data = data.flatten(-10, -1)
data = data.flatten(-3, -2)
assert data.is_contiguous()
assert data.shape[-2] == init_shape[-2] // 4
assert data.shape[-1] == init_shape[-1] * 4
# twiddle the bits
data = _pack_bits(data, self.mx_axis)
data = self._maybe_mT(data)
return data
def unswizzle_data(self, data):
data = self._maybe_mT(data)
data = _unpack_bits(data, self.mx_axis)
*batch, M, K = data.shape
# We have two times the elements if we already upcasted to bfloat16
mult = 2 if data.dtype == torch.bfloat16 else 1
assert M % 4 == 0, "M must be divisible by 4"
assert K % (4 * 8 * 2 * 2 * mult) == 0, (
f"K must be divisible by {4 * 8 * 2 * 2 * mult}"
)
# We are loading 8 bf16 elements per thread to use ld.global.v4
# Every u8 represents 2 mxfp4 elements
u8_kwidth = 8 // 2 if self.mma_version == 2 else 1
data = data.reshape(
*batch,
M // 4,
4,
K // (4 * 8 * 2 * 2 * mult),
2,
4,
8 // u8_kwidth,
2,
u8_kwidth * mult,
)
b = len(batch)
perm = [0, 6, 1, 3, 2, 5, 4, 7]
perm = list(range(b)) + [b + p for p in perm]
data = data.permute(*perm)
data = data.reshape(*batch, M * 4, K // 4)
data = self._maybe_mT(data)
return data[..., : self.K, : self.N]
def swizzle_block_shape(self, block_shape):
return block_shape
@triton.jit
def _unshuffle_triton(x, mma_version: tl.constexpr):
"""
Triton inverse of swizzle_mxfp4_value_hopper
"""
tl.static_assert(mma_version == 2 or mma_version == 3, "mma_version must be 2 or 3")
# if mx_axis == 0:
# x = x.trans()
# We have two times the elements if we already upcasted to bfloat16
mult: tl.constexpr = 2 if x.dtype == tl.bfloat16 else 1
M: tl.constexpr = x.shape[0]
K: tl.constexpr = x.shape[1]
tl.static_assert(M % 4 == 0, "M must be divisible by 4")
tl.static_assert(
K % (4 * 8 * 2 * 2 * mult) == 0,
f"K must be divisible by {4 * 8 * 2 * 2 * mult}",
)
# We are loading 8 bf16 elements per thread to use ld.global.v4
# Every u8 represents 2 mxfp4 elements
u8_kwidth: tl.constexpr = 8 // 2 if mma_version == 2 else 1
x = x.reshape(
M // 4,
4,
K // (4 * 8 * 2 * 2 * mult),
2,
4,
8 // u8_kwidth,
2,
u8_kwidth * mult,
)
x = x.trans(0, 6, 1, 3, 2, 5, 4, 7)
x = x.reshape(M * 4, K // 4)
# if mx_axis == 0:
# x = x.trans()
return x
@triton.jit
def _unpack_fp4_to_bf16_triton(x):
# For now we implement just H100 support (mul.bf16x2)
# A100 support is possible via fma
r0, r1 = tl.inline_asm_elementwise(
r"""
{
.reg .b32 b, c, d<7>, scale;
.reg .b32 bias;
mov.b32 bias, 0x7e807e80; // 2 ** 126 == 2 ** (bias_bf16 - bias_fp2)
// We add the missing bias to the scale directly
and.b32 $0, $4, 0b10000001110000001000000111000000;
mul.bf16x2 $0, $0, bias;
shl.b32 b, $4, 3;
and.b32 $1, b, 0b10000001110000001000000111000000;
mul.bf16x2 $1, $1, bias;
shl.b32 c, $4, 6;
and.b32 $2, c, 0b10000001110000001000000111000000;
mul.bf16x2 $2, $2, bias;
// Unpack last two elements
shl.b32 d0, $4, 1;
and.b32 d1, d0, 0b10000000000000001000000000000000;
shr.b32 d2, $4, 3;
and.b32 d3, d2, 0b00000001100000000000000110000000;
or.b32 d4, d1, d3;
shr.b32 d5, $4, 7;
and.b32 d6, d5, 0b00000000010000000000000001000000;
or.b32 $3, d4, d6;
mul.bf16x2 $3, $3, bias;
}
""",
constraints="=r,=r,=r,=r,r",
args=[x],
dtype=(tl.bfloat16, tl.bfloat16),
is_pure=True,
pack=4,
)
# Concat each pack of 4
x = tl.join(r0, r1)
x = x.reshape(x.shape[0], x.shape[1] // 4, 4, x.shape[2])
x = x.trans(0, 1, 3, 2)
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])
return x
@triton.jit
def mxfp4_to_bf16_triton(x, scale, mx_axis: tl.constexpr):
"""
Implements the bit-untwiddling of a 32-bit integer (8 mxfp4 elements):
(x << 0) & 0b1000000111000000
(x << 3) & 0b1000000111000000
(x << 6) & 0b1000000111000000
((x << 1) & 0b1000000000000000) | ((x >> 3) & 0b0000000110000000) | ((x >> 7) & 0b0000000001000000)
"""
# upcast values to bfloat16
tl.static_assert(len(x.shape) == 2)
tl.static_assert(mx_axis == 0 or mx_axis == 1, "mx_axis must be 0 or 1")
tl.static_assert(x.shape[1] % 4 == 0)
tl.static_assert(x.dtype == tl.uint8)
if mx_axis == 0:
x = x.trans()
x = _unpack_fp4_to_bf16_triton(x)
x = _unshuffle_triton(x, mma_version=3)
if mx_axis == 0:
x = x.trans()
# upcast scale to bfloat16
# Add bias missing from the bf16 upcasting sequence
# triton / LLVM generates terrible code for this sequence
# scale = scale.to(tl.uint16)
# scale = scale << 7
# scale = scale.to(tl.bfloat16, bitcast=True)
scale = tl.inline_asm_elementwise(
r"""
{
prmt.b32 $0, $2, 0, 0x5140;
shl.b32 $0, $0, 7;
prmt.b32 $1, $2, 0, 0x7362;
shl.b32 $1, $1, 7;
}
""",
constraints="=r,=r,r",
args=[scale],
dtype=tl.bfloat16,
is_pure=True,
pack=4,
)
# Broadcast scale
scale = scale.expand_dims(mx_axis + 1)
scale = scale.broadcast_to(
scale.shape[: mx_axis + 1] + [32] + scale.shape[mx_axis + 2 :]
)
scale = scale.reshape(x.shape)
# Combine scale and x
x = x * scale
return x
from .base import Layout
class StridedLayout(Layout):
name: str = None
def __init__(self, shape) -> None:
super().__init__(shape)
def swizzle_data(self, data):
return data
def unswizzle_data(self, data):
return data
def swizzle_block_shape(self, block_shape):
return block_shape
import enum
import functools
import os
import subprocess
import sys
import torch
from compactor_vllm.triton_kernels.numerics import (
MAX_FINITE_FLOAT8E4B8,
MAX_FINITE_FLOAT8E4NV,
MAX_FINITE_FLOAT8E5,
)
def assert_equal(ref, tri):
if isinstance(ref, torch.Tensor):
assert torch.all(ref == tri)
else:
assert ref == tri
def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True):
if tri.dtype.itemsize == 1:
ref_as_type = ref.to(tri.dtype)
if ref.dtype == tri.dtype:
assert torch.all(ref_as_type == tri)
return
ref = ref_as_type
if ref.numel() == 0:
return
if maxtol is None:
maxtol = 2e-2
if rmstol is None:
rmstol = 4e-3
"""
Compare reference values against obtained values.
"""
# cast to float32:
ref = ref.to(torch.float32).detach()
tri = tri.to(torch.float32).detach()
assert ref.shape == tri.shape, (
f"Tensors must have same size {ref.shape=} {tri.shape=}"
)
# deal with infinite elements:
inf_mask_ref = torch.isinf(ref)
inf_mask_tri = torch.isinf(tri)
assert torch.equal(inf_mask_ref, inf_mask_tri), (
"Tensor must have same infinite elements"
)
refn = torch.where(inf_mask_ref, 0, ref)
trin = torch.where(inf_mask_tri, 0, tri)
# normalise so that RMS calculation doesn't overflow:
eps = 1.0e-30
multiplier = 1.0 / (torch.max(torch.abs(refn)) + eps)
refn *= multiplier
trin *= multiplier
ref_rms = torch.sqrt(torch.square(refn).mean()) + eps
rel_err = torch.abs(refn - trin) / torch.maximum(ref_rms, torch.abs(refn))
max_err = torch.max(rel_err).item()
rms_err = torch.sqrt(torch.square(rel_err).mean()).item()
if verbose:
print(
"%s maximum relative error = %s (threshold = %s)"
% (description, max_err, maxtol)
)
print(
"%s RMS relative error = %s (threshold = %s)"
% (description, rms_err, rmstol)
)
if max_err > maxtol:
bad_idxs = torch.nonzero(rel_err > maxtol)
num_nonzero = bad_idxs.size(0)
bad_idxs = bad_idxs[:1000]
print(
"%d / %d mismatched elements (shape = %s) at coords %s"
% (num_nonzero, rel_err.numel(), tuple(rel_err.shape), bad_idxs.tolist())
)
bad_idxs = bad_idxs.unbind(-1)
print("ref values: ", ref[tuple(bad_idxs)].cpu())
print("tri values: ", tri[tuple(bad_idxs)].cpu())
assert max_err <= maxtol
assert rms_err <= rmstol
class ComputeSanitizerTool(enum.Enum):
MEMCHECK = "memcheck"
RACECHECK = "racecheck"
SYNCCHECK = "synccheck"
INITCHECK = "initcheck"
def compute_sanitizer(**target_kwargs):
"""
Decorator to run a test with compute sanitizer enabled and pytorch caching allocator disabled,
to expose potential memory access errors.
This decorator requires the `request` fixture to be present.
If `run_sanitizer` argument is present and set to False, the sanitizer is not run.
Running tests under compute sanitizer requires launching subprocess and is slow,
so use sparingly
"""
def decorator(test_fn):
@functools.wraps(test_fn)
def wrapper(*args, **kwargs):
if os.environ.get("SKIP_COMPUTE_SANITIZER") == "1":
test_fn(*args, **kwargs)
return
import psutil
if target_kwargs.pop("clear_torch_cache", False):
# If we don't pop clear_torch_cache, it won't pass
# target_kwargs.items() <= kwargs.items() condition below.
torch.cuda.empty_cache()
tools_to_check = target_kwargs.pop(
"tools_to_check", [ComputeSanitizerTool.MEMCHECK]
)
assert isinstance(tools_to_check, list), f"{tools_to_check=}"
assert all(tool in ComputeSanitizerTool for tool in tools_to_check), (
f"{(tool for tool in tools_to_check if tool not in ComputeSanitizerTool)=}"
)
ppid_name = psutil.Process(os.getppid()).exe()
run_compute_sanitizer = target_kwargs.items() <= kwargs.items()
if "run_sanitizer" in kwargs:
run_compute_sanitizer &= kwargs["run_sanitizer"]
if run_compute_sanitizer and "compute-sanitizer" not in ppid_name:
for tool in tools_to_check:
path = os.path.realpath(test_fn.__globals__["__file__"])
# get path of current file
env = {
"PATH": os.environ["PATH"],
"PYTORCH_NO_CUDA_MEMORY_CACHING": "1",
"TORCH_SHOW_CPP_STACKTRACES": "1",
"CUDA_LAUNCH_BLOCKING": "1",
}
if "CUDA_VISIBLE_DEVICES" in os.environ:
env["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"]
assert "request_fixture" in kwargs, (
"memcheck'ed test must have a (possibly unused) `request` fixture"
)
test_id = kwargs["request_fixture"].node.callspec.id
cmd = f"{path}::{test_fn.__name__}[{test_id}]"
cmd = [
"compute-sanitizer",
"--target-processes=application-only",
"--destroy-on-device-error=context",
f"--tool={tool.value}",
sys.executable,
"-m",
"pytest",
"-vsx",
cmd,
]
for opt in ["--update_checksum", "--ignore_checksum_error"]:
if opt in sys.argv:
cmd.append(opt)
out = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
env=env,
)
sanitizer_ok = "ERROR SUMMARY: 0 errors" in str(
out.stdout
) or "RACECHECK SUMMARY: 0 hazards displayed" in str(out.stdout)
test_output = out.stdout
if type(test_output) is bytes:
test_output = test_output.decode()
fail = False
if not sanitizer_ok:
print("compute-sanitizer returned an error")
fail = True
elif out.returncode != 0:
print(
"The test failed due to some other reason: consider running without compute-sanitizer to verify."
)
print(f"{out.returncode=}")
fail = True
if fail:
print("*****************************************************")
print("******************** TEST OUTPUT ********************")
print("*****************************************************")
print(test_output)
print("*****************************************************")
print("****************** TEST OUTPUT END ******************")
print("*****************************************************")
assert None
else:
test_fn(*args, **kwargs)
return wrapper
return decorator
def compute_actual_scale(x, dtype):
max_finite = {
torch.float8_e5m2: MAX_FINITE_FLOAT8E5,
torch.float8_e4m3fn: MAX_FINITE_FLOAT8E4NV,
torch.float8_e4m3fnuz: MAX_FINITE_FLOAT8E4B8,
}[dtype]
return x.abs().max() / max_finite
import torch
import triton
from compactor_vllm.triton_kernels.topk_details._topk_forward import _topk_forward
from compactor_vllm.triton_kernels.topk_details import _topk_backward
from compactor_vllm.triton_kernels.tensor import Tensor, Bitmatrix
from typing import Optional, Union
def topk_forward(
x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, n_rows=None
):
if not isinstance(x, Tensor):
x_shape = [x.shape[0] if n_rows is None else n_rows, x.shape[1]]
x_shape_max = [x.shape[0], x.shape[1]]
x = Tensor(x, shape=x_shape, shape_max=x_shape_max)
cdiv = lambda a, b: (a + b - 1) // b
BLOCK_M = 32
BLOCK_N = 32
BLOCK_S = 128
assert len(x.shape) == 2
assert x.shape_max[-1] < 32768
assert dim == 1
assert return_bitmatrix
n_rows, n_cols = x.shape
n_rows_max, _ = x.shape_max
dev = x.device
# scratchpad tensors
# NOTE: these are not returned
y_vals = torch.empty((n_rows_max, k), dtype=x.dtype, device=dev)
if y_indx is not None:
use_provided_indx = True
else:
y_indx = torch.empty((n_rows_max, k), dtype=torch.int16, device=dev)
use_provided_indx = False
# create bitmatrix in transposed memory layout:
n_cols_pad = cdiv(n_cols, BLOCK_N) * BLOCK_N
n_cols_words = n_cols_pad // 32
bitmatrix = torch.empty(
(n_cols_words, cdiv(n_rows_max, 32) * 32), dtype=torch.uint32, device=dev
)
bitmatrix = torch.transpose(bitmatrix, 0, 1)[:n_rows_max]
s_blocks = cdiv(n_cols, BLOCK_S)
s_cols = s_blocks * BLOCK_S
scratchpad = torch.empty((s_cols,), dtype=torch.int32, device=dev)
pids = max(cdiv(n_rows_max, BLOCK_M), s_blocks)
_topk_forward[(pids,)](
x,
x.stride(0), # inputs
y_vals,
y_indx,
y_vals.stride(0),
use_provided_indx, # output [topk]
bitmatrix,
bitmatrix.stride(0),
bitmatrix.stride(1), # output [bitmatrix]
n_rows,
n_cols, # shapes
scratchpad,
BLOCK_S,
s_blocks, # thing to memset to zero
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N, # tunable parameter
APPLY_SOFTMAX=apply_softmax,
N_EXPTS_PAD=n_cols_pad,
N_EXPTS_ACT=k, # constants
)
bitmatrix_shape = [n_rows, n_cols_words * 32]
bitmatrix_shape_max = [n_rows_max, None]
bitmatrix = Bitmatrix(
bitmatrix,
shape=bitmatrix_shape,
shape_max=bitmatrix_shape_max,
scratchpad=scratchpad,
)
return y_vals, y_indx, bitmatrix
def topk_backward(x, y_indx, dy_vals, k, n_rows, apply_softmax):
assert dy_vals.shape[-1] == k
n_expts_pad = triton.next_power_of_2(x.shape[-1])
dx = torch.empty_like(x)
_topk_backward[(dy_vals.shape[0],)](
y_indx,
y_indx.stride(0),
dy_vals,
dy_vals.stride(0),
x,
x.stride(0), # inputs
dx, # outputs
dx.stride(0),
x.shape[0],
n_rows,
x.shape[-1],
APPLY_SOFTMAX=apply_softmax,
N_EXPTS_ACT=k,
N_EXPTS_PAD=n_expts_pad,
)
return dx
class TopK(torch.autograd.Function):
@staticmethod
def forward(ctx, x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows):
y_vals, y_indx, bitmatrix = topk_forward(
x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows
)
ctx.save_for_backward(x, y_indx)
ctx.apply_softmax = apply_softmax
ctx.k = k
ctx.n_rows = n_rows
return y_vals, y_indx, bitmatrix
@staticmethod
def backward(ctx, dy_vals, _0, _1):
x, y_indx = ctx.saved_tensors
dx = topk_backward(x, y_indx, dy_vals, ctx.k, ctx.n_rows, ctx.apply_softmax)
return dx, None, None, None, None, None, None
def topk(
x: Union[Tensor, torch.Tensor],
k: int,
apply_softmax: bool = True,
dim: int = 1,
return_bitmatrix: bool = True,
y_indx: Optional[torch.Tensor] = None,
n_rows: Optional[int] = None,
):
"""
Computes the top-k values and indices along a specified dimension of a tensor.
Note that the input can be either a `Tensor` or a `torch.Tensor`, but the output will always be a `torch.Tensor`.
Parameters
----------
x : Union[triton_kernels.Tensor, torch.Tensor]
Input tensor of shape (n_tokens, n_expts).
k : int
Number of top elements to retrieve.
apply_softmax : bool, default True
Whether to apply softmax to the input tensor before computing top-k.
dim : int, default 1
Dimension along which to compute top-k.
return_bitmatrix : bool, default True
A bitmatrix of shape (n_tokens, cdiv(n_expts, 32)).
Each bit on [t, b] indicates whether the b-th expert was selected for the t-th token.
y_indx : torch.Tensor, optional
Pre-allocated tensor for storing indices of top-k elements with shape (n_tokens, k).
If provided, we skip the computation of top-k indices and use this tensor instead.
n_rows : int, optional
Number of rows to apply top-k on. If None, we consider all rows in `x`.
Returns
-------
(expt_scal, expt_indx, bitmatrix) : Tuple[torch.Tensor, torch.Tensor, Bitmatrix]
"""
ret = TopK.apply(x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows)
return ret
import triton
import triton.language as tl
@triton.jit
def _topk_backward(
Yi,
stride_ym, # topk indices
DY,
stride_dym, # output gradient values
X,
stride_xm, # input values
DX,
stride_dxm, # input gradient values
n_rows,
NRows,
n_expts_tot,
APPLY_SOFTMAX: tl.constexpr,
N_EXPTS_ACT: tl.constexpr,
N_EXPTS_PAD: tl.constexpr,
):
pid_m = tl.program_id(0)
if NRows is not None:
n_rows = tl.load(NRows)
if pid_m >= n_rows:
return
Yi += pid_m * stride_ym
DY += pid_m * stride_dym
X += pid_m * stride_xm
DX += pid_m * stride_dxm
# --
offs_xn = tl.arange(0, N_EXPTS_PAD)
offs_yn = tl.arange(0, N_EXPTS_ACT)
mask_xn = offs_xn < n_expts_tot
# recompute softmax
y_indx = tl.load(Yi + offs_yn)
x = tl.load(X + y_indx)
x = x.to(tl.float32)
y = tl.softmax(x)
# compute input-gradient
dy = tl.load(DY + offs_yn)
dy = dy.to(tl.float32)
s = tl.sum(y * dy, 0)
# write-back input gradient
tl.store(DX + offs_xn, 0, mask=mask_xn)
tl.debug_barrier()
if APPLY_SOFTMAX:
dx = y * (dy - s)
else:
dx = dy
tl.store(DX + y_indx, dx)
import triton
import triton.language as tl
@triton.jit
def get_topmask_and_fullmask(x):
tl.static_assert(
x.dtype.is_int_unsigned(), "floating-point value must be passed as bits"
)
tm: tl.constexpr = 1 << (-1 + x.dtype.primitive_bitwidth)
fm: tl.constexpr = (1 << x.dtype.primitive_bitwidth) - 1
tm_arr = tl.full(x.shape, tm, dtype=x.dtype)
fm_arr = tl.full(x.shape, fm, dtype=x.dtype)
return tm_arr, fm_arr
@triton.jit
def fpval_to_key(x):
tm, fm = get_topmask_and_fullmask(x)
return x ^ tl.where((x & tm) != 0, fm, tm)
@triton.jit
def key_to_fpval(x):
tm, fm = get_topmask_and_fullmask(x)
return x ^ tl.where((x & tm) == 0, fm, tm)
# stable top-k tie-breaks to value with smaller index
@triton.jit
def indx_to_key(indx, N_EXPTS_PAD: tl.constexpr):
return N_EXPTS_PAD - indx
@triton.jit
def key_to_indx(indx, N_EXPTS_PAD: tl.constexpr):
return N_EXPTS_PAD - indx
@triton.jit
def streaming_topk(
X,
stride_xm,
n_expts_tot,
offs_m,
mask_m,
N_EXPTS_PAD: tl.constexpr,
N_EXPTS_ACT: tl.constexpr,
BLOCK_N: tl.constexpr,
):
x_nbits: tl.constexpr = X.dtype.element_ty.primitive_bitwidth
x_utype: tl.constexpr = tl.dtype(f"uint{x_nbits}")
if x_nbits < 16:
# this ensures that we leave at least 16 bits for expert index
# even if the input dtype is smaller than 16 bits:
y_nbits: tl.constexpr = 32
else:
y_nbits: tl.constexpr = x_nbits * 2
x_ultype: tl.constexpr = tl.dtype(f"uint{y_nbits}")
x_dtype: tl.constexpr = X.dtype.element_ty
# subtract 1 from loop iterations because we peel the first (masked) iteration:
loop_iterations: tl.constexpr = N_EXPTS_PAD // BLOCK_N - 1
offs_x_n = loop_iterations * BLOCK_N + tl.arange(0, BLOCK_N)
mask_n = offs_x_n[None, :] < n_expts_tot
# first iteration:
X_ptrs = X + offs_m[:, None] * stride_xm + offs_x_n[None, :]
x = tl.load(X_ptrs, mask=(mask_m & mask_n), other=float("-inf"))
x = fpval_to_key(x.to(x_utype, bitcast=True))
x = (x.to(x_ultype) << 16) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :]
acc = tl.topk(x, N_EXPTS_ACT, dim=1)
# subsequent iterations:
for _i in (tl.static_range if loop_iterations <= 4 else range)(loop_iterations):
acc = tl.bitonic_merge(acc) # ensure sorted ascending for the merge
X_ptrs -= BLOCK_N
offs_x_n -= BLOCK_N
x = tl.load(X_ptrs, mask=mask_m, other=float("-inf"))
x = fpval_to_key(x.to(x_utype, bitcast=True))
x = (x.to(x_ultype) << 16) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :]
acc = tl.maximum(acc, tl.topk(x, N_EXPTS_ACT, dim=1))
# rotate expert index into upper 16 bits:
# 0000vvvvvvvviiii --> iiii0000vvvvvvvv
acc = (acc << (y_nbits - 16)) | (acc >> 16)
# sort in ascending order of expert (descending order of key)
acc = tl.sort(acc, dim=1, descending=True)
# iiii0000vvvvvvvv --> 0000iiii:
y_indices_raw = (acc >> (y_nbits - 16)).to(tl.uint32)
y_indices = key_to_indx(y_indices_raw, N_EXPTS_PAD)
# iiii0000vvvvvvvv --> vvvvvvvv:
y_values_raw = acc.to(x_utype)
y_values = key_to_fpval(y_values_raw).to(x_dtype, bitcast=True)
return y_values, y_indices
@triton.jit
def _topk_forward(
X,
stride_xm, # inputs
Yv,
Yi,
stride_ym, # topk values/indices
USE_PROVIDED_INDX: tl.constexpr,
Bits,
stride_rm: tl.constexpr,
stride_rn: tl.constexpr, # bitmatrix
n_rows,
n_expts_tot, # shape
S,
BLOCK_S: tl.constexpr,
s_blocks, # thing to memset
APPLY_SOFTMAX: tl.constexpr, # constant
BLOCK_M: tl.constexpr,
N_EXPTS_PAD: tl.constexpr,
N_EXPTS_ACT: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid = tl.program_id(0)
if isinstance(n_rows, tl.tensor) and n_rows.dtype.is_ptr():
n_rows = tl.load(n_rows)
if pid < s_blocks:
tl.store(
S + BLOCK_S * pid + tl.arange(0, BLOCK_S), tl.zeros([BLOCK_S], tl.int32)
)
if pid * BLOCK_M >= n_rows:
# early exit:
return
tl.static_assert(BLOCK_N % 32 == 0)
tl.static_assert(N_EXPTS_PAD % BLOCK_N == 0)
x_dtype: tl.constexpr = X.dtype.element_ty
# load logits
offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
offs_y_n = tl.arange(0, N_EXPTS_ACT)
mask_m = offs_m[:, None] < n_rows
if USE_PROVIDED_INDX:
Yi_ptrs = Yi + offs_m[:, None] * stride_ym + offs_y_n[None, :]
y_indices = tl.load(Yi_ptrs, mask=mask_m)
Xv_ptrs = X + offs_m[:, None] * stride_xm + y_indices
y_values = tl.load(Xv_ptrs, mask=mask_m)
else:
y_values, y_indices = streaming_topk(
X,
stride_xm,
n_expts_tot,
offs_m,
mask_m, #
N_EXPTS_PAD,
N_EXPTS_ACT,
BLOCK_N,
)
# normalize selected values
if APPLY_SOFTMAX:
y_values = tl.softmax(y_values.to(tl.float32), dim=1, keep_dims=True).to(
x_dtype
)
# write back
Yv_ptrs = Yv + offs_m[:, None] * stride_ym + offs_y_n[None, :]
tl.store(Yv_ptrs, y_values, mask=mask_m)
if not USE_PROVIDED_INDX:
Yi_ptrs = Yi + offs_m[:, None] * stride_ym + offs_y_n[None, :]
tl.store(Yi_ptrs, y_indices, mask=mask_m)
# pack into bitmatrix
y_div = y_indices // 32
y_rem = y_indices % 32
loop_iterations = N_EXPTS_PAD // BLOCK_N
for i in range(loop_iterations):
offs_r_n = tl.arange(0, BLOCK_N // 32) + i * (BLOCK_N // 32)
y2 = tl.where(
y_div[:, :, None] == offs_r_n[None, None, :], (1 << y_rem)[:, :, None], 0
)
r = tl.reduce_or(y2, axis=1)
BitsPtrs = Bits + offs_m[:, None] * stride_rm + offs_r_n[None, :] * stride_rn
tl.store(BitsPtrs, r, mask=mask_m)
import itertools
import math
from dataclasses import dataclass
from typing import List, Optional
import torch
import torch.distributed as dist
from compactor_vllm.compression import CompressionMethod
from compactor_vllm.compression.compression_config import BatchCompressionParams
from compactor_vllm.config.engine_config import LLMConfig
from compactor_vllm.utils.sequence import Sequence
@dataclass
class PrefillBatchArguments:
B: int
N: int
do_compression: bool
compression_method: CompressionMethod
compression_chunk_size: int
seq_ids: torch.Tensor
input_ids: torch.Tensor
positions: torch.Tensor
cu_seqlens_q: torch.Tensor
cu_seqlens_k: torch.Tensor
max_seqlen_q: int
max_seqlen_k: int
batch_tokens_to_retain: Optional[torch.Tensor]
max_tokens_to_retain: Optional[int]
protected_first: Optional[List[int]]
protected_last: Optional[List[int]]
PHI: Optional[torch.Tensor]
# args needed for memory reservation
context_lens: torch.Tensor
max_new_tokens: torch.Tensor
# 与 kvpress ``CompactorPress`` blending 默认(未显式指定时用 compression_ratio)对齐
compression_ratio: float = 1.0
class PackedTensorArguments:
def __init__(
self, rank: int, max_batched_tokens: int, config: LLMConfig, seed: int = 42
) -> None:
hf_config = config.hf_config
self.rank = rank
self.device = torch.device(f"cuda:{rank}")
self.max_num_batches = config.max_num_seqs
self.max_batched_tokens = max_batched_tokens
self.num_kv_heads = hf_config.num_key_value_heads // dist.get_world_size()
self.world_size = config.tensor_parallel_size
self.page_size = int(config.kvcache_page_size)
self.head_dim = getattr(hf_config, "head_dim", None)
self.sketch_dim = config.leverage_sketch_size
self.model_dtype = hf_config.torch_dtype
# i64 pack = [seq_ids (BMAX)] || [input_ids (NMAX)] || [positions (NMAX)] || max_new_tok (BMAX)
self.i64_len_max = (
self.max_num_batches + 2 * self.max_batched_tokens + self.max_num_batches
)
self.packed_context_i64 = torch.empty(
self.i64_len_max, dtype=torch.int64, device=self.device
)
# i32 pack = [header (6): ... + compression_ratio*1e6] || [cu_q (BMAX+1)] || ...
# || [protected_first_tokens (BMAX)] || [protected_last_tokens (BMAX)]
self.i32_len_max = (
6
+ (self.max_num_batches + 1)
+ (self.max_num_batches + 1)
+ self.max_num_batches
+ self.max_num_batches
+ self.max_num_batches
+ self.max_num_batches
)
self.packed_context_i32 = torch.empty(
self.i32_len_max, dtype=torch.int32, device=self.device
)
self.generator = torch.Generator(device=self.device).manual_seed(seed)
self.PHI = torch.randn(
(self.head_dim, self.sketch_dim),
device=self.packed_context_i32.device,
generator=self.generator,
).to(self.model_dtype) * (1 / math.sqrt(self.sketch_dim))
def _master_build_prefill(
self, seqs: List[Sequence], batch_compression_params: BatchCompressionParams
) -> PrefillBatchArguments:
B = len(seqs)
Ls = [x.prompt_len for x in seqs]
N = sum(Ls)
assert N <= self.max_batched_tokens
do_compression = any(x.compression_params.compression_ratio < 1.0 for x in seqs)
do_compression = (
do_compression
and batch_compression_params.compression_method != CompressionMethod.NONE
)
pack_slices_64 = self.packed_i64_slices(B, N)
pack_slices_32 = self.packed_i32_slices(B)
# max_retain = max(retain)
protected_first_list = [
x.compression_params.protected_first_tokens for x in seqs
]
protected_last_list = [x.compression_params.protected_last_tokens for x in seqs]
retain = [
max(
int(
round(
x.compression_params.compression_ratio
* (L - s - e)
* self.num_kv_heads
)
),
1,
)
for s, e, L, x in zip(protected_first_list, protected_last_list, Ls, seqs)
]
retain = torch.tensor(retain, dtype=torch.int32, device="cpu", pin_memory=True)
protected_first = torch.tensor(
protected_first_list, dtype=torch.int32, device="cpu", pin_memory=True
)
protected_last = torch.tensor(
protected_last_list, dtype=torch.int32, device="cpu", pin_memory=True
)
self.packed_context_i32[pack_slices_32["protected_first"]].copy_(
protected_first, non_blocking=True
)
self.packed_context_i32[pack_slices_32["protected_last"]].copy_(
protected_last, non_blocking=True
)
compression_chunk_size = (
batch_compression_params.chunk_size
if batch_compression_params.do_chunked_compression
else -1
)
min_compression_ratio = min(x.compression_params.compression_ratio for x in seqs)
cr_scaled = int(round(float(min_compression_ratio) * 1_000_000.0))
cr_scaled = max(min(cr_scaled, 2_000_000_000), -2_000_000_000)
header_host = torch.tensor(
[
B,
N,
1 if do_compression else 0,
batch_compression_params.compression_method.value,
compression_chunk_size,
cr_scaled,
],
dtype=torch.int32,
device="cpu",
pin_memory=True,
)
self.packed_context_i32[pack_slices_32["retain"]].copy_(
retain, non_blocking=True
)
self.packed_context_i32[pack_slices_32["header"]].copy_(
header_host, non_blocking=True
)
max_seq_qk = max(Ls)
cu = torch.tensor(
list(itertools.accumulate(Ls, initial=0)),
dtype=torch.int32,
device="cpu",
pin_memory=True,
)
self.packed_context_i32[pack_slices_32["cu_q"]].copy_(cu, non_blocking=True)
self.packed_context_i32[pack_slices_32["cu_k"]].copy_(cu, non_blocking=True)
self.packed_context_i32[pack_slices_32["context_lens"]].copy_(
cu.diff(), non_blocking=True
)
seq_ids = torch.tensor(
[x.seq_id for x in seqs], dtype=torch.int64, device="cpu", pin_memory=True
)
input_ids = torch.tensor(
[tid for x in seqs for tid in x.prompt_token_ids],
dtype=torch.int64,
device="cpu",
pin_memory=True,
)
self.packed_context_i64[pack_slices_64["seq_ids"]].copy_(
seq_ids, non_blocking=True
)
self.packed_context_i64[pack_slices_64["input_ids"]].copy_(
input_ids, non_blocking=True
)
positions = torch.cat(
[
torch.arange(L, dtype=torch.int64, device="cpu", pin_memory=True)
for L in Ls
]
)
self.packed_context_i64[pack_slices_64["positions"]].copy_(
positions, non_blocking=True
)
max_new_tokens = torch.tensor(
[seq.sampling_params.max_new_tokens for seq in seqs],
dtype=torch.int64,
device="cpu",
pin_memory=True,
)
self.packed_context_i64[pack_slices_64["max_new_tokens"]].copy_(
max_new_tokens, non_blocking=True
)
# `prefill_store_topk_kv(..., PAD_TO_PAGE_SIZE=True)` may scan beyond the
# top-k prefix to fill per-head lengths up to a page boundary. Using a
# full ranking (top_k = max_seq_len * HKV) makes `torch.topk` degenerate
# into a full sort, which is very expensive for long contexts.
#
# Instead, request only a prefix that is large enough for:
# 1) the maximum "keep" budget in the batch, plus
# 2) a conservative extra window for page-padding candidates.
max_seq_len = int(self.packed_context_i32[pack_slices_32["context_lens"]].max())
full_budget = max_seq_len * self.num_kv_heads
keep_budget = int(retain.max().item())
pad_search_budget = (self.page_size - 1) * (self.num_kv_heads**2)
max_retain = min(full_budget, keep_budget + pad_search_budget)
dist.broadcast(self.packed_context_i64, src=0)
dist.broadcast(self.packed_context_i32, src=0)
prefill_args = PrefillBatchArguments(
B=B,
N=N,
do_compression=do_compression,
compression_method=batch_compression_params.compression_method,
compression_chunk_size=compression_chunk_size,
seq_ids=self.packed_context_i64[pack_slices_64["seq_ids"]],
input_ids=self.packed_context_i64[pack_slices_64["input_ids"]],
positions=self.packed_context_i64[pack_slices_64["positions"]],
cu_seqlens_q=self.packed_context_i32[pack_slices_32["cu_q"]],
cu_seqlens_k=self.packed_context_i32[pack_slices_32["cu_k"]],
max_seqlen_q=max_seq_qk,
max_seqlen_k=max_seq_qk,
batch_tokens_to_retain=self.packed_context_i32[pack_slices_32["retain"]],
max_tokens_to_retain=max_retain,
PHI=self.PHI,
context_lens=self.packed_context_i32[pack_slices_32["context_lens"]],
max_new_tokens=self.packed_context_i64[pack_slices_64["max_new_tokens"]],
protected_first=protected_first_list,
protected_last=protected_last_list,
compression_ratio=min_compression_ratio,
)
return prefill_args
def _peer_receive_prefill(self) -> PrefillBatchArguments:
dist.broadcast(self.packed_context_i64, src=0)
dist.broadcast(self.packed_context_i32, src=0)
header = self.packed_context_i32[:6].tolist()
B, N = int(header[0]), int(header[1])
do_compression = bool(int(header[2]))
compression_method = CompressionMethod(int(header[3]))
compression_chunk_size = int(header[4])
compression_ratio = int(header[5]) / 1_000_000.0
pack_slices_64 = self.packed_i64_slices(B, N)
pack_slices_32 = self.packed_i32_slices(B)
max_seq_len = int(self.packed_context_i32[pack_slices_32["context_lens"]].max())
full_budget = max_seq_len * self.num_kv_heads
keep_budget = int(self.packed_context_i32[pack_slices_32["retain"]].max().item())
pad_search_budget = (self.page_size - 1) * (self.num_kv_heads**2)
max_retain = min(full_budget, keep_budget + pad_search_budget)
prefill_args = PrefillBatchArguments(
B=B,
N=N,
do_compression=do_compression,
compression_method=compression_method,
compression_chunk_size=compression_chunk_size,
seq_ids=self.packed_context_i64[pack_slices_64["seq_ids"]],
input_ids=self.packed_context_i64[pack_slices_64["input_ids"]],
positions=self.packed_context_i64[pack_slices_64["positions"]],
cu_seqlens_q=self.packed_context_i32[pack_slices_32["cu_q"]],
cu_seqlens_k=self.packed_context_i32[pack_slices_32["cu_k"]],
max_seqlen_q=int(self.packed_context_i32[pack_slices_32["cu_q"]].max()),
max_seqlen_k=int(self.packed_context_i32[pack_slices_32["cu_k"]].max()),
batch_tokens_to_retain=self.packed_context_i32[pack_slices_32["retain"]],
max_tokens_to_retain=max_retain,
PHI=self.PHI,
context_lens=self.packed_context_i32[pack_slices_32["context_lens"]],
max_new_tokens=self.packed_context_i64[pack_slices_64["max_new_tokens"]],
protected_first=self.packed_context_i32[
pack_slices_32["protected_first"]
].tolist(),
protected_last=self.packed_context_i32[
pack_slices_32["protected_last"]
].tolist(),
compression_ratio=compression_ratio,
)
return prefill_args
@torch.inference_mode()
def build_prefill_args(
self,
seqs: Optional[List[Sequence]] = None,
batch_compression_params: Optional[BatchCompressionParams] = None,
) -> PrefillBatchArguments:
if self.rank == 0:
return self._master_build_prefill(seqs, batch_compression_params)
return self._peer_receive_prefill()
def broadcast(self):
if self.world_size > 1:
return dist.broadcast(self.packed_context_i64, src=0)
return None
@staticmethod
def packed_i64_slices(B: int, N: int):
return {
"seq_ids": slice(0, B),
"input_ids": slice(B, B + N),
"positions": slice(B + N, B + 2 * N),
"max_new_tokens": slice(B + 2 * N, 2 * B + 2 * N),
}
@staticmethod
def packed_i32_slices(B: int):
h0, h1 = 0, 6
q0 = h1
q1 = q0 + (B + 1)
k0 = q1
k1 = k0 + (B + 1)
r0 = k1
r1 = r0 + B
c0 = r1
c1 = r1 + B
pf0 = c1
pf1 = c1 + B
pl0 = pf1
pl1 = pf1 + B
return {
"header": slice(h0, h1),
"cu_q": slice(q0, q1),
"cu_k": slice(k0, k1),
"retain": slice(r0, r1),
"context_lens": slice(c0, c1),
"protected_first": slice(pf0, pf1),
"protected_last": slice(pl0, pl1),
}
@dataclass
class DecodeBatchOutput:
output_tokens: Optional[torch.Tensor]
output_seq_ids: Optional[torch.Tensor]
@dataclass
class DecodeBatchArguments:
batch_mapping: Optional[torch.Tensor] = None
token_ids: Optional[torch.Tensor] = None
positions: Optional[torch.Tensor] = None
max_ctx_lens: Optional[torch.Tensor] = None
seq_ids: Optional[torch.Tensor] = None
temps: Optional[torch.Tensor] = None
desired_batch_occupancy: int = -1
num_stashed_batches: int = 0
def update(
self,
batch_mapping,
token_ids,
positions,
max_ctx_lens,
seq_ids,
temps=None,
desired_batch_occupancy: int = None,
):
if self.batch_mapping is not None:
self.batch_mapping = torch.cat([self.batch_mapping, batch_mapping], dim=0)
else:
self.batch_mapping = batch_mapping.clone()
if self.token_ids is not None:
self.token_ids = torch.cat([self.token_ids, token_ids], dim=0)
else:
self.token_ids = token_ids.clone()
if self.positions is not None:
self.positions = torch.cat([self.positions, positions], dim=0)
else:
self.positions = positions.clone()
if self.max_ctx_lens is not None:
self.max_ctx_lens = torch.cat([self.max_ctx_lens, max_ctx_lens], dim=0)
else:
self.max_ctx_lens = max_ctx_lens.clone()
if self.seq_ids is not None:
self.seq_ids = torch.cat([self.seq_ids, seq_ids], dim=0)
else:
self.seq_ids = seq_ids.clone()
if self.temps is not None and temps is not None:
self.temps = torch.cat([self.temps, temps], dim=0)
elif temps is not None:
self.temps = temps.clone()
if desired_batch_occupancy is not None:
self.desired_batch_occupancy = desired_batch_occupancy
return self
from dataclasses import dataclass
from typing import List, Optional
import torch
from compactor_vllm.compression import CompressionMethod
from compactor_vllm.config.engine_config import AttentionBackend
@dataclass
class CompressionContext:
compression_method: CompressionMethod = CompressionMethod.COMPACTOR
compression_chunk_size: int = -1
batch_tokens_to_retain: torch.Tensor | None = None
max_tokens_to_retain: int = 0
context_lens: List[int] | None = None
PHI: torch.Tensor | None = None
# Compactor(与 kvpress ``CompactorPress`` 对齐的可选超参)
sketch_dimension: int = 48
sink_size_start: int = 8
sink_size_end: int = 4
compactor_blending: Optional[float] = None
# 与 kvpress 一致:未设 ``compactor_blending`` 时用该值(来自请求的 compression_ratio)
compression_ratio: Optional[float] = None
protected_first_tokens: List[int] | None = None
protected_last_tokens: List[int] | None = None
# CriticalAdaKV
wo_weight: Optional[torch.Tensor] = None
critical_ada_epsilon: float = 1e-4
critical_ada_first_stage_ratio: float = 0.5
critical_ada_alpha_safeguard: float = 0.2
@dataclass
class Context:
is_prefill: bool = False
do_compression: bool = False
cu_seqlens_q: torch.Tensor | None = None
cu_seqlens_k: torch.Tensor | None = None
max_seqlen_q: int = 0
max_seqlen_k: int = 0
batch_mapping: torch.Tensor | None = None
max_bh_len: int = 0
compression_context: CompressionContext | None = None
STORE_STREAM: torch.cuda.Stream | None = None
key_split: int | None = None
attention_backend: AttentionBackend = AttentionBackend.COMPACTOR_TRITON
_CONTEXT = Context()
def get_context():
return _CONTEXT
def set_context(
*,
is_prefill,
do_compression=False,
cu_seqlens_q=None,
cu_seqlens_k=None,
max_seqlen_q=0,
max_seqlen_k=0,
batch_mapping=None,
max_bh_len=0,
compression_context: CompressionContext = None,
STORE_STREAM=None,
key_split=None,
attention_backend=AttentionBackend.COMPACTOR_TRITON,
):
global _CONTEXT
_CONTEXT = Context(
is_prefill,
do_compression,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
batch_mapping,
max_bh_len,
compression_context,
STORE_STREAM,
key_split,
attention_backend,
)
def reset_context():
global _CONTEXT
_CONTEXT = Context()
from collections.abc import Callable
import torch
def maybe_execute_in_stream(
fn: Callable, *args, STORE_STREAM: torch.cuda.Stream = None, **kwargs
):
if STORE_STREAM is not None:
tensors = [arg for arg in args if isinstance(arg, torch.Tensor)]
tensors += [val for val in kwargs.values() if isinstance(val, torch.Tensor)]
obj = getattr(fn, "__self__", None)
if isinstance(obj, torch.Tensor):
tensors.append(obj)
STORE_STREAM.wait_stream(torch.cuda.default_stream())
# Some PyTorch builds don't make `torch.cuda.Stream` a context manager.
# The portable API is `torch.cuda.stream(stream)`.
stream_ctx = (
STORE_STREAM
if hasattr(STORE_STREAM, "__enter__")
else torch.cuda.stream(STORE_STREAM)
)
with stream_ctx:
output = fn(*args, **kwargs)
for t in tensors:
t.record_stream(STORE_STREAM)
if isinstance(output, tuple):
for o in output:
if isinstance(o, torch.Tensor):
o.record_stream(torch.cuda.default_stream())
elif isinstance(output, torch.Tensor):
output.record_stream(torch.cuda.default_stream())
return output
else:
return fn(*args, **kwargs)
from dataclasses import dataclass, field
from enum import Enum, auto
from itertools import count
from typing import List
from compactor_vllm.compression.compression_config import SequenceCompressionParams
from compactor_vllm.config.sampling_params import SamplingParams
class SequenceStatus(Enum):
WAITING = auto()
RUNNING = auto()
FINISHED = auto()
@dataclass
class Sequence:
"""
Represents a single user request / sequence being generated.
"""
_counter = count()
prompt_token_ids: List[int]
completion_token_ids: List[int] = field(default_factory=list)
sampling_params: SamplingParams = field(default_factory=SamplingParams)
compression_params: SequenceCompressionParams = field(
default_factory=SequenceCompressionParams
)
status: SequenceStatus = SequenceStatus.WAITING
seq_id: int = field(default_factory=lambda: next(Sequence._counter), init=False)
num_tokens_processed: int = 0
@property
def num_prompt_tokens(self) -> int:
return len(self.prompt_token_ids)
@property
def num_generated_tokens(self) -> int:
return len(self.completion_token_ids)
def add_new_token(self, token_id: int) -> None:
if len(self.completion_token_ids) == 0:
self.num_tokens_processed += self.num_prompt_tokens
self.completion_token_ids.append(token_id)
self.num_tokens_processed += 1
def tokens_to_retain_per_layer(self, num_kv_heads: int) -> int:
n = int(
self.compression_params.compression_ratio
* self.num_prompt_tokens
* num_kv_heads
)
return max(1, n)
def __getstate__(self):
return dict(
prompt_token_ids=list(self.prompt_token_ids),
completion_token_ids=list(self.completion_token_ids),
sampling_params=self.sampling_params,
compression_params=self.compression_params,
status=self.status,
seq_id=self.seq_id,
num_tokens_processed=self.num_tokens_processed,
)
def __setstate__(self, state):
self.prompt_token_ids = list(state["prompt_token_ids"])
self.completion_token_ids = list(state["completion_token_ids"])
self.sampling_params = state["sampling_params"]
self.compression_params = state["compression_params"]
self.status = state["status"]
self.seq_id = state["seq_id"]
self.num_tokens_processed = state["num_tokens_processed"]
@property
def prompt_len(self) -> int:
return len(self.prompt_token_ids)
@property
def completion_len(self) -> int:
return len(self.completion_token_ids)
from __future__ import annotations
import inspect
from typing import Any, Callable, Mapping
import torch
def _filter_kwargs_for_callable(
fn: Callable[..., Any], kwargs: Mapping[str, Any]
) -> dict[str, Any]:
try:
params = inspect.signature(fn).parameters
except (TypeError, ValueError):
return dict(kwargs)
return {k: v for k, v in kwargs.items() if k in params}
def autotune(*, configs, key, **kwargs):
"""
Compatibility wrapper around `triton.autotune`.
Some Triton builds (e.g., custom vendor builds) may not support newer
keyword arguments like `cache_results`. This wrapper filters unsupported
kwargs based on the runtime `triton.autotune` signature.
"""
import triton
filtered = _filter_kwargs_for_callable(triton.autotune, kwargs)
return triton.autotune(configs=configs, key=key, **filtered)
def maybe_set_allocator(alloc_fn: Callable[[int, int, int | None], Any]) -> bool:
"""
Call `triton.set_allocator(alloc_fn)` if present; otherwise no-op.
Returns True if the allocator was set.
"""
import triton
setter = getattr(triton, "set_allocator", None)
if setter is None:
return False
setter(alloc_fn)
return True
def cuda_capability_geq(major: int, minor: int = 0, device: int | None = None) -> bool:
"""
Host-side CUDA capability check that works even when `tl.target_info` is absent.
"""
if not torch.cuda.is_available():
return False
if device is None:
try:
device = torch.cuda.current_device()
except Exception:
device = 0
cap = torch.cuda.get_device_capability(device)
return cap >= (major, minor)
import collections
import logging
from dataclasses import dataclass
from typing import List
import pytest
import torch
import triton
from compactor_vllm.compression.common import scores_to_retain_indices
from src.compactor_vllm.kv_cache.store_kv_cache import prefill_store_topk_kv
logger = logging.getLogger(__name__)
@dataclass
class Workload:
name: str
batch_size: int
nk_heads: int
head_dim: int
frac: float # per-sequence cached context length fractionf
page_size: int
cache_lens: List[int] # per-sequence cached context length
WORKLOADS: List[Workload] = [
Workload(
name=f"batch_size={BATCH} kv_cache_len={cache_lens} "
f"FRAC={frac} HKV={NK_HEADS} HEAD_DIM={HEAD_DIM}",
batch_size=BATCH,
nk_heads=NK_HEADS,
head_dim=HEAD_DIM,
cache_lens=[cache_lens] * BATCH,
frac=frac,
page_size=ps,
)
for BATCH in [1, 2, 3, 8]
for frac in [0.10, 0.20, 0.30, 0.40]
for NK_HEADS in [2, 4, 8]
for HEAD_DIM in [32, 64, 128]
for cache_lens in [10, 20, 30, 70, 1000]
for ps in [128, 256]
]
@pytest.mark.parametrize("workload", WORKLOADS, ids=lambda wl: wl.name)
def test_prefill_store_topk_kv(workload: Workload):
B = workload.batch_size
H = workload.nk_heads
D = workload.head_dim
TOP_K = int(workload.cache_lens[0] * workload.nk_heads * workload.frac)
PAGE_SIZE = workload.page_size
dtype = torch.float16
device = triton.runtime.driver.active.get_active_torch_device()
lens = torch.tensor(workload.cache_lens, dtype=torch.int32, device=device)
cu = torch.zeros(B + 1, dtype=torch.int32, device=device)
cu[1:] = torch.cumsum(lens, dim=0)
N_total = int(cu[-1].item())
keys = torch.randn((N_total, H, D), dtype=dtype, device=device)
vals = torch.randn_like(keys)
scores_flat = torch.randn((N_total, H), dtype=torch.float32, device=device)
top_k_eff = max(0, min(TOP_K, int(lens.max().item()) * H))
max_k_len = cu.diff().max().item()
indices = scores_to_retain_indices(
scores_flat, cu, max_k_len, top_k_eff, H
) # [B, TOP_K]
LP = max(1, (top_k_eff + PAGE_SIZE - 1) // PAGE_SIZE)
N_LOGICAL_PAGES_MAX = LP
N_PAGES = B * H * LP + 32
S_LARGE = N_PAGES * PAGE_SIZE
k_cache = torch.empty((S_LARGE, D), dtype=dtype, device=device)
v_cache = torch.empty_like(k_cache)
page_table = torch.empty(
(B, H, N_LOGICAL_PAGES_MAX), dtype=torch.int32, device=device
)
phys = 0
for b in range(B):
for h in range(H):
for lp in range(LP):
page_table[b, h, lp] = phys
phys += 1
assert phys <= N_PAGES, "Not enough physical pages"
local_lens = torch.zeros((B, H), dtype=torch.int32, device=device)
batch_mapping = torch.arange(B, dtype=torch.int32, device=device)
num_to_retain = torch.full((B,), top_k_eff, dtype=torch.int32, device=device)
prefill_store_topk_kv(
new_keys=keys,
new_vals=vals,
indices_topk=indices,
num_tokens_to_retain=num_to_retain,
page_table=page_table,
batch_mapping=batch_mapping,
bh_lens=local_lens,
PAGE_SIZE=PAGE_SIZE,
k_cache=k_cache,
v_cache=v_cache,
PAD_TO_PAGE_SIZE=False,
TRITON_RESERVED_BATCH=-1,
)
torch.cuda.synchronize()
local_lens_cpu = local_lens.cpu()
page_table_cpu = page_table.cpu()
k_cache_cpu = k_cache.cpu()
v_cache_cpu = v_cache.cpu()
keys_cpu = keys.cpu()
vals_cpu = vals.cpu()
indices_cpu = indices.cpu()
for b in range(B):
hed = (indices_cpu[b] % H).numpy()
counts = collections.Counter(hed.tolist())
for h in range(H):
expected = counts.get(h, 0) # type: ignore
got = int(local_lens_cpu[b, h].item())
assert got == expected, (
f"Length mismatch at (b={b}, h={h}): got {got}, expected {expected}"
)
def rows_for_head(b, h, L):
"""Return the list of cache row indices storing the first L logical positions for (b,h)."""
rows = []
for pos in range(L):
lp = pos // PAGE_SIZE
off = pos % PAGE_SIZE
phys = int(page_table_cpu[b, h, lp].item())
rows.append(phys * PAGE_SIZE + off)
return rows
for b in range(B):
# which tokens per head were selected for this batch?
tok = (indices_cpu[b] // H).numpy()
hed = (indices_cpu[b] % H).numpy()
per_head = collections.defaultdict(list)
for t, h in zip(tok, hed):
per_head[int(h)].append(int(t))
for h in range(H):
L = int(local_lens_cpu[b, h].item())
if L == 0:
continue
# expected vectors (unordered) from source
toks_h = per_head.get(h, [])
assert len(toks_h) == L
expK = keys_cpu[toks_h, h, :].contiguous().view(L, -1)
expV = vals_cpu[toks_h, h, :].contiguous().view(L, -1)
# actual vectors read back from cache rows
rows = rows_for_head(b, h, L)
actK = k_cache_cpu[rows, :].contiguous().view(L, -1)
actV = v_cache_cpu[rows, :].contiguous().view(L, -1)
expK_tuples = [tuple(row) for row in expK.numpy().tolist()]
actK_tuples = [tuple(row) for row in actK.numpy().tolist()]
expV_tuples = [tuple(row) for row in expV.numpy().tolist()]
actV_tuples = [tuple(row) for row in actV.numpy().tolist()]
assert collections.Counter(expK_tuples) == collections.Counter(
actK_tuples
), f"K content mismatch at (b={b}, h={h})"
assert collections.Counter(expV_tuples) == collections.Counter(
actV_tuples
), f"V content mismatch at (b={b}, h={h})"
def test_prefill_store_topk_kv_pad_to_page_size():
torch.manual_seed(0)
B, H, D = 2, 2, 64
PAGE_SIZE = 128
RETAIN = 64
dtype = torch.float16
device = triton.runtime.driver.active.get_active_torch_device()
lens = torch.full((B,), 256, dtype=torch.int32, device=device)
cu = torch.zeros(B + 1, dtype=torch.int32, device=device)
cu[1:] = torch.cumsum(lens, dim=0)
N_total = int(cu[-1].item())
keys = torch.randn((N_total, H, D), dtype=dtype, device=device)
vals = torch.randn_like(keys)
scores_flat = torch.randn((N_total, H), dtype=torch.float32, device=device)
max_k_len = int(lens.max().item())
max_sel = max_k_len * H
indices = scores_to_retain_indices(scores_flat, cu, max_k_len, max_sel, H)
N_LOGICAL_PAGES_MAX = 2
N_PAGES = B * H * N_LOGICAL_PAGES_MAX + 32
S_LARGE = N_PAGES * PAGE_SIZE
k_cache = torch.empty((S_LARGE, D), dtype=dtype, device=device)
v_cache = torch.empty_like(k_cache)
page_table = torch.empty(
(B, H, N_LOGICAL_PAGES_MAX), dtype=torch.int32, device=device
)
phys = 0
for b in range(B):
for h in range(H):
for lp in range(N_LOGICAL_PAGES_MAX):
page_table[b, h, lp] = phys
phys += 1
assert phys <= N_PAGES, "Not enough physical pages"
local_lens = torch.zeros((B, H), dtype=torch.int32, device=device)
batch_mapping = torch.arange(B, dtype=torch.int32, device=device)
num_to_retain = torch.full((B,), RETAIN, dtype=torch.int32, device=device)
prefill_store_topk_kv(
new_keys=keys,
new_vals=vals,
indices_topk=indices,
num_tokens_to_retain=num_to_retain,
page_table=page_table,
batch_mapping=batch_mapping,
bh_lens=local_lens,
PAGE_SIZE=PAGE_SIZE,
k_cache=k_cache,
v_cache=v_cache,
PAD_TO_PAGE_SIZE=True,
cu_seqlens_k=cu,
TRITON_RESERVED_BATCH=-1,
)
torch.cuda.synchronize()
local_lens_cpu = local_lens.cpu()
lens_cpu = lens.cpu()
assert (local_lens_cpu % PAGE_SIZE == 0).all()
assert (local_lens_cpu <= lens_cpu[:, None]).all()
import logging
import math
from dataclasses import dataclass
from typing import List
import pytest
import torch
import triton
from flash_attn.flash_attn_interface import (
flash_attn_varlen_func,
flash_attn_with_kvcache,
)
from compactor_vllm.attention.sparse_decode_kernel import head_sparse_decode_attention
from compactor_vllm.attention.sparse_varlen_kernel import (
causal_sparse_varlen_with_cache,
)
logger = logging.getLogger(__name__)
@dataclass
class Workload:
name: str
batch_size: int
nq_heads: int
nk_heads: int
head_dim: int
cache_lens: List[int] # per-sequence cached context length
append_lens: List[int] # per-sequence new tokens this step (Q_app, K_app, V_app)
WORKLOADS: List[Workload] = [
Workload(
name=f"batch_size={BATCH} kv_cache_len={cache_lens} append_len={append_lens} "
f"HQ={NQ_HEADS} HKV={NK_HEADS} HEAD_DIM={HEAD_DIM}",
batch_size=BATCH,
nq_heads=NQ_HEADS,
nk_heads=NK_HEADS,
head_dim=HEAD_DIM,
cache_lens=[cache_lens] * BATCH,
append_lens=[append_lens] * BATCH,
)
for BATCH in [1, 2, 3, 8]
for NQ_HEADS in [32]
for NK_HEADS in [8]
for HEAD_DIM in [128]
for cache_lens in [0, 1, 70, 128, 8193]
for append_lens in [1, 2, 13, 8000]
]
WORKLOADS_DECODE: List[Workload] = [
Workload(
name=f"batch_size={BATCH} kv_cache_len={cache_lens}"
f"HQ={NQ_HEADS} HKV={NK_HEADS} HEAD_DIM={HEAD_DIM}",
batch_size=BATCH,
nq_heads=NQ_HEADS,
nk_heads=NK_HEADS,
head_dim=HEAD_DIM,
cache_lens=[cache_lens] * BATCH,
append_lens=[1] * BATCH,
)
for BATCH in [1, 2, 3, 8]
for NQ_HEADS in [32]
for NK_HEADS in [8]
for HEAD_DIM in [128]
for cache_lens in [1, 2, 70, 128, 8000]
]
def build_paged_cache_from_lengths(
B,
H_kv,
D,
PAGE_SIZE,
N_LOGICAL_PAGES_MAX,
L_cache_per_b, # int32 [B], per-batch cache length
device,
dtype,
):
"""
Construct:
- seq_lens_bh[b, h] = L_cache_per_b[b]
- page_table[b, h, lp] giving physical page ids
- K_cache, V_cache filled for valid cached tokens
Physical layout:
physical_page_id = (b * H_kv + h) * N_LOGICAL_PAGES_MAX + lp
CACHE_SIZE = num_phys_pages * PAGE_SIZE
"""
assert L_cache_per_b.shape[0] == B
max_len = PAGE_SIZE * N_LOGICAL_PAGES_MAX
assert (L_cache_per_b <= max_len).all()
seq_lens_bh = torch.empty((B, H_kv), dtype=torch.int32, device=device)
for b in range(B):
seq_lens_bh[b, :].fill_(L_cache_per_b[b])
num_phys_pages = B * H_kv * N_LOGICAL_PAGES_MAX
CACHE_SIZE = num_phys_pages * PAGE_SIZE
K_cache = torch.zeros((CACHE_SIZE, D), device=device, dtype=dtype)
V_cache = torch.zeros((CACHE_SIZE, D), device=device, dtype=dtype)
page_table = torch.empty(
(B, H_kv, N_LOGICAL_PAGES_MAX), device=device, dtype=torch.int32
)
# assign unique physical pages per (b, h, lp)
phys_page = 0
for b in range(B):
for h in range(H_kv):
for lp in range(N_LOGICAL_PAGES_MAX):
page_table[b, h, lp] = phys_page
phys_page += 1
# fill cached tokens
g = torch.Generator(device=device).manual_seed(1234)
for b in range(B):
Lc = int(L_cache_per_b[b].item())
for h in range(H_kv):
for i in range(Lc):
lp = i // PAGE_SIZE
off = i % PAGE_SIZE
phys = int(page_table[b, h, lp].item())
idx = phys * PAGE_SIZE + off
K_cache[idx] = torch.randn(D, device=device, dtype=dtype, generator=g)
V_cache[idx] = torch.randn(D, device=device, dtype=dtype, generator=g)
return K_cache, V_cache, page_table, seq_lens_bh, CACHE_SIZE
def materialize_kv_for_flash_mixed(
K_cache,
V_cache,
page_table,
L_cache_per_b, # [B]
k_append_raw, # [N, H_kv, D]
v_append_raw, # [N, H_kv, D]
cu_seqlens_qk, # [B+1]
H_kv,
PAGE_SIZE,
):
"""
Build (K_total, V_total, cu_seqlens_k) for flash_attn_varlen_func such that:
For each batch b:
seqlen_q[b] = L_app[b] = cu[b+1] - cu[b]
seqlen_k[b] = L_cache_per_b[b] + L_app[b]
Keys:
- first L_cache_per_b[b] positions from paged cache
- next L_app[b] positions from k_append_raw for that batch
"""
device = K_cache.device
dtype = K_cache.dtype
B = cu_seqlens_qk.numel() - 1
N, H_kv_raw, D = k_append_raw.shape
assert H_kv_raw == H_kv
# appended lengths
L_app = (cu_seqlens_qk[1:] - cu_seqlens_qk[:-1]).to(torch.int32) # [B]
seqlen_k = L_cache_per_b + L_app # [B]
cu_seqlens_k = torch.empty(B + 1, device=device, dtype=torch.int32)
cu_seqlens_k[0] = 0
total_k = int(seqlen_k.sum().item())
K_total = torch.empty((total_k, H_kv, D), device=device, dtype=dtype)
V_total = torch.empty((total_k, H_kv, D), device=device, dtype=dtype)
for b in range(B):
offset_k = int(cu_seqlens_k[b].item())
Lc = int(L_cache_per_b[b].item())
La = int(L_app[b].item())
q_start = int(cu_seqlens_qk[b].item())
# cache segment
for g in range(H_kv):
for i in range(Lc):
lp = i // PAGE_SIZE
off = i % PAGE_SIZE
phys = int(page_table[b, g, lp].item())
idx = phys * PAGE_SIZE + off
K_total[offset_k + i, g] = K_cache[idx]
V_total[offset_k + i, g] = V_cache[idx]
# appended segment
if k_append_raw.numel() > 0:
for g in range(H_kv):
for j in range(La):
src = q_start + j
dst = offset_k + Lc + j
K_total[dst, g] = k_append_raw[src, g]
V_total[dst, g] = v_append_raw[src, g]
cu_seqlens_k[b + 1] = cu_seqlens_k[b] + (Lc + La)
return K_total, V_total, cu_seqlens_k
@pytest.mark.parametrize("workload", WORKLOADS, ids=lambda wl: wl.name)
def test_causal_sparse_varlen_with_cache(workload: Workload):
dtype = torch.float16
device = triton.runtime.driver.active.get_active_torch_device()
DEFAULT_PAGE_SIZE = 256
N_LOGICAL_PAGES_MAX = 256
L_cache_per_b = torch.as_tensor(
workload.cache_lens, device=device, dtype=torch.int32
)
K_cache, V_cache, page_table, seq_lens_bh, CACHE_SIZE = (
build_paged_cache_from_lengths(
B=workload.batch_size,
H_kv=workload.nk_heads,
D=workload.head_dim,
PAGE_SIZE=DEFAULT_PAGE_SIZE,
N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
L_cache_per_b=L_cache_per_b,
device=device,
dtype=dtype,
)
)
assert len(workload.append_lens) == workload.batch_size
cu = [0]
for L in workload.append_lens:
cu.append(cu[-1] + L)
cu_seqlens_qk = torch.tensor(cu, dtype=torch.int32, device=device)
N = int(cu_seqlens_qk[-1].item())
q_raw = torch.randn(
N, workload.nq_heads, workload.head_dim, device=device, dtype=dtype
)
k_append_raw = torch.randn(
N, workload.nk_heads, workload.head_dim, device=device, dtype=dtype
)
v_append_raw = torch.randn_like(k_append_raw)
batch_mapping = torch.arange(workload.batch_size, device=device, dtype=torch.int32)
sm_scale = 1.0 / math.sqrt(workload.head_dim)
K_total, V_total, cu_seqlens_k = materialize_kv_for_flash_mixed(
K_cache=K_cache,
V_cache=V_cache,
page_table=page_table,
L_cache_per_b=L_cache_per_b,
k_append_raw=k_append_raw,
v_append_raw=v_append_raw,
cu_seqlens_qk=cu_seqlens_qk,
H_kv=workload.nk_heads,
PAGE_SIZE=DEFAULT_PAGE_SIZE,
)
max_seqlen_q = int((cu_seqlens_qk[1:] - cu_seqlens_qk[:-1]).max().item())
max_seqlen_k = int((cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item())
max_seqlen_k_triton = seq_lens_bh.max().item()
out_triton = causal_sparse_varlen_with_cache(
q=q_raw,
k_cache=K_cache,
v_cache=V_cache,
k=k_append_raw,
v=v_append_raw,
seq_lens_bh=seq_lens_bh,
global_page_table=page_table,
batch_mapping=batch_mapping,
cu_seqlens_q=cu_seqlens_qk,
HKV=workload.nk_heads,
PAGE_SIZE=DEFAULT_PAGE_SIZE,
sm_scale=sm_scale,
max_seqlen_q=max_seqlen_q,
max_seqlen_k_cache=max_seqlen_k_triton,
)
out_flash = flash_attn_varlen_func(
q=q_raw,
k=K_total,
v=V_total,
cu_seqlens_q=cu_seqlens_qk,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
softmax_scale=sm_scale,
causal=True,
)
assert torch.allclose(out_triton, out_flash, rtol=1e-6, atol=3e-3)
max_diff = (out_triton - out_flash).abs().max().item()
logger.info(
f"[causal_sparse_varlen_with_cache: {workload.name}]: max abs diff={max_diff: .5f}"
)
def materialize_kv_cache_for_flash_decode(
K_cache,
V_cache,
page_table,
L_cache_per_b, # [B] int32
H_kv: int,
PAGE_SIZE: int,
):
"""
Build (K_flash, V_flash) suitable for flash_attn_with_kvcache, with shape:
(B, seqlen_cache_max, H_kv, D)
For each batch b:
- cache_seqlen[b] = L_cache_per_b[b]
- K_flash[b, :cache_seqlen[b], g] and V_flash[...] are filled from the paged KV cache.
- Tokens beyond cache_seqlen[b] (if any) are left as zeros and will be masked out
by flash_attn_with_kvcache via cache_seqlens.
"""
device = K_cache.device
dtype = K_cache.dtype
B = L_cache_per_b.shape[0]
D = K_cache.shape[1]
seqlen_cache_max = int(L_cache_per_b.max().item())
K_flash = torch.zeros((B, seqlen_cache_max, H_kv, D), device=device, dtype=dtype)
V_flash = torch.zeros_like(K_flash)
for b in range(B):
Lc = int(L_cache_per_b[b].item())
if Lc == 0:
continue
for g in range(H_kv):
for i in range(Lc):
lp = i // PAGE_SIZE
off = i % PAGE_SIZE
phys = int(page_table[b, g, lp].item())
idx = phys * PAGE_SIZE + off
K_flash[b, i, g] = K_cache[idx]
V_flash[b, i, g] = V_cache[idx]
return K_flash, V_flash
@pytest.mark.parametrize("workload", WORKLOADS_DECODE, ids=lambda wl: wl.name)
def test_sparse_decode_attention(workload: Workload):
dtype = torch.float16
device = triton.runtime.driver.active.get_active_torch_device()
DEFAULT_PAGE_SIZE = 256
N_LOGICAL_PAGES_MAX = 256
# per-sequence cache lengths (all equal for WORKLOADS_DECODE)
L_cache_per_b = torch.as_tensor(
workload.cache_lens, device=device, dtype=torch.int32
)
# build paged KV cache used by the Triton kernel
K_cache, V_cache, page_table, seq_lens_bh, CACHE_SIZE = (
build_paged_cache_from_lengths(
B=workload.batch_size,
H_kv=workload.nk_heads,
D=workload.head_dim,
PAGE_SIZE=DEFAULT_PAGE_SIZE,
N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
L_cache_per_b=L_cache_per_b,
device=device,
dtype=dtype,
)
)
B = workload.batch_size
HQ = workload.nq_heads
HKV = workload.nk_heads
D = workload.head_dim
# Triton kernel expects q: [B, HQ, D]
q_triton = torch.randn(B, HQ, D, device=device, dtype=dtype)
batch_mapping = torch.arange(B, device=device, dtype=torch.int32)
sm_scale = 1.0 / math.sqrt(D)
out_triton = head_sparse_decode_attention(
q=q_triton,
k=K_cache,
v=V_cache,
seq_lens_bh=seq_lens_bh,
global_page_table=page_table,
batch_mapping=batch_mapping,
HKV=HKV,
PAGE_SIZE=DEFAULT_PAGE_SIZE,
sm_scale=sm_scale,
) # [B, HQ, D]
# materialize contiguous KV cache with shape [B, seqlen_cache_max, HKV, D]
K_flash, V_flash = materialize_kv_cache_for_flash_decode(
K_cache=K_cache,
V_cache=V_cache,
page_table=page_table,
L_cache_per_b=L_cache_per_b,
H_kv=HKV,
PAGE_SIZE=DEFAULT_PAGE_SIZE,
)
# flash_attn_with_kvcache expects q: [B, seqlen_q, HQ, D]
q_flash = q_triton.unsqueeze(1) # seqlen_q = 1
out_flash = flash_attn_with_kvcache(
q=q_flash,
k_cache=K_flash,
v_cache=V_flash,
cache_seqlens=L_cache_per_b,
softmax_scale=sm_scale,
causal=True,
).squeeze(1) # [B, 1, HQ, D]
assert torch.allclose(out_triton, out_flash, rtol=1e-6, atol=3e-3)
max_diff = (out_triton - out_flash).abs().max().item()
logger.info(
f"[head_sparse_decode_attention: {workload.name}]: max abs diff={max_diff: .5f}"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment