"vscode:/vscode.git/clone" did not exist on "9cbbbbc6df9c243a65f64539846afad295696209"
Commit bb1a5fd8 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Remove DeReplicate during parallel loop layout inference (#430)

* [Refactor] Adjust layout inference calculations in Gemm and ParallelOp

* Updated block size calculation in Gemm to account for the range of thread bounds, improving accuracy in layout inference.
* Simplified layout conflict error messages in ParallelOp for better clarity, enhancing debugging experience.
* Removed redundant buffer checks in ParallelOp layout inference logic, streamlining the code.

* [Refactor] Clean up layout inference logic in Gemm and ParallelOp

* Removed unnecessary warning log in Gemm related to WGMMA conditions, streamlining the layout inference process.
* Commented out redundant checks in ParallelOp's layout inference, improving code clarity while maintaining functionality.
* Enhanced error messages in ParallelOp to provide clearer context for layout conflicts, aiding in debugging efforts.

* lint fix
parent 97d63fab
...@@ -161,7 +161,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -161,7 +161,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
return {}; return {};
LayoutMap results; LayoutMap results;
ICHECK(C.scope() == "local.fragment"); ICHECK(C.scope() == "local.fragment");
auto block_size = *as_const_int(T.thread_bounds->extent); auto block_size = *as_const_int(T.thread_bounds->extent) -
*as_const_int(T.thread_bounds->min);
if (TargetIsVolta(T.target)) { if (TargetIsVolta(T.target)) {
const int warp_size = 32; const int warp_size = 32;
auto [warp_m, warp_n] = auto [warp_m, warp_n] =
...@@ -220,10 +221,6 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -220,10 +221,6 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
} else if (TargetIsHopper(T.target)) { } else if (TargetIsHopper(T.target)) {
const int warp_size = 32; const int warp_size = 32;
bool maybe_wgmma = (this->M >= 64) && (block_size / warp_size % 4 == 0); bool maybe_wgmma = (this->M >= 64) && (block_size / warp_size % 4 == 0);
if (!maybe_wgmma) {
LOG(WARNING)
<< "WGMMA is not enabled because M < 64 or block_size % 128 != 0";
}
auto [warp_m, warp_n] = auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma); ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
auto fragment = auto fragment =
......
...@@ -181,24 +181,22 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -181,24 +181,22 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
}; };
if (source_buffer.defined()) { if (source_buffer.defined()) {
loop_layout_ = compute_loop_layout_from_buffer(source_buffer); loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
} else if (read_source_buffer.defined()) {
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
} else if (level == InferLevel::kFree) { } else if (level == InferLevel::kFree) {
if (read_source_buffer.defined()) { if (read_source_buffer.defined()) {
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer); loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
// Loop don't need to be replicated. // // Loop don't need to be replicated.
if (!is_one(loop_layout_->ReplicateExtent())) // if (!is_one(loop_layout_->ReplicateExtent()))
loop_layout_ = loop_layout_->DeReplicate(); // loop_layout_ = loop_layout_->DeReplicate();
// if still has replication, add a condition // // if still has replication, add a condition
if (!is_one(loop_layout_->ReplicateExtent())) { // if (!is_one(loop_layout_->ReplicateExtent())) {
auto inv = loop_layout_->Inverse(); // auto inv = loop_layout_->Inverse();
Array<PrimExpr> fwd; // Array<PrimExpr> fwd;
for (size_t i = 0; i < loop_layout_->OutputDim(); i++) // for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
fwd.push_back(0); // fwd.push_back(0);
fwd.push_back(InputPlaceholder(0)); // fwd.push_back(InputPlaceholder(0));
auto rep = inv->Forward(fwd).back(); // auto rep = inv->Forward(fwd).back();
AddPredicate(EQ(rep, 0)); // AddPredicate(EQ(rep, 0));
} // }
} else { } else {
// Vectorize Size must be aware of the buffer_remap // Vectorize Size must be aware of the buffer_remap
// As the pass will do post processing to the layout // As the pass will do post processing to the layout
...@@ -229,6 +227,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -229,6 +227,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
} else { } else {
return {}; return {};
} }
// Step 2: Check that the loop's partition can correctly align with all source // Step 2: Check that the loop's partition can correctly align with all source
// fragment // fragment
for (const auto &[buffer, _] : indice_map_) { for (const auto &[buffer, _] : indice_map_) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment