Commit d2f59cfa authored by Chunan Zeng's avatar Chunan Zeng Committed by LeiWang1999
Browse files

Support block_N sizes that are 2^n in deepgemm example (#319)

parent 853898a7
......@@ -13,6 +13,7 @@ def tl_gemm(
M,
N,
K,
block_N,
in_dtype,
out_dtype,
accum_dtype,
......@@ -25,15 +26,14 @@ def tl_gemm(
"float32",
], "Currently only float16 and float32 are supported"
TILE_SIZE = (128, 128, 128)
block_M = TILE_SIZE[0]
block_N = TILE_SIZE[1]
block_K = TILE_SIZE[2]
group_size = 128
block_M = 128
block_K = 128
A_shape = (M, K)
Scales_A_shape = (M, T.ceildiv(K, block_K))
Scales_A_shape = (M, T.ceildiv(K, group_size))
B_shape = (N, K)
Scales_B_shape = (T.ceildiv(N, block_N), T.ceildiv(K, block_K))
Scales_B_shape = (T.ceildiv(N, group_size), T.ceildiv(K, group_size))
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (block_M, block_N)
......@@ -67,7 +67,7 @@ def tl_gemm(
# Load B into shared memory
T.copy(B[bx * block_N, k * block_K], B_shared)
# Load scale into shared memory
Scale_B = scales_b[bx, k]
Scale_B = scales_b[bx * block_N // group_size, k]
for i in T.Parallel(block_M):
Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B
......@@ -181,4 +181,5 @@ def assert_tl_gemm_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
if __name__ == "__main__":
for dtype in ["e4m3_float8"]:
for out_dtype in ["bfloat16", "float32"]:
assert_tl_gemm_correctness(1024, 1024, 8192, dtype, out_dtype, "float32")
for block_N in [16, 32, 64, 128]:
assert_tl_gemm_correctness(1024, 1024, 8192, block_N, dtype, out_dtype, "float32")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment