Unverified Commit 9d7d45be authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[TMA] Automatically lower 1d tma in appropriate cases (#788)

* Enhance layout inference and copy operations with 1D TMA support

- Updated `CopyNode` to introduce separate handling for 1D bulk load/store operations, including new methods for checking and lowering these operations.
- Modified `InferLayout` and `GetCopyInst` to accommodate additional parameters for layout maps and analyzers.
- Enhanced `AtomicAddNode` and `FillNode` to utilize the updated layout inference logic.
- Improved buffer out-of-bounds checks during layout inference to ensure safe memory access.

This update improves the efficiency and correctness of memory operations in the TileLang framework.

* Refactor layout inference calls for improved readability

- Updated `InferLayout` calls in `AtomicAddNode`, `CopyNode`, and `FillNode` to enhance code clarity by formatting parameters across multiple lines.
- Cleaned up whitespace and formatting in `copy.h` and `layout_inference.cc` to adhere to coding standards and improve maintainability.

This refactor aims to streamline the layout inference logic and improve overall code organization.

* Fix shared tensor check in CopyNode for bulk copy operations

- Updated the condition in `CheckBulkCopy1D` to verify contiguity of `shared_tensor` instead of `dst`, ensuring correct handling of shared memory layouts during bulk copy operations.
- This change enhances the accuracy of memory operations in the TileLang framework.

* Update test_example_gdn_compilation.py to invoke test function directly

- Commented out the call to `tilelang.testing.main()` in `test_example_gdn_compilation.py` and replaced it with a direct call to `test_example_chunk_delta_bwd_compilation()`. This change simplifies the test execution flow and focuses on the specific test case.

* Enhance bulk load/store checks in CopyNode with last dimension validation

- Updated `CheckBulkLoad` and `CheckBulkStore` methods in `CopyNode` to include an optional parameter for validating the last dimension during bulk copy operations.
- Adjusted related methods `CheckBulkLoad1D` and `CheckBulkStore1D` to pass the new parameter, improving the accuracy of bulk copy checks.
- This change enhances the robustness of memory operations in the TileLang framework by ensuring compliance with dimensional requirements.

* Refactor CheckBulkLoad and CheckBulkStore methods for improved readability

- Reformatted the parameter lists of `CheckBulkLoad` and `CheckBulkStore` methods in `CopyNode` to enhance code clarity by aligning parameters across multiple lines.
- This change improves the maintainability of the code and adheres to coding standards.
parent b6b02dab
...@@ -360,8 +360,9 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -360,8 +360,9 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict, std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree}; InferLevel::kFree};
for (auto level : levels) { for (auto level : levels) {
(par_op)->InferLayout( (par_op)->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
{T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level); false, T.buffer_remap},
level);
} }
auto loop_layout = par_op->GetLoopLayout(); auto loop_layout = par_op->GetLoopLayout();
Var thread_var = T.thread_var; Var thread_var = T.thread_var;
......
This diff is collapsed.
...@@ -15,11 +15,15 @@ using namespace tir; ...@@ -15,11 +15,15 @@ using namespace tir;
/// Copy instruction types for different memory access patterns /// Copy instruction types for different memory access patterns
enum class CopyInst : uint8_t { enum class CopyInst : uint8_t {
kNormal = 0, ///< Standard memory copy (ldg/stg/cpasync) kNormal = 0, // utilize ldg/stg or cpasync or any buffer copy
kLDSM = 1, ///< Load matrix instruction kLDSM = 1, // ldmatrix memory copy
kSTSM = 2, ///< Store matrix instruction kSTSM = 2, // stmatrix memory copy
kBulkLoad = 3, ///< Tensor Memory Access load kBulkLoad = 3, // utilize tma load
kBulkStore = 4, ///< Tensor Memory Access store kBulkStore = 4, // utilize tma store
// we should separate the bulk load and store for 1d and multi-dim
// as they have different memory access patterns
kBulkLoad1D = 5, // utilize tma load 1d
kBulkStore1D = 6, // utilize tma store 1d
}; };
/// Descriptor for Tensor Memory Access (TMA) copy operations /// Descriptor for Tensor Memory Access (TMA) copy operations
...@@ -137,17 +141,41 @@ public: ...@@ -137,17 +141,41 @@ public:
* \param T Arguments for layout inference. * \param T Arguments for layout inference.
* \param level Level of inference (basic or detailed). * \param level Level of inference (basic or detailed).
*/ */
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const; LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
/*! /*!
* \brief Check if bulk copy is supported. * \brief Check if bulk copy is supported.
*/ */
bool CheckBulkLoad(Target target) const; bool CheckBulkLoad(Target target, arith::Analyzer *analyzer,
bool check_last_dim = true) const;
/*! /*!
* \brief Check if bulk store is supported. * \brief Check if bulk store is supported.
*/ */
bool CheckBulkStore(Target target) const; bool CheckBulkStore(Target target, arith::Analyzer *analyzer,
bool check_last_dim = true) const;
/*!
* \brief Check if bulk copy 1d load is supported.
*/
bool CheckBulkLoad1D(Target target, const LayoutMap &layout_map,
arith::Analyzer *analyzer) const;
/*!
* \brief Check if bulk copy 1d store is supported.
*/
bool CheckBulkStore1D(Target target, const LayoutMap &layout_map,
arith::Analyzer *analyzer) const;
/*!
* \brief Check if bulk copy 1d is supported.
*/
bool CheckBulkCopy1D(const Buffer &global_tensor, const Buffer &shared_tensor,
const Array<Range> &global_range,
const Array<Range> &shared_range,
const LayoutMap &layout_map,
arith::Analyzer *analyzer) const;
/*! /*!
* \brief Check if lds memory copy is supported. * \brief Check if lds memory copy is supported.
...@@ -162,11 +190,10 @@ public: ...@@ -162,11 +190,10 @@ public:
/*! /*!
* \brief Get the copy instruction type. * \brief Get the copy instruction type.
*/ */
CopyInst GetCopyInst(Target target, bool disable_tma_lower) const; CopyInst GetCopyInst(Target target, bool disable_tma_lower,
const LayoutMap &layout_map, arith::Analyzer *analyzer,
bool buffer_oob) const;
/*!
* \brief Clone this copy operator.
*/
protected: protected:
/*! /*!
* \brief Generate lowering for bulk/global-to-shared copy. * \brief Generate lowering for bulk/global-to-shared copy.
...@@ -174,6 +201,12 @@ protected: ...@@ -174,6 +201,12 @@ protected:
Stmt LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, Stmt LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
CopyInst copy_inst) const; CopyInst copy_inst) const;
/*!
* \brief Generate lowering for bulk copy 1d.
*/
Stmt LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer,
CopyInst copy_inst) const;
/*! /*!
* \brief Generate lowering for LDS Memory Copy (shared memory to shared * \brief Generate lowering for LDS Memory Copy (shared memory to shared
* memory or smem usage). * memory or smem usage).
...@@ -316,7 +349,8 @@ public: ...@@ -316,7 +349,8 @@ public:
/*! /*!
* \brief Infer layout for this operator. * \brief Infer layout for this operator.
*/ */
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const; LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
/*! /*!
* \brief Get TVM Op handle. * \brief Get TVM Op handle.
......
...@@ -170,9 +170,8 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { ...@@ -170,9 +170,8 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (dst.scope() == "local.fragment") { if (dst.scope() == "local.fragment") {
auto par_op = ParallelOp(MakeSIMTLoop(analyzer)); auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map}, par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
InferLevel::kFree); false, T.buffer_remap},
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map},
InferLevel::kFree); InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout()); par_op->GetLoopLayout());
...@@ -189,7 +188,8 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -189,7 +188,8 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
} else if (dst.scope() == "shared.dyn" || dst.scope() == "shared" || } else if (dst.scope() == "shared.dyn" || dst.scope() == "shared" ||
dst.scope() == "global") { dst.scope() == "global") {
auto par_op = ParallelOp(MakeSIMTLoop(analyzer)); auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map}, par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
false, T.buffer_remap},
InferLevel::kFree); InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout()); par_op->GetLoopLayout());
...@@ -225,9 +225,7 @@ TIR_REGISTER_TL_OP(Fill, fill) ...@@ -225,9 +225,7 @@ TIR_REGISTER_TL_OP(Fill, fill)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK({ FillNode::RegisterReflection(); });
FillNode::RegisterReflection();
});
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
\ No newline at end of file
...@@ -45,6 +45,8 @@ struct LayoutInferArgs { ...@@ -45,6 +45,8 @@ struct LayoutInferArgs {
Target target; Target target;
Range thread_bounds; Range thread_bounds;
LayoutMap layout_map; LayoutMap layout_map;
arith::Analyzer *analyzer;
bool buffer_oob = false;
Map<Buffer, Buffer> buffer_remap; Map<Buffer, Buffer> buffer_remap;
}; };
......
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
#include <queue> #include <queue>
#include "../layout/utils.h" #include "../layout/utils.h"
#include "../op/copy.h"
#include "../op/parallel.h" #include "../op/parallel.h"
#include "../op/region.h" #include "../op/region.h"
#include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h" #include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_fusion_utils.h" #include "common/loop_fusion_utils.h"
...@@ -64,6 +66,8 @@ public: ...@@ -64,6 +66,8 @@ public:
BufferUseDefCollector(bool skip_thread_partition) BufferUseDefCollector(bool skip_thread_partition)
: skip_thread_partition_(skip_thread_partition) {} : skip_thread_partition_(skip_thread_partition) {}
using arith::IRVisitorWithAnalyzer::IRVisitorWithAnalyzer;
void RunInferStep(int cur_infer_id, InferLevel level, bool update_queue, void RunInferStep(int cur_infer_id, InferLevel level, bool update_queue,
LayoutMap &layout_map, const LayoutMap &strict_layout_map, LayoutMap &layout_map, const LayoutMap &strict_layout_map,
std::queue<int> &q, std::vector<bool> &in_queue) { std::queue<int> &q, std::vector<bool> &in_queue) {
...@@ -80,6 +84,7 @@ public: ...@@ -80,6 +84,7 @@ public:
auto &next = infer_list_[cur_infer_id]; auto &next = infer_list_[cur_infer_id];
auto iter_var = thread_var_vec_[cur_infer_id]; auto iter_var = thread_var_vec_[cur_infer_id];
auto thread_bounds = thread_bounds_vec_[cur_infer_id]; auto thread_bounds = thread_bounds_vec_[cur_infer_id];
auto buffer_oob = buffer_oob_vec_[cur_infer_id];
// Double-check that 'next' is valid // Double-check that 'next' is valid
ICHECK(next.defined()) << "infer_list_[" << cur_infer_id ICHECK(next.defined()) << "infer_list_[" << cur_infer_id
<< "] is null inside run_infer_step."; << "] is null inside run_infer_step.";
...@@ -100,8 +105,10 @@ public: ...@@ -100,8 +105,10 @@ public:
"required for layout inference."; "required for layout inference.";
// Run InferLayout // Run InferLayout
auto updates = next->InferLayout( auto updates =
LayoutInferArgs{target_, thread_bounds, layout_map}, level); next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map,
&analyzer_, buffer_oob},
level);
// Process the returned updates // Process the returned updates
for (const auto &[buffer, layout] : updates) { for (const auto &[buffer, layout] : updates) {
...@@ -199,6 +206,9 @@ public: ...@@ -199,6 +206,9 @@ public:
ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size()) ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size())
<< "Size mismatch: thread_bounds_vec_ and infer_list_ must match in " << "Size mismatch: thread_bounds_vec_ and infer_list_ must match in "
"length."; "length.";
ICHECK_EQ(buffer_oob_vec_.size(), infer_list_.size())
<< "Size mismatch: buffer_oob_vec_ and infer_list_ must match in "
"length.";
// If needed, you can also check that annotated_layout_map_ is not empty, or // If needed, you can also check that annotated_layout_map_ is not empty, or
// anything else relevant to your setup. // anything else relevant to your setup.
...@@ -306,8 +316,7 @@ private: ...@@ -306,8 +316,7 @@ private:
addToUseList(buffer.value()); addToUseList(buffer.value());
} }
} }
infer_list_stmt_.push_back(GetRef<ObjectRef>(op)); // Compute thread_var_ and thread_bounds_
infer_list_.push_back(std::move(p));
thread_var_vec_.push_back(thread_var_); thread_var_vec_.push_back(thread_var_);
if (analyzer_.const_int_bound.IsBound(thread_var_->var)) { if (analyzer_.const_int_bound.IsBound(thread_var_->var)) {
auto const_int_bound = analyzer_.const_int_bound(thread_var_); auto const_int_bound = analyzer_.const_int_bound(thread_var_);
...@@ -320,6 +329,39 @@ private: ...@@ -320,6 +329,39 @@ private:
} else { } else {
thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1)); thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
} }
// Compute buffer oob for each buffer in the op
if (const auto *copy = p.as<CopyNode>()) {
auto src_tensor = copy->src;
auto dst_tensor = copy->dst;
auto src_range = copy->src_range;
auto dst_range = copy->dst_range;
bool src_oob = false;
bool dst_oob = false;
for (size_t i = 0; i < src_range.size(); i++) {
if (!analyzer_.CanProve(src_range[i]->min + src_range[i]->extent <=
src_tensor->shape[i],
arith::ProofStrength::kSymbolicBound)) {
src_oob = true;
break;
}
}
for (size_t i = 0; i < dst_range.size(); i++) {
if (!analyzer_.CanProve(dst_range[i]->min + dst_range[i]->extent <=
dst_tensor->shape[i],
arith::ProofStrength::kSymbolicBound)) {
dst_oob = true;
break;
}
}
buffer_oob_vec_.push_back(src_oob || dst_oob);
} else {
buffer_oob_vec_.push_back(false);
}
// Add the tile operator to infer_list_
infer_list_stmt_.push_back(GetRef<ObjectRef>(op));
infer_list_.push_back(std::move(p));
} }
} }
...@@ -365,6 +407,7 @@ private: ...@@ -365,6 +407,7 @@ private:
} else { } else {
thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1)); thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
} }
buffer_oob_vec_.push_back(false);
} else { } else {
IRVisitorWithAnalyzer::VisitStmt(op->body); IRVisitorWithAnalyzer::VisitStmt(op->body);
} }
...@@ -411,6 +454,7 @@ private: ...@@ -411,6 +454,7 @@ private:
IterVarType::kDataPar); IterVarType::kDataPar);
std::vector<IterVar> thread_var_vec_; std::vector<IterVar> thread_var_vec_;
std::vector<Range> thread_bounds_vec_; std::vector<Range> thread_bounds_vec_;
std::vector<bool> buffer_oob_vec_;
Target target_; Target target_;
LayoutMap annotated_layout_map_; LayoutMap annotated_layout_map_;
bool skip_thread_partition_{false}; bool skip_thread_partition_{false};
...@@ -556,6 +600,8 @@ private: ...@@ -556,6 +600,8 @@ private:
: arith::IRMutatorWithAnalyzer(analyzer), result_(result), : arith::IRMutatorWithAnalyzer(analyzer), result_(result),
skip_thread_partition_(skip_thread_partition){}; skip_thread_partition_(skip_thread_partition){};
using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer;
/** /**
* @brief Visit and mutate a Block node to attach inferred layout information. * @brief Visit and mutate a Block node to attach inferred layout information.
* *
......
...@@ -605,43 +605,14 @@ public: ...@@ -605,43 +605,14 @@ public:
class WSCodeEmitter : public StmtMutator { class WSCodeEmitter : public StmtMutator {
public: public:
/**
* @brief Construct a warp-specialized code emitter configured for producer or
* consumer emission.
*
* Initializes a WSCodeEmitter that will emit barrier-aware, role-filtered
* code for a single warp-specialized block. The emitter is configured with
* the loop/thread iteration variable, buffer mapping, role marker used to
* classify statements, and two flags that control emission behavior:
*
* - `mbarrier_only`: when true, emission is restricted to barrier-related
* operations only.
* - `only_has_wgmma`: when true, the emitter will account for the presence of
* WgMMA (workgroup MMA) operations when computing barrier/thread gating
* behavior.
*
* @param is_emitting_producer True to emit producer-side groups; false to
* emit consumer-side groups.
* @param thread_iv IterVar representing the thread iteration variable
* (threadIdx.*) whose Var is used for thread-index rewrites and gating.
* @param buffer_data_to_buffer Map from buffer data Var to the corresponding
* Buffer (used to resolve buffer references during emission).
* @param marker Role marker that classifies statements as
* producer/consumer/both; used to filter which statements are emitted on this
* path.
* @param mbarrier_only If true, restrict emission to mbarrier-related
* statements and helpers.
* @param only_has_wgmma If true, adjust emission and barrier-thread-count
* logic for blocks that contain WgMMA operations.
*/
WSCodeEmitter(bool is_emitting_producer, const IterVar &thread_iv, WSCodeEmitter(bool is_emitting_producer, const IterVar &thread_iv,
Map<Var, Buffer> buffer_data_to_buffer, Map<Var, Buffer> buffer_data_to_buffer,
const WarpSpecializedRoleMarker &marker, const WarpSpecializedRoleMarker &marker,
bool mbarrier_only = false, bool only_has_wgmma = false) bool mbarrier_only = false)
: is_emitting_producer_(is_emitting_producer), : is_emitting_producer_(is_emitting_producer),
buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
marker_(marker), thread_var_(thread_iv->var), marker_(marker), thread_var_(thread_iv->var),
mbarrier_only_(mbarrier_only), only_has_wgmma_(only_has_wgmma) {} mbarrier_only_(mbarrier_only) {}
/** /**
* @brief Whether a SIMT-style bulk copy was detected. * @brief Whether a SIMT-style bulk copy was detected.
...@@ -654,18 +625,6 @@ public: ...@@ -654,18 +625,6 @@ public:
*/ */
bool hasSimtCopy() const { return has_simt_copy_; } bool hasSimtCopy() const { return has_simt_copy_; }
/**
* @brief Whether this emitter contains only warp-group MMA (WgMMA)
* operations.
*
* Returns true if the emitter detected exclusively WgMMA usage in the region
* it analyzed.
*
* @return bool true when only WgMMA-based code paths are present; false
* otherwise.
*/
bool onlyHasWgMMA() const { return only_has_wgmma_; }
private: private:
template < template <
typename NodeType> /** typename NodeType> /**
...@@ -706,47 +665,6 @@ private: ...@@ -706,47 +665,6 @@ private:
} }
} }
/**
* @brief Visit and transform a SeqStmt node, emitting grouped blocks with
* barrier synchronization according to producer/consumer roles.
*
* This method examines the sequence to determine whether producer-side
* synchronization is required (based on marker_ roles). If no producer sync
* is needed it delegates to FilterByRole. Otherwise it:
* - Recursively visits and transforms each child statement.
* - Extracts an acquire/release sync pattern for the sequence via
* ExtractSyncPattern.
* - For producer emission (is_emitting_producer_ == true):
* - Skips consumer-only statements unless marker_ marks a statement as
* Both, in which case the statement is emitted as its own group.
* - For each statement, inserts parity waits for acquire patterns, rewrites
* release statements with MbarrierRewriter using a computed barrier id,
* collects SimT-copy presence (setting has_simt_copy_ and inserting
* cp.async barriers when found), optionally emits arrive barriers for
* release-after events, and emits each resulting set of statements as a
* group block annotated with "stmt_group".
* - For consumer emission (is_emitting_producer_ == false):
* - Skips producer-only statements.
* - Inserts parity waits for acquire patterns, appends the transformed
* statement, and emits arrive barriers for release-after events. When
* only_has_wgmma_ is set, the arrive barrier uses a per-thread predicate
* (FloorMod(thread_var_,128)==0) with CTA=0; otherwise a full arrive is
* emitted.
* - Recomputes pipeline_info_ to drop producer-only ops.
*
* Side effects / state updates:
* - Increments num_barriers_ by (number of extracted patterns * num_stages_).
* - May set has_simt_copy_ when a SimT copy is detected in producer rewrites.
* - Inserts barrier ids into released_barrier_ for release-after events.
* - Updates pipeline_info_ for the consumer path to remove producer ops.
*
* The resulting statements are emitted as grouped blocks (via MakeGroupBlock)
* with the annotation "stmt_group" and returned as either a single Stmt (if
* there's only one group) or a SeqStmt containing the grouped blocks.
*
* @return Stmt The transformed statement (either a single group block or a
* SeqStmt of group blocks).
*/
Stmt VisitStmt_(const SeqStmtNode *op) final { Stmt VisitStmt_(const SeqStmtNode *op) final {
bool has_producer = false; bool has_producer = false;
...@@ -855,11 +773,7 @@ private: ...@@ -855,11 +773,7 @@ private:
int pattern_idx = map.release[i][j]; int pattern_idx = map.release[i][j];
PrimExpr release_barrier_id = PrimExpr release_barrier_id =
stage_ + num_barriers_ + num_stages_ * pattern_idx; stage_ + num_barriers_ + num_stages_ * pattern_idx;
if (only_has_wgmma_) block_stmt.push_back(makeArriveBarrier(release_barrier_id));
block_stmt.push_back(makeArriveBarrier(
release_barrier_id, 0, EQ(FloorMod(thread_var_, 128), 0)));
else
block_stmt.push_back(makeArriveBarrier(release_barrier_id));
for (int s = 0; s < num_stages_; s++) { for (int s = 0; s < num_stages_; s++) {
released_barrier_.insert(s + num_barriers_ + released_barrier_.insert(s + num_barriers_ +
num_stages_ * pattern_idx); num_stages_ * pattern_idx);
...@@ -1209,7 +1123,6 @@ private: ...@@ -1209,7 +1123,6 @@ private:
bool mbarrier_only_ = false; bool mbarrier_only_ = false;
PipelineInfo pipeline_info_; PipelineInfo pipeline_info_;
friend class WarpSpecializedRewriter; friend class WarpSpecializedRewriter;
bool only_has_wgmma_ = false;
bool has_simt_copy_ = false; bool has_simt_copy_ = false;
}; };
...@@ -1277,38 +1190,6 @@ private: ...@@ -1277,38 +1190,6 @@ private:
return for_node; return for_node;
} }
/**
* @brief Rewrite a BlockRealize for warp specialization, inserting barriers
* and emitting producer/consumer bodies.
*
* This visitor handles BlockRealize nodes when a thread IterVar (thread_iv_)
* is defined and warp-specialization is applicable. It:
* - Determines producer/consumer roles via WarpSpecializedRoleMarker and
* returns the original block if no producer is detected.
* - If warp specialization is disabled, emits only mbarrier initialization
* and the mbarrier-only transformed body.
* - Otherwise, detects WgMMA usage for the block body and constructs separate
* WSCodeEmitter instances for producer and consumer paths (propagating the
* WgMMA flag to the consumer emitter).
* - Generates producer/consumer code, applies register hint calls
* (set_max_nreg) when available, and rewrites thread indices with
* ThreadIdxRewriter to partition threads between producer and consumer roles.
* - Computes and initializes a list of mbarrier handles with per-barrier
* arrive thread counts (taking SIMT-copy and WgMMA cases into account).
* - Wraps the transformed body in an IfThenElse that dispatches producer vs
* consumer based on thread index, and annotates the region with the
* "kWarpSpecializationScope" attribute that contains producer/consumer
* thread extents.
*
* Side effects:
* - May update member state: only_has_wgmma_, updated_thread_extent_,
* need_update_thread_extent_.
* - May abort via ICHECK if invariants (e.g., matching barrier counts) are
* violated.
*
* @return The possibly rewritten BlockRealize statement (original when no
* warp-specialization is applied or thread_iv_ is undefined).
*/
Stmt VisitStmt_(const BlockRealizeNode *op) final { Stmt VisitStmt_(const BlockRealizeNode *op) final {
BlockRealize block_realize = BlockRealize block_realize =
Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op)); Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
...@@ -1342,10 +1223,9 @@ private: ...@@ -1342,10 +1223,9 @@ private:
block_realize.CopyOnWrite()->block = block; block_realize.CopyOnWrite()->block = block;
return block_realize; return block_realize;
} }
only_has_wgmma_ = WgMMACollector::HasWgMMA(block->body);
WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker); WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker);
WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker, WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker,
false, only_has_wgmma_); false);
Stmt producer_code = producer(block->body); Stmt producer_code = producer(block->body);
Stmt consumer_code = consumer(block->body); Stmt consumer_code = consumer(block->body);
PrimExpr consumer_thread_extent = thread_iv_->dom->extent; PrimExpr consumer_thread_extent = thread_iv_->dom->extent;
...@@ -1374,8 +1254,7 @@ private: ...@@ -1374,8 +1254,7 @@ private:
PrimExpr arrive_thread_count = PrimExpr arrive_thread_count =
producer.released_barrier_.count(i) producer.released_barrier_.count(i)
? (producer.hasSimtCopy() ? producer_thread_extent : 1) ? (producer.hasSimtCopy() ? producer_thread_extent : 1)
: (only_has_wgmma_ ? FloorDiv(consumer_thread_extent, 128) : consumer_thread_extent;
: consumer_thread_extent);
barrier_num_threads.push_back(arrive_thread_count); barrier_num_threads.push_back(arrive_thread_count);
} }
...@@ -1403,7 +1282,6 @@ private: ...@@ -1403,7 +1282,6 @@ private:
bool need_update_thread_extent_ = false; bool need_update_thread_extent_ = false;
bool disable_warp_specialized_ = false; bool disable_warp_specialized_ = false;
bool disable_shuffle_elect_ = false; bool disable_shuffle_elect_ = false;
bool only_has_wgmma_ = false;
}; };
class WarpSpecializedDetector : public IRVisitorWithAnalyzer { class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment