Commit b426b99a authored by Jiming Ruan's avatar Jiming Ruan
Browse files

remove unnecessary change

parent 485e530b
...@@ -125,7 +125,7 @@ struct Rmsnorm2dFwdPipelineTwoPass ...@@ -125,7 +125,7 @@ struct Rmsnorm2dFwdPipelineTwoPass
// compute inv-rms // compute inv-rms
auto inv_rms = tile_elementwise_in( auto inv_rms = tile_elementwise_in(
[&](const auto& v_) { [&](const auto& v_) {
return rsqrtf(v_ / row_size + epsilon); return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ / row_size + epsilon));
}, },
square_sum); square_sum);
...@@ -151,7 +151,8 @@ struct Rmsnorm2dFwdPipelineTwoPass ...@@ -151,7 +151,8 @@ struct Rmsnorm2dFwdPipelineTwoPass
// rmsnorm computation // rmsnorm computation
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
auto acc = make_static_distributed_tensor<ComputeDataType>(decltype(load_tile(x_window))::get_tile_distribution()); auto acc = make_static_distributed_tensor<ComputeDataType>(
decltype(load_tile(x_window))::get_tile_distribution());
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE) if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{ {
...@@ -178,7 +179,8 @@ struct Rmsnorm2dFwdPipelineTwoPass ...@@ -178,7 +179,8 @@ struct Rmsnorm2dFwdPipelineTwoPass
const auto gamma = load_tile(gamma_window); const auto gamma = load_tile(gamma_window);
// rmsnorm computation // rmsnorm computation
auto rmsn = make_static_distributed_tensor<ComputeDataType>(decltype(load_tile(x_window))::get_tile_distribution()); auto rmsn = make_static_distributed_tensor<ComputeDataType>(
decltype(load_tile(x_window))::get_tile_distribution());
sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) { sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]);
...@@ -192,7 +194,7 @@ struct Rmsnorm2dFwdPipelineTwoPass ...@@ -192,7 +194,7 @@ struct Rmsnorm2dFwdPipelineTwoPass
static_assert(kFusedQuant == Rmsnorm2dFusedQuantEnum::NO_SWEEP); static_assert(kFusedQuant == Rmsnorm2dFusedQuantEnum::NO_SWEEP);
Epilogue{}(y_window, rmsn); Epilogue{}(y_window, rmsn);
move_tile_window(gamma_window, {-Block_N}); move_tile_window(gamma_window, {-Block_N});
move_tile_window(y_window, {0, -Block_N}); move_tile_window(y_window, {0, -Block_N});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment