Unverified Commit 3cfefc8e authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Support python reflection for tile operators (#783)

* Implement Fill operator and related reflection methods in TileLang

- Added Fill operator implementation in `fill.cc` and `fill.h` for element-wise filling of buffers.
- Introduced reflection methods for Fill, AtomicAdd, Copy, Conv2DIm2Col, FinalizeReducer, Gemm, and Parallel operators to enhance introspection capabilities.
- Updated relevant files to register reflection methods and ensure proper initialization in static blocks.
- Removed outdated comments and unnecessary code in various operator files to improve clarity and maintainability.
- Added new Python bindings for the Fill operator in `tilelang/ir/fill.py` and updated the module imports accordingly.

* Refactor operator reflection methods and improve code clarity

- Updated reflection methods for AtomicAdd, Copy, FinalizeReducer, Gemm, and Parallel operators to enhance readability by using `empty()` instead of size checks.
- Consolidated static initialization blocks for various operators to a single line for improved consistency.
- Cleaned up whitespace and formatting in multiple files to adhere to coding standards and improve maintainability.
- Added new Python bindings for operators in the `tilelang/ir` module, ensuring proper registration and organization of imports.

* Refactor GEMM and AtomicAdd operations for improved clarity

- Updated the `GetArchInt` function in `atomic_add.cc` to use `std::string` and `std::stoi` for better readability and type safety.
- Removed unnecessary variables and comments in `gemm_sp.cc` and `gemm.cc` to streamline the `ComputeWarpPartition` method.
- Cleaned up the `layout_reducer.cc` file by removing unused variable declarations, enhancing code clarity.
- Added import for the `ir` module in `tilelang/__init__.py` to ensure proper organization of module imports.

* Remove deprecated operator files from the tilelang IR module

- Deleted files for Fill, AtomicAdd, Copy, Gemm, GemmSP, FinalizeReducer, Parallel, Reduce, and Region operators to streamline the codebase.
- This cleanup enhances maintainability by removing unused code and improving overall organization of the module.

* Refactor imports in tilelang IR module for improved organization

- Updated import statements in `tilelang/ir.py` to reflect changes in the TVM library structure, enhancing clarity and maintainability of the codebase.

* lint fix

* Refactor GEMM and GEMM-SP operations to enhance clarity and maintainability

- Updated the `Gemm` and `GemmSP` classes to utilize a new `GemmWarpPolicy` object for warp partitioning, improving encapsulation and readability.
- Removed deprecated `ComputeWarpPartition` methods and replaced them with calls to the new policy object, streamlining the code.
- Cleaned up comments and unnecessary code in `gemm.cc`, `gemm_sp.cc`, and related header files to enhance overall clarity.
- Introduced a new `GemmWarpPolicyNode` class to manage warp policy attributes and methods, facilitating better organization of related functionalities.
- Updated reflection methods to include the new policy structure, ensuring proper registration and introspection capabilities.

* Refactor Reduce operation to utilize ReduceType class for improved clarity and maintainability

- Replaced multiple conditional checks for reduce types with a single ReduceType object, simplifying the code structure.
- Introduced a new ReduceTypeNode class to encapsulate reduce type logic and methods, enhancing organization.
- Updated MakeInitValue, MakeReduce, and Lower methods to leverage the new ReduceType class, improving readability.
- Added Python bindings for the ReduceType class in tilelang IR module to ensure proper registration and usability.

* comment

* Refactor operator header files for improved readability

- Cleaned up formatting and whitespace in `atomic_add.h`, `copy.h`, `fill.h`, `reduce.cc`, and `reduce.h` to enhance code clarity.
- Consolidated comments and adjusted line breaks for better organization and maintainability across multiple operator definitions.

* Refactor MakeReduce method in ReduceOpNode for clarity

- Updated the parameter name in the MakeReduce method from `rhs` to `b` and assigned it to `rhs` for improved readability.
- This change enhances the clarity of the method's purpose and aligns with the overall refactoring efforts in the Reduce operation.

* Update Reduce operation type checks for consistency

- Changed string comparisons for reduce types in the MakeReduce method from "abs_sum" to "abssum" and "abs_max" to "absmax" for uniformity.
- This adjustment enhances the clarity and consistency of the reduce type handling in the codebase.
parent 141e01fb
......@@ -37,9 +37,9 @@ static int GetArchInt(Target target) {
int arch_int = 0;
auto s = target->GetAttr<String>("arch");
ICHECK(s.defined());
const char *arch_str = s.value().c_str();
if (arch_str[0] == 's' && arch_str[1] == 'm' && arch_str[2] == '_') {
arch_int = atoi(&arch_str[3]);
std::string arch = s.value();
if (arch.rfind("sm_", 0) == 0) {
arch_int = std::stoi(arch.substr(3));
} else {
arch_int = 0;
}
......@@ -255,7 +255,7 @@ PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
*/
For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<IterVar> loop_vars = MakeIterVars();
bool is_scalar = loop_vars.size() == 0;
bool is_scalar = loop_vars.empty();
if (is_scalar) {
return For(Var("i"), 0, 1, ForKind::kSerial,
BufferStore(dst, BufferLoad(src, {0}), {0}));
......@@ -425,5 +425,7 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ AtomicAddNode::RegisterReflection(); });
} // namespace tl
} // namespace tvm
\ No newline at end of file
/*!
* \file tl/op/atomic_add.h
* \brief Define atomic add operator.
*
* \brief Atomic addition operations for concurrent memory updates
*/
#ifndef TVM_TL_OP_ATOMIC_ADD_H_
......@@ -10,91 +9,20 @@
#include "operator.h"
#include "parallel.h"
/**
* Lower this tile operator into a TIR statement for the given lowering context.
*
* @param T Lowering context containing mapped buffers and iteration
* information.
* @param analyzer Arithmetic analyzer used to simplify and reason about
* expressions.
* @return A TIR Stmt that implements the atomic-add tile operation for the
* provided context.
*/
/**
* Infer memory/layout mapping for tensors and buffers used by this operator.
*
* @param T Layout inference context providing buffer and shape information.
* @param level Inference aggressiveness level; higher levels may perform more
* speculative decisions.
* @return A LayoutMap describing inferred layouts for the operator's inputs and
* outputs.
*/
/**
* Get the Op registration that identifies this tile operator.
*
* @return A reference to the registered Op representing this operator.
*/
/**
* Create a deep copy of this tile operator node wrapped as a TileOperator.
*
* @return A TileOperator handle owning a cloned AtomicAddNode.
*/
/**
* Construct a SIMT-style For loop nest (thread/block mapping) appropriate for
* the operator.
*
* @param analyzer Arithmetic analyzer used to simplify loop bounds and
* predicates.
* @return A For loop node representing the SIMT-parallel loop structure.
*/
/**
* Create iteration variables used by this operator's loop nest.
*
* @return An array of IterVar objects describing the loop iteration axes.
*/
/**
* Produce index expressions for either source or destination buffer access
* based on iteration vars.
*
* @param ivs IterVars created by MakeIterVars().
* @param src_dst Selects which indices to produce: 0 for source indices, 1 for
* destination indices.
* @return An array of PrimExpr index expressions suitable for indexing the
* selected buffer.
*/
/**
* Build a predicate expression that guards out-of-bounds or conditional
* accesses for src or dst.
*
* @param analyzer Arithmetic analyzer used to simplify the predicate.
* @param ivs IterVars created by MakeIterVars().
* @param extents The loop extents corresponding to the itervars.
* @param src_dst Selects which side the predicate is for: 0 for source, 1 for
* destination.
* @return A PrimExpr boolean predicate that evaluates to true for valid
* iterations.
*/
/**
* Construct an AtomicAdd tile operator from operation arguments and a buffer
* mapping.
*
* @param args Operation arguments (e.g., values or indices) specific to the
* atomic-add semantics.
* @param vmap Mapping from buffer names to Buffer objects used by this
* operator.
*/
namespace tvm {
namespace tl {
using namespace tir;
/// Node class for atomic addition operations
class AtomicAddNode : public TileOperatorNode {
public:
Buffer src, dst;
Array<Range> src_range, dst_range;
IntImm coalesced_width;
Buffer src, dst; ///< Source and destination buffers
Array<Range> src_range,
dst_range; ///< Access ranges for source and destination
IntImm coalesced_width; ///< Width for memory coalescing optimization
mutable ParallelOp par_op_;
mutable ParallelOp par_op_; ///< Associated parallel operation
static constexpr const char *_type_key = "tl.AtomicAdd";
TVM_DECLARE_FINAL_OBJECT_INFO(AtomicAddNode, TileOperatorNode);
......@@ -104,18 +32,47 @@ public:
static const Op &Get();
TileOperator Clone() const;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<AtomicAddNode>()
.def_ro("src", &AtomicAddNode::src)
.def_ro("dst", &AtomicAddNode::dst)
.def_ro("src_range", &AtomicAddNode::src_range)
.def_ro("dst_range", &AtomicAddNode::dst_range)
.def_ro("coalesced_width", &AtomicAddNode::coalesced_width);
}
bool SEqualReduce(const AtomicAddNode *other, SEqualReducer equal) const {
return equal(src, other->src) && equal(dst, other->dst) &&
equal(src_range, other->src_range) &&
equal(dst_range, other->dst_range) &&
equal(coalesced_width, other->coalesced_width);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(src);
hash_reduce(dst);
hash_reduce(src_range);
hash_reduce(dst_range);
hash_reduce(coalesced_width);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
protected:
/// Create SIMT-style parallel loop structure
For MakeSIMTLoop(arith::Analyzer *analyzer) const;
/// Generate iteration variables for loop nest
Array<IterVar> MakeIterVars() const;
// ivs: itervars returned by MakeIterVars()
// src_dst: 0 for src_indices, 1 for dst_indices
/// Generate buffer indices from iteration variables
Array<PrimExpr> MakeIndices(const Array<IterVar> &ivs, int src_dst) const;
/// Create boundary predicate for memory safety
PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const;
};
/// Wrapper class for atomic addition operations
class AtomicAdd : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(AtomicAdd, TileOperator, AtomicAddNode);
......
......@@ -297,7 +297,7 @@ PrimExpr CopyNode::MakePredicate(arith::Analyzer *analyzer,
*/
For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<IterVar> loop_vars = MakeIterVars();
bool is_scalar = loop_vars.size() == 0;
bool is_scalar = loop_vars.empty();
if (is_scalar) {
return For(Var("i"), 0, 1, ForKind::kSerial,
BufferStore(dst, BufferLoad(src, {0}), {0}));
......@@ -1197,7 +1197,7 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
int swizzle;
int max_dim;
};
static const SwizzleCheck swizzle_checks[] = {
static const std::vector<SwizzleCheck> swizzle_checks = {
{static_cast<int>(CU_TENSOR_MAP_SWIZZLE_32B), 32},
{static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B), 64},
{static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B), 128},
......@@ -1559,5 +1559,9 @@ TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({
CopyNode::RegisterReflection();
Conv2DIm2ColOpNode::RegisterReflection();
});
} // namespace tl
} // namespace tvm
\ No newline at end of file
/*!
* \file tl/op/elem.h
* \brief Define element-wise and copy-related operators for TVM TensorIR
* Lowering.
*
* This header declares the Copy operator and related operator descriptors
* such as TMADesc and TMAIm2ColDesc, as well as a Conv2DIm2Col special
* operator.
* \file tl/op/copy.h
* \brief Copy operations and Tensor Memory Access (TMA) descriptors
*/
#ifndef TVM_TL_OP_COPY_H_
......@@ -18,42 +13,30 @@ namespace tvm {
namespace tl {
using namespace tir;
/*!
* \brief Copy instruction type.
*/
/// Copy instruction types for different memory access patterns
enum class CopyInst : uint8_t {
kNormal = 0, // utilize ldg/stg or cpasync or any buffer copy
kLDSM = 1, // ldmatrix memory copy
kSTSM = 2, // stmatrix memory copy
kBulkLoad = 3, // utilize tma load
kBulkStore = 4, // utilize tma store
kNormal = 0, ///< Standard memory copy (ldg/stg/cpasync)
kLDSM = 1, ///< Load matrix instruction
kSTSM = 2, ///< Store matrix instruction
kBulkLoad = 3, ///< Tensor Memory Access load
kBulkStore = 4, ///< Tensor Memory Access store
};
/*!
* \brief Descriptor for Tensor Memory Access (TMA) copy operations.
*
* Contains meta-information required to perform global-to-shared memory copy
* using Tensor Memory Accelerator (TMA) hardware instructions. It is mainly
* used to describe the shape, strides, and data layout for both source and
* shared memory buffers.
*/
/// Descriptor for Tensor Memory Access (TMA) copy operations
struct TMADesc {
size_t rank; // Tensor rank (number of dimensions)
int data_type; // Data type identifier (numeric code)
Array<PrimExpr> global_shape; // Shape of the source tensor in global memory
Array<PrimExpr>
global_stride; // Strides of the source tensor in global memory
Array<PrimExpr> smem_box; // Block shape in shared memory
Array<PrimExpr> smem_stride; // Strides in shared memory layout
PrimExpr global_addr; // Base address in global memory
int swizzle; // Swizzle parameter for memory layout transform
int interleave; // Interleave parameter for optimization
int oob_fill; // Out-of-bound fill policy
int l2_promotion; // Whether to promote data to L2 cache
/*!
* \brief Encode descriptor fields into an argument array for runtime calls.
*/
size_t rank; ///< Tensor rank (number of dimensions)
int data_type; ///< Data type identifier
Array<PrimExpr> global_shape; ///< Shape in global memory
Array<PrimExpr> global_stride; ///< Strides in global memory
Array<PrimExpr> smem_box; ///< Block shape in shared memory
Array<PrimExpr> smem_stride; ///< Strides in shared memory
PrimExpr global_addr; ///< Base address in global memory
int swizzle; ///< Memory layout swizzle parameter
int interleave; ///< Memory interleave parameter
int oob_fill; ///< Out-of-bound fill policy
int l2_promotion; ///< L2 cache promotion flag
/// Encode descriptor fields into runtime call arguments
Array<PrimExpr> EncodeCallArgs() const;
};
......@@ -87,215 +70,6 @@ struct TMAIm2ColDesc {
Array<PrimExpr> EncodeCallArgs() const;
};
/*!
* \brief Copy operator for transferring data between buffers.
*
* Performs element- or block-wise copies between `src` and `dst` buffers for
* TensorIR lowering. The operator supports thread-level parallelization,
* shared-memory layouts, and hardware-accelerated paths (TMA/LDSM/STMATRIX)
* when available. Public fields describe the copy ranges and tuning knobs
* (coalesced width, eviction policy, disable_tma).
*/
/*!
* \brief Lower the copy operator to a TIR statement.
*
* Produces a TIR statement implementing the configured copy (normal, LDSM,
* STSM, or bulk TMA-based) for the given lowering context.
*
* \param T Lowering arguments that provide buffer bindings and context.
* \param analyzer Analyzer used for expression simplification and bounds
* checks. \return A TIR `Stmt` implementing the copy.
*/
/*!
* \brief Infer buffer layouts after applying this operator.
*
* Computes resulting layouts (shape/stride mappings) for buffers affected by
* this copy operation.
*
* \param T Arguments for layout inference (buffer maps, shapes).
* \param level Granularity of inference to perform.
* \return A LayoutMap describing inferred layouts.
*/
/*!
* \brief Check if bulk global->shared copy is supported on the target.
*
* Returns true if the target supports bulk (TMA) loads from global memory.
*
* \param target Target to query.
*/
/*!
* \brief Check if bulk shared->global store is supported on the target.
*
* Returns true if the target supports bulk (TMA) stores to global memory.
*
* \param target Target to query.
*/
/*!
* \brief Check if LDSM (LDMATRIX) memory-copy is supported on the target.
*
* \param target Target to query.
*/
/*!
* \brief Check if STSM (STMATRIX) memory-copy is supported on the target.
*
* \param target Target to query.
*/
/*!
* \brief Select the copy instruction type to use.
*
* Chooses between kNormal, kLDSM, kSTSM, kBulkLoad, and kBulkStore based on
* the target capabilities and whether TMA lowering is disabled.
*
* \param target Target to query.
* \param disable_tma_lower When true, force non-TMA copy paths.
* \return The selected CopyInst value.
*/
/*!
* \brief Clone this copy operator.
*
* Returns a TileOperator reference that is a shallow clone of this operator
* object suitable for further modifications in pass pipelines.
*/
/*!
* \brief Generate lowering for bulk (global-to-shared or shared-to-global)
* copy.
*
* Implements TMA-based bulk load/store lowering when `copy_inst` indicates a
* bulk path. The function encodes TMA descriptors and produces calls or
* loops required by the selected bulk mechanism.
*
* \param T Lowering context.
* \param analyzer Analyzer for simplification.
* \param copy_inst Copy instruction type indicating bulk load/store.
* \return A TIR `Stmt` implementing the bulk copy.
*/
/*!
* \brief Generate lowering for LDS matrix-copy paths (LDMATRIX/STMATRIX).
*
* Emits the lowering for LDS-based matrix-copy instructions when the chosen
* `copy_inst` is an LDSM or STSM variant.
*
* \param T Lowering context.
* \param analyzer Analyzer for simplification.
* \param copy_inst Copy instruction type indicating an LDS matrix path.
* \return A TIR `Stmt` implementing the matrix-copy.
*/
/*!
* \brief Generate lowering for the normal (non-bulk, scalar/vec) copy path.
*
* Emits element-wise or vectorized loads/stores using the computed iteration
* space and predicates to ensure in-bounds accesses.
*
* \param T Lowering context.
* \param analyzer Analyzer for simplification.
* \return A TIR `Stmt` implementing the normal copy.
*/
/*!
* \brief Generate a SIMT-style thread-level loop for the copy.
*
* Produces a `For` loop that distributes copy work across SIMD/warp lanes or
* CUDA threads according to the operator's iteration strategy.
*
* \param analyzer Analyzer for simplification.
* \return A `For` loop representing the thread-level iteration.
*/
/*!
* \brief Compute a linear shared-memory layout suitable for TMA copies.
*
* Returns a `Layout` that maps the shared-memory `shared_tensor` into a
* linearized representation required by bulk/TMA transfers.
*
* \param shared_tensor Buffer representing the shared-memory tensor.
* \return A `Layout` describing the linearized shared layout.
*/
/*!
* \brief Create iterator variables for multi-dimensional copy loops.
*
* The returned `IterVar` array enumerates the loop indices used to traverse
* the copy extents in each tensor dimension.
*
* \return Array of iterator variables.
*/
/*!
* \brief Calculate source or destination indices from iteration variables.
*
* Converts the iterator variables (from MakeIterVars) into concrete index
* expressions for either the source image or the destination tensor.
*
* \param ivs Iterator variables returned by MakeIterVars().
* \param src_dst 0 to produce source indices, 1 to produce destination indices.
* \return Array of `PrimExpr` index expressions.
*/
/*!
* \brief Construct the boundary predicate ensuring in-bounds accesses.
*
* Builds a boolean expression that guards loads/stores so they only occur
* when indices lie within the provided `extents`.
*
* \param analyzer Arithmetic analyzer used to simplify predicates.
* \param ivs Iterator variables.
* \param extents Extent expressions for the target buffer.
* \param src_dst 0 = predicate for source indices, 1 = predicate for
* destination. \return A `PrimExpr` boolean predicate.
*/
/*!
* \brief Constructor.
*
* \param args Expression arguments for the copy (indices, sizes, etc.).
* \param vmap Buffer variable mapping for source and destination.
*/
/*!
* \brief Get the TVM Op handle corresponding to this Copy op.
*/
/*!
* \brief Special operator for Conv2D im2col transformation.
*
* Converts an input feature map into an im2col matrix layout used for GEMM-
* based convolution lowering. Public fields configure kernel geometry,
* stride/padding/dilation, and cache eviction behavior.
*/
/*!
* \brief Lower to TIR statement.
*
* Emits TIR that performs the im2col extraction from `src` into `dst`
* according to kernel, stride, padding, and dilation parameters.
*
* \param T Lowering context with buffer bindings.
* \param analyzer Analyzer for expression simplification and bounds reasoning.
* \return A TIR `Stmt` performing the im2col transform.
*/
/*!
* \brief Infer layout for this operator.
*
* Produces the layout mapping for the destination im2col matrix given the
* source layout and convolution parameters.
*
* \param T Layout inference arguments.
* \param level Inference granularity level.
* \return A LayoutMap with inferred layouts for affected buffers.
*/
/*!
* \brief Get TVM Op handle for Conv2DIm2Col.
*/
......@@ -324,6 +98,33 @@ public:
static constexpr const char *_type_key = "tl.Copy";
TVM_DECLARE_FINAL_OBJECT_INFO(CopyNode, TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<CopyNode>()
.def_ro("src", &CopyNode::src)
.def_ro("dst", &CopyNode::dst)
.def_ro("src_range", &CopyNode::src_range)
.def_ro("dst_range", &CopyNode::dst_range)
.def_ro("coalesced_width", &CopyNode::coalesced_width);
}
bool SEqualReduce(const CopyNode *other, SEqualReducer equal) const {
return equal(src, other->src) && equal(dst, other->dst) &&
equal(src_range, other->src_range) &&
equal(dst_range, other->dst_range) &&
equal(coalesced_width, other->coalesced_width);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(src);
hash_reduce(dst);
hash_reduce(src_range);
hash_reduce(dst_range);
hash_reduce(coalesced_width);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
/*!
* \brief Lower the copy operator to a TIR statement.
* \param T Arguments for lowering.
......@@ -475,6 +276,38 @@ public:
static constexpr const char *_type_key = "tl.Conv2DIm2Col";
TVM_DECLARE_FINAL_OBJECT_INFO(Conv2DIm2ColOpNode, TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<Conv2DIm2ColOpNode>()
.def_ro("src", &Conv2DIm2ColOpNode::src)
.def_ro("dst", &Conv2DIm2ColOpNode::dst)
.def_ro("stride", &Conv2DIm2ColOpNode::stride)
.def_ro("padding", &Conv2DIm2ColOpNode::padding)
.def_ro("dilation", &Conv2DIm2ColOpNode::dilation)
.def_ro("kernel", &Conv2DIm2ColOpNode::kernel)
.def_ro("eviction_policy", &Conv2DIm2ColOpNode::eviction_policy);
}
bool SEqualReduce(const Conv2DIm2ColOpNode *other,
SEqualReducer equal) const {
return equal(src, other->src) && equal(dst, other->dst) &&
equal(stride, other->stride) && equal(padding, other->padding) &&
equal(dilation, other->dilation) && equal(kernel, other->kernel) &&
equal(eviction_policy, other->eviction_policy);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(src);
hash_reduce(dst);
hash_reduce(stride);
hash_reduce(padding);
hash_reduce(dilation);
hash_reduce(kernel);
hash_reduce(eviction_policy);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
/*!
* \brief Lower to TIR statement.
*/
......
/*!
* \file tl/op/elem.h
* \brief Define elment-wise operators.
*
*/
#ifndef TVM_TL_OP_ELEM_H_
#define TVM_TL_OP_ELEM_H_
#include "operator.h"
#include "parallel.h"
/**
* Lower the Fill operator into TIR statements.
*
* Produces a TIR Stmt that implements element-wise filling of `dst` over
* `region` with `value`, using information from `T`.
*
* @param T Lowering inputs (buffers, shapes, and iteration info) used to
* generate the IR.
*/
/**
* Infer the memory layout mapping for the Fill operator.
*
* Returns a LayoutMap that describes how logical iteration axes map to memory
* dimensions for the destination buffer. `level` controls the aggressiveness
* of inference (e.g., relaxed vs. strict constraints).
*
* @param T Layout inference inputs (buffers, shapes, and related metadata).
* @param level Inference level controlling precision of the returned mapping.
*/
/**
* Return the global operator descriptor for tl.Fill.
*
* The returned Op can be used to look up operator-level metadata and to
* register or query the operator within the TVM operator registry.
*/
/**
* Create a copy of this operator node as a TileOperator reference.
*
* The returned TileOperator is an independent handle representing a clone of
* the underlying FillNode.
*/
/**
* Build a SIMT-style For loop that implements the fill.
*
* Constructs and returns a TIR `For` loop that iterates over the target region
* in a SIMT-friendly ordering appropriate for `dst` and `region`.
*/
/**
* Construct a Fill operator from argument expressions and a buffer mapping.
*
* @param args Positional PrimExpr arguments passed to the operator (e.g.,
* indices or shape expressions required by the operator's specification).
* @param vmap Mapping from named buffer parameters to concrete tir::Buffer
* instances used by this operator instance.
*/
/**
* Return the global operator descriptor for the public Fill wrapper.
*
* Mirrors FillNode::Get() and provides the operator descriptor for users of the
* public TileOperator API.
*/
namespace tvm {
namespace tl {
using namespace tir;
class FillNode : public TileOperatorNode {
public:
tir::Buffer dst;
PrimExpr value;
Array<Range> region;
static constexpr const char *_type_key = "tl.Fill";
TVM_DECLARE_FINAL_OBJECT_INFO(FillNode, TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
static const Op &Get();
TileOperator Clone() const;
private:
For MakeSIMTLoop(arith::Analyzer *analyzer) const;
};
class Fill : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Fill, TileOperator, FillNode);
TVM_DLL Fill(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_ELEM_H_
\ No newline at end of file
/*!
* \file tl/op/elem.cc
* \file tl/op/fill.cc
*
* Define elment-wise operators.
*/
#include "elem.h"
#include "fill.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
......@@ -225,5 +225,9 @@ TIR_REGISTER_TL_OP(Fill, fill)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({
FillNode::RegisterReflection();
});
} // namespace tl
} // namespace tvm
\ No newline at end of file
/*!
* \file tl/op/fill.h
* \brief Fill operations for tensor initialization
*/
#ifndef TVM_TL_OP_FILL_H_
#define TVM_TL_OP_FILL_H_
#include "operator.h"
#include "parallel.h"
namespace tvm {
namespace tl {
using namespace tir;
/// Node class for fill operations
class FillNode : public TileOperatorNode {
public:
tir::Buffer dst; ///< Destination buffer to fill
PrimExpr value; ///< Value to fill with
Array<Range> region; ///< Region to fill within the buffer
static constexpr const char *_type_key = "tl.Fill";
TVM_DECLARE_FINAL_OBJECT_INFO(FillNode, TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
static const Op &Get();
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<FillNode>()
.def_ro("dst", &FillNode::dst)
.def_ro("value", &FillNode::value)
.def_ro("region", &FillNode::region);
}
bool SEqualReduce(const FillNode *other, SEqualReducer equal) const {
return equal(dst, other->dst) && equal(value, other->value) &&
equal(region, other->region);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dst);
hash_reduce(value);
hash_reduce(region);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
TileOperator Clone() const;
private:
/// Create SIMT-style parallel loop for filling
For MakeSIMTLoop(arith::Analyzer *analyzer) const;
};
/// Wrapper class for fill operations
class Fill : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Fill, TileOperator, FillNode);
TVM_DLL Fill(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_FILL_H_
\ No newline at end of file
......@@ -160,5 +160,7 @@ TIR_REGISTER_TL_OP(FinalizeReducerOp, finalize_reducer)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ FinalizeReducerOpNode::RegisterReflection(); });
} // namespace tl
} // namespace tvm
......@@ -12,66 +12,6 @@
#include "../transform/layout_reducer.h"
#include "./operator.h"
/**
* FinalizeReducer operator node for Tile IR.
*
* Represents a TL-level operator that finalizes a reducer buffer into a
* result using a specified reducer operation.
*
* Public members:
* - reducer: the tir::Buffer that holds the intermediate reduction values.
* - op: the reducer operation to apply when finalizing values.
*/
/**
* Lower this operator to a TIR statement.
*
* @param T Lowering arguments (buffers, indices, and other lowering context).
* @param analyzer Arithmetic analyzer used to simplify expressions during
* lowering.
* @return A tir::Stmt that implements the finalize-reducer semantics for the
* provided lowering context.
*/
/**
* Infer layout mapping for this operator.
*
* Determines how input and output buffer layouts relate for the
* finalize-reducer operator at the given inference level.
*
* @param T Layout inference arguments (including operand layouts and shapes).
* @param level Inference precision level.
* @return A LayoutMap describing the inferred layouts.
*/
/**
* Get the singleton Op object representing this operator.
*
* @return A reference to the Op describing FinalizeReducer.
*/
/**
* Create a deep copy of this operator node as a TileOperator.
*
* @return A TileOperator handle that is an independent clone of this node.
*/
/**
* Public wrapper for FinalizeReducerOpNode.
*
* Provides the reference semantics and construction API used by callers.
*/
/**
* Construct a FinalizeReducerOp from TL-level arguments.
*
* @param args Positional primitive expressions that parameterize the operator
* (e.g., shapes, axis indices). Documented where their meaning is
* not obvious from name or type in call sites.
* @param vmap Mapping from operand names to tir::Buffer instances used by this
* operator.
*/
/**
* Get the Op singleton for the public FinalizeReducerOp handle.
*
......@@ -90,6 +30,25 @@ public:
static constexpr const char *_type_key = "tl.FinalizeReducerOp";
TVM_DECLARE_FINAL_OBJECT_INFO(FinalizeReducerOpNode, TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<FinalizeReducerOpNode>()
.def_ro("reducer", &FinalizeReducerOpNode::reducer)
.def_ro("op", &FinalizeReducerOpNode::op);
}
bool SEqualReduce(const FinalizeReducerOpNode *other,
SEqualReducer equal) const {
return equal(reducer, other->reducer) && equal(op, other->op);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(reducer);
hash_reduce(op);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
......
/*!
* \file tl/op/gemm.cc
*
* Define gemm operator.
* \brief Implementation of General Matrix Multiplication (GEMM) operators
*/
#include "gemm.h"
......@@ -85,8 +84,7 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
node->M = args[5].as<IntImm>().value()->value;
node->N = args[6].as<IntImm>().value()->value;
node->K = args[7].as<IntImm>().value()->value;
node->policy =
static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
node->clear_accum = args[9].as<Bool>().value();
node->stride_A = args[10].as<IntImm>().value()->value;
node->stride_B = args[11].as<IntImm>().value()->value;
......@@ -117,26 +115,6 @@ TileOperator GemmNode::Clone() const {
return Gemm(op);
}
/**
* @brief Selects the GEMM implementation variant for a given block size and
* target.
*
* Determines which low-level GEMM instruction to use:
* - Returns kWGMMA when running on Hopper-class targets and the operator meets
* WGMMA constraints (M >= 64, number of warps is a multiple of 4, and
* CheckWGMMA() returns true).
* - Returns kMFMA for CDNA targets.
* - Returns kMMA for CUDA targets.
*
* @param block_size Number of threads in the CUDA/ROCm thread block used for
* the GEMM.
* @param target Target backend describing the hardware (used to detect
* architecture).
* @return GemmInst The chosen GEMM implementation enum value.
*
* @throws fatal error (ICHECK) If the target is not recognized/supported, this
* function triggers a runtime check failure.
*/
GemmNode::GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size;
......@@ -153,63 +131,20 @@ GemmNode::GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
}
}
/**
* @brief Compute how warps are partitioned between the M and N GEMM dimensions.
*
* Determines the number of warps assigned to the M (rows) and N (columns)
* dimensions for a block given the selected GEMM implementation and target.
* The function enforces constraints required by the implementations (e.g.,
* per-warp tile sizes) and adapts the partition according to the configured
* GemmWarpPolicy (FullRow, FullCol, Square).
*
* @param block_size Total number of threads in the block (used to derive
* num_warps).
* @param gemm_inst The chosen GEMM implementation (e.g., kWGMMA, kMFMA, kMMA).
* @param target Target device information (used for warp size and
* target-specific rules).
* @return std::pair<int, int> {m_warp, n_warp} where m_warp * n_warp ==
* num_warps.
*
* Constraints and behavior:
* - Each warp is assumed to cover 16 rows (M) and 8 columns (N). The function
* checks that M % 16 == 0 and N % 8 == 0.
* - num_warps is computed as block_size / warp_size(target).
* - For WGMMA (kWGMMA):
* - num_warps must be a multiple of 4 (warp-groups of 4).
* - m_warp is always a multiple of 4.
* - The warp partition respects the GemmWarpPolicy:
* - FullRow: maximize warps on M (in multiples of 4) while keeping
* divisibility.
* - FullCol: maximize warps on N, but if N is not evenly divisible, move
* whole warp-groups to M to achieve feasibility.
* - Square: choose a multiple-of-4 m_warp that best balances per-warp work
* between M and N.
* - For non-WGMMA implementations:
* - FullRow: favor allocating warps to M first; if M cannot use all warps,
* remaining warps are placed on N.
* - FullCol: favor allocating warps to N first; if N cannot use all warps,
* remaining warps are placed on M.
* - Square: search for the m/n split that best balances per-warp work given
* integer warp counts and the per-warp tile sizes.
*
* Error handling:
* - The function performs internal checks (ICHECK) and will fail if required
* divisibility or policy conditions are not met (e.g., M/N tile divisibility,
* invalid policy, or WGMMA-specific warp-group requirements).
*/
std::pair<int, int> GemmNode::ComputeWarpPartition(int block_size,
GemmInst gemm_inst,
Target target) const {
std::pair<int, int>
GemmWarpPolicyNode::ComputeWarpPartition(int M, int N, int block_size,
Target target, bool use_wgmma) const {
int num_warps = block_size / TargetGetWarpSize(target);
int m_warp = 1, n_warp = 1;
constexpr int kMPerWarp = 16; // Rows processed by a single warp
constexpr int kNPerWarp = 8; // Columns processed by a single warp
ICHECK(this->M % kMPerWarp == 0)
<< "M must be divisible by " << kMPerWarp << ", but got " << this->M;
ICHECK(this->N % kNPerWarp == 0)
<< "N must be divisible by " << kNPerWarp << ", but got " << this->N;
if (gemm_inst == GemmInst::kWGMMA) {
ICHECK(M % kMPerWarp == 0)
<< "M must be divisible by " << kMPerWarp << ", but got " << M;
ICHECK(N % kNPerWarp == 0)
<< "N must be divisible by " << kNPerWarp << ", but got " << N;
if (use_wgmma) {
ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";
constexpr int kGroup = 4; // Number of warps in a warp-group
......@@ -217,22 +152,22 @@ std::pair<int, int> GemmNode::ComputeWarpPartition(int block_size,
m_warp = kGroup; // Initially, only one warp-group on M dimension
n_warp = num_warps / m_warp; // Rest all on N dimension
if (this->policy == GemmWarpPolicy::kFullRow) {
if (this->isFullRow()) {
// Try to put as many warp-groups as possible on M dimension
// (decreasing multiples of 4, ensuring divisibility by M)
for (int cand = num_warps; cand >= kGroup; cand -= kGroup) {
if (this->M % (cand * kMPerWarp) == 0) {
if (M % (cand * kMPerWarp) == 0) {
m_warp = cand;
n_warp = num_warps / m_warp;
break;
}
}
} else if (this->policy == GemmWarpPolicy::kFullCol) {
} else if (this->isFullCol()) {
// Try to use warps on N dimension; if N is not divisible, split excess
// groups to M
int cand_n = n_warp; // Initially assume all on N
if (this->N % (cand_n * kNPerWarp) != 0) { // N direction division fails
int max_n = this->N / kNPerWarp;
if (N % (cand_n * kNPerWarp) != 0) { // N direction division fails
int max_n = N / kNPerWarp;
// Find a feasible n_warp from max possible downwards, ensuring
// num_warps/n_warp is multiple of 4
for (int n = std::min(cand_n, max_n); n >= 1; --n) {
......@@ -243,12 +178,12 @@ std::pair<int, int> GemmNode::ComputeWarpPartition(int block_size,
}
}
}
} else if (this->policy == GemmWarpPolicy::kSquare) {
} else if (this->isSquare()) {
// Exhaustive search, but m must be multiple of 4
int max_m = this->M / kMPerWarp;
int max_n = this->N / kNPerWarp;
int max_m = M / kMPerWarp;
int max_n = N / kNPerWarp;
float ideal = this->N > 0 ? static_cast<float>(this->M) / this->N : 1.f;
float ideal = N > 0 ? static_cast<float>(M) / N : 1.f;
float best_score = std::numeric_limits<float>::max();
int best_m = kGroup, best_n = n_warp;
......@@ -260,8 +195,8 @@ std::pair<int, int> GemmNode::ComputeWarpPartition(int block_size,
if (n > max_n)
continue;
float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
float score = std::abs(m_per_warp / n_per_warp - ideal);
if (score < best_score) {
......@@ -278,58 +213,57 @@ std::pair<int, int> GemmNode::ComputeWarpPartition(int block_size,
ICHECK(m_warp * n_warp == num_warps)
<< "m_warp * n_warp must equal num_warps";
// Store the computed values in the object's member variables
this->m_warp = m_warp;
this->n_warp = n_warp;
return {m_warp, n_warp};
}
if (this->policy == GemmWarpPolicy::kFullRow) {
if (this->isFullRow()) {
// Try to partition M first
m_warp = num_warps;
n_warp = 1;
// If M cannot be evenly divided by m_warp*16, try to split remaining warps
// to N
if (this->M % (m_warp * kMPerWarp) != 0) {
if (M % (m_warp * kMPerWarp) != 0) {
// Calculate how many warps we can use for M
int max_m_warps = this->M / kMPerWarp;
int max_m_warps = M / kMPerWarp;
m_warp = max_m_warps;
// Use remaining warps for N
n_warp = num_warps / m_warp;
if (n_warp == 0)
n_warp = 1;
}
} else if (this->policy == GemmWarpPolicy::kFullCol) {
} else if (this->isFullCol()) {
// Try to partition N first
m_warp = 1;
n_warp = num_warps;
// If N cannot be evenly divided by n_warp*8, try to split remaining warps
// to M
if (this->N % (n_warp * kNPerWarp) != 0) {
if (N % (n_warp * kNPerWarp) != 0) {
// Calculate how many warps we can use for N
int max_n_warps = this->N / kNPerWarp;
int max_n_warps = N / kNPerWarp;
n_warp = max_n_warps;
// Use remaining warps for M
m_warp = num_warps / n_warp;
if (m_warp == 0)
m_warp = 1;
}
} else if (this->policy == GemmWarpPolicy::kSquare) {
} else if (this->isSquare()) {
// First calculate the maximum possible warps for each dimension
int max_m_warps =
this->M / kMPerWarp; // Each warp needs at least 16 elements in M
int max_n_warps =
this->N / kNPerWarp; // Each warp needs at least 8 elements in N
M / kMPerWarp; // Each warp needs at least 16 elements in M
// Calculate the ideal ratio of M/N warps based on the matrix dimensions
float ideal_ratio = 1.0f;
if (this->N > 0) {
ideal_ratio = static_cast<float>(this->M) / this->N;
if (N > 0) {
ideal_ratio = static_cast<float>(M) / N;
}
// Start with a balanced initial guess
m_warp = 1;
n_warp = 1;
// Try to find the best balanced partition
int best_m = 1;
int best_n = 1;
......@@ -340,8 +274,8 @@ std::pair<int, int> GemmNode::ComputeWarpPartition(int block_size,
int n = num_warps / m;
// Calculate how balanced this partition is
float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
float m_per_warp = static_cast<float>(M) / (m * kMPerWarp);
float n_per_warp = static_cast<float>(N) / (n * kNPerWarp);
float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);
if (balance < best_balance) {
......@@ -356,6 +290,11 @@ std::pair<int, int> GemmNode::ComputeWarpPartition(int block_size,
} else {
ICHECK(0) << "Unknown GemmWarpPolicy";
}
// Store the computed values in the object's member variables
this->m_warp = m_warp;
this->n_warp = n_warp;
return {m_warp, n_warp};
}
......@@ -459,9 +398,9 @@ static int GetArchInt(Target target) {
int arch_int = 0;
auto s = target->GetAttr<String>("arch");
ICHECK(s.defined());
const char *arch_str = s.value().c_str();
if (arch_str[0] == 's' && arch_str[1] == 'm' && arch_str[2] == '_') {
arch_int = atoi(&arch_str[3]);
std::string arch = s.value();
if (arch.rfind("sm_", 0) == 0) {
arch_int = std::stoi(arch.substr(3));
} else {
arch_int = 0;
}
......@@ -484,7 +423,8 @@ static int GetArchInt(Target target) {
Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto block_size = *as_const_int(T.thread_bounds->extent);
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);
auto [warp_m, warp_n] = policy->ComputeWarpPartition(
M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA);
std::stringstream ss;
std::string op_name = "tl::gemm_ss";
......@@ -546,7 +486,8 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
auto thread_range = T.thread_bounds;
auto block_size = *as_const_int(thread_range->extent);
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);
auto [warp_m, warp_n] = policy->ComputeWarpPartition(
M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA);
if (TargetIsVolta(T.target)) {
auto fragment =
......
......@@ -10,88 +10,93 @@
#include "operator.h"
namespace tvm {
/**
* Check whether the target and configuration allow using WGMMA (wavefront-group
* MMA) for this GEMM.
*
* @returns true if WGMMA can be used for the current node configuration and
* target; false otherwise.
*/
/**
* Lower this GEMM operator to a TVM Stmt for the given lowering context.
*
* @param T Lowering arguments and context (tile mappings, target, etc.).
* @param analyzer Arithmetic analyzer used for symbolic simplification and
* bounds reasoning.
* @returns A lowered Stmt implementing the GEMM.
*/
/**
* Infer memory/layout mapping for GEMM inputs/outputs at the given inference
* level.
*
* @param T Layout inference inputs (buffers, shapes, constraints).
* @param level Inference level that controls how aggressive/specific the
* inferred layouts should be.
* @returns A LayoutMap describing how logical tensor axes map to storage/layout
* axes.
*/
/**
* Create a deep/shallow copy of this TileOperator node as a TileOperator
* reference.
*
* @returns A TileOperator reference that represents a clone of this GemmNode.
*/
/**
* Determine the specific GEMM instruction variant to use for the given block
* size and target.
*
* @param block_size The tile/block size (in elements or threads) used to select
* instruction variant.
* @param target The compilation target describing architecture and instruction
* set.
* @returns The GemmInst enum value representing the chosen GEMM instruction
* family.
*/
/**
* Compute how to partition work across warps for the given number of warps and
* GEMM instruction.
*
* The returned pair is (warp_rows, warp_cols), describing the per-warp tiling
* in row and column dimensions respectively.
*
* @param num_warps Total number of warps available for the block.
* @param gemm_inst The GEMM instruction variant selected for the target.
* @param target The compilation target which may constrain or influence
* partitioning.
* @returns A pair<int,int> = (warp_rows, warp_cols) describing the warp
* partition.
*/
/**
* Construct a Gemm operator handle from call arguments and a buffer mapping.
*
* @param args Array of call-time PrimExpr arguments passed to the operator.
* @param vmap Mapping from buffer names/indices to tir::Buffer objects used by
* this GEMM.
*/
/**
* Obtain the registered Op descriptor for the GEMM operator.
*
* @returns A const reference to the Op representing "tl.Gemm".
*/
namespace tl {
using namespace tir;
enum class GemmWarpPolicy : uint8_t {
enum class GemmWarpPolicyType : uint8_t {
kSquare = 0,
kFullRow = 1,
kFullCol = 2,
kFree = 3,
};
class GemmWarpPolicyNode : public Object {
public:
mutable int m_warp{0};
mutable int n_warp{0};
int policy_type;
static constexpr const char *_type_key = "tl.GemmWarpPolicy";
TVM_DECLARE_FINAL_OBJECT_INFO(GemmWarpPolicyNode, Object);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmWarpPolicyNode>()
.def_ro("policy_type", &GemmWarpPolicyNode::policy_type)
.def_ro("m_warp", &GemmWarpPolicyNode::m_warp)
.def_ro("n_warp", &GemmWarpPolicyNode::n_warp);
}
bool SEqualReduce(const GemmWarpPolicyNode *other,
SEqualReducer equal) const {
return equal(policy_type, other->policy_type) &&
equal(m_warp, other->m_warp) && equal(n_warp, other->n_warp);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(policy_type);
hash_reduce(m_warp);
hash_reduce(n_warp);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
std::pair<int, int> ComputeWarpPartition(int M, int N, int block_size,
Target target, bool use_wgmma) const;
bool isSquare() const {
return policy_type == int(GemmWarpPolicyType::kSquare);
}
bool isFullRow() const {
return policy_type == int(GemmWarpPolicyType::kFullRow);
}
bool isFullCol() const {
return policy_type == int(GemmWarpPolicyType::kFullCol);
}
bool isFree() const { return policy_type == int(GemmWarpPolicyType::kFree); }
};
class GemmWarpPolicy : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(GemmWarpPolicy, ObjectRef, GemmWarpPolicyNode);
explicit GemmWarpPolicy(GemmWarpPolicyType policy_type) {
auto node = make_object<GemmWarpPolicyNode>();
node->policy_type = (int)policy_type;
data_ = std::move(node);
}
explicit GemmWarpPolicy(int policy_type) {
auto node = make_object<GemmWarpPolicyNode>();
node->policy_type = policy_type;
data_ = std::move(node);
}
explicit GemmWarpPolicy(int m_warp, int n_warp) {
auto node = make_object<GemmWarpPolicyNode>();
node->m_warp = m_warp;
node->n_warp = n_warp;
node->policy_type = (int)GemmWarpPolicyType::kFree;
data_ = std::move(node);
}
};
class GemmNode : public TileOperatorNode {
public:
bool CheckWGMMA() const;
Array<PrimExpr> call_args;
tir::Buffer A, B, C;
// pointer to the A, B, C
PrimExpr Aptr, Bptr, Cptr;
......@@ -104,11 +109,74 @@ public:
// only will be enabled under cdna mfma instructions
int kPack = 1;
int wg_wait = 0;
GemmWarpPolicy policy;
mutable GemmWarpPolicy policy;
static constexpr const char *_type_key = "tl.Gemm";
TVM_DECLARE_FINAL_OBJECT_INFO(GemmNode, TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmNode>()
.def_ro("A", &GemmNode::A)
.def_ro("B", &GemmNode::B)
.def_ro("C", &GemmNode::C)
.def_ro("Aptr", &GemmNode::Aptr)
.def_ro("Bptr", &GemmNode::Bptr)
.def_ro("Cptr", &GemmNode::Cptr)
.def_ro("trans_A", &GemmNode::trans_A)
.def_ro("trans_B", &GemmNode::trans_B)
.def_ro("M", &GemmNode::M)
.def_ro("N", &GemmNode::N)
.def_ro("K", &GemmNode::K)
.def_ro("stride_A", &GemmNode::stride_A)
.def_ro("stride_B", &GemmNode::stride_B)
.def_ro("offset_A", &GemmNode::offset_A)
.def_ro("offset_B", &GemmNode::offset_B)
.def_ro("clear_accum", &GemmNode::clear_accum)
.def_ro("kPack", &GemmNode::kPack)
.def_ro("wg_wait", &GemmNode::wg_wait)
.def_ro("policy", &GemmNode::policy);
}
bool SEqualReduce(const GemmNode *other, SEqualReducer equal) const {
return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) &&
equal(Aptr, other->Aptr) && equal(Bptr, other->Bptr) &&
equal(Cptr, other->Cptr) && equal(trans_A, other->trans_A) &&
equal(trans_B, other->trans_B) && equal(M, other->M) &&
equal(N, other->N) && equal(K, other->K) &&
equal(stride_A, other->stride_A) &&
equal(stride_B, other->stride_B) &&
equal(offset_A, other->offset_B) &&
equal(offset_B, other->offset_B) &&
equal(clear_accum, other->clear_accum) &&
equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) &&
equal(policy, other->policy);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(A);
hash_reduce(B);
hash_reduce(C);
hash_reduce(Aptr);
hash_reduce(Bptr);
hash_reduce(Cptr);
hash_reduce(trans_A);
hash_reduce(trans_B);
hash_reduce(M);
hash_reduce(N);
hash_reduce(K);
hash_reduce(stride_A);
hash_reduce(stride_B);
hash_reduce(offset_A);
hash_reduce(offset_B);
hash_reduce(clear_accum);
hash_reduce(kPack);
hash_reduce(wg_wait);
hash_reduce(policy);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
......@@ -120,9 +188,6 @@ private:
enum class GemmInst : uint8_t { kMMA, kWGMMA, kUTCMMA, kMFMA };
GemmInst GetGemmInst(int block_size, Target target) const;
std::pair<int, int> ComputeWarpPartition(int num_warps, GemmInst gemm_inst,
Target target) const;
mutable bool completed_ = false;
};
......
......@@ -74,8 +74,7 @@ GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) {
node->M = args[6].as<IntImm>().value()->value;
node->N = args[7].as<IntImm>().value()->value;
node->K = args[8].as<IntImm>().value()->value;
node->policy = static_cast<GemmSPNode::GemmWarpPolicy>(
args[9].as<IntImm>().value()->value);
node->policy = GemmWarpPolicy(args[9].as<IntImm>().value()->value);
node->clear_accum = args[10].as<Bool>().value();
if (args.size() > 11) {
node->kPack = args[11].as<IntImm>().value()->value;
......@@ -103,185 +102,6 @@ TileOperator GemmSPNode::Clone() const {
return GemmSP(op);
}
/**
* @brief Compute a partition of warps across the M and N GEMM dimensions.
*
* Computes (m_warp, n_warp) such that m_warp * n_warp == num_warps and the
* warp counts respect element-per-warp granularity and the configured
* GemmWarpPolicy. On Hopper targets, when `maybe_hopper_wgmma` is true and
* the problem size permits, a warp-group (WGMMA)-aware partitioning is used
* (groups of 4 warps).
*
* @param num_warps Total number of warps available for the block.
* @param target Hardware target used to decide target-specific strategies
* (e.g., Hopper WGMMA grouping).
* @param maybe_hopper_wgmma If true, allows using Hopper WGMMA-specific
* partitioning when the target and problem size
* permit.
* @return std::pair<int,int> A pair (m_warp, n_warp) giving the number of warp
* partitions along M and N, respectively.
*
* @note The function uses ICHECK to enforce invariants (e.g., unknown policy or
* invalid m_warp * n_warp), which will terminate on failure.
*/
std::pair<int, int>
GemmSPNode::ComputeWarpPartition(int num_warps, Target target,
bool maybe_hopper_wgmma) const {
int m_warp = 1, n_warp = 1;
constexpr int kMPerWarp = 16; // Rows processed by a single warp
constexpr int kNPerWarp = 8; // Columns processed by a single warp
bool allow_wgmma = TargetIsHopper(target) && maybe_hopper_wgmma &&
(this->M >= 64) && (num_warps % 4 == 0);
if (allow_wgmma) {
ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";
constexpr int kGroup = 4; // Number of warps in a warp-group
m_warp = kGroup; // Initially, only one warp-group on M dimension
n_warp = num_warps / m_warp; // Rest all on N dimension
if (this->policy == GemmWarpPolicy::kFullRow) {
// Try to put as many warp-groups as possible on M dimension
// (decreasing multiples of 4, ensuring divisibility by M)
for (int cand = num_warps; cand >= kGroup; cand -= kGroup) {
if (this->M % (cand * kMPerWarp) == 0) {
m_warp = cand;
n_warp = num_warps / m_warp;
break;
}
}
} else if (this->policy == GemmWarpPolicy::kFullCol) {
// Try to use warps on N dimension; if N is not divisible, split excess
// groups to M
int cand_n = n_warp; // Initially assume all on N
if (this->N % (cand_n * kNPerWarp) != 0) { // N direction division fails
int max_n = this->N / kNPerWarp;
// Find a feasible n_warp from max possible downwards, ensuring
// num_warps/n_warp is multiple of 4
for (int n = std::min(cand_n, max_n); n >= 1; --n) {
if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) {
n_warp = n;
m_warp = num_warps / n_warp;
break;
}
}
}
} else if (this->policy == GemmWarpPolicy::kSquare) {
// Exhaustive search, but m must be multiple of 4
int max_m = this->M / kMPerWarp;
int max_n = this->N / kNPerWarp;
float ideal = this->N > 0 ? static_cast<float>(this->M) / this->N : 1.f;
float best_score = std::numeric_limits<float>::max();
int best_m = kGroup, best_n = n_warp;
for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) {
if (num_warps % m)
continue;
int n = num_warps / m;
if (n > max_n)
continue;
float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
float score = std::abs(m_per_warp / n_per_warp - ideal);
if (score < best_score) {
best_score = score;
best_m = m;
best_n = n;
}
}
m_warp = best_m;
n_warp = best_n;
} else {
ICHECK(0) << "Unknown GemmWarpPolicy";
}
ICHECK(m_warp * n_warp == num_warps)
<< "m_warp * n_warp must equal num_warps";
return {m_warp, n_warp};
}
if (this->policy == GemmWarpPolicy::kFullRow) {
// Try to partition M first
m_warp = num_warps;
n_warp = 1;
// If M cannot be evenly divided by m_warp*16, try to split remaining warps
// to N
if (this->M % (m_warp * kMPerWarp) != 0) {
// Calculate how many warps we can use for M
int max_m_warps = this->M / kMPerWarp;
m_warp = max_m_warps;
// Use remaining warps for N
n_warp = num_warps / m_warp;
if (n_warp == 0)
n_warp = 1;
}
} else if (this->policy == GemmWarpPolicy::kFullCol) {
// Try to partition N first
m_warp = 1;
n_warp = num_warps;
// If N cannot be evenly divided by n_warp*8, try to split remaining warps
// to M
if (this->N % (n_warp * kNPerWarp) != 0) {
// Calculate how many warps we can use for N
int max_n_warps = this->N / kNPerWarp;
n_warp = max_n_warps;
// Use remaining warps for M
m_warp = num_warps / n_warp;
if (m_warp == 0)
m_warp = 1;
}
} else if (this->policy == GemmWarpPolicy::kSquare) {
// First calculate the maximum possible warps for each dimension
int max_m_warps =
this->M / kMPerWarp; // Each warp needs at least 16 elements in M
int max_n_warps =
this->N / kNPerWarp; // Each warp needs at least 8 elements in N
// Calculate the ideal ratio of M/N warps based on the matrix dimensions
float ideal_ratio = 1.0f;
if (this->N > 0) {
ideal_ratio = static_cast<float>(this->M) / this->N;
}
// Start with a balanced initial guess
m_warp = 1;
n_warp = 1;
// Try to find the best balanced partition
int best_m = 1;
int best_n = 1;
float best_balance = std::numeric_limits<float>::max();
// Try all possible combinations that satisfy the constraints
for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
int n = num_warps / m;
// Calculate how balanced this partition is
float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);
if (balance < best_balance) {
best_balance = balance;
best_m = m;
best_n = n;
}
}
m_warp = best_m;
n_warp = best_n;
} else {
ICHECK(0) << "Unknown GemmWarpPolicy";
}
return {m_warp, n_warp};
}
/**
* @brief Lower this GemmSP node to a TL (tensile-like) intrinsic call.
*
......@@ -308,7 +128,7 @@ Stmt GemmSPNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
(block_size / warp_size % 4 == 0);
auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
policy->ComputeWarpPartition(M, N, block_size, T.target, maybe_wgmma);
std::stringstream ss;
std::string op_name = "tl::gemm_sp_ss";
......@@ -386,7 +206,7 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
bool maybe_wgmma =
(this->M >= wgmma_m) && (block_size / warp_size % 4 == 0);
auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
policy->ComputeWarpPartition(M, N, block_size, T.target, maybe_wgmma);
auto fragment =
maybe_wgmma
? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
......@@ -397,8 +217,6 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
int dim_A = A->shape.size();
const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
const int64_t continuity =
trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
results.Set(A, makeGemmABLayoutHopper(mat_stride, mat_continuous,
mat_continuous, A->dtype.bits(),
trans_A ? 1 : 2));
......@@ -431,5 +249,7 @@ TIR_REGISTER_TL_OP(GemmSP, gemm_sp)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ GemmSPNode::RegisterReflection(); });
} // namespace tl
} // namespace tvm
......@@ -7,82 +7,17 @@
#ifndef TVM_TL_OP_GEMM_SP_H_
#define TVM_TL_OP_GEMM_SP_H_
#include "gemm.h"
#include "operator.h"
namespace tvm {
/**
* Lower the GemmSP operator into a TIR statement for the given lowering
* context.
*
* Produces the TIR Stmt that implements this operator using the provided
* lowering arguments. The `analyzer` is used for arithmetic simplifications and
* may be null.
*
* @param T Lowering context and arguments.
* @returns A TIR `Stmt` implementing the lowered operator.
*/
/**
* Infer memory/layout mapping for operands and outputs of this operator.
*
* Computes a LayoutMap describing how logical tensor layouts map to physical
* buffer layouts for the given inference `level`.
*
* @param T Layout inference inputs (shapes, buffer info, etc.).
* @param level Inference granularity/level.
* @returns A LayoutMap describing inferred layouts.
*/
/**
* Compute a warp-level partitioning (rows, cols) for the given number of warps.
*
* Returns a pair (warps_per_row, warps_per_col) describing how to tile the GEMM
* across warps for the specified `target`. The optional `maybe_hopper_wgmma`
* enables target-specific adjustments (e.g., CDNA WG/MMA variants) when set.
*
* @param num_warps Total number of warps available for the tile.
* @param target Target device/architecture used to guide partitioning choices.
* @param maybe_hopper_wgmma Enable target-specific WG/MMA adjustments when
* true.
* @returns Pair<int,int> of (warps_per_row, warps_per_col).
*/
/**
* Create a copy of this TileOperator node as a TileOperator reference.
*
* The returned TileOperator refers to a new node that is a copy of this node.
*
* @returns A TileOperator that is a clone of this node.
*/
/**
* Construct a GemmSP TileOperator from call arguments and a buffer map.
*
* @param args Array of PrimExpr specifying call-site arguments for the
* operator.
* @param vmap Mapping from buffer names to tir::Buffer objects for
* operands/outputs.
*/
/**
* Return the singleton Op descriptor for the GemmSP operator.
*
* @returns Reference to the operator's Op registration object.
*/
namespace tl {
using namespace tir;
class GemmSPNode : public TileOperatorNode {
public:
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
enum class GemmWarpPolicy : uint8_t {
kSquare = 0,
kFullRow = 1,
kFullCol = 2,
} policy;
std::pair<int, int>
ComputeWarpPartition(int num_warps, Target target,
bool maybe_hopper_wgmma = true) const;
Array<PrimExpr> call_args;
tir::Buffer A, B, C, E;
bool trans_A, trans_B;
int M, N, K;
......@@ -92,8 +27,59 @@ public:
int kPack = 1;
int wg_wait = 0;
mutable GemmWarpPolicy policy;
static constexpr const char *_type_key = "tl.GemmSP";
TVM_DECLARE_FINAL_OBJECT_INFO(GemmSPNode, TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
TileOperator Clone() const;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmSPNode>()
.def_ro("policy", &GemmSPNode::policy)
.def_ro("A", &GemmSPNode::A)
.def_ro("B", &GemmSPNode::B)
.def_ro("C", &GemmSPNode::C)
.def_ro("E", &GemmSPNode::E)
.def_ro("trans_A", &GemmSPNode::trans_A)
.def_ro("trans_B", &GemmSPNode::trans_B)
.def_ro("M", &GemmSPNode::M)
.def_ro("N", &GemmSPNode::N)
.def_ro("K", &GemmSPNode::K)
.def_ro("clear_accum", &GemmSPNode::clear_accum)
.def_ro("kPack", &GemmSPNode::kPack)
.def_ro("wg_wait", &GemmSPNode::wg_wait);
}
bool SEqualReduce(const GemmSPNode *other, SEqualReducer equal) const {
return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) &&
equal(E, other->E) && equal(trans_A, other->trans_A) &&
equal(trans_B, other->trans_B) && equal(M, other->M) &&
equal(N, other->N) && equal(K, other->K) &&
equal(clear_accum, other->clear_accum) &&
equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(policy);
hash_reduce(A);
hash_reduce(B);
hash_reduce(C);
hash_reduce(E);
hash_reduce(trans_A);
hash_reduce(trans_B);
hash_reduce(M);
hash_reduce(N);
hash_reduce(K);
hash_reduce(clear_accum);
hash_reduce(kPack);
hash_reduce(wg_wait);
}
private:
mutable bool completed_ = false;
};
......
......@@ -48,7 +48,6 @@ struct LayoutInferArgs {
Map<Buffer, Buffer> buffer_remap;
};
class TileOperatorNode;
class TileOperator;
class TileOperatorNode : public Object {
......
......@@ -378,9 +378,9 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
indice_map_[buffer], analyzer_)) {
std::ostringstream oss;
oss << "Layout infer conflict between " << buffer << " and "
<< source_buffer << " in T.Parallel loop:" << std::endl
<< " loop " << loop_layout_->DebugOutput() << std::endl
<< " fragment " << fragment->DebugOutput() << std::endl;
<< source_buffer << " in T.Parallel loop:" << '\n'
<< " loop " << loop_layout_->DebugOutput() << '\n'
<< " fragment " << fragment->DebugOutput() << '\n';
throw LayoutConflictException(oss.str());
}
} else {
......@@ -427,5 +427,7 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
->CondenseReplicateVar();
}
TVM_FFI_STATIC_INIT_BLOCK({ ParallelOpNode::RegisterReflection(); });
} // namespace tl
} // namespace tvm
......@@ -13,97 +13,6 @@
#include "../transform/layout_reducer.h"
#include "./operator.h"
/**
* Exception indicating a layout conflict during layout inference or validation.
* The stored message is returned by what().
*/
/**
* Verify that `small_frag` is contained within `large_frag` under the provided
* index mappings and using symbolic reasoning via `analyzer_`.
*
* @param small_frag Fragment describing the smaller layout fragment.
* @param large_frag Fragment describing the larger layout fragment.
* @param small_frag_indices Index expressions that map accesses into
* `small_frag`.
* @param large_frag_indices Index expressions that map accesses into
* `large_frag`.
* @param analyzer_ Analyzer used for symbolic simplification and proving
* relations.
* @return true if `small_frag` can be proven to be contained in `large_frag`
* given the index mappings and analyzer; false otherwise.
*/
/**
* Visitor that traverses a parallel loop nest to collect loop structure,
* buffer access patterns, and to populate the associated ParallelOpNode.
*/
/**
* Construct a ParallelOpNode from a root For loop.
*
* @param root The TIR For node that is the root of the parallel loop nest.
*/
/**
* Lower this ParallelOpNode to a TIR statement.
*
* Performs lowering of the operator (including any necessary predicates,
* reductions, and loop transformations) to produce an equivalent tir::Stmt.
*
* @param T Lowering options and context.
* @param analyzer Optional analyzer for symbolic simplification during
* lowering.
* @return A tir::Stmt representing the lowered operator.
*/
/**
* Infer layouts for buffers used by this parallel operator.
*
* This performs layout inference at the requested level and returns a mapping
* from buffers to their inferred layout fragments.
*
* @param T Layout inference arguments and context.
* @param level Granularity level for inference.
* @return LayoutMap mapping buffers to inferred fragments.
*/
/**
* Return an optional predicate expression associated with the given thread
* variable.
*
* If the loop nest imposes a condition on `thread_var` (e.g., bounds checks or
* tiling edge predicates), this returns the combined predicate; otherwise
* returns an empty Optional.
*
* @param thread_var The thread variable for which to retrieve the predicate.
* @return Optional containing the predicate expression if present.
*/
/**
* Create and return a clone of this operator as a TileOperator (deep copy of
* operator state necessary for further transformations).
*
* @return A TileOperator referencing a cloned ParallelOpNode.
*/
/**
* Complete the layout fragment for `buffer` by filling in any missing
* dimension or stride information derived from access patterns in the loop
* nest.
*
* @param buffer The buffer whose fragment should be completed.
* @return A Fragment representing the completed layout for `buffer`.
*/
/**
* Determine whether `buffer` is accessed using only the loop-common indices
* (i.e., indices that correspond to the loop variables of this parallel nest).
*
* @param buffer The buffer to inspect.
* @return true if accesses use common loop indices; false otherwise.
*/
/**
* Conjoin `expr` into the operator's predicate (logical AND). If no predicate
* exists yet, `expr` becomes the predicate.
......@@ -148,6 +57,8 @@ private:
// predicates.
class ParallelOpNode : public TileOperatorNode {
public:
// The root For loop node.
For root_;
// The inferred layout for the loop, mutable to allow lazy inference.
mutable Fragment loop_layout_;
// The predicate expression for the loop, if any, mutable for lazy
......@@ -158,6 +69,28 @@ public:
static constexpr const char *_type_key = "tl.ParallelOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ParallelOpNode, TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<ParallelOpNode>()
.def_ro("root", &ParallelOpNode::root_)
.def_ro("loop_layout", &ParallelOpNode::loop_layout_)
.def_ro("predicate", &ParallelOpNode::predicate_);
}
bool SEqualReduce(const ParallelOpNode *other, SEqualReducer equal) const {
return equal(root_, other->root_) &&
equal(loop_layout_, other->loop_layout_) &&
equal(predicate_, other->predicate_);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(root_);
hash_reduce(loop_layout_);
hash_reduce(predicate_);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
// Construct from a root For loop.
ParallelOpNode(For root);
......@@ -198,8 +131,6 @@ private:
// Allow ParallelLoopNestVisitor to access private members.
friend class ParallelLoopNestVisitor;
// The root For loop node.
For root_;
// Visitor for collecting loop nest information.
ParallelLoopNestVisitor V;
// Mapping from buffer to their access indices in the loop.
......
/*!
* \file tl/op/reduce.cc
*
* Define reduce operator.
* \brief Implementation of reduction operators
*/
#include "reduce.h"
......@@ -28,18 +27,7 @@ ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
node->dst = vmap[GetVarFromAccessPtr(args[1])];
std::string reduce_type = args[2].as<StringImm>().value()->value;
node->dim = args[3].as<IntImm>().value()->value;
if (reduce_type == "sum")
node->type = ReduceType::kSum;
else if (reduce_type == "abssum")
node->type = ReduceType::kAbsSum;
else if (reduce_type == "absmax")
node->type = ReduceType::kAbsMax;
else if (reduce_type == "max")
node->type = ReduceType::kMax;
else if (reduce_type == "min")
node->type = ReduceType::kMin;
else
ICHECK(0) << "Unknown reduce type: " << reduce_type;
node->type = ReduceType(reduce_type);
node->clear = args[4].as<Bool>().value();
data_ = std::move(node);
}
......@@ -60,12 +48,11 @@ PrimExpr ReduceOpNode::MakeInitValue() const {
bool is_uint = dst_dtype.is_uint();
auto bits = dst_dtype.bits();
switch (type) {
case ReduceType::kSum:
if (type->isSum()) {
return make_zero(dst->dtype);
case ReduceType::kAbsSum:
} else if (type->isAbsSum()) {
return make_zero(dst->dtype);
case ReduceType::kMax:
} else if (type->isMax()) {
if (is_int) {
return make_const(dst->dtype, -(1 << (bits - 1)));
} else if (is_uint) {
......@@ -73,7 +60,7 @@ PrimExpr ReduceOpNode::MakeInitValue() const {
} else {
return make_const(dst->dtype, -INFINITY);
}
case ReduceType::kMin:
} else if (type->isMin()) {
if (is_int) {
return make_const(dst->dtype, (1 << (bits - 1)) - 1);
} else if (is_uint) {
......@@ -81,49 +68,47 @@ PrimExpr ReduceOpNode::MakeInitValue() const {
} else {
return make_const(dst->dtype, INFINITY);
}
case ReduceType::kAbsMax:
} else if (type->isAbsMax()) {
return make_const(dst->dtype, 0);
default:
ICHECK(0);
} else {
LOG(FATAL) << "Unsupported reduce type: " << type->type;
}
}
PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &a, const PrimExpr &b) const {
PrimExpr lhs = a, rhs = b;
PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &lhs,
const PrimExpr &b) const {
PrimExpr rhs = b;
if (lhs->dtype != rhs->dtype) {
rhs = Cast(lhs->dtype, rhs);
}
switch (type) {
case ReduceType::kSum:
if (type->isSum()) {
return lhs + rhs;
case ReduceType::kAbsSum:
} else if (type->isAbsSum()) {
return lhs + Max(rhs, -rhs);
case ReduceType::kMax:
} else if (type->isMax()) {
return Max(lhs, rhs);
case ReduceType::kMin:
} else if (type->isMin()) {
return Min(lhs, rhs);
case ReduceType::kAbsMax:
} else if (type->isAbsMax()) {
return Max(Max(lhs, rhs), -Min(lhs, rhs));
default:
ICHECK(0);
return PrimExpr(0);
} else {
LOG(FATAL) << "Unsupported reduce type: " << type->type;
}
}
std::string ReduceOpNode::MakeCodegenReducer() const {
switch (type) {
case ReduceType::kSum:
if (type->isSum()) {
return "tl::SumOp";
case ReduceType::kAbsSum:
} else if (type->isAbsSum()) {
return "tl::SumOp";
case ReduceType::kMax:
} else if (type->isMax()) {
return "tl::MaxOp";
case ReduceType::kMin:
} else if (type->isMin()) {
return "tl::MinOp";
case ReduceType::kAbsMax:
} else if (type->isAbsMax()) {
return "tl::MaxOp";
default:
ICHECK(0);
} else {
LOG(FATAL) << "Unsupported reduce type: " << type->type;
return "";
}
}
......@@ -206,17 +191,17 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
bool require_init = this->clear;
// sum op must be cleared
if (this->type == ReduceType::kSum) {
if (this->type->isSum()) {
require_init = true;
} else if (this->type == ReduceType::kAbsSum) {
} else if (this->type->isAbsSum()) {
require_init = true;
}
Buffer clear_buffer = dst_buffer;
bool need_duplicate = false;
if (this->type == ReduceType::kSum && !this->clear) {
if (this->type->isSum() && !this->clear) {
need_duplicate = true;
} else if (this->type == ReduceType::kAbsSum && !this->clear) {
} else if (this->type->isAbsSum() && !this->clear) {
need_duplicate = true;
}
......@@ -303,18 +288,18 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
// copy clear_buffer to dst_buffer
if (need_duplicate) {
// if is reduce sum, we should add a copy from clear_buffer to dst_buffer
if (this->type == ReduceType::kSum) {
if (this->type->isSum()) {
stmts.push_back(BufferStore(dst_buffer,
Add(BufferLoad(dst_buffer, dst_indices),
BufferLoad(clear_buffer, dst_indices)),
dst_indices));
} else if (this->type == ReduceType::kAbsSum) {
} else if (this->type->isAbsSum()) {
stmts.push_back(BufferStore(dst_buffer,
Add(BufferLoad(dst_buffer, dst_indices),
BufferLoad(clear_buffer, dst_indices)),
dst_indices));
} else {
ICHECK(false) << "Unsupported reduce type: " << (int)this->type;
ICHECK(false) << "Unsupported reduce type: " << this->type->type;
}
}
// make the outer spatial loop
......@@ -410,13 +395,11 @@ TIR_REGISTER_TL_OP(ReduceOp, reduce)
Integer(CallEffectKind::kOpaque));
CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
/*
CumSum arguments:
src: input buffer
dst: output buffer
dim: dimension to cumsum
reverse: whether to cumsum in reverse order
*/
/// CumSum constructor arguments:
/// - src: input buffer
/// - dst: output buffer
/// - dim: dimension to cumsum
/// - reverse: whether to cumsum in reverse order
CHECK_EQ(args.size(), 4);
ObjectPtr<CumSumOpNode> node = make_object<CumSumOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])];
......
/*!
* \file tl/op/reduce.h
* \brief Define reduce operator.
*
* \brief Reduction operators for tensor computations
*/
#ifndef TVM_TL_OP_REDUCE_H_
......@@ -10,180 +9,128 @@
#include "operator.h"
namespace tvm {
/**
* Tile operator node that performs a reduction (sum, max, min, etc.) along a
* single tensor dimension.
*
* Represents a per-instance reduce operator with explicit source/destination
* buffers, target dimension, reduction type, and a flag controlling whether the
* destination is cleared before reduction.
*/
/**
* Lower this ReduceOpNode into a Tir Stmt suitable for code generation.
*
* Produces the TIR statement(s) that implement the configured reduction.
*
* @return A TIR `Stmt` implementing the reduce operation.
*/
/**
* Infer input/output layouts for this reduce operator.
*
* Returns a LayoutMap describing how input and output buffer layouts relate
* for the configured reduction dimension.
*
* @param level Inference detail level that may affect how aggressively layouts
* are inferred.
* @return A LayoutMap mapping operator arguments to inferred layouts.
*/
/**
* Retrieve the global operator descriptor for the reduce operator.
*
* @return A reference to the Op descriptor corresponding to this operator type.
*/
/**
* Create a copy of this reduce operator as a TileOperator handle.
*
* The returned TileOperator preserves the node's configuration (buffers, dim,
* type, clear).
*
* @return A TileOperator wrapping a cloned ReduceOpNode.
*/
/**
* Construct the initial value used by the reduction (e.g., 0 for sum, -inf for
* max).
*
* @return A PrimExpr representing the reduction's identity/init value.
*/
/**
* Combine two partial values according to the configured reduction.
*
* Implements the binary reducer (for example, `a + b` for sum or `max(a, b)`
* for max).
*
* @return A PrimExpr representing the reduced result of `a` and `b`.
*/
/**
* Generate a string snippet suitable for code generation of the reducer
* expression.
*
* The returned code fragment should implement the binary reduction operation in
* the target backend's code string form.
*
* @return A std::string containing the codegen expression for the reducer.
*/
/**
* Reference wrapper for ReduceOpNode as a TileOperator.
*
* Construct a ReduceOp from explicit arguments and a buffer map.
*/
/**
* Construct a ReduceOp TileOperator from operator arguments and a buffer
* mapping.
*
* @param args Operator arguments (typically shapes, axes, or other prim exprs).
* @param vmap Mapping from argument names to tir::Buffer instances used by the
* operator.
*/
/**
* Tile operator node that computes a cumulative sum along a single tensor
* dimension.
*
* Contains source/destination buffers, the target dimension, and a flag to
* compute the cumulative sum in reverse order.
*/
/**
* Lower this CumSumOpNode into a Tir Stmt suitable for code generation.
*
* Produces the TIR statement(s) that implement the configured cumulative-sum.
*
* @return A TIR `Stmt` implementing the cum-sum operation.
*/
/**
* Infer input/output layouts for this cumulative-sum operator.
*
* Returns a LayoutMap describing how input and output buffer layouts relate
* for the configured cumulative-sum dimension.
*
* @param level Inference detail level that may affect how aggressively layouts
* are inferred.
* @return A LayoutMap mapping operator arguments to inferred layouts.
*/
/**
* Retrieve the global operator descriptor for the cumulative-sum operator.
*
* @return A reference to the Op descriptor corresponding to this operator type.
*/
/**
* Create a copy of this cum-sum operator as a TileOperator handle.
*
* The returned TileOperator preserves the node's configuration (buffers, dim,
* reverse).
*
* @return A TileOperator wrapping a cloned CumSumOpNode.
*/
/**
* Reference wrapper for CumSumOpNode as a TileOperator.
*
* Construct a CumSumOp from explicit arguments and a buffer map.
*/
/**
* Construct a CumSumOp TileOperator from operator arguments and a buffer
* mapping.
*
* @param args Operator arguments (typically shapes, axes, or other prim exprs).
* @param vmap Mapping from argument names to tir::Buffer instances used by the
* operator.
*/
namespace tl {
using namespace tir;
enum class ReduceType : uint8_t {
kSum,
kAbsSum,
kMax,
kMin,
kAbsMax,
/// Supported reduction operation types
enum class ReduceTypeEnum : uint8_t {
kSum, ///< Sum reduction
kAbsSum, ///< Absolute sum reduction
kMax, ///< Maximum value reduction
kMin, ///< Minimum value reduction
kAbsMax, ///< Maximum absolute value reduction
};
/// Node class representing a reduction type
class ReduceTypeNode : public Object {
public:
int type{-1}; ///< Internal type identifier
static constexpr const char *_type_key = "tl.ReduceType";
TVM_DECLARE_FINAL_OBJECT_INFO(ReduceTypeNode, Object);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<ReduceTypeNode>().def_ro("type", &ReduceTypeNode::type);
}
bool SEqualReduce(const ReduceTypeNode *other, SEqualReducer equal) const {
return equal(type, other->type);
}
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(type); }
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
/// Type checking methods
bool isSum() const { return type == int(ReduceTypeEnum::kSum); }
bool isAbsSum() const { return type == int(ReduceTypeEnum::kAbsSum); }
bool isMax() const { return type == int(ReduceTypeEnum::kMax); }
bool isMin() const { return type == int(ReduceTypeEnum::kMin); }
bool isAbsMax() const { return type == int(ReduceTypeEnum::kAbsMax); }
};
/// Wrapper class for reduction type with string-based construction
class ReduceType : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ReduceType, ObjectRef, ReduceTypeNode);
TVM_DLL ReduceType(std::string type) {
auto node = make_object<ReduceTypeNode>();
if (type == "sum") {
node->type = int(ReduceTypeEnum::kSum);
} else if (type == "abssum") {
node->type = int(ReduceTypeEnum::kAbsSum);
} else if (type == "max") {
node->type = int(ReduceTypeEnum::kMax);
} else if (type == "absmax") {
node->type = int(ReduceTypeEnum::kAbsMax);
} else if (type == "min") {
node->type = int(ReduceTypeEnum::kMin);
} else {
LOG(FATAL) << "Invalid reduce type: " << type;
}
data_ = std::move(node);
}
};
/// Node class for reduction operations
class ReduceOpNode : public TileOperatorNode {
public:
tir::Buffer src, dst;
int dim;
ReduceType type;
bool clear;
tir::Buffer src, dst; ///< Source and destination buffers
int dim; ///< Dimension to reduce along
ReduceType type; ///< Type of reduction operation
bool clear; ///< Whether to clear destination before reduction
static constexpr const char *_type_key = "tl.ReduceOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ReduceOpNode, TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<ReduceOpNode>()
.def_ro("src", &ReduceOpNode::src)
.def_ro("dst", &ReduceOpNode::dst)
.def_ro("dim", &ReduceOpNode::dim)
.def_ro("type", &ReduceOpNode::type)
.def_ro("clear", &ReduceOpNode::clear);
}
bool SEqualReduce(const ReduceOpNode *other, SEqualReducer equal) const {
return equal(src, other->src) && equal(dst, other->dst) &&
equal(dim, other->dim) && equal(type, other->type) &&
equal(clear, other->clear);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(src);
hash_reduce(dst);
hash_reduce(dim);
hash_reduce(type);
hash_reduce(clear);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
/// Lower the operator to TIR statements
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
/// Infer memory layout for buffers
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
static const Op &Get();
TileOperator Clone() const;
private:
/// Generate initial value for reduction
PrimExpr MakeInitValue() const;
/// Generate reduction expression
PrimExpr MakeReduce(const PrimExpr &a, const PrimExpr &b) const;
/// Generate codegen reducer string
std::string MakeCodegenReducer() const;
};
/// Wrapper class for reduction operations
class ReduceOp : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ReduceOp, TileOperator, ReduceOpNode);
......@@ -191,11 +138,12 @@ public:
static const Op &Get();
};
/// Node class for cumulative sum operations
class CumSumOpNode : public TileOperatorNode {
public:
tir::Buffer src, dst;
int dim;
bool reverse;
tir::Buffer src, dst; ///< Source and destination buffers
int dim; ///< Dimension along which to compute cumulative sum
bool reverse; ///< Whether to compute in reverse order
static constexpr const char *_type_key = "tl.CumSumOp";
TVM_DECLARE_FINAL_OBJECT_INFO(CumSumOpNode, TileOperatorNode);
......@@ -206,6 +154,7 @@ public:
TileOperator Clone() const;
};
/// Wrapper class for cumulative sum operations
class CumSumOp : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(CumSumOp, TileOperator, CumSumOpNode);
......
......@@ -93,6 +93,28 @@ public:
bool IsFullRegion() const;
TileOperator Clone() const override;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<RegionOpNode>()
.def_ro("buffer", &RegionOpNode::buffer_)
.def_ro("ranges", &RegionOpNode::ranges_)
.def_ro("access_mask", &RegionOpNode::access_mask_);
}
bool SEqualReduce(const RegionOpNode *other, SEqualReducer equal) const {
return equal(buffer_, other->buffer_) && equal(ranges_, other->ranges_) &&
equal(access_mask_, other->access_mask_);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(buffer_);
hash_reduce(ranges_);
hash_reduce(access_mask_);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
};
class RegionOp : public TileOperator {
......
......@@ -12,7 +12,7 @@
#include <tvm/tir/transform.h>
#include "../layout/layout.h"
#include "../op/elem.h"
#include "../op/fill.h"
#include "../op/finalize_reducer.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "layout_reducer.h"
......@@ -132,7 +132,7 @@ private:
.value_or(Map<Var, Layout>());
for (auto &&[k, v] : new_layout_map_)
layout_map.Set(k, v);
if (layout_map.size())
if (!layout_map.empty())
p_result->annotations.Set(attr::kLayoutMap, layout_map);
new_layout_map_.clear();
return result;
......@@ -178,7 +178,7 @@ private:
Stmt VisitStmt_(const ForNode *op) final {
// only annotate the outermost loop
bool should_annotate = false;
if (inside_reducer_range_.size() > 0 && !already_annotated_) {
if (!inside_reducer_range_.empty() && !already_annotated_) {
should_annotate = true;
already_annotated_ = true;
}
......@@ -202,7 +202,6 @@ private:
ICHECK(thread_var_.defined());
ICHECK(analyzer_->const_int_bound.IsBound(thread_var_->var));
auto const_int_bound = analyzer_->const_int_bound(thread_var_);
auto dtype = thread_var_->var.dtype();
int thread_min = const_int_bound->min_value;
int thread_extent =
const_int_bound->max_value - const_int_bound->min_value + 1;
......@@ -274,7 +273,7 @@ private:
auto op_ref = IRMutatorWithAnalyzer::VisitExpr_(op_).as<Call>().value();
auto op = op_ref.CopyOnWrite();
if (op->op.same_as(Fill::Get())) {
ICHECK(op->args.size() > 0);
ICHECK(!op->args.empty());
if (auto arg0_call = op->args[0].as<Call>();
arg0_call &&
arg0_call.value()->op.same_as(builtin::tvm_access_ptr())) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment