Commit 3bb0cbe7 authored by rocking's avatar rocking
Browse files

We only use one block in K dimension.

Hence, we can simplify the indexing of global R/W.
parent 6d3ad8cd
...@@ -199,7 +199,6 @@ struct DeviceLayernorm : public BaseOperator ...@@ -199,7 +199,6 @@ struct DeviceLayernorm : public BaseOperator
gamma_grid_desc_m_k, gamma_grid_desc_m_k,
beta_grid_desc_m_k, beta_grid_desc_m_k,
y_grid_desc_m_k, y_grid_desc_m_k,
arg.blkGroupSize,
arg.numBlockTileIteration, arg.numBlockTileIteration,
arg.epsilon_, arg.epsilon_,
arg.in_dev_, arg.in_dev_,
......
...@@ -25,7 +25,6 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k, ...@@ -25,7 +25,6 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_M_K gamma_grid_desc_m_k, const GridDesc_M_K gamma_grid_desc_m_k,
const GridDesc_M_K beta_grid_desc_m_k, const GridDesc_M_K beta_grid_desc_m_k,
const GridDesc_M_K y_grid_desc_m_k, const GridDesc_M_K y_grid_desc_m_k,
index_t block_group_size,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
AccDataType epsilon, AccDataType epsilon,
const XDataType* const __restrict__ p_x_global, const XDataType* const __restrict__ p_x_global,
...@@ -37,7 +36,6 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k, ...@@ -37,7 +36,6 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
gamma_grid_desc_m_k, gamma_grid_desc_m_k,
beta_grid_desc_m_k, beta_grid_desc_m_k,
y_grid_desc_m_k, y_grid_desc_m_k,
block_group_size,
num_k_block_tile_iteration, num_k_block_tile_iteration,
epsilon, epsilon,
p_x_global, p_x_global,
...@@ -119,7 +117,6 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -119,7 +117,6 @@ struct GridwiseLayernorm_mk_to_mk
const GridDesc_M_K& gamma_grid_desc_m_k, const GridDesc_M_K& gamma_grid_desc_m_k,
const GridDesc_M_K& beta_grid_desc_m_k, const GridDesc_M_K& beta_grid_desc_m_k,
const GridDesc_M_K& y_grid_desc_m_k, const GridDesc_M_K& y_grid_desc_m_k,
index_t block_group_size,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
AccDataType epsilon, AccDataType epsilon,
const XDataType* const __restrict__ p_x_global, const XDataType* const __restrict__ p_x_global,
...@@ -171,8 +168,6 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -171,8 +168,6 @@ struct GridwiseLayernorm_mk_to_mk
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size;
const auto thread_cluster_idx = const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
...@@ -180,8 +175,6 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -180,8 +175,6 @@ struct GridwiseLayernorm_mk_to_mk
const auto thread_m_cluster_id = thread_cluster_idx[I0]; const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1]; const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>; using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
...@@ -197,9 +190,9 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -197,9 +190,9 @@ struct GridwiseLayernorm_mk_to_mk
1, 1,
true>( true>(
x_grid_desc_m_k, x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, make_multi_index(block_global_id * M_BlockTileSize +
block_local_id * reduceSizePerBlock + thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * KThreadSliceSize));
auto threadwise_gamma_load = ThreadwiseTensorSliceTransfer_v2<GammaDataType, auto threadwise_gamma_load = ThreadwiseTensorSliceTransfer_v2<GammaDataType,
AccDataType, AccDataType,
...@@ -212,9 +205,9 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -212,9 +205,9 @@ struct GridwiseLayernorm_mk_to_mk
1, 1,
true>( true>(
gamma_grid_desc_m_k, gamma_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, make_multi_index(block_global_id * M_BlockTileSize +
block_local_id * reduceSizePerBlock + thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * KThreadSliceSize));
auto threadwise_beta_load = ThreadwiseTensorSliceTransfer_v2<BetaDataType, auto threadwise_beta_load = ThreadwiseTensorSliceTransfer_v2<BetaDataType,
AccDataType, AccDataType,
...@@ -227,9 +220,9 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -227,9 +220,9 @@ struct GridwiseLayernorm_mk_to_mk
1, 1,
true>( true>(
beta_grid_desc_m_k, beta_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, make_multi_index(block_global_id * M_BlockTileSize +
block_local_id * reduceSizePerBlock + thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * KThreadSliceSize));
auto threadwise_y_store = ThreadwiseTensorSliceTransfer_v1r3<AccDataType, auto threadwise_y_store = ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
YDataType, YDataType,
...@@ -244,9 +237,9 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -244,9 +237,9 @@ struct GridwiseLayernorm_mk_to_mk
1, 1,
true>( true>(
y_grid_desc_m_k, y_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, make_multi_index(block_global_id * M_BlockTileSize +
block_local_id * reduceSizePerBlock + thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize), thread_k_cluster_id * KThreadSliceSize),
PassThroughOp{}); PassThroughOp{});
// Copy x from Cache // Copy x from Cache
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment