Commit e2bc1cb6 authored by Yuxuan Hu's avatar Yuxuan Hu Committed by LeiWang1999
Browse files

[Bugfix] Fix `K // block_K` to T.ceildiv(K,block_K) and add tests (#210)

parent 227ed7ec
......@@ -119,7 +119,7 @@ def tl_matmul(
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
......@@ -182,7 +182,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(compressed_A, compressed_B, C)
print(C)
latency = mod.do_bench(mod.func, warmup=25, profiler="tvm")
latency = mod.do_bench(mod.func, warmup=25)
print(latency)
# Ensure that the latency is not None
assert latency is not None
......@@ -194,6 +194,11 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
def test_assert_tl_matmul_correctness():
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32")
assert_tl_matmul_correctness(128, 128, 64, "int8", "int32", "int32")
@simplify_prim_func
def tl_matmul_weight_only_transform(
M,
......@@ -302,7 +307,7 @@ def tl_matmul_weight_only_transform(
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment