"...composable_kernel_rocm.git" did not exist on "5f2c89e8b43d670e3405a4f17ff475d25960f9b3"
Commit dd21c599 authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

tmp save

parent f23a2e2a
......@@ -23,6 +23,33 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
// ToDo: Add more bias config to support different categories of GEMM.
};
template <>
struct GemmBasicTypeConfig<ck_tile::bf16_t>
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using AccDataType = float;
using CDataType = ck_tile::bf16_t;
};
template <>
struct GemmBasicTypeConfig<ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmBasicTypeConfig<ck_tile::bf8_t>
{
using ADataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <typename T>
struct DataTypeTraits;
......@@ -44,13 +71,23 @@ struct DataTypeTraits<ck_tile::half_t>
static constexpr const char* name = "fp16";
};
using Types = GemmBasicTypeConfig<ck_tile::half_t>;
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
// Specific type aliases for easy access
using ADataType = Types::ADataType;
using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
/** \brief Struct used for specifying desired gemm details*/
struct gemm_traits
......
......@@ -10,34 +10,54 @@ using S = ck_tile::stream_config;
template <typename Traits_>
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool TransposeC = false;
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<Traits_::M_Tile, Traits_::N_Tile, Traits_::K_Tile>,
ck_tile::sequence<Traits_::M_Warp, Traits_::N_Warp, Traits_::K_Warp>,
ck_tile::sequence<Traits_::M_Warp_Tile, Traits_::N_Warp_Tile, Traits_::K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile2DPartitioner<GemmShape>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using GemmEpilogue =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename Traits_::AccDataType,
typename Traits_::CDataType,
Traits_::kPadM,
Traits_::kPadN>>;
constexpr bool TransposeC = false;
using GemmUniversalTraits = ck_tile::
TileGemmUniversalTraits<Traits_::kPadM, Traits_::kPadN, Traits_::kPadK, Traits_::ALayout, Traits_::BLayout, Traits_::CLayout, TransposeC>;
using GemmTraits = ck_tile::TileGemmTraits<Traits_::kPadM,
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<Traits_::kPadM,
Traits_::kPadN,
Traits_::kPadK,
typename Traits_::ALayout,
typename Traits_::BLayout,
typename Traits_::CLayout,
TransposeC>;
using GemmTraits = ck_tile::TileGemmTraits<Traits_::kPadM,
Traits_::kPadN,
Traits_::kPadK,
typename Traits_::ALayout,
typename Traits_::BLayout,
typename Traits_::CLayout>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<
ck_tile::GemmPipelineProblem<typename Traits_::ADataType,
typename Traits_::BDataType,
typename Traits_::AccDataType,
GemmShape,
GemmTraits>>;
constexpr int kBlockPerCu = 1;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<typename Traits_::ADataType,
typename Traits_::BDataType,
typename Traits_::AccDataType,
GemmShape,
GemmTraits>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
using GemmEpilogue =
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<typename Traits_::AccDataType,
typename Traits_::CDataType,
typename Traits_::CLayout,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Traits_::M_Warp,
Traits_::N_Warp,
Traits_::M_Warp_Tile,
Traits_::N_Warp_Tile,
Traits_::K_Warp_Tile,
TransposeC>>;
const ck_tile::index_t k_grain = args.k_batch * Traits_::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * Traits_::K_Tile;
......@@ -59,7 +79,8 @@ float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
GemmUniversalTraits,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>, ck_tile::UniversalGemmPipelineAgBgCrPolicy>;
tail_number_v>,
ck_tile::UniversalGemmPipelineAgBgCrPolicy>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
......
......@@ -10,31 +10,54 @@ using S = ck_tile::stream_config;
template <typename Traits_>
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool TransposeC = false;
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<Traits_::M_Tile, Traits_::N_Tile, Traits_::K_Tile>,
ck_tile::sequence<Traits_::M_Warp, Traits_::N_Warp, Traits_::K_Warp>,
ck_tile::sequence<Traits_::M_Warp_Tile, Traits_::N_Warp_Tile, Traits_::K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile2DPartitioner<GemmShape>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using GemmEpilogue =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename Traits_::AccDataType,
typename Traits_::CDataType,
Traits_::kPadM,
Traits_::kPadN>>;
using GemmTraits = ck_tile::TileGemmTraits<Traits_::kPadM,
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<Traits_::kPadM,
Traits_::kPadN,
Traits_::kPadK,
typename Traits_::ALayout,
typename Traits_::BLayout,
typename Traits_::CLayout,
TransposeC>;
using GemmTraits = ck_tile::TileGemmTraits<Traits_::kPadM,
Traits_::kPadN,
Traits_::kPadK,
typename Traits_::ALayout,
typename Traits_::BLayout,
typename Traits_::CLayout>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
ck_tile::GemmPipelineProblem<typename Traits_::ADataType,
typename Traits_::BDataType,
typename Traits_::AccDataType,
GemmShape,
GemmTraits>>;
typename Traits_::CLayout>;
constexpr int kBlockPerCu = 1;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<typename Traits_::ADataType,
typename Traits_::BDataType,
typename Traits_::AccDataType,
GemmShape,
GemmTraits>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>;
using GemmEpilogue =
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<typename Traits_::AccDataType,
typename Traits_::CDataType,
typename Traits_::CLayout,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Traits_::M_Warp,
Traits_::N_Warp,
Traits_::M_Warp_Tile,
Traits_::N_Warp_Tile,
Traits_::K_Warp_Tile,
TransposeC>>;
const ck_tile::index_t k_grain = args.k_batch * Traits_::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * Traits_::K_Tile;
......@@ -53,10 +76,11 @@ float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
typename Traits_::BDataType,
typename Traits_::AccDataType,
GemmShape,
GemmTraits,
ck_tile::GemmPipelineScheduler::Interwave,
GemmUniversalTraits,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>>;
tail_number_v>,
ck_tile::UniversalGemmPipelineAgBgCrPolicy>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment