"tests/vscode:/vscode.git/clone" did not exist on "10d765482d19abfab6c66b5f815720a66aa9de42"
Commit d29c39ca authored by chenzk's avatar chenzk
Browse files

vllm kvprune wo:v1.1.0

parent f81ce56b
from dataclasses import dataclass, fields
from typing import Type
import torch
from triton.tools.tensor_descriptor import TensorDescriptor
from triton.tools.ragged_tma import create_ragged_descriptor
from .reduction_details.reduce_bitmatrix import clear_sums, sum_bitmatrix_rows
from .target_info import cuda_capability_geq
from .tensor_details.layout import Layout, StridedLayout
@dataclass
class Storage:
data: torch.Tensor
layout: Layout = None
def __post_init__(self):
assert isinstance(self.data, torch.Tensor)
if self.layout is None:
self.layout = StridedLayout(self.data.shape)
@property
def device(self):
return self.data.device
def is_tma_compliant(self):
# TMAs didn't exist until Hopper
if not cuda_capability_geq(9, 0):
return False
# TMAs only exist for 2D, 3D, 5D inputs
if len(self.data.shape) not in [2, 3, 5]:
return False
# TMAs need at most one stride equal to 1
# and all other strides divisble by 16
strides = list(self.data.stride())
try:
major_dim = strides.index(1)
except ValueError:
major_dim = -1
ndim = self.data.ndim
bitwidth = 4 if self.data.dtype == torch.uint8 else self.data.element_size() * 8
compliant = [
strides[i] * bitwidth % 128 == 0 for i in range(ndim) if i != major_dim
]
return all(compliant)
def make_dense_tma(self, block_shape, transpose=False):
strides = list(self.data.stride())
shape = list(self.data.shape)
transpose = self.data.stride()[-1] != 1
if transpose:
block_shape = block_shape[:-2] + [block_shape[-1], block_shape[-2]]
shape = shape[:-2] + [shape[-1], shape[-2]]
strides = strides[:-2] + [strides[-1], strides[-2]]
if self.data.dtype == torch.uint8 and self.layout.name == "BLACKWELL_VALUE":
indx = strides.index(1)
block_shape[indx] = block_shape[indx] // 2
if shape[-1] % 128 != 0:
raise ValueError(
"inner shape need to be multiple of 128 for "
"mxfp4 (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B) TMAs."
)
block_shape = self.layout.swizzle_block_shape(block_shape)
return TensorDescriptor(self.data, shape, strides, block_shape)
def make_tma(self, block_shape, mode, transpose=False):
if mode in ["dense", "gather", "scatter"]:
return self.make_dense_tma(block_shape, transpose)
assert mode == "ragged"
ragged_dim = len(self.data.shape) - 2
return create_ragged_descriptor(self.data, block_shape, ragged_dim=ragged_dim)
@dataclass
class IntegerType:
bitwidth: int
@dataclass
class FloatType:
bitwidth_exponent: int
bitwidth_mantissa: int
is_signed: bool
def __post_init__(self):
self.bitwidth = (
int(self.is_signed) + self.bitwidth_exponent + self.bitwidth_mantissa
)
BIT = IntegerType(1)
FP4 = FloatType(bitwidth_exponent=2, bitwidth_mantissa=1, is_signed=True)
def bitwidth(type: IntegerType | FloatType | torch.dtype):
if isinstance(type, torch.dtype):
return type.itemsize * 8
return type.bitwidth
@dataclass
class Tensor:
storage: Storage | torch.Tensor
dtype: IntegerType | FloatType | torch.dtype = None
shape: list[int] | None = None
shape_max: list[int] | None = None
def __post_init__(self):
# set storage
if isinstance(self.storage, torch.Tensor):
self.storage = Storage(self.storage)
# initialize dtype
if self.dtype is None:
self.dtype = self.storage.data.dtype
if bitwidth(self.dtype) < 8 and self.shape is None:
raise ValueError("shape must be provided for sub-byte types")
# initialize shape
if self.shape is None:
self.shape = list(self.storage.data.shape)
# validate shape: all elements must be `int` or numel-1 `torch.Tensor`
is_int = lambda s: isinstance(s, int)
is_item = lambda s: hasattr(s, "numel") and s.numel() == 1
assert all(map(lambda s: is_int(s) or is_item(s), self.shape))
# initialize shape_max
if self.shape_max is None:
self.shape_max = [None] * len(self.shape)
for i, (s, smax) in enumerate(zip(self.shape, self.shape_max)):
if smax is not None and not is_int(smax):
raise ValueError(
f"shape_max[{i}] must be `int` or `None`; got {type(smax)}"
)
if smax is None:
self.shape_max[i] = s
# validate shape_max: all elements must be `int`
assert all(map(is_int, self.shape_max))
# torch compatibility layer
@property
def ndim(self):
return len(self.shape)
@property
def device(self):
return self.storage.device
def stride(self, i=None):
return self.storage.data.stride() if i is None else self.storage.data.stride(i)
def data_ptr(self):
return self.storage.data.data_ptr()
def numel(self):
return self.storage.data.numel()
def element_size(self):
return bitwidth(self.dtype) // 8
@property
def data(self):
t = self.storage
return t.data if isinstance(t, Storage) else t
def dim(self):
return self.ndim
def size(self, i=None):
if i is None:
return self.shape
return self.shape[i]
@dataclass
class Bitmatrix(Tensor):
"""
Represents a boolean matrix in a packed format where each element occupies
a single bit of memory.
_scratchpad is either None or an all-zero array of size >= shape[-1]; we pass it along
with the actual bitmatrix to avoid having to launch a separate memset
kernel when we call Bitmatrix::sum().
"""
scratchpad: torch.Tensor = None
def __init__(self, storage, shape, shape_max=None, scratchpad=None):
super().__init__(storage, dtype=BIT, shape=shape, shape_max=shape_max)
self.scratchpad = scratchpad
def sum(self, partials_block_size):
_, n_cols = self.shape
dev = self.device
if self.scratchpad is None:
self.scratchpad = clear_sums(n_cols, dev)
out_ret = self.scratchpad[:n_cols]
self.scratchpad = None # throw error if we try to sum again
return sum_bitmatrix_rows(self, out_ret, partials_block_size)
def get_layout(tensor: torch.Tensor | Tensor | None):
if tensor is None:
return None
if isinstance(tensor, Tensor):
return tensor.storage.layout
return StridedLayout
def wrap_torch_tensor(torch_tensor, dtype=None):
if dtype is None:
dtype = torch_tensor.dtype
shape = list(torch_tensor.shape)
shape[torch_tensor.stride().index(1)] *= bitwidth(torch_tensor.dtype) // bitwidth(
dtype
)
return Tensor(Storage(torch_tensor), dtype=dtype, shape=shape)
def convert_layout(tensor: Tensor, layout_cls: Type[Layout], **layout_kwargs):
assert isinstance(tensor, Tensor)
old_storage = tensor.storage
old_data = old_storage.layout.unswizzle_data(old_storage.data)
new_layout = layout_cls(old_data.shape, **layout_kwargs)
new_data = new_layout.swizzle_data(old_data)
attrs = {
k.name: getattr(tensor, k.name) for k in fields(tensor) if k.name != "storage"
}
return Tensor(Storage(new_data, new_layout), **attrs)
from .layout_details.base import Layout
from .layout_details.blackwell_scale import BlackwellMXScaleLayout
from .layout_details.blackwell_value import BlackwellMXValueLayout
from .layout_details.hopper_scale import HopperMXScaleLayout
from .layout_details.hopper_value import HopperMXValueLayout
from .layout_details.cdna4_scale import CDNA4MXScaleLayout
from .layout_details.strided import StridedLayout
from ..target_info import cuda_capability_geq, is_hip_cdna4
__all__ = [
"Layout",
"BlackwellMXValueLayout",
"BlackwellMXScaleLayout",
"HopperMXScaleLayout",
"HopperMXValueLayout",
"CDNA4MXScaleLayout",
"StridedLayout",
]
def make_default_matmul_mxfp4_w_layout(mx_axis: int):
if cuda_capability_geq(10):
# return StridedLayout, dict()
return BlackwellMXValueLayout, dict()
elif cuda_capability_geq(9):
return HopperMXValueLayout, {"mx_axis": mx_axis}
else:
return StridedLayout, dict()
def make_default_matmul_mxfp4_w_scale_layout(mx_axis: int, num_warps: int = 8):
if is_hip_cdna4():
return CDNA4MXScaleLayout, dict()
else:
if cuda_capability_geq(10):
return BlackwellMXScaleLayout, dict()
elif cuda_capability_geq(9):
return HopperMXScaleLayout, {"mx_axis": mx_axis, "num_warps": num_warps}
return StridedLayout, dict()
from abc import ABC, abstractmethod
class Layout(ABC):
def __init__(self, shape) -> None:
self.initial_shape = shape
@abstractmethod
def swizzle_data(self, data):
pass
@abstractmethod
def unswizzle_data(self, data):
pass
@abstractmethod
def swizzle_block_shape(self, block_shape):
pass
import math
import triton
import triton.language as tl
import torch
from .base import Layout
SWIZZLE_ALIGN_INNER = 8
SWIZZLE_SIZE_INNER = 4
SWIZZLE_SIZE_OUTER = 128
class BlackwellMXScaleLayout(Layout):
name: str = "BLACKWELL_SCALE"
def __init__(self, shape) -> None:
super().__init__(shape)
(
*self.leading_shape,
self.K,
self.N,
) = shape
self.B = math.prod(self.leading_shape)
self.ALIGN_K = 8
self.ALIGN_N = 128
self.SWIZZLE_K = 4
self.K_pad = (self.K + self.ALIGN_K - 1) // self.ALIGN_K * self.ALIGN_K
self.N_pad = (self.N + self.ALIGN_N - 1) // self.ALIGN_N * self.ALIGN_N
def swizzle_data(self, data):
data = torch.nn.functional.pad(
data, (0, self.N_pad - self.N, 0, self.K_pad - self.K)
)
data = data.transpose(-1, -2).contiguous()
data = data.reshape(
self.B,
self.N_pad // self.ALIGN_N,
self.ALIGN_N // 32,
32,
self.K_pad // self.SWIZZLE_K,
self.SWIZZLE_K,
)
data = data.transpose(2, 4).contiguous()
data = data.view(1, self.B * self.N_pad // 128, self.K_pad // 4, 2, 256)
return data
def unswizzle_data(self, data):
data = data.reshape(
self.B,
self.N_pad // self.ALIGN_N,
self.K_pad // self.SWIZZLE_K,
32,
self.ALIGN_N // 32,
self.SWIZZLE_K,
)
data = data.transpose(2, 4)
data = data.reshape(*self.leading_shape, self.N_pad, self.K_pad)
data = data.transpose(-1, -2)
return data[..., : self.K, : self.N]
def swizzle_block_shape(self, block_shape):
MX_PACK_DIVISOR = 32
MX_SCALE_BLOCK_K = block_shape[1] // MX_PACK_DIVISOR
return [1, block_shape[0] // 128, MX_SCALE_BLOCK_K // 4, 2, 256]
@triton.jit
def unswizzle_mx_scale_bw(
x,
SIZE_OUTER: tl.constexpr = SWIZZLE_SIZE_OUTER,
SIZE_INNER: tl.constexpr = SWIZZLE_SIZE_INNER,
ALIGN_INNER: tl.constexpr = SWIZZLE_ALIGN_INNER,
):
shape_0: tl.constexpr = x.shape[0]
shape_1: tl.constexpr = x.shape[1]
tl.static_assert(shape_1 % SIZE_OUTER == 0)
tl.static_assert(shape_1 // SIZE_OUTER <= ALIGN_INNER)
x = x.reshape(
shape_0, (shape_1 // SIZE_OUTER) // SIZE_INNER, 32, SIZE_OUTER // 32, SIZE_INNER
)
x = x.trans(0, 3, 2, 1, 4).reshape(shape_0 * SIZE_OUTER, shape_1 // SIZE_OUTER)
return x
import torch
from .base import Layout
class BlackwellMXValueLayout(Layout):
name: str = "BLACKWELL_VALUE"
def __init__(self, shape) -> None:
super().__init__(shape)
self.shape = shape
def swizzle_data(self, data):
# permutation needed to make `data` row major
to_row_major = sorted(range(data.ndim), key=lambda d: (data.stride(d), d))[::-1]
# permutation needed to retrieve original order
inv = [0] * data.ndim
for i, d in enumerate(to_row_major):
inv[d] = i
# leading dimension must be padded to be aligned to 128
align_dim = lambda x: (x + 128 - 1) // 128 * 128
major_dim = data.stride().index(1)
pad = align_dim(data.shape[major_dim]) - data.shape[major_dim]
data = torch.nn.functional.pad(data.permute(to_row_major), (0, pad)).permute(
inv
)
return data
def unswizzle_data(self, data: torch.Tensor):
# Trim padding along all dims back to the original shape recorded at init.
assert data.ndim == len(self.shape), (
"Rank mismatch between data and recorded shape"
)
sizes = [min(data.size(i), self.shape[i]) for i in range(data.ndim)]
return data[tuple(slice(0, s) for s in sizes)]
def swizzle_block_shape(self, block_shape):
return block_shape
import triton
import triton.language as tl
from .base import Layout
NON_K_PRESHUFFLE_BLOCK_SIZE = 32
class CDNA4MXScaleLayout(Layout):
name: str = "CDNA4_SCALE"
def __init__(self, shape) -> None:
super().__init__(shape)
def swizzle_data(self, data):
block_shape = data.shape
SCALE_K = block_shape[-2]
N = block_shape[-1]
data = data.transpose(-1, -2)
data = data.view(
-1, N // NON_K_PRESHUFFLE_BLOCK_SIZE, 2, 16, SCALE_K // 8, 2, 4, 1
)
data = data.permute(0, 1, 4, 6, 3, 5, 2, 7).contiguous()
if len(block_shape) == 3:
E = block_shape[0]
data = data.reshape(E, N // 32, SCALE_K * 32)
else:
assert len(block_shape) == 2
data = data.reshape(N // 32, SCALE_K * 32)
return data.transpose(-1, -2)
def unswizzle_data(self, data):
raise NotImplementedError()
def swizzle_block_shape(self, block_shape):
SCALE_K = block_shape[-2]
N = block_shape[-1]
return block_shape[:-2] + [N // 32, SCALE_K * 32]
@triton.jit
def unswizzle_mx_scale_cdna4(
x,
BLOCK_N: tl.constexpr,
MX_SCALE_BLOCK_K: tl.constexpr,
N_PRESHUFFLE_FACTOR: tl.constexpr = NON_K_PRESHUFFLE_BLOCK_SIZE,
):
x = x.reshape(BLOCK_N // N_PRESHUFFLE_FACTOR, MX_SCALE_BLOCK_K // 8, 4, 16, 2, 2, 1)
x = x.permute(0, 5, 3, 1, 4, 2, 6)
x = x.reshape(BLOCK_N, MX_SCALE_BLOCK_K)
return x
import torch
import triton
import triton.language as tl
from .base import Layout
class HopperMXScaleLayout(Layout):
name: str = "HOPPER_SCALE"
def __init__(self, shape, mx_axis, num_warps=8) -> None:
assert num_warps & (num_warps - 1) == 0, "warps_n must be a power of 2"
super().__init__(shape)
self.mx_axis = mx_axis
self.num_warps = num_warps
*self.leading_shape, _, _ = shape
def _maybe_mT(self, data):
if self.mx_axis == len(self.leading_shape):
return data.contiguous().mT
return data
def swizzle_data(self, data):
data = self._maybe_mT(data).contiguous()
*batch, M, K = data.shape
SWIZZLE_ALIGN_M = 2 * self.num_warps * 2 * 8
SWIZZLE_ALIGN_K = 2
pad_m = (SWIZZLE_ALIGN_M - (M % SWIZZLE_ALIGN_M)) % SWIZZLE_ALIGN_M
pad_k = (SWIZZLE_ALIGN_K - (K % SWIZZLE_ALIGN_K)) % SWIZZLE_ALIGN_K
data = torch.nn.functional.pad(data, (0, pad_k, 0, pad_m))
*batch, M, K = data.shape
assert data.is_contiguous()
assert M % (2 * self.num_warps * 2 * 8) == 0 and K % 2 == 0, (
f"Input tensor must have a subtile of shape (..., {2 * self.num_warps * 2 * 8}, 2)"
)
b = len(batch)
data = data.reshape(
*batch,
M // (2 * self.num_warps * 2 * 8),
2,
self.num_warps,
2,
8,
K // 2,
2,
)
perm = [0, 2, 5, 1, 4, 6, 3]
perm = list(range(b)) + [b + p for p in perm]
data = data.permute(*perm)
data = data.flatten(-5, -1)
data = data.flatten(-3, -2)
assert data.shape[-2] == M // 32
assert data.shape[-1] == K * 32
data = self._maybe_mT(data)
return data
def unswizzle_data(self, data):
data = self._maybe_mT(data)
*batch, M, K = data.shape
b = len(batch)
data = data.reshape(
*batch, M // self.num_warps, self.num_warps, K // 64, 2, 8, 2, 2
)
perm = [0, 3, 1, 6, 4, 2, 5]
perm = list(range(b)) + [b + p for p in perm]
data = data.permute(*perm)
data = data.reshape(*batch, M * 32, K // 32)
data = self._maybe_mT(data)
return data
def swizzle_block_shape(self, block_shape):
return block_shape
@triton.jit
def unswizzle_mxfp4_scale_hopper(x, mx_axis: tl.constexpr, num_warps: tl.constexpr):
"""
Triton inverse of swizzle_mxfp4_scale_hopper
"""
tl.static_assert(len(x.shape) == 2, "NYI")
# implementation assumes mxfp data is packed along the last dimension
x = x.trans() if mx_axis == 0 else x
M: tl.constexpr = x.shape[0]
K: tl.constexpr = x.shape[1]
tl.static_assert(M % num_warps == 0, f"M must be divisible by {num_warps}. Got {M}")
tl.static_assert(K % 64 == 0, f"K must be divisible by 64. Got {K}")
x = x.reshape(M // num_warps, num_warps, K // 64, 2, 8, 2, 2)
x = x.trans(0, 3, 1, 6, 4, 2, 5)
x = x.reshape(M * 32, K // 32)
# implementation assumed mxfp data is packed along the last dimension
x = x.trans() if mx_axis == 0 else x
return x
import torch
import triton
import triton.language as tl
from .base import Layout
def right_shift_unsigned(x, shift):
return (x >> shift) & ((1 << (32 - shift)) - 1)
# -----------------------------------------------------------------------
# Interleave the bits of four consecutive fp4 values (i.e. 16-bits) as:
# 1000000111000000 (first fp4)
# 1000000111000000 (second fp4)
# 1000000111000000 (third fp4)
# 0110110000000000 (fourth fp4)
# This is done so that dequantization can be done in 14 SASS instructions
# -----------------------------------------------------------------------
def _compress_fp4(x):
x = x.to(torch.int32)
return ((x & 0x8) << 12) | ((x & 0x7) << 6)
def _compress_fourth(x):
x = x.to(torch.int32)
return ((x & 0x8) << 11) | ((x & 0x6) << 9) | ((x & 0x1) << 13)
def _pack_bits(x: torch.Tensor, mx_axis: int):
x = x.contiguous()
assert x.shape[-1] % 4 == 0, (
"Input tensor must have a last dimension divisible by 4"
)
x = x.reshape(x.shape[:-1] + (x.shape[-1] // 4, 4))
first = _compress_fp4(x[..., 0]) | (_compress_fp4(x[..., 0] >> 4) << 16)
second = _compress_fp4(x[..., 1]) | (_compress_fp4(x[..., 1] >> 4) << 16)
third = _compress_fp4(x[..., 2]) | (_compress_fp4(x[..., 2] >> 4) << 16)
fourth = _compress_fourth(x[..., 3]) | (_compress_fourth(x[..., 3] >> 4) << 16)
x = (
first
| right_shift_unsigned(second, 3)
| right_shift_unsigned(third, 6)
| fourth
)
assert x.is_contiguous()
x = x.view(torch.uint8)
return x
# -----------------------------------------------------------------------
# inverse operation of _pack_bits
# -----------------------------------------------------------------------
def _bf16_to_fp4e2m1(x):
# 0bAxxxxxxBCDxxxxxx (int16) -> 0b0000ABCD (uint8)
assert x.dtype == torch.int16
s = (right_shift_unsigned(x, 15) & 0x1) << 3
em = right_shift_unsigned(x, 6) & 0x7
return (s | em).to(torch.uint8)
def _bf16x2_to_fp4e2m1x2(x):
# 0bAxxxxxxBCDxxxxxx_0bExxxxxxFGHxxxxxx (int32) -> 0bABCD_EFGH (uint8)
assert x.dtype == torch.int32
lo = (x & 0xFFFF).to(torch.int16)
hi = (right_shift_unsigned(x, 16) & 0xFFFF).to(torch.int16)
ret_lo = _bf16_to_fp4e2m1(lo)
ret_hi = _bf16_to_fp4e2m1(hi)
return ret_lo | (ret_hi << 4)
def _unpack_bits(x, mx_axis: int):
x = x.view(torch.int32)
m = 0b10000001110000001000000111000000
a = (x << 1) & 0b10000000000000001000000000000000
b = right_shift_unsigned(x, 3) & 0b00000001100000000000000110000000
c = right_shift_unsigned(x, 7) & 0b00000000010000000000000001000000
unpacked = [x & m, (x << 3) & m, (x << 6) & m, (a | b) | c]
x = torch.stack(unpacked, dim=-1)
x = x.flatten(-2, -1)
x = _bf16x2_to_fp4e2m1x2(x)
return x
# -----------------------------------------------------------------------
class HopperMXValueLayout(Layout):
name: str = "HOPPER_VALUE"
def __init__(self, shape, mx_axis, mma_version=3):
super().__init__(shape)
assert mx_axis in range(len(shape))
self.mx_axis = mx_axis
self.mma_version = mma_version
(
*self.leading_shape,
self.K,
self.N,
) = shape
def _maybe_mT(self, data):
if self.mx_axis == len(self.leading_shape):
return data.mT
return data
def swizzle_data(self, data):
"""
Given a uint8 tensor of shape (*, M, K), returns a tensor of shape
(*, M // 4, K * 4) such that:
1) Groups contiguously all the elements owned by the same thread of 4
mma tiles along the K axis. The following animation shows a similar
grouping for 2 tiles along M and 2 tiles along K rather than 4 along K
as done here:
https://neuralmagic.com/wp-content/uploads/2024/10/animation_4.gif
2) Moves the elements belonging to thread 4-7 to be contiguous with those
from thread 0-3. This is done to get a full cache line when loading them
from HBM.
mx_axis selects the lhs or rhs of the matmul.
WARNING: Assumes that the matmul will be done in bf16 or fp16!
Implementing it for fp8 is as easy as making the tile size (8, 8)
"""
batch = data.ndim - 2
assert batch >= 0
assert self.mma_version in (2, 3)
data = self._maybe_mT(data)
init_shape = data.shape
# We are loading 8 bf16 elements per thread to use ld.global.v4
# Every u8 represents 2 mxfp4 elements
u8_kwidth = 8 // 2 if self.mma_version == 2 else 1
# Pack the 4 // u8_kwidth subtiles of an mma into a u4x8
contig = (1, u8_kwidth)
scott_trick = (2, 1)
threads = (4, 4)
warp_tile = (2, 2)
k_tile = (1, 4 // u8_kwidth)
sizes = list(data.shape[:-2])
pads = []
# [rest, K, tile, threads] per dimension
for i, (a, b, c, s, d) in enumerate(
zip(k_tile, warp_tile, threads, scott_trick, contig)
):
pack = a * b * c * s * d
size = data.shape[batch + i]
pad = (pack - size % pack) % pack
pads += [(0, pad)]
sizes.append((size + pad) // pack)
sizes += [a, b, c, s, d]
pads = tuple(x for t in pads[::-1] for x in t)
data = torch.nn.functional.pad(data, pads)
init_shape = data.shape
# 0: rest[0]
# 1: k_tile[0]
# 2: warp_tile[0]
# 3: threads[0]
# 4: scott_trick[0]
# 5: contig[0]
# 6: rest[1]
# 7: k_tile[1]
# 8: warp_tile[1]
# 9: threads[1]
# 10: scott_trick[1]
# 11: contig[1]
data = data.view(*sizes)
# Want [rest[0], threads[0], rest[1], scott_trick[0], scott_trick[0], threads[1], contig[1], contig[0], k_tile[1], k_tile[0], warp_tile[1], warp_tile[0]]
perm = [0, 3, 6, 10, 4, 9, 7, 1, 8, 2, 5, 11]
perm = list(range(batch)) + [batch + p for p in perm]
data = data.permute(*perm).contiguous()
# These are views
data = data.flatten(-10, -1)
data = data.flatten(-3, -2)
assert data.is_contiguous()
assert data.shape[-2] == init_shape[-2] // 4
assert data.shape[-1] == init_shape[-1] * 4
# twiddle the bits
data = _pack_bits(data, self.mx_axis)
data = self._maybe_mT(data)
return data
def unswizzle_data(self, data):
data = self._maybe_mT(data)
data = _unpack_bits(data, self.mx_axis)
*batch, M, K = data.shape
# We have two times the elements if we already upcasted to bfloat16
mult = 2 if data.dtype == torch.bfloat16 else 1
assert M % 4 == 0, "M must be divisible by 4"
assert K % (4 * 8 * 2 * 2 * mult) == 0, (
f"K must be divisible by {4 * 8 * 2 * 2 * mult}"
)
# We are loading 8 bf16 elements per thread to use ld.global.v4
# Every u8 represents 2 mxfp4 elements
u8_kwidth = 8 // 2 if self.mma_version == 2 else 1
data = data.reshape(
*batch,
M // 4,
4,
K // (4 * 8 * 2 * 2 * mult),
2,
4,
8 // u8_kwidth,
2,
u8_kwidth * mult,
)
b = len(batch)
perm = [0, 6, 1, 3, 2, 5, 4, 7]
perm = list(range(b)) + [b + p for p in perm]
data = data.permute(*perm)
data = data.reshape(*batch, M * 4, K // 4)
data = self._maybe_mT(data)
return data[..., : self.K, : self.N]
def swizzle_block_shape(self, block_shape):
return block_shape
@triton.jit
def _unshuffle_triton(x, mma_version: tl.constexpr):
"""
Triton inverse of swizzle_mxfp4_value_hopper
"""
tl.static_assert(mma_version == 2 or mma_version == 3, "mma_version must be 2 or 3")
# if mx_axis == 0:
# x = x.trans()
# We have two times the elements if we already upcasted to bfloat16
mult: tl.constexpr = 2 if x.dtype == tl.bfloat16 else 1
M: tl.constexpr = x.shape[0]
K: tl.constexpr = x.shape[1]
tl.static_assert(M % 4 == 0, "M must be divisible by 4")
tl.static_assert(
K % (4 * 8 * 2 * 2 * mult) == 0,
f"K must be divisible by {4 * 8 * 2 * 2 * mult}",
)
# We are loading 8 bf16 elements per thread to use ld.global.v4
# Every u8 represents 2 mxfp4 elements
u8_kwidth: tl.constexpr = 8 // 2 if mma_version == 2 else 1
x = x.reshape(
M // 4,
4,
K // (4 * 8 * 2 * 2 * mult),
2,
4,
8 // u8_kwidth,
2,
u8_kwidth * mult,
)
x = x.trans(0, 6, 1, 3, 2, 5, 4, 7)
x = x.reshape(M * 4, K // 4)
# if mx_axis == 0:
# x = x.trans()
return x
@triton.jit
def _unpack_fp4_to_bf16_triton(x):
# For now we implement just H100 support (mul.bf16x2)
# A100 support is possible via fma
r0, r1 = tl.inline_asm_elementwise(
r"""
{
.reg .b32 b, c, d<7>, scale;
.reg .b32 bias;
mov.b32 bias, 0x7e807e80; // 2 ** 126 == 2 ** (bias_bf16 - bias_fp2)
// We add the missing bias to the scale directly
and.b32 $0, $4, 0b10000001110000001000000111000000;
mul.bf16x2 $0, $0, bias;
shl.b32 b, $4, 3;
and.b32 $1, b, 0b10000001110000001000000111000000;
mul.bf16x2 $1, $1, bias;
shl.b32 c, $4, 6;
and.b32 $2, c, 0b10000001110000001000000111000000;
mul.bf16x2 $2, $2, bias;
// Unpack last two elements
shl.b32 d0, $4, 1;
and.b32 d1, d0, 0b10000000000000001000000000000000;
shr.b32 d2, $4, 3;
and.b32 d3, d2, 0b00000001100000000000000110000000;
or.b32 d4, d1, d3;
shr.b32 d5, $4, 7;
and.b32 d6, d5, 0b00000000010000000000000001000000;
or.b32 $3, d4, d6;
mul.bf16x2 $3, $3, bias;
}
""",
constraints="=r,=r,=r,=r,r",
args=[x],
dtype=(tl.bfloat16, tl.bfloat16),
is_pure=True,
pack=4,
)
# Concat each pack of 4
x = tl.join(r0, r1)
x = x.reshape(x.shape[0], x.shape[1] // 4, 4, x.shape[2])
x = x.trans(0, 1, 3, 2)
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])
return x
@triton.jit
def mxfp4_to_bf16_triton(x, scale, mx_axis: tl.constexpr):
"""
Implements the bit-untwiddling of a 32-bit integer (8 mxfp4 elements):
(x << 0) & 0b1000000111000000
(x << 3) & 0b1000000111000000
(x << 6) & 0b1000000111000000
((x << 1) & 0b1000000000000000) | ((x >> 3) & 0b0000000110000000) | ((x >> 7) & 0b0000000001000000)
"""
# upcast values to bfloat16
tl.static_assert(len(x.shape) == 2)
tl.static_assert(mx_axis == 0 or mx_axis == 1, "mx_axis must be 0 or 1")
tl.static_assert(x.shape[1] % 4 == 0)
tl.static_assert(x.dtype == tl.uint8)
if mx_axis == 0:
x = x.trans()
x = _unpack_fp4_to_bf16_triton(x)
x = _unshuffle_triton(x, mma_version=3)
if mx_axis == 0:
x = x.trans()
# upcast scale to bfloat16
# Add bias missing from the bf16 upcasting sequence
# triton / LLVM generates terrible code for this sequence
# scale = scale.to(tl.uint16)
# scale = scale << 7
# scale = scale.to(tl.bfloat16, bitcast=True)
scale = tl.inline_asm_elementwise(
r"""
{
prmt.b32 $0, $2, 0, 0x5140;
shl.b32 $0, $0, 7;
prmt.b32 $1, $2, 0, 0x7362;
shl.b32 $1, $1, 7;
}
""",
constraints="=r,=r,r",
args=[scale],
dtype=tl.bfloat16,
is_pure=True,
pack=4,
)
# Broadcast scale
scale = scale.expand_dims(mx_axis + 1)
scale = scale.broadcast_to(
scale.shape[: mx_axis + 1] + [32] + scale.shape[mx_axis + 2 :]
)
scale = scale.reshape(x.shape)
# Combine scale and x
x = x * scale
return x
from .base import Layout
class StridedLayout(Layout):
name: str = None
def __init__(self, shape) -> None:
super().__init__(shape)
def swizzle_data(self, data):
return data
def unswizzle_data(self, data):
return data
def swizzle_block_shape(self, block_shape):
return block_shape
import enum
import functools
import os
import subprocess
import sys
import torch
from vllm.kvprune.triton_kernels.numerics import (
MAX_FINITE_FLOAT8E4B8,
MAX_FINITE_FLOAT8E4NV,
MAX_FINITE_FLOAT8E5,
)
def assert_equal(ref, tri):
if isinstance(ref, torch.Tensor):
assert torch.all(ref == tri)
else:
assert ref == tri
def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True):
if tri.dtype.itemsize == 1:
ref_as_type = ref.to(tri.dtype)
if ref.dtype == tri.dtype:
assert torch.all(ref_as_type == tri)
return
ref = ref_as_type
if ref.numel() == 0:
return
if maxtol is None:
maxtol = 2e-2
if rmstol is None:
rmstol = 4e-3
"""
Compare reference values against obtained values.
"""
# cast to float32:
ref = ref.to(torch.float32).detach()
tri = tri.to(torch.float32).detach()
assert ref.shape == tri.shape, (
f"Tensors must have same size {ref.shape=} {tri.shape=}"
)
# deal with infinite elements:
inf_mask_ref = torch.isinf(ref)
inf_mask_tri = torch.isinf(tri)
assert torch.equal(inf_mask_ref, inf_mask_tri), (
"Tensor must have same infinite elements"
)
refn = torch.where(inf_mask_ref, 0, ref)
trin = torch.where(inf_mask_tri, 0, tri)
# normalise so that RMS calculation doesn't overflow:
eps = 1.0e-30
multiplier = 1.0 / (torch.max(torch.abs(refn)) + eps)
refn *= multiplier
trin *= multiplier
ref_rms = torch.sqrt(torch.square(refn).mean()) + eps
rel_err = torch.abs(refn - trin) / torch.maximum(ref_rms, torch.abs(refn))
max_err = torch.max(rel_err).item()
rms_err = torch.sqrt(torch.square(rel_err).mean()).item()
if verbose:
print(
"%s maximum relative error = %s (threshold = %s)"
% (description, max_err, maxtol)
)
print(
"%s RMS relative error = %s (threshold = %s)"
% (description, rms_err, rmstol)
)
if max_err > maxtol:
bad_idxs = torch.nonzero(rel_err > maxtol)
num_nonzero = bad_idxs.size(0)
bad_idxs = bad_idxs[:1000]
print(
"%d / %d mismatched elements (shape = %s) at coords %s"
% (num_nonzero, rel_err.numel(), tuple(rel_err.shape), bad_idxs.tolist())
)
bad_idxs = bad_idxs.unbind(-1)
print("ref values: ", ref[tuple(bad_idxs)].cpu())
print("tri values: ", tri[tuple(bad_idxs)].cpu())
assert max_err <= maxtol
assert rms_err <= rmstol
class ComputeSanitizerTool(enum.Enum):
MEMCHECK = "memcheck"
RACECHECK = "racecheck"
SYNCCHECK = "synccheck"
INITCHECK = "initcheck"
def compute_sanitizer(**target_kwargs):
"""
Decorator to run a test with compute sanitizer enabled and pytorch caching allocator disabled,
to expose potential memory access errors.
This decorator requires the `request` fixture to be present.
If `run_sanitizer` argument is present and set to False, the sanitizer is not run.
Running tests under compute sanitizer requires launching subprocess and is slow,
so use sparingly
"""
def decorator(test_fn):
@functools.wraps(test_fn)
def wrapper(*args, **kwargs):
if os.environ.get("SKIP_COMPUTE_SANITIZER") == "1":
test_fn(*args, **kwargs)
return
import psutil
if target_kwargs.pop("clear_torch_cache", False):
# If we don't pop clear_torch_cache, it won't pass
# target_kwargs.items() <= kwargs.items() condition below.
torch.cuda.empty_cache()
tools_to_check = target_kwargs.pop(
"tools_to_check", [ComputeSanitizerTool.MEMCHECK]
)
assert isinstance(tools_to_check, list), f"{tools_to_check=}"
assert all(tool in ComputeSanitizerTool for tool in tools_to_check), (
f"{(tool for tool in tools_to_check if tool not in ComputeSanitizerTool)=}"
)
ppid_name = psutil.Process(os.getppid()).exe()
run_compute_sanitizer = target_kwargs.items() <= kwargs.items()
if "run_sanitizer" in kwargs:
run_compute_sanitizer &= kwargs["run_sanitizer"]
if run_compute_sanitizer and "compute-sanitizer" not in ppid_name:
for tool in tools_to_check:
path = os.path.realpath(test_fn.__globals__["__file__"])
# get path of current file
env = {
"PATH": os.environ["PATH"],
"PYTORCH_NO_CUDA_MEMORY_CACHING": "1",
"TORCH_SHOW_CPP_STACKTRACES": "1",
"CUDA_LAUNCH_BLOCKING": "1",
}
if "CUDA_VISIBLE_DEVICES" in os.environ:
env["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"]
assert "request_fixture" in kwargs, (
"memcheck'ed test must have a (possibly unused) `request` fixture"
)
test_id = kwargs["request_fixture"].node.callspec.id
cmd = f"{path}::{test_fn.__name__}[{test_id}]"
cmd = [
"compute-sanitizer",
"--target-processes=application-only",
"--destroy-on-device-error=context",
f"--tool={tool.value}",
sys.executable,
"-m",
"pytest",
"-vsx",
cmd,
]
for opt in ["--update_checksum", "--ignore_checksum_error"]:
if opt in sys.argv:
cmd.append(opt)
out = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
env=env,
)
sanitizer_ok = "ERROR SUMMARY: 0 errors" in str(
out.stdout
) or "RACECHECK SUMMARY: 0 hazards displayed" in str(out.stdout)
test_output = out.stdout
if type(test_output) is bytes:
test_output = test_output.decode()
fail = False
if not sanitizer_ok:
print("compute-sanitizer returned an error")
fail = True
elif out.returncode != 0:
print(
"The test failed due to some other reason: consider running without compute-sanitizer to verify."
)
print(f"{out.returncode=}")
fail = True
if fail:
print("*****************************************************")
print("******************** TEST OUTPUT ********************")
print("*****************************************************")
print(test_output)
print("*****************************************************")
print("****************** TEST OUTPUT END ******************")
print("*****************************************************")
assert None
else:
test_fn(*args, **kwargs)
return wrapper
return decorator
def compute_actual_scale(x, dtype):
max_finite = {
torch.float8_e5m2: MAX_FINITE_FLOAT8E5,
torch.float8_e4m3fn: MAX_FINITE_FLOAT8E4NV,
torch.float8_e4m3fnuz: MAX_FINITE_FLOAT8E4B8,
}[dtype]
return x.abs().max() / max_finite
import torch
import triton
from vllm.kvprune.triton_kernels.topk_details._topk_forward import _topk_forward
from vllm.kvprune.triton_kernels.topk_details import _topk_backward
from vllm.kvprune.triton_kernels.tensor import Tensor, Bitmatrix
from typing import Optional, Union
def topk_forward(
x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, n_rows=None
):
if not isinstance(x, Tensor):
x_shape = [x.shape[0] if n_rows is None else n_rows, x.shape[1]]
x_shape_max = [x.shape[0], x.shape[1]]
x = Tensor(x, shape=x_shape, shape_max=x_shape_max)
cdiv = lambda a, b: (a + b - 1) // b
BLOCK_M = 32
BLOCK_N = 32
BLOCK_S = 128
assert len(x.shape) == 2
assert x.shape_max[-1] < 32768
assert dim == 1
assert return_bitmatrix
n_rows, n_cols = x.shape
n_rows_max, _ = x.shape_max
dev = x.device
# scratchpad tensors
# NOTE: these are not returned
y_vals = torch.empty((n_rows_max, k), dtype=x.dtype, device=dev)
if y_indx is not None:
use_provided_indx = True
else:
y_indx = torch.empty((n_rows_max, k), dtype=torch.int16, device=dev)
use_provided_indx = False
# create bitmatrix in transposed memory layout:
n_cols_pad = cdiv(n_cols, BLOCK_N) * BLOCK_N
n_cols_words = n_cols_pad // 32
bitmatrix = torch.empty(
(n_cols_words, cdiv(n_rows_max, 32) * 32), dtype=torch.uint32, device=dev
)
bitmatrix = torch.transpose(bitmatrix, 0, 1)[:n_rows_max]
s_blocks = cdiv(n_cols, BLOCK_S)
s_cols = s_blocks * BLOCK_S
scratchpad = torch.empty((s_cols,), dtype=torch.int32, device=dev)
pids = max(cdiv(n_rows_max, BLOCK_M), s_blocks)
_topk_forward[(pids,)](
x,
x.stride(0), # inputs
y_vals,
y_indx,
y_vals.stride(0),
use_provided_indx, # output [topk]
bitmatrix,
bitmatrix.stride(0),
bitmatrix.stride(1), # output [bitmatrix]
n_rows,
n_cols, # shapes
scratchpad,
BLOCK_S,
s_blocks, # thing to memset to zero
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N, # tunable parameter
APPLY_SOFTMAX=apply_softmax,
N_EXPTS_PAD=n_cols_pad,
N_EXPTS_ACT=k, # constants
)
bitmatrix_shape = [n_rows, n_cols_words * 32]
bitmatrix_shape_max = [n_rows_max, None]
bitmatrix = Bitmatrix(
bitmatrix,
shape=bitmatrix_shape,
shape_max=bitmatrix_shape_max,
scratchpad=scratchpad,
)
return y_vals, y_indx, bitmatrix
def topk_backward(x, y_indx, dy_vals, k, n_rows, apply_softmax):
assert dy_vals.shape[-1] == k
n_expts_pad = triton.next_power_of_2(x.shape[-1])
dx = torch.empty_like(x)
_topk_backward[(dy_vals.shape[0],)](
y_indx,
y_indx.stride(0),
dy_vals,
dy_vals.stride(0),
x,
x.stride(0), # inputs
dx, # outputs
dx.stride(0),
x.shape[0],
n_rows,
x.shape[-1],
APPLY_SOFTMAX=apply_softmax,
N_EXPTS_ACT=k,
N_EXPTS_PAD=n_expts_pad,
)
return dx
class TopK(torch.autograd.Function):
@staticmethod
def forward(ctx, x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows):
y_vals, y_indx, bitmatrix = topk_forward(
x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows
)
ctx.save_for_backward(x, y_indx)
ctx.apply_softmax = apply_softmax
ctx.k = k
ctx.n_rows = n_rows
return y_vals, y_indx, bitmatrix
@staticmethod
def backward(ctx, dy_vals, _0, _1):
x, y_indx = ctx.saved_tensors
dx = topk_backward(x, y_indx, dy_vals, ctx.k, ctx.n_rows, ctx.apply_softmax)
return dx, None, None, None, None, None, None
def topk(
x: Union[Tensor, torch.Tensor],
k: int,
apply_softmax: bool = True,
dim: int = 1,
return_bitmatrix: bool = True,
y_indx: Optional[torch.Tensor] = None,
n_rows: Optional[int] = None,
):
"""
Computes the top-k values and indices along a specified dimension of a tensor.
Note that the input can be either a `Tensor` or a `torch.Tensor`, but the output will always be a `torch.Tensor`.
Parameters
----------
x : Union[triton_kernels.Tensor, torch.Tensor]
Input tensor of shape (n_tokens, n_expts).
k : int
Number of top elements to retrieve.
apply_softmax : bool, default True
Whether to apply softmax to the input tensor before computing top-k.
dim : int, default 1
Dimension along which to compute top-k.
return_bitmatrix : bool, default True
A bitmatrix of shape (n_tokens, cdiv(n_expts, 32)).
Each bit on [t, b] indicates whether the b-th expert was selected for the t-th token.
y_indx : torch.Tensor, optional
Pre-allocated tensor for storing indices of top-k elements with shape (n_tokens, k).
If provided, we skip the computation of top-k indices and use this tensor instead.
n_rows : int, optional
Number of rows to apply top-k on. If None, we consider all rows in `x`.
Returns
-------
(expt_scal, expt_indx, bitmatrix) : Tuple[torch.Tensor, torch.Tensor, Bitmatrix]
"""
ret = TopK.apply(x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows)
return ret
import triton
import triton.language as tl
@triton.jit
def _topk_backward(
Yi,
stride_ym, # topk indices
DY,
stride_dym, # output gradient values
X,
stride_xm, # input values
DX,
stride_dxm, # input gradient values
n_rows,
NRows,
n_expts_tot,
APPLY_SOFTMAX: tl.constexpr,
N_EXPTS_ACT: tl.constexpr,
N_EXPTS_PAD: tl.constexpr,
):
pid_m = tl.program_id(0)
if NRows is not None:
n_rows = tl.load(NRows)
if pid_m >= n_rows:
return
Yi += pid_m * stride_ym
DY += pid_m * stride_dym
X += pid_m * stride_xm
DX += pid_m * stride_dxm
# --
offs_xn = tl.arange(0, N_EXPTS_PAD)
offs_yn = tl.arange(0, N_EXPTS_ACT)
mask_xn = offs_xn < n_expts_tot
# recompute softmax
y_indx = tl.load(Yi + offs_yn)
x = tl.load(X + y_indx)
x = x.to(tl.float32)
y = tl.softmax(x)
# compute input-gradient
dy = tl.load(DY + offs_yn)
dy = dy.to(tl.float32)
s = tl.sum(y * dy, 0)
# write-back input gradient
tl.store(DX + offs_xn, 0, mask=mask_xn)
tl.debug_barrier()
if APPLY_SOFTMAX:
dx = y * (dy - s)
else:
dx = dy
tl.store(DX + y_indx, dx)
import triton
import triton.language as tl
@triton.jit
def get_topmask_and_fullmask(x):
tl.static_assert(
x.dtype.is_int_unsigned(), "floating-point value must be passed as bits"
)
tm: tl.constexpr = 1 << (-1 + x.dtype.primitive_bitwidth)
fm: tl.constexpr = (1 << x.dtype.primitive_bitwidth) - 1
tm_arr = tl.full(x.shape, tm, dtype=x.dtype)
fm_arr = tl.full(x.shape, fm, dtype=x.dtype)
return tm_arr, fm_arr
@triton.jit
def fpval_to_key(x):
tm, fm = get_topmask_and_fullmask(x)
return x ^ tl.where((x & tm) != 0, fm, tm)
@triton.jit
def key_to_fpval(x):
tm, fm = get_topmask_and_fullmask(x)
return x ^ tl.where((x & tm) == 0, fm, tm)
# stable top-k tie-breaks to value with smaller index
@triton.jit
def indx_to_key(indx, N_EXPTS_PAD: tl.constexpr):
return N_EXPTS_PAD - indx
@triton.jit
def key_to_indx(indx, N_EXPTS_PAD: tl.constexpr):
return N_EXPTS_PAD - indx
@triton.jit
def streaming_topk(
X,
stride_xm,
n_expts_tot,
offs_m,
mask_m,
N_EXPTS_PAD: tl.constexpr,
N_EXPTS_ACT: tl.constexpr,
BLOCK_N: tl.constexpr,
):
x_nbits: tl.constexpr = X.dtype.element_ty.primitive_bitwidth
x_utype: tl.constexpr = tl.dtype(f"uint{x_nbits}")
if x_nbits < 16:
# this ensures that we leave at least 16 bits for expert index
# even if the input dtype is smaller than 16 bits:
y_nbits: tl.constexpr = 32
else:
y_nbits: tl.constexpr = x_nbits * 2
x_ultype: tl.constexpr = tl.dtype(f"uint{y_nbits}")
x_dtype: tl.constexpr = X.dtype.element_ty
# subtract 1 from loop iterations because we peel the first (masked) iteration:
loop_iterations: tl.constexpr = N_EXPTS_PAD // BLOCK_N - 1
offs_x_n = loop_iterations * BLOCK_N + tl.arange(0, BLOCK_N)
mask_n = offs_x_n[None, :] < n_expts_tot
# first iteration:
X_ptrs = X + offs_m[:, None] * stride_xm + offs_x_n[None, :]
x = tl.load(X_ptrs, mask=(mask_m & mask_n), other=float("-inf"))
x = fpval_to_key(x.to(x_utype, bitcast=True))
x = (x.to(x_ultype) << 16) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :]
acc = tl.topk(x, N_EXPTS_ACT, dim=1)
# subsequent iterations:
for _i in (tl.static_range if loop_iterations <= 4 else range)(loop_iterations):
acc = tl.bitonic_merge(acc) # ensure sorted ascending for the merge
X_ptrs -= BLOCK_N
offs_x_n -= BLOCK_N
x = tl.load(X_ptrs, mask=mask_m, other=float("-inf"))
x = fpval_to_key(x.to(x_utype, bitcast=True))
x = (x.to(x_ultype) << 16) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :]
acc = tl.maximum(acc, tl.topk(x, N_EXPTS_ACT, dim=1))
# rotate expert index into upper 16 bits:
# 0000vvvvvvvviiii --> iiii0000vvvvvvvv
acc = (acc << (y_nbits - 16)) | (acc >> 16)
# sort in ascending order of expert (descending order of key)
acc = tl.sort(acc, dim=1, descending=True)
# iiii0000vvvvvvvv --> 0000iiii:
y_indices_raw = (acc >> (y_nbits - 16)).to(tl.uint32)
y_indices = key_to_indx(y_indices_raw, N_EXPTS_PAD)
# iiii0000vvvvvvvv --> vvvvvvvv:
y_values_raw = acc.to(x_utype)
y_values = key_to_fpval(y_values_raw).to(x_dtype, bitcast=True)
return y_values, y_indices
@triton.jit
def _topk_forward(
X,
stride_xm, # inputs
Yv,
Yi,
stride_ym, # topk values/indices
USE_PROVIDED_INDX: tl.constexpr,
Bits,
stride_rm: tl.constexpr,
stride_rn: tl.constexpr, # bitmatrix
n_rows,
n_expts_tot, # shape
S,
BLOCK_S: tl.constexpr,
s_blocks, # thing to memset
APPLY_SOFTMAX: tl.constexpr, # constant
BLOCK_M: tl.constexpr,
N_EXPTS_PAD: tl.constexpr,
N_EXPTS_ACT: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid = tl.program_id(0)
if isinstance(n_rows, tl.tensor) and n_rows.dtype.is_ptr():
n_rows = tl.load(n_rows)
if pid < s_blocks:
tl.store(
S + BLOCK_S * pid + tl.arange(0, BLOCK_S), tl.zeros([BLOCK_S], tl.int32)
)
if pid * BLOCK_M >= n_rows:
# early exit:
return
tl.static_assert(BLOCK_N % 32 == 0)
tl.static_assert(N_EXPTS_PAD % BLOCK_N == 0)
x_dtype: tl.constexpr = X.dtype.element_ty
# load logits
offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
offs_y_n = tl.arange(0, N_EXPTS_ACT)
mask_m = offs_m[:, None] < n_rows
if USE_PROVIDED_INDX:
Yi_ptrs = Yi + offs_m[:, None] * stride_ym + offs_y_n[None, :]
y_indices = tl.load(Yi_ptrs, mask=mask_m)
Xv_ptrs = X + offs_m[:, None] * stride_xm + y_indices
y_values = tl.load(Xv_ptrs, mask=mask_m)
else:
y_values, y_indices = streaming_topk(
X,
stride_xm,
n_expts_tot,
offs_m,
mask_m, #
N_EXPTS_PAD,
N_EXPTS_ACT,
BLOCK_N,
)
# normalize selected values
if APPLY_SOFTMAX:
y_values = tl.softmax(y_values.to(tl.float32), dim=1, keep_dims=True).to(
x_dtype
)
# write back
Yv_ptrs = Yv + offs_m[:, None] * stride_ym + offs_y_n[None, :]
tl.store(Yv_ptrs, y_values, mask=mask_m)
if not USE_PROVIDED_INDX:
Yi_ptrs = Yi + offs_m[:, None] * stride_ym + offs_y_n[None, :]
tl.store(Yi_ptrs, y_indices, mask=mask_m)
# pack into bitmatrix
y_div = y_indices // 32
y_rem = y_indices % 32
loop_iterations = N_EXPTS_PAD // BLOCK_N
for i in range(loop_iterations):
offs_r_n = tl.arange(0, BLOCK_N // 32) + i * (BLOCK_N // 32)
y2 = tl.where(
y_div[:, :, None] == offs_r_n[None, None, :], (1 << y_rem)[:, :, None], 0
)
r = tl.reduce_or(y2, axis=1)
BitsPtrs = Bits + offs_m[:, None] * stride_rm + offs_r_n[None, :] * stride_rn
tl.store(BitsPtrs, r, mask=mask_m)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Shared helpers: Triton compat, layout bridge, context, sequences."""
from vllm.kvprune.utils.layout_bridge import (
block_table_to_global_page_table,
build_batch_mapping,
build_page_table_head_major,
flatten_kv_cache_head_major,
flatten_kv_cache_plane,
write_head_major_flat_to_interleaved,
)
from vllm.kvprune.utils.triton_compat import (
autotune as triton_autotune,
cuda_capability_geq,
maybe_set_allocator,
)
__all__ = [
"block_table_to_global_page_table",
"build_batch_mapping",
"build_page_table_head_major",
"cuda_capability_geq",
"flatten_kv_cache_head_major",
"flatten_kv_cache_plane",
"write_head_major_flat_to_interleaved",
"maybe_set_allocator",
"triton_autotune",
]
import itertools
import math
from dataclasses import dataclass
from typing import List, Optional
import torch
from vllm.kvprune.compression import CompressionMethod
from vllm.kvprune.compression.compression_config import BatchCompressionParams
from vllm.kvprune.config.engine_config import LLMConfig
from vllm.kvprune.utils.sequence import Sequence
from vllm.kvprune.utils.kv_dist import broadcast_from_tp_rank0
from vllm.kvprune.utils.tp_utils import kv_heads_shard_divisor
@dataclass
class PrefillBatchArguments:
B: int
N: int
do_compression: bool
compression_method: CompressionMethod
compression_chunk_size: int
seq_ids: torch.Tensor
input_ids: torch.Tensor
positions: torch.Tensor
cu_seqlens_q: torch.Tensor
cu_seqlens_k: torch.Tensor
max_seqlen_q: int
max_seqlen_k: int
batch_tokens_to_retain: Optional[torch.Tensor]
max_tokens_to_retain: Optional[int]
protected_first: Optional[List[int]]
protected_last: Optional[List[int]]
PHI: Optional[torch.Tensor]
# args needed for memory reservation
context_lens: torch.Tensor
max_new_tokens: torch.Tensor
# 与 kvpress ``CompactorPress`` blending 默认(未显式指定时用 compression_ratio)对齐
compression_ratio: float = 1.0
class PackedTensorArguments:
def __init__(
self,
rank: int,
max_batched_tokens: int,
config: LLMConfig,
seed: int = 42,
*,
device: torch.device | None = None,
use_tp_group_for_collectives: bool = False,
) -> None:
hf_config = config.hf_config
self.rank = rank
self.device = device if device is not None else torch.device(f"cuda:{rank}")
self._use_tp_group = use_tp_group_for_collectives
self.max_num_batches = config.max_num_seqs
self.max_batched_tokens = max_batched_tokens
_ws = kv_heads_shard_divisor()
self.num_kv_heads = hf_config.num_key_value_heads // _ws
self.world_size = config.tensor_parallel_size
self.page_size = int(config.kvcache_page_size)
self.head_dim = getattr(hf_config, "head_dim", None)
self.sketch_dim = config.leverage_sketch_size
self.model_dtype = hf_config.torch_dtype
# i64 pack = [seq_ids (BMAX)] || [input_ids (NMAX)] || [positions (NMAX)] || max_new_tok (BMAX)
self.i64_len_max = (
self.max_num_batches + 2 * self.max_batched_tokens + self.max_num_batches
)
self.packed_context_i64 = torch.empty(
self.i64_len_max, dtype=torch.int64, device=self.device
)
# i32 pack = [header (6): ... + compression_ratio*1e6] || [cu_q (BMAX+1)] || ...
# || [protected_first_tokens (BMAX)] || [protected_last_tokens (BMAX)]
self.i32_len_max = (
6
+ (self.max_num_batches + 1)
+ (self.max_num_batches + 1)
+ self.max_num_batches
+ self.max_num_batches
+ self.max_num_batches
+ self.max_num_batches
)
self.packed_context_i32 = torch.empty(
self.i32_len_max, dtype=torch.int32, device=self.device
)
self.generator = torch.Generator(device=self.device).manual_seed(seed)
self.PHI = torch.randn(
(self.head_dim, self.sketch_dim),
device=self.packed_context_i32.device,
generator=self.generator,
).to(self.model_dtype) * (1 / math.sqrt(self.sketch_dim))
def _master_build_prefill(
self, seqs: List[Sequence], batch_compression_params: BatchCompressionParams
) -> PrefillBatchArguments:
B = len(seqs)
if B == 0:
raise ValueError(
"prefill batch is empty (scheduler should not call build_prefill with "
"no sequences)"
)
Ls = [x.prompt_len for x in seqs]
N = sum(Ls)
assert N <= self.max_batched_tokens
do_compression = any(x.compression_params.compression_ratio < 1.0 for x in seqs)
do_compression = (
do_compression
and batch_compression_params.compression_method != CompressionMethod.NONE
)
pack_slices_64 = self.packed_i64_slices(B, N)
pack_slices_32 = self.packed_i32_slices(B)
# max_retain = max(retain)
protected_first_list = [
x.compression_params.protected_first_tokens for x in seqs
]
protected_last_list = [x.compression_params.protected_last_tokens for x in seqs]
retain = [
max(
int(
round(
x.compression_params.compression_ratio
* (L - s - e)
* self.num_kv_heads
)
),
1,
)
for s, e, L, x in zip(protected_first_list, protected_last_list, Ls, seqs)
]
retain = torch.tensor(retain, dtype=torch.int32, device="cpu", pin_memory=True)
protected_first = torch.tensor(
protected_first_list, dtype=torch.int32, device="cpu", pin_memory=True
)
protected_last = torch.tensor(
protected_last_list, dtype=torch.int32, device="cpu", pin_memory=True
)
self.packed_context_i32[pack_slices_32["protected_first"]].copy_(
protected_first, non_blocking=True
)
self.packed_context_i32[pack_slices_32["protected_last"]].copy_(
protected_last, non_blocking=True
)
compression_chunk_size = (
batch_compression_params.chunk_size
if batch_compression_params.do_chunked_compression
else -1
)
min_compression_ratio = min(x.compression_params.compression_ratio for x in seqs)
cr_scaled = int(round(float(min_compression_ratio) * 1_000_000.0))
cr_scaled = max(min(cr_scaled, 2_000_000_000), -2_000_000_000)
header_host = torch.tensor(
[
B,
N,
1 if do_compression else 0,
batch_compression_params.compression_method.value,
compression_chunk_size,
cr_scaled,
],
dtype=torch.int32,
device="cpu",
pin_memory=True,
)
self.packed_context_i32[pack_slices_32["retain"]].copy_(
retain, non_blocking=True
)
self.packed_context_i32[pack_slices_32["header"]].copy_(
header_host, non_blocking=True
)
max_seq_qk = max(Ls)
cu = torch.tensor(
list(itertools.accumulate(Ls, initial=0)),
dtype=torch.int32,
device="cpu",
pin_memory=True,
)
self.packed_context_i32[pack_slices_32["cu_q"]].copy_(cu, non_blocking=True)
self.packed_context_i32[pack_slices_32["cu_k"]].copy_(cu, non_blocking=True)
self.packed_context_i32[pack_slices_32["context_lens"]].copy_(
cu.diff(), non_blocking=True
)
seq_ids = torch.tensor(
[x.seq_id for x in seqs], dtype=torch.int64, device="cpu", pin_memory=True
)
input_ids = torch.tensor(
[tid for x in seqs for tid in x.prompt_token_ids],
dtype=torch.int64,
device="cpu",
pin_memory=True,
)
self.packed_context_i64[pack_slices_64["seq_ids"]].copy_(
seq_ids, non_blocking=True
)
self.packed_context_i64[pack_slices_64["input_ids"]].copy_(
input_ids, non_blocking=True
)
positions = torch.cat(
[
torch.arange(L, dtype=torch.int64, device="cpu", pin_memory=True)
for L in Ls
]
)
self.packed_context_i64[pack_slices_64["positions"]].copy_(
positions, non_blocking=True
)
max_new_tokens = torch.tensor(
[seq.sampling_params.max_new_tokens for seq in seqs],
dtype=torch.int64,
device="cpu",
pin_memory=True,
)
self.packed_context_i64[pack_slices_64["max_new_tokens"]].copy_(
max_new_tokens, non_blocking=True
)
# `prefill_store_topk_kv(..., PAD_TO_PAGE_SIZE=True)` may scan beyond the
# top-k prefix to fill per-head lengths up to a page boundary. Using a
# full ranking (top_k = max_seq_len * HKV) makes `torch.topk` degenerate
# into a full sort, which is very expensive for long contexts.
#
# Instead, request only a prefix that is large enough for:
# 1) the maximum "keep" budget in the batch, plus
# 2) a conservative extra window for page-padding candidates.
max_seq_len = int(self.packed_context_i32[pack_slices_32["context_lens"]].max())
full_budget = max_seq_len * self.num_kv_heads
keep_budget = int(retain.max().item())
pad_search_budget = (self.page_size - 1) * (self.num_kv_heads**2)
max_retain = min(full_budget, keep_budget + pad_search_budget)
# Non-blocking H2D copies above must finish before NCCL broadcast, or peers can
# receive stale/garbage packed buffers → wrong prefill → garbage tokens on TP>1.
if self.packed_context_i64.is_cuda:
torch.cuda.synchronize()
# PHI: rank 0's sketch matrix is broadcast so all TP ranks share one PHI for
# leverage / compactor scores (same order as packed_context: i64, i32, PHI).
broadcast_from_tp_rank0(
self.packed_context_i64, use_tp_group=self._use_tp_group
)
broadcast_from_tp_rank0(
self.packed_context_i32, use_tp_group=self._use_tp_group
)
if self.world_size > 1:
broadcast_from_tp_rank0(self.PHI, use_tp_group=self._use_tp_group)
prefill_args = PrefillBatchArguments(
B=B,
N=N,
do_compression=do_compression,
compression_method=batch_compression_params.compression_method,
compression_chunk_size=compression_chunk_size,
seq_ids=self.packed_context_i64[pack_slices_64["seq_ids"]],
input_ids=self.packed_context_i64[pack_slices_64["input_ids"]],
positions=self.packed_context_i64[pack_slices_64["positions"]],
cu_seqlens_q=self.packed_context_i32[pack_slices_32["cu_q"]],
cu_seqlens_k=self.packed_context_i32[pack_slices_32["cu_k"]],
max_seqlen_q=max_seq_qk,
max_seqlen_k=max_seq_qk,
batch_tokens_to_retain=self.packed_context_i32[pack_slices_32["retain"]],
max_tokens_to_retain=max_retain,
PHI=self.PHI,
context_lens=self.packed_context_i32[pack_slices_32["context_lens"]],
max_new_tokens=self.packed_context_i64[pack_slices_64["max_new_tokens"]],
protected_first=protected_first_list,
protected_last=protected_last_list,
compression_ratio=min_compression_ratio,
)
return prefill_args
def _peer_receive_prefill(self) -> PrefillBatchArguments:
broadcast_from_tp_rank0(
self.packed_context_i64, use_tp_group=self._use_tp_group
)
broadcast_from_tp_rank0(
self.packed_context_i32, use_tp_group=self._use_tp_group
)
if self.world_size > 1:
broadcast_from_tp_rank0(self.PHI, use_tp_group=self._use_tp_group)
# Header is 6 fields (B, N, do_compression, method, chunk_size, cr_scaled); must match
# packed_i32_slices(B)["header"] for any B.
header = self.packed_context_i32[:6].tolist()
B, N = int(header[0]), int(header[1])
do_compression = bool(int(header[2]))
compression_method = CompressionMethod(int(header[3]))
compression_chunk_size = int(header[4])
compression_ratio = int(header[5]) / 1_000_000.0
pack_slices_64 = self.packed_i64_slices(B, N)
pack_slices_32 = self.packed_i32_slices(B)
max_seq_len = int(self.packed_context_i32[pack_slices_32["context_lens"]].max())
# Must match _master_build_prefill: max_seqlen_{q,k} = max(Ls), not cu_q.max()
# (which equals total batch tokens N and breaks varlen attention on peers).
full_budget = max_seq_len * self.num_kv_heads
keep_budget = int(self.packed_context_i32[pack_slices_32["retain"]].max().item())
pad_search_budget = (self.page_size - 1) * (self.num_kv_heads**2)
max_retain = min(full_budget, keep_budget + pad_search_budget)
prefill_args = PrefillBatchArguments(
B=B,
N=N,
do_compression=do_compression,
compression_method=compression_method,
compression_chunk_size=compression_chunk_size,
seq_ids=self.packed_context_i64[pack_slices_64["seq_ids"]],
input_ids=self.packed_context_i64[pack_slices_64["input_ids"]],
positions=self.packed_context_i64[pack_slices_64["positions"]],
cu_seqlens_q=self.packed_context_i32[pack_slices_32["cu_q"]],
cu_seqlens_k=self.packed_context_i32[pack_slices_32["cu_k"]],
max_seqlen_q=max_seq_len,
max_seqlen_k=max_seq_len,
batch_tokens_to_retain=self.packed_context_i32[pack_slices_32["retain"]],
max_tokens_to_retain=max_retain,
PHI=self.PHI,
context_lens=self.packed_context_i32[pack_slices_32["context_lens"]],
max_new_tokens=self.packed_context_i64[pack_slices_64["max_new_tokens"]],
protected_first=self.packed_context_i32[
pack_slices_32["protected_first"]
].tolist(),
protected_last=self.packed_context_i32[
pack_slices_32["protected_last"]
].tolist(),
compression_ratio=compression_ratio,
)
return prefill_args
@torch.inference_mode()
def build_prefill_args(
self,
seqs: Optional[List[Sequence]] = None,
batch_compression_params: Optional[BatchCompressionParams] = None,
) -> PrefillBatchArguments:
if self.rank == 0:
return self._master_build_prefill(seqs, batch_compression_params)
return self._peer_receive_prefill()
def broadcast(self):
if self.world_size > 1:
return broadcast_from_tp_rank0(
self.packed_context_i64, use_tp_group=self._use_tp_group
)
return None
@staticmethod
def packed_i64_slices(B: int, N: int):
return {
"seq_ids": slice(0, B),
"input_ids": slice(B, B + N),
"positions": slice(B + N, B + 2 * N),
"max_new_tokens": slice(B + 2 * N, 2 * B + 2 * N),
}
@staticmethod
def packed_i32_slices(B: int):
h0, h1 = 0, 6
q0 = h1
q1 = q0 + (B + 1)
k0 = q1
k1 = k0 + (B + 1)
r0 = k1
r1 = r0 + B
c0 = r1
c1 = r1 + B
pf0 = c1
pf1 = c1 + B
pl0 = pf1
pl1 = pf1 + B
return {
"header": slice(h0, h1),
"cu_q": slice(q0, q1),
"cu_k": slice(k0, k1),
"retain": slice(r0, r1),
"context_lens": slice(c0, c1),
"protected_first": slice(pf0, pf1),
"protected_last": slice(pl0, pl1),
}
@dataclass
class DecodeBatchOutput:
output_tokens: Optional[torch.Tensor]
output_seq_ids: Optional[torch.Tensor]
@dataclass
class DecodeBatchArguments:
batch_mapping: Optional[torch.Tensor] = None
token_ids: Optional[torch.Tensor] = None
positions: Optional[torch.Tensor] = None
max_ctx_lens: Optional[torch.Tensor] = None
seq_ids: Optional[torch.Tensor] = None
temps: Optional[torch.Tensor] = None
desired_batch_occupancy: int = -1
num_stashed_batches: int = 0
def update(
self,
batch_mapping,
token_ids,
positions,
max_ctx_lens,
seq_ids,
temps=None,
desired_batch_occupancy: int = None,
):
if self.batch_mapping is not None:
self.batch_mapping = torch.cat([self.batch_mapping, batch_mapping], dim=0)
else:
self.batch_mapping = batch_mapping.clone()
if self.token_ids is not None:
self.token_ids = torch.cat([self.token_ids, token_ids], dim=0)
else:
self.token_ids = token_ids.clone()
if self.positions is not None:
self.positions = torch.cat([self.positions, positions], dim=0)
else:
self.positions = positions.clone()
if self.max_ctx_lens is not None:
self.max_ctx_lens = torch.cat([self.max_ctx_lens, max_ctx_lens], dim=0)
else:
self.max_ctx_lens = max_ctx_lens.clone()
if self.seq_ids is not None:
self.seq_ids = torch.cat([self.seq_ids, seq_ids], dim=0)
else:
self.seq_ids = seq_ids.clone()
if self.temps is not None and temps is not None:
self.temps = torch.cat([self.temps, temps], dim=0)
elif temps is not None:
self.temps = temps.clone()
if desired_batch_occupancy is not None:
self.desired_batch_occupancy = desired_batch_occupancy
return self
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
# Import from compression_config, not compression.__init__, to avoid circular imports
# (compression -> compactor -> context -> compression).
from vllm.kvprune.compression.compression_config import CompressionMethod
from vllm.kvprune.config.engine_config import KvpruneAttentionSchedule
@dataclass
class CompressionContext:
compression_method: CompressionMethod = CompressionMethod.COMPACTOR
compression_chunk_size: int = -1
batch_tokens_to_retain: torch.Tensor | None = None
max_tokens_to_retain: int = 0
context_lens: List[int] | None = None
PHI: torch.Tensor | None = None
# Compactor(与 kvpress ``CompactorPress`` 对齐的可选超参)
sketch_dimension: int = 48
sink_size_start: int = 8
sink_size_end: int = 4
compactor_blending: Optional[float] = None
# 与 kvpress 一致:未设 ``compactor_blending`` 时用该值(来自请求的 compression_ratio)
compression_ratio: Optional[float] = None
protected_first_tokens: List[int] | None = None
protected_last_tokens: List[int] | None = None
# CriticalAdaKV
wo_weight: Optional[torch.Tensor] = None
critical_ada_epsilon: float = 1e-4
critical_ada_first_stage_ratio: float = 0.5
critical_ada_alpha_safeguard: float = 0.2
@dataclass
class Context:
is_prefill: bool = False
do_compression: bool = False
cu_seqlens_q: torch.Tensor | None = None
cu_seqlens_k: torch.Tensor | None = None
# Set in ModelRunner.run_prefill before forward — avoids D2H inside compactor kernels.
cu_seqlens_q_host: Optional[Tuple[int, ...]] = None
cu_seqlens_k_host: Optional[Tuple[int, ...]] = None
max_seqlen_q: int = 0
max_seqlen_k: int = 0
batch_mapping: torch.Tensor | None = None
max_bh_len: int = 0
compression_context: CompressionContext | None = None
STORE_STREAM: torch.cuda.Stream | None = None
key_split: int | None = None
attention_schedule: KvpruneAttentionSchedule = (
KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
)
_CONTEXT = Context()
def get_context():
return _CONTEXT
def set_context(
*,
is_prefill,
do_compression=False,
cu_seqlens_q=None,
cu_seqlens_k=None,
cu_seqlens_q_host: Optional[Tuple[int, ...]] = None,
cu_seqlens_k_host: Optional[Tuple[int, ...]] = None,
max_seqlen_q=0,
max_seqlen_k=0,
batch_mapping=None,
max_bh_len=0,
compression_context: CompressionContext = None,
STORE_STREAM=None,
key_split=None,
attention_schedule=KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE,
):
global _CONTEXT
_CONTEXT = Context(
is_prefill,
do_compression,
cu_seqlens_q,
cu_seqlens_k,
cu_seqlens_q_host,
cu_seqlens_k_host,
max_seqlen_q,
max_seqlen_k,
batch_mapping,
max_bh_len,
compression_context,
STORE_STREAM,
key_split,
attention_schedule,
)
def reset_context():
global _CONTEXT
_CONTEXT = Context()
from collections.abc import Callable
import torch
def maybe_execute_in_stream(
fn: Callable, *args, STORE_STREAM: torch.cuda.Stream = None, **kwargs
):
if STORE_STREAM is not None:
tensors = [arg for arg in args if isinstance(arg, torch.Tensor)]
tensors += [val for val in kwargs.values() if isinstance(val, torch.Tensor)]
obj = getattr(fn, "__self__", None)
if isinstance(obj, torch.Tensor):
tensors.append(obj)
STORE_STREAM.wait_stream(torch.cuda.default_stream())
# Some PyTorch builds don't make `torch.cuda.Stream` a context manager.
# The portable API is `torch.cuda.stream(stream)`.
stream_ctx = (
STORE_STREAM
if hasattr(STORE_STREAM, "__enter__")
else torch.cuda.stream(STORE_STREAM)
)
with stream_ctx:
output = fn(*args, **kwargs)
for t in tensors:
t.record_stream(STORE_STREAM)
if isinstance(output, tuple):
for o in output:
if isinstance(o, torch.Tensor):
o.record_stream(torch.cuda.default_stream())
elif isinstance(output, torch.Tensor):
output.record_stream(torch.cuda.default_stream())
return output
else:
return fn(*args, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment