Commit df45a6b5 authored by Jiming Ruan's avatar Jiming Ruan
Browse files

50ms -> 28ms

parent 3c7fef7f
......@@ -535,10 +535,11 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)],
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)]}
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 1024, 8, True, False, True, 0, 0),
# h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0),
# h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0),
# h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)
]}
total_blob = list()
for hs_key in h_trait_dict:
hs = h_trait_dict[hs_key]
......
......@@ -125,7 +125,7 @@ struct Rmsnorm2dFwdPipelineTwoPass
// compute inv-rms
auto inv_rms = tile_elementwise_in(
[&](const auto& v_) {
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ / row_size + epsilon));
return rsqrtf(v_ / row_size + epsilon);
},
square_sum);
......@@ -136,32 +136,47 @@ struct Rmsnorm2dFwdPipelineTwoPass
ck_tile::index_t stride_to_right_most_window =
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
move_tile_window(y_residual_window, {0, -Block_N});
}
else
{
move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
}
move_tile_window(gamma_window, {stride_to_right_most_window});
move_tile_window(y_window, {0, stride_to_right_most_window});
// rmsnorm computation
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
auto acc = cast_tile<ComputeDataType>(x);
auto acc = make_static_distributed_tensor<ComputeDataType>(decltype(load_tile(x_window))::get_tile_distribution());
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE ||
kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD)
if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD)
{
auto x = load_tile(x_window);
auto x_resi = load_tile(x_residual_window);
sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
});
move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
}
else if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE)
{
acc = cast_tile<ComputeDataType>(load_tile(y_residual_window));
move_tile_window(y_residual_window, {0, -Block_N});
}
// load gamma (TODO: support no gamma?)
const auto gamma = load_tile(gamma_window);
// rmsnorm computation
auto rmsn = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
auto rmsn = make_static_distributed_tensor<ComputeDataType>(decltype(load_tile(x_window))::get_tile_distribution());
sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
......@@ -175,9 +190,7 @@ struct Rmsnorm2dFwdPipelineTwoPass
static_assert(kFusedQuant == Rmsnorm2dFusedQuantEnum::NO_SWEEP);
Epilogue{}(y_window, rmsn);
move_tile_window(x_window, {0, -Block_N});
move_tile_window(x_residual_window, {0, -Block_N});
move_tile_window(gamma_window, {-Block_N});
move_tile_window(y_window, {0, -Block_N});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment