Commit bb62f6bf authored by qisan's avatar qisan
Browse files

[Bugfix] Pass pre commit check

parent 667632cc
...@@ -4,7 +4,8 @@ import tilelang ...@@ -4,7 +4,8 @@ import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mmac_macro_generator import ( from tilelang.intrinsics.mmac_macro_generator import (
MatrixCoreIntrinEmitter,) MatrixCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
from tilelang import disable_cache from tilelang import disable_cache
...@@ -107,7 +108,6 @@ def tl_matmul( ...@@ -107,7 +108,6 @@ def tl_matmul(
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
...@@ -115,10 +115,12 @@ def tl_matmul( ...@@ -115,10 +115,12 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({ T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared), A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared), B_shared: make_swizzle_layout(B_shared),
}) }
)
# Improve L2 Cache # Improve L2 Cache
T.use_swizzle(panel_size=10) T.use_swizzle(panel_size=10)
...@@ -126,7 +128,6 @@ def tl_matmul( ...@@ -126,7 +128,6 @@ def tl_matmul(
T.clear(C_local) T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage): for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory # Load A into shared memory
for i, k in T.Parallel(block_M, block_K): for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k] A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
...@@ -136,7 +137,6 @@ def tl_matmul( ...@@ -136,7 +137,6 @@ def tl_matmul(
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment # Load A into fragment
mmac_emitter.ldmatrix_a(A_local, A_shared, ki) mmac_emitter.ldmatrix_a(A_local, A_shared, ki)
......
...@@ -978,7 +978,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -978,7 +978,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
// arg 11: C accumulator index // arg 11: C accumulator index
ICHECK(op->args.size() == 12U) ICHECK(op->args.size() == 12U)
<< "Invalid number of arguments for tvm_mfma"; << "Invalid number of arguments for tvm_mmac";
std::string prefix = Downcast<StringImm>(op->args[0])->value; std::string prefix = Downcast<StringImm>(op->args[0])->value;
std::string A_layout = Downcast<StringImm>(op->args[1])->value; std::string A_layout = Downcast<StringImm>(op->args[1])->value;
std::string B_layout = Downcast<StringImm>(op->args[2])->value; std::string B_layout = Downcast<StringImm>(op->args[2])->value;
......
...@@ -3,10 +3,12 @@ import tilelang.testing ...@@ -3,10 +3,12 @@ import tilelang.testing
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm import DataType from tvm import DataType
import tilelang.language as T import tilelang.language as T
# from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout # from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout
from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mmac_macro_generator import ( from tilelang.intrinsics.mmac_macro_generator import (
MatrixCoreIntrinEmitter,) MatrixCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
...@@ -111,7 +113,6 @@ def tl_matmul( ...@@ -111,7 +113,6 @@ def tl_matmul(
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
...@@ -119,10 +120,12 @@ def tl_matmul( ...@@ -119,10 +120,12 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({ T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared), A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared), B_shared: make_swizzle_layout(B_shared),
}) }
)
# Improve L2 Cache # Improve L2 Cache
T.use_swizzle(panel_size=10) T.use_swizzle(panel_size=10)
...@@ -130,7 +133,6 @@ def tl_matmul( ...@@ -130,7 +133,6 @@ def tl_matmul(
T.clear(C_local) T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=0): for ko in T.Pipelined((K // block_K), num_stages=0):
# Load A into shared memory # Load A into shared memory
if a_transposed: if a_transposed:
T.copy(A[ko * block_K, by * block_M], A_shared) T.copy(A[ko * block_K, by * block_M], A_shared)
...@@ -144,7 +146,6 @@ def tl_matmul( ...@@ -144,7 +146,6 @@ def tl_matmul(
T.copy(B[ko * block_K, bx * block_N], B_shared) T.copy(B[ko * block_K, bx * block_N], B_shared)
for ki in T.serial(0, (block_K // (k_pack * micro_size_k))): for ki in T.serial(0, (block_K // (k_pack * micro_size_k))):
# Load A into fragment # Load A into fragment
mmac_emitter.ldmatrix_a( mmac_emitter.ldmatrix_a(
A_local, A_local,
...@@ -180,17 +181,8 @@ def tl_matmul( ...@@ -180,17 +181,8 @@ def tl_matmul(
return main return main
def assert_tl_matmul_correctness(M, def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="float32", a_transposed=False, b_transposed=True, k_pack=1):
N, matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack)
K,
in_dtype,
out_dtype,
accum_dtype="float32",
a_transposed=False,
b_transposed=True,
k_pack=1):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed,
k_pack)
print(matmul) print(matmul)
kernel = tilelang.compile(matmul) kernel = tilelang.compile(matmul)
src_code = kernel.get_kernel_source() src_code = kernel.get_kernel_source()
...@@ -218,16 +210,13 @@ def assert_tl_matmul_correctness(M, ...@@ -218,16 +210,13 @@ def assert_tl_matmul_correctness(M,
if a_transposed and b_transposed: if a_transposed and b_transposed:
# Get Reference Result # Get Reference Result
ref_c = torch.matmul(A.T.to(torch.float32), ref_c = torch.matmul(A.T.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype))
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
elif a_transposed and not b_transposed: elif a_transposed and not b_transposed:
# Get Reference Result # Get Reference Result
ref_c = torch.matmul(A.Tto(torch.float32), ref_c = torch.matmul(A.Tto(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype))
B.to(torch.float32)).to(getattr(torch, out_dtype))
elif not a_transposed and b_transposed: elif not a_transposed and b_transposed:
# Get Reference Result # Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype))
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
else: else:
# Get Reference Result # Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype)) ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype))
...@@ -245,10 +234,8 @@ def test_assert_tl_matmul(): ...@@ -245,10 +234,8 @@ def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32") assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32") assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2) assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2)
assert_tl_matmul_correctness( assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32")
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32") assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2)
assert_tl_matmul_correctness(
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2)
# assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3fnuz", "float16") # assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3fnuz", "float16")
# assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32") # assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32")
# assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", k_pack=2) # assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", k_pack=2)
......
...@@ -5,7 +5,8 @@ from tvm import DataType ...@@ -5,7 +5,8 @@ from tvm import DataType
from tvm.tir import PrimExpr from tvm.tir import PrimExpr
from tvm.runtime import convert from tvm.runtime import convert
from .utils import ( from .utils import (
mfma_store_index_map,) mfma_store_index_map,
)
lift = convert lift = convert
...@@ -77,7 +78,7 @@ class MatrixCoreIntrinEmitter: ...@@ -77,7 +78,7 @@ class MatrixCoreIntrinEmitter:
self.warp_rows = warp_row_tiles // self.micro_size_x self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y self.warp_cols = warp_col_tiles // self.micro_size_y
self.reduce_k = reduce_k self.reduce_k = reduce_k
self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k) self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k
self.num_elems_per_byte = num_elems_per_byte self.num_elems_per_byte = num_elems_per_byte
def _initialize_k_dim(self, a_dtype="float16"): def _initialize_k_dim(self, a_dtype="float16"):
...@@ -107,19 +108,9 @@ class MatrixCoreIntrinEmitter: ...@@ -107,19 +108,9 @@ class MatrixCoreIntrinEmitter:
def _initialize_mmac_prefix(self, k_dim=16): def _initialize_mmac_prefix(self, k_dim=16):
in_dtype, out_dtype = self.a_dtype, self.accum_dtype in_dtype, out_dtype = self.a_dtype, self.accum_dtype
M_DIM, N_DIM = self.M_DIM, self.N_DIM M_DIM, N_DIM = self.M_DIM, self.N_DIM
out_dtype_abbrv = { out_dtype_abbrv = {"float16": "f16", "float32": "f32", "int8": "i8", "int32": "i32"}[out_dtype]
"float16": "f16",
"float32": "f32", in_dtype_abbrv = {"float16": "f16", "float32": "f32", "int8": "i8", "bfloat16": "bf16"}[in_dtype]
"int8": "i8",
"int32": "i32"
}[out_dtype]
in_dtype_abbrv = {
"float16": "f16",
"float32": "f32",
"int8": "i8",
"bfloat16": "bf16"
}[in_dtype]
self.mmac_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}" self.mmac_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}"
...@@ -167,41 +158,53 @@ class MatrixCoreIntrinEmitter: ...@@ -167,41 +158,53 @@ class MatrixCoreIntrinEmitter:
reverse_index_map = thread_id_shared_access_64x1_to_16x4_layout_A reverse_index_map = thread_id_shared_access_64x1_to_16x4_layout_A
if is_b: if is_b:
index_map = shared_16x4_to_local_64x1_layout_A if transposed else shared_4x16_to_local_64x1_layout_B index_map = shared_16x4_to_local_64x1_layout_A if transposed else shared_4x16_to_local_64x1_layout_B
reverse_index_map = thread_id_shared_access_64x1_to_16x4_layout_A if transposed else thread_id_shared_access_64x1_to_4x16_layout_B reverse_index_map = (
thread_id_shared_access_64x1_to_16x4_layout_A if transposed else thread_id_shared_access_64x1_to_4x16_layout_B
)
elif k_dim == 16: elif k_dim == 16:
index_map = shared_16x16_to_local_64x4_layout_B if transposed else shared_16x16_to_local_64x4_layout_A index_map = shared_16x16_to_local_64x4_layout_B if transposed else shared_16x16_to_local_64x4_layout_A
reverse_index_map = thread_id_shared_access_64x4_to_16x16_layout_B if transposed else thread_id_shared_access_64x4_to_16x16_layout_A reverse_index_map = (
thread_id_shared_access_64x4_to_16x16_layout_B if transposed else thread_id_shared_access_64x4_to_16x16_layout_A
)
if is_b: if is_b:
index_map = shared_16x16_to_local_64x4_layout_A if transposed else shared_16x16_to_local_64x4_layout_B index_map = shared_16x16_to_local_64x4_layout_A if transposed else shared_16x16_to_local_64x4_layout_B
reverse_index_map = thread_id_shared_access_64x4_to_16x16_layout_A if transposed else thread_id_shared_access_64x4_to_16x16_layout_B reverse_index_map = (
thread_id_shared_access_64x4_to_16x16_layout_A if transposed else thread_id_shared_access_64x4_to_16x16_layout_B
)
elif k_dim == 32: elif k_dim == 32:
index_map = shared_16x32_to_local_64x8_layout_B if transposed else shared_16x32_to_local_64x8_layout_A index_map = shared_16x32_to_local_64x8_layout_B if transposed else shared_16x32_to_local_64x8_layout_A
reverse_index_map = thread_id_shared_access_64x8_to_16x32_layout_B if transposed else thread_id_shared_access_64x8_to_16x32_layout_A reverse_index_map = (
thread_id_shared_access_64x8_to_16x32_layout_B if transposed else thread_id_shared_access_64x8_to_16x32_layout_A
)
if is_b: if is_b:
index_map = shared_16x32_to_local_64x8_layout_A if transposed else shared_16x32_to_local_64x8_layout_B index_map = shared_16x32_to_local_64x8_layout_A if transposed else shared_16x32_to_local_64x8_layout_B
reverse_index_map = thread_id_shared_access_64x8_to_16x32_layout_A if transposed else thread_id_shared_access_64x8_to_16x32_layout_B reverse_index_map = (
thread_id_shared_access_64x8_to_16x32_layout_A if transposed else thread_id_shared_access_64x8_to_16x32_layout_B
)
elif k_dim == 64: elif k_dim == 64:
index_map = shared_16x64_to_local_64x16_layout_B if transposed else shared_16x64_to_local_64x16_layout_A index_map = shared_16x64_to_local_64x16_layout_B if transposed else shared_16x64_to_local_64x16_layout_A
reverse_index_map = thread_id_shared_access_64x16_to_16x64_layout_B if transposed else thread_id_shared_access_64x16_to_16x64_layout_A reverse_index_map = (
thread_id_shared_access_64x16_to_16x64_layout_B if transposed else thread_id_shared_access_64x16_to_16x64_layout_A
)
if is_b: if is_b:
index_map = shared_16x64_to_local_64x16_layout_A if transposed else shared_16x64_to_local_64x16_layout_B index_map = shared_16x64_to_local_64x16_layout_A if transposed else shared_16x64_to_local_64x16_layout_B
reverse_index_map = thread_id_shared_access_64x16_to_16x64_layout_A if transposed else thread_id_shared_access_64x16_to_16x64_layout_B reverse_index_map = (
thread_id_shared_access_64x16_to_16x64_layout_A if transposed else thread_id_shared_access_64x16_to_16x64_layout_B
)
else: else:
raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently") raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently")
return index_map, reverse_index_map return index_map, reverse_index_map
def extract_thread_binding(self, def extract_thread_binding(self, thread_id, is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
thread_id, """
is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
'''
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)] Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]
''' """
WARP_SIZE = self.WARP_SIZE WARP_SIZE = self.WARP_SIZE
block_row_warps = self.block_row_warps block_row_warps = self.block_row_warps
block_col_warps = self.block_col_warps block_col_warps = self.block_col_warps
...@@ -211,16 +214,18 @@ class MatrixCoreIntrinEmitter: ...@@ -211,16 +214,18 @@ class MatrixCoreIntrinEmitter:
is_m_first = self.is_m_first is_m_first = self.is_m_first
if is_m_first: if is_m_first:
lane_id, warp_n, warp_m = thread_id % WARP_SIZE, ( lane_id, warp_n, warp_m = (
thread_id // thread_id % WARP_SIZE,
WARP_SIZE) % block_col_warps, (thread_id // (thread_id // WARP_SIZE) % block_col_warps,
(WARP_SIZE * block_col_warps)) % block_row_warps, (thread_id // (WARP_SIZE * block_col_warps)) % block_row_warps,
)
return lane_id, warp_n, warp_m return lane_id, warp_n, warp_m
else: else:
lane_id, warp_m, warp_n = thread_id % WARP_SIZE, ( lane_id, warp_m, warp_n = (
thread_id // thread_id % WARP_SIZE,
WARP_SIZE) % block_row_warps, (thread_id // (thread_id // WARP_SIZE) % block_row_warps,
(WARP_SIZE * block_row_warps)) % block_col_warps, (thread_id // (WARP_SIZE * block_row_warps)) % block_col_warps,
)
return lane_id, warp_n, warp_m return lane_id, warp_n, warp_m
def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, rk=0): def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, rk=0):
...@@ -249,18 +254,14 @@ class MatrixCoreIntrinEmitter: ...@@ -249,18 +254,14 @@ class MatrixCoreIntrinEmitter:
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a): for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (rk * chunk + ki * (k_pack * micro_size_k), l, r = (rk * chunk + ki * (k_pack * micro_size_k), warp_m * warp_row_tiles + i * micro_size_x)
warp_m * warp_row_tiles + i * micro_size_x) A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row, r + col]
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row,
r + col]
else: else:
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a): for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (warp_m * warp_row_tiles + i * micro_size_x, l, r = (warp_m * warp_row_tiles + i * micro_size_x, rk * chunk + ki * (k_pack * micro_size_k))
rk * chunk + ki * (k_pack * micro_size_k)) A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row, r + col]
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row,
r + col]
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
...@@ -289,28 +290,22 @@ class MatrixCoreIntrinEmitter: ...@@ -289,28 +290,22 @@ class MatrixCoreIntrinEmitter:
if is_transposed: if is_transposed:
for j in T.serial(warp_cols): for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b): for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var( row, col = T.meta_var(reverse_index_map((tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16, local_id))
reverse_index_map((tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16,
local_id))
l, r = ( l, r = (
warp_n * warp_col_tiles + j * micro_size_y, warp_n * warp_col_tiles + j * micro_size_y,
rk * chunk + ki * (k_pack * micro_size_k), rk * chunk + ki * (k_pack * micro_size_k),
) )
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row, B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row, r + col]
r + col]
else: else:
for j in T.serial(warp_cols): for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b): for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var( row, col = T.meta_var(reverse_index_map((tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16, local_id))
reverse_index_map((tx & 15) // 4 + (tx & 3) * 4 + (tx // 16) * 16,
local_id))
l, r = ( l, r = (
rk * chunk + ki * (k_pack * micro_size_k), rk * chunk + ki * (k_pack * micro_size_k),
warp_n * warp_col_tiles + j * micro_size_y, warp_n * warp_col_tiles + j * micro_size_y,
) )
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row, B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row, r + col]
r + col]
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk) return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
...@@ -374,14 +369,13 @@ class MatrixCoreIntrinEmitter: ...@@ -374,14 +369,13 @@ class MatrixCoreIntrinEmitter:
for local_id in T.vectorized(local_size_out): for local_id in T.vectorized(local_size_out):
row, col = T.meta_var(mfma_store_index_map(tx, local_id)) row, col = T.meta_var(mfma_store_index_map(tx, local_id))
if C_buf_dims == 2: if C_buf_dims == 2:
C_buf[(warp_m * warp_rows + i) * M_DIM + row, C_buf[(warp_m * warp_rows + i) * M_DIM + row, (warp_n * warp_cols + j) * N_DIM + col] = C_local_buf[
(warp_n * warp_cols + j) * N_DIM + j * (warp_rows * local_size_out) + i * local_size_out + local_id
col] = C_local_buf[j * (warp_rows * local_size_out) + ]
i * local_size_out + local_id]
else: else:
C_buf[warp_n * warp_cols + j, warp_m * warp_rows + i, row, C_buf[warp_n * warp_cols + j, warp_m * warp_rows + i, row, col] = C_local_buf[
col] = C_local_buf[j * warp_rows * local_size_out + j * warp_rows * local_size_out + i * local_size_out + local_id
i * local_size_out + local_id] ]
@T.macro @T.macro
def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding): def _warp_stmatrix_global(C_local_buf, C_buf, thread_binding):
...@@ -389,18 +383,18 @@ class MatrixCoreIntrinEmitter: ...@@ -389,18 +383,18 @@ class MatrixCoreIntrinEmitter:
for i, j in T.grid(warp_rows, warp_cols): for i, j in T.grid(warp_rows, warp_cols):
for local_id in T.vectorized(local_size_out): for local_id in T.vectorized(local_size_out):
row, col = T.meta_var(mfma_store_index_map(tx, local_id)) row, col = T.meta_var(mfma_store_index_map(tx, local_id))
C_buf[(pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, C_buf[
(pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM + (pid_m * BLOCK_M + warp_m * warp_rows + i) * M_DIM + row, (pid_n * BLOCK_N + warp_n * warp_cols + j) * N_DIM + col
col] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + ] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + local_id]
local_id]
return (
return _warp_stmatrix_global(C_local_buf, C_buf, _warp_stmatrix_global(C_local_buf, C_buf, thread_binding)
thread_binding) if is_global else _warp_stmatrix_shared( if is_global
C_local_buf, C_buf, thread_binding) else _warp_stmatrix_shared(C_local_buf, C_buf, thread_binding)
)
class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
def __init__( def __init__(
self, self,
a_dtype: str = "float16", a_dtype: str = "float16",
...@@ -420,7 +414,6 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): ...@@ -420,7 +414,6 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
a_preshuffle: bool | None = False, a_preshuffle: bool | None = False,
b_preshuffle: bool | None = False, b_preshuffle: bool | None = False,
): ):
self.a_dtype = a_dtype self.a_dtype = a_dtype
self.b_dtype = b_dtype self.b_dtype = b_dtype
self.accum_dtype = accum_dtype self.accum_dtype = accum_dtype
...@@ -444,7 +437,7 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): ...@@ -444,7 +437,7 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
self.warp_rows = warp_row_tiles // self.micro_size_x self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y self.warp_cols = warp_col_tiles // self.micro_size_y
self.reduce_k = reduce_k self.reduce_k = reduce_k
self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k) self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k
self.num_elems_per_byte = num_elems_per_byte self.num_elems_per_byte = num_elems_per_byte
def _initialize_preshuffle(self, a_preshuffle: bool, b_preshuffle: bool): def _initialize_preshuffle(self, a_preshuffle: bool, b_preshuffle: bool):
...@@ -513,19 +506,19 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): ...@@ -513,19 +506,19 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
rk * (chunk // micro_size_k) + ki, rk * (chunk // micro_size_k) + ki,
warp_m * warp_rows + i, warp_m * warp_rows + i,
) )
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, col]
col]
else: else:
for i in T.serial(warp_rows): for i in T.serial(warp_rows):
for local_id in T.vectorized(k_pack * local_size_a): for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id)) row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (warp_m * warp_rows + i, rk * (chunk // micro_size_k) + ki) l, r = (warp_m * warp_rows + i, rk * (chunk // micro_size_k) + ki)
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, col]
col]
return _warp_ldmatrix_a_global(A_local_buf, A_buf, ki, thread_binding, return (
rk) if is_global else _warp_ldmatrix_a_shared( _warp_ldmatrix_a_global(A_local_buf, A_buf, ki, thread_binding, rk)
A_local_buf, A_buf, ki, thread_binding, rk) if is_global
else _warp_ldmatrix_a_shared(A_local_buf, A_buf, ki, thread_binding, rk)
)
def ldmatrix_b(self, B_local_buf, B_buf, ki, rk=0, pid_m=None, pid_n=None): def ldmatrix_b(self, B_local_buf, B_buf, ki, rk=0, pid_m=None, pid_n=None):
warp_cols = self.warp_cols warp_cols = self.warp_cols
...@@ -582,28 +575,24 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): ...@@ -582,28 +575,24 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
if is_transposed: if is_transposed:
for j in T.serial(warp_cols): for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b): for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var( row, col = T.meta_var(reverse_index_map(((tx & 15) >> 2) + ((tx & 3) << 2) + ((tx >> 4) << 4), local_id))
reverse_index_map(((tx & 15) >> 2) + ((tx & 3) << 2) + ((tx >> 4) << 4),
local_id))
l, r = ( l, r = (
warp_n * warp_cols + j, warp_n * warp_cols + j,
rk * (chunk // micro_size_k) + ki, rk * (chunk // micro_size_k) + ki,
) )
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, col]
col]
else: else:
for j in T.serial(warp_cols): for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b): for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var( row, col = T.meta_var(reverse_index_map(((tx & 15) >> 2) + ((tx & 3) << 2) + ((tx >> 4) << 4), local_id))
reverse_index_map(((tx & 15) >> 2) + ((tx & 3) << 2) + ((tx >> 4) << 4),
local_id))
l, r = ( l, r = (
rk * (chunk // micro_size_k) + ki, rk * (chunk // micro_size_k) + ki,
warp_n * warp_cols + j, warp_n * warp_cols + j,
) )
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, col]
col]
return _warp_ldmatrix_b_global(B_local_buf, B_buf, ki, thread_binding, return (
rk) if is_global else _warp_ldmatrix_b_shared( _warp_ldmatrix_b_global(B_local_buf, B_buf, ki, thread_binding, rk)
B_local_buf, B_buf, ki, thread_binding, rk) if is_global
else _warp_ldmatrix_b_shared(B_local_buf, B_buf, ki, thread_binding, rk)
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment