"docs/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "b91aba399b0927e8850ed6fb253dc88439953a7c"
Commit c2480907 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Improve layout equality checks and error messaging (#471)

* [Refactor] Simplify buffer_region_to_tile_region function in copy.py

* Removed redundant logic for handling region extents in the buffer_region_to_tile_region function, streamlining the code for better readability and maintainability.
* Enhanced error handling by focusing on essential checks while eliminating unnecessary complexity related to variable extents.

* [Refactor] Improve layout equality checks and error messaging

* Updated the `IsEqual` method in `FragmentNode` to ensure consistent evaluation of thread ranges.
* Enhanced error messaging in `ParallelOp::InferLayout` to include source buffer information for better debugging.
* Adjusted `ReduceOp::InferLayout` to set thread range during layout condensation, improving layout inference accuracy.

* lintfix

* [Refactor] Rename SetThreadRange to BindThreadRange for clarity

* Updated the `SetThreadRange` method in `FragmentNode` and related classes to `BindThreadRange`, improving method naming consistency and clarity.
* Adjusted all references to the renamed method across the codebase, ensuring proper functionality and maintaining existing behavior.
* Enhanced layout equality checks to handle thread ranges more robustly in `IsEqual` method.
* Updated layout inference methods in `Gemm`, `ParallelOp`, and `ReduceOp` to utilize the new method name, ensuring seamless integration with the updated API.

* [Refactor] Update BindThreadRange usage across layout inference methods

* Modified the implementation of `BindThreadRange` in `FragmentNode` to create a new object instance, enhancing thread range binding functionality.
* Updated all references to `BindThreadRange` in layout inference methods across `Gemm`, `ParallelOp`, and `ReduceOp` to ensure consistency with the new implementation.
* Adjusted the return statements in various layout inference functions to utilize the updated method, maintaining existing behavior while improving clarity.

* lint fix
parent 273be768
......@@ -201,9 +201,10 @@ Fragment FragmentNode::DeReplicate() const {
int(*rep_size) / factor, NullOpt);
}
Fragment FragmentNode::SetThreadRange(Range thread_range) {
thread_range_ = thread_range;
return GetRef<Fragment>(this);
Fragment FragmentNode::BindThreadRange(Range thread_range) const {
auto n = make_object<FragmentNode>(*this);
n->thread_range_ = thread_range;
return Fragment(n);
}
Layout LayoutNode::Inverse() const {
......@@ -415,11 +416,13 @@ bool FragmentNode::IsEqual(const FragmentNode *other, bool skip_index) const {
// a[i, j] = b[j, i] in register level.
bool ret = StructuralEqual()(this->InputShape(), other->InputShape());
ret &= StructuralEqual()(this->ThreadRange(), other->ThreadRange());
if (!ret) {
// may be broadcast case
return true;
}
if (this->thread_range_.defined() && other->thread_range_.defined()) {
ret &= StructuralEqual()(this->thread_range_, other->thread_range_);
}
ret &= StructuralEqual()(this->OutputShape(), other->OutputShape());
ret &= StructuralEqual()(this->ReplicateExtent(), other->ReplicateExtent());
ret &= StructuralEqual()(this->ThreadExtent(), other->ThreadExtent());
......
......@@ -95,7 +95,7 @@ public:
std::string DebugOutput() const final;
Fragment SetThreadRange(Range thread_range);
Fragment BindThreadRange(Range thread_range) const;
Range ThreadRange() const { return thread_range_; }
......@@ -127,12 +127,6 @@ public:
Optional<Var> replicate_var);
TVM_DEFINE_OBJECT_REF_METHODS(Fragment, Layout, FragmentNode);
Fragment SetThreadRange(Range thread_range) {
auto node = make_object<FragmentNode>(*this->get());
node->SetThreadRange(thread_range);
return Fragment(node);
}
};
Var InputPlaceholder(size_t idx);
......
......@@ -172,7 +172,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ComputeWarpPartition(block_size / warp_size, T.target);
auto fragment =
makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment.SetThreadRange(thread_range));
results.Set(C, fragment->BindThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size();
results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
......@@ -181,7 +181,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
} else if (A.scope() == "local.fragment") {
ICHECK(trans_A == false);
auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
results.Set(A, fragment.SetThreadRange(thread_range));
results.Set(A, fragment->BindThreadRange(thread_range));
} else {
ICHECK(0);
}
......@@ -197,7 +197,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ComputeWarpPartition(block_size / warp_size, T.target);
auto fragment =
makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment.SetThreadRange(thread_range));
results.Set(C, fragment->BindThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size();
......@@ -210,7 +210,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(trans_A == false);
auto fragment =
makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits());
results.Set(A, fragment.SetThreadRange(thread_range));
results.Set(A, fragment->BindThreadRange(thread_range));
} else {
ICHECK(0);
}
......@@ -225,7 +225,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(trans_B == false) << "B is local.fragment, trans_B must be false, "
"please raise an issue if you see this";
auto fragment = makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n);
results.Set(B, fragment.SetThreadRange(thread_range));
results.Set(B, fragment->BindThreadRange(thread_range));
} else {
ICHECK(0);
}
......@@ -239,7 +239,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
C->dtype.bits())
: makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment.SetThreadRange(thread_range));
results.Set(C, fragment->BindThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size();
const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
......@@ -252,7 +252,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
ICHECK(trans_A == false);
auto fragment =
makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits());
results.Set(A, fragment.SetThreadRange(thread_range));
results.Set(A, fragment->BindThreadRange(thread_range));
}
if (B.scope() == "shared" || B.scope() == "shared.dyn") {
int dim_B = B->shape.size();
......@@ -272,7 +272,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
auto fragment =
makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment.SetThreadRange(thread_range));
results.Set(C, fragment->BindThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size();
......@@ -283,7 +283,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
} else if (A.scope() == "local.fragment") {
auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
A->dtype.bits(), trans_A);
results.Set(A, fragment.SetThreadRange(thread_range));
results.Set(A, fragment->BindThreadRange(thread_range));
} else {
ICHECK(0);
}
......@@ -296,7 +296,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
results.Set(B, shared_layout);
} else if (B.scope() == "local.fragment") {
auto fragment = makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n);
results.Set(B, fragment.SetThreadRange(thread_range));
results.Set(B, fragment->BindThreadRange(thread_range));
} else {
ICHECK(0);
}
......
......@@ -181,7 +181,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
PrimExpr loop_var_to_thread =
src_layout->ForwardThread(indice_map_[buffer], rep);
return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
.SetThreadRange(T.thread_bounds);
->BindThreadRange(T.thread_bounds);
}
};
if (source_buffer.defined()) {
......@@ -272,7 +272,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
LayoutMap results;
for (const auto &[buffer, _] : indice_map_) {
if (!T.layout_map.count(buffer)) {
results.Set(buffer, CompleteBufferFragment(buffer).SetThreadRange(
results.Set(buffer, CompleteBufferFragment(buffer)->BindThreadRange(
T.thread_bounds));
}
// Though they may exist some conflicts, but it's fine.
......@@ -285,13 +285,13 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
const FragmentNode *src_layout =
T.layout_map[buffer].as<Fragment>().get();
Fragment dst_layout_fragment =
CompleteBufferFragment(buffer).SetThreadRange(T.thread_bounds);
CompleteBufferFragment(buffer)->BindThreadRange(T.thread_bounds);
const FragmentNode *dst_layout =
dst_layout_fragment.as<Fragment>().get();
if (src_layout && dst_layout) {
ICHECK(src_layout->IsEqual(dst_layout, true))
<< "Layout may conflict with ParallelOp for buffer " << buffer
<< "\nError body begin:\n"
<< " vs. " << source_buffer << "\nError body begin:\n"
<< GetRoot()->body << "\nError body end"
<< "\nLHS = " << src_layout->DebugOutput()
<< "\nRHS = " << dst_layout->DebugOutput()
......
......@@ -272,7 +272,8 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
Fragment dst_layout =
Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, NullOpt)
->CondenseReplicateVar();
->CondenseReplicateVar()
->BindThreadRange(T.thread_bounds);
return {{dst, dst_layout}};
}
return {};
......
......@@ -191,7 +191,7 @@ Fragment PlanLoopPartition(For op, int vectorize_size, Range thread_range) {
size_t num_thread = *as_const_int(thread_range->extent);
LoopPartitioner partitioner;
Fragment fragment = partitioner.Partition(op, num_thread, vectorize_size);
return fragment.SetThreadRange(thread_range);
return fragment->BindThreadRange(thread_range);
}
For LoopPragmaUnroll(For stmt) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment