Commit b4b527d3 authored by Yuxi Chi's avatar Yuxi Chi Committed by LeiWang1999
Browse files

[Examples] Add fp8 gemm 2xAcc and deepgemm example (#217)

* add fp8 gemm 2xAcc and deepgemm example.

* format deepgemm example.

* fix the fotmat lint.

* format with the updated format.sh
parent 889451eb
from typing import Tuple
import torch
import tilelang.testing
import tilelang as TL
import tilelang.language as T
from tilelang.utils.tensor import map_torch_type
tilelang.testing.set_random_seed(0)
def tl_gemm(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"e4m3_float8",
], "Currently only e4m3_float8 is supported"
assert out_dtype in [
"bfloat16",
"float32",
], "Currently only float16 and float32 are supported"
TILE_SIZE = (128, 128, 128)
block_M = TILE_SIZE[0]
block_N = TILE_SIZE[1]
block_K = TILE_SIZE[2]
A_shape = (M, K)
Scales_A_shape = (M, T.ceildiv(K, block_K))
B_shape = (N, K)
Scales_B_shape = (T.ceildiv(N, block_N), T.ceildiv(K, block_K))
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (block_M, block_N)
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
scales_a: T.Buffer(Scales_A_shape, "float32"),
scales_b: T.Buffer(Scales_B_shape, "float32"),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_shared = T.alloc_shared(C_shared_shape, out_dtype)
Scale_C_shared = T.alloc_shared((block_M), "float32")
C_local = T.alloc_fragment(C_shared_shape, accum_dtype)
C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
T.clear(C_local_accum)
K_iters = T.ceildiv(K, block_K)
for k in T.Pipelined(K_iters, num_stages=4):
# Load A into shared memory
T.copy(A[by * block_M, k * block_K], A_shared)
# Load B into shared memory
T.copy(B[bx * block_N, k * block_K], B_shared)
# Load scale into shared memory
Scale_B = scales_b[bx, k]
for i in T.Parallel(block_M):
Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
# Promote to enable 2xAcc
for i, j in T.Parallel(block_M, block_N):
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
T.clear(C_local)
# TMA store
T.copy(C_local_accum, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
def ceildiv(a, b):
return (a + b - 1) // b
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(
m, n), (x_amax / 448.0).view(m, -1)
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
ceildiv(m, 128) * 128, ceildiv(n, 128) * 128, dtype=x.dtype, device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
x_view.size(0), x_view.size(2))
def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype):
# A_scale: (M, K//128) ==> (M//128, K//128, 128)
# B_scale: (N//128, K//128) ==> (N//128, K//128, 128)
# A_fp8: (M, K)
# B_fp8: (N, K)
# out_dtype: float16 or float32
# return C: (M, N)
M, N, K = A_fp8.shape[0], B_fp8.shape[0], A_fp8.shape[1]
A_scales = A_scale.view(M // 128, 128, K // 128).permute(0, 2, 1)
B_scales = B_scale.repeat_interleave(128, dim=1).view(N // 128, K // 128, 128)
C = torch.zeros(M, N, device="cuda", dtype=out_dtype)
c_acc = torch.zeros(128, 128, device="cuda", dtype=torch.float32)
for i in range(ceildiv(M, 128)):
for j in range(ceildiv(N, 128)):
c_acc.zero_()
for k in range(ceildiv(K, 128)):
c = torch._scaled_mm(
A_fp8[i * 128:(i + 1) * 128, k * 128:(k + 1) * 128],
B_fp8[j * 128:(j + 1) * 128, k * 128:(k + 1) * 128].T,
scale_a=A_scales[i, k].view(128, 1).contiguous(),
scale_b=B_scales[j, k].view(1, 128).contiguous(),
out_dtype=torch.bfloat16)
c_acc += c.to(torch.float32)
C[i * 128:(i + 1) * 128, j * 128:(j + 1) * 128] = c_acc.to(out_dtype)
return C
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def assert_tl_gemm_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
gemm = tl_gemm(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(gemm)
src_code = mod.imported_modules[0].get_source()
# src_code is the generated cuda source
assert src_code is not None
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype)
A = torch.randn(M, K).to(torch.bfloat16).cuda()
B = torch.randn(N, K).to(torch.bfloat16).cuda()
A_fp8, A_scale = per_token_cast_to_fp8(A.clone())
B_fp8, B_scale = per_block_cast_to_fp8(B.clone())
C = torch.zeros(M, N, device="cuda", dtype=out_dtype)
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A_fp8, B_fp8, C, A_scale, B_scale)
# Get Reference Result
ref_c = ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype)
diff = calc_diff(C, ref_c)
print(f"diff: {diff}")
assert diff < 1e-3
latency = mod.do_bench(mod.func, warmup=25)
# Ensure that the latency is not None
assert latency is not None
print(f"latency: {latency} ms")
tflops = 2 * M * N * K / latency / 1e9
print(f"tflops: {tflops}")
if __name__ == "__main__":
for dtype in ["e4m3_float8"]:
for out_dtype in ["bfloat16", "float32"]:
assert_tl_gemm_correctness(1024, 1024, 8192, dtype, out_dtype, "float32")
import torch
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import map_torch_type
def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
# for fp8 gemm, do one promote after 4 wgmma inst, i.e. block_K = 128.
# if block_K < 128, promote after 128/block_K iters.
# if block_K > 128, promote after every iter.
update_interval = 128 // block_K if block_K < 128 else 1
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
T.clear(C_local_accum)
K_iters = T.ceildiv(K, block_K)
for k in T.Pipelined(K_iters, num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
# Promote to enable 2xAcc
if (k + 1) % update_interval == 0:
for i, j in T.Parallel(block_M, block_N):
C_local_accum[i, j] += C_local[i, j]
T.clear(C_local)
# Tail processing
if K_iters % update_interval != 0:
for i, j in T.Parallel(block_M, block_N):
C_local_accum[i, j] += C_local[i, j]
# TMA store
T.copy(C_local_accum, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def test_gemm_fp8(M, N, K, dtype):
torch_dtype = map_torch_type(dtype)
func = matmul(M, N, K, 128, 128, 64, dtype)
kernel = tilelang.compile(func, out_idx=-1)
a = torch.rand(M, K, dtype=torch.float16, device='cuda')
a = (100 * (2 * a - 1)).to(dtype=torch_dtype)
b = torch.rand(N, K, dtype=torch.float16, device='cuda')
b = (100 * (2 * b - 1)).to(dtype=torch_dtype)
c = kernel(a, b)
ref_c = (a.float() @ b.float().T)
diff = calc_diff(c, ref_c)
print(f"diff: {diff}")
assert diff < 1e-3
if __name__ == "__main__":
test_gemm_fp8(1024, 1024, 8192, 'e4m3_float8')
test_gemm_fp8(1024, 1024, 8192, 'e5m2_float8')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment