Commit 34a94d42 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Typo] Fix a typo in gemm splitk examples (#111)

parent 5cea760c
......@@ -18,9 +18,9 @@ def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_d
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_K), dtype, "shared")
B_shared = T.alloc_shared((block_K, block_N), dtype, "shared")
C_shared = T.alloc_shared((block_M, block_N), dtype, "shared")
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
if bz == 0:
......@@ -42,9 +42,9 @@ def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_d
m, n = by * block_M + i, bx * block_N + j * 2
# vectorized atomic
T.atomic_addx2(C[m, n], C_shared[i, j * 2])
else:
for i, j in T.Parallel(block_M, block_N):
T.atomic_add(C[by * block_M + i, bx * block_N + j], C_shared[i, j])
else:
for i, j in T.Parallel(block_M, block_N):
T.atomic_add(C[by * block_M + i, bx * block_N + j], C_shared[i, j])
return main
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment