Commit 8745c0ca authored by rocking's avatar rocking
Browse files

Update naive variance kernel

parent 3b6f9c16
...@@ -4,9 +4,8 @@ ...@@ -4,9 +4,8 @@
#pragma once #pragma once
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_operator.hpp" #include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" #include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" #include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
...@@ -19,8 +18,8 @@ template <typename XDataType, ...@@ -19,8 +18,8 @@ template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename AccDataType, typename ComputeDataType,
typename AccElementwiseOperation, typename YElementwiseOperation,
typename GridDesc_M_K, typename GridDesc_M_K,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
...@@ -46,6 +45,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -46,6 +45,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"); "Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(XSrcVectorSize == YDstVectorSize);
static_assert(XSrcVectorSize == GammaSrcVectorSize);
static_assert(XSrcVectorSize == BetaSrcVectorSize);
static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0); static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>; using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
...@@ -59,19 +62,23 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -59,19 +62,23 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
static constexpr auto thread_cluster_desc = static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, XSrcVectorSize>;
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}))); make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})));
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}))); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseSumReduce = PartitionedBlockwiseReduction<AccDataType, using BlockwiseSumReduce = PartitionedBlockwiseReduction<ComputeDataType,
BlockSize, BlockSize,
ThreadClusterLengths_M_K, ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder, ThreadClusterArrangeOrder,
reduce::Add, reduce::Add,
true>; true>;
using ThreadwiseSumReduce = ThreadwiseReduction<AccDataType, using ThreadwiseSumReduce = ThreadwiseReduction<ComputeDataType,
ThreadReduceSrcDesc_M_K, ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M, ThreadReduceDstDesc_M,
reduce::Add, reduce::Add,
...@@ -81,64 +88,77 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -81,64 +88,77 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize;
static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
__device__ static void Run(const GridDesc_M_K& x_grid_desc_m_k, __device__ static void Run(const GridDesc_M_K& x_grid_desc_m_k,
const GridDesc_M_K& gamma_grid_desc_m_k, const GridDesc_M_K& gamma_grid_desc_m_k,
const GridDesc_M_K& beta_grid_desc_m_k, const GridDesc_M_K& beta_grid_desc_m_k,
const GridDesc_M_K& y_grid_desc_m_k, const GridDesc_M_K& y_grid_desc_m_k,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
AccDataType epsilon, ComputeDataType epsilon,
const XDataType* const __restrict__ p_x_global, const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global, const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global, const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global, YDataType* const __restrict__ p_y_global,
const AccElementwiseOperation acc_elementwise_op) const YElementwiseOperation y_elementwise_op)
{ {
if constexpr(SweepOnce)
{
num_k_block_tile_iteration = 1;
}
// LDS // LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; __shared__ ComputeDataType p_reduce_work_buffer[BlockSize];
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto reduce_work_buf = auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
x_thread_buf; p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>& beta_thread_buf = gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
y_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>& x_square_thread_buf = y_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> mean_thread_buf; auto x_thread_buf = generate_tuple(
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> [&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * XSrcVectorSize,
true>{};
},
Number<ThreadBufferNumber>{});
auto gamma_thread_buf = generate_tuple(
[&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * GammaSrcVectorSize,
true>{};
},
Number<ThreadBufferNumber>{});
auto beta_thread_buf = generate_tuple(
[&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * BetaSrcVectorSize,
true>{};
},
Number<ThreadBufferNumber>{});
auto y_thread_buf = generate_tuple(
[&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * YDstVectorSize,
true>{};
},
Number<ThreadBufferNumber>{});
auto& x_square_thread_buf = y_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
mean_square_thread_buf; mean_square_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>& var_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>&
mean_square_thread_buf; var_thread_buf = mean_square_thread_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
mean_square_thread_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
});
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
...@@ -149,12 +169,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -149,12 +169,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
const auto thread_m_cluster_id = thread_cluster_idx[I0]; const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1]; const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType, auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType, ComputeDataType,
GridDesc_M_K, GridDesc_M_K,
decltype(thread_buffer_desc_m_k), decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K, ThreadBufferLengths_M_K,
...@@ -166,11 +182,11 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -166,11 +182,11 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
x_grid_desc_m_k, x_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * XSrcVectorSize));
auto threadwise_gamma_load = auto threadwise_gamma_load =
ThreadwiseTensorSliceTransfer_v2<GammaDataType, ThreadwiseTensorSliceTransfer_v2<GammaDataType,
AccDataType, ComputeDataType,
GridDesc_M_K, GridDesc_M_K,
decltype(thread_buffer_desc_m_k), decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K, ThreadBufferLengths_M_K,
...@@ -182,11 +198,11 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -182,11 +198,11 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
gamma_grid_desc_m_k, gamma_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * GammaSrcVectorSize));
auto threadwise_beta_load = auto threadwise_beta_load =
ThreadwiseTensorSliceTransfer_v2<BetaDataType, ThreadwiseTensorSliceTransfer_v2<BetaDataType,
AccDataType, ComputeDataType,
GridDesc_M_K, GridDesc_M_K,
decltype(thread_buffer_desc_m_k), decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K, ThreadBufferLengths_M_K,
...@@ -198,14 +214,14 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -198,14 +214,14 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
beta_grid_desc_m_k, beta_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * BetaSrcVectorSize));
auto threadwise_y_store = auto threadwise_y_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
YDataType, YDataType,
decltype(thread_buffer_desc_m_k), decltype(thread_buffer_desc_m_k),
GridDesc_M_K, GridDesc_M_K,
AccElementwiseOperation, YElementwiseOperation,
ThreadBufferLengths_M_K, ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder, ThreadBufferDimAccessOrder,
YDstVectorDim, YDstVectorDim,
...@@ -216,13 +232,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -216,13 +232,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
y_grid_desc_m_k, y_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize), thread_k_cluster_id * YDstVectorSize),
acc_elementwise_op); y_elementwise_op);
// Copy x from Cache constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
// one pass: fwd, second pass: bwd
constexpr auto thread_copy_fwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize);
constexpr auto thread_copy_bwd_step_m_k = constexpr auto thread_copy_bwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
...@@ -239,121 +252,244 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -239,121 +252,244 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// FIXME: Should not hack the transform from deviceOP // FIXME: Should not hack the transform from deviceOP
int reduce_length = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; int reduce_length = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
index_t reducedTiles = 0; static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
do mean_thread_buf(I) = reduce::Add::template GetIdentityValue<ComputeDataType>();
{ mean_square_thread_buf(I) = reduce::Add::template GetIdentityValue<ComputeDataType>();
threadwise_x_load.Run(x_grid_desc_m_k, });
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { // Separate sweep once and sweep twice pipeline
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { if constexpr(SweepOnce)
constexpr auto offset_m_k = {
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
x_square_thread_buf(Number<offset_m_k>{}) = threadwise_x_load.Run(x_grid_desc_m_k,
x_thread_buf(Number<offset_m_k>{}) * x_thread_buf(Number<offset_m_k>{}); x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf(i));
threadwise_gamma_load.Run(gamma_grid_desc_m_k,
gamma_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
gamma_thread_buf(i));
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
x_square_thread_buf(i)(Number<offset_m_k>{}) =
x_thread_buf(i)(Number<offset_m_k>{}) *
x_thread_buf(i)(Number<offset_m_k>{});
});
}); });
});
ThreadwiseSumReduce::Reduce(x_thread_buf, mean_thread_buf); ThreadwiseSumReduce::Reduce(x_thread_buf[i], mean_thread_buf);
ThreadwiseSumReduce::Reduce(x_square_thread_buf, mean_square_thread_buf); ThreadwiseSumReduce::Reduce(x_square_thread_buf[i], mean_square_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); if constexpr(i != ThreadBufferNumber - 1)
{
++reducedTiles; threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
} while(reducedTiles < num_k_block_tile_iteration); threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
thread_copy_fwd_step_m_k);
}
});
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0) if constexpr(I > 0)
block_sync_lds(); block_sync_lds();
BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I)); BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I));
mean_thread_buf(I) = mean_thread_buf(I) / reduce_length; mean_thread_buf(I) = mean_thread_buf(I) / reduce_length;
block_sync_lds(); block_sync_lds();
BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I)); BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I));
mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length; mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length;
// var(x) = E[x^2] - E[x]^2 // var(x) = E[x^2] - E[x]^2
var_thread_buf(I) = var_thread_buf(I) =
mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
}); });
// y = (x - E[x]) / sqrt(var[x] + epsilon) static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k; auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
threadwise_beta_load.Run(beta_grid_desc_m_k,
beta_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
beta_thread_buf(iK0));
if constexpr(iK0 != ThreadBufferNumber - 1)
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
thread_copy_fwd_step_m_k);
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor;
// gamma & beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) *
gamma_thread_buf(iK0)(Number<offset_m_k>{}) +
beta_thread_buf(iK0)(Number<offset_m_k>{});
});
});
});
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k); threadwise_y_store.Run(thread_buffer_desc_m_k,
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k); make_tuple(I0, I0),
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k); y_thread_buf(i),
y_grid_desc_m_k,
y_global_val_buf);
reducedTiles = 0; if constexpr(i != ThreadBufferNumber - 1)
do threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
thread_copy_fwd_step_m_k);
});
} // end of sweep once
else
{ {
if constexpr(!SweepOnce) for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{ {
threadwise_x_load.Run(x_grid_desc_m_k, static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
x_global_val_buf, threadwise_x_load.Run(x_grid_desc_m_k,
thread_buffer_desc_m_k, x_global_val_buf,
make_tuple(I0, I0), thread_buffer_desc_m_k,
x_thread_buf); make_tuple(I0, I0),
x_thread_buf(i));
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
x_square_thread_buf(i)(Number<offset_m_k>{}) =
x_thread_buf(i)(Number<offset_m_k>{}) *
x_thread_buf(i)(Number<offset_m_k>{});
});
});
ThreadwiseSumReduce::Reduce(x_thread_buf[i], mean_thread_buf);
ThreadwiseSumReduce::Reduce(x_square_thread_buf[i], mean_square_thread_buf);
});
} }
threadwise_gamma_load.Run(gamma_grid_desc_m_k, static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
gamma_global_val_buf, if constexpr(I > 0)
thread_buffer_desc_m_k, block_sync_lds();
make_tuple(I0, I0),
gamma_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I));
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { mean_thread_buf(I) = mean_thread_buf(I) / reduce_length;
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
// normalize
y_thread_buf(Number<offset_m_k>{}) =
(x_thread_buf(Number<offset_m_k>{}) - mean_thread_buf(iM)) /
sqrt(var_thread_buf(iM) + epsilon);
// gamma
y_thread_buf(Number<offset_m_k>{}) =
y_thread_buf(Number<offset_m_k>{}) * gamma_thread_buf(Number<offset_m_k>{});
});
});
threadwise_beta_load.Run(beta_grid_desc_m_k, block_sync_lds();
beta_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
beta_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I));
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length;
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
// beta // var(x) = E[x^2] - E[x]^2
y_thread_buf(Number<offset_m_k>{}) = var_thread_buf(I) =
y_thread_buf(Number<offset_m_k>{}) + beta_thread_buf(Number<offset_m_k>{}); mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
});
}); });
threadwise_y_store.Run(thread_buffer_desc_m_k, auto thread_copy_tail_m_k =
make_tuple(I0, I0), (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k;
y_thread_buf,
y_grid_desc_m_k,
y_global_val_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_bwd_step_m_k); threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_bwd_step_m_k); threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step_m_k); threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf(i));
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
});
++reducedTiles; static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
} while(reducedTiles < num_k_block_tile_iteration); threadwise_gamma_load.Run(gamma_grid_desc_m_k,
gamma_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
gamma_thread_buf(i));
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
thread_copy_fwd_step_m_k);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor;
// gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) *
gamma_thread_buf(iK0)(Number<offset_m_k>{});
});
});
});
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_beta_load.Run(beta_grid_desc_m_k,
beta_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
beta_thread_buf(i));
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
thread_copy_fwd_step_m_k);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) +
beta_thread_buf(iK0)(Number<offset_m_k>{});
});
});
});
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_y_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0),
y_thread_buf(i),
y_grid_desc_m_k,
y_global_val_buf);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
thread_copy_fwd_step_m_k);
});
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
2 * thread_copy_bwd_step_m_k);
threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
2 * thread_copy_bwd_step_m_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k,
2 * thread_copy_bwd_step_m_k);
}
} // end of sweep twice
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment