/*! * \file Layout.h * */ #ifndef TVM_TL_LAYOUT_LAYOUT_H_ #define TVM_TL_LAYOUT_LAYOUT_H_ #include #include #include namespace tvm { namespace tl { using namespace tir; class Layout; class Fragment; class LayoutNode : public Object { public: LayoutNode() = default; LayoutNode(Array input_size, Array forward_index); size_t InputDim() const { return input_size_.size(); } size_t OutputDim() const { return forward_index_.size(); } Array InputShape() const { return input_size_; } Array OutputShape() const; Array GetForwardIndex() const { return forward_index_; } virtual Array GetForwardVars() const; virtual Array Forward(const Array &vars) const; virtual Layout Inverse() const; virtual std::pair InverseWithLevel() const; virtual std::string DebugOutput() const; virtual bool IsEqual(const LayoutNode *other, bool skip_index = false) const; static constexpr bool _type_has_method_sequal_reduce = true; static constexpr const char *_type_key = "tl.Layout"; bool SEqualReduce(const LayoutNode *other, SEqualReducer equal) const; static void RegisterReflection(); TVM_DECLARE_BASE_OBJECT_INFO(LayoutNode, Object); protected: virtual Map getVarMap() const; void UpdateAnalyzer(arith::Analyzer *analyzer) const; Array forward_index_; Array input_size_; }; /*! * \brief Layout reference class. */ class Layout : public ObjectRef { public: TVM_DLL Layout(Array forward_var, Array forward_index); TVM_DLL Layout(Array input_size, Array forward_index); TVM_DEFINE_OBJECT_REF_METHODS(Layout, ObjectRef, LayoutNode); }; class FragmentNode : public LayoutNode { public: FragmentNode() = default; FragmentNode(Array input_size, Array forward_index, PrimExpr forward_thread, PrimExpr replicate_size); PrimExpr GetForwardThread() const { return forward_thread_; } Array GetForwardVars() const final; Layout Inverse() const final; std::pair InverseWithLevel() const final; PrimExpr ThreadExtent() const; PrimExpr ReplicateExtent() const { return replicate_size_; }; PrimExpr ForwardThread(const Array &vars, const Optional &rep_var) const; Fragment Repeat(const Array &repeats, bool repeat_on_thread, bool lower_dim_first = true) const; Fragment Replicate(int repeats) const; Fragment DeReplicate() const; Fragment CondenseReplicateVar() const; std::string DebugOutput() const final; Fragment BindThreadRange(Range thread_range) const; Range ThreadRange() const { return thread_range_; } bool IsEqual(const FragmentNode *other, bool skip_index = false) const; bool IsCompletedReplicated() const; static void RegisterReflection(); bool SEqualReduce(const FragmentNode *other, SEqualReducer equal) const; static constexpr const char *_type_key = "tl.Fragment"; TVM_DECLARE_FINAL_OBJECT_INFO(FragmentNode, LayoutNode); protected: Map getVarMap() const final; Range thread_range_; PrimExpr forward_thread_; PrimExpr replicate_size_; }; /*! * \brief Fragment reference class. */ class Fragment : public Layout { public: TVM_DLL Fragment(Array forward_var, Array forward_index, PrimExpr forward_thread, IterVar thread_replicate); TVM_DLL Fragment(Array input_size, Array forward_index, PrimExpr forward_thread, PrimExpr replicate_size, Optional replicate_var); TVM_DEFINE_OBJECT_REF_METHODS(Fragment, Layout, FragmentNode); }; Var InputPlaceholder(size_t idx); Var ReplicationPlaceholder(); IterVar make_itervar(std::string name, PrimExpr dom); Fragment makeGemmFragment8x8(); Fragment makeGemmFragment8x8Transposed(); Fragment makeGemmFragmentC(const int block_m, const int block_n, const int warp_m, const int warp_n, const int element_size); Fragment makeGemmSparseFragmentC(const int block_m, const int block_n, const int warp_m, const int warp_n, const int element_size); Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n, const int warp_m, const int warp_n, const int element_size); Fragment makeGemmFragmentCHopper(const int block_m, const int block_n, const int warp_m, const int warp_n, const int element_size); Fragment makeGemmFragmentA(const int block_m, const int block_n, const int block_k, const int warp_m, const int warp_n, const int element_size, bool transposed = false); Fragment makeGemmFragmentB(const int block_m, const int block_n, const int block_k, const int warp_m, const int warp_n, bool transposed = false); Fragment makeGemmFragmentACDNA(const int block_m, const int block_n, const int block_k, const int warp_m, const int warp_n, const int element_size, const int k_pack, bool transposed = false); // Default Memory Layout Layout makeGemmLayoutLinear(int stride, int continuous); Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size); Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity, int element_size, bool k_inner = true); Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous, int continuity, int element_size, bool k_inner = true); Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity, int element_size, bool k_inner = true); Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size, int kPack); Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n, const int warp_m, const int warp_n, const int element_size); Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n, const int block_k, const int warp_m, const int warp_n); Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a, bool k_inner = true); Layout makeTensorOpMultiplicand(int mat_stride, int mat_continuous, int elementsize, int crosswise); Layout makeGemmSparseAmpereABLayout(int mat_stride, int mat_continuous, int elementsize); Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size); Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size); Layout makeQuarterBankSwizzleLayout(int stride, int continuous, int element_size); namespace attr { // BlockAttr, Containing the layout for all the buffers in the block constexpr const char *kLayoutMap = "layout_map"; } // namespace attr } // namespace tl } // namespace tvm #endif // TVM_TL_LAYOUT_LAYOUT_H_