"...composable_kernel-1.git" did not exist on "ccc4a1d365999a3e15623f490314e66c2d671389"
Commit 02236580 authored by rocking's avatar rocking
Browse files

refine welford max count calculation

parent 96568141
...@@ -31,9 +31,9 @@ struct Layernorm2dFwd ...@@ -31,9 +31,9 @@ struct Layernorm2dFwd
static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock; static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock; static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock;
static constexpr bool kPadM = false; // TODO - Problem::kPadM static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kTwoPass = Problem::kTwoPass; static constexpr bool kTwoPass = Problem::kTwoPass;
static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp; static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp;
static constexpr ck_tile::index_t kNPerThread = Problem::BlockShape::kNPerThread; static constexpr ck_tile::index_t kNPerThread = Problem::BlockShape::kNPerThread;
...@@ -106,21 +106,6 @@ struct Layernorm2dFwd ...@@ -106,21 +106,6 @@ struct Layernorm2dFwd
sequence<0, 3>>{}); sequence<0, 3>>{});
} }
template <typename Dstr>
CK_TILE_DEVICE static constexpr auto GetNPerThread(Dstr)
{
constexpr auto nDstrSpan = Dstr::get_distributed_spans().template at<1>();
using Lengths = decltype(nDstrSpan.impl_);
ck_tile::index_t ret = 1;
ck_tile::static_for<0, Lengths::size(), 1>{}(
[&](auto idx) { ret *= Lengths::template at(idx); });
return ret;
}
template <typename DistributedTensor> template <typename DistributedTensor>
CK_TILE_DEVICE static auto InvSqrt(const DistributedTensor& in_dstr_tensor, CK_TILE_DEVICE static auto InvSqrt(const DistributedTensor& in_dstr_tensor,
const ComputeDataType epsilon) const ComputeDataType epsilon)
...@@ -139,20 +124,25 @@ struct Layernorm2dFwd ...@@ -139,20 +124,25 @@ struct Layernorm2dFwd
return out_dstr_tensor; return out_dstr_tensor;
} }
CK_TILE_HOST_DEVICE static constexpr auto CK_TILE_DEVICE static int GetWelfordMaxCount(int N)
GetLastloopLayerNormIntraLaneReduceCount(index_t NLength)
{ {
using S = typename Problem::BlockShape; constexpr ck_tile::index_t kNThreadPerBlock = kNPerBlock / kNPerThread;
// S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread
auto LastloopN = NLength % kNPerBlock == 0 ? kNPerBlock : NLength % kNPerBlock; int thread_id_n = get_thread_id() % kNThreadPerBlock;
constexpr auto NThread = S::kNWarpPerBlock * S::kNThreadPerWarp; int max_count =
auto iNLane = get_thread_local_1d_id() % NThread; __builtin_amdgcn_readfirstlane(N < kNPerBlock ? 0 : kNPerThread * (N / kNPerBlock));
auto iN0 = LastloopN / (S::kNPerThread * S::kNThreadPerWarp); int n_per_block_tail_loop =
auto iN1 = (LastloopN % (S::kNPerThread * S::kNThreadPerWarp)) / S::kNPerThread; __builtin_amdgcn_readfirstlane(N - max_count * kNThreadPerBlock);
auto N2 = (LastloopN % (S::kNPerThread * S::kNThreadPerWarp)) % S::kNPerThread;
auto iN3 = iNLane < iN1 ? S::kNPerThread : iNLane == iN1 ? N2 : 0; if(n_per_block_tail_loop > 0)
{
return iN0 * S::kNPerThread + iN3; int thread_max_n = (thread_id_n + 1) * kNPerThread;
int delta = thread_max_n - n_per_block_tail_loop;
delta = clamp(thread_max_n - n_per_block_tail_loop, 0, kNPerThread);
max_count += kNPerThread - delta;
}
return max_count;
} }
template <typename XBlockWindow, template <typename XBlockWindow,
...@@ -172,8 +162,8 @@ struct Layernorm2dFwd ...@@ -172,8 +162,8 @@ struct Layernorm2dFwd
ComputeDataType epsilon, ComputeDataType epsilon,
ck_tile::index_t N) const ck_tile::index_t N) const
{ {
auto intra_thread_count_last = GetLastloopLayerNormIntraLaneReduceCount(N); int welford_max_count = GetWelfordMaxCount(N);
ThreadWelford<ComputeDataType, XDataType> thread_welford{intra_thread_count_last}; ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
using XTensorType = decltype(load_tile(x_block_window)); using XTensorType = decltype(load_tile(x_block_window));
auto mean_compute_block_tensor = auto mean_compute_block_tensor =
...@@ -246,15 +236,11 @@ struct Layernorm2dFwd ...@@ -246,15 +236,11 @@ struct Layernorm2dFwd
ComputeDataType epsilon, ComputeDataType epsilon,
ck_tile::index_t N) const ck_tile::index_t N) const
{ {
using S = typename Problem::BlockShape;
index_t num_n_tile_iteration = index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane((N + kNPerBlock - 1) / kNPerBlock); __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, kNPerBlock));
auto intra_thread_count = S::kNRepeat * S::kNPerThread * (num_n_tile_iteration - 1);
auto intra_thread_count_last = GetLastloopLayerNormIntraLaneReduceCount(N);
ThreadWelford<ComputeDataType, XDataType> thread_welford{intra_thread_count}; int welford_max_count = GetWelfordMaxCount(N);
ThreadWelford<ComputeDataType, XDataType> thread_welford_last{intra_thread_count_last}; ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
using XTensorType = decltype(load_tile(x_block_window)); using XTensorType = decltype(load_tile(x_block_window));
auto mean_compute_block_tensor = auto mean_compute_block_tensor =
...@@ -265,19 +251,13 @@ struct Layernorm2dFwd ...@@ -265,19 +251,13 @@ struct Layernorm2dFwd
clear_tile(mean_compute_block_tensor); clear_tile(mean_compute_block_tensor);
clear_tile(var_compute_block_tensor); clear_tile(var_compute_block_tensor);
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration - 1; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
const auto x_block_tensor = load_tile(x_block_window); const auto x_block_tensor = load_tile(x_block_window);
thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor); thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor);
move_tile_window(x_block_window, {0, kNPerBlock}); move_tile_window(x_block_window, {0, kNPerBlock});
} }
const auto x_block_tensor_ = load_tile(x_block_window);
thread_welford_last.cur_count_ += intra_thread_count;
thread_welford_last.max_count_ += intra_thread_count;
thread_welford_last(x_block_tensor_, mean_compute_block_tensor, var_compute_block_tensor);
thread_welford.cur_count_ += intra_thread_count_last;
// TODO: support cross warp Welford // TODO: support cross warp Welford
WarpMergeWelford<ComputeDataType, true>{}( WarpMergeWelford<ComputeDataType, true>{}(
...@@ -295,6 +275,7 @@ struct Layernorm2dFwd ...@@ -295,6 +275,7 @@ struct Layernorm2dFwd
ck_tile::index_t stride_to_right_most_window = ck_tile::index_t stride_to_right_most_window =
N % kNPerBlock == 0 ? N - kNPerBlock : N - N % kNPerBlock; N % kNPerBlock == 0 ? N - kNPerBlock : N - N % kNPerBlock;
move_tile_window(x_block_window, {0, -kNPerBlock});
move_tile_window(gamma_block_window, {stride_to_right_most_window}); move_tile_window(gamma_block_window, {stride_to_right_most_window});
move_tile_window(beta_block_window, {stride_to_right_most_window}); move_tile_window(beta_block_window, {stride_to_right_most_window});
move_tile_window(y_block_window, {0, stride_to_right_most_window}); move_tile_window(y_block_window, {0, stride_to_right_most_window});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment