Commit c2480907 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Improve layout equality checks and error messaging (#471)

* [Refactor] Simplify buffer_region_to_tile_region function in copy.py

* Removed redundant logic for handling region extents in the buffer_region_to_tile_region function, streamlining the code for better readability and maintainability.
* Enhanced error handling by focusing on essential checks while eliminating unnecessary complexity related to variable extents.

* [Refactor] Improve layout equality checks and error messaging

* Updated the `IsEqual` method in `FragmentNode` to ensure consistent evaluation of thread ranges.
* Enhanced error messaging in `ParallelOp::InferLayout` to include source buffer information for better debugging.
* Adjusted `ReduceOp::InferLayout` to set thread range during layout condensation, improving layout inference accuracy.

* lintfix

* [Refactor] Rename SetThreadRange to BindThreadRange for clarity

* Updated the `SetThreadRange` method in `FragmentNode` and related classes to `BindThreadRange`, improving method naming consistency and clarity.
* Adjusted all references to the renamed method across the codebase, ensuring proper functionality and maintaining existing behavior.
* Enhanced layout equality checks to handle thread ranges more robustly in `IsEqual` method.
* Updated layout inference methods in `Gemm`, `ParallelOp`, and `ReduceOp` to utilize the new method name, ensuring seamless integration with the updated API.

* [Refactor] Update BindThreadRange usage across layout inference methods

* Modified the implementation of `BindThreadRange` in `FragmentNode` to create a new object instance, enhancing thread range binding functionality.
* Updated all references to `BindThreadRange` in layout inference methods across `Gemm`, `ParallelOp`, and `ReduceOp` to ensure consistency with the new implementation.
* Adjusted the return statements in various layout inference functions to utilize the updated method, maintaining existing behavior while improving clarity.

* lint fix
parent 273be768
...@@ -201,9 +201,10 @@ Fragment FragmentNode::DeReplicate() const { ...@@ -201,9 +201,10 @@ Fragment FragmentNode::DeReplicate() const {
int(*rep_size) / factor, NullOpt); int(*rep_size) / factor, NullOpt);
} }
Fragment FragmentNode::SetThreadRange(Range thread_range) { Fragment FragmentNode::BindThreadRange(Range thread_range) const {
thread_range_ = thread_range; auto n = make_object<FragmentNode>(*this);
return GetRef<Fragment>(this); n->thread_range_ = thread_range;
return Fragment(n);
} }
Layout LayoutNode::Inverse() const { Layout LayoutNode::Inverse() const {
...@@ -415,11 +416,13 @@ bool FragmentNode::IsEqual(const FragmentNode *other, bool skip_index) const { ...@@ -415,11 +416,13 @@ bool FragmentNode::IsEqual(const FragmentNode *other, bool skip_index) const {
// a[i, j] = b[j, i] in register level. // a[i, j] = b[j, i] in register level.
bool ret = StructuralEqual()(this->InputShape(), other->InputShape()); bool ret = StructuralEqual()(this->InputShape(), other->InputShape());
ret &= StructuralEqual()(this->ThreadRange(), other->ThreadRange());
if (!ret) { if (!ret) {
// may be broadcast case // may be broadcast case
return true; return true;
} }
if (this->thread_range_.defined() && other->thread_range_.defined()) {
ret &= StructuralEqual()(this->thread_range_, other->thread_range_);
}
ret &= StructuralEqual()(this->OutputShape(), other->OutputShape()); ret &= StructuralEqual()(this->OutputShape(), other->OutputShape());
ret &= StructuralEqual()(this->ReplicateExtent(), other->ReplicateExtent()); ret &= StructuralEqual()(this->ReplicateExtent(), other->ReplicateExtent());
ret &= StructuralEqual()(this->ThreadExtent(), other->ThreadExtent()); ret &= StructuralEqual()(this->ThreadExtent(), other->ThreadExtent());
......
...@@ -95,7 +95,7 @@ public: ...@@ -95,7 +95,7 @@ public:
std::string DebugOutput() const final; std::string DebugOutput() const final;
Fragment SetThreadRange(Range thread_range); Fragment BindThreadRange(Range thread_range) const;
Range ThreadRange() const { return thread_range_; } Range ThreadRange() const { return thread_range_; }
...@@ -127,12 +127,6 @@ public: ...@@ -127,12 +127,6 @@ public:
Optional<Var> replicate_var); Optional<Var> replicate_var);
TVM_DEFINE_OBJECT_REF_METHODS(Fragment, Layout, FragmentNode); TVM_DEFINE_OBJECT_REF_METHODS(Fragment, Layout, FragmentNode);
Fragment SetThreadRange(Range thread_range) {
auto node = make_object<FragmentNode>(*this->get());
node->SetThreadRange(thread_range);
return Fragment(node);
}
}; };
Var InputPlaceholder(size_t idx); Var InputPlaceholder(size_t idx);
......
...@@ -172,7 +172,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -172,7 +172,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ComputeWarpPartition(block_size / warp_size, T.target); ComputeWarpPartition(block_size / warp_size, T.target);
auto fragment = auto fragment =
makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment.SetThreadRange(thread_range)); results.Set(C, fragment->BindThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") { if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size(); int dim_A = A->shape.size();
results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]), results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
...@@ -181,7 +181,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -181,7 +181,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
} else if (A.scope() == "local.fragment") { } else if (A.scope() == "local.fragment") {
ICHECK(trans_A == false); ICHECK(trans_A == false);
auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n); auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
results.Set(A, fragment.SetThreadRange(thread_range)); results.Set(A, fragment->BindThreadRange(thread_range));
} else { } else {
ICHECK(0); ICHECK(0);
} }
...@@ -197,7 +197,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -197,7 +197,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ComputeWarpPartition(block_size / warp_size, T.target); ComputeWarpPartition(block_size / warp_size, T.target);
auto fragment = auto fragment =
makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment.SetThreadRange(thread_range)); results.Set(C, fragment->BindThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") { if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size(); int dim_A = A->shape.size();
...@@ -210,7 +210,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -210,7 +210,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(trans_A == false); ICHECK(trans_A == false);
auto fragment = auto fragment =
makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits()); makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits());
results.Set(A, fragment.SetThreadRange(thread_range)); results.Set(A, fragment->BindThreadRange(thread_range));
} else { } else {
ICHECK(0); ICHECK(0);
} }
...@@ -225,7 +225,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -225,7 +225,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(trans_B == false) << "B is local.fragment, trans_B must be false, " ICHECK(trans_B == false) << "B is local.fragment, trans_B must be false, "
"please raise an issue if you see this"; "please raise an issue if you see this";
auto fragment = makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n); auto fragment = makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n);
results.Set(B, fragment.SetThreadRange(thread_range)); results.Set(B, fragment->BindThreadRange(thread_range));
} else { } else {
ICHECK(0); ICHECK(0);
} }
...@@ -239,7 +239,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -239,7 +239,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n, ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
C->dtype.bits()) C->dtype.bits())
: makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment.SetThreadRange(thread_range)); results.Set(C, fragment->BindThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") { if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size(); int dim_A = A->shape.size();
const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]); const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
...@@ -252,7 +252,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -252,7 +252,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(trans_A == false); ICHECK(trans_A == false);
auto fragment = auto fragment =
makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits()); makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits());
results.Set(A, fragment.SetThreadRange(thread_range)); results.Set(A, fragment->BindThreadRange(thread_range));
} }
if (B.scope() == "shared" || B.scope() == "shared.dyn") { if (B.scope() == "shared" || B.scope() == "shared.dyn") {
int dim_B = B->shape.size(); int dim_B = B->shape.size();
...@@ -272,7 +272,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -272,7 +272,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
auto fragment = auto fragment =
makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits()); makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment.SetThreadRange(thread_range)); results.Set(C, fragment->BindThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") { if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size(); int dim_A = A->shape.size();
...@@ -283,7 +283,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -283,7 +283,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
} else if (A.scope() == "local.fragment") { } else if (A.scope() == "local.fragment") {
auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n, auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
A->dtype.bits(), trans_A); A->dtype.bits(), trans_A);
results.Set(A, fragment.SetThreadRange(thread_range)); results.Set(A, fragment->BindThreadRange(thread_range));
} else { } else {
ICHECK(0); ICHECK(0);
} }
...@@ -296,7 +296,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -296,7 +296,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
results.Set(B, shared_layout); results.Set(B, shared_layout);
} else if (B.scope() == "local.fragment") { } else if (B.scope() == "local.fragment") {
auto fragment = makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n); auto fragment = makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n);
results.Set(B, fragment.SetThreadRange(thread_range)); results.Set(B, fragment->BindThreadRange(thread_range));
} else { } else {
ICHECK(0); ICHECK(0);
} }
......
...@@ -181,7 +181,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -181,7 +181,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
PrimExpr loop_var_to_thread = PrimExpr loop_var_to_thread =
src_layout->ForwardThread(indice_map_[buffer], rep); src_layout->ForwardThread(indice_map_[buffer], rep);
return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter) return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
.SetThreadRange(T.thread_bounds); ->BindThreadRange(T.thread_bounds);
} }
}; };
if (source_buffer.defined()) { if (source_buffer.defined()) {
...@@ -272,7 +272,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -272,7 +272,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
LayoutMap results; LayoutMap results;
for (const auto &[buffer, _] : indice_map_) { for (const auto &[buffer, _] : indice_map_) {
if (!T.layout_map.count(buffer)) { if (!T.layout_map.count(buffer)) {
results.Set(buffer, CompleteBufferFragment(buffer).SetThreadRange( results.Set(buffer, CompleteBufferFragment(buffer)->BindThreadRange(
T.thread_bounds)); T.thread_bounds));
} }
// Though they may exist some conflicts, but it's fine. // Though they may exist some conflicts, but it's fine.
...@@ -285,13 +285,13 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -285,13 +285,13 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
const FragmentNode *src_layout = const FragmentNode *src_layout =
T.layout_map[buffer].as<Fragment>().get(); T.layout_map[buffer].as<Fragment>().get();
Fragment dst_layout_fragment = Fragment dst_layout_fragment =
CompleteBufferFragment(buffer).SetThreadRange(T.thread_bounds); CompleteBufferFragment(buffer)->BindThreadRange(T.thread_bounds);
const FragmentNode *dst_layout = const FragmentNode *dst_layout =
dst_layout_fragment.as<Fragment>().get(); dst_layout_fragment.as<Fragment>().get();
if (src_layout && dst_layout) { if (src_layout && dst_layout) {
ICHECK(src_layout->IsEqual(dst_layout, true)) ICHECK(src_layout->IsEqual(dst_layout, true))
<< "Layout may conflict with ParallelOp for buffer " << buffer << "Layout may conflict with ParallelOp for buffer " << buffer
<< "\nError body begin:\n" << " vs. " << source_buffer << "\nError body begin:\n"
<< GetRoot()->body << "\nError body end" << GetRoot()->body << "\nError body end"
<< "\nLHS = " << src_layout->DebugOutput() << "\nLHS = " << src_layout->DebugOutput()
<< "\nRHS = " << dst_layout->DebugOutput() << "\nRHS = " << dst_layout->DebugOutput()
......
...@@ -272,7 +272,8 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -272,7 +272,8 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent)); fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
Fragment dst_layout = Fragment dst_layout =
Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, NullOpt) Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, NullOpt)
->CondenseReplicateVar(); ->CondenseReplicateVar()
->BindThreadRange(T.thread_bounds);
return {{dst, dst_layout}}; return {{dst, dst_layout}};
} }
return {}; return {};
......
...@@ -191,7 +191,7 @@ Fragment PlanLoopPartition(For op, int vectorize_size, Range thread_range) { ...@@ -191,7 +191,7 @@ Fragment PlanLoopPartition(For op, int vectorize_size, Range thread_range) {
size_t num_thread = *as_const_int(thread_range->extent); size_t num_thread = *as_const_int(thread_range->extent);
LoopPartitioner partitioner; LoopPartitioner partitioner;
Fragment fragment = partitioner.Partition(op, num_thread, vectorize_size); Fragment fragment = partitioner.Partition(op, num_thread, vectorize_size);
return fragment.SetThreadRange(thread_range); return fragment->BindThreadRange(thread_range);
} }
For LoopPragmaUnroll(For stmt) { For LoopPragmaUnroll(For stmt) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment