// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. /*! * \file target/codegen.h * \brief Utility to generate code */ #ifndef TVM_TL_TARGET_CODEGEN_CUDA_H_ #define TVM_TL_TARGET_CODEGEN_CUDA_H_ #include #include #include #include #include #include "target/source/codegen_c.h" namespace tvm { namespace codegen { class CodeGenTileLangCUDA final : public CodeGenC { public: CodeGenTileLangCUDA(); std::string Finish(); // override behavior void PrintFuncPrefix(std::ostream& os) final; void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final; void VisitStmt_(const ForNode* op) final; void PrintStorageSync(const CallNode* op) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, std::ostream& os) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) void PrintVecElemLoad(const std::string& vec, DataType t, int i, std::ostream& os) final; // NOLINT(*) void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final; std::string CastFromTo(std::string value, DataType from, DataType target) final; // overload visitor void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; void VisitExpr_(const CallNode* op, std::ostream& os) final; void VisitExpr_(const CastNode* op, std::ostream& os) final; void VisitStmt_(const AllocateNode* op) final; void VisitStmt_(const AttrStmtNode* op) final; // Override this as a work around for __grid_constant__ parameter void AddFunction(const PrimFunc& f); protected: virtual std::string GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) final; void PrintCallExtern(Type ret_type, String global_symbol, const Array& args, bool skip_first_arg, std::ostream& os) final; // NOLINT(*) private: // Handle volatile loads void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op, std::ostream& os) final; // Whether scope such as "__shared__" or "__constant__" is part of type. bool IsScopePartOfType() const final { return false; } friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangCUDA* p); // The size of the barrier array in shared memory int barrier_count_ = -1; // whether need mma.h bool need_mma_h_{false}; // whether need cast_smem_ptr_to_int helper function bool need_cast_smem_ptr_to_int_{false}; // The name of the barrier array in shared memory const std::string barrier_name_ = "barrier"; // The alignment of the barrier array in shared memory // Set to 16 to maintain minimum alignment requirements for async bulk copy const int barrier_alignment_bytes_ = 16; std::unordered_map fragment_shapes; std::unordered_map fragment_layouts; friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangCUDA* p); void PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable, std::ostream& os); int32_t GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, int32_t size); }; } // namespace codegen } // namespace tvm #endif // TVM_TL_TARGET_CODEGEN_CUDA_H_