Unverified Commit 10911e28 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[FFI] Rebase tvm to v0.22.0 to utilize tvm-ffi (#1108)



* 3rdparty tvm bump

* bump tvm into v0.22.0

* lint fix

* rebase tvm

* Update submodule tvm to latest commit 3085bc4

* Refactor: Update configuration retrieval in CopyNode and adjust test registration in tilelang

* test fix

* add requirement

* atomic_fix

* atomic_fix

* phaseout py39

* optimize

* optimize

* lint fix

* do not clean cache

* do not clean cache

* [Minor] Minor update for Python versions and dependencies

* [Lint] fix lint for py39

* [Lint] fix lint for ROCm

* [Build][CI] Sync CI changes from upstream/sdist

* [Lint] fix lint for ROCm

* [Build][CI] Update `repair-wheel-command`

* [Minor] update abi3audit result format

* [Lint] fix lint for ROCm

* [BugFix] fix build

* [Lint] fix lint for ROCm

* [BugFix] set rpath for libtvm and libtvm_runtime

* [Deps] pin apache-tvm-ffi version

* [Build] set Python 3.9 Limited API for Cython target

* [Build] set Python 3.9 Limited API for Cython target

* [Deps] Restore Python 3.8 support

* [Build] use `apache-tvm-ffi`'s `libtvm_ffi`

* [BugFix] use `;` as delimiter for RPATH on macOS

* [BugFix] use `--ignore-missing-dependencies` for `delocate-wheel`

* [Build] support `sccache` if available

* [Build] add CIBW import test

* [Build][CI] enable ccache for CIBW on Linux

* [BugFix] set rpath for libtvm and libtvm_runtime

* Revert "[Build][CI] enable ccache for CIBW on Linux"

This reverts commit cd9ab57bb5ddd2572c60bcbbebde81480a658fd3.

* [CI] fix perfbench bot

* [BugFix] use Python 3.9 to build wheel

* [Minor] update perfbench bot envs

* [BugFix] fix CIBW environment on Linux

* [CI] skip import test on CentOS 7

* [CI] use Python urllib to download file instead of Wget

---------
Co-authored-by: default avatarXuehai Pan <XuehaiPan@pku.edu.cn>
parent c37621c5
......@@ -25,8 +25,8 @@ public:
IntImm memory_order; ///< Memory order for atomic operations
mutable ParallelOp par_op_; ///< Associated parallel operation
static constexpr const char *_type_key = "tl.AtomicAdd";
TVM_DECLARE_FINAL_OBJECT_INFO(AtomicAddNode, TileOperatorNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.AtomicAdd", AtomicAddNode,
TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
......@@ -46,28 +46,6 @@ public:
.def_ro("memory_order", &AtomicAddNode::memory_order);
}
bool SEqualReduce(const AtomicAddNode *other, SEqualReducer equal) const {
return equal(src, other->src) && equal(dst, other->dst) &&
equal(src_range, other->src_range) &&
equal(dst_range, other->dst_range) &&
equal(use_tma, other->use_tma) &&
equal(coalesced_width, other->coalesced_width) &&
equal(memory_order, other->memory_order);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(src);
hash_reduce(dst);
hash_reduce(src_range);
hash_reduce(dst_range);
hash_reduce(use_tma);
hash_reduce(coalesced_width);
hash_reduce(memory_order);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
protected:
/// Create SIMT-style parallel loop structure
For MakeSIMTLoop(arith::Analyzer *analyzer) const;
......@@ -85,7 +63,8 @@ protected:
/// Wrapper class for atomic addition operations
class AtomicAdd : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(AtomicAdd, TileOperator, AtomicAddNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AtomicAdd, TileOperator,
AtomicAddNode);
TVM_DLL AtomicAdd(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
......@@ -93,4 +72,4 @@ public:
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_ATOMIC_ADD_H_
\ No newline at end of file
#endif // TVM_TL_OP_ATOMIC_ADD_H_
......@@ -130,7 +130,7 @@ template <typename T> static Array<T> ReverseArray(Array<T> array) {
* @param vmap BufferMap used to resolve RegionOp buffers and ranges.
*/
Copy::Copy(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<CopyNode> node = make_object<CopyNode>();
ObjectPtr<CopyNode> node = tvm::ffi::make_object<CopyNode>();
Array<Range> rgs[2];
Buffer bf[2];
for (int i = 0; i < 2; i++) {
......@@ -169,7 +169,7 @@ Copy::Copy(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator owning the cloned CopyNode.
*/
TileOperator CopyNode::Clone() const {
auto op = make_object<CopyNode>(*this);
auto op = tvm::ffi::make_object<CopyNode>(*this);
if (par_op_.defined()) {
op->par_op_ = Downcast<ParallelOp>(par_op_->Clone());
}
......@@ -401,7 +401,7 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
using namespace tvm::transform;
PassContext pass_ctx = PassContext::Current();
bool disable_tma_lower =
pass_ctx->GetConfig<bool>(kDisableTMALower, false).value();
pass_ctx->GetConfig<Bool>(kDisableTMALower, Bool(false)).value();
auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma,
T.layout_map, T.analyzer, T.buffer_oob);
......@@ -793,7 +793,7 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
using namespace tvm::transform;
PassContext pass_ctx = PassContext::Current();
bool disable_tma_lower =
pass_ctx->GetConfig<bool>(kDisableTMALower, false).value();
pass_ctx->GetConfig<Bool>(kDisableTMALower, Bool(false)).value();
auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma,
T.layout_map, analyzer);
if (copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) {
......@@ -1722,7 +1722,8 @@ Array<PrimExpr> TMADesc::EncodeCallArgs() const {
* @param vmap Mapping from original buffer variables to actual Buffer objects.
*/
Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<Conv2DIm2ColOpNode> node = make_object<Conv2DIm2ColOpNode>();
ObjectPtr<Conv2DIm2ColOpNode> node =
tvm::ffi::make_object<Conv2DIm2ColOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])];
node->dst = vmap[GetVarFromAccessPtr(args[1])];
node->nhw_step = args[2];
......@@ -1747,7 +1748,7 @@ Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator containing the cloned Conv2DIm2ColOpNode.
*/
TileOperator Conv2DIm2ColOpNode::Clone() const {
auto op = make_object<Conv2DIm2ColOpNode>(*this);
auto op = tvm::ffi::make_object<Conv2DIm2ColOpNode>(*this);
return Conv2DIm2ColOp(op);
}
......@@ -1973,9 +1974,9 @@ TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
CopyNode::RegisterReflection();
Conv2DIm2ColOpNode::RegisterReflection();
});
}
} // namespace tl
} // namespace tvm
......@@ -101,8 +101,7 @@ public:
};
uint8_t eviction_policy; // Policy for cache eviction
static constexpr const char *_type_key = "tl.Copy";
TVM_DECLARE_FINAL_OBJECT_INFO(CopyNode, TileOperatorNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Copy", CopyNode, TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
......@@ -114,23 +113,6 @@ public:
.def_ro("coalesced_width", &CopyNode::coalesced_width);
}
bool SEqualReduce(const CopyNode *other, SEqualReducer equal) const {
return equal(src, other->src) && equal(dst, other->dst) &&
equal(src_range, other->src_range) &&
equal(dst_range, other->dst_range) &&
equal(coalesced_width, other->coalesced_width);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(src);
hash_reduce(dst);
hash_reduce(src_range);
hash_reduce(dst_range);
hash_reduce(coalesced_width);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
/*!
* \brief Lower the copy operator to a TIR statement.
* \param T Arguments for lowering.
......@@ -291,7 +273,7 @@ protected:
class Copy : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Copy, TileOperator, CopyNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Copy, TileOperator, CopyNode);
/*!
* \brief Constructor.
......@@ -323,8 +305,8 @@ public:
PrimExpr nhw_step; // Step size in NHW dimensions
PrimExpr c_step; // Step size in channel dimension
static constexpr const char *_type_key = "tl.Conv2DIm2Col";
TVM_DECLARE_FINAL_OBJECT_INFO(Conv2DIm2ColOpNode, TileOperatorNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Conv2DIm2Col", Conv2DIm2ColOpNode,
TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
......@@ -338,26 +320,6 @@ public:
.def_ro("eviction_policy", &Conv2DIm2ColOpNode::eviction_policy);
}
bool SEqualReduce(const Conv2DIm2ColOpNode *other,
SEqualReducer equal) const {
return equal(src, other->src) && equal(dst, other->dst) &&
equal(stride, other->stride) && equal(padding, other->padding) &&
equal(dilation, other->dilation) && equal(kernel, other->kernel) &&
equal(eviction_policy, other->eviction_policy);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(src);
hash_reduce(dst);
hash_reduce(stride);
hash_reduce(padding);
hash_reduce(dilation);
hash_reduce(kernel);
hash_reduce(eviction_policy);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
/*!
* \brief Lower to TIR statement.
*/
......@@ -378,8 +340,8 @@ public:
class Conv2DIm2ColOp : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Conv2DIm2ColOp, TileOperator,
Conv2DIm2ColOpNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Conv2DIm2ColOp, TileOperator,
Conv2DIm2ColOpNode);
TVM_DLL Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
......@@ -387,4 +349,4 @@ public:
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_COPY_H_
\ No newline at end of file
#endif // TVM_TL_OP_COPY_H_
......@@ -60,7 +60,7 @@ using namespace tir;
* of bounds.
*/
Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<FillNode> node = make_object<FillNode>();
ObjectPtr<FillNode> node = tvm::ffi::make_object<FillNode>();
if (args[0]->IsInstance<BufferLoadNode>()) {
auto buffer_load = Downcast<BufferLoad>(args[0]);
......@@ -117,7 +117,7 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator that owns the copied FillNode.
*/
TileOperator FillNode::Clone() const {
auto op = make_object<FillNode>(*this);
auto op = tvm::ffi::make_object<FillNode>(*this);
return Fill(op);
}
......@@ -226,7 +226,7 @@ TIR_REGISTER_TL_OP(Fill, fill)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ FillNode::RegisterReflection(); });
TVM_FFI_STATIC_INIT_BLOCK() { FillNode::RegisterReflection(); }
} // namespace tl
} // namespace tvm
\ No newline at end of file
......@@ -20,8 +20,7 @@ public:
tir::Buffer dst; ///< Destination buffer to fill
PrimExpr value; ///< Value to fill with
Array<Range> region; ///< Region to fill within the buffer
static constexpr const char *_type_key = "tl.Fill";
TVM_DECLARE_FINAL_OBJECT_INFO(FillNode, TileOperatorNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Fill", FillNode, TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
......@@ -35,19 +34,6 @@ public:
.def_ro("region", &FillNode::region);
}
bool SEqualReduce(const FillNode *other, SEqualReducer equal) const {
return equal(dst, other->dst) && equal(value, other->value) &&
equal(region, other->region);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dst);
hash_reduce(value);
hash_reduce(region);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
TileOperator Clone() const;
private:
......@@ -58,7 +44,7 @@ private:
/// Wrapper class for fill operations
class Fill : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Fill, TileOperator, FillNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fill, TileOperator, FillNode);
TVM_DLL Fill(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
......@@ -66,4 +52,4 @@ public:
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_FILL_H_
\ No newline at end of file
#endif // TVM_TL_OP_FILL_H_
......@@ -33,7 +33,7 @@ using namespace tir;
* Buffer.
*/
FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap) {
auto node = make_object<FinalizeReducerOpNode>();
auto node = tvm::ffi::make_object<FinalizeReducerOpNode>();
node->reducer = vmap[GetVarFromAccessPtr(args[0])];
node->op = (ReducerOpType)*as_const_int(args[1]);
data_ = std::move(node);
......@@ -152,7 +152,7 @@ LayoutMap FinalizeReducerOpNode::InferLayout(const LayoutInferArgs &T,
* @return TileOperator A TileOperator that contains a deep copy of this node.
*/
TileOperator FinalizeReducerOpNode::Clone() const {
auto node = make_object<FinalizeReducerOpNode>(*this);
auto node = tvm::ffi::make_object<FinalizeReducerOpNode>(*this);
return TileOperator(node);
}
......@@ -161,6 +161,6 @@ TIR_REGISTER_TL_OP(FinalizeReducerOp, finalize_reducer)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ FinalizeReducerOpNode::RegisterReflection(); });
TVM_FFI_STATIC_INIT_BLOCK() { FinalizeReducerOpNode::RegisterReflection(); }
} // namespace tl
} // namespace tvm
......@@ -27,8 +27,8 @@ public:
tir::Buffer reducer;
ReducerOpType op;
static constexpr const char *_type_key = "tl.FinalizeReducerOp";
TVM_DECLARE_FINAL_OBJECT_INFO(FinalizeReducerOpNode, TileOperatorNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.FinalizeReducerOp",
FinalizeReducerOpNode, TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
......@@ -37,18 +37,6 @@ public:
.def_ro("op", &FinalizeReducerOpNode::op);
}
bool SEqualReduce(const FinalizeReducerOpNode *other,
SEqualReducer equal) const {
return equal(reducer, other->reducer) && equal(op, other->op);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(reducer);
hash_reduce(op);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
......@@ -58,8 +46,8 @@ public:
class FinalizeReducerOp : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(FinalizeReducerOp, TileOperator,
FinalizeReducerOpNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FinalizeReducerOp, TileOperator,
FinalizeReducerOpNode);
TVM_DLL FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
......@@ -67,4 +55,4 @@ public:
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_FINALIZE_REDUCER_H_
\ No newline at end of file
#endif // TVM_TL_OP_FINALIZE_REDUCER_H_
......@@ -112,7 +112,7 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
* performed here.
*/
Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmNode> node = make_object<GemmNode>();
ObjectPtr<GemmNode> node = tvm::ffi::make_object<GemmNode>();
node->Aptr = args[0];
node->Bptr = args[1];
......@@ -160,7 +160,7 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A Gemm operator that owns a copy of this node.
*/
TileOperator GemmNode::Clone() const {
auto op = make_object<GemmNode>(*this);
auto op = tvm::ffi::make_object<GemmNode>(*this);
return Gemm(op);
}
......@@ -476,8 +476,8 @@ bool GemmNode::CheckWGMMA() const {
*/
static int GetArchInt(Target target) {
int arch_int = 0;
auto s = target->GetAttr<String>("arch");
ICHECK(s.defined());
auto s = target->GetAttr<tvm::ffi::String>("arch");
ICHECK(s.has_value());
std::string arch = s.value();
if (arch.rfind("sm_", 0) == 0) {
arch_int = std::stoi(arch.substr(3));
......@@ -874,7 +874,7 @@ TIR_REGISTER_TL_OP(Gemm, gemm)
TVM_REGISTER_OP("tl.GemmWarpPolicy")
.set_attr<TScriptPrinterName>("TScriptPrinterName", "GemmWarpPolicy");
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
GemmNode::RegisterReflection();
GemmWarpPolicyNode::RegisterReflection();
namespace refl = tvm::ffi::reflection;
......@@ -883,9 +883,8 @@ TVM_FFI_STATIC_INIT_BLOCK({
Target target, GemmInst gemm_inst) {
policy->ComputeWarpPartition(M, N, block_size, target,
gemm_inst);
return;
});
});
}
} // namespace tl
} // namespace tvm
......@@ -30,8 +30,7 @@ public:
mutable int n_warp{0};
int policy_type;
static constexpr const char *_type_key = "tl.GemmWarpPolicy";
TVM_DECLARE_FINAL_OBJECT_INFO(GemmWarpPolicyNode, Object);
TVM_FFI_DECLARE_OBJECT_INFO("tl.GemmWarpPolicy", GemmWarpPolicyNode, Object);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
......@@ -41,21 +40,6 @@ public:
.def_ro("n_warp", &GemmWarpPolicyNode::n_warp);
}
bool SEqualReduce(const GemmWarpPolicyNode *other,
SEqualReducer equal) const {
return equal(policy_type, other->policy_type) &&
equal(m_warp, other->m_warp) && equal(n_warp, other->n_warp);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(policy_type);
hash_reduce(m_warp);
hash_reduce(n_warp);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
std::pair<int, int> ComputeWarpPartition(int M, int N, int block_size,
Target target,
GemmInst gemm_inst) const;
......@@ -74,22 +58,23 @@ public:
class GemmWarpPolicy : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(GemmWarpPolicy, ObjectRef, GemmWarpPolicyNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmWarpPolicy, ObjectRef,
GemmWarpPolicyNode);
explicit GemmWarpPolicy(GemmWarpPolicyType policy_type) {
auto node = make_object<GemmWarpPolicyNode>();
auto node = tvm::ffi::make_object<GemmWarpPolicyNode>();
node->policy_type = (int)policy_type;
data_ = std::move(node);
}
explicit GemmWarpPolicy(int policy_type) {
auto node = make_object<GemmWarpPolicyNode>();
auto node = tvm::ffi::make_object<GemmWarpPolicyNode>();
node->policy_type = policy_type;
data_ = std::move(node);
}
explicit GemmWarpPolicy(int m_warp, int n_warp) {
auto node = make_object<GemmWarpPolicyNode>();
auto node = tvm::ffi::make_object<GemmWarpPolicyNode>();
node->m_warp = m_warp;
node->n_warp = n_warp;
node->policy_type = (int)GemmWarpPolicyType::kFree;
......@@ -116,9 +101,7 @@ public:
std::optional<tir::Buffer> mbar; // mbar is optional, only used for TCGEN5MMA
Array<PrimExpr> C_coords;
mutable GemmWarpPolicy policy;
static constexpr const char *_type_key = "tl.Gemm";
TVM_DECLARE_FINAL_OBJECT_INFO(GemmNode, TileOperatorNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Gemm", GemmNode, TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
......@@ -144,45 +127,6 @@ public:
.def_ro("policy", &GemmNode::policy);
}
bool SEqualReduce(const GemmNode *other, SEqualReducer equal) const {
return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) &&
equal(Aptr, other->Aptr) && equal(Bptr, other->Bptr) &&
equal(Cptr, other->Cptr) && equal(trans_A, other->trans_A) &&
equal(trans_B, other->trans_B) && equal(M, other->M) &&
equal(N, other->N) && equal(K, other->K) &&
equal(stride_A, other->stride_A) &&
equal(stride_B, other->stride_B) &&
equal(offset_A, other->offset_A) &&
equal(offset_B, other->offset_B) &&
equal(clear_accum, other->clear_accum) &&
equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) &&
equal(policy, other->policy);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(A);
hash_reduce(B);
hash_reduce(C);
hash_reduce(Aptr);
hash_reduce(Bptr);
hash_reduce(Cptr);
hash_reduce(trans_A);
hash_reduce(trans_B);
hash_reduce(M);
hash_reduce(N);
hash_reduce(K);
hash_reduce(stride_A);
hash_reduce(stride_B);
hash_reduce(offset_A);
hash_reduce(offset_B);
hash_reduce(clear_accum);
hash_reduce(kPack);
hash_reduce(wg_wait);
hash_reduce(policy);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
......@@ -199,7 +143,7 @@ private:
class Gemm : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Gemm, TileOperator, GemmNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Gemm, TileOperator, GemmNode);
TVM_DLL Gemm(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
......@@ -207,4 +151,4 @@ public:
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_GEMM_H_
\ No newline at end of file
#endif // TVM_TL_OP_GEMM_H_
......@@ -11,8 +11,8 @@
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/transform.h>
#include "../support/ffi_aliases.h"
#include "../target/utils.h"
#include "tvm/ffi/string.h"
namespace tvm {
namespace tl {
......@@ -48,7 +48,7 @@ using namespace tir;
* performed here.
*/
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmPyNode> node = make_object<GemmPyNode>();
ObjectPtr<GemmPyNode> node = tvm::ffi::make_object<GemmPyNode>();
node->Aptr = args[0];
node->Bptr = args[1];
......@@ -88,7 +88,7 @@ GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A Gemm operator that owns a copy of this node.
*/
TileOperator GemmPyNode::Clone() const {
auto op = make_object<GemmPyNode>(*this);
auto op = tvm::ffi::make_object<GemmPyNode>(*this);
return GemmPy(op);
}
......@@ -208,8 +208,8 @@ bool GemmPyNode::CheckWGMMA() const {
*/
static int GetArchInt(Target target) {
int arch_int = 0;
auto s = target->GetAttr<String>("arch");
ICHECK(s.defined());
auto s = target->GetAttr<tvm::ffi::String>("arch");
ICHECK(s.has_value());
std::string arch = s.value();
if (arch.rfind("sm_", 0) == 0) {
arch_int = std::stoi(arch.substr(3));
......@@ -228,11 +228,12 @@ Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) {
auto prim_func =
Downcast<PrimFunc>((*f)(GetRef<GemmPy>(this), T.layout_map, T.target,
T.thread_bounds, T.thread_var));
Downcast<PrimFunc>((*f)(tvm::ffi::GetRef<GemmPy>(this), T.layout_map,
T.target, T.thread_bounds, T.thread_var));
ICHECK(prim_func->attrs.defined());
auto global_symbol = prim_func->attrs.GetAttr<String>("global_symbol");
ICHECK(global_symbol.defined());
auto global_symbol =
prim_func->attrs.GetAttr<tvm::ffi::String>("global_symbol");
ICHECK(global_symbol.has_value());
if (prim_func->body.as<BlockRealizeNode>()) {
BlockRealize block_realize = Downcast<BlockRealize>(prim_func->body);
auto block = block_realize->block;
......@@ -265,7 +266,7 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T,
if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.infer_layout")) {
results = Downcast<LayoutMap>(
(*f)(GetRef<GemmPy>(this), T.target, T.thread_bounds));
(*f)(tvm::ffi::GetRef<GemmPy>(this), T.target, T.thread_bounds));
} else {
LOG(FATAL) << "No infer layout function found for gemm_py";
}
......@@ -279,15 +280,15 @@ TIR_REGISTER_TL_OP(GemmPy, gemm_py)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ GemmPyNode::RegisterReflection(); });
TVM_FFI_STATIC_INIT_BLOCK() { GemmPyNode::RegisterReflection(); }
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.GemmPyGemmInst",
[](GemmPy gemm_py, int block_size, Target target) {
return gemm_py->GetGemmInst(block_size, target);
});
});
}
} // namespace tl
} // namespace tvm
......@@ -33,8 +33,7 @@ public:
int wg_wait = 0;
mutable GemmWarpPolicy policy;
static constexpr const char *_type_key = "tl.GemmPy";
TVM_DECLARE_FINAL_OBJECT_INFO(GemmPyNode, TileOperatorNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmPy", GemmPyNode, TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
......@@ -60,45 +59,6 @@ public:
.def_ro("policy", &GemmPyNode::policy);
}
bool SEqualReduce(const GemmPyNode *other, SEqualReducer equal) const {
return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) &&
equal(Aptr, other->Aptr) && equal(Bptr, other->Bptr) &&
equal(Cptr, other->Cptr) && equal(trans_A, other->trans_A) &&
equal(trans_B, other->trans_B) && equal(M, other->M) &&
equal(N, other->N) && equal(K, other->K) &&
equal(stride_A, other->stride_A) &&
equal(stride_B, other->stride_B) &&
equal(offset_A, other->offset_B) &&
equal(offset_B, other->offset_B) &&
equal(clear_accum, other->clear_accum) &&
equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) &&
equal(policy, other->policy);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(A);
hash_reduce(B);
hash_reduce(C);
hash_reduce(Aptr);
hash_reduce(Bptr);
hash_reduce(Cptr);
hash_reduce(trans_A);
hash_reduce(trans_B);
hash_reduce(M);
hash_reduce(N);
hash_reduce(K);
hash_reduce(stride_A);
hash_reduce(stride_B);
hash_reduce(offset_A);
hash_reduce(offset_B);
hash_reduce(clear_accum);
hash_reduce(kPack);
hash_reduce(wg_wait);
hash_reduce(policy);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
......@@ -114,7 +74,7 @@ private:
class GemmPy : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(GemmPy, TileOperator, GemmPyNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmPy, TileOperator, GemmPyNode);
TVM_DLL GemmPy(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
......@@ -122,4 +82,4 @@ public:
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_GEMM_PY_H_
\ No newline at end of file
#endif // TVM_TL_OP_GEMM_PY_H_
......@@ -84,7 +84,7 @@ std::pair<int, int> GemmSPWarpPolicyNode::ComputeWarpPartition(int M, int N,
* @note An ICHECK failure is raised if a provided kPack is not 1 or 2.
*/
GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmSPNode> node = make_object<GemmSPNode>();
ObjectPtr<GemmSPNode> node = tvm::ffi::make_object<GemmSPNode>();
node->A = vmap[GetVarFromAccessPtr(args[0])];
node->E = vmap[GetVarFromAccessPtr(args[1])];
node->B = vmap[GetVarFromAccessPtr(args[2])];
......@@ -118,7 +118,7 @@ GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator holding a cloned GemmSPNode.
*/
TileOperator GemmSPNode::Clone() const {
auto op = make_object<GemmSPNode>(*this);
auto op = tvm::ffi::make_object<GemmSPNode>(*this);
return GemmSP(op);
}
......@@ -303,7 +303,7 @@ TIR_REGISTER_TL_OP(GemmSP, gemm_sp)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ GemmSPNode::RegisterReflection(); });
TVM_FFI_STATIC_INIT_BLOCK() { GemmSPNode::RegisterReflection(); }
} // namespace tl
} // namespace tvm
......@@ -21,27 +21,29 @@ public:
std::pair<int, int> ComputeWarpPartition(int M, int N, int block_size,
Target target, bool use_wgmma,
int bits) const;
TVM_FFI_DECLARE_OBJECT_INFO("tl.GemmSPWarpPolicy", GemmSPWarpPolicyNode,
GemmWarpPolicyNode);
};
class GemmSPWarpPolicy : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(GemmSPWarpPolicy, ObjectRef,
GemmSPWarpPolicyNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSPWarpPolicy, ObjectRef,
GemmSPWarpPolicyNode);
explicit GemmSPWarpPolicy(GemmWarpPolicyType policy_type) {
auto node = make_object<GemmSPWarpPolicyNode>();
auto node = tvm::ffi::make_object<GemmSPWarpPolicyNode>();
node->policy_type = (int)policy_type;
data_ = std::move(node);
}
explicit GemmSPWarpPolicy(int policy_type) {
auto node = make_object<GemmSPWarpPolicyNode>();
auto node = tvm::ffi::make_object<GemmSPWarpPolicyNode>();
node->policy_type = policy_type;
data_ = std::move(node);
}
explicit GemmSPWarpPolicy(int m_warp, int n_warp) {
auto node = make_object<GemmSPWarpPolicyNode>();
auto node = tvm::ffi::make_object<GemmSPWarpPolicyNode>();
node->m_warp = m_warp;
node->n_warp = n_warp;
node->policy_type = (int)GemmWarpPolicyType::kFree;
......@@ -62,8 +64,7 @@ public:
mutable GemmSPWarpPolicy policy;
static constexpr const char *_type_key = "tl.GemmSP";
TVM_DECLARE_FINAL_OBJECT_INFO(GemmSPNode, TileOperatorNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmSP", GemmSPNode, TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
......@@ -88,38 +89,13 @@ public:
.def_ro("wg_wait", &GemmSPNode::wg_wait);
}
bool SEqualReduce(const GemmSPNode *other, SEqualReducer equal) const {
return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) &&
equal(E, other->E) && equal(trans_A, other->trans_A) &&
equal(trans_B, other->trans_B) && equal(M, other->M) &&
equal(N, other->N) && equal(K, other->K) &&
equal(clear_accum, other->clear_accum) &&
equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(policy);
hash_reduce(A);
hash_reduce(B);
hash_reduce(C);
hash_reduce(E);
hash_reduce(trans_A);
hash_reduce(trans_B);
hash_reduce(M);
hash_reduce(N);
hash_reduce(K);
hash_reduce(clear_accum);
hash_reduce(kPack);
hash_reduce(wg_wait);
}
private:
mutable bool completed_ = false;
};
class GemmSP : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(GemmSP, TileOperator, GemmSPNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSP, TileOperator, GemmSPNode);
TVM_DLL GemmSP(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
......
......@@ -9,6 +9,8 @@
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include "../support/ffi_aliases.h"
namespace tvm {
namespace tl {
using namespace tir;
......@@ -50,4 +52,4 @@ TVM_REGISTER_OP("tl.all_of")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", all_of_op);
} // namespace tl
} // namespace tvm
\ No newline at end of file
} // namespace tvm
......@@ -9,6 +9,8 @@
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include "../support/ffi_aliases.h"
namespace tvm {
namespace tl {
using namespace tir;
......
......@@ -55,7 +55,7 @@ TileOperator ParseOperator(Call call, BufferMap vmap) {
TileOperator ParseOperator(Stmt stmt, BufferMap vmap) {
if (stmt.as<Evaluate>() && stmt.as<EvaluateNode>()->value.as<CallNode>()) {
auto call = stmt.as<EvaluateNode>()->value.as<CallNode>();
return ParseOperator(GetRef<Call>(call), vmap);
return ParseOperator(tvm::ffi::GetRef<Call>(call), vmap);
}
return TileOperator();
}
......@@ -77,7 +77,7 @@ Var GetVarFromAccessPtr(const PrimExpr &expr) {
ICHECK(call->op.same_as(builtin::tvm_access_ptr()));
auto var = call->args[1].as<VarNode>();
ICHECK(var);
return GetRef<Var>(var);
return tvm::ffi::GetRef<Var>(var);
}
} // namespace tl
......
......@@ -62,14 +62,13 @@ public:
virtual TileOperator Clone() const = 0;
static constexpr const char *_type_key = "tl.TileOperator";
TVM_DECLARE_BASE_OBJECT_INFO(TileOperatorNode, Object);
TVM_FFI_DECLARE_OBJECT_INFO("tl.TileOperator", TileOperatorNode, Object);
};
class TileOperator : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TileOperator, ObjectRef, TileOperatorNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TileOperator, ObjectRef,
TileOperatorNode);
};
Var GetVarFromAccessPtr(const PrimExpr &expr);
......
......@@ -178,7 +178,7 @@ ParallelOpNode::ParallelOpNode(For root) : root_(root), V(this) {
}
TileOperator ParallelOpNode::Clone() const {
auto op = make_object<ParallelOpNode>(*this);
auto op = tvm::ffi::make_object<ParallelOpNode>(*this);
return ParallelOp(op);
}
......@@ -642,7 +642,7 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
->CondenseReplicateVar();
}
TVM_FFI_STATIC_INIT_BLOCK({ ParallelOpNode::RegisterReflection(); });
TVM_FFI_STATIC_INIT_BLOCK() { ParallelOpNode::RegisterReflection(); }
} // namespace tl
} // namespace tvm
......@@ -66,8 +66,8 @@ public:
mutable Optional<PrimExpr> predicate_;
// Type key for TVM object system.
static constexpr const char *_type_key = "tl.ParallelOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ParallelOpNode, TileOperatorNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ParallelOp", ParallelOpNode,
TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
......@@ -77,20 +77,6 @@ public:
.def_ro("predicate", &ParallelOpNode::predicate_);
}
bool SEqualReduce(const ParallelOpNode *other, SEqualReducer equal) const {
return equal(root_, other->root_) &&
equal(loop_layout_, other->loop_layout_) &&
equal(predicate_, other->predicate_);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(root_);
hash_reduce(loop_layout_);
hash_reduce(predicate_);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
// Construct from a root For loop.
ParallelOpNode(For root);
......@@ -150,10 +136,11 @@ private:
class ParallelOp : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ParallelOp, TileOperator, ParallelOpNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ParallelOp, TileOperator,
ParallelOpNode);
ParallelOp(const For &root) {
auto op = make_object<ParallelOpNode>(root);
auto op = tvm::ffi::make_object<ParallelOpNode>(root);
data_ = std::move(op);
}
};
......
......@@ -22,7 +22,7 @@ namespace tl {
using namespace tir;
ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<ReduceOpNode> node = make_object<ReduceOpNode>();
ObjectPtr<ReduceOpNode> node = tvm::ffi::make_object<ReduceOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])];
node->dst = vmap[GetVarFromAccessPtr(args[1])];
std::string reduce_type = args[2].as<StringImm>().value()->value;
......@@ -33,12 +33,12 @@ ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
}
TileOperator ReduceOpNode::Clone() const {
auto op = make_object<ReduceOpNode>(*this);
auto op = tvm::ffi::make_object<ReduceOpNode>(*this);
return ReduceOp(op);
}
TileOperator CumSumOpNode::Clone() const {
auto op = make_object<CumSumOpNode>(*this);
auto op = tvm::ffi::make_object<CumSumOpNode>(*this);
return CumSumOp(op);
}
......@@ -85,6 +85,7 @@ PrimExpr ReduceOpNode::MakeInitValue() const {
return make_zero(dst->dtype);
} else {
LOG(FATAL) << "Unsupported reduce type: " << type->type;
return PrimExpr();
}
}
......@@ -512,7 +513,7 @@ CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
/// - dim: dimension to cumsum
/// - reverse: whether to cumsum in reverse order
CHECK_EQ(args.size(), 4);
ObjectPtr<CumSumOpNode> node = make_object<CumSumOpNode>();
ObjectPtr<CumSumOpNode> node = tvm::ffi::make_object<CumSumOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])];
node->dst = vmap[GetVarFromAccessPtr(args[1])];
node->dim = args[2].as<IntImm>().value()->value;
......@@ -567,5 +568,12 @@ TIR_REGISTER_TL_OP(CumSumOp, cumsum)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK() {
ReduceOpNode::RegisterReflection();
CumSumOpNode::RegisterReflection();
ReduceTypeNode::RegisterReflection();
}
} // namespace tl
} // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment