Commit f41a43a1 authored by Jiming Ruan's avatar Jiming Ruan
Browse files

Fix bug in non fuse_add_store cases

parent df45a6b5
...@@ -153,24 +153,26 @@ struct Rmsnorm2dFwdPipelineTwoPass ...@@ -153,24 +153,26 @@ struct Rmsnorm2dFwdPipelineTwoPass
{ {
auto acc = make_static_distributed_tensor<ComputeDataType>(decltype(load_tile(x_window))::get_tile_distribution()); auto acc = make_static_distributed_tensor<ComputeDataType>(decltype(load_tile(x_window))::get_tile_distribution());
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD) if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
});
move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
}
else if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{ {
acc = cast_tile<ComputeDataType>(load_tile(y_residual_window)); acc = cast_tile<ComputeDataType>(load_tile(y_residual_window));
move_tile_window(y_residual_window, {0, -Block_N}); move_tile_window(y_residual_window, {0, -Block_N});
} }
else
{
acc = cast_tile<ComputeDataType>(load_tile(x_window));
move_tile_window(x_window, {0, -Block_N});
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD)
{
auto x_resi = load_tile(x_residual_window);
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
});
move_tile_window(x_residual_window, {0, -Block_N});
}
}
// load gamma (TODO: support no gamma?) // load gamma (TODO: support no gamma?)
const auto gamma = load_tile(gamma_window); const auto gamma = load_tile(gamma_window);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment