Commit 02236580 authored by rocking's avatar rocking
Browse files

refine welford max count calculation

parent 96568141
......@@ -31,9 +31,9 @@ struct Layernorm2dFwd
static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock;
static constexpr bool kPadM = false; // TODO - Problem::kPadM
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kTwoPass = Problem::kTwoPass;
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kTwoPass = Problem::kTwoPass;
static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp;
static constexpr ck_tile::index_t kNPerThread = Problem::BlockShape::kNPerThread;
......@@ -106,21 +106,6 @@ struct Layernorm2dFwd
sequence<0, 3>>{});
}
template <typename Dstr>
CK_TILE_DEVICE static constexpr auto GetNPerThread(Dstr)
{
constexpr auto nDstrSpan = Dstr::get_distributed_spans().template at<1>();
using Lengths = decltype(nDstrSpan.impl_);
ck_tile::index_t ret = 1;
ck_tile::static_for<0, Lengths::size(), 1>{}(
[&](auto idx) { ret *= Lengths::template at(idx); });
return ret;
}
template <typename DistributedTensor>
CK_TILE_DEVICE static auto InvSqrt(const DistributedTensor& in_dstr_tensor,
const ComputeDataType epsilon)
......@@ -139,20 +124,25 @@ struct Layernorm2dFwd
return out_dstr_tensor;
}
CK_TILE_HOST_DEVICE static constexpr auto
GetLastloopLayerNormIntraLaneReduceCount(index_t NLength)
CK_TILE_DEVICE static int GetWelfordMaxCount(int N)
{
using S = typename Problem::BlockShape;
// S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread
auto LastloopN = NLength % kNPerBlock == 0 ? kNPerBlock : NLength % kNPerBlock;
constexpr auto NThread = S::kNWarpPerBlock * S::kNThreadPerWarp;
auto iNLane = get_thread_local_1d_id() % NThread;
auto iN0 = LastloopN / (S::kNPerThread * S::kNThreadPerWarp);
auto iN1 = (LastloopN % (S::kNPerThread * S::kNThreadPerWarp)) / S::kNPerThread;
auto N2 = (LastloopN % (S::kNPerThread * S::kNThreadPerWarp)) % S::kNPerThread;
auto iN3 = iNLane < iN1 ? S::kNPerThread : iNLane == iN1 ? N2 : 0;
return iN0 * S::kNPerThread + iN3;
constexpr ck_tile::index_t kNThreadPerBlock = kNPerBlock / kNPerThread;
int thread_id_n = get_thread_id() % kNThreadPerBlock;
int max_count =
__builtin_amdgcn_readfirstlane(N < kNPerBlock ? 0 : kNPerThread * (N / kNPerBlock));
int n_per_block_tail_loop =
__builtin_amdgcn_readfirstlane(N - max_count * kNThreadPerBlock);
if(n_per_block_tail_loop > 0)
{
int thread_max_n = (thread_id_n + 1) * kNPerThread;
int delta = thread_max_n - n_per_block_tail_loop;
delta = clamp(thread_max_n - n_per_block_tail_loop, 0, kNPerThread);
max_count += kNPerThread - delta;
}
return max_count;
}
template <typename XBlockWindow,
......@@ -172,8 +162,8 @@ struct Layernorm2dFwd
ComputeDataType epsilon,
ck_tile::index_t N) const
{
auto intra_thread_count_last = GetLastloopLayerNormIntraLaneReduceCount(N);
ThreadWelford<ComputeDataType, XDataType> thread_welford{intra_thread_count_last};
int welford_max_count = GetWelfordMaxCount(N);
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
using XTensorType = decltype(load_tile(x_block_window));
auto mean_compute_block_tensor =
......@@ -246,15 +236,11 @@ struct Layernorm2dFwd
ComputeDataType epsilon,
ck_tile::index_t N) const
{
using S = typename Problem::BlockShape;
index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane((N + kNPerBlock - 1) / kNPerBlock);
auto intra_thread_count = S::kNRepeat * S::kNPerThread * (num_n_tile_iteration - 1);
auto intra_thread_count_last = GetLastloopLayerNormIntraLaneReduceCount(N);
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, kNPerBlock));
ThreadWelford<ComputeDataType, XDataType> thread_welford{intra_thread_count};
ThreadWelford<ComputeDataType, XDataType> thread_welford_last{intra_thread_count_last};
int welford_max_count = GetWelfordMaxCount(N);
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
using XTensorType = decltype(load_tile(x_block_window));
auto mean_compute_block_tensor =
......@@ -265,19 +251,13 @@ struct Layernorm2dFwd
clear_tile(mean_compute_block_tensor);
clear_tile(var_compute_block_tensor);
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration - 1; ++iN)
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x_block_tensor = load_tile(x_block_window);
thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor);
move_tile_window(x_block_window, {0, kNPerBlock});
}
const auto x_block_tensor_ = load_tile(x_block_window);
thread_welford_last.cur_count_ += intra_thread_count;
thread_welford_last.max_count_ += intra_thread_count;
thread_welford_last(x_block_tensor_, mean_compute_block_tensor, var_compute_block_tensor);
thread_welford.cur_count_ += intra_thread_count_last;
// TODO: support cross warp Welford
WarpMergeWelford<ComputeDataType, true>{}(
......@@ -295,6 +275,7 @@ struct Layernorm2dFwd
ck_tile::index_t stride_to_right_most_window =
N % kNPerBlock == 0 ? N - kNPerBlock : N - N % kNPerBlock;
move_tile_window(x_block_window, {0, -kNPerBlock});
move_tile_window(gamma_block_window, {stride_to_right_most_window});
move_tile_window(beta_block_window, {stride_to_right_most_window});
move_tile_window(y_block_window, {0, stride_to_right_most_window});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment