Commit e2bc1cb6 authored by Yuxuan Hu's avatar Yuxuan Hu Committed by LeiWang1999
Browse files

[Bugfix] Fix `K // block_K` to T.ceildiv(K,block_K) and add tests (#210)

parent 227ed7ec
...@@ -119,7 +119,7 @@ def tl_matmul( ...@@ -119,7 +119,7 @@ def tl_matmul(
T.clear(C_local) T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage): for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=stage):
# Load A into shared memory # Load A into shared memory
for i, k in T.Parallel(block_M, block_K): for i, k in T.Parallel(block_M, block_K):
...@@ -182,7 +182,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -182,7 +182,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(compressed_A, compressed_B, C) mod(compressed_A, compressed_B, C)
print(C) print(C)
latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") latency = mod.do_bench(mod.func, warmup=25)
print(latency) print(latency)
# Ensure that the latency is not None # Ensure that the latency is not None
assert latency is not None assert latency is not None
...@@ -194,6 +194,11 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -194,6 +194,11 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
def test_assert_tl_matmul_correctness():
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32")
assert_tl_matmul_correctness(128, 128, 64, "int8", "int32", "int32")
@simplify_prim_func @simplify_prim_func
def tl_matmul_weight_only_transform( def tl_matmul_weight_only_transform(
M, M,
...@@ -302,7 +307,7 @@ def tl_matmul_weight_only_transform( ...@@ -302,7 +307,7 @@ def tl_matmul_weight_only_transform(
T.clear(C_local) T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage): for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=stage):
# Load A into shared memory # Load A into shared memory
for i, k in T.Parallel(block_M, block_K): for i, k in T.Parallel(block_M, block_K):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment