Commit 34a94d42 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Typo] Fix a typo in gemm splitk examples (#111)

parent 5cea760c
...@@ -18,9 +18,9 @@ def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_d ...@@ -18,9 +18,9 @@ def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_d
): ):
with T.Kernel( with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_K), dtype, "shared") A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype, "shared") B_shared = T.alloc_shared((block_K, block_N), dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype, "shared") C_shared = T.alloc_shared((block_M, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
if bz == 0: if bz == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment