Unverified Commit b38bd69e authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Refactor `Operator` into `TileOperator` and with tvm reflection (#763)

* Refactor operator classes to inherit from TileOperator and update layout inference methods

- Changed base class of several operator classes (AtomicAdd, Copy, Gemm, etc.) from Operator to TileOperator for better alignment with tile operations.
- Updated InferLayout and Lower methods to use 'override' specifier for clarity and consistency.
- Adjusted header inclusions to replace "op.h" with "operator.h" across multiple files for improved organization.
- Added missing layout inference implementations for Fill and Conv2DIm2ColOp.
- Removed deprecated op.cc and op.h files to streamline the codebase.

* lint fix

* Refactor operator classes to use Node pattern and improve memory management

- Updated several operator classes (AtomicAdd, Copy, Gemm, etc.) to utilize the Node pattern for better memory management and encapsulation.
- Changed constructors to initialize member variables through a node object, enhancing clarity and reducing direct member access.
- Updated Clone methods to return TileOperator instances instead of unique pointers, aligning with the new design.
- Refactored InferLayout and Lower methods to ensure consistency across operator implementations.
- Adjusted header files to reflect the new class structure and removed deprecated code for a cleaner codebase.

* Enhance Clone methods in AtomicAdd and Copy classes to support parallel operation cloning

- Updated the Clone methods in AtomicAddNode and CopyNode to ensure that the parallel operation (par_op_) is properly cloned when defined, improving the integrity of cloned objects.
- Refactored the FillNode class to use ParallelOp directly instead of std::make_unique, streamlining the creation of parallel operations.
- Made minor adjustments in layout inference and other related methods for consistency and clarity.

* Refactor FillNode::Lower method to remove unused global function call

- Eliminated the call to the global function "tl.fill.lower" in the FillNode::Lower method, streamlining the code and improving clarity.
- Retained the core functionality of the method while enhancing maintainability by reducing unnecessary dependencies.
parent 277ed53c
......@@ -4,8 +4,8 @@
* Define elment-wise operators.
*/
#include "atomic_add.h"
#include "./atomic_add.h"
#include "./region.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
......@@ -34,7 +34,8 @@ static int GetArchInt(Target target) {
return arch_int;
}
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) : args_(args) {
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<AtomicAddNode> node = make_object<AtomicAddNode>();
Array<Range> rgs[2];
Buffer bf[2];
for (int i = 0; i < 2; i++) {
......@@ -42,17 +43,26 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) : args_(args) {
auto call = expr.as<CallNode>();
ICHECK(call);
auto region = RegionOp(call->args, vmap);
rgs[i] = region.GetRanges();
bf[i] = region.GetBuffer();
rgs[i] = region->GetRanges();
bf[i] = region->GetBuffer();
}
std::tie(this->src, this->dst) = std::tie(bf[0], bf[1]);
std::tie(this->src_range, this->dst_range) = std::tie(rgs[0], rgs[1]);
std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
if (args.size() >= 3) {
coalesced_width = Downcast<IntImm>(args[2]);
node->coalesced_width = Downcast<IntImm>(args[2]);
}
data_ = std::move(node);
}
TileOperator AtomicAddNode::Clone() const {
auto op = make_object<AtomicAddNode>(*this);
if (par_op_.defined()) {
op->par_op_ = Downcast<ParallelOp>(par_op_->Clone());
}
return AtomicAdd(op);
}
Array<IterVar> AtomicAdd::MakeIterVars() const {
Array<IterVar> AtomicAddNode::MakeIterVars() const {
Array<IterVar> loop_vars;
size_t idx = 0;
for (size_t i = 0; i < src_range.size(); i++) {
......@@ -68,7 +78,7 @@ Array<IterVar> AtomicAdd::MakeIterVars() const {
// ivs: itervars returned by MakeIterVars()
// src_dst: 0 for src_indices, 1 for dst_indices
Array<PrimExpr> AtomicAdd::MakeIndices(const Array<IterVar> &ivs,
Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
int src_dst) const {
Array<PrimExpr> indices;
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
......@@ -87,9 +97,10 @@ Array<PrimExpr> AtomicAdd::MakeIndices(const Array<IterVar> &ivs,
return indices;
}
PrimExpr AtomicAdd::MakePredicate(arith::Analyzer *analyzer,
PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const {
Array<PrimExpr> extents,
int src_dst) const {
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
Array<PrimExpr> cond_list;
ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
......@@ -117,7 +128,7 @@ PrimExpr AtomicAdd::MakePredicate(arith::Analyzer *analyzer,
}
}
For AtomicAdd::MakeSIMTLoop(arith::Analyzer *analyzer) const {
For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<IterVar> loop_vars = MakeIterVars();
bool is_scalar = loop_vars.size() == 0;
if (is_scalar) {
......@@ -180,16 +191,16 @@ For AtomicAdd::MakeSIMTLoop(arith::Analyzer *analyzer) const {
return Downcast<For>(body);
}
Stmt AtomicAdd::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Target target = T.target;
auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
auto par_op = std::make_unique<ParallelOp>(fused_loop);
auto par_op = ParallelOp(fused_loop);
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree};
for (auto level : levels) {
par_op->InferLayout(
(par_op)->InferLayout(
{T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level);
}
auto loop_layout = par_op->GetLoopLayout();
......@@ -210,10 +221,11 @@ Stmt AtomicAdd::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return vectorized_thread_loop;
}
LayoutMap AtomicAdd::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (par_op_ == nullptr) {
LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (!par_op_.defined()) {
arith::Analyzer analyzer;
par_op_ = std::make_unique<ParallelOp>(MakeSIMTLoop(&analyzer));
par_op_ = ParallelOp(MakeSIMTLoop(&analyzer));
}
if (T.layout_map.count(src) && T.layout_map.count(dst)) {
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") {
......@@ -236,10 +248,5 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
// TVM_REGISTER_OP("tl.atomicadd")
// .set_num_inputs(2)
// .add_argument("ref", "Buffer", "The destination buffer")
// .add_argument("val", "Expr", "The value to be added atomically");
} // namespace tl
} // namespace tvm
\ No newline at end of file
......@@ -7,7 +7,7 @@
#ifndef TVM_TL_OP_ATOMIC_ADD_H_
#define TVM_TL_OP_ATOMIC_ADD_H_
#include "op.h"
#include "operator.h"
#include "parallel.h"
namespace tvm {
......@@ -15,26 +15,23 @@ namespace tl {
using namespace tir;
class AtomicAdd : public Operator {
class AtomicAddNode : public TileOperatorNode {
public:
AtomicAdd(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
Array<PrimExpr> args_;
static const Op &Get();
Buffer src, dst;
Array<Range> src_range, dst_range;
IntImm coalesced_width;
AtomicAdd(const AtomicAdd &other)
: args_(other.args_), src(other.src), dst(other.dst),
src_range(other.src_range), dst_range(other.dst_range),
coalesced_width(other.coalesced_width) {
// No clone nullptr
if (other.par_op_)
par_op_ = std::unique_ptr<ParallelOp>(
static_cast<ParallelOp *>(other.par_op_->Clone().release()));
}
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<AtomicAdd>(*this);
}
mutable ParallelOp par_op_;
static constexpr const char *_type_key = "tl.AtomicAdd";
TVM_DECLARE_FINAL_OBJECT_INFO(AtomicAddNode, TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
static const Op &Get();
TileOperator Clone() const;
protected:
For MakeSIMTLoop(arith::Analyzer *analyzer) const;
......@@ -46,14 +43,13 @@ protected:
PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const;
};
Array<PrimExpr> args_;
Buffer src, dst;
Array<Range> src_range, dst_range;
IntImm coalesced_width;
std::unique_ptr<ParallelOp> par_op_;
class AtomicAdd : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(AtomicAdd, TileOperator, AtomicAddNode);
TVM_DLL AtomicAdd(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
} // namespace tl
......
......@@ -7,7 +7,7 @@
#ifndef TVM_TL_OP_BUILTIN_H_
#define TVM_TL_OP_BUILTIN_H_
#include "op.h"
#include "operator.h"
#include <tvm/ir/transform.h>
namespace tvm {
......
......@@ -15,6 +15,7 @@
#include "../transform/common/loop_parallel_transform_utils.h"
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
#include "region.h"
#include "../target/cuda.h"
#include "../target/utils.h"
......@@ -111,7 +112,8 @@ template <typename T> static Array<T> ReverseArray(Array<T> array) {
* operation. \param vmap BufferMap mapping original buffer names to new buffer
* names.
*/
Copy::Copy(Array<PrimExpr> args, BufferMap vmap) : args_(args) {
Copy::Copy(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<CopyNode> node = make_object<CopyNode>();
Array<Range> rgs[2];
Buffer bf[2];
for (int i = 0; i < 2; i++) {
......@@ -119,23 +121,32 @@ Copy::Copy(Array<PrimExpr> args, BufferMap vmap) : args_(args) {
auto call = expr.as<CallNode>();
ICHECK(call);
auto region = RegionOp(call->args, vmap);
rgs[i] = region.GetRanges();
bf[i] = region.GetBuffer();
rgs[i] = region->GetRanges();
bf[i] = region->GetBuffer();
}
std::tie(this->src, this->dst) = std::tie(bf[0], bf[1]);
std::tie(this->src_range, this->dst_range) = std::tie(rgs[0], rgs[1]);
std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
if (args.size() >= 3) {
auto coalesced_width = Downcast<IntImm>(args[2]);
if (coalesced_width->value > 0) {
this->coalesced_width = coalesced_width;
node->coalesced_width = coalesced_width;
}
}
if (args.size() >= 4) {
this->disable_tma = Downcast<Bool>(args[3]);
node->disable_tma = Downcast<Bool>(args[3]);
}
if (args.size() >= 5) {
this->eviction_policy = args[4].as<IntImmNode>()->value;
node->eviction_policy = args[4].as<IntImmNode>()->value;
}
data_ = std::move(node);
}
TileOperator CopyNode::Clone() const {
auto op = make_object<CopyNode>(*this);
if (par_op_.defined()) {
op->par_op_ = Downcast<ParallelOp>(par_op_->Clone());
}
return Copy(op);
}
/*!
......@@ -144,7 +155,7 @@ Copy::Copy(Array<PrimExpr> args, BufferMap vmap) : args_(args) {
* > 1. \return Array of IterVar representing the iterator variables for the
* copy operation.
*/
Array<IterVar> Copy::MakeIterVars() const {
Array<IterVar> CopyNode::MakeIterVars() const {
Array<IterVar> loop_vars;
size_t idx = 0;
for (size_t i = 0; i < src_range.size(); i++) {
......@@ -167,7 +178,7 @@ Array<IterVar> Copy::MakeIterVars() const {
* dst_indices. \return Array of PrimExpr representing the indices for the copy
* operation.
*/
Array<PrimExpr> Copy::MakeIndices(const Array<IterVar> &ivs,
Array<PrimExpr> CopyNode::MakeIndices(const Array<IterVar> &ivs,
int src_dst) const {
Array<PrimExpr> indices;
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
......@@ -195,9 +206,9 @@ Array<PrimExpr> Copy::MakeIndices(const Array<IterVar> &ivs,
* of the copy operation. \param src_dst 0 for src_indices, 1 for dst_indices.
* \return PrimExpr representing the predicate for the copy operation.
*/
PrimExpr Copy::MakePredicate(arith::Analyzer *analyzer,
const Array<IterVar> &ivs, Array<PrimExpr> extents,
int src_dst) const {
PrimExpr CopyNode::MakePredicate(arith::Analyzer *analyzer,
const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const {
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
Array<PrimExpr> cond_list;
ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
......@@ -233,7 +244,7 @@ PrimExpr Copy::MakePredicate(arith::Analyzer *analyzer,
* simplification. \return For representing the SIMT loop for the copy
* operation.
*/
For Copy::MakeSIMTLoop(arith::Analyzer *analyzer) const {
For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<IterVar> loop_vars = MakeIterVars();
bool is_scalar = loop_vars.size() == 0;
if (is_scalar) {
......@@ -289,7 +300,7 @@ For Copy::MakeSIMTLoop(arith::Analyzer *analyzer) const {
* shared tensor. \return Layout representing the linear layout for the TMA
* copy.
*/
Layout Copy::ComputeLinearLayout(const Buffer &shared_tensor) const {
Layout CopyNode::ComputeLinearLayout(const Buffer &shared_tensor) const {
Array<PrimExpr> input_size = shared_tensor->shape;
Array<PrimExpr> forward_vars;
for (size_t i = 0; i < input_size.size(); i++) {
......@@ -316,7 +327,8 @@ Layout Copy::ComputeLinearLayout(const Buffer &shared_tensor) const {
* indicating the level of layout inference. \return LayoutMap containing the
* inferred layout.
*/
LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) {
LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
auto target = T.target;
using namespace tvm::transform;
PassContext pass_ctx = PassContext::Current();
......@@ -340,17 +352,15 @@ LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) {
return Map<Buffer, Layout>({{shared_tensor, linear_layout}});
}
}
// for LDSM/STSM, the layout was deduced from register layout
// so we can directly apply the layout of normal copy
// Use parallel op to infer the layout
if (!par_op_) {
if (!par_op_.defined()) {
arith::Analyzer analyzer;
par_op_ = std::make_unique<ParallelOp>(MakeSIMTLoop(&analyzer));
par_op_ = ParallelOp((MakeSIMTLoop(&analyzer)));
}
return par_op_->InferLayout(T, level);
}
/*!
* \brief Check if the copy operation is a bulk load.
* This function verifies if the copy operation can be implemented using CUDA's
......@@ -359,7 +369,7 @@ LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) {
* same data type. \param target Target device. \return True if the copy
* operation is a bulk load, false otherwise.
*/
bool Copy::CheckBulkLoad(Target target) const {
bool CopyNode::CheckBulkLoad(Target target) const {
// 1. arch must have bulk copy support
if (!TargetHasBulkCopy(target))
return false;
......@@ -387,7 +397,7 @@ bool Copy::CheckBulkLoad(Target target) const {
* same data type. \param target Target device. \return True if the copy
* operation is a bulk store, false otherwise.
*/
bool Copy::CheckBulkStore(Target target) const {
bool CopyNode::CheckBulkStore(Target target) const {
// 1. arch must have bulk copy support
if (!TargetHasBulkCopy(target))
return false;
......@@ -415,7 +425,7 @@ bool Copy::CheckBulkStore(Target target) const {
* Target device. \return True if the copy operation is a LDSM copy, false
* otherwise.
*/
bool Copy::CheckLDSMCopy(Target target) const {
bool CopyNode::CheckLDSMCopy(Target target) const {
return TargetHasLdmatrix(target) &&
(src.scope() == "shared.dyn" || src.scope() == "shared") &&
dst.scope() == "local.fragment";
......@@ -429,7 +439,7 @@ bool Copy::CheckLDSMCopy(Target target) const {
* Target device. \return True if the copy operation is a STSM copy, false
* otherwise.
*/
bool Copy::CheckSTSMCopy(Target target) const {
bool CopyNode::CheckSTSMCopy(Target target) const {
return TargetHasStmatrix(target) && src.scope() == "local.fragment" &&
(dst.scope() == "shared.dyn" || dst.scope() == "shared");
}
......@@ -442,7 +452,7 @@ bool Copy::CheckSTSMCopy(Target target) const {
* copy if no specialized instruction is applicable. \param target Target
* device. \return CopyInst representing the copy instruction type.
*/
Copy::CopyInst Copy::GetCopyInst(Target target, bool disable_tma_lower) const {
CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower) const {
// disable_tma_lower is from pass_configs
// when tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER is True,
// we will not use tma for bulk load/store
......@@ -471,7 +481,7 @@ Copy::CopyInst Copy::GetCopyInst(Target target, bool disable_tma_lower) const {
* \param analyzer Arithmetic analyzer for simplification.
* \return Stmt representing the PTX code for the copy operation.
*/
Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Target target = T.target;
using namespace tvm::transform;
PassContext pass_ctx = PassContext::Current();
......@@ -502,7 +512,7 @@ Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
* map. \param analyzer Arithmetic analyzer for simplification. \return Stmt
* representing the normal copy code.
*/
Stmt Copy::LowerNormalCopy(const LowerArgs &T,
Stmt CopyNode::LowerNormalCopy(const LowerArgs &T,
arith::Analyzer *analyzer) const {
bool is_cpu_target = T.target->GetTargetDeviceType() == kDLCPU;
auto simt_loop = MakeSIMTLoop(analyzer);
......@@ -512,7 +522,7 @@ Stmt Copy::LowerNormalCopy(const LowerArgs &T,
Downcast<For>(ParallelLoopTransformer::Substitute(fused_loop));
For vectorized_thread_loop;
auto par_op = std::make_unique<ParallelOp>(transformed_loop);
auto par_op = ParallelOp(transformed_loop);
if (is_cpu_target) {
vectorized_thread_loop = VectorizeLoop(transformed_loop);
......@@ -548,7 +558,7 @@ Stmt Copy::LowerNormalCopy(const LowerArgs &T,
* \param copy_inst CopyInst representing the copy instruction type.
* \return Stmt representing the LDSM/STSM copy code.
*/
Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer,
Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer,
CopyInst copy_inst) const {
ICHECK(copy_inst == CopyInst::kLDSM || copy_inst == CopyInst::kSTSM)
<< "Invalid copy inst " << static_cast<int>(copy_inst);
......@@ -741,7 +751,7 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer,
* copy_inst CopyInst representing the copy instruction type. \return Stmt
* representing the bulk copy code.
*/
Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
CopyInst copy_inst) const {
ICHECK(copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkStore)
<< "Invalid copy inst " << static_cast<int>(copy_inst);
......@@ -1153,15 +1163,22 @@ Array<PrimExpr> TMADesc::EncodeCallArgs() const {
* buffer names to new buffer names.
*/
Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) {
src = vmap[GetVarFromAccessPtr(args[0])];
dst = vmap[GetVarFromAccessPtr(args[1])];
nhw_step = args[2];
c_step = args[3];
kernel = args[4].as<IntImm>().value()->value;
stride = args[5].as<IntImm>().value()->value;
dilation = args[6].as<IntImm>().value()->value;
padding = args[7].as<IntImm>().value()->value;
eviction_policy = args[8].as<IntImm>().value()->value;
ObjectPtr<Conv2DIm2ColOpNode> node = make_object<Conv2DIm2ColOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])];
node->dst = vmap[GetVarFromAccessPtr(args[1])];
node->nhw_step = args[2];
node->c_step = args[3];
node->kernel = args[4].as<IntImm>().value()->value;
node->stride = args[5].as<IntImm>().value()->value;
node->dilation = args[6].as<IntImm>().value()->value;
node->padding = args[7].as<IntImm>().value()->value;
node->eviction_policy = args[8].as<IntImm>().value()->value;
data_ = std::move(node);
}
TileOperator Conv2DIm2ColOpNode::Clone() const {
auto op = make_object<Conv2DIm2ColOpNode>(*this);
return Conv2DIm2ColOp(op);
}
/*!
......@@ -1174,7 +1191,7 @@ Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) {
* \param analyzer Arithmetic analyzer for simplification.
* \return Stmt representing the PTX code for the Conv2DIm2ColOp.
*/
Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T,
Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T,
arith::Analyzer *analyzer) const {
ICHECK(TargetIsHopper(T.target));
ICHECK(src.scope() == "global" &&
......@@ -1343,6 +1360,11 @@ TIR_REGISTER_TL_OP(Copy, copy)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
LayoutMap Conv2DIm2ColOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
return {};
}
// Register the Conv2DIm2Col operation with TVM's TIR system
// This operation performs im2col transformation for 2D convolutions using TMA
// - Takes 9 inputs: src_buffer, dst_buffer, nhw_step, c_step, kernel, stride,
......
......@@ -11,13 +11,24 @@
#ifndef TVM_TL_OP_COPY_H_
#define TVM_TL_OP_COPY_H_
#include "op.h"
#include "operator.h"
#include "parallel.h"
namespace tvm {
namespace tl {
using namespace tir;
/*!
* \brief Copy instruction type.
*/
enum class CopyInst {
kNormal = 0, // utilize ldg/stg or cpasync or any buffer copy
kLDSM = 1, // ldmatrix memory copy
kSTSM = 2, // stmatrix memory copy
kBulkLoad = 3, // utilize tma load
kBulkStore = 4, // utilize tma store
};
/*!
* \brief Descriptor for Tensor Memory Access (TMA) copy operations.
*
......@@ -83,44 +94,40 @@ struct TMAIm2ColDesc {
* block-wise or element-wise data transfer, possibly optimized with
* parallelization or TMA hardware acceleration.
*/
class Copy : public Operator {
class CopyNode : public TileOperatorNode {
public:
/*!
* \brief Constructor.
* \param args Expression arguments for the copy.
* \param vmap Buffer variable mapping.
*/
Copy(Array<PrimExpr> args, BufferMap vmap);
Array<PrimExpr> args_; // Copy parameters (indices, sizes, etc.)
Buffer src, dst; // Source and destination buffers
Array<Range> src_range, dst_range; // Ranges for each dimension in src and dst
IntImm coalesced_width; // Width (in elements) for coalesced memory access
Bool disable_tma = Bool(false); // Whether to disable TMA acceleration
mutable ParallelOp par_op_; // Optional associated parallelization operator
enum class EvictionPolicy {
kEvictNormal = 0,
kEvictFirst = 1,
kEvictLast = 2,
};
int eviction_policy; // Policy for cache eviction
static constexpr const char *_type_key = "tl.Copy";
TVM_DECLARE_FINAL_OBJECT_INFO(CopyNode, TileOperatorNode);
/*!
* \brief Lower the copy operator to a TIR statement.
* \param T Arguments for lowering.
* \param analyzer Analyzer for simplification and bounds checks.
*/
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
/*!
* \brief Infer buffer layouts after applying this operator.
* \param T Arguments for layout inference.
* \param level Level of inference (basic or detailed).
*/
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
/*!
* \brief Get the TVM Op handle corresponding to this Copy op.
*/
static const Op &Get();
/*!
* \brief Copy instruction type.
*/
enum class CopyInst {
kNormal = 0, // utilize ldg/stg or cpasync or any buffer copy
kLDSM = 1, // ldmatrix memory copy
kSTSM = 2, // stmatrix memory copy
kBulkLoad = 3, // utilize tma load
kBulkStore = 4, // utilize tma store
};
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
/*!
* \brief Check if bulk copy is supported.
......@@ -147,26 +154,9 @@ public:
*/
CopyInst GetCopyInst(Target target, bool disable_tma_lower) const;
/*!
* \brief Copy constructor (deep clones ParallelOp if present).
*/
Copy(const Copy &other)
: args_(other.args_), src(other.src), dst(other.dst),
src_range(other.src_range), dst_range(other.dst_range),
coalesced_width(other.coalesced_width), disable_tma(other.disable_tma) {
// Deep copy ParallelOp if it exists
if (other.par_op_)
par_op_ = std::unique_ptr<ParallelOp>(
static_cast<ParallelOp *>(other.par_op_->Clone().release()));
}
/*!
* \brief Clone this copy operator.
*/
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<Copy>(*this);
}
protected:
/*!
* \brief Generate lowering for bulk/global-to-shared copy.
......@@ -218,23 +208,24 @@ protected:
PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const;
Array<PrimExpr> args_; // Copy parameters (indices, sizes, etc.)
Buffer src, dst; // Source and destination buffers
Array<Range> src_range, dst_range; // Ranges for each dimension in src and dst
IntImm coalesced_width; // Width (in elements) for coalesced memory access
Bool disable_tma = Bool(false); // Whether to disable TMA acceleration
TileOperator Clone() const;
};
std::unique_ptr<ParallelOp>
par_op_; // Optional associated parallelization operator
class Copy : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Copy, TileOperator, CopyNode);
enum class EvictionPolicy {
kEvictNormal = 0,
kEvictFirst = 1,
kEvictLast = 2,
};
/*!
* \brief Constructor.
* \param args Expression arguments for the copy.
* \param vmap Buffer variable mapping.
*/
TVM_DLL Copy(Array<PrimExpr> args, BufferMap vmap);
int eviction_policy; // Policy for cache eviction
/*!
* \brief Get the TVM Op handle corresponding to this Copy op.
*/
static const Op &Get();
};
/*!
......@@ -243,41 +234,43 @@ protected:
* This operator converts input image layout into columnar format suitable
* for matrix multiplication-based convolution lowering.
*/
class Conv2DIm2ColOp : public Operator {
class Conv2DIm2ColOpNode : public TileOperatorNode {
public:
/*!
* \brief Constructor.
* \param args Op arguments (convolution parameters, shapes, etc.)
* \param vmap Variable buffer mapping.
*/
Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap);
Buffer src, dst; // Source (input feature map) and destination (im2col matrix)
int stride; // Stride for convolution
int padding; // Padding amount
int dilation; // Dilation factor
int kernel; // Kernel size
int eviction_policy; // Cache eviction policy
PrimExpr nhw_step; // Step size in NHW dimensions
PrimExpr c_step; // Step size in channel dimension
static constexpr const char *_type_key = "tl.Conv2DIm2Col";
TVM_DECLARE_FINAL_OBJECT_INFO(Conv2DIm2ColOpNode, TileOperatorNode);
/*!
* \brief Lower to TIR statement.
*/
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
/*!
* \brief Get TVM Op handle.
* \brief Infer layout for this operator.
*/
static const Op &Get();
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
/*!
* \brief Clone this operator.
* \brief Get TVM Op handle.
*/
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<Conv2DIm2ColOp>(*this);
}
static const Op &Get();
TileOperator Clone() const;
};
private:
Buffer src, dst; // Source (input feature map) and destination (im2col matrix)
int stride; // Stride for convolution
int padding; // Padding amount
int dilation; // Dilation factor
int kernel; // Kernel size
int eviction_policy; // Cache eviction policy
PrimExpr nhw_step; // Step size in NHW dimensions
PrimExpr c_step; // Step size in channel dimension
class Conv2DIm2ColOp : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Conv2DIm2ColOp, TileOperator,
Conv2DIm2ColOpNode);
TVM_DLL Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
} // namespace tl
......
......@@ -23,6 +23,7 @@ namespace tl {
using namespace tir;
Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<FillNode> node = make_object<FillNode>();
if (args[0]->IsInstance<BufferLoadNode>()) {
auto buffer_load = Downcast<BufferLoad>(args[0]);
......@@ -33,42 +34,49 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
const auto *lanes = ramp->lanes.as<IntImmNode>();
CHECK(lanes)
<< "Scalable vectors not supported in BufferRegion conversion";
region.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
node->region.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
region.push_back(Range::FromMinExtent(index, 1));
node->region.push_back(Range::FromMinExtent(index, 1));
}
}
dst = buffer_load->buffer;
node->dst = buffer_load->buffer;
} else {
dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < dst->shape.size(); i++) {
region.push_back(Range(0, dst->shape[i]));
node->dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < node->dst->shape.size(); i++) {
node->region.push_back(Range(0, node->dst->shape[i]));
}
}
if (args[1]->dtype != dst->dtype) {
value = Cast(dst->dtype, args[1]);
if (args[1]->dtype != node->dst->dtype) {
node->value = Cast(node->dst->dtype, args[1]);
} else {
value = args[1];
node->value = args[1];
}
ICHECK(region.size() == dst->shape.size())
<< "region size = " << region.size() << " != " << dst->shape.size();
for (int i = 0; i < region.size(); i++) {
ICHECK(node->region.size() == node->dst->shape.size())
<< "region size = " << node->region.size()
<< " != " << node->dst->shape.size();
for (int i = 0; i < node->region.size(); i++) {
// bound check if region is static
if (region[i]->min.as<IntImm>()) {
int64_t min = Downcast<IntImm>(region[i]->min)->value;
if (node->region[i]->min.as<IntImm>()) {
int64_t min = Downcast<IntImm>(node->region[i]->min)->value;
ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0";
}
if (region[i]->extent.as<IntImm>()) {
int64_t extent = Downcast<IntImm>(region[i]->extent)->value;
ICHECK_LE(extent, Downcast<IntImm>(dst->shape[i])->value)
<< "region[" << i << "] = " << extent << " > " << dst->shape[i];
if (node->region[i]->extent.as<IntImm>()) {
int64_t extent = Downcast<IntImm>(node->region[i]->extent)->value;
ICHECK_LE(extent, Downcast<IntImm>(node->dst->shape[i])->value)
<< "region[" << i << "] = " << extent << " > " << node->dst->shape[i];
}
}
data_ = std::move(node);
}
For Fill::MakeSIMTLoop(arith::Analyzer *analyzer) const {
TileOperator FillNode::Clone() const {
auto op = make_object<FillNode>(*this);
return Fill(op);
}
For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
int ndim = dst->shape.size();
Array<IterVar> loop_vars;
Array<PrimExpr> dst_indices;
......@@ -85,10 +93,9 @@ For Fill::MakeSIMTLoop(arith::Analyzer *analyzer) const {
return Downcast<For>(body);
}
Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (dst.scope() == "local.fragment") {
auto par_op = std::make_unique<ParallelOp>(MakeSIMTLoop(analyzer));
auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map},
InferLevel::kFree);
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map},
......@@ -106,7 +113,7 @@ Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto vectorized_thread_loop = VectorizeLoop(init_loop);
return vectorized_thread_loop;
} else if (dst.scope() == "shared.dyn" || dst.scope() == "shared") {
auto par_op = std::make_unique<ParallelOp>(MakeSIMTLoop(analyzer));
auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map},
InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
......@@ -122,6 +129,11 @@ Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
}
}
LayoutMap FillNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
return {};
}
TIR_REGISTER_TL_OP(Fill, fill)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
......
......@@ -7,7 +7,7 @@
#ifndef TVM_TL_OP_ELEM_H_
#define TVM_TL_OP_ELEM_H_
#include "op.h"
#include "operator.h"
#include "parallel.h"
namespace tvm {
......@@ -15,21 +15,29 @@ namespace tl {
using namespace tir;
class Fill : public Operator {
class FillNode : public TileOperatorNode {
public:
Fill(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
tir::Buffer dst;
PrimExpr value;
Array<Range> region;
static constexpr const char *_type_key = "tl.Fill";
TVM_DECLARE_FINAL_OBJECT_INFO(FillNode, TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
static const Op &Get();
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<Fill>(*this);
}
TileOperator Clone() const;
private:
For MakeSIMTLoop(arith::Analyzer *analyzer) const;
tir::Buffer dst;
PrimExpr value;
Array<Range> region;
};
class Fill : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Fill, TileOperator, FillNode);
TVM_DLL Fill(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
} // namespace tl
......
......@@ -34,35 +34,44 @@ static std::vector<int> toPrimeFactors(int x) {
}
Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
Aptr = args[0];
Bptr = args[1];
Cptr = args[2];
A = vmap[GetVarFromAccessPtr(Aptr)];
B = vmap[GetVarFromAccessPtr(Bptr)];
C = vmap[GetVarFromAccessPtr(Cptr)];
trans_A = args[3].as<Bool>().value();
trans_B = args[4].as<Bool>().value();
M = args[5].as<IntImm>().value()->value;
N = args[6].as<IntImm>().value()->value;
K = args[7].as<IntImm>().value()->value;
policy = static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
clear_accum = args[9].as<Bool>().value();
stride_A = args[10].as<IntImm>().value()->value;
stride_B = args[11].as<IntImm>().value()->value;
offset_A = args[12].as<IntImm>().value()->value;
offset_B = args[13].as<IntImm>().value()->value;
ObjectPtr<GemmNode> node = make_object<GemmNode>();
node->Aptr = args[0];
node->Bptr = args[1];
node->Cptr = args[2];
node->A = vmap[GetVarFromAccessPtr(node->Aptr)];
node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
node->trans_A = args[3].as<Bool>().value();
node->trans_B = args[4].as<Bool>().value();
node->M = args[5].as<IntImm>().value()->value;
node->N = args[6].as<IntImm>().value()->value;
node->K = args[7].as<IntImm>().value()->value;
node->policy =
static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
node->clear_accum = args[9].as<Bool>().value();
node->stride_A = args[10].as<IntImm>().value()->value;
node->stride_B = args[11].as<IntImm>().value()->value;
node->offset_A = args[12].as<IntImm>().value()->value;
node->offset_B = args[13].as<IntImm>().value()->value;
if (args.size() > 14) {
kPack = args[14].as<IntImm>().value()->value;
if (kPack != 1 && kPack != 2) {
node->kPack = args[14].as<IntImm>().value()->value;
if (node->kPack != 1 && node->kPack != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
if (args.size() > 15) {
wg_wait = args[15].as<IntImm>().value()->value;
node->wg_wait = args[15].as<IntImm>().value()->value;
}
data_ = std::move(node);
}
Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const {
TileOperator GemmNode::Clone() const {
auto op = make_object<GemmNode>(*this);
return Gemm(op);
}
GemmNode::GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size;
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
......@@ -87,10 +96,13 @@ Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const {
* per-warp tile sizes) and adapts the partition according to the configured
* GemmWarpPolicy (FullRow, FullCol, Square).
*
* @param block_size Total number of threads in the block (used to derive num_warps).
* @param block_size Total number of threads in the block (used to derive
* num_warps).
* @param gemm_inst The chosen GEMM implementation (e.g., kWGMMA, kMFMA, kMMA).
* @param target Target device information (used for warp size and target-specific rules).
* @return std::pair<int, int> {m_warp, n_warp} where m_warp * n_warp == num_warps.
* @param target Target device information (used for warp size and
* target-specific rules).
* @return std::pair<int, int> {m_warp, n_warp} where m_warp * n_warp ==
* num_warps.
*
* Constraints and behavior:
* - Each warp is assumed to cover 16 rows (M) and 8 columns (N). The function
......@@ -100,7 +112,8 @@ Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const {
* - num_warps must be a multiple of 4 (warp-groups of 4).
* - m_warp is always a multiple of 4.
* - The warp partition respects the GemmWarpPolicy:
* - FullRow: maximize warps on M (in multiples of 4) while keeping divisibility.
* - FullRow: maximize warps on M (in multiples of 4) while keeping
* divisibility.
* - FullCol: maximize warps on N, but if N is not evenly divisible, move
* whole warp-groups to M to achieve feasibility.
* - Square: choose a multiple-of-4 m_warp that best balances per-warp work
......@@ -118,7 +131,7 @@ Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const {
* divisibility or policy conditions are not met (e.g., M/N tile divisibility,
* invalid policy, or WGMMA-specific warp-group requirements).
*/
std::pair<int, int> Gemm::ComputeWarpPartition(int block_size,
std::pair<int, int> GemmNode::ComputeWarpPartition(int block_size,
GemmInst gemm_inst,
Target target) const {
int num_warps = block_size / TargetGetWarpSize(target);
......@@ -296,19 +309,21 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int block_size,
* Supported combinations and constraints:
* - C=float16:
* - A=float16, B=float16: K % 16 == 0
* - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K % 32 == 0
* - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K %
* 32 == 0
* - C=float32:
* - A=float16, B=float16: K % 16 == 0
* - A=bfloat16, B=bfloat16: K % 16 == 0
* - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
* - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
* - C=int32:
* - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B) and K % 32 == 0
* - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B)
* and K % 32 == 0
*
* @return true if WGMMA is supported for the current buffers, dtypes, and
* transpose/shape constraints; false otherwise.
*/
bool Gemm::CheckWGMMA() const {
bool GemmNode::CheckWGMMA() const {
if (B.scope() != "shared.dyn" && B.scope() != "shared") {
return false;
}
......@@ -373,7 +388,7 @@ static int GetArchInt(Target target) {
return arch_int;
}
Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto block_size = *as_const_int(T.thread_bounds->extent);
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target);
......@@ -425,7 +440,8 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
* - C.scope() must be "local.fragment".
*
* Postconditions / side effects:
* - Marks the operator's layout inference as completed (sets completed_ = true).
* - Marks the operator's layout inference as completed (sets completed_ =
* true).
* - May abort via ICHECK on unsupported targets, invalid buffer scopes, or
* incompatible shape constraints.
*
......@@ -433,7 +449,8 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
* @param level Inference level (unused for side effects but retained for API).
* @return LayoutMap mapping each of A, B, and C to their inferred layouts.
*/
LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (completed_)
return {};
LayoutMap results;
......
......@@ -7,37 +7,21 @@
#ifndef TVM_TL_OP_GEMM_H_
#define TVM_TL_OP_GEMM_H_
#include "op.h"
#include "operator.h"
namespace tvm {
namespace tl {
using namespace tir;
class Gemm : public Operator {
public:
Gemm(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op &Get();
enum class GemmWarpPolicy {
enum class GemmWarpPolicy {
kSquare = 0,
kFullRow = 1,
kFullCol = 2,
} policy;
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<Gemm>(*this);
}
private:
// Target GEMM instruction
enum class GemmInst { kMMA, kWGMMA, kUTCMMA, kMFMA };
GemmInst GetGemmInst(int block_size, Target target) const;
std::pair<int, int> ComputeWarpPartition(int num_warps, GemmInst gemm_inst,
Target target) const;
};
class GemmNode : public TileOperatorNode {
public:
bool CheckWGMMA() const;
Array<PrimExpr> call_args;
tir::Buffer A, B, C;
......@@ -52,7 +36,33 @@ private:
// only will be enabled under cdna mfma instructions
int kPack = 1;
int wg_wait = 0;
bool completed_ = false;
GemmWarpPolicy policy;
static constexpr const char *_type_key = "tl.Gemm";
TVM_DECLARE_FINAL_OBJECT_INFO(GemmNode, TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
TileOperator Clone() const;
private:
// Target GEMM instruction
enum class GemmInst { kMMA, kWGMMA, kUTCMMA, kMFMA };
GemmInst GetGemmInst(int block_size, Target target) const;
std::pair<int, int> ComputeWarpPartition(int num_warps, GemmInst gemm_inst,
Target target) const;
mutable bool completed_ = false;
};
class Gemm : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Gemm, TileOperator, GemmNode);
TVM_DLL Gemm(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
} // namespace tl
......
......@@ -32,30 +32,38 @@ static std::vector<int> toPrimeFactors(int x) {
}
GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) {
A = vmap[GetVarFromAccessPtr(args[0])];
E = vmap[GetVarFromAccessPtr(args[1])];
B = vmap[GetVarFromAccessPtr(args[2])];
C = vmap[GetVarFromAccessPtr(args[3])];
trans_A = args[4].as<Bool>().value();
trans_B = args[5].as<Bool>().value();
M = args[6].as<IntImm>().value()->value;
N = args[7].as<IntImm>().value()->value;
K = args[8].as<IntImm>().value()->value;
policy = static_cast<GemmWarpPolicy>(args[9].as<IntImm>().value()->value);
clear_accum = args[10].as<Bool>().value();
ObjectPtr<GemmSPNode> node = make_object<GemmSPNode>();
node->A = vmap[GetVarFromAccessPtr(args[0])];
node->E = vmap[GetVarFromAccessPtr(args[1])];
node->B = vmap[GetVarFromAccessPtr(args[2])];
node->C = vmap[GetVarFromAccessPtr(args[3])];
node->trans_A = args[4].as<Bool>().value();
node->trans_B = args[5].as<Bool>().value();
node->M = args[6].as<IntImm>().value()->value;
node->N = args[7].as<IntImm>().value()->value;
node->K = args[8].as<IntImm>().value()->value;
node->policy = static_cast<GemmSPNode::GemmWarpPolicy>(
args[9].as<IntImm>().value()->value);
node->clear_accum = args[10].as<Bool>().value();
if (args.size() > 11) {
kPack = args[11].as<IntImm>().value()->value;
if (kPack != 1 && kPack != 2) {
node->kPack = args[11].as<IntImm>().value()->value;
if (node->kPack != 1 && node->kPack != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
if (args.size() > 12) {
wg_wait = args[12].as<IntImm>().value()->value;
node->wg_wait = args[12].as<IntImm>().value()->value;
}
data_ = std::move(node);
}
TileOperator GemmSPNode::Clone() const {
auto op = make_object<GemmSPNode>(*this);
return GemmSP(op);
}
std::pair<int, int>
GemmSP::ComputeWarpPartition(int num_warps, Target target,
GemmSPNode::ComputeWarpPartition(int num_warps, Target target,
bool maybe_hopper_wgmma) const {
int m_warp = 1, n_warp = 1;
constexpr int kMPerWarp = 16; // Rows processed by a single warp
......@@ -212,7 +220,7 @@ GemmSP::ComputeWarpPartition(int num_warps, Target target,
return {m_warp, n_warp};
}
Stmt GemmSP::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Stmt GemmSPNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
int warp_size = 32;
auto block_size = *as_const_int(T.thread_bounds->extent);
......@@ -256,7 +264,8 @@ Stmt GemmSP::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return Evaluate(new_call);
}
LayoutMap GemmSP::InferLayout(const LayoutInferArgs &T, InferLevel level) {
LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (completed_)
return {};
LayoutMap results;
......@@ -308,6 +317,7 @@ LayoutMap GemmSP::InferLayout(const LayoutInferArgs &T, InferLevel level) {
completed_ = true;
return results;
}
TIR_REGISTER_TL_OP(GemmSP, gemm_sp)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
......
......@@ -7,30 +7,23 @@
#ifndef TVM_TL_OP_GEMM_SP_H_
#define TVM_TL_OP_GEMM_SP_H_
#include "op.h"
#include "operator.h"
namespace tvm {
namespace tl {
using namespace tir;
class GemmSP : public Operator {
class GemmSPNode : public TileOperatorNode {
public:
GemmSP(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op &Get();
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
enum class GemmWarpPolicy {
kSquare = 0,
kFullRow = 1,
kFullCol = 2,
} policy;
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<GemmSP>(*this);
}
private:
std::pair<int, int>
ComputeWarpPartition(int num_warps, Target target,
bool maybe_hopper_wgmma = true) const;
......@@ -44,7 +37,18 @@ private:
// only will be enabled under cdna mfma instructions
int kPack = 1;
int wg_wait = 0;
bool completed_ = false;
TileOperator Clone() const;
private:
mutable bool completed_ = false;
};
class GemmSP : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(GemmSP, TileOperator, GemmSPNode);
TVM_DLL GemmSP(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
} // namespace tl
......
/*!
* \file tl/op/op.cc
*
* Define operators usd in tile library.
*/
#include "operator.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
namespace tvm {
namespace tl {
using namespace tir;
TileOperator ParseOperator(Call call, BufferMap vmap) {
auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder");
Op op = call->op.as<Op>().value();
if (op_map.count(op)) {
auto tile_op = op_map[op](call->args, vmap);
ICHECK(tile_op.defined());
return tile_op;
}
return TileOperator();
}
TileOperator ParseOperator(Stmt stmt, BufferMap vmap) {
if (stmt.as<Evaluate>() && stmt.as<EvaluateNode>()->value.as<CallNode>()) {
auto call = stmt.as<EvaluateNode>()->value.as<CallNode>();
return ParseOperator(GetRef<Call>(call), vmap);
}
return TileOperator();
}
Var GetVarFromAccessPtr(const PrimExpr &expr) {
auto call = expr.as<CallNode>();
ICHECK(call);
ICHECK(call->op.same_as(builtin::tvm_access_ptr()));
auto var = call->args[1].as<VarNode>();
ICHECK(var);
return GetRef<Var>(var);
}
} // namespace tl
} // namespace tvm
......@@ -11,6 +11,8 @@
#include <tvm/ir/op.h>
#include <tvm/target/target.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/op_attr_types.h>
#include "../layout/layout.h"
......@@ -22,19 +24,6 @@ using namespace tir;
using AddWorkspaceCallback = std::function<PrimExpr(int, DataType)>;
using LayoutMap = Map<Buffer, Layout>;
using BufferMap = Map<Var, Buffer>;
using OpBuilderFunc = ffi::TypedFunction<void *(Array<PrimExpr>, BufferMap)>;
#define TIR_REGISTER_TL_OP(Entry, OpName) \
const Op &Entry::Get() { \
static const Op &op = Op::Get("tl." #OpName); \
return op; \
} \
TVM_REGISTER_OP("tl." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \
.set_attr<OpBuilderFunc>("TLOpBuilder", \
[](Array<PrimExpr> a, BufferMap b) { \
return (void *)(new Entry(a, b)); \
})
enum class InferLevel {
kFree = 0,
......@@ -59,38 +48,48 @@ struct LayoutInferArgs {
Map<Buffer, Buffer> buffer_remap;
};
class Operator {
public:
virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
virtual LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level);
virtual ~Operator() = default;
virtual std::unique_ptr<Operator> Clone() const = 0;
};
class TileOperatorNode;
class TileOperator;
class RegionOp : public Operator {
public:
RegionOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
class TileOperatorNode: public Object {
public:
virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const = 0;
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<RegionOp>(*this);
}
virtual LayoutMap InferLayout(const LayoutInferArgs& T,
InferLevel level) const = 0;
const Buffer &GetBuffer() const { return buffer_; }
const Array<Range> &GetRanges() const { return ranges_; }
int GetAccessMask() const { return access_mask_; }
bool IsFullRegion() const;
virtual TileOperator Clone() const = 0;
static constexpr const char* _type_key = "tl.TileOperator";
TVM_DECLARE_BASE_OBJECT_INFO(TileOperatorNode, Object);
};
private:
Buffer buffer_;
Array<Range> ranges_;
int access_mask_;
class TileOperator : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TileOperator, ObjectRef, TileOperatorNode);
};
Var GetVarFromAccessPtr(const PrimExpr &expr);
std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap);
std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap);
TileOperator ParseOperator(Call call, BufferMap vmap);
TileOperator ParseOperator(Stmt stmt, BufferMap vmap);
using OpBuilderFunc = ffi::TypedFunction<TileOperator(Array<PrimExpr>, BufferMap)>;
#define TIR_REGISTER_TL_OP(Entry, OpName) \
const Op &Entry::Get() { \
static const Op &op = Op::Get("tl." #OpName); \
return op; \
} \
TVM_REGISTER_OP("tl." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \
.set_attr<OpBuilderFunc>("TLOpBuilder", \
[](Array<PrimExpr> args, BufferMap vmap) { \
return Entry(args, vmap); \
})
} // namespace tl
} // namespace tvm
......
......@@ -154,9 +154,21 @@ void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) {
StmtExprVisitor::VisitExpr_(op);
}
ParallelOp::ParallelOp(For root) : root_(root), V(this) { V.VisitStmt(root); }
ParallelOpNode::ParallelOpNode(For root) : root_(root), V(this) {
V.VisitStmt(root);
}
TileOperator ParallelOpNode::Clone() const {
auto op = make_object<ParallelOpNode>(*this);
return ParallelOp(op);
}
Stmt ParallelOpNode::Lower(const LowerArgs &T,
arith::Analyzer *analyzer) const {
return root_;
}
bool ParallelOp::IsCommonAccessIndice(const Buffer &buffer) const {
bool ParallelOpNode::IsCommonAccessIndice(const Buffer &buffer) const {
auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; });
return StructuralEqual()(indice_map_[buffer], common_indice);
}
......@@ -179,7 +191,8 @@ bool ParallelOp::IsCommonAccessIndice(const Buffer &buffer) const {
* Can generate new layouts based on vectorization and thread
* bounds. Used when maximum performance optimization is desired.
*/
LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (loop_layout_.defined())
return {};
if (level == InferLevel::kStrict)
......@@ -355,7 +368,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
return results;
}
Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const {
Optional<PrimExpr> ParallelOpNode::GetPredicate(Var thread_var) const {
if (predicate_.defined()) {
return Substitute(predicate_.value(), {{InputPlaceholder(0), thread_var}});
} else {
......@@ -363,7 +376,7 @@ Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const {
}
}
Fragment ParallelOp::CompleteBufferFragment(const Buffer &buffer) {
Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
ICHECK(loop_layout_.defined());
if (IsCommonAccessIndice(buffer)) {
return loop_layout_;
......
......@@ -10,7 +10,7 @@
#include <tvm/tir/stmt_functor.h>
#include "../layout/layout.h"
#include "op.h"
#include "operator.h"
namespace tvm {
namespace tl {
......@@ -31,58 +31,97 @@ bool ProveFragmentContains(Fragment small_frag, Fragment large_frag,
Array<PrimExpr> large_frag_indices,
arith::Analyzer &analyzer_);
class ParallelOp;
class ParallelOpNode;
class ParallelLoopNestVisitor : public StmtExprVisitor {
private:
ParallelLoopNestVisitor(ParallelOp *op) : p(op){};
void VisitStmt_(const ForNode *op) final;
void VisitStmt_(const BufferStoreNode *op) final;
void VisitExpr_(const BufferLoadNode *op) final;
ParallelLoopNestVisitor(ParallelOpNode *op) : p(op){};
void VisitStmt_(const ForNode *op) override;
void VisitStmt_(const BufferStoreNode *op) override;
void VisitExpr_(const BufferLoadNode *op) override;
ParallelOp *p;
ParallelOpNode *p;
friend class ParallelOp;
friend class ParallelOpNode;
};
class ParallelOp : public Operator {
// ParallelOpNode represents a parallel for loop operator in TileLang.
// It is responsible for inferring layouts, holding loop structure, and managing
// predicates.
class ParallelOpNode : public TileOperatorNode {
public:
ParallelOp(For root);
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
// The inferred layout for the loop, mutable to allow lazy inference.
mutable Fragment loop_layout_;
// The predicate expression for the loop, if any, mutable for lazy
// construction.
mutable Optional<PrimExpr> predicate_;
ParallelOp(const ParallelOp &other) : ParallelOp(other.root_) {
// Type key for TVM object system.
static constexpr const char *_type_key = "tl.ParallelOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ParallelOpNode, TileOperatorNode);
// Construct from a root For loop.
ParallelOpNode(For root);
// Lower the operator to a TIR statement.
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
// Infer the layout for this parallel operator.
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
// Copy constructor for ParallelOpNode.
ParallelOpNode(const ParallelOpNode &other) : ParallelOpNode(other.root_) {
loop_layout_ = other.loop_layout_;
predicate_ = other.predicate_;
}
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<ParallelOp>(*this);
}
// Get the inferred loop layout.
Fragment GetLoopLayout() const { return loop_layout_; }
// Get the root For loop.
For GetRoot() const { return root_; }
// Get the mapping from buffer to access indices.
Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; }
// Get the predicate for a given thread variable.
Optional<PrimExpr> GetPredicate(Var thread_var) const;
// Clone this operator.
TileOperator Clone() const;
private:
Fragment CompleteBufferFragment(const Buffer &buffer);
// Complete the fragment layout for a given buffer.
Fragment CompleteBufferFragment(const Buffer &buffer) const;
// Check if the buffer is accessed with common indices (i.e., loop variables).
bool IsCommonAccessIndice(const Buffer &buffer) const;
void AddPredicate(PrimExpr expr) {
// Add a predicate to the current predicate expression.
void AddPredicate(PrimExpr expr) const {
predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
}
// Allow ParallelLoopNestVisitor to access private members.
friend class ParallelLoopNestVisitor;
// The root For loop node.
For root_;
// Visitor for collecting loop nest information.
ParallelLoopNestVisitor V;
// Mapping from buffer to their access indices in the loop.
Map<Buffer, Array<PrimExpr>> indice_map_;
// Set of buffers that are written to in the loop.
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_is_write_;
// The loop variables for the parallel loop nest.
Array<IterVar> loop_vars_;
Fragment loop_layout_;
// Analyzer for simplifying and analyzing expressions, mutable for lazy use.
mutable arith::Analyzer analyzer_;
Optional<PrimExpr> predicate_;
};
friend class ParallelLoopNestVisitor;
class ParallelOp : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ParallelOp, TileOperator, ParallelOpNode);
ParallelOp(For root) {
auto op = make_object<ParallelOpNode>(root);
data_ = std::move(op);
}
};
} // namespace tl
......
......@@ -22,26 +22,38 @@ namespace tl {
using namespace tir;
ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
src = vmap[GetVarFromAccessPtr(args[0])];
dst = vmap[GetVarFromAccessPtr(args[1])];
String reduce_type = args[2].as<StringImm>().value()->value;
dim = args[3].as<IntImm>().value()->value;
ObjectPtr<ReduceOpNode> node = make_object<ReduceOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])];
node->dst = vmap[GetVarFromAccessPtr(args[1])];
std::string reduce_type = args[2].as<StringImm>().value()->value;
node->dim = args[3].as<IntImm>().value()->value;
if (reduce_type == "sum")
type = ReduceType::kSum;
node->type = ReduceType::kSum;
else if (reduce_type == "abssum")
type = ReduceType::kAbsSum;
node->type = ReduceType::kAbsSum;
else if (reduce_type == "absmax")
type = ReduceType::kAbsMax;
node->type = ReduceType::kAbsMax;
else if (reduce_type == "max")
type = ReduceType::kMax;
node->type = ReduceType::kMax;
else if (reduce_type == "min")
type = ReduceType::kMin;
node->type = ReduceType::kMin;
else
ICHECK(0) << "Unknown reduce type: " << reduce_type;
clear = args[4].as<Bool>().value();
node->clear = args[4].as<Bool>().value();
data_ = std::move(node);
}
PrimExpr ReduceOp::MakeInitValue() const {
TileOperator ReduceOpNode::Clone() const {
auto op = make_object<ReduceOpNode>(*this);
return ReduceOp(op);
}
TileOperator CumSumOpNode::Clone() const {
auto op = make_object<CumSumOpNode>(*this);
return CumSumOp(op);
}
PrimExpr ReduceOpNode::MakeInitValue() const {
auto dst_dtype = dst->dtype;
auto is_int = dst_dtype.is_int();
bool is_uint = dst_dtype.is_uint();
......@@ -75,7 +87,7 @@ PrimExpr ReduceOp::MakeInitValue() const {
}
}
PrimExpr ReduceOp::MakeReduce(const PrimExpr &a, const PrimExpr &b) const {
PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &a, const PrimExpr &b) const {
PrimExpr lhs = a, rhs = b;
if (lhs->dtype != rhs->dtype) {
rhs = Cast(lhs->dtype, rhs);
......@@ -97,7 +109,7 @@ PrimExpr ReduceOp::MakeReduce(const PrimExpr &a, const PrimExpr &b) const {
}
}
std::string ReduceOp::MakeCodegenReducer() const {
std::string ReduceOpNode::MakeCodegenReducer() const {
switch (type) {
case ReduceType::kSum:
return "tl::SumOp";
......@@ -115,7 +127,7 @@ std::string ReduceOp::MakeCodegenReducer() const {
}
}
Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(this->src.scope() == "local.fragment" &&
this->dst.scope() == "local.fragment")
<< "Reduce for shared memory not implemented.";
......@@ -284,7 +296,8 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return body;
}
LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (level >= InferLevel::kStrict)
return {};
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
......@@ -369,14 +382,16 @@ CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
reverse: whether to cumsum in reverse order
*/
CHECK_EQ(args.size(), 4);
src = vmap[GetVarFromAccessPtr(args[0])];
dst = vmap[GetVarFromAccessPtr(args[1])];
dim = args[2].as<IntImm>().value()->value;
reverse = args[3].as<Bool>().value();
CHECK_LT(dim, static_cast<int>(src->shape.size()));
ObjectPtr<CumSumOpNode> node = make_object<CumSumOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])];
node->dst = vmap[GetVarFromAccessPtr(args[1])];
node->dim = args[2].as<IntImm>().value()->value;
node->reverse = args[3].as<Bool>().value();
CHECK_LT(node->dim, static_cast<int>(node->src->shape.size()));
data_ = std::move(node);
}
Stmt CumSumOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (this->src.scope() == "local.fragment" &&
this->dst.scope() == "local.fragment") {
LOG(FATAL) << "CumSum for fragment not implemented, please raise an issue "
......@@ -402,7 +417,8 @@ Stmt CumSumOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return Stmt();
}
LayoutMap CumSumOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
LayoutMap CumSumOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
return {};
}
......
......@@ -7,56 +7,70 @@
#ifndef TVM_TL_OP_REDUCE_H_
#define TVM_TL_OP_REDUCE_H_
#include "op.h"
#include "operator.h"
namespace tvm {
namespace tl {
using namespace tir;
class ReduceOp : public Operator {
public:
ReduceOp(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op &Get();
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<ReduceOp>(*this);
}
private:
tir::Buffer src, dst;
int dim;
enum class ReduceType {
enum class ReduceType {
kSum,
kAbsSum,
kMax,
kMin,
kAbsMax,
} type;
};
class ReduceOpNode : public TileOperatorNode {
public:
tir::Buffer src, dst;
int dim;
ReduceType type;
bool clear;
static constexpr const char *_type_key = "tl.ReduceOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ReduceOpNode, TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
static const Op &Get();
TileOperator Clone() const;
private:
PrimExpr MakeInitValue() const;
PrimExpr MakeReduce(const PrimExpr &a, const PrimExpr &b) const;
std::string MakeCodegenReducer() const;
};
class CumSumOp : public Operator {
class ReduceOp : public TileOperator {
public:
CumSumOp(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
TVM_DEFINE_OBJECT_REF_METHODS(ReduceOp, TileOperator, ReduceOpNode);
TVM_DLL ReduceOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<CumSumOp>(*this);
}
private:
class CumSumOpNode : public TileOperatorNode {
public:
tir::Buffer src, dst;
int dim;
bool reverse;
static constexpr const char *_type_key = "tl.CumSumOp";
TVM_DECLARE_FINAL_OBJECT_INFO(CumSumOpNode, TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
static const Op &Get();
TileOperator Clone() const;
};
class CumSumOp : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(CumSumOp, TileOperator, CumSumOpNode);
TVM_DLL CumSumOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
} // namespace tl
......
/*!
* \file tl/op/op.cc
* \file tl/op/region.cc
* \brief Define region operator.
*
* Define operators usd in tile library.
*/
#include "op.h"
#include <tvm/tir/builtin.h>
#include "region.h"
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
namespace tvm {
namespace tl {
using namespace tir;
TIR_REGISTER_TL_OP(RegionOp, region)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap) {
auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder");
Op op = call->op.as<Op>().value();
if (op_map.count(op)) {
Operator *ptr = static_cast<Operator *>(op_map[op](call->args, vmap));
ICHECK(ptr != nullptr);
return std::unique_ptr<Operator>(ptr);
}
return nullptr;
}
std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap) {
if (stmt.as<Evaluate>() && stmt.as<EvaluateNode>()->value.as<CallNode>()) {
auto call = stmt.as<EvaluateNode>()->value.as<CallNode>();
return ParseOperator(GetRef<Call>(call), vmap);
}
return nullptr;
}
Var GetVarFromAccessPtr(const PrimExpr &expr) {
auto call = expr.as<CallNode>();
ICHECK(call);
ICHECK(call->op.same_as(builtin::tvm_access_ptr()));
auto var = call->args[1].as<VarNode>();
ICHECK(var);
return GetRef<Var>(var);
}
RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
size_t n = args.size();
size_t ndim = n - 2;
......@@ -55,16 +18,25 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
ICHECK(load);
ICHECK(load->indices.size() == ndim)
<< "load->indices.size() = " << load->indices << " ndim = " << ndim;
buffer_ = load->buffer;
access_mask_ = static_cast<int>(*as_const_int(args[1]));
Array<Range> ranges;
for (size_t i = 0; i < ndim; i++) {
PrimExpr min = load->indices[i];
PrimExpr extent = args[2 + i];
ranges_.push_back(Range::FromMinExtent(min, extent));
ranges.push_back(Range::FromMinExtent(min, extent));
}
ObjectPtr<RegionOpNode> node = make_object<RegionOpNode>();
node->buffer_ = load->buffer;
node->access_mask_ = static_cast<int>(*as_const_int(args[1]));
node->ranges_ = ranges;
data_ = std::move(node);
}
TileOperator RegionOpNode::Clone() const {
auto op = make_object<RegionOpNode>(*this);
return RegionOp(op);
}
bool RegionOp::IsFullRegion() const {
bool RegionOpNode::IsFullRegion() const {
for (size_t i = 0; i < ranges_.size(); i++) {
if (!is_zero(ranges_[i]->min))
return false;
......@@ -74,14 +46,19 @@ bool RegionOp::IsFullRegion() const {
return true;
}
Stmt Operator::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(0) << "Not Implemented Lower method.";
Stmt RegionOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return Evaluate(0);
}
LayoutMap Operator::InferLayout(const LayoutInferArgs &T, InferLevel level) {
LayoutMap RegionOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
return {};
}
TIR_REGISTER_TL_OP(RegionOp, region)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
} // namespace tl
} // namespace tvm
/*!
* \file tl/op/op.h
* \brief Tile library operations.
*
*/
#ifndef TVM_TL_OP_REGION_H_
#define TVM_TL_OP_REGION_H_
#include "./operator.h"
#include <tvm/arith/analyzer.h>
#include <tvm/ir/op.h>
#include <tvm/target/target.h>
#include <tvm/tir/buffer.h>
namespace tvm {
namespace tl {
using namespace tir;
class RegionOpNode : public TileOperatorNode {
public:
Buffer buffer_;
Array<Range> ranges_;
int access_mask_;
static constexpr const char *_type_key = "tl.RegionOp";
TVM_DECLARE_FINAL_OBJECT_INFO(RegionOpNode, TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
const Buffer &GetBuffer() const { return buffer_; }
const Array<Range> &GetRanges() const { return ranges_; }
int GetAccessMask() const { return access_mask_; }
bool IsFullRegion() const;
TileOperator Clone() const;
};
class RegionOp : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RegionOp, TileOperator, RegionOpNode);
TVM_DLL RegionOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_REGION_H_
......@@ -15,6 +15,7 @@
#include "../layout/utils.h"
#include "../op/parallel.h"
#include "../op/region.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_fusion_utils.h"
......@@ -79,8 +80,8 @@ public:
auto iter_var = thread_var_vec_[cur_infer_id];
auto thread_bounds = thread_bounds_vec_[cur_infer_id];
// Double-check that 'next' is valid
ICHECK(next != nullptr)
<< "infer_list_[" << cur_infer_id << "] is null inside run_infer_step.";
ICHECK(next.defined()) << "infer_list_[" << cur_infer_id
<< "] is null inside run_infer_step.";
// Check iter_var->dom and dom->extent
ICHECK(iter_var.defined())
......@@ -100,6 +101,7 @@ public:
// Run InferLayout
auto updates = next->InferLayout(
LayoutInferArgs{target_, thread_bounds, layout_map}, level);
// Process the returned updates
for (const auto &[buffer, layout] : updates) {
// Basic validity checks
......@@ -112,7 +114,7 @@ public:
level != InferLevel::kStrict && !strict_layout_map.count(buffer)) {
// Actually this test has been done in ParallelOp::InferLayout
// already. Just do it again to avoid missing implementations in other
// `Operator`s.
// `TileOperator`s.
auto dst_layout = layout.as<Fragment>().value();
auto src_layout = layout_map[buffer].as<Fragment>().value();
ICHECK(dst_layout->InputDim() == src_layout->InputDim());
......@@ -210,7 +212,7 @@ public:
std::vector<bool> in_queue(num_infer, true);
for (int i = 0; i < num_infer; i++) {
// Check that each infer_list_ entry is valid
ICHECK(infer_list_[i] != nullptr)
ICHECK(infer_list_[i].defined())
<< "infer_list_[" << i
<< "] is null. The inference object is not allocated properly.";
......@@ -253,13 +255,13 @@ public:
ICHECK(infer_list_.size() == thread_var_vec_.size())
<< "infer_list_ and thread_var_vec_ size mismatch";
for (int i = 0; i < infer_list_.size(); i++) {
std::unique_ptr<Operator> base_infer = std::move(infer_list_[i]);
TileOperator base_infer = std::move(infer_list_[i]);
auto thread_var = thread_var_vec_[i];
// Check if base_infer is valid
ICHECK(base_infer != nullptr) << "Null pointer encountered in "
ICHECK(base_infer.defined()) << "Null pointer encountered in "
"infer_list_ while collecting for_map.";
if (auto for_infer = dynamic_cast<ParallelOp *>(base_infer.get())) {
if (auto for_infer = base_infer.as<ParallelOpNode>()) {
// Check that the loop layout is defined
ICHECK(for_infer->GetLoopLayout().defined())
<< "The Layout for Parallel for cannot be inferred correctly:\n"
......@@ -297,7 +299,7 @@ private:
return;
auto p = ParseOperator(GetRef<Call>(op), buffer_data_to_buffer_);
if (p != nullptr) {
if (p.defined()) {
for (const auto &arg : op->args) {
if (auto buffer = getBufferFromAccessPtr(arg)) {
addToUseList(buffer.value());
......@@ -344,7 +346,7 @@ private:
void VisitStmt_(const ForNode *op) final {
if (op->kind == ForKind::kParallel) {
auto infer = std::make_unique<ParallelOp>(GetRef<For>(op));
auto infer = ParallelOp(GetRef<For>(op));
for (const auto &[buffer, _] : infer->GetIndiceMap()) {
addToUseList(buffer);
}
......@@ -399,7 +401,7 @@ private:
Map<Var, Buffer> buffer_data_to_buffer_;
std::vector<ObjectRef> infer_list_stmt_;
std::vector<std::unique_ptr<Operator>> infer_list_;
std::vector<TileOperator> infer_list_;
std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
use_list_;
// This is a workaround for cpu backend,
......@@ -412,8 +414,8 @@ private:
LayoutMap annotated_layout_map_;
bool skip_thread_partition_{false};
std::vector<std::unique_ptr<Operator>> BackupInferList() {
std::vector<std::unique_ptr<Operator>> back_infer_list;
std::vector<TileOperator> BackupInferList() {
std::vector<TileOperator> back_infer_list;
back_infer_list.reserve(infer_list_.size());
for (auto &&p : infer_list_) {
back_infer_list.push_back(p->Clone());
......@@ -443,20 +445,25 @@ private:
int root = uf.Find(i);
components[root].push_back(i);
}
// Create a map from root to buffers
std::unordered_map<int, std::vector<Buffer>> components_buffers;
for (const auto &[buffer, infer_indices] : use_list_) {
int root = uf.Find(infer_indices[0]);
components_buffers[root].push_back(buffer);
}
// Keep components_buffers for debug purpose
(void)components_buffers;
// For each component, try each op as root, and determine the least
// replicated one
std::queue<int> q;
std::vector<bool> in_queue(infer_list_.size(), false);
for (auto &&[root, members] : components) {
decltype(infer_list_) best_infer_list;
LayoutMap best_layout_map;
int64_t min_reg_num = INT64_MAX;
for (int attempt_infer_root : members) {
// backup infer_list_ in class member
auto back_infer_list = BackupInferList();
......@@ -470,7 +477,6 @@ private:
tmp_layout_map, strict_layout_map, q, in_queue);
FinishInferQueue(InferLevel::kFree, tmp_layout_map, strict_layout_map,
q, in_queue);
// Silly workaround: we have no clue if single root will iterate over
// the entire component, since the InferLayout implementations have
// complicated conditioning inside and we know nothing about it.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment