Commit df45a6b5 authored by Jiming Ruan's avatar Jiming Ruan
Browse files

50ms -> 28ms

parent 3c7fef7f
...@@ -535,10 +535,11 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, ...@@ -535,10 +535,11 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)],
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, 0, 0), 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 1024, 8, True, False, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0), # h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0), # h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)]} # h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)
]}
total_blob = list() total_blob = list()
for hs_key in h_trait_dict: for hs_key in h_trait_dict:
hs = h_trait_dict[hs_key] hs = h_trait_dict[hs_key]
......
...@@ -125,7 +125,7 @@ struct Rmsnorm2dFwdPipelineTwoPass ...@@ -125,7 +125,7 @@ struct Rmsnorm2dFwdPipelineTwoPass
// compute inv-rms // compute inv-rms
auto inv_rms = tile_elementwise_in( auto inv_rms = tile_elementwise_in(
[&](const auto& v_) { [&](const auto& v_) {
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ / row_size + epsilon)); return rsqrtf(v_ / row_size + epsilon);
}, },
square_sum); square_sum);
...@@ -136,32 +136,47 @@ struct Rmsnorm2dFwdPipelineTwoPass ...@@ -136,32 +136,47 @@ struct Rmsnorm2dFwdPipelineTwoPass
ck_tile::index_t stride_to_right_most_window = ck_tile::index_t stride_to_right_most_window =
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
move_tile_window(y_residual_window, {0, -Block_N});
}
else
{
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N}); move_tile_window(x_residual_window, {0, -Block_N});
}
move_tile_window(gamma_window, {stride_to_right_most_window}); move_tile_window(gamma_window, {stride_to_right_most_window});
move_tile_window(y_window, {0, stride_to_right_most_window}); move_tile_window(y_window, {0, stride_to_right_most_window});
// rmsnorm computation // rmsnorm computation
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
auto acc = make_static_distributed_tensor<ComputeDataType>(decltype(load_tile(x_window))::get_tile_distribution());
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD)
{ {
auto x = load_tile(x_window); auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window); auto x_resi = load_tile(x_residual_window);
auto acc = cast_tile<ComputeDataType>(x);
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD)
{
sweep_tile(x_resi, [&](auto idx) { sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x // compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx); acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
}); });
move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
}
else if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
acc = cast_tile<ComputeDataType>(load_tile(y_residual_window));
move_tile_window(y_residual_window, {0, -Block_N});
} }
// load gamma (TODO: support no gamma?) // load gamma (TODO: support no gamma?)
const auto gamma = load_tile(gamma_window); const auto gamma = load_tile(gamma_window);
// rmsnorm computation // rmsnorm computation
auto rmsn = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution()); auto rmsn = make_static_distributed_tensor<ComputeDataType>(decltype(load_tile(x_window))::get_tile_distribution());
sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) { sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]);
...@@ -176,8 +191,6 @@ struct Rmsnorm2dFwdPipelineTwoPass ...@@ -176,8 +191,6 @@ struct Rmsnorm2dFwdPipelineTwoPass
static_assert(kFusedQuant == Rmsnorm2dFusedQuantEnum::NO_SWEEP); static_assert(kFusedQuant == Rmsnorm2dFusedQuantEnum::NO_SWEEP);
Epilogue{}(y_window, rmsn); Epilogue{}(y_window, rmsn);
move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(gamma_window, {-Block_N}); move_tile_window(gamma_window, {-Block_N});
move_tile_window(y_window, {0, -Block_N}); move_tile_window(y_window, {0, -Block_N});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment