Commit 6fe9e964 authored by Adam Osewski's avatar Adam Osewski
Browse files

Adapt example to refactor changes.

parent 241baec1
...@@ -199,16 +199,14 @@ int run_gemm_example(int argc, char* argv[]) ...@@ -199,16 +199,14 @@ int run_gemm_example(int argc, char* argv[])
{ {
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
} }
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not else if(a_layout == "C" && b_layout == "C")
// work. {
// else if(a_layout == "C" && b_layout == "C") return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
// { }
// return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); else if(a_layout == "C" && b_layout == "R")
// } {
// else if(a_layout == "C" && b_layout == "R") return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
// { }
// return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
// }
else else
{ {
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
......
...@@ -37,8 +37,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -37,8 +37,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 8;
#endif
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
// Compute friendly for Intrawave scheduler // Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256; constexpr ck_tile::index_t N_Tile = 256;
...@@ -57,6 +57,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -57,6 +57,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr bool kPadN = false; constexpr bool kPadN = false;
constexpr bool kPadK = false; constexpr bool kPadK = false;
constexpr bool TransposeC = false;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
// =============================================== // ===============================================
...@@ -71,10 +73,13 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -71,10 +73,13 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile::
TileGemmUniversalTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, TransposeC>;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3< using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<
#endif #endif
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>; ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
...@@ -97,14 +102,16 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -97,14 +102,16 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
BDataType, BDataType,
AccDataType, AccDataType,
GemmShape, GemmShape,
Traits, GemmUniversalTraits,
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
ck_tile::GemmPipelineScheduler::Interwave, ck_tile::GemmPipelineScheduler::Interwave,
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
ck_tile::GemmPipelineScheduler::Intrawave, ck_tile::GemmPipelineScheduler::Intrawave,
#endif #endif
has_hot_loop_v, has_hot_loop_v,
tail_number_v>>; tail_number_v>,
ck_tile::UniversalGemmPipelineAgBgCrPolicy>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a, auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b, args.p_b,
...@@ -139,6 +146,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -139,6 +146,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
if(has_hot_loop) if(has_hot_loop)
{ {
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else
{
std::ostringstream err;
err << "For compute pipeline tail number should always be Full, but have \"" << tail_num
<< "\" which is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Tail pipeline One to Seven // Tail pipeline One to Seven
if(tail_num == ck_tile::TailNumber::One) if(tail_num == ck_tile::TailNumber::One)
{ {
...@@ -199,6 +221,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -199,6 +221,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{}); ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
} }
} }
#endif
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment