Commit 018e939f authored by Jiming Ruan's avatar Jiming Ruan
Browse files

Modify tests and bug fix

parent 54617a85
This diff is collapsed.
......@@ -200,6 +200,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
float ave_time = rmsnorm2d_fwd(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
if(ave_time < 0)
{
std::cout << " not supported!" << std::endl << std::flush;
return false;
}
std::size_t num_byte =
sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + sizeof(YDataType) * m * n;
num_byte += SaveRms ? sizeof(InvRmsDataType) * m * n : 0;
......
......@@ -120,6 +120,13 @@ struct Rmsnorm2dFwdPipelineOnePass
block_norm_reduce_sync(square_mean, cur_count);
block_norm_reduce_cross_warp_sync(square_mean, cur_count, smem);
if constexpr(!kWelford)
{
sweep_tile(square_mean, [&](auto idx) {
square_mean(idx) = square_mean(idx) / type_convert<ComputeDataType>(row_size);
});
}
// compute inv-rms
auto inv_rms = tile_elementwise_in(
[&](const auto& v_) {
......
......@@ -70,6 +70,8 @@ struct Rmsnorm2dFwdPipelineTwoPass
void* smem,
Epilogue) const
{
static_assert(kWelford == true, "2 pass only supports welford merge");
auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto gamma_window = make_tile_window(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment