"testing/python/jit/test_tilelang_jit_gemm_cython.py" did not exist on "38ba083b581d3f9f1424e1bcfd45ca068bce65cd"
Commit 8df45c9d authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Avoid duplicate data access when cross thread buffer meet replicate register (#606)

* [Enhancement] Improve debug output formatting in layout and fragment nodes

- Updated the `DebugOutput` methods in `LayoutNode` and `FragmentNode` to provide more structured and informative output, including transformation details and thread range information.
- Enhanced layout inference logic in `ParallelOp` to add predicates for cross-thread shared memory access, improving layout handling in parallel operations.
- Minor adjustment in `layout_inference.cc` to ensure clarity in parallel loop handling.

* lint fix
parent a8811d9b
...@@ -370,19 +370,22 @@ Fragment FragmentNode::CondenseReplicateVar() const { ...@@ -370,19 +370,22 @@ Fragment FragmentNode::CondenseReplicateVar() const {
std::string LayoutNode::DebugOutput() const { std::string LayoutNode::DebugOutput() const {
std::stringstream ss; std::stringstream ss;
ss << "Layout Shape: " << InputShape() << " -> " << OutputShape() << " -> " ss << "Layout(" << InputShape() << " -> " << OutputShape()
<< GetForwardIndex(); << ", transform: " << GetForwardVars() << " -> " << GetForwardIndex()
<< ")";
return ss.str(); return ss.str();
} }
std::string FragmentNode::DebugOutput() const { std::string FragmentNode::DebugOutput() const {
std::stringstream ss; std::stringstream ss;
ss << "Fragment Shape: " << InputShape() << " -> " << OutputShape(); ss << "Fragment(" << InputShape() << " -> " << OutputShape()
ss << " -> replicate: " << ReplicateExtent(); << ", replicate: " << ReplicateExtent() << ", thread: " << ThreadExtent()
ss << " -> thread: " << ThreadExtent(); << ", forward_thread: " << forward_thread_
ss << " -> forward_thread: " << forward_thread_; << ", forward_index: " << GetForwardIndex();
ss << " -> forward_index: " << GetForwardIndex(); if (thread_range_.defined()) {
ss << " -> thread_range: " << thread_range_; ss << ", thread_range: " << thread_range_;
}
ss << ")";
return ss.str(); return ss.str();
} }
......
...@@ -174,16 +174,38 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -174,16 +174,38 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
// // Loop don't need to be replicated. // // Loop don't need to be replicated.
// if (!is_one(loop_layout_->ReplicateExtent())) // if (!is_one(loop_layout_->ReplicateExtent()))
// loop_layout_ = loop_layout_->DeReplicate(); // loop_layout_ = loop_layout_->DeReplicate();
// // if still has replication, add a condition
// if (!is_one(loop_layout_->ReplicateExtent())) { // For free layout inference
// auto inv = loop_layout_->Inverse(); // If replication exists and buffer has cross-thread shared memory access,
// Array<PrimExpr> fwd; // add predicate
// for (size_t i = 0; i < loop_layout_->OutputDim(); i++) bool has_cross_thread_access = false;
// fwd.push_back(0); PostOrderVisit(root_, [&](const ObjectRef &obj) {
// fwd.push_back(InputPlaceholder(0)); if (const auto *store = obj.as<BufferStoreNode>()) {
// auto rep = inv->Forward(fwd).back(); // check if scope is shared or global
// AddPredicate(EQ(rep, 0)); if (store->buffer.scope() == "shared" ||
// } store->buffer.scope() == "shared.dyn" ||
store->buffer.scope() == "global") {
has_cross_thread_access = true;
}
} else if (const auto *load = obj.as<BufferLoadNode>()) {
// check if scope is shared or global
if (load->buffer.scope() == "shared" ||
load->buffer.scope() == "shared.dyn" ||
load->buffer.scope() == "global") {
has_cross_thread_access = true;
}
}
});
if (!is_one(loop_layout_->ReplicateExtent()) && has_cross_thread_access) {
auto inv = loop_layout_->Inverse();
Array<PrimExpr> fwd;
for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
fwd.push_back(0);
fwd.push_back(InputPlaceholder(0));
auto rep = inv->Forward(fwd).back();
AddPredicate(EQ(rep, 0));
}
} else { } else {
// Vectorize Size must be aware of the buffer_remap // Vectorize Size must be aware of the buffer_remap
// As the pass will do post processing to the layout // As the pass will do post processing to the layout
......
...@@ -598,6 +598,7 @@ private: ...@@ -598,6 +598,7 @@ private:
auto loop_layout = result_.for_map[root]; auto loop_layout = result_.for_map[root];
bool parallel_loop = !is_register_store && !skip_thread_partition_; bool parallel_loop = !is_register_store && !skip_thread_partition_;
if (parallel_loop) { if (parallel_loop) {
for_node = for_node =
PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout); PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment