Commit bc2d5632 authored by root's avatar root
Browse files

init

parents
Pipeline #3222 failed with stages
in 0 seconds
import tilelang
import tilelang.language as T
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def gemm_schedule(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Enable rasterization for better L2 Cache Locality
T.use_swizzle(panel_size=10)
# Clear the local buffer
T.clear(C_local)
# Auto pipeline the computation
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, ko * block_K], A_shared)
# Instead of using
# T.copy(B[k * block_K, bx * block_N], B_shared)
# we can also use Parallel to auto map the thread
# bindings and vectorize the copy operation.
for k, j in T.Parallel(block_K, block_N):
B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
return gemm_schedule
def main():
kernel = matmul(1024, 1024, 1024, 128, 128, 32)
import torch
a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()
c = kernel(a, b)
ref_c = a @ b
print("c:")
print(c)
print("ref_c:")
print(ref_c)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("All check passed.")
# Get CUDA Source
print("CUDA Source:")
print(kernel.get_kernel_source())
if __name__ == "__main__":
main()
import tilelang.testing
import example_gemm_autotune
import example_gemm_intrinsics
import example_gemm_schedule
import example_gemm
def test_example_gemm_autotune():
# enable roller for fast tuning
example_gemm_autotune.main(M=1024, N=1024, K=1024, with_roller=True)
def test_example_gemm_intrinsics():
example_gemm_intrinsics.main(M=1024, N=1024, K=1024)
def test_example_gemm_schedule():
example_gemm_schedule.main()
def test_example_gemm():
example_gemm.main()
if __name__ == "__main__":
tilelang.testing.main()
**Notes**: Now we only support fp8 with mma instructions instead of `T.gemm`, because the cutlass version of tilelang is too old, we should update the cutlass version in future.
\ No newline at end of file
import torch
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import torch_assert_close
import itertools
def ref_program(A, B):
return (A.half() @ B.half().T).to(dtype=torch.float32)
def manual_check_prog(C, C_ref):
torch_assert_close(C[0], C_ref[0], rtol=0.01, atol=0.1)
def supply_prog(args):
a_param, b_param = args
M, K = a_param.shape
N, _ = b_param.shape
a = (torch.randn(M, K, dtype=torch.float16, device='cuda') *
0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device='cuda') *
0.01).to(dtype=torch.float8_e4m3fnuz)
return [a, b]
def get_configs():
block_Ms = [32, 64, 128]
block_Ns = [32, 64, 128]
block_Ks = [64, 128]
num_stages = [0]
num_threads = [256]
k_packs = [1, 2]
gemm_types = ["ss", "rs"]
valid_configs = []
for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks,
num_stages, num_threads, k_packs,
gemm_types):
valid_configs.append({
"block_M": m,
"block_N": n,
"block_K": k,
"num_stages": stages,
"num_threads": t,
"k_pack": kp,
"gemm_type": gemm_type,
})
return valid_configs
@tilelang.autotune(
configs=get_configs(),
cache_input_tensors=True,
ref_prog=ref_program,
manual_check_prog=manual_check_prog,
supply_prog=supply_prog)
@tilelang.jit(out_idx=[-1])
def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type):
dtype = "float8_e4m3fnuz"
accum_dtype = "float"
@T.prim_func
def gemm_fp8_rs(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_local = T.alloc_fragment((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_local)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_local,
B_shared,
C_local,
transpose_B=True,
k_pack=k_pack,
policy=T.GemmWarpPolicy.FullRow)
T.copy(C_local, C[by * block_M, bx * block_N])
@T.prim_func
def gemm_fp8_ss(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
k_pack=k_pack,
policy=T.GemmWarpPolicy.FullRow)
T.copy(C_local, C[by * block_M, bx * block_N])
if gemm_type == "ss":
return gemm_fp8_ss
elif gemm_type == "rs":
return gemm_fp8_rs
else:
raise ValueError(f"Invalid gemm_type: {gemm_type}")
def test_gemm_fp8(M, N, K):
kernel = fp8_matmul(M, N, K)
a = (torch.randn(M, K, dtype=torch.float16, device='cuda') *
0.01).to(dtype=torch.float8_e4m3fnuz)
b = (torch.randn(N, K, dtype=torch.float16, device='cuda') *
0.01).to(dtype=torch.float8_e4m3fnuz)
c = kernel(a, b)
ref_c = ref_program(a, b)
torch_assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("passed~")
if __name__ == "__main__":
test_gemm_fp8(512, 512, 512)
import torch
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import map_torch_type
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
@T.prim_func
def gemm_fp8(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
T.copy(C_local, C[by * block_M, bx * block_N])
return gemm_fp8
def test_gemm_fp8(M, N, K, dtype):
torch_dtype = map_torch_type(dtype)
kernel = matmul(M, N, K, 128, 128, 64, dtype)
a = torch.randn(M, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype)
b = torch.randn(N, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype)
c = kernel(a, b)
ref_c = (a.half() @ b.half().T).to(dtype=torch_dtype)
print(c)
print(ref_c)
diff = calc_diff(c, ref_c)
print(f"diff: {diff}")
assert diff < 1e-3
def main():
test_gemm_fp8(1024, 1024, 1024, 'float8_e4m3')
test_gemm_fp8(1024, 1024, 1024, 'float8_e5m2')
if __name__ == "__main__":
main()
import torch
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import map_torch_type
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
# for fp8 gemm, do one promote after 4 wgmma inst, i.e. block_K = 128.
# if block_K < 128, promote after 128/block_K iters.
# if block_K > 128, promote after every iter.
update_interval = 128 // block_K if block_K < 128 else 1
@T.prim_func
def gemm_fp8_2xAcc(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
T.clear(C_local_accum)
K_iters = T.ceildiv(K, block_K)
for k in T.Pipelined(K_iters, num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
# Promote to enable 2xAcc
if (k + 1) % update_interval == 0:
for i, j in T.Parallel(block_M, block_N):
C_local_accum[i, j] += C_local[i, j]
T.clear(C_local)
# Tail processing
if K_iters % update_interval != 0:
for i, j in T.Parallel(block_M, block_N):
C_local_accum[i, j] += C_local[i, j]
# TMA store
T.copy(C_local_accum, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return gemm_fp8_2xAcc
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def test_gemm_fp8(M, N, K, dtype):
torch_dtype = map_torch_type(dtype)
kernel = matmul(M, N, K, 128, 128, 64, dtype)
a = torch.rand(M, K, dtype=torch.float16, device='cuda')
a = (100 * (2 * a - 1)).to(dtype=torch_dtype)
b = torch.rand(N, K, dtype=torch.float16, device='cuda')
b = (100 * (2 * b - 1)).to(dtype=torch_dtype)
c = kernel(a, b)
ref_c = (a.float() @ b.float().T)
diff = calc_diff(c, ref_c)
print(f"diff: {diff}")
assert diff < 1e-3
def main():
test_gemm_fp8(1024, 1024, 8192, 'float8_e4m3')
test_gemm_fp8(1024, 1024, 8192, 'float8_e5m2')
if __name__ == "__main__":
main()
import torch
from tilelang import tvm as tvm
import tilelang.testing
from tvm import DataType
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func
from tilelang.utils.tensor import map_torch_type
tilelang.testing.set_random_seed(0)
def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape
can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
return T.Layout(shape, lambda *args: args)
def transform_func(i, j):
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [new_warp_i, new_warp_j]
return T.Layout(shape, transform_func)
@tilelang.jit(out_idx=[2])
@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"float16",
"float8_e4m3",
"float8_e5m2",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"]
if out_dtype == "int32" or is_float8:
micro_size_k = 32
# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 32
warp_col_tiles = 32
chunk = 32 if in_dtype == "float16" else 64
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
@T.prim_func
def gemm_fp8_intrinsic(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
# Perform STMatrix
mma_emitter.stmatrix(
C_local,
C_shared,
)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
return gemm_fp8_intrinsic
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
src_code = kernel.get_kernel_source()
print(src_code)
# src_code is the generated cuda source
assert src_code is not None
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype)
if in_dtype in {torch.int8, torch.int32}:
A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda()
B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda()
elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
A = torch.randn(M, K).to(in_dtype).cuda()
B = torch.randn(N, K).to(in_dtype).cuda()
else:
A = torch.randn(M, K).to(in_dtype).cuda() - 0.5
B = torch.randn(N, K).to(in_dtype).cuda() - 0.5
C = torch.zeros(M, N, device="cuda", dtype=accum_dtype)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
C = profiler(A, B)
latency = profiler.do_bench(warmup=25)
# Ensure that the latency is not None
assert latency is not None
# Get Reference Result
ref_c = torch.matmul(A.to(accum_dtype), B.T.to(accum_dtype)).to(out_dtype)
print(C)
print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
def main():
assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "float8_e5m2", "float32", "float32")
if __name__ == "__main__":
main()
import tilelang.testing
import example_tilelang_gemm_fp8_2xAcc
import example_tilelang_gemm_fp8_intrinsic
import example_tilelang_gemm_fp8
def test_example_tilelang_gemm_fp8_2xAcc():
example_tilelang_gemm_fp8_2xAcc.main()
def test_example_tilelang_gemm_fp8_intrinsic():
example_tilelang_gemm_fp8_intrinsic.main()
def test_example_tilelang_gemm_fp8():
example_tilelang_gemm_fp8.main()
if __name__ == "__main__":
tilelang.testing.main()
# TileLang SM100 Support (Preview)
This directory contains examples for TileLang's experimental SM100 architecture support. **This is a preview version** with limited functionality.
## Current Limitations (Manual Implementation Required)
### 1. Manual TCGEN5.MMA Management
Users must manually handle TCGEN5MMA operations using:
- `T.alloc_tmem()` - Allocate Tensor Memory
- `T.gemm()` with `wg_wait=-1` - Launch TCGEN5MMA without waiting
- Manual synchronization with mbarrier
### 2. Manual mbarrier Synchronization
TCGEN5MMA is asynchronous and requires explicit synchronization:
```python
mbar = T.alloc_barrier(1) # expect-arrive-count = 1
T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k==0)
T.mbarrier_wait_parity(mbar, k%2) # Manual phase calculation required
```
## Examples
### TCGEN5MMA Example (`gemm_tcgen5mma.py`)
Demonstrates TCGEN5MMA operations with:
- Tensor Memory allocation
- Manual mbarrier synchronization
- TCGEN5MMA gemm operations
### Traditional MMA Example (`gemm_mma.py`)
Shows standard MMA operations that work across architectures for comparison.
## Code Example
The following code is based on `gemm_tcgen5mma.py`, demonstrating TCGEN5MMA matrix multiplication:
```python
import torch
import tilelang
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((M, K), "bfloat16"),
B: T.Tensor((N, K), "bfloat16"),
C: T.Tensor((M, N), "bfloat16"),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
# 1. Allocate memory buffers
A_shared = T.alloc_shared((block_M, block_K), "bfloat16") # A matrix shared memory
B_shared = T.alloc_shared((block_N, block_K), "bfloat16") # B matrix shared memory
C_tmem = T.alloc_tmem([block_M, block_N], "float") # TCGEN5MMA output to Tensor Memory
mbar = T.alloc_barrier(1) # mbarrier synchronization primitive
C_local = T.alloc_fragment((block_M, block_N), "float") # Register storage
C_shared = T.alloc_shared((block_M, block_N), "bfloat16") # Output shared memory
# 2. Main computation loop
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
# Data loading: global memory to shared memory
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
# TCGEN5MMA computation: asynchronous launch, output to Tensor Memory
T.gemm(A_shared, B_shared, C_tmem, trans_A=False, trans_B=True,
mbar=mbar, wg_wait=-1, clear_accum=k==0)
# Critical: wait for TCGEN5MMA completion
T.mbarrier_wait_parity(mbar, k%2)
# 3. Output processing (only subset of threads)
T.copy(C_tmem, C_local) # Tensor Memory → registers
T.copy(C_local, C_shared) # registers → shared memory
# 4. Write back to global memory
T.copy(C_shared, C[by * block_M, bx * block_N])
```
### Compilation and Usage
```python
# Parameter setup
M, N, K = 4096, 4096, 8192
block_M, block_N, block_K = 128, 256, 128
# Compile kernel
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, # Required
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # Required
})
# Run test
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
c = jit_kernel(a, b)
# Verify correctness
ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
# Performance benchmark
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
print(f"Performance: {2 * M * N * K / (latency/1e3) / 1e12:.2f} TFLOPS")
```
import tilelang
import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Clear local accumulation
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
# Copy tile of A
# This is a sugar syntax for parallelized copy
# for i, k in T.Parallel(M, block_K):
# A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
T.copy(A[by * block_M, ko * block_K], A_shared)
# Copy tile of B
T.copy(B[bx * block_N, ko * block_K], B_shared)
# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])
return main
M = 128 # M = T.dynamic("m") if you want to use dynamic shape
N = 128
K = 32
block_M = 128
block_N = 128
block_K = 32
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(M, N, K, block_M, block_N, block_K)
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
print(jit_kernel.get_kernel_source())
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(N, K, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
c = jit_kernel(a, b)
print(c)
# Reference multiplication using PyTorch
ref_c = a @ b.T
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
import torch
import tilelang
import tilelang.language as T
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
mbar = T.alloc_barrier(1)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_tmem,
trans_A,
trans_B,
mbar=mbar,
wg_wait=-1,
clear_accum=k == 0)
T.mbarrier_wait_parity(mbar, k % 2)
T.copy(C_tmem, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
M, N, K = 4096, 4096, 8192
block_M, block_N, block_K = 128, 256, 128
trans_A, trans_B = False, True
in_dtype, out_dtype, accum_dtype = "bfloat16", "bfloat16", "float"
num_stages = 2
threads = 256
func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
accum_dtype, num_stages, threads)
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
print(jit_kernel.get_kernel_source())
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
c = jit_kernel(a, b)
ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
print(f"Flops: {2 * M * N * K / (latency/1e3) / 1e12} TFLOPS")
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
import argparse
import tilelang
import tilelang.language as T
from tilelang.layout import make_metadata_layout
from tilelang.utils.sparse import compress, randn_semi_sparse
from tilelang.contrib import nvcc
from triton.testing import do_bench
import torch
arch = nvcc.get_target_compute_version()
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
default_config = { # take best config from autotune script
"4090": {
'float': {
'block_M': 128,
'block_N': 64,
'block_K': 64,
'num_stages': 1,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
},
'float16': {
'block_M': 256,
'block_N': 128,
'block_K': 64,
'num_stages': 2,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
}
},
"h20": {
'float': {
'block_M': 128,
'block_N': 64,
'block_K': 128,
'num_stages': 3,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
},
'float16': {
'block_M': 128,
'block_N': 64,
'block_K': 128,
'num_stages': 3,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
}
}
}
@tilelang.jit(out_idx=[-1])
def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy,
enable_rasterization):
e_factor, e_dtype = ARCH_INFO[arch]
@T.prim_func
def gemm_sp_fp16(
A_sparse: T.Tensor((M, K // 2), 'float16'),
E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), 'float16'),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K // 2), 'float16')
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
B_shared = T.alloc_shared((block_K, block_N), 'float16')
C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
T.disable_warp_group_reg_alloc()
T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout({
E:
make_metadata_layout(
E, mma_dtype="float16", backend="cutlass", block_k=block_K, arch=arch),
E_shared:
make_metadata_layout(
E_shared,
mma_dtype="float16",
backend="cutlass",
block_k=block_K,
arch=arch),
})
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
T.copy(E[by * block_M, k * block_K // e_factor], E_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm_sp(A_shared, E_shared, B_shared, C_local, False, False, policy=policy)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return gemm_sp_fp16
def main():
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument(
"--accum_dtype",
type=str,
default="float",
choices=["float", "float16"],
help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090", "h20"], required=True)
args = parser.parse_args()
kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype,
**default_config[args.cfg][args.accum_dtype])
a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half)
b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half)
a_sparse, e = compress(
a,
transposed=False,
block_k=default_config[args.cfg][args.accum_dtype]['block_K'],
arch=arch)
c = kernel(a_sparse, e, b)
ref_c = a @ b
assert not c.isnan().any(), "Reference result contains NaNs, please report an issue"
torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2)
print(f"Precision check passed. diff: {(c - ref_c).abs().mean()}")
latency = do_bench(lambda: kernel(a_sparse, e, b))
ref_latency = do_bench(lambda: a @ b)
total_flops = 2 * args.m * args.n * args.k
tflops = total_flops / latency / 1e9
ref_tflops = total_flops / ref_latency / 1e9
print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency/1e3} s")
print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency/1e3:} s")
if __name__ == "__main__":
main()
import tilelang
import tilelang.language as T
@tilelang.jit
def matmul(M,
N,
K,
block_M,
block_N,
block_K,
split_k,
dtype="float16",
accum_dtype="float",
out_dtype="float32"):
splitK = K // split_k
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(splitK, block_K), num_stages=0):
T.copy(A[by * block_M, bz * splitK + ko * block_K], A_shared)
T.copy(B[bz * splitK + ko * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C_shared)
for i, j in T.Parallel(block_M, block_N):
T.atomic_add(C[by * block_M + i, bx * block_N + j], C_shared[i, j])
return main
def main():
M = 1024
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32
split_k = 4
kernel = matmul(M, N, K, block_M, block_N, block_K, split_k)
import torch
torch.random.manual_seed(42)
a = torch.randn(M, K).cuda().half()
b = torch.randn(K, N).cuda().half()
c = torch.zeros(M, N).cuda().float()
kernel(a, b, c)
ref_c = a @ b
torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2)
if __name__ == "__main__":
main()
import tilelang
import tilelang.language as T
@tilelang.jit
def matmul(M,
N,
K,
block_M,
block_N,
block_K,
split_k,
dtype="float16",
accum_dtype="float",
out_dtype="float32"):
splitK = K // split_k
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(splitK, block_K), num_stages=0):
T.copy(A[by * block_M, bz * splitK + ko * block_K], A_shared)
T.copy(B[bz * splitK + ko * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C_shared)
T.atomic_add(C[by * block_M, bx * block_N], C_shared)
return main
def main():
M = 1024
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32
split_k = 4
kernel = matmul(M, N, K, block_M, block_N, block_K, split_k)
import torch
torch.random.manual_seed(42)
a = torch.randn(M, K).cuda().half()
b = torch.randn(K, N).cuda().half()
c = torch.zeros(M, N).cuda().float()
kernel(a, b, c)
ref_c = a @ b
torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2)
if __name__ == "__main__":
main()
import tilelang.testing
import example_tilelang_gemm_splitk
import example_tilelang_gemm_splitk_vectorize_atomicadd
def test_example_tilelang_gemm_splitk():
example_tilelang_gemm_splitk.main()
def test_example_tilelang_gemm_splitk_vectorize_atomicadd():
example_tilelang_gemm_splitk_vectorize_atomicadd.main()
if __name__ == "__main__":
tilelang.testing.main()
import torch
import torch.backends
import tilelang
from tilelang import language as T
import math
def cdiv(a, b):
return math.ceil(a / b)
# disable tf32
torch.backends.cuda.matmul.allow_tf32 = False
m = 256
n = 1024
k = 512
total_sm = 108
torch.random.manual_seed(0)
# uniform distribution from -1 to 1
A = torch.rand(m, k, device="cuda", dtype=torch.float16) * 2 - 1
B = torch.rand(n, k, device="cuda", dtype=torch.float16) * 2 - 1
streamk_programs = total_sm
BLOCK_SIZE_M = 16
BLOCK_SIZE_N = 128
BLOCK_SIZE_K = 32
two_tiles = False
M, K = A.shape
N, K = B.shape
# accumulator types
# compute grid (work to do per SM on the first wave)
num_block_m = tilelang.cdiv(M, BLOCK_SIZE_M)
num_block_n = tilelang.cdiv(N, BLOCK_SIZE_N)
iters_per_tile = tilelang.cdiv(K, BLOCK_SIZE_K)
total_tiles = num_block_m * num_block_n
# Two-tile SK + DP
streamk_tiles = total_tiles % streamk_programs
if (total_tiles - streamk_tiles > streamk_programs): # (total_tiles // total_programs > 1)
streamk_tiles += streamk_programs
blocking_tiles = total_tiles - streamk_tiles
streamk_iters = streamk_tiles * iters_per_tile
streamk_full_tiles = streamk_iters // streamk_programs
streamk_partial_tiles = streamk_iters % streamk_programs
print(f"{total_tiles=} ")
print(f"{iters_per_tile=} ")
sm_patition_factor = max(blocking_tiles // total_sm, 1)
@tilelang.jit
def tl_matmul_streamk(
M,
N,
K,
streamk_tiles,
block_M,
block_N,
block_K,
trans_A,
trans_B,
dtypeAB,
dtypeC,
accum_dtype,
num_stages,
threads,
):
assert not trans_A
A_shape = (M, K) if not trans_A else (K, M)
B_shape = (K, N) if not trans_B else (N, K)
A_shared_shape = (block_M, block_K) if not trans_A else (block_K, block_M)
B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K)
@T.macro
def compute_first_wave(
pid: T.int32,
A_buf: T.Tensor,
A_buf_shared: T.SharedBuffer,
B_buf: T.Tensor,
B_buf_shared: T.SharedBuffer,
C: T.Tensor,
C_local: T.LocalBuffer,
):
start_iter = T.alloc_fragment((1,), "int32", "local")
end_iter = T.alloc_fragment((1,), "int32", "local")
start_iter[0] = pid * streamk_full_tiles + T.min(pid, streamk_partial_tiles)
last_iter = (pid + 1) * streamk_full_tiles + T.min(pid + 1, streamk_partial_tiles)
while start_iter[0] < last_iter:
end_iter[0] = T.min(
start_iter[0] + (iters_per_tile - (start_iter[0] % iters_per_tile)),
last_iter,
)
tile_id = start_iter[0] // iters_per_tile
remain_iters = start_iter[0] % iters_per_tile
pid_m = tile_id // T.ceildiv(N, block_N)
pid_n = tile_id % T.ceildiv(N, block_N)
T.clear(C_local)
for k in T.Pipelined(end_iter[0] - start_iter[0], num_stages=num_stages):
T.copy(
A_buf[pid_m * block_M, (k + (start_iter[0] % iters_per_tile)) * block_K],
A_buf_shared,
)
T.copy(
B_buf[pid_n * block_N, (k + (start_iter[0] % iters_per_tile)) * block_K],
B_buf_shared,
)
T.gemm(A_buf_shared, B_buf_shared, C_local, transpose_B=trans_B)
# last iteration of the tile always happens before its start on another SM
if remain_iters == 0 and (end_iter[0] % iters_per_tile == 0):
T.copy(C_local, C[pid_m * block_M, pid_n * block_N])
else:
for i, j in T.Parallel(block_M, block_N):
T.atomic_add(C[pid_m * block_M + i, pid_n * block_N + j], C_local[i, j])
start_iter[0] = end_iter[0]
@T.macro
def compute_full_tiles(
pid: T.int32,
A_buf: T.Tensor,
A_shared: T.SharedBuffer,
B_buf: T.Tensor,
B_shared: T.SharedBuffer,
C: T.Tensor,
C_local: T.LocalBuffer,
):
for p in T.serial(sm_patition_factor):
tile_id = pid + streamk_tiles + p * total_sm
pid_m = tile_id // T.ceildiv(N, block_N)
pid_n = tile_id % T.ceildiv(N, block_N)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
T.copy(A_buf[pid_m * block_M, k * block_K], A_shared)
T.copy(B_buf[pid_n * block_N, k * block_K], B_shared)
T.gemm(A_shared, B_shared, C_local, transpose_B=trans_B)
T.copy(C_local, C[pid_m * block_M, pid_n * block_N])
@T.prim_func
def main(
A: T.Tensor(A_shape, dtypeAB),
B: T.Tensor(B_shape, dtypeAB),
C: T.Tensor((M, N), dtypeC),
):
with T.Kernel(streamk_programs, threads=threads) as pid:
A_shared = T.alloc_shared(A_shared_shape, dtypeAB)
B_shared = T.alloc_shared(B_shared_shape, dtypeAB)
A_shared_full_tiles = T.alloc_shared(A_shared_shape, dtypeAB)
B_shared_full_tiles = T.alloc_shared(B_shared_shape, dtypeAB)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
compute_first_wave(pid, A, A_shared, B, B_shared, C, C_local)
if sm_patition_factor > 0:
compute_full_tiles(pid, A, A_shared_full_tiles, B, B_shared_full_tiles, C, C_local)
return main
def main():
kernel = tl_matmul_streamk(
m,
n,
k,
streamk_tiles,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
BLOCK_SIZE_K,
False,
True,
"float16",
"float16",
"float32",
2,
64,
)
print(kernel.get_kernel_source())
b_c = torch.zeros((m, n), device="cuda", dtype=torch.float16)
kernel(A, B, b_c)
C = torch.matmul(A, B.T)
print(b_c)
print(C)
torch.testing.assert_close(C, b_c, rtol=1e-2, atol=1e-2)
if __name__ == "__main__":
main()
import tilelang.testing
from example_tilelang_gemm_streamk import main
# not fully supported on sm90
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
def test_example_tilelang_gemm_streamk():
main()
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment