Unverified Commit 10911e28 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[FFI] Rebase tvm to v0.22.0 to utilize tvm-ffi (#1108)



* 3rdparty tvm bump

* bump tvm into v0.22.0

* lint fix

* rebase tvm

* Update submodule tvm to latest commit 3085bc4

* Refactor: Update configuration retrieval in CopyNode and adjust test registration in tilelang

* test fix

* add requirement

* atomic_fix

* atomic_fix

* phaseout py39

* optimize

* optimize

* lint fix

* do not clean cache

* do not clean cache

* [Minor] Minor update for Python versions and dependencies

* [Lint] fix lint for py39

* [Lint] fix lint for ROCm

* [Build][CI] Sync CI changes from upstream/sdist

* [Lint] fix lint for ROCm

* [Build][CI] Update `repair-wheel-command`

* [Minor] update abi3audit result format

* [Lint] fix lint for ROCm

* [BugFix] fix build

* [Lint] fix lint for ROCm

* [BugFix] set rpath for libtvm and libtvm_runtime

* [Deps] pin apache-tvm-ffi version

* [Build] set Python 3.9 Limited API for Cython target

* [Build] set Python 3.9 Limited API for Cython target

* [Deps] Restore Python 3.8 support

* [Build] use `apache-tvm-ffi`'s `libtvm_ffi`

* [BugFix] use `;` as delimiter for RPATH on macOS

* [BugFix] use `--ignore-missing-dependencies` for `delocate-wheel`

* [Build] support `sccache` if available

* [Build] add CIBW import test

* [Build][CI] enable ccache for CIBW on Linux

* [BugFix] set rpath for libtvm and libtvm_runtime

* Revert "[Build][CI] enable ccache for CIBW on Linux"

This reverts commit cd9ab57bb5ddd2572c60bcbbebde81480a658fd3.

* [CI] fix perfbench bot

* [BugFix] use Python 3.9 to build wheel

* [Minor] update perfbench bot envs

* [BugFix] fix CIBW environment on Linux

* [CI] skip import test on CentOS 7

* [CI] use Python urllib to download file instead of Wget

---------
Co-authored-by: default avatarXuehai Pan <XuehaiPan@pku.edu.cn>
parent c37621c5
...@@ -46,13 +46,13 @@ public: ...@@ -46,13 +46,13 @@ public:
Stmt VisitStmt_(const AttrStmtNode *op) final { Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tvm::attr::kTarget) { if (op->attr_key == tvm::attr::kTarget) {
// If a target attribute already exists, use it as-is. // If a target attribute already exists, use it as-is.
return GetRef<Stmt>(op); return tvm::ffi::GetRef<Stmt>(op);
} else if (op->attr_key == tir::attr::thread_extent || } else if (op->attr_key == tir::attr::thread_extent ||
op->attr_key == tir::attr::pipeline_exec_scope || op->attr_key == tir::attr::pipeline_exec_scope ||
op->attr_key == tir::attr::device_scope) { op->attr_key == tir::attr::device_scope) {
// These attributes are only allowed in device-side code, so // These attributes are only allowed in device-side code, so
// they should be annotated with the function's default target. // they should be annotated with the function's default target.
Stmt body = GetRef<Stmt>(op); Stmt body = tvm::ffi::GetRef<Stmt>(op);
return AttrStmt(device_target_, tvm::attr::kTarget, 0, body); return AttrStmt(device_target_, tvm::attr::kTarget, 0, body);
} else { } else {
// All other annotations are ignored // All other annotations are ignored
...@@ -90,11 +90,11 @@ tvm::transform::Pass AnnotateDeviceRegions() { ...@@ -90,11 +90,11 @@ tvm::transform::Pass AnnotateDeviceRegions() {
return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateDeviceRegions", {}); return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateDeviceRegions", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.AnnotateDeviceRegions", refl::GlobalDef().def("tl.transform.AnnotateDeviceRegions",
AnnotateDeviceRegions); AnnotateDeviceRegions);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -181,11 +181,11 @@ tvm::transform::Pass AnnotateWarpGroupRegAlloc() { ...@@ -181,11 +181,11 @@ tvm::transform::Pass AnnotateWarpGroupRegAlloc() {
return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateWarpGroupRegAlloc", {}); return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateWarpGroupRegAlloc", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.AnnotateWarpGroupRegAlloc", refl::GlobalDef().def("tl.transform.AnnotateWarpGroupRegAlloc",
AnnotateWarpGroupRegAlloc); AnnotateWarpGroupRegAlloc);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -80,8 +80,8 @@ void ArgBinder::Bind(const PrimExpr &arg, const PrimExpr &value, ...@@ -80,8 +80,8 @@ void ArgBinder::Bind(const PrimExpr &arg, const PrimExpr &value,
Bind_(arg, value, arg_name, with_let); Bind_(arg, value, arg_name, with_let);
} }
void ArgBinder::BindArray(const Array<PrimExpr> &arg, void ArgBinder::BindArray(const ffi::Array<PrimExpr> &arg,
const Array<PrimExpr> &value, const ffi::Array<PrimExpr> &value,
const std::string &arg_name) { const std::string &arg_name) {
ICHECK_EQ(arg.size(), value.size()) ICHECK_EQ(arg.size(), value.size())
<< "Argument " << arg_name << " array size mismatch"; << "Argument " << arg_name << " array size mismatch";
...@@ -250,7 +250,7 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, ...@@ -250,7 +250,7 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
// Assert the buffer is compact // Assert the buffer is compact
DataType stype = buffer->DefaultIndexType(); DataType stype = buffer->DefaultIndexType();
PrimExpr expect_stride = make_const(stype, 1); PrimExpr expect_stride = make_const(stype, 1);
Array<PrimExpr> conds; ffi::Array<PrimExpr> conds;
for (size_t i = buffer->shape.size(); i != 0; --i) { for (size_t i = buffer->shape.size(); i != 0; --i) {
size_t k = i - 1; size_t k = i - 1;
PrimExpr svalue = PrimExpr svalue =
......
...@@ -82,7 +82,8 @@ public: ...@@ -82,7 +82,8 @@ public:
* \param value The target expression value * \param value The target expression value
* \param arg_name argument name. * \param arg_name argument name.
*/ */
void BindArray(const Array<PrimExpr> &arg, const Array<PrimExpr> &value, void BindArray(const ffi::Array<PrimExpr> &arg,
const ffi::Array<PrimExpr> &value,
const std::string &arg_name); const std::string &arg_name);
/*! /*!
* \brief Bind symbolic buffer to another symbolic buffer * \brief Bind symbolic buffer to another symbolic buffer
...@@ -149,7 +150,7 @@ public: ...@@ -149,7 +150,7 @@ public:
*/ */
const std::vector<Stmt> &init_nest() const { return init_nest_; } const std::vector<Stmt> &init_nest() const { return init_nest_; }
/*! \return Handle data type of the data */ /*! \return Handle data type of the data */
const Map<Var, PrimExpr> &def_handle_dtype() const { const ffi::Map<Var, PrimExpr> &def_handle_dtype() const {
return def_handle_dtype_; return def_handle_dtype_;
} }
...@@ -164,7 +165,7 @@ private: ...@@ -164,7 +165,7 @@ private:
/*! \brief Initialize nest */ /*! \brief Initialize nest */
std::vector<Stmt> init_nest_; std::vector<Stmt> init_nest_;
/*! \brief handle data type in the defintiions */ /*! \brief handle data type in the defintiions */
Map<Var, PrimExpr> def_handle_dtype_; ffi::Map<Var, PrimExpr> def_handle_dtype_;
/*! \brief asserts generated */ /*! \brief asserts generated */
std::vector<Stmt> asserts_; std::vector<Stmt> asserts_;
/*! \brief internal analyzer. */ /*! \brief internal analyzer. */
......
...@@ -249,7 +249,6 @@ private: ...@@ -249,7 +249,6 @@ private:
new_args.push_back(dst_node); new_args.push_back(dst_node);
new_args.push_back(value_node); new_args.push_back(value_node);
} }
new_args.push_back(memory_order); new_args.push_back(memory_order);
Call new_call = Call new_call =
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include "../support/ffi_aliases.h"
namespace tvm { namespace tvm {
namespace tir { namespace tir {
...@@ -66,7 +68,8 @@ public: ...@@ -66,7 +68,8 @@ public:
} }
if (mem_reuse_max > 0) { if (mem_reuse_max > 0) {
std::string tag_str = cluster_tag; // Convert to std::string std::string tag_str =
static_cast<std::string>(cluster_tag); // Convert to std::string
if (tag_str.rfind("blockIdx", 0) == 0) { if (tag_str.rfind("blockIdx", 0) == 0) {
// starts with "blockIdx" // starts with "blockIdx"
tag_str = "clusterIdx" + tag_str.substr(strlen("blockIdx")); tag_str = "clusterIdx" + tag_str.substr(strlen("blockIdx"));
...@@ -74,7 +77,7 @@ public: ...@@ -74,7 +77,7 @@ public:
// Unexpected format — maybe just prefix // Unexpected format — maybe just prefix
tag_str = "clusterIdx" + tag_str; tag_str = "clusterIdx" + tag_str;
} }
cluster_tag = tvm::ffi::String(tag_str); // Convert back cluster_tag = String(tag_str); // Convert back
return WithAttr(f, cluster_tag, Integer(cluster_size_)); return WithAttr(f, cluster_tag, Integer(cluster_size_));
} else { } else {
return f; return f;
...@@ -122,10 +125,10 @@ tvm::transform::Pass ClusterPlanning() { ...@@ -122,10 +125,10 @@ tvm::transform::Pass ClusterPlanning() {
return CreatePrimFuncPass(pass_func, 0, "tl.ClusterPlanning", {}); return CreatePrimFuncPass(pass_func, 0, "tl.ClusterPlanning", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.ClusterPlanning", ClusterPlanning); refl::GlobalDef().def("tl.transform.ClusterPlanning", ClusterPlanning);
}); }
} // namespace transform } // namespace transform
} // namespace tir } // namespace tir
......
...@@ -41,7 +41,7 @@ public: ...@@ -41,7 +41,7 @@ public:
return StmtMutator::VisitStmt_(op); return StmtMutator::VisitStmt_(op);
// Collect loop variables and ranges // Collect loop variables and ranges
auto for_node = GetRef<For>(op); auto for_node = tvm::ffi::GetRef<For>(op);
Array<Var> loop_vars; Array<Var> loop_vars;
Array<PrimExpr> loop_extents; Array<PrimExpr> loop_extents;
Stmt body = op->body; Stmt body = op->body;
...@@ -81,7 +81,7 @@ public: ...@@ -81,7 +81,7 @@ public:
// post order visit the index // post order visit the index
PostOrderVisit(index, [&](const ObjectRef &obj) { PostOrderVisit(index, [&](const ObjectRef &obj) {
if (const VarNode *v = obj.as<VarNode>()) { if (const VarNode *v = obj.as<VarNode>()) {
used_vars.insert(GetRef<Var>(v)); used_vars.insert(tvm::ffi::GetRef<Var>(v));
} }
}); });
if (used_vars.empty()) { if (used_vars.empty()) {
......
...@@ -211,7 +211,7 @@ public: ...@@ -211,7 +211,7 @@ public:
PrimExpr a = this->VisitExpr(op->a); PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b); PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) { if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector(); bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector();
bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector(); bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector();
...@@ -265,7 +265,7 @@ public: ...@@ -265,7 +265,7 @@ public:
PrimExpr VisitExpr_(const NotNode *op) final { PrimExpr VisitExpr_(const NotNode *op) final {
PrimExpr a = this->VisitExpr(op->a); PrimExpr a = this->VisitExpr(op->a);
if (a.same_as(op->a)) { if (a.same_as(op->a)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
return !(a); return !(a);
} }
...@@ -306,10 +306,10 @@ public: ...@@ -306,10 +306,10 @@ public:
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
if (value.dtype().is_scalable_or_fixed_length_vector()) { if (value.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true; need_scalarize_ = true;
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
if (value.same_as(op->value)) { if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
return Broadcast(op->value, op->lanes); return Broadcast(op->value, op->lanes);
} }
...@@ -321,7 +321,7 @@ public: ...@@ -321,7 +321,7 @@ public:
PrimExpr f = this->VisitExpr(op->false_value); PrimExpr f = this->VisitExpr(op->false_value);
if (cond.same_as(op->condition) && t.same_as(op->true_value) && if (cond.same_as(op->condition) && t.same_as(op->true_value) &&
f.same_as(op->false_value)) { f.same_as(op->false_value)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
int cond_lanes = cond.dtype().get_lanes_or_vscale_factor(); int cond_lanes = cond.dtype().get_lanes_or_vscale_factor();
int t_lanes = t.dtype().get_lanes_or_vscale_factor(); int t_lanes = t.dtype().get_lanes_or_vscale_factor();
...@@ -339,7 +339,7 @@ public: ...@@ -339,7 +339,7 @@ public:
PrimExpr VisitExpr_(const CastNode *op) final { PrimExpr VisitExpr_(const CastNode *op) final {
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) { if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
if (value.dtype().is_scalable_vector()) { if (value.dtype().is_scalable_vector()) {
return Cast(op->dtype.with_scalable_vscale_factor( return Cast(op->dtype.with_scalable_vscale_factor(
...@@ -352,20 +352,20 @@ public: ...@@ -352,20 +352,20 @@ public:
} }
PrimExpr VisitExpr_(const FloatImmNode *op) final { PrimExpr VisitExpr_(const FloatImmNode *op) final {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
PrimExpr VisitExpr_(const IntImmNode *op) final { PrimExpr VisitExpr_(const IntImmNode *op) final {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
PrimExpr VisitExpr_(const StringImmNode *op) final { PrimExpr VisitExpr_(const StringImmNode *op) final {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
// Variable // Variable
PrimExpr VisitExpr_(const VarNode *op) final { PrimExpr VisitExpr_(const VarNode *op) final {
Var var = GetRef<Var>(op); Var var = tvm::ffi::GetRef<Var>(op);
if (var.same_as(var_)) { if (var.same_as(var_)) {
return ramp_; return ramp_;
...@@ -382,13 +382,13 @@ public: ...@@ -382,13 +382,13 @@ public:
PrimExpr cond = this->VisitExpr(op->args[0]); PrimExpr cond = this->VisitExpr(op->args[0]);
if (cond.dtype().is_scalable_or_fixed_length_vector()) { if (cond.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true; need_scalarize_ = true;
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
PrimExpr t = this->VisitExpr(op->args[1]); PrimExpr t = this->VisitExpr(op->args[1]);
PrimExpr f = this->VisitExpr(op->args[2]); PrimExpr f = this->VisitExpr(op->args[2]);
if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) &&
f.same_as(op->args[2])) { f.same_as(op->args[2])) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
int t_lanes = t.dtype().get_lanes_or_vscale_factor(); int t_lanes = t.dtype().get_lanes_or_vscale_factor();
int f_lanes = f.dtype().get_lanes_or_vscale_factor(); int f_lanes = f.dtype().get_lanes_or_vscale_factor();
...@@ -410,7 +410,7 @@ public: ...@@ -410,7 +410,7 @@ public:
ICHECK(op->op.same_as(builtin::reinterpret())); ICHECK(op->op.same_as(builtin::reinterpret()));
PrimExpr value = this->VisitExpr(op->args[0]); PrimExpr value = this->VisitExpr(op->args[0]);
if (value.same_as(op->args[0])) { if (value.same_as(op->args[0])) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
int lanes = value.dtype().get_lanes_or_vscale_factor(); int lanes = value.dtype().get_lanes_or_vscale_factor();
if (value.dtype().is_scalable_vector()) { if (value.dtype().is_scalable_vector()) {
...@@ -455,12 +455,12 @@ public: ...@@ -455,12 +455,12 @@ public:
auto new_arg = this->VisitExpr(arg); auto new_arg = this->VisitExpr(arg);
if (new_arg.dtype().is_scalable_or_fixed_length_vector()) { if (new_arg.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true; need_scalarize_ = true;
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
new_args.push_back(new_arg); new_args.push_back(new_arg);
} }
if (op->args.same_as(new_args)) { if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
return Call(op->dtype, op->op, new_args); return Call(op->dtype, op->op, new_args);
} }
...@@ -469,7 +469,7 @@ public: ...@@ -469,7 +469,7 @@ public:
Array<PrimExpr> new_args = MutateArray(op->args, &lane); Array<PrimExpr> new_args = MutateArray(op->args, &lane);
// normal code path. // normal code path.
if (op->args.same_as(new_args)) { if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
return Call(op->dtype.with_lanes(lane), op->op, new_args); return Call(op->dtype.with_lanes(lane), op->op, new_args);
} }
...@@ -477,7 +477,7 @@ public: ...@@ -477,7 +477,7 @@ public:
} }
// BufferLoad // BufferLoad
PrimExpr VisitExpr_(const BufferLoadNode *op) final { PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = GetRef<BufferLoad>(op); auto load = tvm::ffi::GetRef<BufferLoad>(op);
auto fmutate = [this](const PrimExpr &index) { auto fmutate = [this](const PrimExpr &index) {
return this->VisitExpr(index); return this->VisitExpr(index);
...@@ -514,7 +514,7 @@ public: ...@@ -514,7 +514,7 @@ public:
let_binding_[op->var] = op->var; let_binding_[op->var] = op->var;
PrimExpr body = this->VisitExpr(op->body); PrimExpr body = this->VisitExpr(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) { if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
return Let(op->var, value, body); return Let(op->var, value, body);
} }
...@@ -522,7 +522,7 @@ public: ...@@ -522,7 +522,7 @@ public:
} }
// BufferStore // BufferStore
Stmt VisitStmt_(const BufferStoreNode *op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store = GetRef<BufferStore>(op); auto store = tvm::ffi::GetRef<BufferStore>(op);
auto fmutate = [this](const PrimExpr &index) { auto fmutate = [this](const PrimExpr &index) {
return this->VisitExpr(index); return this->VisitExpr(index);
...@@ -585,11 +585,11 @@ public: ...@@ -585,11 +585,11 @@ public:
ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector()); ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector());
PrimExpr extent = this->VisitExpr(op->extent); PrimExpr extent = this->VisitExpr(op->extent);
if (extent.dtype().is_scalable_or_fixed_length_vector()) { if (extent.dtype().is_scalable_or_fixed_length_vector()) {
return Scalarize(GetRef<Stmt>(op)); return Scalarize(tvm::ffi::GetRef<Stmt>(op));
} }
Stmt body = this->VisitStmt(op->body); Stmt body = this->VisitStmt(op->body);
if (extent.same_as(op->extent) && body.same_as(op->body)) { if (extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op); return tvm::ffi::GetRef<Stmt>(op);
} else { } else {
return For(op->loop_var, op->min, extent, op->kind, body, return For(op->loop_var, op->min, extent, op->kind, body,
op->thread_binding, op->annotations); op->thread_binding, op->annotations);
...@@ -600,7 +600,7 @@ public: ...@@ -600,7 +600,7 @@ public:
ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector()); ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector());
PrimExpr condition = this->VisitExpr(op->condition); PrimExpr condition = this->VisitExpr(op->condition);
if (condition.dtype().is_scalable_or_fixed_length_vector()) { if (condition.dtype().is_scalable_or_fixed_length_vector()) {
return Scalarize(GetRef<Stmt>(op)); return Scalarize(tvm::ffi::GetRef<Stmt>(op));
} }
Stmt then_case = this->VisitStmt(op->then_case); Stmt then_case = this->VisitStmt(op->then_case);
Optional<Stmt> else_case = std::nullopt; Optional<Stmt> else_case = std::nullopt;
...@@ -609,7 +609,7 @@ public: ...@@ -609,7 +609,7 @@ public:
} }
if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) { else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op); return tvm::ffi::GetRef<Stmt>(op);
} else { } else {
return IfThenElse(condition, then_case, else_case); return IfThenElse(condition, then_case, else_case);
} }
...@@ -634,7 +634,7 @@ public: ...@@ -634,7 +634,7 @@ public:
let_binding_[op->var] = op->var; let_binding_[op->var] = op->var;
Stmt body = this->VisitStmt(op->body); Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) { if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op); return tvm::ffi::GetRef<Stmt>(op);
} else { } else {
return LetStmt(op->var, value, body); return LetStmt(op->var, value, body);
} }
...@@ -647,7 +647,7 @@ public: ...@@ -647,7 +647,7 @@ public:
if (condition.dtype().is_scalable_or_fixed_length_vector()) { if (condition.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of " LOG(WARNING) << "Cannot handle vector extent in alloc of "
<< op->buffer_var->name_hint; << op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op)); return Scalarize(tvm::ffi::GetRef<Stmt>(op));
} }
// Mutate the extents // Mutate the extents
...@@ -657,7 +657,7 @@ public: ...@@ -657,7 +657,7 @@ public:
if (new_ext.dtype().is_scalable_or_fixed_length_vector()) { if (new_ext.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of " LOG(WARNING) << "Cannot handle vector extent in alloc of "
<< op->buffer_var->name_hint; << op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op)); return Scalarize(tvm::ffi::GetRef<Stmt>(op));
} }
extents.push_back(new_ext); extents.push_back(new_ext);
} }
...@@ -738,7 +738,7 @@ private: ...@@ -738,7 +738,7 @@ private:
PrimExpr a = this->VisitExpr(op->a); PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b); PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) { if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
int a_lanes = a.dtype().get_lanes_or_vscale_factor(); int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor(); int b_lanes = b.dtype().get_lanes_or_vscale_factor();
...@@ -754,7 +754,7 @@ private: ...@@ -754,7 +754,7 @@ private:
PrimExpr a = this->VisitExpr(op->a); PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b); PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) { if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
int a_lanes = a.dtype().get_lanes_or_vscale_factor(); int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor(); int b_lanes = b.dtype().get_lanes_or_vscale_factor();
......
...@@ -38,7 +38,7 @@ protected: ...@@ -38,7 +38,7 @@ protected:
if (is_enabled_ && op->dtype.is_int() && op->dtype.bits() < 64) { if (is_enabled_ && op->dtype.is_int() && op->dtype.bits() < 64) {
return IntImm(DataType::Int(_index_bitwidth_), op->value); return IntImm(DataType::Int(_index_bitwidth_), op->value);
} }
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
PrimExpr VisitExpr_(const CastNode *op) final { PrimExpr VisitExpr_(const CastNode *op) final {
...@@ -88,23 +88,23 @@ private: ...@@ -88,23 +88,23 @@ private:
PrimExpr VisitExpr_(const VarNode *op) final { PrimExpr VisitExpr_(const VarNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) { if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), GetRef<Var>(op)); return cast(DataType::Int(64), tvm::ffi::GetRef<Var>(op));
} }
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
PrimExpr VisitExpr_(const IntImmNode *op) final { PrimExpr VisitExpr_(const IntImmNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) { if (op->dtype.is_int() && op->dtype.bits() < 64) {
return IntImm(DataType::Int(64), op->value); return IntImm(DataType::Int(64), op->value);
} }
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
PrimExpr VisitExpr_(const CastNode *op) final { PrimExpr VisitExpr_(const CastNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) { if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), op->value); return cast(DataType::Int(64), op->value);
} }
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
Stmt VisitStmt_(const BufferStoreNode *op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
...@@ -183,11 +183,11 @@ tvm::transform::Pass ConfigIndexBitwidth() { ...@@ -183,11 +183,11 @@ tvm::transform::Pass ConfigIndexBitwidth() {
return CreatePrimFuncPass(pass_func, 0, "tl.ConfigIndexBitwidth", {}); return CreatePrimFuncPass(pass_func, 0, "tl.ConfigIndexBitwidth", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.ConfigIndexBitwidth", refl::GlobalDef().def("tl.transform.ConfigIndexBitwidth",
ConfigIndexBitwidth); ConfigIndexBitwidth);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -35,9 +35,7 @@ public: ...@@ -35,9 +35,7 @@ public:
Stmt VisitStmt_(const AttrStmtNode *op) final { Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == "thread_extent") { if (op->attr_key == "thread_extent") {
const VarNode *var = nullptr; if (const auto *var = op->node.as<VarNode>()) {
if (op->node->IsInstance<VarNode>()) {
var = op->node.as<VarNode>();
if (var->name_hint == "threadIdx.x") { if (var->name_hint == "threadIdx.x") {
thread_extent_ = op; thread_extent_ = op;
} }
...@@ -82,7 +80,7 @@ public: ...@@ -82,7 +80,7 @@ public:
} }
Stmt VisitStmt_(const ForNode *op) final { Stmt VisitStmt_(const ForNode *op) final {
PostOrderVisit(GetRef<For>(op), [&](const ObjectRef &node) { PostOrderVisit(tvm::ffi::GetRef<For>(op), [&](const ObjectRef &node) {
if (const auto *call = node.as<CallNode>()) { if (const auto *call = node.as<CallNode>()) {
if (call->op.same_as(create_list_of_mbarrier()) || if (call->op.same_as(create_list_of_mbarrier()) ||
call->op.same_as(mbarrier_wait_parity()) || call->op.same_as(mbarrier_wait_parity()) ||
...@@ -116,11 +114,11 @@ tvm::transform::Pass EliminateStorageSyncForMBarrier() { ...@@ -116,11 +114,11 @@ tvm::transform::Pass EliminateStorageSyncForMBarrier() {
{}); {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.EliminateStorageSyncForMBarrier", refl::GlobalDef().def("tl.transform.EliminateStorageSyncForMBarrier",
EliminateStorageSyncForMBarrier); EliminateStorageSyncForMBarrier);
}); }
} // namespace transform } // namespace transform
} // namespace tl } // namespace tl
......
...@@ -75,23 +75,23 @@ private: ...@@ -75,23 +75,23 @@ private:
PrimExpr VisitExpr_(const VarNode *op) final { PrimExpr VisitExpr_(const VarNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) { if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), GetRef<Var>(op)); return cast(DataType::Int(64), tvm::ffi::GetRef<Var>(op));
} }
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
PrimExpr VisitExpr_(const IntImmNode *op) final { PrimExpr VisitExpr_(const IntImmNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) { if (op->dtype.is_int() && op->dtype.bits() < 64) {
return IntImm(DataType::Int(64), op->value); return IntImm(DataType::Int(64), op->value);
} }
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
PrimExpr VisitExpr_(const CastNode *op) final { PrimExpr VisitExpr_(const CastNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) { if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), op->value); return cast(DataType::Int(64), op->value);
} }
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
Stmt VisitStmt_(const BufferStoreNode *op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
...@@ -115,7 +115,7 @@ private: ...@@ -115,7 +115,7 @@ private:
<< "All MatchBufferRegion should be removed in " << "All MatchBufferRegion should be removed in "
"tir.transform.LowerMatchBuffer."; "tir.transform.LowerMatchBuffer.";
Block block = GetRef<Block>(op); Block block = tvm::ffi::GetRef<Block>(op);
Array<Buffer> alloc_buffers = op->alloc_buffers; Array<Buffer> alloc_buffers = op->alloc_buffers;
alloc_buffers.MutateByApply( alloc_buffers.MutateByApply(
...@@ -385,10 +385,10 @@ tvm::transform::Pass FlattenBuffer() { ...@@ -385,10 +385,10 @@ tvm::transform::Pass FlattenBuffer() {
return CreatePrimFuncPass(pass_func, 0, "tl.FlattenBuffer", {}); return CreatePrimFuncPass(pass_func, 0, "tl.FlattenBuffer", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.FlattenBuffer", FlattenBuffer); refl::GlobalDef().def("tl.transform.FlattenBuffer", FlattenBuffer);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -89,10 +89,10 @@ Pass LetInline() { ...@@ -89,10 +89,10 @@ Pass LetInline() {
return CreatePrimFuncPass(pass_func, 0, "tl.LetInline", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LetInline", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LetInline", LetInline); refl::GlobalDef().def("tl.transform.LetInline", LetInline);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -33,7 +33,7 @@ private: ...@@ -33,7 +33,7 @@ private:
auto then_case = VisitStmt(op->then_case); auto then_case = VisitStmt(op->then_case);
Optional<Stmt> else_case = op->else_case; Optional<Stmt> else_case = op->else_case;
if (else_case.defined()) { if (else_case.defined()) {
return GetRef<Stmt>(op); return tvm::ffi::GetRef<Stmt>(op);
} }
ICHECK(then_case.defined()) << "then_case must be defined"; ICHECK(then_case.defined()) << "then_case must be defined";
ICHECK(!else_case.defined()) << "else_case must be undefined"; ICHECK(!else_case.defined()) << "else_case must be undefined";
...@@ -81,10 +81,10 @@ tvm::transform::Pass IfStmtBinding() { ...@@ -81,10 +81,10 @@ tvm::transform::Pass IfStmtBinding() {
return CreatePrimFuncPass(pass_func, 0, "tl.IfStmtBinding", {}); return CreatePrimFuncPass(pass_func, 0, "tl.IfStmtBinding", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.IfStmtBinding", IfStmtBinding); refl::GlobalDef().def("tl.transform.IfStmtBinding", IfStmtBinding);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -156,9 +156,9 @@ tvm::transform::Pass InjectAssumes() { ...@@ -156,9 +156,9 @@ tvm::transform::Pass InjectAssumes() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectAssumes", {}); return CreatePrimFuncPass(pass_func, 0, "tl.InjectAssumes", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectAssumes", InjectAssumes); refl::GlobalDef().def("tl.transform.InjectAssumes", InjectAssumes);
}); }
} // namespace tvm::tl } // namespace tvm::tl
...@@ -319,10 +319,10 @@ tvm::transform::Pass InjectFenceProxy() { ...@@ -319,10 +319,10 @@ tvm::transform::Pass InjectFenceProxy() {
{}); {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectFenceProxy", InjectFenceProxy); refl::GlobalDef().def("tl.transform.InjectFenceProxy", InjectFenceProxy);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
using namespace ffi;
namespace software_pipeline { namespace software_pipeline {
/*! /*!
...@@ -459,7 +459,8 @@ private: ...@@ -459,7 +459,8 @@ private:
* \return The resized buffer. * \return The resized buffer.
*/ */
Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) { Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get())); ObjectPtr<BufferNode> new_buffer =
tvm::ffi::make_object<BufferNode>(*(buffer.get()));
new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions)); new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
if (!new_buffer->strides.empty()) { if (!new_buffer->strides.empty()) {
ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size()); ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
...@@ -865,7 +866,7 @@ private: ...@@ -865,7 +866,7 @@ private:
const SeqStmtNode *pipeline_body_seq = nullptr; const SeqStmtNode *pipeline_body_seq = nullptr;
std::vector<std::function<Stmt(Stmt)>> rewrap_fns; std::vector<std::function<Stmt(Stmt)>> rewrap_fns;
auto append_attr_wrapper = [&rewrap_fns](const AttrStmtNode *attr) { auto append_attr_wrapper = [&rewrap_fns](const AttrStmtNode *attr) {
ObjectRef node = attr->node; Any node = attr->node;
String attr_key = attr->attr_key; String attr_key = attr->attr_key;
PrimExpr value = attr->value; PrimExpr value = attr->value;
Span span = attr->span; Span span = attr->span;
...@@ -981,7 +982,7 @@ private: ...@@ -981,7 +982,7 @@ private:
// Step 4: Rewrite the pipeline body. // Step 4: Rewrite the pipeline body.
Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs, Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
GetRef<For>(op), pipeline_info) tvm::ffi::GetRef<For>(op), pipeline_info)
.BuildPipeline(); .BuildPipeline();
auto apply_wrappers = [&](Stmt stmt) { auto apply_wrappers = [&](Stmt stmt) {
for (auto it = rewrap_fns.rbegin(); it != rewrap_fns.rend(); ++it) { for (auto it = rewrap_fns.rbegin(); it != rewrap_fns.rend(); ++it) {
...@@ -1072,11 +1073,11 @@ tir::transform::Pass InjectSoftwarePipeline() { ...@@ -1072,11 +1073,11 @@ tir::transform::Pass InjectSoftwarePipeline() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {}); return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectSoftwarePipeline", refl::GlobalDef().def("tl.transform.InjectSoftwarePipeline",
InjectSoftwarePipeline); InjectSoftwarePipeline);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -232,10 +232,10 @@ tvm::transform::Pass InjectPTXAsyncCopy() { ...@@ -232,10 +232,10 @@ tvm::transform::Pass InjectPTXAsyncCopy() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectPTXAsyncCopy", {}); return CreatePrimFuncPass(pass_func, 0, "tl.InjectPTXAsyncCopy", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectPTXAsyncCopy", InjectPTXAsyncCopy); refl::GlobalDef().def("tl.transform.InjectPTXAsyncCopy", InjectPTXAsyncCopy);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -204,9 +204,9 @@ private: ...@@ -204,9 +204,9 @@ private:
void VisitStmt_(const EvaluateNode *op) final { void VisitStmt_(const EvaluateNode *op) final {
if (const auto *call = op->value.as<CallNode>()) { if (const auto *call = op->value.as<CallNode>()) {
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
pending_tma_ops_.push_back(GetRef<Call>(call)); pending_tma_ops_.push_back(tvm::ffi::GetRef<Call>(call));
} else if (call->op.same_as(mbarrier_expect_tx())) { } else if (call->op.same_as(mbarrier_expect_tx())) {
pending_tma_ops_.push_back(GetRef<Call>(call)); pending_tma_ops_.push_back(tvm::ffi::GetRef<Call>(call));
} else if (call->op.same_as(builtin::ptx_arrive_barrier())) { } else if (call->op.same_as(builtin::ptx_arrive_barrier())) {
PrimExpr barrier_id = call->args[0]; PrimExpr barrier_id = call->args[0];
for (const auto &tma_call : pending_tma_ops_) { for (const auto &tma_call : pending_tma_ops_) {
...@@ -295,8 +295,9 @@ public: ...@@ -295,8 +295,9 @@ public:
void VisitExpr_(const CallNode *op) final { void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(mbarrier_expect_tx())) { if (op->op.same_as(mbarrier_expect_tx())) {
PrimExpr e = PrimExpr e = tma_op_to_barrier_id_[tvm::ffi::GetRef<Call>(op)]
tma_op_to_barrier_id_[GetRef<Call>(op)].as<CallNode>()->args[0]; .as<CallNode>()
->args[0];
auto int_set = arith::EvalSet(e, var_int_set_); auto int_set = arith::EvalSet(e, var_int_set_);
expect_.push_back(if_depth_ == 1); expect_.push_back(if_depth_ == 1);
sequence.push_back(0); sequence.push_back(0);
...@@ -406,7 +407,7 @@ public: ...@@ -406,7 +407,7 @@ public:
private: private:
Stmt VisitStmt_(const BlockNode *op) { Stmt VisitStmt_(const BlockNode *op) {
auto block = GetRef<Block>(op); auto block = tvm::ffi::GetRef<Block>(op);
if (!has_create_list_of_mbarrier_ && !barrier_id_to_range_.empty() && if (!has_create_list_of_mbarrier_ && !barrier_id_to_range_.empty() &&
op->name_hint == MainBlockName) { op->name_hint == MainBlockName) {
ICHECK(false) << "Please declare create_list_of_mbarrier."; ICHECK(false) << "Please declare create_list_of_mbarrier.";
...@@ -453,9 +454,9 @@ private: ...@@ -453,9 +454,9 @@ private:
PrimExpr VisitExpr_(const CallNode *op) { PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) { if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
// check this must be in the tma_op_to_barrier_id_ // check this must be in the tma_op_to_barrier_id_
ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op))) ICHECK(tma_op_to_barrier_id_.count(tvm::ffi::GetRef<Call>(op)))
<< "tma_load must be in the tma_op_to_barrier_id_"; << "tma_load must be in the tma_op_to_barrier_id_";
auto barrier_id = tma_op_to_barrier_id_[GetRef<Call>(op)]; auto barrier_id = tma_op_to_barrier_id_[tvm::ffi::GetRef<Call>(op)];
auto new_args = op->args; auto new_args = op->args;
auto arg0 = op->args[0].as<Call>(); auto arg0 = op->args[0].as<Call>();
auto is_1d_tma_load = auto is_1d_tma_load =
...@@ -468,9 +469,9 @@ private: ...@@ -468,9 +469,9 @@ private:
} }
return Call(op->dtype, op->op, new_args); return Call(op->dtype, op->op, new_args);
} else if (op->op.same_as(mbarrier_expect_tx())) { } else if (op->op.same_as(mbarrier_expect_tx())) {
ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op))) ICHECK(tma_op_to_barrier_id_.count(tvm::ffi::GetRef<Call>(op)))
<< "mbarrier_expect_tx must be in the tma_op_to_barrier_id_"; << "mbarrier_expect_tx must be in the tma_op_to_barrier_id_";
auto barrier_id = tma_op_to_barrier_id_[GetRef<Call>(op)]; auto barrier_id = tma_op_to_barrier_id_[tvm::ffi::GetRef<Call>(op)];
auto new_args = op->args; auto new_args = op->args;
new_args.Set(0, barrier_id); new_args.Set(0, barrier_id);
if (!has_warp_specialization_) if (!has_warp_specialization_)
...@@ -522,10 +523,10 @@ tvm::transform::Pass InjectTmaBarrier() { ...@@ -522,10 +523,10 @@ tvm::transform::Pass InjectTmaBarrier() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectTmaBarrier", {}); return CreatePrimFuncPass(pass_func, 0, "tl.InjectTmaBarrier", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectTmaBarrier", InjectTmaBarrier); refl::GlobalDef().def("tl.transform.InjectTmaBarrier", InjectTmaBarrier);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -330,7 +330,7 @@ private: ...@@ -330,7 +330,7 @@ private:
if (op->op.as<GlobalVarNode>()) if (op->op.as<GlobalVarNode>())
return; return;
auto p = ParseOperator(GetRef<Call>(op), buffer_data_to_buffer_); auto p = ParseOperator(tvm::ffi::GetRef<Call>(op), buffer_data_to_buffer_);
if (p.defined()) { if (p.defined()) {
for (const auto &arg : op->args) { for (const auto &arg : op->args) {
if (auto buffer = getBufferFromAccessPtr(arg)) { if (auto buffer = getBufferFromAccessPtr(arg)) {
...@@ -381,7 +381,7 @@ private: ...@@ -381,7 +381,7 @@ private:
} }
// Add the tile operator to infer_list_ // Add the tile operator to infer_list_
infer_list_stmt_.push_back(GetRef<ObjectRef>(op)); infer_list_stmt_.push_back(tvm::ffi::GetRef<ObjectRef>(op));
infer_list_.push_back(std::move(p)); infer_list_.push_back(std::move(p));
} }
} }
...@@ -416,11 +416,11 @@ private: ...@@ -416,11 +416,11 @@ private:
void VisitStmt_(const ForNode *op) final { void VisitStmt_(const ForNode *op) final {
if (op->kind == ForKind::kParallel) { if (op->kind == ForKind::kParallel) {
auto infer = ParallelOp(GetRef<For>(op)); auto infer = ParallelOp(tvm::ffi::GetRef<For>(op));
for (const auto &[buffer, _] : infer->GetIndiceMap()) { for (const auto &[buffer, _] : infer->GetIndiceMap()) {
addToUseList(buffer); addToUseList(buffer);
} }
infer_list_stmt_.push_back(GetRef<ObjectRef>(op)); infer_list_stmt_.push_back(tvm::ffi::GetRef<ObjectRef>(op));
infer_list_.push_back(std::move(infer)); infer_list_.push_back(std::move(infer));
thread_var_vec_.push_back(thread_var_); thread_var_vec_.push_back(thread_var_);
if (thread_var_.defined() && if (thread_var_.defined() &&
...@@ -713,8 +713,8 @@ private: ...@@ -713,8 +713,8 @@ private:
.value(); .value();
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op)); For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
if (result_.for_map.count(GetRef<For>(op))) { if (result_.for_map.count(tvm::ffi::GetRef<For>(op))) {
auto root = GetRef<For>(op); auto root = tvm::ffi::GetRef<For>(op);
// This check is a workaround to support T.Parallel for local buffers. // This check is a workaround to support T.Parallel for local buffers.
// For example: // For example:
// for i in T.Parallel(1024): // for i in T.Parallel(1024):
...@@ -844,10 +844,10 @@ tvm::transform::Pass LayoutInference() { ...@@ -844,10 +844,10 @@ tvm::transform::Pass LayoutInference() {
return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LayoutInference", LayoutInference); refl::GlobalDef().def("tl.transform.LayoutInference", LayoutInference);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -362,10 +362,10 @@ tvm::transform::Pass LayoutReducer() { ...@@ -362,10 +362,10 @@ tvm::transform::Pass LayoutReducer() {
return CreatePrimFuncPass(pass_func, 0, "tl.LayoutReducer", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LayoutReducer", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LayoutReducer", LayoutReducer); refl::GlobalDef().def("tl.transform.LayoutReducer", LayoutReducer);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment