"examples/git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "716dbef52f550dd4d0864c340eb2362904b0ea33"
Commit 2a286ae6 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Refactor for Better Layout Conflict Handling (#240)

* [Feature] Add reduce_max functionality and corresponding tests

* Introduced a new test file for the reduce_max operation in the tilelang language module.
* Implemented the reduce_max functionality using T.prim_func, including local memory allocation and result copying.
* Added tests for various input sizes and data types to ensure correctness of the reduce_max implementation.
* Enhanced profiling assertions to validate the output against reference implementations.

* Fix whitespace issues in reduce_max test file for improved readability

* [Refactor] Update DebugOutput methods to return strings instead of void

* Modified DebugOutput methods in LayoutNode, FragmentNode, and SwizzledLayoutNode to return std::string instead of void, enhancing usability for logging and debugging.
* Updated corresponding header files to reflect the new return types.
* Improved layout inference error messages by incorporating DebugOutput for better clarity in layout conflicts.

* lint fix

* Fix typo in matmul function: changed loop from T.Parallel to T.grid for correct parallel execution in webgpu code generation tests.

* [Enhancement] Improve layout inference conflict handling in ParallelOp

* Updated the layout inference logic in ParallelOp to better handle conflicts for local.fragment buffers.
* Added checks to ensure that layout conflicts are reported only when both source and destination buffers are defined, improving clarity in error messages.
* Enhanced the overall robustness of the layout inference process by addressing specific cases where conflicts may arise.

* [Feature] Add IsEqual methods for layout comparison

* Introduced IsEqual methods in LayoutNode, FragmentNode, and SwizzledLayoutNode to facilitate structural equality checks, allowing for optional index comparison.
* Enhanced layout inference logic in Copy and ParallelOp to utilize the new IsEqual methods for better conflict detection in local.fragment layouts.
* Improved error messages for layout conflicts to provide clearer guidance on potential issues.houm

* [Refactor] Update profiler usage in benchmark_nsa_fwd.py and improve layout inference in elem.cc and parallel.cc

* Modified the profiler call in benchmark_nsa_fwd.py to streamline latency measurement.
* Updated layout inference logic in elem.cc and parallel.cc to use const pointers for FragmentNode, enhancing type safety and clarity.
* Improved error messages in layout conflict checks to provide better guidance on potential issues.

* [Refactor] Clean up pointer formatting in layout inference files

* Standardized pointer formatting for FragmentNode in elem.cc and parallel.cc to improve code readability.
* Minor adjustments to error message formatting in layout conflict checks for better clarity.
parent 45534789
...@@ -614,7 +614,7 @@ def benchmark_nsa(batch_size, ...@@ -614,7 +614,7 @@ def benchmark_nsa(batch_size,
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
profiler_latency = profiler.do_bench(profiler.mod) profiler_latency = profiler.do_bench()
print(f"Profiler latency: {profiler_latency} ms") print(f"Profiler latency: {profiler_latency} ms")
# Create input tensors # Create input tensors
......
...@@ -364,17 +364,21 @@ Fragment FragmentNode::CondenseReplicateVar() const { ...@@ -364,17 +364,21 @@ Fragment FragmentNode::CondenseReplicateVar() const {
new_thread_replicate->dom->extent, new_thread_replicate->var); new_thread_replicate->dom->extent, new_thread_replicate->var);
} }
void LayoutNode::DebugOutput() const { std::string LayoutNode::DebugOutput() const {
LOG_DEBUG << "Layout Shape: " << InputShape() << " -> " << OutputShape(); std::stringstream ss;
LOG_DEBUG << "Layout Index: " << forward_index_; ss << "Layout Shape: " << InputShape() << " -> " << OutputShape() << " -> "
<< GetForwardIndex();
return ss.str();
} }
void FragmentNode::DebugOutput() const { std::string FragmentNode::DebugOutput() const {
LOG_DEBUG << "Fragment Shape: " << InputShape() << " -> " << OutputShape(); std::stringstream ss;
LOG_DEBUG << "Fragment Replicate: " << ReplicateExtent(); ss << "Fragment Shape: " << InputShape() << " -> " << OutputShape();
LOG_DEBUG << "Fragment ThreadExtent: " << ThreadExtent(); ss << " -> replicate: " << ReplicateExtent();
LOG_DEBUG << "Fragment Index: " << forward_index_; ss << " -> thread: " << ThreadExtent();
LOG_DEBUG << "Fragment ThreadIndex: " << forward_thread_; ss << " -> forward_thread: " << forward_thread_;
ss << " -> forward_index: " << GetForwardIndex();
return ss.str();
} }
bool LayoutNode::SEqualReduce(const LayoutNode *other, bool LayoutNode::SEqualReduce(const LayoutNode *other,
...@@ -392,6 +396,30 @@ bool FragmentNode::SEqualReduce(const FragmentNode *other, ...@@ -392,6 +396,30 @@ bool FragmentNode::SEqualReduce(const FragmentNode *other,
equal(this->forward_thread_, other->forward_thread_); equal(this->forward_thread_, other->forward_thread_);
} }
bool LayoutNode::IsEqual(const LayoutNode *other, bool skip_index) const {
bool ret = StructuralEqual()(this->InputShape(), other->InputShape());
ret &= StructuralEqual()(this->OutputShape(), other->OutputShape());
if (!skip_index) {
ret &= StructuralEqual()(this->forward_index_, other->forward_index_);
}
return ret;
}
bool FragmentNode::IsEqual(const FragmentNode *other, bool skip_index) const {
// Fragment Layout Comparison can skip the index comparison
// when the output shape is the same, as we can do
// a[i, j] = b[j, i] in register level.
bool ret = StructuralEqual()(this->InputShape(), other->InputShape());
ret &= StructuralEqual()(this->OutputShape(), other->OutputShape());
ret &= StructuralEqual()(this->ReplicateExtent(), other->ReplicateExtent());
ret &= StructuralEqual()(this->ThreadExtent(), other->ThreadExtent());
if (!skip_index) {
ret &= StructuralEqual()(this->forward_index_, other->forward_index_);
}
return ret;
}
TVM_REGISTER_NODE_TYPE(LayoutNode); TVM_REGISTER_NODE_TYPE(LayoutNode);
TVM_REGISTER_NODE_TYPE(FragmentNode); TVM_REGISTER_NODE_TYPE(FragmentNode);
......
...@@ -40,7 +40,9 @@ public: ...@@ -40,7 +40,9 @@ public:
virtual Layout Inverse() const; virtual Layout Inverse() const;
virtual void DebugOutput() const; virtual std::string DebugOutput() const;
virtual bool IsEqual(const LayoutNode *other, bool skip_index = false) const;
static constexpr bool _type_has_method_sequal_reduce = true; static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr const char *_type_key = "tl.Layout"; static constexpr const char *_type_key = "tl.Layout";
...@@ -94,7 +96,9 @@ public: ...@@ -94,7 +96,9 @@ public:
Fragment CondenseReplicateVar() const; Fragment CondenseReplicateVar() const;
void DebugOutput() const final; std::string DebugOutput() const final;
bool IsEqual(const FragmentNode *other, bool skip_index = false) const;
void VisitAttrs(tvm::AttrVisitor *v); void VisitAttrs(tvm::AttrVisitor *v);
bool SEqualReduce(const FragmentNode *other, SEqualReducer equal) const; bool SEqualReduce(const FragmentNode *other, SEqualReducer equal) const;
......
...@@ -58,10 +58,12 @@ Array<PrimExpr> SwizzledLayoutNode::Forward(const Array<PrimExpr> &vars) const { ...@@ -58,10 +58,12 @@ Array<PrimExpr> SwizzledLayoutNode::Forward(const Array<PrimExpr> &vars) const {
return expr_list; return expr_list;
} }
void SwizzledLayoutNode::DebugOutput() const { std::string SwizzledLayoutNode::DebugOutput() const {
LayoutNode::DebugOutput(); std::stringstream ss;
std::cout << "Layout Swizzle: " << pattern_.Base() << " " << pattern_.Bits() ss << LayoutNode::DebugOutput();
<< " " << pattern_.Shift(); ss << "Layout Swizzle: " << pattern_.Base() << " " << pattern_.Bits() << " "
<< pattern_.Shift();
return ss.str();
} }
Layout SwizzledLayoutNode::Inverse() const { Layout SwizzledLayoutNode::Inverse() const {
...@@ -69,6 +71,11 @@ Layout SwizzledLayoutNode::Inverse() const { ...@@ -69,6 +71,11 @@ Layout SwizzledLayoutNode::Inverse() const {
return {}; return {};
} }
bool SwizzledLayoutNode::IsEqual(const SwizzledLayoutNode *other,
bool skip_index) const {
return LayoutNode::IsEqual(other, skip_index) && pattern_ == other->pattern_;
}
SwizzledLayout::SwizzledLayout(Array<IterVar> forward_var, SwizzledLayout::SwizzledLayout(Array<IterVar> forward_var,
Array<PrimExpr> forward_index, Array<PrimExpr> forward_index,
SwizzlePattern pattern) { SwizzlePattern pattern) {
......
...@@ -45,8 +45,8 @@ public: ...@@ -45,8 +45,8 @@ public:
Array<PrimExpr> Forward(const Array<PrimExpr> &vars) const final; Array<PrimExpr> Forward(const Array<PrimExpr> &vars) const final;
Layout Inverse() const final; Layout Inverse() const final;
void DebugOutput() const final; std::string DebugOutput() const final;
bool IsEqual(const SwizzledLayoutNode *other, bool skip_index = false) const;
static constexpr const char *_type_key = "tl.SwizzledLayout"; static constexpr const char *_type_key = "tl.SwizzledLayout";
bool SEqualReduce(const SwizzledLayoutNode *other, SEqualReducer equal) const; bool SEqualReduce(const SwizzledLayoutNode *other, SEqualReducer equal) const;
void VisitAttrs(tvm::AttrVisitor *v); void VisitAttrs(tvm::AttrVisitor *v);
......
...@@ -340,6 +340,20 @@ LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -340,6 +340,20 @@ LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
par_op_ = std::make_unique<ParallelOp>(MakeSIMTLoop(&analyzer)); par_op_ = std::make_unique<ParallelOp>(MakeSIMTLoop(&analyzer));
} }
if (T.layout_map.count(src) && T.layout_map.count(dst)) {
// Only compare fragment layout
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") {
const FragmentNode *src_layout = T.layout_map[src].as<Fragment>().get();
const FragmentNode *dst_layout = T.layout_map[dst].as<Fragment>().get();
if (src_layout && dst_layout) {
ICHECK(src_layout->IsEqual(dst_layout, true))
<< "Get different layout for " << src << " and " << dst
<< "\nLHS = " << src_layout->DebugOutput()
<< "\nRHS = " << dst_layout->DebugOutput()
<< "\nYou may need to use a shared memory to transform the layout";
}
}
}
return par_op_->InferLayout(T, level); return par_op_->InferLayout(T, level);
} }
......
...@@ -225,8 +225,31 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -225,8 +225,31 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
// Step 3: Infer other fragment's layout from the loop's partition // Step 3: Infer other fragment's layout from the loop's partition
LayoutMap results; LayoutMap results;
for (const auto &[buffer, _] : indice_map_) { for (const auto &[buffer, _] : indice_map_) {
if (!T.layout_map.count(buffer)) if (!T.layout_map.count(buffer)) {
results.Set(buffer, CompleteBufferFragment(buffer)); results.Set(buffer, CompleteBufferFragment(buffer));
}
// Though they may exist some conflicts, but it's fine.
// Layout infer conflict for local.fragment can noy be handled here
// because the source_buffer is not always available
if (buffer.scope() == "local.fragment" && source_buffer.defined() &&
source_buffer.scope() == "local.fragment") {
if (T.layout_map.count(buffer)) {
const FragmentNode *src_layout =
T.layout_map[buffer].as<Fragment>().get();
Fragment dst_layout_fragment = CompleteBufferFragment(buffer);
const FragmentNode *dst_layout =
dst_layout_fragment.as<Fragment>().get();
if (src_layout && dst_layout) {
ICHECK(src_layout->IsEqual(dst_layout, true))
<< "Layout may conflict with ParallelOp for buffer " << buffer
<< "\nLHS = " << src_layout->DebugOutput()
<< "\nRHS = " << dst_layout->DebugOutput()
<< "\nYou may need to use a shared memory to transform the "
"layout";
}
}
}
} }
return results; return results;
} }
......
...@@ -128,7 +128,6 @@ public: ...@@ -128,7 +128,6 @@ public:
LayoutInferArgs{target_, static_cast<size_t>(*extent_ptr), LayoutInferArgs{target_, static_cast<size_t>(*extent_ptr),
layout_map}, layout_map},
level); level);
// Process the returned updates // Process the returned updates
for (const auto &[buffer, layout] : updates) { for (const auto &[buffer, layout] : updates) {
// Basic validity checks // Basic validity checks
...@@ -139,7 +138,8 @@ public: ...@@ -139,7 +138,8 @@ public:
// If already in map, ensure they are structurally equal // If already in map, ensure they are structurally equal
ICHECK(StructuralEqual()(layout, layout_map[buffer])) ICHECK(StructuralEqual()(layout, layout_map[buffer]))
<< "Get different layout for " << buffer << "Get different layout for " << buffer
<< " in cur_infer_id = " << cur_infer_id; << " current layout: " << layout->DebugOutput()
<< " previous layout: " << layout_map[buffer]->DebugOutput();
} else { } else {
// Otherwise, update map // Otherwise, update map
layout_map.Set(buffer, layout); layout_map.Set(buffer, layout);
...@@ -188,9 +188,10 @@ public: ...@@ -188,9 +188,10 @@ public:
for (int i = 0; i < num_infer; i++) { for (int i = 0; i < num_infer; i++) {
run_infer_step(i, InferLevel::kStrict, false); run_infer_step(i, InferLevel::kStrict, false);
} }
// step 2: infer common layout with BFS
// step 2: infer common layout with BFS
finish_infer_queue(); finish_infer_queue();
// step 3: relax constraints to free and re-run // step 3: relax constraints to free and re-run
for (int i = 0; i < num_infer; i++) { for (int i = 0; i < num_infer; i++) {
run_infer_step(i, InferLevel::kFree, true); run_infer_step(i, InferLevel::kFree, true);
......
...@@ -27,7 +27,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -27,7 +27,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
T.copy(A[by * block_M, ko * block_K], A_shared, coalesced_width=2) T.copy(A[by * block_M, ko * block_K], A_shared, coalesced_width=2)
T.copy(B[ko * block_K, bx * block_N], B_shared, coalesced_width=2) T.copy(B[ko * block_K, bx * block_N], B_shared, coalesced_width=2)
for i, j, k in T.Parallel(block_M, block_N, block_K): for i, j, k in T.grid(block_M, block_N, block_K):
C_local[i, j] += A_shared[i, k] * B_shared[k, j] C_local[i, j] += A_shared[i, k] * B_shared[k, j]
T.copy(C_local, C[by * block_M, bx * block_N], coalesced_width=2) T.copy(C_local, C[by * block_M, bx * block_N], coalesced_width=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment