Commit 3bb0cbe7 authored by rocking's avatar rocking
Browse files

We only use one block in K dimension.

Hence, we can simplify the indexing of global R/W.
parent 6d3ad8cd
......@@ -199,7 +199,6 @@ struct DeviceLayernorm : public BaseOperator
gamma_grid_desc_m_k,
beta_grid_desc_m_k,
y_grid_desc_m_k,
arg.blkGroupSize,
arg.numBlockTileIteration,
arg.epsilon_,
arg.in_dev_,
......
......@@ -25,7 +25,6 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_M_K gamma_grid_desc_m_k,
const GridDesc_M_K beta_grid_desc_m_k,
const GridDesc_M_K y_grid_desc_m_k,
index_t block_group_size,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const XDataType* const __restrict__ p_x_global,
......@@ -37,7 +36,6 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
gamma_grid_desc_m_k,
beta_grid_desc_m_k,
y_grid_desc_m_k,
block_group_size,
num_k_block_tile_iteration,
epsilon,
p_x_global,
......@@ -119,7 +117,6 @@ struct GridwiseLayernorm_mk_to_mk
const GridDesc_M_K& gamma_grid_desc_m_k,
const GridDesc_M_K& beta_grid_desc_m_k,
const GridDesc_M_K& y_grid_desc_m_k,
index_t block_group_size,
index_t num_k_block_tile_iteration,
AccDataType epsilon,
const XDataType* const __restrict__ p_x_global,
......@@ -171,8 +168,6 @@ struct GridwiseLayernorm_mk_to_mk
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
......@@ -180,8 +175,6 @@ struct GridwiseLayernorm_mk_to_mk
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
......@@ -197,9 +190,9 @@ struct GridwiseLayernorm_mk_to_mk
1,
true>(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_gamma_load = ThreadwiseTensorSliceTransfer_v2<GammaDataType,
AccDataType,
......@@ -212,9 +205,9 @@ struct GridwiseLayernorm_mk_to_mk
1,
true>(
gamma_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_beta_load = ThreadwiseTensorSliceTransfer_v2<BetaDataType,
AccDataType,
......@@ -227,9 +220,9 @@ struct GridwiseLayernorm_mk_to_mk
1,
true>(
beta_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_y_store = ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
YDataType,
......@@ -244,9 +237,9 @@ struct GridwiseLayernorm_mk_to_mk
1,
true>(
y_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize),
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize),
PassThroughOp{});
// Copy x from Cache
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment