Commit bc2d5632 authored by root's avatar root
Browse files

init

parents
Pipeline #3222 failed with stages
in 0 seconds
import torch
import tilelang
import tilelang.language as T
pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True,
}
def convert_to_uint16(x):
hval = T.Cast("float16", x)
bits_uint = T.reinterpret("uint16", hval)
bits_uint = T.if_then_else(x < 0, ~bits_uint & (0xFFFF), bits_uint | (0x8000))
return bits_uint >> 8
def convert_to_uint32(x):
bits_uint = T.reinterpret("uint32", x)
bits_uint = T.if_then_else(
x < 0,
~bits_uint & T.Cast("uint32", (0xFFFFFFFF)),
bits_uint | T.Cast("uint32", (0x80000000)),
)
return bits_uint
@tilelang.jit(pass_configs=pass_configs)
def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"):
batch = T.dynamic("batch")
seq_len = T.dynamic("seq_len")
RADIX = 1 << 8
BLOCK_SIZE = 1024
SMEM_INPUT_SIZE = 4096 # assume the threshold bucket size after first pass is less than 4K
@T.prim_func
def tl_topk_kernel(
input: T.Tensor[(batch, seq_len), in_dtype],
index: T.Tensor[(batch, topk), out_dtype],
starts: T.Tensor[(batch), out_dtype],
ends: T.Tensor[(batch), out_dtype],
):
with T.Kernel(batch, threads=BLOCK_SIZE) as (bx):
tx = T.get_thread_binding()
s_threshold_bin_id = T.alloc_shared([1], "int32")
s_histogram = T.alloc_shared([RADIX + 1], "int32")
s_num_input = T.alloc_shared([2], "int32")
s_input_idx = T.alloc_shared([2, SMEM_INPUT_SIZE], "int32")
l_threshold_bin_id = T.alloc_var("int32")
l_new_topk = T.alloc_var("int32")
l_num_input = T.alloc_var("int32")
l_bin_id32 = T.alloc_var("int32")
l_val = T.alloc_var("int32")
l_start_pos = T.alloc_var("int32")
l_start_idx = T.alloc_var("int32")
l_end_idx = T.alloc_var("int32")
l_out_pos = T.alloc_var("int32")
l_new_topk = topk
l_start_idx = starts[bx]
l_end_idx = ends[bx]
# stage 1: use 8bit to do quick topk
T.fill(s_histogram, 0)
T.fill(s_num_input[0], 0)
T.sync_threads()
for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)):
input_idx = s * BLOCK_SIZE + tx
if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len:
inval_int16 = convert_to_uint16(input[bx, input_idx])
T.atomic_add(s_histogram[inval_int16], 1)
T.sync_threads()
# cumsum
if tx < RADIX:
for i in T.serial(8):
offset = 1 << i
T.sync_threads(3, RADIX)
if tx < RADIX - offset:
l_val = s_histogram[tx] + s_histogram[tx + offset]
T.sync_threads(3, RADIX)
if tx < RADIX - offset:
s_histogram[tx] = l_val
# find threshold bin id
T.sync_threads(3, RADIX)
if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk:
s_threshold_bin_id[0] = tx
T.sync_threads()
l_threshold_bin_id = s_threshold_bin_id[0]
l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1]
T.sync_threads()
# collect all elements with exponent ≥ threshold
for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)):
T.sync_threads()
input_idx = s * BLOCK_SIZE + tx
if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len:
bin_id = convert_to_uint16(input[bx, input_idx])
l_bin_id32 = T.Cast("int32", bin_id)
if l_bin_id32 > l_threshold_bin_id:
# need a pos = T.atomic_add(s_histogram[bin_id32+1], 1)
pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True)
index[bx, pos] = input_idx
elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0:
# pos = s_num_input[0]
pos = T.atomic_add(s_num_input[0], 1, return_prev=True)
s_input_idx[0, pos] = input_idx
# stage 2: tail pass
for round in T.serial(4):
if l_new_topk <= 0:
T.loop_break()
r_idx = round % 2
l_start_pos = topk - l_new_topk
T.sync_threads()
T.fill(s_histogram, 0)
if tx == 0:
s_num_input[r_idx ^ 1] = 0
T.sync_threads()
l_num_input = s_num_input[r_idx]
for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)):
if s * BLOCK_SIZE + tx < l_num_input:
l_bin_id32 = T.Cast("int32", ((
convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >>
(24 - round * 8)) & 0xFF))
T.atomic_add(s_histogram[l_bin_id32], 1)
T.sync_threads()
# cumsum
if tx < RADIX:
for i in T.serial(8):
offset = 1 << i
T.sync_threads(3, RADIX)
if tx < RADIX - offset:
l_val = s_histogram[tx] + s_histogram[tx + offset]
T.sync_threads(3, RADIX)
if tx < RADIX - offset:
s_histogram[tx] = l_val
# find threshold bin id
T.sync_threads(3, RADIX)
if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk:
s_threshold_bin_id[0] = tx
T.sync_threads()
l_threshold_bin_id = s_threshold_bin_id[0]
l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1]
T.sync_threads()
for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)):
T.sync_threads()
if s * BLOCK_SIZE + tx < l_num_input:
l_bin_id32 = T.Cast("int32", ((
convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >>
(24 - round * 8)) & 0xFF))
if l_bin_id32 > l_threshold_bin_id:
pos = T.atomic_add(
s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos
index[bx, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx]
elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0:
if round == 3:
l_out_pos = T.atomic_add(
s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos
if l_out_pos < topk:
index[bx, l_out_pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx]
else:
pos = T.atomic_add(s_num_input[r_idx ^ 1], 1, return_prev=True)
s_input_idx[r_idx ^ 1, pos] = s_input_idx[r_idx,
s * BLOCK_SIZE + tx]
return tl_topk_kernel
def tl_topk(input, starts, ends, topk):
batch, seq_len = input.shape
indexes = torch.zeros(batch, topk, dtype=torch.int32, device=input.device)
kernel = tl_topk_impl(topk)
kernel(input, indexes, starts, ends)
return indexes
def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048):
batch = 64
seq_len = 32 * 1024
topk = 2048
torch.manual_seed(1)
input = torch.randn(batch, seq_len, dtype=torch.float32).cuda()
starts = torch.zeros(batch, dtype=torch.int32).cuda()
ends = torch.ones(batch, dtype=torch.int32).cuda() * seq_len
indexes = tl_topk(input, starts, ends, topk)
print(indexes)
indexes_ref = torch.topk(input, topk, dim=-1)[1]
print(indexes_ref)
# indexes_ref = fast_topk(input, topk)
# print(indexes_ref)
# Calculate intersection of out_ref and out_trt
for i in range(batch):
ref_np = indexes_ref[i].cpu().to(torch.int32).numpy()
trt_np = indexes[i].cpu().to(torch.int32).numpy()
set_ref = set(ref_np)
set_trt = set(trt_np)
intersection = set_ref & set_trt
print("selected/all:", len(intersection), "/", len(set_ref), "=",
len(intersection) / len(set_ref))
# Performance test with CUDA events
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# Warmup
for _ in range(5):
_ = tl_topk(input, starts, ends, topk)
torch.cuda.synchronize()
n_iters = 20
start_event.record()
for _ in range(n_iters):
_ = tl_topk(input, starts, ends, topk)
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
print(f"Average tl_topk time: {elapsed_time_ms / n_iters:.3f} ms")
# Torch topk time
start_event.record()
for _ in range(n_iters):
_ = torch.topk(input, topk, dim=-1)[1]
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
print(f"Average torch.topk time: {elapsed_time_ms / n_iters:.3f} ms")
if __name__ == "__main__":
test_topk_selector()
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
import contextlib
import functools
import logging
import os
import sys
from enum import Enum
from functools import lru_cache
from typing import Any, Callable, Dict, Literal, Optional, Tuple
from packaging import version
def _is_equal(a, b):
if isinstance(a, torch.Tensor):
return a is b
# Whitelist of types that are safe to compare by value for caching.
if isinstance(a, (int, float, str, bool, type(None))) and isinstance(
b, (int, float, str, bool, type(None))):
return a == b
# For other types, we cannot guarantee a cheap and safe comparison, so we fail the cache check.
return False
def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
"""
A decorator that caches the most recent result of a function with tensor inputs.
This decorator will store the output of the decorated function for the most recent set of input tensors.
If the function is called again with the same input tensors, it will return the cached result.
Args:
fn (Callable[..., torch.Tensor]):
The function to be decorated. It should take tensor inputs and return tensor outputs.
Returns:
Callable[..., torch.Tensor]:
A wrapped version of the input function with single-entry caching.
"""
last_args: Optional[Tuple] = None
last_kwargs: Optional[Dict] = None
last_result: Any = None
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
nonlocal last_args, last_kwargs, last_result
if last_args is not None and last_kwargs is not None:
if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
# For Tensors, check for object identity. For other types, check for equality.
# Python caches small integers, so `is` works for them but not for large integers like 4096.
if all(_is_equal(a, b) for a, b in zip(args, last_args)) and \
set(kwargs.keys()) == set(last_kwargs.keys()) and \
all(_is_equal(v, last_kwargs[k]) for k, v in kwargs.items()):
return last_result
result = fn(*args, **kwargs)
last_args, last_kwargs, last_result = args, kwargs, result
return result
return wrapper
@tensor_cache
def cal_seq_idx_from_cu_seqlens(cu_seqlens: torch.LongTensor, seq_len: int):
seq_idx = cu_seqlens.new_zeros(seq_len + 1)
seq_idx.scatter_add_(0, cu_seqlens[1:].long(), torch.ones_like(seq_idx))
seq_idx.cumsum_(0)
return seq_idx[:-1]
@tensor_cache
def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor,
seq_len: int) -> torch.IntTensor:
seq_idx_for_q = torch.full((seq_len,),
len(cu_seqlens_qs),
dtype=torch.int32,
device=cu_seqlens_qs.device)
for i in range(len(cu_seqlens_qs)):
seq_idx_for_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = i
return seq_idx_for_q
@tensor_cache
def cal_cu_seqlen_ks_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor,
cu_seqlens_ks: torch.LongTensor, seq_len: int) -> torch.IntTensor:
cu_seqlen_ks_for_each_q = torch.gather(
input=torch.cat([
cu_seqlens_ks,
torch.full((1,),
torch.iinfo(torch.int32).max,
dtype=torch.int32,
device=cu_seqlens_qs.device)
]),
dim=0,
index=cal_seq_idx_for_q(
cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long())
return cu_seqlen_ks_for_each_q.int()
@tensor_cache
def cal_cu_seqlen_ke_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor,
cu_seqlens_ks: torch.LongTensor, cu_seqlens_ke: torch.LongTensor,
q_start_idxs: torch.LongTensor, seq_len: int,
kv_stride: int) -> torch.IntTensor:
cu_seqlen_ke_for_each_q = torch.gather(
input=torch.cat(
[cu_seqlens_ke,
torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]),
dim=0,
index=cal_seq_idx_for_q(
cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long())
casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,),
dtype=torch.int32,
device=cu_seqlens_qs.device)
for i in range(len(cu_seqlens_qs)):
casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = (torch.arange(
q_start_idxs[i],
q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i],
dtype=torch.int32,
device=cu_seqlens_qs.device) + 1) // kv_stride + cu_seqlens_ks[i]
cu_seqlen_ke_for_each_q = torch.minimum(casual_cu_seqlen_ke_for_each_q, cu_seqlen_ke_for_each_q)
return cu_seqlen_ke_for_each_q.int()
@tensor_cache
def cal_ks_ke_from_cu_seqlen_qk(cu_seqlens_q: torch.LongTensor,
cu_seqlens_k: torch.LongTensor = None,
offs_q: torch.LongTensor = None,
*,
seq_len: int,
kv_stride: int = 1,
cp_rank: int = 0,
cp_size: int = 1,
balanced_cp=False):
'''
seq_len: seq len per cp rank
balanced cp slice assignment: 0 1 2 3 3 2 1 0
'''
n_seq = len(cu_seqlens_q) - 1
assert n_seq > 0
assert cu_seqlens_q.shape == (n_seq + 1,)
seq_idx = cal_seq_idx_from_cu_seqlens(cu_seqlens_q.long(), seq_len * cp_size)
qs = cu_seqlens_q.gather(0, seq_idx)
pos = torch.arange(len(qs), dtype=qs.dtype, device=qs.device) - qs
if offs_q is not None:
assert offs_q.shape == (n_seq,), offs_q.shape
qoff = offs_q.gather(0, seq_idx)
pos += qoff
if cu_seqlens_k is None or cu_seqlens_k is cu_seqlens_q:
ks = qs
else:
assert cu_seqlens_k.shape == (n_seq + 1,)
ks = cu_seqlens_k.gather(0, seq_idx)
ke = ks + (pos + 1) // kv_stride
if cp_size == 1:
pass
elif balanced_cp:
assert cp_size % 2 == 0, cp_size
def f(x: torch.Tensor):
chunks = x.chunk(cp_size * 2)
return torch.cat([
chunks[cp_rank],
chunks[cp_size - cp_rank - 1],
])
ks = f(ks)
ke = f(ke)
else:
ks = ks.chunk(cp_size)[cp_rank]
ke = ke.chunk(cp_size)[cp_rank]
return ks, ke
def ceil_to_ue8m0(x: torch.Tensor):
assert x.view(-1).amax().item() > 0
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int],
use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
sf = x_amax / 448.0
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled, sf.squeeze()
def generate_random_cu_seqlens(per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1, average_q_len=512):
total_seqlen = per_cp_seqlen * cp_size
cu_seqlens = torch.randint(0, average_q_len * 2, (total_seqlen // average_q_len * 2,)).cuda()
last_seq_id = torch.where(cu_seqlens.cumsum(0) >= total_seqlen)[0][0]
cu_seqlens = cu_seqlens[:last_seq_id]
if cu_seqlens.sum() < total_seqlen:
cu_seqlens = torch.cat([cu_seqlens, torch.tensor([total_seqlen - cu_seqlens.sum()]).cuda()])
cu_seqlens_cumsum = torch.cumsum(cu_seqlens, dim=0)
cu_seqlens_k_cumsum = torch.cumsum(cu_seqlens // kv_stride, dim=0)
cu_seqlens_qs = torch.cat([torch.tensor([0]).cuda(), cu_seqlens_cumsum[:-1]])
cu_seqlens_ks = torch.cat([torch.tensor([0]).cuda(), cu_seqlens_k_cumsum[:-1]])
cu_seqlens_qe = cu_seqlens_cumsum.clone()
cu_seqlens_ke = cu_seqlens_k_cumsum.clone()
cu_seqlens_ks_for_each_q = cal_cu_seqlen_ks_for_q(
cu_seqlens_qs=cu_seqlens_qs,
cu_seqlens_qe=cu_seqlens_qe,
cu_seqlens_ks=cu_seqlens_ks,
seq_len=total_seqlen,
)
cu_seqlens_ke_for_each_q = cal_cu_seqlen_ke_for_q(
cu_seqlens_qs=cu_seqlens_qs,
cu_seqlens_qe=cu_seqlens_qe,
cu_seqlens_ks=cu_seqlens_ks,
cu_seqlens_ke=cu_seqlens_ke,
q_start_idxs=torch.zeros_like(cu_seqlens_qs),
seq_len=total_seqlen,
kv_stride=kv_stride,
)
assert per_cp_seqlen % 2 == 0
per_chunk_seqlen = per_cp_seqlen // 2
slice_short = slice(cp_rank * per_chunk_seqlen, (cp_rank + 1) * per_chunk_seqlen)
slice_long = slice(
total_seqlen - (cp_rank + 1) * per_chunk_seqlen,
total_seqlen - cp_rank * per_chunk_seqlen,
)
ks = torch.cat([
cu_seqlens_ks_for_each_q[slice_short],
cu_seqlens_ks_for_each_q[slice_long],
])
ke = torch.cat([
cu_seqlens_ke_for_each_q[slice_short],
cu_seqlens_ke_for_each_q[slice_long],
])
assert len(ks) == len(ke) == per_cp_seqlen
return ks, ke
def calculate_tensor_similarity(x, y, name="tensor"):
"""
Calculate similarity between two tensors using a normalized dot product metric.
Unlike torch.testing.assert_close which uses absolute/relative tolerance based on
element-wise differences, this function computes a global similarity score:
sim = 2 * <x, y> / (||x||^2 + ||y||^2)
This metric is scale-invariant and measures the cosine-like similarity normalized
by the magnitude of both tensors. It returns 1 for identical tensors and values
closer to 0 for dissimilar ones. This is particularly useful for comparing tensors
with varying magnitudes where relative errors matter more than absolute differences.
Args:
x: First tensor to compare
y: Second tensor to compare
name: Name of the tensor for logging purposes
Returns:
Similarity score in range [0, 1] where 1 means identical
"""
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print(f"\033[33mWARNING: {name} all zero\033[0m")
return 1
sim = 2 * (x * y).sum() / denominator
return sim
def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True):
"""
Assert that two tensors are similar using a global similarity metric.
Key differences from torch.testing.assert_close:
- torch.testing.assert_close: Uses element-wise comparison with rtol/atol, checking
that |x - y| <= atol + rtol * |y| for each element. It's sensitive to outliers
and requires all elements to satisfy the tolerance.
- assert_tensors_similar: Uses a single global similarity score (1 - sim) where sim is the
normalized dot product. It's more robust to outliers and focuses on overall
tensor similarity rather than element-wise precision. This is better suited for
comparing large tensors where a few outlier elements shouldn't fail the test.
Args:
x: First tensor to compare
y: Second tensor to compare
eps: Maximum allowed difference (1 - similarity), default 1e-8
name: Name of the tensor for error messages
raise_assert: Whether to raise assertion error on failure
"""
sim = calculate_tensor_similarity(x, y, name)
diff = 1. - sim
if not (0 <= diff <= eps):
print(
f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m"
)
if raise_assert:
assert False # noqa: B011
if __name__ == "__main__":
seq_len = 32768
cu_seqlens = torch.randint(128, 4096, (1000,), dtype=torch.int32, device="cuda")
last_idx = torch.where(cu_seqlens.cumsum(dim=0) >= seq_len)[0][0]
cu_seqlens_cumsum = cu_seqlens[:last_idx].cumsum(dim=0)
cu_seqlens_qs = torch.cat(
[torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), cu_seqlens_cumsum])
cu_seqlens_qe = torch.cat(
[cu_seqlens_cumsum,
torch.ones(1, dtype=torch.int32, device=cu_seqlens.device) * seq_len])
from tilelang.profiler import do_bench
fn = lambda: cal_seq_idx_for_q(cu_seqlens_qs, cu_seqlens_qe, seq_len) # noqa: E731
ms = do_bench(fn, warmup=25, rep=100)
### Dequantization GEMM
An example of implementing a dequantization GEMM:
```python
@T.prim_func
def dequant_matmul(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
T.clear(Ct_local)
for k in T.Pipelined(
T.ceildiv(K, block_K),
num_stages=num_stages
):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_packed_to_unsigned_convert("int", 8)(
num_bits,
B_local[i, j // 2],
j % 2,
dtype=in_dtype,
)
T.gemm(B_dequantize_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct[bx * block_N, by * block_M])
```
**Notes:** Dequantize GEMM with magic layout transformations to get optimal performance can be found at project [BitBLAS](https://github.com/microsoft/BitBLAS), example kernels can be found at `testing/python/kernel/test_tilelang_dequantize_gemm.py`, detailed explanation and examples is coming soon.
import torch
def torch_convert_bit_twiddling(tensor):
"""
This function expects `tensor` to be a 2-D torch.Tensor of dtype `torch.uint8`. Each output element is produced by combining two input bytes and extracting a bf16-like 16-bit pattern according to one of four positional bit layouts (pos 0..3). The result is scaled by 2**126 to adjust the exponent bias and returned as dtype `torch.bfloat16`.
Parameters:
tensor (torch.Tensor): 2-D input tensor with dtype `torch.uint8`. Shape (N, K).
Returns:
torch.Tensor: New tensor of dtype `torch.bfloat16` with shape (N, K*2), where each input column pair produces two bf16 output columns.
Raises:
AssertionError: If any byte inputs used for a conversion are not dtype `torch.uint8`.
"""
assert tensor.dim() == 2 and tensor.dtype == torch.uint8
N, K = tensor.shape
assert K % 2 == 0, "Number of columns must be even"
# Combine pairs of uint8 values into uint32 for safe bitwise ops on CUDA
val0 = tensor[:, 0::2].to(torch.int32)
val1 = tensor[:, 1::2].to(torch.int32)
val_concat = (val0 << 8) | val1 # (N, K//2), uint32
# Expand to match output shape where each pair generates 4 values
val_concat_expanded = val_concat.repeat_interleave(4, dim=1) # (N, K//2*4)
# Positional encoding for bit-twiddling logic
pos = torch.arange(K * 2, device=tensor.device) % 4 # (K*2,)
# Bit masks for decoding (as uint32 for CUDA compatibility)
mask = 0b1000000111000000
mask1 = 0b1000000000000000
mask2 = 0b0000000110000000
mask3 = 0b0000000001000000
# Calculate results for all 4 positions in parallel
res0 = val_concat_expanded & mask
res1 = (val_concat_expanded << 3) & mask
res2 = (val_concat_expanded << 6) & mask
res3 = ((val_concat_expanded << 1) & mask1) | ((val_concat_expanded >> 3) & mask2) | (
(val_concat_expanded >> 7) & mask3)
# Select the correct result based on position
bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1,
torch.where(pos == 2, res2, res3)))
# Convert to uint16 for .view(torch.bfloat16)
bf16_uint16 = (bf16 & 0xFFFF).to(torch.uint16)
bf16_bf16 = bf16_uint16.view(torch.bfloat16)
# Avoid integer overflow by using a float32 multiplier for the exponent scaling
bf16_new = bf16_bf16 * (2.0**126)
return bf16_new
def torch_convert(tensor, scale_size=None, Scale=None):
"""
Decode a 2D uint8 tensor into a 2D bfloat16 tensor by expanding each byte into two bf16 values using a 4-bit (nibble) encoding.
Each input byte holds two 4-bit encoded values (low and high nibble). For each nibble this function derives sign/scale bits, a 3-bit exponent fragment and a 1-bit mantissa fragment, assembles a 16-bit bf16 pattern, and returns the resulting tensor with shape (N, K*2) and dtype torch.bfloat16 on the same device as the input.
Parameters:
tensor (torch.Tensor): 2D tensor of dtype torch.uint8 and shape (N, K). Each byte contains two encoded 4-bit entries that become two bf16 values.
scale_size (int, optional): If provided, controls how elements of the optional Scale tensor are indexed. When supplied, per-output-element scaling is applied to the exponent using Scale.
Scale (torch.Tensor, optional): A 2D tensor used to supply per-element integer scale adjustments to the exponent. If scale_size is provided, the scale used for output element (i, j) is Scale[i][j // scale_size].
Returns:
torch.Tensor: A new tensor of shape (N, K*2) and dtype torch.bfloat16 containing the decoded bf16 values.
"""
def _convert(val, pos, scale=None):
assert val.dtype == torch.uint8
# val = val.view(torch.int8)
mask = (1 << 4) - 1
f4 = ((val >> (pos * 4)) & mask).to(torch.int16)
s = f4 >> 3
e_f4 = (f4 & 6) >> 1
e_f16 = e_f4 + 126
if scale is not None:
e_f16 = min(e_f16 + scale, (1 << 8) - 1)
m_f4 = f4 & 1
m_f16 = m_f4
val_f16 = (((e_f16 | (s << 8)) << 7) | (m_f16 << 6)) & 0xFFFF
lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16)
return lower_16_bits.view(torch.bfloat16)
N = tensor.shape[0]
K = tensor.shape[1]
new_tensor = torch.empty(N, K * 2, dtype=torch.bfloat16, device=tensor.device)
for i in range(new_tensor.shape[0]):
for j in range(new_tensor.shape[1]):
if scale_size is not None:
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2, Scale[i][j // scale_size])
else:
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
return new_tensor
def print_bit(name, val):
"""
Print the 32-bit binary representation of a CPU scalar extracted from a PyTorch tensor.
Converts `val` to CPU, reads its Python scalar with `.item()`, formats it as a 32-bit binary string, and prints it prefixed by `name`.
Parameters:
name (str): Label printed before the binary representation.
val (torch.Tensor): A scalar PyTorch tensor (numeric) whose 32-bit binary representation will be shown.
"""
val_cpu = val.cpu().item()
binary_repr = f'{val_cpu:032b}'
print(name, binary_repr)
def print_red_warning(message):
print(f"\033[31mWARNING: {message}\033[0m")
def calc_sim(x, y, name="tensor"):
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print_red_warning(f'{name} all zero')
return 1
sim = 2 * (x * y).sum() / denominator
return sim
def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
x_mask = torch.isfinite(x)
y_mask = torch.isfinite(y)
if not torch.all(x_mask == y_mask):
print_red_warning(f'{name} Error: isfinite mask mismatch')
if raise_assert:
raise AssertionError
if not torch.isclose(
x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0,
equal_nan=True).all():
print_red_warning(f'{name} Error: nonfinite value mismatch')
if raise_assert:
raise AssertionError
x = x.masked_fill(~x_mask, 0)
y = y.masked_fill(~y_mask, 0)
sim = calc_sim(x, y, name)
diff = (1. - sim).item()
print(f'{diff=}')
if not (0 <= diff <= eps):
print_red_warning(f'{name} Error: {diff=}')
if raise_assert:
raise AssertionError
import tilelang
import tilelang.language as T
from tilelang import tvm as tvm
from tvm import DataType
from tvm import tir
import torch
from dequantize_utils import torch_convert_bit_twiddling, torch_convert
def get_configs():
"""
Return a list of tuning configuration dictionaries for the autotuned matmul kernel.
Each dictionary is a single combination (Cartesian product) of the following parameters:
- block_M: tile size for M dimension (one of 64, 128, 256)
- block_N: tile size for N dimension (one of 64, 128, 256)
- block_K: tile size for K dimension
- num_stages: pipeline stages for K-loop (0 or 2)
- threads: number of threads to launch (128, 256, or 512)
- split: K-splitting factor (1 or 2)
Returns:
list[dict]: List of configuration dicts usable by the autotuner, where each dict maps
the parameter name to its chosen value.
"""
import itertools
iter_params = dict(
block_M=[64, 128, 256],
block_N=[64, 128, 256],
block_K=[128],
num_stages=[0, 2],
threads=[128, 256, 512],
split=[1, 2],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs(),)
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
},
)
def matmul(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
source_format='uint',
num_bits=4,
fast_dequant=True,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1):
"""
Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T.
This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts:
- A: dense input of shape (M, K) with dtype `in_dtype`.
- B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`.
- C: output of shape (M, N) with dtype `out_dtype`.
The generated kernel supports two dequantization paths:
- fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group.
- simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element.
Important behavior and requirements:
- num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits.
- QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes.
- Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid.
- When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group.
- The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages.
Parameters that alter kernel layout/behavior (brief):
- block_M, block_N, block_K: tile sizes for M, N, and K dimensions.
- num_stages: number of software pipeline stages for the K-loop.
- threads: number of threads used per kernel block.
- split: extra K-splitting factor; K must be divisible by block_K * split.
- source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics.
Returns:
A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel.
"""
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
QK = K // num_elems_per_byte
Block_QK = block_K // num_elems_per_byte
A_shape = (M, K)
B_shape = (N, QK)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, Block_QK)
B_dequantize_shared_shape = (block_N, block_K)
assert K % (block_K * split) == 0
from tilelang.quantize import get_mxfp_intrin_group
# fast_dequant_bf16_fp4_twiddling
# It requires that the 2 consecutive uint8 elements (16bits) contains 4 fp4 elements in a bit-twiddling way.
# The bit-twiddling way is shown here: The pair (x,y) shows that the bit in this position is the y-th bit of the x-th fp4 element.
# (0,0)(3,0)(3,3)(1,0)(3,1)(3,2)(2,0)(0,1)(0,2)(0,3)(1,1)(1,2)(1,3)(2,1)(2,2)(2,3)
mxfp_intrin_info = get_mxfp_intrin_group(
out_dtype=in_dtype,
source_format=source_format,
source_bit=num_bits,
storage_dtype=storage_dtype,
use_twiddling=True,
)
import_source = mxfp_intrin_info["c_source"]
func_name = mxfp_intrin_info["func_name"]
assert import_source is not None, "mxfp_intrin_info is not found"
assert func_name is not None, "mxfp_intrin_info is not found"
import_source = import_source
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"):
"""
Create a TileLang macro that performs fast, twiddling-based dequantization from packed FP4 to BF16 using an external runtime plugin.
This function validates the requested input/output datatypes and returns a TileLang `@T.macro` named `fast_dequant_bf16_fp4_twiddling` which:
- Loads compressed FP4 bytes from a shared buffer into per-thread local registers (vectorized loads).
- Invokes an external dequantization routine (via `T.call_extern`) to expand the packed FP4 values into BF16 in registers.
- Writes the dequantized BF16 values back to a shared dequantized buffer for use by the kernel.
Notes and preconditions:
- Asserts that `in_dtype == "fp4"` and `out_dtype == "bfloat16"`.
- The generated macro depends on several surrounding-scope symbols (e.g., `import_source`, `func_name`, `block_K`, `Block_QK`, `threads`, `num_elems_per_byte`, `storage_dtype`, and `out_dtype`) and expects them to be defined consistently in the enclosing kernel.
- The macro is optimized for block-wise, per-thread transactions sized to the target storage width (uses a MAX_TRANSACTION_SIZE_BITS constant) and uses local/register buffers sized accordingly.
- The macro uses `T.import_source` to bring the external plugin into the module and `T.call_extern` to perform the high-throughput dequantization; callers must ensure the external function matches the expected calling convention and memory layout.
"""
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
# Some variables for dequantization in each thread
MAX_TRANSACTION_SIZE_BITS = 128
local_size = MAX_TRANSACTION_SIZE_BITS // DataType(out_dtype).bits
local_compress_size = local_size // num_elems_per_byte
@T.macro
def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared):
# import fast_dequantize plugin
"""
Fast dequantization kernel routine that converts packed FP4 values in shared memory to BF16 and writes the results back into a shared dequantized buffer.
This function is intended to run inside a tiled GPU kernel: each thread loads a small packed segment from the quantized shared buffer `B_shared` into a per-thread local register buffer, calls an external dequantization routine (provided by the runtime plugin imported from `import_source` and identified by `func_name`) to expand the packed values to BF16 in a per-thread local output buffer, and stores the expanded values into `B_dequantize_shared`. It performs vectorized per-thread loads and stores and is sized according to the surrounding kernel's tiling and threading parameters.
Parameters:
B_shared: Shared-memory buffer containing packed quantized values (packed FP4 layout).
B_dequantize_shared: Shared-memory buffer to receive dequantized BF16 values (written in-place by this routine).
Side effects:
- Imports the external dequantization plugin via `import_source` and invokes `func_name`.
- Writes dequantized BF16 results into `B_dequantize_shared`.
Notes:
- This routine expects the surrounding kernel to define and provide the tiling/threading constants (e.g., thread count, local buffer sizes, block dimensions) and the runtime plugin identifiers (`import_source`, `func_name`).
- No value is returned; results are produced by mutation of `B_dequantize_shared`.
"""
T.import_source(import_source)
tx = T.get_thread_binding()
B_local_thread = T.alloc_local((local_compress_size,), storage_dtype)
B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype)
for i in T.serial(0, block_N * block_K // threads // local_size):
# First, load data from share memory to register.
# Prepare for dequant.
for v in T.vectorized(0, local_compress_size):
index = i * threads * local_compress_size + tx * local_compress_size + v
B_local_thread[v] = B_shared[index // Block_QK, index % Block_QK]
# Then, dequant.
T.call_extern(
func_name,
T.address_of(B_local_thread[0]),
T.address_of(B_dequantize_local_thread[0]),
1,
dtype=out_dtype,
)
# Finally, store the dequantized data to shared memory.
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
B_dequantize_shared[index // block_K,
index % block_K] = B_dequantize_local_thread[v]
return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"):
"""
Create a simple TIR dequantization macro that converts packed 4-bit FP (FP4) stored in uint8 into bfloat16.
The returned macro (named `simple_dequant_bf16_fp4`) expects B_shared and B_dequantize_shared buffers (shapes and a few loop/constant names like
`B_shared_shape`, `B_dequantize_shared_shape`, `storage_dtype`, `out_dtype`, `num_bits`, `num_elems_per_byte`, `block_N`, and `block_K`) to be available in the surrounding TIR scope. It:
- Unpacks 4-bit FP values from the packed uint8 representation in B_shared.
- Converts each 4-bit value to a bfloat16 element using an internal helper `_tir_u8_to_f4_to_bf16`.
- Writes the dequantized bfloat16 block into B_dequantize_shared.
Constraints:
- Supports only in_dtype="fp4" and out_dtype="bfloat16".
- The helper assumes nbit == 4 and produces bfloat16 values.
- The macro uses a fixed test-scale of 0 (no per-element scaling) as written.
Returns:
A TIR macro function performing the described in-place block dequantization from packed uint8 FP4 to bfloat16.
"""
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr,
scale: tir.PrimExpr, dtype: str):
"""
Convert a 4-bit FP4 value packed in a uint8 byte into a bfloat16 value.
This helper extracts the 4-bit field located at the bit position `pos` within the
byte `val`, interprets it as an FP4 (sign, exponent, mantissa) value, applies an
exponent `scale` offset to align it with bfloat16 exponent bias, clamps the
resulting exponent to 8 bits, and returns the assembled bfloat16 bit pattern.
Parameters:
nbit (int): Number of bits in the packed element; must be 4.
val (tir.PrimExpr): A uint8 value containing packed FP4 elements.
pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract.
scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16.
dtype (str): Target dtype string; must be "bfloat16".
Returns:
tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value.
Notes:
- The function asserts `nbit == 4`, `dtype == "bfloat16"`, and that `val.dtype` is "uint8".
- The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16
bit fields and clamps the computed exponent to fit into 8 bits.
"""
assert nbit == 4
assert dtype == "bfloat16"
assert val.dtype == "uint8"
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16")
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16")
# Scale is the exponential part, within the representation of uint8
# To handle the overflow, we use the max function to limit the exponential part to 8 bits
e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16")
val_bf16 = tir.reinterpret(
"bfloat16", ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16"))
| (m_f4 << tir.const(6, "uint16"))).astype("uint16"))
return val_bf16
@T.macro
def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared):
"""
Dequantize a packed FP4 uint8 shared buffer into BF16 and store the result into a shared dequantized buffer.
This helper:
- Loads B_shared into a local fragment, converts each packed FP4 element to BF16 using `_tir_u8_to_f4_to_bf16`, and writes the dequantized values into B_dequantize_shared.
- Iterates in parallel over the logical block columns (block_N) and block_K, unpacking elements from bytes using `num_elems_per_byte`.
- Uses a fixed scale of 0 in the conversion (placeholder for testing); `num_bits` and `num_elems_per_byte` are expected to be available from the enclosing scope.
Parameters:
B_shared: shared-memory buffer containing packed FP4 data (uint8-packed).
B_dequantize_shared: shared-memory buffer to receive BF16 dequantized values.
Side effects:
Writes dequantized BF16 values into B_dequantize_shared. No return value.
"""
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16(
num_bits,
B_shared[i, j // num_elems_per_byte],
j % num_elems_per_byte,
0, # No scale for test
dtype=out_dtype,
)
T.copy(B_dequantize_local, B_dequantize_shared)
return simple_dequant_bf16_fp4
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
):
"""
Kernel entry for the tiled, pipelined matmul used by the generated prim_func.
This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it:
- Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile.
- Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout.
- Pipelines over K in chunks of `block_K` for `num_stages` stages:
- Loads A and packed B tiles into shared memory.
- Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine.
- Performs a GEMM accumulating into C_local with B transposed.
- Stores the accumulated block from C_local back to the global output C via C_shared.
Parameters:
- A: input tile of shape (M, K) with dtype `in_dtype`.
- B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing).
- C: output tensor of shape (M, N) with dtype `out_dtype`.
Side effects:
- Writes the computed output block into the global tensor `C`.
- Uses and updates shared memory buffers and per-thread accumulators.
No value is returned.
"""
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
T.annotate_layout({
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
})
T.clear(C_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
if fast_dequant:
get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared)
else:
get_simple_dequant_func()(B_shared, B_dequantize_shared)
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N])
return main
def ref_program_twiddling(A, qB):
"""
Compute reference BF16 matrix multiply using bit-twiddled FP4 quantized B.
Converts qB (a bit-twiddled, packed FP4 representation of matrix B) back to floating,
performs C = A @ B^T in full precision, and returns the result converted to bfloat16.
Parameters:
A (torch.Tensor): Left operand with shape (M, K). Treated as floating-point (converted to torch.float for compute).
qB (torch.Tensor): Bit-twiddled, packed FP4 representation of B (quantized). Shape corresponds to B's packed layout.
Returns:
torch.Tensor: Result matrix C with shape (M, N) in bfloat16.
"""
dtypeC = "bfloat16"
B = torch_convert_bit_twiddling(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
def ref_program_simple(A, qB):
"""
Compute a reference BF16 matrix multiply using a simple (non-twiddled) dequantization of qB.
Converts the quantized tensor `qB` to full-precision values via `torch_convert`, computes C = A @ B^T in float32, and casts the result to bfloat16 before returning.
Parameters:
A (torch.Tensor): Left input matrix with shape (M, K).
qB (torch.Tensor): Quantized representation of the right matrix; expected to be compatible with `torch_convert` and represent a matrix whose transpose will be multiplied by A.
Returns:
torch.Tensor: Resulting matrix C in bfloat16 with shape (M, N).
"""
dtypeC = "bfloat16"
B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
def main(m=256, n=256, k=256, fast_dequant=True, tune=False):
"""
Run and benchmark the tiled, optionally autotuned FP4->BF16 GEMM kernel and validate results against a PyTorch reference.
This function builds a matmul kernel (either with autotuning or fixed tiling), obtains a profiler, validates numerical correctness against the appropriate reference implementation (bit-twiddled fast dequantization or simple dequantization), and runs a benchmark that prints measured latency (ms) and effective TFLOPs.
Parameters:
m (int): Number of rows of A and output C (default 256).
n (int): Number of columns of B and output C (default 256).
k (int): Inner dimension (columns of A, rows of B) (default 256).
fast_dequant (bool): If True use the fast twiddling dequantization path and validate against the twiddling reference; otherwise use the simple dequant path (default True).
tune (bool): If True build the kernel with autotuning configurations; if False use a fixed tiling and threading configuration for reproducible benchmarking (default False).
Side effects:
- Prints latency and TFLOPs to stdout.
- Raises an assertion via the profiler if the kernel's outputs do not match the chosen reference within the tolerances (rtol=0.01, atol=0.01).
"""
total_flops = 2 * m * n * k
if tune:
kernel = matmul(
m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, fast_dequant=fast_dequant)
else:
kernel = matmul(
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
num_bits=4,
fast_dequant=fast_dequant,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto)
if fast_dequant:
profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01)
else:
profiler.assert_allclose(ref_program_simple, rtol=0.01, atol=0.01)
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
main(256, 256, 256, True)
main(256, 256, 256, False)
import tilelang
import tilelang.language as T
from tilelang import tvm as tvm
from tvm import DataType
from tvm import tir
import torch
from dequantize_utils import torch_convert_bit_twiddling, torch_convert
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr,
dtype: str):
"""
Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale.
This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its
bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by
`scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation.
Parameters:
nbit (int): Number of bits in the packed field (must be 4).
val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields.
pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field).
scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like).
dtype (str): Destination dtype string (must be "bfloat16").
Returns:
tir.PrimExpr: The resulting value reinterpreted as `bfloat16`.
Notes:
- Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8".
- The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern.
"""
assert nbit == 4
assert dtype == "bfloat16"
assert val.dtype == "uint8"
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16")
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16")
# Scale is the exponential part, within the representation of uint8
# To handle the overflow, we may use the min function to limit the exponential part to 8 bits
# e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16")
val_bf16 = tir.reinterpret("bfloat16",
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16"))
| (m_f4 << tir.const(6, "uint16"))).astype("uint16"))
return val_bf16
def get_configs():
"""
Generate a list of hyperparameter configuration dictionaries for tuning.
Each configuration is a dict with keys: 'block_M', 'block_N', 'block_K',
'num_stages', 'threads', and 'split'. The function returns the Cartesian
product of the parameter value lists:
- block_M, block_N, block_K: tiling sizes (64, 128, 256)
- num_stages: pipeline stages (0, 2)
- threads: thread counts (128, 256, 512)
- split: K-splitting factor (1, 2)
Returns:
List[dict]: A list of configuration dictionaries covering all combinations.
"""
import itertools
iter_params = dict(
block_M=[64, 128, 256],
block_N=[64, 128, 256],
block_K=[64, 128, 256],
num_stages=[0, 2],
threads=[128, 256, 512],
split=[1, 2],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs(),)
@tilelang.jit(out_idx=[-1],)
def matmul(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
source_format='uint',
num_bits=4,
scale_size=32,
fast_dequant=True,
with_bias=False,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1):
"""
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
The generated kernel accepts:
- A: dense matrix with element type `in_dtype`.
- B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)).
- Scale: per-block scale/exponent information used to dequantize B.
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., "bfloat16").
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
scale_size (int, optional): number of elements grouped per scale entry (default 32).
fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True).
block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128).
num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1).
with_bias (bool, optional): whether to add Bias to the output (default False).
Returns:
A T.prim_func implementing the tiled, pipelined GEMM that:
- loads tiled blocks of A and packed B to shared memory,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- writes the final MxN block to the global output tensor.
Notes:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
- An assertion enforces that K % (block_K * split) == 0.
"""
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
QK = K // num_elems_per_byte
Block_QK = block_K // num_elems_per_byte
A_shape = (M, K)
B_shape = (N, QK)
Bias_shape = (M, N)
Scale_shape = (N, K // scale_size)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, Block_QK)
Bias_shared_shape = (block_M, block_N)
B_dequantize_shared_shape = (block_N, block_K)
assert K % (block_K * split) == 0
from tilelang.quantize import get_mxfp_intrin_group
# fast_dequant_bf16_fp4_twiddling
mxfp_intrin_info = get_mxfp_intrin_group(
out_dtype=in_dtype,
source_format=source_format,
source_bit=num_bits,
storage_dtype=storage_dtype,
use_twiddling=True,
)
import_source = mxfp_intrin_info["c_source"]
func_name = mxfp_intrin_info["func_name"]
assert import_source is not None, "mxfp_intrin_info is not found"
assert func_name is not None, "mxfp_intrin_info is not found"
import_source = import_source
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"):
"""
Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16.
The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and:
- Loads packed FP4 elements from B_shared into per-thread local registers.
- Calls an external fast dequantization intrinsic (provided via `import_source` / `func_name` in the outer scope) to expand packed FP4 -> BF16 values.
- Applies a per-block scale factor derived from the Scale tensor (using exponentiation by powers of two).
- Writes the scaled BF16 results into B_dequantize_shared.
Notes:
- This factory only supports in_dtype="fp4" and out_dtype="bfloat16".
- The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro.
- The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime.
"""
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
# Some variables for dequantization in each thread
MAX_TRANSACTION_SIZE_BITS = 128
local_size = MAX_TRANSACTION_SIZE_BITS // DataType(out_dtype).bits
local_compress_size = local_size // num_elems_per_byte
@T.macro
def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale, k):
# import fast_dequantize plugin
"""
Fast dequantization kernel: convert packed 4-bit quantized values in B_shared to bfloat16
in B_dequantize_shared using an external intrinsic optimized for twiddled (bit-packed) FP4,
applying per-block scale factors from Scale.
This routine is a tiled, thread-parallel helper that:
- Imports and calls an external dequantization function (via `import_source`/`func_name`)
to expand compressed uint8-packed FP4 values into BF16 fragments in-thread.
- Loads the corresponding per-block scale entry, interprets it as an exponent bias
(applies 2^(Scale - 127)), and multiplies the dequantized BF16 fragment by that factor.
- Writes the scaled BF16 results back into the shared B_dequantize_shared buffer in-place.
Parameters:
- B_shared: read-only shared buffer containing compressed FP4 data (packed uint8 layout).
- B_dequantize_shared: shared output buffer that is overwritten with BF16 dequantized values.
- Scale: per-block scale tensor; entries are interpreted such that the multiplicative scale
= 2^(Scale - 127).
- k: block index along the K dimension used to select the appropriate Scale entries.
Side effects:
- Mutates B_dequantize_shared in shared memory.
- Calls an external intrinsic function (must be provided by the environment via `import_source`
and `func_name`) to perform the low-level unpacking/dequantization.
"""
T.import_source(import_source)
tx = T.get_thread_binding()
bx = T.get_block_binding(0)
B_local_thread = T.alloc_local((local_compress_size,), storage_dtype)
B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype)
Scale_local_thread = T.alloc_local((1,), storage_dtype)
Scale_local_thread_exponent = T.alloc_local((1,), out_dtype)
for i in T.serial(0, block_N * block_K // threads // local_size):
# First, load data from share memory to register.
# Prepare for dequant.
index_base = i * threads * local_compress_size + tx * local_compress_size
for v in T.vectorized(0, local_compress_size):
index = index_base + v
B_local_thread[v] = B_shared[index // Block_QK, index % Block_QK]
index_scale = index_base // (scale_size // num_elems_per_byte)
si = index_scale // (block_K // scale_size)
sj = index_scale % (block_K // scale_size)
Scale_local_thread[0] = Scale[bx * block_N + si, k * block_K // scale_size + sj]
Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0]))
# Then, dequant.
T.call_extern(
func_name,
T.address_of(B_local_thread[0]),
T.address_of(B_dequantize_local_thread[0]),
1,
dtype=out_dtype,
)
# Finally, store the dequantized data to shared memory.
for v in T.Parallel(local_size):
B_dequantize_local_thread[v] *= Scale_local_thread_exponent[0]
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
B_dequantize_shared[index // block_K,
index % block_K] = B_dequantize_local_thread[v]
return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"):
"""
Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16.
Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared.
Notes:
- Only supports in_dtype="fp4" and out_dtype="bfloat16".
- The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion.
- Scale holds the exponent-like scaling values indexed per output element as used by the conversion helper.
"""
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
@T.macro
def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k):
"""
Dequantizes a packed 4-bit (FP4) block from B_shared into BF16 values in B_dequantize_shared using per-element scale exponents.
Per-element behavior:
- Reads packed 4-bit entries from B_shared (uint8 storage, multiple nibbles per byte).
- Uses Scale to obtain an exponent term (stored as uint8) and reconstructs BF16 values via _tir_u8_to_f4_to_bf16.
- Writes the dequantized BF16 block into B_dequantize_shared.
Parameters:
- B_shared: shared-memory buffer holding packed 4-bit values (uint8-packed layout).
- B_dequantize_shared: shared-memory buffer to receive dequantized BF16 results.
- Scale: per-element exponent buffer; used to compute the scale factor for each dequantized element.
- k: current block index along the K dimension (used to select the appropriate slice of Scale).
Side effects:
- Mutates B_dequantize_shared by storing the dequantized BF16 fragment.
"""
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype)
bx = T.get_block_binding(0)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
Scale[
bx * block_N + i, k * block_K // scale_size + j //
scale_size], # Scale is the exponential part, within the representation of uint8
dtype=out_dtype,
) * T.shift_left(
1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size]))
T.copy(B_dequantize_local, B_dequantize_shared)
return simple_dequant_bf16_fp4
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Scale: T.Tensor(Scale_shape, storage_dtype),
Bias: T.Tensor(Bias_shape, out_dtype),
C: T.Tensor((M, N), out_dtype),
):
"""
Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C.
This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function.
Parameters are self-descriptive in the signature; notable behaviors:
- B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM.
- The selected dequantization path is controlled by the outer-scope flag `fast_dequant`.
- The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization).
- The function writes results in-place into C.
"""
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
Bias_shared = T.alloc_shared(Bias_shared_shape, out_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
T.annotate_layout({
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
})
if with_bias:
T.annotate_layout({
Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared),
})
if threads == 512:
T.disable_warp_group_reg_alloc()
if with_bias:
T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N],
Bias_shared)
T.copy(Bias_shared, C_local)
else:
T.clear(C_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
if fast_dequant:
get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale, k)
else:
get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale, k)
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N])
return main
def ref_program_twiddling(A, qB, Scale, Bias=None):
"""
Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A · B^T in float, and casts the result to bfloat16.
Parameters:
A (torch.Tensor): Left operand with shape (M, K), used in floating precision.
qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling.
Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B.
Returns:
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
"""
dtypeC = "bfloat16"
B = torch_convert_bit_twiddling(qB)
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
def ref_program_twiddling_with_bias(A, qB, Scale, Bias):
"""
Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A · B^T in float, and casts the result to bfloat16.
Parameters:
A (torch.Tensor): Left operand with shape (M, K), used in floating precision.
qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling.
Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B.
Bias (torch.Tensor): Bias tensor with shape (M, N).
Returns:
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
"""
dtypeC = "bfloat16"
B = torch_convert_bit_twiddling(qB)
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC))
return C
def ref_program_simple(A, qB, Scale, Bias=None):
"""
Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A · B^T and returns the result converted to bfloat16.
Parameters:
- A: 2D tensor representing the left operand (will be cast to float32 for the matmul).
- qB: Quantized representation of B accepted by `torch_convert`.
- Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32.
Returns:
- 2D bfloat16 tensor C containing the matrix product A · B^T.
No in-place modification is performed on inputs (a local floating copy of B is scaled).
"""
dtypeC = "bfloat16"
B = torch_convert(qB)
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
def ref_program_simple_with_bias(A, qB, Scale, Bias):
"""
Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A · B^T and returns the result converted to bfloat16.
Parameters:
Returns:
- A: 2D tensor representing the left operand (will be cast to float32 for the matmul).
- qB: Quantized representation of B accepted by `torch_convert`.
- Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32.
- Bias: 2D tensor representing the Bias (will be cast to float32 for the matmul).
Returns:
- 2D bfloat16 tensor C containing the matrix product A · B^T.
No in-place modification is performed on inputs (a local floating copy of B is scaled).
"""
dtypeC = "bfloat16"
B = torch_convert(qB)
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC))
return C
def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, tune=False):
"""
Run and validate the tiled quantized matmul kernel, then benchmark its latency and report TFLOPS.
Builds a matmul kernel for the given matrix sizes and quantization scale size. If `tune` is True the kernel is obtained via the autotuning path; otherwise a fixed-parameter kernel is used. Validates numerical correctness against the appropriate reference implementation (bit-twiddling reference when `fast_dequant` is True, plain reference otherwise) with rtol/atol=0.01, prints a confirmation, then runs a benchmark (500 warmup iterations) and prints the measured latency (ms) and achieved TFLOPS.
Parameters:
m (int): Number of rows of A / output rows. Default 256.
n (int): Number of columns of B / output columns. Default 256.
k (int): Reduction dimension. Default 256.
scale_size (int): Size of the per-block scale vector used for dequantization. Default 32.
fast_dequant (bool): If True validate against the twiddling (fast dequant) reference and exercise the fast dequant path; otherwise use the simple dequant reference. Default True.
tune (bool): If True obtain a tuned/autotuned kernel; otherwise use a fixed-parameter kernel. Default False.
Returns:
None
"""
total_flops = 2 * m * n * k
if tune:
kernel = matmul(
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
num_bits=4,
scale_size=scale_size,
fast_dequant=fast_dequant,
with_bias=with_bias)
else:
kernel = matmul(
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
num_bits=4,
scale_size=scale_size,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1,
fast_dequant=fast_dequant,
with_bias=with_bias)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto)
if fast_dequant:
if with_bias:
profiler.assert_allclose(ref_program_twiddling_with_bias, rtol=0.01, atol=0.01)
else:
profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01)
else:
if with_bias:
profiler.assert_allclose(ref_program_simple_with_bias, rtol=0.01, atol=0.01)
else:
profiler.assert_allclose(ref_program_simple, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
M, N, K = 256, 256, 256
scale_size = 32
main(M, N, K, scale_size, fast_dequant=True, with_bias=True)
main(M, N, K, scale_size, fast_dequant=False, with_bias=True)
main(M, N, K, scale_size, fast_dequant=True, with_bias=False)
main(M, N, K, scale_size, fast_dequant=False, with_bias=False)
import tilelang
import tilelang.language as T
from tilelang import tvm as tvm
from tvm import DataType
from tvm import tir
import torch
from dequantize_utils import torch_convert_bit_twiddling, torch_convert
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr,
dtype: str):
"""
Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale.
This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its
bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by
`scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation.
Parameters:
nbit (int): Number of bits in the packed field (must be 4).
val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields.
pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field).
scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like).
dtype (str): Destination dtype string (must be "bfloat16").
Returns:
tir.PrimExpr: The resulting value reinterpreted as `bfloat16`.
Notes:
- Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8".
- The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern.
"""
assert nbit == 4
assert dtype == "bfloat16"
assert val.dtype == "uint8"
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16")
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16")
# Scale is the exponential part, within the representation of uint8
# To handle the overflow, we may use the min function to limit the exponential part to 8 bits
# e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16")
val_bf16 = tir.reinterpret("bfloat16",
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16"))
| (m_f4 << tir.const(6, "uint16"))).astype("uint16"))
return val_bf16
def get_configs():
"""
Generate a list of hyperparameter configuration dictionaries for tuning.
Each configuration is a dict with keys: 'block_M', 'block_N', 'block_K',
'num_stages', 'threads', and 'split'. The function returns the Cartesian
product of the parameter value lists:
- block_M, block_N, block_K: tiling sizes (64, 128, 256)
- num_stages: pipeline stages (0, 2)
- threads: thread counts (128, 256, 512)
- split: K-splitting factor (1, 2)
Returns:
List[dict]: A list of configuration dictionaries covering all combinations.
"""
import itertools
iter_params = dict(
block_M=[64, 128, 256],
block_N=[64, 128, 256],
block_K=[64, 128, 256],
num_stages=[0, 1, 2],
threads=[128, 256, 512],
split=[1, 2],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs(),)
@tilelang.jit(out_idx=[-1],)
def matmul(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
source_format='uint',
num_bits=4,
scale_size=32,
fast_dequant=True,
with_bias=False,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1):
"""
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
The generated kernel accepts:
- A: dense matrix with element type `in_dtype`.
- B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)).
- Scale: per-block scale/exponent information used to dequantize B.
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., "bfloat16").
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
scale_size (int, optional): number of elements grouped per scale entry (default 32).
fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True).
block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128).
num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1).
with_bias (bool, optional): whether to add Bias to the output (default False).
Returns:
A T.prim_func implementing the tiled, pipelined GEMM that:
- loads tiled blocks of A and packed B to shared memory,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- writes the final MxN block to the global output tensor.
Notes:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
- An assertion enforces that K % (block_K * split) == 0.
"""
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
QK = K // num_elems_per_byte
Block_QK = block_K // num_elems_per_byte
A_shape = (M, K)
B_shape = (N, QK)
Bias_shape = (M, N)
Scale_shape = (N, K // scale_size)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, Block_QK)
Bias_shared_shape = (block_M, block_N)
B_dequantize_shared_shape = (block_N, block_K)
assert K % (block_K * split) == 0
from tilelang.quantize import get_mxfp_intrin_group
# fast_dequant_bf16_fp4_twiddling
mxfp_intrin_info = get_mxfp_intrin_group(
out_dtype=in_dtype,
source_format=source_format,
source_bit=num_bits,
storage_dtype=storage_dtype,
use_twiddling=True,
)
import_source = mxfp_intrin_info["c_source"]
func_name = mxfp_intrin_info["func_name"]
assert import_source is not None, "mxfp_intrin_info is not found"
assert func_name is not None, "mxfp_intrin_info is not found"
import_source = import_source
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"):
"""
Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16.
The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and:
- Loads packed FP4 elements from B_shared into per-thread local registers.
- Calls an external fast dequantization intrinsic (provided via `import_source` / `func_name` in the outer scope) to expand packed FP4 -> BF16 values.
- Applies a per-block scale factor derived from the Scale tensor (using exponentiation by powers of two).
- Writes the scaled BF16 results into B_dequantize_shared.
Notes:
- This factory only supports in_dtype="fp4" and out_dtype="bfloat16".
- The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro.
- The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime.
"""
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
# Some variables for dequantization in each thread
MAX_TRANSACTION_SIZE_BITS = 128
local_size = MAX_TRANSACTION_SIZE_BITS // DataType(out_dtype).bits
local_compress_size = local_size // num_elems_per_byte
@T.macro
def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale_shared, k):
# import fast_dequantize plugin
"""
Fast dequantization kernel: convert packed 4-bit quantized values in B_shared to bfloat16
in B_dequantize_shared using an external intrinsic optimized for twiddled (bit-packed) FP4,
applying per-block scale factors from Scale.
This routine is a tiled, thread-parallel helper that:
- Imports and calls an external dequantization function (via `import_source`/`func_name`)
to expand compressed uint8-packed FP4 values into BF16 fragments in-thread.
- Loads the corresponding per-block scale entry, interprets it as an exponent bias
(applies 2^(Scale - 127)), and multiplies the dequantized BF16 fragment by that factor.
- Writes the scaled BF16 results back into the shared B_dequantize_shared buffer in-place.
Parameters:
- B_shared: read-only shared buffer containing compressed FP4 data (packed uint8 layout).
- B_dequantize_shared: shared output buffer that is overwritten with BF16 dequantized values.
- Scale: per-block scale tensor; entries are interpreted such that the multiplicative scale
= 2^(Scale - 127).
- k: block index along the K dimension used to select the appropriate Scale entries.
Side effects:
- Mutates B_dequantize_shared in shared memory.
- Calls an external intrinsic function (must be provided by the environment via `import_source`
and `func_name`) to perform the low-level unpacking/dequantization.
"""
T.import_source(import_source)
tx = T.get_thread_binding()
bx = T.get_block_binding(0) # noqa: F841
B_local_thread = T.alloc_local((local_compress_size,), storage_dtype)
B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype)
Scale_local_thread = T.alloc_local((1,), storage_dtype)
Scale_local_thread_exponent = T.alloc_local((1,), out_dtype)
for i in T.serial(0, block_N * block_K // threads // local_size):
# First, load data from share memory to register.
# Prepare for dequant.
index_base = i * threads * local_compress_size + tx * local_compress_size
for v in T.vectorized(0, local_compress_size):
index = index_base + v
B_local_thread[v] = B_shared[index // Block_QK, index % Block_QK]
index_scale = index_base // (scale_size // num_elems_per_byte)
si = index_scale // (block_K // scale_size)
sj = index_scale % (block_K // scale_size)
Scale_local_thread[0] = Scale_shared[si, k * block_K // scale_size + sj]
Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0]))
# Then, dequant.
T.call_extern(
func_name,
T.address_of(B_local_thread[0]),
T.address_of(B_dequantize_local_thread[0]),
1,
dtype=out_dtype,
)
# Finally, store the dequantized data to shared memory.
for v in T.Parallel(local_size):
B_dequantize_local_thread[v] *= Scale_local_thread_exponent[0]
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
B_dequantize_shared[index // block_K,
index % block_K] = B_dequantize_local_thread[v]
return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"):
"""
Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16.
Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared.
Notes:
- Only supports in_dtype="fp4" and out_dtype="bfloat16".
- The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion.
- Scale holds the exponent-like scaling values indexed per output element as used by the conversion helper.
"""
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
@T.macro
def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k):
"""
Dequantizes a packed 4-bit (FP4) block from B_shared into BF16 values in B_dequantize_shared using per-element scale exponents.
Per-element behavior:
- Reads packed 4-bit entries from B_shared (uint8 storage, multiple nibbles per byte).
- Uses Scale to obtain an exponent term (stored as uint8) and reconstructs BF16 values via _tir_u8_to_f4_to_bf16.
- Writes the dequantized BF16 block into B_dequantize_shared.
Parameters:
- B_shared: shared-memory buffer holding packed 4-bit values (uint8-packed layout).
- B_dequantize_shared: shared-memory buffer to receive dequantized BF16 results.
- Scale: per-element exponent buffer; used to compute the scale factor for each dequantized element.
- k: current block index along the K dimension (used to select the appropriate slice of Scale).
Side effects:
- Mutates B_dequantize_shared by storing the dequantized BF16 fragment.
"""
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype)
bx = T.get_block_binding(0) # noqa: F841
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
Scale_shared[
i, k * block_K // scale_size + j //
scale_size], # Scale is the exponential part, within the representation of uint8
dtype=out_dtype,
) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size]))
T.copy(B_dequantize_local, B_dequantize_shared)
return simple_dequant_bf16_fp4
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Scale: T.Tensor(Scale_shape, storage_dtype),
Bias: T.Tensor(Bias_shape, out_dtype),
C: T.Tensor((M, N), out_dtype),
):
"""
Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C.
This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function.
Parameters are self-descriptive in the signature; notable behaviors:
- B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM.
- The selected dequantization path is controlled by the outer-scope flag `fast_dequant`.
- The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization).
- The function writes results in-place into C.
"""
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
Bias_shared = T.alloc_shared(Bias_shared_shape, out_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
# To use 1D TMA, the last dim of Scale_shared must have stride=1
# May use much more shared memory than necessary
Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype)
T.annotate_layout({
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
})
if with_bias:
T.annotate_layout({
Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared),
})
if threads == 512:
T.disable_warp_group_reg_alloc()
if with_bias:
# T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N],
# Bias_shared)
# T.copy(Bias_shared, C_local)
T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N],
C_local)
else:
T.clear(C_local)
# Use 1D TMA to load Scale
T.copy(Scale[bx * block_N:(bx + 1) * block_N, :], Scale_shared)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
if fast_dequant:
get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared,
k)
else:
get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k)
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N])
return main
def ref_program_twiddling(A, qB, Scale, Bias=None):
"""
Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A · B^T in float, and casts the result to bfloat16.
Parameters:
A (torch.Tensor): Left operand with shape (M, K), used in floating precision.
qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling.
Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B.
Returns:
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
"""
dtypeC = "bfloat16"
B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
def ref_program_twiddling_with_bias(A, qB, Scale, Bias):
"""
Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A · B^T in float, and casts the result to bfloat16.
Parameters:
A (torch.Tensor): Left operand with shape (M, K), used in floating precision.
qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling.
Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B.
Bias (torch.Tensor): Bias tensor with shape (M, N).
Returns:
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
"""
dtypeC = "bfloat16"
B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC))
return C
def ref_program_simple(A, qB, Scale, Bias=None):
"""
Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A · B^T and returns the result converted to bfloat16.
Parameters:
- A: 2D tensor representing the left operand (will be cast to float32 for the matmul).
- qB: Quantized representation of B accepted by `torch_convert`.
- Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32.
Returns:
- 2D bfloat16 tensor C containing the matrix product A · B^T.
No in-place modification is performed on inputs (a local floating copy of B is scaled).
"""
dtypeC = "bfloat16"
B = torch_convert(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
def ref_program_simple_with_bias(A, qB, Scale, Bias):
"""
Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A · B^T and returns the result converted to bfloat16.
Parameters:
Returns:
- A: 2D tensor representing the left operand (will be cast to float32 for the matmul).
- qB: Quantized representation of B accepted by `torch_convert`.
- Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32.
- Bias: 2D tensor representing the Bias (will be cast to float32 for the matmul).
Returns:
- 2D bfloat16 tensor C containing the matrix product A · B^T.
No in-place modification is performed on inputs (a local floating copy of B is scaled).
"""
dtypeC = "bfloat16"
B = torch_convert(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC))
return C
def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, tune=False):
"""
Run and validate the tiled quantized matmul kernel, then benchmark its latency and report TFLOPS.
Builds a matmul kernel for the given matrix sizes and quantization scale size. If `tune` is True the kernel is obtained via the autotuning path; otherwise a fixed-parameter kernel is used. Validates numerical correctness against the appropriate reference implementation (bit-twiddling reference when `fast_dequant` is True, plain reference otherwise) with rtol/atol=0.01, prints a confirmation, then runs a benchmark (500 warmup iterations) and prints the measured latency (ms) and achieved TFLOPS.
Parameters:
m (int): Number of rows of A / output rows. Default 256.
n (int): Number of columns of B / output columns. Default 256.
k (int): Reduction dimension. Default 256.
scale_size (int): Size of the per-block scale vector used for dequantization. Default 32.
fast_dequant (bool): If True validate against the twiddling (fast dequant) reference and exercise the fast dequant path; otherwise use the simple dequant reference. Default True.
tune (bool): If True obtain a tuned/autotuned kernel; otherwise use a fixed-parameter kernel. Default False.
Returns:
None
"""
total_flops = 2 * m * n * k
if tune:
kernel = matmul(
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
num_bits=4,
scale_size=scale_size,
fast_dequant=fast_dequant,
with_bias=with_bias)
else:
kernel = matmul(
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
num_bits=4,
scale_size=scale_size,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1,
fast_dequant=fast_dequant,
with_bias=with_bias)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto)
if fast_dequant:
if with_bias:
profiler.assert_allclose(ref_program_twiddling_with_bias, rtol=0.01, atol=0.01)
else:
profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01)
else:
if with_bias:
profiler.assert_allclose(ref_program_simple_with_bias, rtol=0.01, atol=0.01)
else:
profiler.assert_allclose(ref_program_simple, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
if __name__ == "__main__":
M, N, K = 256, 256, 256
scale_size = 32
main(M, N, K, scale_size, fast_dequant=True, with_bias=True)
main(M, N, K, scale_size, fast_dequant=False, with_bias=True)
main(M, N, K, scale_size, fast_dequant=True, with_bias=False)
main(M, N, K, scale_size, fast_dequant=False, with_bias=False)
import torch
import torch.backends
import tilelang.testing
from tilelang import tvm as tvm
from tvm import DataType
import tilelang.language as T
tilelang.testing.set_random_seed(0)
@tilelang.jit(out_idx=[2])
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
num_bits=4,
):
from tilelang.quantize import _tir_packed_to_unsigned_convert
num_elems_per_byte = 8 // num_bits
storage_dtype = "int8"
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
storage_type = str("".join(c for c in storage_dtype if not c.isdigit()))
A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
MAX_TRANSACTION_SIZE_IN_BITS = 128
local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
local_size_compressed = local_size // num_elems_per_byte
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_local([local_size_compressed], storage_dtype)
B_dequantize_local = T.alloc_local([local_size], in_dtype)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
tx = T.get_thread_binding()
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
for i in T.serial(block_N * block_K // num_elems_per_byte //
(threads * local_size_compressed)):
for v in T.vectorized(0, local_size_compressed):
index = i * threads * local_size_compressed + tx * local_size_compressed + v
vi = index // (block_K // num_elems_per_byte)
vj = index % (block_K // num_elems_per_byte)
B_local[v] = B_shared[vi, vj]
for v in T.serial(0, local_size):
B_dequantize_local[v] = _tir_packed_to_unsigned_convert(
storage_type, storage_nbit)(
num_bits,
B_local[v // num_elems_per_byte],
v % num_elems_per_byte,
dtype=in_dtype,
)
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
vi = index // block_K
vj = index % block_K
B_dequantize_shared[vi, vj] = B_dequantize_local[v]
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm(
M,
N,
K,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
kernel = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
out = profiler.run_once()
assert out is not None
def ref_program(A, qB):
import torch
B = (
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
dtype=torch.half).to(torch.half).to(A.device))
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program)
@tvm.testing.requires_package("bitblas")
def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
transform_b,
):
from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout as make_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitterWithLadderTransform,)
from bitblas.gpu.intrin.lop3 import decode_i4_to_f16
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
num_bits = 4
num_elems_per_byte = 8 // num_bits
storage_dtype = "int8"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
micro_size_k = 32
# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_rows = 4
warp_cols = 4
warp_row_tiles = micro_size_x * warp_rows
warp_col_tiles = micro_size_y * warp_cols
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
reduce_k = 1
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = 32 if in_dtype == "float16" else 64
chunk = block_K // reduce_k
is_smooth_a = False
can_swizzle = block_K * DataType(in_dtype).bits == 512
apply_pad_a = not (is_smooth_a or can_swizzle)
pad_factor = 8
A_shape = (M, K)
B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y,
micro_size_k // num_elems_per_byte)
A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K)
B_shared_shape = (
block_N // micro_size_y,
block_K // micro_size_k,
micro_size_y,
micro_size_k // num_elems_per_byte,
)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitterWithLadderTransform(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
reduce_k=reduce_k,
transform_kind_b=transform_b,
num_elems_per_byte=num_elems_per_byte)
vec_load_qb = 16
if block_N * (block_K // reduce_k) // num_elems_per_byte // threads < vec_load_qb:
vec_load_qb = block_N * (block_K // reduce_k) // num_elems_per_byte // threads
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads,
prelude=decode_i4_to_f16) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size), in_dtype)
B_local = T.alloc_local((warp_cols * local_size // num_elems_per_byte), storage_dtype)
B_dequantize_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
reduced_accum_res = T.alloc_local(0, accum_dtype)
thread_binding = T.get_thread_binding(0)
rk = T.get_thread_binding(1)
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
})
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, (block_K // reduce_k)):
vk = rk * (block_K // reduce_k) + k
A_shared[i, vk] = A[by * block_M + i, ko * block_K + vk]
# TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load
for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte //
(threads * vec_load_qb)):
for v in T.vectorized(0, vec_load_qb):
t = thread_binding
idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v
vkk = idx % (micro_size_k // num_elems_per_byte)
vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y
vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % (
block_K // micro_size_k)
vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y //
(block_K // micro_size_k)) % (
block_N // micro_size_y)
B_shared[vj, vk, vjj,
vkk] = B[bx * (block_N // micro_size_y) + vj,
ko * (block_K // micro_size_k) + vk, vjj, vkk]
for ki in T.serial(0, (block_K // (micro_size_k * reduce_k))):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
rk=rk,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
rk=rk,
)
for j in T.serial(warp_cols):
local_size_b = mma_emitter.local_size_b
T.call_extern('handle', 'decode_i4u_to_f16',
T.address_of(B_local[j * local_size_b // num_elems_per_byte]),
T.address_of(B_dequantize_local[j * local_size_b]), 8)
mma_emitter.mma(A_local, B_dequantize_local, C_local)
if reduce_k > 1:
for n in T.serial(warp_rows * warp_cols * local_size):
T.attr(
T.comm_reducer(lambda x, y: x + y, [T.float16(0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
)
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
C_local[n],
True,
reduced_accum_res[0],
rk,
dtype="handle",
))
if rk == 0:
C_local[n] = reduced_accum_res[0]
if rk == 0:
mma_emitter.stmatrix(
C_local,
C_shared,
)
for i, j in T.Parallel(block_M, (block_N // reduce_k)):
vj = rk * (block_N // reduce_k) + j
C[by * block_M + i,
bx * block_N + vj] = C_shared[i // micro_size_x, vj // micro_size_y,
i % micro_size_x, vj % micro_size_y]
return main
def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
transform_b,
):
import bitblas
matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
M, N, K, in_dtype, out_dtype, accum_dtype, transform_b)
kernel = tilelang.compile(matmul, out_idx=[2])
src_code = kernel.get_kernel_source()
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
# src_code is the generated cuda source
assert src_code is not None
num_bits = 4
num_elems_per_byte = 8 // num_bits
storage_dtype = "int8"
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
qB = torch.randint(
0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
ladder_permutate_config = bitblas.ops.LadderPermutateConfig(
M=N,
N=K,
transform_kind=transform_b,
transpose_matrix=True,
dequantize_bits=num_bits,
storage_dtype=storage_dtype,
)
ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config)
lop3_permutate_config = bitblas.ops.LOP3PermutateConfig(
M=N,
N=K,
datatype=in_dtype,
dequantize_bits=num_bits,
storage_dtype=storage_dtype,
)
lop3_permutate = bitblas.ops.LOP3Permutate(
config=lop3_permutate_config,
target=tvm.target.Target("llvm"),
)
QLB = ladder_permutate(qB.cpu()).cuda()
QLB = lop3_permutate(QLB.cpu()).cuda()
kernel(A, QLB, C)
latency = profiler.do_bench(warmup=25)
# Ensure that the latency is not None
assert latency is not None
B = (
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
dtype=torch.half).to(torch.half).to(A.device))
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
# Get Reference Result
ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype))
print("Ref C: ", ref_c)
print("C: ", C)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@tilelang.testing.requires_package("bitblas")
def test_run_dequantize_gemm():
run_gemm(256, 256, 256, "float16", "float16", "float16", 128, 128, 32, num_threads=128)
run_gemm(256, 256, 256, "int8", "int32", "int32", 128, 128, 32, num_threads=128)
@tilelang.testing.requires_package("bitblas")
def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4():
assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(
256, 1024, 512, "float16", "float16", "float16", 3)
def main():
test_run_dequantize_gemm()
if __name__ == "__main__":
main()
import tilelang
import tilelang.language as T
from tilelang.autotuner import *
from tvm import tir
import itertools
import torch
import argparse
def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == "float16"
assert val.dtype == "uint8"
# e_f4 == 0 -> e_f16 = 0
# e_f4 != 0 -> e_f16 = e_f4 + ExponentialBias(f16, f4) = e_f4 + (2^4 - 2^1) = e_f4 + 14
# s1e2m1
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16")
e_f16 = e_f4 + tir.const(14, "uint16")
m_f4 = f4 & tir.const(1, "uint16")
m_f16 = m_f4
val_f16 = tir.reinterpret("float16",
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")
| m_f16 << tir.const(9, "uint16")).astype("uint16"))
# return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)
return val_f16
def torch_convert(tensor):
def print_bit(name, val):
val_cpu = val.cpu().item()
binary_repr = f'{val_cpu:032b}'
print(name, binary_repr)
def _convert(val, pos):
assert val.dtype == torch.uint8
val = val.view(torch.int8)
mask = (1 << 4) - 1
f4 = ((val >> (pos * 4)) & mask).to(torch.int16)
s = f4 >> 3
e_f4 = (f4 & 6) >> 1
e_f16 = e_f4 + 14
m_f4 = f4 & 1
m_f16 = m_f4
val_f16 = (((e_f16 | (s << 5)) << 10) | (m_f16 << 9)) & 0xFFFF
lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16)
return lower_16_bits.view(torch.float16)
N = tensor.shape[0]
K = tensor.shape[1]
new_tensor = torch.empty(N, K * 2, dtype=torch.float16, device=tensor.device)
for i in range(new_tensor.shape[0]):
for j in range(new_tensor.shape[1]):
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
return new_tensor
@tilelang.jit(out_idx=[1])
def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
B_shape = (N, K // num_elems_per_byte)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
@T.prim_func
def main(
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((N, K), in_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, C[bx * block_N, k * block_K])
return main
def test_fp4_fp16_convert_close():
N, K = 256, 256
block_N, block_K = 64, 64
kernel = test_convert(
N,
K,
block_N,
block_K,
"float16",
)
B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
tl_out = kernel(B)
ref_out = torch_convert(B)
assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out)
print("Pass")
def get_configs():
block_M = [64, 128]
block_N = [64, 128]
block_K = [128, 256]
num_stages = [1, 2]
threads = [128, 256]
splits = [1]
_configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits))
configs = [{
'block_M': c[0],
'block_N': c[1],
'block_K': c[2],
'num_stages': c[3],
'threads': c[4],
'split': c[5]
} for c in _configs]
return configs
def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@tilelang.jit(out_idx=[2])
def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
assert K % (block_K * split) == 0
KK = K // split
@T.prim_func
def main_split(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
):
SplitC = T.alloc_buffer([
split, (N + block_N - 1) // block_N * block_N,
(M + block_M - 1) // block_M * block_M
], out_dtype)
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split,
threads=threads) as (bx, by, bz):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
})
T.clear(Ct_local)
for k in T.Pipelined(K // (block_K * split), num_stages=num_stages):
T.copy(A[by * block_M, KK * bz + k * block_K], A_shared)
T.copy(B[bx * block_N, (KK * bz + k * block_K) // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, SplitC[bz, bx * block_N:(bx + 1) * block_N,
by * block_M:(by + 1) * block_M])
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by):
acc = T.alloc_fragment((block_N, block_M), out_dtype)
T.clear(acc)
for k in range(split):
for i, j in T.Parallel(block_N, block_M):
acc[i, j] += SplitC[k, bx * block_N + i, by * block_M + j]
T.copy(acc, Ct[bx * block_N, by * block_M])
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
})
T.clear(Ct_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct_shared)
T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N,
by * block_M:(by + 1) * block_M])
if split == 1:
return main
else:
return main_split
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[2])
def kernel(block_M=None,
block_N=None,
block_K=None,
num_stages=None,
threads=None,
split=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads, split).prim_func
return kernel()
else:
def kernel(block_M, block_N, block_K, num_stages, threads, split=1):
return kernel_func(block_M, block_N, block_K, num_stages, threads, split)
return kernel
def ref_program(A, qB):
dtypeC = "float16"
B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C.transpose(0, 1)
def main(m=256, n=256, k=256, tune=False):
total_flops = 2 * m * n * k
if (not tune):
kernel = matmul(
m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)(
block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--m', type=int, default=256, help='M')
parser.add_argument('--n', type=int, default=256, help='N')
parser.add_argument('--k', type=int, default=256, help='K')
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
main(M, N, K, args.tune)
import tilelang
import tilelang.language as T
from tilelang.autotuner import *
from tvm import tir
import itertools
import torch
import argparse
def _tir_u8_to_i4_to_i8(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == "int8"
assert val.dtype == "uint8"
mask = tir.const((1 << nbit) - 1, "uint8")
i4 = (val >> (pos.astype("uint8") * tir.const(nbit, "uint8"))) & mask
i8_shifted = tir.reinterpret("int8", i4 << tir.const(4, "uint8"))
i8 = i8_shifted >> tir.const(4, "int8")
return i8
def get_configs():
iter_params = dict(
block_M=[64, 128],
block_N=[64, 128],
block_K=[128, 256],
num_stages=[1, 2],
threads=[128, 256, 512],
)
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@tilelang.jit(out_idx=[1])
def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
B_shape = (N, K // num_elems_per_byte)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
@T.prim_func
def main(
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((N, K), in_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=1):
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, C[bx * block_N, k * block_K])
return main
def torch_convert(tensor):
def _convert(val, pos):
assert val.dtype == torch.uint8
val = val.view(torch.int8)
mask = (1 << 4) - 1
i4_shifted = ((val >> (pos * 4)) & mask)
i4 = ((i4_shifted << 4) >> 4)
return i4.view(torch.int8)
N = tensor.shape[0]
K = tensor.shape[1]
new_tensor = torch.empty(N, K * 2, dtype=torch.int8, device=tensor.device)
for i in range(new_tensor.shape[0]):
for j in range(new_tensor.shape[1]):
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
return new_tensor
def ref_program(A, qB):
dtypeC = "int32"
B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C.transpose(0, 1)
def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@tilelang.jit(out_idx=[2])
def kernel_func(block_M, block_N, block_K, num_stages, threads):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_local_shape = (block_N, block_K)
assert K % (block_K) == 0
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_local_shape, in_dtype)
B_dequantize_prev_local = T.alloc_fragment(B_dequantize_local_shape, in_dtype)
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
})
T.clear(Ct_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct_shared)
T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N,
by * block_M:(by + 1) * block_M])
return main
if tune:
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[2])
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads).prim_func
return kernel()
else:
def kernel(block_M, block_N, block_K, num_stages, threads):
return kernel_func(block_M, block_N, block_K, num_stages, threads)
return kernel
def main(m=128, n=256, k=256, tune=False):
total_flops = 2 * m * n * k
if (not tune):
kernel = matmul_int8xint4(
m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)(
block_M=32, block_N=32, block_K=128, num_stages=1, threads=128)
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program, rtol=1e-2, atol=1e-2)
print("All checks pass.")
latency = profiler.do_bench(warmup=50)
print(f"Tilelang: {latency} ms")
else:
best_result = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
print(f"Bset latency: {best_latency}")
print(f"Best config: {best_config}")
print(f"Best tflops: {total_flops / best_latency * 1e-9}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=512, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=512, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=512, help="Matrix dimension K")
parser.add_argument("--tune", action="store_true", help="Enable tuning")
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
main(M, N, K, args.tune)
# main(M, N, K, True)
import tilelang
from tilelang import language as T
from typing import Optional, Callable, Any
import torch
from tilelang import DataType
from tilelang.quantize import (
_tir_packed_int_to_int_convert,)
@tilelang.jit
def dequantize_gemv(
M: int,
N: int,
K: int,
in_dtype: str,
out_dtype: str,
accum_dtype: str,
num_bits: int = 4,
storage_dtype: str = "int8",
source_format: str = "uint",
n_partition: int = 4,
reduce_thread: int = 32,
fast_decoding: bool = False,
trans_A: bool = False,
trans_B: bool = True,
group_size: int = -1,
with_scaling: bool = False,
) -> Callable[..., Any]:
assert n_partition is not None, "n_partition must be provided"
assert reduce_thread is not None, (
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV"
"sch_outer_reduction_with_config is not implemented")
assert trans_A is False, "Dequantize only implement for trans_A=False currently"
assert trans_B is True, "Dequantize only implement for trans_B=TRue currently"
storage_type = "".join(c for c in storage_dtype if not c.isdigit())
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
num_elems_per_byte = storage_nbit // num_bits
MAX_TRANSACTION_SIZE_IN_BITS = 128
micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
micro_size_k_compressed = micro_size_k // num_elems_per_byte
block_K = reduce_thread * micro_size_k
if group_size == -1:
group_size = K
A_shape = (M, K)
B_shape = (N, K // storage_nbit * num_bits)
C_shape = (M, N)
dp4a_size = 4
use_dp4a = in_dtype == "int8" and accum_dtype == "int32"
import_source: Optional[str] = None
func_name: str = ""
if fast_decoding is True:
# Lazy import to decrease the startup time
# as intrin registry may take a while to load
from tilelang.quantize import get_lop3_intrin_group
lop3_intrin_info = get_lop3_intrin_group(
out_dtype=in_dtype,
source_format=source_format,
source_bit=num_bits,
storage_dtype=storage_dtype,
with_scaling=with_scaling,
with_zeros=False,
)
import_source = lop3_intrin_info["c_source"]
func_name = lop3_intrin_info["func_name"]
assert import_source is not None, "lop3_intrin_info is not found"
assert func_name is not None, "lop3_intrin_info is not found"
import_source = import_source
@T.prim_func
def main(
A: T.Tensor[A_shape, in_dtype],
B: T.Tensor[B_shape, storage_dtype],
C: T.Tensor[C_shape, out_dtype],
):
with T.Kernel(
T.ceildiv(N, n_partition),
M,
threads=(reduce_thread, n_partition),
) as (
bx,
by,
):
A_local = T.alloc_local((micro_size_k,), in_dtype)
B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype)
B_dequantize_local = T.alloc_local([micro_size_k], in_dtype)
accum_res = T.alloc_local((1,), accum_dtype)
reduced_accum_res = T.alloc_local((1,), accum_dtype)
kr = T.thread_binding(0, reduce_thread, thread="threadIdx.x")
ni = T.thread_binding(0, n_partition, thread="threadIdx.y")
T.import_source(import_source)
T.clear(accum_res)
for ko in T.serial(T.ceildiv(K, block_K)):
for v in T.vectorized(micro_size_k):
A_local[v] = A[by, ko * block_K + kr * micro_size_k + v]
for v in T.vectorized(micro_size_k_compressed):
B_quant_local[v] = B[
bx * n_partition + ni,
ko * (reduce_thread * micro_size_k_compressed) +
kr * micro_size_k_compressed + v,
]
if fast_decoding:
T.call_extern(
func_name,
T.address_of(B_quant_local[0]),
T.address_of(B_dequantize_local[0]),
dtype=in_dtype,
)
else:
for ki in T.serial(micro_size_k):
B_dequantize_local[ki] = _tir_packed_int_to_int_convert(
storage_type,
storage_nbit)(num_bits, B_quant_local[ki // num_elems_per_byte],
ki % num_elems_per_byte, in_dtype)
if use_dp4a:
for ki in T.serial(micro_size_k // dp4a_size):
T.dp4a(
A_local[ki * dp4a_size],
B_dequantize_local[ki * dp4a_size],
accum_res[0],
)
else:
for ki in T.serial(micro_size_k):
accum_res[0] += A_local[ki] * B_dequantize_local[ki]
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
accum_res[0],
True,
reduced_accum_res[0],
kr,
dtype="handle",
))
if kr == 0:
C[by, bx * n_partition + ni] = reduced_accum_res[0]
return main
def main() -> None:
M = 1
N = 1024
K = 1024
in_dtype = "float16"
out_dtype = "float16"
accum_dtype = "float16"
num_bits = 4
storage_dtype = "int8"
source_format = "uint"
n_partition = 4
reduce_thread = 32
fast_decoding = True
trans_A = False
trans_B = True
group_size = -1
with_scaling = False
kernel = dequantize_gemv(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits, storage_dtype,
source_format, n_partition, reduce_thread, fast_decoding, trans_A,
trans_B, group_size, with_scaling)
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
num_elems_per_byte = storage_nbit // num_bits
A = torch.rand(M, K, dtype=getattr(torch, in_dtype)).cuda()
qB = torch.randint(
0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda()
C = torch.zeros(M, N, dtype=getattr(torch, accum_dtype)).cuda()
if fast_decoding:
from tilelang.quantize.utils import interleave_weight
qB = interleave_weight(qB, num_bits, in_dtype)
kernel(A, qB, C)
# int4 reference
B = (
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
dtype=torch.half).to(torch.half).to(A.device))
for j in range(B.shape[1]):
B[:, j] = ((qB[:, j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
# Get Reference Result
ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype))
print("C: ", C)
print("Ref C: ", ref_c)
# doesn't apply scaling, the absolute error is large
torch.testing.assert_close(C, ref_c, atol=1e3, rtol=1e-1)
if __name__ == "__main__":
main()
import tilelang
import tilelang.language as T
from tilelang.quantize import _tir_u8_to_f4_to_bf16
from tilelang import tvm as tvm
from tvm import DataType
import torch
from dequantize_utils import torch_convert_bit_twiddling, assert_similar
from tilelang.autotuner import set_autotune_inputs
import argparse
def get_configs():
"""
Generate a list of hyperparameter configuration dictionaries for tuning.
Each configuration is a dict with keys: 'block_M', 'block_N', 'block_K',
'num_stages', 'threads', and 'split'. The function returns the Cartesian
product of the parameter value lists:
- block_M, block_N, block_K: tiling sizes
- num_stages: pipeline stages
- threads: thread counts
- split: K-splitting factor
Returns:
List[dict]: A list of configuration dictionaries covering all combinations.
"""
import itertools
iter_params = dict(
block_M=[128],
block_N=[64, 128, 256],
block_K=[128],
num_stages=[0, 1, 2],
threads=[128, 256, 512],
split=[1],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs())
@tilelang.jit(out_idx=[-1])
def matmul(M,
N,
K,
topk,
E,
padding_M,
in_dtype,
out_dtype,
accum_dtype,
source_format='uint',
num_bits=4,
scale_size=32,
fast_dequant=True,
with_bias=False,
block_M=128,
block_N=256,
block_K=128,
num_stages=2,
threads=256,
split=1):
"""
Construct and return a grouped (Mixture-of-Experts) matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized, expert-grouped B (shape ExNxQK) and writes an output of shape (M, topk, N) in out_dtype.
The generated kernel accepts:
- A: dense matrix with element type `in_dtype` and shape (M, K).
- B: packed quantized matrix for all experts, stored as uint8 with `num_bits` bits per element, shape (E, N, QK), where QK = K / (8/num_bits).
- Scale: per-expert, per-block scale/exponent information for dequantizing B, shape (E, N, K // scale_size).
- Bias: per-expert, per-output bias, shape (E, N).
- topk_weights: router weights for the top-k experts for each token, shape (M, topk).
- sorted_token_ids: flattened and padded tensor of token indices, shape (padding_M,).
- expert_ids: expert id for each token in the padded batch, shape (padding_M // block_M,).
- C: output tensor, shape (M, topk, N).
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is (M, topk, N)). K must be divisible by (block_K * split).
topk (int): number of experts selected per token.
E (int): number of experts.
padding_M (int): padded number of tokens after grouping and block alignment.
in_dtype (str): element type of A (e.g., "bfloat16").
out_dtype (str): output tensor element type (e.g., "bfloat16").
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
scale_size (int, optional): number of elements grouped per scale entry (default 32).
fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True).
block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128).
num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1).
with_bias (bool, optional): whether to add Bias to the output (default False).
Returns:
A T.prim_func implementing the grouped, pipelined GEMM that:
- loads tiled blocks of A and packed B for each expert to shared memory,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- applies per-token topk weights and bias,
- writes the final (M, topk, N) block to the global output tensor.
Notes:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
- An assertion enforces that K % (block_K * split) == 0.
"""
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
QK = K // num_elems_per_byte
Block_QK = block_K // num_elems_per_byte
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, Block_QK)
Bias_shared_shape = (block_N)
B_dequantize_shared_shape = (block_N, block_K)
assert K % (block_K * split) == 0
from tilelang.quantize import get_mxfp_intrin_group
# fast_dequant_bf16_fp4_twiddling
mxfp_intrin_info = get_mxfp_intrin_group(
out_dtype=in_dtype,
source_format=source_format,
source_bit=num_bits,
storage_dtype=storage_dtype,
use_twiddling=True,
)
import_source = mxfp_intrin_info["c_source"]
func_name = mxfp_intrin_info["func_name"]
assert import_source is not None, "mxfp_intrin_info is not found"
assert func_name is not None, "mxfp_intrin_info is not found"
import_source = import_source
# the dequant part is the same as in dequant_gemm
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"):
"""
Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16.
The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and:
- Loads packed FP4 elements from B_shared into per-thread local registers.
- Calls an external fast dequantization intrinsic (provided via `import_source` / `func_name` in the outer scope) to expand packed FP4 -> BF16 values.
- Applies a per-block scale factor derived from the Scale tensor (using exponentiation by powers of two).
- Writes the scaled BF16 results into B_dequantize_shared.
Notes:
- This factory only supports in_dtype="fp4" and out_dtype="bfloat16".
- The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro.
- The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime.
"""
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
# Some variables for dequantization in each thread
MAX_TRANSACTION_SIZE_BITS = 128
local_size = MAX_TRANSACTION_SIZE_BITS // DataType(out_dtype).bits
local_compress_size = local_size // num_elems_per_byte
@T.macro
def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale_shared, k):
# import fast_dequantize plugin
"""
Fast dequantization kernel: convert packed 4-bit quantized values in B_shared to bfloat16
in B_dequantize_shared using an external intrinsic optimized for twiddled (bit-packed) FP4,
applying per-block scale factors from Scale.
This routine is a tiled, thread-parallel helper that:
- Imports and calls an external dequantization function (via `import_source`/`func_name`)
to expand compressed uint8-packed FP4 values into BF16 fragments in-thread.
- Loads the corresponding per-block scale entry, interprets it as an exponent bias
(applies 2^(Scale - 127)), and multiplies the dequantized BF16 fragment by that factor.
- Writes the scaled BF16 results back into the shared B_dequantize_shared buffer in-place.
Parameters:
- B_shared: read-only shared buffer containing compressed FP4 data (packed uint8 layout).
- B_dequantize_shared: shared output buffer that is overwritten with BF16 dequantized values.
- Scale_shared: per-block scale tensor; entries are interpreted such that the multiplicative scale
= 2^(Scale - 127).
- k: block index along the K dimension used to select the appropriate Scale entries.
Side effects:
- Mutates B_dequantize_shared in shared memory.
- Calls an external intrinsic function (must be provided by the environment via `import_source`
and `func_name`) to perform the low-level unpacking/dequantization.
"""
T.import_source(import_source)
tx = T.get_thread_binding()
B_local_thread = T.alloc_local((local_compress_size,), storage_dtype)
B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype)
Scale_local_thread = T.alloc_local((1,), storage_dtype)
Scale_local_thread_exponent = T.alloc_local((1,), out_dtype)
for i in T.serial(0, block_N * block_K // threads // local_size):
# First, load data from share memory to register.
# Prepare for dequant.
index_base = i * threads * local_compress_size + tx * local_compress_size
for v in T.vectorized(0, local_compress_size):
index = index_base + v
B_local_thread[v] = B_shared[index // Block_QK, index % Block_QK]
index_scale = index_base // (scale_size // num_elems_per_byte)
si = index_scale // (block_K // scale_size)
sj = index_scale % (block_K // scale_size)
Scale_local_thread[0] = Scale_shared[si, k * block_K // scale_size + sj]
Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0]))
# Then, dequant.
T.call_extern(
func_name,
T.address_of(B_local_thread[0]),
T.address_of(B_dequantize_local_thread[0]),
1,
dtype=out_dtype,
)
# Finally, store the dequantized data to shared memory.
for v in T.Parallel(local_size):
B_dequantize_local_thread[v] *= Scale_local_thread_exponent[0]
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
B_dequantize_shared[index // block_K,
index % block_K] = B_dequantize_local_thread[v]
return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"):
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
@T.macro
def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k):
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
Scale_shared[
i, k * block_K // scale_size + j //
scale_size], # Scale is the exponential part, within the representation of uint8
dtype=out_dtype,
) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size]))
T.copy(B_dequantize_local, B_dequantize_shared)
return simple_dequant_bf16_fp4
@T.prim_func
def main(
A: T.Tensor((M, K), in_dtype),
B: T.Tensor((E, N, QK), storage_dtype),
Scale: T.Tensor((E, N, K // scale_size), storage_dtype),
Bias: T.Tensor((E, N), out_dtype),
# Add fusedmoe tensors
topk_weights: T.Tensor((M * topk), out_dtype),
sorted_token_ids: T.Tensor((padding_M), "int32"),
expert_ids: T.Tensor((padding_M // block_M), "int32"),
C: T.Tensor((M, topk, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
Bias_shared = T.alloc_shared(Bias_shared_shape, out_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
topk_weights_shared = T.alloc_shared((block_M), out_dtype)
sorted_token_ids_shared = T.alloc_shared((block_M), "int32")
expert_id = T.alloc_local((1), "int32") # the expert id for the current block
# To use 1D TMA, the last dim of Scale_shared must have stride=1
# May use much more shared memory than necessary
Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype)
T.annotate_layout({
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
})
T.use_swizzle(10)
if threads == 512:
T.disable_warp_group_reg_alloc()
T.copy(sorted_token_ids[by * block_M:(by + 1) * block_M], sorted_token_ids_shared)
expert_id[0] = expert_ids[by]
# Get the topk weights of each token in the current block
for i in T.Parallel(block_M):
if sorted_token_ids_shared[i] != -1:
topk_weights_shared[i] = topk_weights[sorted_token_ids_shared[i]]
# Get bias and scale based on the expert id
if with_bias:
T.copy(Bias[expert_id[0], bx * block_N:(bx + 1) * block_N], Bias_shared)
else:
T.clear(Bias_shared)
T.copy(Scale[expert_id[0], bx * block_N:(bx + 1) * block_N, :], Scale_shared)
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = Bias_shared[j]
tx = T.get_thread_binding()
for k in T.Pipelined(K // block_K, num_stages=num_stages):
# Each thread copies 4 bytes, local size is 16
for copy_i in T.serial(block_M * block_K // threads // 16):
base = copy_i * threads * 16 + tx * 16
if sorted_token_ids_shared[base // block_K] != -1:
for copy_j in T.vectorized(16):
A_shared[base // block_K, base % block_K +
copy_j] = A[sorted_token_ids_shared[base // block_K] // topk,
k * block_K + base % block_K + copy_j]
T.copy(B[expert_id[0], bx * block_N, k * block_K // num_elems_per_byte], B_shared)
if fast_dequant:
get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared,
k)
else:
get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k)
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = C_local[i, j] * topk_weights_shared[i]
T.copy(C_local, C_shared)
for copy_i in T.serial(block_M * block_N // threads // 16):
base = copy_i * threads * 16 + tx * 16
if sorted_token_ids_shared[base // block_N] != -1:
for copy_j in T.vectorized(16):
C[sorted_token_ids_shared[base // block_N] // topk,
sorted_token_ids_shared[base // block_N] % topk, bx * block_N +
base % block_N + copy_j] = C_shared[base // block_N,
base % block_N + copy_j]
return main
def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=256):
dtypeC = "bfloat16"
M, K = A.shape
E, N, QK = qB.shape
topk = topk_weights.shape[0] // M
scale_size = K // Scale.shape[2]
assert scale_size == 32 # MXFP4
# Initialize output tensor
C = torch.ones((M, topk, N), dtype=getattr(torch, dtypeC), device='cuda')
# Iterate over sorted_token_ids
for idx in range(len(sorted_token_ids)): # padding_M
token_id = sorted_token_ids[idx]
if token_id == -1:
continue
expert_id = expert_ids[idx // block_M]
topk_idx = token_id % topk
# Get the token embedding
token_embedding = A[token_id // topk]
# Dequantize the expert weights
B = torch_convert_bit_twiddling(qB[expert_id]) # shape: (N, K)
B *= 2**(
Scale[expert_id][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to(
torch.bfloat16))
# Compute the output for this token-expert pair
# token_embedding @ B.T + bias
output = torch.matmul(token_embedding.to(torch.bfloat16), B.T.to(
torch.bfloat16)) + Bias[expert_id]
output = output.to(torch.__getattribute__(dtypeC))
# Apply the topk weight
weight = topk_weights[token_id]
output = output * weight
# Store the result
C[token_id // topk, topk_idx] = output
return C
def get_data(m, n, k, qk, scale_size, topk, E, block_M):
A = torch.empty(m, k, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1)
qB = torch.randint(
0, 256, (E, n, qk), dtype=torch.uint8,
device='cuda') # Quantized weight tensor for E experts.
Scale = torch.randint(0, 8, (E, n, k // scale_size), dtype=torch.uint8, device='cuda')
Bias = torch.empty(E, n, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1)
weights = torch.empty(m, E, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1)
# topk_weights: Router weights for the top-k experts for each token.
# Shape: (m, topk)
# tokens_experts: A flattened tensor of expert assignments for each token.
# For each of m tokens, topk unique experts are chosen. Shape: (m * topk,)
topk_weights, tokens_experts = torch.topk(weights, topk, dim=-1)
tokens_experts = tokens_experts.reshape(m * topk)
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.reshape(m * topk)
sorted_expert_vals, sorted_indices = torch.sort(tokens_experts, stable=True)
sorted_token_ids = sorted_indices
unique_expert_ids, counts = torch.unique_consecutive(sorted_expert_vals, return_counts=True)
expert_ids = []
padded_token_ids = []
start = 0
for eid, cnt in zip(unique_expert_ids.tolist(), counts.tolist()):
end = start + cnt
group_token_ids = sorted_token_ids[start:end]
pad_len = ((cnt + block_M - 1) // block_M) * block_M - cnt
if pad_len > 0:
# -1 for padding (`M` instead in vLLM moe_align_block_size())
group_token_ids = torch.cat([
group_token_ids,
torch.full((pad_len,), -1, dtype=group_token_ids.dtype, device='cuda')
])
padded_token_ids.append(group_token_ids)
expert_ids.extend([eid] * ((cnt + block_M - 1) // block_M))
start = end
# sorted_token_ids: The final flattened and padded tensor of token indices.
sorted_token_ids = torch.cat(padded_token_ids, dim=0).to(torch.int32) # (padding_M,)
# expert_ids: The final tensor of expert IDs corresponding to `sorted_token_ids`.
expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda') # (padding_M,)
padding_M = sorted_token_ids.shape[0] # padding_M: token number after padding
return A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M
def main(m=256,
n=256,
k=256,
scale_size=32,
topk=4,
E=32,
fast_dequant=True,
with_bias=False,
tune=False):
# Tunable parameters
block_M, block_N, block_K = 128, 256, 128 # noqa: F841
num_stages = 1 # noqa: F841
threads = 512 # noqa: F841
split = 1 # noqa: F841
total_flops = 2 * m * n * k * topk
num_bits = 4
num_elems_per_byte = 8 // num_bits
qk = k // num_elems_per_byte
A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data(
m, n, k, qk, scale_size, topk, E, block_M)
if tune:
with set_autotune_inputs([A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids]):
# Autotune with inputs manually composed
kernel = matmul(
m,
n,
k,
topk,
E,
padding_M,
"bfloat16",
"bfloat16",
"float32",
num_bits=num_bits,
scale_size=scale_size,
fast_dequant=fast_dequant,
with_bias=with_bias,
)
else:
kernel = matmul(
m,
n,
k,
topk,
E,
padding_M,
"bfloat16",
"bfloat16",
"float32",
num_bits=num_bits,
scale_size=scale_size,
fast_dequant=fast_dequant,
with_bias=with_bias,
block_M=block_M,
block_N=block_N,
block_K=block_K,
num_stages=num_stages,
threads=threads,
split=split,
)
output = kernel(
A,
qB,
Scale,
Bias,
topk_weights,
sorted_token_ids,
expert_ids,
)
print('Tilelang kernel run finished.')
ref_output = ref_moe(
A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids,
block_M=block_M) # Maybe a little bit slow...
latency = tilelang.profiler.do_bench(
lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100)
print("Tilelang: {:.2f} ms".format(latency))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
diff = (output - ref_output).abs()
max_val = diff.max()
max_idx = diff.argmax()
print(f"max abs diff: {max_val} at index: {max_idx}")
assert_similar(
output, ref_output, name="output",
eps=2e-5) # We care about the similarity rather than abs. difference
print("All checks pass. ✅")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--M", type=int, default=16384, help="M") # From gpt-oss-20b MoE's first gemm
parser.add_argument("--N", type=int, default=5760, help="N")
parser.add_argument("--K", type=int, default=2944, help="K")
parser.add_argument("--scale_size", type=int, default=32, help="scale size")
parser.add_argument(
"--topk", type=int, default=4, help="topk") # experts activated for each token
parser.add_argument("--E", type=int, default=32, help="E") # number of experts
parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args()
main(
args.M,
args.N,
args.K,
args.scale_size,
topk=args.topk,
E=args.E,
fast_dequant=True,
with_bias=True,
tune=args.tune)
import tilelang.testing
import example_dequant_gemv_fp16xint4
import example_dequant_gemm_fp4_hopper
import example_dequant_gemm_bf16_mxfp4_hopper
import example_dequant_gemm_bf16_mxfp4_hopper_tma
import example_dequant_groupedgemm_bf16_mxfp4_hopper
import example_dequant_gemm_w4a8
@tilelang.testing.requires_cuda
def test_example_dequant_gemv_fp16xint4():
example_dequant_gemv_fp16xint4.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_gemm_fp4_hopper():
example_dequant_gemm_fp4_hopper.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_gemm_bf16_mxfp4_hopper():
example_dequant_gemm_bf16_mxfp4_hopper.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_gemm_bf16_mxfp4_hopper_tma():
example_dequant_gemm_bf16_mxfp4_hopper_tma.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_groupedgemm_bf16_mxfp4_hopper():
example_dequant_groupedgemm_bf16_mxfp4_hopper.main()
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_gemm_w4a8():
example_dequant_gemm_w4a8.main()
if __name__ == "__main__":
tilelang.testing.main()
import tilelang
import tilelang.language as T
import tilelang.testing
from tilelang import tvm as tvm
@tilelang.jit(pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8})
def matmul_dynamic_mnk(
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
M = tvm.te.var("m")
N = tvm.te.var("n")
K = tvm.te.var("k")
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def dynamic_matmul(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return dynamic_matmul
def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
accum_dtype, num_stages, threads):
print(
f"M: {M}, N: {N}, K: {K}, block_M: {block_M}, block_N: {block_N}, block_K: {block_K}, trans_A: {trans_A}, trans_B: {trans_B}, in_dtype: {in_dtype}, out_dtype: {out_dtype}, accum_dtype: {accum_dtype}, num_stages: {num_stages}, threads: {threads}"
)
kernel = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
accum_dtype, num_stages, threads)
import torch
if trans_A:
A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype))
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
if trans_B:
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
else:
B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
kernel(A, B, C)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
# Get Reference Result
ref_c = ref_program(A, B)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench(input_tensors=[A, B, C])
print(f"Latency: {latency} ms")
def main(M=16384, N=16384, K=16384):
block_M, block_N, block_K = 128, 128, 32
trans_A, trans_B = False, False
in_dtype, out_dtype = "float16", "float16"
accum_dtype = "float32"
num_stages = 3
threads = 128
matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
accum_dtype, num_stages, threads)
if __name__ == "__main__":
main()
import tilelang.testing
import example_dynamic
def test_example_dynamic():
example_dynamic.main(M=1024, N=1024, K=1024)
if __name__ == "__main__":
tilelang.testing.main()
import argparse
import itertools
import torch
import tilelang
import tilelang.language as T
from tilelang.autotuner import AutoTuner
def ref_program(x, y):
return x + y
@tilelang.jit(out_idx=[-1])
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
@T.prim_func
def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor(
(M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), in_dtype)
B_shared = T.alloc_shared((block_M, block_N), in_dtype)
C_local = T.alloc_fragment((block_M, block_N), out_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(B[by * block_M, bx * block_N], B_shared)
for (local_y, local_x) in T.Parallel(block_M, block_N):
C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x]
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return elem_add
def get_configs(M, N):
block_M = [64, 128, 256]
block_N = [64, 128, 256]
threads = [64, 128, 256]
configs = list(itertools.product(block_M, block_N, threads))
return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs]
def get_best_config(M, N):
def kernel(block_M=None, block_N=None, threads=None):
return elementwise_add(M, N, block_M, block_N, "float32", "float32", threads)
autotuner = AutoTuner.from_kernel(
kernel=kernel, configs=get_configs(M, N)).set_compile_args(
out_idx=[-1],
target="cuda",
).set_profile_args(
supply_type=tilelang.TensorSupplyType.Auto,
ref_prog=ref_program,
skip_check=False,
)
return autotuner.run(warmup=3, rep=20)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=1024)
parser.add_argument("--n", type=int, default=1024)
parser.add_argument("--use_autotune", action="store_true", default=False)
args, _ = parser.parse_known_args()
M, N = args.m, args.n
a = torch.randn(M, N, dtype=torch.float32, device="cuda")
b = torch.randn(M, N, dtype=torch.float32, device="cuda")
if args.use_autotune:
result = get_best_config(M, N)
kernel = result.kernel
else:
# Default config
config = {"block_M": 32, "block_N": 32, "threads": 128}
kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32")
out = kernel(a, b)
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
if __name__ == "__main__":
main()
import argparse
import tilelang
import tilelang.language as T
import torch
def ref_program(x, y):
return x + y
@tilelang.jit(out_idx=[-1])
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
@T.prim_func
def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor(
(M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), in_dtype)
B_shared = T.alloc_shared((block_M, block_N), in_dtype)
C_local = T.alloc_fragment((block_M, block_N), out_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(B[by * block_M, bx * block_N], B_shared)
for (local_y, local_x) in T.Parallel(block_M, block_N):
C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x]
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return elem_add
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=128)
parser.add_argument("--n", type=int, default=128)
args, _ = parser.parse_known_args()
M, N = args.m, args.n
a = torch.randn(M, N, dtype=torch.float32, device="cuda")
b = torch.randn(M, N, dtype=torch.float32, device="cuda")
# Default config
config = {"block_M": 128, "block_N": 128, "threads": 128}
kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32")
out = kernel(a, b)
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
print("All passed!")
if __name__ == "__main__":
main()
import tilelang.testing
import example_elementwise_add
import example_elementwise_add_tma_1d
def test_example_elementwise_add():
example_elementwise_add.main()
def test_example_elementwise_add_tma_1d():
example_elementwise_add_tma_1d.main()
if __name__ == "__main__":
tilelang.testing.main()
# FlashAttention
Using tile-lang, we can define buffers at different memory layers. For instance, `Q_shared`, `K_shared`, and `V_shared` can be defined in shared memory, while `acc_s` and `acc_o` can be placed in registers. This flexibility allows us to represent a complex fusion pattern like FlashAttention in a simple way.
```python
@T.prim_func
def flash_attention(
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
Output: T.Tensor(shape, dtype),
):
# Launch a specialized T.Kernel with 3D mapping: (bx, by, bz)
# bx: block index in sequence dimension
# by: block index in "heads" dimension
# bz: block index in "batch" dimension
# threads=thread_num means how many threads per block
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num) as (bx, by, bz):
# Allocate shared memory for Q, K, V to reduce global memory accesses
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
# Allocate buffers on register
# acc_s: buffer to hold intermediate attention scores
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
# acc_s_cast: buffer for storing casted/adjusted scores
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
# acc_o: partial accumulation of output
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
# Buffers to track per-row maximum score and related stats
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
# Annotate layout for Q_shared, e.g., use a swizzled layout to optimize memory access
T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)})
# Copy a block of Q from global memory to Q_shared
T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)
# Initialize accumulators
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N)
)
# Pipeline the loop to overlap copies/gemm stages
for k in T.Pipelined(loop_range, num_stages=num_stages):
# Copy K block into shared memory
T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(
bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)
)
else:
T.clear(acc_s)
# Perform the Q*K^T multiplication, Here, transpose_B=True indicates that K_shared is transposed,
# policy=T.GemmWarpPolicy.FullRow means each warp is responsible for computing an entire row
# of acc_s, and the resulting acc_s is retained in registers.
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
# Copy V block into shared memory
T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared)
for i, j in T.Parallel(block_M, dim):
acc_s[i, j] *= scale
# Save old scores_max, then reset scores_max
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
# Compute the maximum value per row on dimension 1 (block_N)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# Compute the factor by which we need to rescale previous partial sums
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i])
# Rescale the partial output accumulation to keep exponents consistent
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
# Exponentiate (scores - max) for the new block
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i])
# Make a cast of acc_s to fp16 for the next GEMM
T.copy(acc_s, acc_s_cast)
# Multiply the attention acc_s_cast by V and add to partial output (acc_o)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
# Update the "logsum" tracker with the newly accumulated sum
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
# Final step: divide each partial output by logsum (completing the softmax)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
# Write back the final output block from acc_o to the Output buffer
T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])
```
\ No newline at end of file
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
# ruff: noqa
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
class IndexFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return torch.gather(
rearrange(input, "b ... -> b (...)"), 0,
repeat(indices, "z -> z d", d=second_dim)).reshape(-1, *other_shape)
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
grad_output = rearrange(grad_output, "b ... -> b (...)")
grad_input = torch.zeros(
[ctx.first_axis_dim, grad_output.shape[1]],
device=grad_output.device,
dtype=grad_output.dtype,
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# grad_input[indices] = grad_output
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
index_first_axis = IndexFirstAxis.apply
class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, values, indices, first_axis_dim):
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim >= 2
output = torch.zeros(
first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output[indices] = values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
return output
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
grad_values = grad_output[indices]
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
return grad_values, None, None
index_put_first_axis = IndexPutFirstAxis.apply
class IndexFirstAxisResidual(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
output = input[indices]
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
# memory format to channel_first. In other words, input might not be contiguous.
# If we don't detach, Pytorch complains about output being a view and is being modified inplace
return output, input.detach()
@staticmethod
def backward(ctx, grad_output, grad_residual):
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
assert grad_residual.shape[1:] == other_shape
grad_input = grad_residual
# grad_input[indices] += grad_output
indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
indices = indices.expand_as(grad_output)
grad_input.scatter_add_(0, indices, grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
index_first_axis_residual = IndexFirstAxisResidual.apply
def unpad_input(hidden_states, attention_mask):
"""
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
"""
Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
```
[
[2, 3, 0, 0, 0, 0],
[3, 2, 0, 0, 0, 0],
[6, 0, 0, 0, 0, 0]
]
```
, which refers to the 3D-attention mask:
```
[
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1]
],
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0],
[0, 0, 0, 0, 0, 1]
],
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1]
]
]
```.
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
"""
length = attention_mask_in_length.sum(dim=-1)
seqlen = attention_mask_in_length.size(-1)
attention_mask_2d = torch.arange(
seqlen, device=length.device, dtype=length.dtype).expand(len(length),
seqlen) < length.unsqueeze(1)
real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def pad_input(hidden_states, indices, batch, seqlen):
"""
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz)
Return:
hidden_states: (batch, seqlen, ...)
"""
dim = hidden_states.shape[-1]
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output[indices] = hidden_states
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, "(b s) ... -> b s ...", b=batch)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment