Unverified Commit 10911e28 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[FFI] Rebase tvm to v0.22.0 to utilize tvm-ffi (#1108)



* 3rdparty tvm bump

* bump tvm into v0.22.0

* lint fix

* rebase tvm

* Update submodule tvm to latest commit 3085bc4

* Refactor: Update configuration retrieval in CopyNode and adjust test registration in tilelang

* test fix

* add requirement

* atomic_fix

* atomic_fix

* phaseout py39

* optimize

* optimize

* lint fix

* do not clean cache

* do not clean cache

* [Minor] Minor update for Python versions and dependencies

* [Lint] fix lint for py39

* [Lint] fix lint for ROCm

* [Build][CI] Sync CI changes from upstream/sdist

* [Lint] fix lint for ROCm

* [Build][CI] Update `repair-wheel-command`

* [Minor] update abi3audit result format

* [Lint] fix lint for ROCm

* [BugFix] fix build

* [Lint] fix lint for ROCm

* [BugFix] set rpath for libtvm and libtvm_runtime

* [Deps] pin apache-tvm-ffi version

* [Build] set Python 3.9 Limited API for Cython target

* [Build] set Python 3.9 Limited API for Cython target

* [Deps] Restore Python 3.8 support

* [Build] use `apache-tvm-ffi`'s `libtvm_ffi`

* [BugFix] use `;` as delimiter for RPATH on macOS

* [BugFix] use `--ignore-missing-dependencies` for `delocate-wheel`

* [Build] support `sccache` if available

* [Build] add CIBW import test

* [Build][CI] enable ccache for CIBW on Linux

* [BugFix] set rpath for libtvm and libtvm_runtime

* Revert "[Build][CI] enable ccache for CIBW on Linux"

This reverts commit cd9ab57bb5ddd2572c60bcbbebde81480a658fd3.

* [CI] fix perfbench bot

* [BugFix] use Python 3.9 to build wheel

* [Minor] update perfbench bot envs

* [BugFix] fix CIBW environment on Linux

* [CI] skip import test on CentOS 7

* [CI] use Python urllib to download file instead of Wget

---------
Co-authored-by: default avatarXuehai Pan <XuehaiPan@pku.edu.cn>
parent c37621c5
...@@ -25,8 +25,8 @@ public: ...@@ -25,8 +25,8 @@ public:
IntImm memory_order; ///< Memory order for atomic operations IntImm memory_order; ///< Memory order for atomic operations
mutable ParallelOp par_op_; ///< Associated parallel operation mutable ParallelOp par_op_; ///< Associated parallel operation
static constexpr const char *_type_key = "tl.AtomicAdd"; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.AtomicAdd", AtomicAddNode,
TVM_DECLARE_FINAL_OBJECT_INFO(AtomicAddNode, TileOperatorNode); TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
...@@ -46,28 +46,6 @@ public: ...@@ -46,28 +46,6 @@ public:
.def_ro("memory_order", &AtomicAddNode::memory_order); .def_ro("memory_order", &AtomicAddNode::memory_order);
} }
bool SEqualReduce(const AtomicAddNode *other, SEqualReducer equal) const {
return equal(src, other->src) && equal(dst, other->dst) &&
equal(src_range, other->src_range) &&
equal(dst_range, other->dst_range) &&
equal(use_tma, other->use_tma) &&
equal(coalesced_width, other->coalesced_width) &&
equal(memory_order, other->memory_order);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(src);
hash_reduce(dst);
hash_reduce(src_range);
hash_reduce(dst_range);
hash_reduce(use_tma);
hash_reduce(coalesced_width);
hash_reduce(memory_order);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
protected: protected:
/// Create SIMT-style parallel loop structure /// Create SIMT-style parallel loop structure
For MakeSIMTLoop(arith::Analyzer *analyzer) const; For MakeSIMTLoop(arith::Analyzer *analyzer) const;
...@@ -85,7 +63,8 @@ protected: ...@@ -85,7 +63,8 @@ protected:
/// Wrapper class for atomic addition operations /// Wrapper class for atomic addition operations
class AtomicAdd : public TileOperator { class AtomicAdd : public TileOperator {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(AtomicAdd, TileOperator, AtomicAddNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AtomicAdd, TileOperator,
AtomicAddNode);
TVM_DLL AtomicAdd(Array<PrimExpr> args, BufferMap vmap); TVM_DLL AtomicAdd(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get(); static const Op &Get();
}; };
......
...@@ -130,7 +130,7 @@ template <typename T> static Array<T> ReverseArray(Array<T> array) { ...@@ -130,7 +130,7 @@ template <typename T> static Array<T> ReverseArray(Array<T> array) {
* @param vmap BufferMap used to resolve RegionOp buffers and ranges. * @param vmap BufferMap used to resolve RegionOp buffers and ranges.
*/ */
Copy::Copy(Array<PrimExpr> args, BufferMap vmap) { Copy::Copy(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<CopyNode> node = make_object<CopyNode>(); ObjectPtr<CopyNode> node = tvm::ffi::make_object<CopyNode>();
Array<Range> rgs[2]; Array<Range> rgs[2];
Buffer bf[2]; Buffer bf[2];
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
...@@ -169,7 +169,7 @@ Copy::Copy(Array<PrimExpr> args, BufferMap vmap) { ...@@ -169,7 +169,7 @@ Copy::Copy(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator owning the cloned CopyNode. * @return TileOperator A TileOperator owning the cloned CopyNode.
*/ */
TileOperator CopyNode::Clone() const { TileOperator CopyNode::Clone() const {
auto op = make_object<CopyNode>(*this); auto op = tvm::ffi::make_object<CopyNode>(*this);
if (par_op_.defined()) { if (par_op_.defined()) {
op->par_op_ = Downcast<ParallelOp>(par_op_->Clone()); op->par_op_ = Downcast<ParallelOp>(par_op_->Clone());
} }
...@@ -401,7 +401,7 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, ...@@ -401,7 +401,7 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
using namespace tvm::transform; using namespace tvm::transform;
PassContext pass_ctx = PassContext::Current(); PassContext pass_ctx = PassContext::Current();
bool disable_tma_lower = bool disable_tma_lower =
pass_ctx->GetConfig<bool>(kDisableTMALower, false).value(); pass_ctx->GetConfig<Bool>(kDisableTMALower, Bool(false)).value();
auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma, auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma,
T.layout_map, T.analyzer, T.buffer_oob); T.layout_map, T.analyzer, T.buffer_oob);
...@@ -793,7 +793,7 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -793,7 +793,7 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
using namespace tvm::transform; using namespace tvm::transform;
PassContext pass_ctx = PassContext::Current(); PassContext pass_ctx = PassContext::Current();
bool disable_tma_lower = bool disable_tma_lower =
pass_ctx->GetConfig<bool>(kDisableTMALower, false).value(); pass_ctx->GetConfig<Bool>(kDisableTMALower, Bool(false)).value();
auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma, auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma,
T.layout_map, analyzer); T.layout_map, analyzer);
if (copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) { if (copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) {
...@@ -1722,7 +1722,8 @@ Array<PrimExpr> TMADesc::EncodeCallArgs() const { ...@@ -1722,7 +1722,8 @@ Array<PrimExpr> TMADesc::EncodeCallArgs() const {
* @param vmap Mapping from original buffer variables to actual Buffer objects. * @param vmap Mapping from original buffer variables to actual Buffer objects.
*/ */
Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) { Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<Conv2DIm2ColOpNode> node = make_object<Conv2DIm2ColOpNode>(); ObjectPtr<Conv2DIm2ColOpNode> node =
tvm::ffi::make_object<Conv2DIm2ColOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])]; node->src = vmap[GetVarFromAccessPtr(args[0])];
node->dst = vmap[GetVarFromAccessPtr(args[1])]; node->dst = vmap[GetVarFromAccessPtr(args[1])];
node->nhw_step = args[2]; node->nhw_step = args[2];
...@@ -1747,7 +1748,7 @@ Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) { ...@@ -1747,7 +1748,7 @@ Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator containing the cloned Conv2DIm2ColOpNode. * @return TileOperator A TileOperator containing the cloned Conv2DIm2ColOpNode.
*/ */
TileOperator Conv2DIm2ColOpNode::Clone() const { TileOperator Conv2DIm2ColOpNode::Clone() const {
auto op = make_object<Conv2DIm2ColOpNode>(*this); auto op = tvm::ffi::make_object<Conv2DIm2ColOpNode>(*this);
return Conv2DIm2ColOp(op); return Conv2DIm2ColOp(op);
} }
...@@ -1973,9 +1974,9 @@ TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col) ...@@ -1973,9 +1974,9 @@ TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
CopyNode::RegisterReflection(); CopyNode::RegisterReflection();
Conv2DIm2ColOpNode::RegisterReflection(); Conv2DIm2ColOpNode::RegisterReflection();
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -101,8 +101,7 @@ public: ...@@ -101,8 +101,7 @@ public:
}; };
uint8_t eviction_policy; // Policy for cache eviction uint8_t eviction_policy; // Policy for cache eviction
static constexpr const char *_type_key = "tl.Copy"; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Copy", CopyNode, TileOperatorNode);
TVM_DECLARE_FINAL_OBJECT_INFO(CopyNode, TileOperatorNode);
static void RegisterReflection() { static void RegisterReflection() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
...@@ -114,23 +113,6 @@ public: ...@@ -114,23 +113,6 @@ public:
.def_ro("coalesced_width", &CopyNode::coalesced_width); .def_ro("coalesced_width", &CopyNode::coalesced_width);
} }
bool SEqualReduce(const CopyNode *other, SEqualReducer equal) const {
return equal(src, other->src) && equal(dst, other->dst) &&
equal(src_range, other->src_range) &&
equal(dst_range, other->dst_range) &&
equal(coalesced_width, other->coalesced_width);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(src);
hash_reduce(dst);
hash_reduce(src_range);
hash_reduce(dst_range);
hash_reduce(coalesced_width);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
/*! /*!
* \brief Lower the copy operator to a TIR statement. * \brief Lower the copy operator to a TIR statement.
* \param T Arguments for lowering. * \param T Arguments for lowering.
...@@ -291,7 +273,7 @@ protected: ...@@ -291,7 +273,7 @@ protected:
class Copy : public TileOperator { class Copy : public TileOperator {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(Copy, TileOperator, CopyNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Copy, TileOperator, CopyNode);
/*! /*!
* \brief Constructor. * \brief Constructor.
...@@ -323,8 +305,8 @@ public: ...@@ -323,8 +305,8 @@ public:
PrimExpr nhw_step; // Step size in NHW dimensions PrimExpr nhw_step; // Step size in NHW dimensions
PrimExpr c_step; // Step size in channel dimension PrimExpr c_step; // Step size in channel dimension
static constexpr const char *_type_key = "tl.Conv2DIm2Col"; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Conv2DIm2Col", Conv2DIm2ColOpNode,
TVM_DECLARE_FINAL_OBJECT_INFO(Conv2DIm2ColOpNode, TileOperatorNode); TileOperatorNode);
static void RegisterReflection() { static void RegisterReflection() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
...@@ -338,26 +320,6 @@ public: ...@@ -338,26 +320,6 @@ public:
.def_ro("eviction_policy", &Conv2DIm2ColOpNode::eviction_policy); .def_ro("eviction_policy", &Conv2DIm2ColOpNode::eviction_policy);
} }
bool SEqualReduce(const Conv2DIm2ColOpNode *other,
SEqualReducer equal) const {
return equal(src, other->src) && equal(dst, other->dst) &&
equal(stride, other->stride) && equal(padding, other->padding) &&
equal(dilation, other->dilation) && equal(kernel, other->kernel) &&
equal(eviction_policy, other->eviction_policy);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(src);
hash_reduce(dst);
hash_reduce(stride);
hash_reduce(padding);
hash_reduce(dilation);
hash_reduce(kernel);
hash_reduce(eviction_policy);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
/*! /*!
* \brief Lower to TIR statement. * \brief Lower to TIR statement.
*/ */
...@@ -378,7 +340,7 @@ public: ...@@ -378,7 +340,7 @@ public:
class Conv2DIm2ColOp : public TileOperator { class Conv2DIm2ColOp : public TileOperator {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(Conv2DIm2ColOp, TileOperator, TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Conv2DIm2ColOp, TileOperator,
Conv2DIm2ColOpNode); Conv2DIm2ColOpNode);
TVM_DLL Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap); TVM_DLL Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get(); static const Op &Get();
......
...@@ -60,7 +60,7 @@ using namespace tir; ...@@ -60,7 +60,7 @@ using namespace tir;
* of bounds. * of bounds.
*/ */
Fill::Fill(Array<PrimExpr> args, BufferMap vmap) { Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<FillNode> node = make_object<FillNode>(); ObjectPtr<FillNode> node = tvm::ffi::make_object<FillNode>();
if (args[0]->IsInstance<BufferLoadNode>()) { if (args[0]->IsInstance<BufferLoadNode>()) {
auto buffer_load = Downcast<BufferLoad>(args[0]); auto buffer_load = Downcast<BufferLoad>(args[0]);
...@@ -117,7 +117,7 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) { ...@@ -117,7 +117,7 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator that owns the copied FillNode. * @return TileOperator A TileOperator that owns the copied FillNode.
*/ */
TileOperator FillNode::Clone() const { TileOperator FillNode::Clone() const {
auto op = make_object<FillNode>(*this); auto op = tvm::ffi::make_object<FillNode>(*this);
return Fill(op); return Fill(op);
} }
...@@ -226,7 +226,7 @@ TIR_REGISTER_TL_OP(Fill, fill) ...@@ -226,7 +226,7 @@ TIR_REGISTER_TL_OP(Fill, fill)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ FillNode::RegisterReflection(); }); TVM_FFI_STATIC_INIT_BLOCK() { FillNode::RegisterReflection(); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
\ No newline at end of file
...@@ -20,8 +20,7 @@ public: ...@@ -20,8 +20,7 @@ public:
tir::Buffer dst; ///< Destination buffer to fill tir::Buffer dst; ///< Destination buffer to fill
PrimExpr value; ///< Value to fill with PrimExpr value; ///< Value to fill with
Array<Range> region; ///< Region to fill within the buffer Array<Range> region; ///< Region to fill within the buffer
static constexpr const char *_type_key = "tl.Fill"; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Fill", FillNode, TileOperatorNode);
TVM_DECLARE_FINAL_OBJECT_INFO(FillNode, TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
...@@ -35,19 +34,6 @@ public: ...@@ -35,19 +34,6 @@ public:
.def_ro("region", &FillNode::region); .def_ro("region", &FillNode::region);
} }
bool SEqualReduce(const FillNode *other, SEqualReducer equal) const {
return equal(dst, other->dst) && equal(value, other->value) &&
equal(region, other->region);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dst);
hash_reduce(value);
hash_reduce(region);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
TileOperator Clone() const; TileOperator Clone() const;
private: private:
...@@ -58,7 +44,7 @@ private: ...@@ -58,7 +44,7 @@ private:
/// Wrapper class for fill operations /// Wrapper class for fill operations
class Fill : public TileOperator { class Fill : public TileOperator {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(Fill, TileOperator, FillNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fill, TileOperator, FillNode);
TVM_DLL Fill(Array<PrimExpr> args, BufferMap vmap); TVM_DLL Fill(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get(); static const Op &Get();
}; };
......
...@@ -33,7 +33,7 @@ using namespace tir; ...@@ -33,7 +33,7 @@ using namespace tir;
* Buffer. * Buffer.
*/ */
FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap) { FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap) {
auto node = make_object<FinalizeReducerOpNode>(); auto node = tvm::ffi::make_object<FinalizeReducerOpNode>();
node->reducer = vmap[GetVarFromAccessPtr(args[0])]; node->reducer = vmap[GetVarFromAccessPtr(args[0])];
node->op = (ReducerOpType)*as_const_int(args[1]); node->op = (ReducerOpType)*as_const_int(args[1]);
data_ = std::move(node); data_ = std::move(node);
...@@ -152,7 +152,7 @@ LayoutMap FinalizeReducerOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -152,7 +152,7 @@ LayoutMap FinalizeReducerOpNode::InferLayout(const LayoutInferArgs &T,
* @return TileOperator A TileOperator that contains a deep copy of this node. * @return TileOperator A TileOperator that contains a deep copy of this node.
*/ */
TileOperator FinalizeReducerOpNode::Clone() const { TileOperator FinalizeReducerOpNode::Clone() const {
auto node = make_object<FinalizeReducerOpNode>(*this); auto node = tvm::ffi::make_object<FinalizeReducerOpNode>(*this);
return TileOperator(node); return TileOperator(node);
} }
...@@ -161,6 +161,6 @@ TIR_REGISTER_TL_OP(FinalizeReducerOp, finalize_reducer) ...@@ -161,6 +161,6 @@ TIR_REGISTER_TL_OP(FinalizeReducerOp, finalize_reducer)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ FinalizeReducerOpNode::RegisterReflection(); }); TVM_FFI_STATIC_INIT_BLOCK() { FinalizeReducerOpNode::RegisterReflection(); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -27,8 +27,8 @@ public: ...@@ -27,8 +27,8 @@ public:
tir::Buffer reducer; tir::Buffer reducer;
ReducerOpType op; ReducerOpType op;
static constexpr const char *_type_key = "tl.FinalizeReducerOp"; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.FinalizeReducerOp",
TVM_DECLARE_FINAL_OBJECT_INFO(FinalizeReducerOpNode, TileOperatorNode); FinalizeReducerOpNode, TileOperatorNode);
static void RegisterReflection() { static void RegisterReflection() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
...@@ -37,18 +37,6 @@ public: ...@@ -37,18 +37,6 @@ public:
.def_ro("op", &FinalizeReducerOpNode::op); .def_ro("op", &FinalizeReducerOpNode::op);
} }
bool SEqualReduce(const FinalizeReducerOpNode *other,
SEqualReducer equal) const {
return equal(reducer, other->reducer) && equal(op, other->op);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(reducer);
hash_reduce(op);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T, LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override; InferLevel level) const override;
...@@ -58,7 +46,7 @@ public: ...@@ -58,7 +46,7 @@ public:
class FinalizeReducerOp : public TileOperator { class FinalizeReducerOp : public TileOperator {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(FinalizeReducerOp, TileOperator, TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FinalizeReducerOp, TileOperator,
FinalizeReducerOpNode); FinalizeReducerOpNode);
TVM_DLL FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap); TVM_DLL FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get(); static const Op &Get();
......
...@@ -112,7 +112,7 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { ...@@ -112,7 +112,7 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
* performed here. * performed here.
*/ */
Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) { Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmNode> node = make_object<GemmNode>(); ObjectPtr<GemmNode> node = tvm::ffi::make_object<GemmNode>();
node->Aptr = args[0]; node->Aptr = args[0];
node->Bptr = args[1]; node->Bptr = args[1];
...@@ -160,7 +160,7 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) { ...@@ -160,7 +160,7 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A Gemm operator that owns a copy of this node. * @return TileOperator A Gemm operator that owns a copy of this node.
*/ */
TileOperator GemmNode::Clone() const { TileOperator GemmNode::Clone() const {
auto op = make_object<GemmNode>(*this); auto op = tvm::ffi::make_object<GemmNode>(*this);
return Gemm(op); return Gemm(op);
} }
...@@ -476,8 +476,8 @@ bool GemmNode::CheckWGMMA() const { ...@@ -476,8 +476,8 @@ bool GemmNode::CheckWGMMA() const {
*/ */
static int GetArchInt(Target target) { static int GetArchInt(Target target) {
int arch_int = 0; int arch_int = 0;
auto s = target->GetAttr<String>("arch"); auto s = target->GetAttr<tvm::ffi::String>("arch");
ICHECK(s.defined()); ICHECK(s.has_value());
std::string arch = s.value(); std::string arch = s.value();
if (arch.rfind("sm_", 0) == 0) { if (arch.rfind("sm_", 0) == 0) {
arch_int = std::stoi(arch.substr(3)); arch_int = std::stoi(arch.substr(3));
...@@ -874,7 +874,7 @@ TIR_REGISTER_TL_OP(Gemm, gemm) ...@@ -874,7 +874,7 @@ TIR_REGISTER_TL_OP(Gemm, gemm)
TVM_REGISTER_OP("tl.GemmWarpPolicy") TVM_REGISTER_OP("tl.GemmWarpPolicy")
.set_attr<TScriptPrinterName>("TScriptPrinterName", "GemmWarpPolicy"); .set_attr<TScriptPrinterName>("TScriptPrinterName", "GemmWarpPolicy");
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
GemmNode::RegisterReflection(); GemmNode::RegisterReflection();
GemmWarpPolicyNode::RegisterReflection(); GemmWarpPolicyNode::RegisterReflection();
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
...@@ -883,9 +883,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ ...@@ -883,9 +883,8 @@ TVM_FFI_STATIC_INIT_BLOCK({
Target target, GemmInst gemm_inst) { Target target, GemmInst gemm_inst) {
policy->ComputeWarpPartition(M, N, block_size, target, policy->ComputeWarpPartition(M, N, block_size, target,
gemm_inst); gemm_inst);
return;
}); });
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -30,8 +30,7 @@ public: ...@@ -30,8 +30,7 @@ public:
mutable int n_warp{0}; mutable int n_warp{0};
int policy_type; int policy_type;
static constexpr const char *_type_key = "tl.GemmWarpPolicy"; TVM_FFI_DECLARE_OBJECT_INFO("tl.GemmWarpPolicy", GemmWarpPolicyNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(GemmWarpPolicyNode, Object);
static void RegisterReflection() { static void RegisterReflection() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
...@@ -41,21 +40,6 @@ public: ...@@ -41,21 +40,6 @@ public:
.def_ro("n_warp", &GemmWarpPolicyNode::n_warp); .def_ro("n_warp", &GemmWarpPolicyNode::n_warp);
} }
bool SEqualReduce(const GemmWarpPolicyNode *other,
SEqualReducer equal) const {
return equal(policy_type, other->policy_type) &&
equal(m_warp, other->m_warp) && equal(n_warp, other->n_warp);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(policy_type);
hash_reduce(m_warp);
hash_reduce(n_warp);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
std::pair<int, int> ComputeWarpPartition(int M, int N, int block_size, std::pair<int, int> ComputeWarpPartition(int M, int N, int block_size,
Target target, Target target,
GemmInst gemm_inst) const; GemmInst gemm_inst) const;
...@@ -74,22 +58,23 @@ public: ...@@ -74,22 +58,23 @@ public:
class GemmWarpPolicy : public ObjectRef { class GemmWarpPolicy : public ObjectRef {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(GemmWarpPolicy, ObjectRef, GemmWarpPolicyNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmWarpPolicy, ObjectRef,
GemmWarpPolicyNode);
explicit GemmWarpPolicy(GemmWarpPolicyType policy_type) { explicit GemmWarpPolicy(GemmWarpPolicyType policy_type) {
auto node = make_object<GemmWarpPolicyNode>(); auto node = tvm::ffi::make_object<GemmWarpPolicyNode>();
node->policy_type = (int)policy_type; node->policy_type = (int)policy_type;
data_ = std::move(node); data_ = std::move(node);
} }
explicit GemmWarpPolicy(int policy_type) { explicit GemmWarpPolicy(int policy_type) {
auto node = make_object<GemmWarpPolicyNode>(); auto node = tvm::ffi::make_object<GemmWarpPolicyNode>();
node->policy_type = policy_type; node->policy_type = policy_type;
data_ = std::move(node); data_ = std::move(node);
} }
explicit GemmWarpPolicy(int m_warp, int n_warp) { explicit GemmWarpPolicy(int m_warp, int n_warp) {
auto node = make_object<GemmWarpPolicyNode>(); auto node = tvm::ffi::make_object<GemmWarpPolicyNode>();
node->m_warp = m_warp; node->m_warp = m_warp;
node->n_warp = n_warp; node->n_warp = n_warp;
node->policy_type = (int)GemmWarpPolicyType::kFree; node->policy_type = (int)GemmWarpPolicyType::kFree;
...@@ -116,9 +101,7 @@ public: ...@@ -116,9 +101,7 @@ public:
std::optional<tir::Buffer> mbar; // mbar is optional, only used for TCGEN5MMA std::optional<tir::Buffer> mbar; // mbar is optional, only used for TCGEN5MMA
Array<PrimExpr> C_coords; Array<PrimExpr> C_coords;
mutable GemmWarpPolicy policy; mutable GemmWarpPolicy policy;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Gemm", GemmNode, TileOperatorNode);
static constexpr const char *_type_key = "tl.Gemm";
TVM_DECLARE_FINAL_OBJECT_INFO(GemmNode, TileOperatorNode);
static void RegisterReflection() { static void RegisterReflection() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
...@@ -144,45 +127,6 @@ public: ...@@ -144,45 +127,6 @@ public:
.def_ro("policy", &GemmNode::policy); .def_ro("policy", &GemmNode::policy);
} }
bool SEqualReduce(const GemmNode *other, SEqualReducer equal) const {
return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) &&
equal(Aptr, other->Aptr) && equal(Bptr, other->Bptr) &&
equal(Cptr, other->Cptr) && equal(trans_A, other->trans_A) &&
equal(trans_B, other->trans_B) && equal(M, other->M) &&
equal(N, other->N) && equal(K, other->K) &&
equal(stride_A, other->stride_A) &&
equal(stride_B, other->stride_B) &&
equal(offset_A, other->offset_A) &&
equal(offset_B, other->offset_B) &&
equal(clear_accum, other->clear_accum) &&
equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) &&
equal(policy, other->policy);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(A);
hash_reduce(B);
hash_reduce(C);
hash_reduce(Aptr);
hash_reduce(Bptr);
hash_reduce(Cptr);
hash_reduce(trans_A);
hash_reduce(trans_B);
hash_reduce(M);
hash_reduce(N);
hash_reduce(K);
hash_reduce(stride_A);
hash_reduce(stride_B);
hash_reduce(offset_A);
hash_reduce(offset_B);
hash_reduce(clear_accum);
hash_reduce(kPack);
hash_reduce(wg_wait);
hash_reduce(policy);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T, LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override; InferLevel level) const override;
...@@ -199,7 +143,7 @@ private: ...@@ -199,7 +143,7 @@ private:
class Gemm : public TileOperator { class Gemm : public TileOperator {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(Gemm, TileOperator, GemmNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Gemm, TileOperator, GemmNode);
TVM_DLL Gemm(Array<PrimExpr> args, BufferMap vmap); TVM_DLL Gemm(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get(); static const Op &Get();
}; };
......
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
#include <tvm/tir/op_attr_types.h> #include <tvm/tir/op_attr_types.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include "../support/ffi_aliases.h"
#include "../target/utils.h" #include "../target/utils.h"
#include "tvm/ffi/string.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -48,7 +48,7 @@ using namespace tir; ...@@ -48,7 +48,7 @@ using namespace tir;
* performed here. * performed here.
*/ */
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) { GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmPyNode> node = make_object<GemmPyNode>(); ObjectPtr<GemmPyNode> node = tvm::ffi::make_object<GemmPyNode>();
node->Aptr = args[0]; node->Aptr = args[0];
node->Bptr = args[1]; node->Bptr = args[1];
...@@ -88,7 +88,7 @@ GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) { ...@@ -88,7 +88,7 @@ GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A Gemm operator that owns a copy of this node. * @return TileOperator A Gemm operator that owns a copy of this node.
*/ */
TileOperator GemmPyNode::Clone() const { TileOperator GemmPyNode::Clone() const {
auto op = make_object<GemmPyNode>(*this); auto op = tvm::ffi::make_object<GemmPyNode>(*this);
return GemmPy(op); return GemmPy(op);
} }
...@@ -208,8 +208,8 @@ bool GemmPyNode::CheckWGMMA() const { ...@@ -208,8 +208,8 @@ bool GemmPyNode::CheckWGMMA() const {
*/ */
static int GetArchInt(Target target) { static int GetArchInt(Target target) {
int arch_int = 0; int arch_int = 0;
auto s = target->GetAttr<String>("arch"); auto s = target->GetAttr<tvm::ffi::String>("arch");
ICHECK(s.defined()); ICHECK(s.has_value());
std::string arch = s.value(); std::string arch = s.value();
if (arch.rfind("sm_", 0) == 0) { if (arch.rfind("sm_", 0) == 0) {
arch_int = std::stoi(arch.substr(3)); arch_int = std::stoi(arch.substr(3));
...@@ -228,11 +228,12 @@ Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -228,11 +228,12 @@ Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) { if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) {
auto prim_func = auto prim_func =
Downcast<PrimFunc>((*f)(GetRef<GemmPy>(this), T.layout_map, T.target, Downcast<PrimFunc>((*f)(tvm::ffi::GetRef<GemmPy>(this), T.layout_map,
T.thread_bounds, T.thread_var)); T.target, T.thread_bounds, T.thread_var));
ICHECK(prim_func->attrs.defined()); ICHECK(prim_func->attrs.defined());
auto global_symbol = prim_func->attrs.GetAttr<String>("global_symbol"); auto global_symbol =
ICHECK(global_symbol.defined()); prim_func->attrs.GetAttr<tvm::ffi::String>("global_symbol");
ICHECK(global_symbol.has_value());
if (prim_func->body.as<BlockRealizeNode>()) { if (prim_func->body.as<BlockRealizeNode>()) {
BlockRealize block_realize = Downcast<BlockRealize>(prim_func->body); BlockRealize block_realize = Downcast<BlockRealize>(prim_func->body);
auto block = block_realize->block; auto block = block_realize->block;
...@@ -265,7 +266,7 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T, ...@@ -265,7 +266,7 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T,
if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.infer_layout")) { if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.infer_layout")) {
results = Downcast<LayoutMap>( results = Downcast<LayoutMap>(
(*f)(GetRef<GemmPy>(this), T.target, T.thread_bounds)); (*f)(tvm::ffi::GetRef<GemmPy>(this), T.target, T.thread_bounds));
} else { } else {
LOG(FATAL) << "No infer layout function found for gemm_py"; LOG(FATAL) << "No infer layout function found for gemm_py";
} }
...@@ -279,15 +280,15 @@ TIR_REGISTER_TL_OP(GemmPy, gemm_py) ...@@ -279,15 +280,15 @@ TIR_REGISTER_TL_OP(GemmPy, gemm_py)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ GemmPyNode::RegisterReflection(); }); TVM_FFI_STATIC_INIT_BLOCK() { GemmPyNode::RegisterReflection(); }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.GemmPyGemmInst", refl::GlobalDef().def("tl.GemmPyGemmInst",
[](GemmPy gemm_py, int block_size, Target target) { [](GemmPy gemm_py, int block_size, Target target) {
return gemm_py->GetGemmInst(block_size, target); return gemm_py->GetGemmInst(block_size, target);
}); });
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -33,8 +33,7 @@ public: ...@@ -33,8 +33,7 @@ public:
int wg_wait = 0; int wg_wait = 0;
mutable GemmWarpPolicy policy; mutable GemmWarpPolicy policy;
static constexpr const char *_type_key = "tl.GemmPy"; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmPy", GemmPyNode, TileOperatorNode);
TVM_DECLARE_FINAL_OBJECT_INFO(GemmPyNode, TileOperatorNode);
static void RegisterReflection() { static void RegisterReflection() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
...@@ -60,45 +59,6 @@ public: ...@@ -60,45 +59,6 @@ public:
.def_ro("policy", &GemmPyNode::policy); .def_ro("policy", &GemmPyNode::policy);
} }
bool SEqualReduce(const GemmPyNode *other, SEqualReducer equal) const {
return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) &&
equal(Aptr, other->Aptr) && equal(Bptr, other->Bptr) &&
equal(Cptr, other->Cptr) && equal(trans_A, other->trans_A) &&
equal(trans_B, other->trans_B) && equal(M, other->M) &&
equal(N, other->N) && equal(K, other->K) &&
equal(stride_A, other->stride_A) &&
equal(stride_B, other->stride_B) &&
equal(offset_A, other->offset_B) &&
equal(offset_B, other->offset_B) &&
equal(clear_accum, other->clear_accum) &&
equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) &&
equal(policy, other->policy);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(A);
hash_reduce(B);
hash_reduce(C);
hash_reduce(Aptr);
hash_reduce(Bptr);
hash_reduce(Cptr);
hash_reduce(trans_A);
hash_reduce(trans_B);
hash_reduce(M);
hash_reduce(N);
hash_reduce(K);
hash_reduce(stride_A);
hash_reduce(stride_B);
hash_reduce(offset_A);
hash_reduce(offset_B);
hash_reduce(clear_accum);
hash_reduce(kPack);
hash_reduce(wg_wait);
hash_reduce(policy);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T, LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override; InferLevel level) const override;
...@@ -114,7 +74,7 @@ private: ...@@ -114,7 +74,7 @@ private:
class GemmPy : public TileOperator { class GemmPy : public TileOperator {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(GemmPy, TileOperator, GemmPyNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmPy, TileOperator, GemmPyNode);
TVM_DLL GemmPy(Array<PrimExpr> args, BufferMap vmap); TVM_DLL GemmPy(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get(); static const Op &Get();
}; };
......
...@@ -84,7 +84,7 @@ std::pair<int, int> GemmSPWarpPolicyNode::ComputeWarpPartition(int M, int N, ...@@ -84,7 +84,7 @@ std::pair<int, int> GemmSPWarpPolicyNode::ComputeWarpPartition(int M, int N,
* @note An ICHECK failure is raised if a provided kPack is not 1 or 2. * @note An ICHECK failure is raised if a provided kPack is not 1 or 2.
*/ */
GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) { GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmSPNode> node = make_object<GemmSPNode>(); ObjectPtr<GemmSPNode> node = tvm::ffi::make_object<GemmSPNode>();
node->A = vmap[GetVarFromAccessPtr(args[0])]; node->A = vmap[GetVarFromAccessPtr(args[0])];
node->E = vmap[GetVarFromAccessPtr(args[1])]; node->E = vmap[GetVarFromAccessPtr(args[1])];
node->B = vmap[GetVarFromAccessPtr(args[2])]; node->B = vmap[GetVarFromAccessPtr(args[2])];
...@@ -118,7 +118,7 @@ GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) { ...@@ -118,7 +118,7 @@ GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator holding a cloned GemmSPNode. * @return TileOperator A TileOperator holding a cloned GemmSPNode.
*/ */
TileOperator GemmSPNode::Clone() const { TileOperator GemmSPNode::Clone() const {
auto op = make_object<GemmSPNode>(*this); auto op = tvm::ffi::make_object<GemmSPNode>(*this);
return GemmSP(op); return GemmSP(op);
} }
...@@ -303,7 +303,7 @@ TIR_REGISTER_TL_OP(GemmSP, gemm_sp) ...@@ -303,7 +303,7 @@ TIR_REGISTER_TL_OP(GemmSP, gemm_sp)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ GemmSPNode::RegisterReflection(); }); TVM_FFI_STATIC_INIT_BLOCK() { GemmSPNode::RegisterReflection(); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -21,27 +21,29 @@ public: ...@@ -21,27 +21,29 @@ public:
std::pair<int, int> ComputeWarpPartition(int M, int N, int block_size, std::pair<int, int> ComputeWarpPartition(int M, int N, int block_size,
Target target, bool use_wgmma, Target target, bool use_wgmma,
int bits) const; int bits) const;
TVM_FFI_DECLARE_OBJECT_INFO("tl.GemmSPWarpPolicy", GemmSPWarpPolicyNode,
GemmWarpPolicyNode);
}; };
class GemmSPWarpPolicy : public ObjectRef { class GemmSPWarpPolicy : public ObjectRef {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(GemmSPWarpPolicy, ObjectRef, TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSPWarpPolicy, ObjectRef,
GemmSPWarpPolicyNode); GemmSPWarpPolicyNode);
explicit GemmSPWarpPolicy(GemmWarpPolicyType policy_type) { explicit GemmSPWarpPolicy(GemmWarpPolicyType policy_type) {
auto node = make_object<GemmSPWarpPolicyNode>(); auto node = tvm::ffi::make_object<GemmSPWarpPolicyNode>();
node->policy_type = (int)policy_type; node->policy_type = (int)policy_type;
data_ = std::move(node); data_ = std::move(node);
} }
explicit GemmSPWarpPolicy(int policy_type) { explicit GemmSPWarpPolicy(int policy_type) {
auto node = make_object<GemmSPWarpPolicyNode>(); auto node = tvm::ffi::make_object<GemmSPWarpPolicyNode>();
node->policy_type = policy_type; node->policy_type = policy_type;
data_ = std::move(node); data_ = std::move(node);
} }
explicit GemmSPWarpPolicy(int m_warp, int n_warp) { explicit GemmSPWarpPolicy(int m_warp, int n_warp) {
auto node = make_object<GemmSPWarpPolicyNode>(); auto node = tvm::ffi::make_object<GemmSPWarpPolicyNode>();
node->m_warp = m_warp; node->m_warp = m_warp;
node->n_warp = n_warp; node->n_warp = n_warp;
node->policy_type = (int)GemmWarpPolicyType::kFree; node->policy_type = (int)GemmWarpPolicyType::kFree;
...@@ -62,8 +64,7 @@ public: ...@@ -62,8 +64,7 @@ public:
mutable GemmSPWarpPolicy policy; mutable GemmSPWarpPolicy policy;
static constexpr const char *_type_key = "tl.GemmSP"; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmSP", GemmSPNode, TileOperatorNode);
TVM_DECLARE_FINAL_OBJECT_INFO(GemmSPNode, TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T, LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override; InferLevel level) const override;
...@@ -88,38 +89,13 @@ public: ...@@ -88,38 +89,13 @@ public:
.def_ro("wg_wait", &GemmSPNode::wg_wait); .def_ro("wg_wait", &GemmSPNode::wg_wait);
} }
bool SEqualReduce(const GemmSPNode *other, SEqualReducer equal) const {
return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) &&
equal(E, other->E) && equal(trans_A, other->trans_A) &&
equal(trans_B, other->trans_B) && equal(M, other->M) &&
equal(N, other->N) && equal(K, other->K) &&
equal(clear_accum, other->clear_accum) &&
equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(policy);
hash_reduce(A);
hash_reduce(B);
hash_reduce(C);
hash_reduce(E);
hash_reduce(trans_A);
hash_reduce(trans_B);
hash_reduce(M);
hash_reduce(N);
hash_reduce(K);
hash_reduce(clear_accum);
hash_reduce(kPack);
hash_reduce(wg_wait);
}
private: private:
mutable bool completed_ = false; mutable bool completed_ = false;
}; };
class GemmSP : public TileOperator { class GemmSP : public TileOperator {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(GemmSP, TileOperator, GemmSPNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSP, TileOperator, GemmSPNode);
TVM_DLL GemmSP(Array<PrimExpr> args, BufferMap vmap); TVM_DLL GemmSP(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get(); static const Op &Get();
}; };
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h> #include <tvm/tir/op_attr_types.h>
#include "../support/ffi_aliases.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h> #include <tvm/tir/op_attr_types.h>
#include "../support/ffi_aliases.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
......
...@@ -55,7 +55,7 @@ TileOperator ParseOperator(Call call, BufferMap vmap) { ...@@ -55,7 +55,7 @@ TileOperator ParseOperator(Call call, BufferMap vmap) {
TileOperator ParseOperator(Stmt stmt, BufferMap vmap) { TileOperator ParseOperator(Stmt stmt, BufferMap vmap) {
if (stmt.as<Evaluate>() && stmt.as<EvaluateNode>()->value.as<CallNode>()) { if (stmt.as<Evaluate>() && stmt.as<EvaluateNode>()->value.as<CallNode>()) {
auto call = stmt.as<EvaluateNode>()->value.as<CallNode>(); auto call = stmt.as<EvaluateNode>()->value.as<CallNode>();
return ParseOperator(GetRef<Call>(call), vmap); return ParseOperator(tvm::ffi::GetRef<Call>(call), vmap);
} }
return TileOperator(); return TileOperator();
} }
...@@ -77,7 +77,7 @@ Var GetVarFromAccessPtr(const PrimExpr &expr) { ...@@ -77,7 +77,7 @@ Var GetVarFromAccessPtr(const PrimExpr &expr) {
ICHECK(call->op.same_as(builtin::tvm_access_ptr())); ICHECK(call->op.same_as(builtin::tvm_access_ptr()));
auto var = call->args[1].as<VarNode>(); auto var = call->args[1].as<VarNode>();
ICHECK(var); ICHECK(var);
return GetRef<Var>(var); return tvm::ffi::GetRef<Var>(var);
} }
} // namespace tl } // namespace tl
......
...@@ -62,14 +62,13 @@ public: ...@@ -62,14 +62,13 @@ public:
virtual TileOperator Clone() const = 0; virtual TileOperator Clone() const = 0;
static constexpr const char *_type_key = "tl.TileOperator"; TVM_FFI_DECLARE_OBJECT_INFO("tl.TileOperator", TileOperatorNode, Object);
TVM_DECLARE_BASE_OBJECT_INFO(TileOperatorNode, Object);
}; };
class TileOperator : public ObjectRef { class TileOperator : public ObjectRef {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(TileOperator, ObjectRef, TileOperatorNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TileOperator, ObjectRef,
TileOperatorNode);
}; };
Var GetVarFromAccessPtr(const PrimExpr &expr); Var GetVarFromAccessPtr(const PrimExpr &expr);
......
...@@ -178,7 +178,7 @@ ParallelOpNode::ParallelOpNode(For root) : root_(root), V(this) { ...@@ -178,7 +178,7 @@ ParallelOpNode::ParallelOpNode(For root) : root_(root), V(this) {
} }
TileOperator ParallelOpNode::Clone() const { TileOperator ParallelOpNode::Clone() const {
auto op = make_object<ParallelOpNode>(*this); auto op = tvm::ffi::make_object<ParallelOpNode>(*this);
return ParallelOp(op); return ParallelOp(op);
} }
...@@ -642,7 +642,7 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const { ...@@ -642,7 +642,7 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
->CondenseReplicateVar(); ->CondenseReplicateVar();
} }
TVM_FFI_STATIC_INIT_BLOCK({ ParallelOpNode::RegisterReflection(); }); TVM_FFI_STATIC_INIT_BLOCK() { ParallelOpNode::RegisterReflection(); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -66,8 +66,8 @@ public: ...@@ -66,8 +66,8 @@ public:
mutable Optional<PrimExpr> predicate_; mutable Optional<PrimExpr> predicate_;
// Type key for TVM object system. // Type key for TVM object system.
static constexpr const char *_type_key = "tl.ParallelOp"; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ParallelOp", ParallelOpNode,
TVM_DECLARE_FINAL_OBJECT_INFO(ParallelOpNode, TileOperatorNode); TileOperatorNode);
static void RegisterReflection() { static void RegisterReflection() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
...@@ -77,20 +77,6 @@ public: ...@@ -77,20 +77,6 @@ public:
.def_ro("predicate", &ParallelOpNode::predicate_); .def_ro("predicate", &ParallelOpNode::predicate_);
} }
bool SEqualReduce(const ParallelOpNode *other, SEqualReducer equal) const {
return equal(root_, other->root_) &&
equal(loop_layout_, other->loop_layout_) &&
equal(predicate_, other->predicate_);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(root_);
hash_reduce(loop_layout_);
hash_reduce(predicate_);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
// Construct from a root For loop. // Construct from a root For loop.
ParallelOpNode(For root); ParallelOpNode(For root);
...@@ -150,10 +136,11 @@ private: ...@@ -150,10 +136,11 @@ private:
class ParallelOp : public TileOperator { class ParallelOp : public TileOperator {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(ParallelOp, TileOperator, ParallelOpNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ParallelOp, TileOperator,
ParallelOpNode);
ParallelOp(const For &root) { ParallelOp(const For &root) {
auto op = make_object<ParallelOpNode>(root); auto op = tvm::ffi::make_object<ParallelOpNode>(root);
data_ = std::move(op); data_ = std::move(op);
} }
}; };
......
...@@ -22,7 +22,7 @@ namespace tl { ...@@ -22,7 +22,7 @@ namespace tl {
using namespace tir; using namespace tir;
ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) { ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<ReduceOpNode> node = make_object<ReduceOpNode>(); ObjectPtr<ReduceOpNode> node = tvm::ffi::make_object<ReduceOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])]; node->src = vmap[GetVarFromAccessPtr(args[0])];
node->dst = vmap[GetVarFromAccessPtr(args[1])]; node->dst = vmap[GetVarFromAccessPtr(args[1])];
std::string reduce_type = args[2].as<StringImm>().value()->value; std::string reduce_type = args[2].as<StringImm>().value()->value;
...@@ -33,12 +33,12 @@ ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) { ...@@ -33,12 +33,12 @@ ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
} }
TileOperator ReduceOpNode::Clone() const { TileOperator ReduceOpNode::Clone() const {
auto op = make_object<ReduceOpNode>(*this); auto op = tvm::ffi::make_object<ReduceOpNode>(*this);
return ReduceOp(op); return ReduceOp(op);
} }
TileOperator CumSumOpNode::Clone() const { TileOperator CumSumOpNode::Clone() const {
auto op = make_object<CumSumOpNode>(*this); auto op = tvm::ffi::make_object<CumSumOpNode>(*this);
return CumSumOp(op); return CumSumOp(op);
} }
...@@ -85,6 +85,7 @@ PrimExpr ReduceOpNode::MakeInitValue() const { ...@@ -85,6 +85,7 @@ PrimExpr ReduceOpNode::MakeInitValue() const {
return make_zero(dst->dtype); return make_zero(dst->dtype);
} else { } else {
LOG(FATAL) << "Unsupported reduce type: " << type->type; LOG(FATAL) << "Unsupported reduce type: " << type->type;
return PrimExpr();
} }
} }
...@@ -512,7 +513,7 @@ CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) { ...@@ -512,7 +513,7 @@ CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
/// - dim: dimension to cumsum /// - dim: dimension to cumsum
/// - reverse: whether to cumsum in reverse order /// - reverse: whether to cumsum in reverse order
CHECK_EQ(args.size(), 4); CHECK_EQ(args.size(), 4);
ObjectPtr<CumSumOpNode> node = make_object<CumSumOpNode>(); ObjectPtr<CumSumOpNode> node = tvm::ffi::make_object<CumSumOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])]; node->src = vmap[GetVarFromAccessPtr(args[0])];
node->dst = vmap[GetVarFromAccessPtr(args[1])]; node->dst = vmap[GetVarFromAccessPtr(args[1])];
node->dim = args[2].as<IntImm>().value()->value; node->dim = args[2].as<IntImm>().value()->value;
...@@ -567,5 +568,12 @@ TIR_REGISTER_TL_OP(CumSumOp, cumsum) ...@@ -567,5 +568,12 @@ TIR_REGISTER_TL_OP(CumSumOp, cumsum)
.set_num_inputs(4) .set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK() {
ReduceOpNode::RegisterReflection();
CumSumOpNode::RegisterReflection();
ReduceTypeNode::RegisterReflection();
}
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment