Commit 8c3d43cf authored by rocking's avatar rocking
Browse files

Fix bug of padding

parent 629257f9
...@@ -125,29 +125,17 @@ struct Layernorm2dFwd ...@@ -125,29 +125,17 @@ struct Layernorm2dFwd
return out_dstr_tensor; return out_dstr_tensor;
} }
CK_TILE_DEVICE static int GetWelfordMaxCount(int N) CK_TILE_DEVICE static index_t GetLastloopIntraLaneReduceCount(index_t N)
{ {
constexpr ck_tile::index_t kNThreadPerBlock = kNPerBlock / kNPerThread; using S = typename Problem::BlockShape;
constexpr ck_tile::index_t kNThreadSliceSize = kNPerThread * kNRepeat; index_t LastloopN = N % kNPerBlock == 0 ? kNPerBlock : N % kNPerBlock;
constexpr ck_tile::index_t kNThreadStepSize = kNThreadPerBlock * kNPerThread; constexpr index_t NThread = S::kNWarpPerBlock * S::kNThreadPerWarp;
index_t iNLane = get_thread_id() % NThread;
int thread_id_n = get_thread_id() % kNThreadPerBlock; index_t iN0 = LastloopN / (S::kNPerThread * S::kNThreadPerWarp);
int max_count = __builtin_amdgcn_readfirstlane( index_t iN1 = (LastloopN % (S::kNPerThread * S::kNThreadPerWarp)) / S::kNPerThread;
N < kNPerBlock ? 0 : kNThreadSliceSize * (N / kNPerBlock)); index_t N2 = (LastloopN % (S::kNPerThread * S::kNThreadPerWarp)) % S::kNPerThread;
int n_per_block_tail_loop = index_t iN3 = iNLane < iN1 ? S::kNPerThread : iNLane == iN1 ? N2 : 0;
__builtin_amdgcn_readfirstlane(N - max_count * kNThreadPerBlock); return iN0 * S::kNPerThread + iN3;
if(n_per_block_tail_loop > 0)
{
static_for<0, kNRepeat, 1>{}([&](auto i) {
int thread_max_n = (thread_id_n + 1) * kNPerThread + kNThreadStepSize * i;
int delta = thread_max_n - n_per_block_tail_loop;
delta = clamp(thread_max_n - n_per_block_tail_loop, 0, kNPerThread);
max_count += kNPerThread - delta;
});
}
return max_count;
} }
template <typename XBlockWindow, template <typename XBlockWindow,
...@@ -167,7 +155,7 @@ struct Layernorm2dFwd ...@@ -167,7 +155,7 @@ struct Layernorm2dFwd
ComputeDataType epsilon, ComputeDataType epsilon,
ck_tile::index_t N) const ck_tile::index_t N) const
{ {
int welford_max_count = GetWelfordMaxCount(N); index_t welford_max_count = GetLastloopIntraLaneReduceCount(N);
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count}; ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
using XTensorType = decltype(load_tile(x_block_window)); using XTensorType = decltype(load_tile(x_block_window));
...@@ -244,7 +232,8 @@ struct Layernorm2dFwd ...@@ -244,7 +232,8 @@ struct Layernorm2dFwd
index_t num_n_tile_iteration = index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, kNPerBlock)); __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, kNPerBlock));
int welford_max_count = GetWelfordMaxCount(N); index_t intra_thread_count = kNRepeat * kNPerThread * (num_n_tile_iteration - 1);
index_t welford_max_count = intra_thread_count + GetLastloopIntraLaneReduceCount(N);
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count}; ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
using XTensorType = decltype(load_tile(x_block_window)); using XTensorType = decltype(load_tile(x_block_window));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment