Unverified Commit 667632cc authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents d6dd2ddf a874e4e8
......@@ -23,8 +23,7 @@ def _is_equal(a, b):
if isinstance(a, torch.Tensor):
return a is b
# Whitelist of types that are safe to compare by value for caching.
if isinstance(a, (int, float, str, bool, type(None))) and isinstance(
b, (int, float, str, bool, type(None))):
if isinstance(a, (int, float, str, bool, type(None))) and isinstance(b, (int, float, str, bool, type(None))):
return a == b
# For other types, we cannot guarantee a cheap and safe comparison, so we fail the cache check.
return False
......@@ -58,9 +57,11 @@ def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]
if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
# For Tensors, check for object identity. For other types, check for equality.
# Python caches small integers, so `is` works for them but not for large integers like 4096.
if all(_is_equal(a, b) for a, b in zip(args, last_args)) and \
set(kwargs.keys()) == set(last_kwargs.keys()) and \
all(_is_equal(v, last_kwargs[k]) for k, v in kwargs.items()):
if (
all(_is_equal(a, b) for a, b in zip(args, last_args))
and set(kwargs.keys()) == set(last_kwargs.keys())
and all(_is_equal(v, last_kwargs[k]) for k, v in kwargs.items())
):
return last_result
result = fn(*args, **kwargs)
......@@ -79,73 +80,68 @@ def cal_seq_idx_from_cu_seqlens(cu_seqlens: torch.LongTensor, seq_len: int):
@tensor_cache
def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor,
seq_len: int) -> torch.IntTensor:
seq_idx_for_q = torch.full((seq_len,),
len(cu_seqlens_qs),
dtype=torch.int32,
device=cu_seqlens_qs.device)
def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, seq_len: int) -> torch.IntTensor:
seq_idx_for_q = torch.full((seq_len,), len(cu_seqlens_qs), dtype=torch.int32, device=cu_seqlens_qs.device)
for i in range(len(cu_seqlens_qs)):
seq_idx_for_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = i
seq_idx_for_q[cu_seqlens_qs[i] : cu_seqlens_qe[i]] = i
return seq_idx_for_q
@tensor_cache
def cal_cu_seqlen_ks_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor,
cu_seqlens_ks: torch.LongTensor, seq_len: int) -> torch.IntTensor:
def cal_cu_seqlen_ks_for_q(
cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, cu_seqlens_ks: torch.LongTensor, seq_len: int
) -> torch.IntTensor:
cu_seqlen_ks_for_each_q = torch.gather(
input=torch.cat([
cu_seqlens_ks,
torch.full((1,),
torch.iinfo(torch.int32).max,
dtype=torch.int32,
device=cu_seqlens_qs.device)
]),
input=torch.cat([cu_seqlens_ks, torch.full((1,), torch.iinfo(torch.int32).max, dtype=torch.int32, device=cu_seqlens_qs.device)]),
dim=0,
index=cal_seq_idx_for_q(
cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long())
index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long(),
)
return cu_seqlen_ks_for_each_q.int()
@tensor_cache
def cal_cu_seqlen_ke_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor,
cu_seqlens_ks: torch.LongTensor, cu_seqlens_ke: torch.LongTensor,
q_start_idxs: torch.LongTensor, seq_len: int,
kv_stride: int) -> torch.IntTensor:
def cal_cu_seqlen_ke_for_q(
cu_seqlens_qs: torch.LongTensor,
cu_seqlens_qe: torch.LongTensor,
cu_seqlens_ks: torch.LongTensor,
cu_seqlens_ke: torch.LongTensor,
q_start_idxs: torch.LongTensor,
seq_len: int,
kv_stride: int,
) -> torch.IntTensor:
cu_seqlen_ke_for_each_q = torch.gather(
input=torch.cat(
[cu_seqlens_ke,
torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]),
input=torch.cat([cu_seqlens_ke, torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]),
dim=0,
index=cal_seq_idx_for_q(
cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long())
casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,),
dtype=torch.int32,
device=cu_seqlens_qs.device)
index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long(),
)
casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,), dtype=torch.int32, device=cu_seqlens_qs.device)
for i in range(len(cu_seqlens_qs)):
casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = (torch.arange(
q_start_idxs[i],
q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i],
dtype=torch.int32,
device=cu_seqlens_qs.device) + 1) // kv_stride + cu_seqlens_ks[i]
casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i] : cu_seqlens_qe[i]] = (
torch.arange(
q_start_idxs[i], q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i], dtype=torch.int32, device=cu_seqlens_qs.device
)
+ 1
) // kv_stride + cu_seqlens_ks[i]
cu_seqlen_ke_for_each_q = torch.minimum(casual_cu_seqlen_ke_for_each_q, cu_seqlen_ke_for_each_q)
return cu_seqlen_ke_for_each_q.int()
@tensor_cache
def cal_ks_ke_from_cu_seqlen_qk(cu_seqlens_q: torch.LongTensor,
cu_seqlens_k: torch.LongTensor = None,
offs_q: torch.LongTensor = None,
*,
seq_len: int,
kv_stride: int = 1,
cp_rank: int = 0,
cp_size: int = 1,
balanced_cp=False):
'''
def cal_ks_ke_from_cu_seqlen_qk(
cu_seqlens_q: torch.LongTensor,
cu_seqlens_k: torch.LongTensor = None,
offs_q: torch.LongTensor = None,
*,
seq_len: int,
kv_stride: int = 1,
cp_rank: int = 0,
cp_size: int = 1,
balanced_cp=False,
):
"""
seq_len: seq len per cp rank
balanced cp slice assignment: 0 1 2 3 3 2 1 0
'''
"""
n_seq = len(cu_seqlens_q) - 1
assert n_seq > 0
assert cu_seqlens_q.shape == (n_seq + 1,)
......@@ -170,10 +166,12 @@ def cal_ks_ke_from_cu_seqlen_qk(cu_seqlens_q: torch.LongTensor,
def f(x: torch.Tensor):
chunks = x.chunk(cp_size * 2)
return torch.cat([
chunks[cp_rank],
chunks[cp_size - cp_rank - 1],
])
return torch.cat(
[
chunks[cp_rank],
chunks[cp_size - cp_rank - 1],
]
)
ks = f(ks)
ke = f(ke)
......@@ -189,8 +187,7 @@ def ceil_to_ue8m0(x: torch.Tensor):
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int],
use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int], use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
sf = x_amax / 448.0
......@@ -239,14 +236,18 @@ def generate_random_cu_seqlens(per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1,
total_seqlen - (cp_rank + 1) * per_chunk_seqlen,
total_seqlen - cp_rank * per_chunk_seqlen,
)
ks = torch.cat([
cu_seqlens_ks_for_each_q[slice_short],
cu_seqlens_ks_for_each_q[slice_long],
])
ke = torch.cat([
cu_seqlens_ke_for_each_q[slice_short],
cu_seqlens_ke_for_each_q[slice_long],
])
ks = torch.cat(
[
cu_seqlens_ks_for_each_q[slice_short],
cu_seqlens_ks_for_each_q[slice_long],
]
)
ke = torch.cat(
[
cu_seqlens_ke_for_each_q[slice_short],
cu_seqlens_ke_for_each_q[slice_long],
]
)
assert len(ks) == len(ke) == per_cp_seqlen
return ks, ke
......@@ -302,11 +303,9 @@ def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True):
raise_assert: Whether to raise assertion error on failure
"""
sim = calculate_tensor_similarity(x, y, name)
diff = 1. - sim
diff = 1.0 - sim
if not (0 <= diff <= eps):
print(
f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m"
)
print(f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m")
if raise_assert:
assert False # noqa: B011
......@@ -316,11 +315,8 @@ if __name__ == "__main__":
cu_seqlens = torch.randint(128, 4096, (1000,), dtype=torch.int32, device="cuda")
last_idx = torch.where(cu_seqlens.cumsum(dim=0) >= seq_len)[0][0]
cu_seqlens_cumsum = cu_seqlens[:last_idx].cumsum(dim=0)
cu_seqlens_qs = torch.cat(
[torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), cu_seqlens_cumsum])
cu_seqlens_qe = torch.cat(
[cu_seqlens_cumsum,
torch.ones(1, dtype=torch.int32, device=cu_seqlens.device) * seq_len])
cu_seqlens_qs = torch.cat([torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), cu_seqlens_cumsum])
cu_seqlens_qe = torch.cat([cu_seqlens_cumsum, torch.ones(1, dtype=torch.int32, device=cu_seqlens.device) * seq_len])
from tilelang.profiler import do_bench
......
......@@ -39,12 +39,10 @@ def torch_convert_bit_twiddling(tensor):
res0 = val_concat_expanded & mask
res1 = (val_concat_expanded << 3) & mask
res2 = (val_concat_expanded << 6) & mask
res3 = ((val_concat_expanded << 1) & mask1) | ((val_concat_expanded >> 3) & mask2) | (
(val_concat_expanded >> 7) & mask3)
res3 = ((val_concat_expanded << 1) & mask1) | ((val_concat_expanded >> 3) & mask2) | ((val_concat_expanded >> 7) & mask3)
# Select the correct result based on position
bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1,
torch.where(pos == 2, res2, res3)))
bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1, torch.where(pos == 2, res2, res3)))
# Convert to uint16 for .view(torch.bfloat16)
bf16_uint16 = (bf16 & 0xFFFF).to(torch.uint16)
......@@ -110,7 +108,7 @@ def print_bit(name, val):
val (torch.Tensor): A scalar PyTorch tensor (numeric) whose 32-bit binary representation will be shown.
"""
val_cpu = val.cpu().item()
binary_repr = f'{val_cpu:032b}'
binary_repr = f"{val_cpu:032b}"
print(name, binary_repr)
......@@ -122,7 +120,7 @@ def calc_sim(x, y, name="tensor"):
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print_red_warning(f'{name} all zero')
print_red_warning(f"{name} all zero")
return 1
sim = 2 * (x * y).sum() / denominator
return sim
......@@ -132,21 +130,19 @@ def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
x_mask = torch.isfinite(x)
y_mask = torch.isfinite(y)
if not torch.all(x_mask == y_mask):
print_red_warning(f'{name} Error: isfinite mask mismatch')
print_red_warning(f"{name} Error: isfinite mask mismatch")
if raise_assert:
raise AssertionError
if not torch.isclose(
x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0,
equal_nan=True).all():
print_red_warning(f'{name} Error: nonfinite value mismatch')
if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all():
print_red_warning(f"{name} Error: nonfinite value mismatch")
if raise_assert:
raise AssertionError
x = x.masked_fill(~x_mask, 0)
y = y.masked_fill(~y_mask, 0)
sim = calc_sim(x, y, name)
diff = (1. - sim).item()
print(f'{diff=}')
diff = (1.0 - sim).item()
print(f"{diff=}")
if not (0 <= diff <= eps):
print_red_warning(f'{name} Error: {diff=}')
print_red_warning(f"{name} Error: {diff=}")
if raise_assert:
raise AssertionError
......@@ -24,6 +24,7 @@ def get_configs():
the parameter name to its chosen value.
"""
import itertools
iter_params = dict(
block_M=[64, 128, 256],
block_N=[64, 128, 256],
......@@ -32,65 +33,64 @@ def get_configs():
threads=[128, 256, 512],
split=[1, 2],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs(),)
@tilelang.autotune(
configs=get_configs(),
)
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
},
pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True},
)
def matmul(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
source_format='uint',
num_bits=4,
fast_dequant=True,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1):
def matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
source_format=T.uint32,
num_bits=4,
fast_dequant=True,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1,
):
"""
Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T.
This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts:
- A: dense input of shape (M, K) with dtype `in_dtype`.
- B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`.
- C: output of shape (M, N) with dtype `out_dtype`.
The generated kernel supports two dequantization paths:
- fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group.
- simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element.
Important behavior and requirements:
- num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits.
- QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes.
- Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid.
- When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group.
- The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages.
Parameters that alter kernel layout/behavior (brief):
- block_M, block_N, block_K: tile sizes for M, N, and K dimensions.
- num_stages: number of software pipeline stages for the K-loop.
- threads: number of threads used per kernel block.
- split: extra K-splitting factor; K must be divisible by block_K * split.
- source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics.
Returns:
A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel.
"""
Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T.
This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts:
- A: dense input of shape (M, K) with dtype `in_dtype`.
- B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`.
- C: output of shape (M, N) with dtype `out_dtype`.
The generated kernel supports two dequantization paths:
- fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group.
- simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element.
Important behavior and requirements:
- num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits.
- QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes.
- Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid.
- When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group.
- The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages.
Parameters that alter kernel layout/behavior (brief):
- block_M, block_N, block_K: tile sizes for M, N, and K dimensions.
- num_stages: number of software pipeline stages for the K-loop.
- threads: number of threads used per kernel block.
- split: extra K-splitting factor; K must be divisible by block_K * split.
- source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics.
Returns:
A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel.
"""
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
storage_dtype = T.uint8
QK = K // num_elems_per_byte
Block_QK = block_K // num_elems_per_byte
......@@ -121,7 +121,7 @@ def matmul(M,
assert func_name is not None, "mxfp_intrin_info is not found"
import_source = import_source
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"):
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16):
"""
Create a TileLang macro that performs fast, twiddling-based dequantization from packed FP4 to BF16 using an external runtime plugin.
......@@ -131,13 +131,13 @@ def matmul(M,
- Writes the dequantized BF16 values back to a shared dequantized buffer for use by the kernel.
Notes and preconditions:
- Asserts that `in_dtype == "fp4"` and `out_dtype == "bfloat16"`.
- Asserts that `in_dtype == "fp4"` and `out_dtype == T.bfloat16`.
- The generated macro depends on several surrounding-scope symbols (e.g., `import_source`, `func_name`, `block_K`, `Block_QK`, `threads`, `num_elems_per_byte`, `storage_dtype`, and `out_dtype`) and expects them to be defined consistently in the enclosing kernel.
- The macro is optimized for block-wise, per-thread transactions sized to the target storage width (uses a MAX_TRANSACTION_SIZE_BITS constant) and uses local/register buffers sized accordingly.
- The macro uses `T.import_source` to bring the external plugin into the module and `T.call_extern` to perform the high-throughput dequantization; callers must ensure the external function matches the expected calling convention and memory layout.
"""
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
assert out_dtype in [T.bfloat16]
# Some variables for dequantization in each thread
MAX_TRANSACTION_SIZE_BITS = 128
......@@ -189,12 +189,11 @@ def matmul(M,
# Finally, store the dequantized data to shared memory.
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
B_dequantize_shared[index // block_K,
index % block_K] = B_dequantize_local_thread[v]
B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v]
return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"):
def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16):
"""
Create a simple TIR dequantization macro that converts packed 4-bit FP (FP4) stored in uint8 into bfloat16.
......@@ -205,7 +204,7 @@ def matmul(M,
- Writes the dequantized bfloat16 block into B_dequantize_shared.
Constraints:
- Supports only in_dtype="fp4" and out_dtype="bfloat16".
- Supports only in_dtype="fp4" and out_dtype=T.bfloat16.
- The helper assumes nbit == 4 and produces bfloat16 values.
- The macro uses a fixed test-scale of 0 (no per-element scaling) as written.
......@@ -213,49 +212,49 @@ def matmul(M,
A TIR macro function performing the described in-place block dequantization from packed uint8 FP4 to bfloat16.
"""
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
assert out_dtype in [T.bfloat16]
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr,
scale: tir.PrimExpr, dtype: str):
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str):
"""
Convert a 4-bit FP4 value packed in a uint8 byte into a bfloat16 value.
This helper extracts the 4-bit field located at the bit position `pos` within the
byte `val`, interprets it as an FP4 (sign, exponent, mantissa) value, applies an
exponent `scale` offset to align it with bfloat16 exponent bias, clamps the
resulting exponent to 8 bits, and returns the assembled bfloat16 bit pattern.
Parameters:
nbit (int): Number of bits in the packed element; must be 4.
val (tir.PrimExpr): A uint8 value containing packed FP4 elements.
pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract.
scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16.
dtype (str): Target dtype string; must be "bfloat16".
Returns:
tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value.
Notes:
- The function asserts `nbit == 4`, `dtype == "bfloat16"`, and that `val.dtype` is "uint8".
- The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16
bit fields and clamps the computed exponent to fit into 8 bits.
Convert a 4-bit FP4 value packed in a uint8 byte into a bfloat16 value.
This helper extracts the 4-bit field located at the bit position `pos` within the
byte `val`, interprets it as an FP4 (sign, exponent, mantissa) value, applies an
exponent `scale` offset to align it with bfloat16 exponent bias, clamps the
resulting exponent to 8 bits, and returns the assembled bfloat16 bit pattern.
Parameters:
nbit (int): Number of bits in the packed element; must be 4.
val (tir.PrimExpr): A uint8 value containing packed FP4 elements.
pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract.
scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16.
dtype (str): Target dtype string; must be T.bfloat16.
Returns:
tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value.
Notes:
- The function asserts `nbit == 4`, `dtype == T.bfloat16`, and that `val.dtype` is T.uint8.
- The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16
bit fields and clamps the computed exponent to fit into 8 bits.
"""
assert nbit == 4
assert dtype == "bfloat16"
assert val.dtype == "uint8"
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16")
assert dtype == T.bfloat16
assert val.dtype == T.uint8
mask = tir.const((1 << nbit) - 1, T.uint16)
f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask
s = f4 >> tir.const(3, T.uint16)
e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16)
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16")
e_bf16 = e_f4 + tir.const(126, T.uint16)
# Scale is the exponential part, within the representation of uint8
# To handle the overflow, we use the max function to limit the exponential part to 8 bits
e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16")
e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, T.uint16))
m_f4 = f4 & tir.const(1, T.uint16)
val_bf16 = tir.reinterpret(
"bfloat16", ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16"))
| (m_f4 << tir.const(6, "uint16"))).astype("uint16"))
T.bfloat16,
((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16),
)
return val_bf16
@T.macro
......@@ -292,32 +291,32 @@ def matmul(M,
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
):
"""
Kernel entry for the tiled, pipelined matmul used by the generated prim_func.
This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it:
- Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile.
- Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout.
- Pipelines over K in chunks of `block_K` for `num_stages` stages:
- Loads A and packed B tiles into shared memory.
- Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine.
- Performs a GEMM accumulating into C_local with B transposed.
- Stores the accumulated block from C_local back to the global output C via C_shared.
Parameters:
- A: input tile of shape (M, K) with dtype `in_dtype`.
- B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing).
- C: output tensor of shape (M, N) with dtype `out_dtype`.
Side effects:
- Writes the computed output block into the global tensor `C`.
- Uses and updates shared memory buffers and per-thread accumulators.
No value is returned.
Kernel entry for the tiled, pipelined matmul used by the generated prim_func.
This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it:
- Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile.
- Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout.
- Pipelines over K in chunks of `block_K` for `num_stages` stages:
- Loads A and packed B tiles into shared memory.
- Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine.
- Performs a GEMM accumulating into C_local with B transposed.
- Stores the accumulated block from C_local back to the global output C via C_shared.
Parameters:
- A: input tile of shape (M, K) with dtype `in_dtype`.
- B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing).
- C: output tensor of shape (M, N) with dtype `out_dtype`.
Side effects:
- Writes the computed output block into the global tensor `C`.
- Uses and updates shared memory buffers and per-thread accumulators.
No value is returned.
"""
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -327,9 +326,11 @@ def matmul(M,
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
T.annotate_layout({
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
})
T.annotate_layout(
{
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
}
)
T.clear(C_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
......@@ -344,7 +345,7 @@ def matmul(M,
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N])
T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N])
return main
......@@ -363,7 +364,7 @@ def ref_program_twiddling(A, qB):
Returns:
torch.Tensor: Result matrix C with shape (M, N) in bfloat16.
"""
dtypeC = "bfloat16"
dtypeC = T.bfloat16
B = torch_convert_bit_twiddling(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
......@@ -383,7 +384,7 @@ def ref_program_simple(A, qB):
Returns:
torch.Tensor: Resulting matrix C in bfloat16 with shape (M, N).
"""
dtypeC = "bfloat16"
dtypeC = T.bfloat16
B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
......@@ -409,16 +410,15 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False):
"""
total_flops = 2 * m * n * k
if tune:
kernel = matmul(
m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, fast_dequant=fast_dequant)
kernel = matmul(m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, fast_dequant=fast_dequant)
else:
kernel = matmul(
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
T.bfloat16,
T.bfloat16,
T.float32,
num_bits=4,
fast_dequant=fast_dequant,
block_M=256,
......@@ -426,7 +426,8 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False):
block_K=128,
num_stages=2,
threads=256,
split=1)
split=1,
)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto)
if fast_dequant:
profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01)
......
......@@ -7,45 +7,45 @@ import torch
from dequantize_utils import torch_convert_bit_twiddling, torch_convert
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr,
dtype: str):
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str):
"""
Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale.
Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale.
This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its
bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by
`scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation.
This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its
bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by
`scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation.
Parameters:
nbit (int): Number of bits in the packed field (must be 4).
val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields.
pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field).
scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like).
dtype (str): Destination dtype string (must be "bfloat16").
Parameters:
nbit (int): Number of bits in the packed field (must be 4).
val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields.
pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field).
scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like).
dtype (str): Destination dtype string (must be T.bfloat16).
Returns:
tir.PrimExpr: The resulting value reinterpreted as `bfloat16`.
Returns:
tir.PrimExpr: The resulting value reinterpreted as `bfloat16`.
Notes:
- Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8".
- The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern.
"""
Notes:
- Preconditions are enforced via assertions: nbit == 4, dtype == T.bfloat16, and val.dtype == T.uint8.
- The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern.
"""
assert nbit == 4
assert dtype == "bfloat16"
assert val.dtype == "uint8"
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16")
assert dtype == T.bfloat16
assert val.dtype == T.uint8
mask = tir.const((1 << nbit) - 1, T.uint16)
f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask
s = f4 >> tir.const(3, T.uint16)
e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16)
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16")
e_bf16 = e_f4 + tir.const(126, T.uint16)
# Scale is the exponential part, within the representation of uint8
# To handle the overflow, we may use the min function to limit the exponential part to 8 bits
# e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16")
val_bf16 = tir.reinterpret("bfloat16",
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16"))
| (m_f4 << tir.const(6, "uint16"))).astype("uint16"))
m_f4 = f4 & tir.const(1, T.uint16)
val_bf16 = tir.reinterpret(
T.bfloat16,
((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16),
)
return val_bf16
......@@ -65,6 +65,7 @@ def get_configs():
List[dict]: A list of configuration dictionaries covering all combinations.
"""
import itertools
iter_params = dict(
block_M=[64, 128, 256],
block_N=[64, 128, 256],
......@@ -73,70 +74,74 @@ def get_configs():
threads=[128, 256, 512],
split=[1, 2],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs(),)
@tilelang.jit(out_idx=[-1],)
def matmul(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
source_format='uint',
num_bits=4,
scale_size=32,
fast_dequant=True,
with_bias=False,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1):
return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(
configs=get_configs(),
)
@tilelang.jit(
out_idx=[-1],
)
def matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
source_format=T.uint32,
num_bits=4,
scale_size=32,
fast_dequant=True,
with_bias=False,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1,
):
"""
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
The generated kernel accepts:
- A: dense matrix with element type `in_dtype`.
- B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)).
- Scale: per-block scale/exponent information used to dequantize B.
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., "bfloat16").
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
scale_size (int, optional): number of elements grouped per scale entry (default 32).
fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True).
block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128).
num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1).
with_bias (bool, optional): whether to add Bias to the output (default False).
Returns:
A T.prim_func implementing the tiled, pipelined GEMM that:
- loads tiled blocks of A and packed B to shared memory,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- writes the final MxN block to the global output tensor.
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
Notes:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
- An assertion enforces that K % (block_K * split) == 0.
The generated kernel accepts:
- A: dense matrix with element type `in_dtype`.
- B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)).
- Scale: per-block scale/exponent information used to dequantize B.
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., T.bfloat16).
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
scale_size (int, optional): number of elements grouped per scale entry (default 32).
fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True).
block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128).
num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1).
with_bias (bool, optional): whether to add Bias to the output (default False).
Returns:
A T.prim_func implementing the tiled, pipelined GEMM that:
- loads tiled blocks of A and packed B to shared memory,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- writes the final MxN block to the global output tensor.
Notes:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
- An assertion enforces that K % (block_K * split) == 0.
"""
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
storage_dtype = T.uint8
QK = K // num_elems_per_byte
Block_QK = block_K // num_elems_per_byte
A_shape = (M, K)
......@@ -150,6 +155,7 @@ def matmul(M,
assert K % (block_K * split) == 0
from tilelang.quantize import get_mxfp_intrin_group
# fast_dequant_bf16_fp4_twiddling
mxfp_intrin_info = get_mxfp_intrin_group(
out_dtype=in_dtype,
......@@ -164,7 +170,7 @@ def matmul(M,
assert func_name is not None, "mxfp_intrin_info is not found"
import_source = import_source
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"):
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16):
"""
Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16.
......@@ -175,12 +181,12 @@ def matmul(M,
- Writes the scaled BF16 results into B_dequantize_shared.
Notes:
- This factory only supports in_dtype="fp4" and out_dtype="bfloat16".
- This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16.
- The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro.
- The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime.
"""
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
assert out_dtype in [T.bfloat16]
# Some variables for dequantization in each thread
MAX_TRANSACTION_SIZE_BITS = 128
......@@ -252,24 +258,23 @@ def matmul(M,
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
B_dequantize_shared[index // block_K,
index % block_K] = B_dequantize_local_thread[v]
B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v]
return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"):
def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16):
"""
Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16.
Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared.
Notes:
- Only supports in_dtype="fp4" and out_dtype="bfloat16".
- Only supports in_dtype="fp4" and out_dtype=T.bfloat16.
- The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion.
- Scale holds the exponent-like scaling values indexed per output element as used by the conversion helper.
"""
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
assert out_dtype in [T.bfloat16]
@T.macro
def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k):
......@@ -301,33 +306,32 @@ def matmul(M,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
Scale[
bx * block_N + i, k * block_K // scale_size + j //
scale_size], # Scale is the exponential part, within the representation of uint8
bx * block_N + i, k * block_K // scale_size + j // scale_size
], # Scale is the exponential part, within the representation of uint8
dtype=out_dtype,
) * T.shift_left(
1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size]))
) * T.shift_left(1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size]))
T.copy(B_dequantize_local, B_dequantize_shared)
return simple_dequant_bf16_fp4
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Scale: T.Tensor(Scale_shape, storage_dtype),
Bias: T.Tensor(Bias_shape, out_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Scale: T.Tensor(Scale_shape, storage_dtype),
Bias: T.Tensor(Bias_shape, out_dtype),
C: T.Tensor((M, N), out_dtype),
):
"""
Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C.
Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C.
This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function.
This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function.
Parameters are self-descriptive in the signature; notable behaviors:
- B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM.
- The selected dequantization path is controlled by the outer-scope flag `fast_dequant`.
- The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization).
- The function writes results in-place into C.
Parameters are self-descriptive in the signature; notable behaviors:
- B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM.
- The selected dequantization path is controlled by the outer-scope flag `fast_dequant`.
- The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization).
- The function writes results in-place into C.
"""
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -337,23 +341,26 @@ def matmul(M,
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
T.annotate_layout({
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
})
T.annotate_layout(
{
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
}
)
if with_bias:
T.annotate_layout({
Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared),
})
T.annotate_layout(
{
Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared),
}
)
if threads == 512:
T.disable_warp_group_reg_alloc()
if with_bias:
T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N],
Bias_shared)
T.copy(Bias[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N], Bias_shared)
T.copy(Bias_shared, C_local)
else:
T.clear(C_local)
......@@ -368,7 +375,7 @@ def matmul(M,
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N])
T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N])
return main
......@@ -387,9 +394,9 @@ def ref_program_twiddling(A, qB, Scale, Bias=None):
Returns:
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
"""
dtypeC = "bfloat16"
dtypeC = T.bfloat16
B = torch_convert_bit_twiddling(qB)
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
......@@ -410,9 +417,9 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias):
Returns:
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
"""
dtypeC = "bfloat16"
dtypeC = T.bfloat16
B = torch_convert_bit_twiddling(qB)
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC))
return C
......@@ -434,9 +441,9 @@ def ref_program_simple(A, qB, Scale, Bias=None):
No in-place modification is performed on inputs (a local floating copy of B is scaled).
"""
dtypeC = "bfloat16"
dtypeC = T.bfloat16
B = torch_convert(qB)
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
......@@ -462,9 +469,9 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias):
No in-place modification is performed on inputs (a local floating copy of B is scaled).
"""
dtypeC = "bfloat16"
dtypeC = T.bfloat16
B = torch_convert(qB)
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC))
return C
......@@ -491,24 +498,16 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
if tune:
kernel = matmul(
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
num_bits=4,
scale_size=scale_size,
fast_dequant=fast_dequant,
with_bias=with_bias)
m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias
)
else:
kernel = matmul(
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
T.bfloat16,
T.bfloat16,
T.float32,
num_bits=4,
scale_size=scale_size,
block_M=256,
......@@ -518,7 +517,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
threads=256,
split=1,
fast_dequant=fast_dequant,
with_bias=with_bias)
with_bias=with_bias,
)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto)
......
......@@ -7,45 +7,45 @@ import torch
from dequantize_utils import torch_convert_bit_twiddling, torch_convert
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr,
dtype: str):
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str):
"""
Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale.
Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale.
This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its
bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by
`scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation.
This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its
bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by
`scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation.
Parameters:
nbit (int): Number of bits in the packed field (must be 4).
val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields.
pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field).
scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like).
dtype (str): Destination dtype string (must be "bfloat16").
Parameters:
nbit (int): Number of bits in the packed field (must be 4).
val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields.
pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field).
scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like).
dtype (str): Destination dtype string (must be T.bfloat16).
Returns:
tir.PrimExpr: The resulting value reinterpreted as `bfloat16`.
Returns:
tir.PrimExpr: The resulting value reinterpreted as `bfloat16`.
Notes:
- Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8".
- The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern.
"""
Notes:
- Preconditions are enforced via assertions: nbit == 4, dtype == T.bfloat16, and val.dtype == T.uint8.
- The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern.
"""
assert nbit == 4
assert dtype == "bfloat16"
assert val.dtype == "uint8"
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16")
assert dtype == T.bfloat16
assert val.dtype == T.uint8
mask = tir.const((1 << nbit) - 1, T.uint16)
f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask
s = f4 >> tir.const(3, T.uint16)
e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16)
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16")
e_bf16 = e_f4 + tir.const(126, T.uint16)
# Scale is the exponential part, within the representation of uint8
# To handle the overflow, we may use the min function to limit the exponential part to 8 bits
# e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16")
val_bf16 = tir.reinterpret("bfloat16",
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16"))
| (m_f4 << tir.const(6, "uint16"))).astype("uint16"))
m_f4 = f4 & tir.const(1, T.uint16)
val_bf16 = tir.reinterpret(
T.bfloat16,
((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16),
)
return val_bf16
......@@ -65,6 +65,7 @@ def get_configs():
List[dict]: A list of configuration dictionaries covering all combinations.
"""
import itertools
iter_params = dict(
block_M=[64, 128, 256],
block_N=[64, 128, 256],
......@@ -73,70 +74,74 @@ def get_configs():
threads=[128, 256, 512],
split=[1, 2],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs(),)
@tilelang.jit(out_idx=[-1],)
def matmul(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
source_format='uint',
num_bits=4,
scale_size=32,
fast_dequant=True,
with_bias=False,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1):
return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(
configs=get_configs(),
)
@tilelang.jit(
out_idx=[-1],
)
def matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
source_format=T.uint32,
num_bits=4,
scale_size=32,
fast_dequant=True,
with_bias=False,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1,
):
"""
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
The generated kernel accepts:
- A: dense matrix with element type `in_dtype`.
- B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)).
- Scale: per-block scale/exponent information used to dequantize B.
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., "bfloat16").
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
scale_size (int, optional): number of elements grouped per scale entry (default 32).
fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True).
block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128).
num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1).
with_bias (bool, optional): whether to add Bias to the output (default False).
Returns:
A T.prim_func implementing the tiled, pipelined GEMM that:
- loads tiled blocks of A and packed B to shared memory,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- writes the final MxN block to the global output tensor.
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
Notes:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
- An assertion enforces that K % (block_K * split) == 0.
The generated kernel accepts:
- A: dense matrix with element type `in_dtype`.
- B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)).
- Scale: per-block scale/exponent information used to dequantize B.
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., T.bfloat16).
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
scale_size (int, optional): number of elements grouped per scale entry (default 32).
fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True).
block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128).
num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1).
with_bias (bool, optional): whether to add Bias to the output (default False).
Returns:
A T.prim_func implementing the tiled, pipelined GEMM that:
- loads tiled blocks of A and packed B to shared memory,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- writes the final MxN block to the global output tensor.
Notes:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
- An assertion enforces that K % (block_K * split) == 0.
"""
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
storage_dtype = T.uint8
QK = K // num_elems_per_byte
Block_QK = block_K // num_elems_per_byte
A_shape = (M, K)
......@@ -150,6 +155,7 @@ def matmul(M,
assert K % (block_K * split) == 0
from tilelang.quantize import get_mxfp_intrin_group
# fast_dequant_bf16_fp4_twiddling
mxfp_intrin_info = get_mxfp_intrin_group(
out_dtype=in_dtype,
......@@ -164,7 +170,7 @@ def matmul(M,
assert func_name is not None, "mxfp_intrin_info is not found"
import_source = import_source
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"):
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16):
"""
Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16.
......@@ -175,12 +181,12 @@ def matmul(M,
- Writes the scaled BF16 results into B_dequantize_shared.
Notes:
- This factory only supports in_dtype="fp4" and out_dtype="bfloat16".
- This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16.
- The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro.
- The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime.
"""
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
assert out_dtype in [T.bfloat16]
# Some variables for dequantization in each thread
MAX_TRANSACTION_SIZE_BITS = 128
......@@ -252,24 +258,23 @@ def matmul(M,
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
B_dequantize_shared[index // block_K,
index % block_K] = B_dequantize_local_thread[v]
B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v]
return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"):
def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16):
"""
Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16.
Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared.
Notes:
- Only supports in_dtype="fp4" and out_dtype="bfloat16".
- Only supports in_dtype="fp4" and out_dtype=T.bfloat16.
- The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion.
- Scale holds the exponent-like scaling values indexed per output element as used by the conversion helper.
"""
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
assert out_dtype in [T.bfloat16]
@T.macro
def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k):
......@@ -301,8 +306,8 @@ def matmul(M,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
Scale_shared[
i, k * block_K // scale_size + j //
scale_size], # Scale is the exponential part, within the representation of uint8
i, k * block_K // scale_size + j // scale_size
], # Scale is the exponential part, within the representation of uint8
dtype=out_dtype,
) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size]))
T.copy(B_dequantize_local, B_dequantize_shared)
......@@ -311,22 +316,22 @@ def matmul(M,
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Scale: T.Tensor(Scale_shape, storage_dtype),
Bias: T.Tensor(Bias_shape, out_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Scale: T.Tensor(Scale_shape, storage_dtype),
Bias: T.Tensor(Bias_shape, out_dtype),
C: T.Tensor((M, N), out_dtype),
):
"""
Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C.
Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C.
This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function.
This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function.
Parameters are self-descriptive in the signature; notable behaviors:
- B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM.
- The selected dequantization path is controlled by the outer-scope flag `fast_dequant`.
- The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization).
- The function writes results in-place into C.
Parameters are self-descriptive in the signature; notable behaviors:
- B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM.
- The selected dequantization path is controlled by the outer-scope flag `fast_dequant`.
- The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization).
- The function writes results in-place into C.
"""
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -339,16 +344,20 @@ def matmul(M,
# May use much more shared memory than necessary
Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype)
T.annotate_layout({
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
})
T.annotate_layout(
{
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
}
)
if with_bias:
T.annotate_layout({
Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared),
})
T.annotate_layout(
{
Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared),
}
)
if threads == 512:
T.disable_warp_group_reg_alloc()
......@@ -357,26 +366,24 @@ def matmul(M,
# T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N],
# Bias_shared)
# T.copy(Bias_shared, C_local)
T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N],
C_local)
T.copy(Bias[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N], C_local)
else:
T.clear(C_local)
# Use 1D TMA to load Scale
T.copy(Scale[bx * block_N:(bx + 1) * block_N, :], Scale_shared)
T.copy(Scale[bx * block_N : (bx + 1) * block_N, :], Scale_shared)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
if fast_dequant:
get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared,
k)
get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, k)
else:
get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k)
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N])
T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N])
return main
......@@ -395,11 +402,11 @@ def ref_program_twiddling(A, qB, Scale, Bias=None):
Returns:
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
"""
dtypeC = "bfloat16"
dtypeC = T.bfloat16
B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
......@@ -420,11 +427,11 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias):
Returns:
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
"""
dtypeC = "bfloat16"
dtypeC = T.bfloat16
B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC))
return C
......@@ -446,11 +453,11 @@ def ref_program_simple(A, qB, Scale, Bias=None):
No in-place modification is performed on inputs (a local floating copy of B is scaled).
"""
dtypeC = "bfloat16"
dtypeC = T.bfloat16
B = torch_convert(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
......@@ -476,11 +483,11 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias):
No in-place modification is performed on inputs (a local floating copy of B is scaled).
"""
dtypeC = "bfloat16"
dtypeC = T.bfloat16
B = torch_convert(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC))
return C
......@@ -507,24 +514,16 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
if tune:
kernel = matmul(
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
num_bits=4,
scale_size=scale_size,
fast_dequant=fast_dequant,
with_bias=with_bias)
m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias
)
else:
kernel = matmul(
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
T.bfloat16,
T.bfloat16,
T.float32,
num_bits=4,
scale_size=scale_size,
block_M=256,
......@@ -534,7 +533,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
threads=256,
split=1,
fast_dequant=fast_dequant,
with_bias=with_bias)
with_bias=with_bias,
)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto)
......
......@@ -24,8 +24,9 @@ def matmul(
num_bits=4,
):
from tilelang.quantize import _tir_packed_to_unsigned_convert
num_elems_per_byte = 8 // num_bits
storage_dtype = "int8"
storage_dtype = T.int8
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
storage_type = str("".join(c for c in storage_dtype if not c.isdigit()))
A_shape = (M, K)
......@@ -39,9 +40,9 @@ def matmul(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -58,21 +59,19 @@ def matmul(
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
for i in T.serial(block_N * block_K // num_elems_per_byte //
(threads * local_size_compressed)):
for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * local_size_compressed)):
for v in T.vectorized(0, local_size_compressed):
index = i * threads * local_size_compressed + tx * local_size_compressed + v
vi = index // (block_K // num_elems_per_byte)
vj = index % (block_K // num_elems_per_byte)
B_local[v] = B_shared[vi, vj]
for v in T.serial(0, local_size):
B_dequantize_local[v] = _tir_packed_to_unsigned_convert(
storage_type, storage_nbit)(
num_bits,
B_local[v // num_elems_per_byte],
v % num_elems_per_byte,
dtype=in_dtype,
)
B_dequantize_local[v] = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
num_bits,
B_local[v // num_elems_per_byte],
v % num_elems_per_byte,
dtype=in_dtype,
)
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
vi = index // block_K
......@@ -121,9 +120,7 @@ def run_gemm(
def ref_program(A, qB):
import torch
B = (
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
dtype=torch.half).to(torch.half).to(A.device))
B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
......@@ -146,25 +143,27 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
):
from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout as make_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitterWithLadderTransform,)
TensorCoreIntrinEmitterWithLadderTransform,
)
from bitblas.gpu.intrin.lop3 import decode_i4_to_f16
assert in_dtype in [
"float16",
"int8",
T.float16,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
num_bits = 4
num_elems_per_byte = 8 // num_bits
storage_dtype = "int8"
storage_dtype = T.int8
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
if out_dtype == T.int32:
micro_size_k = 32
# This is a debug config
......@@ -183,7 +182,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = 32 if in_dtype == "float16" else 64
block_K = 32 if in_dtype == T.float16 else 64
chunk = block_K // reduce_k
is_smooth_a = False
......@@ -192,8 +191,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
pad_factor = 8
A_shape = (M, K)
B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y,
micro_size_k // num_elems_per_byte)
B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k // num_elems_per_byte)
A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K)
B_shared_shape = (
block_N // micro_size_y,
......@@ -228,7 +226,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
chunk=chunk,
reduce_k=reduce_k,
transform_kind_b=transform_b,
num_elems_per_byte=num_elems_per_byte)
num_elems_per_byte=num_elems_per_byte,
)
vec_load_qb = 16
if block_N * (block_K // reduce_k) // num_elems_per_byte // threads < vec_load_qb:
......@@ -236,14 +235,11 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads,
prelude=decode_i4_to_f16) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads, prelude=decode_i4_to_f16) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
......@@ -255,40 +251,36 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
thread_binding = T.get_thread_binding(0)
rk = T.get_thread_binding(1)
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
})
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
}
)
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, (block_K // reduce_k)):
vk = rk * (block_K // reduce_k) + k
A_shared[i, vk] = A[by * block_M + i, ko * block_K + vk]
# TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load
for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte //
(threads * vec_load_qb)):
for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte // (threads * vec_load_qb)):
for v in T.vectorized(0, vec_load_qb):
t = thread_binding
idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v
vkk = idx % (micro_size_k // num_elems_per_byte)
vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y
vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % (
block_K // micro_size_k)
vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y //
(block_K // micro_size_k)) % (
block_N // micro_size_y)
B_shared[vj, vk, vjj,
vkk] = B[bx * (block_N // micro_size_y) + vj,
ko * (block_K // micro_size_k) + vk, vjj, vkk]
vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % (block_K // micro_size_k)
vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y // (block_K // micro_size_k)) % (
block_N // micro_size_y
)
B_shared[vj, vk, vjj, vkk] = B[bx * (block_N // micro_size_y) + vj, ko * (block_K // micro_size_k) + vk, vjj, vkk]
for ki in T.serial(0, (block_K // (micro_size_k * reduce_k))):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
......@@ -307,9 +299,13 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
for j in T.serial(warp_cols):
local_size_b = mma_emitter.local_size_b
T.call_extern('handle', 'decode_i4u_to_f16',
T.address_of(B_local[j * local_size_b // num_elems_per_byte]),
T.address_of(B_dequantize_local[j * local_size_b]), 8)
T.call_extern(
"handle",
"decode_i4u_to_f16",
T.address_of(B_local[j * local_size_b // num_elems_per_byte]),
T.address_of(B_dequantize_local[j * local_size_b]),
8,
)
mma_emitter.mma(A_local, B_dequantize_local, C_local)
......@@ -328,7 +324,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
reduced_accum_res[0],
rk,
dtype="handle",
))
)
)
if rk == 0:
C_local[n] = reduced_accum_res[0]
......@@ -340,9 +337,9 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
for i, j in T.Parallel(block_M, (block_N // reduce_k)):
vj = rk * (block_N // reduce_k) + j
C[by * block_M + i,
bx * block_N + vj] = C_shared[i // micro_size_x, vj // micro_size_y,
i % micro_size_x, vj % micro_size_y]
C[by * block_M + i, bx * block_N + vj] = C_shared[
i // micro_size_x, vj // micro_size_y, i % micro_size_x, vj % micro_size_y
]
return main
......@@ -357,8 +354,8 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct
transform_b,
):
import bitblas
matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
M, N, K, in_dtype, out_dtype, accum_dtype, transform_b)
matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(M, N, K, in_dtype, out_dtype, accum_dtype, transform_b)
kernel = tilelang.compile(matmul, out_idx=[2])
src_code = kernel.get_kernel_source()
......@@ -368,11 +365,10 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct
assert src_code is not None
num_bits = 4
num_elems_per_byte = 8 // num_bits
storage_dtype = "int8"
storage_dtype = T.int8
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
qB = torch.randint(
0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype))
qB = torch.randint(0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
ladder_permutate_config = bitblas.ops.LadderPermutateConfig(
......@@ -407,9 +403,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct
# Ensure that the latency is not None
assert latency is not None
B = (
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
dtype=torch.half).to(torch.half).to(A.device))
B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
......@@ -423,14 +417,13 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct
@tilelang.testing.requires_package("bitblas")
def test_run_dequantize_gemm():
run_gemm(256, 256, 256, "float16", "float16", "float16", 128, 128, 32, num_threads=128)
run_gemm(256, 256, 256, "int8", "int32", "int32", 128, 128, 32, num_threads=128)
run_gemm(256, 256, 256, T.float16, T.float16, T.float16, 128, 128, 32, num_threads=128)
run_gemm(256, 256, 256, T.int8, T.int32, T.int32, 128, 128, 32, num_threads=128)
@tilelang.testing.requires_package("bitblas")
def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4():
assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(
256, 1024, 512, "float16", "float16", "float16", 3)
assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(256, 1024, 512, T.float16, T.float16, T.float16, 3)
def main():
......
......@@ -9,30 +9,29 @@ import argparse
def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == "float16"
assert val.dtype == "uint8"
assert dtype == T.float16
assert val.dtype == T.uint8
# e_f4 == 0 -> e_f16 = 0
# e_f4 != 0 -> e_f16 = e_f4 + ExponentialBias(f16, f4) = e_f4 + (2^4 - 2^1) = e_f4 + 14
# s1e2m1
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16")
e_f16 = e_f4 + tir.const(14, "uint16")
m_f4 = f4 & tir.const(1, "uint16")
mask = tir.const((1 << nbit) - 1, T.uint16)
f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask
s = f4 >> tir.const(3, T.uint16)
e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16)
e_f16 = e_f4 + tir.const(14, T.uint16)
m_f4 = f4 & tir.const(1, T.uint16)
m_f16 = m_f4
val_f16 = tir.reinterpret("float16",
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")
| m_f16 << tir.const(9, "uint16")).astype("uint16"))
# return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)
val_f16 = tir.reinterpret(
T.float16, ((e_f16 | (s << tir.const(5, T.uint16))) << tir.const(10, T.uint16) | m_f16 << tir.const(9, T.uint16)).astype(T.uint16)
)
# return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, T.float16), val_f16)
return val_f16
def torch_convert(tensor):
def print_bit(name, val):
val_cpu = val.cpu().item()
binary_repr = f'{val_cpu:032b}'
binary_repr = f"{val_cpu:032b}"
print(name, binary_repr)
def _convert(val, pos):
......@@ -61,15 +60,15 @@ def torch_convert(tensor):
@tilelang.jit(out_idx=[1])
def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
storage_dtype = T.uint8
B_shape = (N, K // num_elems_per_byte)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
@T.prim_func
def main(
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((N, K), in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((N, K), in_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
......@@ -99,7 +98,7 @@ def test_fp4_fp16_convert_close():
K,
block_N,
block_K,
"float16",
T.float16,
)
B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
......@@ -118,23 +117,15 @@ def get_configs():
splits = [1]
_configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits))
configs = [{
'block_M': c[0],
'block_N': c[1],
'block_K': c[2],
'num_stages': c[3],
'threads': c[4],
'split': c[5]
} for c in _configs]
configs = [{"block_M": c[0], "block_N": c[1], "block_K": c[2], "num_stages": c[3], "threads": c[4], "split": c[5]} for c in _configs]
return configs
def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@tilelang.jit(out_idx=[2])
def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
storage_dtype = T.uint8
A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte)
A_shared_shape = (block_M, block_K)
......@@ -145,17 +136,12 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@T.prim_func
def main_split(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
):
SplitC = T.alloc_buffer([
split, (N + block_N - 1) // block_N * block_N,
(M + block_M - 1) // block_M * block_M
], out_dtype)
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split,
threads=threads) as (bx, by, bz):
SplitC = T.alloc_buffer([split, (N + block_N - 1) // block_N * block_N, (M + block_M - 1) // block_M * block_M], out_dtype)
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split, threads=threads) as (bx, by, bz):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
......@@ -164,10 +150,12 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
})
T.annotate_layout(
{
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
}
)
T.clear(Ct_local)
for k in T.Pipelined(K // (block_K * split), num_stages=num_stages):
......@@ -183,8 +171,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
)
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, SplitC[bz, bx * block_N:(bx + 1) * block_N,
by * block_M:(by + 1) * block_M])
T.copy(Ct_local, SplitC[bz, bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M])
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by):
acc = T.alloc_fragment((block_N, block_M), out_dtype)
T.clear(acc)
......@@ -195,12 +182,11 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
......@@ -209,10 +195,12 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
})
T.annotate_layout(
{
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
}
)
T.clear(Ct_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
......@@ -229,8 +217,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct_shared)
T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N,
by * block_M:(by + 1) * block_M])
T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M])
if split == 1:
return main
......@@ -241,12 +228,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[2])
def kernel(block_M=None,
block_N=None,
block_K=None,
num_stages=None,
threads=None,
split=None):
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None, split=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads, split).prim_func
return kernel()
......@@ -259,7 +241,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
def ref_program(A, qB):
dtypeC = "float16"
dtypeC = T.float16
B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
......@@ -269,10 +251,10 @@ def ref_program(A, qB):
def main(m=256, n=256, k=256, tune=False):
total_flops = 2 * m * n * k
if (not tune):
kernel = matmul(
m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)(
block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1)
if not tune:
kernel = matmul(m, n, k, T.float16, T.float16, T.float32, num_bits=4, tune=tune)(
block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1
)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
......@@ -283,7 +265,7 @@ def main(m=256, n=256, k=256, tune=False):
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)
best_result = matmul(m, n, k, T.float16, T.float16, T.float32, num_bits=4, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
print(f"Best latency: {best_latency}")
......@@ -293,10 +275,10 @@ def main(m=256, n=256, k=256, tune=False):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--m', type=int, default=256, help='M')
parser.add_argument('--n', type=int, default=256, help='N')
parser.add_argument('--k', type=int, default=256, help='K')
parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument("--m", type=int, default=256, help="M")
parser.add_argument("--n", type=int, default=256, help="N")
parser.add_argument("--k", type=int, default=256, help="K")
parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
main(M, N, K, args.tune)
......@@ -9,15 +9,15 @@ import argparse
def _tir_u8_to_i4_to_i8(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == "int8"
assert val.dtype == "uint8"
assert dtype == T.int8
assert val.dtype == T.uint8
mask = tir.const((1 << nbit) - 1, "uint8")
mask = tir.const((1 << nbit) - 1, T.uint8)
i4 = (val >> (pos.astype("uint8") * tir.const(nbit, "uint8"))) & mask
i4 = (val >> (pos.astype(T.uint8) * tir.const(nbit, T.uint8))) & mask
i8_shifted = tir.reinterpret("int8", i4 << tir.const(4, "uint8"))
i8 = i8_shifted >> tir.const(4, "int8")
i8_shifted = tir.reinterpret(T.int8, i4 << tir.const(4, T.uint8))
i8 = i8_shifted >> tir.const(4, T.int8)
return i8
......@@ -35,15 +35,15 @@ def get_configs():
@tilelang.jit(out_idx=[1])
def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
storage_dtype = T.uint8
B_shape = (N, K // num_elems_per_byte)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
@T.prim_func
def main(
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((N, K), in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((N, K), in_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
......@@ -66,13 +66,12 @@ def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
def torch_convert(tensor):
def _convert(val, pos):
assert val.dtype == torch.uint8
val = val.view(torch.int8)
mask = (1 << 4) - 1
i4_shifted = ((val >> (pos * 4)) & mask)
i4 = ((i4_shifted << 4) >> 4)
i4_shifted = (val >> (pos * 4)) & mask
i4 = (i4_shifted << 4) >> 4
return i4.view(torch.int8)
......@@ -86,7 +85,7 @@ def torch_convert(tensor):
def ref_program(A, qB):
dtypeC = "int32"
dtypeC = T.int32
B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
......@@ -94,11 +93,10 @@ def ref_program(A, qB):
def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@tilelang.jit(out_idx=[2])
def kernel_func(block_M, block_N, block_K, num_stages, threads):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
storage_dtype = T.uint8
A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte)
A_shared_shape = (block_M, block_K)
......@@ -109,12 +107,11 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
......@@ -123,10 +120,12 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
})
T.annotate_layout(
{
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
}
)
T.clear(Ct_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
......@@ -143,8 +142,7 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct_shared)
T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N,
by * block_M:(by + 1) * block_M])
T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M])
return main
......@@ -167,10 +165,10 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune
def main(m=128, n=256, k=256, tune=False):
total_flops = 2 * m * n * k
if (not tune):
kernel = matmul_int8xint4(
m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)(
block_M=32, block_N=32, block_K=128, num_stages=1, threads=128)
if not tune:
kernel = matmul_int8xint4(m, n, k, T.int8, T.int32, T.int32, num_bits=4, tune=tune)(
block_M=32, block_N=32, block_K=128, num_stages=1, threads=128
)
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program, rtol=1e-2, atol=1e-2)
print("All checks pass.")
......@@ -179,7 +177,7 @@ def main(m=128, n=256, k=256, tune=False):
print(f"Tilelang: {latency} ms")
else:
best_result = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)
best_result = matmul_int8xint4(m, n, k, T.int8, T.int32, T.int32, num_bits=4, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
print(f"Bset latency: {best_latency}")
......
......@@ -4,7 +4,8 @@ from typing import Optional, Callable, Any
import torch
from tilelang import DataType
from tilelang.quantize import (
_tir_packed_int_to_int_convert,)
_tir_packed_int_to_int_convert,
)
@tilelang.jit
......@@ -16,7 +17,7 @@ def dequantize_gemv(
out_dtype: str,
accum_dtype: str,
num_bits: int = 4,
storage_dtype: str = "int8",
storage_dtype: T.dtype = T.int8,
source_format: str = "uint",
n_partition: int = 4,
reduce_thread: int = 32,
......@@ -26,11 +27,10 @@ def dequantize_gemv(
group_size: int = -1,
with_scaling: bool = False,
) -> Callable[..., Any]:
assert n_partition is not None, "n_partition must be provided"
assert reduce_thread is not None, (
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV"
"sch_outer_reduction_with_config is not implemented")
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMVsch_outer_reduction_with_config is not implemented"
)
assert trans_A is False, "Dequantize only implement for trans_A=False currently"
assert trans_B is True, "Dequantize only implement for trans_B=TRue currently"
......@@ -51,7 +51,7 @@ def dequantize_gemv(
C_shape = (M, N)
dp4a_size = 4
use_dp4a = in_dtype == "int8" and accum_dtype == "int32"
use_dp4a = in_dtype == T.int8 and accum_dtype == T.int32
import_source: Optional[str] = None
func_name: str = ""
......@@ -81,12 +81,12 @@ def dequantize_gemv(
C: T.Tensor[C_shape, out_dtype],
):
with T.Kernel(
T.ceildiv(N, n_partition),
M,
threads=(reduce_thread, n_partition),
T.ceildiv(N, n_partition),
M,
threads=(reduce_thread, n_partition),
) as (
bx,
by,
bx,
by,
):
A_local = T.alloc_local((micro_size_k,), in_dtype)
B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype)
......@@ -107,8 +107,7 @@ def dequantize_gemv(
for v in T.vectorized(micro_size_k_compressed):
B_quant_local[v] = B[
bx * n_partition + ni,
ko * (reduce_thread * micro_size_k_compressed) +
kr * micro_size_k_compressed + v,
ko * (reduce_thread * micro_size_k_compressed) + kr * micro_size_k_compressed + v,
]
if fast_decoding:
......@@ -120,10 +119,9 @@ def dequantize_gemv(
)
else:
for ki in T.serial(micro_size_k):
B_dequantize_local[ki] = _tir_packed_int_to_int_convert(
storage_type,
storage_nbit)(num_bits, B_quant_local[ki // num_elems_per_byte],
ki % num_elems_per_byte, in_dtype)
B_dequantize_local[ki] = _tir_packed_int_to_int_convert(storage_type, storage_nbit)(
num_bits, B_quant_local[ki // num_elems_per_byte], ki % num_elems_per_byte, in_dtype
)
if use_dp4a:
for ki in T.serial(micro_size_k // dp4a_size):
......@@ -137,9 +135,9 @@ def dequantize_gemv(
accum_res[0] += A_local[ki] * B_dequantize_local[ki]
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
......@@ -149,7 +147,8 @@ def dequantize_gemv(
reduced_accum_res[0],
kr,
dtype="handle",
))
)
)
if kr == 0:
C[by, bx * n_partition + ni] = reduced_accum_res[0]
......@@ -160,11 +159,11 @@ def main() -> None:
M = 1
N = 1024
K = 1024
in_dtype = "float16"
out_dtype = "float16"
accum_dtype = "float16"
in_dtype = T.float16
out_dtype = T.float16
accum_dtype = T.float16
num_bits = 4
storage_dtype = "int8"
storage_dtype = T.int8
source_format = "uint"
n_partition = 4
reduce_thread = 32
......@@ -174,26 +173,39 @@ def main() -> None:
group_size = -1
with_scaling = False
kernel = dequantize_gemv(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits, storage_dtype,
source_format, n_partition, reduce_thread, fast_decoding, trans_A,
trans_B, group_size, with_scaling)
kernel = dequantize_gemv(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
num_bits,
storage_dtype,
source_format,
n_partition,
reduce_thread,
fast_decoding,
trans_A,
trans_B,
group_size,
with_scaling,
)
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
num_elems_per_byte = storage_nbit // num_bits
A = torch.rand(M, K, dtype=getattr(torch, in_dtype)).cuda()
qB = torch.randint(
0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda()
qB = torch.randint(0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda()
C = torch.zeros(M, N, dtype=getattr(torch, accum_dtype)).cuda()
if fast_decoding:
from tilelang.quantize.utils import interleave_weight
qB = interleave_weight(qB, num_bits, in_dtype)
kernel(A, qB, C)
# int4 reference
B = (
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
dtype=torch.half).to(torch.half).to(A.device))
B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device)
for j in range(B.shape[1]):
B[:, j] = ((qB[:, j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
......
......@@ -25,6 +25,7 @@ def get_configs():
List[dict]: A list of configuration dictionaries covering all combinations.
"""
import itertools
iter_params = dict(
block_M=[128],
block_N=[64, 128, 256],
......@@ -33,33 +34,33 @@ def get_configs():
threads=[128, 256, 512],
split=[1],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs())
@tilelang.jit(out_idx=[-1])
def matmul(M,
N,
K,
topk,
E,
padding_M,
in_dtype,
out_dtype,
accum_dtype,
source_format='uint',
num_bits=4,
scale_size=32,
fast_dequant=True,
with_bias=False,
block_M=128,
block_N=256,
block_K=128,
num_stages=2,
threads=256,
split=1):
def matmul(
M,
N,
K,
topk,
E,
padding_M,
in_dtype,
out_dtype,
accum_dtype,
source_format=T.uint32,
num_bits=4,
scale_size=32,
fast_dequant=True,
with_bias=False,
block_M=128,
block_N=256,
block_K=128,
num_stages=2,
threads=256,
split=1,
):
"""
Construct and return a grouped (Mixture-of-Experts) matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized, expert-grouped B (shape ExNxQK) and writes an output of shape (M, topk, N) in out_dtype.
......@@ -82,8 +83,8 @@ def matmul(M,
topk (int): number of experts selected per token.
E (int): number of experts.
padding_M (int): padded number of tokens after grouping and block alignment.
in_dtype (str): element type of A (e.g., "bfloat16").
out_dtype (str): output tensor element type (e.g., "bfloat16").
in_dtype (str): element type of A (e.g., T.bfloat16).
out_dtype (str): output tensor element type (e.g., T.bfloat16).
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
......@@ -110,16 +111,17 @@ def matmul(M,
"""
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
storage_dtype = T.uint8
QK = K // num_elems_per_byte
Block_QK = block_K // num_elems_per_byte
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, Block_QK)
Bias_shared_shape = (block_N)
Bias_shared_shape = block_N
B_dequantize_shared_shape = (block_N, block_K)
assert K % (block_K * split) == 0
from tilelang.quantize import get_mxfp_intrin_group
# fast_dequant_bf16_fp4_twiddling
mxfp_intrin_info = get_mxfp_intrin_group(
out_dtype=in_dtype,
......@@ -135,7 +137,7 @@ def matmul(M,
import_source = import_source
# the dequant part is the same as in dequant_gemm
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"):
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16):
"""
Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16.
The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and:
......@@ -145,12 +147,12 @@ def matmul(M,
- Writes the scaled BF16 results into B_dequantize_shared.
Notes:
- This factory only supports in_dtype="fp4" and out_dtype="bfloat16".
- This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16.
- The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro.
- The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime.
"""
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
assert out_dtype in [T.bfloat16]
# Some variables for dequantization in each thread
MAX_TRANSACTION_SIZE_BITS = 128
......@@ -221,19 +223,16 @@ def matmul(M,
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
B_dequantize_shared[index // block_K,
index % block_K] = B_dequantize_local_thread[v]
B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v]
return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"):
def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16):
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
assert out_dtype in [T.bfloat16]
@T.macro
def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k):
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype)
......@@ -244,8 +243,8 @@ def matmul(M,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
Scale_shared[
i, k * block_K // scale_size + j //
scale_size], # Scale is the exponential part, within the representation of uint8
i, k * block_K // scale_size + j // scale_size
], # Scale is the exponential part, within the representation of uint8
dtype=out_dtype,
) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size]))
T.copy(B_dequantize_local, B_dequantize_shared)
......@@ -254,19 +253,17 @@ def matmul(M,
@T.prim_func
def main(
A: T.Tensor((M, K), in_dtype),
B: T.Tensor((E, N, QK), storage_dtype),
Scale: T.Tensor((E, N, K // scale_size), storage_dtype),
Bias: T.Tensor((E, N), out_dtype),
# Add fusedmoe tensors
topk_weights: T.Tensor((M * topk), out_dtype),
sorted_token_ids: T.Tensor((padding_M), "int32"),
expert_ids: T.Tensor((padding_M // block_M), "int32"),
C: T.Tensor((M, topk, N), out_dtype),
A: T.Tensor((M, K), in_dtype),
B: T.Tensor((E, N, QK), storage_dtype),
Scale: T.Tensor((E, N, K // scale_size), storage_dtype),
Bias: T.Tensor((E, N), out_dtype),
# Add fusedmoe tensors
topk_weights: T.Tensor((M * topk), out_dtype),
sorted_token_ids: T.Tensor((padding_M), T.int32),
expert_ids: T.Tensor((padding_M // block_M), T.int32),
C: T.Tensor((M, topk, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
......@@ -274,23 +271,25 @@ def matmul(M,
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
topk_weights_shared = T.alloc_shared((block_M), out_dtype)
sorted_token_ids_shared = T.alloc_shared((block_M), "int32")
expert_id = T.alloc_local((1), "int32") # the expert id for the current block
sorted_token_ids_shared = T.alloc_shared((block_M), T.int32)
expert_id = T.alloc_local((1), T.int32) # the expert id for the current block
# To use 1D TMA, the last dim of Scale_shared must have stride=1
# May use much more shared memory than necessary
Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype)
T.annotate_layout({
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
})
T.annotate_layout(
{
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
}
)
T.use_swizzle(10)
if threads == 512:
T.disable_warp_group_reg_alloc()
T.copy(sorted_token_ids[by * block_M:(by + 1) * block_M], sorted_token_ids_shared)
T.copy(sorted_token_ids[by * block_M : (by + 1) * block_M], sorted_token_ids_shared)
expert_id[0] = expert_ids[by]
# Get the topk weights of each token in the current block
......@@ -300,11 +299,11 @@ def matmul(M,
# Get bias and scale based on the expert id
if with_bias:
T.copy(Bias[expert_id[0], bx * block_N:(bx + 1) * block_N], Bias_shared)
T.copy(Bias[expert_id[0], bx * block_N : (bx + 1) * block_N], Bias_shared)
else:
T.clear(Bias_shared)
T.copy(Scale[expert_id[0], bx * block_N:(bx + 1) * block_N, :], Scale_shared)
T.copy(Scale[expert_id[0], bx * block_N : (bx + 1) * block_N, :], Scale_shared)
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = Bias_shared[j]
......@@ -317,14 +316,13 @@ def matmul(M,
base = copy_i * threads * 16 + tx * 16
if sorted_token_ids_shared[base // block_K] != -1:
for copy_j in T.vectorized(16):
A_shared[base // block_K, base % block_K +
copy_j] = A[sorted_token_ids_shared[base // block_K] // topk,
k * block_K + base % block_K + copy_j]
A_shared[base // block_K, base % block_K + copy_j] = A[
sorted_token_ids_shared[base // block_K] // topk, k * block_K + base % block_K + copy_j
]
T.copy(B[expert_id[0], bx * block_N, k * block_K // num_elems_per_byte], B_shared)
if fast_dequant:
get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared,
k)
get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, k)
else:
get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k)
......@@ -338,16 +336,17 @@ def matmul(M,
base = copy_i * threads * 16 + tx * 16
if sorted_token_ids_shared[base // block_N] != -1:
for copy_j in T.vectorized(16):
C[sorted_token_ids_shared[base // block_N] // topk,
sorted_token_ids_shared[base // block_N] % topk, bx * block_N +
base % block_N + copy_j] = C_shared[base // block_N,
base % block_N + copy_j]
C[
sorted_token_ids_shared[base // block_N] // topk,
sorted_token_ids_shared[base // block_N] % topk,
bx * block_N + base % block_N + copy_j,
] = C_shared[base // block_N, base % block_N + copy_j]
return main
def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=256):
dtypeC = "bfloat16"
dtypeC = T.bfloat16
M, K = A.shape
E, N, QK = qB.shape
topk = topk_weights.shape[0] // M
......@@ -355,7 +354,7 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc
assert scale_size == 32 # MXFP4
# Initialize output tensor
C = torch.ones((M, topk, N), dtype=getattr(torch, dtypeC), device='cuda')
C = torch.ones((M, topk, N), dtype=getattr(torch, dtypeC), device="cuda")
# Iterate over sorted_token_ids
for idx in range(len(sorted_token_ids)): # padding_M
......@@ -370,14 +369,11 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc
# Dequantize the expert weights
B = torch_convert_bit_twiddling(qB[expert_id]) # shape: (N, K)
B *= 2**(
Scale[expert_id][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to(
torch.bfloat16))
B *= 2 ** (Scale[expert_id][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to(torch.bfloat16))
# Compute the output for this token-expert pair
# token_embedding @ B.T + bias
output = torch.matmul(token_embedding.to(torch.bfloat16), B.T.to(
torch.bfloat16)) + Bias[expert_id]
output = torch.matmul(token_embedding.to(torch.bfloat16), B.T.to(torch.bfloat16)) + Bias[expert_id]
output = output.to(torch.__getattribute__(dtypeC))
# Apply the topk weight
......@@ -391,14 +387,12 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc
def get_data(m, n, k, qk, scale_size, topk, E, block_M):
A = torch.empty(m, k, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1)
qB = torch.randint(
0, 256, (E, n, qk), dtype=torch.uint8,
device='cuda') # Quantized weight tensor for E experts.
Scale = torch.randint(0, 8, (E, n, k // scale_size), dtype=torch.uint8, device='cuda')
Bias = torch.empty(E, n, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1)
weights = torch.empty(m, E, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1)
A = torch.empty(m, k, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1)
qB = torch.randint(0, 256, (E, n, qk), dtype=torch.uint8, device="cuda") # Quantized weight tensor for E experts.
Scale = torch.randint(0, 8, (E, n, k // scale_size), dtype=torch.uint8, device="cuda")
Bias = torch.empty(E, n, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1)
weights = torch.empty(m, E, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1)
# topk_weights: Router weights for the top-k experts for each token.
# Shape: (m, topk)
# tokens_experts: A flattened tensor of expert assignments for each token.
......@@ -420,10 +414,7 @@ def get_data(m, n, k, qk, scale_size, topk, E, block_M):
pad_len = ((cnt + block_M - 1) // block_M) * block_M - cnt
if pad_len > 0:
# -1 for padding (`M` instead in vLLM moe_align_block_size())
group_token_ids = torch.cat([
group_token_ids,
torch.full((pad_len,), -1, dtype=group_token_ids.dtype, device='cuda')
])
group_token_ids = torch.cat([group_token_ids, torch.full((pad_len,), -1, dtype=group_token_ids.dtype, device="cuda")])
padded_token_ids.append(group_token_ids)
expert_ids.extend([eid] * ((cnt + block_M - 1) // block_M))
start = end
......@@ -431,21 +422,13 @@ def get_data(m, n, k, qk, scale_size, topk, E, block_M):
# sorted_token_ids: The final flattened and padded tensor of token indices.
sorted_token_ids = torch.cat(padded_token_ids, dim=0).to(torch.int32) # (padding_M,)
# expert_ids: The final tensor of expert IDs corresponding to `sorted_token_ids`.
expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda') # (padding_M,)
expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device="cuda") # (padding_M,)
padding_M = sorted_token_ids.shape[0] # padding_M: token number after padding
return A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M
def main(m=256,
n=256,
k=256,
scale_size=32,
topk=4,
E=32,
fast_dequant=True,
with_bias=False,
tune=False):
def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, with_bias=False, tune=False):
# Tunable parameters
block_M, block_N, block_K = 128, 256, 128 # noqa: F841
num_stages = 1 # noqa: F841
......@@ -456,8 +439,7 @@ def main(m=256,
num_bits = 4
num_elems_per_byte = 8 // num_bits
qk = k // num_elems_per_byte
A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data(
m, n, k, qk, scale_size, topk, E, block_M)
A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data(m, n, k, qk, scale_size, topk, E, block_M)
if tune:
with set_autotune_inputs([A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids]):
......@@ -469,9 +451,9 @@ def main(m=256,
topk,
E,
padding_M,
"bfloat16",
"bfloat16",
"float32",
T.bfloat16,
T.bfloat16,
T.float32,
num_bits=num_bits,
scale_size=scale_size,
fast_dequant=fast_dequant,
......@@ -485,9 +467,9 @@ def main(m=256,
topk,
E,
padding_M,
"bfloat16",
"bfloat16",
"float32",
T.bfloat16,
T.bfloat16,
T.float32,
num_bits=num_bits,
scale_size=scale_size,
fast_dequant=fast_dequant,
......@@ -510,14 +492,11 @@ def main(m=256,
expert_ids,
)
print('Tilelang kernel run finished.')
print("Tilelang kernel run finished.")
ref_output = ref_moe(
A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids,
block_M=block_M) # Maybe a little bit slow...
ref_output = ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=block_M) # Maybe a little bit slow...
latency = tilelang.profiler.do_bench(
lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100)
latency = tilelang.profiler.do_bench(lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100)
print("Tilelang: {:.2f} ms".format(latency))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
......@@ -525,32 +504,19 @@ def main(m=256,
max_val = diff.max()
max_idx = diff.argmax()
print(f"max abs diff: {max_val} at index: {max_idx}")
assert_similar(
output, ref_output, name="output",
eps=2e-5) # We care about the similarity rather than abs. difference
assert_similar(output, ref_output, name="output", eps=2e-5) # We care about the similarity rather than abs. difference
print("All checks pass. ✅")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--M", type=int, default=16384, help="M") # From gpt-oss-20b MoE's first gemm
parser.add_argument("--M", type=int, default=16384, help="M") # From gpt-oss-20b MoE's first gemm
parser.add_argument("--N", type=int, default=5760, help="N")
parser.add_argument("--K", type=int, default=2944, help="K")
parser.add_argument("--scale_size", type=int, default=32, help="scale size")
parser.add_argument(
"--topk", type=int, default=4, help="topk") # experts activated for each token
parser.add_argument("--topk", type=int, default=4, help="topk") # experts activated for each token
parser.add_argument("--E", type=int, default=32, help="E") # number of experts
parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args()
main(
args.M,
args.N,
args.K,
args.scale_size,
topk=args.topk,
E=args.E,
fast_dequant=True,
with_bias=True,
tune=args.tune)
main(args.M, args.N, args.K, args.scale_size, topk=args.topk, E=args.E, fast_dequant=True, with_bias=True, tune=args.tune)
from typing import Optional
import torch
import torch.nn.functional as F
from indexer_topk_reducesum import indexer_topk_reducesum_interface
from indexer_bwd import indexer_bwd_interface
from sparse_mla_fwd import sparse_mla_fwd_interface
from sparse_mla_bwd import sparse_mla_bwd
from sparse_mla_topk_reducesum import sparse_mla_topk_reducesum_interface
from einops import einsum, repeat
from utils import get_abs_err, get_err_ratio
class RegsiterLossFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, loss):
ctx.save_for_backward(loss)
return x
@staticmethod
def backward(ctx, grad):
loss = ctx.saved_tensors
return grad, torch.ones(1, dtype=loss[0].dtype, device=loss[0].device)
register_loss = RegsiterLossFunction.apply
def ref_deepseek_sparse_attention_innner(
q: torch.Tensor,
kv: torch.Tensor,
index_q: torch.Tensor,
index_k: torch.Tensor,
weights: torch.Tensor,
topk: int,
dim_v: int,
sm_scale: Optional[float] = None,
index_sm_scale: Optional[float] = None,
):
dtype = q.dtype
q, kv, index_q, index_k, weights = map(lambda x: x.to(torch.float32), (q, kv, index_q, index_k, weights))
index_sm_scale = index_q.shape[-1] ** -0.5
b, s = index_q.shape[:2]
# tl_topk_indices = tl_topk_indices.to(torch.int64)
# tl_topk_indices[tl_topk_indices == -1] = s
casual_mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device)
index_logits = einsum(index_q, index_k, "b s1 h k, b s2 k -> b s1 h s2")
index_logits = F.relu(index_logits)
index_logits = (index_logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * index_sm_scale
index_logits = torch.where(casual_mask, index_logits, float("-inf"))
topk_indices = torch.topk(index_logits, k=topk, dim=-1).indices
topk_logits = torch.gather(F.pad(index_logits, (0, 1), value=float("-inf")), dim=-1, index=topk_indices)
topk_score = F.log_softmax(topk_logits, dim=-1, dtype=torch.float32)
index_topk_score = topk_score
if sm_scale is None:
sm_scale = kv.shape[-1] ** -0.5
h = q.shape[-2]
index_mask = torch.zeros((b, s, s + 1), dtype=torch.bool, device="cuda").scatter_(
dim=-1, index=topk_indices, src=torch.ones_like(topk_indices, dtype=torch.bool)
)[:, :, :-1]
mask = repeat(casual_mask & index_mask, "b s1 s2 -> b s1 h s2", h=h)
k, v = kv, kv[..., :dim_v]
logits = einsum(q, k, "b s1 h d, b s2 d -> b s1 h s2") * sm_scale
logits = torch.where(mask, logits, float("-inf"))
attn_score = F.softmax(logits, dim=-1, dtype=torch.float32)
o = einsum(attn_score, v, "b s1 h s2, b s2 d -> b s1 h d")
attn_score = attn_score.sum(dim=-2) # [b, s1, s2]
attn_topk_score = torch.gather(F.pad(attn_score, (0, 1)), dim=-1, index=topk_indices)
attn_topk_score = attn_topk_score / attn_topk_score.sum(dim=-1, keepdim=True)
loss = F.kl_div(index_topk_score.clip(-100, 0), attn_topk_score.detach().log().clip(-100, 0), log_target=True, reduction="sum")
o = register_loss(o, loss)
return o.to(dtype), topk_indices
def ref_deepseek_sparse_attention(
q: torch.Tensor,
kv: torch.Tensor,
index_q: torch.Tensor,
index_k: torch.Tensor,
weights: torch.Tensor,
offsets: torch.Tensor,
topk: int,
dim_v: int,
sm_scale: Optional[float] = None,
index_sm_scale: Optional[float] = None,
):
all_o, all_topk_indices = [], []
for i in range(offsets.shape[0] - 1):
o, topk_indices = ref_deepseek_sparse_attention_innner(
q[None, offsets[i] : offsets[i + 1]],
kv[None, offsets[i] : offsets[i + 1]],
index_q[None, offsets[i] : offsets[i + 1]],
index_k[None, offsets[i] : offsets[i + 1]],
weights[None, offsets[i] : offsets[i + 1]],
topk,
dim_v,
sm_scale,
index_sm_scale,
)
all_o.append(o.squeeze(0))
all_topk_indices.append(topk_indices.squeeze(0))
o = torch.cat(all_o, dim=0)
topk_indices = torch.cat(all_topk_indices, dim=0)
return o, topk_indices
class DSAFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q: torch.Tensor,
kv: torch.Tensor,
index_q: torch.Tensor,
index_k: torch.Tensor,
weights: torch.Tensor,
offsets: torch.Tensor,
topk: int,
dim_v: int,
sm_scale: Optional[float] = None,
):
# topk_indices, index_score = ref_index_score(index_q, weights, index_k, topk)
topk_indices, index_score = indexer_topk_reducesum_interface(index_q, weights, index_k, topk, offsets)
o, lse = sparse_mla_fwd_interface(q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), offsets, sm_scale=sm_scale, d_v=dim_v)
ctx.save_for_backward(q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets)
ctx.topk = topk
ctx.dim_v = dim_v
ctx.sm_scale = sm_scale
return o, topk_indices
@staticmethod
def backward(
ctx,
do: torch.Tensor,
_1: torch.Tensor,
):
q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets = ctx.saved_tensors
attn_score = sparse_mla_topk_reducesum_interface(
q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), lse, offsets, dim_v=ctx.dim_v
).squeeze(-2)
dq, dkv = sparse_mla_bwd(q, kv.unsqueeze(-2), o, do, topk_indices.unsqueeze(-2), lse, offsets, sm_scale=ctx.sm_scale)
dindex_q, dweights, dindex_k = indexer_bwd_interface(index_q, weights, index_k, attn_score, index_score, topk_indices, offsets)
return dq, dkv.squeeze(-2), dindex_q, dindex_k, dweights, None, None, None, None
def deepseek_sparse_attention(
q: torch.Tensor,
kv: torch.Tensor,
index_q: torch.Tensor,
index_k: torch.Tensor,
weights: torch.Tensor,
offsets: torch.Tensor,
topk: int,
dim_v: int,
sm_scale: Optional[float] = None,
):
return DSAFunction.apply(q, kv, index_q, index_k, weights, offsets, topk, dim_v, sm_scale)
def test_kernel(
B=1,
S=2048,
H=16,
D=512,
tail_D=64,
index_D=128,
topk=64,
):
torch.manual_seed(42)
q = torch.randn((S, H, D + tail_D)).cuda().bfloat16().requires_grad_()
kv = torch.randn((S, D + tail_D)).cuda().bfloat16().requires_grad_()
index_q = torch.randn((S, H, index_D)).cuda().bfloat16().requires_grad_()
weights = torch.randn((S, H)).cuda().bfloat16().requires_grad_()
index_k = torch.randn((S, index_D)).cuda().bfloat16().requires_grad_()
do = torch.randn((S, H, D)).cuda().bfloat16().requires_grad_()
offsets = torch.tensor([0, S // 2, S], dtype=torch.int32).cuda()
o, topk_indices = deepseek_sparse_attention(q, kv, index_q, index_k, weights, offsets, topk, D)
o.backward(do)
q_grad, q.grad = q.grad, None
kv_grad, kv.grad = kv.grad, None
index_q_grad, index_q.grad = index_q.grad, None
index_k_grad, index_k.grad = index_k.grad, None
weights_grad, weights.grad = weights.grad, None
ref_o, ref_topk_indices = ref_deepseek_sparse_attention(q, kv, index_q, index_k, weights, offsets, topk, D)
ref_o.backward(do)
ref_q_grad, q.grad = q.grad, None
ref_kv_grad, kv.grad = kv.grad, None
ref_index_q_grad, index_q.grad = index_q.grad, None
ref_index_k_grad, index_k.grad = index_k.grad, None
ref_weights_grad, weights.grad = weights.grad, None
print(f"o err: {get_abs_err(o, ref_o):.6f} ratio: {get_err_ratio(o, ref_o):.6f}")
print(f"q.grad err: {get_abs_err(q_grad, ref_q_grad):.6f} ratio: {get_err_ratio(q_grad, ref_q_grad):.6f}")
print(f"kv.grad err: {get_abs_err(kv_grad, ref_kv_grad):.6f} ratio: {get_err_ratio(kv_grad, ref_kv_grad):.6f}")
print(
f"index_q.grad err: {get_abs_err(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f} ratio: {get_err_ratio(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f}"
)
print(f"index_k.grad err: {get_abs_err(index_k_grad, ref_index_k_grad):.6f} ratio: {get_err_ratio(index_k_grad, ref_index_k_grad):.6f}")
print(f"weights.grad err: {get_abs_err(weights_grad, ref_weights_grad):.6f} ratio: {get_err_ratio(weights_grad, ref_weights_grad):.6f}")
intersections = []
for j in range(S):
ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy()
trt_np = topk_indices[j].cpu().to(torch.int32).numpy()
mask = trt_np != -1
set_ref = set(ref_np[mask])
set_trt = set(trt_np[mask])
intersection = set_ref & set_trt
intersections.append(len(intersection) / len(set_ref))
print("average intersections: {:.4f}".format(sum(intersections) / len(intersections)))
test_kernel()
# Modified from: https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/index.py
import torch
import torch.nn.functional as F
import functools
from typing import Callable, Any
def tensor_cache(
fn: Callable[..., torch.Tensor],
) -> Callable[..., torch.Tensor]:
"""
A decorator that caches the most recent result of a function with tensor inputs.
This decorator will store the output of the decorated function for the most recent set of input tensors.
If the function is called again with the same input tensors, it will return the cached result.
Args:
fn (Callable[..., torch.Tensor]):
The function to be decorated. It should take tensor inputs and return tensor outputs.
Returns:
Callable[..., torch.Tensor]:
A wrapped version of the input function with single-entry caching.
"""
last_args: tuple | None = None
last_kwargs: dict | None = None
last_result: Any = None
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
nonlocal last_args, last_kwargs, last_result
if (
(last_args is not None and last_kwargs is not None)
and (len(args) == len(last_args) and len(kwargs) == len(last_kwargs))
and all(a is b for a, b in zip(args, last_args, strict=False))
and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items())
):
return last_result
result = fn(*args, **kwargs)
last_args, last_kwargs, last_result = args, kwargs, result
return result
return wrapper
@tensor_cache
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return torch.diff(cu_seqlens)
@tensor_cache
def prepare_cu_seqlens_from_lens(
lens: torch.LongTensor,
dtype: torch.dtype | None = torch.int32,
) -> torch.LongTensor:
return F.pad(lens.cumsum(dim=0, dtype=dtype), (1, 0))
@tensor_cache
def prepare_lens_from_cu_seqlens(
cu_seqlens: torch.LongTensor,
) -> torch.LongTensor:
return torch.diff(cu_seqlens)
@tensor_cache
def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return torch.cat([torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) for n in prepare_lens(cu_seqlens).unbind()])
@tensor_cache
def prepare_sequence_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return prepare_position_ids(cu_seqlens).eq(0).cumsum(0) - 1
@tensor_cache
def prepare_token_indices(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
position_ids = prepare_position_ids(cu_seqlens)
return torch.stack([prepare_sequence_ids(cu_seqlens), position_ids], 1).to(cu_seqlens)
import torch
import torch.nn.functional as F
from einops import einsum, repeat
import tilelang as tl
import tilelang.language as T
from typing import Optional
from index import prepare_token_indices
from utils import get_abs_err, get_err_ratio
BF16 = T.bfloat16
FP32 = T.float32
INT32 = T.int32
pass_configs = {
tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}
@tl.jit(pass_configs=pass_configs)
def tl_indexer_bwd_impl(
heads: int,
dim: int,
topk: int,
sm_scale: Optional[float] = None,
block_I: int = 32,
num_stages: int = 0,
num_threads: int = 128,
):
assert num_stages == 0
assert topk == tl.math.next_power_of_2(topk)
assert topk % block_I == 0
assert heads <= 64 and heads % 8 == 0
batch_plus_one = T.symbolic("batch_plus_one")
seq_len = T.symbolic("seq_len")
dtype: str = BF16
accum_dtype: str = FP32
index_q_shape = [seq_len, heads, dim]
weights_shape = [seq_len, heads]
index_k_shape = [seq_len, dim]
shape_p = [seq_len, topk]
topk_indices_shape = [seq_len, topk]
offsets_shape = [batch_plus_one]
token_indices_shape = [seq_len, 2]
if sm_scale is None:
sm_scale = dim**-0.5
@T.prim_func
def tl_indexer_bwd_kernel(
IndexQ: T.Tensor(index_q_shape, dtype),
Weights: T.Tensor(weights_shape, dtype),
IndexK: T.Tensor(index_k_shape, dtype),
dIndexQ: T.Tensor(index_q_shape, dtype),
dWeights: T.Tensor(weights_shape, dtype),
dIndexK: T.Tensor(index_k_shape, dtype),
AttnScore: T.Tensor(shape_p, FP32),
IndexScore: T.Tensor(shape_p, FP32),
TopkIndices: T.Tensor(topk_indices_shape, INT32),
Offsets: T.Tensor(offsets_shape, INT32),
TokenIndices: T.Tensor(token_indices_shape, INT32),
):
with T.Kernel(seq_len, threads=num_threads) as (bx):
i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1]
bos = Offsets[i_b]
num_blocks = T.ceildiv(topk, block_I)
index_q_shared = T.alloc_shared([heads, dim], dtype=dtype)
weights_shared = T.alloc_shared([heads], dtype=dtype)
d_index_q_frag = T.alloc_fragment([heads, dim], dtype=accum_dtype)
d_weights_frag = T.alloc_fragment([heads], dtype=accum_dtype)
T.copy(IndexQ[bos + i_t, :, :], index_q_shared)
T.copy(Weights[bos + i_t, :], weights_shared)
T.fill(d_index_q_frag, 0)
T.fill(d_weights_frag, 0)
for i, j in T.Parallel(heads, dim):
index_q_shared[i, j] = index_q_shared[i, j] * sm_scale
for bi_i in T.Pipelined(num_blocks, num_stages=num_stages):
i_st = bi_i * block_I
i_ed = (bi_i + 1) * block_I
indices_shared = T.alloc_shared([block_I], dtype=INT32)
T.copy(TopkIndices[bos + i_t, i_st:i_ed], indices_shared)
index_k_shared = T.alloc_shared([block_I, dim], dtype=dtype)
for i, j in T.Parallel(block_I, dim):
pos = indices_shared[i]
index_k_shared[i, j] = T.if_then_else((pos > -1) & (pos <= i_t), IndexK[bos + pos, j], 0)
attn_score_shared = T.alloc_shared([block_I], dtype=accum_dtype)
index_score_shared = T.alloc_shared([block_I], dtype=accum_dtype)
for i in T.Parallel(block_I):
attn_score_shared[i] = AttnScore[bos + i_t, i_st + i]
index_score_shared[i] = IndexScore[bos + i_t, i_st + i]
logits = T.alloc_fragment((block_I, heads), accum_dtype)
T.gemm(
index_k_shared,
index_q_shared,
logits,
transpose_A=False,
transpose_B=True,
clear_accum=True,
)
for i, j in T.Parallel(block_I, heads):
logits[i, j] = T.max(logits[i, j], 0)
# dw
d_weights_i = T.alloc_fragment((block_I, heads), accum_dtype)
for i, j in T.Parallel(block_I, heads):
d_weights_i[i, j] = (index_score_shared[i] - attn_score_shared[i]) * logits[i, j]
T.reduce_sum(d_weights_i, d_weights_frag, dim=0, clear=False)
d_logits_qk = T.alloc_shared((block_I, heads), accum_dtype)
d_logits_qk_cast1 = T.alloc_fragment((block_I, heads), dtype)
d_logits_qk_cast2 = T.alloc_fragment((block_I, heads), dtype)
for i, j in T.Parallel(block_I, heads):
d_relu = T.alloc_var(accum_dtype)
if logits[i, j] > 0:
d_relu = 1.0
else:
d_relu = 0.0
d_logits_qk[i, j] = (index_score_shared[i] - attn_score_shared[i]) * d_relu * weights_shared[j]
# dq
T.copy(d_logits_qk, d_logits_qk_cast1)
T.gemm(
d_logits_qk_cast1, # [BS, HQ]
index_k_shared, # [BS, K]
d_index_q_frag, # [HQ, K]
transpose_A=True,
transpose_B=False,
clear_accum=False,
)
# dk
T.copy(d_logits_qk, d_logits_qk_cast2)
d_index_k_frag = T.alloc_fragment([block_I, dim], dtype=accum_dtype)
T.gemm(
d_logits_qk_cast2, # [BS, HQ]
index_q_shared, # [HQ, K]
d_index_k_frag, # [BS, K]
transpose_A=False,
transpose_B=False,
clear_accum=True,
)
for i, j in T.Parallel(block_I, dim):
pos = indices_shared[i]
if (pos > -1) & (pos <= i_t):
T.atomic_add(dIndexK[bos + pos, j], d_index_k_frag[i, j])
for i, j in T.Parallel(heads, dim):
d_index_q_frag[i, j] = d_index_q_frag[i, j] * sm_scale
T.copy(d_index_q_frag, dIndexQ[bos + i_t, :, :])
T.copy(d_weights_frag, dWeights[bos + i_t, :])
return tl_indexer_bwd_kernel
def indexer_bwd_interface(
q: torch.Tensor,
weights: torch.Tensor,
k: torch.Tensor,
attn_score: torch.Tensor,
index_score: torch.Tensor,
topk_indices: torch.Tensor,
offsets: torch.Tensor,
):
_, heads, dim, topk = *q.shape, topk_indices.shape[-1]
token_indices = prepare_token_indices(offsets)
dq = torch.zeros_like(q)
dweights = torch.zeros_like(weights)
dk = torch.zeros_like(k)
kernel = tl_indexer_bwd_impl(heads, dim, topk)
kernel(q, weights, k, dq, dweights, dk, attn_score, index_score, topk_indices, offsets, token_indices)
return dq, dweights, dk
def ref_indexer_bwd(
Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, AttnScore: torch.Tensor, offsets: torch.Tensor
) -> torch.Tensor:
Q.requires_grad_(True)
Weights.requires_grad_(True)
K.requires_grad_(True)
softmax_scale = Q.shape[-1] ** -0.5
all_loss = []
all_log_topk_prob = []
for i in range(offsets.shape[0] - 1):
assert (offsets[i + 1] - offsets[i]).item() >= TopkIndices.shape[-1]
q = Q[offsets[i] : offsets[i + 1]]
weights = Weights[offsets[i] : offsets[i + 1]]
k = K[offsets[i] : offsets[i + 1]]
topk_indices = TopkIndices[offsets[i] : offsets[i + 1]]
attn_score = AttnScore[offsets[i] : offsets[i + 1]]
s = q.shape[0]
mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device)
logits = einsum(q, k, "s1 h k, s2 k -> s1 h s2") * softmax_scale
logits = F.relu(logits)
score = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32)
score = torch.where(mask, score, float("-inf"))
topk_value = torch.gather(score, dim=-1, index=topk_indices.to(torch.int64))
log_topk_prob = F.log_softmax(topk_value, dim=-1, dtype=torch.float32)
loss = F.kl_div(log_topk_prob.clip(-100, 0), attn_score.log().clip(-100, 0), log_target=True, reduction="sum")
all_loss.append(loss)
all_log_topk_prob.append(log_topk_prob)
loss = torch.stack(all_loss).sum()
loss.backward()
log_topk_prob = torch.cat(all_log_topk_prob, dim=0)
return log_topk_prob.exp(), Q.grad, Weights.grad, K.grad
def test_kernel(
B=1,
S=2048,
H=16,
D=128,
topk=64,
):
torch.manual_seed(42)
q = torch.randn((S, H, D)).cuda().bfloat16()
w = torch.randn((S, H)).cuda().bfloat16()
k = torch.randn((S, D)).cuda().bfloat16()
offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda()
all_attn_score = []
for i in range(offsets.shape[0] - 1):
seq_len = (offsets[i + 1] - offsets[i]).item()
mask = (torch.arange(seq_len)[:, None] >= torch.arange(topk)[None, :]).to(q.device)
logits = torch.ones(seq_len, topk).cuda()
logits = torch.where(mask, logits, float("-inf"))
attn_score = F.softmax(logits, dim=-1, dtype=torch.float32)
all_attn_score.append(attn_score)
attn_score = torch.cat(all_attn_score, dim=0)
topk_indices = repeat(torch.arange(topk, dtype=torch.int32).cuda(), "k -> s k", s=S).contiguous()
index_score, ref_dq, ref_dw, ref_dk = ref_indexer_bwd(q, w, k, topk_indices, attn_score, offsets)
dq, dw, dk = indexer_bwd_interface(q, w, k, attn_score, index_score, topk_indices, offsets)
print(f"dq err: {get_abs_err(dq, ref_dq):.6f} ratio: {get_err_ratio(dq, ref_dq):.6f}")
print(f"dq err: {get_abs_err(dw, ref_dw):.6f} ratio: {get_err_ratio(dw, ref_dw):.6f}")
print(f"dq err: {get_abs_err(dk, ref_dk):.6f} ratio: {get_err_ratio(dk, ref_dk):.6f}")
if __name__ == "__main__":
test_kernel()
import math
import torch
import torch.nn.functional as F
from einops import einsum
import tilelang as tl
import tilelang.language as T
from typing import Optional
from index import prepare_token_indices
from utils import get_abs_err, get_err_ratio
BF16 = T.bfloat16
FP32 = T.float32
INT32 = T.int32
pass_configs = {
tl.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True,
tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}
@tl.jit(pass_configs=pass_configs)
def tl_indexer_topk_reducesum_impl(
heads: int,
dim: int,
topk: int,
sm_scale: Optional[float] = None,
block_K: int = 32,
dtype: str = FP32,
num_stages: int = 0,
num_threads: int = 128,
):
assert topk == tl.math.next_power_of_2(topk)
assert topk % block_K == 0
assert heads <= 64 and heads % 8 == 0
assert num_stages == 0
batch_plus_one = T.symbolic("batch_plus_one")
seq_len = T.symbolic("seq_len")
index_q_shape = [seq_len, heads, dim]
weights_shape = [seq_len, heads]
index_k_shape = [seq_len, dim]
topk_indices_shape = [seq_len, topk]
offsets_shape = [batch_plus_one]
token_indices_shape = [seq_len, 2]
N = 2 * topk
num_iters = int(round(math.log2(N)))
if sm_scale is None:
sm_scale = dim**-0.5
@T.macro
def bitonic_sort(
topk_index_shared: T.SharedBuffer([N], dtype=INT32),
topk_value_shared: T.SharedBuffer([N], dtype=FP32),
):
T.sync_threads()
for i1 in T.serial(num_iters):
for i2 in T.serial(i1 + 1):
for i in T.Parallel(N):
ascending = (i & (1 << (i1 + 1))) != 0
j = i ^ (1 << (i1 - i2))
if i < j and (
(ascending and topk_value_shared[i] > topk_value_shared[j])
or (not ascending and topk_value_shared[i] < topk_value_shared[j])
):
val = topk_value_shared[i]
topk_value_shared[i] = topk_value_shared[j]
topk_value_shared[j] = val
idx = topk_index_shared[i]
topk_index_shared[i] = topk_index_shared[j]
topk_index_shared[j] = idx
T.sync_threads()
@T.prim_func
def tl_indexer_topk_reducesum_kernel(
IndexQ: T.Tensor(index_q_shape, dtype),
Weights: T.Tensor(weights_shape, dtype),
IndexK: T.Tensor(index_k_shape, dtype),
TopkIndices: T.Tensor(topk_indices_shape, INT32),
ReduceSum: T.Tensor(topk_indices_shape, FP32),
Offsets: T.Tensor(offsets_shape, INT32),
TokenIndices: T.Tensor(token_indices_shape, INT32),
):
with T.Kernel(seq_len, threads=num_threads) as (bx):
i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1]
bos, eos = Offsets[i_b], Offsets[i_b + 1]
num_blocks = T.ceildiv(i_t + 1, block_K)
topk_index_shared = T.alloc_shared([N], dtype=INT32)
topk_value_shared = T.alloc_shared([N], dtype=FP32)
T.fill(topk_index_shared, -1)
T.fill(topk_value_shared, float("-inf"))
T.sync_threads()
index_q_shared = T.alloc_shared([heads, dim], dtype=dtype)
T.copy(IndexQ[bos + i_t, :, :], index_q_shared)
T.sync_threads()
weights_frag = T.alloc_shared([heads], dtype=dtype)
T.copy(Weights[bos + i_t, :], weights_frag)
T.sync_threads()
for i, j in T.Parallel(heads, dim):
index_q_shared[i, j] = index_q_shared[i, j] * sm_scale
T.sync_threads()
for bk_i in T.Pipelined(num_blocks, num_stages=num_stages):
k_st = bk_i * block_K
k_ed = T.min((bk_i + 1) * block_K, eos - bos)
index_k_shared = T.alloc_shared([block_K, dim], dtype=dtype)
for i, j in T.Parallel(block_K, dim):
index_k_shared[i, j] = T.if_then_else(k_st + i < k_ed, IndexK[bos + k_st + i, j], 0)
T.sync_threads()
logits = T.alloc_fragment((block_K, heads), FP32)
T.gemm(
index_k_shared,
index_q_shared,
logits,
transpose_A=False,
transpose_B=True,
clear_accum=True,
)
T.sync_threads()
for i, j in T.Parallel(block_K, heads):
logits[i, j] = T.max(logits[i, j], 0) * weights_frag[j]
T.sync_threads()
logits_sum = T.alloc_fragment(block_K, FP32)
T.reduce_sum(logits, logits_sum, dim=1)
T.sync_threads()
offset = T.alloc_var(INT32)
if k_st >= topk:
offset = topk + (k_st % topk)
else:
offset = k_st
T.sync_threads()
for i in T.Parallel(block_K):
if k_st + i > i_t:
logits_sum[i] = float("-inf")
j = offset + i
topk_index_shared[j] = k_st + i
topk_value_shared[j] = logits_sum[i]
T.sync_threads()
if k_ed > topk and k_ed % topk == 0:
bitonic_sort(topk_index_shared, topk_value_shared)
bitonic_sort(topk_index_shared, topk_value_shared)
logits_max_frag = T.alloc_fragment([1], dtype=FP32)
logits_frag = T.alloc_fragment([topk], dtype=FP32)
reducesum_shared = T.alloc_shared([topk], dtype=FP32)
T.copy(topk_value_shared[:topk], logits_frag)
T.sync_threads()
T.reduce_max(logits_frag, logits_max_frag, dim=-1)
T.sync_threads()
for i in T.Parallel(topk):
logits_frag[i] = T.exp(logits_frag[i] - logits_max_frag[0])
T.sync_threads()
lse_frag = T.alloc_fragment([1], dtype=FP32)
T.reduce_sum(logits_frag, lse_frag)
T.sync_threads()
for i in T.Parallel(topk):
reducesum_shared[i] = logits_frag[i] / lse_frag[0]
T.sync_threads()
# for i in T.Parallel(topk):
# reducesum_shared[i] = logits_frag[i]
# T.sync_threads()
for i in T.Parallel(topk):
if topk_index_shared[i] > i_t:
topk_index_shared[i] = -1
T.sync_threads()
T.copy(topk_index_shared[:topk], TopkIndices[bos + i_t, :])
T.copy(reducesum_shared[:topk], ReduceSum[bos + i_t, :])
return tl_indexer_topk_reducesum_kernel
def indexer_topk_reducesum_interface(
q: torch.Tensor,
weights: torch.Tensor,
k: torch.Tensor,
topk: int,
offsets: torch.Tensor,
dtype: str = BF16,
):
seq_len, heads, dim = q.shape
kernel = tl_indexer_topk_reducesum_impl(heads=heads, dim=dim, topk=topk, dtype=dtype)
token_indices = prepare_token_indices(offsets)
topk_indices = torch.zeros((seq_len, topk), device=q.device, dtype=torch.int32)
topk_score = torch.zeros((seq_len, topk), device=q.device, dtype=torch.float32)
kernel(q, weights, k, topk_indices, topk_score, offsets, token_indices)
return topk_indices, topk_score
def ref_index_score(Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, topk: int, offsets: torch.Tensor) -> torch.Tensor:
all_topk_indices = []
all_topk_score = []
for i in range(offsets.shape[0] - 1):
assert (offsets[i + 1] - offsets[i]).item() >= topk
q = Q[offsets[i] : offsets[i + 1]]
weights = Weights[offsets[i] : offsets[i + 1]]
k = K[offsets[i] : offsets[i + 1]]
softmax_scale = q.shape[-1] ** -0.5
s = q.shape[0]
mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device)
logits = einsum(q, k, "s1 h k, s2 k -> s1 h s2")
logits = F.relu(logits)
logits = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * softmax_scale
logits = torch.where(mask, logits, float("-inf"))
topk_logits, topk_indices = torch.topk(logits, k=topk, dim=-1)
topk_score = F.softmax(topk_logits, dim=-1, dtype=torch.float32)
all_topk_indices.append(topk_indices)
all_topk_score.append(topk_score)
topk_indices = torch.cat(all_topk_indices, dim=0)
topk_score = torch.cat(all_topk_score, dim=0)
return topk_indices, topk_score
def test_kernel(
B=1,
S=2048,
H=64,
D=128,
topk=64,
):
torch.manual_seed(42)
q = torch.randn((S, H, D)).cuda().bfloat16()
weights = torch.randn((S, H)).cuda().bfloat16()
k = torch.randn((S, D)).cuda().bfloat16()
offsets = torch.tensor([0, S], dtype=torch.int32).cuda()
ref_topk_indices, ref_topk_score = ref_index_score(q, weights, k, topk, offsets)
topk_indices, topk_score = indexer_topk_reducesum_interface(q, weights, k, topk, offsets)
for j in range(S):
ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy()
trt_np = topk_indices[j].cpu().to(torch.int32).numpy()
ref_np_val = ref_topk_score[j]
trt_np_val = topk_score[j]
mask = (ref_np_val > 0).cpu().numpy()
set_ref = set(ref_np[mask])
set_trt = set(trt_np[mask])
intersection = set_ref & set_trt
print("idx:", j, "selected/all:", len(intersection), "/", len(set_ref), "=", len(intersection) / len(set_ref))
print(f"err: {get_abs_err(ref_np_val, trt_np_val):.6f} ratio: {get_err_ratio(ref_np_val, trt_np_val):.6f}")
if __name__ == "__main__":
test_kernel()
# ruff: noqa
import tilelang
from tilelang import language as T
import torch
from index import prepare_token_indices
from utils import assert_tensors_similar
@tilelang.jit(out_idx=[-1])
def preprocess(
H,
D,
block_ND=32,
num_stages=5,
dtype=T.bfloat16,
accum_dtype=T.float32,
):
assert dtype == T.bfloat16
assert accum_dtype == T.float32
S = T.symbolic("S")
shape = [S, H, D]
@T.prim_func
def preprocess_kernel(
O: T.Tensor(shape, dtype),
dO: T.Tensor(shape, dtype),
Delta: T.Tensor([S, H], accum_dtype),
):
with T.Kernel(H, T.ceildiv(S, block_ND)) as (bx, by):
o = T.alloc_fragment([block_ND, block_ND], accum_dtype)
do = T.alloc_fragment([block_ND, block_ND], accum_dtype)
delta = T.alloc_fragment([block_ND], accum_dtype)
acc = T.alloc_fragment([block_ND, block_ND], accum_dtype)
T.clear(acc)
for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages):
T.copy(O[by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], o)
T.copy(dO[by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], do)
for i, j in T.Parallel(block_ND, block_ND):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[by * block_ND : (by + 1) * block_ND, bx])
return preprocess_kernel
@tilelang.jit(out_idx=[-1])
def postprocess(
D,
D_tail,
kv_group=1,
block_N=64,
threads=128,
dtype=T.bfloat16,
accum_dtype=T.float32,
):
assert dtype == T.bfloat16
assert accum_dtype == T.float32
S_kv = T.symbolic("S_kv")
dkv_shape = [S_kv, kv_group, D + D_tail]
@T.prim_func
def postprocess_kernel(
dKV: T.Tensor(dkv_shape, accum_dtype),
dKV_out: T.Tensor(dkv_shape, dtype),
):
with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, threads=threads) as (bx, by):
T.copy(
dKV[bx * block_N : (bx + 1) * block_N, by, :],
dKV_out[bx * block_N : (bx + 1) * block_N, by, :],
)
return postprocess_kernel
@tilelang.jit(
out_idx=[-2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
def bwd(
H,
D,
D_tail,
topk,
kv_group=1,
sm_scale=None,
is_causal=True,
block_size=32,
num_stages=0,
threads=128,
indices_dtype=T.int32,
dtype=T.bfloat16,
accum_dtype=T.float32,
):
assert is_causal == True, "non-casual is not supported now"
assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded"
assert dtype == T.bfloat16
assert accum_dtype == T.float32
assert indices_dtype == T.int32
if sm_scale is None:
sm_scale = (D + D_tail) ** (-0.5)
B_plus_one = T.symbolic("B_plus_one")
S = T.symbolic("S")
H_kv = H // kv_group
q_shape = [S, H, D + D_tail]
k_shape = [S, kv_group, D + D_tail]
o_shape = [S, H, D]
indices_shape = [S, kv_group, topk]
delta_shape = [S, H]
lse_shape = [S, H]
offsets_shape = [B_plus_one]
token_indices_shape = [S, 2]
assert indices_dtype == T.int32
assert dtype == T.bfloat16
assert accum_dtype == T.float32
H = H_kv
padded_H = max(tilelang.math.next_power_of_2(H_kv), 16)
BS = block_size
NS = tilelang.cdiv(topk, block_size)
split_store = 2
@T.prim_func
def sparse_mla_bwd_kernel(
Q: T.Tensor(q_shape, dtype),
KV: T.Tensor(k_shape, dtype),
dO: T.Tensor(o_shape, dtype),
Indices: T.Tensor(indices_shape, indices_dtype),
Lse: T.Tensor(lse_shape, accum_dtype),
Delta: T.Tensor(delta_shape, accum_dtype),
Offsets: T.Tensor(offsets_shape, indices_dtype),
TokenIndices: T.Tensor(token_indices_shape, indices_dtype),
dQ: T.Tensor(q_shape, dtype),
dKV: T.Tensor(k_shape, accum_dtype),
):
with T.Kernel(S, kv_group, threads=threads) as (b_s_i, bz):
Q_shared = T.alloc_shared([padded_H, D], dtype)
Q_tail_shared = T.alloc_shared([padded_H, D_tail], dtype)
KV_shared = T.alloc_shared([BS, D], dtype)
KV_tail_shared = T.alloc_shared([BS, D_tail], dtype)
dO_shared = T.alloc_shared([padded_H, D], dtype)
mask = T.alloc_fragment([BS], "bool")
P_shared_cast = T.alloc_shared([padded_H, BS], dtype)
dP_shared_cast = T.alloc_shared([padded_H, BS], dtype)
dQ_shared = T.alloc_shared([padded_H, D], dtype)
dQ_tail_shared = T.alloc_shared([padded_H, D_tail], dtype)
acc_p = T.alloc_fragment([padded_H, BS], accum_dtype)
acc_dp = T.alloc_fragment([padded_H, BS], accum_dtype)
acc_dq = T.alloc_fragment([padded_H, D], accum_dtype)
acc_dq_tail = T.alloc_fragment([padded_H, D_tail], accum_dtype)
acc_dkv = T.alloc_fragment([BS, D], accum_dtype)
acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype)
acc_dkv_shared = T.view(KV_shared, shape=[BS // split_store, D], dtype=accum_dtype)
acc_dkv_tail_shared = T.view(KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype)
b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1]
bos, eos = Offsets[b_i], Offsets[b_i + 1]
max_kv_i = s_i
T.copy(Q[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D], Q_shared)
T.copy(Q[bos + s_i, bz * padded_H : (bz + 1) * padded_H, D:], Q_tail_shared)
T.copy(dO[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D], dO_shared)
T.clear(acc_dq)
T.clear(acc_dq_tail)
T.annotate_layout(
{
dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared),
dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared),
}
)
# Process each block of indices
for i_i in T.Pipelined(NS, num_stages=num_stages):
# Check which indices are valid
for bi_i in T.Parallel(BS):
mask[bi_i] = (Indices[bos + s_i, bz, i_i * BS + bi_i] <= max_kv_i) & (Indices[bos + s_i, bz, i_i * BS + bi_i] != -1)
# Compute attention scores
for h_i, bi_i in T.Parallel(padded_H, BS):
acc_p[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_p.dtype))
# Load KV, V for this block of indices
for bi_i, d_i in T.Parallel(BS, D):
KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, d_i]
T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
for bi_i, d_i in T.Parallel(BS, D_tail):
KV_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, D + d_i]
T.gemm(Q_tail_shared, KV_tail_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
for h_i, bi_i in T.Parallel(padded_H, BS):
acc_p[h_i, bi_i] = T.exp(acc_p[h_i, bi_i] * sm_scale - Lse[bos + s_i, bz * padded_H + h_i])
T.copy(acc_p, P_shared_cast)
T.gemm(dO_shared, KV_shared, acc_dp, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True)
for h_i, bi_i in T.Parallel(padded_H, BS):
acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[bos + s_i, bz * padded_H + h_i]) * sm_scale
T.copy(acc_dp, dP_shared_cast)
T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol)
T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol)
T.gemm(dP_shared_cast, Q_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True)
T.gemm(P_shared_cast, dO_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol)
T.clear(acc_dkv_tail)
T.gemm(dP_shared_cast, Q_tail_shared, acc_dkv_tail, transpose_A=True, policy=T.GemmWarpPolicy.FullCol)
for s in range(split_store):
for bi_i, d_i in T.Parallel(BS, D):
if bi_i < BS // split_store:
acc_dkv_shared[bi_i, d_i] = acc_dkv[bi_i + s * (BS // split_store), d_i]
for bi_i, d_i in T.Parallel(BS, D_tail):
if bi_i < BS // split_store:
acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), d_i]
for bi_i, d_i in T.Parallel(BS // split_store, D // 4):
T.atomic_addx4(
dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, d_i * 4],
acc_dkv_shared[bi_i, d_i * 4],
)
# Atomically update dKV, dKV_tail tensors
for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4):
T.atomic_addx4(
dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, D + d_i * 4],
acc_dkv_tail_shared[bi_i, d_i * 4],
)
# Store the accumulated dQ
T.copy(acc_dq, dQ_shared)
T.copy(acc_dq_tail, dQ_tail_shared)
T.copy(dQ_shared, dQ[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D])
T.copy(dQ_tail_shared, dQ[bos + s_i, bz * padded_H : (bz + 1) * padded_H, D:])
return sparse_mla_bwd_kernel
def sparse_mla_bwd(q, kv, o, do, indices, lse, offsets, sm_scale=None, is_casual=True, return_kernel=False, delta=None):
assert q.is_contiguous()
assert kv.is_contiguous()
assert indices.is_contiguous()
assert lse.is_contiguous()
S, H, dim_plus_tail_dim = q.shape
S_kv, kv_group, _ = kv.shape
assert kv.shape[-1] == dim_plus_tail_dim
assert S == S_kv
# dim should be assigned
D = 512
D_tail = dim_plus_tail_dim - D
topk = indices.shape[-1]
assert indices.shape == (S, kv_group, topk)
assert lse.shape == (S, H)
token_indices = prepare_token_indices(offsets)
# Get kernels
preprocess_kernel = preprocess(H, D)
bwd_kernel = bwd(H, D, D_tail, topk, kv_group, sm_scale, is_casual)
postprocess_kernel = postprocess(D, D_tail, kv_group)
if delta is None:
delta = preprocess_kernel(o, do)
dkv = torch.zeros_like(kv, dtype=torch.float32)
dq = bwd_kernel(q, kv, do, indices, lse, delta, offsets, token_indices, dkv)
dkv = postprocess_kernel(dkv)
return dq, dkv
def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, offsets, sm_scale=None, is_casual=True):
from sparse_mla_fwd import ref_sparse_mla_fwd_interface
q = q.detach().clone()
kv = kv.detach().clone()
q.requires_grad = True
kv.requires_grad = True
o = ref_sparse_mla_fwd_interface(q, kv, indices, offsets, sm_scale, is_casual)
o.backward(do)
return q.grad, kv.grad
def test_sparse_mla_bwd(B=1, S=2048, H=64, HKV=1, DQKV=576, DV=512, topk=512, dtype=torch.bfloat16, check_correctness=True):
# Prepare data
q = torch.randn((S, H, DQKV), dtype=dtype, device="cuda").requires_grad_(True)
kv = torch.randn((S, HKV, DQKV), dtype=dtype, device="cuda").requires_grad_(True)
do = torch.randn((S, H, DV), dtype=dtype, device="cuda")
offsets = torch.tensor([0, S], dtype=torch.int32, device="cuda")
indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device="cuda")
for i in range(offsets.shape[0] - 1):
seq_len = (offsets[i + 1] - offsets[i]).item()
assert seq_len >= topk
for t in range(seq_len):
for h in range(HKV):
i_i = torch.randperm(max(1, t))[:topk]
indices[offsets[i] + t, h, : len(i_i)] = i_i
# Forward
from sparse_mla_fwd import sparse_mla_fwd_interface
tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets)
tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse, offsets)
ref_dq, ref_dkv = ref_sparse_mla_bwd_interface(q, kv, None, do, indices, None, offsets)
if check_correctness:
assert_tensors_similar(tl_dq, ref_dq, eps=1e-4, name="dq")
assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv")
print("assert_tensors_similar passed")
per_token_flop = 2 * sum(
[
H * DV * topk,
H * DQKV * topk,
H * DQKV * topk,
H * DQKV * topk,
H * DV * topk,
]
)
from tilelang.profiler import do_bench
def fn():
return sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse, offsets)
ms = do_bench(fn, rep=100, warmup=250)
print(f"Average time: {ms:.3f} ms")
print(f"bwd io bandwidth = ", (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12)
print(f"bwd tflops = ", per_token_flop * S / (ms * 1e-3) / 1e12)
if __name__ == "__main__":
test_sparse_mla_bwd(B=1, S=2048, H=64, HKV=1, DQKV=576, DV=512, topk=512, dtype=torch.bfloat16, check_correctness=True)
# ruff: noqa
import torch
import tilelang
from tilelang import language as T
from index import prepare_token_indices
from utils import assert_tensors_similar
@tilelang.jit(
out_idx=[-2, -1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
def sparse_mla_fwd(
heads,
dim,
tail_dim,
topk,
kv_group=1,
sm_scale=None,
is_causal=True,
CP0=True,
block_I=32,
num_stages=2,
threads=128,
):
assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}"
assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
assert is_causal == True, "non-casual is not supported"
assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded"
if sm_scale is None:
sm_scale = (1.0 / (dim + tail_dim)) ** 0.5
else:
sm_scale = sm_scale
batch_plus_one = T.symbolic("batch_plus_one")
seq_len = T.symbolic("seq_len")
head_kv = heads // kv_group
q_shape = [seq_len, heads, dim + tail_dim]
kv_shape = [seq_len, kv_group, dim + tail_dim]
o_shape = [seq_len, heads, dim]
indices_shape = [seq_len, kv_group, topk]
lse_shape = [seq_len, heads]
offsets_shape = [batch_plus_one]
token_indices_shape = [seq_len, 2]
indices_dtype = T.int32
dtype = T.bfloat16
accum_dtype = T.float32
G = kv_group
H = head_kv
padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
if padded_H != H:
assert kv_group == 1, (
"here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
)
BI = block_I
NI = tilelang.cdiv(topk, block_I)
D = dim
D_tail = tail_dim
if head_kv > 64:
assert head_kv % 64 == 0, "head_kv should be a multiple of 64"
REPLICATE_H = head_kv // 64
else:
REPLICATE_H = 1
H_per_block = padded_H if REPLICATE_H == 1 else 64
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore
TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore
Output: T.Tensor(o_shape, dtype), # type: ignore
Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore
):
with T.Kernel(seq_len * REPLICATE_H, kv_group, threads=threads) as (
bx,
by,
):
Q_shared = T.alloc_shared([H_per_block, D], dtype)
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
KV_shared = T.alloc_shared([BI, D], dtype)
K_tail_shared = T.alloc_shared([BI, D_tail], dtype)
mask = T.alloc_fragment([BI], "bool")
acc_o = T.alloc_fragment([H_per_block, D], accum_dtype)
acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
S_shared = T.alloc_shared([H_per_block, BI], dtype)
sumexp = T.alloc_fragment([H_per_block], accum_dtype)
sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)
alpha = T.alloc_fragment([H_per_block], accum_dtype)
m_i = T.alloc_fragment([H_per_block], accum_dtype)
m_i_prev = T.alloc_fragment([H_per_block], accum_dtype)
T.fill(acc_o, 0)
T.fill(sumexp, 0)
T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
b_s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)
b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1]
bos, eos = Offsets[b_i], Offsets[b_i + 1]
g_i = by
q_i = s_i
max_kv_i = q_i
H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
H1 = H0 + H_per_block
T.copy(Q[bos + s_i, H0:H1, :D], Q_shared)
T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared)
for i_i in T.Pipelined(NI, num_stages=num_stages):
for bi_i in T.Parallel(BI):
mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1)
for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i]
for bi_i, d_i in T.Parallel(BI, D_tail):
K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i]
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype))
T.gemm(
Q_shared,
KV_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
T.gemm(
Q_tail_shared,
K_tail_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
alpha[h_i] = T.exp((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale)
T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator?
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i]
for h_i, d_i in T.Parallel(H_per_block, D):
acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i]
T.copy(acc_s, S_shared)
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
# Rescale
for h_i, d_i in T.Parallel(H_per_block, D):
acc_o[h_i, d_i] /= sumexp[h_i]
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = T.log(sumexp[h_i]) + m_i[h_i] * sm_scale
T.copy(acc_o, Output[bos + s_i, H0:H1, :])
T.copy(sumexp, Lse[bos + s_i, H0:H1])
return main
def sparse_mla_fwd_interface(
q, kv, indices, offsets, sm_scale=None, return_p_sum: bool = False, d_v=512, block_I=32, num_stages=2, threads=128
):
is_casual = True
assert return_p_sum == False, "This kernel file is for fwd only"
assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous()
seq_len, heads, dim_plus_tail_dim = q.shape
seq_len_kv, kv_group, _ = kv.shape
assert seq_len == seq_len_kv
assert dim_plus_tail_dim == 576, "you should assign dim otherwise"
dim = d_v
assert kv.shape[-1] == dim_plus_tail_dim
tail_dim = dim_plus_tail_dim - dim
_, _, topk = indices.shape
assert indices.shape == (seq_len, kv_group, topk)
token_indices = prepare_token_indices(offsets)
kernel = sparse_mla_fwd(
heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual, block_I=block_I, num_stages=num_stages, threads=threads
)
out, lse = kernel(q, kv, indices, offsets, token_indices)
return out, lse
def ref_sparse_mla_fwd_interface(Q, KV, Indices, offsets, sm_scale=None, is_casual=True):
Q = Q.float()
KV = KV.float()
all_o = []
for i in range(offsets.shape[0] - 1):
q = Q[None, offsets[i] : offsets[i + 1]]
kv = KV[None, offsets[i] : offsets[i + 1]]
indices = Indices[None, offsets[i] : offsets[i + 1]].clone()
indices = indices.transpose(1, 2)
b, sq, h, dim_q = q.shape
b, sk, g, _ = kv.shape
assert kv.shape[-1] == 576, "you should assign dim otherwise"
dim = 512
k = kv
v = kv[..., :dim]
b, _, _, dim_v = v.shape
g_index = g
h_index = h // g
compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange(
1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda"
).view(1, -1)
indices[indices > sk] = sk
mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1)
mask = mask[..., :-1]
mask = mask & compressed_casual_mask.view(1, 1, sq, sk)
mask[:, :, : 1 - 1, 0] = True
mask = mask.view(b, g_index, 1, sq, sk)
q = q.view(b, sq, g, -1, dim_q)
score = torch.einsum("bmghd,bngd->bghmn", q, k)
sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale
score = score.masked_fill(~mask, float("-inf")).mul(sm_scale)
p = score.softmax(dim=-1)
p = p.view(b, g_index, h_index, -1, sq, sk)
p = p.view(b, g, -1, sq, sk)
o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v)
o = o.reshape(b, sq, h, dim_v)
all_o.append(o.squeeze(0))
o = torch.cat(all_o, dim=0)
return o.to(torch.bfloat16)
def test_sparse_mla_fwd(
B=1,
S=4096,
H=128,
HKV=1,
DQK=576,
DV=512,
topk=2048,
dtype=torch.bfloat16,
check_correctness=True,
block_I=64,
num_stages=2,
threads=256,
):
torch.random.manual_seed(0)
q = torch.randn((S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True)
kv = torch.randn((S, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True)
offsets = torch.tensor([0, S // 2 - 1, S], dtype=torch.int32, device="cuda")
indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device="cuda")
for i in range(offsets.shape[0] - 1):
seq_len = (offsets[i + 1] - offsets[i]).item()
assert seq_len >= topk
for t in range(seq_len):
for h in range(HKV):
i_i = torch.randperm(max(1, t))[:topk]
indices[offsets[i] + t, h, : len(i_i)] = i_i
tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads)
if check_correctness:
# otherwise may cause out of memory
ref_out = ref_sparse_mla_fwd_interface(q, kv, indices, offsets)
assert_tensors_similar(tl_out, ref_out, eps=1e-2, name="out")
print("assert_tensors_similar passed")
def fn():
return sparse_mla_fwd_interface(q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads)
from tilelang.profiler import do_bench
ms = do_bench(
fn,
rep=100,
warmup=250,
)
print(f"Average time: {ms:.3f} ms")
print("fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12)
print("fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12)
if __name__ == "__main__":
test_sparse_mla_fwd(
B=1,
S=4096,
H=128,
HKV=1,
DQK=576,
DV=512,
topk=1024,
dtype=torch.bfloat16,
check_correctness=True,
block_I=64,
num_stages=2,
threads=256,
)
# ruff: noqa
import torch
import torch.nn as nn
import torch.nn.functional as F
import tilelang
from tilelang import language as T
from einops import repeat, rearrange, einsum
from index import prepare_token_indices
from utils import get_abs_err, get_err_ratio
BF16 = T.bfloat16
FP32 = T.float32
INT32 = T.int32
pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}
@tilelang.jit(pass_configs=pass_configs)
def tl_sparse_mla_topk_reducesum_impl(
heads,
dim,
tail_dim,
topk,
kv_group=1,
sm_scale=None,
block_I=32,
num_stages=2,
threads=128,
):
assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}"
assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded"
if sm_scale is None:
sm_scale = (1.0 / (dim + tail_dim)) ** 0.5
batch_plus_one = T.symbolic("batch_plus_one")
seq_len = T.symbolic("seq_len")
seq_len_kv = T.symbolic("seq_len_kv")
head_kv = heads // kv_group
indices_dtype = T.int32
dtype = T.bfloat16
accum_dtype = T.float32
G = kv_group
H = head_kv
padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
if padded_H != H:
assert kv_group == 1, (
"here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
)
BI = block_I
NI = tilelang.cdiv(topk, block_I)
D = dim
D_tail = tail_dim
if head_kv > 64:
assert head_kv % 64 == 0, "head_kv should be a multiple of 64"
REPLICATE_H = head_kv // 64
else:
REPLICATE_H = 1
H_per_block = padded_H if REPLICATE_H == 1 else 64
q_shape = [seq_len, heads, dim + tail_dim]
kv_shape = [seq_len_kv, kv_group, dim + tail_dim]
indices_shape = [seq_len, kv_group, topk]
lse_shape = [seq_len, heads]
reducesum_shape = [seq_len, kv_group, REPLICATE_H, topk]
offsets_shape = [batch_plus_one]
token_indices_shape = [seq_len, 2]
@T.prim_func
def tl_sparse_mla_topk_reducesum_kernel(
Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore
Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore
TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore
ReduceSum: T.Tensor(reducesum_shape, accum_dtype), # type: ignore
):
with T.Kernel(seq_len * REPLICATE_H, kv_group, threads=threads) as (
bx,
by,
):
Q_shared = T.alloc_shared([H_per_block, D], dtype)
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
KV_shared = T.alloc_shared([BI, D], dtype)
K_tail_shared = T.alloc_shared([BI, D_tail], dtype)
mask = T.alloc_fragment([BI], "bool")
acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
reducesum = T.alloc_fragment([BI], accum_dtype)
lse = T.alloc_fragment([H_per_block], accum_dtype)
T.fill(lse, 0)
b_s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)
b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1]
bos, eos = Offsets[b_i], Offsets[b_i + 1]
r_i = bx % REPLICATE_H
g_i = by
q_i = s_i
max_kv_i = q_i
H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
H1 = H0 + H_per_block
T.copy(Q[bos + s_i, H0:H1, :D], Q_shared)
T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared)
T.copy(Lse[bos + s_i, H0:H1], lse)
for i_i in T.Pipelined(NI, num_stages=num_stages):
for bi_i in T.Parallel(BI):
mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1)
for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, d_i]
for bi_i, d_i in T.Parallel(BI, D_tail):
K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i, D + d_i]
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype))
T.gemm(
Q_shared,
KV_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
T.gemm(
Q_tail_shared,
K_tail_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - lse[h_i])
T.reduce_sum(acc_s, reducesum, dim=0)
T.copy(reducesum, ReduceSum[bos + s_i, g_i, r_i, i_i * BI : i_i * BI + BI])
return tl_sparse_mla_topk_reducesum_kernel
def sparse_mla_topk_reducesum_interface(
q: torch.Tensor,
kv: torch.Tensor,
topk_indices: torch.Tensor,
lse: torch.Tensor,
offsets: torch.Tensor,
dim_v: int,
):
assert kv.shape[-2] == 1
seq_len, heads, dim_plus_tail_dim, topk = *q.shape, topk_indices.shape[-1]
REPLICATE_H = max(heads // 64, 1)
tail_dim = dim_plus_tail_dim - dim_v
token_indices = prepare_token_indices(offsets)
reducesum = torch.zeros([seq_len, 1, REPLICATE_H, topk], dtype=torch.float32, device=q.device)
kernel = tl_sparse_mla_topk_reducesum_impl(heads=heads, dim=dim_v, tail_dim=tail_dim, topk=topk)
kernel(q, kv, topk_indices, lse, offsets, token_indices, reducesum)
reducesum = reducesum.sum(dim=-2) # [batch, seq_len, 1, RH, topk] -> [batch, seq_len, 1, topk]
attn_score = reducesum / reducesum.sum(dim=-1, keepdim=True)
return attn_score
def ref_mla_topk_softmax(Q: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, offsets: torch.Tensor):
# q: [batch, seq_len, heads, dim]
# k: [batch, seq_len, dim]
sm_scale = Q.shape[-1] ** -0.5
all_lse = []
all_topk_score = []
for i in range(offsets.shape[0] - 1):
q = Q[offsets[i] : offsets[i + 1]]
k = K[offsets[i] : offsets[i + 1]]
topk_indices = TopkIndices[offsets[i] : offsets[i + 1]]
seq_len = q.shape[0]
mask = (torch.arange(seq_len)[:, None] >= torch.arange(seq_len)[None, :]).unsqueeze(-2).cuda()
logits = einsum(q, k, "s1 h d, s2 d -> s1 h s2") * sm_scale
logits = torch.where(mask, logits, float("-inf"))
score = F.softmax(logits, dim=-1, dtype=torch.float32)
score_sum = score.sum(dim=-2)
topk_score = torch.gather(score_sum, dim=-1, index=topk_indices.to(torch.int64))
topk_score = topk_score / topk_score.sum(dim=-1, keepdim=True)
max_logits = logits.amax(dim=-1).to(torch.float32)
lse = torch.log((logits - max_logits.unsqueeze(-1).to(torch.float32)).exp().sum(dim=-1)) + max_logits
all_lse.append(lse)
all_topk_score.append(topk_score)
lse = torch.cat(all_lse, dim=0)
topk_score = torch.cat(all_topk_score, dim=0)
return lse, topk_score
def test_kernel(
B=1,
S=2048,
H=16,
D=512,
tail_D=64,
topk=128,
):
torch.manual_seed(42)
q = torch.randn((S, H, D + tail_D)).cuda().bfloat16()
kv = torch.randn((S, D + tail_D)).cuda().bfloat16()
offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda()
topk_indices = repeat(torch.arange(topk, dtype=torch.int32).cuda(), "k -> s k", s=S).contiguous()
lse, ref_attn_score = ref_mla_topk_softmax(q, kv, topk_indices, offsets)
kv = kv.unsqueeze(-2)
topk_indices = topk_indices.unsqueeze(-2)
attn_score = sparse_mla_topk_reducesum_interface(q, kv, topk_indices, lse, offsets, dim_v=D).squeeze(-2)
print(f"attn_score err: {get_abs_err(attn_score, ref_attn_score):.6f} ratio: {get_err_ratio(attn_score, ref_attn_score):.6f}")
if __name__ == "__main__":
test_kernel()
import torch
def get_abs_err(y, x):
x = x.to(torch.float32)
y = y.to(torch.float32)
return (x - y).flatten().abs().max().item()
def get_err_ratio(y, x):
x = x.to(torch.float32)
y = y.to(torch.float32)
err = (x - y).flatten().square().mean().sqrt().item()
base = (x).flatten().square().mean().sqrt().item()
return err / base
def calculate_tensor_similarity(x, y, name="tensor"):
"""
Calculate similarity between two tensors using a normalized dot product metric.
Unlike torch.testing.assert_close which uses absolute/relative tolerance based on
element-wise differences, this function computes a global similarity score:
sim = 2 * <x, y> / (||x||^2 + ||y||^2)
This metric is scale-invariant and measures the cosine-like similarity normalized
by the magnitude of both tensors. It returns 1 for identical tensors and values
closer to 0 for dissimilar ones. This is particularly useful for comparing tensors
with varying magnitudes where relative errors matter more than absolute differences.
Args:
x: First tensor to compare
y: Second tensor to compare
name: Name of the tensor for logging purposes
Returns:
Similarity score in range [0, 1] where 1 means identical
"""
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print(f"\033[33mWARNING: {name} all zero\033[0m")
return 1
sim = 2 * (x * y).sum() / denominator
return sim
def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True):
"""
Assert that two tensors are similar using a global similarity metric.
Key differences from torch.testing.assert_close:
- torch.testing.assert_close: Uses element-wise comparison with rtol/atol, checking
that |x - y| <= atol + rtol * |y| for each element. It's sensitive to outliers
and requires all elements to satisfy the tolerance.
- assert_tensors_similar: Uses a single global similarity score (1 - sim) where sim is the
normalized dot product. It's more robust to outliers and focuses on overall
tensor similarity rather than element-wise precision. This is better suited for
comparing large tensors where a few outlier elements shouldn't fail the test.
Args:
x: First tensor to compare
y: Second tensor to compare
eps: Maximum allowed difference (1 - similarity), default 1e-8
name: Name of the tensor for error messages
raise_assert: Whether to raise assertion error on failure
"""
sim = calculate_tensor_similarity(x, y, name)
diff = 1.0 - sim
if not (0 <= diff <= eps):
print(f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m")
if raise_assert:
assert False # noqa: B011
import tilelang
import tilelang.language as T
import tilelang.testing
from tilelang import tvm as tvm
@tilelang.jit(pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8})
def matmul_dynamic_mnk(
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
M = tvm.te.var("m")
N = tvm.te.var("n")
K = tvm.te.var("k")
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def dynamic_matmul(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return dynamic_matmul
def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
accum_dtype, num_stages, threads):
print(
f"M: {M}, N: {N}, K: {K}, block_M: {block_M}, block_N: {block_N}, block_K: {block_K}, trans_A: {trans_A}, trans_B: {trans_B}, in_dtype: {in_dtype}, out_dtype: {out_dtype}, accum_dtype: {accum_dtype}, num_stages: {num_stages}, threads: {threads}"
)
kernel = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
accum_dtype, num_stages, threads)
import torch
if trans_A:
A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype))
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
if trans_B:
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
else:
B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
kernel(A, B, C)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
# Get Reference Result
ref_c = ref_program(A, B)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench(input_tensors=[A, B, C])
print(f"Latency: {latency} ms")
def main(M=16384, N=16384, K=16384):
block_M, block_N, block_K = 128, 128, 32
trans_A, trans_B = False, False
in_dtype, out_dtype = "float16", "float16"
accum_dtype = "float32"
num_stages = 3
threads = 128
matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
accum_dtype, num_stages, threads)
if __name__ == "__main__":
main()
......@@ -3,19 +3,25 @@ import itertools
import torch
import tilelang
import tilelang.language as T
from tilelang.autotuner import AutoTuner
def ref_program(x, y):
return x + y
def get_configs():
block_M = [64, 128, 256]
block_N = [64, 128, 256]
threads = [64, 128, 256]
configs = list(itertools.product(block_M, block_N, threads))
return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs]
@tilelang.autotune(configs=get_configs())
@tilelang.jit(out_idx=[-1])
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
@T.prim_func
def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor(
(M, N), out_dtype)):
def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), in_dtype)
B_shared = T.alloc_shared((block_M, block_N), in_dtype)
......@@ -24,7 +30,7 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(B[by * block_M, bx * block_N], B_shared)
for (local_y, local_x) in T.Parallel(block_M, block_N):
for local_y, local_x in T.Parallel(block_M, block_N):
C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x]
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
......@@ -32,53 +38,25 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
return elem_add
def get_configs(M, N):
block_M = [64, 128, 256]
block_N = [64, 128, 256]
threads = [64, 128, 256]
configs = list(itertools.product(block_M, block_N, threads))
return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs]
def get_best_config(M, N):
def kernel(block_M=None, block_N=None, threads=None):
return elementwise_add(M, N, block_M, block_N, "float32", "float32", threads)
autotuner = AutoTuner.from_kernel(
kernel=kernel, configs=get_configs(M, N)).set_compile_args(
out_idx=[-1],
target="cuda",
).set_profile_args(
supply_type=tilelang.TensorSupplyType.Auto,
ref_prog=ref_program,
skip_check=False,
)
return autotuner.run(warmup=3, rep=20)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=1024)
parser.add_argument("--n", type=int, default=1024)
parser.add_argument("--use_autotune", action="store_true", default=False)
args, _ = parser.parse_known_args()
M, N = args.m, args.n
def main(M=1024, N=1024, use_autotune=False):
a = torch.randn(M, N, dtype=torch.float32, device="cuda")
b = torch.randn(M, N, dtype=torch.float32, device="cuda")
if args.use_autotune:
result = get_best_config(M, N)
kernel = result.kernel
if use_autotune:
kernel = elementwise_add(M, N, in_dtype=T.float32, out_dtype=T.float32)
else:
# Default config
config = {"block_M": 32, "block_N": 32, "threads": 128}
kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32")
kernel = elementwise_add(M, N, **config, in_dtype=T.float32, out_dtype=T.float32)
out = kernel(a, b)
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=1024)
parser.add_argument("--n", type=int, default=1024)
parser.add_argument("--use_autotune", action="store_true", default=False)
args, _ = parser.parse_known_args()
main(args.m, args.n, args.use_autotune)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment