Commit 6d3ad8cd authored by rocking's avatar rocking
Browse files

[What] Get length from upper length.

[Why] if we get length directly, we may get length after padding.
parent 0a2a25e3
...@@ -251,10 +251,8 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -251,10 +251,8 @@ struct GridwiseLayernorm_mk_to_mk
// Copy x from Cache // Copy x from Cache
// one pass: fwd, second pass: bwd // one pass: fwd, second pass: bwd
constexpr auto thread_copy_fwd_step = constexpr auto thread_copy_fwd_step = make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize);
make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize); constexpr auto thread_copy_bwd_step = make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
constexpr auto thread_copy_bwd_step =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
...@@ -266,7 +264,8 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -266,7 +264,8 @@ struct GridwiseLayernorm_mk_to_mk
p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize()); p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
// E(x), E[x^2], var(x) // E(x), E[x^2], var(x)
int reduce_length = x_grid_desc_m_k.GetLength(I1); int reduce_length = x_grid_desc_m_k.GetTransforms()[I0].GetUpperLengths()[I1];
index_t reducedTiles = 0; index_t reducedTiles = 0;
do do
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment