/*! * \file tl/op/parallel.h * \brief Infer layout from ops and parallel for */ #ifndef TVM_TL_OP_PARALLEL_H_ #define TVM_TL_OP_PARALLEL_H_ #include #include #include "../layout/layout.h" #include "op.h" namespace tvm { namespace tl { using namespace tir; class ParallelOp; class ParallelLoopNestVisitor : public StmtExprVisitor { private: ParallelLoopNestVisitor(ParallelOp *op) : p(op){}; void VisitStmt_(const ForNode *op) final; void VisitStmt_(const BufferStoreNode *op) final; void VisitExpr_(const BufferLoadNode *op) final; ParallelOp *p; friend class ParallelOp; }; class ParallelOp : public Operator { public: ParallelOp(For root); LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final; Fragment GetLoopLayout() const { return loop_layout_; } For GetRoot() const { return root_; } Map> GetIndiceMap() const { return indice_map_; } Optional GetPredicate(Var thread_var) const; private: Fragment CompleteBufferFragment(const Buffer &buffer); bool IsCommonAccessIndice(const Buffer &buffer) const; void AddPredicate(PrimExpr expr) { predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr; } For root_; ParallelLoopNestVisitor V; Map> indice_map_; std::unordered_set buffer_is_write_; Array loop_vars_; Fragment loop_layout_; mutable arith::Analyzer analyzer_; Optional predicate_; friend class ParallelLoopNestVisitor; }; } // namespace tl } // namespace tvm #endif // TVM_TL_OP_PARALLEL_H_