Unverified Commit 1b0efb65 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Minor fix for some cases (#1278)

parent 0f980f15
...@@ -191,7 +191,7 @@ def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): ...@@ -191,7 +191,7 @@ def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
if __name__ == "__main__": if __name__ == "__main__":
# tilelang.testing.main() tilelang.testing.main()
# # Test Pass # # Test Pass
# for m in [32, 64, 128, 256]: # for m in [32, 64, 128, 256]:
...@@ -203,6 +203,16 @@ if __name__ == "__main__": ...@@ -203,6 +203,16 @@ if __name__ == "__main__":
# run_gemm(m, n, k * 3, False, True, "float16", "float", "float", m, n, k, 2, 128) # run_gemm(m, n, k * 3, False, True, "float16", "float", "float", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass") # print(f"Test {m} {n} {k} Pass")
# # Test Pass
# for m in [32, 64, 128, 256]:
# for n in [32, 64, 128]:
# for k in [16, 32, 64, 128]:
# if m in [32, 64] and (n not in [64, 128, 256]):
# continue
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float16", "float", "float", m, n, k, 2, 256)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass # # Test Pass
# for m in [32, 64, 128, 256]: # for m in [32, 64, 128, 256]:
# for n in [16, 32, 64, 128]: # for n in [16, 32, 64, 128]:
...@@ -211,16 +221,3 @@ if __name__ == "__main__": ...@@ -211,16 +221,3 @@ if __name__ == "__main__":
# continue # continue
# print(f"======================= Test {m} {n} {k} False True =============================") # print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float8_e5m2", "float", "float", m, n, k, 2, 128) # run_gemm(m, n, k * 3, False, True, "float8_e5m2", "float", "float", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
tilelang.disable_cache()
run_gemm(32, 512, 16, False, True, "float16", "float32", "float32", 32, 512, 16, 0, 128)
run_gemm(32, 512, 32, False, True, "float16", "float32", "float32", 32, 512, 32, 0, 128)
run_gemm(32, 512, 64, False, True, "float16", "float32", "float32", 32, 512, 64, 0, 128)
run_gemm(64, 512, 16, False, True, "float16", "float32", "float32", 64, 512, 16, 0, 128)
run_gemm(64, 512, 16, False, True, "float16", "float32", "float32", 32, 512, 16, 0, 128)
run_gemm(128, 512, 16, False, True, "float16", "float32", "float32", 128, 512, 16, 0, 128)
# run_gemm(64, 512, 32, False, True, "float16", "float32", "float32", 64, 512, 32, 0, 128)
# run_gemm(64, 512, 64, False, True, "float16", "float32", "float32", 64, 512, 64, 0, 128)
# run_gemm(128, 512, 16, False, True, "float16", "float32", "float32", 128, 512, 16, 0, 128)
...@@ -247,8 +247,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -247,8 +247,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
mask_zero = T.Cast("int32", 0) mask_zero = T.Cast("int32", 0)
mask0 = mask1 = mask2 = mask3 = mask_zero mask0 = mask1 = mask2 = mask3 = mask_zero
num_inst_m = 4 * self.warp_row_tiles // atom_m # TCGEN05 only has one warp group
num_inst_n = self.warp_col_tiles // atom_n num_inst_m = self.block_row_warps * self.warp_row_tiles // atom_m
num_inst_n = self.block_col_warps * self.warp_col_tiles // atom_n
# Helper to allow BufferRegion/BufferLoad as inputs # Helper to allow BufferRegion/BufferLoad as inputs
def access_ptr_from(buffer_or_load_or_region, access_type: str = "r"): def access_ptr_from(buffer_or_load_or_region, access_type: str = "r"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment