Commit 611064a1 authored by Adam Osewski's avatar Adam Osewski
Browse files

Do not use macro.

parent 41fc6a24
......@@ -84,20 +84,37 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
float ave_time{0};
const auto Run = [&](const auto& kernel) {
using GemmKernel = ck_tile::remove_cvref_t<decltype(kernel)>;
auto kargs = GemmKernel::MakeKargs(args.p_a,
args.p_b,
args.p_c,
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);
const dim3 grids = GemmKernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = GemmKernel::BlockSize();
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
CDataType,
GemmShape,
ALayout,
BLayout,
CLayout,
kPadA,
kPadB,
kPadC,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b,
args.p_c,
args.M,
args.N,
args.K,
args.stride_A,
args.stride_B,
args.stride_C);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize();
if(s.log_level_ > 0)
{
......@@ -108,79 +125,70 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(kernel, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
#define RUN_KERNEL_(has_hot_loop_, tail_number_) \
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< \
ck_tile::UniversalGemmPipelineProblem<ADataType, \
BDataType, \
CDataType, \
GemmShape, \
ALayout, \
BLayout, \
CLayout, \
kPadA, \
kPadB, \
kPadC, \
ck_tile::GemmPipelineScheduler::Intrawave, \
has_hot_loop_, \
tail_number_>>; \
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; \
Run(Kernel{});
if(has_hot_loop)
{
// Tail pipeline One to Seven
if(tail_num == ck_tile::TailNumber::One)
{
RUN_KERNEL_(true, ck_tile::TailNumber::One);
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
else if(tail_num == ck_tile::TailNumber::Full)
{
RUN_KERNEL_(true, ck_tile::TailNumber::Full);
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
{
if(tail_num == ck_tile::TailNumber::Two)
{
RUN_KERNEL_(true, ck_tile::TailNumber::Two);
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 3)
{
if(tail_num == ck_tile::TailNumber::Three)
{
RUN_KERNEL_(true, ck_tile::TailNumber::Three);
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
{
if(tail_num == ck_tile::TailNumber::Four)
{
RUN_KERNEL_(true, ck_tile::TailNumber::Four);
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
{
if(tail_num == ck_tile::TailNumber::Five)
{
RUN_KERNEL_(true, ck_tile::TailNumber::Five);
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 6)
{
if(tail_num == ck_tile::TailNumber::Six)
{
RUN_KERNEL_(true, ck_tile::TailNumber::Six);
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
{
if(tail_num == ck_tile::TailNumber::Seven)
{
RUN_KERNEL_(true, ck_tile::TailNumber::Seven);
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
}
}
}
......@@ -189,12 +197,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
// Tail number always 1
if(tail_num == ck_tile::TailNumber::One)
{
RUN_KERNEL_(false, ck_tile::TailNumber::One);
Run(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
}
#undef RUN_KERNEL_
return ave_time;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment