Unverified Commit 10911e28 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[FFI] Rebase tvm to v0.22.0 to utilize tvm-ffi (#1108)



* 3rdparty tvm bump

* bump tvm into v0.22.0

* lint fix

* rebase tvm

* Update submodule tvm to latest commit 3085bc4

* Refactor: Update configuration retrieval in CopyNode and adjust test registration in tilelang

* test fix

* add requirement

* atomic_fix

* atomic_fix

* phaseout py39

* optimize

* optimize

* lint fix

* do not clean cache

* do not clean cache

* [Minor] Minor update for Python versions and dependencies

* [Lint] fix lint for py39

* [Lint] fix lint for ROCm

* [Build][CI] Sync CI changes from upstream/sdist

* [Lint] fix lint for ROCm

* [Build][CI] Update `repair-wheel-command`

* [Minor] update abi3audit result format

* [Lint] fix lint for ROCm

* [BugFix] fix build

* [Lint] fix lint for ROCm

* [BugFix] set rpath for libtvm and libtvm_runtime

* [Deps] pin apache-tvm-ffi version

* [Build] set Python 3.9 Limited API for Cython target

* [Build] set Python 3.9 Limited API for Cython target

* [Deps] Restore Python 3.8 support

* [Build] use `apache-tvm-ffi`'s `libtvm_ffi`

* [BugFix] use `;` as delimiter for RPATH on macOS

* [BugFix] use `--ignore-missing-dependencies` for `delocate-wheel`

* [Build] support `sccache` if available

* [Build] add CIBW import test

* [Build][CI] enable ccache for CIBW on Linux

* [BugFix] set rpath for libtvm and libtvm_runtime

* Revert "[Build][CI] enable ccache for CIBW on Linux"

This reverts commit cd9ab57bb5ddd2572c60bcbbebde81480a658fd3.

* [CI] fix perfbench bot

* [BugFix] use Python 3.9 to build wheel

* [Minor] update perfbench bot envs

* [BugFix] fix CIBW environment on Linux

* [CI] skip import test on CentOS 7

* [CI] use Python urllib to download file instead of Wget

---------
Co-authored-by: default avatarXuehai Pan <XuehaiPan@pku.edu.cn>
parent c37621c5
......@@ -46,13 +46,13 @@ public:
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tvm::attr::kTarget) {
// If a target attribute already exists, use it as-is.
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
} else if (op->attr_key == tir::attr::thread_extent ||
op->attr_key == tir::attr::pipeline_exec_scope ||
op->attr_key == tir::attr::device_scope) {
// These attributes are only allowed in device-side code, so
// they should be annotated with the function's default target.
Stmt body = GetRef<Stmt>(op);
Stmt body = tvm::ffi::GetRef<Stmt>(op);
return AttrStmt(device_target_, tvm::attr::kTarget, 0, body);
} else {
// All other annotations are ignored
......@@ -90,11 +90,11 @@ tvm::transform::Pass AnnotateDeviceRegions() {
return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateDeviceRegions", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.AnnotateDeviceRegions",
AnnotateDeviceRegions);
});
}
} // namespace tl
} // namespace tvm
......@@ -181,11 +181,11 @@ tvm::transform::Pass AnnotateWarpGroupRegAlloc() {
return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateWarpGroupRegAlloc", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.AnnotateWarpGroupRegAlloc",
AnnotateWarpGroupRegAlloc);
});
}
} // namespace tl
} // namespace tvm
......@@ -80,8 +80,8 @@ void ArgBinder::Bind(const PrimExpr &arg, const PrimExpr &value,
Bind_(arg, value, arg_name, with_let);
}
void ArgBinder::BindArray(const Array<PrimExpr> &arg,
const Array<PrimExpr> &value,
void ArgBinder::BindArray(const ffi::Array<PrimExpr> &arg,
const ffi::Array<PrimExpr> &value,
const std::string &arg_name) {
ICHECK_EQ(arg.size(), value.size())
<< "Argument " << arg_name << " array size mismatch";
......@@ -250,7 +250,7 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
// Assert the buffer is compact
DataType stype = buffer->DefaultIndexType();
PrimExpr expect_stride = make_const(stype, 1);
Array<PrimExpr> conds;
ffi::Array<PrimExpr> conds;
for (size_t i = buffer->shape.size(); i != 0; --i) {
size_t k = i - 1;
PrimExpr svalue =
......
......@@ -82,7 +82,8 @@ public:
* \param value The target expression value
* \param arg_name argument name.
*/
void BindArray(const Array<PrimExpr> &arg, const Array<PrimExpr> &value,
void BindArray(const ffi::Array<PrimExpr> &arg,
const ffi::Array<PrimExpr> &value,
const std::string &arg_name);
/*!
* \brief Bind symbolic buffer to another symbolic buffer
......@@ -149,7 +150,7 @@ public:
*/
const std::vector<Stmt> &init_nest() const { return init_nest_; }
/*! \return Handle data type of the data */
const Map<Var, PrimExpr> &def_handle_dtype() const {
const ffi::Map<Var, PrimExpr> &def_handle_dtype() const {
return def_handle_dtype_;
}
......@@ -164,7 +165,7 @@ private:
/*! \brief Initialize nest */
std::vector<Stmt> init_nest_;
/*! \brief handle data type in the defintiions */
Map<Var, PrimExpr> def_handle_dtype_;
ffi::Map<Var, PrimExpr> def_handle_dtype_;
/*! \brief asserts generated */
std::vector<Stmt> asserts_;
/*! \brief internal analyzer. */
......
......@@ -249,7 +249,6 @@ private:
new_args.push_back(dst_node);
new_args.push_back(value_node);
}
new_args.push_back(memory_order);
Call new_call =
......@@ -284,4 +283,4 @@ For VectorizeAtomicAdd(const For &for_node, int compute_capability) {
}
} // namespace tl
} // namespace tvm
\ No newline at end of file
} // namespace tvm
......@@ -10,6 +10,8 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../support/ffi_aliases.h"
namespace tvm {
namespace tir {
......@@ -66,7 +68,8 @@ public:
}
if (mem_reuse_max > 0) {
std::string tag_str = cluster_tag; // Convert to std::string
std::string tag_str =
static_cast<std::string>(cluster_tag); // Convert to std::string
if (tag_str.rfind("blockIdx", 0) == 0) {
// starts with "blockIdx"
tag_str = "clusterIdx" + tag_str.substr(strlen("blockIdx"));
......@@ -74,7 +77,7 @@ public:
// Unexpected format — maybe just prefix
tag_str = "clusterIdx" + tag_str;
}
cluster_tag = tvm::ffi::String(tag_str); // Convert back
cluster_tag = String(tag_str); // Convert back
return WithAttr(f, cluster_tag, Integer(cluster_size_));
} else {
return f;
......@@ -122,10 +125,10 @@ tvm::transform::Pass ClusterPlanning() {
return CreatePrimFuncPass(pass_func, 0, "tl.ClusterPlanning", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.ClusterPlanning", ClusterPlanning);
});
}
} // namespace transform
} // namespace tir
......
......@@ -41,7 +41,7 @@ public:
return StmtMutator::VisitStmt_(op);
// Collect loop variables and ranges
auto for_node = GetRef<For>(op);
auto for_node = tvm::ffi::GetRef<For>(op);
Array<Var> loop_vars;
Array<PrimExpr> loop_extents;
Stmt body = op->body;
......@@ -81,7 +81,7 @@ public:
// post order visit the index
PostOrderVisit(index, [&](const ObjectRef &obj) {
if (const VarNode *v = obj.as<VarNode>()) {
used_vars.insert(GetRef<Var>(v));
used_vars.insert(tvm::ffi::GetRef<Var>(v));
}
});
if (used_vars.empty()) {
......
......@@ -211,7 +211,7 @@ public:
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector();
bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector();
......@@ -265,7 +265,7 @@ public:
PrimExpr VisitExpr_(const NotNode *op) final {
PrimExpr a = this->VisitExpr(op->a);
if (a.same_as(op->a)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
return !(a);
}
......@@ -306,10 +306,10 @@ public:
PrimExpr value = this->VisitExpr(op->value);
if (value.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true;
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
return Broadcast(op->value, op->lanes);
}
......@@ -321,7 +321,7 @@ public:
PrimExpr f = this->VisitExpr(op->false_value);
if (cond.same_as(op->condition) && t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
int cond_lanes = cond.dtype().get_lanes_or_vscale_factor();
int t_lanes = t.dtype().get_lanes_or_vscale_factor();
......@@ -339,7 +339,7 @@ public:
PrimExpr VisitExpr_(const CastNode *op) final {
PrimExpr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
if (value.dtype().is_scalable_vector()) {
return Cast(op->dtype.with_scalable_vscale_factor(
......@@ -352,20 +352,20 @@ public:
}
PrimExpr VisitExpr_(const FloatImmNode *op) final {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const IntImmNode *op) final {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const StringImmNode *op) final {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
// Variable
PrimExpr VisitExpr_(const VarNode *op) final {
Var var = GetRef<Var>(op);
Var var = tvm::ffi::GetRef<Var>(op);
if (var.same_as(var_)) {
return ramp_;
......@@ -382,13 +382,13 @@ public:
PrimExpr cond = this->VisitExpr(op->args[0]);
if (cond.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true;
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
PrimExpr t = this->VisitExpr(op->args[1]);
PrimExpr f = this->VisitExpr(op->args[2]);
if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) &&
f.same_as(op->args[2])) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
int t_lanes = t.dtype().get_lanes_or_vscale_factor();
int f_lanes = f.dtype().get_lanes_or_vscale_factor();
......@@ -410,7 +410,7 @@ public:
ICHECK(op->op.same_as(builtin::reinterpret()));
PrimExpr value = this->VisitExpr(op->args[0]);
if (value.same_as(op->args[0])) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
int lanes = value.dtype().get_lanes_or_vscale_factor();
if (value.dtype().is_scalable_vector()) {
......@@ -455,12 +455,12 @@ public:
auto new_arg = this->VisitExpr(arg);
if (new_arg.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true;
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
new_args.push_back(new_arg);
}
if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
return Call(op->dtype, op->op, new_args);
}
......@@ -469,7 +469,7 @@ public:
Array<PrimExpr> new_args = MutateArray(op->args, &lane);
// normal code path.
if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
return Call(op->dtype.with_lanes(lane), op->op, new_args);
}
......@@ -477,7 +477,7 @@ public:
}
// BufferLoad
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = GetRef<BufferLoad>(op);
auto load = tvm::ffi::GetRef<BufferLoad>(op);
auto fmutate = [this](const PrimExpr &index) {
return this->VisitExpr(index);
......@@ -514,7 +514,7 @@ public:
let_binding_[op->var] = op->var;
PrimExpr body = this->VisitExpr(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
return Let(op->var, value, body);
}
......@@ -522,7 +522,7 @@ public:
}
// BufferStore
Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store = GetRef<BufferStore>(op);
auto store = tvm::ffi::GetRef<BufferStore>(op);
auto fmutate = [this](const PrimExpr &index) {
return this->VisitExpr(index);
......@@ -585,11 +585,11 @@ public:
ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector());
PrimExpr extent = this->VisitExpr(op->extent);
if (extent.dtype().is_scalable_or_fixed_length_vector()) {
return Scalarize(GetRef<Stmt>(op));
return Scalarize(tvm::ffi::GetRef<Stmt>(op));
}
Stmt body = this->VisitStmt(op->body);
if (extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
} else {
return For(op->loop_var, op->min, extent, op->kind, body,
op->thread_binding, op->annotations);
......@@ -600,7 +600,7 @@ public:
ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector());
PrimExpr condition = this->VisitExpr(op->condition);
if (condition.dtype().is_scalable_or_fixed_length_vector()) {
return Scalarize(GetRef<Stmt>(op));
return Scalarize(tvm::ffi::GetRef<Stmt>(op));
}
Stmt then_case = this->VisitStmt(op->then_case);
Optional<Stmt> else_case = std::nullopt;
......@@ -609,7 +609,7 @@ public:
}
if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
} else {
return IfThenElse(condition, then_case, else_case);
}
......@@ -634,7 +634,7 @@ public:
let_binding_[op->var] = op->var;
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
} else {
return LetStmt(op->var, value, body);
}
......@@ -647,7 +647,7 @@ public:
if (condition.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of "
<< op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op));
return Scalarize(tvm::ffi::GetRef<Stmt>(op));
}
// Mutate the extents
......@@ -657,7 +657,7 @@ public:
if (new_ext.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of "
<< op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op));
return Scalarize(tvm::ffi::GetRef<Stmt>(op));
}
extents.push_back(new_ext);
}
......@@ -738,7 +738,7 @@ private:
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor();
......@@ -754,7 +754,7 @@ private:
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor();
......
......@@ -38,7 +38,7 @@ protected:
if (is_enabled_ && op->dtype.is_int() && op->dtype.bits() < 64) {
return IntImm(DataType::Int(_index_bitwidth_), op->value);
}
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const CastNode *op) final {
......@@ -88,23 +88,23 @@ private:
PrimExpr VisitExpr_(const VarNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), GetRef<Var>(op));
return cast(DataType::Int(64), tvm::ffi::GetRef<Var>(op));
}
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const IntImmNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return IntImm(DataType::Int(64), op->value);
}
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const CastNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), op->value);
}
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
......@@ -183,11 +183,11 @@ tvm::transform::Pass ConfigIndexBitwidth() {
return CreatePrimFuncPass(pass_func, 0, "tl.ConfigIndexBitwidth", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.ConfigIndexBitwidth",
ConfigIndexBitwidth);
});
}
} // namespace tl
} // namespace tvm
......@@ -35,9 +35,7 @@ public:
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == "thread_extent") {
const VarNode *var = nullptr;
if (op->node->IsInstance<VarNode>()) {
var = op->node.as<VarNode>();
if (const auto *var = op->node.as<VarNode>()) {
if (var->name_hint == "threadIdx.x") {
thread_extent_ = op;
}
......@@ -82,7 +80,7 @@ public:
}
Stmt VisitStmt_(const ForNode *op) final {
PostOrderVisit(GetRef<For>(op), [&](const ObjectRef &node) {
PostOrderVisit(tvm::ffi::GetRef<For>(op), [&](const ObjectRef &node) {
if (const auto *call = node.as<CallNode>()) {
if (call->op.same_as(create_list_of_mbarrier()) ||
call->op.same_as(mbarrier_wait_parity()) ||
......@@ -116,11 +114,11 @@ tvm::transform::Pass EliminateStorageSyncForMBarrier() {
{});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.EliminateStorageSyncForMBarrier",
EliminateStorageSyncForMBarrier);
});
}
} // namespace transform
} // namespace tl
......
......@@ -75,23 +75,23 @@ private:
PrimExpr VisitExpr_(const VarNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), GetRef<Var>(op));
return cast(DataType::Int(64), tvm::ffi::GetRef<Var>(op));
}
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const IntImmNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return IntImm(DataType::Int(64), op->value);
}
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const CastNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), op->value);
}
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
......@@ -115,7 +115,7 @@ private:
<< "All MatchBufferRegion should be removed in "
"tir.transform.LowerMatchBuffer.";
Block block = GetRef<Block>(op);
Block block = tvm::ffi::GetRef<Block>(op);
Array<Buffer> alloc_buffers = op->alloc_buffers;
alloc_buffers.MutateByApply(
......@@ -385,10 +385,10 @@ tvm::transform::Pass FlattenBuffer() {
return CreatePrimFuncPass(pass_func, 0, "tl.FlattenBuffer", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.FlattenBuffer", FlattenBuffer);
});
}
} // namespace tl
} // namespace tvm
......@@ -89,10 +89,10 @@ Pass LetInline() {
return CreatePrimFuncPass(pass_func, 0, "tl.LetInline", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LetInline", LetInline);
});
}
} // namespace tl
} // namespace tvm
......@@ -33,7 +33,7 @@ private:
auto then_case = VisitStmt(op->then_case);
Optional<Stmt> else_case = op->else_case;
if (else_case.defined()) {
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
}
ICHECK(then_case.defined()) << "then_case must be defined";
ICHECK(!else_case.defined()) << "else_case must be undefined";
......@@ -81,10 +81,10 @@ tvm::transform::Pass IfStmtBinding() {
return CreatePrimFuncPass(pass_func, 0, "tl.IfStmtBinding", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.IfStmtBinding", IfStmtBinding);
});
}
} // namespace tl
} // namespace tvm
......@@ -156,9 +156,9 @@ tvm::transform::Pass InjectAssumes() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectAssumes", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectAssumes", InjectAssumes);
});
}
} // namespace tvm::tl
......@@ -319,10 +319,10 @@ tvm::transform::Pass InjectFenceProxy() {
{});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectFenceProxy", InjectFenceProxy);
});
}
} // namespace tl
} // namespace tvm
......@@ -37,7 +37,7 @@
namespace tvm {
namespace tl {
using namespace tir;
using namespace ffi;
namespace software_pipeline {
/*!
......@@ -459,7 +459,8 @@ private:
* \return The resized buffer.
*/
Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get()));
ObjectPtr<BufferNode> new_buffer =
tvm::ffi::make_object<BufferNode>(*(buffer.get()));
new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
if (!new_buffer->strides.empty()) {
ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
......@@ -865,7 +866,7 @@ private:
const SeqStmtNode *pipeline_body_seq = nullptr;
std::vector<std::function<Stmt(Stmt)>> rewrap_fns;
auto append_attr_wrapper = [&rewrap_fns](const AttrStmtNode *attr) {
ObjectRef node = attr->node;
Any node = attr->node;
String attr_key = attr->attr_key;
PrimExpr value = attr->value;
Span span = attr->span;
......@@ -981,7 +982,7 @@ private:
// Step 4: Rewrite the pipeline body.
Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
GetRef<For>(op), pipeline_info)
tvm::ffi::GetRef<For>(op), pipeline_info)
.BuildPipeline();
auto apply_wrappers = [&](Stmt stmt) {
for (auto it = rewrap_fns.rbegin(); it != rewrap_fns.rend(); ++it) {
......@@ -1072,11 +1073,11 @@ tir::transform::Pass InjectSoftwarePipeline() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectSoftwarePipeline",
InjectSoftwarePipeline);
});
}
} // namespace tl
} // namespace tvm
......@@ -232,10 +232,10 @@ tvm::transform::Pass InjectPTXAsyncCopy() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectPTXAsyncCopy", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectPTXAsyncCopy", InjectPTXAsyncCopy);
});
}
} // namespace tl
} // namespace tvm
......@@ -204,9 +204,9 @@ private:
void VisitStmt_(const EvaluateNode *op) final {
if (const auto *call = op->value.as<CallNode>()) {
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
pending_tma_ops_.push_back(GetRef<Call>(call));
pending_tma_ops_.push_back(tvm::ffi::GetRef<Call>(call));
} else if (call->op.same_as(mbarrier_expect_tx())) {
pending_tma_ops_.push_back(GetRef<Call>(call));
pending_tma_ops_.push_back(tvm::ffi::GetRef<Call>(call));
} else if (call->op.same_as(builtin::ptx_arrive_barrier())) {
PrimExpr barrier_id = call->args[0];
for (const auto &tma_call : pending_tma_ops_) {
......@@ -295,8 +295,9 @@ public:
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(mbarrier_expect_tx())) {
PrimExpr e =
tma_op_to_barrier_id_[GetRef<Call>(op)].as<CallNode>()->args[0];
PrimExpr e = tma_op_to_barrier_id_[tvm::ffi::GetRef<Call>(op)]
.as<CallNode>()
->args[0];
auto int_set = arith::EvalSet(e, var_int_set_);
expect_.push_back(if_depth_ == 1);
sequence.push_back(0);
......@@ -406,7 +407,7 @@ public:
private:
Stmt VisitStmt_(const BlockNode *op) {
auto block = GetRef<Block>(op);
auto block = tvm::ffi::GetRef<Block>(op);
if (!has_create_list_of_mbarrier_ && !barrier_id_to_range_.empty() &&
op->name_hint == MainBlockName) {
ICHECK(false) << "Please declare create_list_of_mbarrier.";
......@@ -453,9 +454,9 @@ private:
PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
// check this must be in the tma_op_to_barrier_id_
ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op)))
ICHECK(tma_op_to_barrier_id_.count(tvm::ffi::GetRef<Call>(op)))
<< "tma_load must be in the tma_op_to_barrier_id_";
auto barrier_id = tma_op_to_barrier_id_[GetRef<Call>(op)];
auto barrier_id = tma_op_to_barrier_id_[tvm::ffi::GetRef<Call>(op)];
auto new_args = op->args;
auto arg0 = op->args[0].as<Call>();
auto is_1d_tma_load =
......@@ -468,9 +469,9 @@ private:
}
return Call(op->dtype, op->op, new_args);
} else if (op->op.same_as(mbarrier_expect_tx())) {
ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op)))
ICHECK(tma_op_to_barrier_id_.count(tvm::ffi::GetRef<Call>(op)))
<< "mbarrier_expect_tx must be in the tma_op_to_barrier_id_";
auto barrier_id = tma_op_to_barrier_id_[GetRef<Call>(op)];
auto barrier_id = tma_op_to_barrier_id_[tvm::ffi::GetRef<Call>(op)];
auto new_args = op->args;
new_args.Set(0, barrier_id);
if (!has_warp_specialization_)
......@@ -522,10 +523,10 @@ tvm::transform::Pass InjectTmaBarrier() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectTmaBarrier", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectTmaBarrier", InjectTmaBarrier);
});
}
} // namespace tl
} // namespace tvm
......@@ -330,7 +330,7 @@ private:
if (op->op.as<GlobalVarNode>())
return;
auto p = ParseOperator(GetRef<Call>(op), buffer_data_to_buffer_);
auto p = ParseOperator(tvm::ffi::GetRef<Call>(op), buffer_data_to_buffer_);
if (p.defined()) {
for (const auto &arg : op->args) {
if (auto buffer = getBufferFromAccessPtr(arg)) {
......@@ -381,7 +381,7 @@ private:
}
// Add the tile operator to infer_list_
infer_list_stmt_.push_back(GetRef<ObjectRef>(op));
infer_list_stmt_.push_back(tvm::ffi::GetRef<ObjectRef>(op));
infer_list_.push_back(std::move(p));
}
}
......@@ -416,11 +416,11 @@ private:
void VisitStmt_(const ForNode *op) final {
if (op->kind == ForKind::kParallel) {
auto infer = ParallelOp(GetRef<For>(op));
auto infer = ParallelOp(tvm::ffi::GetRef<For>(op));
for (const auto &[buffer, _] : infer->GetIndiceMap()) {
addToUseList(buffer);
}
infer_list_stmt_.push_back(GetRef<ObjectRef>(op));
infer_list_stmt_.push_back(tvm::ffi::GetRef<ObjectRef>(op));
infer_list_.push_back(std::move(infer));
thread_var_vec_.push_back(thread_var_);
if (thread_var_.defined() &&
......@@ -713,8 +713,8 @@ private:
.value();
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
if (result_.for_map.count(GetRef<For>(op))) {
auto root = GetRef<For>(op);
if (result_.for_map.count(tvm::ffi::GetRef<For>(op))) {
auto root = tvm::ffi::GetRef<For>(op);
// This check is a workaround to support T.Parallel for local buffers.
// For example:
// for i in T.Parallel(1024):
......@@ -844,10 +844,10 @@ tvm::transform::Pass LayoutInference() {
return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LayoutInference", LayoutInference);
});
}
} // namespace tl
} // namespace tvm
......@@ -362,10 +362,10 @@ tvm::transform::Pass LayoutReducer() {
return CreatePrimFuncPass(pass_func, 0, "tl.LayoutReducer", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LayoutReducer", LayoutReducer);
});
}
} // namespace tl
} // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment