// SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include "ck_tile/host.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/epilogue.hpp" template struct GemmBasicTypeConfig; template <> struct GemmBasicTypeConfig { using ADataType = ck_tile::half_t; using BDataType = ck_tile::half_t; using AccDataType = float; using CDataType = ck_tile::half_t; // ToDo: Add more bias config to support different categories of GEMM. }; template <> struct GemmBasicTypeConfig { using ADataType = ck_tile::bf16_t; using BDataType = ck_tile::bf16_t; using AccDataType = float; using CDataType = ck_tile::bf16_t; }; template <> struct GemmBasicTypeConfig { using ADataType = ck_tile::fp8_t; using BDataType = ck_tile::fp8_t; using AccDataType = float; using CDataType = ck_tile::fp8_t; }; template <> struct GemmBasicTypeConfig { using ADataType = ck_tile::bf8_t; using BDataType = ck_tile::bf8_t; using AccDataType = float; using CDataType = ck_tile::bf8_t; }; template struct DataTypeTraits; template <> struct DataTypeTraits { static constexpr const char* name = "fp32"; }; template <> struct DataTypeTraits { static constexpr const char* name = "fp64"; }; template <> struct DataTypeTraits { static constexpr const char* name = "fp16"; }; template <> struct DataTypeTraits { static constexpr const char* name = "bf16"; }; template <> struct DataTypeTraits { static constexpr const char* name = "fp8"; }; template <> struct DataTypeTraits { static constexpr const char* name = "bf8"; }; /** \brief Struct used for specifying desired gemm details*/ struct gemm_traits { std::string data_type; /** Tensors datatype, can be set to either fp16 or bf16*/ bool is_a_rowmajor; /** Whether A matrix is rowmajor */ bool is_b_rowmajor; /** Whether B matrix is rowmajor */ bool is_c_rowmajor; /** Whether C matrix is rowmajor */ }; template struct gemm_traits_ { using ADataType = ck_tile::remove_cvref_t; using BDataType = ck_tile::remove_cvref_t; using AccDataType = ck_tile::remove_cvref_t; using CDataType = ck_tile::remove_cvref_t; using ALayout = ck_tile::remove_cvref_t; using BLayout = ck_tile::remove_cvref_t; using CLayout = ck_tile::remove_cvref_t; static constexpr ck_tile::index_t M_Tile = M_Tile_; static constexpr ck_tile::index_t N_Tile = N_Tile_; static constexpr ck_tile::index_t K_Tile = K_Tile_; static constexpr ck_tile::index_t M_Warp = M_Warp_; static constexpr ck_tile::index_t N_Warp = N_Warp_; static constexpr ck_tile::index_t K_Warp = K_Warp_; static constexpr ck_tile::index_t M_Warp_Tile = M_Warp_Tile_; static constexpr ck_tile::index_t N_Warp_Tile = N_Warp_Tile_; static constexpr ck_tile::index_t K_Warp_Tile = K_Warp_Tile_; static constexpr bool kPadM = kPadM_; static constexpr bool kPadN = kPadN_; static constexpr bool kPadK = kPadK_; }; // host API template float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); /** * \brief Invoke gemm function * * \param traits Gemm traits which are used for choosing best instance. * \param args Runtime gemm host arguments. * \param s Stream configuration. * \return Time of execution. */ float gemm(const gemm_traits& traits, const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);