"vscode:/vscode.git/clone" did not exist on "e24e54fdfac69df29711c3bf99e88e624395bb13"
Commit cffcf1c2 authored by Zhengju Tang's avatar Zhengju Tang Committed by LeiWang1999
Browse files

[BugFix] Address should aligned with access size in tail split (#401)

* [BugFix] Address should aligned with access size in tail split

* Lint

* Lint
parent 6c63bb40
...@@ -374,6 +374,13 @@ private: ...@@ -374,6 +374,13 @@ private:
if (!dynamic_) { if (!dynamic_) {
return fnode; return fnode;
} }
if (!disable_dynamic_tail_split) {
// To handle the fact that cp.async only support address aligned with
// access size
vector_size_ = 1;
}
ICHECK(extent % vector_size_ == 0) ICHECK(extent % vector_size_ == 0)
<< "extent: " << extent << " vector_size_: " << vector_size_; << "extent: " << extent << " vector_size_: " << vector_size_;
ICHECK(is_zero(fnode->min)); ICHECK(is_zero(fnode->min));
...@@ -424,7 +431,7 @@ private: ...@@ -424,7 +431,7 @@ private:
} }
const ForNode *inner_for_; const ForNode *inner_for_;
const int vector_size_; int vector_size_;
const PrimExpr condition_; const PrimExpr condition_;
const bool dynamic_; const bool dynamic_;
const bool disable_dynamic_tail_split_; const bool disable_dynamic_tail_split_;
...@@ -466,7 +473,9 @@ private: ...@@ -466,7 +473,9 @@ private:
disable_dynamic_tail_split_); disable_dynamic_tail_split_);
NestedLoopChecker checker; NestedLoopChecker checker;
int nest_num = checker.GetNestLoopNum(for_node); int nest_num = checker.GetNestLoopNum(for_node);
if (nest_num > 1) { // only rewrite the innermost loop if (nest_num > 1 ||
for_node->kind == ForKind::kVectorized) { // only rewrite the innermost
// non-vectorized loop
return for_node; return for_node;
} }
int vectorize_hint = res.vector_size; int vectorize_hint = res.vector_size;
......
...@@ -3,7 +3,6 @@ import torch.backends ...@@ -3,7 +3,6 @@ import torch.backends
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
import tilelang.language as T import tilelang.language as T
import pytest
tilelang.testing.set_random_seed(0) tilelang.testing.set_random_seed(0)
tilelang.disable_cache() tilelang.disable_cache()
...@@ -449,14 +448,12 @@ def assert_tl_matmul_block_dynamic_mnk( ...@@ -449,14 +448,12 @@ def assert_tl_matmul_block_dynamic_mnk(
print(f"Dynamic MNK Latency with pass_configs: {pass_configs} is {latency} ms") print(f"Dynamic MNK Latency with pass_configs: {pass_configs} is {latency} ms")
@pytest.mark.skip("Skip static test") def run_assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K):
def test_assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K, False, False, "float16", assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K, False, False, "float16",
"float16", "float32") "float16", "float32")
@pytest.mark.skip("Skip dynamic m test") def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
def test_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_dynamic_m( assert_tl_matmul_block_dynamic_m(
M, M,
N, N,
...@@ -488,8 +485,7 @@ def test_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K): ...@@ -488,8 +485,7 @@ def test_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
pass_configs={"tl.disable_dynamic_tail_split": False}) pass_configs={"tl.disable_dynamic_tail_split": False})
@pytest.mark.skip("Skip dynamic mn test") def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
def test_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_dynamic_mn( assert_tl_matmul_block_dynamic_mn(
M, M,
N, N,
...@@ -521,8 +517,7 @@ def test_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K): ...@@ -521,8 +517,7 @@ def test_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
pass_configs={"tl.disable_dynamic_tail_split": False}) pass_configs={"tl.disable_dynamic_tail_split": False})
@pytest.mark.skip("Skip dynamic mnk test") def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
def test_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_dynamic_mnk( assert_tl_matmul_block_dynamic_mnk(
M, M,
N, N,
...@@ -537,7 +532,7 @@ def test_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K): ...@@ -537,7 +532,7 @@ def test_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
"float32", "float32",
pass_configs={ pass_configs={
"tl.disable_dynamic_tail_split": True, "tl.disable_dynamic_tail_split": True,
"tl.dynamic_alignment": 8 "tl.dynamic_alignment": 4
}) })
assert_tl_matmul_block_dynamic_mnk( assert_tl_matmul_block_dynamic_mnk(
M, M,
...@@ -555,10 +550,10 @@ def test_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K): ...@@ -555,10 +550,10 @@ def test_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
def test_all(): def test_all():
test_assert_tl_matmul_block_static(16384, 16384, 16384, 128, 128, 32) run_assert_tl_matmul_block_static(16384, 16384, 16384, 128, 128, 32)
test_assert_tl_matmul_block_dynamic_m(16384, 16384, 16384, 128, 128, 32) run_assert_tl_matmul_block_dynamic_m(16384, 16384, 16384, 128, 128, 32)
test_assert_tl_matmul_block_dynamic_mn(16384, 16384, 16384, 128, 128, 32) run_assert_tl_matmul_block_dynamic_mn(16384, 16384, 16384, 128, 128, 32)
test_assert_tl_matmul_block_dynamic_mnk(16384, 16384, 16384, 128, 128, 32) run_assert_tl_matmul_block_dynamic_mnk(16384, 16384, 16384, 128, 128, 32)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment