Unverified Commit 2af3f22e authored by coderabbitai[bot]'s avatar coderabbitai[bot] Committed by GitHub
Browse files

📝 Add docstrings to `pytile_0826` (#770)

* 📝 Add docstrings to `pytile_0826`

Docstrings generation was requested by @LeiWang1999.

* https://github.com/tile-ai/tilelang/pull/763#issuecomment-3224197814



The following files were modified:

* `src/op/atomic_add.cc`
* `src/op/atomic_add.h`
* `src/op/copy.cc`
* `src/op/copy.h`
* `src/op/elem.cc`
* `src/op/elem.h`
* `src/op/gemm.cc`
* `src/op/gemm.h`
* `src/op/gemm_sp.cc`
* `src/op/gemm_sp.h`
* `src/op/operator.cc`
* `src/op/operator.h`
* `src/op/parallel.cc`
* `src/op/parallel.h`
* `src/op/reduce.cc`
* `src/op/reduce.h`
* `src/op/region.cc`
* `src/op/region.h`
* `src/transform/layout_inference.cc`
* `src/transform/lower_tile_op.cc`

* lint fix

---------
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 8eab7755
......@@ -21,6 +21,18 @@ namespace tl {
using namespace tir;
/**
* @brief Extracts a numeric architecture identifier from a Target's "arch"
* attribute.
*
* Reads the Target's "arch" string (must be defined) and, if it has the form
* "sm_<N>", parses and returns N as an integer. For any other arch string,
* returns 0.
*
* @param target Target whose "arch" attribute will be inspected (ICHECKs that
* the attribute is defined).
* @return int Parsed integer suffix when the arch is "sm_<N>", otherwise 0.
*/
static int GetArchInt(Target target) {
int arch_int = 0;
auto s = target->GetAttr<String>("arch");
......@@ -34,6 +46,25 @@ static int GetArchInt(Target target) {
return arch_int;
}
/**
* @brief Construct an AtomicAdd operator from call arguments and a buffer map.
*
* Builds the internal AtomicAddNode, extracts the source and destination
* regions and their backing Buffers from the first two call-style expressions
* in `args` (via RegionOp), and stores them along with their ranges. If a third
* argument is provided, it is interpreted as an integer immediate and stored as
* the node's coalesced width.
*
* @param args Call-style PrimExprs where:
* - args[0] is the source region call,
* - args[1] is the destination region call,
* - args[2] (optional) is an IntImm specifying coalesced width.
* @param vmap Mapping from buffers used by RegionOp to concrete Buffer objects.
*
* Notes:
* - The constructor checks that args[0] and args[1] are CallNodes.
* - The constructed node is stored in this->data_.
*/
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<AtomicAddNode> node = make_object<AtomicAddNode>();
Array<Range> rgs[2];
......@@ -54,6 +85,15 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
data_ = std::move(node);
}
/**
* @brief Create a deep copy of this AtomicAdd node wrapped as a TileOperator.
*
* Produces a new AtomicAddNode object copied from this node. If this node has
* an associated ParallelOp (par_op_), the parallel op is cloned and attached to
* the new node so the cloned operator preserves parallelization state.
*
* @return TileOperator A TileOperator owning the cloned AtomicAddNode.
*/
TileOperator AtomicAddNode::Clone() const {
auto op = make_object<AtomicAddNode>(*this);
if (par_op_.defined()) {
......@@ -62,6 +102,19 @@ TileOperator AtomicAddNode::Clone() const {
return AtomicAdd(op);
}
/**
* @brief Create data-parallel iteration variables for non-singleton dimensions
* of the source.
*
* Constructs an Array of IterVar corresponding to each dimension in `src_range`
* whose extent is not equal to 1. Each IterVar has domain Range(0, extent), a
* Var named sequentially ("i", "j", "k", ...) with the same dtype as the
* extent, and type IterVarType::kDataPar. The ordering of returned itervars
* matches the order of dimensions in `src_range`.
*
* @return Array<IterVar> Iteration variables for all non-singleton extents in
* `src_range`.
*/
Array<IterVar> AtomicAddNode::MakeIterVars() const {
Array<IterVar> loop_vars;
size_t idx = 0;
......@@ -77,7 +130,26 @@ Array<IterVar> AtomicAddNode::MakeIterVars() const {
}
// ivs: itervars returned by MakeIterVars()
// src_dst: 0 for src_indices, 1 for dst_indices
/**
* @brief Build index expressions for either source or destination from loop
* iter vars.
*
* Given a list of iteration variables that correspond to the non-singleton
* extents of the selected region (source when src_dst == 0, destination when
* src_dst == 1), return an array of index expressions matching the full rank of
* that region. For dimensions with extent == 1, the corresponding index is the
* range's minimum; otherwise the index is `min + ivar`.
*
* @param ivs Iteration variables in order for all non-singleton dimensions of
* the chosen region.
* @param src_dst Selects which region to index: 0 for source (src_range), 1 for
* destination (dst_range).
* @return Array<PrimExpr> Index expressions for every dimension of the selected
* region, in original dimension order.
*
* @note The function checks that the number of provided iter vars equals the
* number of non-singleton extents; it will abort (ICHECK) if they differ.
*/
Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
int src_dst) const {
Array<PrimExpr> indices;
......@@ -97,6 +169,31 @@ Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
return indices;
}
/**
* @brief Build a combined bound-check predicate for indexed access.
*
* Constructs an AND'd predicate ensuring each non-singleton index (derived from
* `ivs`) stays within [0, extent) for the selected operand (source when
* `src_dst==0`, destination otherwise). For each non-unit Range in the chosen
* range list this produces two conditions:
* - range.min + iv >= 0
* - range.min + iv < extent
*
* Conditions that the analyzer can prove (with symbolic bounds) are omitted.
* If no uncertain conditions remain, an empty PrimExpr is returned.
*
* Note: the function ICHECKs that `extents.size()` equals the number of ranges
* for the selected operand.
*
* @param ivs Iteration variables corresponding to non-singleton extents (order
* matches the non-unit ranges of the chosen operand).
* @param extents Per-dimension upper bounds to check against; must have the
* same size as the selected range list.
* @param src_dst Selects which ranges to validate: 0 => `src_range`, else
* `dst_range`.
* @return PrimExpr A conjunction of remaining (non-provable) bounds checks, or
* an empty PrimExpr when no checks are required.
*/
PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
const Array<IterVar> &ivs,
Array<PrimExpr> extents,
......@@ -128,6 +225,34 @@ PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
}
}
/**
* @brief Build a SIMT-style loop nest that performs element-wise atomic
* additions from src to dst.
*
* Constructs a nested loop (parallelized per iter var) that loads a value from
* the source buffer, optionally casts it to the destination dtype, and performs
* an extern atomic add into the destination buffer address. For scalar
* (zero-dimensional) operations a trivial serial For with a single BufferStore
* is returned.
*
* The method:
* - Creates iter vars for all non-singleton extents and binds them into the
* provided analyzer.
* - Validates loop variable counts against src/dst ranges (ICHECK on mismatch).
* - Computes indexed accesses and emits optional bound predicates;
* out-of-bounds accesses are masked to zero when predicates are uncertain.
* - Emits an extern `call_extern("AtomicAdd", address_of(dst_value),
* src_value)` call wrapped in an Evaluate statement.
* - Wraps the body with a parallel For at each loop level. If `coalesced_width`
* is defined it is attached as the "coalesced_width" annotation on each loop.
*
* Note: This function mutates the analyzer binding state by binding loop
* variables and may fail via ICHECK if internal assumptions about shapes are
* violated.
*
* @return A nested For loop (parallel loops) implementing the atomic-add
* kernel. For scalar cases a serial For of extent 1 is returned.
*/
For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<IterVar> loop_vars = MakeIterVars();
bool is_scalar = loop_vars.size() == 0;
......@@ -191,6 +316,41 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
return Downcast<For>(body);
}
/**
* @brief Lower the atomic-add top-level operator into a parallel, vectorized
* TIR loop.
*
* Constructs a SIMT-style loop for the atomic-add, fuses parallel loops, runs
* layout inference at multiple levels, partitions the root loop by the provided
* thread variable, vectorizes the thread loop, and returns the final
* (optionally predicate-guarded) statement.
*
* The lowering pipeline:
* - Build the SIMT loop via MakeSIMTLoop.
* - Fuse parallel loops into a single For and wrap as a ParallelOp.
* - Run layout inference at kCommon, kStrict, and kFree levels using fields
* from `T`.
* - Obtain the loop layout, partition the root loop with PartitionLoop by
* `T.thread_var`.
* - Vectorize the partitioned thread loop via VectorizeLoop.
* - If the ParallelOp produced a predicate for `T.thread_var`, return an
* IfThenElse that guards the vectorized loop with that predicate; otherwise
* return the vectorized loop.
*
* @param T Lowering context whose fields are used:
* - T.target: target architecture for layout inference and lowering
* decisions.
* - T.thread_var: the Var used to partition the outer loop for thread-level
* parallelism.
* - T.thread_bounds: bounds associated with the thread dimension (used during
* partitioning).
* - T.layout_map, T.buffer_remap: layout and buffer remapping inputs used
* during InferLayout.
* @param analyzer Analyzer used for symbolic reasoning during partitioning and
* folding (omitted from detailed param docs as a common analysis utility).
* @return Stmt A lowered TIR statement representing the parallelized and
* vectorized atomic-add.
*/
Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Target target = T.target;
auto simt_loop = MakeSIMTLoop(analyzer);
......@@ -221,6 +381,25 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return vectorized_thread_loop;
}
/**
* @brief Infer and return the layout map for the atomic add operator.
*
* Constructs a cached ParallelOp (by building the SIMT loop) if not already
* present, validates that local.fragment layouts for src and dst match when
* both are provided, and then delegates layout inference to the underlying
* ParallelOp.
*
* @param T Layout inference inputs, including an optional mapping of buffers to
* layouts.
* @param level Inference strictness level.
* @return LayoutMap The inferred layout mapping for buffers used by this
* operator.
*
* @note This method mutates the AtomicAddNode by creating and storing a
* ParallelOp on first invocation.
* @throws If both src and dst have layouts in `local.fragment` and their
* fragment layouts differ, an ICHECK failure is raised with diagnostic output.
*/
LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (!par_op_.defined()) {
......
......@@ -10,6 +10,79 @@
#include "operator.h"
#include "parallel.h"
/**
* Lower this tile operator into a TIR statement for the given lowering context.
*
* @param T Lowering context containing mapped buffers and iteration
* information.
* @param analyzer Arithmetic analyzer used to simplify and reason about
* expressions.
* @return A TIR Stmt that implements the atomic-add tile operation for the
* provided context.
*/
/**
* Infer memory/layout mapping for tensors and buffers used by this operator.
*
* @param T Layout inference context providing buffer and shape information.
* @param level Inference aggressiveness level; higher levels may perform more
* speculative decisions.
* @return A LayoutMap describing inferred layouts for the operator's inputs and
* outputs.
*/
/**
* Get the Op registration that identifies this tile operator.
*
* @return A reference to the registered Op representing this operator.
*/
/**
* Create a deep copy of this tile operator node wrapped as a TileOperator.
*
* @return A TileOperator handle owning a cloned AtomicAddNode.
*/
/**
* Construct a SIMT-style For loop nest (thread/block mapping) appropriate for
* the operator.
*
* @param analyzer Arithmetic analyzer used to simplify loop bounds and
* predicates.
* @return A For loop node representing the SIMT-parallel loop structure.
*/
/**
* Create iteration variables used by this operator's loop nest.
*
* @return An array of IterVar objects describing the loop iteration axes.
*/
/**
* Produce index expressions for either source or destination buffer access
* based on iteration vars.
*
* @param ivs IterVars created by MakeIterVars().
* @param src_dst Selects which indices to produce: 0 for source indices, 1 for
* destination indices.
* @return An array of PrimExpr index expressions suitable for indexing the
* selected buffer.
*/
/**
* Build a predicate expression that guards out-of-bounds or conditional
* accesses for src or dst.
*
* @param analyzer Arithmetic analyzer used to simplify the predicate.
* @param ivs IterVars created by MakeIterVars().
* @param extents The loop extents corresponding to the itervars.
* @param src_dst Selects which side the predicate is for: 0 for source, 1 for
* destination.
* @return A PrimExpr boolean predicate that evaluates to true for valid
* iterations.
*/
/**
* Construct an AtomicAdd tile operator from operation arguments and a buffer
* mapping.
*
* @param args Operation arguments (e.g., values or indices) specific to the
* atomic-add semantics.
* @param vmap Mapping from buffer names to Buffer objects used by this
* operator.
*/
namespace tvm {
namespace tl {
......
This diff is collapsed.
......@@ -90,9 +90,220 @@ struct TMAIm2ColDesc {
/*!
* \brief Copy operator for transferring data between buffers.
*
* This class implements a generic copy operator in TensorIR Lowering for
* block-wise or element-wise data transfer, possibly optimized with
* parallelization or TMA hardware acceleration.
* Performs element- or block-wise copies between `src` and `dst` buffers for
* TensorIR lowering. The operator supports thread-level parallelization,
* shared-memory layouts, and hardware-accelerated paths (TMA/LDSM/STMATRIX)
* when available. Public fields describe the copy ranges and tuning knobs
* (coalesced width, eviction policy, disable_tma).
*/
/*!
* \brief Lower the copy operator to a TIR statement.
*
* Produces a TIR statement implementing the configured copy (normal, LDSM,
* STSM, or bulk TMA-based) for the given lowering context.
*
* \param T Lowering arguments that provide buffer bindings and context.
* \param analyzer Analyzer used for expression simplification and bounds
* checks. \return A TIR `Stmt` implementing the copy.
*/
/*!
* \brief Infer buffer layouts after applying this operator.
*
* Computes resulting layouts (shape/stride mappings) for buffers affected by
* this copy operation.
*
* \param T Arguments for layout inference (buffer maps, shapes).
* \param level Granularity of inference to perform.
* \return A LayoutMap describing inferred layouts.
*/
/*!
* \brief Check if bulk global->shared copy is supported on the target.
*
* Returns true if the target supports bulk (TMA) loads from global memory.
*
* \param target Target to query.
*/
/*!
* \brief Check if bulk shared->global store is supported on the target.
*
* Returns true if the target supports bulk (TMA) stores to global memory.
*
* \param target Target to query.
*/
/*!
* \brief Check if LDSM (LDMATRIX) memory-copy is supported on the target.
*
* \param target Target to query.
*/
/*!
* \brief Check if STSM (STMATRIX) memory-copy is supported on the target.
*
* \param target Target to query.
*/
/*!
* \brief Select the copy instruction type to use.
*
* Chooses between kNormal, kLDSM, kSTSM, kBulkLoad, and kBulkStore based on
* the target capabilities and whether TMA lowering is disabled.
*
* \param target Target to query.
* \param disable_tma_lower When true, force non-TMA copy paths.
* \return The selected CopyInst value.
*/
/*!
* \brief Clone this copy operator.
*
* Returns a TileOperator reference that is a shallow clone of this operator
* object suitable for further modifications in pass pipelines.
*/
/*!
* \brief Generate lowering for bulk (global-to-shared or shared-to-global)
* copy.
*
* Implements TMA-based bulk load/store lowering when `copy_inst` indicates a
* bulk path. The function encodes TMA descriptors and produces calls or
* loops required by the selected bulk mechanism.
*
* \param T Lowering context.
* \param analyzer Analyzer for simplification.
* \param copy_inst Copy instruction type indicating bulk load/store.
* \return A TIR `Stmt` implementing the bulk copy.
*/
/*!
* \brief Generate lowering for LDS matrix-copy paths (LDMATRIX/STMATRIX).
*
* Emits the lowering for LDS-based matrix-copy instructions when the chosen
* `copy_inst` is an LDSM or STSM variant.
*
* \param T Lowering context.
* \param analyzer Analyzer for simplification.
* \param copy_inst Copy instruction type indicating an LDS matrix path.
* \return A TIR `Stmt` implementing the matrix-copy.
*/
/*!
* \brief Generate lowering for the normal (non-bulk, scalar/vec) copy path.
*
* Emits element-wise or vectorized loads/stores using the computed iteration
* space and predicates to ensure in-bounds accesses.
*
* \param T Lowering context.
* \param analyzer Analyzer for simplification.
* \return A TIR `Stmt` implementing the normal copy.
*/
/*!
* \brief Generate a SIMT-style thread-level loop for the copy.
*
* Produces a `For` loop that distributes copy work across SIMD/warp lanes or
* CUDA threads according to the operator's iteration strategy.
*
* \param analyzer Analyzer for simplification.
* \return A `For` loop representing the thread-level iteration.
*/
/*!
* \brief Compute a linear shared-memory layout suitable for TMA copies.
*
* Returns a `Layout` that maps the shared-memory `shared_tensor` into a
* linearized representation required by bulk/TMA transfers.
*
* \param shared_tensor Buffer representing the shared-memory tensor.
* \return A `Layout` describing the linearized shared layout.
*/
/*!
* \brief Create iterator variables for multi-dimensional copy loops.
*
* The returned `IterVar` array enumerates the loop indices used to traverse
* the copy extents in each tensor dimension.
*
* \return Array of iterator variables.
*/
/*!
* \brief Calculate source or destination indices from iteration variables.
*
* Converts the iterator variables (from MakeIterVars) into concrete index
* expressions for either the source image or the destination tensor.
*
* \param ivs Iterator variables returned by MakeIterVars().
* \param src_dst 0 to produce source indices, 1 to produce destination indices.
* \return Array of `PrimExpr` index expressions.
*/
/*!
* \brief Construct the boundary predicate ensuring in-bounds accesses.
*
* Builds a boolean expression that guards loads/stores so they only occur
* when indices lie within the provided `extents`.
*
* \param analyzer Arithmetic analyzer used to simplify predicates.
* \param ivs Iterator variables.
* \param extents Extent expressions for the target buffer.
* \param src_dst 0 = predicate for source indices, 1 = predicate for
* destination. \return A `PrimExpr` boolean predicate.
*/
/*!
* \brief Constructor.
*
* \param args Expression arguments for the copy (indices, sizes, etc.).
* \param vmap Buffer variable mapping for source and destination.
*/
/*!
* \brief Get the TVM Op handle corresponding to this Copy op.
*/
/*!
* \brief Special operator for Conv2D im2col transformation.
*
* Converts an input feature map into an im2col matrix layout used for GEMM-
* based convolution lowering. Public fields configure kernel geometry,
* stride/padding/dilation, and cache eviction behavior.
*/
/*!
* \brief Lower to TIR statement.
*
* Emits TIR that performs the im2col extraction from `src` into `dst`
* according to kernel, stride, padding, and dilation parameters.
*
* \param T Lowering context with buffer bindings.
* \param analyzer Analyzer for expression simplification and bounds reasoning.
* \return A TIR `Stmt` performing the im2col transform.
*/
/*!
* \brief Infer layout for this operator.
*
* Produces the layout mapping for the destination im2col matrix given the
* source layout and convolution parameters.
*
* \param T Layout inference arguments.
* \param level Inference granularity level.
* \return A LayoutMap with inferred layouts for affected buffers.
*/
/*!
* \brief Get TVM Op handle for Conv2DIm2Col.
*/
/*!
* \brief Clone this Conv2DIm2Col operator.
*
* Returns a TileOperator reference that is a shallow clone of this operator.
*/
class CopyNode : public TileOperatorNode {
public:
......@@ -208,6 +419,24 @@ protected:
PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const;
/**
* \brief Create a deep copy of this operator.
*
* Returns a TileOperator that is a copy of the current node, preserving all
* configuration (buffers, parameters, and layout-related fields).
* @return A TileOperator owning the cloned operator node.
*/
/**
* \brief Constructor.
* \param args Expression arguments for the Conv2D im2col operator.
* \param vmap Buffer variable mapping.
*/
/**
* \brief Get the TVM Op handle corresponding to this Conv2DIm2Col operator.
* @return Reference to the singleton TVM Op representing this operator.
*/
TileOperator Clone() const;
};
......
......@@ -22,6 +22,42 @@ namespace tl {
using namespace tir;
/**
* @brief Construct a Fill operator node from call arguments and a buffer map.
*
* This constructor builds a FillNode describing an element-wise fill of a
* destination buffer region with a scalar/vector value and stores it in
* `data_`.
*
* Detailed behavior:
* - If `args[0]` is a `BufferLoad`, the loaded buffer becomes the destination
* and the load indices are converted to per-dimension ranges:
* - `Ramp(base, lanes, stride)` is converted to `Range(base, lanes)`. Only
* stride == 1 and constant `lanes` are supported.
* - Non-ramp indices become `Range(index, 1)`.
* - Otherwise `args[0]` is treated as an access pointer; the destination buffer
* is resolved via `vmap[GetVarFromAccessPtr(args[0])]` and the region is the
* full buffer shape for each dimension.
* - `args[1]` is used as the fill value; it is cast to the destination buffer's
* dtype if necessary.
* - Performs validation:
* - Region dimensionality must match destination rank.
* - For statically-known region mins and extents, checks that mins >= 0 and
* extents do not exceed the corresponding destination shape extents.
*
* Parameters:
* @param args Call arguments: expected layout is [dst_access_or_bufferload,
* value].
* - args[0]: destination access (BufferLoad or pointer expression).
* - args[1]: value to fill (scalar or vector).
* @param vmap Mapping from buffer variables to Buffer objects; used to resolve
* the destination when args[0] is not a BufferLoad.
*
* Notes:
* - The constructor enforces constraints (e.g., stride == 1 ramps, constant
* lanes) and will terminate (via CHECK/ICHECK) if inputs are unsupported or out
* of bounds.
*/
Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<FillNode> node = make_object<FillNode>();
......@@ -71,11 +107,31 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
data_ = std::move(node);
}
/**
* @brief Create a copy of this FillNode and return it as a TileOperator.
*
* Constructs a new FillNode by copying the current node and wraps the copy in a
* Fill TileOperator.
*
* @return TileOperator A TileOperator that owns the copied FillNode.
*/
TileOperator FillNode::Clone() const {
auto op = make_object<FillNode>(*this);
return Fill(op);
}
/**
* @brief Build a SIMT-style nested parallel loop that fills the destination
* buffer.
*
* Constructs per-dimension data-parallel loop iterators matching this node's
* region extents, emits a BufferStore that writes the node's `value` into `dst`
* at the loop indices, and nests the loops (innermost to outermost) as parallel
* `For` nodes. Returns the outermost `For` loop representing the complete
* multi-dimensional fill kernel.
*
* @return For Outermost parallel `For` loop of the generated nested SIMT loop.
*/
For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
int ndim = dst->shape.size();
Array<IterVar> loop_vars;
......@@ -93,6 +149,24 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
return Downcast<For>(body);
}
/**
* @brief Lower this Fill operator to a TIR statement for the target.
*
* Lowers the FillNode into a Stmt according to the destination buffer scope:
* - "local.fragment" and shared ("shared", "shared.dyn"): create a parallel
* operation from a SIMT loop, infer its layout, partition the root loop by
* the thread variable, vectorize the resulting thread loop, and, if a
* per-thread predicate exists, guard the vectorized loop with that
* predicate.
* - "local": build a SIMT loop and return its vectorized form.
* - other scopes: fatal error.
*
* The lowering may query layout and thread information from @p T and uses the
* provided analyzer for any required arithmetic/layout analysis.
*
* @param T Lowering arguments (target, thread bounds, thread var, layout map).
* @return Stmt The lowered TIR statement implementing the fill.
*/
Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (dst.scope() == "local.fragment") {
auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
......@@ -129,6 +203,17 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
}
}
/**
* @brief Infer memory/layout mapping for the Fill operator.
*
* Returns the layout mapping produced by layout inference for this FillNode.
* Currently no layout inference is performed for Fill and the function returns
* an empty LayoutMap.
*
* @param T Context required for layout inference (unused).
* @param level The inference level requested (unused).
* @return LayoutMap Empty map indicating no inferred layouts for this operator.
*/
LayoutMap FillNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
return {};
......
......@@ -10,6 +10,63 @@
#include "operator.h"
#include "parallel.h"
/**
* Lower the Fill operator into TIR statements.
*
* Produces a TIR Stmt that implements element-wise filling of `dst` over
* `region` with `value`, using information from `T`.
*
* @param T Lowering inputs (buffers, shapes, and iteration info) used to
* generate the IR.
*/
/**
* Infer the memory layout mapping for the Fill operator.
*
* Returns a LayoutMap that describes how logical iteration axes map to memory
* dimensions for the destination buffer. `level` controls the aggressiveness
* of inference (e.g., relaxed vs. strict constraints).
*
* @param T Layout inference inputs (buffers, shapes, and related metadata).
* @param level Inference level controlling precision of the returned mapping.
*/
/**
* Return the global operator descriptor for tl.Fill.
*
* The returned Op can be used to look up operator-level metadata and to
* register or query the operator within the TVM operator registry.
*/
/**
* Create a copy of this operator node as a TileOperator reference.
*
* The returned TileOperator is an independent handle representing a clone of
* the underlying FillNode.
*/
/**
* Build a SIMT-style For loop that implements the fill.
*
* Constructs and returns a TIR `For` loop that iterates over the target region
* in a SIMT-friendly ordering appropriate for `dst` and `region`.
*/
/**
* Construct a Fill operator from argument expressions and a buffer mapping.
*
* @param args Positional PrimExpr arguments passed to the operator (e.g.,
* indices or shape expressions required by the operator's specification).
* @param vmap Mapping from named buffer parameters to concrete tir::Buffer
* instances used by this operator instance.
*/
/**
* Return the global operator descriptor for the public Fill wrapper.
*
* Mirrors FillNode::Get() and provides the operator descriptor for users of the
* public TileOperator API.
*/
namespace tvm {
namespace tl {
......
......@@ -19,6 +19,16 @@ namespace tl {
using namespace tir;
/**
* @brief Compute the prime factorization of an integer.
*
* Returns the prime factors of x in non-decreasing order by repeatedly dividing
* out the smallest possible factor.
*
* @param x Integer to factorize. If x <= 1, an empty vector is returned.
* @return std::vector<int> Prime factors of x (with multiplicity), in
* non-decreasing order.
*/
static std::vector<int> toPrimeFactors(int x) {
int i = 2;
std::vector<int> result;
......@@ -33,6 +43,34 @@ static std::vector<int> toPrimeFactors(int x) {
return result;
}
/**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
* map.
*
* This constructor deserializes operator parameters from `args` and resolves
* buffer references via `vmap`, populating an internal GemmNode with:
* - device pointers for A, B, C and their corresponding Buffer objects,
* - transpose flags for A and B,
* - matrix dimensions M, N, K,
* - warp allocation policy and clear_accum flag,
* - strides and memory offsets for A and B,
* - optional kPack (must be 1 or 2) and optional wg_wait.
*
* The populated GemmNode is stored into the wrapper's internal `data_`.
*
* @param args Positional serialized arguments produced by the TL frontend:
* expected layout is:
* [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
* M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
* stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
* (optional) kPack (Int), (optional) wg_wait (Int)]
* @param vmap Mapping from access pointer vars to Buffer objects used to
* resolve the Buffer corresponding to each pointer argument.
*
* @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
* fails with an ICHECK (runtime assertion). No other validation is
* performed here.
*/
Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmNode> node = make_object<GemmNode>();
......@@ -66,11 +104,39 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
data_ = std::move(node);
}
/**
* @brief Create a copy of this GemmNode as a TileOperator.
*
* Constructs a new GemmNode by copying the current node state and returns it
* wrapped in a Gemm TileOperator.
*
* @return TileOperator A Gemm operator that owns a copy of this node.
*/
TileOperator GemmNode::Clone() const {
auto op = make_object<GemmNode>(*this);
return Gemm(op);
}
/**
* @brief Selects the GEMM implementation variant for a given block size and
* target.
*
* Determines which low-level GEMM instruction to use:
* - Returns kWGMMA when running on Hopper-class targets and the operator meets
* WGMMA constraints (M >= 64, number of warps is a multiple of 4, and
* CheckWGMMA() returns true).
* - Returns kMFMA for CDNA targets.
* - Returns kMMA for CUDA targets.
*
* @param block_size Number of threads in the CUDA/ROCm thread block used for
* the GEMM.
* @param target Target backend describing the hardware (used to detect
* architecture).
* @return GemmInst The chosen GEMM implementation enum value.
*
* @throws fatal error (ICHECK) If the target is not recognized/supported, this
* function triggers a runtime check failure.
*/
GemmNode::GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size;
......@@ -375,6 +441,20 @@ bool GemmNode::CheckWGMMA() const {
}
}
/**
* @brief Parse and return the numeric GPU architecture from a Target's "arch"
* attribute.
*
* Examines the target's "arch" string and, if it matches the pattern
* "sm_<num>", returns <num> as an int. If the attribute is present but does not
* match that pattern, returns 0.
*
* Preconditions: the target must have an "arch" attribute (this is checked via
* ICHECK).
*
* @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if
* the arch string does not match "sm_<num>".
*/
static int GetArchInt(Target target) {
int arch_int = 0;
auto s = target->GetAttr<String>("arch");
......@@ -388,6 +468,19 @@ static int GetArchInt(Target target) {
return arch_int;
}
/**
* @brief Lower the GEMM operator to a TL TIR call expression.
*
* Constructs a tl::gemm call string parameterized by M, N, K, warp partition,
* transpose flags, accumulation clearing, target-specific stride/offset/kPack
* and optional workgroup wait value, then returns an Evaluate(call) node
* invoking tl::tl_gemm with the composed string and the A/B/C buffer handles.
*
* @param T Contains lowering context including thread bounds and target.
* @param analyzer Optional arithmetic analyzer used by lowering (may be
* nullptr).
* @return Stmt A TIR statement representing the evaluated TL GEMM call.
*/
Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto block_size = *as_const_int(T.thread_bounds->extent);
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
......@@ -426,28 +519,23 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
}
/**
* @brief Infer memory/layout mappings for A, B, and C buffers for this GEMM op.
* @brief Infer and bind target-specific memory/layout mappings for A, B, and C.
*
* Generates and returns a LayoutMap that binds buffer A, B, and C to
* target- and architecture-specific fragment or shared-memory layouts based
* on the current target, thread bounds, warp partitioning, data types, and
* transpose flags. This performs target dispatch (Volta, Ampere/Turing/SM120,
* Hopper, CDNA), selects the appropriate fragment or shared layout creators,
* and binds fragment layouts to the thread range when buffers are local
* fragments.
* Infers per-buffer layouts (fragment or shared-memory layouts) for this GEMM
* operator according to the target architecture, thread bounds, warp
* partitioning, data types, and transpose flags, then binds fragment layouts
* to the thread range when required.
*
* Preconditions:
* - C.scope() must be "local.fragment".
* - C.scope() == "local.fragment"
*
* Postconditions / side effects:
* - Marks the operator's layout inference as completed (sets completed_ =
* true).
* Side effects:
* - Marks layout inference as completed (sets completed_ = true).
* - May abort via ICHECK on unsupported targets, invalid buffer scopes, or
* incompatible shape constraints.
*
* @param T Layout inference inputs (thread bounds and target).
* @param level Inference level (unused for side effects but retained for API).
* @return LayoutMap mapping each of A, B, and C to their inferred layouts.
* @param T Input layout-inference context (provides thread bounds and target).
* @return LayoutMap mapping A, B, and C to their inferred layouts.
*/
LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
......
......@@ -10,6 +10,74 @@
#include "operator.h"
namespace tvm {
/**
* Check whether the target and configuration allow using WGMMA (wavefront-group
* MMA) for this GEMM.
*
* @returns true if WGMMA can be used for the current node configuration and
* target; false otherwise.
*/
/**
* Lower this GEMM operator to a TVM Stmt for the given lowering context.
*
* @param T Lowering arguments and context (tile mappings, target, etc.).
* @param analyzer Arithmetic analyzer used for symbolic simplification and
* bounds reasoning.
* @returns A lowered Stmt implementing the GEMM.
*/
/**
* Infer memory/layout mapping for GEMM inputs/outputs at the given inference
* level.
*
* @param T Layout inference inputs (buffers, shapes, constraints).
* @param level Inference level that controls how aggressive/specific the
* inferred layouts should be.
* @returns A LayoutMap describing how logical tensor axes map to storage/layout
* axes.
*/
/**
* Create a deep/shallow copy of this TileOperator node as a TileOperator
* reference.
*
* @returns A TileOperator reference that represents a clone of this GemmNode.
*/
/**
* Determine the specific GEMM instruction variant to use for the given block
* size and target.
*
* @param block_size The tile/block size (in elements or threads) used to select
* instruction variant.
* @param target The compilation target describing architecture and instruction
* set.
* @returns The GemmInst enum value representing the chosen GEMM instruction
* family.
*/
/**
* Compute how to partition work across warps for the given number of warps and
* GEMM instruction.
*
* The returned pair is (warp_rows, warp_cols), describing the per-warp tiling
* in row and column dimensions respectively.
*
* @param num_warps Total number of warps available for the block.
* @param gemm_inst The GEMM instruction variant selected for the target.
* @param target The compilation target which may constrain or influence
* partitioning.
* @returns A pair<int,int> = (warp_rows, warp_cols) describing the warp
* partition.
*/
/**
* Construct a Gemm operator handle from call arguments and a buffer mapping.
*
* @param args Array of call-time PrimExpr arguments passed to the operator.
* @param vmap Mapping from buffer names/indices to tir::Buffer objects used by
* this GEMM.
*/
/**
* Obtain the registered Op descriptor for the GEMM operator.
*
* @returns A const reference to the Op representing "tl.Gemm".
*/
namespace tl {
using namespace tir;
......
......@@ -17,6 +17,17 @@
namespace tvm {
namespace tl {
/**
* @brief Decomposes a positive integer into its prime factors.
*
* Returns the prime factorization of `x` as a vector of prime factors in
* non-decreasing order. If `x <= 1` the returned vector is empty.
*
* @param x Integer to factorize (expected non-negative; behavior: returns empty
* for values <= 1).
* @return std::vector<int> Prime factors of `x` (with repetition), e.g. 12 ->
* {2, 2, 3}.
*/
static std::vector<int> toPrimeFactors(int x) {
int i = 2;
std::vector<int> result;
......@@ -31,6 +42,27 @@ static std::vector<int> toPrimeFactors(int x) {
return result;
}
/**
* @brief Construct a GemmSP operator node from TL call arguments and a buffer
* map.
*
* Parses the expected call argument tuple and fills an internal GemmSPNode:
* - Buffers: A (args[0]), E (args[1]), B (args[2]), C (args[3]) are looked up
* in vmap.
* - Booleans: trans_A (args[4]), trans_B (args[5]).
* - Dimensions: M (args[6]), N (args[7]), K (args[8]) as integers.
* - Warp policy: policy (args[9]) mapped to GemmWarpPolicy.
* - clear_accum: boolean flag (args[10]).
* - Optional kPack (args[11]): must be 1 or 2 (checked via ICHECK).
* - Optional wg_wait (args[12]): integer workgroup wait parameter.
*
* The populated GemmSPNode is stored in the instance's internal data_ pointer.
*
* @param args Positional TL call arguments in the above order.
* @param vmap BufferMap mapping access pointers (from args) to Buffer objects.
*
* @note An ICHECK failure is raised if a provided kPack is not 1 or 2.
*/
GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmSPNode> node = make_object<GemmSPNode>();
node->A = vmap[GetVarFromAccessPtr(args[0])];
......@@ -57,11 +89,41 @@ GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) {
data_ = std::move(node);
}
/**
* @brief Create a deep copy of this GemmSPNode wrapped as a TileOperator.
*
* Returns a new TileOperator that owns a copy of this node. The cloned node
* duplicates all fields of the original; subsequent modifications to the
* clone do not affect the original node.
*
* @return TileOperator A TileOperator holding a cloned GemmSPNode.
*/
TileOperator GemmSPNode::Clone() const {
auto op = make_object<GemmSPNode>(*this);
return GemmSP(op);
}
/**
* @brief Compute a partition of warps across the M and N GEMM dimensions.
*
* Computes (m_warp, n_warp) such that m_warp * n_warp == num_warps and the
* warp counts respect element-per-warp granularity and the configured
* GemmWarpPolicy. On Hopper targets, when `maybe_hopper_wgmma` is true and
* the problem size permits, a warp-group (WGMMA)-aware partitioning is used
* (groups of 4 warps).
*
* @param num_warps Total number of warps available for the block.
* @param target Hardware target used to decide target-specific strategies
* (e.g., Hopper WGMMA grouping).
* @param maybe_hopper_wgmma If true, allows using Hopper WGMMA-specific
* partitioning when the target and problem size
* permit.
* @return std::pair<int,int> A pair (m_warp, n_warp) giving the number of warp
* partitions along M and N, respectively.
*
* @note The function uses ICHECK to enforce invariants (e.g., unknown policy or
* invalid m_warp * n_warp), which will terminate on failure.
*/
std::pair<int, int>
GemmSPNode::ComputeWarpPartition(int num_warps, Target target,
bool maybe_hopper_wgmma) const {
......@@ -220,6 +282,24 @@ GemmSPNode::ComputeWarpPartition(int num_warps, Target target,
return {m_warp, n_warp};
}
/**
* @brief Lower this GemmSP node to a TL (tensile-like) intrinsic call.
*
* Constructs and returns an Evaluate statement containing a call to the
* TL gemm_sp intrinsic that encodes this GEMM's template parameters
* (M, N, K, warp partition, transposition flags, clear_accum, and optional
* Hopper/WGMMA and wg_wait modifiers) and the remapped buffer access pointers.
*
* The function validates that A, B, and E reside in shared (or shared.dyn)
* memory (ICHECK failures otherwise), computes the warp partition based on
* the launch configuration and target, and emits a single tl::tl_gemm_sp call
* with a string template describing the configuration.
*
* @param T Lowering context containing thread bounds, target, and optional
* buffer remapping used to obtain the final buffer AccessPtr
* arguments for the TL call.
* @return Stmt An Evaluate wrapping the constructed tl::tl_gemm_sp call.
*/
Stmt GemmSPNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
int warp_size = 32;
......@@ -264,6 +344,34 @@ Stmt GemmSPNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return Evaluate(new_call);
}
/**
* @brief Infers and returns the memory/layout mapping for the GemmSP operator.
*
* Infers thread-local fragment layout for C and shared-memory layouts for A and
* B based on the target (Hopper-only path), block/thread bounds in T,
* transposition flags, and matrix dimensions stored in the node. The function
* caches its work: if layout inference has already completed (completed_ ==
* true) it returns an empty LayoutMap.
*
* Precondition:
* - C.scope() must be "local.fragment".
*
* Behavior notes:
* - Only the Hopper target is supported; non-Hopper targets trigger a fatal
* check.
* - For Hopper, the function computes a warp partition from block size and may
* enable WGMMA-specific fragment creation when conditions on M and block size
* are met.
* - A and B must reside in "shared" or "shared.dyn"; otherwise the function
* aborts with a check failure.
* - The method sets completed_ = true before returning to avoid re-entrance.
*
* @param T LayoutInferArgs containing thread bounds and the target (used to
* select Hopper-specific layouts).
* @param level Currently unused inference detail level.
* @return LayoutMap mapping A, B, and C to their inferred layouts (or empty if
* inference was already completed).
*/
LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (completed_)
......
......@@ -10,6 +10,60 @@
#include "operator.h"
namespace tvm {
/**
* Lower the GemmSP operator into a TIR statement for the given lowering
* context.
*
* Produces the TIR Stmt that implements this operator using the provided
* lowering arguments. The `analyzer` is used for arithmetic simplifications and
* may be null.
*
* @param T Lowering context and arguments.
* @returns A TIR `Stmt` implementing the lowered operator.
*/
/**
* Infer memory/layout mapping for operands and outputs of this operator.
*
* Computes a LayoutMap describing how logical tensor layouts map to physical
* buffer layouts for the given inference `level`.
*
* @param T Layout inference inputs (shapes, buffer info, etc.).
* @param level Inference granularity/level.
* @returns A LayoutMap describing inferred layouts.
*/
/**
* Compute a warp-level partitioning (rows, cols) for the given number of warps.
*
* Returns a pair (warps_per_row, warps_per_col) describing how to tile the GEMM
* across warps for the specified `target`. The optional `maybe_hopper_wgmma`
* enables target-specific adjustments (e.g., CDNA WG/MMA variants) when set.
*
* @param num_warps Total number of warps available for the tile.
* @param target Target device/architecture used to guide partitioning choices.
* @param maybe_hopper_wgmma Enable target-specific WG/MMA adjustments when
* true.
* @returns Pair<int,int> of (warps_per_row, warps_per_col).
*/
/**
* Create a copy of this TileOperator node as a TileOperator reference.
*
* The returned TileOperator refers to a new node that is a copy of this node.
*
* @returns A TileOperator that is a clone of this node.
*/
/**
* Construct a GemmSP TileOperator from call arguments and a buffer map.
*
* @param args Array of PrimExpr specifying call-site arguments for the
* operator.
* @param vmap Mapping from buffer names to tir::Buffer objects for
* operands/outputs.
*/
/**
* Return the singleton Op descriptor for the GemmSP operator.
*
* @returns Reference to the operator's Op registration object.
*/
namespace tl {
using namespace tir;
......
......@@ -15,6 +15,21 @@ namespace tl {
using namespace tir;
/**
* @brief Construct a TileOperator from a TIR Call using a registered builder.
*
* Looks up a builder function in the "TLOpBuilder" Op attribute map for the
* operator referenced by `call` and invokes it to produce a TileOperator. If no
* builder is registered for the operator, returns a default-constructed (empty)
* TileOperator.
*
* @param call The TIR Call whose operator and arguments will be used to build
* the TileOperator.
* @param vmap Buffer mapping passed through to the builder to resolve buffer
* references.
* @return TileOperator The constructed TileOperator, or a default (empty)
* TileOperator if no builder exists.
*/
TileOperator ParseOperator(Call call, BufferMap vmap) {
auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder");
Op op = call->op.as<Op>().value();
......@@ -26,6 +41,18 @@ TileOperator ParseOperator(Call call, BufferMap vmap) {
return TileOperator();
}
/**
* @brief Parse a TileOperator from a TIR statement if it contains a call.
*
* If `stmt` is an Evaluate node whose value is a Call, delegates to
* ParseOperator(Call, BufferMap) and returns the resulting TileOperator.
* Otherwise returns a default-constructed (empty) TileOperator.
*
* @param stmt TIR statement to inspect; expected to be an Evaluate of a Call.
* @param vmap Mapping of buffer variables used when building the operator.
* @return TileOperator Parsed operator on success, or a default (empty)
* TileOperator if `stmt` is not an Evaluate(Call).
*/
TileOperator ParseOperator(Stmt stmt, BufferMap vmap) {
if (stmt.as<Evaluate>() && stmt.as<EvaluateNode>()->value.as<CallNode>()) {
auto call = stmt.as<EvaluateNode>()->value.as<CallNode>();
......@@ -34,6 +61,17 @@ TileOperator ParseOperator(Stmt stmt, BufferMap vmap) {
return TileOperator();
}
/**
* @brief Extracts the Var referenced by a `tvm_access_ptr` call expression.
*
* The function expects `expr` to be a `Call` to the builtin `tvm_access_ptr`
* and returns the `Var` found in the call's second argument (`args[1]`). The
* function performs runtime checks and will abort if `expr` is not a call, the
* call is not `tvm_access_ptr`, or the second argument is not a `Var`.
*
* @param expr A `PrimExpr` representing a `tvm_access_ptr(...)` call.
* @return tvm::Var The `Var` referenced by the `tvm_access_ptr` call.
*/
Var GetVarFromAccessPtr(const PrimExpr &expr) {
auto call = expr.as<CallNode>();
ICHECK(call);
......
......@@ -11,8 +11,8 @@
#include <tvm/ir/op.h>
#include <tvm/target/target.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/stmt.h>
#include "../layout/layout.h"
......@@ -51,32 +51,117 @@ struct LayoutInferArgs {
class TileOperatorNode;
class TileOperator;
class TileOperatorNode: public Object {
public:
/**
* Abstract base class for tile-level operators.
*
* Implementations must provide lowering to TIR, layout inference, and cloning.
*/
/**
* Lower this tile operator to a TIR statement.
*
* @param T Lowering context and utilities (target, thread bounds, layout
* mappings, buffer remapping, and AddWorkspace callback for requesting
* temporary buffers).
* @param analyzer Arithmetic analyzer used during lowering.
* @return A TIR Stmt representing the lowered operator.
*/
/**
* Infer buffer layouts for this operator.
*
* The returned LayoutMap associates input/output Buffers with inferred Layouts.
* The `level` controls how strictly layouts are determined (kFree, kCommon,
* kStrict).
*
* @param T Layout inference context (target, thread bounds, existing
* layout_map, buffer_remap).
* @param level Inference strictness level.
* @return A LayoutMap mapping Buffers to their inferred Layouts.
*/
/**
* Create a deep copy of this TileOperator.
*
* @return A TileOperator referencing a cloned operator instance.
*/
/**
* Reference wrapper for TileOperatorNode.
*
* Use this ObjectRef to hold and pass tile operator instances within the
* runtime.
*/
/**
* Extract the underlying Var from an access pointer expression.
*
* If `expr` represents an access pointer that directly refers to a variable,
* returns that Var; otherwise returns a null/default Var.
*
* @param expr The pointer/access expression to inspect.
* @return The extracted Var, or a null Var if none can be found.
*/
/**
* Parse a Call into a TileOperator using the provided buffer mapping.
*
* @param call The Call node representing a tile operator invocation.
* @param vmap Mapping from TIR Vars to Buffers for resolving buffer arguments.
* @return A TileOperator constructed from the call and buffer map.
*/
/**
* Parse a Stmt into a TileOperator using the provided buffer mapping.
*
* @param stmt The Stmt representing a tile operator region or call.
* @param vmap Mapping from TIR Vars to Buffers for resolving buffer references.
* @return A TileOperator constructed from the statement and buffer map.
*/
/**
* Function type for TL operator builders exposed to the FFI.
*
* Builder functions take an array of PrimExpr arguments and a BufferMap, and
* return a constructed TileOperator.
*/
/**
* Register a TL operator and its builder with TVM's op registry.
*
* Entry should be a type providing a static `Get()` and a constructor taking
* `(Array<PrimExpr>, BufferMap)`. This macro registers the operator under the
* name "tl.OpName" and sets an FFI builder attribute that constructs
* Entry(args, vmap).
*
* Usage: TIR_REGISTER_TL_OP(MyOpEntry, MyOp)
*/
class TileOperatorNode : public Object {
public:
virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const = 0;
virtual LayoutMap InferLayout(const LayoutInferArgs& T,
virtual LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const = 0;
virtual TileOperator Clone() const = 0;
static constexpr const char* _type_key = "tl.TileOperator";
static constexpr const char *_type_key = "tl.TileOperator";
TVM_DECLARE_BASE_OBJECT_INFO(TileOperatorNode, Object);
};
class TileOperator : public ObjectRef {
public:
public:
TVM_DEFINE_OBJECT_REF_METHODS(TileOperator, ObjectRef, TileOperatorNode);
};
Var GetVarFromAccessPtr(const PrimExpr &expr);
TileOperator ParseOperator(Call call, BufferMap vmap);
TileOperator ParseOperator(Stmt stmt, BufferMap vmap);
using OpBuilderFunc = ffi::TypedFunction<TileOperator(Array<PrimExpr>, BufferMap)>;
using OpBuilderFunc =
ffi::TypedFunction<TileOperator(Array<PrimExpr>, BufferMap)>;
#define TIR_REGISTER_TL_OP(Entry, OpName) \
const Op &Entry::Get() { \
......@@ -90,7 +175,6 @@ using OpBuilderFunc = ffi::TypedFunction<TileOperator(Array<PrimExpr>, BufferMap
return Entry(args, vmap); \
})
} // namespace tl
} // namespace tvm
......
......@@ -147,6 +147,19 @@ void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) {
StmtExprVisitor::VisitStmt_(op);
}
/**
* @brief Visit a BufferLoad node and record/validate index mapping for
* fragment-local buffers.
*
* If the loaded buffer's scope is "local.fragment", this records the load
* indices in the visitor's indice_map_ when seen for the first time. If an
* entry already exists, the previously recorded indices are asserted
* structurally equal to the current indices.
*
* This ensures all accesses to the same fragment-local buffer within the
* parallel loop use a consistent index map. The function then continues
* standard expression visitation.
*/
void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) {
if (op->buffer.scope() == "local.fragment") {
if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
......@@ -160,42 +173,91 @@ void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) {
StmtExprVisitor::VisitExpr_(op);
}
/**
* @brief Construct a ParallelOpNode from a parallel loop nest root.
*
* Initializes the node with the given For loop as the root of the parallel
* operator and immediately runs the internal ParallelLoopNestVisitor to collect
* loop and buffer access information from the nested body.
*
* @param root The root For node representing the parallel loop nest to be
* analyzed.
*/
ParallelOpNode::ParallelOpNode(For root) : root_(root), V(this) {
V.VisitStmt(root);
}
/**
* @brief Create a copy of this ParallelOpNode wrapped as a TileOperator.
*
* Returns a new TileOperator that holds a deep copy of this ParallelOpNode.
*
* @return TileOperator A TileOperator owning a copy of this node.
*/
TileOperator ParallelOpNode::Clone() const {
auto op = make_object<ParallelOpNode>(*this);
return ParallelOp(op);
}
/**
* @brief No-op lowering: return the stored root statement unchanged.
*
* This implementation does not perform any transformation and returns the
* operator's original root For statement as-is.
*
* @param T Lowering arguments (unused).
* @return Stmt The original root statement held by this ParallelOpNode.
*/
Stmt ParallelOpNode::Lower(const LowerArgs &T,
arith::Analyzer *analyzer) const {
return root_;
}
/**
* @brief Check whether a buffer is indexed by the loop's canonical (common)
* iteration variables.
*
* Returns true if the recorded index mapping for `buffer` is structurally equal
* to the sequence of loop iteration variables for this parallel op (i.e., the
* buffer is accessed using the common access indices of the loop nest).
*
* @param buffer The buffer to check.
* @return true if the buffer's index map equals the loop's iteration variables;
* false otherwise.
*/
bool ParallelOpNode::IsCommonAccessIndice(const Buffer &buffer) const {
auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; });
return StructuralEqual()(indice_map_[buffer], common_indice);
}
/*! \brief Infer the layout for parallel operations based on different inference
* levels
/**
* @brief Infer buffer layouts for a Parallel operator based on the chosen
* inference level.
*
* The inference level controls how aggressively we try to infer and optimize
* layouts:
* - kStrict (2): Most conservative level. Only allows explicitly defined
* layouts. Returns empty layout map if loop_layout_ is not already defined.
* Used when exact layout control is required.
* Attempts to compute a consistent LayoutMap for buffers accessed by a parallel
* loop (root_) using explicit input layouts (T.layout_map), thread bounds
* (T.thread_bounds), and optional buffer remapping/vectorization information in
* T. Behavior depends on the supplied InferLevel:
* - kStrict: only accept pre-existing loop_layout_ (no inference).
* - kCommon: allow inference from explicit buffer fragments when available.
* - kFree: attempt more aggressive inference (derive loop partition from
* read/write fragments, plan partitioning from vectorization/thread bounds, and
* add predicates to constrain replication when necessary).
*
* - kCommon (1): Intermediate level between strict and free.
* Allows common layout patterns while maintaining some
* constraints.
* This method may mutate the node's internal state (sets loop_layout_ when
* inferred and registers predicates via AddPredicate) and consults analyzer_
* for symbolic proofs.
*
* - kFree (0): Most permissive level. Allows maximum optimization freedom.
* Will attempt layout inference even without source buffers.
* Can generate new layouts based on vectorization and thread
* bounds. Used when maximum performance optimization is desired.
* @param T Container of auxiliary inputs used for inference (buffer_remap,
* layout_map, and thread_bounds). The function uses T.layout_map for source
* fragments and T.thread_bounds to bind thread-range information in inferred
* fragments.
* @param level Controls inference aggressiveness (kStrict, kCommon, kFree).
* @return LayoutMap A map of buffers to inferred Fragment layouts for buffers
* that did not already have layouts in T.layout_map. Returns an empty map when
* no inference was performed.
* @throws LayoutConflictException If a computed loop partition conflicts with
* an existing buffer fragment (incompatible thread mappings).
*/
LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
......@@ -384,6 +446,20 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
return results;
}
/**
* @brief Retrieve the loop's thread predicate with the thread variable
* substituted.
*
* If a predicate is set for this ParallelOpNode, returns a copy of that
* predicate where the placeholder input (InputPlaceholder(0)) is replaced by
* the provided thread_var. If no predicate is defined, returns an empty
* Optional.
*
* @param thread_var The thread loop variable to substitute for the predicate's
* input placeholder.
* @return Optional<PrimExpr> The substituted predicate expression, or
* std::nullopt if none is defined.
*/
Optional<PrimExpr> ParallelOpNode::GetPredicate(Var thread_var) const {
if (predicate_.defined()) {
return Substitute(predicate_.value(), {{InputPlaceholder(0), thread_var}});
......@@ -392,6 +468,32 @@ Optional<PrimExpr> ParallelOpNode::GetPredicate(Var thread_var) const {
}
}
/**
* @brief Construct the complete fragment layout for a buffer within the
* parallel loop.
*
* Given a buffer referenced inside the parallel loop, return a Fragment that
* maps the buffer's logical indices to the loop's thread space and replication
* extent.
*
* Detailed behavior:
* - Precondition: a loop layout (loop_layout_) must be defined.
* - If the buffer uses the common access indices of the loop, the loop's
* fragment is returned directly.
* - Otherwise, the function:
* - Computes the buffer's bijective index by appending the flattened
* replication expression for unused iterators.
* - Inverts that bijection to obtain the replication extent of the buffer's
* index space and combines it with the loop's replication extent to produce the
* destination replication extent.
* - Builds forward index placeholders for the buffer elements and maps them
* through the inverted layout and the loop layout to derive the thread binding.
* - Returns a Fragment with the computed thread binding and combined
* replication extent, with replicate variables condensed.
*
* @return Fragment The completed fragment describing thread binding and
* replication extent for `buffer`.
*/
Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
ICHECK(loop_layout_.defined());
if (IsCommonAccessIndice(buffer)) {
......
......@@ -13,6 +13,140 @@
#include "../transform/layout_reducer.h"
#include "./operator.h"
/**
* Exception representing a layout conflict detected during layout inference.
*
* Stores an explanatory message retrievable via what().
*/
/**
* Determine whether `small_frag` is guaranteed to be contained within
* `large_frag` under the given index mappings and using the provided arithmetic
* analyzer.
*
* @param small_frag The smaller fragment to test for containment.
* @param large_frag The larger fragment that may contain `small_frag`.
* @param small_frag_indices Index expressions mapping the small fragment into
* buffer space.
* @param large_frag_indices Index expressions mapping the large fragment into
* buffer space.
* @param analyzer_ Arithmetic analyzer used to simplify and prove index
* relations.
* @return true if containment can be proven; false otherwise.
*/
/**
* Visitor that traverses a parallel loop nest to collect buffer access and
* loop-structure information for a ParallelOpNode.
*
* The visitor records loop variables, buffer read/write accesses, and builds
* predicates as it encounters BufferLoad/BufferStore and For nodes.
*/
/**
* Represents a parallel for-loop operator in TileLang.
*
* Holds the root For loop, collects and exposes loop layout and access-index
* information, and provides layout inference and lowering to TIR.
*
* Public methods expose the inferred loop layout, root loop, buffer index
* mappings, and any per-thread predicate; Lower and InferLayout perform the
* operator's lowering and layout inference respectively.
*/
/**
* Create a ParallelOpNode from a root For loop.
*
* @param root The root For node representing the parallel loop nest.
*/
/**
* Lower this parallel operator into a TIR statement suitable for codegen.
*
* @param T Lowering arguments and context.
* @param analyzer Arithmetic analyzer for expression simplification during
* lowering.
* @return A TIR statement representing the lowered parallel loop.
*/
/**
* Infer the layout mapping for this parallel operator at the specified level.
*
* @param T Arguments and context for layout inference.
* @param level Inference granularity level.
* @return A LayoutMap describing inferred buffer/layout relationships for the
* operator.
*/
/**
* Copy-construct a ParallelOpNode, preserving inferred layout and predicate.
*/
/**
* Get the inferred loop layout fragment.
*
* @return The Fragment representing the loop's inferred layout (may be lazily
* computed).
*/
/**
* Get the root For loop of this operator.
*
* @return The root For AST node.
*/
/**
* Get the mapping from each buffer to the array of index expressions used to
* access it within the loop nest.
*
* @return A Map from Buffer to Array<PrimExpr> of access indices.
*/
/**
* Retrieve the predicate expression associated with a given thread variable, if
* any.
*
* @param thread_var The thread variable whose predicate is requested.
* @return An Optional<PrimExpr> containing the predicate when present.
*/
/**
* Create a deep copy of this operator as a TileOperator handle.
*
* @return A TileOperator that references a copy of this node.
*/
/**
* Visitor helper: complete the fragment layout for a buffer (internal).
*
* (Private helper — not part of the public API.)
*/
/**
* Helper to check whether a buffer's access indices are the common loop indices
* (internal).
*
* (Private helper — not part of the public API.)
*/
/**
* Add `expr` to the current predicate by logical AND; sets predicate if none
* exists.
*
* (Private helper — not part of the public API.)
*/
/**
* Thin handle type exposing ParallelOpNode as a TileOperator.
*
* Construct from a root For loop to create and own a ParallelOpNode instance.
*/
/**
* Construct a ParallelOp handle from a root For loop.
*
* @param root The root For node representing the parallel loop nest.
*/
namespace tvm {
namespace tl {
......
......@@ -22,6 +22,25 @@ namespace tl {
using namespace tir;
/**
* @brief Construct a ReduceOp from raw TL arguments and a buffer mapping.
*
* Interprets `args` and `vmap` to populate an internal ReduceOpNode:
* - args[0]: access pointer for the source buffer
* - args[1]: access pointer for the destination buffer
* - args[2]: string literal specifying the reduce type: "sum", "abssum",
* "absmax", "max", or "min"
* - args[3]: integer literal for the reduction dimension (axis)
* - args[4]: boolean literal indicating whether to clear/init the destination
*
* The constructor resolves the access pointers via `vmap`, maps the reduce
* type string to the ReduceType enum, assigns the reduction dimension and
* clear flag, and stores the constructed node in `data_`. An invalid reduce
* type triggers a fatal check.
*
* @param args Array of TL prim-expr arguments as described above.
* @param vmap Mapping from variables (from access pointers) to Buffer objects.
*/
ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<ReduceOpNode> node = make_object<ReduceOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])];
......@@ -44,16 +63,52 @@ ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
data_ = std::move(node);
}
/**
* @brief Create a copy of this ReduceOpNode wrapped as a TileOperator.
*
* Returns a new TileOperator holding a freshly allocated ReduceOpNode
* constructed as a copy of this node.
*
* @return TileOperator A tile operator that owns the cloned ReduceOpNode.
*/
TileOperator ReduceOpNode::Clone() const {
auto op = make_object<ReduceOpNode>(*this);
return ReduceOp(op);
}
/**
* @brief Create a deep copy of this CumSum op node wrapped as a TileOperator.
*
* Returns a new TileOperator whose underlying CumSumOpNode is a copy of
* the current node. Useful for cloning operators when building or
* transforming computation graphs.
*
* @return TileOperator A TileOperator containing a copy of this node.
*/
TileOperator CumSumOpNode::Clone() const {
auto op = make_object<CumSumOpNode>(*this);
return CumSumOp(op);
}
/**
* @brief Create the initial accumulator value for the destination buffer based
* on reduction type.
*
* Returns the PrimExpr representing the initial value stored in the destination
* accumulator before any source elements are combined. The returned value
* depends on the destination dtype and the node's reduction type:
* - kSum, kAbsSum: zero of the destination dtype.
* - kMax: minimum representable value for signed integers, zero for unsigned
* integers, and -INFINITY for floating point.
* - kMin: maximum representable value for signed integers, all-ones (max) for
* unsigned integers, and +INFINITY for floating point.
* - kAbsMax: zero of the destination dtype.
*
* The function will abort (ICHECK failure) if the reduction type is
* unrecognized.
*
* @return PrimExpr initial value appropriate for `dst->dtype` and `type`.
*/
PrimExpr ReduceOpNode::MakeInitValue() const {
auto dst_dtype = dst->dtype;
auto is_int = dst_dtype.is_int();
......@@ -88,6 +143,24 @@ PrimExpr ReduceOpNode::MakeInitValue() const {
}
}
/**
* @brief Combine two scalar expressions according to this node's reduction
* type.
*
* Casts the right operand to the left operand's dtype if they differ, then
* returns the reduction of `a` and `b` using the operator specified by `type`:
* - kSum: `a + b`
* - kAbsSum: `a + max(b, -b)`
* - kMax: `max(a, b)`
* - kMin: `min(a, b)`
* - kAbsMax: `max(max(a, b), -min(a, b))`
*
* @param a Left-hand operand (result dtype drives the output dtype).
* @param b Right-hand operand (will be cast to `a`'s dtype if needed).
* @return PrimExpr The combined expression with dtype equal to `a.dtype`.
*
* @note The function DCHECKs/ICHECKs on an unknown/unsupported reduction type.
*/
PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &a, const PrimExpr &b) const {
PrimExpr lhs = a, rhs = b;
if (lhs->dtype != rhs->dtype) {
......@@ -110,6 +183,20 @@ PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &a, const PrimExpr &b) const {
}
}
/**
* @brief Map the reduction type to the codegen reducer name used by external
* ALL-Reduce/CUDA helpers.
*
* Returns the string identifier of the code-generation reducer corresponding to
* this ReduceOpNode's `type`. Mapping:
* - kSum, kAbsSum -> "tl::SumOp"
* - kMax, kAbsMax -> "tl::MaxOp"
* - kMin -> "tl::MinOp"
*
* The function terminates with a check failure if `type` is unknown.
*
* @return std::string Reducer name used by codegen extern calls.
*/
std::string ReduceOpNode::MakeCodegenReducer() const {
switch (type) {
case ReduceType::kSum:
......@@ -128,6 +215,32 @@ std::string ReduceOpNode::MakeCodegenReducer() const {
}
}
/**
* @brief Lower the Reduce operator node to a TIR statement.
*
* Lowers a ReduceOpNode that targets fragment-local buffers into a sequence of
* TIR statements implementing: per-thread local reduction, inter-thread
* AllReduce (when needed), and final writeback (with an optional duplicate
* clear buffer to avoid in-place conflicts). Supports reduction kinds
* (sum/abs-sum/max/min/abs-max) and handles layout-driven index mapping and
* loop partitioning to thread axes.
*
* @param T Lowering context providing buffer remapping, layout map, target and
* thread bounds, and workspace allocation helper. Must contain
* fragment-local mappings for both src and dst.
* @param analyzer Symbolic analyzer used to simplify and compress iterators.
* @return Stmt The constructed TIR statement implementing the reduction.
*
* Preconditions:
* - src and dst buffers must be in "local.fragment" scope.
* - The layouts must have compatible input/output dimensions for the
* specified reduction axis.
*
* Failure modes:
* - The function uses ICHECK to enforce unsupported scopes, dimension
* mismatches, unknown reduction types, and other invariants; violations
* will trigger a fatal check failure.
*/
Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(this->src.scope() == "local.fragment" &&
this->dst.scope() == "local.fragment")
......@@ -296,6 +409,38 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return body;
}
/**
* @brief Infer a layout mapping for the destination buffer of a Reduce
* operator.
*
* When inference level is below `kStrict`, and both source and destination
* buffers live in `local.fragment` with a known source fragment layout, this
* computes a candidate destination Fragment layout that accounts for
* replication over the reduction dimension and binds thread ranges from
* `T.thread_bounds`.
*
* Behavior:
* - Constructs a destination Fragment whose replicate extent equals
* src.shape[dim] * src_fragment.ReplicateExtent(), and whose threading is
* derived from the source fragment with the reduction dimension folded out.
* - If no layout exists for `dst` in `T.layout_map`, returns a map {dst ->
* inferred}.
* - If `dst` already has a layout, validates that the existing layout strictly
* contains the computed layout (shapes match and fragment containment holds).
* If compatible but the computed replicate extent is larger, returns the new
* layout.
* - In all other cases (strict inference level, unsupported scopes, or no src
* layout), returns an empty map.
*
* @param T Layout inference context containing `layout_map` and
* `thread_bounds`.
* @param level Inference strictness; no inference is performed at or above
* `kStrict`.
* @return LayoutMap A mapping for `dst` to an inferred Fragment layout, or
* empty.
* @throws LayoutConflictException if an existing `dst` layout conflicts with
* the computed layout (not containable or incompatible replication extents).
*/
LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (level >= InferLevel::kStrict)
......@@ -373,6 +518,22 @@ TIR_REGISTER_TL_OP(ReduceOp, reduce)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
/**
* @brief Construct a CumSumOp from a list of arguments and a buffer map.
*
* Expects args to contain exactly four PrimExprs in this order:
* 0: access pointer to source buffer (src),
* 1: access pointer to destination buffer (dst),
* 2: integer dimension to perform the cumulative sum along (dim),
* 3: boolean flag indicating whether to compute the cumsum in reverse
* (reverse).
*
* The constructor resolves src and dst from the provided BufferMap and stores
* the parsed dim and reverse values on the node. It verifies that args.size()
* == 4 and that dim is a valid axis for the source buffer shape.
*
* @param args Array of PrimExpr as described above.
*/
CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
/*
CumSum arguments:
......@@ -391,6 +552,28 @@ CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
data_ = std::move(node);
}
/**
* @brief Lower the CumSum operator to TIR.
*
* Produces a TIR statement implementing cumulative sum depending on buffer
* scopes:
* - For shared/shared.dyn scopes: emits an extern call to
* `tl::CumSum2D<threads, dim, reverse>::run` with arguments [function_name,
* src.access_ptr(1), dst.access_ptr(3), src.shape...]. The number of threads is
* taken from `T.thread_bounds->extent`. Returns an Evaluate(Call(...))
* statement.
* - For local.fragment scopes on both src and dst: fatal error (not
* implemented).
* - For any other scope combinations: fails with an assertion.
*
* The `analyzer` parameter is accepted for interface compatibility but is not
* used by this lowering.
*
* @param T Lowering arguments (provides thread bounds and other lowering
* context).
* @return Stmt A TIR statement representing the lowered cumulative-sum
* operation.
*/
Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (this->src.scope() == "local.fragment" &&
this->dst.scope() == "local.fragment") {
......@@ -417,6 +600,17 @@ Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return Stmt();
}
/**
* @brief Layout inference for CumSum operator.
*
* CumSum does not perform any layout inference; this function always returns
* an empty mapping. The operator's lowering expects shared-memory semantics
* and layout decisions are handled elsewhere.
*
* @param T Layout inference inputs (buffers, existing layouts, etc.).
* @param level Inference strictness level (unused).
* @return LayoutMap Empty map indicating no inferred layouts.
*/
LayoutMap CumSumOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
return {};
......
......@@ -10,6 +10,146 @@
#include "operator.h"
namespace tvm {
/**
* Tile operator node that performs a reduction (sum, max, min, etc.) along a
* single tensor dimension.
*
* Represents a per-instance reduce operator with explicit source/destination
* buffers, target dimension, reduction type, and a flag controlling whether the
* destination is cleared before reduction.
*/
/**
* Lower this ReduceOpNode into a Tir Stmt suitable for code generation.
*
* Produces the TIR statement(s) that implement the configured reduction.
*
* @return A TIR `Stmt` implementing the reduce operation.
*/
/**
* Infer input/output layouts for this reduce operator.
*
* Returns a LayoutMap describing how input and output buffer layouts relate
* for the configured reduction dimension.
*
* @param level Inference detail level that may affect how aggressively layouts
* are inferred.
* @return A LayoutMap mapping operator arguments to inferred layouts.
*/
/**
* Retrieve the global operator descriptor for the reduce operator.
*
* @return A reference to the Op descriptor corresponding to this operator type.
*/
/**
* Create a copy of this reduce operator as a TileOperator handle.
*
* The returned TileOperator preserves the node's configuration (buffers, dim,
* type, clear).
*
* @return A TileOperator wrapping a cloned ReduceOpNode.
*/
/**
* Construct the initial value used by the reduction (e.g., 0 for sum, -inf for
* max).
*
* @return A PrimExpr representing the reduction's identity/init value.
*/
/**
* Combine two partial values according to the configured reduction.
*
* Implements the binary reducer (for example, `a + b` for sum or `max(a, b)`
* for max).
*
* @return A PrimExpr representing the reduced result of `a` and `b`.
*/
/**
* Generate a string snippet suitable for code generation of the reducer
* expression.
*
* The returned code fragment should implement the binary reduction operation in
* the target backend's code string form.
*
* @return A std::string containing the codegen expression for the reducer.
*/
/**
* Reference wrapper for ReduceOpNode as a TileOperator.
*
* Construct a ReduceOp from explicit arguments and a buffer map.
*/
/**
* Construct a ReduceOp TileOperator from operator arguments and a buffer
* mapping.
*
* @param args Operator arguments (typically shapes, axes, or other prim exprs).
* @param vmap Mapping from argument names to tir::Buffer instances used by the
* operator.
*/
/**
* Tile operator node that computes a cumulative sum along a single tensor
* dimension.
*
* Contains source/destination buffers, the target dimension, and a flag to
* compute the cumulative sum in reverse order.
*/
/**
* Lower this CumSumOpNode into a Tir Stmt suitable for code generation.
*
* Produces the TIR statement(s) that implement the configured cumulative-sum.
*
* @return A TIR `Stmt` implementing the cum-sum operation.
*/
/**
* Infer input/output layouts for this cumulative-sum operator.
*
* Returns a LayoutMap describing how input and output buffer layouts relate
* for the configured cumulative-sum dimension.
*
* @param level Inference detail level that may affect how aggressively layouts
* are inferred.
* @return A LayoutMap mapping operator arguments to inferred layouts.
*/
/**
* Retrieve the global operator descriptor for the cumulative-sum operator.
*
* @return A reference to the Op descriptor corresponding to this operator type.
*/
/**
* Create a copy of this cum-sum operator as a TileOperator handle.
*
* The returned TileOperator preserves the node's configuration (buffers, dim,
* reverse).
*
* @return A TileOperator wrapping a cloned CumSumOpNode.
*/
/**
* Reference wrapper for CumSumOpNode as a TileOperator.
*
* Construct a CumSumOp from explicit arguments and a buffer map.
*/
/**
* Construct a CumSumOp TileOperator from operator arguments and a buffer
* mapping.
*
* @param args Operator arguments (typically shapes, axes, or other prim exprs).
* @param vmap Mapping from argument names to tir::Buffer instances used by the
* operator.
*/
namespace tl {
using namespace tir;
......
......@@ -11,6 +11,26 @@ namespace tvm {
namespace tl {
using namespace tir;
/**
* @brief Construct a RegionOp from TL operator arguments.
*
* Parses the TL `region` operator call arguments to populate the RegionOpNode:
* - Expects args[0] to be a `BufferLoad` whose `indices` are the per-dimension
* minima.
* - args[1] must be a constant integer used as the access mask.
* - args[2 + i] provides the extent for dimension `i`.
*
* The constructor validates that the number of load indices equals `args.size()
* - 2` and will abort via ICHECK on mismatch or if args[0] is not a
* `BufferLoad`.
*
* Parameters:
* - args: TL operator call arguments in the form
* [BufferLoad(min_i...), access_mask, extent_0, extent_1, ...,
* extent_{n-1}] where n = number of dimensions.
* - vmap: BufferMap passed through by the caller (not documented here as a
* generic utility).
*/
RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
size_t n = args.size();
size_t ndim = n - 2;
......@@ -31,11 +51,26 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
data_ = std::move(node);
}
/**
* @brief Create a copy of this RegionOpNode and return it as a TileOperator.
*
* @return TileOperator A new TileOperator that owns a copied RegionOpNode.
*/
TileOperator RegionOpNode::Clone() const {
auto op = make_object<RegionOpNode>(*this);
return RegionOp(op);
}
/**
* @brief Check whether the region spans the entire underlying buffer.
*
* Returns true if for every dimension the range minimum is zero and the
* range extent is structurally equal to the corresponding buffer shape
* dimension. Otherwise returns false.
*
* @return true if the region covers the full buffer in all dimensions; false
* otherwise.
*/
bool RegionOpNode::IsFullRegion() const {
for (size_t i = 0; i < ranges_.size(); i++) {
if (!is_zero(ranges_[i]->min))
......@@ -46,10 +81,33 @@ bool RegionOpNode::IsFullRegion() const {
return true;
}
/**
* @brief Lower the region operator to a TIR statement.
*
* Lowers this RegionOpNode into a TIR Stmt by delegating to the operator's
* evaluation path (currently `Evaluate(0)`).
*
* @param T Lowering context (provides buffers, producers/consumers and other
* environment required for lowering).
* @param analyzer Optional arithmetic analyzer used for simplification during
* lowering.
* @return Stmt The lowered TIR statement representing this region operation.
*/
Stmt RegionOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return Evaluate(0);
}
/**
* @brief Infers data layout for the region operator.
*
* This operator does not provide any layout inference; the function always
* returns an empty LayoutMap regardless of the provided arguments or inference
* level.
*
* @param T Layout inference arguments (ignored).
* @param level Inference granularity level (ignored).
* @return LayoutMap Empty map indicating no inferred layouts.
*/
LayoutMap RegionOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
return {};
......
......@@ -13,6 +13,62 @@
#include <tvm/target/target.h>
#include <tvm/tir/buffer.h>
/**
* Tile operator representing a memory region (buffer + ranges) used by TL
* passes.
*
* Encapsulates the target tir::Buffer, the region extents as an Array<Range>,
* and an access mask that indicates permitted or intended accesses for lowering
* and layout inference.
*/
/**
* Lower this RegionOp into a TIR statement representing the region access.
*
* @param T Lowering-time arguments (e.g., loop/build context and value
* mappings).
* @param analyzer Arithmetic analyzer used to simplify and reason about
* expressions.
* @return A tir::Stmt that implements the region access/mutation described by
* this operator.
*/
/**
* Infer the layout mapping for this region operator.
*
* Produces a LayoutMap describing how loop/axis indices map to buffer axes for
* layout-aware scheduling and subsequent operators.
*
* @param T Layout inference arguments (e.g., input layouts and shapes).
* @param level The inference detail level to use.
* @return A LayoutMap describing inferred mappings for the operator.
*/
/**
* Return true when this RegionOp represents the full buffer region (i.e.,
* ranges cover the entire buffer extent).
*/
/**
* Create a shallow copy of this operator as a TileOperator handle.
*
* @return A TileOperator that references a cloned RegionOpNode.
*/
/**
* Construct a RegionOp from argument expressions and a buffer map.
*
* @param args Positional expressions used to instantiate the operator
* (semantics depend on how RegionOp is invoked in TL pipelines).
* @param vmap Mapping from Buffer to replacement Buffer or buffer metadata used
* during creation.
*/
/**
* Return the global Op registration for RegionOp.
*
* @return Reference to the registered tvm::Op describing the RegionOp.
*/
namespace tvm {
namespace tl {
......
......@@ -64,6 +64,37 @@ public:
BufferUseDefCollector(bool skip_thread_partition)
: skip_thread_partition_(skip_thread_partition) {}
/**
* @brief Execute a single layout-inference step for the infer node at the
* given index.
*
* Runs InferLayout on the TileOperator at cur_infer_id with the provided
* InferLevel and thread bounds, applies returned buffer->layout updates into
* layout_map (respecting strict_layout_map constraints for fragment buffers),
* and optionally propagates changes to dependent infer nodes by enqueueing
* them into q and marking in_queue.
*
* The function mutates layout_map and, when update_queue is true, may modify
* q and in_queue. It performs internal sanity checks via ICHECK and will
* LOG(WARNING) for buffers that cannot be propagated; ICHECK failures abort
* execution.
*
* @param cur_infer_id Index of the infer operator in infer_list_ to run (must
* be within range).
* @param level Inference relaxation level to pass to the operator's
* InferLayout.
* @param update_queue If true, discovered layout changes will enqueue
* dependent infer nodes.
* @param layout_map Mutable map of inferred layouts that will be updated with
* returned layouts.
* @param strict_layout_map Read-only map of layouts produced in the strict
* phase; used to enforce containment checks for local.fragment buffers when
* relaxing.
* @param q BFS queue used to propagate dependent inference indices; new
* indices may be pushed.
* @param in_queue Parallel boolean vector tracking queued status; entries
* corresponding to enqueued indices will be set to true.
*/
void RunInferStep(int cur_infer_id, InferLevel level, bool update_queue,
LayoutMap &layout_map, const LayoutMap &strict_layout_map,
std::queue<int> &q, std::vector<bool> &in_queue) {
......@@ -190,6 +221,30 @@ public:
}
};
/**
* @brief Run the multi-stage layout inference and return the collected
* results.
*
* Performs layout inference over the collected TileOperator entries in three
* phases: (1) strict per-operator inference, (2) common inference via a BFS
* propagation queue, and (3) a free-mode relaxation phase that explores
* alternative root orderings within connected components to reduce register
* footprint. After inference completes, verifies that all local.fragment
* buffers have inferred layouts and collects loop (For) -> Fragment layouts
* and any per-loop predicates discovered during inference.
*
* The method consumes/permutes internal inference state (notably moves
* entries out of infer_list_) and returns a LayoutInferenceResult containing:
* - layout_map: inferred Layout for each Buffer,
* - for_map: mapping from For nodes to their inferred Fragment layout,
* - predicate_map: optional loop predicates keyed by For nodes.
*
* The function performs internal consistency checks (ICHECK) on sizes and
* required definitions; violations will terminate via ICHECK failure.
*
* @return LayoutInferenceResult A tuple-like struct with the inferred
* layout_map, for_map, and predicate_map.
*/
LayoutInferenceResult Run() {
// Basic consistency check: infer_list_ and thread_var_vec_ should have the
// same size
......@@ -293,6 +348,23 @@ public:
}
private:
/**
* @brief Visits a Call expression to collect tile-operator-based inference
* inputs.
*
* Processes non-global function calls by parsing them into a TileOperator
* (via ParseOperator). If the parse succeeds, records:
* - buffers referenced by call arguments into the collector's use lists,
* - the call AST node into infer_list_stmt_,
* - the parsed TileOperator into infer_list_,
* - the current thread IterVar into thread_var_vec_,
* - the thread iteration bounds into thread_bounds_vec_ (uses analyzer const
* bounds when available; otherwise [0,1]).
*
* Calls to global functions (where op->op is a GlobalVar) are ignored.
*
* @param op The Call node being visited.
*/
void VisitExpr_(const CallNode *op) final {
IRVisitorWithAnalyzer::VisitExpr_(op);
// Do not analysis the call node to the global function.
......@@ -345,6 +417,25 @@ private:
use_list_[buffer].push_back(infer_idx);
}
/**
* @brief Handles For nodes during IR traversal.
*
* When the loop is a parallel loop (ForKind::kParallel), records it as a
* ParallelOp:
* - constructs a ParallelOp for the loop and appends it to the internal infer
* lists (infer_list_ and infer_list_stmt_),
* - registers all buffers referenced by the loop indices with use-list
* bookkeeping,
* - captures the current thread IterVar context and its compile-time extent
* (if available) into thread_var_vec_ and thread_bounds_vec_ (falls back to
* range [0,1] when unknown).
*
* For non-parallel loops, continues recursive traversal into the loop body.
*
* Side effects:
* - Mutates infer_list_, infer_list_stmt_, use_list_ (via addToUseList),
* thread_var_vec_, and thread_bounds_vec_.
*/
void VisitStmt_(const ForNode *op) final {
if (op->kind == ForKind::kParallel) {
auto infer = ParallelOp(GetRef<For>(op));
......@@ -415,6 +506,15 @@ private:
LayoutMap annotated_layout_map_;
bool skip_thread_partition_{false};
/**
* @brief Create a deep copy of the current inference operator list.
*
* Returns a vector containing clones of each TileOperator in the collector's
* internal infer_list_. The returned list is independent of the original so
* subsequent modifications to either do not affect the other.
*
* @return std::vector<TileOperator> Cloned copy of infer_list_.
*/
std::vector<TileOperator> BackupInferList() {
std::vector<TileOperator> back_infer_list;
back_infer_list.reserve(infer_list_.size());
......@@ -424,6 +524,48 @@ private:
return back_infer_list;
}
/**
* @brief Explore alternative inference orders within connected components to
* relax layouts.
*
* This function performs a "free-mode" exploration that attempts different
* root operators within each connected component of the operator-use graph in
* order to find a layout assignment with lower register (fragment) usage.
*
* Detailed behavior:
* - Builds connected components of infer_list_ by unioning operators that
* share buffer uses (use_list_).
* - For each component, iterates each member operator as a candidate root:
* - Backups the current infer_list_ and uses a temporary copy of
* layout_map.
* - Runs RunInferStep and FinishInferQueue in InferLevel::kFree starting
* from the candidate root and then (as a fallback) runs the remaining members
* to try to cover the whole component.
* - If inference succeeds, computes a coarse register usage metric by
* summing the product of OutputShape dimensions for all Fragment layouts
* in the temporary layout map.
* - Tracks the candidate that yields the smallest register usage.
* - If a better plan is found for a component, replaces the global
* infer_list_ and updates layout_map with the best layout_map found.
*
* Side effects:
* - Mutates layout_map to the best-found free-mode layout assignment when a
* better plan is discovered.
* - Mutates the member infer_list_ (backed up and restored during attempts;
* finally set to the best plan if found).
*
* Notes:
* - LayoutConflictException and NormalizeIterException raised during attempts
* are caught and treated as failed attempts; they do not propagate out of
* this function.
* - The register-usage metric is a heuristic (sum of fragment output element
* counts) used to prefer less-replicated layouts.
*
* @param layout_map[in,out] The current global layout map to be updated with
* a better free-mode result if found.
* @param strict_layout_map Read-only map of layouts inferred in strict mode,
* used to constrain free-mode inference.
*/
void InferInFreeMode(LayoutMap &layout_map,
const LayoutMap &strict_layout_map) {
// Group operators into connected components
......
......@@ -464,6 +464,32 @@ private:
return var;
}
/**
* @brief Handle an Evaluate node, lowering a detected tile operator to TIR.
*
* This visit implementation detects whether the Evaluate node represents a
* tile operator invocation (via ParseOperator). If no tile operator is found
* or the call targets a global function, the node is delegated to the base
* visitor.
*
* When a tile operator is present, the method:
* - Builds a workspace-allocation callback that creates a dynamic shared
* buffer named "workspace" (storage scope "shared.dyn") and returns its write
* access pointer.
* - Determines thread bounds for lowering from the analyzer's constant-int
* information for thread_var_; if unavailable, a default range [0,1) is
* used.
* - Invokes tile_op->Lower(...) with LowerArgs containing target, thread
* bounds, thread variable, the workspace callback, layout and buffer remap
* maps, and the list of GEMM-involved buffer vars; the analyzer is passed
* through for use during lowering.
*
* The lowered statement returned by the operator is then visited by the base
* IRMutatorWithAnalyzer and that result is returned.
*
* @return Stmt The (possibly transformed) statement after lowering or base
* visitor processing.
*/
Stmt VisitStmt_(const EvaluateNode *op) final {
// LOG(INFO) << "evaluate node: " << op->value;
const CallNode *call = op->value.as<CallNode>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment