"...composable_kernel_rocm.git" did not exist on "ff92222f937b54955011d394f46130fc5002110c"
Commit 611064a1 authored by Adam Osewski's avatar Adam Osewski
Browse files

Do not use macro.

parent 41fc6a24
...@@ -84,20 +84,37 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -84,20 +84,37 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
float ave_time{0}; float ave_time{0};
const auto Run = [&](const auto& kernel) { const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
using GemmKernel = ck_tile::remove_cvref_t<decltype(kernel)>; constexpr bool has_hot_loop_v = has_hot_loop_.value;
auto kargs = GemmKernel::MakeKargs(args.p_a, constexpr auto tail_number_v = tail_number_.value;
args.p_b,
args.p_c, using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<
args.M, ck_tile::UniversalGemmPipelineProblem<ADataType,
args.N, BDataType,
args.K, CDataType,
args.stride_A, GemmShape,
args.stride_B, ALayout,
args.stride_C); BLayout,
CLayout,
const dim3 grids = GemmKernel::GridSize(args.M, args.N, args.kbatch); kPadA,
constexpr dim3 blocks = GemmKernel::BlockSize(); kPadB,
kPadC,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b,
args.p_c,
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize();
if(s.log_level_ > 0) if(s.log_level_ > 0)
{ {
...@@ -108,79 +125,70 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -108,79 +125,70 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
} }
ave_time = ck_tile::launch_kernel( ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(kernel, grids, blocks, 0, kargs)); s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}; };
#define RUN_KERNEL_(has_hot_loop_, tail_number_) \
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< \
ck_tile::UniversalGemmPipelineProblem<ADataType, \
BDataType, \
CDataType, \
GemmShape, \
ALayout, \
BLayout, \
CLayout, \
kPadA, \
kPadB, \
kPadC, \
ck_tile::GemmPipelineScheduler::Intrawave, \
has_hot_loop_, \
tail_number_>>; \
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; \
Run(Kernel{});
if(has_hot_loop) if(has_hot_loop)
{ {
// Tail pipeline One to Seven // Tail pipeline One to Seven
if(tail_num == ck_tile::TailNumber::One) if(tail_num == ck_tile::TailNumber::One)
{ {
RUN_KERNEL_(true, ck_tile::TailNumber::One); Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
} }
else if(tail_num == ck_tile::TailNumber::Full) else if(tail_num == ck_tile::TailNumber::Full)
{ {
RUN_KERNEL_(true, ck_tile::TailNumber::Full); Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
} }
if constexpr(BaseGemmPipeline::PrefetchStages > 2) if constexpr(BaseGemmPipeline::PrefetchStages > 2)
{ {
if(tail_num == ck_tile::TailNumber::Two) if(tail_num == ck_tile::TailNumber::Two)
{ {
RUN_KERNEL_(true, ck_tile::TailNumber::Two); Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
} }
} }
if constexpr(BaseGemmPipeline::PrefetchStages > 3) if constexpr(BaseGemmPipeline::PrefetchStages > 3)
{ {
if(tail_num == ck_tile::TailNumber::Three) if(tail_num == ck_tile::TailNumber::Three)
{ {
RUN_KERNEL_(true, ck_tile::TailNumber::Three); Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
} }
} }
if constexpr(BaseGemmPipeline::PrefetchStages > 4) if constexpr(BaseGemmPipeline::PrefetchStages > 4)
{ {
if(tail_num == ck_tile::TailNumber::Four) if(tail_num == ck_tile::TailNumber::Four)
{ {
RUN_KERNEL_(true, ck_tile::TailNumber::Four); Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{});
} }
} }
if constexpr(BaseGemmPipeline::PrefetchStages > 5) if constexpr(BaseGemmPipeline::PrefetchStages > 5)
{ {
if(tail_num == ck_tile::TailNumber::Five) if(tail_num == ck_tile::TailNumber::Five)
{ {
RUN_KERNEL_(true, ck_tile::TailNumber::Five); Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{});
} }
} }
if constexpr(BaseGemmPipeline::PrefetchStages > 6) if constexpr(BaseGemmPipeline::PrefetchStages > 6)
{ {
if(tail_num == ck_tile::TailNumber::Six) if(tail_num == ck_tile::TailNumber::Six)
{ {
RUN_KERNEL_(true, ck_tile::TailNumber::Six); Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{});
} }
} }
if constexpr(BaseGemmPipeline::PrefetchStages > 7) if constexpr(BaseGemmPipeline::PrefetchStages > 7)
{ {
if(tail_num == ck_tile::TailNumber::Seven) if(tail_num == ck_tile::TailNumber::Seven)
{ {
RUN_KERNEL_(true, ck_tile::TailNumber::Seven); Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
} }
} }
} }
...@@ -189,12 +197,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -189,12 +197,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
// Tail number always 1 // Tail number always 1
if(tail_num == ck_tile::TailNumber::One) if(tail_num == ck_tile::TailNumber::One)
{ {
RUN_KERNEL_(false, ck_tile::TailNumber::One); Run(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
} }
} }
#undef RUN_KERNEL_
return ave_time; return ave_time;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment