"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "23071a560eda671f8324337c6688a74d9896c47e"
Unverified Commit 47039f06 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Language] Refactor reduce and support shared memory as its in/out (#1219)

* [Refactor] Update ReduceOpNode to use absolute values in Max computation and remove unused shared memory reduction logic

* Changed Max computation for AbsMax type to use absolute values of lhs and rhs.
* Removed unused shared memory reduction logic and related checks for buffer dimensions and thread extents, simplifying the Lower method.
* Added a fatal log for unsupported buffer scope reductions.

* reduce fix

* [Fix] Update type check for eval value in Builder class

* Changed the type check for eval values to raise a TypeError for unsupported types, specifically excluding instances of tvm.tir.Buffer. This improves error handling and clarity in the Builder class.
parent 2957afca
...@@ -104,7 +104,7 @@ PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &lhs, ...@@ -104,7 +104,7 @@ PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &lhs,
} else if (type->isMin()) { } else if (type->isMin()) {
return Min(lhs, rhs); return Min(lhs, rhs);
} else if (type->isAbsMax()) { } else if (type->isAbsMax()) {
return Max(Max(lhs, rhs), -Min(lhs, rhs)); return Max(tvm::abs(lhs), tvm::abs(rhs));
} else if (type->isBitAnd()) { } else if (type->isBitAnd()) {
return lhs & rhs; return lhs & rhs;
} else if (type->isBitOr()) { } else if (type->isBitOr()) {
...@@ -360,70 +360,6 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -360,70 +360,6 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return body; return body;
} }
auto is_shared_scope = [](const std::string &scope) {
return scope == "shared" || scope == "shared.dyn";
};
if (is_shared_scope(src_scope) && is_shared_scope(dst_scope)) {
Buffer src_buffer = get_buffer(this->src);
Buffer dst_buffer = get_buffer(this->dst);
size_t src_dim = src_buffer->shape.size();
size_t dst_dim = dst_buffer->shape.size();
bool is_1d_reduce = (src_dim == dst_dim && dst_dim == 1);
if (!is_1d_reduce) {
ICHECK_EQ(src_dim, dst_dim + 1) << "Reduce dimension mismatch.";
} else {
ICHECK_EQ(dst_dim, 1U) << "Expect scalar layout for 1D reduce.";
}
auto thread_extent = as_const_int(T.thread_bounds->extent);
ICHECK(thread_extent)
<< "Shared-memory reduce requires static thread extent.";
int threads = *thread_extent;
if (TargetIsCuda(T.target)) {
ICHECK_EQ(threads % 32, 0)
<< "Shared reduce expects blockDim.x to be a multiple of 32 on CUDA.";
} else if (TargetIsRocm(T.target)) {
ICHECK_EQ(threads % 64, 0)
<< "Shared reduce expects blockDim.x to be a multiple of 64 on HIP.";
}
bool use_abs = this->type->isAbsSum() || this->type->isAbsMax();
bool need_accumulate =
(!this->clear) && (this->type->isSum() || this->type->isAbsSum() ||
this->type->isBitAnd() || this->type->isBitOr() ||
this->type->isBitXor());
PrimExpr reduce_extent = src_buffer->shape[this->dim];
PrimExpr tail_extent = make_const(DataType::Int(32), 1);
for (size_t i = this->dim + 1; i < src_dim; ++i) {
tail_extent = analyzer->Simplify(tail_extent * src_buffer->shape[i]);
}
PrimExpr total_dest = make_const(DataType::Int(32), 1);
for (size_t i = 0; i < dst_dim; ++i) {
total_dest = analyzer->Simplify(total_dest * dst_buffer->shape[i]);
}
std::stringstream ss;
std::string reducer = this->MakeCodegenReducer();
ss << "tl::SharedReduceWarp<" << reducer << ", " << threads << ", "
<< (use_abs ? "true" : "false") << ", "
<< (need_accumulate ? "true" : "false") << ">::run";
Array<PrimExpr> call_args = {StringImm(ss.str()),
src_buffer.access_ptr(1),
dst_buffer.access_ptr(3),
cast(DataType::Int(32), total_dest),
cast(DataType::Int(32), reduce_extent),
cast(DataType::Int(32), tail_extent),
this->MakeInitValue()};
return Evaluate(Call(dst_buffer->dtype, builtin::call_extern(), call_args));
}
LOG(FATAL) << "Reduce for buffers in scope (" << src_scope << ", " LOG(FATAL) << "Reduce for buffers in scope (" << src_scope << ", "
<< dst_scope << ") is not implemented."; << dst_scope << ") is not implemented.";
return Stmt(); return Stmt();
......
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
from __future__ import annotations from __future__ import annotations
from tvm import tir from tvm import tir
from tilelang.language import copy, macro, alloc_shared from tilelang.language import copy, macro, alloc_shared, alloc_fragment
from tilelang.utils.language import is_shared, is_fragment
from tvm.script.ir_builder import IRBuilder
def _legalize_dim(buffer: tir.Buffer, dim: int): def _legalize_dim(buffer: tir.Buffer, dim: int):
...@@ -34,17 +36,70 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea ...@@ -34,17 +36,70 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
raise ValueError( raise ValueError(
f"Invalid reduce output shape, buffer shape is {buffer.shape}, dim is {dim}, " f"Invalid reduce output shape, buffer shape is {buffer.shape}, dim is {dim}, "
f"output shape is {out.shape}, expected shapes are {expected_shapes_str}") f"output shape is {out.shape}, expected shapes are {expected_shapes_str}")
buffer = buffer.access_ptr("r")
out = out.access_ptr("w") @macro
return tir.call_intrin( def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool):
"handle", if is_shared(buffer) and is_shared(out):
tir.op.Op.get("tl.reduce"), red_frag_in = alloc_fragment(buffer.shape, buffer.dtype)
buffer, red_frag_out = alloc_fragment(out.shape, out.dtype)
out,
reduce_type, # rename buffers
dim, IRBuilder.name(buffer.name + "_frag", red_frag_in)
clear, IRBuilder.name(out.name + "_frag", red_frag_out)
)
copy(buffer, red_frag_in)
tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
red_frag_in.access_ptr("r"),
red_frag_out.access_ptr("w"),
reduce_type,
dim,
clear,
)
copy(red_frag_out, out)
elif is_shared(buffer) and is_fragment(out):
red_frag_in = alloc_fragment(buffer.shape, buffer.dtype)
IRBuilder.name(buffer.name + "_frag", red_frag_in)
copy(buffer, red_frag_in)
tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
red_frag_in.access_ptr("r"),
out.access_ptr("w"),
reduce_type,
dim,
clear,
)
elif is_fragment(buffer) and is_shared(out):
red_frag_out = alloc_fragment(out.shape, out.dtype)
IRBuilder.name(out.name + "_frag", red_frag_out)
tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
buffer.access_ptr("r"),
red_frag_out.access_ptr("w"),
reduce_type,
dim,
clear,
)
copy(red_frag_out, out)
elif is_fragment(buffer) and is_fragment(out):
tir.call_intrin(
"handle",
tir.op.Op.get("tl.reduce"),
buffer.access_ptr("r"),
out.access_ptr("w"),
reduce_type,
dim,
clear,
)
else:
raise ValueError(f"Invalid buffer scopes: {buffer.scope()} and {out.scope()}")
return reduce_macro(buffer, out, reduce_type, dim, clear)
def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True): def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):
......
...@@ -245,7 +245,7 @@ class Builder(BaseBuilder): ...@@ -245,7 +245,7 @@ class Builder(BaseBuilder):
pass pass
elif isinstance(val, tvm.tir.stmt.BufferStore): elif isinstance(val, tvm.tir.stmt.BufferStore):
tir.buffer_store(val.buffer, val.value, val.indices, val.predicate) tir.buffer_store(val.buffer, val.value, val.indices, val.predicate)
else: elif not isinstance(val, tvm.tir.Buffer):
raise TypeError(f"Unsupported eval value: {val} of type {type(val)}") raise TypeError(f"Unsupported eval value: {val} of type {type(val)}")
def ctx_for(self, it): def ctx_for(self, it):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment