Commit cffcf1c2 authored by Zhengju Tang's avatar Zhengju Tang Committed by LeiWang1999
Browse files

[BugFix] Address should aligned with access size in tail split (#401)

* [BugFix] Address should aligned with access size in tail split

* Lint

* Lint
parent 6c63bb40
......@@ -374,6 +374,13 @@ private:
if (!dynamic_) {
return fnode;
}
if (!disable_dynamic_tail_split) {
// To handle the fact that cp.async only support address aligned with
// access size
vector_size_ = 1;
}
ICHECK(extent % vector_size_ == 0)
<< "extent: " << extent << " vector_size_: " << vector_size_;
ICHECK(is_zero(fnode->min));
......@@ -424,7 +431,7 @@ private:
}
const ForNode *inner_for_;
const int vector_size_;
int vector_size_;
const PrimExpr condition_;
const bool dynamic_;
const bool disable_dynamic_tail_split_;
......@@ -466,7 +473,9 @@ private:
disable_dynamic_tail_split_);
NestedLoopChecker checker;
int nest_num = checker.GetNestLoopNum(for_node);
if (nest_num > 1) { // only rewrite the innermost loop
if (nest_num > 1 ||
for_node->kind == ForKind::kVectorized) { // only rewrite the innermost
// non-vectorized loop
return for_node;
}
int vectorize_hint = res.vector_size;
......
......@@ -3,7 +3,6 @@ import torch.backends
from tilelang import tvm as tvm
import tilelang.testing
import tilelang.language as T
import pytest
tilelang.testing.set_random_seed(0)
tilelang.disable_cache()
......@@ -449,14 +448,12 @@ def assert_tl_matmul_block_dynamic_mnk(
print(f"Dynamic MNK Latency with pass_configs: {pass_configs} is {latency} ms")
@pytest.mark.skip("Skip static test")
def test_assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K):
def run_assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K, False, False, "float16",
"float16", "float32")
@pytest.mark.skip("Skip dynamic m test")
def test_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_dynamic_m(
M,
N,
......@@ -488,8 +485,7 @@ def test_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
pass_configs={"tl.disable_dynamic_tail_split": False})
@pytest.mark.skip("Skip dynamic mn test")
def test_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_dynamic_mn(
M,
N,
......@@ -521,8 +517,7 @@ def test_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
pass_configs={"tl.disable_dynamic_tail_split": False})
@pytest.mark.skip("Skip dynamic mnk test")
def test_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_dynamic_mnk(
M,
N,
......@@ -537,7 +532,7 @@ def test_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
"float32",
pass_configs={
"tl.disable_dynamic_tail_split": True,
"tl.dynamic_alignment": 8
"tl.dynamic_alignment": 4
})
assert_tl_matmul_block_dynamic_mnk(
M,
......@@ -555,10 +550,10 @@ def test_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
def test_all():
test_assert_tl_matmul_block_static(16384, 16384, 16384, 128, 128, 32)
test_assert_tl_matmul_block_dynamic_m(16384, 16384, 16384, 128, 128, 32)
test_assert_tl_matmul_block_dynamic_mn(16384, 16384, 16384, 128, 128, 32)
test_assert_tl_matmul_block_dynamic_mnk(16384, 16384, 16384, 128, 128, 32)
run_assert_tl_matmul_block_static(16384, 16384, 16384, 128, 128, 32)
run_assert_tl_matmul_block_dynamic_m(16384, 16384, 16384, 128, 128, 32)
run_assert_tl_matmul_block_dynamic_mn(16384, 16384, 16384, 128, 128, 32)
run_assert_tl_matmul_block_dynamic_mnk(16384, 16384, 16384, 128, 128, 32)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment