Unverified Commit 10911e28 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[FFI] Rebase tvm to v0.22.0 to utilize tvm-ffi (#1108)



* 3rdparty tvm bump

* bump tvm into v0.22.0

* lint fix

* rebase tvm

* Update submodule tvm to latest commit 3085bc4

* Refactor: Update configuration retrieval in CopyNode and adjust test registration in tilelang

* test fix

* add requirement

* atomic_fix

* atomic_fix

* phaseout py39

* optimize

* optimize

* lint fix

* do not clean cache

* do not clean cache

* [Minor] Minor update for Python versions and dependencies

* [Lint] fix lint for py39

* [Lint] fix lint for ROCm

* [Build][CI] Sync CI changes from upstream/sdist

* [Lint] fix lint for ROCm

* [Build][CI] Update `repair-wheel-command`

* [Minor] update abi3audit result format

* [Lint] fix lint for ROCm

* [BugFix] fix build

* [Lint] fix lint for ROCm

* [BugFix] set rpath for libtvm and libtvm_runtime

* [Deps] pin apache-tvm-ffi version

* [Build] set Python 3.9 Limited API for Cython target

* [Build] set Python 3.9 Limited API for Cython target

* [Deps] Restore Python 3.8 support

* [Build] use `apache-tvm-ffi`'s `libtvm_ffi`

* [BugFix] use `;` as delimiter for RPATH on macOS

* [BugFix] use `--ignore-missing-dependencies` for `delocate-wheel`

* [Build] support `sccache` if available

* [Build] add CIBW import test

* [Build][CI] enable ccache for CIBW on Linux

* [BugFix] set rpath for libtvm and libtvm_runtime

* Revert "[Build][CI] enable ccache for CIBW on Linux"

This reverts commit cd9ab57bb5ddd2572c60bcbbebde81480a658fd3.

* [CI] fix perfbench bot

* [BugFix] use Python 3.9 to build wheel

* [Minor] update perfbench bot envs

* [BugFix] fix CIBW environment on Linux

* [CI] skip import test on CentOS 7

* [CI] use Python urllib to download file instead of Wget

---------
Co-authored-by: default avatarXuehai Pan <XuehaiPan@pku.edu.cn>
parent c37621c5
...@@ -59,10 +59,10 @@ tvm::transform::Pass PersistThreadblock() { ...@@ -59,10 +59,10 @@ tvm::transform::Pass PersistThreadblock() {
return CreatePrimFuncPass(pass_func, 0, "tl.PersistThreadblock", {}); return CreatePrimFuncPass(pass_func, 0, "tl.PersistThreadblock", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.PersistThreadblock", PersistThreadblock); refl::GlobalDef().def("tl.transform.PersistThreadblock", PersistThreadblock);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -103,7 +103,7 @@ private: ...@@ -103,7 +103,7 @@ private:
ICHECK(call->op.same_as(builtin::tvm_access_ptr())); ICHECK(call->op.same_as(builtin::tvm_access_ptr()));
auto var = call->args[1].as<VarNode>(); auto var = call->args[1].as<VarNode>();
ICHECK(var); ICHECK(var);
auto it = buffer_data_to_buffer_.find(GetRef<Var>(var)); auto it = buffer_data_to_buffer_.find(tvm::ffi::GetRef<Var>(var));
ICHECK(it != buffer_data_to_buffer_.end()); ICHECK(it != buffer_data_to_buffer_.end());
return (*it).second; return (*it).second;
}; };
...@@ -210,7 +210,7 @@ private: ...@@ -210,7 +210,7 @@ private:
if (const auto *load = op->args[0].as<BufferLoadNode>()) { if (const auto *load = op->args[0].as<BufferLoadNode>()) {
buffer_region = BufferRegion::FullRegion(load->buffer); buffer_region = BufferRegion::FullRegion(load->buffer);
} else if (const auto *var_node = op->args[0].as<VarNode>()) { } else if (const auto *var_node = op->args[0].as<VarNode>()) {
Var data_var = GetRef<Var>(var_node); Var data_var = tvm::ffi::GetRef<Var>(var_node);
auto it = buffer_data_to_buffer_.find(data_var); auto it = buffer_data_to_buffer_.find(data_var);
if (it != buffer_data_to_buffer_.end()) { if (it != buffer_data_to_buffer_.end()) {
buffer_region = BufferRegion::FullRegion((*it).second); buffer_region = BufferRegion::FullRegion((*it).second);
...@@ -223,7 +223,7 @@ private: ...@@ -223,7 +223,7 @@ private:
} else if (op->op.same_as(builtin::tvm_access_ptr())) { } else if (op->op.same_as(builtin::tvm_access_ptr())) {
const VarNode *buffer_var = op->args[1].as<VarNode>(); const VarNode *buffer_var = op->args[1].as<VarNode>();
ICHECK(buffer_var); ICHECK(buffer_var);
auto it = buffer_data_to_buffer_.find(GetRef<Var>(buffer_var)); auto it = buffer_data_to_buffer_.find(tvm::ffi::GetRef<Var>(buffer_var));
if (it != buffer_data_to_buffer_.end()) { if (it != buffer_data_to_buffer_.end()) {
const Buffer &buffer = (*it).second; const Buffer &buffer = (*it).second;
const BufferRegion buffer_region = BufferRegion::FullRegion(buffer); const BufferRegion buffer_region = BufferRegion::FullRegion(buffer);
...@@ -402,7 +402,7 @@ private: ...@@ -402,7 +402,7 @@ private:
if (TargetHasAsyncCopy(target_) && use_async_copy_) if (TargetHasAsyncCopy(target_) && use_async_copy_)
annotations.Set(tir::attr::software_pipeline_async_stages, annotations.Set(tir::attr::software_pipeline_async_stages,
Array<Integer>{0}); Array<Integer>{0});
auto for_node = GetRef<For>(loop); auto for_node = tvm::ffi::GetRef<For>(loop);
for_node.CopyOnWrite()->annotations = annotations; for_node.CopyOnWrite()->annotations = annotations;
return for_node; return for_node;
} }
...@@ -728,10 +728,10 @@ tvm::transform::Pass PipelinePlanning() { ...@@ -728,10 +728,10 @@ tvm::transform::Pass PipelinePlanning() {
return CreatePrimFuncPass(pass_func, 0, "tl.PipelinePlanning", {}); return CreatePrimFuncPass(pass_func, 0, "tl.PipelinePlanning", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.PipelinePlanning", PipelinePlanning); refl::GlobalDef().def("tl.transform.PipelinePlanning", PipelinePlanning);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -23,6 +23,7 @@ namespace tvm { ...@@ -23,6 +23,7 @@ namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
using namespace ffi;
using namespace arith; using namespace arith;
struct SimplifyConfigNode : public AttrsNodeReflAdapter<SimplifyConfigNode> { struct SimplifyConfigNode : public AttrsNodeReflAdapter<SimplifyConfigNode> {
...@@ -62,8 +63,8 @@ struct SimplifyConfigNode : public AttrsNodeReflAdapter<SimplifyConfigNode> { ...@@ -62,8 +63,8 @@ struct SimplifyConfigNode : public AttrsNodeReflAdapter<SimplifyConfigNode> {
"branch", "branch",
refl::DefaultValue(false)); refl::DefaultValue(false));
} }
static constexpr const char *_type_key = "tl.transform.SimplifyConfig"; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.transform.SimplifyConfig",
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(SimplifyConfigNode, BaseAttrsNode); SimplifyConfigNode, BaseAttrsNode);
RewriteSimplifier::Extension GetEnabledExtensions() const { RewriteSimplifier::Extension GetEnabledExtensions() const {
RewriteSimplifier::Extension flags = RewriteSimplifier::kNone; RewriteSimplifier::Extension flags = RewriteSimplifier::kNone;
...@@ -209,12 +210,11 @@ CollectVarsUsedInBufferDefinition(const Stmt &stmt) { ...@@ -209,12 +210,11 @@ CollectVarsUsedInBufferDefinition(const Stmt &stmt) {
class SimplifyConfig : public Attrs { class SimplifyConfig : public Attrs {
public: public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs, TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SimplifyConfig, Attrs,
SimplifyConfigNode); SimplifyConfigNode);
}; };
TVM_FFI_STATIC_INIT_BLOCK({ SimplifyConfigNode::RegisterReflection(); }); TVM_FFI_STATIC_INIT_BLOCK() { SimplifyConfigNode::RegisterReflection(); }
TVM_REGISTER_NODE_TYPE(SimplifyConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig); TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig);
class StmtSimplifier : public IRMutatorWithAnalyzer { class StmtSimplifier : public IRMutatorWithAnalyzer {
...@@ -391,7 +391,7 @@ private: ...@@ -391,7 +391,7 @@ private:
if (can_inline && !used_in_buffer_def) { if (can_inline && !used_in_buffer_def) {
return body; return body;
} else if (value.same_as(op->value) && body.same_as(op->body)) { } else if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op); return tvm::ffi::GetRef<Stmt>(op);
} else { } else {
auto n = this->CopyOnWrite(op); auto n = this->CopyOnWrite(op);
n->value = std::move(value); n->value = std::move(value);
...@@ -522,10 +522,10 @@ tvm::transform::Pass Simplify(bool simplify_arguments = true) { ...@@ -522,10 +522,10 @@ tvm::transform::Pass Simplify(bool simplify_arguments = true) {
return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {}); return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.Simplify", Simplify); refl::GlobalDef().def("tl.transform.Simplify", Simplify);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace ffi;
namespace tir = tvm::tir; namespace tir = tvm::tir;
class HostDeviceSplitter : public tir::StmtMutator { class HostDeviceSplitter : public tir::StmtMutator {
...@@ -200,10 +200,10 @@ tvm::transform::Pass SplitHostDevice() { ...@@ -200,10 +200,10 @@ tvm::transform::Pass SplitHostDevice() {
{}); {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.SplitHostDevice", SplitHostDevice); refl::GlobalDef().def("tl.transform.SplitHostDevice", SplitHostDevice);
}); }
} // namespace transform } // namespace transform
} // namespace tl } // namespace tl
......
...@@ -39,10 +39,11 @@ using namespace tir; ...@@ -39,10 +39,11 @@ using namespace tir;
void TileLangStorageAccessVisitor::VisitExpr_(const BufferLoadNode *op) { void TileLangStorageAccessVisitor::VisitExpr_(const BufferLoadNode *op) {
Var buf = op->buffer->data; Var buf = op->buffer->data;
buffer_data_to_buffer_.Set(GetRef<Var>(buf.get()), op->buffer); buffer_data_to_buffer_.Set(tvm::ffi::GetRef<Var>(buf.get()), op->buffer);
StorageScope scope = GetScope(buf); StorageScope scope = GetScope(buf);
if (Enabled(buf.get(), scope)) { if (Enabled(buf.get(), scope)) {
ICHECK(allow_append_) << GetRef<BufferLoad>(op) << " " << scope.to_string(); ICHECK(allow_append_) << tvm::ffi::GetRef<BufferLoad>(op) << " "
<< scope.to_string();
AccessEntry e; AccessEntry e;
e.threads = env_threads(); e.threads = env_threads();
e.thread_range = this->ComputeThreadRange(e.threads); e.thread_range = this->ComputeThreadRange(e.threads);
...@@ -66,7 +67,7 @@ void TileLangStorageAccessVisitor::VisitStmt_(const BufferStoreNode *op) { ...@@ -66,7 +67,7 @@ void TileLangStorageAccessVisitor::VisitStmt_(const BufferStoreNode *op) {
curr_stmt_.stmt = op; curr_stmt_.stmt = op;
Var buf = op->buffer->data; Var buf = op->buffer->data;
buffer_data_to_buffer_.Set(GetRef<Var>(buf.get()), op->buffer); buffer_data_to_buffer_.Set(tvm::ffi::GetRef<Var>(buf.get()), op->buffer);
StorageScope scope = GetScope(buf); StorageScope scope = GetScope(buf);
if (Enabled(buf.get(), scope)) { if (Enabled(buf.get(), scope)) {
AccessEntry e; AccessEntry e;
...@@ -326,8 +327,8 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) { ...@@ -326,8 +327,8 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
Buffer buffer = load->buffer; Buffer buffer = load->buffer;
DataType dtype = buffer->dtype; DataType dtype = buffer->dtype;
const VarNode *buffer_var = buffer->data.as<VarNode>(); const VarNode *buffer_var = buffer->data.as<VarNode>();
buffer_data_to_buffer_.Set(GetRef<Var>(buffer_var), buffer); buffer_data_to_buffer_.Set(tvm::ffi::GetRef<Var>(buffer_var), buffer);
StorageScope scope = GetScope(GetRef<Var>(buffer_var)); StorageScope scope = GetScope(tvm::ffi::GetRef<Var>(buffer_var));
Array<Range> buffer_ranges; Array<Range> buffer_ranges;
// from indices to buffer indices // from indices to buffer indices
ICHECK(buffer->shape.size() == load->indices.size()); ICHECK(buffer->shape.size() == load->indices.size());
...@@ -365,17 +366,18 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) { ...@@ -365,17 +366,18 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
PrimExpr offset = op->args[2]; PrimExpr offset = op->args[2];
PrimExpr extent = op->args[3]; PrimExpr extent = op->args[3];
const IntImmNode *flag = op->args[4].as<IntImmNode>(); const IntImmNode *flag = op->args[4].as<IntImmNode>();
StorageScope scope = GetScope(GetRef<Var>(buffer_var)); StorageScope scope = GetScope(tvm::ffi::GetRef<Var>(buffer_var));
// The buffer scope. // The buffer scope.
if (Enabled(buffer_var, scope)) { if (Enabled(buffer_var, scope)) {
ICHECK(allow_append_); ICHECK(allow_append_);
Array<Range> buffer_ranges; Array<Range> buffer_ranges;
if (buffer_data_to_buffer_.find(GetRef<Var>(buffer_var)) == if (buffer_data_to_buffer_.find(tvm::ffi::GetRef<Var>(buffer_var)) ==
buffer_data_to_buffer_.end()) { buffer_data_to_buffer_.end()) {
// cannot find buffer map, use the default buffer // cannot find buffer map, use the default buffer
buffer_ranges = {Range::FromMinExtent(offset, extent)}; buffer_ranges = {Range::FromMinExtent(offset, extent)};
} else { } else {
Buffer buffer = buffer_data_to_buffer_.at(GetRef<Var>(buffer_var)); Buffer buffer =
buffer_data_to_buffer_.at(tvm::ffi::GetRef<Var>(buffer_var));
auto buffer_shape = buffer->shape; auto buffer_shape = buffer->shape;
// convert 1d offset to multi-dimensional index // convert 1d offset to multi-dimensional index
auto linear_to_indices = [this](PrimExpr offset, auto linear_to_indices = [this](PrimExpr offset,
...@@ -406,7 +408,7 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) { ...@@ -406,7 +408,7 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
e.threads = env_threads(); e.threads = env_threads();
e.thread_range = this->ComputeThreadRange(e.threads); e.thread_range = this->ComputeThreadRange(e.threads);
e.dtype = dtype; e.dtype = dtype;
e.buffer = GetRef<Var>(buffer_var); e.buffer = tvm::ffi::GetRef<Var>(buffer_var);
e.buffer_ranges = buffer_ranges; e.buffer_ranges = buffer_ranges;
e.is_pointer_access = true; e.is_pointer_access = true;
e.touched = { e.touched = {
......
...@@ -39,6 +39,7 @@ namespace tvm { ...@@ -39,6 +39,7 @@ namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
using namespace ffi;
using arith::IRVisitorWithAnalyzer; using arith::IRVisitorWithAnalyzer;
using runtime::StorageRank; using runtime::StorageRank;
using runtime::StorageScope; using runtime::StorageScope;
......
...@@ -544,7 +544,7 @@ public: ...@@ -544,7 +544,7 @@ public:
} }
return it->second->alloc_var; return it->second->alloc_var;
} else { } else {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
} }
PrimExpr VisitExpr_(const CallNode *op) final { PrimExpr VisitExpr_(const CallNode *op) final {
...@@ -978,8 +978,8 @@ private: ...@@ -978,8 +978,8 @@ private:
ICHECK(alloc_info.count(var)); ICHECK(alloc_info.count(var));
const AllocEntry &entry = alloc_info.at(var); const AllocEntry &entry = alloc_info.at(var);
const AllocateNode *alloc = entry.alloc; const AllocateNode *alloc = entry.alloc;
auto storage_scope = auto storage_scope = StorageScope::Create(
StorageScope::Create(GetPtrStorageScope(GetRef<Var>(var))); GetPtrStorageScope(tvm::ffi::GetRef<Var>(var)));
StorageEntry *dst_entry = nullptr; StorageEntry *dst_entry = nullptr;
// inplace detection // inplace detection
if (detect_inplace) { if (detect_inplace) {
...@@ -1732,7 +1732,7 @@ public: ...@@ -1732,7 +1732,7 @@ public:
Var var = (it == rewrite_map_.end()) ? op->var : it->second.new_buffer_var; Var var = (it == rewrite_map_.end()) ? op->var : it->second.new_buffer_var;
if (var.same_as(op->var) && value.same_as(op->value) && if (var.same_as(op->var) && value.same_as(op->value) &&
body.same_as(op->body)) { body.same_as(op->body)) {
return GetRef<Stmt>(op); return tvm::ffi::GetRef<Stmt>(op);
} }
return LetStmt(var, value, body); return LetStmt(var, value, body);
} }
...@@ -1985,10 +1985,10 @@ Pass StorageRewrite() { ...@@ -1985,10 +1985,10 @@ Pass StorageRewrite() {
return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {}); return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.StorageRewrite", StorageRewrite); refl::GlobalDef().def("tl.transform.StorageRewrite", StorageRewrite);
}); }
Pass PointerValueTypeRewrite() { Pass PointerValueTypeRewrite() {
auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) { auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) {
...@@ -1997,11 +1997,11 @@ Pass PointerValueTypeRewrite() { ...@@ -1997,11 +1997,11 @@ Pass PointerValueTypeRewrite() {
return CreatePrimFuncPass(pass_func, 0, "tl.PointerValueTypeRewrite", {}); return CreatePrimFuncPass(pass_func, 0, "tl.PointerValueTypeRewrite", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.PointerValueTypeRewrite", refl::GlobalDef().def("tl.transform.PointerValueTypeRewrite",
PointerValueTypeRewrite); PointerValueTypeRewrite);
}); }
} // namespace transform } // namespace transform
} // namespace tl } // namespace tl
......
...@@ -850,10 +850,10 @@ tvm::transform::Pass ThreadSync(const String &storage_scope) { ...@@ -850,10 +850,10 @@ tvm::transform::Pass ThreadSync(const String &storage_scope) {
return CreatePrimFuncPass(pass_func, 0, "tl.ThreadSync", {}); return CreatePrimFuncPass(pass_func, 0, "tl.ThreadSync", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.ThreadSync", ThreadSync); refl::GlobalDef().def("tl.transform.ThreadSync", ThreadSync);
}); }
} // namespace transform } // namespace transform
} // namespace tl } // namespace tl
......
...@@ -44,6 +44,7 @@ namespace tvm { ...@@ -44,6 +44,7 @@ namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
using namespace ffi;
/*! /*!
* \brief Perform data type legalization on the given BufferLoadNode pointer. * \brief Perform data type legalization on the given BufferLoadNode pointer.
...@@ -252,7 +253,7 @@ public: ...@@ -252,7 +253,7 @@ public:
PrimExpr a = this->VisitExpr(op->a); PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b); PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) { if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector(); bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector();
bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector(); bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector();
...@@ -306,7 +307,7 @@ public: ...@@ -306,7 +307,7 @@ public:
PrimExpr VisitExpr_(const NotNode *op) final { PrimExpr VisitExpr_(const NotNode *op) final {
PrimExpr a = this->VisitExpr(op->a); PrimExpr a = this->VisitExpr(op->a);
if (a.same_as(op->a)) { if (a.same_as(op->a)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
return !(a); return !(a);
} }
...@@ -347,10 +348,10 @@ public: ...@@ -347,10 +348,10 @@ public:
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
if (value.dtype().is_scalable_or_fixed_length_vector()) { if (value.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true; need_scalarize_ = true;
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
if (value.same_as(op->value)) { if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
return Broadcast(op->value, op->lanes); return Broadcast(op->value, op->lanes);
} }
...@@ -362,7 +363,7 @@ public: ...@@ -362,7 +363,7 @@ public:
PrimExpr f = this->VisitExpr(op->false_value); PrimExpr f = this->VisitExpr(op->false_value);
if (cond.same_as(op->condition) && t.same_as(op->true_value) && if (cond.same_as(op->condition) && t.same_as(op->true_value) &&
f.same_as(op->false_value)) { f.same_as(op->false_value)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
int cond_lanes = cond.dtype().get_lanes_or_vscale_factor(); int cond_lanes = cond.dtype().get_lanes_or_vscale_factor();
int t_lanes = t.dtype().get_lanes_or_vscale_factor(); int t_lanes = t.dtype().get_lanes_or_vscale_factor();
...@@ -380,7 +381,7 @@ public: ...@@ -380,7 +381,7 @@ public:
PrimExpr VisitExpr_(const CastNode *op) final { PrimExpr VisitExpr_(const CastNode *op) final {
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) { if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
if (value.dtype().is_scalable_vector()) { if (value.dtype().is_scalable_vector()) {
return Cast(op->dtype.with_scalable_vscale_factor( return Cast(op->dtype.with_scalable_vscale_factor(
...@@ -393,20 +394,20 @@ public: ...@@ -393,20 +394,20 @@ public:
} }
PrimExpr VisitExpr_(const FloatImmNode *op) final { PrimExpr VisitExpr_(const FloatImmNode *op) final {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
PrimExpr VisitExpr_(const IntImmNode *op) final { PrimExpr VisitExpr_(const IntImmNode *op) final {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
PrimExpr VisitExpr_(const StringImmNode *op) final { PrimExpr VisitExpr_(const StringImmNode *op) final {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
// Variable // Variable
PrimExpr VisitExpr_(const VarNode *op) final { PrimExpr VisitExpr_(const VarNode *op) final {
Var var = GetRef<Var>(op); Var var = tvm::ffi::GetRef<Var>(op);
if (var.same_as(var_)) { if (var.same_as(var_)) {
return ramp_; return ramp_;
...@@ -423,13 +424,13 @@ public: ...@@ -423,13 +424,13 @@ public:
PrimExpr cond = this->VisitExpr(op->args[0]); PrimExpr cond = this->VisitExpr(op->args[0]);
if (cond.dtype().is_scalable_or_fixed_length_vector()) { if (cond.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true; need_scalarize_ = true;
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
PrimExpr t = this->VisitExpr(op->args[1]); PrimExpr t = this->VisitExpr(op->args[1]);
PrimExpr f = this->VisitExpr(op->args[2]); PrimExpr f = this->VisitExpr(op->args[2]);
if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) &&
f.same_as(op->args[2])) { f.same_as(op->args[2])) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
int t_lanes = t.dtype().get_lanes_or_vscale_factor(); int t_lanes = t.dtype().get_lanes_or_vscale_factor();
int f_lanes = f.dtype().get_lanes_or_vscale_factor(); int f_lanes = f.dtype().get_lanes_or_vscale_factor();
...@@ -451,7 +452,7 @@ public: ...@@ -451,7 +452,7 @@ public:
ICHECK(op->op.same_as(builtin::reinterpret())); ICHECK(op->op.same_as(builtin::reinterpret()));
PrimExpr value = this->VisitExpr(op->args[0]); PrimExpr value = this->VisitExpr(op->args[0]);
if (value.same_as(op->args[0])) { if (value.same_as(op->args[0])) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
int lanes = value.dtype().get_lanes_or_vscale_factor(); int lanes = value.dtype().get_lanes_or_vscale_factor();
if (value.dtype().is_scalable_vector()) { if (value.dtype().is_scalable_vector()) {
...@@ -495,12 +496,12 @@ public: ...@@ -495,12 +496,12 @@ public:
auto new_arg = this->VisitExpr(arg); auto new_arg = this->VisitExpr(arg);
if (new_arg.dtype().is_scalable_or_fixed_length_vector()) { if (new_arg.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true; need_scalarize_ = true;
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
new_args.push_back(new_arg); new_args.push_back(new_arg);
} }
if (op->args.same_as(new_args)) { if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
return Call(op->dtype, op->op, new_args); return Call(op->dtype, op->op, new_args);
} }
...@@ -509,7 +510,7 @@ public: ...@@ -509,7 +510,7 @@ public:
Array<PrimExpr> new_args = MutateArray(op->args, &lane); Array<PrimExpr> new_args = MutateArray(op->args, &lane);
// normal code path. // normal code path.
if (op->args.same_as(new_args)) { if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
return Call(op->dtype.with_lanes(lane), op->op, new_args); return Call(op->dtype.with_lanes(lane), op->op, new_args);
} }
...@@ -517,7 +518,7 @@ public: ...@@ -517,7 +518,7 @@ public:
} }
// BufferLoad // BufferLoad
PrimExpr VisitExpr_(const BufferLoadNode *op) final { PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = GetRef<BufferLoad>(op); auto load = tvm::ffi::GetRef<BufferLoad>(op);
auto fmutate = [this](const PrimExpr &index) { auto fmutate = [this](const PrimExpr &index) {
return this->VisitExpr(index); return this->VisitExpr(index);
...@@ -557,7 +558,7 @@ public: ...@@ -557,7 +558,7 @@ public:
let_var_map_[op->var] = op->var; let_var_map_[op->var] = op->var;
PrimExpr body = this->VisitExpr(op->body); PrimExpr body = this->VisitExpr(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) { if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
return Let(op->var, value, body); return Let(op->var, value, body);
} }
...@@ -565,7 +566,7 @@ public: ...@@ -565,7 +566,7 @@ public:
} }
// BufferStore // BufferStore
Stmt VisitStmt_(const BufferStoreNode *op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store = GetRef<BufferStore>(op); auto store = tvm::ffi::GetRef<BufferStore>(op);
auto fmutate = [this](const PrimExpr &index) { auto fmutate = [this](const PrimExpr &index) {
return this->VisitExpr(index); return this->VisitExpr(index);
...@@ -628,11 +629,11 @@ public: ...@@ -628,11 +629,11 @@ public:
ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector()); ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector());
PrimExpr extent = this->VisitExpr(op->extent); PrimExpr extent = this->VisitExpr(op->extent);
if (extent.dtype().is_scalable_or_fixed_length_vector()) { if (extent.dtype().is_scalable_or_fixed_length_vector()) {
return Scalarize(GetRef<Stmt>(op)); return Scalarize(tvm::ffi::GetRef<Stmt>(op));
} }
Stmt body = this->VisitStmt(op->body); Stmt body = this->VisitStmt(op->body);
if (extent.same_as(op->extent) && body.same_as(op->body)) { if (extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op); return tvm::ffi::GetRef<Stmt>(op);
} else { } else {
return For(op->loop_var, op->min, extent, op->kind, body, return For(op->loop_var, op->min, extent, op->kind, body,
op->thread_binding, op->annotations); op->thread_binding, op->annotations);
...@@ -643,7 +644,7 @@ public: ...@@ -643,7 +644,7 @@ public:
ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector()); ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector());
PrimExpr condition = this->VisitExpr(op->condition); PrimExpr condition = this->VisitExpr(op->condition);
if (condition.dtype().is_scalable_or_fixed_length_vector()) { if (condition.dtype().is_scalable_or_fixed_length_vector()) {
return Scalarize(GetRef<Stmt>(op)); return Scalarize(tvm::ffi::GetRef<Stmt>(op));
} }
Stmt then_case = this->VisitStmt(op->then_case); Stmt then_case = this->VisitStmt(op->then_case);
Optional<Stmt> else_case = std::nullopt; Optional<Stmt> else_case = std::nullopt;
...@@ -652,7 +653,7 @@ public: ...@@ -652,7 +653,7 @@ public:
} }
if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) { else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op); return tvm::ffi::GetRef<Stmt>(op);
} else { } else {
return IfThenElse(condition, then_case, else_case); return IfThenElse(condition, then_case, else_case);
} }
...@@ -680,7 +681,7 @@ public: ...@@ -680,7 +681,7 @@ public:
let_value_binding_[op->var] = value; let_value_binding_[op->var] = value;
Stmt body = this->VisitStmt(op->body); Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) { if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op); return tvm::ffi::GetRef<Stmt>(op);
} else { } else {
return LetStmt(op->var, value, body); return LetStmt(op->var, value, body);
} }
...@@ -694,7 +695,7 @@ public: ...@@ -694,7 +695,7 @@ public:
if (condition.dtype().is_scalable_or_fixed_length_vector()) { if (condition.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of " LOG(WARNING) << "Cannot handle vector extent in alloc of "
<< op->buffer_var->name_hint; << op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op)); return Scalarize(tvm::ffi::GetRef<Stmt>(op));
} }
return StmtMutator::VisitStmt_(op); return StmtMutator::VisitStmt_(op);
...@@ -781,7 +782,7 @@ private: ...@@ -781,7 +782,7 @@ private:
PrimExpr a = this->VisitExpr(op->a); PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b); PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) { if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
int a_lanes = a.dtype().get_lanes_or_vscale_factor(); int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor(); int b_lanes = b.dtype().get_lanes_or_vscale_factor();
...@@ -797,7 +798,7 @@ private: ...@@ -797,7 +798,7 @@ private:
PrimExpr a = this->VisitExpr(op->a); PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b); PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) { if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
int a_lanes = a.dtype().get_lanes_or_vscale_factor(); int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor(); int b_lanes = b.dtype().get_lanes_or_vscale_factor();
...@@ -877,10 +878,10 @@ tvm::transform::Pass VectorizeLoop(bool enable_vectorize = true) { ...@@ -877,10 +878,10 @@ tvm::transform::Pass VectorizeLoop(bool enable_vectorize = true) {
return CreatePrimFuncPass(pass_func, 0, "tl.VectorizeLoop", {}); return CreatePrimFuncPass(pass_func, 0, "tl.VectorizeLoop", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.VectorizeLoop", VectorizeLoop); refl::GlobalDef().def("tl.transform.VectorizeLoop", VectorizeLoop);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -159,7 +159,7 @@ public: ...@@ -159,7 +159,7 @@ public:
// Check reads from global // Check reads from global
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
/*body*/ GetRef<Stmt>(op)); /*body*/ tvm::ffi::GetRef<Stmt>(op));
auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
auto reads = access[0]; auto reads = access[0];
Role role = Role::kProducer; Role role = Role::kProducer;
...@@ -511,7 +511,7 @@ private: ...@@ -511,7 +511,7 @@ private:
annotations.Set(String("stmt_group"), Integer(1)); annotations.Set(String("stmt_group"), Integer(1));
auto original_node = (op->body).as<SeqStmtNode>(); auto original_node = (op->body).as<SeqStmtNode>();
if (!original_node) { if (!original_node) {
return GetRef<For>(op); return tvm::ffi::GetRef<For>(op);
} }
Array<Stmt> new_body; Array<Stmt> new_body;
int cur_id = 0; int cur_id = 0;
...@@ -646,7 +646,7 @@ private: ...@@ -646,7 +646,7 @@ private:
if (role == Role::kBoth) { if (role == Role::kBoth) {
return StmtMutator::VisitStmt_(op); return StmtMutator::VisitStmt_(op);
} else if ((role == Role::kProducer) == is_emitting_producer_) { } else if ((role == Role::kProducer) == is_emitting_producer_) {
return GetRef<Stmt>(op); return tvm::ffi::GetRef<Stmt>(op);
} else { } else {
return Evaluate(0); return Evaluate(0);
} }
...@@ -1284,7 +1284,7 @@ tvm::transform::Pass WarpSpecialized() { ...@@ -1284,7 +1284,7 @@ tvm::transform::Pass WarpSpecialized() {
return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized, return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized,
disable_shuffle_elect); disable_shuffle_elect);
} else { } else {
ObjectRef node = String("default"); auto node = ffi::String("default");
f.CopyOnWrite()->body = f.CopyOnWrite()->body =
AttrStmt(node, attr::kCustomWarpSpecialization, 1, f->body); AttrStmt(node, attr::kCustomWarpSpecialization, 1, f->body);
return f; return f;
...@@ -1293,10 +1293,10 @@ tvm::transform::Pass WarpSpecialized() { ...@@ -1293,10 +1293,10 @@ tvm::transform::Pass WarpSpecialized() {
return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {}); return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.WarpSpecialized", WarpSpecialized); refl::GlobalDef().def("tl.transform.WarpSpecialized", WarpSpecialized);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -266,10 +266,10 @@ tvm::transform::Pass RewriteWgmmaSync() { ...@@ -266,10 +266,10 @@ tvm::transform::Pass RewriteWgmmaSync() {
return CreatePrimFuncPass(pass_func, 0, "tl.RewriteWgmmaSync", {}); return CreatePrimFuncPass(pass_func, 0, "tl.RewriteWgmmaSync", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.RewriteWgmmaSync", RewriteWgmmaSync); refl::GlobalDef().def("tl.transform.RewriteWgmmaSync", RewriteWgmmaSync);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -85,7 +85,7 @@ def run_gemm( ...@@ -85,7 +85,7 @@ def run_gemm(
stramp = "&*(XS)" stramp = "&*(XS)"
@tvm.register_func("tilelang_callback_cuda_postproc", override=True) @tvm.register_global_func("tilelang_callback_cuda_postproc", override=True)
def tilelang_callback_cuda_postproc(code, _): def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code code = f"// {stramp}\n" + code
return code return code
...@@ -407,4 +407,5 @@ def test_ctypes_dynamic_shape(): ...@@ -407,4 +407,5 @@ def test_ctypes_dynamic_shape():
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() # tilelang.testing.main()
test_gemm_f16f16f16_nn()
...@@ -85,7 +85,7 @@ def run_gemm( ...@@ -85,7 +85,7 @@ def run_gemm(
stramp = "&*(XS)" stramp = "&*(XS)"
@tvm.register_func("tilelang_callback_cuda_postproc", override=True) @tvm.register_global_func("tilelang_callback_cuda_postproc", override=True)
def tilelang_callback_cuda_postproc(code, _): def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code code = f"// {stramp}\n" + code
return code return code
......
"""FFI APIs for tilelang""" """FFI APIs for tilelang"""
import tvm.ffi import tvm_ffi
# TVM_REGISTER_GLOBAL("tl.name").set_body_typed(func); # TVM_REGISTER_GLOBAL("tl.name").set_body_typed(func);
tvm.ffi._init_api("tl", __name__) # pylint: disable=protected-access tvm_ffi.init_ffi_api("tl", __name__)
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Wrapping functions to bridge frameworks with DLPack support to TVM""" """Wrapping functions to bridge frameworks with DLPack support to TVM"""
from tvm.runtime import ndarray from tvm import runtime
def convert_func(tvm_func, tensor_type, to_dlpack_func): def convert_func(tvm_func, tensor_type, to_dlpack_func):
...@@ -49,9 +49,9 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func): ...@@ -49,9 +49,9 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func):
torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2,
torch.float8_e5m2fnuz torch.float8_e5m2fnuz
}: }:
return ndarray.from_dlpack(to_dlpack_func(arg.view(torch.int8)))._create_view( return runtime.from_dlpack(to_dlpack_func(arg.view(torch.int8)))._create_view(
arg.shape, dtype=float8_dtype_map[arg.dtype]) arg.shape, dtype=float8_dtype_map[arg.dtype])
return ndarray.from_dlpack(to_dlpack_func(arg)) return runtime.from_dlpack(to_dlpack_func(arg))
return arg return arg
def _wrapper(*args): def _wrapper(*args):
......
...@@ -9,7 +9,7 @@ from __future__ import absolute_import as _abs ...@@ -9,7 +9,7 @@ from __future__ import absolute_import as _abs
import subprocess import subprocess
import tvm.ffi import tvm_ffi
from tvm.contrib import utils from tvm.contrib import utils
from tvm.base import py_str from tvm.base import py_str
...@@ -96,7 +96,7 @@ def compile_hip(code, ...@@ -96,7 +96,7 @@ def compile_hip(code,
return data return data
@tvm.ffi.register_func("tilelang_callback_hip_compile", override=True) @tvm_ffi.register_global_func("tilelang_callback_hip_compile", override=True)
def tilelang_callback_hip_compile(code, target): def tilelang_callback_hip_compile(code, target):
"""use hipcc to generate fatbin code for better optimization""" """use hipcc to generate fatbin code for better optimization"""
hsaco = compile_hip(code, target_format="hsaco") hsaco = compile_hip(code, target_format="hsaco")
......
...@@ -8,8 +8,8 @@ import os ...@@ -8,8 +8,8 @@ import os
import subprocess import subprocess
import warnings import warnings
from tilelang.env import CUDA_HOME from tilelang.env import CUDA_HOME
import tvm_ffi
import tvm.ffi from tilelang import tvm as tvm
from tvm.target import Target from tvm.target import Target
from tvm.base import py_str from tvm.base import py_str
...@@ -182,14 +182,14 @@ def get_cuda_version(cuda_path=None): ...@@ -182,14 +182,14 @@ def get_cuda_version(cuda_path=None):
raise RuntimeError("Cannot read cuda version file") raise RuntimeError("Cannot read cuda version file")
@tvm.ffi.register_func("tilelang_callback_cuda_compile", override=True) @tvm_ffi.register_global_func("tilelang_callback_cuda_compile", override=True)
def tilelang_callback_cuda_compile(code, target): # pylint: disable=unused-argument def tilelang_callback_cuda_compile(code, target): # pylint: disable=unused-argument
"""use nvcc to generate fatbin code for better optimization""" """use nvcc to generate fatbin code for better optimization"""
ptx = compile_cuda(code, target_format="fatbin") ptx = compile_cuda(code, target_format="fatbin")
return ptx return ptx
@tvm.ffi.register_func("tilelang_callback_libdevice_path", override=True) @tvm_ffi.register_global_func("tilelang_callback_libdevice_path", override=True)
def find_libdevice_path(arch): def find_libdevice_path(arch):
"""Utility function to find libdevice """Utility function to find libdevice
...@@ -254,7 +254,7 @@ def callback_libdevice_path(arch): ...@@ -254,7 +254,7 @@ def callback_libdevice_path(arch):
return "" return ""
@tvm.ffi.register_func("tvm.contrib.nvcc.get_compute_version", override=True) @tvm_ffi.register_global_func("tvm.contrib.nvcc.get_compute_version", override=True)
def get_target_compute_version(target=None): def get_target_compute_version(target=None):
"""Utility function to get compute capability of compilation target. """Utility function to get compute capability of compilation target.
...@@ -400,7 +400,7 @@ def have_cudagraph(): ...@@ -400,7 +400,7 @@ def have_cudagraph():
return False return False
@tvm.ffi.register_func("tvm.contrib.nvcc.supports_bf16", override=True) @tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_bf16", override=True)
def have_bf16(compute_version): def have_bf16(compute_version):
"""Either bf16 support is provided in the compute capability or not """Either bf16 support is provided in the compute capability or not
...@@ -413,7 +413,7 @@ def have_bf16(compute_version): ...@@ -413,7 +413,7 @@ def have_bf16(compute_version):
return major >= 8 return major >= 8
@tvm.ffi.register_func("tvm.contrib.nvcc.supports_fp8", override=True) @tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_fp8", override=True)
def have_fp8(compute_version): def have_fp8(compute_version):
"""Whether fp8 support is provided in the specified compute capability or not """Whether fp8 support is provided in the specified compute capability or not
...@@ -430,7 +430,7 @@ def have_fp8(compute_version): ...@@ -430,7 +430,7 @@ def have_fp8(compute_version):
return any(conditions) return any(conditions)
@tvm.ffi.register_func("tvm.contrib.nvcc.supports_tma", override=True) @tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_tma", override=True)
def have_tma(target): def have_tma(target):
"""Whether TMA support is provided in the specified compute capability or not """Whether TMA support is provided in the specified compute capability or not
......
...@@ -21,7 +21,7 @@ import subprocess ...@@ -21,7 +21,7 @@ import subprocess
import os import os
from os.path import join, exists from os.path import join, exists
import tvm.ffi import tvm_ffi
from tvm.base import py_str from tvm.base import py_str
import tvm.runtime import tvm.runtime
import tvm.target import tvm.target
...@@ -100,7 +100,7 @@ def rocm_link(in_file, out_file, lld=None): ...@@ -100,7 +100,7 @@ def rocm_link(in_file, out_file, lld=None):
raise RuntimeError(msg) raise RuntimeError(msg)
@tvm.ffi.register_func("tvm_callback_rocm_link", override=True) @tvm_ffi.register_global_func("tvm_callback_rocm_link", override=True)
def callback_rocm_link(obj_bin): def callback_rocm_link(obj_bin):
"""Links object file generated from LLVM to HSA Code Object """Links object file generated from LLVM to HSA Code Object
...@@ -124,7 +124,7 @@ def callback_rocm_link(obj_bin): ...@@ -124,7 +124,7 @@ def callback_rocm_link(obj_bin):
return cobj_bin return cobj_bin
@tvm.ffi.register_func("tvm_callback_rocm_bitcode_path", override=True) @tvm_ffi.register_global_func("tvm_callback_rocm_bitcode_path", override=True)
def callback_rocm_bitcode_path(rocdl_dir=None): def callback_rocm_bitcode_path(rocdl_dir=None):
"""Utility function to find ROCm device library bitcodes """Utility function to find ROCm device library bitcodes
...@@ -226,7 +226,7 @@ def have_matrixcore(compute_version=None): ...@@ -226,7 +226,7 @@ def have_matrixcore(compute_version=None):
return False return False
@tvm.ffi.register_func("tvm_callback_rocm_get_arch", override=True) @tvm_ffi.register_global_func("tvm_callback_rocm_get_arch", override=True)
def get_rocm_arch(rocm_path="/opt/rocm"): def get_rocm_arch(rocm_path="/opt/rocm"):
"""Utility function to get the AMD GPU architecture """Utility function to get the AMD GPU architecture
......
from __future__ import annotations from __future__ import annotations
from typing import Callable from typing import Callable
from tvm import register_func import tvm_ffi
from tvm.target import Target from tvm.target import Target
...@@ -12,7 +12,7 @@ def register_cuda_postproc(func: Callable[[str, Target], str], override: bool = ...@@ -12,7 +12,7 @@ def register_cuda_postproc(func: Callable[[str, Target], str], override: bool =
and returns the processed code (str). and returns the processed code (str).
override: Whether to override existing registered function. Defaults to True. override: Whether to override existing registered function. Defaults to True.
""" """
register_func("tilelang_callback_cuda_postproc", f=func, override=override) tvm_ffi.register_global_func("tilelang_callback_cuda_postproc", f=func, override=override)
def register_hip_postproc(func: Callable[[str, Target], str], override: bool = True): def register_hip_postproc(func: Callable[[str, Target], str], override: bool = True):
...@@ -23,7 +23,7 @@ def register_hip_postproc(func: Callable[[str, Target], str], override: bool = T ...@@ -23,7 +23,7 @@ def register_hip_postproc(func: Callable[[str, Target], str], override: bool = T
and returns the processed code (str). and returns the processed code (str).
override: Whether to override existing registered function. Defaults to True. override: Whether to override existing registered function. Defaults to True.
""" """
register_func("tilelang_callback_hip_postproc", f=func, override=override) tvm_ffi.register_global_func("tilelang_callback_hip_postproc", f=func, override=override)
def register_cuda_postproc_callback(func: Callable | bool = None, override: bool = True): def register_cuda_postproc_callback(func: Callable | bool = None, override: bool = True):
......
...@@ -7,6 +7,7 @@ from typing import Callable ...@@ -7,6 +7,7 @@ from typing import Callable
import tilelang.transform import tilelang.transform
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm import tir from tvm import tir
import tvm_ffi
from tvm.ir import CallingConv from tvm.ir import CallingConv
from tvm.target import Target from tvm.target import Target
from tilelang.contrib import hipcc, nvcc from tilelang.contrib import hipcc, nvcc
...@@ -52,7 +53,7 @@ def get_host_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]: ...@@ -52,7 +53,7 @@ def get_host_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]:
return lambda func: not get_device_call(is_device_c)(func) return lambda func: not get_device_call(is_device_c)(func)
@tvm.register_func("tilelang_callback_cuda_compile", override=True) @tvm_ffi.register_global_func("tilelang_callback_cuda_compile", override=True)
def tilelang_callback_cuda_compile(code, target): def tilelang_callback_cuda_compile(code, target):
project_root = osp.join(osp.dirname(__file__), "../..") project_root = osp.join(osp.dirname(__file__), "../..")
if "TL_TEMPLATE_PATH" in os.environ: if "TL_TEMPLATE_PATH" in os.environ:
...@@ -89,7 +90,7 @@ def tilelang_callback_cuda_compile(code, target): ...@@ -89,7 +90,7 @@ def tilelang_callback_cuda_compile(code, target):
return ptx return ptx
@tvm.register_func("tilelang_callback_hip_compile", override=True) @tvm_ffi.register_global_func("tilelang_callback_hip_compile", override=True)
def tilelang_callback_hip_compile(code, target): def tilelang_callback_hip_compile(code, target):
project_root = osp.join(osp.dirname(__file__), "../..") project_root = osp.join(osp.dirname(__file__), "../..")
tl_template_path = osp.abspath(osp.join(project_root, "src")) tl_template_path = osp.abspath(osp.join(project_root, "src"))
...@@ -181,7 +182,7 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> ...@@ -181,7 +182,7 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) ->
elif target.kind.name == "llvm": elif target.kind.name == "llvm":
device_mod = tvm.ffi.get_global_func("target.build.llvm")(device_mod, target) device_mod = tvm.ffi.get_global_func("target.build.llvm")(device_mod, target)
elif target.kind.name == "webgpu": elif target.kind.name == "webgpu":
device_mod = tvm.ffi.get_global_func("target.build.tilelang_webgpu")(device_mod, target) device_mod = tvm.ffi.get_global_func("target.build.webgpu")(device_mod, target)
elif target.kind.name == "metal": elif target.kind.name == "metal":
device_mod = tvm.ffi.get_global_func("target.build.metal")(device_mod, target) device_mod = tvm.ffi.get_global_func("target.build.metal")(device_mod, target)
else: else:
...@@ -240,6 +241,6 @@ def lower( ...@@ -240,6 +241,6 @@ def lower(
host_mod = host_codegen(host_mod, target_host) host_mod = host_codegen(host_mod, target_host)
host_mod.import_module(codegen_mod) host_mod.import_module(codegen_mod)
return CompiledArtifact( return CompiledArtifact(
host_mod, device_mod, params, codegen_mod.get_source(), rt_mod=host_mod) host_mod, device_mod, params, codegen_mod.inspect_source(), rt_mod=host_mod)
return CompiledArtifact(host_mod, device_mod, params, codegen_mod.get_source()) return CompiledArtifact(host_mod, device_mod, params, codegen_mod.inspect_source())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment