/*! * \file tl/op/op.h * \brief Tile library operations. * */ #ifndef TVM_TL_OP_OP_H_ #define TVM_TL_OP_OP_H_ #include #include #include #include #include #include #include #include "../layout/layout.h" namespace tvm { namespace tl { using namespace tir; using AddWorkspaceCallback = std::function; using LayoutMap = Map; using BufferMap = Map; enum class InferLevel : uint8_t { kFree = 0, kCommon = 1, kStrict = 2, }; struct LowerArgs { Target target; Range thread_bounds; Var thread_var; AddWorkspaceCallback AddWorkspace; LayoutMap layout_map; Map buffer_remap; Array buffer_var_gemm; }; struct LayoutInferArgs { Target target; Range thread_bounds; LayoutMap layout_map; arith::Analyzer *analyzer; bool buffer_oob = false; Map buffer_remap; }; class TileOperator; class TileOperatorNode : public Object { public: virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const = 0; virtual LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const = 0; virtual TileOperator Clone() const = 0; static constexpr const char *_type_key = "tl.TileOperator"; TVM_DECLARE_BASE_OBJECT_INFO(TileOperatorNode, Object); }; class TileOperator : public ObjectRef { public: TVM_DEFINE_OBJECT_REF_METHODS(TileOperator, ObjectRef, TileOperatorNode); }; Var GetVarFromAccessPtr(const PrimExpr &expr); TileOperator ParseOperator(Call call, BufferMap vmap); TileOperator ParseOperator(Stmt stmt, BufferMap vmap); using OpBuilderFunc = ffi::TypedFunction, BufferMap)>; #define TIR_REGISTER_TL_OP(Entry, OpName) \ const Op &Entry::Get() { \ static const Op &op = Op::Get("tl." #OpName); \ return op; \ } \ TVM_REGISTER_OP("tl." #OpName) \ .set_attr("TScriptPrinterName", #OpName) \ .set_attr("TLOpBuilder", \ [](Array args, BufferMap vmap) { \ return Entry(args, vmap); \ }) } // namespace tl } // namespace tvm #endif // TVM_TL_OP_OP_H_