"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "27bb8ca68beeea7792938fdad3a5568bcf10857a"
Unverified Commit 10911e28 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[FFI] Rebase tvm to v0.22.0 to utilize tvm-ffi (#1108)



* 3rdparty tvm bump

* bump tvm into v0.22.0

* lint fix

* rebase tvm

* Update submodule tvm to latest commit 3085bc4

* Refactor: Update configuration retrieval in CopyNode and adjust test registration in tilelang

* test fix

* add requirement

* atomic_fix

* atomic_fix

* phaseout py39

* optimize

* optimize

* lint fix

* do not clean cache

* do not clean cache

* [Minor] Minor update for Python versions and dependencies

* [Lint] fix lint for py39

* [Lint] fix lint for ROCm

* [Build][CI] Sync CI changes from upstream/sdist

* [Lint] fix lint for ROCm

* [Build][CI] Update `repair-wheel-command`

* [Minor] update abi3audit result format

* [Lint] fix lint for ROCm

* [BugFix] fix build

* [Lint] fix lint for ROCm

* [BugFix] set rpath for libtvm and libtvm_runtime

* [Deps] pin apache-tvm-ffi version

* [Build] set Python 3.9 Limited API for Cython target

* [Build] set Python 3.9 Limited API for Cython target

* [Deps] Restore Python 3.8 support

* [Build] use `apache-tvm-ffi`'s `libtvm_ffi`

* [BugFix] use `;` as delimiter for RPATH on macOS

* [BugFix] use `--ignore-missing-dependencies` for `delocate-wheel`

* [Build] support `sccache` if available

* [Build] add CIBW import test

* [Build][CI] enable ccache for CIBW on Linux

* [BugFix] set rpath for libtvm and libtvm_runtime

* Revert "[Build][CI] enable ccache for CIBW on Linux"

This reverts commit cd9ab57bb5ddd2572c60bcbbebde81480a658fd3.

* [CI] fix perfbench bot

* [BugFix] use Python 3.9 to build wheel

* [Minor] update perfbench bot envs

* [BugFix] fix CIBW environment on Linux

* [CI] skip import test on CentOS 7

* [CI] use Python urllib to download file instead of Wget

---------
Co-authored-by: default avatarXuehai Pan <XuehaiPan@pku.edu.cn>
parent c37621c5
...@@ -30,23 +30,13 @@ enum class ReduceTypeEnum : uint8_t { ...@@ -30,23 +30,13 @@ enum class ReduceTypeEnum : uint8_t {
class ReduceTypeNode : public Object { class ReduceTypeNode : public Object {
public: public:
int type{-1}; ///< Internal type identifier int type{-1}; ///< Internal type identifier
static constexpr const char *_type_key = "tl.ReduceType"; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ReduceType", ReduceTypeNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(ReduceTypeNode, Object);
static void RegisterReflection() { static void RegisterReflection() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::ObjectDef<ReduceTypeNode>().def_ro("type", &ReduceTypeNode::type); refl::ObjectDef<ReduceTypeNode>().def_ro("type", &ReduceTypeNode::type);
} }
bool SEqualReduce(const ReduceTypeNode *other, SEqualReducer equal) const {
return equal(type, other->type);
}
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(type); }
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
/// Type checking methods /// Type checking methods
bool isSum() const { return type == int(ReduceTypeEnum::kSum); } bool isSum() const { return type == int(ReduceTypeEnum::kSum); }
bool isAbsSum() const { return type == int(ReduceTypeEnum::kAbsSum); } bool isAbsSum() const { return type == int(ReduceTypeEnum::kAbsSum); }
...@@ -61,9 +51,10 @@ public: ...@@ -61,9 +51,10 @@ public:
/// Wrapper class for reduction type with string-based construction /// Wrapper class for reduction type with string-based construction
class ReduceType : public ObjectRef { class ReduceType : public ObjectRef {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(ReduceType, ObjectRef, ReduceTypeNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReduceType, ObjectRef,
ReduceTypeNode);
TVM_DLL ReduceType(std::string type) { TVM_DLL ReduceType(std::string type) {
auto node = make_object<ReduceTypeNode>(); auto node = tvm::ffi::make_object<ReduceTypeNode>();
if (type == "sum") { if (type == "sum") {
node->type = int(ReduceTypeEnum::kSum); node->type = int(ReduceTypeEnum::kSum);
} else if (type == "abssum") { } else if (type == "abssum") {
...@@ -95,8 +86,8 @@ public: ...@@ -95,8 +86,8 @@ public:
ReduceType type; ///< Type of reduction operation ReduceType type; ///< Type of reduction operation
bool clear; ///< Whether to clear destination before reduction bool clear; ///< Whether to clear destination before reduction
static constexpr const char *_type_key = "tl.ReduceOp"; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ReduceOp", ReduceOpNode,
TVM_DECLARE_FINAL_OBJECT_INFO(ReduceOpNode, TileOperatorNode); TileOperatorNode);
static void RegisterReflection() { static void RegisterReflection() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
...@@ -108,23 +99,6 @@ public: ...@@ -108,23 +99,6 @@ public:
.def_ro("clear", &ReduceOpNode::clear); .def_ro("clear", &ReduceOpNode::clear);
} }
bool SEqualReduce(const ReduceOpNode *other, SEqualReducer equal) const {
return equal(src, other->src) && equal(dst, other->dst) &&
equal(dim, other->dim) && equal(type, other->type) &&
equal(clear, other->clear);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(src);
hash_reduce(dst);
hash_reduce(dim);
hash_reduce(type);
hash_reduce(clear);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
/// Lower the operator to TIR statements /// Lower the operator to TIR statements
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
/// Infer memory layout for buffers /// Infer memory layout for buffers
...@@ -145,7 +119,8 @@ private: ...@@ -145,7 +119,8 @@ private:
/// Wrapper class for reduction operations /// Wrapper class for reduction operations
class ReduceOp : public TileOperator { class ReduceOp : public TileOperator {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(ReduceOp, TileOperator, ReduceOpNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReduceOp, TileOperator,
ReduceOpNode);
TVM_DLL ReduceOp(Array<PrimExpr> args, BufferMap vmap); TVM_DLL ReduceOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get(); static const Op &Get();
}; };
...@@ -156,8 +131,17 @@ public: ...@@ -156,8 +131,17 @@ public:
tir::Buffer src, dst; ///< Source and destination buffers tir::Buffer src, dst; ///< Source and destination buffers
int dim; ///< Dimension along which to compute cumulative sum int dim; ///< Dimension along which to compute cumulative sum
bool reverse; ///< Whether to compute in reverse order bool reverse; ///< Whether to compute in reverse order
static constexpr const char *_type_key = "tl.CumSumOp"; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.CumSumOp", CumSumOpNode,
TVM_DECLARE_FINAL_OBJECT_INFO(CumSumOpNode, TileOperatorNode); TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<CumSumOpNode>()
.def_ro("src", &CumSumOpNode::src)
.def_ro("dst", &CumSumOpNode::dst)
.def_ro("dim", &CumSumOpNode::dim)
.def_ro("reverse", &CumSumOpNode::reverse);
}
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T, LayoutMap InferLayout(const LayoutInferArgs &T,
...@@ -169,7 +153,8 @@ public: ...@@ -169,7 +153,8 @@ public:
/// Wrapper class for cumulative sum operations /// Wrapper class for cumulative sum operations
class CumSumOp : public TileOperator { class CumSumOp : public TileOperator {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(CumSumOp, TileOperator, CumSumOpNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CumSumOp, TileOperator,
CumSumOpNode);
TVM_DLL CumSumOp(Array<PrimExpr> args, BufferMap vmap); TVM_DLL CumSumOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get(); static const Op &Get();
}; };
......
...@@ -44,7 +44,7 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) { ...@@ -44,7 +44,7 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
PrimExpr extent = args[2 + i]; PrimExpr extent = args[2 + i];
ranges.push_back(Range::FromMinExtent(min, extent)); ranges.push_back(Range::FromMinExtent(min, extent));
} }
ObjectPtr<RegionOpNode> node = make_object<RegionOpNode>(); ObjectPtr<RegionOpNode> node = tvm::ffi::make_object<RegionOpNode>();
node->buffer_ = load->buffer; node->buffer_ = load->buffer;
node->access_mask_ = static_cast<int>(*as_const_int(args[1])); node->access_mask_ = static_cast<int>(*as_const_int(args[1]));
node->ranges_ = ranges; node->ranges_ = ranges;
...@@ -57,7 +57,7 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) { ...@@ -57,7 +57,7 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A new TileOperator that owns a copied RegionOpNode. * @return TileOperator A new TileOperator that owns a copied RegionOpNode.
*/ */
TileOperator RegionOpNode::Clone() const { TileOperator RegionOpNode::Clone() const {
auto op = make_object<RegionOpNode>(*this); auto op = tvm::ffi::make_object<RegionOpNode>(*this);
return RegionOp(op); return RegionOp(op);
} }
...@@ -118,5 +118,7 @@ TIR_REGISTER_TL_OP(RegionOp, region) ...@@ -118,5 +118,7 @@ TIR_REGISTER_TL_OP(RegionOp, region)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure)); Integer(CallEffectKind::kPure));
TVM_FFI_STATIC_INIT_BLOCK() { RegionOpNode::RegisterReflection(); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -80,8 +80,8 @@ public: ...@@ -80,8 +80,8 @@ public:
Array<Range> ranges_; Array<Range> ranges_;
int access_mask_; int access_mask_;
static constexpr const char *_type_key = "tl.RegionOp"; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.RegionOp", RegionOpNode,
TVM_DECLARE_FINAL_OBJECT_INFO(RegionOpNode, TileOperatorNode); TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T, LayoutMap InferLayout(const LayoutInferArgs &T,
...@@ -101,25 +101,12 @@ public: ...@@ -101,25 +101,12 @@ public:
.def_ro("ranges", &RegionOpNode::ranges_) .def_ro("ranges", &RegionOpNode::ranges_)
.def_ro("access_mask", &RegionOpNode::access_mask_); .def_ro("access_mask", &RegionOpNode::access_mask_);
} }
bool SEqualReduce(const RegionOpNode *other, SEqualReducer equal) const {
return equal(buffer_, other->buffer_) && equal(ranges_, other->ranges_) &&
equal(access_mask_, other->access_mask_);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(buffer_);
hash_reduce(ranges_);
hash_reduce(access_mask_);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
}; };
class RegionOp : public TileOperator { class RegionOp : public TileOperator {
public: public:
TVM_DEFINE_OBJECT_REF_METHODS(RegionOp, TileOperator, RegionOpNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(RegionOp, TileOperator,
RegionOpNode);
TVM_DLL RegionOp(Array<PrimExpr> args, BufferMap vmap); TVM_DLL RegionOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get(); static const Op &Get();
......
...@@ -89,7 +89,7 @@ struct TensorMapArgs { ...@@ -89,7 +89,7 @@ struct TensorMapArgs {
}; };
// set device api // set device api
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed("tvm_tensormap_create_tiled", [](PackedArgs args, refl::GlobalDef().def_packed("tvm_tensormap_create_tiled", [](PackedArgs args,
Any *ret) { Any *ret) {
...@@ -104,7 +104,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ ...@@ -104,7 +104,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
} }
*ret = static_cast<int>(result); *ret = static_cast<int>(result);
}); });
}); }
struct TensorMapIm2ColArgs { struct TensorMapIm2ColArgs {
CUtensorMap *map; CUtensorMap *map;
...@@ -180,7 +180,7 @@ struct TensorMapIm2ColArgs { ...@@ -180,7 +180,7 @@ struct TensorMapIm2ColArgs {
} }
}; };
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed( refl::GlobalDef().def_packed(
"tvm_tensormap_create_im2col", [](PackedArgs args, Any *ret) { "tvm_tensormap_create_im2col", [](PackedArgs args, Any *ret) {
...@@ -197,7 +197,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ ...@@ -197,7 +197,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
} }
*ret = static_cast<int>(result); *ret = static_cast<int>(result);
}); });
}); }
#endif // (CUDA_MAJOR_VERSION >= 12) #endif // (CUDA_MAJOR_VERSION >= 12)
......
#pragma once
#include <tvm/ffi/cast.h>
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/memory.h>
#include <tvm/ffi/optional.h>
#include <tvm/ffi/string.h>
namespace tvm {
using ffi::Array;
using ffi::Function;
using ffi::Map;
using ffi::Optional;
using ffi::String;
} // namespace tvm
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include "../support/ffi_aliases.h"
#include "support/str_escape.h" #include "support/str_escape.h"
#include "target/build_common.h" #include "target/build_common.h"
#include "target/source/codegen_params.h" #include "target/source/codegen_params.h"
...@@ -54,8 +55,7 @@ void CodeGenTileLangCPP::Init(bool output_ssa, bool emit_asserts, ...@@ -54,8 +55,7 @@ void CodeGenTileLangCPP::Init(bool output_ssa, bool emit_asserts,
} }
void CodeGenTileLangCPP::InitGlobalContext() { void CodeGenTileLangCPP::InitGlobalContext() {
decl_stream << "void* " << tvm::runtime::symbol::tvm_ffi_library_ctx decl_stream << "void* " << ffi::symbol::tvm_ffi_library_ctx << " = NULL;\n";
<< " = NULL;\n";
} }
void CodeGenTileLangCPP::DefineModuleName() { void CodeGenTileLangCPP::DefineModuleName() {
...@@ -256,8 +256,8 @@ void CodeGenTileLangCPP::AddFunction(const PrimFunc &f) { ...@@ -256,8 +256,8 @@ void CodeGenTileLangCPP::AddFunction(const PrimFunc &f) {
// reserve keywords // reserve keywords
ReserveKeywordsAsUnique(); ReserveKeywordsAsUnique();
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); auto global_symbol = f->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined()) ICHECK(global_symbol)
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
......
...@@ -73,10 +73,10 @@ public: ...@@ -73,10 +73,10 @@ public:
void VisitStmt_(const AssertStmtNode *op) final; // NOLINT(*) void VisitStmt_(const AssertStmtNode *op) final; // NOLINT(*)
void VisitStmt_(const AllocateNode *op) final; // NOLINT(*) void VisitStmt_(const AllocateNode *op) final; // NOLINT(*)
void GenerateForwardFunctionDeclarations(String global_symbol, void GenerateForwardFunctionDeclarations(ffi::String global_symbol,
const Array<Type> &arg_types, const ffi::Array<Type> &arg_types,
const Type &ret_type) override; const Type &ret_type) override;
Array<String> GetFunctionNames() { return function_names_; } ffi::Array<ffi::String> GetFunctionNames() { return function_names_; }
private: private:
/* \brief Internal structure to store information about function calls */ /* \brief Internal structure to store information about function calls */
...@@ -92,7 +92,7 @@ private: ...@@ -92,7 +92,7 @@ private:
/* \brief mapping global packed func to the unique name */ /* \brief mapping global packed func to the unique name */
std::unordered_map<std::string, std::string> declared_globals_; std::unordered_map<std::string, std::string> declared_globals_;
/* \brief names of the functions declared in this module */ /* \brief names of the functions declared in this module */
Array<String> function_names_; ffi::Array<ffi::String> function_names_;
/*! \brief whether to emit asserts in the resulting C code */ /*! \brief whether to emit asserts in the resulting C code */
bool emit_asserts_; bool emit_asserts_;
/*! \brief whether to emit forward function declarations in the resulting C /*! \brief whether to emit forward function declarations in the resulting C
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
using namespace tvm::tl::codegen; using namespace tvm::tl::codegen;
using namespace ffi;
struct CUDAMath { struct CUDAMath {
std::string operator()(DataType t, std::string name) const { std::string operator()(DataType t, std::string name) const {
...@@ -2165,8 +2166,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -2165,8 +2166,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
"A_ptr, B_ptr, C_ptr>, but got " "A_ptr, B_ptr, C_ptr>, but got "
<< op->args.size(); << op->args.size();
auto op_instance = Downcast<StringImm>(op->args[0]); auto op_instance = Downcast<StringImm>(op->args[0]);
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_instance->value, this->PrintCallExtern(GetType(tvm::ffi::GetRef<PrimExpr>(op)),
op->args, true, os); op_instance->value, op->args, true, os);
} else if (op->op.same_as(tl::tl_gemm_sp())) { } else if (op->op.same_as(tl::tl_gemm_sp())) {
ICHECK(op->args.size() == 5) ICHECK(op->args.size() == 5)
<< "tl_gemm_sp expects 5 arguments <op_instance, A_ptr, B_ptr, C_ptr, " << "tl_gemm_sp expects 5 arguments <op_instance, A_ptr, B_ptr, C_ptr, "
...@@ -2174,8 +2175,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -2174,8 +2175,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
<< op->args.size(); << op->args.size();
auto op_instance = Downcast<StringImm>(op->args[0]); auto op_instance = Downcast<StringImm>(op->args[0]);
enable_sparse_gemm_ = true; enable_sparse_gemm_ = true;
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_instance->value, this->PrintCallExtern(GetType(tvm::ffi::GetRef<PrimExpr>(op)),
op->args, true, os); op_instance->value, op->args, true, os);
} else if (op->op.same_as(tl::get_lane_idx())) { } else if (op->op.same_as(tl::get_lane_idx())) {
ICHECK_LE(op->args.size(), 1) ICHECK_LE(op->args.size(), 1)
<< "tl.get_lane_idx expects at most one argument <warp_size>."; << "tl.get_lane_idx expects at most one argument <warp_size>.";
...@@ -2458,8 +2459,8 @@ void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) { ...@@ -2458,8 +2459,8 @@ void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) {
void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) { void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) {
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value); int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
CHECK_LE(lanes, 4) << "Translate Ramp Node " << GetRef<Ramp>(op) << " with " CHECK_LE(lanes, 4) << "Translate Ramp Node " << tvm::ffi::GetRef<Ramp>(op)
<< lanes << " lanes is not allowed."; << " with " << lanes << " lanes is not allowed.";
os << "(make_"; os << "(make_";
PrintType(op->dtype, os); PrintType(op->dtype, os);
os << "("; os << "(";
...@@ -2971,7 +2972,7 @@ void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar, ...@@ -2971,7 +2972,7 @@ void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar,
ReserveKeywordsAsUnique(); ReserveKeywordsAsUnique();
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined()) ICHECK(global_symbol)
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
......
...@@ -60,14 +60,14 @@ public: ...@@ -60,14 +60,14 @@ public:
// Override this as a work around for __grid_constant__ parameter // Override this as a work around for __grid_constant__ parameter
void AddFunction(const GlobalVar &gvar, const PrimFunc &f); void AddFunction(const GlobalVar &gvar, const PrimFunc &f);
void PrintFunctionSignature(const String &function_name, const PrimFunc &func, void PrintFunctionSignature(const ffi::String &function_name,
std::ostream &os); const PrimFunc &func, std::ostream &os);
protected: protected:
virtual std::string GetBufferRef(DataType t, const BufferNode *buffer, virtual std::string GetBufferRef(DataType t, const BufferNode *buffer,
PrimExpr index) final; PrimExpr index) final;
void PrintCallExtern(Type ret_type, String global_symbol, void PrintCallExtern(Type ret_type, ffi::String global_symbol,
const Array<PrimExpr> &args, bool skip_first_arg, const ffi::Array<PrimExpr> &args, bool skip_first_arg,
std::ostream &os) final; // NOLINT(*) std::ostream &os) final; // NOLINT(*)
private: private:
......
...@@ -959,8 +959,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -959,8 +959,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
"A_ptr, B_ptr, C_ptr>, but got " "A_ptr, B_ptr, C_ptr>, but got "
<< op->args.size(); << op->args.size();
auto op_instance = Downcast<StringImm>(op->args[0]); auto op_instance = Downcast<StringImm>(op->args[0]);
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_instance->value, this->PrintCallExtern(GetType(tvm::ffi::GetRef<PrimExpr>(op)),
op->args, true, os); op_instance->value, op->args, true, os);
} else if (op->op.same_as(tl::tl_gemm_sp())) { } else if (op->op.same_as(tl::tl_gemm_sp())) {
LOG(FATAL) << "tl_gemm_sp is not supported on HIP"; LOG(FATAL) << "tl_gemm_sp is not supported on HIP";
} else if (op->op.same_as(tl::loop_break())) { } else if (op->op.same_as(tl::loop_break())) {
...@@ -1309,7 +1309,7 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc &f) { ...@@ -1309,7 +1309,7 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc &f) {
ReserveKeywordsAsUnique(); ReserveKeywordsAsUnique();
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined()) ICHECK(global_symbol.has_value())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
......
...@@ -56,8 +56,8 @@ public: ...@@ -56,8 +56,8 @@ public:
protected: protected:
virtual std::string GetBufferRef(DataType t, const BufferNode *buffer, virtual std::string GetBufferRef(DataType t, const BufferNode *buffer,
PrimExpr index) final; PrimExpr index) final;
void PrintCallExtern(Type ret_type, String global_symbol, void PrintCallExtern(Type ret_type, ffi::String global_symbol,
const Array<PrimExpr> &args, bool skip_first_arg, const ffi::Array<PrimExpr> &args, bool skip_first_arg,
std::ostream &os) final; // NOLINT(*) std::ostream &os) final; // NOLINT(*)
private: private:
......
This diff is collapsed.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_webgpu.h
* \brief Generate WebGPU shaders in WGSL.
*
* This module generates WGSL shading language.
* See https://www.w3.org/TR/WGSL/ for the language reference.
*/
#ifndef TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_
#define TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_
#include <tvm/target/codegen.h>
#include <string>
#include "target/source/codegen_c.h"
namespace tvm {
namespace codegen {
/*!
* \brief WebGPU code generator.
*
* Note WGSL have a different syntax from normal C.
* We only leverage the C for expression generation and
* write most of the language generations.
*/
class CodeGenTileLangWebGPU final : public CodeGenC {
public:
explicit CodeGenTileLangWebGPU(Target target);
// overrides
std::string Finish() final;
using CodeGenC::AddFunction;
runtime::FunctionInfo AddFunction(const PrimFunc &f,
bool skip_readonly_decl); // NOLINT(*)
void InitFuncState(const PrimFunc &f) final;
void PrintStorageSync(const CallNode *op) final; // NOLINT(*)
void PrintType(DataType t, std::ostream &os) final; // NOLINT(*)
void BindThreadIndex(const IterVar &iv) final; // NOLINT(*)
// assignment printing
void PrintSSAAssign(const std::string &target, const std::string &src,
DataType type) final;
// overload visitor
void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const CallNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const BufferLoadNode *op,
std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const CastNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const SelectNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const FloatImmNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const IntImmNode *op, std::ostream &os) final; // NOLINT(*)
// stmt printing
void VisitStmt_(const LetStmtNode *op) final;
void VisitStmt_(const BufferStoreNode *op) final;
void VisitStmt_(const ForNode *op) final;
void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AssertStmtNode *op) final;
void VisitStmt_(const AllocateConstNode *op) final;
void VisitStmt_(const WhileNode *op) final;
private:
/*!
* \brief Enforce value to be U32.
*/
static PrimExpr EnforceU32(PrimExpr value);
/*!
* \brief Storage type of bool values.
*/
DataType boolean_storage_type_{DataType::Int(8)};
// whether enable fp16
bool enable_fp16_{false};
/*! \brief the header stream for function label and enable directive if any,
* goes before any other declaration */
std::ostringstream header_stream;
Target target_;
};
} // namespace codegen
} // namespace tvm
#endif // TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/op_attr_types.h> #include <tvm/tir/op_attr_types.h>
#include "../support/ffi_aliases.h"
#include "target/intrin_rule.h" #include "target/intrin_rule.h"
namespace tvm { namespace tvm {
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/op_attr_types.h> #include <tvm/tir/op_attr_types.h>
#include "../support/ffi_aliases.h"
#include "target/intrin_rule.h" #include "target/intrin_rule.h"
namespace tvm { namespace tvm {
......
#include "codegen_cpp.h" #include "codegen_cpp.h"
#include <tvm/ffi/extra/module.h>
#include <tvm/ffi/reflection/registry.h> #include <tvm/ffi/reflection/registry.h>
#include "../support/ffi_aliases.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
runtime::Module BuildCPPHost(IRModule mod, Target target) { ffi::Module BuildCPPHost(IRModule mod, Target target) {
bool output_ssa = false; bool output_ssa = false;
bool emit_asserts = false; bool emit_asserts = false;
bool emit_fwd_func_decl = true; bool emit_fwd_func_decl = true;
...@@ -67,10 +70,10 @@ runtime::Module BuildCPPHost(IRModule mod, Target target) { ...@@ -67,10 +70,10 @@ runtime::Module BuildCPPHost(IRModule mod, Target target) {
return CSourceModuleCreate(code, "c", cg.GetFunctionNames()); return CSourceModuleCreate(code, "c", cg.GetFunctionNames());
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("target.build.tilelang_cpp", BuildCPPHost); refl::GlobalDef().def("target.build.tilelang_cpp", BuildCPPHost);
}); }
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -26,18 +26,19 @@ ExtractFuncInfo(const IRModule &mod) { ...@@ -26,18 +26,19 @@ ExtractFuncInfo(const IRModule &mod) {
} }
info.arg_types.push_back(f->params[i].dtype()); info.arg_types.push_back(f->params[i].dtype());
} }
if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) { if (auto opt = f->GetAttr<ffi::Array<ffi::String>>(
tir::attr::kKernelLaunchParams)) {
for (const auto &tag : opt.value()) { for (const auto &tag : opt.value()) {
info.launch_param_tags.push_back(tag); info.launch_param_tags.push_back(tag);
} }
} }
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); auto global_symbol = f->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
fmap[static_cast<std::string>(global_symbol.value())] = info; fmap[static_cast<std::string>(global_symbol.value())] = info;
} }
return fmap; return fmap;
} }
runtime::Module BuildTileLangCUDA(IRModule mod, Target target) { ffi::Module BuildTileLangCUDA(IRModule mod, Target target) {
bool output_ssa = false; bool output_ssa = false;
CodeGenTileLangCUDA cg; CodeGenTileLangCUDA cg;
cg.Init(output_ssa); cg.Init(output_ssa);
...@@ -70,7 +71,7 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) { ...@@ -70,7 +71,7 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
return runtime::CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code); return runtime::CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
} }
runtime::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) { ffi::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) {
bool output_ssa = false; bool output_ssa = false;
CodeGenTileLangCUDA cg; CodeGenTileLangCUDA cg;
cg.Init(output_ssa); cg.Init(output_ssa);
...@@ -93,13 +94,13 @@ runtime::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) { ...@@ -93,13 +94,13 @@ runtime::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) {
return runtime::CUDAModuleCreate("ptx", "ptx", ExtractFuncInfo(mod), code); return runtime::CUDAModuleCreate("ptx", "ptx", ExtractFuncInfo(mod), code);
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef() refl::GlobalDef()
.def("target.build.tilelang_cuda", BuildTileLangCUDA) .def("target.build.tilelang_cuda", BuildTileLangCUDA)
.def("target.build.tilelang_cuda_without_compile", .def("target.build.tilelang_cuda_without_compile",
BuildTileLangCUDAWithoutCompile); BuildTileLangCUDAWithoutCompile);
}); }
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -37,18 +37,19 @@ ExtractFuncInfo(const IRModule &mod) { ...@@ -37,18 +37,19 @@ ExtractFuncInfo(const IRModule &mod) {
} }
info.arg_types.push_back(f->params[i].dtype()); info.arg_types.push_back(f->params[i].dtype());
} }
if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) { if (auto opt = f->GetAttr<ffi::Array<ffi::String>>(
tir::attr::kKernelLaunchParams)) {
for (const auto &tag : opt.value()) { for (const auto &tag : opt.value()) {
info.launch_param_tags.push_back(tag); info.launch_param_tags.push_back(tag);
} }
} }
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); auto global_symbol = f->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
fmap[static_cast<std::string>(global_symbol.value())] = info; fmap[static_cast<std::string>(global_symbol.value())] = info;
} }
return fmap; return fmap;
} }
runtime::Module BuildTileLangHIP(IRModule mod, Target target) { ffi::Module BuildTileLangHIP(IRModule mod, Target target) {
bool output_ssa = false; bool output_ssa = false;
CodeGenTileLangHIP cg; CodeGenTileLangHIP cg;
cg.Init(output_ssa); cg.Init(output_ssa);
...@@ -84,7 +85,7 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) { ...@@ -84,7 +85,7 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
return ROCMModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code, std::string()); return ROCMModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code, std::string());
} }
runtime::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) { ffi::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) {
bool output_ssa = false; bool output_ssa = false;
CodeGenTileLangHIP cg; CodeGenTileLangHIP cg;
cg.Init(output_ssa); cg.Init(output_ssa);
...@@ -110,13 +111,13 @@ runtime::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) { ...@@ -110,13 +111,13 @@ runtime::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) {
std::string()); std::string());
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef() refl::GlobalDef()
.def("target.build.tilelang_hip", BuildTileLangHIP) .def("target.build.tilelang_hip", BuildTileLangHIP)
.def("target.build.tilelang_hip_without_compile", .def("target.build.tilelang_hip_without_compile",
BuildTileLangHIPWithoutCompile); BuildTileLangHIPWithoutCompile);
}); }
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
#include "utils.h" #include "utils.h"
#include "../support/ffi_aliases.h"
#include <tvm/node/node.h>
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -16,8 +19,8 @@ bool TargetIsRocm(Target target) { ...@@ -16,8 +19,8 @@ bool TargetIsRocm(Target target) {
} }
int GetArchInt(Target target) { int GetArchInt(Target target) {
auto s = target->GetAttr<String>("arch"); auto s = target->GetAttr<tvm::ffi::String>("arch");
ICHECK(s.defined()); ICHECK(s.has_value());
const std::string arch_str = s.value(); const std::string arch_str = s.value();
ICHECK(arch_str.size() >= 3); ICHECK(arch_str.size() >= 3);
ICHECK_EQ(arch_str.compare(0, 3, "sm_"), 0) ICHECK_EQ(arch_str.compare(0, 3, "sm_"), 0)
...@@ -71,7 +74,7 @@ bool TargetIsCDNA(Target target) { ...@@ -71,7 +74,7 @@ bool TargetIsCDNA(Target target) {
if (!TargetIsRocm(target)) if (!TargetIsRocm(target))
return false; return false;
if (target->attrs.count("mcpu")) { if (target->attrs.count("mcpu")) {
std::string mcpu = Downcast<String>(target->attrs.at("mcpu")); std::string mcpu = Downcast<tvm::ffi::String>(target->attrs.at("mcpu"));
// if mcpu start with "gfx9", it is CDNA // if mcpu start with "gfx9", it is CDNA
return mcpu.find("gfx9") == 0; return mcpu.find("gfx9") == 0;
} }
...@@ -84,7 +87,7 @@ bool TargetHasAsyncCopy(Target target) { ...@@ -84,7 +87,7 @@ bool TargetHasAsyncCopy(Target target) {
return arch >= 80; return arch >= 80;
} else if (TargetIsCDNA(target)) { } else if (TargetIsCDNA(target)) {
if (target->attrs.count("mcpu")) { if (target->attrs.count("mcpu")) {
std::string mcpu = Downcast<String>(target->attrs.at("mcpu")); std::string mcpu = Downcast<tvm::ffi::String>(target->attrs.at("mcpu"));
if (mcpu.rfind("gfx9", 0) == 0) { if (mcpu.rfind("gfx9", 0) == 0) {
int gfx_version = std::stoi(mcpu.substr(3, 2)); int gfx_version = std::stoi(mcpu.substr(3, 2));
return gfx_version >= 94; return gfx_version >= 94;
...@@ -131,7 +134,7 @@ int TargetGetWarpSize(Target target) { ...@@ -131,7 +134,7 @@ int TargetGetWarpSize(Target target) {
return res; return res;
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef() refl::GlobalDef()
.def("tl.TargetIsCuda", .def("tl.TargetIsCuda",
...@@ -160,7 +163,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ ...@@ -160,7 +163,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
[](Target target) { return TargetHasBulkCopy(target); }) [](Target target) { return TargetHasBulkCopy(target); })
.def("tl.TargetGetWarpSize", .def("tl.TargetGetWarpSize",
[](Target target) { return TargetGetWarpSize(target); }); [](Target target) { return TargetGetWarpSize(target); });
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -47,7 +47,7 @@ public: ...@@ -47,7 +47,7 @@ public:
} }
Stmt VisitStmt_(const BlockNode *op) final { Stmt VisitStmt_(const BlockNode *op) final {
Block block = GetRef<Block>(op); Block block = tvm::ffi::GetRef<Block>(op);
Array<Buffer> alloc_buffers = op->alloc_buffers; Array<Buffer> alloc_buffers = op->alloc_buffers;
alloc_buffers.MutateByApply([this](Buffer buf) { alloc_buffers.MutateByApply([this](Buffer buf) {
auto storage_scope = auto storage_scope =
...@@ -58,7 +58,7 @@ public: ...@@ -58,7 +58,7 @@ public:
buf->dtype.bytes()); buf->dtype.bytes());
if (!new_shape.same_as(buf->shape)) { if (!new_shape.same_as(buf->shape)) {
ObjectPtr<BufferNode> new_buffer = ObjectPtr<BufferNode> new_buffer =
make_object<BufferNode>(*(buf.get())); tvm::ffi::make_object<BufferNode>(*(buf.get()));
new_buffer->shape = std::move(new_shape); new_buffer->shape = std::move(new_shape);
buffer_remap_.Set(buf, Buffer(new_buffer)); buffer_remap_.Set(buf, Buffer(new_buffer));
return Buffer(new_buffer); return Buffer(new_buffer);
...@@ -73,7 +73,7 @@ public: ...@@ -73,7 +73,7 @@ public:
} }
Stmt VisitStmt_(const BufferStoreNode *op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store_node = GetRef<BufferStore>(op); auto store_node = tvm::ffi::GetRef<BufferStore>(op);
Buffer buf = op->buffer; Buffer buf = op->buffer;
if (buffer_remap_.count(buf)) { if (buffer_remap_.count(buf)) {
buf = buffer_remap_[buf]; buf = buffer_remap_[buf];
...@@ -83,7 +83,7 @@ public: ...@@ -83,7 +83,7 @@ public:
} }
PrimExpr VisitExpr_(const BufferLoadNode *op) final { PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load_node = GetRef<BufferLoad>(op); auto load_node = tvm::ffi::GetRef<BufferLoad>(op);
Buffer buf = op->buffer; Buffer buf = op->buffer;
if (buffer_remap_.count(buf)) { if (buffer_remap_.count(buf)) {
buf = buffer_remap_[buf]; buf = buffer_remap_[buf];
...@@ -149,11 +149,11 @@ tvm::transform::Pass AlignDynamicSharedMemoryAllocations(int align_bytes) { ...@@ -149,11 +149,11 @@ tvm::transform::Pass AlignDynamicSharedMemoryAllocations(int align_bytes) {
"tl.AlignDynamicSharedMemoryAllocations", {}); "tl.AlignDynamicSharedMemoryAllocations", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.AlignDynamicSharedMemoryAllocations", refl::GlobalDef().def("tl.transform.AlignDynamicSharedMemoryAllocations",
AlignDynamicSharedMemoryAllocations); AlignDynamicSharedMemoryAllocations);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment