Commit 4071440c authored by rocking's avatar rocking
Browse files

Revert "refine welford max count calculation"

This reverts commit 02236580.
parent d62f0358
...@@ -31,9 +31,9 @@ struct Layernorm2dFwd ...@@ -31,9 +31,9 @@ struct Layernorm2dFwd
static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock; static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock; static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock;
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM static constexpr bool kPadM = false; // TODO - Problem::kPadM
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kTwoPass = Problem::kTwoPass; static constexpr bool kTwoPass = Problem::kTwoPass;
static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp; static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp;
static constexpr ck_tile::index_t kNPerThread = Problem::BlockShape::kNPerThread; static constexpr ck_tile::index_t kNPerThread = Problem::BlockShape::kNPerThread;
...@@ -106,6 +106,21 @@ struct Layernorm2dFwd ...@@ -106,6 +106,21 @@ struct Layernorm2dFwd
sequence<0, 3>>{}); sequence<0, 3>>{});
} }
template <typename Dstr>
CK_TILE_DEVICE static constexpr auto GetNPerThread(Dstr)
{
constexpr auto nDstrSpan = Dstr::get_distributed_spans().template at<1>();
using Lengths = decltype(nDstrSpan.impl_);
ck_tile::index_t ret = 1;
ck_tile::static_for<0, Lengths::size(), 1>{}(
[&](auto idx) { ret *= Lengths::template at(idx); });
return ret;
}
template <typename DistributedTensor> template <typename DistributedTensor>
CK_TILE_DEVICE static auto InvSqrt(const DistributedTensor& in_dstr_tensor, CK_TILE_DEVICE static auto InvSqrt(const DistributedTensor& in_dstr_tensor,
const ComputeDataType epsilon) const ComputeDataType epsilon)
...@@ -124,25 +139,20 @@ struct Layernorm2dFwd ...@@ -124,25 +139,20 @@ struct Layernorm2dFwd
return out_dstr_tensor; return out_dstr_tensor;
} }
CK_TILE_DEVICE static int GetWelfordMaxCount(int N) CK_TILE_HOST_DEVICE static constexpr auto
GetLastloopLayerNormIntraLaneReduceCount(index_t NLength)
{ {
constexpr ck_tile::index_t kNThreadPerBlock = kNPerBlock / kNPerThread; using S = typename Problem::BlockShape;
// S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread
int thread_id_n = get_thread_id() % kNThreadPerBlock; auto LastloopN = NLength % kNPerBlock == 0 ? kNPerBlock : NLength % kNPerBlock;
int max_count = constexpr auto NThread = S::kNWarpPerBlock * S::kNThreadPerWarp;
__builtin_amdgcn_readfirstlane(N < kNPerBlock ? 0 : kNPerThread * (N / kNPerBlock)); auto iNLane = get_thread_local_1d_id() % NThread;
int n_per_block_tail_loop = auto iN0 = LastloopN / (S::kNPerThread * S::kNThreadPerWarp);
__builtin_amdgcn_readfirstlane(N - max_count * kNThreadPerBlock); auto iN1 = (LastloopN % (S::kNPerThread * S::kNThreadPerWarp)) / S::kNPerThread;
auto N2 = (LastloopN % (S::kNPerThread * S::kNThreadPerWarp)) % S::kNPerThread;
if(n_per_block_tail_loop > 0) auto iN3 = iNLane < iN1 ? S::kNPerThread : iNLane == iN1 ? N2 : 0;
{
int thread_max_n = (thread_id_n + 1) * kNPerThread; return iN0 * S::kNPerThread + iN3;
int delta = thread_max_n - n_per_block_tail_loop;
delta = clamp(thread_max_n - n_per_block_tail_loop, 0, kNPerThread);
max_count += kNPerThread - delta;
}
return max_count;
} }
template <typename XBlockWindow, template <typename XBlockWindow,
...@@ -162,8 +172,8 @@ struct Layernorm2dFwd ...@@ -162,8 +172,8 @@ struct Layernorm2dFwd
ComputeDataType epsilon, ComputeDataType epsilon,
ck_tile::index_t N) const ck_tile::index_t N) const
{ {
int welford_max_count = GetWelfordMaxCount(N); auto intra_thread_count_last = GetLastloopLayerNormIntraLaneReduceCount(N);
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count}; ThreadWelford<ComputeDataType, XDataType> thread_welford{intra_thread_count_last};
using XTensorType = decltype(load_tile(x_block_window)); using XTensorType = decltype(load_tile(x_block_window));
auto mean_compute_block_tensor = auto mean_compute_block_tensor =
...@@ -236,11 +246,15 @@ struct Layernorm2dFwd ...@@ -236,11 +246,15 @@ struct Layernorm2dFwd
ComputeDataType epsilon, ComputeDataType epsilon,
ck_tile::index_t N) const ck_tile::index_t N) const
{ {
using S = typename Problem::BlockShape;
index_t num_n_tile_iteration = index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, kNPerBlock)); __builtin_amdgcn_readfirstlane((N + kNPerBlock - 1) / kNPerBlock);
auto intra_thread_count = S::kNRepeat * S::kNPerThread * (num_n_tile_iteration - 1);
auto intra_thread_count_last = GetLastloopLayerNormIntraLaneReduceCount(N);
int welford_max_count = GetWelfordMaxCount(N); ThreadWelford<ComputeDataType, XDataType> thread_welford{intra_thread_count};
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count}; ThreadWelford<ComputeDataType, XDataType> thread_welford_last{intra_thread_count_last};
using XTensorType = decltype(load_tile(x_block_window)); using XTensorType = decltype(load_tile(x_block_window));
auto mean_compute_block_tensor = auto mean_compute_block_tensor =
...@@ -251,13 +265,19 @@ struct Layernorm2dFwd ...@@ -251,13 +265,19 @@ struct Layernorm2dFwd
clear_tile(mean_compute_block_tensor); clear_tile(mean_compute_block_tensor);
clear_tile(var_compute_block_tensor); clear_tile(var_compute_block_tensor);
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration - 1; ++iN)
{ {
const auto x_block_tensor = load_tile(x_block_window); const auto x_block_tensor = load_tile(x_block_window);
thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor); thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor);
move_tile_window(x_block_window, {0, kNPerBlock}); move_tile_window(x_block_window, {0, kNPerBlock});
} }
const auto x_block_tensor_ = load_tile(x_block_window);
thread_welford_last.cur_count_ += intra_thread_count;
thread_welford_last.max_count_ += intra_thread_count;
thread_welford_last(x_block_tensor_, mean_compute_block_tensor, var_compute_block_tensor);
thread_welford.cur_count_ += intra_thread_count_last;
// TODO: support cross warp Welford // TODO: support cross warp Welford
WarpMergeWelford<ComputeDataType, true>{}( WarpMergeWelford<ComputeDataType, true>{}(
...@@ -275,7 +295,6 @@ struct Layernorm2dFwd ...@@ -275,7 +295,6 @@ struct Layernorm2dFwd
ck_tile::index_t stride_to_right_most_window = ck_tile::index_t stride_to_right_most_window =
N % kNPerBlock == 0 ? N - kNPerBlock : N - N % kNPerBlock; N % kNPerBlock == 0 ? N - kNPerBlock : N - N % kNPerBlock;
move_tile_window(x_block_window, {0, -kNPerBlock});
move_tile_window(gamma_block_window, {stride_to_right_most_window}); move_tile_window(gamma_block_window, {stride_to_right_most_window});
move_tile_window(beta_block_window, {stride_to_right_most_window}); move_tile_window(beta_block_window, {stride_to_right_most_window});
move_tile_window(y_block_window, {0, stride_to_right_most_window}); move_tile_window(y_block_window, {0, stride_to_right_most_window});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment