Unverified Commit 48c9a352 authored by Jiaxing Ding's avatar Jiaxing Ding Committed by GitHub
Browse files

[AMD] refactor MatrixCoreIntrinEmitter (#860)

parent b12a63cf
......@@ -234,6 +234,10 @@ def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2)
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32")
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2)
if __name__ == "__main__":
......
......@@ -3,8 +3,7 @@ import tilelang.testing
from tilelang import tvm as tvm
import tilelang.language as T
from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout
from tilelang.intrinsics.mfma_macro_generator import (
MatrixCoreIntrinEmitter,)
from tilelang.intrinsics.mfma_macro_generator import MatrixCorePreshuffleIntrinEmitter
from tilelang.transform import simplify_prim_func
tilelang.testing.set_random_seed(0)
......@@ -22,16 +21,8 @@ def tl_matmul(
b_transposed=True,
k_pack=1,
b_preshuffle=False,
b_g2l_load=False,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
......@@ -47,15 +38,14 @@ def tl_matmul(
if b_preshuffle:
block_row_warps = 1
block_col_warps = 4
warp_row_tiles = 128
warp_col_tiles = 32
warp_row_tiles = 64
warp_col_tiles = 16
chunk = 32 * k_pack
chunk = 256 * k_pack
pack_size_k = micro_size_k * k_pack
shared_scope = "shared"
cache_write_shared = False
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
......@@ -68,6 +58,7 @@ def tl_matmul(
pack_size_k, micro_size_y)
else:
B_shape = (N, K) if b_transposed else (K, N)
A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K)
if b_preshuffle:
B_shared_shape = (block_N // micro_size_y, block_K // pack_size_k, micro_size_y,
......@@ -76,12 +67,6 @@ def tl_matmul(
micro_size_y)
else:
B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 64
threads = warp_size * (block_row_warps * block_col_warps)
......@@ -92,7 +77,7 @@ def tl_matmul(
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mfma_emitter = MatrixCoreIntrinEmitter(
mfma_emitter = MatrixCorePreshuffleIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
......@@ -117,7 +102,6 @@ def tl_matmul(
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
......@@ -126,12 +110,15 @@ def tl_matmul(
A_shared: make_swizzle_layout(A_shared),
})
num_ko = K // block_K
num_ki = block_K // (k_pack * micro_size_k)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=0):
for ko in T.Pipelined(num_ko, num_stages=0):
# Load A into shared memory
if a_transposed:
......@@ -140,7 +127,7 @@ def tl_matmul(
T.copy(A[by * block_M, ko * block_K], A_shared)
# Load B into shared memory
if b_preshuffle:
if b_g2l_load is False:
if b_transposed:
for j, k, jj, kk in T.Parallel(block_N // micro_size_y,
block_K // pack_size_k, micro_size_y,
......@@ -153,53 +140,37 @@ def tl_matmul(
micro_size_y):
B_shared[k, j, kk, jj] = B[ko * block_K // pack_size_k + k,
bx * block_N // micro_size_y + j, kk, jj]
else:
if b_transposed:
T.copy(B[bx * block_N, ko * block_K], B_shared)
else:
T.copy(B[ko * block_K, bx * block_N], B_shared)
for ki in T.serial(0, (block_K // (k_pack * micro_size_k))):
for ki in T.serial(0, num_ki):
# Load A into fragment
# Load A S2L
mfma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)
# Load B into fragment
mfma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
if b_g2l_load:
# Load B G2L
mfma_emitter.ldmatrix_b(B_local, B, ki + ko * num_ki, pid_m=by, pid_n=bx)
else:
# Load B S2L
mfma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_local, C_local)
# Perform STMatrix
if cache_write_shared:
mfma_emitter.stmatrix(
C_local,
C_shared,
)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
else:
mfma_emitter.stmatrix(
C_local,
C,
pid_m=by,
pid_n=bx,
)
mfma_emitter.stmatrix(
C_local,
C,
pid_m=by,
pid_n=bx,
)
return main
......@@ -232,9 +203,10 @@ def assert_tl_matmul_correctness(M,
a_transposed=False,
b_transposed=True,
k_pack=1,
b_preshuffle=False):
b_preshuffle=False,
b_g2l_load=False):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed,
k_pack, b_preshuffle)
k_pack, b_preshuffle, b_g2l_load)
print(matmul)
kernel = tilelang.compile(matmul)
src_code = kernel.get_kernel_source()
......@@ -285,30 +257,25 @@ def assert_tl_matmul_correctness(M,
print(C)
print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@tilelang.testing.requires_rocm
def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2)
assert_tl_matmul_correctness(
128, 128, 128, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
256, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
256, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True)
256, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True)
256, 256, 512, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True)
assert_tl_matmul_correctness(
128,
256,
256,
512,
"int8",
"int32",
b_transposed=False,
......
......@@ -293,52 +293,27 @@ class MatrixCoreIntrinEmitter(object):
rk=0,
):
tx, warp_n, _ = self.extract_thread_binding(thread_binding)
if is_transposed:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
warp_n * warp_col_tiles + j * micro_size_y,
rk * chunk + ki * (k_pack * micro_size_k),
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row,
r + col]
# 4 dim
if self.b_preshuffle:
if is_transposed:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
warp_n * warp_cols + j,
rk * (chunk // micro_size_k) + ki,
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r,
row,
col]
else:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
rk * (chunk // micro_size_k) + ki,
warp_n * warp_cols + j,
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r,
row,
col]
else:
if is_transposed:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
warp_n * warp_col_tiles + j * micro_size_y,
rk * chunk + ki * (k_pack * micro_size_k),
)
B_local_buf[j * k_pack * local_size_b +
local_id] = B_shared_buf[l + row, r + col]
else:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
rk * chunk + ki * (k_pack * micro_size_k),
warp_n * warp_col_tiles + j * micro_size_y,
)
B_local_buf[j * k_pack * local_size_b +
local_id] = B_shared_buf[l + row, r + col]
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
rk * chunk + ki * (k_pack * micro_size_k),
warp_n * warp_col_tiles + j * micro_size_y,
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row,
r + col]
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
......@@ -425,3 +400,210 @@ class MatrixCoreIntrinEmitter(object):
return _warp_stmatrix_global(C_local_buf, C_buf,
thread_binding) if is_global else _warp_stmatrix_shared(
C_local_buf, C_buf, thread_binding)
class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
def __init__(
self,
a_dtype: str = "float16",
b_dtype: str = "float16",
accum_dtype: str = "float16",
a_transposed: bool = False,
b_transposed: bool = False,
block_row_warps: int = 2,
block_col_warps: int = 2,
warp_row_tiles: int = 8,
warp_col_tiles: int = 8,
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
k_pack: Optional[int] = None,
is_m_first: Optional[bool] = False,
a_preshuffle: Optional[bool] = False,
b_preshuffle: Optional[bool] = False,
):
self.a_dtype = a_dtype
self.b_dtype = b_dtype
self.accum_dtype = accum_dtype
self.a_transposed = a_transposed
self.b_transposed = b_transposed
# Hint Information
self.block_row_warps = block_row_warps
self.block_col_warps = block_col_warps
self.warp_row_tiles = warp_row_tiles
self.warp_col_tiles = warp_col_tiles
self.chunk = chunk
self._initialize_k_dim(a_dtype)
self._initialize_abbrev(a_dtype, b_dtype, accum_dtype)
self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE)
self._initialize_mfma_prefix(self.k_dim)
self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim)
self._initialize_k_pack(k_pack)
self._initialize_is_m_first(is_m_first)
self._initialize_preshuffle(a_preshuffle, b_preshuffle)
self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
self.reduce_k = reduce_k
self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k)
self.num_elems_per_byte = num_elems_per_byte
def _initialize_preshuffle(self, a_preshuffle: bool, b_preshuffle: bool):
if a_preshuffle is not None:
self.a_preshuffle = a_preshuffle
if b_preshuffle is not None:
self.b_preshuffle = b_preshuffle
def ldmatrix_a(self, A_local_buf, A_buf, ki, rk=0, pid_m=None, pid_n=None):
warp_rows = self.warp_rows
chunk = self.chunk
micro_size_k = self.micro_size_k
local_size_a = self.local_size_a
k_pack = self.k_pack
is_transposed = self.a_transposed
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
_, reverse_index_map = self.get_ldmatrix_index_map(is_b=False)
is_global = pid_m is not None and pid_n is not None
# no preshuffle, use the default implementation
if self.a_preshuffle is False:
return super().ldmatrix_a(A_local_buf, A_buf, ki, rk)
def _warp_ldmatrix_a_global(
A_local_buf,
A_buf,
ki,
thread_binding,
rk=0,
):
tx, _, warp_m = self.extract_thread_binding(thread_binding)
if is_transposed:
for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
rk * (chunk // micro_size_k) + ki,
(pid_m * self.block_row_warps + warp_m) * warp_rows + i,
)
A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[l, r, row, col]
else:
for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
(pid_m * self.block_row_warps + warp_m) * warp_rows + i,
rk * (chunk // micro_size_k) + ki,
)
A_local_buf[i * k_pack * local_size_a + local_id] = A_buf[l, r, row, col]
@T.macro
def _warp_ldmatrix_a_shared(
A_local_buf,
A_shared_buf,
ki,
thread_binding,
rk=0,
):
tx, _, warp_m = self.extract_thread_binding(thread_binding)
if is_transposed:
for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
rk * (chunk // micro_size_k) + ki,
warp_m * warp_rows + i,
)
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row,
col]
else:
print(self.a_preshuffle)
for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (warp_m * warp_rows + i, rk * (chunk // micro_size_k) + ki)
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row,
col]
return _warp_ldmatrix_a_global(A_local_buf, A_buf, ki, thread_binding,
rk) if is_global else _warp_ldmatrix_a_shared(
A_local_buf, A_buf, ki, thread_binding, rk)
def ldmatrix_b(self, B_local_buf, B_buf, ki, rk=0, pid_m=None, pid_n=None):
warp_cols = self.warp_cols
chunk = self.chunk
micro_size_k = self.micro_size_k
local_size_b = self.local_size_b
k_pack = self.k_pack
is_transposed = self.b_transposed
current_frame = T.KernelLaunchFrame.Current()
thread_binding = current_frame.get_thread_binding()
_, reverse_index_map = self.get_ldmatrix_index_map(is_b=True)
is_global = pid_m is not None and pid_n is not None
if self.b_preshuffle is False:
return super().ldmatrix_b(B_local_buf, B_buf, ki, rk, pid_m, pid_n)
@T.macro
def _warp_ldmatrix_b_global(
B_local_buf,
B_buf,
ki,
thread_binding,
rk=0,
):
tx, warp_n, _ = self.extract_thread_binding(thread_binding)
if is_transposed:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
(pid_n * self.block_col_warps + warp_n) * warp_cols + j,
rk * (chunk // micro_size_k) + ki,
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[l, r, row, col]
else:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
rk * (chunk // micro_size_k) + ki,
(pid_n * self.block_col_warps + warp_n) * warp_cols + j,
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_buf[l, r, row, col]
@T.macro
def _warp_ldmatrix_b_shared(
B_local_buf,
B_shared_buf,
ki,
thread_binding,
rk=0,
):
tx, warp_n, _ = self.extract_thread_binding(thread_binding)
if is_transposed:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
warp_n * warp_cols + j,
rk * (chunk // micro_size_k) + ki,
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row,
col]
else:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
rk * (chunk // micro_size_k) + ki,
warp_n * warp_cols + j,
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row,
col]
return _warp_ldmatrix_b_global(B_local_buf, B_buf, ki, thread_binding,
rk) if is_global else _warp_ldmatrix_b_shared(
B_local_buf, B_buf, ki, thread_binding, rk)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment