Unverified Commit 10911e28 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[FFI] Rebase tvm to v0.22.0 to utilize tvm-ffi (#1108)



* 3rdparty tvm bump

* bump tvm into v0.22.0

* lint fix

* rebase tvm

* Update submodule tvm to latest commit 3085bc4

* Refactor: Update configuration retrieval in CopyNode and adjust test registration in tilelang

* test fix

* add requirement

* atomic_fix

* atomic_fix

* phaseout py39

* optimize

* optimize

* lint fix

* do not clean cache

* do not clean cache

* [Minor] Minor update for Python versions and dependencies

* [Lint] fix lint for py39

* [Lint] fix lint for ROCm

* [Build][CI] Sync CI changes from upstream/sdist

* [Lint] fix lint for ROCm

* [Build][CI] Update `repair-wheel-command`

* [Minor] update abi3audit result format

* [Lint] fix lint for ROCm

* [BugFix] fix build

* [Lint] fix lint for ROCm

* [BugFix] set rpath for libtvm and libtvm_runtime

* [Deps] pin apache-tvm-ffi version

* [Build] set Python 3.9 Limited API for Cython target

* [Build] set Python 3.9 Limited API for Cython target

* [Deps] Restore Python 3.8 support

* [Build] use `apache-tvm-ffi`'s `libtvm_ffi`

* [BugFix] use `;` as delimiter for RPATH on macOS

* [BugFix] use `--ignore-missing-dependencies` for `delocate-wheel`

* [Build] support `sccache` if available

* [Build] add CIBW import test

* [Build][CI] enable ccache for CIBW on Linux

* [BugFix] set rpath for libtvm and libtvm_runtime

* Revert "[Build][CI] enable ccache for CIBW on Linux"

This reverts commit cd9ab57bb5ddd2572c60bcbbebde81480a658fd3.

* [CI] fix perfbench bot

* [BugFix] use Python 3.9 to build wheel

* [Minor] update perfbench bot envs

* [BugFix] fix CIBW environment on Linux

* [CI] skip import test on CentOS 7

* [CI] use Python urllib to download file instead of Wget

---------
Co-authored-by: default avatarXuehai Pan <XuehaiPan@pku.edu.cn>
parent c37621c5
......@@ -59,10 +59,10 @@ tvm::transform::Pass PersistThreadblock() {
return CreatePrimFuncPass(pass_func, 0, "tl.PersistThreadblock", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.PersistThreadblock", PersistThreadblock);
});
}
} // namespace tl
} // namespace tvm
......@@ -103,7 +103,7 @@ private:
ICHECK(call->op.same_as(builtin::tvm_access_ptr()));
auto var = call->args[1].as<VarNode>();
ICHECK(var);
auto it = buffer_data_to_buffer_.find(GetRef<Var>(var));
auto it = buffer_data_to_buffer_.find(tvm::ffi::GetRef<Var>(var));
ICHECK(it != buffer_data_to_buffer_.end());
return (*it).second;
};
......@@ -210,7 +210,7 @@ private:
if (const auto *load = op->args[0].as<BufferLoadNode>()) {
buffer_region = BufferRegion::FullRegion(load->buffer);
} else if (const auto *var_node = op->args[0].as<VarNode>()) {
Var data_var = GetRef<Var>(var_node);
Var data_var = tvm::ffi::GetRef<Var>(var_node);
auto it = buffer_data_to_buffer_.find(data_var);
if (it != buffer_data_to_buffer_.end()) {
buffer_region = BufferRegion::FullRegion((*it).second);
......@@ -223,7 +223,7 @@ private:
} else if (op->op.same_as(builtin::tvm_access_ptr())) {
const VarNode *buffer_var = op->args[1].as<VarNode>();
ICHECK(buffer_var);
auto it = buffer_data_to_buffer_.find(GetRef<Var>(buffer_var));
auto it = buffer_data_to_buffer_.find(tvm::ffi::GetRef<Var>(buffer_var));
if (it != buffer_data_to_buffer_.end()) {
const Buffer &buffer = (*it).second;
const BufferRegion buffer_region = BufferRegion::FullRegion(buffer);
......@@ -402,7 +402,7 @@ private:
if (TargetHasAsyncCopy(target_) && use_async_copy_)
annotations.Set(tir::attr::software_pipeline_async_stages,
Array<Integer>{0});
auto for_node = GetRef<For>(loop);
auto for_node = tvm::ffi::GetRef<For>(loop);
for_node.CopyOnWrite()->annotations = annotations;
return for_node;
}
......@@ -728,10 +728,10 @@ tvm::transform::Pass PipelinePlanning() {
return CreatePrimFuncPass(pass_func, 0, "tl.PipelinePlanning", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.PipelinePlanning", PipelinePlanning);
});
}
} // namespace tl
} // namespace tvm
......@@ -23,6 +23,7 @@ namespace tvm {
namespace tl {
using namespace tir;
using namespace ffi;
using namespace arith;
struct SimplifyConfigNode : public AttrsNodeReflAdapter<SimplifyConfigNode> {
......@@ -62,8 +63,8 @@ struct SimplifyConfigNode : public AttrsNodeReflAdapter<SimplifyConfigNode> {
"branch",
refl::DefaultValue(false));
}
static constexpr const char *_type_key = "tl.transform.SimplifyConfig";
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(SimplifyConfigNode, BaseAttrsNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.transform.SimplifyConfig",
SimplifyConfigNode, BaseAttrsNode);
RewriteSimplifier::Extension GetEnabledExtensions() const {
RewriteSimplifier::Extension flags = RewriteSimplifier::kNone;
......@@ -209,12 +210,11 @@ CollectVarsUsedInBufferDefinition(const Stmt &stmt) {
class SimplifyConfig : public Attrs {
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs,
SimplifyConfigNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SimplifyConfig, Attrs,
SimplifyConfigNode);
};
TVM_FFI_STATIC_INIT_BLOCK({ SimplifyConfigNode::RegisterReflection(); });
TVM_FFI_STATIC_INIT_BLOCK() { SimplifyConfigNode::RegisterReflection(); }
TVM_REGISTER_NODE_TYPE(SimplifyConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig);
class StmtSimplifier : public IRMutatorWithAnalyzer {
......@@ -391,7 +391,7 @@ private:
if (can_inline && !used_in_buffer_def) {
return body;
} else if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
} else {
auto n = this->CopyOnWrite(op);
n->value = std::move(value);
......@@ -522,10 +522,10 @@ tvm::transform::Pass Simplify(bool simplify_arguments = true) {
return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.Simplify", Simplify);
});
}
} // namespace tl
} // namespace tvm
......@@ -37,7 +37,7 @@
namespace tvm {
namespace tl {
using namespace ffi;
namespace tir = tvm::tir;
class HostDeviceSplitter : public tir::StmtMutator {
......@@ -200,10 +200,10 @@ tvm::transform::Pass SplitHostDevice() {
{});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.SplitHostDevice", SplitHostDevice);
});
}
} // namespace transform
} // namespace tl
......
......@@ -39,10 +39,11 @@ using namespace tir;
void TileLangStorageAccessVisitor::VisitExpr_(const BufferLoadNode *op) {
Var buf = op->buffer->data;
buffer_data_to_buffer_.Set(GetRef<Var>(buf.get()), op->buffer);
buffer_data_to_buffer_.Set(tvm::ffi::GetRef<Var>(buf.get()), op->buffer);
StorageScope scope = GetScope(buf);
if (Enabled(buf.get(), scope)) {
ICHECK(allow_append_) << GetRef<BufferLoad>(op) << " " << scope.to_string();
ICHECK(allow_append_) << tvm::ffi::GetRef<BufferLoad>(op) << " "
<< scope.to_string();
AccessEntry e;
e.threads = env_threads();
e.thread_range = this->ComputeThreadRange(e.threads);
......@@ -66,7 +67,7 @@ void TileLangStorageAccessVisitor::VisitStmt_(const BufferStoreNode *op) {
curr_stmt_.stmt = op;
Var buf = op->buffer->data;
buffer_data_to_buffer_.Set(GetRef<Var>(buf.get()), op->buffer);
buffer_data_to_buffer_.Set(tvm::ffi::GetRef<Var>(buf.get()), op->buffer);
StorageScope scope = GetScope(buf);
if (Enabled(buf.get(), scope)) {
AccessEntry e;
......@@ -326,8 +327,8 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
Buffer buffer = load->buffer;
DataType dtype = buffer->dtype;
const VarNode *buffer_var = buffer->data.as<VarNode>();
buffer_data_to_buffer_.Set(GetRef<Var>(buffer_var), buffer);
StorageScope scope = GetScope(GetRef<Var>(buffer_var));
buffer_data_to_buffer_.Set(tvm::ffi::GetRef<Var>(buffer_var), buffer);
StorageScope scope = GetScope(tvm::ffi::GetRef<Var>(buffer_var));
Array<Range> buffer_ranges;
// from indices to buffer indices
ICHECK(buffer->shape.size() == load->indices.size());
......@@ -365,17 +366,18 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
PrimExpr offset = op->args[2];
PrimExpr extent = op->args[3];
const IntImmNode *flag = op->args[4].as<IntImmNode>();
StorageScope scope = GetScope(GetRef<Var>(buffer_var));
StorageScope scope = GetScope(tvm::ffi::GetRef<Var>(buffer_var));
// The buffer scope.
if (Enabled(buffer_var, scope)) {
ICHECK(allow_append_);
Array<Range> buffer_ranges;
if (buffer_data_to_buffer_.find(GetRef<Var>(buffer_var)) ==
if (buffer_data_to_buffer_.find(tvm::ffi::GetRef<Var>(buffer_var)) ==
buffer_data_to_buffer_.end()) {
// cannot find buffer map, use the default buffer
buffer_ranges = {Range::FromMinExtent(offset, extent)};
} else {
Buffer buffer = buffer_data_to_buffer_.at(GetRef<Var>(buffer_var));
Buffer buffer =
buffer_data_to_buffer_.at(tvm::ffi::GetRef<Var>(buffer_var));
auto buffer_shape = buffer->shape;
// convert 1d offset to multi-dimensional index
auto linear_to_indices = [this](PrimExpr offset,
......@@ -406,7 +408,7 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
e.threads = env_threads();
e.thread_range = this->ComputeThreadRange(e.threads);
e.dtype = dtype;
e.buffer = GetRef<Var>(buffer_var);
e.buffer = tvm::ffi::GetRef<Var>(buffer_var);
e.buffer_ranges = buffer_ranges;
e.is_pointer_access = true;
e.touched = {
......
......@@ -39,6 +39,7 @@ namespace tvm {
namespace tl {
using namespace tir;
using namespace ffi;
using arith::IRVisitorWithAnalyzer;
using runtime::StorageRank;
using runtime::StorageScope;
......
......@@ -544,7 +544,7 @@ public:
}
return it->second->alloc_var;
} else {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
}
PrimExpr VisitExpr_(const CallNode *op) final {
......@@ -978,8 +978,8 @@ private:
ICHECK(alloc_info.count(var));
const AllocEntry &entry = alloc_info.at(var);
const AllocateNode *alloc = entry.alloc;
auto storage_scope =
StorageScope::Create(GetPtrStorageScope(GetRef<Var>(var)));
auto storage_scope = StorageScope::Create(
GetPtrStorageScope(tvm::ffi::GetRef<Var>(var)));
StorageEntry *dst_entry = nullptr;
// inplace detection
if (detect_inplace) {
......@@ -1732,7 +1732,7 @@ public:
Var var = (it == rewrite_map_.end()) ? op->var : it->second.new_buffer_var;
if (var.same_as(op->var) && value.same_as(op->value) &&
body.same_as(op->body)) {
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
}
return LetStmt(var, value, body);
}
......@@ -1985,10 +1985,10 @@ Pass StorageRewrite() {
return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.StorageRewrite", StorageRewrite);
});
}
Pass PointerValueTypeRewrite() {
auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) {
......@@ -1997,11 +1997,11 @@ Pass PointerValueTypeRewrite() {
return CreatePrimFuncPass(pass_func, 0, "tl.PointerValueTypeRewrite", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.PointerValueTypeRewrite",
PointerValueTypeRewrite);
});
}
} // namespace transform
} // namespace tl
......
......@@ -850,10 +850,10 @@ tvm::transform::Pass ThreadSync(const String &storage_scope) {
return CreatePrimFuncPass(pass_func, 0, "tl.ThreadSync", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.ThreadSync", ThreadSync);
});
}
} // namespace transform
} // namespace tl
......
......@@ -44,6 +44,7 @@ namespace tvm {
namespace tl {
using namespace tir;
using namespace ffi;
/*!
* \brief Perform data type legalization on the given BufferLoadNode pointer.
......@@ -252,7 +253,7 @@ public:
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector();
bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector();
......@@ -306,7 +307,7 @@ public:
PrimExpr VisitExpr_(const NotNode *op) final {
PrimExpr a = this->VisitExpr(op->a);
if (a.same_as(op->a)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
return !(a);
}
......@@ -347,10 +348,10 @@ public:
PrimExpr value = this->VisitExpr(op->value);
if (value.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true;
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
return Broadcast(op->value, op->lanes);
}
......@@ -362,7 +363,7 @@ public:
PrimExpr f = this->VisitExpr(op->false_value);
if (cond.same_as(op->condition) && t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
int cond_lanes = cond.dtype().get_lanes_or_vscale_factor();
int t_lanes = t.dtype().get_lanes_or_vscale_factor();
......@@ -380,7 +381,7 @@ public:
PrimExpr VisitExpr_(const CastNode *op) final {
PrimExpr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
if (value.dtype().is_scalable_vector()) {
return Cast(op->dtype.with_scalable_vscale_factor(
......@@ -393,20 +394,20 @@ public:
}
PrimExpr VisitExpr_(const FloatImmNode *op) final {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const IntImmNode *op) final {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const StringImmNode *op) final {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
// Variable
PrimExpr VisitExpr_(const VarNode *op) final {
Var var = GetRef<Var>(op);
Var var = tvm::ffi::GetRef<Var>(op);
if (var.same_as(var_)) {
return ramp_;
......@@ -423,13 +424,13 @@ public:
PrimExpr cond = this->VisitExpr(op->args[0]);
if (cond.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true;
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
PrimExpr t = this->VisitExpr(op->args[1]);
PrimExpr f = this->VisitExpr(op->args[2]);
if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) &&
f.same_as(op->args[2])) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
int t_lanes = t.dtype().get_lanes_or_vscale_factor();
int f_lanes = f.dtype().get_lanes_or_vscale_factor();
......@@ -451,7 +452,7 @@ public:
ICHECK(op->op.same_as(builtin::reinterpret()));
PrimExpr value = this->VisitExpr(op->args[0]);
if (value.same_as(op->args[0])) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
int lanes = value.dtype().get_lanes_or_vscale_factor();
if (value.dtype().is_scalable_vector()) {
......@@ -495,12 +496,12 @@ public:
auto new_arg = this->VisitExpr(arg);
if (new_arg.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true;
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
new_args.push_back(new_arg);
}
if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
return Call(op->dtype, op->op, new_args);
}
......@@ -509,7 +510,7 @@ public:
Array<PrimExpr> new_args = MutateArray(op->args, &lane);
// normal code path.
if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
return Call(op->dtype.with_lanes(lane), op->op, new_args);
}
......@@ -517,7 +518,7 @@ public:
}
// BufferLoad
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = GetRef<BufferLoad>(op);
auto load = tvm::ffi::GetRef<BufferLoad>(op);
auto fmutate = [this](const PrimExpr &index) {
return this->VisitExpr(index);
......@@ -557,7 +558,7 @@ public:
let_var_map_[op->var] = op->var;
PrimExpr body = this->VisitExpr(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
return Let(op->var, value, body);
}
......@@ -565,7 +566,7 @@ public:
}
// BufferStore
Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store = GetRef<BufferStore>(op);
auto store = tvm::ffi::GetRef<BufferStore>(op);
auto fmutate = [this](const PrimExpr &index) {
return this->VisitExpr(index);
......@@ -628,11 +629,11 @@ public:
ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector());
PrimExpr extent = this->VisitExpr(op->extent);
if (extent.dtype().is_scalable_or_fixed_length_vector()) {
return Scalarize(GetRef<Stmt>(op));
return Scalarize(tvm::ffi::GetRef<Stmt>(op));
}
Stmt body = this->VisitStmt(op->body);
if (extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
} else {
return For(op->loop_var, op->min, extent, op->kind, body,
op->thread_binding, op->annotations);
......@@ -643,7 +644,7 @@ public:
ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector());
PrimExpr condition = this->VisitExpr(op->condition);
if (condition.dtype().is_scalable_or_fixed_length_vector()) {
return Scalarize(GetRef<Stmt>(op));
return Scalarize(tvm::ffi::GetRef<Stmt>(op));
}
Stmt then_case = this->VisitStmt(op->then_case);
Optional<Stmt> else_case = std::nullopt;
......@@ -652,7 +653,7 @@ public:
}
if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
} else {
return IfThenElse(condition, then_case, else_case);
}
......@@ -680,7 +681,7 @@ public:
let_value_binding_[op->var] = value;
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
} else {
return LetStmt(op->var, value, body);
}
......@@ -694,7 +695,7 @@ public:
if (condition.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of "
<< op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op));
return Scalarize(tvm::ffi::GetRef<Stmt>(op));
}
return StmtMutator::VisitStmt_(op);
......@@ -781,7 +782,7 @@ private:
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor();
......@@ -797,7 +798,7 @@ private:
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor();
......@@ -877,10 +878,10 @@ tvm::transform::Pass VectorizeLoop(bool enable_vectorize = true) {
return CreatePrimFuncPass(pass_func, 0, "tl.VectorizeLoop", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.VectorizeLoop", VectorizeLoop);
});
}
} // namespace tl
} // namespace tvm
......@@ -159,7 +159,7 @@ public:
// Check reads from global
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
/*body*/ GetRef<Stmt>(op));
/*body*/ tvm::ffi::GetRef<Stmt>(op));
auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
auto reads = access[0];
Role role = Role::kProducer;
......@@ -511,7 +511,7 @@ private:
annotations.Set(String("stmt_group"), Integer(1));
auto original_node = (op->body).as<SeqStmtNode>();
if (!original_node) {
return GetRef<For>(op);
return tvm::ffi::GetRef<For>(op);
}
Array<Stmt> new_body;
int cur_id = 0;
......@@ -646,7 +646,7 @@ private:
if (role == Role::kBoth) {
return StmtMutator::VisitStmt_(op);
} else if ((role == Role::kProducer) == is_emitting_producer_) {
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
} else {
return Evaluate(0);
}
......@@ -1284,7 +1284,7 @@ tvm::transform::Pass WarpSpecialized() {
return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized,
disable_shuffle_elect);
} else {
ObjectRef node = String("default");
auto node = ffi::String("default");
f.CopyOnWrite()->body =
AttrStmt(node, attr::kCustomWarpSpecialization, 1, f->body);
return f;
......@@ -1293,10 +1293,10 @@ tvm::transform::Pass WarpSpecialized() {
return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.WarpSpecialized", WarpSpecialized);
});
}
} // namespace tl
} // namespace tvm
......@@ -266,10 +266,10 @@ tvm::transform::Pass RewriteWgmmaSync() {
return CreatePrimFuncPass(pass_func, 0, "tl.RewriteWgmmaSync", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.RewriteWgmmaSync", RewriteWgmmaSync);
});
}
} // namespace tl
} // namespace tvm
......@@ -85,7 +85,7 @@ def run_gemm(
stramp = "&*(XS)"
@tvm.register_func("tilelang_callback_cuda_postproc", override=True)
@tvm.register_global_func("tilelang_callback_cuda_postproc", override=True)
def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code
return code
......@@ -407,4 +407,5 @@ def test_ctypes_dynamic_shape():
if __name__ == "__main__":
tilelang.testing.main()
# tilelang.testing.main()
test_gemm_f16f16f16_nn()
......@@ -85,7 +85,7 @@ def run_gemm(
stramp = "&*(XS)"
@tvm.register_func("tilelang_callback_cuda_postproc", override=True)
@tvm.register_global_func("tilelang_callback_cuda_postproc", override=True)
def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code
return code
......
"""FFI APIs for tilelang"""
import tvm.ffi
import tvm_ffi
# TVM_REGISTER_GLOBAL("tl.name").set_body_typed(func);
tvm.ffi._init_api("tl", __name__) # pylint: disable=protected-access
tvm_ffi.init_ffi_api("tl", __name__)
......@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Wrapping functions to bridge frameworks with DLPack support to TVM"""
from tvm.runtime import ndarray
from tvm import runtime
def convert_func(tvm_func, tensor_type, to_dlpack_func):
......@@ -49,9 +49,9 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func):
torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2,
torch.float8_e5m2fnuz
}:
return ndarray.from_dlpack(to_dlpack_func(arg.view(torch.int8)))._create_view(
return runtime.from_dlpack(to_dlpack_func(arg.view(torch.int8)))._create_view(
arg.shape, dtype=float8_dtype_map[arg.dtype])
return ndarray.from_dlpack(to_dlpack_func(arg))
return runtime.from_dlpack(to_dlpack_func(arg))
return arg
def _wrapper(*args):
......
......@@ -9,7 +9,7 @@ from __future__ import absolute_import as _abs
import subprocess
import tvm.ffi
import tvm_ffi
from tvm.contrib import utils
from tvm.base import py_str
......@@ -96,7 +96,7 @@ def compile_hip(code,
return data
@tvm.ffi.register_func("tilelang_callback_hip_compile", override=True)
@tvm_ffi.register_global_func("tilelang_callback_hip_compile", override=True)
def tilelang_callback_hip_compile(code, target):
"""use hipcc to generate fatbin code for better optimization"""
hsaco = compile_hip(code, target_format="hsaco")
......
......@@ -8,8 +8,8 @@ import os
import subprocess
import warnings
from tilelang.env import CUDA_HOME
import tvm.ffi
import tvm_ffi
from tilelang import tvm as tvm
from tvm.target import Target
from tvm.base import py_str
......@@ -182,14 +182,14 @@ def get_cuda_version(cuda_path=None):
raise RuntimeError("Cannot read cuda version file")
@tvm.ffi.register_func("tilelang_callback_cuda_compile", override=True)
@tvm_ffi.register_global_func("tilelang_callback_cuda_compile", override=True)
def tilelang_callback_cuda_compile(code, target): # pylint: disable=unused-argument
"""use nvcc to generate fatbin code for better optimization"""
ptx = compile_cuda(code, target_format="fatbin")
return ptx
@tvm.ffi.register_func("tilelang_callback_libdevice_path", override=True)
@tvm_ffi.register_global_func("tilelang_callback_libdevice_path", override=True)
def find_libdevice_path(arch):
"""Utility function to find libdevice
......@@ -254,7 +254,7 @@ def callback_libdevice_path(arch):
return ""
@tvm.ffi.register_func("tvm.contrib.nvcc.get_compute_version", override=True)
@tvm_ffi.register_global_func("tvm.contrib.nvcc.get_compute_version", override=True)
def get_target_compute_version(target=None):
"""Utility function to get compute capability of compilation target.
......@@ -400,7 +400,7 @@ def have_cudagraph():
return False
@tvm.ffi.register_func("tvm.contrib.nvcc.supports_bf16", override=True)
@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_bf16", override=True)
def have_bf16(compute_version):
"""Either bf16 support is provided in the compute capability or not
......@@ -413,7 +413,7 @@ def have_bf16(compute_version):
return major >= 8
@tvm.ffi.register_func("tvm.contrib.nvcc.supports_fp8", override=True)
@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_fp8", override=True)
def have_fp8(compute_version):
"""Whether fp8 support is provided in the specified compute capability or not
......@@ -430,7 +430,7 @@ def have_fp8(compute_version):
return any(conditions)
@tvm.ffi.register_func("tvm.contrib.nvcc.supports_tma", override=True)
@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_tma", override=True)
def have_tma(target):
"""Whether TMA support is provided in the specified compute capability or not
......
......@@ -21,7 +21,7 @@ import subprocess
import os
from os.path import join, exists
import tvm.ffi
import tvm_ffi
from tvm.base import py_str
import tvm.runtime
import tvm.target
......@@ -100,7 +100,7 @@ def rocm_link(in_file, out_file, lld=None):
raise RuntimeError(msg)
@tvm.ffi.register_func("tvm_callback_rocm_link", override=True)
@tvm_ffi.register_global_func("tvm_callback_rocm_link", override=True)
def callback_rocm_link(obj_bin):
"""Links object file generated from LLVM to HSA Code Object
......@@ -124,7 +124,7 @@ def callback_rocm_link(obj_bin):
return cobj_bin
@tvm.ffi.register_func("tvm_callback_rocm_bitcode_path", override=True)
@tvm_ffi.register_global_func("tvm_callback_rocm_bitcode_path", override=True)
def callback_rocm_bitcode_path(rocdl_dir=None):
"""Utility function to find ROCm device library bitcodes
......@@ -226,7 +226,7 @@ def have_matrixcore(compute_version=None):
return False
@tvm.ffi.register_func("tvm_callback_rocm_get_arch", override=True)
@tvm_ffi.register_global_func("tvm_callback_rocm_get_arch", override=True)
def get_rocm_arch(rocm_path="/opt/rocm"):
"""Utility function to get the AMD GPU architecture
......
from __future__ import annotations
from typing import Callable
from tvm import register_func
import tvm_ffi
from tvm.target import Target
......@@ -12,7 +12,7 @@ def register_cuda_postproc(func: Callable[[str, Target], str], override: bool =
and returns the processed code (str).
override: Whether to override existing registered function. Defaults to True.
"""
register_func("tilelang_callback_cuda_postproc", f=func, override=override)
tvm_ffi.register_global_func("tilelang_callback_cuda_postproc", f=func, override=override)
def register_hip_postproc(func: Callable[[str, Target], str], override: bool = True):
......@@ -23,7 +23,7 @@ def register_hip_postproc(func: Callable[[str, Target], str], override: bool = T
and returns the processed code (str).
override: Whether to override existing registered function. Defaults to True.
"""
register_func("tilelang_callback_hip_postproc", f=func, override=override)
tvm_ffi.register_global_func("tilelang_callback_hip_postproc", f=func, override=override)
def register_cuda_postproc_callback(func: Callable | bool = None, override: bool = True):
......
......@@ -7,6 +7,7 @@ from typing import Callable
import tilelang.transform
from tilelang import tvm as tvm
from tvm import tir
import tvm_ffi
from tvm.ir import CallingConv
from tvm.target import Target
from tilelang.contrib import hipcc, nvcc
......@@ -52,7 +53,7 @@ def get_host_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]:
return lambda func: not get_device_call(is_device_c)(func)
@tvm.register_func("tilelang_callback_cuda_compile", override=True)
@tvm_ffi.register_global_func("tilelang_callback_cuda_compile", override=True)
def tilelang_callback_cuda_compile(code, target):
project_root = osp.join(osp.dirname(__file__), "../..")
if "TL_TEMPLATE_PATH" in os.environ:
......@@ -89,7 +90,7 @@ def tilelang_callback_cuda_compile(code, target):
return ptx
@tvm.register_func("tilelang_callback_hip_compile", override=True)
@tvm_ffi.register_global_func("tilelang_callback_hip_compile", override=True)
def tilelang_callback_hip_compile(code, target):
project_root = osp.join(osp.dirname(__file__), "../..")
tl_template_path = osp.abspath(osp.join(project_root, "src"))
......@@ -181,7 +182,7 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) ->
elif target.kind.name == "llvm":
device_mod = tvm.ffi.get_global_func("target.build.llvm")(device_mod, target)
elif target.kind.name == "webgpu":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_webgpu")(device_mod, target)
device_mod = tvm.ffi.get_global_func("target.build.webgpu")(device_mod, target)
elif target.kind.name == "metal":
device_mod = tvm.ffi.get_global_func("target.build.metal")(device_mod, target)
else:
......@@ -240,6 +241,6 @@ def lower(
host_mod = host_codegen(host_mod, target_host)
host_mod.import_module(codegen_mod)
return CompiledArtifact(
host_mod, device_mod, params, codegen_mod.get_source(), rt_mod=host_mod)
host_mod, device_mod, params, codegen_mod.inspect_source(), rt_mod=host_mod)
return CompiledArtifact(host_mod, device_mod, params, codegen_mod.get_source())
return CompiledArtifact(host_mod, device_mod, params, codegen_mod.inspect_source())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment