Unverified Commit e84b24bc authored by Xiangwen Wang's avatar Xiangwen Wang Committed by GitHub
Browse files

[Enhancement] Improve vectorization invariant check (#1398)

* Improve loop vectorize

* Improve loop vectorize

* Improve loop vectorize

* Improve loop vectorize

* Improve loop vectorize

* Add some vectorize tests and comments
parent 6f67da84
......@@ -291,6 +291,24 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var,
if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_size_for_iter),
0))
return false;
// Check if expr is invariant within vector boundaries
// We're trying to prove the access expression A[f(var)] depends only on
// floor(var/vecsize), not on var%vecsize
// Mathematically:
// \forall var, f(floor(var/vecsize)*vecsize + var%vecsize) ==
// f(floor(var/vecsize)*vecsize + 0)
// Example: for i in T.vectorized(8):
// A[i] = B[i] * C[i//4]
// if vecsize=4, f(i)=i//4 depends only on i//4
// Therefore A[i] = B[i] * C[i//4] can be vectorized with vecsize=4
PrimExpr var_aligned =
floordiv(var, target_vectorized_size) * target_vectorized_size;
PrimExpr expr_aligned = Substitute(expr, {{var, var_aligned}});
if (analyzer->CanProveEqual(expr, expr_aligned)) {
return true;
}
auto simplified_expr = analyzer->Simplify(Substitute(expr, {{var, zero}}));
// The base offset must be divisible
if (!analyzer->CanProveEqual(FloorMod(simplified_expr, target_size_for_expr),
......
......@@ -5,7 +5,6 @@ import tilelang.language as T
@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True})
def vectorize_test(N, M, stride_A, stride_B):
assert N % 128 == 0 and M % 128 == 0
@T.prim_func
def main(
......@@ -23,6 +22,7 @@ def vectorize_test(N, M, stride_A, stride_B):
def run_vectorize(N, M, stride_A, stride_B):
assert N % 128 == 0 and M % 128 == 0
assert stride_A >= N and stride_B >= N
jit_kernel = vectorize_test(N, M, stride_A, stride_B)
......@@ -59,5 +59,62 @@ def test_vectorize():
run_vectorize(N, M, N + 8, N + 16)
@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True})
def vectorize_test_invariant_index(N, M, K):
@T.prim_func
def main(
A: T.Tensor[(N, M), "float32"], # noqa: F821
B: T.Tensor[(N, M), "float32"], # noqa: F821
C: T.Tensor[(N, M // K), "float32"], # noqa: F821
):
with T.Kernel(N // 128, threads=128) as (bx):
tx = T.get_thread_binding(0)
row = bx * 128 + tx
for col in T.vectorized(M):
B[row, col] = A[row, col] * C[row, col // K]
return main
def run_vectorize_invariant_index(N, M, K):
assert N % 128 == 0 and M % K == 0
jit_kernel = vectorize_test_invariant_index(N, M, K)
a = torch.randn(N, M, device="cuda", dtype=torch.float32)
b = torch.zeros(N, M, device="cuda", dtype=torch.float32)
c = torch.randn(N, M // K, device="cuda", dtype=torch.float32)
jit_kernel(a, b, c)
indices = torch.arange(a.size(1)) // K
ret = a * c[:, indices]
torch.testing.assert_close(b, ret, atol=1e-8, rtol=1e-8)
code = jit_kernel.get_kernel_source()
vectorize_size = 1
while vectorize_size <= 2 and K % (vectorize_size * 2) == 0:
vectorize_size *= 2
if vectorize_size == 4:
assert "float4" in code
elif vectorize_size == 2:
assert "float2" in code
def test_vectorize_invariant_index():
N, M = 512, 256
run_vectorize_invariant_index(N, M, 2)
run_vectorize_invariant_index(N, M, 4)
run_vectorize_invariant_index(N, M * 3, 6)
run_vectorize_invariant_index(N, M, 8)
run_vectorize_invariant_index(N, M * 3, 12)
run_vectorize_invariant_index(N, M * 7, 14)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment