Commit 6c737768 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Language] Support accumulative `T.reduce_sum` (#436)

* [Enhancement] Update reduce operations to support clear option in sum and abs sum (#436)

* Modified reduce_sum and reduce_absmax functions to include a clear parameter, allowing for accumulation on existing values.
* Updated ReduceOp::Lower method to handle initialization and buffer duplication based on the clear flag for sum and abs sum operations.
* Added new tests for reduce_sum and reduce_max with clear functionality to ensure correctness in various scenarios.
* Enhanced documentation for reduce functions to clarify the behavior of the clear parameter.

* lint fix

* Update tensor type annotations in test_tilelang_transform_annotate_device_regions.py from Buffer to Tensor

* Update tensor type in reduce sum tests from float16 to float32 for improved precision
parent 181267c7
......@@ -12,6 +12,7 @@
#include "../layout/utils.h"
#include "../transform/loop_partition.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace tl {
......@@ -121,10 +122,33 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Array<Stmt> stmts;
bool require_init = this->clear;
// sum op must be cleared
if (this->type == ReduceType::kSum) {
require_init = true;
} else if (this->type == ReduceType::kAbsSum) {
require_init = true;
}
Buffer clear_buffer = dst_buffer;
bool need_duplicate = false;
if (this->type == ReduceType::kSum && !this->clear) {
need_duplicate = true;
} else if (this->type == ReduceType::kAbsSum && !this->clear) {
need_duplicate = true;
}
if (need_duplicate) {
// Create a new buffer with same shape and dtype as dst_buffer
clear_buffer = decl_buffer(dst_buffer->shape, dst_buffer->dtype,
dst_buffer->name + "_clear",
GetPtrStorageScope(dst_buffer->data));
}
// make reduce-init stmt
if (this->clear)
if (require_init)
stmts.push_back(
BufferStore(dst_buffer, this->MakeInitValue(), dst_indices));
BufferStore(clear_buffer, this->MakeInitValue(), dst_indices));
// make thread-local reduce
Array<PrimExpr> src_indice_compressed;
......@@ -138,8 +162,8 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
src_var_compressed.push_back(var);
}
Stmt reduce_local = BufferStore(
dst_buffer,
this->MakeReduce(BufferLoad(dst_buffer, dst_indices),
clear_buffer,
this->MakeReduce(BufferLoad(clear_buffer, dst_indices),
BufferLoad(src_buffer, src_indice_compressed)),
dst_indices);
for (int i = src_layout->OutputDim() - 1; i >= 0; i--) {
......@@ -176,20 +200,37 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
<< reducing_threads << ", " << (*scale) << ">::run";
}
Array<PrimExpr> thread_reduce_args = {
StringImm(ss.str()), BufferLoad(dst_buffer, dst_indices)};
StringImm(ss.str()), BufferLoad(clear_buffer, dst_indices)};
if (reducing_threads >= 32) {
PrimExpr workspace = T.AddWorkspace(
*as_const_int(T.thread_bounds->extent), dst_buffer->dtype);
*as_const_int(T.thread_bounds->extent), clear_buffer->dtype);
thread_reduce_args.push_back(workspace);
}
auto call =
Call(dst_buffer->dtype, builtin::call_extern(), thread_reduce_args);
stmts.push_back(BufferStore(dst_buffer, call, dst_indices));
Call(clear_buffer->dtype, builtin::call_extern(), thread_reduce_args);
stmts.push_back(BufferStore(clear_buffer, call, dst_indices));
}
}
Stmt reduce_interthread =
BufferStore(dst_buffer, BufferLoad(dst_buffer, dst_indices), dst_indices);
Stmt reduce_interthread = BufferStore(
clear_buffer, BufferLoad(clear_buffer, dst_indices), dst_indices);
// copy clear_buffer to dst_buffer
if (need_duplicate) {
// if is reduce sum, we should add a copy from clear_buffer to dst_buffer
if (this->type == ReduceType::kSum) {
stmts.push_back(BufferStore(dst_buffer,
Add(BufferLoad(dst_buffer, dst_indices),
BufferLoad(clear_buffer, dst_indices)),
dst_indices));
} else if (this->type == ReduceType::kAbsSum) {
stmts.push_back(BufferStore(dst_buffer,
Add(BufferLoad(dst_buffer, dst_indices),
BufferLoad(clear_buffer, dst_indices)),
dst_indices));
} else {
ICHECK(false) << "Unsupported reduce type: " << (int)this->type;
}
}
// make the outer spatial loop
Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0];
for (int i = dst_layout->InputDim() - 1; i >= 0; i--) {
......@@ -198,6 +239,10 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
}
body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer, dst_layout);
if (need_duplicate) {
body = Allocate(clear_buffer->data, clear_buffer->dtype,
clear_buffer->shape, const_true(), body);
}
return body;
}
......
......@@ -389,7 +389,6 @@ private:
// Handle trailing unassigned copy stages:
// These are typically final copy operations needing post-main-stage
// insertion
auto &head_pinfo = pipeline_stage_infos.at(0);
int unassigned_order_elem = -1;
......@@ -422,7 +421,7 @@ private:
int copy_order_min = pipeline_stage_infos.size();
int non_copy_order_max = 0;
for (auto &pinfo : pipeline_stage_infos) {
if (pinfo.copy_stage) {
if (pinfo.copy_stage || pinfo.prepare_for_condition) {
copy_stage_cnt++;
copy_order_min = std::min(copy_order_min, pinfo.order);
} else {
......@@ -437,7 +436,7 @@ private:
for (auto &pinfo : pipeline_stage_infos) { // move copy to the beginning
pinfo.order =
(pinfo.order + copy_stage_at_end) % pipeline_stage_infos.size();
if (!pinfo.copy_stage)
if (!pinfo.copy_stage && !pinfo.prepare_for_condition)
pinfo.stage--;
}
}
......
......@@ -47,5 +47,46 @@ def test_reduce_max():
run_reduce_max(256, 256, "float16")
def reduce_max_test_clear(M, N, dtype="float16"):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_local = T.alloc_fragment((M, N), dtype)
B_local = T.alloc_fragment((M,), dtype)
T.copy(A, A_local)
T.fill(B_local, -T.infinity(dtype))
T.reduce_max(A_local, B_local, dim=1, clear=False)
T.copy(B_local, B)
return main
def run_reduce_max_clear(M, N, dtype="float16"):
program = reduce_max_test_clear(M, N, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
print(jit_kernel.get_kernel_source())
def ref_program(A):
return A.max(dim=1).values
import torch
dummp_A = torch.randn((M, N), dtype=getattr(torch, dtype)).cuda()
ref_out = ref_program(dummp_A)
tl_out = jit_kernel(dummp_A)
print(tl_out)
print(ref_out)
torch.testing.assert_close(tl_out, ref_out, atol=1e-2, rtol=1e-2)
def test_reduce_max_clear():
run_reduce_max_clear(256, 256, "float16")
if __name__ == "__main__":
tilelang.testing.main()
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
tilelang.testing.set_random_seed()
def reduce_sum_test(M, N, dtype="float16"):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1) as _:
A_local = T.alloc_fragment((M, N), dtype)
B_local = T.alloc_fragment((M,), dtype)
# Copy input to local
T.copy(A, A_local)
# Perform reduce_sum operation
T.reduce_sum(A_local, B_local, dim=1)
# Copy result back
T.copy(B_local, B)
return main
def run_reduce_sum(M, N, dtype="float16"):
program = reduce_sum_test(M, N, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
profiler = jit_kernel.get_profiler()
def ref_program(A):
return A.sum(dim=1)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_reduce_sum():
# Test different sizes
run_reduce_sum(256, 256)
run_reduce_sum(512, 128)
run_reduce_sum(128, 512)
# Test different dtypes
run_reduce_sum(256, 256, "float32")
run_reduce_sum(256, 256, "float16")
def reduce_sum_test_clear(M, N, dtype="float16"):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_local = T.alloc_fragment((M, N), dtype)
B_local = T.alloc_fragment((M,), dtype)
T.copy(A, A_local)
T.fill(B_local, 1)
T.reduce_sum(A_local, B_local, dim=1, clear=False)
T.copy(B_local, B)
return main
def run_reduce_sum_clear(M, N, dtype="float16"):
program = reduce_sum_test_clear(M, N, dtype)
jit_kernel = tl.compile(
program,
out_idx=-1,
pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True,
})
print(jit_kernel.get_kernel_source())
def ref_program(A):
return A.sum(dim=1) + 1
import torch
dummp_A = torch.randn((M, N), dtype=getattr(torch, dtype)).cuda()
ref_out = ref_program(dummp_A)
tl_out = jit_kernel(dummp_A)
print(tl_out)
print(ref_out)
torch.testing.assert_close(tl_out, ref_out, atol=1e-2, rtol=1e-2)
def test_reduce_sum_clear():
run_reduce_sum_clear(256, 256, "float32")
run_reduce_sum_clear(512, 128, "float32")
run_reduce_sum_clear(128, 512, "float32")
if __name__ == "__main__":
tilelang.testing.main()
......@@ -10,12 +10,12 @@ class BaseCompare(tilelang.testing.CompareBeforeAfter):
class TestAnnotateThreadExtent(BaseCompare):
"""Annotation inserted at the "thread_extent" attribute"""
def before(A: T.Buffer(16, "float32")):
def before(A: T.Tensor(16, "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
i = T.launch_thread("threadIdx.x", 16)
A[i] = 0.0
def expected(A: T.Buffer(16, "float32")):
def expected(A: T.Tensor(16, "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
T.attr(T.target("cuda"), "target", 0)
i = T.launch_thread("threadIdx.x", 16)
......@@ -25,12 +25,12 @@ class TestAnnotateThreadExtent(BaseCompare):
class TestAnnotateDeviceScope(BaseCompare):
"""Annotation inserted at the "device_scope" attribute"""
def before(A: T.Buffer(1, "float32")):
def before(A: T.Tensor(1, "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
T.attr(0, "device_scope", 0)
A[0] = 0.0
def expected(A: T.Buffer(1, "float32")):
def expected(A: T.Tensor(1, "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
T.attr(T.target("cuda"), "target", 0)
T.attr(0, "device_scope", 0)
......
......@@ -66,18 +66,28 @@ def reduce_min(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True
return reduce(buffer, out, "min", dim, clear)
def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int):
def reduce_sum(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True):
"""Perform reduce sum on input buffer, store the result to output buffer.
Args:
buffer (tir.Buffer): The input buffer
out (tir.Buffer): The output buffer
dim (int): The dimension to perform reduce on
clear (bool, optional): If True, output buffer will be cleared before reduction.
If False, results will be accumulated on existing values.
Defaults to True.
Note: When clear=True, reduce_sum will not compute directly on the output buffer. This is because
during warp reduction, the same value would be accumulated multiple times (number of threads
in the warp). Therefore, the implementation with clear=True follows these steps:
1. create a temp buffer with same shape and dtype as out
2. copy out to temp buffer
3. call reduce_sum with temp buffer and out
4. Add temp buffer to out
Returns:
tir.Call: Handle to the reduction operation
"""
return reduce(buffer, out, "sum", dim, True)
return reduce(buffer, out, "sum", dim, clear)
def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int):
......@@ -94,7 +104,7 @@ def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int):
return reduce(buffer, out, "abssum", dim, True)
def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int):
def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True):
"""Perform reduce absolute max on input buffer, store the result to output buffer.
Args:
......@@ -105,7 +115,7 @@ def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int):
Returns:
tir.Call: Handle to the reduction operation
"""
return reduce(buffer, out, "absmax", dim, True)
return reduce(buffer, out, "absmax", dim, clear)
@macro
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment