"...composable_kernel_rocm.git" did not exist on "f0759faff4a1c3ba5f739dfed468530e0ee9f28b"
Commit 8c3d43cf authored by rocking's avatar rocking
Browse files

Fix bug of padding

parent 629257f9
......@@ -125,29 +125,17 @@ struct Layernorm2dFwd
return out_dstr_tensor;
}
CK_TILE_DEVICE static int GetWelfordMaxCount(int N)
CK_TILE_DEVICE static index_t GetLastloopIntraLaneReduceCount(index_t N)
{
constexpr ck_tile::index_t kNThreadPerBlock = kNPerBlock / kNPerThread;
constexpr ck_tile::index_t kNThreadSliceSize = kNPerThread * kNRepeat;
constexpr ck_tile::index_t kNThreadStepSize = kNThreadPerBlock * kNPerThread;
int thread_id_n = get_thread_id() % kNThreadPerBlock;
int max_count = __builtin_amdgcn_readfirstlane(
N < kNPerBlock ? 0 : kNThreadSliceSize * (N / kNPerBlock));
int n_per_block_tail_loop =
__builtin_amdgcn_readfirstlane(N - max_count * kNThreadPerBlock);
if(n_per_block_tail_loop > 0)
{
static_for<0, kNRepeat, 1>{}([&](auto i) {
int thread_max_n = (thread_id_n + 1) * kNPerThread + kNThreadStepSize * i;
int delta = thread_max_n - n_per_block_tail_loop;
delta = clamp(thread_max_n - n_per_block_tail_loop, 0, kNPerThread);
max_count += kNPerThread - delta;
});
}
return max_count;
using S = typename Problem::BlockShape;
index_t LastloopN = N % kNPerBlock == 0 ? kNPerBlock : N % kNPerBlock;
constexpr index_t NThread = S::kNWarpPerBlock * S::kNThreadPerWarp;
index_t iNLane = get_thread_id() % NThread;
index_t iN0 = LastloopN / (S::kNPerThread * S::kNThreadPerWarp);
index_t iN1 = (LastloopN % (S::kNPerThread * S::kNThreadPerWarp)) / S::kNPerThread;
index_t N2 = (LastloopN % (S::kNPerThread * S::kNThreadPerWarp)) % S::kNPerThread;
index_t iN3 = iNLane < iN1 ? S::kNPerThread : iNLane == iN1 ? N2 : 0;
return iN0 * S::kNPerThread + iN3;
}
template <typename XBlockWindow,
......@@ -167,7 +155,7 @@ struct Layernorm2dFwd
ComputeDataType epsilon,
ck_tile::index_t N) const
{
int welford_max_count = GetWelfordMaxCount(N);
index_t welford_max_count = GetLastloopIntraLaneReduceCount(N);
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
using XTensorType = decltype(load_tile(x_block_window));
......@@ -244,7 +232,8 @@ struct Layernorm2dFwd
index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, kNPerBlock));
int welford_max_count = GetWelfordMaxCount(N);
index_t intra_thread_count = kNRepeat * kNPerThread * (num_n_tile_iteration - 1);
index_t welford_max_count = intra_thread_count + GetLastloopIntraLaneReduceCount(N);
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
using XTensorType = decltype(load_tile(x_block_window));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment