Unverified Commit 9a869396 authored by coderabbitai[bot]'s avatar coderabbitai[bot] Committed by GitHub
Browse files

📝 Add docstrings to `reducer_0825` (#772)

* 📝 Add docstrings to `reducer_0825`

Docstrings generation was requested by @LeiWang1999.

* https://github.com/tile-ai/tilelang/pull/757#issuecomment-3219088118



The following files were modified:

* `setup.py`
* `src/op/builtin.h`
* `src/op/finalize_reducer.cc`
* `src/op/finalize_reducer.h`
* `src/op/parallel.cc`
* `src/op/parallel.h`
* `src/op/reduce.cc`
* `src/target/codegen_cuda.cc`
* `src/tl_templates/cuda/common.h`
* `src/transform/layout_inference.cc`
* `src/transform/layout_reducer.cc`
* `src/transform/layout_reducer.h`
* `src/transform/merge_shared_memory_allocations.cc`
* `src/transform/storage_access.cc`
* `src/transform/warp_specialized_rewriter.cc`
* `testing/python/autotune/test_tilelang_autotune_with_inputs.py`
* `tilelang/engine/phase.py`
* `tilelang/language/customize.py`
* `tilelang/language/reduce.py`
* `tilelang/transform/__init__.py`

* lint fix

* lint fix

---------
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent a7a29c09
...@@ -749,9 +749,20 @@ class TilelangExtensionBuild(build_ext): ...@@ -749,9 +749,20 @@ class TilelangExtensionBuild(build_ext):
def build_cmake(self, ext): def build_cmake(self, ext):
""" """
Build a single CMake-based extension. Build a single CMake-based extension by generating a CMake config and invoking CMake/Ninja.
:param ext: The extension (an instance of CMakeExtension). Generates or updates a config.cmake in the build directory (based on the extension's sourcedir),
injecting LLVM/CUDA/ROCm and Python settings, then runs CMake to configure and build the target.
When running an in-place build the resulting library is placed under ./tilelang/lib; otherwise the
standard extension output directory is used.
Parameters:
ext: The CMakeExtension to build; its `sourcedir` should contain the TVM/CMake `config.cmake`
template under `3rdparty/tvm/cmake/`.
Raises:
subprocess.CalledProcessError: If the CMake configuration or build commands fail.
OSError: If filesystem operations (read/write) fail.
""" """
# Only setup LLVM if it's enabled # Only setup LLVM if it's enabled
llvm_config_path = "OFF" llvm_config_path = "OFF"
......
...@@ -11,6 +11,14 @@ ...@@ -11,6 +11,14 @@
#include <tvm/ir/transform.h> #include <tvm/ir/transform.h>
namespace tvm { namespace tvm {
/*!
* \brief Create the TVM intrinsic that initializes a PTX fence barrier.
*
* Initializes a PTX fence-style barrier used to coordinate asynchronous memory
* operations (for example, TMA/TMA_STORE). Returns the Op representing this
* intrinsic for use in TIR lowering and code generation.
*
*/
namespace tl { namespace tl {
namespace attr { namespace attr {
......
...@@ -18,6 +18,20 @@ namespace tl { ...@@ -18,6 +18,20 @@ namespace tl {
using namespace tir; using namespace tir;
/**
* @brief Construct a FinalizeReducerOp from TL operator arguments and a buffer
* map.
*
* Extracts the reducer Buffer from `vmap` using the variable referenced by
* `args[0]` and sets the reduction operation type from the integer code in
* `args[1]`.
*
* @param args TL operator arguments: expects at least two elements where
* `args[0]` is an access pointer identifying the reducer variable
* and `args[1]` is an integer encoding a `ReducerOpType` (e.g., Sum/Max/Min).
* @param vmap Mapping from variables to Buffers used to look up the reducer
* Buffer.
*/
FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap) { FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap) {
auto node = make_object<FinalizeReducerOpNode>(); auto node = make_object<FinalizeReducerOpNode>();
node->reducer = vmap[GetVarFromAccessPtr(args[0])]; node->reducer = vmap[GetVarFromAccessPtr(args[0])];
...@@ -25,6 +39,33 @@ FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap) { ...@@ -25,6 +39,33 @@ FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap) {
data_ = std::move(node); data_ = std::move(node);
} }
/**
* @brief Lower the finalize_reducer TL operator to a TIR statement.
*
* Lowers the operator that finalizes a reducer by performing a thread-wide
* AllReduce across the reducer's output elements and writing the reduced value
* back into the reducer buffer. The function:
* - Fetches the reducer buffer and expects its layout to be a Fragment.
* - Builds index Vars for each output dimension.
* - Reads the layout's ReplicateExtent and:
* - if extent == 1, emits a no-op Evaluate(0);
* - otherwise constructs an AllReduce extern call (uses `run_hopper` when the
* compilation target is Hopper) with an optional workspace (allocated via
* T.AddWorkspace when reducing_threads >= 32) and stores the result via
* BufferStore.
* - Wraps the store in parallel outer For loops over each output dimension.
*
* @param T Lowering context containing buffer remapping, layout map, thread
* bounds, target, and helper methods (e.g., AddWorkspace).
* @param analyzer Arithmetic analyzer (unused by this implementation but
* provided for consistency with lowering API).
* @return Stmt The lowered TIR statement representing the AllReduce and
* surrounding loops.
*
* @note The function ICHECKs that the reducer layout is present and a Fragment,
* and that ReplicateExtent is either 1 or equal to the thread block
* extent; violations cause a fatal check failure.
*/
Stmt FinalizeReducerOpNode::Lower(const LowerArgs &T, Stmt FinalizeReducerOpNode::Lower(const LowerArgs &T,
arith::Analyzer *analyzer) const { arith::Analyzer *analyzer) const {
auto buffer = T.buffer_remap[reducer]; auto buffer = T.buffer_remap[reducer];
...@@ -81,6 +122,19 @@ Stmt FinalizeReducerOpNode::Lower(const LowerArgs &T, ...@@ -81,6 +122,19 @@ Stmt FinalizeReducerOpNode::Lower(const LowerArgs &T,
return body; return body;
} }
/**
* @brief Infer and return the layout mapping for the reducer buffer.
*
* Copies the existing layout for the reducer from the provided LayoutInferArgs
* into a new LayoutMap and returns it. The inference does not modify the
* layout; it preserves the reducer's current layout.
*
* @param T Provides the input layout map from which the reducer's layout is
* copied.
* @param level Unused by this operator; present for API compatibility.
* @return LayoutMap A map that contains the reducer buffer mapped to its
* original layout.
*/
LayoutMap FinalizeReducerOpNode::InferLayout(const LayoutInferArgs &T, LayoutMap FinalizeReducerOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const { InferLevel level) const {
LayoutMap layout_map; LayoutMap layout_map;
...@@ -88,6 +142,15 @@ LayoutMap FinalizeReducerOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -88,6 +142,15 @@ LayoutMap FinalizeReducerOpNode::InferLayout(const LayoutInferArgs &T,
return layout_map; return layout_map;
} }
/**
* @brief Create a deep copy of this FinalizeReducerOpNode and wrap it as a
* TileOperator.
*
* Constructs a new FinalizeReducerOpNode by copying the current node state and
* returns a TileOperator that owns the copied node.
*
* @return TileOperator A TileOperator that contains a deep copy of this node.
*/
TileOperator FinalizeReducerOpNode::Clone() const { TileOperator FinalizeReducerOpNode::Clone() const {
auto node = make_object<FinalizeReducerOpNode>(*this); auto node = make_object<FinalizeReducerOpNode>(*this);
return TileOperator(node); return TileOperator(node);
......
...@@ -12,6 +12,71 @@ ...@@ -12,6 +12,71 @@
#include "../transform/layout_reducer.h" #include "../transform/layout_reducer.h"
#include "./operator.h" #include "./operator.h"
/**
* FinalizeReducer operator node for Tile IR.
*
* Represents a TL-level operator that finalizes a reducer buffer into a
* result using a specified reducer operation.
*
* Public members:
* - reducer: the tir::Buffer that holds the intermediate reduction values.
* - op: the reducer operation to apply when finalizing values.
*/
/**
* Lower this operator to a TIR statement.
*
* @param T Lowering arguments (buffers, indices, and other lowering context).
* @param analyzer Arithmetic analyzer used to simplify expressions during
* lowering.
* @return A tir::Stmt that implements the finalize-reducer semantics for the
* provided lowering context.
*/
/**
* Infer layout mapping for this operator.
*
* Determines how input and output buffer layouts relate for the
* finalize-reducer operator at the given inference level.
*
* @param T Layout inference arguments (including operand layouts and shapes).
* @param level Inference precision level.
* @return A LayoutMap describing the inferred layouts.
*/
/**
* Get the singleton Op object representing this operator.
*
* @return A reference to the Op describing FinalizeReducer.
*/
/**
* Create a deep copy of this operator node as a TileOperator.
*
* @return A TileOperator handle that is an independent clone of this node.
*/
/**
* Public wrapper for FinalizeReducerOpNode.
*
* Provides the reference semantics and construction API used by callers.
*/
/**
* Construct a FinalizeReducerOp from TL-level arguments.
*
* @param args Positional primitive expressions that parameterize the operator
* (e.g., shapes, axis indices). Documented where their meaning is
* not obvious from name or type in call sites.
* @param vmap Mapping from operand names to tir::Buffer instances used by this
* operator.
*/
/**
* Get the Op singleton for the public FinalizeReducerOp handle.
*
* @return A reference to the Op describing FinalizeReducer.
*/
namespace tvm { namespace tvm {
namespace tl { namespace tl {
......
...@@ -119,6 +119,14 @@ private: ...@@ -119,6 +119,14 @@ private:
Map<Buffer, Layout> layout_map_; Map<Buffer, Layout> layout_map_;
}; };
/**
* @brief Handle a parallel For node during traversal, collecting loop metadata.
*
* Visits a parallel loop, asserts the loop is parallel, records a data-parallel
* IterVar for the loop, binds the loop variable range into the analyzer scope,
* and extracts any reducer information from the loop's annotations into the
* visitor's reducer_info_map_. Continues traversal into the loop body.
*/
void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) { void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
ICHECK(op->kind == ForKind::kParallel); ICHECK(op->kind == ForKind::kParallel);
p->loop_vars_.push_back( p->loop_vars_.push_back(
...@@ -147,19 +155,6 @@ void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) { ...@@ -147,19 +155,6 @@ void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) {
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
} }
/**
* @brief Visit a BufferLoad node and record/validate index mapping for
* fragment-local buffers.
*
* If the loaded buffer's scope is "local.fragment", this records the load
* indices in the visitor's indice_map_ when seen for the first time. If an
* entry already exists, the previously recorded indices are asserted
* structurally equal to the current indices.
*
* This ensures all accesses to the same fragment-local buffer within the
* parallel loop use a consistent index map. The function then continues
* standard expression visitation.
*/
void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) { void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) {
if (op->buffer.scope() == "local.fragment") { if (op->buffer.scope() == "local.fragment") {
if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) { if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
...@@ -173,91 +168,42 @@ void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) { ...@@ -173,91 +168,42 @@ void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) {
StmtExprVisitor::VisitExpr_(op); StmtExprVisitor::VisitExpr_(op);
} }
/**
* @brief Construct a ParallelOpNode from a parallel loop nest root.
*
* Initializes the node with the given For loop as the root of the parallel
* operator and immediately runs the internal ParallelLoopNestVisitor to collect
* loop and buffer access information from the nested body.
*
* @param root The root For node representing the parallel loop nest to be
* analyzed.
*/
ParallelOpNode::ParallelOpNode(For root) : root_(root), V(this) { ParallelOpNode::ParallelOpNode(For root) : root_(root), V(this) {
V.VisitStmt(root); V.VisitStmt(root);
} }
/**
* @brief Create a copy of this ParallelOpNode wrapped as a TileOperator.
*
* Returns a new TileOperator that holds a deep copy of this ParallelOpNode.
*
* @return TileOperator A TileOperator owning a copy of this node.
*/
TileOperator ParallelOpNode::Clone() const { TileOperator ParallelOpNode::Clone() const {
auto op = make_object<ParallelOpNode>(*this); auto op = make_object<ParallelOpNode>(*this);
return ParallelOp(op); return ParallelOp(op);
} }
/**
* @brief No-op lowering: return the stored root statement unchanged.
*
* This implementation does not perform any transformation and returns the
* operator's original root For statement as-is.
*
* @param T Lowering arguments (unused).
* @return Stmt The original root statement held by this ParallelOpNode.
*/
Stmt ParallelOpNode::Lower(const LowerArgs &T, Stmt ParallelOpNode::Lower(const LowerArgs &T,
arith::Analyzer *analyzer) const { arith::Analyzer *analyzer) const {
return root_; return root_;
} }
/**
* @brief Check whether a buffer is indexed by the loop's canonical (common)
* iteration variables.
*
* Returns true if the recorded index mapping for `buffer` is structurally equal
* to the sequence of loop iteration variables for this parallel op (i.e., the
* buffer is accessed using the common access indices of the loop nest).
*
* @param buffer The buffer to check.
* @return true if the buffer's index map equals the loop's iteration variables;
* false otherwise.
*/
bool ParallelOpNode::IsCommonAccessIndice(const Buffer &buffer) const { bool ParallelOpNode::IsCommonAccessIndice(const Buffer &buffer) const {
auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; }); auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; });
return StructuralEqual()(indice_map_[buffer], common_indice); return StructuralEqual()(indice_map_[buffer], common_indice);
} }
/** /*! \brief Infer the layout for parallel operations based on different inference
* @brief Infer buffer layouts for a Parallel operator based on the chosen * levels
* inference level.
* *
* Attempts to compute a consistent LayoutMap for buffers accessed by a parallel * The inference level controls how aggressively we try to infer and optimize
* loop (root_) using explicit input layouts (T.layout_map), thread bounds * layouts:
* (T.thread_bounds), and optional buffer remapping/vectorization information in * - kStrict (2): Most conservative level. Only allows explicitly defined
* T. Behavior depends on the supplied InferLevel: * layouts. Returns empty layout map if loop_layout_ is not already defined.
* - kStrict: only accept pre-existing loop_layout_ (no inference). * Used when exact layout control is required.
* - kCommon: allow inference from explicit buffer fragments when available.
* - kFree: attempt more aggressive inference (derive loop partition from
* read/write fragments, plan partitioning from vectorization/thread bounds, and
* add predicates to constrain replication when necessary).
* *
* This method may mutate the node's internal state (sets loop_layout_ when * - kCommon (1): Intermediate level between strict and free.
* inferred and registers predicates via AddPredicate) and consults analyzer_ * Allows common layout patterns while maintaining some
* for symbolic proofs. * constraints.
* *
* @param T Container of auxiliary inputs used for inference (buffer_remap, * - kFree (0): Most permissive level. Allows maximum optimization freedom.
* layout_map, and thread_bounds). The function uses T.layout_map for source * Will attempt layout inference even without source buffers.
* fragments and T.thread_bounds to bind thread-range information in inferred * Can generate new layouts based on vectorization and thread
* fragments. * bounds. Used when maximum performance optimization is desired.
* @param level Controls inference aggressiveness (kStrict, kCommon, kFree).
* @return LayoutMap A map of buffers to inferred Fragment layouts for buffers
* that did not already have layouts in T.layout_map. Returns an empty map when
* no inference was performed.
* @throws LayoutConflictException If a computed loop partition conflicts with
* an existing buffer fragment (incompatible thread mappings).
*/ */
LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const { InferLevel level) const {
...@@ -446,20 +392,6 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -446,20 +392,6 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
return results; return results;
} }
/**
* @brief Retrieve the loop's thread predicate with the thread variable
* substituted.
*
* If a predicate is set for this ParallelOpNode, returns a copy of that
* predicate where the placeholder input (InputPlaceholder(0)) is replaced by
* the provided thread_var. If no predicate is defined, returns an empty
* Optional.
*
* @param thread_var The thread loop variable to substitute for the predicate's
* input placeholder.
* @return Optional<PrimExpr> The substituted predicate expression, or
* std::nullopt if none is defined.
*/
Optional<PrimExpr> ParallelOpNode::GetPredicate(Var thread_var) const { Optional<PrimExpr> ParallelOpNode::GetPredicate(Var thread_var) const {
if (predicate_.defined()) { if (predicate_.defined()) {
return Substitute(predicate_.value(), {{InputPlaceholder(0), thread_var}}); return Substitute(predicate_.value(), {{InputPlaceholder(0), thread_var}});
...@@ -468,32 +400,6 @@ Optional<PrimExpr> ParallelOpNode::GetPredicate(Var thread_var) const { ...@@ -468,32 +400,6 @@ Optional<PrimExpr> ParallelOpNode::GetPredicate(Var thread_var) const {
} }
} }
/**
* @brief Construct the complete fragment layout for a buffer within the
* parallel loop.
*
* Given a buffer referenced inside the parallel loop, return a Fragment that
* maps the buffer's logical indices to the loop's thread space and replication
* extent.
*
* Detailed behavior:
* - Precondition: a loop layout (loop_layout_) must be defined.
* - If the buffer uses the common access indices of the loop, the loop's
* fragment is returned directly.
* - Otherwise, the function:
* - Computes the buffer's bijective index by appending the flattened
* replication expression for unused iterators.
* - Inverts that bijection to obtain the replication extent of the buffer's
* index space and combines it with the loop's replication extent to produce the
* destination replication extent.
* - Builds forward index placeholders for the buffer elements and maps them
* through the inverted layout and the loop layout to derive the thread binding.
* - Returns a Fragment with the computed thread binding and combined
* replication extent, with replicate variables condensed.
*
* @return Fragment The completed fragment describing thread binding and
* replication extent for `buffer`.
*/
Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const { Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
ICHECK(loop_layout_.defined()); ICHECK(loop_layout_.defined());
if (IsCommonAccessIndice(buffer)) { if (IsCommonAccessIndice(buffer)) {
......
...@@ -14,138 +14,101 @@ ...@@ -14,138 +14,101 @@
#include "./operator.h" #include "./operator.h"
/** /**
* Exception representing a layout conflict detected during layout inference. * Exception indicating a layout conflict during layout inference or validation.
* * The stored message is returned by what().
* Stores an explanatory message retrievable via what().
*/ */
/** /**
* Determine whether `small_frag` is guaranteed to be contained within * Verify that `small_frag` is contained within `large_frag` under the provided
* `large_frag` under the given index mappings and using the provided arithmetic * index mappings and using symbolic reasoning via `analyzer_`.
* analyzer.
* *
* @param small_frag The smaller fragment to test for containment. * @param small_frag Fragment describing the smaller layout fragment.
* @param large_frag The larger fragment that may contain `small_frag`. * @param large_frag Fragment describing the larger layout fragment.
* @param small_frag_indices Index expressions mapping the small fragment into * @param small_frag_indices Index expressions that map accesses into
* buffer space. * `small_frag`.
* @param large_frag_indices Index expressions mapping the large fragment into * @param large_frag_indices Index expressions that map accesses into
* buffer space. * `large_frag`.
* @param analyzer_ Arithmetic analyzer used to simplify and prove index * @param analyzer_ Analyzer used for symbolic simplification and proving
* relations. * relations.
* @return true if containment can be proven; false otherwise. * @return true if `small_frag` can be proven to be contained in `large_frag`
* given the index mappings and analyzer; false otherwise.
*/ */
/** /**
* Visitor that traverses a parallel loop nest to collect buffer access and * Visitor that traverses a parallel loop nest to collect loop structure,
* loop-structure information for a ParallelOpNode. * buffer access patterns, and to populate the associated ParallelOpNode.
*
* The visitor records loop variables, buffer read/write accesses, and builds
* predicates as it encounters BufferLoad/BufferStore and For nodes.
*/ */
/** /**
* Represents a parallel for-loop operator in TileLang. * Construct a ParallelOpNode from a root For loop.
* *
* Holds the root For loop, collects and exposes loop layout and access-index * @param root The TIR For node that is the root of the parallel loop nest.
* information, and provides layout inference and lowering to TIR.
*
* Public methods expose the inferred loop layout, root loop, buffer index
* mappings, and any per-thread predicate; Lower and InferLayout perform the
* operator's lowering and layout inference respectively.
*/ */
/** /**
* Create a ParallelOpNode from a root For loop. * Lower this ParallelOpNode to a TIR statement.
* *
* @param root The root For node representing the parallel loop nest. * Performs lowering of the operator (including any necessary predicates,
*/ * reductions, and loop transformations) to produce an equivalent tir::Stmt.
/**
* Lower this parallel operator into a TIR statement suitable for codegen.
* *
* @param T Lowering arguments and context. * @param T Lowering options and context.
* @param analyzer Arithmetic analyzer for expression simplification during * @param analyzer Optional analyzer for symbolic simplification during
* lowering. * lowering.
* @return A TIR statement representing the lowered parallel loop. * @return A tir::Stmt representing the lowered operator.
*/ */
/** /**
* Infer the layout mapping for this parallel operator at the specified level. * Infer layouts for buffers used by this parallel operator.
* *
* @param T Arguments and context for layout inference. * This performs layout inference at the requested level and returns a mapping
* @param level Inference granularity level. * from buffers to their inferred layout fragments.
* @return A LayoutMap describing inferred buffer/layout relationships for the
* operator.
*/
/**
* Copy-construct a ParallelOpNode, preserving inferred layout and predicate.
*/
/**
* Get the inferred loop layout fragment.
* *
* @return The Fragment representing the loop's inferred layout (may be lazily * @param T Layout inference arguments and context.
* computed). * @param level Granularity level for inference.
* @return LayoutMap mapping buffers to inferred fragments.
*/ */
/** /**
* Get the root For loop of this operator. * Return an optional predicate expression associated with the given thread
* variable.
* *
* @return The root For AST node. * If the loop nest imposes a condition on `thread_var` (e.g., bounds checks or
*/ * tiling edge predicates), this returns the combined predicate; otherwise
* returns an empty Optional.
/**
* Get the mapping from each buffer to the array of index expressions used to
* access it within the loop nest.
*
* @return A Map from Buffer to Array<PrimExpr> of access indices.
*/
/**
* Retrieve the predicate expression associated with a given thread variable, if
* any.
*
* @param thread_var The thread variable whose predicate is requested.
* @return An Optional<PrimExpr> containing the predicate when present.
*/
/**
* Create a deep copy of this operator as a TileOperator handle.
*
* @return A TileOperator that references a copy of this node.
*/
/**
* Visitor helper: complete the fragment layout for a buffer (internal).
* *
* (Private helper — not part of the public API.) * @param thread_var The thread variable for which to retrieve the predicate.
* @return Optional containing the predicate expression if present.
*/ */
/** /**
* Helper to check whether a buffer's access indices are the common loop indices * Create and return a clone of this operator as a TileOperator (deep copy of
* (internal). * operator state necessary for further transformations).
* *
* (Private helper — not part of the public API.) * @return A TileOperator referencing a cloned ParallelOpNode.
*/ */
/** /**
* Add `expr` to the current predicate by logical AND; sets predicate if none * Complete the layout fragment for `buffer` by filling in any missing
* exists. * dimension or stride information derived from access patterns in the loop
* nest.
* *
* (Private helper — not part of the public API.) * @param buffer The buffer whose fragment should be completed.
* @return A Fragment representing the completed layout for `buffer`.
*/ */
/** /**
* Thin handle type exposing ParallelOpNode as a TileOperator. * Determine whether `buffer` is accessed using only the loop-common indices
* (i.e., indices that correspond to the loop variables of this parallel nest).
* *
* Construct from a root For loop to create and own a ParallelOpNode instance. * @param buffer The buffer to inspect.
* @return true if accesses use common loop indices; false otherwise.
*/ */
/** /**
* Construct a ParallelOp handle from a root For loop. * Conjoin `expr` into the operator's predicate (logical AND). If no predicate
* exists yet, `expr` becomes the predicate.
* *
* @param root The root For node representing the parallel loop nest. * @param expr Predicate expression to add.
*/ */
namespace tvm { namespace tvm {
namespace tl { namespace tl {
......
...@@ -22,25 +22,6 @@ namespace tl { ...@@ -22,25 +22,6 @@ namespace tl {
using namespace tir; using namespace tir;
/**
* @brief Construct a ReduceOp from raw TL arguments and a buffer mapping.
*
* Interprets `args` and `vmap` to populate an internal ReduceOpNode:
* - args[0]: access pointer for the source buffer
* - args[1]: access pointer for the destination buffer
* - args[2]: string literal specifying the reduce type: "sum", "abssum",
* "absmax", "max", or "min"
* - args[3]: integer literal for the reduction dimension (axis)
* - args[4]: boolean literal indicating whether to clear/init the destination
*
* The constructor resolves the access pointers via `vmap`, maps the reduce
* type string to the ReduceType enum, assigns the reduction dimension and
* clear flag, and stores the constructed node in `data_`. An invalid reduce
* type triggers a fatal check.
*
* @param args Array of TL prim-expr arguments as described above.
* @param vmap Mapping from variables (from access pointers) to Buffer objects.
*/
ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) { ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<ReduceOpNode> node = make_object<ReduceOpNode>(); ObjectPtr<ReduceOpNode> node = make_object<ReduceOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])]; node->src = vmap[GetVarFromAccessPtr(args[0])];
...@@ -63,52 +44,16 @@ ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) { ...@@ -63,52 +44,16 @@ ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
data_ = std::move(node); data_ = std::move(node);
} }
/**
* @brief Create a copy of this ReduceOpNode wrapped as a TileOperator.
*
* Returns a new TileOperator holding a freshly allocated ReduceOpNode
* constructed as a copy of this node.
*
* @return TileOperator A tile operator that owns the cloned ReduceOpNode.
*/
TileOperator ReduceOpNode::Clone() const { TileOperator ReduceOpNode::Clone() const {
auto op = make_object<ReduceOpNode>(*this); auto op = make_object<ReduceOpNode>(*this);
return ReduceOp(op); return ReduceOp(op);
} }
/**
* @brief Create a deep copy of this CumSum op node wrapped as a TileOperator.
*
* Returns a new TileOperator whose underlying CumSumOpNode is a copy of
* the current node. Useful for cloning operators when building or
* transforming computation graphs.
*
* @return TileOperator A TileOperator containing a copy of this node.
*/
TileOperator CumSumOpNode::Clone() const { TileOperator CumSumOpNode::Clone() const {
auto op = make_object<CumSumOpNode>(*this); auto op = make_object<CumSumOpNode>(*this);
return CumSumOp(op); return CumSumOp(op);
} }
/**
* @brief Create the initial accumulator value for the destination buffer based
* on reduction type.
*
* Returns the PrimExpr representing the initial value stored in the destination
* accumulator before any source elements are combined. The returned value
* depends on the destination dtype and the node's reduction type:
* - kSum, kAbsSum: zero of the destination dtype.
* - kMax: minimum representable value for signed integers, zero for unsigned
* integers, and -INFINITY for floating point.
* - kMin: maximum representable value for signed integers, all-ones (max) for
* unsigned integers, and +INFINITY for floating point.
* - kAbsMax: zero of the destination dtype.
*
* The function will abort (ICHECK failure) if the reduction type is
* unrecognized.
*
* @return PrimExpr initial value appropriate for `dst->dtype` and `type`.
*/
PrimExpr ReduceOpNode::MakeInitValue() const { PrimExpr ReduceOpNode::MakeInitValue() const {
auto dst_dtype = dst->dtype; auto dst_dtype = dst->dtype;
auto is_int = dst_dtype.is_int(); auto is_int = dst_dtype.is_int();
...@@ -143,24 +88,6 @@ PrimExpr ReduceOpNode::MakeInitValue() const { ...@@ -143,24 +88,6 @@ PrimExpr ReduceOpNode::MakeInitValue() const {
} }
} }
/**
* @brief Combine two scalar expressions according to this node's reduction
* type.
*
* Casts the right operand to the left operand's dtype if they differ, then
* returns the reduction of `a` and `b` using the operator specified by `type`:
* - kSum: `a + b`
* - kAbsSum: `a + max(b, -b)`
* - kMax: `max(a, b)`
* - kMin: `min(a, b)`
* - kAbsMax: `max(max(a, b), -min(a, b))`
*
* @param a Left-hand operand (result dtype drives the output dtype).
* @param b Right-hand operand (will be cast to `a`'s dtype if needed).
* @return PrimExpr The combined expression with dtype equal to `a.dtype`.
*
* @note The function DCHECKs/ICHECKs on an unknown/unsupported reduction type.
*/
PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &a, const PrimExpr &b) const { PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &a, const PrimExpr &b) const {
PrimExpr lhs = a, rhs = b; PrimExpr lhs = a, rhs = b;
if (lhs->dtype != rhs->dtype) { if (lhs->dtype != rhs->dtype) {
...@@ -183,20 +110,6 @@ PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &a, const PrimExpr &b) const { ...@@ -183,20 +110,6 @@ PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &a, const PrimExpr &b) const {
} }
} }
/**
* @brief Map the reduction type to the codegen reducer name used by external
* ALL-Reduce/CUDA helpers.
*
* Returns the string identifier of the code-generation reducer corresponding to
* this ReduceOpNode's `type`. Mapping:
* - kSum, kAbsSum -> "tl::SumOp"
* - kMax, kAbsMax -> "tl::MaxOp"
* - kMin -> "tl::MinOp"
*
* The function terminates with a check failure if `type` is unknown.
*
* @return std::string Reducer name used by codegen extern calls.
*/
std::string ReduceOpNode::MakeCodegenReducer() const { std::string ReduceOpNode::MakeCodegenReducer() const {
switch (type) { switch (type) {
case ReduceType::kSum: case ReduceType::kSum:
...@@ -216,30 +129,40 @@ std::string ReduceOpNode::MakeCodegenReducer() const { ...@@ -216,30 +129,40 @@ std::string ReduceOpNode::MakeCodegenReducer() const {
} }
/** /**
* @brief Lower the Reduce operator node to a TIR statement. * @brief Lower the Reduce operator to a TIR statement.
* *
* Lowers a ReduceOpNode that targets fragment-local buffers into a sequence of * Lowers a ReduceOpNode operating on fragment-scoped buffers into a sequence of
* TIR statements implementing: per-thread local reduction, inter-thread * TIR statements implementing: optional initialization, thread-local reduction
* AllReduce (when needed), and final writeback (with an optional duplicate * (unrolled inner loops), inter-thread reduction via a runtime AllReduce call
* clear buffer to avoid in-place conflicts). Supports reduction kinds * (Hopper-specific `run_hopper` variant when TargetIsHopper(T.target) is true),
* (sum/abs-sum/max/min/abs-max) and handles layout-driven index mapping and * and an optional accumulation or copy back to the destination buffer when a
* loop partitioning to thread axes. * temporary clear buffer is used.
* *
* @param T Lowering context providing buffer remapping, layout map, target and * Behavior notes:
* thread bounds, and workspace allocation helper. Must contain * - Only supports src and dst in "local.fragment" scope; otherwise it checks
* fragment-local mappings for both src and dst. * and aborts with "Reduce for shared memory not implemented.".
* @param analyzer Symbolic analyzer used to simplify and compress iterators. * - Supports both 1D reductions (scalar output) and reductions along a single
* @return Stmt The constructed TIR statement implementing the reduction. * extra dimension; validates layout dimensionality consistency.
* * - If `clear` is set (or for sum/abssum reductions), an initial value is
* Preconditions: * written to the clear buffer; for non-clearing sum/abssum a duplicate
* - src and dst buffers must be in "local.fragment" scope. * temporary buffer is allocated and accumulated back into dst after
* - The layouts must have compatible input/output dimensions for the * reduction.
* specified reduction axis. * - Performs iterator compression for local reduction loops using `analyzer`.
* * - Detects parallel thread splitting from the normalized iterator sum and
* Failure modes: * emits a call to a templated `tl::AllReduce<...>::run` (or `run_hopper`)
* - The function uses ICHECK to enforce unsupported scopes, dimension * via `builtin::call_extern`. For sufficiently large reducing thread counts
* mismatches, unknown reduction types, and other invariants; violations * (>= 32) a workspace is allocated via T.AddWorkspace and passed to the
* will trigger a fatal check failure. * AllReduce call.
* - The final body is wrapped in parallel loops over the destination spatial
* dimensions and partitioned by the lowering thread variable. If a temporary
* clear buffer is used, it is allocated for the body.
*
* @param T Lowering context providing buffer and layout maps, thread bounds,
* target information, thread variable, and workspace allocation
* helper.
* @param analyzer Analyzer used for iterator compression and arithmetic
* normalization.
* @return Stmt Lowered TIR statement implementing the reduction.
*/ */
Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(this->src.scope() == "local.fragment" && ICHECK(this->src.scope() == "local.fragment" &&
...@@ -409,38 +332,6 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -409,38 +332,6 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return body; return body;
} }
/**
* @brief Infer a layout mapping for the destination buffer of a Reduce
* operator.
*
* When inference level is below `kStrict`, and both source and destination
* buffers live in `local.fragment` with a known source fragment layout, this
* computes a candidate destination Fragment layout that accounts for
* replication over the reduction dimension and binds thread ranges from
* `T.thread_bounds`.
*
* Behavior:
* - Constructs a destination Fragment whose replicate extent equals
* src.shape[dim] * src_fragment.ReplicateExtent(), and whose threading is
* derived from the source fragment with the reduction dimension folded out.
* - If no layout exists for `dst` in `T.layout_map`, returns a map {dst ->
* inferred}.
* - If `dst` already has a layout, validates that the existing layout strictly
* contains the computed layout (shapes match and fragment containment holds).
* If compatible but the computed replicate extent is larger, returns the new
* layout.
* - In all other cases (strict inference level, unsupported scopes, or no src
* layout), returns an empty map.
*
* @param T Layout inference context containing `layout_map` and
* `thread_bounds`.
* @param level Inference strictness; no inference is performed at or above
* `kStrict`.
* @return LayoutMap A mapping for `dst` to an inferred Fragment layout, or
* empty.
* @throws LayoutConflictException if an existing `dst` layout conflicts with
* the computed layout (not containable or incompatible replication extents).
*/
LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T, LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const { InferLevel level) const {
if (level >= InferLevel::kStrict) if (level >= InferLevel::kStrict)
...@@ -518,22 +409,6 @@ TIR_REGISTER_TL_OP(ReduceOp, reduce) ...@@ -518,22 +409,6 @@ TIR_REGISTER_TL_OP(ReduceOp, reduce)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
/**
* @brief Construct a CumSumOp from a list of arguments and a buffer map.
*
* Expects args to contain exactly four PrimExprs in this order:
* 0: access pointer to source buffer (src),
* 1: access pointer to destination buffer (dst),
* 2: integer dimension to perform the cumulative sum along (dim),
* 3: boolean flag indicating whether to compute the cumsum in reverse
* (reverse).
*
* The constructor resolves src and dst from the provided BufferMap and stores
* the parsed dim and reverse values on the node. It verifies that args.size()
* == 4 and that dim is a valid axis for the source buffer shape.
*
* @param args Array of PrimExpr as described above.
*/
CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) { CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
/* /*
CumSum arguments: CumSum arguments:
...@@ -552,28 +427,6 @@ CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) { ...@@ -552,28 +427,6 @@ CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
data_ = std::move(node); data_ = std::move(node);
} }
/**
* @brief Lower the CumSum operator to TIR.
*
* Produces a TIR statement implementing cumulative sum depending on buffer
* scopes:
* - For shared/shared.dyn scopes: emits an extern call to
* `tl::CumSum2D<threads, dim, reverse>::run` with arguments [function_name,
* src.access_ptr(1), dst.access_ptr(3), src.shape...]. The number of threads is
* taken from `T.thread_bounds->extent`. Returns an Evaluate(Call(...))
* statement.
* - For local.fragment scopes on both src and dst: fatal error (not
* implemented).
* - For any other scope combinations: fails with an assertion.
*
* The `analyzer` parameter is accepted for interface compatibility but is not
* used by this lowering.
*
* @param T Lowering arguments (provides thread bounds and other lowering
* context).
* @return Stmt A TIR statement representing the lowered cumulative-sum
* operation.
*/
Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (this->src.scope() == "local.fragment" && if (this->src.scope() == "local.fragment" &&
this->dst.scope() == "local.fragment") { this->dst.scope() == "local.fragment") {
...@@ -600,17 +453,6 @@ Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -600,17 +453,6 @@ Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return Stmt(); return Stmt();
} }
/**
* @brief Layout inference for CumSum operator.
*
* CumSum does not perform any layout inference; this function always returns
* an empty mapping. The operator's lowering expects shared-memory semantics
* and layout decisions are handled elsewhere.
*
* @param T Layout inference inputs (buffers, existing layouts, etc.).
* @param level Inference strictness level (unused).
* @return LayoutMap Empty map indicating no inferred layouts.
*/
LayoutMap CumSumOpNode::InferLayout(const LayoutInferArgs &T, LayoutMap CumSumOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const { InferLevel level) const {
return {}; return {};
......
...@@ -924,6 +924,38 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, ...@@ -924,6 +924,38 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t,
return os.str(); return os.str();
} }
/**
* @brief Emit CUDA/TensorLib-specific code for a call expression.
*
* This visitor handles CallNode intrinsics and builtins that require emitting
* CUDA/TL-specific code (inline PTX/ASM sequences, TensorLanguage runtime
* calls, WMMA/TMA helpers, barriers, cp.async primitives, index-map based
* stores, reinterpret/packing helpers, and various mma/ldmatrix patterns). The
* function writes the generated code to the provided output stream and falls
* back to the C codegen for unrecognized calls.
*
* The method recognizes and emits code for (non-exhaustive): cp.async and its
* commit/wait variants, tma_load/store and im2col variants, ptX
* ldmatrix/stmatrix helpers, mbarrier APIs, cooperative grid sync, WMMA/legacy
* MMA intrinsics (fill/load/store/mma/bmma/ptx_mma/ptx_mma_sp), low-level PTX
* asm helpers (ldg32, cp_async bulk/init/arrive/wait barriers), reinterpret
* paths for special small-float encodings (e.g., float4 e2m1fn), tl::tl_gemm
* and related external calls, and other TL runtime calls.
*
* Side effects:
* - Emits to `os` and the internal codegen output stream.
* - May set internal feature flags (e.g., need_cooperative_groups_,
* need_mma_h_, need_cast_smem_ptr_to_int_, enable_sparse_gemm_).
* - May open/close SSA scopes and mutate internal variable mappings.
* - May call LOG(FATAL) / CHECK / ICHECK on invalid or unsupported argument
* patterns.
*
* @param op The call node to generate code for; the function inspects op->op
* and op->args to determine the appropriate emission.
* @param os Output stream to receive expression-level output when the caller
* expects an expression result (some paths write directly to the
* member stream instead).
*/
void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
auto print_extern_call_stmt = [&](std::string name, size_t start = 0, auto print_extern_call_stmt = [&](std::string name, size_t start = 0,
size_t end = 0) { size_t end = 0) {
......
...@@ -109,7 +109,19 @@ TL_DEVICE uint32_t smem_ptr_to_uint(void const *const ptr) { ...@@ -109,7 +109,19 @@ TL_DEVICE uint32_t smem_ptr_to_uint(void const *const ptr) {
return static_cast<uint32_t>(__cvta_generic_to_shared(ptr)); return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
} }
// Helper to cast SMEM pointer to unsigned /**
* Convert a shared-memory pointer to a 32-bit unsigned integer address.
*
* Casts the given pointer (expected to reference shared memory) into a 32-bit
* unsigned integer using the device address-space conversion required for
* shared-memory pointers.
*
* @param smem_ptr Pointer into shared memory.
* @return 32-bit unsigned integer representation of the shared-memory address.
*
* @note The pointer must refer to shared memory; behavior is undefined for
* pointers in other address spaces.
*/
TL_DEVICE unsigned int cast_smem_ptr_to_int(const void *const smem_ptr) { TL_DEVICE unsigned int cast_smem_ptr_to_int(const void *const smem_ptr) {
unsigned int smem_int; unsigned int smem_int;
asm volatile("{ .reg .u64 smem_int; cvta.to.shared.u64 smem_int, %1; " asm volatile("{ .reg .u64 smem_int; cvta.to.shared.u64 smem_int, %1; "
...@@ -123,7 +135,16 @@ template <typename T> struct normalize_atomic_type { ...@@ -123,7 +135,16 @@ template <typename T> struct normalize_atomic_type {
using type = T; using type = T;
}; };
template <> struct normalize_atomic_type<half_t> { template <> /**
* Map the public half_t alias to the native `half` type for atomic
* operations.
*
* Used by the atomic utilities to normalize externally exposed
* typedefs (e.g., Cutlass half_t) to the compiler's native `half`
* representation so correct atomic intrinsics or `cuda::atomic_ref`
* specializations can be selected.
*/
struct normalize_atomic_type<half_t> {
using type = half; using type = half;
}; };
...@@ -221,7 +242,25 @@ template <typename T> TL_DEVICE T AtomicLoad(T *address, int memory_order) { ...@@ -221,7 +242,25 @@ template <typename T> TL_DEVICE T AtomicLoad(T *address, int memory_order) {
} }
template <typename T1, typename T2> template <typename T1, typename T2>
TL_DEVICE void AtomicStore(T1 *address, T2 value, int memory_order) { TL_DEVICE /**
* Atomically stores a value into the given address using the
* specified memory ordering.
*
* The value is converted to the normalized atomic storage type for T1
* before being stored (for example, vectorized or reduced-width types
* such as FP16/BF16 are mapped to their underlying hardware
* representation). `memory_order` must be an `int` representation of
* a `cuda::memory_order` value (e.g.,
* `int(cuda::memory_order_relaxed)`).
*
* @param address Pointer to the destination atomic object.
* @param value Value to store; will be cast to the atomic storage
* type.
* @param memory_order Memory ordering for the atomic store (as an
* `int`-cast `cuda::memory_order`).
*/
void
AtomicStore(T1 *address, T2 value, int memory_order) {
using NT1 = typename normalize_atomic_type<T1>::type; using NT1 = typename normalize_atomic_type<T1>::type;
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address); cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.store(cuda_cast<NT1>(value), cuda::memory_order(memory_order)); aref.store(cuda_cast<NT1>(value), cuda::memory_order(memory_order));
...@@ -229,7 +268,25 @@ TL_DEVICE void AtomicStore(T1 *address, T2 value, int memory_order) { ...@@ -229,7 +268,25 @@ TL_DEVICE void AtomicStore(T1 *address, T2 value, int memory_order) {
// DP4A // DP4A
template <typename InDatatype, typename OutDatatype> template <typename InDatatype, typename OutDatatype>
TL_DEVICE void DP4A(InDatatype *a, InDatatype *b, OutDatatype *c) { TL_DEVICE /**
* Compute a 4×8-bit dot-product-accumulate using the CUDA DP4A
* intrinsic.
*
* Reads 32-bit packed values from `a` and `b` (each containing four
* signed 8-bit lanes), applies the __dp4a operation (dot product of
* the four lane pairs added to an accumulator), and stores the 32-bit
* integer result through `c`.
*
* @param a Pointer to a 32-bit packed input containing four signed
* 8-bit elements.
* @param b Pointer to a 32-bit packed input containing four signed
* 8-bit elements.
* @param c Pointer to a 32-bit accumulator; its current value is used
* as the initial accumulator and overwritten with the resulting int32
* sum.
*/
void
DP4A(InDatatype *a, InDatatype *b, OutDatatype *c) {
const int a_int = *((int *)a); const int a_int = *((int *)a);
const int b_int = *((int *)b); const int b_int = *((int *)b);
const int c_int = *((int *)c); const int c_int = *((int *)c);
......
...@@ -64,37 +64,6 @@ public: ...@@ -64,37 +64,6 @@ public:
BufferUseDefCollector(bool skip_thread_partition) BufferUseDefCollector(bool skip_thread_partition)
: skip_thread_partition_(skip_thread_partition) {} : skip_thread_partition_(skip_thread_partition) {}
/**
* @brief Execute a single layout-inference step for the infer node at the
* given index.
*
* Runs InferLayout on the TileOperator at cur_infer_id with the provided
* InferLevel and thread bounds, applies returned buffer->layout updates into
* layout_map (respecting strict_layout_map constraints for fragment buffers),
* and optionally propagates changes to dependent infer nodes by enqueueing
* them into q and marking in_queue.
*
* The function mutates layout_map and, when update_queue is true, may modify
* q and in_queue. It performs internal sanity checks via ICHECK and will
* LOG(WARNING) for buffers that cannot be propagated; ICHECK failures abort
* execution.
*
* @param cur_infer_id Index of the infer operator in infer_list_ to run (must
* be within range).
* @param level Inference relaxation level to pass to the operator's
* InferLayout.
* @param update_queue If true, discovered layout changes will enqueue
* dependent infer nodes.
* @param layout_map Mutable map of inferred layouts that will be updated with
* returned layouts.
* @param strict_layout_map Read-only map of layouts produced in the strict
* phase; used to enforce containment checks for local.fragment buffers when
* relaxing.
* @param q BFS queue used to propagate dependent inference indices; new
* indices may be pushed.
* @param in_queue Parallel boolean vector tracking queued status; entries
* corresponding to enqueued indices will be set to true.
*/
void RunInferStep(int cur_infer_id, InferLevel level, bool update_queue, void RunInferStep(int cur_infer_id, InferLevel level, bool update_queue,
LayoutMap &layout_map, const LayoutMap &strict_layout_map, LayoutMap &layout_map, const LayoutMap &strict_layout_map,
std::queue<int> &q, std::vector<bool> &in_queue) { std::queue<int> &q, std::vector<bool> &in_queue) {
...@@ -221,30 +190,6 @@ public: ...@@ -221,30 +190,6 @@ public:
} }
}; };
/**
* @brief Run the multi-stage layout inference and return the collected
* results.
*
* Performs layout inference over the collected TileOperator entries in three
* phases: (1) strict per-operator inference, (2) common inference via a BFS
* propagation queue, and (3) a free-mode relaxation phase that explores
* alternative root orderings within connected components to reduce register
* footprint. After inference completes, verifies that all local.fragment
* buffers have inferred layouts and collects loop (For) -> Fragment layouts
* and any per-loop predicates discovered during inference.
*
* The method consumes/permutes internal inference state (notably moves
* entries out of infer_list_) and returns a LayoutInferenceResult containing:
* - layout_map: inferred Layout for each Buffer,
* - for_map: mapping from For nodes to their inferred Fragment layout,
* - predicate_map: optional loop predicates keyed by For nodes.
*
* The function performs internal consistency checks (ICHECK) on sizes and
* required definitions; violations will terminate via ICHECK failure.
*
* @return LayoutInferenceResult A tuple-like struct with the inferred
* layout_map, for_map, and predicate_map.
*/
LayoutInferenceResult Run() { LayoutInferenceResult Run() {
// Basic consistency check: infer_list_ and thread_var_vec_ should have the // Basic consistency check: infer_list_ and thread_var_vec_ should have the
// same size // same size
...@@ -348,23 +293,6 @@ public: ...@@ -348,23 +293,6 @@ public:
} }
private: private:
/**
* @brief Visits a Call expression to collect tile-operator-based inference
* inputs.
*
* Processes non-global function calls by parsing them into a TileOperator
* (via ParseOperator). If the parse succeeds, records:
* - buffers referenced by call arguments into the collector's use lists,
* - the call AST node into infer_list_stmt_,
* - the parsed TileOperator into infer_list_,
* - the current thread IterVar into thread_var_vec_,
* - the thread iteration bounds into thread_bounds_vec_ (uses analyzer const
* bounds when available; otherwise [0,1]).
*
* Calls to global functions (where op->op is a GlobalVar) are ignored.
*
* @param op The Call node being visited.
*/
void VisitExpr_(const CallNode *op) final { void VisitExpr_(const CallNode *op) final {
IRVisitorWithAnalyzer::VisitExpr_(op); IRVisitorWithAnalyzer::VisitExpr_(op);
// Do not analysis the call node to the global function. // Do not analysis the call node to the global function.
...@@ -417,25 +345,6 @@ private: ...@@ -417,25 +345,6 @@ private:
use_list_[buffer].push_back(infer_idx); use_list_[buffer].push_back(infer_idx);
} }
/**
* @brief Handles For nodes during IR traversal.
*
* When the loop is a parallel loop (ForKind::kParallel), records it as a
* ParallelOp:
* - constructs a ParallelOp for the loop and appends it to the internal infer
* lists (infer_list_ and infer_list_stmt_),
* - registers all buffers referenced by the loop indices with use-list
* bookkeeping,
* - captures the current thread IterVar context and its compile-time extent
* (if available) into thread_var_vec_ and thread_bounds_vec_ (falls back to
* range [0,1] when unknown).
*
* For non-parallel loops, continues recursive traversal into the loop body.
*
* Side effects:
* - Mutates infer_list_, infer_list_stmt_, use_list_ (via addToUseList),
* thread_var_vec_, and thread_bounds_vec_.
*/
void VisitStmt_(const ForNode *op) final { void VisitStmt_(const ForNode *op) final {
if (op->kind == ForKind::kParallel) { if (op->kind == ForKind::kParallel) {
auto infer = ParallelOp(GetRef<For>(op)); auto infer = ParallelOp(GetRef<For>(op));
...@@ -506,15 +415,6 @@ private: ...@@ -506,15 +415,6 @@ private:
LayoutMap annotated_layout_map_; LayoutMap annotated_layout_map_;
bool skip_thread_partition_{false}; bool skip_thread_partition_{false};
/**
* @brief Create a deep copy of the current inference operator list.
*
* Returns a vector containing clones of each TileOperator in the collector's
* internal infer_list_. The returned list is independent of the original so
* subsequent modifications to either do not affect the other.
*
* @return std::vector<TileOperator> Cloned copy of infer_list_.
*/
std::vector<TileOperator> BackupInferList() { std::vector<TileOperator> BackupInferList() {
std::vector<TileOperator> back_infer_list; std::vector<TileOperator> back_infer_list;
back_infer_list.reserve(infer_list_.size()); back_infer_list.reserve(infer_list_.size());
...@@ -524,48 +424,6 @@ private: ...@@ -524,48 +424,6 @@ private:
return back_infer_list; return back_infer_list;
} }
/**
* @brief Explore alternative inference orders within connected components to
* relax layouts.
*
* This function performs a "free-mode" exploration that attempts different
* root operators within each connected component of the operator-use graph in
* order to find a layout assignment with lower register (fragment) usage.
*
* Detailed behavior:
* - Builds connected components of infer_list_ by unioning operators that
* share buffer uses (use_list_).
* - For each component, iterates each member operator as a candidate root:
* - Backups the current infer_list_ and uses a temporary copy of
* layout_map.
* - Runs RunInferStep and FinishInferQueue in InferLevel::kFree starting
* from the candidate root and then (as a fallback) runs the remaining members
* to try to cover the whole component.
* - If inference succeeds, computes a coarse register usage metric by
* summing the product of OutputShape dimensions for all Fragment layouts
* in the temporary layout map.
* - Tracks the candidate that yields the smallest register usage.
* - If a better plan is found for a component, replaces the global
* infer_list_ and updates layout_map with the best layout_map found.
*
* Side effects:
* - Mutates layout_map to the best-found free-mode layout assignment when a
* better plan is discovered.
* - Mutates the member infer_list_ (backed up and restored during attempts;
* finally set to the best plan if found).
*
* Notes:
* - LayoutConflictException and NormalizeIterException raised during attempts
* are caught and treated as failed attempts; they do not propagate out of
* this function.
* - The register-usage metric is a heuristic (sum of fragment output element
* counts) used to prefer less-replicated layouts.
*
* @param layout_map[in,out] The current global layout map to be updated with
* a better free-mode result if found.
* @param strict_layout_map Read-only map of layouts inferred in strict mode,
* used to constrain free-mode inference.
*/
void InferInFreeMode(LayoutMap &layout_map, void InferInFreeMode(LayoutMap &layout_map,
const LayoutMap &strict_layout_map) { const LayoutMap &strict_layout_map) {
// Group operators into connected components // Group operators into connected components
...@@ -698,6 +556,20 @@ private: ...@@ -698,6 +556,20 @@ private:
: arith::IRMutatorWithAnalyzer(analyzer), result_(result), : arith::IRMutatorWithAnalyzer(analyzer), result_(result),
skip_thread_partition_(skip_thread_partition){}; skip_thread_partition_(skip_thread_partition){};
/**
* @brief Visit and mutate a Block node to attach inferred layout information.
*
* Converts the visited Block via the base visitor, asserts that every buffer
* allocated with scope "local.framgent" has an inferred layout in
* result_.layout_map, and attaches result_.layout_map to the Block's
* annotations under attr::kLayoutMap.
*
* If any "local.framgent" buffer lacks an entry in result_.layout_map an
* ICHECK will fail with the offending buffer printed.
*
* @return Stmt The (possibly modified) Block statement with the layout-map
* annotation set.
*/
Stmt VisitStmt_(const BlockNode *op) final { Stmt VisitStmt_(const BlockNode *op) final {
Block block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op)); Block block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));
...@@ -712,6 +584,41 @@ private: ...@@ -712,6 +584,41 @@ private:
return block; return block;
} }
/**
* @brief Visit and transform For nodes according to inferred layout
* information.
*
* If the For node is present in result_.for_map, this method applies
* loop-level layout-driven transformations: it optionally partitions the loop
* across the thread index, vectorizes the loop body, and wraps the loop with
* a predicate if one was inferred for the loop root.
*
* Detailed behavior:
* - Reads reducer information from the For node's attr::kReducerInfo
* annotation (if present) to detect reduction targets.
* - Detects register-local buffer stores (buffers with scope "local") in the
* original loop body; if only register-local stores are present the loop is
* treated as a register-local scenario and is not partitioned across
* threads.
* - Obtains the loop layout from result_.for_map[root] and, unless the loop
* is register-local or skip_thread_partition_ is set, partitions the loop via
* PartitionLoop using thread_var_ and analyzer_.
* - Scans the transformed loop body to determine whether it accesses any
* non-local buffers (scopes other than "local" or "local.fragment").
* - Scans the transformed loop body to detect reducers (based on
* reducer_info). If a reducer is present the loop is NOT vectorized
* (reduction axes are excluded from vectorization as a conservative
* workaround).
* - If the loop has non-local accesses and no reducer, the loop is vectorized
* via VectorizeLoop.
* - If a predicate exists in result_.predicate_map for the loop root and the
* loop was partitioned, the method returns an IfThenElse surrounding the
* (possibly partitioned/vectorized) loop with that predicate; otherwise it
* returns the transformed For.
*
* @return The possibly transformed For statement (or an IfThenElse wrapping
* it)
*/
Stmt VisitStmt_(const ForNode *op) final { Stmt VisitStmt_(const ForNode *op) final {
Map<Var, ReducerInfo> reducer_info; Map<Var, ReducerInfo> reducer_info;
if (op->annotations.count(attr::kReducerInfo)) if (op->annotations.count(attr::kReducerInfo))
......
...@@ -24,6 +24,18 @@ using namespace tir; ...@@ -24,6 +24,18 @@ using namespace tir;
using namespace tir::transform; using namespace tir::transform;
using arith::IRMutatorWithAnalyzer; using arith::IRMutatorWithAnalyzer;
/**
* @brief Construct a ReducerInfoNode from textual op and replication
* descriptors.
*
* Maps op_str to a ReducerOpType ("sum" → SUM, "max" → MAX, "min" → MIN) and
* rep_str to a ReducerRepType ("all" → ALL, "none" → NONE).
*
* @param op_str String identifying the reducer operation.
* @param rep_str String identifying the replication behavior.
* @throws RuntimeError if op_str or rep_str is not one of the supported values
* (triggers ICHECK).
*/
ReducerInfoNode::ReducerInfoNode(const String &op_str, const String &rep_str) { ReducerInfoNode::ReducerInfoNode(const String &op_str, const String &rep_str) {
if (op_str == "sum") if (op_str == "sum")
op = ReducerOpType::SUM; op = ReducerOpType::SUM;
...@@ -45,6 +57,23 @@ ReducerInfoNode::ReducerInfoNode(const String &op_str, const String &rep_str) { ...@@ -45,6 +57,23 @@ ReducerInfoNode::ReducerInfoNode(const String &op_str, const String &rep_str) {
class ReducerLayoutAnnotator : public IRMutatorWithAnalyzer { class ReducerLayoutAnnotator : public IRMutatorWithAnalyzer {
public: public:
private: private:
/**
* @brief Visit an attribute statement and capture the IterVar for
* threadIdx.x.
*
* If the attribute key is `tir::attr::thread_extent` and the node is an
* `IterVar` whose `thread_tag` equals `"threadIdx.x"`, this sets the
* mutator's `thread_var_` to that IterVar (after asserting the iterator's
* extent is an `IntImm`). The previous `thread_var_` is preserved and
* restored after delegating to the base visitor. Delegates all traversal work
* to `IRMutatorWithAnalyzer::VisitStmt_`.
*
* Side effects:
* - Temporarily updates the member `thread_var_` during traversal of the
* child statement so subsequent visitors can read the thread index IterVar.
*
* @return The possibly mutated statement returned by the base visitor.
*/
Stmt VisitStmt_(const AttrStmtNode *op) final { Stmt VisitStmt_(const AttrStmtNode *op) final {
auto prev_thread_var = thread_var_; auto prev_thread_var = thread_var_;
if (op->attr_key == tir::attr::thread_extent) { if (op->attr_key == tir::attr::thread_extent) {
...@@ -59,6 +88,28 @@ private: ...@@ -59,6 +88,28 @@ private:
return result; return result;
} }
/**
* @brief Visits a TIR Block node to collect reducer metadata and apply
* discovered buffer layouts.
*
* This method:
* - Extracts reducer information from the block's `attr::kReducerInfo`
* annotation and populates the internal reducer_info_map_.
* - Registers allocated buffers by mapping each buffer's data Var to its
* Buffer in var_to_buffer_.
* - Recursively visits and rewrites the block body via the base mutator.
* - Merges any layouts accumulated in new_layout_map_ into the block's
* `attr::kLayoutMap` annotation (creating or extending the annotation), then
* clears new_layout_map_ for subsequent blocks.
*
* Side effects:
* - Updates reducer_info_map_, var_to_buffer_, and may set the block-level
* `kLayoutMap` annotation.
* - Clears new_layout_map_ after merging.
*
* @param op The Block node being visited.
* @return Stmt The potentially modified Block statement (as a Stmt).
*/
Stmt VisitStmt_(const BlockNode *op) final { Stmt VisitStmt_(const BlockNode *op) final {
// Record annotations // Record annotations
if (op->annotations.count(attr::kReducerInfo)) { if (op->annotations.count(attr::kReducerInfo)) {
...@@ -87,6 +138,43 @@ private: ...@@ -87,6 +138,43 @@ private:
return result; return result;
} }
/**
* @brief Visit and possibly annotate a For node for reducer layout lowering.
*
* Visits a For node via the base mutator and, if the traversal is currently
* inside a reduction region (tracked by inside_reducer_range_) and this is
* the outermost loop of that region, annotates the loop with reducer
* information and derives per-buffer layout fragments for each reducer
* buffer.
*
* When annotating:
* - Sets the block-level `attr::kReducerInfo` annotation to the current
* inside_reducer_range_ map on the loop.
* - For each reducer buffer, reads the bound of `thread_var_` (requires the
* analyzer to have a const-int bound for it) and creates a Fragment:
* - If the reducer's replication type is ALL, creates a replication
* fragment across the thread extent.
* - If the replication type is NONE, builds a flattened index expression
* across buffer indices, reduces it modulo the thread extent, adds the
* thread minimum offset, and uses that as the fragment index.
* - Records the constructed Fragments into new_layout_map_ keyed by the
* buffer's data Var.
*
* Side effects:
* - May set `attr::kReducerInfo` on the For node's annotations.
* - Updates `new_layout_map_`.
* - Reads and relies on `thread_var_`, `analyzer_->const_int_bound`, and
* `var_to_buffer_`.
*
* Preconditions and checks:
* - `thread_var_` must be defined and have a constant-int bound when
* annotating.
* - Each reducer Var in inside_reducer_range_ must map to an allocated Buffer
* in var_to_buffer_ (ICHECK enforced).
*
* @param op The original For node being visited.
* @return The (possibly) transformed For statement.
*/
Stmt VisitStmt_(const ForNode *op) final { Stmt VisitStmt_(const ForNode *op) final {
// only annotate the outermost loop // only annotate the outermost loop
bool should_annotate = false; bool should_annotate = false;
...@@ -140,11 +228,48 @@ private: ...@@ -140,11 +228,48 @@ private:
return result; return result;
} }
/**
* @brief Handle BufferStore statements during IR mutation.
*
* This override is the visit hook for BufferStoreNode. Currently it delegates
* to the base IRMutatorWithAnalyzer implementation. Intended as the place to
* perform reducer-specific viability checks for stores (e.g., validating
* operations against reducer metadata); such checks are TODO and are not yet
* implemented.
*
* @return Stmt The (possibly transformed) statement returned by the base
* mutator.
*/
Stmt VisitStmt_(const BufferStoreNode *op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
//! TODO: check store viable according to info->op //! TODO: check store viable according to info->op
return IRMutatorWithAnalyzer::VisitStmt_(op); return IRMutatorWithAnalyzer::VisitStmt_(op);
} }
/**
* @brief Processes Call expressions to track reducer ranges and finalize
* reducer operations.
*
* Visits call nodes, detects T.fill calls that target reducer buffers and
* records their reducer metadata in inside_reducer_range_ until the matching
* T.finalize_reducer is seen. When a FinalizeReducerOp call is encountered,
* this method appends the reducer operation enum value to the call arguments
* and removes the corresponding entry from inside_reducer_range_.
*
* Side effects:
* - Inserts and removes entries in inside_reducer_range_.
* - Mutates the FinalizeReducerOp call by pushing the reducer op enum as an
* extra argument.
*
* Failure modes:
* - ICHECK fails if a T.fill targets a reducer already recorded in
* inside_reducer_range_ (i.e., a prior T.fill without an intervening
* T.finalize_reducer).
* - ICHECK fails if T.finalize_reducer has no matching T.fill (no entry in
* inside_reducer_range_).
*
* @param op_ The CallNode being visited.
* @return PrimExpr The (possibly modified) call expression.
*/
PrimExpr VisitExpr_(const CallNode *op_) final { PrimExpr VisitExpr_(const CallNode *op_) final {
auto op_ref = IRMutatorWithAnalyzer::VisitExpr_(op_).as<Call>().value(); auto op_ref = IRMutatorWithAnalyzer::VisitExpr_(op_).as<Call>().value();
auto op = op_ref.CopyOnWrite(); auto op = op_ref.CopyOnWrite();
...@@ -175,6 +300,15 @@ private: ...@@ -175,6 +300,15 @@ private:
return op_ref; return op_ref;
} }
/**
* @brief Construct a ReducerLayoutAnnotator with an arithmetic analyzer.
*
* Initializes the annotator's base IRMutatorWithAnalyzer with the provided
* arith::Analyzer, which the mutator uses to query symbolic bounds and
* simplify integer expressions during layout inference.
*
* @param analyzer Pointer to an arith::Analyzer used for symbolic analysis.
*/
ReducerLayoutAnnotator(arith::Analyzer *analyzer) ReducerLayoutAnnotator(arith::Analyzer *analyzer)
: IRMutatorWithAnalyzer(analyzer) {} : IRMutatorWithAnalyzer(analyzer) {}
...@@ -186,6 +320,19 @@ private: ...@@ -186,6 +320,19 @@ private:
Map<Var, Layout> new_layout_map_; Map<Var, Layout> new_layout_map_;
public: public:
/**
* @brief Apply reducer layout substitution to a PrimFunc.
*
* Runs the ReducerLayoutAnnotator over the function body to collect reducer
* metadata, insert layout mappings for reducer buffers, and lower
* local.reducer usage to local.fragment-compatible forms. Returns a new
* PrimFunc whose body is the transformed IR.
*
* @param f The PrimFunc to transform; passed by value and returned with an
* updated body.
* @return PrimFunc The transformed PrimFunc with reducer layouts and related
* rewrites applied.
*/
static PrimFunc Substitute(PrimFunc f) { static PrimFunc Substitute(PrimFunc f) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
ReducerLayoutAnnotator substituter(&analyzer); ReducerLayoutAnnotator substituter(&analyzer);
...@@ -195,6 +342,18 @@ public: ...@@ -195,6 +342,18 @@ public:
} }
}; };
/**
* @brief Create a TVM transform pass that lowers local.reducer buffers to
* local.fragment layouts.
*
* This pass runs ReducerLayoutAnnotator::Substitute on a PrimFunc to collect
* reducer metadata, compute per-buffer layout fragments for reducer buffers,
* and annotate blocks with the resulting layout map. It is exposed as a
* PrimFunc-level pass named "tl.LayoutReducer".
*
* @return tvm::transform::Pass A prim-function pass that applies the
* layout-reduction substitution.
*/
tvm::transform::Pass LayoutReducer() { tvm::transform::Pass LayoutReducer() {
using namespace tir::transform; using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
......
...@@ -10,6 +10,51 @@ ...@@ -10,6 +10,51 @@
#include "../layout/layout.h" #include "../layout/layout.h"
namespace tvm { namespace tvm {
/**
* Types of reduction operations supported by TL transforms.
*
* SUM - arithmetic sum reduction.
* MAX - elementwise maximum reduction.
* MIN - elementwise minimum reduction.
*/
/**
* Representation semantics for a reducer.
*
* ALL - reducer collapses all elements along the reduced axes.
* NONE - reducer does not collapse (used to represent a placeholder/no-op).
*/
/**
* Holds metadata describing a reducer used in layout transforms.
*
* Contains the reduction operation (`op`) and its representation semantics
* (`rep`).
*/
/**
* Construct a ReducerInfoNode from textual identifiers.
*
* @param op_str String identifier for the reduction operation (e.g., "sum",
* "max", "min").
* @param rep_str String identifier for the representation semantics (e.g.,
* "all", "none").
*/
/**
* Handle type for ReducerInfoNode (ObjectRef wrapper).
*
* Constructed from string identifiers for operation and representation.
*
* @param op_str String identifier for the reduction operation (e.g., "sum",
* "max", "min").
* @param rep_str String identifier for the representation semantics (e.g.,
* "all", "none").
*/
/**
* Attribute key used to attach ReducerInfo to IR nodes or other attribute maps.
*/
namespace tl { namespace tl {
enum class ReducerOpType { SUM, MAX, MIN }; enum class ReducerOpType { SUM, MAX, MIN };
......
...@@ -950,9 +950,25 @@ private: ...@@ -950,9 +950,25 @@ private:
return entry; return entry;
} }
/*! /*!
* \brief find the storage entry in the free list for the allocate * @brief Locate or create a storage entry from free lists to satisfy an
* \param op the allocate node * AllocateNode.
* \return the storage entry *
* Finds a reusable StorageEntry for the given AllocateNode (constant or
* symbolic size) using two-tiered strategies:
* - For constant-size allocations (>0): prefer a free entry that is >=
* required size; if none, coalesce smaller free constant-size entries until
* the sum meets the request and return a new StorageEntry representing the
* merged space. Very small constant allocations (<= 32 bits) are not reused
* and will allocate a fresh entry.
* - For symbolic-size (unknown at compile time): pick and remove an arbitrary
* entry from the symbolic free list.
*
* If no suitable free entry is found, a fresh StorageEntry is created via
* NewAlloc.
*
* @param op Pointer to the AllocateNode to satisfy. Must be non-null.
* @return StorageEntry* A storage entry that will hold the allocation (may be
* newly created).
*/ */
StorageEntry *FindAlloc(const AllocateNode *op) { StorageEntry *FindAlloc(const AllocateNode *op) {
ICHECK(op != nullptr); ICHECK(op != nullptr);
......
...@@ -218,6 +218,30 @@ bool IsThreadInvariant(const PrimExpr &cond) { ...@@ -218,6 +218,30 @@ bool IsThreadInvariant(const PrimExpr &cond) {
return false; return false;
} }
/**
* @brief Visit an IfThenElse statement and collect storage access summaries for
* its branches.
*
* Visits the if-then-else node's condition and both branches to summarize
* buffer reads, writes, and synchronization events under the condition's
* constraints. If the condition is not thread-invariant, increments an internal
* condition counter for the duration of processing.
*
* Behavior and side effects:
* - Evaluates the condition expression (using ExtractRealCondition) and applies
* it as a constraint while summarizing the then-branch.
* - For the else-branch (when present), applies the negated,
* analyzer-simplified condition
* (analyzer_.rewrite_simplify(Not(real_condition))) as the constraint.
* - Accumulates summarized StmtEntry access information for the then/else
* branches and appends a combined StmtEntry for the IfThenElseNode into the
* current scope.
* - Temporarily toggles allow_append_ and clears curr_stmt_.access during
* condition evaluation and branch summarization.
* - Modifies internal state: scope_ (push/pop of temporary branch scopes),
* curr_stmt_.access, and condition_counter_ (incremented/decremented when the
* condition is not thread-invariant).
*/
void TileLangStorageAccessVisitor::VisitStmt_(const IfThenElseNode *op) { void TileLangStorageAccessVisitor::VisitStmt_(const IfThenElseNode *op) {
bool is_thread_invariant = IsThreadInvariant(op->condition); bool is_thread_invariant = IsThreadInvariant(op->condition);
if (!is_thread_invariant) { if (!is_thread_invariant) {
......
...@@ -649,10 +649,44 @@ public: ...@@ -649,10 +649,44 @@ public:
*/ */
bool hasSimtCopy() const { return has_simt_copy_; } bool hasSimtCopy() const { return has_simt_copy_; }
/**
* @brief Whether this emitter contains only warp-group MMA (WgMMA)
* operations.
*
* Returns true if the emitter detected exclusively WgMMA usage in the region
* it analyzed.
*
* @return bool true when only WgMMA-based code paths are present; false
* otherwise.
*/
bool onlyHasWgMMA() const { return only_has_wgmma_; } bool onlyHasWgMMA() const { return only_has_wgmma_; }
private: private:
template <typename NodeType> Stmt FilterByRole(const NodeType *op) { template <
typename NodeType> /**
* @brief Filter a statement by its producer/consumer
* role for emission.
*
* Returns one of:
* - the original statement (unchanged) when this
* emitter should emit it,
* - the result of visiting the statement (to descend
* into it) when mbarrier-only mode requires full
* traversal for non-producer roles,
* - an empty evaluate (`Evaluate(0)`) when the
* statement should be omitted.
*
* The decision is based on the role of `op` as
* reported by `marker_`, the emitter mode
* (`is_emitting_producer_`), and the `mbarrier_only_`
* flag.
*
* @param op The statement node to filter; its role is
* queried via `marker_`.
* @return Stmt The statement to place into the emitted
* IR (possibly transformed or an empty evaluate).
*/
Stmt FilterByRole(const NodeType *op) {
Role role = marker_.GetRole(op); Role role = marker_.GetRole(op);
if (mbarrier_only_) { if (mbarrier_only_) {
if (role != Role::kProducer) if (role != Role::kProducer)
......
...@@ -131,6 +131,12 @@ def run_autotune(M: int, N: int, K: int): ...@@ -131,6 +131,12 @@ def run_autotune(M: int, N: int, K: int):
def test_autotune_matmul(): def test_autotune_matmul():
"""
Run the autotuning validation for the matmul kernel on a 1024x1024x1024 problem.
This test constructs random CUDA tensors, autotunes the JIT-compiled block-level matrix-multiplication kernel,
executes it, and asserts the result matches a reference CPU implementation within tolerances.
"""
run_autotune(1024, 1024, 1024) run_autotune(1024, 1024, 1024)
......
...@@ -63,6 +63,26 @@ def should_enable_aggressive_merge(pass_ctx: Optional[PassContext] = None, ...@@ -63,6 +63,26 @@ def should_enable_aggressive_merge(pass_ctx: Optional[PassContext] = None,
def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
# Bind the target device information to the module # Bind the target device information to the module
"""
Bind target information and progressively legalize and lower frontend Tile IR into a form suitable for downstream optimization and codegen.
This pass pipeline:
- Binds the provided target to the module.
- Legalizes frontend Tile IR into TVM-compatible constructs.
- Simplifies expressions.
- Configures reducer layouts and performs layout inference for fragments and shared memory.
- Lowers high-level tile operations and L2 persistent maps.
- Legalizes vectorized loops and inserts safety checks for memory accesses.
- Re-simplifies to remove redundancies introduced by safety checks.
- Attempts loop vectorization for dynamic-shaped loops.
Parameters:
mod (IRModule): The input IR module containing frontend Tile IR.
target (Target): Target device information to bind into the module.
Returns:
IRModule: The transformed module, ready for target-specific optimization passes.
"""
mod = tir.transform.BindTarget(target)(mod) mod = tir.transform.BindTarget(target)(mod)
# Legalize the frontend IR to make it compatible with TVM # Legalize the frontend IR to make it compatible with TVM
......
...@@ -18,15 +18,22 @@ _MEMORY_ORDER_ID_MAP = { ...@@ -18,15 +18,22 @@ _MEMORY_ORDER_ID_MAP = {
def region(buffer: BufferLoad, access_type: str, *args: PrimExpr): def region(buffer: BufferLoad, access_type: str, *args: PrimExpr):
"""Create a memory region descriptor for tile operations. """
Create a tile memory-region descriptor for a BufferLoad.
Args:
buffer (tir.BufferLoad): The buffer to create a region for Maps access_type ('r', 'w', 'rw') to the numeric codes expected by the `tl.region` intrinsic
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write (1, 2, 3 respectively) and returns a tir.Call representing the region with the provided extents.
*args (tir.PrimExpr): Extent expressions defining the region size
Parameters:
buffer (tir.BufferLoad): The BufferLoad that identifies the underlying buffer and indices.
access_type (str): One of 'r', 'w', or 'rw' indicating read, write, or read-write access.
*args (tir.PrimExpr): Extent expressions for each region dimension.
Returns: Returns:
tir.Call: A region descriptor for tile operations tir.Call: A call to the `tl.region` intrinsic describing the memory region.
Raises:
KeyError: If access_type is not one of 'r', 'w', or 'rw'.
""" """
access_type = {"r": 1, "w": 2, "rw": 3}[access_type] access_type = {"r": 1, "w": 2, "rw": 3}[access_type]
return T.call_intrin("handle", op.Op.get("tl.region"), buffer, access_type, *args) return T.call_intrin("handle", op.Op.get("tl.region"), buffer, access_type, *args)
...@@ -74,15 +81,20 @@ def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: List ...@@ -74,15 +81,20 @@ def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: List
def buffer_region_to_tile_region(buffer_region: BufferRegion, access_type: str, def buffer_region_to_tile_region(buffer_region: BufferRegion, access_type: str,
extents: List[PrimExpr]): extents: List[PrimExpr]):
"""Convert a buffer region to a tile region descriptor.
Args:
buffer_region (tir.BufferRegion): The buffer region to convert
access_type (str): Type of access - 'r' for read, 'w' for write, 'rw' for read-write
Returns:
tir.Call: A region descriptor for the specified buffer region
""" """
Create a tl region descriptor for the given BufferRegion.
Parameters:
buffer_region (tir.BufferRegion): Source buffer region whose `region` items provide mins and extents.
access_type (str): Access mode: "r", "w", or "rw".
extents (List[PrimExpr]): Requested extents; must have length <= the number of extents in buffer_region.region.
Returns:
tir.Call: A tile-region descriptor (tl.region) covering the buffer_region.
Raises:
AssertionError: If the number of extents in buffer_region.region is smaller than len(extents).
"""
mins = [x.min for x in buffer_region.region] mins = [x.min for x in buffer_region.region]
region_extents = [x.extent for x in buffer_region.region] region_extents = [x.extent for x in buffer_region.region]
assert len(region_extents) >= len( assert len(region_extents) >= len(
...@@ -93,14 +105,19 @@ def buffer_region_to_tile_region(buffer_region: BufferRegion, access_type: str, ...@@ -93,14 +105,19 @@ def buffer_region_to_tile_region(buffer_region: BufferRegion, access_type: str,
def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr: def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr:
"""Perform an atomic maximum operation. """
Perform an atomic maximum on the value stored at dst with an optional memory-order.
Args:
dst (Buffer): Destination buffer where the atomic maximum will be performed If memory_order is None the runtime extern "AtomicMax" is called without an explicit memory-order id; otherwise the provided memory_order string is mapped to a numeric id using the module's memory-order map and passed to the extern.
value (PrimExpr): Value to be atomically added
Parameters:
dst (Buffer): Destination buffer/address to apply the atomic max.
value (PrimExpr): Value to compare/store atomically.
memory_order (str | None): Optional memory-order name (e.g. "relaxed", "acquire", "seq_cst").
If provided, it is translated to the corresponding numeric memory-order id before the call.
Returns: Returns:
PrimExpr: Handle to the atomic maximum operation PrimExpr: A handle/expression representing the issued atomic maximum operation.
""" """
if memory_order is None: if memory_order is None:
return T.call_extern("handle", "AtomicMax", T.address_of(dst), value) return T.call_extern("handle", "AtomicMax", T.address_of(dst), value)
...@@ -110,14 +127,18 @@ def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> ...@@ -110,14 +127,18 @@ def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None) ->
def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr: def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr:
"""Perform an atomic minimum operation. """
Atomically update the value at dst to the minimum of its current value and value.
Args:
dst (Buffer): Destination buffer where the atomic minimum will be performed If memory_order is provided, it selects the memory-order semantic used by the underlying extern call;
value (PrimExpr): Value to be atomically added allowed names are "relaxed", "consume", "acquire", "release", "acq_rel", and "seq_cst" (mapped internally
to integer IDs). If memory_order is None, the extern is invoked without an explicit memory-order argument.
Parameters:
memory_order (str | None): Optional memory-order name controlling the atomic operation's ordering.
Returns: Returns:
PrimExpr: Handle to the atomic minimum operation PrimExpr: A handle expression representing the atomic-min operation.
""" """
if memory_order is None: if memory_order is None:
return T.call_extern("handle", "AtomicMin", T.address_of(dst), value) return T.call_extern("handle", "AtomicMin", T.address_of(dst), value)
...@@ -127,17 +148,26 @@ def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> ...@@ -127,17 +148,26 @@ def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None) ->
def atomic_add(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr: def atomic_add(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr:
"""Perform an atomic addition operation. """
Atomically add `value` into `dst`, returning a handle to the operation.
Args:
dst (Buffer): Destination buffer where the atomic addition will be performed Supports scalar/addressed extern atomic add when neither argument exposes extents, or tile-region-based atomic add for Buffer/BufferRegion/BufferLoad inputs. If both arguments are plain Buffers their shapes must be structurally equal. If at least one side exposes extents, extents are aligned (missing dimensions are treated as size 1); an assertion is raised if extents cannot be deduced. The optional `memory_order` (one of "relaxed","consume","acquire","release","acq_rel","seq_cst") is used only for the direct extern `AtomicAdd` path when no extents are available — otherwise the tile-region path ignores `memory_order`.
value (PrimExpr): Value to be atomically added
Returns: Returns:
PrimExpr: Handle to the atomic addition operation PrimExpr: A handle representing the atomic addition operation.
""" """
def get_extent(data): def get_extent(data):
"""
Return the inferred extent (shape) of a buffer-like object.
If `data` is a Var bound to a let value, the let value is resolved before inspection.
Parameters:
data: A Var, Buffer, or BufferRegion to inspect.
Returns:
The shape/extents as a list-like of PrimExpr (Buffer.shape or list of region item extents), or None if the extent cannot be determined.
"""
if isinstance(data, Var) and T.has_let_value(data): if isinstance(data, Var) and T.has_let_value(data):
data = T.get_let_value(data) data = T.get_let_value(data)
if isinstance(data, Buffer): if isinstance(data, Buffer):
...@@ -252,16 +282,11 @@ def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer: ...@@ -252,16 +282,11 @@ def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer:
def view(src: Buffer, def view(src: Buffer,
shape: Union[List[PrimExpr], None] = None, shape: Union[List[PrimExpr], None] = None,
dtype: Union[str, None] = None) -> Buffer: dtype: Union[str, None] = None) -> Buffer:
"""Views the input buffer with optionally modified shape and dtype.
Args:
src (Buffer): Input buffer to be viewed
shape (Union[List[PrimExpr], None], optional): New shape for the buffer. Defaults to None.
dtype (Union[str, None], optional): New dtype for the buffer. Defaults to None.
Returns:
Buffer: A new buffer view with the specified shape and dtype
""" """
Return a Tensor view of the input buffer with an optional new shape and dtype.
If `shape` is None the source buffer's shape is used; if `dtype` is None the source buffer's dtype is used. The returned buffer shares the same underlying data as `src` (no copy).
"""
if shape is None: if shape is None:
shape = src.shape shape = src.shape
if dtype is None: if dtype is None:
...@@ -270,29 +295,34 @@ def view(src: Buffer, ...@@ -270,29 +295,34 @@ def view(src: Buffer,
def atomic_load(src: Buffer, memory_order: str = "seq_cst") -> PrimExpr: def atomic_load(src: Buffer, memory_order: str = "seq_cst") -> PrimExpr:
"""Loads a value from the input buffer with specified memory_order. """
Load a value from the given buffer using the specified atomic memory ordering.
Args:
src (Buffer): Input buffer to load from Performs an atomic load from `src` and returns a PrimExpr representing the loaded value.
memory_order (str, optional): Atomicity level for the load operation. Defaults to "seq_cst". memory_order selects the ordering and must be one of: "relaxed", "consume", "acquire",
"release", "acq_rel", or "seq_cst" (default).
Returns: Raises KeyError if an unknown memory_order is provided.
PrimExpr: The loaded value from the buffer
""" """
return T.call_extern(src.dtype, "AtomicLoad", T.address_of(src), return T.call_extern(src.dtype, "AtomicLoad", T.address_of(src),
_MEMORY_ORDER_ID_MAP[memory_order]) _MEMORY_ORDER_ID_MAP[memory_order])
def atomic_store(dst: Buffer, src: PrimExpr, memory_order: str = "seq_cst") -> PrimExpr: def atomic_store(dst: Buffer, src: PrimExpr, memory_order: str = "seq_cst") -> PrimExpr:
"""Stores a value to the input buffer with specified memory_order. """
Perform an atomic store of `src` into `dst` with the given memory ordering.
Args:
dst (Buffer): Input buffer to store to Parameters:
src (PrimExpr): Value to store dst (Buffer): Destination buffer to store into.
memory_order (str, optional): Atomicity level for the load operation. Defaults to "seq_cst". src (PrimExpr): Value to store.
memory_order (str, optional): Memory ordering name; one of "relaxed", "consume",
"acquire", "release", "acq_rel", or "seq_cst". Defaults to "seq_cst".
The name is mapped to an internal numeric ID used by the underlying runtime.
Returns: Returns:
PrimExpr: The handle of the store operation PrimExpr: A handle representing the issued atomic store operation.
Raises:
KeyError: If `memory_order` is not one of the supported names.
""" """
return T.call_extern("handle", "AtomicStore", T.address_of(dst), src, return T.call_extern("handle", "AtomicStore", T.address_of(dst), src,
_MEMORY_ORDER_ID_MAP[memory_order]) _MEMORY_ORDER_ID_MAP[memory_order])
...@@ -155,16 +155,13 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) - ...@@ -155,16 +155,13 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) -
def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reverse: bool = False): def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reverse: bool = False):
"""Perform cumulative sum on input buffer, store the result to output buffer. """
Compute the cumulative sum of `src` along `dim`, writing results to `dst`.
Args:
src (tir.Buffer): The input buffer Negative `dim` indices are normalized (Python-style). If `dst` is None, the operation is performed in-place into `src`. Raises ValueError when `dim` is out of bounds for `src.shape`. When `src.scope() == "local.fragment"`, this delegates to `cumsum_fragment`; otherwise it emits the `tl.cumsum` intrinsic.
dst (tir.Buffer, optional): The output buffer. Defaults to None.
dim (int, optional): The dimension to perform cumulative sum on. Defaults to 0.
reverse (bool, optional): Whether to perform reverse cumulative sum. Defaults to False.
Returns: Returns:
tir.Call: Handle to the cumulative sum operation tir.Call: A handle to the emitted cumulative-sum operation.
""" """
shape = src.shape shape = src.shape
...@@ -188,13 +185,17 @@ def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reve ...@@ -188,13 +185,17 @@ def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reve
def finalize_reducer(reducer: tir.Buffer): def finalize_reducer(reducer: tir.Buffer):
"""Finalize the reducer buffer. """
Finalize a reducer buffer by emitting the `tl.finalize_reducer` intrinsic.
Args:
reducer (tir.Buffer): The reducer buffer This returns a TVM `tir.Call` handle that finalizes the given reducer using its writable pointer.
The call does not modify Python objects directly; it produces the low-level intrinsic call used by the IR.
Parameters:
reducer (tir.Buffer): Reducer buffer whose writable pointer will be finalized.
Returns: Returns:
tir.Call: Handle to the finalize reducer operation tir.Call: Handle to the finalize reducer intrinsic call.
""" """
return tir.call_intrin( return tir.call_intrin(
"handle", "handle",
......
...@@ -416,12 +416,25 @@ def LowerThreadAllreduce(): ...@@ -416,12 +416,25 @@ def LowerThreadAllreduce():
def LowerDeviceKernelLaunch(): def LowerDeviceKernelLaunch():
"""LowerDeviceKernelLaunch """
Create and return a transform pass that lowers device kernel launch constructs to target-specific IR.
This pass transforms high-level device kernel launch and related intrinsics into lower-level
IR suitable for backend code generation and device-side lowering.
Returns:
tvm.transform.Pass: The transform pass that performs device kernel launch lowering.
""" """
return _ffi_api.LowerDeviceKernelLaunch() # type: ignore return _ffi_api.LowerDeviceKernelLaunch() # type: ignore
def LayoutReducer(): def LayoutReducer():
"""LayoutReducer """
Return a TVM transform pass that performs layout reduction/normalization.
This wrapper delegates to the underlying FFI implementation and returns a pass object suitable for use in a PassContext or pass pipeline. The pass is intended to simplify or reduce tensor/layout-related representations during relay/tile transformations.
Returns:
The transform pass object produced by the FFI backend.
""" """
return _ffi_api.LayoutReducer() # type: ignore return _ffi_api.LayoutReducer() # type: ignore
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment