Unverified Commit 10911e28 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[FFI] Rebase tvm to v0.22.0 to utilize tvm-ffi (#1108)



* 3rdparty tvm bump

* bump tvm into v0.22.0

* lint fix

* rebase tvm

* Update submodule tvm to latest commit 3085bc4

* Refactor: Update configuration retrieval in CopyNode and adjust test registration in tilelang

* test fix

* add requirement

* atomic_fix

* atomic_fix

* phaseout py39

* optimize

* optimize

* lint fix

* do not clean cache

* do not clean cache

* [Minor] Minor update for Python versions and dependencies

* [Lint] fix lint for py39

* [Lint] fix lint for ROCm

* [Build][CI] Sync CI changes from upstream/sdist

* [Lint] fix lint for ROCm

* [Build][CI] Update `repair-wheel-command`

* [Minor] update abi3audit result format

* [Lint] fix lint for ROCm

* [BugFix] fix build

* [Lint] fix lint for ROCm

* [BugFix] set rpath for libtvm and libtvm_runtime

* [Deps] pin apache-tvm-ffi version

* [Build] set Python 3.9 Limited API for Cython target

* [Build] set Python 3.9 Limited API for Cython target

* [Deps] Restore Python 3.8 support

* [Build] use `apache-tvm-ffi`'s `libtvm_ffi`

* [BugFix] use `;` as delimiter for RPATH on macOS

* [BugFix] use `--ignore-missing-dependencies` for `delocate-wheel`

* [Build] support `sccache` if available

* [Build] add CIBW import test

* [Build][CI] enable ccache for CIBW on Linux

* [BugFix] set rpath for libtvm and libtvm_runtime

* Revert "[Build][CI] enable ccache for CIBW on Linux"

This reverts commit cd9ab57bb5ddd2572c60bcbbebde81480a658fd3.

* [CI] fix perfbench bot

* [BugFix] use Python 3.9 to build wheel

* [Minor] update perfbench bot envs

* [BugFix] fix CIBW environment on Linux

* [CI] skip import test on CentOS 7

* [CI] use Python urllib to download file instead of Wget

---------
Co-authored-by: default avatarXuehai Pan <XuehaiPan@pku.edu.cn>
parent c37621c5
......@@ -30,23 +30,13 @@ enum class ReduceTypeEnum : uint8_t {
class ReduceTypeNode : public Object {
public:
int type{-1}; ///< Internal type identifier
static constexpr const char *_type_key = "tl.ReduceType";
TVM_DECLARE_FINAL_OBJECT_INFO(ReduceTypeNode, Object);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ReduceType", ReduceTypeNode, Object);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<ReduceTypeNode>().def_ro("type", &ReduceTypeNode::type);
}
bool SEqualReduce(const ReduceTypeNode *other, SEqualReducer equal) const {
return equal(type, other->type);
}
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(type); }
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
/// Type checking methods
bool isSum() const { return type == int(ReduceTypeEnum::kSum); }
bool isAbsSum() const { return type == int(ReduceTypeEnum::kAbsSum); }
......@@ -61,9 +51,10 @@ public:
/// Wrapper class for reduction type with string-based construction
class ReduceType : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ReduceType, ObjectRef, ReduceTypeNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReduceType, ObjectRef,
ReduceTypeNode);
TVM_DLL ReduceType(std::string type) {
auto node = make_object<ReduceTypeNode>();
auto node = tvm::ffi::make_object<ReduceTypeNode>();
if (type == "sum") {
node->type = int(ReduceTypeEnum::kSum);
} else if (type == "abssum") {
......@@ -95,8 +86,8 @@ public:
ReduceType type; ///< Type of reduction operation
bool clear; ///< Whether to clear destination before reduction
static constexpr const char *_type_key = "tl.ReduceOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ReduceOpNode, TileOperatorNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ReduceOp", ReduceOpNode,
TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
......@@ -108,23 +99,6 @@ public:
.def_ro("clear", &ReduceOpNode::clear);
}
bool SEqualReduce(const ReduceOpNode *other, SEqualReducer equal) const {
return equal(src, other->src) && equal(dst, other->dst) &&
equal(dim, other->dim) && equal(type, other->type) &&
equal(clear, other->clear);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(src);
hash_reduce(dst);
hash_reduce(dim);
hash_reduce(type);
hash_reduce(clear);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
/// Lower the operator to TIR statements
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
/// Infer memory layout for buffers
......@@ -145,7 +119,8 @@ private:
/// Wrapper class for reduction operations
class ReduceOp : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ReduceOp, TileOperator, ReduceOpNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReduceOp, TileOperator,
ReduceOpNode);
TVM_DLL ReduceOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
......@@ -156,8 +131,17 @@ public:
tir::Buffer src, dst; ///< Source and destination buffers
int dim; ///< Dimension along which to compute cumulative sum
bool reverse; ///< Whether to compute in reverse order
static constexpr const char *_type_key = "tl.CumSumOp";
TVM_DECLARE_FINAL_OBJECT_INFO(CumSumOpNode, TileOperatorNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.CumSumOp", CumSumOpNode,
TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<CumSumOpNode>()
.def_ro("src", &CumSumOpNode::src)
.def_ro("dst", &CumSumOpNode::dst)
.def_ro("dim", &CumSumOpNode::dim)
.def_ro("reverse", &CumSumOpNode::reverse);
}
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
......@@ -169,7 +153,8 @@ public:
/// Wrapper class for cumulative sum operations
class CumSumOp : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(CumSumOp, TileOperator, CumSumOpNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CumSumOp, TileOperator,
CumSumOpNode);
TVM_DLL CumSumOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
......
......@@ -44,7 +44,7 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
PrimExpr extent = args[2 + i];
ranges.push_back(Range::FromMinExtent(min, extent));
}
ObjectPtr<RegionOpNode> node = make_object<RegionOpNode>();
ObjectPtr<RegionOpNode> node = tvm::ffi::make_object<RegionOpNode>();
node->buffer_ = load->buffer;
node->access_mask_ = static_cast<int>(*as_const_int(args[1]));
node->ranges_ = ranges;
......@@ -57,7 +57,7 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A new TileOperator that owns a copied RegionOpNode.
*/
TileOperator RegionOpNode::Clone() const {
auto op = make_object<RegionOpNode>(*this);
auto op = tvm::ffi::make_object<RegionOpNode>(*this);
return RegionOp(op);
}
......@@ -118,5 +118,7 @@ TIR_REGISTER_TL_OP(RegionOp, region)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
TVM_FFI_STATIC_INIT_BLOCK() { RegionOpNode::RegisterReflection(); }
} // namespace tl
} // namespace tvm
......@@ -80,8 +80,8 @@ public:
Array<Range> ranges_;
int access_mask_;
static constexpr const char *_type_key = "tl.RegionOp";
TVM_DECLARE_FINAL_OBJECT_INFO(RegionOpNode, TileOperatorNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.RegionOp", RegionOpNode,
TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
......@@ -101,25 +101,12 @@ public:
.def_ro("ranges", &RegionOpNode::ranges_)
.def_ro("access_mask", &RegionOpNode::access_mask_);
}
bool SEqualReduce(const RegionOpNode *other, SEqualReducer equal) const {
return equal(buffer_, other->buffer_) && equal(ranges_, other->ranges_) &&
equal(access_mask_, other->access_mask_);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(buffer_);
hash_reduce(ranges_);
hash_reduce(access_mask_);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
};
class RegionOp : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RegionOp, TileOperator, RegionOpNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(RegionOp, TileOperator,
RegionOpNode);
TVM_DLL RegionOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
......
......@@ -89,7 +89,7 @@ struct TensorMapArgs {
};
// set device api
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed("tvm_tensormap_create_tiled", [](PackedArgs args,
Any *ret) {
......@@ -104,7 +104,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
}
*ret = static_cast<int>(result);
});
});
}
struct TensorMapIm2ColArgs {
CUtensorMap *map;
......@@ -180,7 +180,7 @@ struct TensorMapIm2ColArgs {
}
};
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed(
"tvm_tensormap_create_im2col", [](PackedArgs args, Any *ret) {
......@@ -197,7 +197,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
}
*ret = static_cast<int>(result);
});
});
}
#endif // (CUDA_MAJOR_VERSION >= 12)
......
#pragma once
#include <tvm/ffi/cast.h>
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/memory.h>
#include <tvm/ffi/optional.h>
#include <tvm/ffi/string.h>
namespace tvm {
using ffi::Array;
using ffi::Function;
using ffi::Map;
using ffi::Optional;
using ffi::String;
} // namespace tvm
......@@ -29,6 +29,7 @@
#include <unordered_set>
#include <utility>
#include "../support/ffi_aliases.h"
#include "support/str_escape.h"
#include "target/build_common.h"
#include "target/source/codegen_params.h"
......@@ -54,8 +55,7 @@ void CodeGenTileLangCPP::Init(bool output_ssa, bool emit_asserts,
}
void CodeGenTileLangCPP::InitGlobalContext() {
decl_stream << "void* " << tvm::runtime::symbol::tvm_ffi_library_ctx
<< " = NULL;\n";
decl_stream << "void* " << ffi::symbol::tvm_ffi_library_ctx << " = NULL;\n";
}
void CodeGenTileLangCPP::DefineModuleName() {
......@@ -256,8 +256,8 @@ void CodeGenTileLangCPP::AddFunction(const PrimFunc &f) {
// reserve keywords
ReserveKeywordsAsUnique();
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
auto global_symbol = f->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol)
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
......
......@@ -73,10 +73,10 @@ public:
void VisitStmt_(const AssertStmtNode *op) final; // NOLINT(*)
void VisitStmt_(const AllocateNode *op) final; // NOLINT(*)
void GenerateForwardFunctionDeclarations(String global_symbol,
const Array<Type> &arg_types,
void GenerateForwardFunctionDeclarations(ffi::String global_symbol,
const ffi::Array<Type> &arg_types,
const Type &ret_type) override;
Array<String> GetFunctionNames() { return function_names_; }
ffi::Array<ffi::String> GetFunctionNames() { return function_names_; }
private:
/* \brief Internal structure to store information about function calls */
......@@ -92,7 +92,7 @@ private:
/* \brief mapping global packed func to the unique name */
std::unordered_map<std::string, std::string> declared_globals_;
/* \brief names of the functions declared in this module */
Array<String> function_names_;
ffi::Array<ffi::String> function_names_;
/*! \brief whether to emit asserts in the resulting C code */
bool emit_asserts_;
/*! \brief whether to emit forward function declarations in the resulting C
......
......@@ -20,6 +20,7 @@
namespace tvm {
namespace codegen {
using namespace tvm::tl::codegen;
using namespace ffi;
struct CUDAMath {
std::string operator()(DataType t, std::string name) const {
......@@ -2165,8 +2166,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
"A_ptr, B_ptr, C_ptr>, but got "
<< op->args.size();
auto op_instance = Downcast<StringImm>(op->args[0]);
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_instance->value,
op->args, true, os);
this->PrintCallExtern(GetType(tvm::ffi::GetRef<PrimExpr>(op)),
op_instance->value, op->args, true, os);
} else if (op->op.same_as(tl::tl_gemm_sp())) {
ICHECK(op->args.size() == 5)
<< "tl_gemm_sp expects 5 arguments <op_instance, A_ptr, B_ptr, C_ptr, "
......@@ -2174,8 +2175,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
<< op->args.size();
auto op_instance = Downcast<StringImm>(op->args[0]);
enable_sparse_gemm_ = true;
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_instance->value,
op->args, true, os);
this->PrintCallExtern(GetType(tvm::ffi::GetRef<PrimExpr>(op)),
op_instance->value, op->args, true, os);
} else if (op->op.same_as(tl::get_lane_idx())) {
ICHECK_LE(op->args.size(), 1)
<< "tl.get_lane_idx expects at most one argument <warp_size>.";
......@@ -2458,8 +2459,8 @@ void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) {
void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) {
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
CHECK_LE(lanes, 4) << "Translate Ramp Node " << GetRef<Ramp>(op) << " with "
<< lanes << " lanes is not allowed.";
CHECK_LE(lanes, 4) << "Translate Ramp Node " << tvm::ffi::GetRef<Ramp>(op)
<< " with " << lanes << " lanes is not allowed.";
os << "(make_";
PrintType(op->dtype, os);
os << "(";
......@@ -2971,7 +2972,7 @@ void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar,
ReserveKeywordsAsUnique();
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
ICHECK(global_symbol)
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
......
......@@ -60,14 +60,14 @@ public:
// Override this as a work around for __grid_constant__ parameter
void AddFunction(const GlobalVar &gvar, const PrimFunc &f);
void PrintFunctionSignature(const String &function_name, const PrimFunc &func,
std::ostream &os);
void PrintFunctionSignature(const ffi::String &function_name,
const PrimFunc &func, std::ostream &os);
protected:
virtual std::string GetBufferRef(DataType t, const BufferNode *buffer,
PrimExpr index) final;
void PrintCallExtern(Type ret_type, String global_symbol,
const Array<PrimExpr> &args, bool skip_first_arg,
void PrintCallExtern(Type ret_type, ffi::String global_symbol,
const ffi::Array<PrimExpr> &args, bool skip_first_arg,
std::ostream &os) final; // NOLINT(*)
private:
......
......@@ -959,8 +959,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
"A_ptr, B_ptr, C_ptr>, but got "
<< op->args.size();
auto op_instance = Downcast<StringImm>(op->args[0]);
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_instance->value,
op->args, true, os);
this->PrintCallExtern(GetType(tvm::ffi::GetRef<PrimExpr>(op)),
op_instance->value, op->args, true, os);
} else if (op->op.same_as(tl::tl_gemm_sp())) {
LOG(FATAL) << "tl_gemm_sp is not supported on HIP";
} else if (op->op.same_as(tl::loop_break())) {
......@@ -1309,7 +1309,7 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc &f) {
ReserveKeywordsAsUnique();
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
ICHECK(global_symbol.has_value())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
......
......@@ -56,8 +56,8 @@ public:
protected:
virtual std::string GetBufferRef(DataType t, const BufferNode *buffer,
PrimExpr index) final;
void PrintCallExtern(Type ret_type, String global_symbol,
const Array<PrimExpr> &args, bool skip_first_arg,
void PrintCallExtern(Type ret_type, ffi::String global_symbol,
const ffi::Array<PrimExpr> &args, bool skip_first_arg,
std::ostream &os) final; // NOLINT(*)
private:
......
This diff is collapsed.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_webgpu.h
* \brief Generate WebGPU shaders in WGSL.
*
* This module generates WGSL shading language.
* See https://www.w3.org/TR/WGSL/ for the language reference.
*/
#ifndef TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_
#define TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_
#include <tvm/target/codegen.h>
#include <string>
#include "target/source/codegen_c.h"
namespace tvm {
namespace codegen {
/*!
* \brief WebGPU code generator.
*
* Note WGSL have a different syntax from normal C.
* We only leverage the C for expression generation and
* write most of the language generations.
*/
class CodeGenTileLangWebGPU final : public CodeGenC {
public:
explicit CodeGenTileLangWebGPU(Target target);
// overrides
std::string Finish() final;
using CodeGenC::AddFunction;
runtime::FunctionInfo AddFunction(const PrimFunc &f,
bool skip_readonly_decl); // NOLINT(*)
void InitFuncState(const PrimFunc &f) final;
void PrintStorageSync(const CallNode *op) final; // NOLINT(*)
void PrintType(DataType t, std::ostream &os) final; // NOLINT(*)
void BindThreadIndex(const IterVar &iv) final; // NOLINT(*)
// assignment printing
void PrintSSAAssign(const std::string &target, const std::string &src,
DataType type) final;
// overload visitor
void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const CallNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const BufferLoadNode *op,
std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const CastNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const SelectNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const FloatImmNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const IntImmNode *op, std::ostream &os) final; // NOLINT(*)
// stmt printing
void VisitStmt_(const LetStmtNode *op) final;
void VisitStmt_(const BufferStoreNode *op) final;
void VisitStmt_(const ForNode *op) final;
void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AssertStmtNode *op) final;
void VisitStmt_(const AllocateConstNode *op) final;
void VisitStmt_(const WhileNode *op) final;
private:
/*!
* \brief Enforce value to be U32.
*/
static PrimExpr EnforceU32(PrimExpr value);
/*!
* \brief Storage type of bool values.
*/
DataType boolean_storage_type_{DataType::Int(8)};
// whether enable fp16
bool enable_fp16_{false};
/*! \brief the header stream for function label and enable directive if any,
* goes before any other declaration */
std::ostringstream header_stream;
Target target_;
};
} // namespace codegen
} // namespace tvm
#endif // TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_
......@@ -5,6 +5,7 @@
#include <tvm/tir/builtin.h>
#include <tvm/tir/op_attr_types.h>
#include "../support/ffi_aliases.h"
#include "target/intrin_rule.h"
namespace tvm {
......
......@@ -5,6 +5,7 @@
#include <tvm/tir/builtin.h>
#include <tvm/tir/op_attr_types.h>
#include "../support/ffi_aliases.h"
#include "target/intrin_rule.h"
namespace tvm {
......
#include "codegen_cpp.h"
#include <tvm/ffi/extra/module.h>
#include <tvm/ffi/reflection/registry.h>
#include "../support/ffi_aliases.h"
namespace tvm {
namespace codegen {
runtime::Module BuildCPPHost(IRModule mod, Target target) {
ffi::Module BuildCPPHost(IRModule mod, Target target) {
bool output_ssa = false;
bool emit_asserts = false;
bool emit_fwd_func_decl = true;
......@@ -67,10 +70,10 @@ runtime::Module BuildCPPHost(IRModule mod, Target target) {
return CSourceModuleCreate(code, "c", cg.GetFunctionNames());
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("target.build.tilelang_cpp", BuildCPPHost);
});
}
} // namespace codegen
} // namespace tvm
......@@ -26,18 +26,19 @@ ExtractFuncInfo(const IRModule &mod) {
}
info.arg_types.push_back(f->params[i].dtype());
}
if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) {
if (auto opt = f->GetAttr<ffi::Array<ffi::String>>(
tir::attr::kKernelLaunchParams)) {
for (const auto &tag : opt.value()) {
info.launch_param_tags.push_back(tag);
}
}
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
fmap[static_cast<std::string>(global_symbol.value())] = info;
}
return fmap;
}
runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
ffi::Module BuildTileLangCUDA(IRModule mod, Target target) {
bool output_ssa = false;
CodeGenTileLangCUDA cg;
cg.Init(output_ssa);
......@@ -70,7 +71,7 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
return runtime::CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
}
runtime::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) {
ffi::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) {
bool output_ssa = false;
CodeGenTileLangCUDA cg;
cg.Init(output_ssa);
......@@ -93,13 +94,13 @@ runtime::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) {
return runtime::CUDAModuleCreate("ptx", "ptx", ExtractFuncInfo(mod), code);
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("target.build.tilelang_cuda", BuildTileLangCUDA)
.def("target.build.tilelang_cuda_without_compile",
BuildTileLangCUDAWithoutCompile);
});
}
} // namespace codegen
} // namespace tvm
......@@ -37,18 +37,19 @@ ExtractFuncInfo(const IRModule &mod) {
}
info.arg_types.push_back(f->params[i].dtype());
}
if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) {
if (auto opt = f->GetAttr<ffi::Array<ffi::String>>(
tir::attr::kKernelLaunchParams)) {
for (const auto &tag : opt.value()) {
info.launch_param_tags.push_back(tag);
}
}
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
fmap[static_cast<std::string>(global_symbol.value())] = info;
}
return fmap;
}
runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
ffi::Module BuildTileLangHIP(IRModule mod, Target target) {
bool output_ssa = false;
CodeGenTileLangHIP cg;
cg.Init(output_ssa);
......@@ -84,7 +85,7 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
return ROCMModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code, std::string());
}
runtime::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) {
ffi::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) {
bool output_ssa = false;
CodeGenTileLangHIP cg;
cg.Init(output_ssa);
......@@ -110,13 +111,13 @@ runtime::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) {
std::string());
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("target.build.tilelang_hip", BuildTileLangHIP)
.def("target.build.tilelang_hip_without_compile",
BuildTileLangHIPWithoutCompile);
});
}
} // namespace codegen
} // namespace tvm
......@@ -5,6 +5,9 @@
#include "utils.h"
#include "../support/ffi_aliases.h"
#include <tvm/node/node.h>
namespace tvm {
namespace tl {
......@@ -16,8 +19,8 @@ bool TargetIsRocm(Target target) {
}
int GetArchInt(Target target) {
auto s = target->GetAttr<String>("arch");
ICHECK(s.defined());
auto s = target->GetAttr<tvm::ffi::String>("arch");
ICHECK(s.has_value());
const std::string arch_str = s.value();
ICHECK(arch_str.size() >= 3);
ICHECK_EQ(arch_str.compare(0, 3, "sm_"), 0)
......@@ -71,7 +74,7 @@ bool TargetIsCDNA(Target target) {
if (!TargetIsRocm(target))
return false;
if (target->attrs.count("mcpu")) {
std::string mcpu = Downcast<String>(target->attrs.at("mcpu"));
std::string mcpu = Downcast<tvm::ffi::String>(target->attrs.at("mcpu"));
// if mcpu start with "gfx9", it is CDNA
return mcpu.find("gfx9") == 0;
}
......@@ -84,7 +87,7 @@ bool TargetHasAsyncCopy(Target target) {
return arch >= 80;
} else if (TargetIsCDNA(target)) {
if (target->attrs.count("mcpu")) {
std::string mcpu = Downcast<String>(target->attrs.at("mcpu"));
std::string mcpu = Downcast<tvm::ffi::String>(target->attrs.at("mcpu"));
if (mcpu.rfind("gfx9", 0) == 0) {
int gfx_version = std::stoi(mcpu.substr(3, 2));
return gfx_version >= 94;
......@@ -131,7 +134,7 @@ int TargetGetWarpSize(Target target) {
return res;
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("tl.TargetIsCuda",
......@@ -160,7 +163,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
[](Target target) { return TargetHasBulkCopy(target); })
.def("tl.TargetGetWarpSize",
[](Target target) { return TargetGetWarpSize(target); });
});
}
} // namespace tl
} // namespace tvm
......@@ -47,7 +47,7 @@ public:
}
Stmt VisitStmt_(const BlockNode *op) final {
Block block = GetRef<Block>(op);
Block block = tvm::ffi::GetRef<Block>(op);
Array<Buffer> alloc_buffers = op->alloc_buffers;
alloc_buffers.MutateByApply([this](Buffer buf) {
auto storage_scope =
......@@ -58,7 +58,7 @@ public:
buf->dtype.bytes());
if (!new_shape.same_as(buf->shape)) {
ObjectPtr<BufferNode> new_buffer =
make_object<BufferNode>(*(buf.get()));
tvm::ffi::make_object<BufferNode>(*(buf.get()));
new_buffer->shape = std::move(new_shape);
buffer_remap_.Set(buf, Buffer(new_buffer));
return Buffer(new_buffer);
......@@ -73,7 +73,7 @@ public:
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store_node = GetRef<BufferStore>(op);
auto store_node = tvm::ffi::GetRef<BufferStore>(op);
Buffer buf = op->buffer;
if (buffer_remap_.count(buf)) {
buf = buffer_remap_[buf];
......@@ -83,7 +83,7 @@ public:
}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load_node = GetRef<BufferLoad>(op);
auto load_node = tvm::ffi::GetRef<BufferLoad>(op);
Buffer buf = op->buffer;
if (buffer_remap_.count(buf)) {
buf = buffer_remap_[buf];
......@@ -149,11 +149,11 @@ tvm::transform::Pass AlignDynamicSharedMemoryAllocations(int align_bytes) {
"tl.AlignDynamicSharedMemoryAllocations", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.AlignDynamicSharedMemoryAllocations",
AlignDynamicSharedMemoryAllocations);
});
}
} // namespace tl
} // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment