Unverified Commit 10911e28 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[FFI] Rebase tvm to v0.22.0 to utilize tvm-ffi (#1108)



* 3rdparty tvm bump

* bump tvm into v0.22.0

* lint fix

* rebase tvm

* Update submodule tvm to latest commit 3085bc4

* Refactor: Update configuration retrieval in CopyNode and adjust test registration in tilelang

* test fix

* add requirement

* atomic_fix

* atomic_fix

* phaseout py39

* optimize

* optimize

* lint fix

* do not clean cache

* do not clean cache

* [Minor] Minor update for Python versions and dependencies

* [Lint] fix lint for py39

* [Lint] fix lint for ROCm

* [Build][CI] Sync CI changes from upstream/sdist

* [Lint] fix lint for ROCm

* [Build][CI] Update `repair-wheel-command`

* [Minor] update abi3audit result format

* [Lint] fix lint for ROCm

* [BugFix] fix build

* [Lint] fix lint for ROCm

* [BugFix] set rpath for libtvm and libtvm_runtime

* [Deps] pin apache-tvm-ffi version

* [Build] set Python 3.9 Limited API for Cython target

* [Build] set Python 3.9 Limited API for Cython target

* [Deps] Restore Python 3.8 support

* [Build] use `apache-tvm-ffi`'s `libtvm_ffi`

* [BugFix] use `;` as delimiter for RPATH on macOS

* [BugFix] use `--ignore-missing-dependencies` for `delocate-wheel`

* [Build] support `sccache` if available

* [Build] add CIBW import test

* [Build][CI] enable ccache for CIBW on Linux

* [BugFix] set rpath for libtvm and libtvm_runtime

* Revert "[Build][CI] enable ccache for CIBW on Linux"

This reverts commit cd9ab57bb5ddd2572c60bcbbebde81480a658fd3.

* [CI] fix perfbench bot

* [BugFix] use Python 3.9 to build wheel

* [Minor] update perfbench bot envs

* [BugFix] fix CIBW environment on Linux

* [CI] skip import test on CentOS 7

* [CI] use Python urllib to download file instead of Wget

---------
Co-authored-by: default avatarXuehai Pan <XuehaiPan@pku.edu.cn>
parent c37621c5
...@@ -66,17 +66,17 @@ struct ReducerInfoNode : Object { ...@@ -66,17 +66,17 @@ struct ReducerInfoNode : Object {
ReducerInfoNode() = default; ReducerInfoNode() = default;
ReducerInfoNode(const String &op_str, const String &rep_str); ReducerInfoNode(const String &op_str, const String &rep_str);
static constexpr const char *_type_key = "tl.ReducerInfo"; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ReducerInfo", ReducerInfoNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(ReducerInfoNode, Object);
}; };
struct ReducerInfo : ObjectRef { struct ReducerInfo : ObjectRef {
public: public:
TVM_DLL ReducerInfo(const String &op_str, const String &rep_str) { TVM_DLL ReducerInfo(const String &op_str, const String &rep_str) {
data_ = make_object<ReducerInfoNode>(op_str, rep_str); data_ = tvm::ffi::make_object<ReducerInfoNode>(op_str, rep_str);
} }
TVM_DEFINE_OBJECT_REF_METHODS(ReducerInfo, ObjectRef, ReducerInfoNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReducerInfo, ObjectRef,
ReducerInfoNode);
}; };
namespace attr { namespace attr {
......
...@@ -38,7 +38,7 @@ private: ...@@ -38,7 +38,7 @@ private:
StmtVisitor::VisitStmt(op->body); StmtVisitor::VisitStmt(op->body);
if (!has_child_for_) { if (!has_child_for_) {
leaf_for_nodes.push_back(GetRef<For>(op)); leaf_for_nodes.push_back(tvm::ffi::GetRef<For>(op));
} }
parent_has_child_for_ = parent_has_child_for; parent_has_child_for_ = parent_has_child_for;
...@@ -378,11 +378,11 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() { ...@@ -378,11 +378,11 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() {
} }
// Register the pass globally so it can be used in the compilation pipeline // Register the pass globally so it can be used in the compilation pipeline
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LegalizeSafeMemoryAccess", refl::GlobalDef().def("tl.transform.LegalizeSafeMemoryAccess",
LegalizeSafeMemoryAccess); LegalizeSafeMemoryAccess);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -89,11 +89,11 @@ tvm::transform::Pass LegalizeVectorizedLoop() { ...@@ -89,11 +89,11 @@ tvm::transform::Pass LegalizeVectorizedLoop() {
} }
// Register the pass globally so it can be used in the compilation pipeline // Register the pass globally so it can be used in the compilation pipeline
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LegalizeVectorizedLoop", refl::GlobalDef().def("tl.transform.LegalizeVectorizedLoop",
LegalizeVectorizedLoop); LegalizeVectorizedLoop);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -173,7 +173,7 @@ private: ...@@ -173,7 +173,7 @@ private:
if (as_const_int(analyzer->Simplify(node->extent)) == nullptr) { if (as_const_int(analyzer->Simplify(node->extent)) == nullptr) {
return StmtExprMutator::VisitStmt_(node); return StmtExprMutator::VisitStmt_(node);
} }
For new_for = GetRef<For>(node); For new_for = tvm::ffi::GetRef<For>(node);
auto for_ptr = new_for.CopyOnWrite(); auto for_ptr = new_for.CopyOnWrite();
for_ptr->annotations.Set(tir::attr::pragma_unroll_explicit, Bool(false)); for_ptr->annotations.Set(tir::attr::pragma_unroll_explicit, Bool(false));
for_ptr->kind = ForKind::kUnrolled; for_ptr->kind = ForKind::kUnrolled;
......
...@@ -240,8 +240,9 @@ int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); } ...@@ -240,8 +240,9 @@ int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); }
bool CanProveIndependent(const PrimExpr &expr, Var var, bool CanProveIndependent(const PrimExpr &expr, Var var,
arith::Analyzer *analyzer) { arith::Analyzer *analyzer) {
// 1. if var doesn't exist, it is independent // 1. if var doesn't exist, it is independent
bool used_var = UsesVar( bool used_var = UsesVar(expr, [&](const VarNode *v) {
expr, [&](const VarNode *v) { return GetRef<Var>(v).same_as(var); }); return tvm::ffi::GetRef<Var>(v).same_as(var);
});
if (!used_var) { if (!used_var) {
return true; return true;
} }
......
...@@ -231,10 +231,10 @@ private: ...@@ -231,10 +231,10 @@ private:
if (flag) { if (flag) {
return thenexpr; return thenexpr;
} else { } else {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
} else { } else {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
} }
...@@ -535,11 +535,11 @@ tvm::transform::Pass LoopVectorizeDynamic() { ...@@ -535,11 +535,11 @@ tvm::transform::Pass LoopVectorizeDynamic() {
} }
// Register the pass globally so it can be used in the compilation pipeline // Register the pass globally so it can be used in the compilation pipeline
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LoopVectorizeDynamic", refl::GlobalDef().def("tl.transform.LoopVectorizeDynamic",
LoopVectorizeDynamic); LoopVectorizeDynamic);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -36,7 +36,7 @@ namespace tvm { ...@@ -36,7 +36,7 @@ namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
using namespace ffi;
namespace { namespace {
struct KernelInfo { struct KernelInfo {
// The device on which the PrimFunc runs // The device on which the PrimFunc runs
...@@ -372,8 +372,8 @@ tvm::transform::Pass LowerDeviceKernelLaunch() { ...@@ -372,8 +372,8 @@ tvm::transform::Pass LowerDeviceKernelLaunch() {
IRModule updates; IRModule updates;
for (const auto &[gvar, base_func] : mod->functions) { for (const auto &[gvar, base_func] : mod->functions) {
if (auto *ptr = base_func.as<PrimFuncNode>()) { if (auto *ptr = base_func.as<PrimFuncNode>()) {
auto prim_func = auto prim_func = mutator.RewriteKernelLaunchSite(
mutator.RewriteKernelLaunchSite(gvar, GetRef<PrimFunc>(ptr)); gvar, tvm::ffi::GetRef<PrimFunc>(ptr));
if (!prim_func.same_as(base_func)) { if (!prim_func.same_as(base_func)) {
updates->Add(gvar, prim_func); updates->Add(gvar, prim_func);
} }
...@@ -388,8 +388,8 @@ tvm::transform::Pass LowerDeviceKernelLaunch() { ...@@ -388,8 +388,8 @@ tvm::transform::Pass LowerDeviceKernelLaunch() {
IRModule updates; IRModule updates;
for (const auto &[gvar, base_func] : mod->functions) { for (const auto &[gvar, base_func] : mod->functions) {
if (auto *ptr = base_func.as<PrimFuncNode>()) { if (auto *ptr = base_func.as<PrimFuncNode>()) {
auto prim_func = auto prim_func = mutator.UpdateKernelAttributes(
mutator.UpdateKernelAttributes(gvar, GetRef<PrimFunc>(ptr)); gvar, tvm::ffi::GetRef<PrimFunc>(ptr));
if (!prim_func.same_as(base_func)) { if (!prim_func.same_as(base_func)) {
updates->Add(gvar, prim_func); updates->Add(gvar, prim_func);
} }
...@@ -407,11 +407,11 @@ tvm::transform::Pass LowerDeviceKernelLaunch() { ...@@ -407,11 +407,11 @@ tvm::transform::Pass LowerDeviceKernelLaunch() {
"tl.LowerDeviceKernelLaunch", {}); "tl.LowerDeviceKernelLaunch", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerDeviceKernelLaunch", refl::GlobalDef().def("tl.transform.LowerDeviceKernelLaunch",
LowerDeviceKernelLaunch); LowerDeviceKernelLaunch);
}); }
} // namespace transform } // namespace transform
} // namespace tl } // namespace tl
......
...@@ -143,11 +143,11 @@ Pass LowerDeviceStorageAccessInfo() { ...@@ -143,11 +143,11 @@ Pass LowerDeviceStorageAccessInfo() {
{}); {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerDeviceStorageAccessInfo", refl::GlobalDef().def("tl.transform.LowerDeviceStorageAccessInfo",
LowerDeviceStorageAccessInfo); LowerDeviceStorageAccessInfo);
}); }
} // namespace transform } // namespace transform
} // namespace tl } // namespace tl
......
...@@ -113,14 +113,14 @@ public: ...@@ -113,14 +113,14 @@ public:
if (call->op.same_as(create_tma_descriptor()) || if (call->op.same_as(create_tma_descriptor()) ||
call->op.same_as(create_tma_im2col_descriptor())) { call->op.same_as(create_tma_im2col_descriptor())) {
Var var; Var var;
auto iter = desc_map_.find(GetRef<Call>(call)); auto iter = desc_map_.find(tvm::ffi::GetRef<Call>(call));
if (iter != desc_map_.end()) { if (iter != desc_map_.end()) {
var = iter->second; var = iter->second;
} else { } else {
String name = call->args[2].as<Var>().value()->name_hint; String name = call->args[2].as<Var>().value()->name_hint;
var = Var(name + "_desc", var = Var(name + "_desc",
PointerType(PrimType(cuTensorMapType()), "grid_constant")); PointerType(PrimType(cuTensorMapType()), "grid_constant"));
desc_map_[GetRef<Call>(call)] = var; desc_map_[tvm::ffi::GetRef<Call>(call)] = var;
prefetch_calls_.push_back( prefetch_calls_.push_back(
Evaluate(Call(DataType::Handle(), builtin::call_extern(), Evaluate(Call(DataType::Handle(), builtin::call_extern(),
{StringImm("tl::prefetch_tma_descriptor"), var}))); {StringImm("tl::prefetch_tma_descriptor"), var})));
...@@ -161,10 +161,10 @@ tvm::transform::Pass LowerHopperIntrin() { ...@@ -161,10 +161,10 @@ tvm::transform::Pass LowerHopperIntrin() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerHopperIntrin", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LowerHopperIntrin", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerHopperIntrin", LowerHopperIntrin); refl::GlobalDef().def("tl.transform.LowerHopperIntrin", LowerHopperIntrin);
}); }
#endif // (CUDA_MAJOR_VERSION >= 12) #endif // (CUDA_MAJOR_VERSION >= 12)
} // namespace tl } // namespace tl
......
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
using namespace ffi;
class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
public: public:
...@@ -70,9 +71,9 @@ public: ...@@ -70,9 +71,9 @@ public:
PrimExpr VisitExpr_(const CallNode *op) final { PrimExpr VisitExpr_(const CallNode *op) final {
if (auto *ptr_op = op->op.as<OpNode>()) { if (auto *ptr_op = op->op.as<OpNode>()) {
for (const auto &f_attr_map : attr_maps_) { for (const auto &f_attr_map : attr_maps_) {
FLowerGeneral f = f_attr_map.get(GetRef<Op>(ptr_op), nullptr); FLowerGeneral f = f_attr_map.get(tvm::ffi::GetRef<Op>(ptr_op), nullptr);
if (f != nullptr) { if (f != nullptr) {
PrimExpr e = GetRef<PrimExpr>(op); PrimExpr e = tvm::ffi::GetRef<PrimExpr>(op);
PrimExpr r = f(e); PrimExpr r = f(e);
ICHECK(r.defined()) << "intrinsic rule must always return valid Expr"; ICHECK(r.defined()) << "intrinsic rule must always return valid Expr";
if (!r.same_as(e)) { if (!r.same_as(e)) {
...@@ -99,7 +100,7 @@ public: ...@@ -99,7 +100,7 @@ public:
// We use floordiv for integer analysis, // We use floordiv for integer analysis,
// but will need to lower them to native truncdiv instructions // but will need to lower them to native truncdiv instructions
PrimExpr VisitExpr_(const FloorDivNode *op) final { PrimExpr VisitExpr_(const FloorDivNode *op) final {
auto e = GetRef<PrimExpr>(op); auto e = tvm::ffi::GetRef<PrimExpr>(op);
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorDivNode>(); op = ret.as<FloorDivNode>();
if (op == nullptr) if (op == nullptr)
...@@ -305,7 +306,7 @@ public: ...@@ -305,7 +306,7 @@ public:
using namespace arith; using namespace arith;
PVar<PrimExpr> x, y; PVar<PrimExpr> x, y;
PVar<IntImm> c; PVar<IntImm> c;
auto e = GetRef<PrimExpr>(op); auto e = tvm::ffi::GetRef<PrimExpr>(op);
if (max(floordiv(x, y), c).Match(e) && c.Eval()->value >= 0 && if (max(floordiv(x, y), c).Match(e) && c.Eval()->value >= 0 &&
analyzer_->CanProveGreaterEqual(y.Eval(), 0)) { analyzer_->CanProveGreaterEqual(y.Eval(), 0)) {
return max(VisitExpr(truncdiv(x, y).Eval()), c.Eval()); return max(VisitExpr(truncdiv(x, y).Eval()), c.Eval());
...@@ -316,7 +317,7 @@ public: ...@@ -316,7 +317,7 @@ public:
PrimExpr VisitExpr_(const EQNode *op) final { PrimExpr VisitExpr_(const EQNode *op) final {
using namespace arith; using namespace arith;
PVar<PrimExpr> x, y; PVar<PrimExpr> x, y;
auto e = GetRef<PrimExpr>(op); auto e = tvm::ffi::GetRef<PrimExpr>(op);
if ((floormod(x, y) == 0).Match(e)) { if ((floormod(x, y) == 0).Match(e)) {
return VisitExpr((truncmod(x, y) == 0).Eval()); return VisitExpr((truncmod(x, y) == 0).Eval());
} }
...@@ -326,7 +327,7 @@ public: ...@@ -326,7 +327,7 @@ public:
PrimExpr VisitExpr_(const NENode *op) final { PrimExpr VisitExpr_(const NENode *op) final {
using namespace arith; using namespace arith;
PVar<PrimExpr> x, y; PVar<PrimExpr> x, y;
auto e = GetRef<PrimExpr>(op); auto e = tvm::ffi::GetRef<PrimExpr>(op);
if ((floormod(x, y) != 0).Match(e)) { if ((floormod(x, y) != 0).Match(e)) {
return VisitExpr((truncmod(x, y) != 0).Eval()); return VisitExpr((truncmod(x, y) != 0).Eval());
} }
...@@ -413,10 +414,10 @@ tir::transform::Pass LowerIntrin() { ...@@ -413,10 +414,10 @@ tir::transform::Pass LowerIntrin() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerIntrin", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LowerIntrin", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerIntrin", LowerIntrin); refl::GlobalDef().def("tl.transform.LowerIntrin", LowerIntrin);
}); }
} // namespace transform } // namespace transform
......
...@@ -98,10 +98,10 @@ tvm::transform::Pass LowerL2Persistent() { ...@@ -98,10 +98,10 @@ tvm::transform::Pass LowerL2Persistent() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerL2Persistent", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LowerL2Persistent", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerL2Persistent", LowerL2Persistent); refl::GlobalDef().def("tl.transform.LowerL2Persistent", LowerL2Persistent);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -151,7 +151,7 @@ private: ...@@ -151,7 +151,7 @@ private:
} }
PrimExpr VisitExpr_(const VarNode *op) final { PrimExpr VisitExpr_(const VarNode *op) final {
Var var = GetRef<Var>(op); Var var = tvm::ffi::GetRef<Var>(op);
auto it = unit_loop_vars_.find(var); auto it = unit_loop_vars_.find(var);
if (it == unit_loop_vars_.end()) { if (it == unit_loop_vars_.end()) {
return var; return var;
...@@ -286,10 +286,10 @@ tir::transform::Pass LowerOpaqueBlock() { ...@@ -286,10 +286,10 @@ tir::transform::Pass LowerOpaqueBlock() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerOpaqueBlock", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LowerOpaqueBlock", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerOpaqueBlock", LowerOpaqueBlock); refl::GlobalDef().def("tl.transform.LowerOpaqueBlock", LowerOpaqueBlock);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -32,7 +32,7 @@ private: ...@@ -32,7 +32,7 @@ private:
: disable_shuffle_elect_(disable_shuffle_elect) {} : disable_shuffle_elect_(disable_shuffle_elect) {}
Stmt VisitStmt_(const BlockNode *op) final { Stmt VisitStmt_(const BlockNode *op) final {
Block block = GetRef<Block>(op); Block block = tvm::ffi::GetRef<Block>(op);
Array<Buffer> alloc_buffers = op->alloc_buffers; Array<Buffer> alloc_buffers = op->alloc_buffers;
// Record the mapping from buffer data var to buffer for later lookup // Record the mapping from buffer data var to buffer for later lookup
...@@ -204,10 +204,10 @@ tvm::transform::Pass LowerSharedBarrier() { ...@@ -204,10 +204,10 @@ tvm::transform::Pass LowerSharedBarrier() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerSharedBarrier", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LowerSharedBarrier", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerSharedBarrier", LowerSharedBarrier); refl::GlobalDef().def("tl.transform.LowerSharedBarrier", LowerSharedBarrier);
}); }
} // namespace transform } // namespace transform
} // namespace tl } // namespace tl
......
...@@ -30,7 +30,7 @@ public: ...@@ -30,7 +30,7 @@ public:
private: private:
Stmt VisitStmt_(const BlockNode *op) final { Stmt VisitStmt_(const BlockNode *op) final {
Block block = GetRef<Block>(op); Block block = tvm::ffi::GetRef<Block>(op);
Array<Buffer> alloc_buffers = op->alloc_buffers; Array<Buffer> alloc_buffers = op->alloc_buffers;
if (op->annotations.count(attr::kLayoutMap)) { if (op->annotations.count(attr::kLayoutMap)) {
auto layout_map = op->annotations.Get(attr::kLayoutMap); auto layout_map = op->annotations.Get(attr::kLayoutMap);
...@@ -300,10 +300,10 @@ tvm::transform::Pass LowerSharedTmem() { ...@@ -300,10 +300,10 @@ tvm::transform::Pass LowerSharedTmem() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerSharedTmem", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LowerSharedTmem", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerSharedTmem", LowerSharedTmem); refl::GlobalDef().def("tl.transform.LowerSharedTmem", LowerSharedTmem);
}); }
} // namespace transform } // namespace transform
} // namespace tl } // namespace tl
......
...@@ -39,6 +39,7 @@ ...@@ -39,6 +39,7 @@
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
using namespace ffi;
using runtime::StorageRank; using runtime::StorageRank;
using runtime::StorageScope; using runtime::StorageScope;
...@@ -944,11 +945,11 @@ tvm::transform::Pass LowerThreadAllreduce() { ...@@ -944,11 +945,11 @@ tvm::transform::Pass LowerThreadAllreduce() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerThreadAllreduce", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LowerThreadAllreduce", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerThreadAllreduce", refl::GlobalDef().def("tl.transform.LowerThreadAllreduce",
LowerThreadAllreduce); LowerThreadAllreduce);
}); }
} // namespace transform } // namespace transform
} // namespace tl } // namespace tl
......
...@@ -435,7 +435,7 @@ private: ...@@ -435,7 +435,7 @@ private:
return expr; return expr;
} }
if (const auto *var_node = expr.as<VarNode>()) { if (const auto *var_node = expr.as<VarNode>()) {
Var var = GetRef<Var>(var_node); Var var = tvm::ffi::GetRef<Var>(var_node);
auto it = let_bindings_.find(var); auto it = let_bindings_.find(var);
if (it != let_bindings_.end()) { if (it != let_bindings_.end()) {
return it->second; return it->second;
...@@ -611,7 +611,7 @@ private: ...@@ -611,7 +611,7 @@ private:
let_bindings_.erase(op->var); let_bindings_.erase(op->var);
} }
if (value.same_as(op->value) && body.same_as(op->body)) { if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op); return tvm::ffi::GetRef<Stmt>(op);
} else { } else {
auto n = this->CopyOnWrite(op); auto n = this->CopyOnWrite(op);
n->value = value; n->value = value;
...@@ -652,7 +652,8 @@ private: ...@@ -652,7 +652,8 @@ private:
if (call && call->op.as<GlobalVarNode>()) if (call && call->op.as<GlobalVarNode>())
return Downcast<Evaluate>(IRMutatorWithAnalyzer::VisitStmt_(op)); return Downcast<Evaluate>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto tile_op = ParseOperator(GetRef<Stmt>(op), buffer_data_to_buffer_); auto tile_op =
ParseOperator(tvm::ffi::GetRef<Stmt>(op), buffer_data_to_buffer_);
if (!tile_op.defined()) if (!tile_op.defined())
return IRMutatorWithAnalyzer::VisitStmt_(op); return IRMutatorWithAnalyzer::VisitStmt_(op);
AddWorkspaceCallback callback = [this](int num_elem, DataType dtype) { AddWorkspaceCallback callback = [this](int num_elem, DataType dtype) {
...@@ -730,10 +731,10 @@ tvm::transform::Pass LowerTileOp() { ...@@ -730,10 +731,10 @@ tvm::transform::Pass LowerTileOp() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerTileOp", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LowerTileOp", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerTileOp", LowerTileOp); refl::GlobalDef().def("tl.transform.LowerTileOp", LowerTileOp);
}); }
} // namespace transform } // namespace transform
} // namespace tl } // namespace tl
......
...@@ -42,6 +42,7 @@ ...@@ -42,6 +42,7 @@
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
using namespace ffi;
static constexpr const char *kDeviceContextVar = "device_api_context"; static constexpr const char *kDeviceContextVar = "device_api_context";
namespace { namespace {
...@@ -168,7 +169,7 @@ private: ...@@ -168,7 +169,7 @@ private:
auto node = Downcast<Call>(StmtExprMutator::VisitExpr_(op)); auto node = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
if (auto *gvar_ptr = node->op.as<GlobalVarNode>()) { if (auto *gvar_ptr = node->op.as<GlobalVarNode>()) {
auto gvar = GetRef<GlobalVar>(gvar_ptr); auto gvar = tvm::ffi::GetRef<GlobalVar>(gvar_ptr);
if (auto symbol = packed_func_methods.Get(gvar)) { if (auto symbol = packed_func_methods.Get(gvar)) {
Array<PrimExpr> cpacked_args; Array<PrimExpr> cpacked_args;
cpacked_args.push_back(tir::StringImm(symbol.value())); cpacked_args.push_back(tir::StringImm(symbol.value()));
...@@ -220,7 +221,7 @@ Optional<String> RequiresPackedAPI(const PrimFunc &func) { ...@@ -220,7 +221,7 @@ Optional<String> RequiresPackedAPI(const PrimFunc &func) {
// Internal function calls do not need the PackedFunc API // Internal function calls do not need the PackedFunc API
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol); auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
if (!global_symbol.defined()) { if (!global_symbol) {
return std::nullopt; return std::nullopt;
} }
...@@ -229,7 +230,7 @@ Optional<String> RequiresPackedAPI(const PrimFunc &func) { ...@@ -229,7 +230,7 @@ Optional<String> RequiresPackedAPI(const PrimFunc &func) {
PrimFunc MakePackedAPI(PrimFunc func) { PrimFunc MakePackedAPI(PrimFunc func) {
auto global_symbol = RequiresPackedAPI(func); auto global_symbol = RequiresPackedAPI(func);
if (!global_symbol.defined()) { if (!global_symbol) {
return func; return func;
} }
std::string name_hint = global_symbol.value(); std::string name_hint = global_symbol.value();
...@@ -406,7 +407,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { ...@@ -406,7 +407,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
StringImm(name_hint + "_compute_"), body); StringImm(name_hint + "_compute_"), body);
// Set device context // Set device context
if (vmap.count(device_id.get())) { if (vmap.count(device_id.get())) {
ObjectRef node = String("default"); auto node = String("default");
seq_check.push_back(AttrStmt(node, tir::attr::device_id, device_id, nop)); seq_check.push_back(AttrStmt(node, tir::attr::device_id, device_id, nop));
seq_check.push_back( seq_check.push_back(
AttrStmt(node, tir::attr::device_type, device_type, nop)); AttrStmt(node, tir::attr::device_type, device_type, nop));
...@@ -513,11 +514,11 @@ tvm::transform::Pass MakePackedAPI() { ...@@ -513,11 +514,11 @@ tvm::transform::Pass MakePackedAPI() {
return tvm::transform::CreateModulePass(pass_func, 0, "tl.MakePackedAPI", {}); return tvm::transform::CreateModulePass(pass_func, 0, "tl.MakePackedAPI", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.MakePackedAPI", refl::GlobalDef().def("tl.transform.MakePackedAPI",
[]() { return MakePackedAPI(); }); []() { return MakePackedAPI(); });
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -98,10 +98,10 @@ tvm::transform::Pass MergeIfStmt() { ...@@ -98,10 +98,10 @@ tvm::transform::Pass MergeIfStmt() {
return CreatePrimFuncPass(pass_func, 0, "tl.MergeIfStmt", {}); return CreatePrimFuncPass(pass_func, 0, "tl.MergeIfStmt", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.MergeIfStmt", MergeIfStmt); refl::GlobalDef().def("tl.transform.MergeIfStmt", MergeIfStmt);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -162,7 +162,7 @@ public: ...@@ -162,7 +162,7 @@ public:
auto it = alloc_info_.find(buf); auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) { if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size()); ICHECK_LT(it->second.level, scope_.size());
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) { if (IsAppropriateSharedMemory(tvm::ffi::GetRef<Var>(buf))) {
// set into scope_.size() - 1 for aggressive memory reuse // set into scope_.size() - 1 for aggressive memory reuse
auto enable_aggressive_merge = enable_aggressive_merge_; auto enable_aggressive_merge = enable_aggressive_merge_;
if (enable_aggressive_merge) { if (enable_aggressive_merge) {
...@@ -209,7 +209,7 @@ public: ...@@ -209,7 +209,7 @@ public:
// the merged allocator can reason about their lifetime correctly. // the merged allocator can reason about their lifetime correctly.
ICHECK_LE(it->second.level, scope_.size()) ICHECK_LE(it->second.level, scope_.size())
<< "Load memory in places other than store."; << "Load memory in places other than store.";
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) { if (IsAppropriateSharedMemory(tvm::ffi::GetRef<Var>(buf))) {
auto enable_aggressive_merge = enable_aggressive_merge_; auto enable_aggressive_merge = enable_aggressive_merge_;
if (enable_aggressive_merge) { if (enable_aggressive_merge) {
scope_[scope_.size() - 1].touched.push_back(buf); scope_[scope_.size() - 1].touched.push_back(buf);
...@@ -233,7 +233,7 @@ public: ...@@ -233,7 +233,7 @@ public:
// emitted at the allocation level after flattening, so accept them and // emitted at the allocation level after flattening, so accept them and
// record the touch for liveness planning. // record the touch for liveness planning.
ICHECK_LE(it->second.level, scope_.size()); ICHECK_LE(it->second.level, scope_.size());
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) { if (IsAppropriateSharedMemory(tvm::ffi::GetRef<Var>(buf))) {
auto enable_aggressive_merge = enable_aggressive_merge_; auto enable_aggressive_merge = enable_aggressive_merge_;
if (enable_aggressive_merge) { if (enable_aggressive_merge) {
scope_[scope_.size() - 1].touched.push_back(buf); scope_[scope_.size() - 1].touched.push_back(buf);
...@@ -372,7 +372,7 @@ private: ...@@ -372,7 +372,7 @@ private:
void VisitExpr_(const VarNode *op) { void VisitExpr_(const VarNode *op) {
auto ptr_type = op->type_annotation.as<PointerTypeNode>(); auto ptr_type = op->type_annotation.as<PointerTypeNode>();
if (ptr_type && under_alignment_scope_) { if (ptr_type && under_alignment_scope_) {
auto scope = GetPtrStorageScope(GetRef<Var>(op)); auto scope = GetPtrStorageScope(tvm::ffi::GetRef<Var>(op));
if (scope == "shared" || scope == "shared.dyn") { if (scope == "shared" || scope == "shared.dyn") {
auto target = Target::Current(); auto target = Target::Current();
ICHECK(target.defined()) << "Target is not defined"; ICHECK(target.defined()) << "Target is not defined";
...@@ -1343,11 +1343,11 @@ Pass MergeSharedMemoryAllocations(bool enable_aggressive_merge = false, ...@@ -1343,11 +1343,11 @@ Pass MergeSharedMemoryAllocations(bool enable_aggressive_merge = false,
{}); {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.MergeSharedMemoryAllocations", refl::GlobalDef().def("tl.transform.MergeSharedMemoryAllocations",
MergeSharedMemoryAllocations); MergeSharedMemoryAllocations);
}); }
} // namespace transform } // namespace transform
} // namespace tl } // namespace tl
......
...@@ -57,7 +57,7 @@ public: ...@@ -57,7 +57,7 @@ public:
// Check reads from global // Check reads from global
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
/*body*/ GetRef<Stmt>(op)); /*body*/ tvm::ffi::GetRef<Stmt>(op));
auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
auto reads = access[0]; auto reads = access[0];
Role role = Role::kProducer; Role role = Role::kProducer;
...@@ -253,7 +253,8 @@ private: ...@@ -253,7 +253,8 @@ private:
} }
static Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) { static Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get())); ObjectPtr<BufferNode> new_buffer =
tvm::ffi::make_object<BufferNode>(*(buffer.get()));
new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions)); new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
if (!new_buffer->strides.empty()) { if (!new_buffer->strides.empty()) {
ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size()); ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
...@@ -493,10 +494,10 @@ tvm::transform::Pass MultiVersionBuffer() { ...@@ -493,10 +494,10 @@ tvm::transform::Pass MultiVersionBuffer() {
return CreatePrimFuncPass(pass_func, 0, "tl.MultiVersionBuffer", {}); return CreatePrimFuncPass(pass_func, 0, "tl.MultiVersionBuffer", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.MultiVersionBuffer", MultiVersionBuffer); refl::GlobalDef().def("tl.transform.MultiVersionBuffer", MultiVersionBuffer);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment