Unverified Commit b38bd69e authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Refactor `Operator` into `TileOperator` and with tvm reflection (#763)

* Refactor operator classes to inherit from TileOperator and update layout inference methods

- Changed base class of several operator classes (AtomicAdd, Copy, Gemm, etc.) from Operator to TileOperator for better alignment with tile operations.
- Updated InferLayout and Lower methods to use 'override' specifier for clarity and consistency.
- Adjusted header inclusions to replace "op.h" with "operator.h" across multiple files for improved organization.
- Added missing layout inference implementations for Fill and Conv2DIm2ColOp.
- Removed deprecated op.cc and op.h files to streamline the codebase.

* lint fix

* Refactor operator classes to use Node pattern and improve memory management

- Updated several operator classes (AtomicAdd, Copy, Gemm, etc.) to utilize the Node pattern for better memory management and encapsulation.
- Changed constructors to initialize member variables through a node object, enhancing clarity and reducing direct member access.
- Updated Clone methods to return TileOperator instances instead of unique pointers, aligning with the new design.
- Refactored InferLayout and Lower methods to ensure consistency across operator implementations.
- Adjusted header files to reflect the new class structure and removed deprecated code for a cleaner codebase.

* Enhance Clone methods in AtomicAdd and Copy classes to support parallel operation cloning

- Updated the Clone methods in AtomicAddNode and CopyNode to ensure that the parallel operation (par_op_) is properly cloned when defined, improving the integrity of cloned objects.
- Refactored the FillNode class to use ParallelOp directly instead of std::make_unique, streamlining the creation of parallel operations.
- Made minor adjustments in layout inference and other related methods for consistency and clarity.

* Refactor FillNode::Lower method to remove unused global function call

- Eliminated the call to the global function "tl.fill.lower" in the FillNode::Lower method, streamlining the code and improving clarity.
- Retained the core functionality of the method while enhancing maintainability by reducing unnecessary dependencies.
parent 277ed53c
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
* Define elment-wise operators. * Define elment-wise operators.
*/ */
#include "atomic_add.h" #include "./atomic_add.h"
#include "./region.h"
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h> #include <tvm/tir/op_attr_types.h>
...@@ -34,7 +34,8 @@ static int GetArchInt(Target target) { ...@@ -34,7 +34,8 @@ static int GetArchInt(Target target) {
return arch_int; return arch_int;
} }
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) : args_(args) { AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<AtomicAddNode> node = make_object<AtomicAddNode>();
Array<Range> rgs[2]; Array<Range> rgs[2];
Buffer bf[2]; Buffer bf[2];
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
...@@ -42,17 +43,26 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) : args_(args) { ...@@ -42,17 +43,26 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) : args_(args) {
auto call = expr.as<CallNode>(); auto call = expr.as<CallNode>();
ICHECK(call); ICHECK(call);
auto region = RegionOp(call->args, vmap); auto region = RegionOp(call->args, vmap);
rgs[i] = region.GetRanges(); rgs[i] = region->GetRanges();
bf[i] = region.GetBuffer(); bf[i] = region->GetBuffer();
} }
std::tie(this->src, this->dst) = std::tie(bf[0], bf[1]); std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
std::tie(this->src_range, this->dst_range) = std::tie(rgs[0], rgs[1]); std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
if (args.size() >= 3) { if (args.size() >= 3) {
coalesced_width = Downcast<IntImm>(args[2]); node->coalesced_width = Downcast<IntImm>(args[2]);
} }
data_ = std::move(node);
} }
Array<IterVar> AtomicAdd::MakeIterVars() const { TileOperator AtomicAddNode::Clone() const {
auto op = make_object<AtomicAddNode>(*this);
if (par_op_.defined()) {
op->par_op_ = Downcast<ParallelOp>(par_op_->Clone());
}
return AtomicAdd(op);
}
Array<IterVar> AtomicAddNode::MakeIterVars() const {
Array<IterVar> loop_vars; Array<IterVar> loop_vars;
size_t idx = 0; size_t idx = 0;
for (size_t i = 0; i < src_range.size(); i++) { for (size_t i = 0; i < src_range.size(); i++) {
...@@ -68,8 +78,8 @@ Array<IterVar> AtomicAdd::MakeIterVars() const { ...@@ -68,8 +78,8 @@ Array<IterVar> AtomicAdd::MakeIterVars() const {
// ivs: itervars returned by MakeIterVars() // ivs: itervars returned by MakeIterVars()
// src_dst: 0 for src_indices, 1 for dst_indices // src_dst: 0 for src_indices, 1 for dst_indices
Array<PrimExpr> AtomicAdd::MakeIndices(const Array<IterVar> &ivs, Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
int src_dst) const { int src_dst) const {
Array<PrimExpr> indices; Array<PrimExpr> indices;
Array<Range> ranges = src_dst == 0 ? src_range : dst_range; Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
size_t idx = 0; size_t idx = 0;
...@@ -87,9 +97,10 @@ Array<PrimExpr> AtomicAdd::MakeIndices(const Array<IterVar> &ivs, ...@@ -87,9 +97,10 @@ Array<PrimExpr> AtomicAdd::MakeIndices(const Array<IterVar> &ivs,
return indices; return indices;
} }
PrimExpr AtomicAdd::MakePredicate(arith::Analyzer *analyzer, PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
const Array<IterVar> &ivs, const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const { Array<PrimExpr> extents,
int src_dst) const {
Array<Range> ranges = src_dst == 0 ? src_range : dst_range; Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
Array<PrimExpr> cond_list; Array<PrimExpr> cond_list;
ICHECK(extents.size() == ranges.size()) << extents << " " << ranges; ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
...@@ -117,7 +128,7 @@ PrimExpr AtomicAdd::MakePredicate(arith::Analyzer *analyzer, ...@@ -117,7 +128,7 @@ PrimExpr AtomicAdd::MakePredicate(arith::Analyzer *analyzer,
} }
} }
For AtomicAdd::MakeSIMTLoop(arith::Analyzer *analyzer) const { For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<IterVar> loop_vars = MakeIterVars(); Array<IterVar> loop_vars = MakeIterVars();
bool is_scalar = loop_vars.size() == 0; bool is_scalar = loop_vars.size() == 0;
if (is_scalar) { if (is_scalar) {
...@@ -180,16 +191,16 @@ For AtomicAdd::MakeSIMTLoop(arith::Analyzer *analyzer) const { ...@@ -180,16 +191,16 @@ For AtomicAdd::MakeSIMTLoop(arith::Analyzer *analyzer) const {
return Downcast<For>(body); return Downcast<For>(body);
} }
Stmt AtomicAdd::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Target target = T.target; Target target = T.target;
auto simt_loop = MakeSIMTLoop(analyzer); auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop)); auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
auto par_op = std::make_unique<ParallelOp>(fused_loop); auto par_op = ParallelOp(fused_loop);
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict, std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree}; InferLevel::kFree};
for (auto level : levels) { for (auto level : levels) {
par_op->InferLayout( (par_op)->InferLayout(
{T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level); {T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level);
} }
auto loop_layout = par_op->GetLoopLayout(); auto loop_layout = par_op->GetLoopLayout();
...@@ -210,10 +221,11 @@ Stmt AtomicAdd::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -210,10 +221,11 @@ Stmt AtomicAdd::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return vectorized_thread_loop; return vectorized_thread_loop;
} }
LayoutMap AtomicAdd::InferLayout(const LayoutInferArgs &T, InferLevel level) { LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
if (par_op_ == nullptr) { InferLevel level) const {
if (!par_op_.defined()) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
par_op_ = std::make_unique<ParallelOp>(MakeSIMTLoop(&analyzer)); par_op_ = ParallelOp(MakeSIMTLoop(&analyzer));
} }
if (T.layout_map.count(src) && T.layout_map.count(dst)) { if (T.layout_map.count(src) && T.layout_map.count(dst)) {
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") { if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") {
...@@ -236,10 +248,5 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd) ...@@ -236,10 +248,5 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
// TVM_REGISTER_OP("tl.atomicadd")
// .set_num_inputs(2)
// .add_argument("ref", "Buffer", "The destination buffer")
// .add_argument("val", "Expr", "The value to be added atomically");
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
\ No newline at end of file
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#ifndef TVM_TL_OP_ATOMIC_ADD_H_ #ifndef TVM_TL_OP_ATOMIC_ADD_H_
#define TVM_TL_OP_ATOMIC_ADD_H_ #define TVM_TL_OP_ATOMIC_ADD_H_
#include "op.h" #include "operator.h"
#include "parallel.h" #include "parallel.h"
namespace tvm { namespace tvm {
...@@ -15,26 +15,23 @@ namespace tl { ...@@ -15,26 +15,23 @@ namespace tl {
using namespace tir; using namespace tir;
class AtomicAdd : public Operator { class AtomicAddNode : public TileOperatorNode {
public: public:
AtomicAdd(Array<PrimExpr> args, BufferMap vmap); Array<PrimExpr> args_;
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op &Get(); Buffer src, dst;
Array<Range> src_range, dst_range;
IntImm coalesced_width;
AtomicAdd(const AtomicAdd &other) mutable ParallelOp par_op_;
: args_(other.args_), src(other.src), dst(other.dst), static constexpr const char *_type_key = "tl.AtomicAdd";
src_range(other.src_range), dst_range(other.dst_range), TVM_DECLARE_FINAL_OBJECT_INFO(AtomicAddNode, TileOperatorNode);
coalesced_width(other.coalesced_width) {
// No clone nullptr Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
if (other.par_op_) LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
par_op_ = std::unique_ptr<ParallelOp>(
static_cast<ParallelOp *>(other.par_op_->Clone().release())); static const Op &Get();
} TileOperator Clone() const;
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<AtomicAdd>(*this);
}
protected: protected:
For MakeSIMTLoop(arith::Analyzer *analyzer) const; For MakeSIMTLoop(arith::Analyzer *analyzer) const;
...@@ -46,14 +43,13 @@ protected: ...@@ -46,14 +43,13 @@ protected:
PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs, PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const; Array<PrimExpr> extents, int src_dst) const;
};
Array<PrimExpr> args_; class AtomicAdd : public TileOperator {
public:
Buffer src, dst; TVM_DEFINE_OBJECT_REF_METHODS(AtomicAdd, TileOperator, AtomicAddNode);
Array<Range> src_range, dst_range; TVM_DLL AtomicAdd(Array<PrimExpr> args, BufferMap vmap);
IntImm coalesced_width; static const Op &Get();
std::unique_ptr<ParallelOp> par_op_;
}; };
} // namespace tl } // namespace tl
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#ifndef TVM_TL_OP_BUILTIN_H_ #ifndef TVM_TL_OP_BUILTIN_H_
#define TVM_TL_OP_BUILTIN_H_ #define TVM_TL_OP_BUILTIN_H_
#include "op.h" #include "operator.h"
#include <tvm/ir/transform.h> #include <tvm/ir/transform.h>
namespace tvm { namespace tvm {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "../transform/common/loop_parallel_transform_utils.h" #include "../transform/common/loop_parallel_transform_utils.h"
#include "../transform/loop_partition.h" #include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h" #include "../transform/loop_vectorize.h"
#include "region.h"
#include "../target/cuda.h" #include "../target/cuda.h"
#include "../target/utils.h" #include "../target/utils.h"
...@@ -111,7 +112,8 @@ template <typename T> static Array<T> ReverseArray(Array<T> array) { ...@@ -111,7 +112,8 @@ template <typename T> static Array<T> ReverseArray(Array<T> array) {
* operation. \param vmap BufferMap mapping original buffer names to new buffer * operation. \param vmap BufferMap mapping original buffer names to new buffer
* names. * names.
*/ */
Copy::Copy(Array<PrimExpr> args, BufferMap vmap) : args_(args) { Copy::Copy(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<CopyNode> node = make_object<CopyNode>();
Array<Range> rgs[2]; Array<Range> rgs[2];
Buffer bf[2]; Buffer bf[2];
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
...@@ -119,23 +121,32 @@ Copy::Copy(Array<PrimExpr> args, BufferMap vmap) : args_(args) { ...@@ -119,23 +121,32 @@ Copy::Copy(Array<PrimExpr> args, BufferMap vmap) : args_(args) {
auto call = expr.as<CallNode>(); auto call = expr.as<CallNode>();
ICHECK(call); ICHECK(call);
auto region = RegionOp(call->args, vmap); auto region = RegionOp(call->args, vmap);
rgs[i] = region.GetRanges(); rgs[i] = region->GetRanges();
bf[i] = region.GetBuffer(); bf[i] = region->GetBuffer();
} }
std::tie(this->src, this->dst) = std::tie(bf[0], bf[1]); std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
std::tie(this->src_range, this->dst_range) = std::tie(rgs[0], rgs[1]); std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
if (args.size() >= 3) { if (args.size() >= 3) {
auto coalesced_width = Downcast<IntImm>(args[2]); auto coalesced_width = Downcast<IntImm>(args[2]);
if (coalesced_width->value > 0) { if (coalesced_width->value > 0) {
this->coalesced_width = coalesced_width; node->coalesced_width = coalesced_width;
} }
} }
if (args.size() >= 4) { if (args.size() >= 4) {
this->disable_tma = Downcast<Bool>(args[3]); node->disable_tma = Downcast<Bool>(args[3]);
} }
if (args.size() >= 5) { if (args.size() >= 5) {
this->eviction_policy = args[4].as<IntImmNode>()->value; node->eviction_policy = args[4].as<IntImmNode>()->value;
} }
data_ = std::move(node);
}
TileOperator CopyNode::Clone() const {
auto op = make_object<CopyNode>(*this);
if (par_op_.defined()) {
op->par_op_ = Downcast<ParallelOp>(par_op_->Clone());
}
return Copy(op);
} }
/*! /*!
...@@ -144,7 +155,7 @@ Copy::Copy(Array<PrimExpr> args, BufferMap vmap) : args_(args) { ...@@ -144,7 +155,7 @@ Copy::Copy(Array<PrimExpr> args, BufferMap vmap) : args_(args) {
* > 1. \return Array of IterVar representing the iterator variables for the * > 1. \return Array of IterVar representing the iterator variables for the
* copy operation. * copy operation.
*/ */
Array<IterVar> Copy::MakeIterVars() const { Array<IterVar> CopyNode::MakeIterVars() const {
Array<IterVar> loop_vars; Array<IterVar> loop_vars;
size_t idx = 0; size_t idx = 0;
for (size_t i = 0; i < src_range.size(); i++) { for (size_t i = 0; i < src_range.size(); i++) {
...@@ -167,8 +178,8 @@ Array<IterVar> Copy::MakeIterVars() const { ...@@ -167,8 +178,8 @@ Array<IterVar> Copy::MakeIterVars() const {
* dst_indices. \return Array of PrimExpr representing the indices for the copy * dst_indices. \return Array of PrimExpr representing the indices for the copy
* operation. * operation.
*/ */
Array<PrimExpr> Copy::MakeIndices(const Array<IterVar> &ivs, Array<PrimExpr> CopyNode::MakeIndices(const Array<IterVar> &ivs,
int src_dst) const { int src_dst) const {
Array<PrimExpr> indices; Array<PrimExpr> indices;
Array<Range> ranges = src_dst == 0 ? src_range : dst_range; Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
size_t idx = 0; size_t idx = 0;
...@@ -195,9 +206,9 @@ Array<PrimExpr> Copy::MakeIndices(const Array<IterVar> &ivs, ...@@ -195,9 +206,9 @@ Array<PrimExpr> Copy::MakeIndices(const Array<IterVar> &ivs,
* of the copy operation. \param src_dst 0 for src_indices, 1 for dst_indices. * of the copy operation. \param src_dst 0 for src_indices, 1 for dst_indices.
* \return PrimExpr representing the predicate for the copy operation. * \return PrimExpr representing the predicate for the copy operation.
*/ */
PrimExpr Copy::MakePredicate(arith::Analyzer *analyzer, PrimExpr CopyNode::MakePredicate(arith::Analyzer *analyzer,
const Array<IterVar> &ivs, Array<PrimExpr> extents, const Array<IterVar> &ivs,
int src_dst) const { Array<PrimExpr> extents, int src_dst) const {
Array<Range> ranges = src_dst == 0 ? src_range : dst_range; Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
Array<PrimExpr> cond_list; Array<PrimExpr> cond_list;
ICHECK(extents.size() == ranges.size()) << extents << " " << ranges; ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
...@@ -233,7 +244,7 @@ PrimExpr Copy::MakePredicate(arith::Analyzer *analyzer, ...@@ -233,7 +244,7 @@ PrimExpr Copy::MakePredicate(arith::Analyzer *analyzer,
* simplification. \return For representing the SIMT loop for the copy * simplification. \return For representing the SIMT loop for the copy
* operation. * operation.
*/ */
For Copy::MakeSIMTLoop(arith::Analyzer *analyzer) const { For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<IterVar> loop_vars = MakeIterVars(); Array<IterVar> loop_vars = MakeIterVars();
bool is_scalar = loop_vars.size() == 0; bool is_scalar = loop_vars.size() == 0;
if (is_scalar) { if (is_scalar) {
...@@ -289,7 +300,7 @@ For Copy::MakeSIMTLoop(arith::Analyzer *analyzer) const { ...@@ -289,7 +300,7 @@ For Copy::MakeSIMTLoop(arith::Analyzer *analyzer) const {
* shared tensor. \return Layout representing the linear layout for the TMA * shared tensor. \return Layout representing the linear layout for the TMA
* copy. * copy.
*/ */
Layout Copy::ComputeLinearLayout(const Buffer &shared_tensor) const { Layout CopyNode::ComputeLinearLayout(const Buffer &shared_tensor) const {
Array<PrimExpr> input_size = shared_tensor->shape; Array<PrimExpr> input_size = shared_tensor->shape;
Array<PrimExpr> forward_vars; Array<PrimExpr> forward_vars;
for (size_t i = 0; i < input_size.size(); i++) { for (size_t i = 0; i < input_size.size(); i++) {
...@@ -316,7 +327,8 @@ Layout Copy::ComputeLinearLayout(const Buffer &shared_tensor) const { ...@@ -316,7 +327,8 @@ Layout Copy::ComputeLinearLayout(const Buffer &shared_tensor) const {
* indicating the level of layout inference. \return LayoutMap containing the * indicating the level of layout inference. \return LayoutMap containing the
* inferred layout. * inferred layout.
*/ */
LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) { LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
auto target = T.target; auto target = T.target;
using namespace tvm::transform; using namespace tvm::transform;
PassContext pass_ctx = PassContext::Current(); PassContext pass_ctx = PassContext::Current();
...@@ -340,17 +352,15 @@ LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -340,17 +352,15 @@ LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) {
return Map<Buffer, Layout>({{shared_tensor, linear_layout}}); return Map<Buffer, Layout>({{shared_tensor, linear_layout}});
} }
} }
// for LDSM/STSM, the layout was deduced from register layout // for LDSM/STSM, the layout was deduced from register layout
// so we can directly apply the layout of normal copy // so we can directly apply the layout of normal copy
// Use parallel op to infer the layout // Use parallel op to infer the layout
if (!par_op_) { if (!par_op_.defined()) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
par_op_ = std::make_unique<ParallelOp>(MakeSIMTLoop(&analyzer)); par_op_ = ParallelOp((MakeSIMTLoop(&analyzer)));
} }
return par_op_->InferLayout(T, level); return par_op_->InferLayout(T, level);
} }
/*! /*!
* \brief Check if the copy operation is a bulk load. * \brief Check if the copy operation is a bulk load.
* This function verifies if the copy operation can be implemented using CUDA's * This function verifies if the copy operation can be implemented using CUDA's
...@@ -359,7 +369,7 @@ LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -359,7 +369,7 @@ LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) {
* same data type. \param target Target device. \return True if the copy * same data type. \param target Target device. \return True if the copy
* operation is a bulk load, false otherwise. * operation is a bulk load, false otherwise.
*/ */
bool Copy::CheckBulkLoad(Target target) const { bool CopyNode::CheckBulkLoad(Target target) const {
// 1. arch must have bulk copy support // 1. arch must have bulk copy support
if (!TargetHasBulkCopy(target)) if (!TargetHasBulkCopy(target))
return false; return false;
...@@ -387,7 +397,7 @@ bool Copy::CheckBulkLoad(Target target) const { ...@@ -387,7 +397,7 @@ bool Copy::CheckBulkLoad(Target target) const {
* same data type. \param target Target device. \return True if the copy * same data type. \param target Target device. \return True if the copy
* operation is a bulk store, false otherwise. * operation is a bulk store, false otherwise.
*/ */
bool Copy::CheckBulkStore(Target target) const { bool CopyNode::CheckBulkStore(Target target) const {
// 1. arch must have bulk copy support // 1. arch must have bulk copy support
if (!TargetHasBulkCopy(target)) if (!TargetHasBulkCopy(target))
return false; return false;
...@@ -415,7 +425,7 @@ bool Copy::CheckBulkStore(Target target) const { ...@@ -415,7 +425,7 @@ bool Copy::CheckBulkStore(Target target) const {
* Target device. \return True if the copy operation is a LDSM copy, false * Target device. \return True if the copy operation is a LDSM copy, false
* otherwise. * otherwise.
*/ */
bool Copy::CheckLDSMCopy(Target target) const { bool CopyNode::CheckLDSMCopy(Target target) const {
return TargetHasLdmatrix(target) && return TargetHasLdmatrix(target) &&
(src.scope() == "shared.dyn" || src.scope() == "shared") && (src.scope() == "shared.dyn" || src.scope() == "shared") &&
dst.scope() == "local.fragment"; dst.scope() == "local.fragment";
...@@ -429,7 +439,7 @@ bool Copy::CheckLDSMCopy(Target target) const { ...@@ -429,7 +439,7 @@ bool Copy::CheckLDSMCopy(Target target) const {
* Target device. \return True if the copy operation is a STSM copy, false * Target device. \return True if the copy operation is a STSM copy, false
* otherwise. * otherwise.
*/ */
bool Copy::CheckSTSMCopy(Target target) const { bool CopyNode::CheckSTSMCopy(Target target) const {
return TargetHasStmatrix(target) && src.scope() == "local.fragment" && return TargetHasStmatrix(target) && src.scope() == "local.fragment" &&
(dst.scope() == "shared.dyn" || dst.scope() == "shared"); (dst.scope() == "shared.dyn" || dst.scope() == "shared");
} }
...@@ -442,7 +452,7 @@ bool Copy::CheckSTSMCopy(Target target) const { ...@@ -442,7 +452,7 @@ bool Copy::CheckSTSMCopy(Target target) const {
* copy if no specialized instruction is applicable. \param target Target * copy if no specialized instruction is applicable. \param target Target
* device. \return CopyInst representing the copy instruction type. * device. \return CopyInst representing the copy instruction type.
*/ */
Copy::CopyInst Copy::GetCopyInst(Target target, bool disable_tma_lower) const { CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower) const {
// disable_tma_lower is from pass_configs // disable_tma_lower is from pass_configs
// when tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER is True, // when tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER is True,
// we will not use tma for bulk load/store // we will not use tma for bulk load/store
...@@ -471,7 +481,7 @@ Copy::CopyInst Copy::GetCopyInst(Target target, bool disable_tma_lower) const { ...@@ -471,7 +481,7 @@ Copy::CopyInst Copy::GetCopyInst(Target target, bool disable_tma_lower) const {
* \param analyzer Arithmetic analyzer for simplification. * \param analyzer Arithmetic analyzer for simplification.
* \return Stmt representing the PTX code for the copy operation. * \return Stmt representing the PTX code for the copy operation.
*/ */
Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Target target = T.target; Target target = T.target;
using namespace tvm::transform; using namespace tvm::transform;
PassContext pass_ctx = PassContext::Current(); PassContext pass_ctx = PassContext::Current();
...@@ -502,8 +512,8 @@ Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -502,8 +512,8 @@ Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
* map. \param analyzer Arithmetic analyzer for simplification. \return Stmt * map. \param analyzer Arithmetic analyzer for simplification. \return Stmt
* representing the normal copy code. * representing the normal copy code.
*/ */
Stmt Copy::LowerNormalCopy(const LowerArgs &T, Stmt CopyNode::LowerNormalCopy(const LowerArgs &T,
arith::Analyzer *analyzer) const { arith::Analyzer *analyzer) const {
bool is_cpu_target = T.target->GetTargetDeviceType() == kDLCPU; bool is_cpu_target = T.target->GetTargetDeviceType() == kDLCPU;
auto simt_loop = MakeSIMTLoop(analyzer); auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop)); auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
...@@ -512,7 +522,7 @@ Stmt Copy::LowerNormalCopy(const LowerArgs &T, ...@@ -512,7 +522,7 @@ Stmt Copy::LowerNormalCopy(const LowerArgs &T,
Downcast<For>(ParallelLoopTransformer::Substitute(fused_loop)); Downcast<For>(ParallelLoopTransformer::Substitute(fused_loop));
For vectorized_thread_loop; For vectorized_thread_loop;
auto par_op = std::make_unique<ParallelOp>(transformed_loop); auto par_op = ParallelOp(transformed_loop);
if (is_cpu_target) { if (is_cpu_target) {
vectorized_thread_loop = VectorizeLoop(transformed_loop); vectorized_thread_loop = VectorizeLoop(transformed_loop);
...@@ -548,8 +558,8 @@ Stmt Copy::LowerNormalCopy(const LowerArgs &T, ...@@ -548,8 +558,8 @@ Stmt Copy::LowerNormalCopy(const LowerArgs &T,
* \param copy_inst CopyInst representing the copy instruction type. * \param copy_inst CopyInst representing the copy instruction type.
* \return Stmt representing the LDSM/STSM copy code. * \return Stmt representing the LDSM/STSM copy code.
*/ */
Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer, Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer,
CopyInst copy_inst) const { CopyInst copy_inst) const {
ICHECK(copy_inst == CopyInst::kLDSM || copy_inst == CopyInst::kSTSM) ICHECK(copy_inst == CopyInst::kLDSM || copy_inst == CopyInst::kSTSM)
<< "Invalid copy inst " << static_cast<int>(copy_inst); << "Invalid copy inst " << static_cast<int>(copy_inst);
bool is_ldmatrix = copy_inst == CopyInst::kLDSM; bool is_ldmatrix = copy_inst == CopyInst::kLDSM;
...@@ -741,8 +751,8 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer, ...@@ -741,8 +751,8 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer,
* copy_inst CopyInst representing the copy instruction type. \return Stmt * copy_inst CopyInst representing the copy instruction type. \return Stmt
* representing the bulk copy code. * representing the bulk copy code.
*/ */
Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
CopyInst copy_inst) const { CopyInst copy_inst) const {
ICHECK(copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkStore) ICHECK(copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkStore)
<< "Invalid copy inst " << static_cast<int>(copy_inst); << "Invalid copy inst " << static_cast<int>(copy_inst);
bool is_load = copy_inst == CopyInst::kBulkLoad; bool is_load = copy_inst == CopyInst::kBulkLoad;
...@@ -1153,15 +1163,22 @@ Array<PrimExpr> TMADesc::EncodeCallArgs() const { ...@@ -1153,15 +1163,22 @@ Array<PrimExpr> TMADesc::EncodeCallArgs() const {
* buffer names to new buffer names. * buffer names to new buffer names.
*/ */
Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) { Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) {
src = vmap[GetVarFromAccessPtr(args[0])]; ObjectPtr<Conv2DIm2ColOpNode> node = make_object<Conv2DIm2ColOpNode>();
dst = vmap[GetVarFromAccessPtr(args[1])]; node->src = vmap[GetVarFromAccessPtr(args[0])];
nhw_step = args[2]; node->dst = vmap[GetVarFromAccessPtr(args[1])];
c_step = args[3]; node->nhw_step = args[2];
kernel = args[4].as<IntImm>().value()->value; node->c_step = args[3];
stride = args[5].as<IntImm>().value()->value; node->kernel = args[4].as<IntImm>().value()->value;
dilation = args[6].as<IntImm>().value()->value; node->stride = args[5].as<IntImm>().value()->value;
padding = args[7].as<IntImm>().value()->value; node->dilation = args[6].as<IntImm>().value()->value;
eviction_policy = args[8].as<IntImm>().value()->value; node->padding = args[7].as<IntImm>().value()->value;
node->eviction_policy = args[8].as<IntImm>().value()->value;
data_ = std::move(node);
}
TileOperator Conv2DIm2ColOpNode::Clone() const {
auto op = make_object<Conv2DIm2ColOpNode>(*this);
return Conv2DIm2ColOp(op);
} }
/*! /*!
...@@ -1174,8 +1191,8 @@ Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) { ...@@ -1174,8 +1191,8 @@ Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) {
* \param analyzer Arithmetic analyzer for simplification. * \param analyzer Arithmetic analyzer for simplification.
* \return Stmt representing the PTX code for the Conv2DIm2ColOp. * \return Stmt representing the PTX code for the Conv2DIm2ColOp.
*/ */
Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T, Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T,
arith::Analyzer *analyzer) const { arith::Analyzer *analyzer) const {
ICHECK(TargetIsHopper(T.target)); ICHECK(TargetIsHopper(T.target));
ICHECK(src.scope() == "global" && ICHECK(src.scope() == "global" &&
(dst.scope() == "shared.dyn" || dst.scope() == "shared")); (dst.scope() == "shared.dyn" || dst.scope() == "shared"));
...@@ -1343,6 +1360,11 @@ TIR_REGISTER_TL_OP(Copy, copy) ...@@ -1343,6 +1360,11 @@ TIR_REGISTER_TL_OP(Copy, copy)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
LayoutMap Conv2DIm2ColOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
return {};
}
// Register the Conv2DIm2Col operation with TVM's TIR system // Register the Conv2DIm2Col operation with TVM's TIR system
// This operation performs im2col transformation for 2D convolutions using TMA // This operation performs im2col transformation for 2D convolutions using TMA
// - Takes 9 inputs: src_buffer, dst_buffer, nhw_step, c_step, kernel, stride, // - Takes 9 inputs: src_buffer, dst_buffer, nhw_step, c_step, kernel, stride,
......
...@@ -11,13 +11,24 @@ ...@@ -11,13 +11,24 @@
#ifndef TVM_TL_OP_COPY_H_ #ifndef TVM_TL_OP_COPY_H_
#define TVM_TL_OP_COPY_H_ #define TVM_TL_OP_COPY_H_
#include "op.h" #include "operator.h"
#include "parallel.h" #include "parallel.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
/*!
* \brief Copy instruction type.
*/
enum class CopyInst {
kNormal = 0, // utilize ldg/stg or cpasync or any buffer copy
kLDSM = 1, // ldmatrix memory copy
kSTSM = 2, // stmatrix memory copy
kBulkLoad = 3, // utilize tma load
kBulkStore = 4, // utilize tma store
};
/*! /*!
* \brief Descriptor for Tensor Memory Access (TMA) copy operations. * \brief Descriptor for Tensor Memory Access (TMA) copy operations.
* *
...@@ -83,44 +94,40 @@ struct TMAIm2ColDesc { ...@@ -83,44 +94,40 @@ struct TMAIm2ColDesc {
* block-wise or element-wise data transfer, possibly optimized with * block-wise or element-wise data transfer, possibly optimized with
* parallelization or TMA hardware acceleration. * parallelization or TMA hardware acceleration.
*/ */
class Copy : public Operator { class CopyNode : public TileOperatorNode {
public: public:
/*! Array<PrimExpr> args_; // Copy parameters (indices, sizes, etc.)
* \brief Constructor.
* \param args Expression arguments for the copy. Buffer src, dst; // Source and destination buffers
* \param vmap Buffer variable mapping. Array<Range> src_range, dst_range; // Ranges for each dimension in src and dst
*/ IntImm coalesced_width; // Width (in elements) for coalesced memory access
Copy(Array<PrimExpr> args, BufferMap vmap); Bool disable_tma = Bool(false); // Whether to disable TMA acceleration
mutable ParallelOp par_op_; // Optional associated parallelization operator
enum class EvictionPolicy {
kEvictNormal = 0,
kEvictFirst = 1,
kEvictLast = 2,
};
int eviction_policy; // Policy for cache eviction
static constexpr const char *_type_key = "tl.Copy";
TVM_DECLARE_FINAL_OBJECT_INFO(CopyNode, TileOperatorNode);
/*! /*!
* \brief Lower the copy operator to a TIR statement. * \brief Lower the copy operator to a TIR statement.
* \param T Arguments for lowering. * \param T Arguments for lowering.
* \param analyzer Analyzer for simplification and bounds checks. * \param analyzer Analyzer for simplification and bounds checks.
*/ */
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
/*! /*!
* \brief Infer buffer layouts after applying this operator. * \brief Infer buffer layouts after applying this operator.
* \param T Arguments for layout inference. * \param T Arguments for layout inference.
* \param level Level of inference (basic or detailed). * \param level Level of inference (basic or detailed).
*/ */
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
/*!
* \brief Get the TVM Op handle corresponding to this Copy op.
*/
static const Op &Get();
/*!
* \brief Copy instruction type.
*/
enum class CopyInst {
kNormal = 0, // utilize ldg/stg or cpasync or any buffer copy
kLDSM = 1, // ldmatrix memory copy
kSTSM = 2, // stmatrix memory copy
kBulkLoad = 3, // utilize tma load
kBulkStore = 4, // utilize tma store
};
/*! /*!
* \brief Check if bulk copy is supported. * \brief Check if bulk copy is supported.
...@@ -147,26 +154,9 @@ public: ...@@ -147,26 +154,9 @@ public:
*/ */
CopyInst GetCopyInst(Target target, bool disable_tma_lower) const; CopyInst GetCopyInst(Target target, bool disable_tma_lower) const;
/*!
* \brief Copy constructor (deep clones ParallelOp if present).
*/
Copy(const Copy &other)
: args_(other.args_), src(other.src), dst(other.dst),
src_range(other.src_range), dst_range(other.dst_range),
coalesced_width(other.coalesced_width), disable_tma(other.disable_tma) {
// Deep copy ParallelOp if it exists
if (other.par_op_)
par_op_ = std::unique_ptr<ParallelOp>(
static_cast<ParallelOp *>(other.par_op_->Clone().release()));
}
/*! /*!
* \brief Clone this copy operator. * \brief Clone this copy operator.
*/ */
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<Copy>(*this);
}
protected: protected:
/*! /*!
* \brief Generate lowering for bulk/global-to-shared copy. * \brief Generate lowering for bulk/global-to-shared copy.
...@@ -218,23 +208,24 @@ protected: ...@@ -218,23 +208,24 @@ protected:
PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs, PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const; Array<PrimExpr> extents, int src_dst) const;
Array<PrimExpr> args_; // Copy parameters (indices, sizes, etc.) TileOperator Clone() const;
};
Buffer src, dst; // Source and destination buffers
Array<Range> src_range, dst_range; // Ranges for each dimension in src and dst
IntImm coalesced_width; // Width (in elements) for coalesced memory access
Bool disable_tma = Bool(false); // Whether to disable TMA acceleration
std::unique_ptr<ParallelOp> class Copy : public TileOperator {
par_op_; // Optional associated parallelization operator public:
TVM_DEFINE_OBJECT_REF_METHODS(Copy, TileOperator, CopyNode);
enum class EvictionPolicy { /*!
kEvictNormal = 0, * \brief Constructor.
kEvictFirst = 1, * \param args Expression arguments for the copy.
kEvictLast = 2, * \param vmap Buffer variable mapping.
}; */
TVM_DLL Copy(Array<PrimExpr> args, BufferMap vmap);
int eviction_policy; // Policy for cache eviction /*!
* \brief Get the TVM Op handle corresponding to this Copy op.
*/
static const Op &Get();
}; };
/*! /*!
...@@ -243,41 +234,43 @@ protected: ...@@ -243,41 +234,43 @@ protected:
* This operator converts input image layout into columnar format suitable * This operator converts input image layout into columnar format suitable
* for matrix multiplication-based convolution lowering. * for matrix multiplication-based convolution lowering.
*/ */
class Conv2DIm2ColOp : public Operator { class Conv2DIm2ColOpNode : public TileOperatorNode {
public: public:
/*! Buffer src, dst; // Source (input feature map) and destination (im2col matrix)
* \brief Constructor. int stride; // Stride for convolution
* \param args Op arguments (convolution parameters, shapes, etc.) int padding; // Padding amount
* \param vmap Variable buffer mapping. int dilation; // Dilation factor
*/ int kernel; // Kernel size
Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap); int eviction_policy; // Cache eviction policy
PrimExpr nhw_step; // Step size in NHW dimensions
PrimExpr c_step; // Step size in channel dimension
static constexpr const char *_type_key = "tl.Conv2DIm2Col";
TVM_DECLARE_FINAL_OBJECT_INFO(Conv2DIm2ColOpNode, TileOperatorNode);
/*! /*!
* \brief Lower to TIR statement. * \brief Lower to TIR statement.
*/ */
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
/*! /*!
* \brief Get TVM Op handle. * \brief Infer layout for this operator.
*/ */
static const Op &Get(); LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
/*! /*!
* \brief Clone this operator. * \brief Get TVM Op handle.
*/ */
std::unique_ptr<Operator> Clone() const final { static const Op &Get();
return std::make_unique<Conv2DIm2ColOp>(*this); TileOperator Clone() const;
} };
private: class Conv2DIm2ColOp : public TileOperator {
Buffer src, dst; // Source (input feature map) and destination (im2col matrix) public:
int stride; // Stride for convolution TVM_DEFINE_OBJECT_REF_METHODS(Conv2DIm2ColOp, TileOperator,
int padding; // Padding amount Conv2DIm2ColOpNode);
int dilation; // Dilation factor TVM_DLL Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap);
int kernel; // Kernel size static const Op &Get();
int eviction_policy; // Cache eviction policy
PrimExpr nhw_step; // Step size in NHW dimensions
PrimExpr c_step; // Step size in channel dimension
}; };
} // namespace tl } // namespace tl
......
...@@ -23,6 +23,7 @@ namespace tl { ...@@ -23,6 +23,7 @@ namespace tl {
using namespace tir; using namespace tir;
Fill::Fill(Array<PrimExpr> args, BufferMap vmap) { Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<FillNode> node = make_object<FillNode>();
if (args[0]->IsInstance<BufferLoadNode>()) { if (args[0]->IsInstance<BufferLoadNode>()) {
auto buffer_load = Downcast<BufferLoad>(args[0]); auto buffer_load = Downcast<BufferLoad>(args[0]);
...@@ -33,42 +34,49 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) { ...@@ -33,42 +34,49 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
const auto *lanes = ramp->lanes.as<IntImmNode>(); const auto *lanes = ramp->lanes.as<IntImmNode>();
CHECK(lanes) CHECK(lanes)
<< "Scalable vectors not supported in BufferRegion conversion"; << "Scalable vectors not supported in BufferRegion conversion";
region.push_back(Range::FromMinExtent(ramp->base, ramp->lanes)); node->region.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else { } else {
region.push_back(Range::FromMinExtent(index, 1)); node->region.push_back(Range::FromMinExtent(index, 1));
} }
} }
dst = buffer_load->buffer; node->dst = buffer_load->buffer;
} else { } else {
dst = vmap[GetVarFromAccessPtr(args[0])]; node->dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < dst->shape.size(); i++) { for (int i = 0; i < node->dst->shape.size(); i++) {
region.push_back(Range(0, dst->shape[i])); node->region.push_back(Range(0, node->dst->shape[i]));
} }
} }
if (args[1]->dtype != dst->dtype) { if (args[1]->dtype != node->dst->dtype) {
value = Cast(dst->dtype, args[1]); node->value = Cast(node->dst->dtype, args[1]);
} else { } else {
value = args[1]; node->value = args[1];
} }
ICHECK(region.size() == dst->shape.size()) ICHECK(node->region.size() == node->dst->shape.size())
<< "region size = " << region.size() << " != " << dst->shape.size(); << "region size = " << node->region.size()
for (int i = 0; i < region.size(); i++) { << " != " << node->dst->shape.size();
for (int i = 0; i < node->region.size(); i++) {
// bound check if region is static // bound check if region is static
if (region[i]->min.as<IntImm>()) { if (node->region[i]->min.as<IntImm>()) {
int64_t min = Downcast<IntImm>(region[i]->min)->value; int64_t min = Downcast<IntImm>(node->region[i]->min)->value;
ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0"; ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0";
} }
if (region[i]->extent.as<IntImm>()) { if (node->region[i]->extent.as<IntImm>()) {
int64_t extent = Downcast<IntImm>(region[i]->extent)->value; int64_t extent = Downcast<IntImm>(node->region[i]->extent)->value;
ICHECK_LE(extent, Downcast<IntImm>(dst->shape[i])->value) ICHECK_LE(extent, Downcast<IntImm>(node->dst->shape[i])->value)
<< "region[" << i << "] = " << extent << " > " << dst->shape[i]; << "region[" << i << "] = " << extent << " > " << node->dst->shape[i];
} }
} }
data_ = std::move(node);
} }
For Fill::MakeSIMTLoop(arith::Analyzer *analyzer) const { TileOperator FillNode::Clone() const {
auto op = make_object<FillNode>(*this);
return Fill(op);
}
For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
int ndim = dst->shape.size(); int ndim = dst->shape.size();
Array<IterVar> loop_vars; Array<IterVar> loop_vars;
Array<PrimExpr> dst_indices; Array<PrimExpr> dst_indices;
...@@ -85,10 +93,9 @@ For Fill::MakeSIMTLoop(arith::Analyzer *analyzer) const { ...@@ -85,10 +93,9 @@ For Fill::MakeSIMTLoop(arith::Analyzer *analyzer) const {
return Downcast<For>(body); return Downcast<For>(body);
} }
Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (dst.scope() == "local.fragment") { if (dst.scope() == "local.fragment") {
auto par_op = std::make_unique<ParallelOp>(MakeSIMTLoop(analyzer)); auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map}, par_op->InferLayout({T.target, T.thread_bounds, T.layout_map},
InferLevel::kFree); InferLevel::kFree);
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map}, par_op->InferLayout({T.target, T.thread_bounds, T.layout_map},
...@@ -106,7 +113,7 @@ Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -106,7 +113,7 @@ Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto vectorized_thread_loop = VectorizeLoop(init_loop); auto vectorized_thread_loop = VectorizeLoop(init_loop);
return vectorized_thread_loop; return vectorized_thread_loop;
} else if (dst.scope() == "shared.dyn" || dst.scope() == "shared") { } else if (dst.scope() == "shared.dyn" || dst.scope() == "shared") {
auto par_op = std::make_unique<ParallelOp>(MakeSIMTLoop(analyzer)); auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map}, par_op->InferLayout({T.target, T.thread_bounds, T.layout_map},
InferLevel::kFree); InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
...@@ -122,6 +129,11 @@ Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -122,6 +129,11 @@ Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
} }
} }
LayoutMap FillNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
return {};
}
TIR_REGISTER_TL_OP(Fill, fill) TIR_REGISTER_TL_OP(Fill, fill)
.set_num_inputs(2) .set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#ifndef TVM_TL_OP_ELEM_H_ #ifndef TVM_TL_OP_ELEM_H_
#define TVM_TL_OP_ELEM_H_ #define TVM_TL_OP_ELEM_H_
#include "op.h" #include "operator.h"
#include "parallel.h" #include "parallel.h"
namespace tvm { namespace tvm {
...@@ -15,21 +15,29 @@ namespace tl { ...@@ -15,21 +15,29 @@ namespace tl {
using namespace tir; using namespace tir;
class Fill : public Operator { class FillNode : public TileOperatorNode {
public: public:
Fill(Array<PrimExpr> args, BufferMap vmap); tir::Buffer dst;
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final; PrimExpr value;
Array<Range> region;
static constexpr const char *_type_key = "tl.Fill";
TVM_DECLARE_FINAL_OBJECT_INFO(FillNode, TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
static const Op &Get(); static const Op &Get();
std::unique_ptr<Operator> Clone() const final { TileOperator Clone() const;
return std::make_unique<Fill>(*this);
}
private: private:
For MakeSIMTLoop(arith::Analyzer *analyzer) const; For MakeSIMTLoop(arith::Analyzer *analyzer) const;
tir::Buffer dst; };
PrimExpr value;
Array<Range> region; class Fill : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Fill, TileOperator, FillNode);
TVM_DLL Fill(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
}; };
} // namespace tl } // namespace tl
......
...@@ -34,35 +34,44 @@ static std::vector<int> toPrimeFactors(int x) { ...@@ -34,35 +34,44 @@ static std::vector<int> toPrimeFactors(int x) {
} }
Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) { Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
Aptr = args[0]; ObjectPtr<GemmNode> node = make_object<GemmNode>();
Bptr = args[1];
Cptr = args[2]; node->Aptr = args[0];
A = vmap[GetVarFromAccessPtr(Aptr)]; node->Bptr = args[1];
B = vmap[GetVarFromAccessPtr(Bptr)]; node->Cptr = args[2];
C = vmap[GetVarFromAccessPtr(Cptr)]; node->A = vmap[GetVarFromAccessPtr(node->Aptr)];
trans_A = args[3].as<Bool>().value(); node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
trans_B = args[4].as<Bool>().value(); node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
M = args[5].as<IntImm>().value()->value; node->trans_A = args[3].as<Bool>().value();
N = args[6].as<IntImm>().value()->value; node->trans_B = args[4].as<Bool>().value();
K = args[7].as<IntImm>().value()->value; node->M = args[5].as<IntImm>().value()->value;
policy = static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value); node->N = args[6].as<IntImm>().value()->value;
clear_accum = args[9].as<Bool>().value(); node->K = args[7].as<IntImm>().value()->value;
stride_A = args[10].as<IntImm>().value()->value; node->policy =
stride_B = args[11].as<IntImm>().value()->value; static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
offset_A = args[12].as<IntImm>().value()->value; node->clear_accum = args[9].as<Bool>().value();
offset_B = args[13].as<IntImm>().value()->value; node->stride_A = args[10].as<IntImm>().value()->value;
node->stride_B = args[11].as<IntImm>().value()->value;
node->offset_A = args[12].as<IntImm>().value()->value;
node->offset_B = args[13].as<IntImm>().value()->value;
if (args.size() > 14) { if (args.size() > 14) {
kPack = args[14].as<IntImm>().value()->value; node->kPack = args[14].as<IntImm>().value()->value;
if (kPack != 1 && kPack != 2) { if (node->kPack != 1 && node->kPack != 2) {
ICHECK(false) << "kPack must be 1 or 2"; ICHECK(false) << "kPack must be 1 or 2";
} }
} }
if (args.size() > 15) { if (args.size() > 15) {
wg_wait = args[15].as<IntImm>().value()->value; node->wg_wait = args[15].as<IntImm>().value()->value;
} }
data_ = std::move(node);
} }
Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const { TileOperator GemmNode::Clone() const {
auto op = make_object<GemmNode>(*this);
return Gemm(op);
}
GemmNode::GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
int warp_size = TargetGetWarpSize(target); int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size; int num_warps = block_size / warp_size;
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) && bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
...@@ -87,10 +96,13 @@ Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const { ...@@ -87,10 +96,13 @@ Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const {
* per-warp tile sizes) and adapts the partition according to the configured * per-warp tile sizes) and adapts the partition according to the configured
* GemmWarpPolicy (FullRow, FullCol, Square). * GemmWarpPolicy (FullRow, FullCol, Square).
* *
* @param block_size Total number of threads in the block (used to derive num_warps). * @param block_size Total number of threads in the block (used to derive
* num_warps).
* @param gemm_inst The chosen GEMM implementation (e.g., kWGMMA, kMFMA, kMMA). * @param gemm_inst The chosen GEMM implementation (e.g., kWGMMA, kMFMA, kMMA).
* @param target Target device information (used for warp size and target-specific rules). * @param target Target device information (used for warp size and
* @return std::pair<int, int> {m_warp, n_warp} where m_warp * n_warp == num_warps. * target-specific rules).
* @return std::pair<int, int> {m_warp, n_warp} where m_warp * n_warp ==
* num_warps.
* *
* Constraints and behavior: * Constraints and behavior:
* - Each warp is assumed to cover 16 rows (M) and 8 columns (N). The function * - Each warp is assumed to cover 16 rows (M) and 8 columns (N). The function
...@@ -100,7 +112,8 @@ Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const { ...@@ -100,7 +112,8 @@ Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const {
* - num_warps must be a multiple of 4 (warp-groups of 4). * - num_warps must be a multiple of 4 (warp-groups of 4).
* - m_warp is always a multiple of 4. * - m_warp is always a multiple of 4.
* - The warp partition respects the GemmWarpPolicy: * - The warp partition respects the GemmWarpPolicy:
* - FullRow: maximize warps on M (in multiples of 4) while keeping divisibility. * - FullRow: maximize warps on M (in multiples of 4) while keeping
* divisibility.
* - FullCol: maximize warps on N, but if N is not evenly divisible, move * - FullCol: maximize warps on N, but if N is not evenly divisible, move
* whole warp-groups to M to achieve feasibility. * whole warp-groups to M to achieve feasibility.
* - Square: choose a multiple-of-4 m_warp that best balances per-warp work * - Square: choose a multiple-of-4 m_warp that best balances per-warp work
...@@ -118,9 +131,9 @@ Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const { ...@@ -118,9 +131,9 @@ Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const {
* divisibility or policy conditions are not met (e.g., M/N tile divisibility, * divisibility or policy conditions are not met (e.g., M/N tile divisibility,
* invalid policy, or WGMMA-specific warp-group requirements). * invalid policy, or WGMMA-specific warp-group requirements).
*/ */
std::pair<int, int> Gemm::ComputeWarpPartition(int block_size, std::pair<int, int> GemmNode::ComputeWarpPartition(int block_size,
GemmInst gemm_inst, GemmInst gemm_inst,
Target target) const { Target target) const {
int num_warps = block_size / TargetGetWarpSize(target); int num_warps = block_size / TargetGetWarpSize(target);
int m_warp = 1, n_warp = 1; int m_warp = 1, n_warp = 1;
constexpr int kMPerWarp = 16; // Rows processed by a single warp constexpr int kMPerWarp = 16; // Rows processed by a single warp
...@@ -296,19 +309,21 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int block_size, ...@@ -296,19 +309,21 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int block_size,
* Supported combinations and constraints: * Supported combinations and constraints:
* - C=float16: * - C=float16:
* - A=float16, B=float16: K % 16 == 0 * - A=float16, B=float16: K % 16 == 0
* - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K % 32 == 0 * - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K %
* 32 == 0
* - C=float32: * - C=float32:
* - A=float16, B=float16: K % 16 == 0 * - A=float16, B=float16: K % 16 == 0
* - A=bfloat16, B=bfloat16: K % 16 == 0 * - A=bfloat16, B=bfloat16: K % 16 == 0
* - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0 * - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
* - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0 * - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
* - C=int32: * - C=int32:
* - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B) and K % 32 == 0 * - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B)
* and K % 32 == 0
* *
* @return true if WGMMA is supported for the current buffers, dtypes, and * @return true if WGMMA is supported for the current buffers, dtypes, and
* transpose/shape constraints; false otherwise. * transpose/shape constraints; false otherwise.
*/ */
bool Gemm::CheckWGMMA() const { bool GemmNode::CheckWGMMA() const {
if (B.scope() != "shared.dyn" && B.scope() != "shared") { if (B.scope() != "shared.dyn" && B.scope() != "shared") {
return false; return false;
} }
...@@ -373,7 +388,7 @@ static int GetArchInt(Target target) { ...@@ -373,7 +388,7 @@ static int GetArchInt(Target target) {
return arch_int; return arch_int;
} }
Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto block_size = *as_const_int(T.thread_bounds->extent); auto block_size = *as_const_int(T.thread_bounds->extent);
GemmInst gemm_inst = GetGemmInst(block_size, T.target); GemmInst gemm_inst = GetGemmInst(block_size, T.target);
auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target); auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);
...@@ -425,7 +440,8 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -425,7 +440,8 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
* - C.scope() must be "local.fragment". * - C.scope() must be "local.fragment".
* *
* Postconditions / side effects: * Postconditions / side effects:
* - Marks the operator's layout inference as completed (sets completed_ = true). * - Marks the operator's layout inference as completed (sets completed_ =
* true).
* - May abort via ICHECK on unsupported targets, invalid buffer scopes, or * - May abort via ICHECK on unsupported targets, invalid buffer scopes, or
* incompatible shape constraints. * incompatible shape constraints.
* *
...@@ -433,7 +449,8 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -433,7 +449,8 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
* @param level Inference level (unused for side effects but retained for API). * @param level Inference level (unused for side effects but retained for API).
* @return LayoutMap mapping each of A, B, and C to their inferred layouts. * @return LayoutMap mapping each of A, B, and C to their inferred layouts.
*/ */
LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (completed_) if (completed_)
return {}; return {};
LayoutMap results; LayoutMap results;
......
...@@ -7,37 +7,21 @@ ...@@ -7,37 +7,21 @@
#ifndef TVM_TL_OP_GEMM_H_ #ifndef TVM_TL_OP_GEMM_H_
#define TVM_TL_OP_GEMM_H_ #define TVM_TL_OP_GEMM_H_
#include "op.h" #include "operator.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
class Gemm : public Operator { enum class GemmWarpPolicy {
public: kSquare = 0,
Gemm(Array<PrimExpr> args, BufferMap vmap); kFullRow = 1,
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final; kFullCol = 2,
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final; };
static const Op &Get();
enum class GemmWarpPolicy {
kSquare = 0,
kFullRow = 1,
kFullCol = 2,
} policy;
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<Gemm>(*this);
}
private:
// Target GEMM instruction
enum class GemmInst { kMMA, kWGMMA, kUTCMMA, kMFMA };
GemmInst GetGemmInst(int block_size, Target target) const;
std::pair<int, int> ComputeWarpPartition(int num_warps, GemmInst gemm_inst,
Target target) const;
class GemmNode : public TileOperatorNode {
public:
bool CheckWGMMA() const; bool CheckWGMMA() const;
Array<PrimExpr> call_args; Array<PrimExpr> call_args;
tir::Buffer A, B, C; tir::Buffer A, B, C;
...@@ -52,7 +36,33 @@ private: ...@@ -52,7 +36,33 @@ private:
// only will be enabled under cdna mfma instructions // only will be enabled under cdna mfma instructions
int kPack = 1; int kPack = 1;
int wg_wait = 0; int wg_wait = 0;
bool completed_ = false; GemmWarpPolicy policy;
static constexpr const char *_type_key = "tl.Gemm";
TVM_DECLARE_FINAL_OBJECT_INFO(GemmNode, TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
TileOperator Clone() const;
private:
// Target GEMM instruction
enum class GemmInst { kMMA, kWGMMA, kUTCMMA, kMFMA };
GemmInst GetGemmInst(int block_size, Target target) const;
std::pair<int, int> ComputeWarpPartition(int num_warps, GemmInst gemm_inst,
Target target) const;
mutable bool completed_ = false;
};
class Gemm : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Gemm, TileOperator, GemmNode);
TVM_DLL Gemm(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
}; };
} // namespace tl } // namespace tl
......
...@@ -32,31 +32,39 @@ static std::vector<int> toPrimeFactors(int x) { ...@@ -32,31 +32,39 @@ static std::vector<int> toPrimeFactors(int x) {
} }
GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) { GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) {
A = vmap[GetVarFromAccessPtr(args[0])]; ObjectPtr<GemmSPNode> node = make_object<GemmSPNode>();
E = vmap[GetVarFromAccessPtr(args[1])]; node->A = vmap[GetVarFromAccessPtr(args[0])];
B = vmap[GetVarFromAccessPtr(args[2])]; node->E = vmap[GetVarFromAccessPtr(args[1])];
C = vmap[GetVarFromAccessPtr(args[3])]; node->B = vmap[GetVarFromAccessPtr(args[2])];
trans_A = args[4].as<Bool>().value(); node->C = vmap[GetVarFromAccessPtr(args[3])];
trans_B = args[5].as<Bool>().value(); node->trans_A = args[4].as<Bool>().value();
M = args[6].as<IntImm>().value()->value; node->trans_B = args[5].as<Bool>().value();
N = args[7].as<IntImm>().value()->value; node->M = args[6].as<IntImm>().value()->value;
K = args[8].as<IntImm>().value()->value; node->N = args[7].as<IntImm>().value()->value;
policy = static_cast<GemmWarpPolicy>(args[9].as<IntImm>().value()->value); node->K = args[8].as<IntImm>().value()->value;
clear_accum = args[10].as<Bool>().value(); node->policy = static_cast<GemmSPNode::GemmWarpPolicy>(
args[9].as<IntImm>().value()->value);
node->clear_accum = args[10].as<Bool>().value();
if (args.size() > 11) { if (args.size() > 11) {
kPack = args[11].as<IntImm>().value()->value; node->kPack = args[11].as<IntImm>().value()->value;
if (kPack != 1 && kPack != 2) { if (node->kPack != 1 && node->kPack != 2) {
ICHECK(false) << "kPack must be 1 or 2"; ICHECK(false) << "kPack must be 1 or 2";
} }
} }
if (args.size() > 12) { if (args.size() > 12) {
wg_wait = args[12].as<IntImm>().value()->value; node->wg_wait = args[12].as<IntImm>().value()->value;
} }
data_ = std::move(node);
}
TileOperator GemmSPNode::Clone() const {
auto op = make_object<GemmSPNode>(*this);
return GemmSP(op);
} }
std::pair<int, int> std::pair<int, int>
GemmSP::ComputeWarpPartition(int num_warps, Target target, GemmSPNode::ComputeWarpPartition(int num_warps, Target target,
bool maybe_hopper_wgmma) const { bool maybe_hopper_wgmma) const {
int m_warp = 1, n_warp = 1; int m_warp = 1, n_warp = 1;
constexpr int kMPerWarp = 16; // Rows processed by a single warp constexpr int kMPerWarp = 16; // Rows processed by a single warp
constexpr int kNPerWarp = 8; // Columns processed by a single warp constexpr int kNPerWarp = 8; // Columns processed by a single warp
...@@ -212,7 +220,7 @@ GemmSP::ComputeWarpPartition(int num_warps, Target target, ...@@ -212,7 +220,7 @@ GemmSP::ComputeWarpPartition(int num_warps, Target target,
return {m_warp, n_warp}; return {m_warp, n_warp};
} }
Stmt GemmSP::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Stmt GemmSPNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
int warp_size = 32; int warp_size = 32;
auto block_size = *as_const_int(T.thread_bounds->extent); auto block_size = *as_const_int(T.thread_bounds->extent);
...@@ -256,7 +264,8 @@ Stmt GemmSP::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -256,7 +264,8 @@ Stmt GemmSP::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return Evaluate(new_call); return Evaluate(new_call);
} }
LayoutMap GemmSP::InferLayout(const LayoutInferArgs &T, InferLevel level) { LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (completed_) if (completed_)
return {}; return {};
LayoutMap results; LayoutMap results;
...@@ -308,6 +317,7 @@ LayoutMap GemmSP::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -308,6 +317,7 @@ LayoutMap GemmSP::InferLayout(const LayoutInferArgs &T, InferLevel level) {
completed_ = true; completed_ = true;
return results; return results;
} }
TIR_REGISTER_TL_OP(GemmSP, gemm_sp) TIR_REGISTER_TL_OP(GemmSP, gemm_sp)
.set_num_inputs(5) .set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
......
...@@ -7,30 +7,23 @@ ...@@ -7,30 +7,23 @@
#ifndef TVM_TL_OP_GEMM_SP_H_ #ifndef TVM_TL_OP_GEMM_SP_H_
#define TVM_TL_OP_GEMM_SP_H_ #define TVM_TL_OP_GEMM_SP_H_
#include "op.h" #include "operator.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
class GemmSP : public Operator { class GemmSPNode : public TileOperatorNode {
public: public:
GemmSP(Array<PrimExpr> args, BufferMap vmap); Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op &Get();
enum class GemmWarpPolicy { enum class GemmWarpPolicy {
kSquare = 0, kSquare = 0,
kFullRow = 1, kFullRow = 1,
kFullCol = 2, kFullCol = 2,
} policy; } policy;
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<GemmSP>(*this);
}
private:
std::pair<int, int> std::pair<int, int>
ComputeWarpPartition(int num_warps, Target target, ComputeWarpPartition(int num_warps, Target target,
bool maybe_hopper_wgmma = true) const; bool maybe_hopper_wgmma = true) const;
...@@ -44,7 +37,18 @@ private: ...@@ -44,7 +37,18 @@ private:
// only will be enabled under cdna mfma instructions // only will be enabled under cdna mfma instructions
int kPack = 1; int kPack = 1;
int wg_wait = 0; int wg_wait = 0;
bool completed_ = false;
TileOperator Clone() const;
private:
mutable bool completed_ = false;
};
class GemmSP : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(GemmSP, TileOperator, GemmSPNode);
TVM_DLL GemmSP(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
}; };
} // namespace tl } // namespace tl
......
/*!
* \file tl/op/op.cc
*
* Define operators usd in tile library.
*/
#include "operator.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
namespace tvm {
namespace tl {
using namespace tir;
TileOperator ParseOperator(Call call, BufferMap vmap) {
auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder");
Op op = call->op.as<Op>().value();
if (op_map.count(op)) {
auto tile_op = op_map[op](call->args, vmap);
ICHECK(tile_op.defined());
return tile_op;
}
return TileOperator();
}
TileOperator ParseOperator(Stmt stmt, BufferMap vmap) {
if (stmt.as<Evaluate>() && stmt.as<EvaluateNode>()->value.as<CallNode>()) {
auto call = stmt.as<EvaluateNode>()->value.as<CallNode>();
return ParseOperator(GetRef<Call>(call), vmap);
}
return TileOperator();
}
Var GetVarFromAccessPtr(const PrimExpr &expr) {
auto call = expr.as<CallNode>();
ICHECK(call);
ICHECK(call->op.same_as(builtin::tvm_access_ptr()));
auto var = call->args[1].as<VarNode>();
ICHECK(var);
return GetRef<Var>(var);
}
} // namespace tl
} // namespace tvm
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#include <tvm/ir/op.h> #include <tvm/ir/op.h>
#include <tvm/target/target.h> #include <tvm/target/target.h>
#include <tvm/tir/buffer.h> #include <tvm/tir/buffer.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/op_attr_types.h>
#include "../layout/layout.h" #include "../layout/layout.h"
...@@ -22,19 +24,6 @@ using namespace tir; ...@@ -22,19 +24,6 @@ using namespace tir;
using AddWorkspaceCallback = std::function<PrimExpr(int, DataType)>; using AddWorkspaceCallback = std::function<PrimExpr(int, DataType)>;
using LayoutMap = Map<Buffer, Layout>; using LayoutMap = Map<Buffer, Layout>;
using BufferMap = Map<Var, Buffer>; using BufferMap = Map<Var, Buffer>;
using OpBuilderFunc = ffi::TypedFunction<void *(Array<PrimExpr>, BufferMap)>;
#define TIR_REGISTER_TL_OP(Entry, OpName) \
const Op &Entry::Get() { \
static const Op &op = Op::Get("tl." #OpName); \
return op; \
} \
TVM_REGISTER_OP("tl." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \
.set_attr<OpBuilderFunc>("TLOpBuilder", \
[](Array<PrimExpr> a, BufferMap b) { \
return (void *)(new Entry(a, b)); \
})
enum class InferLevel { enum class InferLevel {
kFree = 0, kFree = 0,
...@@ -59,38 +48,48 @@ struct LayoutInferArgs { ...@@ -59,38 +48,48 @@ struct LayoutInferArgs {
Map<Buffer, Buffer> buffer_remap; Map<Buffer, Buffer> buffer_remap;
}; };
class Operator { class TileOperatorNode;
public: class TileOperator;
virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
virtual LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level);
virtual ~Operator() = default;
virtual std::unique_ptr<Operator> Clone() const = 0;
};
class RegionOp : public Operator { class TileOperatorNode: public Object {
public: public:
RegionOp(Array<PrimExpr> args, BufferMap vmap); virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const = 0;
static const Op &Get();
std::unique_ptr<Operator> Clone() const final { virtual LayoutMap InferLayout(const LayoutInferArgs& T,
return std::make_unique<RegionOp>(*this); InferLevel level) const = 0;
}
const Buffer &GetBuffer() const { return buffer_; } virtual TileOperator Clone() const = 0;
const Array<Range> &GetRanges() const { return ranges_; }
int GetAccessMask() const { return access_mask_; } static constexpr const char* _type_key = "tl.TileOperator";
bool IsFullRegion() const;
TVM_DECLARE_BASE_OBJECT_INFO(TileOperatorNode, Object);
};
private: class TileOperator : public ObjectRef {
Buffer buffer_; public:
Array<Range> ranges_; TVM_DEFINE_OBJECT_REF_METHODS(TileOperator, ObjectRef, TileOperatorNode);
int access_mask_;
}; };
Var GetVarFromAccessPtr(const PrimExpr &expr); Var GetVarFromAccessPtr(const PrimExpr &expr);
std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap); TileOperator ParseOperator(Call call, BufferMap vmap);
std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap); TileOperator ParseOperator(Stmt stmt, BufferMap vmap);
using OpBuilderFunc = ffi::TypedFunction<TileOperator(Array<PrimExpr>, BufferMap)>;
#define TIR_REGISTER_TL_OP(Entry, OpName) \
const Op &Entry::Get() { \
static const Op &op = Op::Get("tl." #OpName); \
return op; \
} \
TVM_REGISTER_OP("tl." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \
.set_attr<OpBuilderFunc>("TLOpBuilder", \
[](Array<PrimExpr> args, BufferMap vmap) { \
return Entry(args, vmap); \
})
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
......
...@@ -154,9 +154,21 @@ void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) { ...@@ -154,9 +154,21 @@ void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) {
StmtExprVisitor::VisitExpr_(op); StmtExprVisitor::VisitExpr_(op);
} }
ParallelOp::ParallelOp(For root) : root_(root), V(this) { V.VisitStmt(root); } ParallelOpNode::ParallelOpNode(For root) : root_(root), V(this) {
V.VisitStmt(root);
}
TileOperator ParallelOpNode::Clone() const {
auto op = make_object<ParallelOpNode>(*this);
return ParallelOp(op);
}
Stmt ParallelOpNode::Lower(const LowerArgs &T,
arith::Analyzer *analyzer) const {
return root_;
}
bool ParallelOp::IsCommonAccessIndice(const Buffer &buffer) const { bool ParallelOpNode::IsCommonAccessIndice(const Buffer &buffer) const {
auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; }); auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; });
return StructuralEqual()(indice_map_[buffer], common_indice); return StructuralEqual()(indice_map_[buffer], common_indice);
} }
...@@ -179,7 +191,8 @@ bool ParallelOp::IsCommonAccessIndice(const Buffer &buffer) const { ...@@ -179,7 +191,8 @@ bool ParallelOp::IsCommonAccessIndice(const Buffer &buffer) const {
* Can generate new layouts based on vectorization and thread * Can generate new layouts based on vectorization and thread
* bounds. Used when maximum performance optimization is desired. * bounds. Used when maximum performance optimization is desired.
*/ */
LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (loop_layout_.defined()) if (loop_layout_.defined())
return {}; return {};
if (level == InferLevel::kStrict) if (level == InferLevel::kStrict)
...@@ -355,7 +368,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -355,7 +368,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
return results; return results;
} }
Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const { Optional<PrimExpr> ParallelOpNode::GetPredicate(Var thread_var) const {
if (predicate_.defined()) { if (predicate_.defined()) {
return Substitute(predicate_.value(), {{InputPlaceholder(0), thread_var}}); return Substitute(predicate_.value(), {{InputPlaceholder(0), thread_var}});
} else { } else {
...@@ -363,7 +376,7 @@ Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const { ...@@ -363,7 +376,7 @@ Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const {
} }
} }
Fragment ParallelOp::CompleteBufferFragment(const Buffer &buffer) { Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
ICHECK(loop_layout_.defined()); ICHECK(loop_layout_.defined());
if (IsCommonAccessIndice(buffer)) { if (IsCommonAccessIndice(buffer)) {
return loop_layout_; return loop_layout_;
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include "../layout/layout.h" #include "../layout/layout.h"
#include "op.h" #include "operator.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -31,58 +31,97 @@ bool ProveFragmentContains(Fragment small_frag, Fragment large_frag, ...@@ -31,58 +31,97 @@ bool ProveFragmentContains(Fragment small_frag, Fragment large_frag,
Array<PrimExpr> large_frag_indices, Array<PrimExpr> large_frag_indices,
arith::Analyzer &analyzer_); arith::Analyzer &analyzer_);
class ParallelOp; class ParallelOpNode;
class ParallelLoopNestVisitor : public StmtExprVisitor { class ParallelLoopNestVisitor : public StmtExprVisitor {
private: private:
ParallelLoopNestVisitor(ParallelOp *op) : p(op){}; ParallelLoopNestVisitor(ParallelOpNode *op) : p(op){};
void VisitStmt_(const ForNode *op) final; void VisitStmt_(const ForNode *op) override;
void VisitStmt_(const BufferStoreNode *op) final; void VisitStmt_(const BufferStoreNode *op) override;
void VisitExpr_(const BufferLoadNode *op) final; void VisitExpr_(const BufferLoadNode *op) override;
ParallelOp *p; ParallelOpNode *p;
friend class ParallelOp; friend class ParallelOpNode;
}; };
class ParallelOp : public Operator { // ParallelOpNode represents a parallel for loop operator in TileLang.
// It is responsible for inferring layouts, holding loop structure, and managing
// predicates.
class ParallelOpNode : public TileOperatorNode {
public: public:
ParallelOp(For root); // The inferred layout for the loop, mutable to allow lazy inference.
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final; mutable Fragment loop_layout_;
// The predicate expression for the loop, if any, mutable for lazy
// construction.
mutable Optional<PrimExpr> predicate_;
ParallelOp(const ParallelOp &other) : ParallelOp(other.root_) { // Type key for TVM object system.
static constexpr const char *_type_key = "tl.ParallelOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ParallelOpNode, TileOperatorNode);
// Construct from a root For loop.
ParallelOpNode(For root);
// Lower the operator to a TIR statement.
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
// Infer the layout for this parallel operator.
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
// Copy constructor for ParallelOpNode.
ParallelOpNode(const ParallelOpNode &other) : ParallelOpNode(other.root_) {
loop_layout_ = other.loop_layout_; loop_layout_ = other.loop_layout_;
predicate_ = other.predicate_; predicate_ = other.predicate_;
} }
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<ParallelOp>(*this);
}
// Get the inferred loop layout.
Fragment GetLoopLayout() const { return loop_layout_; } Fragment GetLoopLayout() const { return loop_layout_; }
// Get the root For loop.
For GetRoot() const { return root_; } For GetRoot() const { return root_; }
// Get the mapping from buffer to access indices.
Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; } Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; }
// Get the predicate for a given thread variable.
Optional<PrimExpr> GetPredicate(Var thread_var) const; Optional<PrimExpr> GetPredicate(Var thread_var) const;
// Clone this operator.
TileOperator Clone() const;
private: private:
Fragment CompleteBufferFragment(const Buffer &buffer); // Complete the fragment layout for a given buffer.
Fragment CompleteBufferFragment(const Buffer &buffer) const;
// Check if the buffer is accessed with common indices (i.e., loop variables).
bool IsCommonAccessIndice(const Buffer &buffer) const; bool IsCommonAccessIndice(const Buffer &buffer) const;
void AddPredicate(PrimExpr expr) { // Add a predicate to the current predicate expression.
void AddPredicate(PrimExpr expr) const {
predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr; predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
} }
// Allow ParallelLoopNestVisitor to access private members.
friend class ParallelLoopNestVisitor;
// The root For loop node.
For root_; For root_;
// Visitor for collecting loop nest information.
ParallelLoopNestVisitor V; ParallelLoopNestVisitor V;
// Mapping from buffer to their access indices in the loop.
Map<Buffer, Array<PrimExpr>> indice_map_; Map<Buffer, Array<PrimExpr>> indice_map_;
// Set of buffers that are written to in the loop.
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_is_write_; std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_is_write_;
// The loop variables for the parallel loop nest.
Array<IterVar> loop_vars_; Array<IterVar> loop_vars_;
// Analyzer for simplifying and analyzing expressions, mutable for lazy use.
Fragment loop_layout_;
mutable arith::Analyzer analyzer_; mutable arith::Analyzer analyzer_;
Optional<PrimExpr> predicate_; };
friend class ParallelLoopNestVisitor; class ParallelOp : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ParallelOp, TileOperator, ParallelOpNode);
ParallelOp(For root) {
auto op = make_object<ParallelOpNode>(root);
data_ = std::move(op);
}
}; };
} // namespace tl } // namespace tl
......
...@@ -22,26 +22,38 @@ namespace tl { ...@@ -22,26 +22,38 @@ namespace tl {
using namespace tir; using namespace tir;
ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) { ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
src = vmap[GetVarFromAccessPtr(args[0])]; ObjectPtr<ReduceOpNode> node = make_object<ReduceOpNode>();
dst = vmap[GetVarFromAccessPtr(args[1])]; node->src = vmap[GetVarFromAccessPtr(args[0])];
String reduce_type = args[2].as<StringImm>().value()->value; node->dst = vmap[GetVarFromAccessPtr(args[1])];
dim = args[3].as<IntImm>().value()->value; std::string reduce_type = args[2].as<StringImm>().value()->value;
node->dim = args[3].as<IntImm>().value()->value;
if (reduce_type == "sum") if (reduce_type == "sum")
type = ReduceType::kSum; node->type = ReduceType::kSum;
else if (reduce_type == "abssum") else if (reduce_type == "abssum")
type = ReduceType::kAbsSum; node->type = ReduceType::kAbsSum;
else if (reduce_type == "absmax") else if (reduce_type == "absmax")
type = ReduceType::kAbsMax; node->type = ReduceType::kAbsMax;
else if (reduce_type == "max") else if (reduce_type == "max")
type = ReduceType::kMax; node->type = ReduceType::kMax;
else if (reduce_type == "min") else if (reduce_type == "min")
type = ReduceType::kMin; node->type = ReduceType::kMin;
else else
ICHECK(0) << "Unknown reduce type: " << reduce_type; ICHECK(0) << "Unknown reduce type: " << reduce_type;
clear = args[4].as<Bool>().value(); node->clear = args[4].as<Bool>().value();
data_ = std::move(node);
} }
PrimExpr ReduceOp::MakeInitValue() const { TileOperator ReduceOpNode::Clone() const {
auto op = make_object<ReduceOpNode>(*this);
return ReduceOp(op);
}
TileOperator CumSumOpNode::Clone() const {
auto op = make_object<CumSumOpNode>(*this);
return CumSumOp(op);
}
PrimExpr ReduceOpNode::MakeInitValue() const {
auto dst_dtype = dst->dtype; auto dst_dtype = dst->dtype;
auto is_int = dst_dtype.is_int(); auto is_int = dst_dtype.is_int();
bool is_uint = dst_dtype.is_uint(); bool is_uint = dst_dtype.is_uint();
...@@ -75,7 +87,7 @@ PrimExpr ReduceOp::MakeInitValue() const { ...@@ -75,7 +87,7 @@ PrimExpr ReduceOp::MakeInitValue() const {
} }
} }
PrimExpr ReduceOp::MakeReduce(const PrimExpr &a, const PrimExpr &b) const { PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &a, const PrimExpr &b) const {
PrimExpr lhs = a, rhs = b; PrimExpr lhs = a, rhs = b;
if (lhs->dtype != rhs->dtype) { if (lhs->dtype != rhs->dtype) {
rhs = Cast(lhs->dtype, rhs); rhs = Cast(lhs->dtype, rhs);
...@@ -97,7 +109,7 @@ PrimExpr ReduceOp::MakeReduce(const PrimExpr &a, const PrimExpr &b) const { ...@@ -97,7 +109,7 @@ PrimExpr ReduceOp::MakeReduce(const PrimExpr &a, const PrimExpr &b) const {
} }
} }
std::string ReduceOp::MakeCodegenReducer() const { std::string ReduceOpNode::MakeCodegenReducer() const {
switch (type) { switch (type) {
case ReduceType::kSum: case ReduceType::kSum:
return "tl::SumOp"; return "tl::SumOp";
...@@ -115,7 +127,7 @@ std::string ReduceOp::MakeCodegenReducer() const { ...@@ -115,7 +127,7 @@ std::string ReduceOp::MakeCodegenReducer() const {
} }
} }
Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(this->src.scope() == "local.fragment" && ICHECK(this->src.scope() == "local.fragment" &&
this->dst.scope() == "local.fragment") this->dst.scope() == "local.fragment")
<< "Reduce for shared memory not implemented."; << "Reduce for shared memory not implemented.";
...@@ -284,7 +296,8 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -284,7 +296,8 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return body; return body;
} }
LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (level >= InferLevel::kStrict) if (level >= InferLevel::kStrict)
return {}; return {};
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" && if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
...@@ -369,14 +382,16 @@ CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) { ...@@ -369,14 +382,16 @@ CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
reverse: whether to cumsum in reverse order reverse: whether to cumsum in reverse order
*/ */
CHECK_EQ(args.size(), 4); CHECK_EQ(args.size(), 4);
src = vmap[GetVarFromAccessPtr(args[0])]; ObjectPtr<CumSumOpNode> node = make_object<CumSumOpNode>();
dst = vmap[GetVarFromAccessPtr(args[1])]; node->src = vmap[GetVarFromAccessPtr(args[0])];
dim = args[2].as<IntImm>().value()->value; node->dst = vmap[GetVarFromAccessPtr(args[1])];
reverse = args[3].as<Bool>().value(); node->dim = args[2].as<IntImm>().value()->value;
CHECK_LT(dim, static_cast<int>(src->shape.size())); node->reverse = args[3].as<Bool>().value();
CHECK_LT(node->dim, static_cast<int>(node->src->shape.size()));
data_ = std::move(node);
} }
Stmt CumSumOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (this->src.scope() == "local.fragment" && if (this->src.scope() == "local.fragment" &&
this->dst.scope() == "local.fragment") { this->dst.scope() == "local.fragment") {
LOG(FATAL) << "CumSum for fragment not implemented, please raise an issue " LOG(FATAL) << "CumSum for fragment not implemented, please raise an issue "
...@@ -402,7 +417,8 @@ Stmt CumSumOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -402,7 +417,8 @@ Stmt CumSumOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return Stmt(); return Stmt();
} }
LayoutMap CumSumOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { LayoutMap CumSumOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
return {}; return {};
} }
......
...@@ -7,56 +7,70 @@ ...@@ -7,56 +7,70 @@
#ifndef TVM_TL_OP_REDUCE_H_ #ifndef TVM_TL_OP_REDUCE_H_
#define TVM_TL_OP_REDUCE_H_ #define TVM_TL_OP_REDUCE_H_
#include "op.h" #include "operator.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
class ReduceOp : public Operator { enum class ReduceType {
public: kSum,
ReduceOp(Array<PrimExpr> args, BufferMap vmap); kAbsSum,
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final; kMax,
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final; kMin,
static const Op &Get(); kAbsMax,
};
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<ReduceOp>(*this);
}
private: class ReduceOpNode : public TileOperatorNode {
public:
tir::Buffer src, dst; tir::Buffer src, dst;
int dim; int dim;
enum class ReduceType { ReduceType type;
kSum,
kAbsSum,
kMax,
kMin,
kAbsMax,
} type;
bool clear; bool clear;
static constexpr const char *_type_key = "tl.ReduceOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ReduceOpNode, TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
static const Op &Get();
TileOperator Clone() const;
private:
PrimExpr MakeInitValue() const; PrimExpr MakeInitValue() const;
PrimExpr MakeReduce(const PrimExpr &a, const PrimExpr &b) const; PrimExpr MakeReduce(const PrimExpr &a, const PrimExpr &b) const;
std::string MakeCodegenReducer() const; std::string MakeCodegenReducer() const;
}; };
class CumSumOp : public Operator { class ReduceOp : public TileOperator {
public: public:
CumSumOp(Array<PrimExpr> args, BufferMap vmap); TVM_DEFINE_OBJECT_REF_METHODS(ReduceOp, TileOperator, ReduceOpNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final; TVM_DLL ReduceOp(Array<PrimExpr> args, BufferMap vmap);
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op &Get(); static const Op &Get();
};
std::unique_ptr<Operator> Clone() const final { class CumSumOpNode : public TileOperatorNode {
return std::make_unique<CumSumOp>(*this); public:
}
private:
tir::Buffer src, dst; tir::Buffer src, dst;
int dim; int dim;
bool reverse; bool reverse;
static constexpr const char *_type_key = "tl.CumSumOp";
TVM_DECLARE_FINAL_OBJECT_INFO(CumSumOpNode, TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
static const Op &Get();
TileOperator Clone() const;
};
class CumSumOp : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(CumSumOp, TileOperator, CumSumOpNode);
TVM_DLL CumSumOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
}; };
} // namespace tl } // namespace tl
......
/*! /*!
* \file tl/op/op.cc * \file tl/op/region.cc
* \brief Define region operator.
* *
* Define operators usd in tile library.
*/ */
#include "op.h" #include "region.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
TIR_REGISTER_TL_OP(RegionOp, region)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap) {
auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder");
Op op = call->op.as<Op>().value();
if (op_map.count(op)) {
Operator *ptr = static_cast<Operator *>(op_map[op](call->args, vmap));
ICHECK(ptr != nullptr);
return std::unique_ptr<Operator>(ptr);
}
return nullptr;
}
std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap) {
if (stmt.as<Evaluate>() && stmt.as<EvaluateNode>()->value.as<CallNode>()) {
auto call = stmt.as<EvaluateNode>()->value.as<CallNode>();
return ParseOperator(GetRef<Call>(call), vmap);
}
return nullptr;
}
Var GetVarFromAccessPtr(const PrimExpr &expr) {
auto call = expr.as<CallNode>();
ICHECK(call);
ICHECK(call->op.same_as(builtin::tvm_access_ptr()));
auto var = call->args[1].as<VarNode>();
ICHECK(var);
return GetRef<Var>(var);
}
RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) { RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
size_t n = args.size(); size_t n = args.size();
size_t ndim = n - 2; size_t ndim = n - 2;
...@@ -55,16 +18,25 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) { ...@@ -55,16 +18,25 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
ICHECK(load); ICHECK(load);
ICHECK(load->indices.size() == ndim) ICHECK(load->indices.size() == ndim)
<< "load->indices.size() = " << load->indices << " ndim = " << ndim; << "load->indices.size() = " << load->indices << " ndim = " << ndim;
buffer_ = load->buffer; Array<Range> ranges;
access_mask_ = static_cast<int>(*as_const_int(args[1]));
for (size_t i = 0; i < ndim; i++) { for (size_t i = 0; i < ndim; i++) {
PrimExpr min = load->indices[i]; PrimExpr min = load->indices[i];
PrimExpr extent = args[2 + i]; PrimExpr extent = args[2 + i];
ranges_.push_back(Range::FromMinExtent(min, extent)); ranges.push_back(Range::FromMinExtent(min, extent));
} }
ObjectPtr<RegionOpNode> node = make_object<RegionOpNode>();
node->buffer_ = load->buffer;
node->access_mask_ = static_cast<int>(*as_const_int(args[1]));
node->ranges_ = ranges;
data_ = std::move(node);
}
TileOperator RegionOpNode::Clone() const {
auto op = make_object<RegionOpNode>(*this);
return RegionOp(op);
} }
bool RegionOp::IsFullRegion() const { bool RegionOpNode::IsFullRegion() const {
for (size_t i = 0; i < ranges_.size(); i++) { for (size_t i = 0; i < ranges_.size(); i++) {
if (!is_zero(ranges_[i]->min)) if (!is_zero(ranges_[i]->min))
return false; return false;
...@@ -74,14 +46,19 @@ bool RegionOp::IsFullRegion() const { ...@@ -74,14 +46,19 @@ bool RegionOp::IsFullRegion() const {
return true; return true;
} }
Stmt Operator::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Stmt RegionOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(0) << "Not Implemented Lower method.";
return Evaluate(0); return Evaluate(0);
} }
LayoutMap Operator::InferLayout(const LayoutInferArgs &T, InferLevel level) { LayoutMap RegionOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
return {}; return {};
} }
TIR_REGISTER_TL_OP(RegionOp, region)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
/*!
* \file tl/op/op.h
* \brief Tile library operations.
*
*/
#ifndef TVM_TL_OP_REGION_H_
#define TVM_TL_OP_REGION_H_
#include "./operator.h"
#include <tvm/arith/analyzer.h>
#include <tvm/ir/op.h>
#include <tvm/target/target.h>
#include <tvm/tir/buffer.h>
namespace tvm {
namespace tl {
using namespace tir;
class RegionOpNode : public TileOperatorNode {
public:
Buffer buffer_;
Array<Range> ranges_;
int access_mask_;
static constexpr const char *_type_key = "tl.RegionOp";
TVM_DECLARE_FINAL_OBJECT_INFO(RegionOpNode, TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
const Buffer &GetBuffer() const { return buffer_; }
const Array<Range> &GetRanges() const { return ranges_; }
int GetAccessMask() const { return access_mask_; }
bool IsFullRegion() const;
TileOperator Clone() const;
};
class RegionOp : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RegionOp, TileOperator, RegionOpNode);
TVM_DLL RegionOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_REGION_H_
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "../layout/utils.h" #include "../layout/utils.h"
#include "../op/parallel.h" #include "../op/parallel.h"
#include "../op/region.h"
#include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h" #include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_fusion_utils.h" #include "common/loop_fusion_utils.h"
...@@ -79,8 +80,8 @@ public: ...@@ -79,8 +80,8 @@ public:
auto iter_var = thread_var_vec_[cur_infer_id]; auto iter_var = thread_var_vec_[cur_infer_id];
auto thread_bounds = thread_bounds_vec_[cur_infer_id]; auto thread_bounds = thread_bounds_vec_[cur_infer_id];
// Double-check that 'next' is valid // Double-check that 'next' is valid
ICHECK(next != nullptr) ICHECK(next.defined()) << "infer_list_[" << cur_infer_id
<< "infer_list_[" << cur_infer_id << "] is null inside run_infer_step."; << "] is null inside run_infer_step.";
// Check iter_var->dom and dom->extent // Check iter_var->dom and dom->extent
ICHECK(iter_var.defined()) ICHECK(iter_var.defined())
...@@ -100,6 +101,7 @@ public: ...@@ -100,6 +101,7 @@ public:
// Run InferLayout // Run InferLayout
auto updates = next->InferLayout( auto updates = next->InferLayout(
LayoutInferArgs{target_, thread_bounds, layout_map}, level); LayoutInferArgs{target_, thread_bounds, layout_map}, level);
// Process the returned updates // Process the returned updates
for (const auto &[buffer, layout] : updates) { for (const auto &[buffer, layout] : updates) {
// Basic validity checks // Basic validity checks
...@@ -112,7 +114,7 @@ public: ...@@ -112,7 +114,7 @@ public:
level != InferLevel::kStrict && !strict_layout_map.count(buffer)) { level != InferLevel::kStrict && !strict_layout_map.count(buffer)) {
// Actually this test has been done in ParallelOp::InferLayout // Actually this test has been done in ParallelOp::InferLayout
// already. Just do it again to avoid missing implementations in other // already. Just do it again to avoid missing implementations in other
// `Operator`s. // `TileOperator`s.
auto dst_layout = layout.as<Fragment>().value(); auto dst_layout = layout.as<Fragment>().value();
auto src_layout = layout_map[buffer].as<Fragment>().value(); auto src_layout = layout_map[buffer].as<Fragment>().value();
ICHECK(dst_layout->InputDim() == src_layout->InputDim()); ICHECK(dst_layout->InputDim() == src_layout->InputDim());
...@@ -210,7 +212,7 @@ public: ...@@ -210,7 +212,7 @@ public:
std::vector<bool> in_queue(num_infer, true); std::vector<bool> in_queue(num_infer, true);
for (int i = 0; i < num_infer; i++) { for (int i = 0; i < num_infer; i++) {
// Check that each infer_list_ entry is valid // Check that each infer_list_ entry is valid
ICHECK(infer_list_[i] != nullptr) ICHECK(infer_list_[i].defined())
<< "infer_list_[" << i << "infer_list_[" << i
<< "] is null. The inference object is not allocated properly."; << "] is null. The inference object is not allocated properly.";
...@@ -253,13 +255,13 @@ public: ...@@ -253,13 +255,13 @@ public:
ICHECK(infer_list_.size() == thread_var_vec_.size()) ICHECK(infer_list_.size() == thread_var_vec_.size())
<< "infer_list_ and thread_var_vec_ size mismatch"; << "infer_list_ and thread_var_vec_ size mismatch";
for (int i = 0; i < infer_list_.size(); i++) { for (int i = 0; i < infer_list_.size(); i++) {
std::unique_ptr<Operator> base_infer = std::move(infer_list_[i]); TileOperator base_infer = std::move(infer_list_[i]);
auto thread_var = thread_var_vec_[i]; auto thread_var = thread_var_vec_[i];
// Check if base_infer is valid // Check if base_infer is valid
ICHECK(base_infer != nullptr) << "Null pointer encountered in " ICHECK(base_infer.defined()) << "Null pointer encountered in "
"infer_list_ while collecting for_map."; "infer_list_ while collecting for_map.";
if (auto for_infer = dynamic_cast<ParallelOp *>(base_infer.get())) { if (auto for_infer = base_infer.as<ParallelOpNode>()) {
// Check that the loop layout is defined // Check that the loop layout is defined
ICHECK(for_infer->GetLoopLayout().defined()) ICHECK(for_infer->GetLoopLayout().defined())
<< "The Layout for Parallel for cannot be inferred correctly:\n" << "The Layout for Parallel for cannot be inferred correctly:\n"
...@@ -297,7 +299,7 @@ private: ...@@ -297,7 +299,7 @@ private:
return; return;
auto p = ParseOperator(GetRef<Call>(op), buffer_data_to_buffer_); auto p = ParseOperator(GetRef<Call>(op), buffer_data_to_buffer_);
if (p != nullptr) { if (p.defined()) {
for (const auto &arg : op->args) { for (const auto &arg : op->args) {
if (auto buffer = getBufferFromAccessPtr(arg)) { if (auto buffer = getBufferFromAccessPtr(arg)) {
addToUseList(buffer.value()); addToUseList(buffer.value());
...@@ -344,7 +346,7 @@ private: ...@@ -344,7 +346,7 @@ private:
void VisitStmt_(const ForNode *op) final { void VisitStmt_(const ForNode *op) final {
if (op->kind == ForKind::kParallel) { if (op->kind == ForKind::kParallel) {
auto infer = std::make_unique<ParallelOp>(GetRef<For>(op)); auto infer = ParallelOp(GetRef<For>(op));
for (const auto &[buffer, _] : infer->GetIndiceMap()) { for (const auto &[buffer, _] : infer->GetIndiceMap()) {
addToUseList(buffer); addToUseList(buffer);
} }
...@@ -399,7 +401,7 @@ private: ...@@ -399,7 +401,7 @@ private:
Map<Var, Buffer> buffer_data_to_buffer_; Map<Var, Buffer> buffer_data_to_buffer_;
std::vector<ObjectRef> infer_list_stmt_; std::vector<ObjectRef> infer_list_stmt_;
std::vector<std::unique_ptr<Operator>> infer_list_; std::vector<TileOperator> infer_list_;
std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual> std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
use_list_; use_list_;
// This is a workaround for cpu backend, // This is a workaround for cpu backend,
...@@ -412,8 +414,8 @@ private: ...@@ -412,8 +414,8 @@ private:
LayoutMap annotated_layout_map_; LayoutMap annotated_layout_map_;
bool skip_thread_partition_{false}; bool skip_thread_partition_{false};
std::vector<std::unique_ptr<Operator>> BackupInferList() { std::vector<TileOperator> BackupInferList() {
std::vector<std::unique_ptr<Operator>> back_infer_list; std::vector<TileOperator> back_infer_list;
back_infer_list.reserve(infer_list_.size()); back_infer_list.reserve(infer_list_.size());
for (auto &&p : infer_list_) { for (auto &&p : infer_list_) {
back_infer_list.push_back(p->Clone()); back_infer_list.push_back(p->Clone());
...@@ -443,20 +445,25 @@ private: ...@@ -443,20 +445,25 @@ private:
int root = uf.Find(i); int root = uf.Find(i);
components[root].push_back(i); components[root].push_back(i);
} }
// Create a map from root to buffers
std::unordered_map<int, std::vector<Buffer>> components_buffers; std::unordered_map<int, std::vector<Buffer>> components_buffers;
for (const auto &[buffer, infer_indices] : use_list_) { for (const auto &[buffer, infer_indices] : use_list_) {
int root = uf.Find(infer_indices[0]); int root = uf.Find(infer_indices[0]);
components_buffers[root].push_back(buffer); components_buffers[root].push_back(buffer);
} }
// Keep components_buffers for debug purpose
(void)components_buffers;
// For each component, try each op as root, and determine the least // For each component, try each op as root, and determine the least
// replicated one // replicated one
std::queue<int> q; std::queue<int> q;
std::vector<bool> in_queue(infer_list_.size(), false); std::vector<bool> in_queue(infer_list_.size(), false);
for (auto &&[root, members] : components) { for (auto &&[root, members] : components) {
decltype(infer_list_) best_infer_list; decltype(infer_list_) best_infer_list;
LayoutMap best_layout_map; LayoutMap best_layout_map;
int64_t min_reg_num = INT64_MAX; int64_t min_reg_num = INT64_MAX;
for (int attempt_infer_root : members) { for (int attempt_infer_root : members) {
// backup infer_list_ in class member // backup infer_list_ in class member
auto back_infer_list = BackupInferList(); auto back_infer_list = BackupInferList();
...@@ -470,7 +477,6 @@ private: ...@@ -470,7 +477,6 @@ private:
tmp_layout_map, strict_layout_map, q, in_queue); tmp_layout_map, strict_layout_map, q, in_queue);
FinishInferQueue(InferLevel::kFree, tmp_layout_map, strict_layout_map, FinishInferQueue(InferLevel::kFree, tmp_layout_map, strict_layout_map,
q, in_queue); q, in_queue);
// Silly workaround: we have no clue if single root will iterate over // Silly workaround: we have no clue if single root will iterate over
// the entire component, since the InferLayout implementations have // the entire component, since the InferLayout implementations have
// complicated conditioning inside and we know nothing about it. // complicated conditioning inside and we know nothing about it.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment