/*! * \file target/codegen.h * \brief Utility to generate code */ #ifndef TVM_TL_TARGET_CODEGEN_HIP_H_ #define TVM_TL_TARGET_CODEGEN_HIP_H_ #include #include #include #include #include #include "target/source/codegen_c.h" namespace tvm { namespace codegen { class CodeGenTileLangHIP final : public CodeGenC { public: CodeGenTileLangHIP(); std::string Finish(); // override behavior void PrintFuncPrefix(std::ostream &os) final; void PrintExtraAttrs(const PrimFunc &f, std::ostream &os) final; void VisitStmt_(const ForNode *op) final; void PrintStorageSync(const CallNode *op) final; void PrintStorageScope(const std::string &scope, std::ostream &os) final; // NOLINT(*) void PrintVecBinaryOp(const std::string &op, DataType t, PrimExpr lhs, PrimExpr rhs, std::ostream &os) final; // NOLINT(*) void PrintType(DataType t, std::ostream &os) final; // NOLINT(*) void PrintVecElemLoad(const std::string &vec, DataType t, int i, std::ostream &os) final; // NOLINT(*) void PrintVecElemStore(const std::string &vec, DataType t, int i, const std::string &value) final; void BindThreadIndex(const IterVar &iv) final; // NOLINT(*) void PrintVecElemLoadExpr(DataType t, int i, const std::string &value, std::ostream &os) final; std::string CastFromTo(std::string value, DataType from, DataType target) final; // overload visitor void VisitExpr_(const RampNode *op, std::ostream &os) final; // NOLINT(*) void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode *op, std::ostream &os) final; void VisitExpr_(const CallNode *op, std::ostream &os) final; void VisitExpr_(const CastNode *op, std::ostream &os) final; void VisitStmt_(const AllocateNode *op) final; void VisitStmt_(const AttrStmtNode *op) final; // Override this as a work around for __grid_constant__ parameter void AddFunction(const PrimFunc &f); protected: virtual std::string GetBufferRef(DataType t, const BufferNode *buffer, PrimExpr index) final; void PrintCallExtern(Type ret_type, String global_symbol, const Array &args, bool skip_first_arg, std::ostream &os) final; // NOLINT(*) private: // Handle volatile loads void HandleVolatileLoads(const std::string &value, const BufferLoadNode *op, std::ostream &os) final; // Whether scope such as "__shared__" or "__constant__" is part of type. bool IsScopePartOfType() const final { return false; } friend void PrintConst(const FloatImmNode *op, std::ostream &os, CodeGenTileLangHIP *p); // whether need math_constants.h bool need_math_constants_h_{false}; // whether need mfma.h bool need_wmma_h_{false}; // whether need fp8.h bool enable_fp8_{false}; // The size of the barrier array in shared memory int barrier_count_ = -1; // whether need mma.h bool need_mma_h_{false}; // whether need cast_smem_ptr_to_int helper function bool need_cast_smem_ptr_to_int_{false}; // The name of the barrier array in shared memory const std::string barrier_name_ = "barrier"; // The alignment of the barrier array in shared memory // Set to 16 to maintain minimum alignment requirements for async bulk copy const int barrier_alignment_bytes_ = 16; }; } // namespace codegen } // namespace tvm #endif // TVM_TL_TARGET_CODEGEN_HIP_H_