Unverified Commit 143b5222 authored by Jiaxing Ding's avatar Jiaxing Ding Committed by GitHub
Browse files

[AMD] support preshuffle weight mfma (#806)


Co-authored-by: default avatarJiaxing Ding <jiaxing.ding@bytedance.com>
parent 409ab83d
import torch
import tilelang.testing
from tilelang import tvm as tvm
import tilelang.language as T
from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout
from tilelang.intrinsics.mfma_macro_generator import (
MatrixCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func
tilelang.testing.set_random_seed(0)
@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
a_transposed=False,
b_transposed=True,
k_pack=1,
b_preshuffle=False,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if in_dtype in {"float8_e4m3fnuz", "int8"}:
micro_size_k = 32
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 32
warp_col_tiles = 32
# for preshuffle_b, warp_layout = {1, 4}
if b_preshuffle:
block_row_warps = 1
block_col_warps = 4
warp_row_tiles = 128
warp_col_tiles = 32
chunk = 32 * k_pack
pack_size_k = micro_size_k * k_pack
shared_scope = "shared"
cache_write_shared = False
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (K, M) if a_transposed else (M, K)
if b_preshuffle:
B_shape = (N // micro_size_y, K // pack_size_k, micro_size_y,
pack_size_k) if b_transposed else (K // pack_size_k, N // micro_size_y,
pack_size_k, micro_size_y)
else:
B_shape = (N, K) if b_transposed else (K, N)
A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K)
if b_preshuffle:
B_shared_shape = (block_N // micro_size_y, block_K // pack_size_k, micro_size_y,
pack_size_k) if b_transposed else (block_K // pack_size_k,
block_N // micro_size_y, pack_size_k,
micro_size_y)
else:
B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 64
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (k_pack * micro_size_x * micro_size_k) // warp_size
local_size_b = (k_pack * micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mfma_emitter = MatrixCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=a_transposed,
b_transposed=b_transposed,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
k_pack=k_pack,
b_preshuffle=b_preshuffle,
)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
})
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=0):
# Load A into shared memory
if a_transposed:
T.copy(A[ko * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, ko * block_K], A_shared)
# Load B into shared memory
if b_preshuffle:
if b_transposed:
for j, k, jj, kk in T.Parallel(block_N // micro_size_y,
block_K // pack_size_k, micro_size_y,
pack_size_k):
B_shared[j, k, jj, kk] = B[bx * block_N // micro_size_y + j,
ko * block_K // pack_size_k + k, jj, kk]
else:
for k, j, kk, jj in T.Parallel(block_K // pack_size_k,
block_N // micro_size_y, pack_size_k,
micro_size_y):
B_shared[k, j, kk, jj] = B[ko * block_K // pack_size_k + k,
bx * block_N // micro_size_y + j, kk, jj]
else:
if b_transposed:
T.copy(B[bx * block_N, ko * block_K], B_shared)
else:
T.copy(B[ko * block_K, bx * block_N], B_shared)
for ki in T.serial(0, (block_K // (k_pack * micro_size_k))):
# Load A into fragment
mfma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)
# Load B into fragment
mfma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_local, C_local)
# Perform STMatrix
if cache_write_shared:
mfma_emitter.stmatrix(
C_local,
C_shared,
)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
else:
mfma_emitter.stmatrix(
C_local,
C,
pid_m=by,
pid_n=bx,
)
return main
def shuffle_weight(
x: torch.Tensor,
layout=(16, 32),
k_pack=1,
is_transpose=False,
) -> torch.Tensor:
IN, IK = layout
BK = IK * k_pack
BN = IN
N, K = (x.shape[-2], x.shape[-1]) if is_transpose else (x.shape[-1], x.shape[-2])
assert N % BN == 0
assert K % BK == 0
x = x.view(N // BN, BN, K // BK, BK) if is_transpose else x.view(K // BK, BK, N // BN, BN)
x = x.permute(0, 2, 1, 3)
return x.contiguous()
def assert_tl_matmul_correctness(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype="float32",
a_transposed=False,
b_transposed=True,
k_pack=1,
b_preshuffle=False):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed,
k_pack, b_preshuffle)
print(matmul)
kernel = tilelang.compile(matmul)
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
A_shape = (K, M) if a_transposed else (M, K)
B_shape = (N, K) if b_transposed else (K, N)
if in_dtype == "int8":
A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8)
else:
A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
B_preshuffle = B
if b_preshuffle:
B_preshuffle = shuffle_weight(B_preshuffle, k_pack=k_pack, is_transpose=b_transposed)
kernel(A, B_preshuffle, C)
else:
kernel(A, B, C)
print(kernel.get_kernel_source())
profiler = kernel.get_profiler()
latency = profiler.do_bench()
# Ensure that the latency is not None
assert latency is not None
if a_transposed and b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.T.to(torch.float32),
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
elif a_transposed and not b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.Tto(torch.float32),
B.to(torch.float32)).to(getattr(torch, out_dtype))
elif not a_transposed and b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32),
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
else:
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype))
print(C)
print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@tilelang.testing.requires_rocm
def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2)
assert_tl_matmul_correctness(
128, 128, 128, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True)
assert_tl_matmul_correctness(
128,
256,
256,
"int8",
"int32",
b_transposed=False,
accum_dtype="int32",
k_pack=2,
b_preshuffle=True)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -53,6 +53,7 @@ class MatrixCoreIntrinEmitter(object):
num_elems_per_byte: int = 1,
k_pack: Optional[int] = None,
is_m_first: Optional[bool] = False,
b_preshuffle: Optional[bool] = False,
):
self.a_dtype = a_dtype
self.b_dtype = b_dtype
......@@ -72,6 +73,7 @@ class MatrixCoreIntrinEmitter(object):
self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim)
self._initialize_k_pack(k_pack)
self._initialize_is_m_first(is_m_first)
self._initialize_b_preshuffle(b_preshuffle)
self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
......@@ -141,6 +143,10 @@ class MatrixCoreIntrinEmitter(object):
if is_m_first is not None:
self.is_m_first = is_m_first
def _initialize_b_preshuffle(self, b_preshuffle: Optional[bool] = False):
if b_preshuffle is not None:
self.b_preshuffle = b_preshuffle
def get_ldmatrix_index_map(self, is_b=False):
from .mfma_layout import (
shared_16x4_to_local_64x1_layout_A,
......@@ -288,6 +294,31 @@ class MatrixCoreIntrinEmitter(object):
):
tx, warp_n, _ = self.extract_thread_binding(thread_binding)
# 4 dim
if self.b_preshuffle:
if is_transposed:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
warp_n * warp_cols + j,
rk * (chunk // micro_size_k) + ki,
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r,
row,
col]
else:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
rk * (chunk // micro_size_k) + ki,
warp_n * warp_cols + j,
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r,
row,
col]
else:
if is_transposed:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
......@@ -296,8 +327,8 @@ class MatrixCoreIntrinEmitter(object):
warp_n * warp_col_tiles + j * micro_size_y,
rk * chunk + ki * (k_pack * micro_size_k),
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row,
r + col]
B_local_buf[j * k_pack * local_size_b +
local_id] = B_shared_buf[l + row, r + col]
else:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
......@@ -306,8 +337,8 @@ class MatrixCoreIntrinEmitter(object):
rk * chunk + ki * (k_pack * micro_size_k),
warp_n * warp_col_tiles + j * micro_size_y,
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row,
r + col]
B_local_buf[j * k_pack * local_size_b +
local_id] = B_shared_buf[l + row, r + col]
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment