Commit 22a38a50 authored by Rocking's avatar Rocking
Browse files

[What] Fix bug of layernorm for greater than 2 dimension.

[Why] We need to get upper length from merge transform instead of embed transform.
parent 8166d875
...@@ -79,6 +79,7 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk ...@@ -79,6 +79,7 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
...@@ -235,7 +236,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk ...@@ -235,7 +236,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize()); p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
// E(x), E[x^2], var(x) // E(x), E[x^2], var(x)
int reduce_length = x_grid_desc_m_k.GetTransforms()[I0].GetUpperLengths()[I1]; // FIXME: Should not hack the transform from deviceOP
int reduce_length = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
index_t reducedTiles = 0; index_t reducedTiles = 0;
do do
......
...@@ -71,6 +71,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -71,6 +71,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
...@@ -78,7 +79,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk ...@@ -78,7 +79,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
__device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k, __device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k,
int thread_k_cluster_id) int thread_k_cluster_id)
{ {
int kPerBlock = x_grid_desc_m_k.GetTransforms()[I0].GetUpperLengths()[I1]; // FIXME: Should not hack the transform from deviceOP
int kPerBlock = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
int kPerThread = int kPerThread =
kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize); kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize);
int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize; int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment