Commit 3c5190e0 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Simplify vectorization process in loop_vectorize.cc and add...

[Enhancement] Simplify vectorization process in loop_vectorize.cc and add atomic add test (#436) (#439)

* Removed redundant simplification step in vectorization logic to streamline performance.
* Introduced a new test for atomic addition in TileLang, validating functionality with a reference implementation using PyTorch.
parent bfb5b0a3
...@@ -244,9 +244,7 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, ...@@ -244,9 +244,7 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size,
PrimExpr expr_simplified = analyzer->Simplify(expr_transformed); PrimExpr expr_simplified = analyzer->Simplify(expr_transformed);
Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size)); Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size));
PrimExpr expr_vectorized = PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed);
analyzer->Simplify(vectorizer.VisitExpr(expr_transformed));
auto ramp_node = expr_vectorized.as<RampNode>(); auto ramp_node = expr_vectorized.as<RampNode>();
if (!ramp_node) { if (!ramp_node) {
// Broadcast value // Broadcast value
......
import tilelang.testing
import tilelang.language as T
def atomic_add_program(K, M, N, block_M, block_N, dtype="float"):
@T.prim_func
def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.copy(A[bz, bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N],
A_shared)
for i, j in T.Parallel(block_M, block_N):
T.atomic_add(B[bx * block_M + i, by * block_N + j], A_shared[i, j])
return atomic_add
def run_atomic_add(K, M, N, block_M, block_N, dtype="float32"):
program = atomic_add_program(K, M, N, block_M, block_N, dtype=dtype)
kernel = tilelang.compile(program)
# print(kernel.get_kernel_source())
import torch
def ref_program(A, B):
for k in range(K):
for i in range(M):
for j in range(N):
B[i, j] += A[k, i, j]
A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda()
B = torch.zeros(M, N, dtype=getattr(torch, dtype)).cuda()
ref_B = B.clone()
ref_program(A, ref_B)
kernel(A, B)
torch.testing.assert_close(B, ref_B)
def test_atomic_add():
run_atomic_add(8, 128, 128, 32, 32)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment