"vllm/vscode:/vscode.git/clone" did not exist on "404422f42ed9c59ee816dacd9b54196a59ae65b2"
Commit f81ce56b authored by chenzk's avatar chenzk
Browse files

vllm kvprune:v1.0.1

parent 2b7160c6
import torch
import triton
import triton.language as tl
@triton.jit
def vpopc(x):
"""
Vertical popcount
Input x : uint32[..., N]
Output y : uint32[..., 32]
semantics : y[..., i] = sum_j((x[..., j] >> i) & 1)
credits: @apgoucher
"""
tl.static_assert(
x.dtype == tl.uint32, "x should consist of 32-bit unsigned integers"
)
BLOCK_N: tl.constexpr = x.shape[-1] # summation axis
BATCHES: tl.constexpr = x.numel // BLOCK_N # number of batches
if BLOCK_N >= 8:
sa1: tl.constexpr = 8
else:
sa1: tl.constexpr = BLOCK_N
# create 8-way sums in 4-bit fields:
y = tl.reshape(x, [BATCHES, BLOCK_N // sa1, sa1, 1])
y = (y >> tl.arange(0, 4)[None, None, None, :]) & 0x11111111
y = tl.sum(y, 2) # [BATCHES, BLOCK_N // sa1, 4]
if BLOCK_N >= 128:
sa2: tl.constexpr = 16
else:
sa2: tl.constexpr = BLOCK_N // sa1
# create 128-way sums in 8-bit fields:
y = tl.reshape(y, [BATCHES, BLOCK_N // (sa1 * sa2), sa2, 1, 4])
y = (y >> (4 * tl.arange(0, 2))[None, None, None, :, None]) & 0x0F0F0F0F
y = tl.sum(y, 2) # [BATCHES, BLOCK_N // (sa1 * sa2), 2, 4]
sa3: tl.constexpr = BLOCK_N // (sa1 * sa2)
# create N-way sums in 32-bit fields:
y = tl.reshape(y, [BATCHES, 1, sa3, 8])
y = (y >> (8 * tl.arange(0, 4))[None, :, None, None]) & 0x000000FF
y = tl.sum(y, 2) # [BATCHES, 4, 8]
y = tl.reshape(y, x.shape[:-1] + [32])
return y
@triton.jit
def _sum_bitmatrix_memset(Ret, BLOCK: tl.constexpr):
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
tl.store(Ret + offs, 0)
@triton.jit
def _sum_bitmatrix_rows(
B,
shape_bm,
stride_bm: tl.constexpr,
stride_bn: tl.constexpr, # input bitmatrix
Ret,
Partials,
stride_pm: tl.constexpr,
stride_pn,
shape_pn, # outputs
BLOCK_MM: tl.constexpr,
BLOCK_M: tl.constexpr,
):
tl.static_assert(BLOCK_MM % BLOCK_M == 0)
TILE_SIZE: tl.constexpr = BLOCK_MM // BLOCK_M
if isinstance(shape_bm, tl.tensor) and shape_bm.dtype.is_ptr():
shape_bm = tl.load(shape_bm)
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_MM + tl.arange(0, BLOCK_MM)
offs_n = pid_n * 32 + tl.arange(0, 32)
n_rows = shape_bm
bits = tl.load(
B + pid_n * stride_bn + offs_m * stride_bm, mask=offs_m < n_rows, other=0
)
bits = tl.reshape(bits, [TILE_SIZE, BLOCK_M])
ret = vpopc(bits) # [TILE_SIZE, 32]
offs_t = pid_m * TILE_SIZE + tl.arange(0, TILE_SIZE)
tl.atomic_add(Ret + offs_n, tl.sum(ret, 0), sem="relaxed")
tl.store(Partials + offs_t[:, None] * stride_pm + offs_n[None, :] * stride_pn, ret)
def clear_sums(n_cols, device, MEMSET_BLOCK=512):
cdiv = triton.cdiv
blocks = cdiv(n_cols, MEMSET_BLOCK)
out_ret = torch.empty((blocks * MEMSET_BLOCK,), device=device, dtype=torch.int32)
_sum_bitmatrix_memset[(blocks,)](out_ret, MEMSET_BLOCK)
return out_ret
def sum_bitmatrix_rows(x, out_ret, partials_block_size=None):
assert partials_block_size is not None
cdiv = triton.cdiv
PARTIALS_BLOCK_M = partials_block_size
n_rows, n_cols = x.shape
n_rows_max = x.shape_max[0]
assert out_ret.shape == (n_cols,)
TILE_SIZE = max(1, 128 // PARTIALS_BLOCK_M)
BLOCK_MM = PARTIALS_BLOCK_M * TILE_SIZE
pids_x = cdiv(n_rows_max, BLOCK_MM)
pids_y = cdiv(n_cols, 32)
out_partials = torch.empty(
(pids_y * 32, pids_x * TILE_SIZE), device=out_ret.device, dtype=torch.int32
)
out_partials = torch.transpose(out_partials, 0, 1)
# output tensors
_sum_bitmatrix_rows[(pids_x, pids_y)](
x.storage.data,
n_rows,
x.stride(0),
x.stride(1), # input
out_ret, # output [final reduction]
out_partials,
out_partials.stride(0),
out_partials.stride(1),
out_partials.shape[1], # output [partial reductions]
BLOCK_M=PARTIALS_BLOCK_M,
BLOCK_MM=BLOCK_MM, # constants
num_warps=8,
)
out_partials = out_partials[: cdiv(n_rows_max, PARTIALS_BLOCK_M), :]
return out_ret, out_partials
import torch
import triton
from dataclasses import dataclass, field
from .routing_details._routing_compute import _combined_routing_compute
from .routing_details._routing_compute import _combined_routing_memset
from .routing_details._routing_compute import _routing_clear_bitmatrix
from .routing_details._expt_data import _expt_data_memset
from .routing_details._expt_data import _expt_data_compute
from .target_info import is_hip
@dataclass
class GatherIndx:
"""
Indices for an operation that performs:
Y = X[src_idx, :]
"""
# array such that `dst_idx[src_idx] = arange(0, N)`
src_indx: torch.Tensor
dst_indx: torch.Tensor
@dataclass
class ScatterIndx:
"""
Indices for an operation that performs:
Y[dst_idx, :] = X
"""
# array such that `dst_idx[src_idx] = arange(0, N)`
src_indx: torch.Tensor
dst_indx: torch.Tensor
@dataclass
class ExptData:
# hist[i] is the number of tokens routed to expert i
hist: torch.Tensor
# token_offs_raw[i] is the offset of the first token routed
# to expert i in an expert-sorted array
token_offs_raw: torch.Tensor
# token_offs_pad[block][i] is the offset of the first token routed
# to expert i in an expert-sorted array, assuming histogram
# rounded to the next multiple of `block`
token_offs_pad: dict[int, torch.Tensor]
# block_id_map[block] contain one value for each `pid`` launched by
# the matrix multiplication kernel launched with BLOCK_M=block:
# - the value is -1 if the `pid` has no work to do
# - otherwise, the value is two int16 (packed as an int32) that
# correspond respectively to (1) the expert assigned to
# the tokens processed by this pid; (2) the block assigned to the
# tokens processed by this pid (think `pid_m` in a regular matmul)
# see `test_routing.py` for a reference implementation and more details
block_pid_map: dict[int, torch.Tensor]
def __post_init__(self):
if self.hist is not None:
assert self.hist.dtype == torch.int32
if self.token_offs_raw is not None:
assert self.token_offs_raw.dtype == torch.int32
if self.token_offs_pad is not None:
for v in self.token_offs_pad.values():
assert v.dtype == torch.int32
if self.block_pid_map is not None:
for v in self.block_pid_map.values():
assert v.dtype == torch.int32
@dataclass
class RoutingData:
gate_scal: torch.Tensor = field()
expt_hist: torch.Tensor = field()
n_expts_tot: int = field()
n_expts_act: int = field()
expt_data: ExptData = None
# Used to make perf annotation cleaner: when we use expert sharding, we can
# use this to tell the "expected" number of local tokens per expert, because
# the actual number can vary per each input.
expected_tokens_per_expt: int = field(default=None)
def n_blocks(self, n_rows, block_m):
if n_rows <= self.n_expts_tot:
return n_rows
else:
return (
triton.cdiv(max(n_rows - self.n_expts_tot + 1, 0), block_m)
+ self.n_expts_tot
- 1
)
# --------------------------
# sort tokens by expert
# --------------------------
class SortTokens(torch.autograd.Function):
@staticmethod
def forward(ctx, expt_scal, expt_indx, n_expts_tot, bitmatrix):
HIST_BLOCK_M = 32
INDX_OFFS_BLOCK_M = 512
MEMSET_BLOCK = 1024
cdiv = triton.cdiv
device = expt_scal.device
dtype = expt_scal.dtype
n_tokens_raw, _ = bitmatrix.shape
n_tokens_pad, n_expts_act = expt_scal.shape
n_gates_pad = n_tokens_pad * n_expts_act
hist, partial_hist = bitmatrix.sum(partials_block_size=HIST_BLOCK_M)
hist = hist[:n_expts_tot]
assert hist.dtype == torch.int32
# scratchpad
expt_offs = torch.empty(n_expts_tot, dtype=torch.int32, device=device)
combined_indx = torch.empty(n_gates_pad * 2, dtype=torch.int32, device=device)
# output
topk_indx = combined_indx[:n_gates_pad]
gate_indx = combined_indx[n_gates_pad:]
gate_scal = torch.empty(n_gates_pad, dtype=dtype, device=device)
(
token_offs_combined,
token_offs_raw,
token_offs_pad,
block_pid_map,
blocks1a,
blocks2a,
MEMSET_BLOCK_A,
HIST2_BLOCK_M,
block_m_log2_start,
block_m_num,
) = _compute_expt_data_internal(hist, n_expts_tot, n_gates_pad)
blocks1b = cdiv(n_gates_pad * 2, MEMSET_BLOCK) + n_expts_tot + 1
blocks2b = cdiv(n_tokens_pad, HIST_BLOCK_M)
_combined_routing_memset[(blocks1a + blocks1b,)](
combined_indx,
n_gates_pad * 2,
-1,
MEMSET_BLOCK,
hist, #
expt_offs,
hist.shape[0],
n_expts_tot,
partial_hist, # inputs
partial_hist.shape[0],
partial_hist.stride(0),
partial_hist.stride(1), # outputs
token_offs_combined,
token_offs_combined.stride(0), #
blocks1a,
block_pid_map, #
block_m_log2_start,
SIZES=block_m_num,
BLOCK_A=MEMSET_BLOCK_A, # optimization parameters
BLOCK_N=512,
BLOCK_M=INDX_OFFS_BLOCK_M, # tunable parameters
)
indx_offs = partial_hist
_combined_routing_compute[(blocks2a + blocks2b,)](
topk_indx,
gate_indx,
gate_scal, # outputs
expt_scal,
expt_indx,
indx_offs,
indx_offs.stride(0),
indx_offs.stride(1), # inputs
expt_offs,
n_tokens_raw, # input shape
HIST_BLOCK_M,
n_expts_act, # constants
hist,
token_offs_pad,
token_offs_pad.stride(0),
block_pid_map,
block_pid_map.stride(0), # outputs
block_m_log2_start,
block_m_num,
HIST2_BLOCK_M,
blocks2a, # etc.
)
ctx.n_tokens_raw = n_tokens_raw
ctx.n_tokens_pad = n_tokens_pad
ctx.n_expts_act = n_expts_act
ctx.save_for_backward(gate_indx)
return (
hist,
topk_indx,
gate_indx,
gate_scal,
token_offs_raw,
token_offs_pad,
block_pid_map,
)
@staticmethod
def backward(ctx, _0, _1, _2, dgate_scal, _3, _4, _5):
(gate_indx,) = ctx.saved_tensors
dgate_scal = dgate_scal[gate_indx]
dgate_scal = dgate_scal.reshape(ctx.n_tokens_pad, ctx.n_expts_act)
return dgate_scal, None, None, None
def sort_tokens(expt_scal, expt_indx, n_expts_tot, bitmatrix):
return SortTokens.apply(expt_scal, expt_indx, n_expts_tot, bitmatrix)
# --------------------------
# prune routing
# --------------------------
class PruneRouting(torch.autograd.Function):
@staticmethod
def forward(ctx, expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep):
from .compaction import compaction
n_tokens_pad = expt_scal.shape[0]
assert n_expts_tot % simulated_ep == 0
_routing_clear_bitmatrix[(n_tokens_pad,)](
bitmatrix.storage.data,
bitmatrix.storage.data.stride(0),
bitmatrix.storage.data.stride(1),
bitmatrix.storage.data.shape[1],
n_expts_tot // simulated_ep,
BLOCK_N=512,
)
# perform compaction to update expt_scal / expt_indx
expt_scal, expt_indx = compaction(expt_scal, expt_indx, bitmatrix)
n_expts_tot = n_expts_tot // simulated_ep
bitmatrix.shape[-1] = n_expts_tot
return expt_scal, expt_indx, bitmatrix
def prune_routing(expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep):
return PruneRouting.apply(
expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep
)
# --------------------------
# expt_data
# --------------------------
def log2_power_of_two(x):
assert x > 0 and (x & (x - 1)) == 0, "x must be a power of two"
return x.bit_length() - 1
block_m_log2_start = 4
def _compute_expt_data_internal(expt_hist, n_expts_tot, n_gates):
MEMSET_BLOCK = 512
HIST2_BLOCK_M = 512
device = expt_hist.device
n_expts_tot = n_expts_tot
cdiv = triton.cdiv
# block_ms are all powers-of-two between 16 and 128 (inclusive)
block_m_log2_end = 9 if is_hip() else 8
block_m_num = block_m_log2_end - block_m_log2_start
if n_gates <= n_expts_tot:
max_n_tiles = n_gates
else:
max_n_tiles = (
n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // 2**block_m_log2_start)
)
# allocate memory
pad = lambda x: cdiv(x, MEMSET_BLOCK) * MEMSET_BLOCK
dtype = torch.int32
token_offs_combined = torch.empty(
(block_m_num + 1, pad(n_expts_tot + 1)), dtype=dtype, device=device
)
token_offs_raw = token_offs_combined[0][: n_expts_tot + 1]
token_offs_pad = token_offs_combined[1:]
block_pid_map = torch.empty(
(block_m_num, pad(max_n_tiles)), dtype=dtype, device=device
)
memset_grid = torch.numel(block_pid_map) // MEMSET_BLOCK # exact division
# compute outputs
token_offs_pad = token_offs_pad[:, : n_expts_tot + 1]
block_pid_map = block_pid_map[:, :max_n_tiles]
blocks1 = memset_grid + block_m_num + 1
blocks2 = n_expts_tot * block_m_num
return (
token_offs_combined,
token_offs_raw,
token_offs_pad,
block_pid_map,
blocks1,
blocks2,
MEMSET_BLOCK,
HIST2_BLOCK_M,
block_m_log2_start,
block_m_num,
)
def _unpack_into_dict(x):
block_m_log2_end = block_m_log2_start + x.shape[0]
x = {
2**j: x[i, :] for i, j in enumerate(range(block_m_log2_start, block_m_log2_end))
}
return x
def compute_expt_data(expt_hist, n_expts_tot, n_gates):
if expt_hist is None:
return ExptData(None, None, None, None)
# this just computes the kernel arguments:
(
token_offs_combined,
token_offs_raw,
token_offs_pad,
block_pid_map,
blocks1,
blocks2,
MEMSET_BLOCK,
HIST2_BLOCK_M,
block_m_log2_start,
block_m_num,
) = _compute_expt_data_internal(expt_hist, n_expts_tot, n_gates)
_expt_data_memset[(blocks1,)](
expt_hist,
n_expts_tot, #
token_offs_combined,
token_offs_combined.stride(0), #
block_pid_map, #
block_m_log2_start,
SIZES=block_m_num,
BLOCK=MEMSET_BLOCK, # optimization parameters
num_warps=4,
)
_expt_data_compute[(blocks2,)](
expt_hist,
token_offs_pad,
token_offs_pad.stride(0),
block_pid_map,
block_pid_map.stride(0), # outputs
block_m_log2_start,
SIZES=block_m_num,
BLOCK=HIST2_BLOCK_M, # optimization parameters
num_warps=4,
)
token_offs_pad = _unpack_into_dict(token_offs_pad)
block_pid_map = _unpack_into_dict(block_pid_map)
return ExptData(expt_hist, token_offs_raw, token_offs_pad, block_pid_map)
# --------------------------
# routing
# --------------------------
def routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act):
(
hist,
topk_indx,
gate_indx,
gate_scal,
token_offs_raw,
token_offs_pad,
block_pid_map,
) = sort_tokens(expt_scal, expt_indx, n_expts_tot, bitmatrix)
token_offs_pad = _unpack_into_dict(token_offs_pad)
block_pid_map = _unpack_into_dict(block_pid_map)
expt_data = ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map)
# pack the matmul data structure
gather_indx = GatherIndx(src_indx=topk_indx, dst_indx=gate_indx)
scatter_indx = ScatterIndx(src_indx=gate_indx, dst_indx=topk_indx)
return (
RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data),
gather_indx,
scatter_indx,
)
def routing(
logits, n_expts_act, sm_first=False, expt_indx=None, simulated_ep=1, n_rows=None
):
from .topk import topk
if sm_first:
logits = torch.softmax(logits, dim=-1)
expt_scal, expt_indx, bitmatrix = topk(
logits,
n_expts_act, #
apply_softmax=not sm_first,
y_indx=expt_indx,
n_rows=n_rows,
)
n_expts_tot = logits.shape[-1] // simulated_ep
# mutate bitmatrix
if simulated_ep > 1:
expt_scal, expt_indx, bitmatrix = prune_routing(
expt_scal, expt_indx, bitmatrix, logits.shape[-1], simulated_ep
)
return routing_from_bitmatrix(
bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act
)
# --------------------------
# torch reference
# --------------------------
def compute_expt_data_torch(hist, n_expts_tot, n_gates):
# offset for each experts
device = hist.device
token_offs_raw = torch.cumsum(hist, dim=0)
token_offs_raw = torch.cat((torch.zeros(1, device=device), token_offs_raw))
token_offs_raw = token_offs_raw.int()
# maximum number of tiles for all values of `block_m` considered
block_ms = [16, 32, 64, 128]
if is_hip():
block_ms.append(256)
if n_gates <= n_expts_tot:
max_n_tiles = n_gates
else:
# ceil_div(n_gates - n_experts + 1, d_tile) + n_experts - 1
# ceil_div(x, y): -(-x // y)
max_n_tiles = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // min(block_ms))
# fill up tile offset/infos for each block
token_offs_pad = dict()
block_pid_map = dict()
for block_m in block_ms:
n_tiles = (hist + block_m - 1) // block_m # matmul blocks needed
token_offs_pad[block_m] = torch.cumsum(n_tiles, dim=0)
token_offs_pad[block_m] = torch.cat(
(torch.zeros(1, device=device), token_offs_pad[block_m])
)
token_offs_pad[block_m] = token_offs_pad[block_m].int()
# compute data required to drive ragged batch matmul
block_pid_map[block_m] = -torch.ones(
max_n_tiles, dtype=torch.int32, device=device
)
# for e in range(n_expts_tot):
# offset = token_offs_pad[block_m][e]
# for b in range(n_tiles[e]):
# block_pid_map[block_m][offset + b] = (b << 16) + e
col = torch.arange(max_n_tiles, device=device)
map_vals = (
torch.arange(n_expts_tot, device=device)[:, None] + (col << 16)[None, :]
)
map_idxs = token_offs_pad[block_m][:-1, None] + col[None, :]
mask = col[None, :] < n_tiles[:, None]
block_pid_map[block_m].index_put_((map_idxs[mask],), map_vals.int()[mask])
return ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map)
def topk_torch(vals, k, expt_indx, has_user_provided_indx=False):
# topk of experts
if has_user_provided_indx:
tk_indx = expt_indx
else:
tk_indx = torch.argsort(-vals, dim=1, stable=True)[:, :k]
tk_indx = tk_indx.long()
tk_val = torch.take_along_dim(vals, tk_indx, dim=1)
tk_indx = tk_indx.int()
return tk_val, tk_indx
def routing_torch(logits, n_expts_act, sm_first=False, expt_indx=None, n_rows=None):
has_user_provided_indx = expt_indx is not None
n_gates_pad = logits.shape[0] * n_expts_act
if n_rows is not None:
logits = logits[:n_rows, :]
_, n_expts_tot = logits.shape
if sm_first:
logits = torch.softmax(logits, dim=-1)
expt_scal, expt_indx = topk_torch(
logits, n_expts_act, expt_indx, has_user_provided_indx=has_user_provided_indx
)
if not sm_first:
expt_scal = torch.softmax(expt_scal, dim=-1)
# sort each token's selections by expert
if not has_user_provided_indx:
expt_indx, sort_indices = torch.sort(expt_indx, dim=1)
expt_scal = torch.gather(expt_scal, 1, sort_indices)
# flatten topk data
expt_scal = expt_scal.reshape(-1)
expt_indx = expt_indx.reshape(-1).to(torch.int32)
# sort by expert_id so experts are contiguous for the matmul
topk_indx = torch.argsort(expt_indx, stable=True)
gate_indx = torch.argsort(topk_indx, stable=True)
gate_scal = expt_scal[topk_indx]
hist = torch.histc(
expt_indx, bins=n_expts_tot, max=n_expts_tot - 1
).int() # histogram of tokens over experts
# pack the matmul data structure
gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int())
scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int())
# compute expt_data
expt_data = compute_expt_data_torch(hist, n_expts_tot, n_gates_pad)
return (
RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data),
gather_indx,
scatter_indx,
)
import triton
import triton.language as tl
@triton.jit
def _cdiv_pow2(n, log2_k):
return (n + ((1 << log2_k) - 1)) >> log2_k
@triton.jit
def _expt_data_memset(
Hist,
n_expts_tot,
MDStarts,
tile_starts_stridem,
MDTileInfo,
first_tile_dim_log2,
SIZES: tl.constexpr,
BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
if pid <= SIZES:
MDStarts += pid * tile_starts_stridem
x_tile = tl.zeros([BLOCK], dtype=MDStarts.dtype.element_ty)
Tile_ptrs = MDStarts + tl.arange(0, BLOCK)
tile_dim_log2 = tl.where(pid == 0, 0, pid + first_tile_dim_log2 - 1)
for i in range(0, n_expts_tot + 1, BLOCK):
offs_n = tl.arange(0, BLOCK) + i
mask_n0 = offs_n < n_expts_tot
hist_tok = tl.load(Hist + offs_n, mask=mask_n0, other=0)
hist_tile = _cdiv_pow2(hist_tok, tile_dim_log2)
tile_starts = tl.cumsum(hist_tile, 0) + x_tile
x_tile += tl.sum(hist_tile, 0).to(MDStarts.dtype.element_ty)
tl.store(Tile_ptrs, tile_starts - hist_tile)
Tile_ptrs += BLOCK
else:
pid -= SIZES + 1
TileInfoOut = MDTileInfo + pid * BLOCK + tl.arange(0, BLOCK)
tl.store(TileInfoOut, 0xFFFFFFFF)
@triton.jit
def _expt_data_compute(
Hist,
MDTileStarts,
tile_starts_stridem,
MDTileInfo,
tile_info_stridem,
first_tile_dim_log2,
SIZES: tl.constexpr,
BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
expt_id = pid // SIZES
buff_id = pid % SIZES
MDTileStarts += buff_id * tile_starts_stridem
MDTileInfo += buff_id * tile_info_stridem
n_tokens = tl.load(Hist + expt_id)
tile_dim_log2 = first_tile_dim_log2 + buff_id
n_blocks = _cdiv_pow2(n_tokens, tile_dim_log2)
tile_off = tl.load(MDTileStarts + expt_id)
MDTileInfo += tile_off
for block_off in range(0, n_blocks, BLOCK):
block_offs = block_off + tl.arange(0, BLOCK)
data = (block_offs << 16) + expt_id
tl.store(MDTileInfo + block_offs, data, mask=block_offs < n_blocks)
import triton
import triton.language as tl
from ._expt_data import _expt_data_compute, _expt_data_memset
@triton.jit
def _routing_compute_expt_offs(
ExpertHist,
FinalExpertOffs,
hist_size, # histogram
BLOCK_N: tl.constexpr,
):
loop_iterations = (hist_size + BLOCK_N - 1) // BLOCK_N
x = tl.zeros([BLOCK_N], ExpertHist.dtype.element_ty)
for i in range(loop_iterations):
offs_n = i * BLOCK_N + tl.arange(0, BLOCK_N)
mask_n = offs_n < hist_size
hist2 = tl.load(ExpertHist + offs_n, mask=mask_n)
tok_starts = tl.cumsum(hist2, 0) - hist2 + x
x += tl.sum(hist2, 0)
tl.store(FinalExpertOffs + offs_n, tok_starts, mask=mask_n)
offs_n += BLOCK_N
@triton.jit
def _routing_compute_indx_offs(
PartialHist, shape_pm, stride_pm, stride_pn, BLOCK_M: tl.constexpr, expt_id
):
offs_m = tl.arange(0, BLOCK_M)
# iterate over input data
curr_sum = 0
for _ in range(0, shape_pm, BLOCK_M):
offs = offs_m * stride_pm + expt_id * stride_pn
curr = tl.load(PartialHist + offs, mask=offs_m < shape_pm)
out = tl.cumsum(curr, 0) + curr_sum
curr_sum += tl.sum(curr, 0)
tl.store(PartialHist + offs, out - curr, mask=offs_m < shape_pm)
offs_m += BLOCK_M
@triton.jit
def _keyed_add(x, y):
# we keep the key in the upper 16 bits of a uint32:
key_mask: tl.constexpr = 0xFFFF0000
kx = x & key_mask
ky = y & key_mask
z = tl.where(kx == ky, x + y - kx, y)
return z
@triton.jit
def _routing_compute_indx(
pid_m,
GatherIndx,
ScatterIndx,
GateScal,
ExptScal,
ExptIndx,
PartialOffs,
stride_pm,
stride_pn,
TokensStart,
n_tokens,
BLOCK_M: tl.constexpr,
N_EXPTS_ACT: tl.constexpr,
):
if isinstance(n_tokens, tl.tensor) and n_tokens.dtype.is_ptr():
n_tokens = tl.load(n_tokens)
n_gates = n_tokens * N_EXPTS_ACT
tl.static_assert(N_EXPTS_ACT * BLOCK_M <= 32768)
local_offs = tl.arange(0, N_EXPTS_ACT * BLOCK_M)
offs = pid_m * BLOCK_M * N_EXPTS_ACT + local_offs
expert = tl.load(ExptIndx + offs, mask=(offs < n_gates), other=-1).to(tl.uint32)
# stable-sort by expert ID:
kv_pairs = ((expert << 16) | local_offs).to(tl.uint32)
kv_pairs = tl.sort(kv_pairs, 0)
expert = kv_pairs >> 16
offs = pid_m * BLOCK_M * N_EXPTS_ACT + (kv_pairs & 0xFFFF)
mask = expert != 0xFFFF
gate_scal = tl.load(ExptScal + offs, mask=mask)
# compute run lengths in expert-sorted order:
x = kv_pairs & 0xFFFF0000 | 0x00000001
expts_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add)
exclusive_run_lengths = (expts_and_inclusive_run_lengths - 1) & 0xFFFF
gates = tl.load(PartialOffs + pid_m * stride_pm + expert * stride_pn, mask=mask)
gates += tl.load(TokensStart + expert, mask=mask)
gates += exclusive_run_lengths
tl.store(ScatterIndx + offs, gates, mask=mask)
tl.store(GatherIndx + gates, offs, mask=mask)
tl.store(GateScal + gates, gate_scal, mask=mask)
@triton.jit
def _combined_routing_compute(
GatherIndx,
ScatterIndx,
GateScal,
ExptScal,
ExptIndx,
PartialOffs,
stride_pm,
stride_pn,
TokensStart,
n_tokens,
BLOCK_M: tl.constexpr,
N_EXPTS_ACT: tl.constexpr,
Hist,
MDTileStarts,
tile_starts_stridem,
MDTileInfo,
tile_info_stridem,
first_tile_dim_log2,
SIZES: tl.constexpr,
BLOCK: tl.constexpr,
blocks2a,
):
pid = tl.program_id(0)
if pid < blocks2a:
_expt_data_compute(
Hist,
MDTileStarts,
tile_starts_stridem,
MDTileInfo,
tile_info_stridem,
first_tile_dim_log2,
SIZES,
BLOCK,
)
else:
pid -= blocks2a
_routing_compute_indx(
pid,
GatherIndx,
ScatterIndx,
GateScal,
ExptScal,
ExptIndx,
PartialOffs,
stride_pm,
stride_pn,
TokensStart,
n_tokens,
BLOCK_M,
N_EXPTS_ACT,
)
@triton.jit
def _routing_clear_bitmatrix(
Bitmatrix, stride_bm, stride_bn, shape_bn, cutoff, BLOCK_N: tl.constexpr
):
pid_m = tl.program_id(0)
cutoff_word = cutoff // 32
cutoff_bit = cutoff % 32
cutoff_mask = (1 << (cutoff_bit)) - 1
for start_n in range(0, shape_bn, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
values = tl.load(
Bitmatrix + pid_m * stride_bm + offs_n * stride_bn, mask=offs_n < shape_bn
)
values = tl.where(offs_n == cutoff_word, values & cutoff_mask, values)
values = tl.where(offs_n > cutoff_word, 0, values)
tl.store(
Bitmatrix + pid_m * stride_bm + offs_n * stride_bn,
values,
mask=offs_n < shape_bn,
)
@triton.jit
def _combined_routing_memset(
Indx,
size,
sentinel,
BLOCK: tl.constexpr,
ExpertHist,
FinalExpertOffs,
hist_size,
n_expts_tot,
PartialHist,
shape_pm,
stride_pm,
stride_pn,
MDStarts,
tile_starts_stridem,
blocks1a,
MDTileInfo,
first_tile_dim_log2,
SIZES: tl.constexpr,
BLOCK_A: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_M: tl.constexpr,
):
"""
This kernel essentially combines 6 different pieces of functionality,
statically branching on the value of tl.program_id(0) to decide which
codepath to take.
pid == 0: create the token cumsum
1 <= pid <= SIZES: create a tile cumsum
SIZES < pid < blocks1a: initialise MDTileInfo to 0xffffffff
blocks1a <= pid < blocks1a + n_expts_tot: compute_indx_offs
pid == blocks1a + n_expts_tot: compute_expt_offs
pid > blocks1a + n_expts_tot: initialise Indx to sentinel
As each of these is a relatively trivial workload, launching them from
this single trampoline is beneficial as they can execute on different
streaming multiprocesses in parallel.
"""
pid = tl.program_id(0)
if pid < blocks1a:
_expt_data_memset(
ExpertHist,
n_expts_tot,
MDStarts,
tile_starts_stridem,
MDTileInfo,
first_tile_dim_log2,
SIZES,
BLOCK_A,
)
elif pid == n_expts_tot + blocks1a:
_routing_compute_expt_offs(ExpertHist, FinalExpertOffs, hist_size, BLOCK_N)
elif pid < n_expts_tot + blocks1a:
_routing_compute_indx_offs(
PartialHist, shape_pm, stride_pm, stride_pn, BLOCK_M, pid - blocks1a
)
else:
offs = (pid - n_expts_tot - blocks1a - 1) * BLOCK + tl.arange(0, BLOCK)
mask = offs < size
tl.store(Indx + offs, sentinel, mask=mask)
import inspect
import re
import textwrap
import types
import triton
def cacheable(f):
"""
A decorator that allow you to write something of the form:
@cacheable
def my_kernel(): return (expression dynamically defining a kernel)
such that it interacts gracefully with triton cache and preload.
"""
g = f()
g.fn.__name__ = f.__name__
g.fn.__module__ = f.__module__
g.fn.__qualname__ = f.__qualname__
g.__name__ = f.__name__
g.__module__ = f.__module__
g.__qualname__ = f.__qualname__
g._fn_name = f"{f.__module__}.{f.__qualname__}"
return g
def define_kernel(src, module, attrs=None, **extra_globals):
"""
Dynamically create a Triton function or kernel from a src string,
linking any symbols in the kernel to objects specified by extra_globals.
"""
# create templace function
def _empty_fn():
pass
gdict = dict(**(_empty_fn.__globals__))
gdict.update(extra_globals)
f = types.FunctionType(_empty_fn.__code__, gdict)
f.__module__ = module.__name__
src = textwrap.dedent(src)
src = src[src.find("def ") :]
stored_functions = []
function_name = src[4:].split("(")[0].strip()
exec_globals = gdict
exec_globals.update({"stored_functions": stored_functions})
exec(src + "\n\nstored_functions.append(" + function_name + ")\n", exec_globals)
f.__signature__ = inspect.signature(stored_functions[0])
f.__name__ = function_name
f.__doc__ = stored_functions[0].__doc__
if attrs is None:
attrs = dict()
f = triton.JITFunction(f, **attrs)
f._unsafe_update_src(src)
return f
def specialize(fn, module, constants, tuples, name=None, do_not_specialize=tuple()):
assert isinstance(fn, triton.runtime.jit.JITFunction)
if name is None:
name = f"{fn.__name__}"
# Get original source code
src = inspect.getsource(fn.fn)
src = textwrap.dedent(src)
lines = src.split("\n")
# Skip decorator and def line
def_idx = next(i for i, line in enumerate(lines) if line.strip().startswith("def"))
# separate header vs body LOC
header_end = def_idx
while not lines[header_end].rstrip().endswith(":"):
header_end += 1
body_lines = lines[header_end + 1 :]
header_lines = lines[def_idx : header_end + 1]
# clean-up header
header_clean = [
l.split("#", 1)[0].strip() # keep code, discard comment
for l in header_lines
if l.split("#", 1)[0].strip() # skip blank‑after‑comment lines
]
# decompose arguments
header_src = " ".join(header_clean) # turn it into a single line
m = re.search(r"\((.*)\)\s*:", header_src)
if not m:
raise ValueError("Could not parse function header")
args_str = m.group(1)
args = [arg.strip() for arg in args_str.split(",") if arg.strip()]
non_specialized_args = []
for arg in args:
arg_key = arg.split(":")[0].split("=")[0].strip()
new_args = tuples.get(arg_key, [arg])
if arg_key not in constants:
non_specialized_args += new_args
# add global symbols
spec_fns = {
v.__name__: v
for k, v in constants.items()
if isinstance(v, triton.runtime.jit.JITFunction)
}
globals = spec_fns | fn.get_capture_scope()
# build new source code and define kernel dynamically
new_signature = f"def {name}({', '.join(non_specialized_args)}):"
constexpr_lines = [
f" {key}: tl.constexpr = {value.__name__ if callable(value) else value}"
for key, value in constants.items()
]
tuple_lines = [
f" {key} = {'(' + ','.join(value) + (',' if len(value) >= 1 else '') + ')'}"
for key, value in tuples.items()
]
new_src = "\n".join(
["@triton.jit", new_signature] + constexpr_lines + tuple_lines + body_lines
)
# find function parameters
sig = inspect.signature(triton.runtime.jit.JITFunction.__init__)
params = list(sig.parameters.values())[2:]
attrs = {param.name: getattr(fn, param.name, param.default) for param in params}
# make a new repr which appends the repr of the specialized functions.
base_repr = attrs["repr"]
def new_repr(specialization):
ret = base_repr(specialization)
for spec_fn in spec_fns.values():
spec_repr = spec_fn.repr(None)
if spec_repr:
spec_repr = spec_repr.strip("_")
if spec_repr:
ret += f"_{spec_repr}"
return ret
attrs["repr"] = new_repr
if do_not_specialize:
attrs["do_not_specialize"] = do_not_specialize
ret = define_kernel(new_src, module, attrs, **globals)
return ret
from dataclasses import dataclass
from vllm.kvprune.triton_kernels.numerics import InFlexData, OutFlexData
import torch
import triton
from .swiglu_details._swiglu import _swiglu, _swiglu_fn
from vllm.kvprune.triton_kernels import target_info
@dataclass(frozen=True)
class FlexCtx:
out_data: OutFlexData = OutFlexData()
inp_data: InFlexData = InFlexData()
saturate_inf: bool = False
@dataclass(frozen=True)
class PrecisionConfig:
limit: float
flex_ctx: FlexCtx = FlexCtx()
swiglu_fn = _swiglu_fn
class SwiGLU(torch.autograd.Function):
@staticmethod
def forward(ctx, a, alpha, precision_config, routing_data):
N = a.shape[-1]
M = a.numel() // N
assert a.stride()[-1] == 1
assert a.shape[-1] % 2 == 0
out = torch.empty(size=(M, N // 2), dtype=a.dtype, device=a.device)
flex_ctx = precision_config.flex_ctx
# optimization hyperparameters
BLOCK_M, BLOCK_N = 32 // a.itemsize, 128
num_warps = 4
kwargs = {"maxnreg": 64} if not target_info.is_hip() else {}
# launch semi-persistent kernel
N_BLOCKS = triton.cdiv(N // 2, BLOCK_N)
num_sms = target_info.num_sms()
if routing_data is not None:
waves_per_sm = 32 if target_info.is_hip() else 128
num_pid = num_sms * (waves_per_sm // num_warps)
M_BLOCKS = max(1, triton.cdiv(num_pid, N_BLOCKS))
grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms),)
else:
M_BLOCKS = triton.cdiv(M, BLOCK_M)
if M_BLOCKS * N_BLOCKS >= 8 * num_sms:
grid = (8 * num_sms,)
else:
grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms),)
n_tokens = None
if routing_data is not None:
n_tokens = routing_data.expt_data.token_offs_raw[routing_data.n_expts_tot]
_swiglu[grid](
flex_ctx.out_data.reinterpret(out),
flex_ctx.out_data.expected_scale,
flex_ctx.out_data.actual_scale,
flex_ctx.out_data.checksum_scale,
flex_ctx.inp_data.reinterpret(a),
flex_ctx.inp_data.scale,
alpha,
M,
N // 2,
a.shape[-1],
1,
out.shape[-1],
1,
precision_config.limit,
n_tokens,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
EVEN_N=(N // 2) % BLOCK_N == 0,
M_BLOCKS=M_BLOCKS,
N_BLOCKS=N_BLOCKS,
flexpoint_saturate_inf=flex_ctx.saturate_inf,
num_warps=num_warps,
**kwargs,
)
out = out.view(a.shape[:-1] + out.shape[-1:])
return out
def swiglu(a, alpha, precision_config, routing_data=None):
return SwiGLU.apply(a, alpha, precision_config, routing_data)
def swiglu_torch(a, alpha, precision_config):
limit = precision_config.limit
a_gelu = a[..., ::2]
if limit is not None:
a_gelu = a_gelu.clamp(max=limit)
a_linear = a[..., 1::2]
if limit is not None:
a_linear = a_linear.clamp(min=-limit, max=limit)
out_gelu = a_gelu * torch.sigmoid(alpha * a_gelu)
out = out_gelu * (a_linear + 1)
return out
from vllm.kvprune.triton_kernels.numerics_details.flexpoint import (
load_scale,
float_to_flex,
update_scale,
)
import triton
import triton.language as tl
@triton.jit
def clip(x, limit, clip_lower: tl.constexpr):
res = tl.minimum(x, limit)
if clip_lower:
res = tl.maximum(-limit, res)
return res
@triton.jit
def thread_local_absmax(x, BLOCK_SIZE: tl.constexpr, NUM_THREADS: tl.constexpr):
return tl.max(
tl.reshape(
tl.abs(x), [NUM_THREADS, BLOCK_SIZE // NUM_THREADS], can_reorder=True
),
axis=1,
)
def swiglu_repr(specialization):
signature = specialization.signature
constants = specialization.constants
convert_dtype = lambda dtype: "mxfp4" if "u8" in dtype else dtype
dtypes = "x".join([convert_dtype(f"{signature[i][1:]}") for i in ["Out", "A"]])
blocks = "x".join([f"{constants[i]}" for i in ["BLOCK_M", "BLOCK_N"]])
return f"_swiglu_{dtypes}_{blocks}"
def swiglu_launch_metadata(grid, kernel, args):
M, N = args["M"], args["N"]
ret = dict()
ret["name"] = f"{kernel.name} [M = {M}, N = {N}]"
A, Out = args["A"], args["Out"]
ret["bytes"] = Out.numel() * Out.element_size() + A.numel() * A.element_size()
return ret
@triton.jit
def compute_swiglu(gelu, linear, scale, alpha, limit):
gelu = gelu.to(tl.float32) * scale
if limit is not None:
gelu = clip(gelu, limit, clip_lower=False)
linear = linear.to(tl.float32) * scale
if limit is not None:
linear = clip(linear, limit, clip_lower=True)
s = gelu / (1 + tl.exp(-alpha * gelu))
return tl.fma(s, linear, s) # (s * (linear + 1))
@triton.jit(repr=lambda _: "_swiglu")
def _swiglu_fn(input, alpha, limit):
gelu, linear = tl.split(tl.reshape(input, (input.shape[0], input.shape[1] // 2, 2)))
return compute_swiglu(gelu, linear, 1.0, alpha, limit)
@triton.jit(repr=swiglu_repr, launch_metadata=swiglu_launch_metadata)
def _swiglu(
Out,
OutExpectedScale,
OutActualScale,
OutChecksumScale,
A,
AScale,
alpha,
M,
N,
stride_am,
stride_an,
stride_outm,
stride_outn,
limit: tl.constexpr,
NTokens,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
EVEN_N: tl.constexpr,
M_BLOCKS,
N_BLOCKS,
flexpoint_saturate_inf: tl.constexpr,
):
if NTokens is not None:
M = tl.load(NTokens)
M_BLOCKS = (M + BLOCK_M - 1) // BLOCK_M
local_max = tl.full([tl.extra.cuda.num_threads()], 0.0, tl.float32)
a_scale = load_scale(AScale)
out_expected_scale = load_scale(OutExpectedScale)
for pid in tl.range(
tl.program_id(0), M_BLOCKS * N_BLOCKS, tl.num_programs(0), num_stages=2
):
pid_m = pid // N_BLOCKS
pid_n = pid % N_BLOCKS
off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
mask_m = off_m < M
mask_n = off_n < N
packed_off_n = pid_n * BLOCK_N + tl.arange(0, 2 * BLOCK_N) // 2
packed_mask_n = packed_off_n < N
packed_mask_n = tl.max_constancy(packed_mask_n, [16])
# load a
packed_off_n = pid_n * 2 * BLOCK_N + tl.arange(0, 2 * BLOCK_N)
packed_offs = off_m[:, None] * stride_am + packed_off_n[None, :] * stride_an
if EVEN_N:
a_packed = tl.load(A + packed_offs, mask=mask_m[:, None], other=0.0)
else:
if pid_n * BLOCK_N + BLOCK_N <= N:
a_packed = tl.load(A + packed_offs, mask=mask_m[:, None], other=0.0)
else:
packed_mask = mask_m[:, None] & packed_mask_n[None, :]
a_packed = tl.load(A + packed_offs, mask=packed_mask, other=0.0)
a_gelu, a_linear = tl.split(tl.reshape(a_packed, (BLOCK_M, BLOCK_N, 2)))
out = compute_swiglu(a_gelu, a_linear, a_scale, alpha, limit)
# update flexpoint stats and divide by scale
# we don't need masking because of the `other` when loading `A`
if OutActualScale is not None:
absmax = thread_local_absmax(out, out.numel, tl.extra.cuda.num_threads())
local_max = tl.maximum(local_max, absmax)
out = float_to_flex(
out,
out_expected_scale,
None, # ActualScale: local absmax is tracked and updated after the loop
OutChecksumScale,
None,
Out,
flexpoint_saturate_inf,
)
mask = mask_m[:, None] if EVEN_N else mask_m[:, None] & mask_n[None, :]
tl.store(
Out + off_m[:, None] * stride_outm + off_n[None, :] * stride_outn, out, mask
)
update_scale(local_max, OutActualScale, Out)
import torch
import triton
# ``constexpr_function`` moved across Triton versions; ROCm/vendor wheels often
# only expose ``triton.constexpr_function`` (not ``triton.runtime.jit``).
def _resolve_constexpr_function():
fn = getattr(triton, "constexpr_function", None)
if fn is not None:
return fn
try:
from triton.runtime.jit import constexpr_function as _fn
return _fn
except ImportError:
pass
_jit = getattr(triton, "jit", None)
if _jit is not None:
fn = getattr(_jit, "constexpr_function", None)
if fn is not None:
return fn
raise ImportError(
"Cannot resolve Triton constexpr_function (try: pip install -U triton)"
)
constexpr_function = _resolve_constexpr_function()
__all__ = [
"cuda_capability_geq",
"get_cdna_version",
"has_tma_gather",
"has_native_mxfp",
"is_cuda",
"is_hip",
"is_hip_cdna3",
"is_hip_cdna4",
"num_sms",
]
try:
from triton.language.target_info import (
cuda_capability_geq,
current_target,
is_cuda,
is_hip,
is_hip_cdna3,
is_hip_cdna4,
)
except ImportError:
# Some ROCm / vendor Triton wheels omit ``triton.language.target_info``.
# Mirror upstream Triton (see triton/language/target_info.py) via runtime.
from triton.runtime import driver
def current_target():
try:
active_driver = driver.active
except RuntimeError:
return None
return active_driver.get_current_target()
@constexpr_function
def is_cuda():
target = current_target()
return target is not None and target.backend == "cuda"
@constexpr_function
def is_hip():
target = current_target()
return target is not None and target.backend == "hip"
@constexpr_function
def cuda_capability_geq(major, minor=0):
target = current_target()
if target is None or target.backend != "cuda":
return False
assert isinstance(target.arch, int)
return target.arch >= major * 10 + minor
@constexpr_function
def is_hip_cdna3():
target = current_target()
return target is not None and target.arch == "gfx942"
@constexpr_function
def is_hip_cdna4():
target = current_target()
return target is not None and target.arch == "gfx950"
@constexpr_function
def get_cdna_version():
"""
AMD CDNA generation: 3 (gfx942) or 4 (gfx950); -1 if unknown / non-HIP.
"""
target = current_target()
if target is None or target.backend != "hip":
return -1
if target.arch == "gfx942":
return 3
if target.arch == "gfx950":
return 4
return -1
@constexpr_function
def has_tma_gather():
return cuda_capability_geq(10, 0)
@constexpr_function
def has_native_mxfp():
return cuda_capability_geq(10, 0)
def num_sms():
return torch.cuda.get_device_properties(0).multi_processor_count
from dataclasses import dataclass, fields
from typing import Type
import torch
from triton.tools.tensor_descriptor import TensorDescriptor
from triton.tools.ragged_tma import create_ragged_descriptor
from .reduction_details.reduce_bitmatrix import clear_sums, sum_bitmatrix_rows
from .target_info import cuda_capability_geq
from .tensor_details.layout import Layout, StridedLayout
@dataclass
class Storage:
data: torch.Tensor
layout: Layout = None
def __post_init__(self):
assert isinstance(self.data, torch.Tensor)
if self.layout is None:
self.layout = StridedLayout(self.data.shape)
@property
def device(self):
return self.data.device
def is_tma_compliant(self):
# TMAs didn't exist until Hopper
if not cuda_capability_geq(9, 0):
return False
# TMAs only exist for 2D, 3D, 5D inputs
if len(self.data.shape) not in [2, 3, 5]:
return False
# TMAs need at most one stride equal to 1
# and all other strides divisble by 16
strides = list(self.data.stride())
try:
major_dim = strides.index(1)
except ValueError:
major_dim = -1
ndim = self.data.ndim
bitwidth = 4 if self.data.dtype == torch.uint8 else self.data.element_size() * 8
compliant = [
strides[i] * bitwidth % 128 == 0 for i in range(ndim) if i != major_dim
]
return all(compliant)
def make_dense_tma(self, block_shape, transpose=False):
strides = list(self.data.stride())
shape = list(self.data.shape)
transpose = self.data.stride()[-1] != 1
if transpose:
block_shape = block_shape[:-2] + [block_shape[-1], block_shape[-2]]
shape = shape[:-2] + [shape[-1], shape[-2]]
strides = strides[:-2] + [strides[-1], strides[-2]]
if self.data.dtype == torch.uint8 and self.layout.name == "BLACKWELL_VALUE":
indx = strides.index(1)
block_shape[indx] = block_shape[indx] // 2
if shape[-1] % 128 != 0:
raise ValueError(
"inner shape need to be multiple of 128 for "
"mxfp4 (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B) TMAs."
)
block_shape = self.layout.swizzle_block_shape(block_shape)
return TensorDescriptor(self.data, shape, strides, block_shape)
def make_tma(self, block_shape, mode, transpose=False):
if mode in ["dense", "gather", "scatter"]:
return self.make_dense_tma(block_shape, transpose)
assert mode == "ragged"
ragged_dim = len(self.data.shape) - 2
return create_ragged_descriptor(self.data, block_shape, ragged_dim=ragged_dim)
@dataclass
class IntegerType:
bitwidth: int
@dataclass
class FloatType:
bitwidth_exponent: int
bitwidth_mantissa: int
is_signed: bool
def __post_init__(self):
self.bitwidth = (
int(self.is_signed) + self.bitwidth_exponent + self.bitwidth_mantissa
)
BIT = IntegerType(1)
FP4 = FloatType(bitwidth_exponent=2, bitwidth_mantissa=1, is_signed=True)
def bitwidth(type: IntegerType | FloatType | torch.dtype):
if isinstance(type, torch.dtype):
return type.itemsize * 8
return type.bitwidth
@dataclass
class Tensor:
storage: Storage | torch.Tensor
dtype: IntegerType | FloatType | torch.dtype = None
shape: list[int] | None = None
shape_max: list[int] | None = None
def __post_init__(self):
# set storage
if isinstance(self.storage, torch.Tensor):
self.storage = Storage(self.storage)
# initialize dtype
if self.dtype is None:
self.dtype = self.storage.data.dtype
if bitwidth(self.dtype) < 8 and self.shape is None:
raise ValueError("shape must be provided for sub-byte types")
# initialize shape
if self.shape is None:
self.shape = list(self.storage.data.shape)
# validate shape: all elements must be `int` or numel-1 `torch.Tensor`
is_int = lambda s: isinstance(s, int)
is_item = lambda s: hasattr(s, "numel") and s.numel() == 1
assert all(map(lambda s: is_int(s) or is_item(s), self.shape))
# initialize shape_max
if self.shape_max is None:
self.shape_max = [None] * len(self.shape)
for i, (s, smax) in enumerate(zip(self.shape, self.shape_max)):
if smax is not None and not is_int(smax):
raise ValueError(
f"shape_max[{i}] must be `int` or `None`; got {type(smax)}"
)
if smax is None:
self.shape_max[i] = s
# validate shape_max: all elements must be `int`
assert all(map(is_int, self.shape_max))
# torch compatibility layer
@property
def ndim(self):
return len(self.shape)
@property
def device(self):
return self.storage.device
def stride(self, i=None):
return self.storage.data.stride() if i is None else self.storage.data.stride(i)
def data_ptr(self):
return self.storage.data.data_ptr()
def numel(self):
return self.storage.data.numel()
def element_size(self):
return bitwidth(self.dtype) // 8
@property
def data(self):
t = self.storage
return t.data if isinstance(t, Storage) else t
def dim(self):
return self.ndim
def size(self, i=None):
if i is None:
return self.shape
return self.shape[i]
@dataclass
class Bitmatrix(Tensor):
"""
Represents a boolean matrix in a packed format where each element occupies
a single bit of memory.
_scratchpad is either None or an all-zero array of size >= shape[-1]; we pass it along
with the actual bitmatrix to avoid having to launch a separate memset
kernel when we call Bitmatrix::sum().
"""
scratchpad: torch.Tensor = None
def __init__(self, storage, shape, shape_max=None, scratchpad=None):
super().__init__(storage, dtype=BIT, shape=shape, shape_max=shape_max)
self.scratchpad = scratchpad
def sum(self, partials_block_size):
_, n_cols = self.shape
dev = self.device
if self.scratchpad is None:
self.scratchpad = clear_sums(n_cols, dev)
out_ret = self.scratchpad[:n_cols]
self.scratchpad = None # throw error if we try to sum again
return sum_bitmatrix_rows(self, out_ret, partials_block_size)
def get_layout(tensor: torch.Tensor | Tensor | None):
if tensor is None:
return None
if isinstance(tensor, Tensor):
return tensor.storage.layout
return StridedLayout
def wrap_torch_tensor(torch_tensor, dtype=None):
if dtype is None:
dtype = torch_tensor.dtype
shape = list(torch_tensor.shape)
shape[torch_tensor.stride().index(1)] *= bitwidth(torch_tensor.dtype) // bitwidth(
dtype
)
return Tensor(Storage(torch_tensor), dtype=dtype, shape=shape)
def convert_layout(tensor: Tensor, layout_cls: Type[Layout], **layout_kwargs):
assert isinstance(tensor, Tensor)
old_storage = tensor.storage
old_data = old_storage.layout.unswizzle_data(old_storage.data)
new_layout = layout_cls(old_data.shape, **layout_kwargs)
new_data = new_layout.swizzle_data(old_data)
attrs = {
k.name: getattr(tensor, k.name) for k in fields(tensor) if k.name != "storage"
}
return Tensor(Storage(new_data, new_layout), **attrs)
from .layout_details.base import Layout
from .layout_details.blackwell_scale import BlackwellMXScaleLayout
from .layout_details.blackwell_value import BlackwellMXValueLayout
from .layout_details.hopper_scale import HopperMXScaleLayout
from .layout_details.hopper_value import HopperMXValueLayout
from .layout_details.cdna4_scale import CDNA4MXScaleLayout
from .layout_details.strided import StridedLayout
from ..target_info import cuda_capability_geq, is_hip_cdna4
__all__ = [
"Layout",
"BlackwellMXValueLayout",
"BlackwellMXScaleLayout",
"HopperMXScaleLayout",
"HopperMXValueLayout",
"CDNA4MXScaleLayout",
"StridedLayout",
]
def make_default_matmul_mxfp4_w_layout(mx_axis: int):
if cuda_capability_geq(10):
# return StridedLayout, dict()
return BlackwellMXValueLayout, dict()
elif cuda_capability_geq(9):
return HopperMXValueLayout, {"mx_axis": mx_axis}
else:
return StridedLayout, dict()
def make_default_matmul_mxfp4_w_scale_layout(mx_axis: int, num_warps: int = 8):
if is_hip_cdna4():
return CDNA4MXScaleLayout, dict()
else:
if cuda_capability_geq(10):
return BlackwellMXScaleLayout, dict()
elif cuda_capability_geq(9):
return HopperMXScaleLayout, {"mx_axis": mx_axis, "num_warps": num_warps}
return StridedLayout, dict()
from abc import ABC, abstractmethod
class Layout(ABC):
def __init__(self, shape) -> None:
self.initial_shape = shape
@abstractmethod
def swizzle_data(self, data):
pass
@abstractmethod
def unswizzle_data(self, data):
pass
@abstractmethod
def swizzle_block_shape(self, block_shape):
pass
import math
import triton
import triton.language as tl
import torch
from .base import Layout
SWIZZLE_ALIGN_INNER = 8
SWIZZLE_SIZE_INNER = 4
SWIZZLE_SIZE_OUTER = 128
class BlackwellMXScaleLayout(Layout):
name: str = "BLACKWELL_SCALE"
def __init__(self, shape) -> None:
super().__init__(shape)
(
*self.leading_shape,
self.K,
self.N,
) = shape
self.B = math.prod(self.leading_shape)
self.ALIGN_K = 8
self.ALIGN_N = 128
self.SWIZZLE_K = 4
self.K_pad = (self.K + self.ALIGN_K - 1) // self.ALIGN_K * self.ALIGN_K
self.N_pad = (self.N + self.ALIGN_N - 1) // self.ALIGN_N * self.ALIGN_N
def swizzle_data(self, data):
data = torch.nn.functional.pad(
data, (0, self.N_pad - self.N, 0, self.K_pad - self.K)
)
data = data.transpose(-1, -2).contiguous()
data = data.reshape(
self.B,
self.N_pad // self.ALIGN_N,
self.ALIGN_N // 32,
32,
self.K_pad // self.SWIZZLE_K,
self.SWIZZLE_K,
)
data = data.transpose(2, 4).contiguous()
data = data.view(1, self.B * self.N_pad // 128, self.K_pad // 4, 2, 256)
return data
def unswizzle_data(self, data):
data = data.reshape(
self.B,
self.N_pad // self.ALIGN_N,
self.K_pad // self.SWIZZLE_K,
32,
self.ALIGN_N // 32,
self.SWIZZLE_K,
)
data = data.transpose(2, 4)
data = data.reshape(*self.leading_shape, self.N_pad, self.K_pad)
data = data.transpose(-1, -2)
return data[..., : self.K, : self.N]
def swizzle_block_shape(self, block_shape):
MX_PACK_DIVISOR = 32
MX_SCALE_BLOCK_K = block_shape[1] // MX_PACK_DIVISOR
return [1, block_shape[0] // 128, MX_SCALE_BLOCK_K // 4, 2, 256]
@triton.jit
def unswizzle_mx_scale_bw(
x,
SIZE_OUTER: tl.constexpr = SWIZZLE_SIZE_OUTER,
SIZE_INNER: tl.constexpr = SWIZZLE_SIZE_INNER,
ALIGN_INNER: tl.constexpr = SWIZZLE_ALIGN_INNER,
):
shape_0: tl.constexpr = x.shape[0]
shape_1: tl.constexpr = x.shape[1]
tl.static_assert(shape_1 % SIZE_OUTER == 0)
tl.static_assert(shape_1 // SIZE_OUTER <= ALIGN_INNER)
x = x.reshape(
shape_0, (shape_1 // SIZE_OUTER) // SIZE_INNER, 32, SIZE_OUTER // 32, SIZE_INNER
)
x = x.trans(0, 3, 2, 1, 4).reshape(shape_0 * SIZE_OUTER, shape_1 // SIZE_OUTER)
return x
import torch
from .base import Layout
class BlackwellMXValueLayout(Layout):
name: str = "BLACKWELL_VALUE"
def __init__(self, shape) -> None:
super().__init__(shape)
self.shape = shape
def swizzle_data(self, data):
# permutation needed to make `data` row major
to_row_major = sorted(range(data.ndim), key=lambda d: (data.stride(d), d))[::-1]
# permutation needed to retrieve original order
inv = [0] * data.ndim
for i, d in enumerate(to_row_major):
inv[d] = i
# leading dimension must be padded to be aligned to 128
align_dim = lambda x: (x + 128 - 1) // 128 * 128
major_dim = data.stride().index(1)
pad = align_dim(data.shape[major_dim]) - data.shape[major_dim]
data = torch.nn.functional.pad(data.permute(to_row_major), (0, pad)).permute(
inv
)
return data
def unswizzle_data(self, data: torch.Tensor):
# Trim padding along all dims back to the original shape recorded at init.
assert data.ndim == len(self.shape), (
"Rank mismatch between data and recorded shape"
)
sizes = [min(data.size(i), self.shape[i]) for i in range(data.ndim)]
return data[tuple(slice(0, s) for s in sizes)]
def swizzle_block_shape(self, block_shape):
return block_shape
import triton
import triton.language as tl
from .base import Layout
NON_K_PRESHUFFLE_BLOCK_SIZE = 32
class CDNA4MXScaleLayout(Layout):
name: str = "CDNA4_SCALE"
def __init__(self, shape) -> None:
super().__init__(shape)
def swizzle_data(self, data):
block_shape = data.shape
SCALE_K = block_shape[-2]
N = block_shape[-1]
data = data.transpose(-1, -2)
data = data.view(
-1, N // NON_K_PRESHUFFLE_BLOCK_SIZE, 2, 16, SCALE_K // 8, 2, 4, 1
)
data = data.permute(0, 1, 4, 6, 3, 5, 2, 7).contiguous()
if len(block_shape) == 3:
E = block_shape[0]
data = data.reshape(E, N // 32, SCALE_K * 32)
else:
assert len(block_shape) == 2
data = data.reshape(N // 32, SCALE_K * 32)
return data.transpose(-1, -2)
def unswizzle_data(self, data):
raise NotImplementedError()
def swizzle_block_shape(self, block_shape):
SCALE_K = block_shape[-2]
N = block_shape[-1]
return block_shape[:-2] + [N // 32, SCALE_K * 32]
@triton.jit
def unswizzle_mx_scale_cdna4(
x,
BLOCK_N: tl.constexpr,
MX_SCALE_BLOCK_K: tl.constexpr,
N_PRESHUFFLE_FACTOR: tl.constexpr = NON_K_PRESHUFFLE_BLOCK_SIZE,
):
x = x.reshape(BLOCK_N // N_PRESHUFFLE_FACTOR, MX_SCALE_BLOCK_K // 8, 4, 16, 2, 2, 1)
x = x.permute(0, 5, 3, 1, 4, 2, 6)
x = x.reshape(BLOCK_N, MX_SCALE_BLOCK_K)
return x
import torch
import triton
import triton.language as tl
from .base import Layout
class HopperMXScaleLayout(Layout):
name: str = "HOPPER_SCALE"
def __init__(self, shape, mx_axis, num_warps=8) -> None:
assert num_warps & (num_warps - 1) == 0, "warps_n must be a power of 2"
super().__init__(shape)
self.mx_axis = mx_axis
self.num_warps = num_warps
*self.leading_shape, _, _ = shape
def _maybe_mT(self, data):
if self.mx_axis == len(self.leading_shape):
return data.contiguous().mT
return data
def swizzle_data(self, data):
data = self._maybe_mT(data).contiguous()
*batch, M, K = data.shape
SWIZZLE_ALIGN_M = 2 * self.num_warps * 2 * 8
SWIZZLE_ALIGN_K = 2
pad_m = (SWIZZLE_ALIGN_M - (M % SWIZZLE_ALIGN_M)) % SWIZZLE_ALIGN_M
pad_k = (SWIZZLE_ALIGN_K - (K % SWIZZLE_ALIGN_K)) % SWIZZLE_ALIGN_K
data = torch.nn.functional.pad(data, (0, pad_k, 0, pad_m))
*batch, M, K = data.shape
assert data.is_contiguous()
assert M % (2 * self.num_warps * 2 * 8) == 0 and K % 2 == 0, (
f"Input tensor must have a subtile of shape (..., {2 * self.num_warps * 2 * 8}, 2)"
)
b = len(batch)
data = data.reshape(
*batch,
M // (2 * self.num_warps * 2 * 8),
2,
self.num_warps,
2,
8,
K // 2,
2,
)
perm = [0, 2, 5, 1, 4, 6, 3]
perm = list(range(b)) + [b + p for p in perm]
data = data.permute(*perm)
data = data.flatten(-5, -1)
data = data.flatten(-3, -2)
assert data.shape[-2] == M // 32
assert data.shape[-1] == K * 32
data = self._maybe_mT(data)
return data
def unswizzle_data(self, data):
data = self._maybe_mT(data)
*batch, M, K = data.shape
b = len(batch)
data = data.reshape(
*batch, M // self.num_warps, self.num_warps, K // 64, 2, 8, 2, 2
)
perm = [0, 3, 1, 6, 4, 2, 5]
perm = list(range(b)) + [b + p for p in perm]
data = data.permute(*perm)
data = data.reshape(*batch, M * 32, K // 32)
data = self._maybe_mT(data)
return data
def swizzle_block_shape(self, block_shape):
return block_shape
@triton.jit
def unswizzle_mxfp4_scale_hopper(x, mx_axis: tl.constexpr, num_warps: tl.constexpr):
"""
Triton inverse of swizzle_mxfp4_scale_hopper
"""
tl.static_assert(len(x.shape) == 2, "NYI")
# implementation assumes mxfp data is packed along the last dimension
x = x.trans() if mx_axis == 0 else x
M: tl.constexpr = x.shape[0]
K: tl.constexpr = x.shape[1]
tl.static_assert(M % num_warps == 0, f"M must be divisible by {num_warps}. Got {M}")
tl.static_assert(K % 64 == 0, f"K must be divisible by 64. Got {K}")
x = x.reshape(M // num_warps, num_warps, K // 64, 2, 8, 2, 2)
x = x.trans(0, 3, 1, 6, 4, 2, 5)
x = x.reshape(M * 32, K // 32)
# implementation assumed mxfp data is packed along the last dimension
x = x.trans() if mx_axis == 0 else x
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment