Commit f41a43a1 authored by Jiming Ruan's avatar Jiming Ruan
Browse files

Fix bug in non fuse_add_store cases

parent df45a6b5
...@@ -153,23 +153,25 @@ struct Rmsnorm2dFwdPipelineTwoPass ...@@ -153,23 +153,25 @@ struct Rmsnorm2dFwdPipelineTwoPass
{ {
auto acc = make_static_distributed_tensor<ComputeDataType>(decltype(load_tile(x_window))::get_tile_distribution()); auto acc = make_static_distributed_tensor<ComputeDataType>(decltype(load_tile(x_window))::get_tile_distribution());
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
acc = cast_tile<ComputeDataType>(load_tile(y_residual_window));
move_tile_window(y_residual_window, {0, -Block_N});
}
else
{
acc = cast_tile<ComputeDataType>(load_tile(x_window));
move_tile_window(x_window, {0, -Block_N});
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD) if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD)
{ {
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window); auto x_resi = load_tile(x_residual_window);
sweep_tile(x_resi, [&](auto idx) { sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x // compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx); acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
}); });
move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N}); move_tile_window(x_residual_window, {0, -Block_N});
} }
else if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
acc = cast_tile<ComputeDataType>(load_tile(y_residual_window));
move_tile_window(y_residual_window, {0, -Block_N});
} }
// load gamma (TODO: support no gamma?) // load gamma (TODO: support no gamma?)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment