/*! * \file tl/op/op.h * \brief Tile library operations. * */ #ifndef TVM_TL_OP_OP_H_ #define TVM_TL_OP_OP_H_ #include #include #include #include #include "../layout/layout.h" namespace tvm { namespace tl { using namespace tir; using AddWorkspaceCallback = std::function; using LayoutMap = Map; using BufferMap = Map; using OpBuilderFunc = TypedPackedFunc, BufferMap)>; #define TIR_REGISTER_TL_OP(Entry, OpName) \ const Op &Entry::Get() { \ static const Op &op = Op::Get("tl." #OpName); \ return op; \ } \ TVM_REGISTER_OP("tl." #OpName) \ .set_attr("TScriptPrinterName", #OpName) \ .set_attr("TLOpBuilder", \ [](Array a, BufferMap b) { \ return (void *)(new Entry(a, b)); \ }) enum class InferLevel { kFree = 0, kCommon = 1, kStrict = 2, }; struct LowerArgs { Target target; Range thread_bounds; Var thread_var; AddWorkspaceCallback AddWorkspace; LayoutMap layout_map; Map buffer_remap; bool disable_tma_lower; }; struct LayoutInferArgs { Target target; Range thread_bounds; LayoutMap layout_map; Map buffer_remap; }; struct CanonializeArgs { Target target; }; class Operator { public: virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const; virtual Stmt Canonialize(const CanonializeArgs &T, arith::Analyzer *analyzer) const; virtual LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level); virtual ~Operator() = default; }; class RegionOp : public Operator { public: RegionOp(Array args, BufferMap vmap); static const Op &Get(); const Buffer &GetBuffer() const { return buffer_; } const Array &GetRanges() const { return ranges_; } int GetAccessMask() const { return access_mask_; } bool IsFullRegion() const; private: Buffer buffer_; Array ranges_; int access_mask_; }; Var GetVarFromAccessPtr(const PrimExpr &expr); std::unique_ptr ParseOperator(Call call, BufferMap vmap); std::unique_ptr ParseOperator(Stmt stmt, BufferMap vmap); } // namespace tl } // namespace tvm #endif // TVM_TL_OP_OP_H_