"ts/webui/src/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "60e1e01f625aa63905e1a43e33876d246141d375"
Commit bb1a5fd8 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Remove DeReplicate during parallel loop layout inference (#430)

* [Refactor] Adjust layout inference calculations in Gemm and ParallelOp

* Updated block size calculation in Gemm to account for the range of thread bounds, improving accuracy in layout inference.
* Simplified layout conflict error messages in ParallelOp for better clarity, enhancing debugging experience.
* Removed redundant buffer checks in ParallelOp layout inference logic, streamlining the code.

* [Refactor] Clean up layout inference logic in Gemm and ParallelOp

* Removed unnecessary warning log in Gemm related to WGMMA conditions, streamlining the layout inference process.
* Commented out redundant checks in ParallelOp's layout inference, improving code clarity while maintaining functionality.
* Enhanced error messages in ParallelOp to provide clearer context for layout conflicts, aiding in debugging efforts.

* lint fix
parent 97d63fab
......@@ -161,7 +161,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
return {};
LayoutMap results;
ICHECK(C.scope() == "local.fragment");
auto block_size = *as_const_int(T.thread_bounds->extent);
auto block_size = *as_const_int(T.thread_bounds->extent) -
*as_const_int(T.thread_bounds->min);
if (TargetIsVolta(T.target)) {
const int warp_size = 32;
auto [warp_m, warp_n] =
......@@ -220,10 +221,6 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
} else if (TargetIsHopper(T.target)) {
const int warp_size = 32;
bool maybe_wgmma = (this->M >= 64) && (block_size / warp_size % 4 == 0);
if (!maybe_wgmma) {
LOG(WARNING)
<< "WGMMA is not enabled because M < 64 or block_size % 128 != 0";
}
auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
auto fragment =
......
......@@ -181,24 +181,22 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
};
if (source_buffer.defined()) {
loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
} else if (read_source_buffer.defined()) {
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
} else if (level == InferLevel::kFree) {
if (read_source_buffer.defined()) {
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
// Loop don't need to be replicated.
if (!is_one(loop_layout_->ReplicateExtent()))
loop_layout_ = loop_layout_->DeReplicate();
// if still has replication, add a condition
if (!is_one(loop_layout_->ReplicateExtent())) {
auto inv = loop_layout_->Inverse();
Array<PrimExpr> fwd;
for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
fwd.push_back(0);
fwd.push_back(InputPlaceholder(0));
auto rep = inv->Forward(fwd).back();
AddPredicate(EQ(rep, 0));
}
// // Loop don't need to be replicated.
// if (!is_one(loop_layout_->ReplicateExtent()))
// loop_layout_ = loop_layout_->DeReplicate();
// // if still has replication, add a condition
// if (!is_one(loop_layout_->ReplicateExtent())) {
// auto inv = loop_layout_->Inverse();
// Array<PrimExpr> fwd;
// for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
// fwd.push_back(0);
// fwd.push_back(InputPlaceholder(0));
// auto rep = inv->Forward(fwd).back();
// AddPredicate(EQ(rep, 0));
// }
} else {
// Vectorize Size must be aware of the buffer_remap
// As the pass will do post processing to the layout
......@@ -229,6 +227,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
} else {
return {};
}
// Step 2: Check that the loop's partition can correctly align with all source
// fragment
for (const auto &[buffer, _] : indice_map_) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment