Commit c636337e authored by rocking's avatar rocking
Browse files

Extract padN and RawStore

parent 256dcc22
...@@ -40,8 +40,8 @@ auto create_args(int argc, char* argv[]) ...@@ -40,8 +40,8 @@ auto create_args(int argc, char* argv[])
.insert("prec_sy", .insert("prec_sy",
"auto", "auto",
"output quant scale type, set auto will use fp32. used when fquant=1 or 2") "output quant scale type, set auto will use fp32. used when fquant=1 or 2")
.insert("warmup", "5", "cold iter") .insert("warmup", "10", "cold iter")
.insert("repeat", "20", "hot iter"); .insert("repeat", "40", "hot iter");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
...@@ -129,9 +129,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -129,9 +129,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
beta_buf.ToDevice(beta_host.data()); beta_buf.ToDevice(beta_host.data());
x_residual_buf.ToDevice(x_residual_host.data()); x_residual_buf.ToDevice(x_residual_host.data());
constexpr bool kTwoPass = false; constexpr bool kTwoPass = false;
constexpr auto kFuseAdd = ck_tile::Layernorm2dFusedAddEnum::PRE_ADD_STORE; constexpr bool kPadN = true;
constexpr auto kFuseQuant = ck_tile::Layernorm2dFusedQuantEnum::NO_SWEEP; constexpr bool UseRawStore = true;
constexpr auto kFuseAdd = ck_tile::Layernorm2dFusedAddEnum::PRE_ADD_STORE;
constexpr auto kFuseQuant = ck_tile::Layernorm2dFusedQuantEnum::NO_SWEEP;
using BlockWarps = ck_tile::sequence<1, 4>; using BlockWarps = ck_tile::sequence<1, 4>;
using BlockTile = ck_tile::sequence<1, 8192>; using BlockTile = ck_tile::sequence<1, 8192>;
...@@ -139,7 +141,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -139,7 +141,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using Vector = ck_tile::sequence<1, 8>; using Vector = ck_tile::sequence<1, 8>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>; using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Trait = ck_tile::Layernorm2dFwdTraits<true, false, true, kTwoPass, kFuseAdd, kFuseQuant>; using Trait = ck_tile::Layernorm2dFwdTraits<kPadN, false, true, kTwoPass, kFuseAdd, kFuseQuant>;
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<XDataType, using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
...@@ -157,7 +159,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -157,7 +159,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using Pipeline = std::conditional_t<kTwoPass, TwoPassPipeline, OnePassPipeline>; using Pipeline = std::conditional_t<kTwoPass, TwoPassPipeline, OnePassPipeline>;
using EpilogueProblem = using EpilogueProblem =
ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, true, false>; ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, kPadN, UseRawStore>;
using Epilogue = ck_tile::Default2DEpilogue<EpilogueProblem>; using Epilogue = ck_tile::Default2DEpilogue<EpilogueProblem>;
using Kernel = ck_tile::Layernorm2dFwd<Pipeline, Epilogue>; using Kernel = ck_tile::Layernorm2dFwd<Pipeline, Epilogue>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment