Commit d2f59cfa authored by Chunan Zeng's avatar Chunan Zeng Committed by LeiWang1999
Browse files

Support block_N sizes that are 2^n in deepgemm example (#319)

parent 853898a7
...@@ -13,6 +13,7 @@ def tl_gemm( ...@@ -13,6 +13,7 @@ def tl_gemm(
M, M,
N, N,
K, K,
block_N,
in_dtype, in_dtype,
out_dtype, out_dtype,
accum_dtype, accum_dtype,
...@@ -25,15 +26,14 @@ def tl_gemm( ...@@ -25,15 +26,14 @@ def tl_gemm(
"float32", "float32",
], "Currently only float16 and float32 are supported" ], "Currently only float16 and float32 are supported"
TILE_SIZE = (128, 128, 128) group_size = 128
block_M = TILE_SIZE[0] block_M = 128
block_N = TILE_SIZE[1] block_K = 128
block_K = TILE_SIZE[2]
A_shape = (M, K) A_shape = (M, K)
Scales_A_shape = (M, T.ceildiv(K, block_K)) Scales_A_shape = (M, T.ceildiv(K, group_size))
B_shape = (N, K) B_shape = (N, K)
Scales_B_shape = (T.ceildiv(N, block_N), T.ceildiv(K, block_K)) Scales_B_shape = (T.ceildiv(N, group_size), T.ceildiv(K, group_size))
A_shared_shape = (block_M, block_K) A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K) B_shared_shape = (block_N, block_K)
C_shared_shape = (block_M, block_N) C_shared_shape = (block_M, block_N)
...@@ -67,7 +67,7 @@ def tl_gemm( ...@@ -67,7 +67,7 @@ def tl_gemm(
# Load B into shared memory # Load B into shared memory
T.copy(B[bx * block_N, k * block_K], B_shared) T.copy(B[bx * block_N, k * block_K], B_shared)
# Load scale into shared memory # Load scale into shared memory
Scale_B = scales_b[bx, k] Scale_B = scales_b[bx * block_N // group_size, k]
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B
...@@ -181,4 +181,5 @@ def assert_tl_gemm_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -181,4 +181,5 @@ def assert_tl_gemm_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
if __name__ == "__main__": if __name__ == "__main__":
for dtype in ["e4m3_float8"]: for dtype in ["e4m3_float8"]:
for out_dtype in ["bfloat16", "float32"]: for out_dtype in ["bfloat16", "float32"]:
assert_tl_gemm_correctness(1024, 1024, 8192, dtype, out_dtype, "float32") for block_N in [16, 32, 64, 128]:
assert_tl_gemm_correctness(1024, 1024, 8192, block_N, dtype, out_dtype, "float32")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment