Commit dd21c599 authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

tmp save

parent f23a2e2a
...@@ -23,6 +23,33 @@ struct GemmBasicTypeConfig<ck_tile::half_t> ...@@ -23,6 +23,33 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
// ToDo: Add more bias config to support different categories of GEMM. // ToDo: Add more bias config to support different categories of GEMM.
}; };
template <>
struct GemmBasicTypeConfig<ck_tile::bf16_t>
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using AccDataType = float;
using CDataType = ck_tile::bf16_t;
};
template <>
struct GemmBasicTypeConfig<ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmBasicTypeConfig<ck_tile::bf8_t>
{
using ADataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <typename T> template <typename T>
struct DataTypeTraits; struct DataTypeTraits;
...@@ -44,13 +71,23 @@ struct DataTypeTraits<ck_tile::half_t> ...@@ -44,13 +71,23 @@ struct DataTypeTraits<ck_tile::half_t>
static constexpr const char* name = "fp16"; static constexpr const char* name = "fp16";
}; };
using Types = GemmBasicTypeConfig<ck_tile::half_t>; template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
// Specific type aliases for easy access template <>
using ADataType = Types::ADataType; struct DataTypeTraits<ck_tile::fp8_t>
using BDataType = Types::BDataType; {
using AccDataType = Types::AccDataType; static constexpr const char* name = "fp8";
using CDataType = Types::CDataType; };
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
/** \brief Struct used for specifying desired gemm details*/ /** \brief Struct used for specifying desired gemm details*/
struct gemm_traits struct gemm_traits
......
...@@ -10,34 +10,54 @@ using S = ck_tile::stream_config; ...@@ -10,34 +10,54 @@ using S = ck_tile::stream_config;
template <typename Traits_> template <typename Traits_>
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{ {
constexpr bool TransposeC = false;
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape = ck_tile::TileGemmShape< using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<Traits_::M_Tile, Traits_::N_Tile, Traits_::K_Tile>, ck_tile::sequence<Traits_::M_Tile, Traits_::N_Tile, Traits_::K_Tile>,
ck_tile::sequence<Traits_::M_Warp, Traits_::N_Warp, Traits_::K_Warp>, ck_tile::sequence<Traits_::M_Warp, Traits_::N_Warp, Traits_::K_Warp>,
ck_tile::sequence<Traits_::M_Warp_Tile, Traits_::N_Warp_Tile, Traits_::K_Warp_Tile>>; ck_tile::sequence<Traits_::M_Warp_Tile, Traits_::N_Warp_Tile, Traits_::K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile2DPartitioner<GemmShape>; using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using GemmEpilogue = using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<Traits_::kPadM,
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename Traits_::AccDataType, Traits_::kPadN,
typename Traits_::CDataType, Traits_::kPadK,
Traits_::kPadM, typename Traits_::ALayout,
Traits_::kPadN>>; typename Traits_::BLayout,
constexpr bool TransposeC = false; typename Traits_::CLayout,
using GemmUniversalTraits = ck_tile:: TransposeC>;
TileGemmUniversalTraits<Traits_::kPadM, Traits_::kPadN, Traits_::kPadK, Traits_::ALayout, Traits_::BLayout, Traits_::CLayout, TransposeC>;
using GemmTraits = ck_tile::TileGemmTraits<Traits_::kPadM, using GemmTraits = ck_tile::TileGemmTraits<Traits_::kPadM,
Traits_::kPadN, Traits_::kPadN,
Traits_::kPadK, Traits_::kPadK,
typename Traits_::ALayout, typename Traits_::ALayout,
typename Traits_::BLayout, typename Traits_::BLayout,
typename Traits_::CLayout>; typename Traits_::CLayout>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<
ck_tile::GemmPipelineProblem<typename Traits_::ADataType, using GemmPipelineProblem = ck_tile::GemmPipelineProblem<typename Traits_::ADataType,
typename Traits_::BDataType, typename Traits_::BDataType,
typename Traits_::AccDataType, typename Traits_::AccDataType,
GemmShape, GemmShape,
GemmTraits>>; GemmTraits>;
constexpr int kBlockPerCu = 1; using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
using GemmEpilogue =
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<typename Traits_::AccDataType,
typename Traits_::CDataType,
typename Traits_::CLayout,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Traits_::M_Warp,
Traits_::N_Warp,
Traits_::M_Warp_Tile,
Traits_::N_Warp_Tile,
Traits_::K_Warp_Tile,
TransposeC>>;
const ck_tile::index_t k_grain = args.k_batch * Traits_::K_Tile; const ck_tile::index_t k_grain = args.k_batch * Traits_::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * Traits_::K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * Traits_::K_Tile;
...@@ -59,7 +79,8 @@ float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) ...@@ -59,7 +79,8 @@ float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
GemmUniversalTraits, GemmUniversalTraits,
ck_tile::GemmPipelineScheduler::Intrawave, ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v, has_hot_loop_v,
tail_number_v>, ck_tile::UniversalGemmPipelineAgBgCrPolicy>; tail_number_v>,
ck_tile::UniversalGemmPipelineAgBgCrPolicy>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args); auto kargs = Kernel::MakeKernelArgs(args);
......
...@@ -10,31 +10,54 @@ using S = ck_tile::stream_config; ...@@ -10,31 +10,54 @@ using S = ck_tile::stream_config;
template <typename Traits_> template <typename Traits_>
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{ {
constexpr bool TransposeC = false;
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape = ck_tile::TileGemmShape< using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<Traits_::M_Tile, Traits_::N_Tile, Traits_::K_Tile>, ck_tile::sequence<Traits_::M_Tile, Traits_::N_Tile, Traits_::K_Tile>,
ck_tile::sequence<Traits_::M_Warp, Traits_::N_Warp, Traits_::K_Warp>, ck_tile::sequence<Traits_::M_Warp, Traits_::N_Warp, Traits_::K_Warp>,
ck_tile::sequence<Traits_::M_Warp_Tile, Traits_::N_Warp_Tile, Traits_::K_Warp_Tile>>; ck_tile::sequence<Traits_::M_Warp_Tile, Traits_::N_Warp_Tile, Traits_::K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile2DPartitioner<GemmShape>; using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using GemmEpilogue = using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<Traits_::kPadM,
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename Traits_::AccDataType, Traits_::kPadN,
typename Traits_::CDataType, Traits_::kPadK,
Traits_::kPadM, typename Traits_::ALayout,
Traits_::kPadN>>; typename Traits_::BLayout,
typename Traits_::CLayout,
TransposeC>;
using GemmTraits = ck_tile::TileGemmTraits<Traits_::kPadM, using GemmTraits = ck_tile::TileGemmTraits<Traits_::kPadM,
Traits_::kPadN, Traits_::kPadN,
Traits_::kPadK, Traits_::kPadK,
typename Traits_::ALayout, typename Traits_::ALayout,
typename Traits_::BLayout, typename Traits_::BLayout,
typename Traits_::CLayout>; typename Traits_::CLayout>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
ck_tile::GemmPipelineProblem<typename Traits_::ADataType, using GemmPipelineProblem = ck_tile::GemmPipelineProblem<typename Traits_::ADataType,
typename Traits_::BDataType, typename Traits_::BDataType,
typename Traits_::AccDataType, typename Traits_::AccDataType,
GemmShape, GemmShape,
GemmTraits>>; GemmTraits>;
constexpr int kBlockPerCu = 1; using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>;
using GemmEpilogue =
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<typename Traits_::AccDataType,
typename Traits_::CDataType,
typename Traits_::CLayout,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Traits_::M_Warp,
Traits_::N_Warp,
Traits_::M_Warp_Tile,
Traits_::N_Warp_Tile,
Traits_::K_Warp_Tile,
TransposeC>>;
const ck_tile::index_t k_grain = args.k_batch * Traits_::K_Tile; const ck_tile::index_t k_grain = args.k_batch * Traits_::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * Traits_::K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * Traits_::K_Tile;
...@@ -53,10 +76,11 @@ float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) ...@@ -53,10 +76,11 @@ float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
typename Traits_::BDataType, typename Traits_::BDataType,
typename Traits_::AccDataType, typename Traits_::AccDataType,
GemmShape, GemmShape,
GemmTraits, GemmUniversalTraits,
ck_tile::GemmPipelineScheduler::Interwave, ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v, has_hot_loop_v,
tail_number_v>>; tail_number_v>,
ck_tile::UniversalGemmPipelineAgBgCrPolicy>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args); auto kargs = Kernel::MakeKernelArgs(args);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment