Commit 22a38a50 authored by Rocking's avatar Rocking
Browse files

[What] Fix bug of layernorm for greater than 2 dimension.

[Why] We need to get upper length from merge transform instead of embed transform.
parent 8166d875
......@@ -79,6 +79,7 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
......@@ -235,7 +236,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
// E(x), E[x^2], var(x)
int reduce_length = x_grid_desc_m_k.GetTransforms()[I0].GetUpperLengths()[I1];
// FIXME: Should not hack the transform from deviceOP
int reduce_length = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
index_t reducedTiles = 0;
do
......
......@@ -71,6 +71,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
......@@ -78,7 +79,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
__device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k,
int thread_k_cluster_id)
{
int kPerBlock = x_grid_desc_m_k.GetTransforms()[I0].GetUpperLengths()[I1];
// FIXME: Should not hack the transform from deviceOP
int kPerBlock = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
int kPerThread =
kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize);
int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment