Unverified Commit 10911e28 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[FFI] Rebase tvm to v0.22.0 to utilize tvm-ffi (#1108)



* 3rdparty tvm bump

* bump tvm into v0.22.0

* lint fix

* rebase tvm

* Update submodule tvm to latest commit 3085bc4

* Refactor: Update configuration retrieval in CopyNode and adjust test registration in tilelang

* test fix

* add requirement

* atomic_fix

* atomic_fix

* phaseout py39

* optimize

* optimize

* lint fix

* do not clean cache

* do not clean cache

* [Minor] Minor update for Python versions and dependencies

* [Lint] fix lint for py39

* [Lint] fix lint for ROCm

* [Build][CI] Sync CI changes from upstream/sdist

* [Lint] fix lint for ROCm

* [Build][CI] Update `repair-wheel-command`

* [Minor] update abi3audit result format

* [Lint] fix lint for ROCm

* [BugFix] fix build

* [Lint] fix lint for ROCm

* [BugFix] set rpath for libtvm and libtvm_runtime

* [Deps] pin apache-tvm-ffi version

* [Build] set Python 3.9 Limited API for Cython target

* [Build] set Python 3.9 Limited API for Cython target

* [Deps] Restore Python 3.8 support

* [Build] use `apache-tvm-ffi`'s `libtvm_ffi`

* [BugFix] use `;` as delimiter for RPATH on macOS

* [BugFix] use `--ignore-missing-dependencies` for `delocate-wheel`

* [Build] support `sccache` if available

* [Build] add CIBW import test

* [Build][CI] enable ccache for CIBW on Linux

* [BugFix] set rpath for libtvm and libtvm_runtime

* Revert "[Build][CI] enable ccache for CIBW on Linux"

This reverts commit cd9ab57bb5ddd2572c60bcbbebde81480a658fd3.

* [CI] fix perfbench bot

* [BugFix] use Python 3.9 to build wheel

* [Minor] update perfbench bot envs

* [BugFix] fix CIBW environment on Linux

* [CI] skip import test on CentOS 7

* [CI] use Python urllib to download file instead of Wget

---------
Co-authored-by: default avatarXuehai Pan <XuehaiPan@pku.edu.cn>
parent c37621c5
......@@ -66,17 +66,17 @@ struct ReducerInfoNode : Object {
ReducerInfoNode() = default;
ReducerInfoNode(const String &op_str, const String &rep_str);
static constexpr const char *_type_key = "tl.ReducerInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(ReducerInfoNode, Object);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ReducerInfo", ReducerInfoNode, Object);
};
struct ReducerInfo : ObjectRef {
public:
TVM_DLL ReducerInfo(const String &op_str, const String &rep_str) {
data_ = make_object<ReducerInfoNode>(op_str, rep_str);
data_ = tvm::ffi::make_object<ReducerInfoNode>(op_str, rep_str);
}
TVM_DEFINE_OBJECT_REF_METHODS(ReducerInfo, ObjectRef, ReducerInfoNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReducerInfo, ObjectRef,
ReducerInfoNode);
};
namespace attr {
......
......@@ -38,7 +38,7 @@ private:
StmtVisitor::VisitStmt(op->body);
if (!has_child_for_) {
leaf_for_nodes.push_back(GetRef<For>(op));
leaf_for_nodes.push_back(tvm::ffi::GetRef<For>(op));
}
parent_has_child_for_ = parent_has_child_for;
......@@ -378,11 +378,11 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() {
}
// Register the pass globally so it can be used in the compilation pipeline
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LegalizeSafeMemoryAccess",
LegalizeSafeMemoryAccess);
});
}
} // namespace tl
} // namespace tvm
......@@ -89,11 +89,11 @@ tvm::transform::Pass LegalizeVectorizedLoop() {
}
// Register the pass globally so it can be used in the compilation pipeline
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LegalizeVectorizedLoop",
LegalizeVectorizedLoop);
});
}
} // namespace tl
} // namespace tvm
......@@ -173,7 +173,7 @@ private:
if (as_const_int(analyzer->Simplify(node->extent)) == nullptr) {
return StmtExprMutator::VisitStmt_(node);
}
For new_for = GetRef<For>(node);
For new_for = tvm::ffi::GetRef<For>(node);
auto for_ptr = new_for.CopyOnWrite();
for_ptr->annotations.Set(tir::attr::pragma_unroll_explicit, Bool(false));
for_ptr->kind = ForKind::kUnrolled;
......
......@@ -240,8 +240,9 @@ int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); }
bool CanProveIndependent(const PrimExpr &expr, Var var,
arith::Analyzer *analyzer) {
// 1. if var doesn't exist, it is independent
bool used_var = UsesVar(
expr, [&](const VarNode *v) { return GetRef<Var>(v).same_as(var); });
bool used_var = UsesVar(expr, [&](const VarNode *v) {
return tvm::ffi::GetRef<Var>(v).same_as(var);
});
if (!used_var) {
return true;
}
......
......@@ -231,10 +231,10 @@ private:
if (flag) {
return thenexpr;
} else {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
} else {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
}
......@@ -535,11 +535,11 @@ tvm::transform::Pass LoopVectorizeDynamic() {
}
// Register the pass globally so it can be used in the compilation pipeline
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LoopVectorizeDynamic",
LoopVectorizeDynamic);
});
}
} // namespace tl
} // namespace tvm
......@@ -36,7 +36,7 @@ namespace tvm {
namespace tl {
using namespace tir;
using namespace ffi;
namespace {
struct KernelInfo {
// The device on which the PrimFunc runs
......@@ -372,8 +372,8 @@ tvm::transform::Pass LowerDeviceKernelLaunch() {
IRModule updates;
for (const auto &[gvar, base_func] : mod->functions) {
if (auto *ptr = base_func.as<PrimFuncNode>()) {
auto prim_func =
mutator.RewriteKernelLaunchSite(gvar, GetRef<PrimFunc>(ptr));
auto prim_func = mutator.RewriteKernelLaunchSite(
gvar, tvm::ffi::GetRef<PrimFunc>(ptr));
if (!prim_func.same_as(base_func)) {
updates->Add(gvar, prim_func);
}
......@@ -388,8 +388,8 @@ tvm::transform::Pass LowerDeviceKernelLaunch() {
IRModule updates;
for (const auto &[gvar, base_func] : mod->functions) {
if (auto *ptr = base_func.as<PrimFuncNode>()) {
auto prim_func =
mutator.UpdateKernelAttributes(gvar, GetRef<PrimFunc>(ptr));
auto prim_func = mutator.UpdateKernelAttributes(
gvar, tvm::ffi::GetRef<PrimFunc>(ptr));
if (!prim_func.same_as(base_func)) {
updates->Add(gvar, prim_func);
}
......@@ -407,11 +407,11 @@ tvm::transform::Pass LowerDeviceKernelLaunch() {
"tl.LowerDeviceKernelLaunch", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerDeviceKernelLaunch",
LowerDeviceKernelLaunch);
});
}
} // namespace transform
} // namespace tl
......
......@@ -143,11 +143,11 @@ Pass LowerDeviceStorageAccessInfo() {
{});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerDeviceStorageAccessInfo",
LowerDeviceStorageAccessInfo);
});
}
} // namespace transform
} // namespace tl
......
......@@ -113,14 +113,14 @@ public:
if (call->op.same_as(create_tma_descriptor()) ||
call->op.same_as(create_tma_im2col_descriptor())) {
Var var;
auto iter = desc_map_.find(GetRef<Call>(call));
auto iter = desc_map_.find(tvm::ffi::GetRef<Call>(call));
if (iter != desc_map_.end()) {
var = iter->second;
} else {
String name = call->args[2].as<Var>().value()->name_hint;
var = Var(name + "_desc",
PointerType(PrimType(cuTensorMapType()), "grid_constant"));
desc_map_[GetRef<Call>(call)] = var;
desc_map_[tvm::ffi::GetRef<Call>(call)] = var;
prefetch_calls_.push_back(
Evaluate(Call(DataType::Handle(), builtin::call_extern(),
{StringImm("tl::prefetch_tma_descriptor"), var})));
......@@ -161,10 +161,10 @@ tvm::transform::Pass LowerHopperIntrin() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerHopperIntrin", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerHopperIntrin", LowerHopperIntrin);
});
}
#endif // (CUDA_MAJOR_VERSION >= 12)
} // namespace tl
......
......@@ -37,6 +37,7 @@
namespace tvm {
namespace tl {
using namespace tir;
using namespace ffi;
class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
public:
......@@ -70,9 +71,9 @@ public:
PrimExpr VisitExpr_(const CallNode *op) final {
if (auto *ptr_op = op->op.as<OpNode>()) {
for (const auto &f_attr_map : attr_maps_) {
FLowerGeneral f = f_attr_map.get(GetRef<Op>(ptr_op), nullptr);
FLowerGeneral f = f_attr_map.get(tvm::ffi::GetRef<Op>(ptr_op), nullptr);
if (f != nullptr) {
PrimExpr e = GetRef<PrimExpr>(op);
PrimExpr e = tvm::ffi::GetRef<PrimExpr>(op);
PrimExpr r = f(e);
ICHECK(r.defined()) << "intrinsic rule must always return valid Expr";
if (!r.same_as(e)) {
......@@ -99,7 +100,7 @@ public:
// We use floordiv for integer analysis,
// but will need to lower them to native truncdiv instructions
PrimExpr VisitExpr_(const FloorDivNode *op) final {
auto e = GetRef<PrimExpr>(op);
auto e = tvm::ffi::GetRef<PrimExpr>(op);
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorDivNode>();
if (op == nullptr)
......@@ -305,7 +306,7 @@ public:
using namespace arith;
PVar<PrimExpr> x, y;
PVar<IntImm> c;
auto e = GetRef<PrimExpr>(op);
auto e = tvm::ffi::GetRef<PrimExpr>(op);
if (max(floordiv(x, y), c).Match(e) && c.Eval()->value >= 0 &&
analyzer_->CanProveGreaterEqual(y.Eval(), 0)) {
return max(VisitExpr(truncdiv(x, y).Eval()), c.Eval());
......@@ -316,7 +317,7 @@ public:
PrimExpr VisitExpr_(const EQNode *op) final {
using namespace arith;
PVar<PrimExpr> x, y;
auto e = GetRef<PrimExpr>(op);
auto e = tvm::ffi::GetRef<PrimExpr>(op);
if ((floormod(x, y) == 0).Match(e)) {
return VisitExpr((truncmod(x, y) == 0).Eval());
}
......@@ -326,7 +327,7 @@ public:
PrimExpr VisitExpr_(const NENode *op) final {
using namespace arith;
PVar<PrimExpr> x, y;
auto e = GetRef<PrimExpr>(op);
auto e = tvm::ffi::GetRef<PrimExpr>(op);
if ((floormod(x, y) != 0).Match(e)) {
return VisitExpr((truncmod(x, y) != 0).Eval());
}
......@@ -413,10 +414,10 @@ tir::transform::Pass LowerIntrin() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerIntrin", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerIntrin", LowerIntrin);
});
}
} // namespace transform
......
......@@ -98,10 +98,10 @@ tvm::transform::Pass LowerL2Persistent() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerL2Persistent", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerL2Persistent", LowerL2Persistent);
});
}
} // namespace tl
} // namespace tvm
......@@ -151,7 +151,7 @@ private:
}
PrimExpr VisitExpr_(const VarNode *op) final {
Var var = GetRef<Var>(op);
Var var = tvm::ffi::GetRef<Var>(op);
auto it = unit_loop_vars_.find(var);
if (it == unit_loop_vars_.end()) {
return var;
......@@ -286,10 +286,10 @@ tir::transform::Pass LowerOpaqueBlock() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerOpaqueBlock", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerOpaqueBlock", LowerOpaqueBlock);
});
}
} // namespace tl
} // namespace tvm
......@@ -32,7 +32,7 @@ private:
: disable_shuffle_elect_(disable_shuffle_elect) {}
Stmt VisitStmt_(const BlockNode *op) final {
Block block = GetRef<Block>(op);
Block block = tvm::ffi::GetRef<Block>(op);
Array<Buffer> alloc_buffers = op->alloc_buffers;
// Record the mapping from buffer data var to buffer for later lookup
......@@ -204,10 +204,10 @@ tvm::transform::Pass LowerSharedBarrier() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerSharedBarrier", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerSharedBarrier", LowerSharedBarrier);
});
}
} // namespace transform
} // namespace tl
......
......@@ -30,7 +30,7 @@ public:
private:
Stmt VisitStmt_(const BlockNode *op) final {
Block block = GetRef<Block>(op);
Block block = tvm::ffi::GetRef<Block>(op);
Array<Buffer> alloc_buffers = op->alloc_buffers;
if (op->annotations.count(attr::kLayoutMap)) {
auto layout_map = op->annotations.Get(attr::kLayoutMap);
......@@ -300,10 +300,10 @@ tvm::transform::Pass LowerSharedTmem() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerSharedTmem", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerSharedTmem", LowerSharedTmem);
});
}
} // namespace transform
} // namespace tl
......
......@@ -39,6 +39,7 @@
namespace tvm {
namespace tl {
using namespace tir;
using namespace ffi;
using runtime::StorageRank;
using runtime::StorageScope;
......@@ -944,11 +945,11 @@ tvm::transform::Pass LowerThreadAllreduce() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerThreadAllreduce", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerThreadAllreduce",
LowerThreadAllreduce);
});
}
} // namespace transform
} // namespace tl
......
......@@ -435,7 +435,7 @@ private:
return expr;
}
if (const auto *var_node = expr.as<VarNode>()) {
Var var = GetRef<Var>(var_node);
Var var = tvm::ffi::GetRef<Var>(var_node);
auto it = let_bindings_.find(var);
if (it != let_bindings_.end()) {
return it->second;
......@@ -611,7 +611,7 @@ private:
let_bindings_.erase(op->var);
}
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
} else {
auto n = this->CopyOnWrite(op);
n->value = value;
......@@ -652,7 +652,8 @@ private:
if (call && call->op.as<GlobalVarNode>())
return Downcast<Evaluate>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto tile_op = ParseOperator(GetRef<Stmt>(op), buffer_data_to_buffer_);
auto tile_op =
ParseOperator(tvm::ffi::GetRef<Stmt>(op), buffer_data_to_buffer_);
if (!tile_op.defined())
return IRMutatorWithAnalyzer::VisitStmt_(op);
AddWorkspaceCallback callback = [this](int num_elem, DataType dtype) {
......@@ -730,10 +731,10 @@ tvm::transform::Pass LowerTileOp() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerTileOp", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerTileOp", LowerTileOp);
});
}
} // namespace transform
} // namespace tl
......
......@@ -42,6 +42,7 @@
namespace tvm {
namespace tl {
using namespace tir;
using namespace ffi;
static constexpr const char *kDeviceContextVar = "device_api_context";
namespace {
......@@ -168,7 +169,7 @@ private:
auto node = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
if (auto *gvar_ptr = node->op.as<GlobalVarNode>()) {
auto gvar = GetRef<GlobalVar>(gvar_ptr);
auto gvar = tvm::ffi::GetRef<GlobalVar>(gvar_ptr);
if (auto symbol = packed_func_methods.Get(gvar)) {
Array<PrimExpr> cpacked_args;
cpacked_args.push_back(tir::StringImm(symbol.value()));
......@@ -220,7 +221,7 @@ Optional<String> RequiresPackedAPI(const PrimFunc &func) {
// Internal function calls do not need the PackedFunc API
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
if (!global_symbol.defined()) {
if (!global_symbol) {
return std::nullopt;
}
......@@ -229,7 +230,7 @@ Optional<String> RequiresPackedAPI(const PrimFunc &func) {
PrimFunc MakePackedAPI(PrimFunc func) {
auto global_symbol = RequiresPackedAPI(func);
if (!global_symbol.defined()) {
if (!global_symbol) {
return func;
}
std::string name_hint = global_symbol.value();
......@@ -406,7 +407,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
StringImm(name_hint + "_compute_"), body);
// Set device context
if (vmap.count(device_id.get())) {
ObjectRef node = String("default");
auto node = String("default");
seq_check.push_back(AttrStmt(node, tir::attr::device_id, device_id, nop));
seq_check.push_back(
AttrStmt(node, tir::attr::device_type, device_type, nop));
......@@ -513,11 +514,11 @@ tvm::transform::Pass MakePackedAPI() {
return tvm::transform::CreateModulePass(pass_func, 0, "tl.MakePackedAPI", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.MakePackedAPI",
[]() { return MakePackedAPI(); });
});
}
} // namespace tl
} // namespace tvm
......@@ -98,10 +98,10 @@ tvm::transform::Pass MergeIfStmt() {
return CreatePrimFuncPass(pass_func, 0, "tl.MergeIfStmt", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.MergeIfStmt", MergeIfStmt);
});
}
} // namespace tl
} // namespace tvm
......@@ -162,7 +162,7 @@ public:
auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size());
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
if (IsAppropriateSharedMemory(tvm::ffi::GetRef<Var>(buf))) {
// set into scope_.size() - 1 for aggressive memory reuse
auto enable_aggressive_merge = enable_aggressive_merge_;
if (enable_aggressive_merge) {
......@@ -209,7 +209,7 @@ public:
// the merged allocator can reason about their lifetime correctly.
ICHECK_LE(it->second.level, scope_.size())
<< "Load memory in places other than store.";
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
if (IsAppropriateSharedMemory(tvm::ffi::GetRef<Var>(buf))) {
auto enable_aggressive_merge = enable_aggressive_merge_;
if (enable_aggressive_merge) {
scope_[scope_.size() - 1].touched.push_back(buf);
......@@ -233,7 +233,7 @@ public:
// emitted at the allocation level after flattening, so accept them and
// record the touch for liveness planning.
ICHECK_LE(it->second.level, scope_.size());
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
if (IsAppropriateSharedMemory(tvm::ffi::GetRef<Var>(buf))) {
auto enable_aggressive_merge = enable_aggressive_merge_;
if (enable_aggressive_merge) {
scope_[scope_.size() - 1].touched.push_back(buf);
......@@ -372,7 +372,7 @@ private:
void VisitExpr_(const VarNode *op) {
auto ptr_type = op->type_annotation.as<PointerTypeNode>();
if (ptr_type && under_alignment_scope_) {
auto scope = GetPtrStorageScope(GetRef<Var>(op));
auto scope = GetPtrStorageScope(tvm::ffi::GetRef<Var>(op));
if (scope == "shared" || scope == "shared.dyn") {
auto target = Target::Current();
ICHECK(target.defined()) << "Target is not defined";
......@@ -1343,11 +1343,11 @@ Pass MergeSharedMemoryAllocations(bool enable_aggressive_merge = false,
{});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.MergeSharedMemoryAllocations",
MergeSharedMemoryAllocations);
});
}
} // namespace transform
} // namespace tl
......
......@@ -57,7 +57,7 @@ public:
// Check reads from global
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
/*body*/ GetRef<Stmt>(op));
/*body*/ tvm::ffi::GetRef<Stmt>(op));
auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
auto reads = access[0];
Role role = Role::kProducer;
......@@ -253,7 +253,8 @@ private:
}
static Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get()));
ObjectPtr<BufferNode> new_buffer =
tvm::ffi::make_object<BufferNode>(*(buffer.get()));
new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
if (!new_buffer->strides.empty()) {
ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
......@@ -493,10 +494,10 @@ tvm::transform::Pass MultiVersionBuffer() {
return CreatePrimFuncPass(pass_func, 0, "tl.MultiVersionBuffer", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.MultiVersionBuffer", MultiVersionBuffer);
});
}
} // namespace tl
} // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment