Commit de6aad06 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Parameters renaming again in batchnorm backward kernels

parent 7d114e80
......@@ -26,7 +26,7 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
const XYGridDesc_M_K x_grid_desc_m_k,
const XYGridDesc_M_K dy_grid_desc_m_k,
const XYGridDesc_M_K dx_grid_desc_m_k,
const ScaleBiasDiffGridDesc_M_K scale_bias_diff_grid_desc_m_k,
const ScaleBiasDiffGridDesc_M_K dscale_dbias_grid_desc_m_k,
const MeanVarGridDesc_M mean_var_grid_desc_m,
const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M bias_grid_desc_m,
......@@ -48,7 +48,7 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
GridwiseReduceSecondHalfBatchNormBackwardFinal_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k,
dx_grid_desc_m_k,
scale_bias_diff_grid_desc_m_k,
dscale_dbias_grid_desc_m_k,
mean_var_grid_desc_m,
scale_grid_desc_m,
bias_grid_desc_m,
......@@ -143,7 +143,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
__device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
const XYGridDesc_M_K& dy_grid_desc_m_k,
const XYGridDesc_M_K& dx_grid_desc_m_k,
const ScaleBiasDiffGridDesc_M_K& scale_bias_diff_grid_desc_m_k,
const ScaleBiasDiffGridDesc_M_K& dscale_dbias_grid_desc_m_k,
const MeanVarGridDesc_M& mean_var_grid_desc_m,
const ScaleBiasGridDesc_M& scale_grid_desc_m,
const ScaleBiasGridDesc_M& bias_grid_desc_m,
......@@ -168,14 +168,12 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
reduce_scale_diff_thread_buf;
reduce_dscale_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
reduce_bias_diff_thread_buf;
reduce_dbias_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
scale_diff_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
bias_diff_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> dscale_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> dbias_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
......@@ -212,7 +210,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
// Step 1: do final reduction for scale_diff and bias_diff and output
auto threadwise_scale_diff_load_m_k =
auto threadwise_dscale_load_m_k =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
AccDataType,
ScaleBiasDiffGridDesc_M_K,
......@@ -223,12 +221,12 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
1,
1,
true>(
scale_bias_diff_grid_desc_m_k,
dscale_dbias_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
auto threadwise_bias_diff_load_m_k =
auto threadwise_dbias_load_m_k =
ThreadwiseTensorSliceTransfer_v2<BiasDataType,
AccDataType,
ScaleBiasDiffGridDesc_M_K,
......@@ -239,12 +237,12 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
1,
1,
true>(
scale_bias_diff_grid_desc_m_k,
dscale_dbias_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
auto threadwise_scale_diff_store_m =
auto threadwise_dscale_store_m =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType,
decltype(thread_buffer_desc_m),
......@@ -262,7 +260,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_bias_diff_store_m =
auto threadwise_dbias_store_m =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType,
decltype(thread_buffer_desc_m),
......@@ -280,67 +278,67 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
const auto reduce_scale_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_dscale, scale_bias_diff_grid_desc_m_k.GetElementSpaceSize());
const auto reduce_dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_dscale, dscale_dbias_grid_desc_m_k.GetElementSpaceSize());
const auto reduce_bias_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_dbias, scale_bias_diff_grid_desc_m_k.GetElementSpaceSize());
const auto reduce_dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_dbias, dscale_dbias_grid_desc_m_k.GetElementSpaceSize());
auto scale_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dscale, scale_grid_desc_m.GetElementSpaceSize());
auto bias_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbias, bias_grid_desc_m.GetElementSpaceSize());
constexpr auto scale_bias_diff_thread_copy_step_m_k =
constexpr auto dscale_dbias_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize * 1);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
scale_diff_thread_buf(I) = type_convert<AccDataType>(0.0f);
bias_diff_thread_buf(I) = type_convert<AccDataType>(0.0f);
dscale_thread_buf(I) = type_convert<AccDataType>(0.0f);
dbias_thread_buf(I) = type_convert<AccDataType>(0.0f);
});
for(index_t reducedTiles = 0; reducedTiles < num_scale_bias_diff_k_block_tile_iteration;
++reducedTiles)
{
threadwise_scale_diff_load_m_k.Run(scale_bias_diff_grid_desc_m_k,
reduce_scale_diff_global_val_buf,
threadwise_dscale_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_dscale_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_scale_diff_thread_buf);
reduce_dscale_thread_buf);
threadwise_bias_diff_load_m_k.Run(scale_bias_diff_grid_desc_m_k,
reduce_bias_diff_global_val_buf,
threadwise_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_dbias_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_bias_diff_thread_buf);
reduce_dbias_thread_buf);
ThreadwiseReduce::Reduce(reduce_scale_diff_thread_buf, scale_diff_thread_buf);
ThreadwiseReduce::Reduce(reduce_bias_diff_thread_buf, bias_diff_thread_buf);
ThreadwiseReduce::Reduce(reduce_dscale_thread_buf, dscale_thread_buf);
ThreadwiseReduce::Reduce(reduce_dbias_thread_buf, dbias_thread_buf);
threadwise_scale_diff_load_m_k.MoveSrcSliceWindow(scale_bias_diff_grid_desc_m_k,
scale_bias_diff_thread_copy_step_m_k);
threadwise_bias_diff_load_m_k.MoveSrcSliceWindow(scale_bias_diff_grid_desc_m_k,
scale_bias_diff_thread_copy_step_m_k);
threadwise_dscale_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
dscale_dbias_thread_copy_step_m_k);
threadwise_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
dscale_dbias_thread_copy_step_m_k);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
BlockwiseReduce::Reduce(reduce_work_buf, scale_diff_thread_buf(I));
BlockwiseReduce::Reduce(reduce_work_buf, dscale_thread_buf(I));
block_sync_lds();
BlockwiseReduce::Reduce(reduce_work_buf, bias_diff_thread_buf(I));
BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I));
});
threadwise_scale_diff_store_m.Run(thread_buffer_desc_m,
threadwise_dscale_store_m.Run(thread_buffer_desc_m,
make_tuple(I0),
scale_diff_thread_buf,
dscale_thread_buf,
scale_grid_desc_m,
scale_diff_global_val_buf);
dscale_global_buf);
threadwise_bias_diff_store_m.Run(thread_buffer_desc_m,
threadwise_dbias_store_m.Run(thread_buffer_desc_m,
make_tuple(I0),
bias_diff_thread_buf,
dbias_thread_buf,
bias_grid_desc_m,
bias_diff_global_val_buf);
dbias_global_buf);
// Step 2: calculate dx = 1/N * invVar * scale * (N * dy - biasDiff - scaleDiff * (x - mean)
// * invVar) and output
......@@ -426,38 +424,38 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto x_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
const auto dy_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto dy_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dy, dy_grid_desc_m_k.GetElementSpaceSize());
auto dx_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
auto dx_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dx, dx_grid_desc_m_k.GetElementSpaceSize());
const auto scale_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto scale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_scale, scale_grid_desc_m.GetElementSpaceSize());
const auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_mean, mean_var_grid_desc_m.GetElementSpaceSize());
const auto inv_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_inv_var, mean_var_grid_desc_m.GetElementSpaceSize());
threadwise_scale_load.Run(scale_grid_desc_m,
scale_global_val_buf,
scale_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
scale_thread_buf);
threadwise_mean_var_load.Run(mean_var_grid_desc_m,
mean_global_val_buf,
mean_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf);
threadwise_mean_var_load.Run(mean_var_grid_desc_m,
inv_var_global_val_buf,
inv_var_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
inv_var_thread_buf);
......@@ -467,13 +465,13 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
for(index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
x_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_dy_load.Run(dy_grid_desc_m_k,
dy_global_val_buf,
dy_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
dy_thread_buf);
......@@ -490,12 +488,12 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
inv_var_thread_buf[iM];
AccDataType tmpVal = norm_x * scale_diff_thread_buf[iM];
AccDataType tmpVal = norm_x * dscale_thread_buf[iM];
dx_thread_buf(Number<offset>{}) =
multiplier *
(type_convert<AccDataType>(reduce_size) * dy_thread_buf[Number<offset>{}] -
bias_diff_thread_buf[iM] - tmpVal);
dbias_thread_buf[iM] - tmpVal);
});
});
......@@ -503,7 +501,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
make_tuple(I0, I0),
dx_thread_buf,
dx_grid_desc_m_k,
dx_global_val_buf);
dx_global_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, xy_thread_copy_step_m_k);
......
......@@ -28,7 +28,7 @@ __global__ void kernel_welford_second_half_reduce_first_half(
const XYGridDesc_M_K dy_grid_desc_m_k,
const MeanVarGridDesc_M mean_var_grid_desc_m,
const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k,
const ScaleBiasDiffGridDesc_M_G scale_bias_grid_desc_m_g,
const ScaleBiasDiffGridDesc_M_G dscale_dbias_grid_desc_m_g,
index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
......@@ -50,7 +50,7 @@ __global__ void kernel_welford_second_half_reduce_first_half(
dy_grid_desc_m_k,
mean_var_grid_desc_m,
mean_var_count_grid_desc_m_k,
scale_bias_grid_desc_m_g,
dscale_dbias_grid_desc_m_g,
blkgroup_size,
num_xy_k_block_tile_iteration,
num_mean_var_count_k_block_tile_iteration,
......@@ -149,7 +149,7 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
const XYGridDesc_M_K& dy_grid_desc_m_k,
const MeanVarGridDesc_M& mean_var_grid_desc_m,
const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k,
const ScaleBiasDiffGridDesc_M_G& scale_bias_diff_grid_desc_m_g,
const ScaleBiasDiffGridDesc_M_G& dscale_dbias_grid_desc_m_g,
index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
......@@ -201,9 +201,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
tmp1_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
reduce_scale_diff_thread_buf;
reduce_dscale_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
reduce_bias_diff_thread_buf;
reduce_dbias_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
......@@ -231,10 +231,10 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
if(haveSavedMeanInvVar)
{
const auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_savedMean, mean_var_grid_desc_m.GetElementSpaceSize());
const auto inv_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_savedInvVar, mean_var_grid_desc_m.GetElementSpaceSize());
auto threadwise_mean_inv_var_load =
......@@ -253,26 +253,26 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
thread_m_cluster_id * MThreadSliceSize));
threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
mean_global_val_buf,
mean_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf);
threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
inv_var_global_val_buf,
inv_var_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
inv_var_thread_buf);
}
else
{
const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto welford_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_mean, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto welford_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_variance, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto welford_count_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
auto threadwise_mean_var_load_m_k =
......@@ -320,19 +320,19 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
++reducedTiles)
{
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_mean_global_val_buf,
welford_mean_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_mean_thread_buf);
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_var_global_val_buf,
welford_var_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_var_thread_buf);
threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_count_global_val_buf,
welford_count_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_count_thread_buf);
......@@ -386,23 +386,23 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_welford_mean, mean_var_grid_desc_m.GetElementSpaceSize());
auto inv_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_welford_inv_variance, mean_var_grid_desc_m.GetElementSpaceSize());
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
mean_var_grid_desc_m,
mean_global_val_buf);
mean_global_buf);
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_var_thread_buf,
mean_var_grid_desc_m,
inv_var_global_val_buf);
inv_var_global_buf);
};
};
......@@ -438,17 +438,17 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
workSizePerBlock * block_local_id +
thread_k_cluster_id * KThreadSliceSize));
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto x_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
const auto dy_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto dy_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dy, dy_grid_desc_m_k.GetElementSpaceSize());
constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
reduce_scale_diff_thread_buf(I) = type_convert<AccDataType>(0);
reduce_bias_diff_thread_buf(I) = type_convert<AccDataType>(0);
reduce_dscale_thread_buf(I) = type_convert<AccDataType>(0);
reduce_dbias_thread_buf(I) = type_convert<AccDataType>(0);
});
// Step 2: do first-half reduction on dy and dy * (x-mean) * inv-variance
......@@ -456,13 +456,13 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
for(index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
x_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_dy_load.Run(dy_grid_desc_m_k,
dy_global_val_buf,
dy_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
dy_thread_buf);
......@@ -479,20 +479,20 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
});
});
ThreadwiseReduce::Reduce(tmp1_thread_buf, reduce_scale_diff_thread_buf);
ThreadwiseReduce::Reduce(dy_thread_buf, reduce_bias_diff_thread_buf);
ThreadwiseReduce::Reduce(tmp1_thread_buf, reduce_dscale_thread_buf);
ThreadwiseReduce::Reduce(dy_thread_buf, reduce_dbias_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, xy_thread_copy_step_m_k);
};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
BlockwiseReduce::Reduce(reduce_work_buf, reduce_scale_diff_thread_buf(I));
BlockwiseReduce::Reduce(reduce_work_buf, reduce_dscale_thread_buf(I));
block_sync_lds();
BlockwiseReduce::Reduce(reduce_work_buf, reduce_bias_diff_thread_buf(I));
BlockwiseReduce::Reduce(reduce_work_buf, reduce_dbias_thread_buf(I));
});
auto threadwise_scale_diff_store =
auto threadwise_dscale_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType,
decltype(thread_buffer_desc_m_1),
......@@ -505,13 +505,13 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
InMemoryDataOperationEnum::Set,
1,
true>(
scale_bias_diff_grid_desc_m_g,
dscale_dbias_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto threadwise_bias_diff_store =
auto threadwise_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType,
decltype(thread_buffer_desc_m_1),
......@@ -524,31 +524,31 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
InMemoryDataOperationEnum::Set,
1,
true>(
scale_bias_diff_grid_desc_m_g,
dscale_dbias_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto reduce_scale_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_dscale, scale_bias_diff_grid_desc_m_g.GetElementSpaceSize());
auto reduce_dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_dscale, dscale_dbias_grid_desc_m_g.GetElementSpaceSize());
auto reduce_bias_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_dbias, scale_bias_diff_grid_desc_m_g.GetElementSpaceSize());
auto reduce_dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_dbias, dscale_dbias_grid_desc_m_g.GetElementSpaceSize());
if(thread_k_cluster_id == 0)
{
threadwise_scale_diff_store.Run(thread_buffer_desc_m_1,
threadwise_dscale_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_scale_diff_thread_buf,
scale_bias_diff_grid_desc_m_g,
reduce_scale_diff_global_val_buf);
reduce_dscale_thread_buf,
dscale_dbias_grid_desc_m_g,
reduce_dscale_global_buf);
threadwise_bias_diff_store.Run(thread_buffer_desc_m_1,
threadwise_dbias_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_bias_diff_thread_buf,
scale_bias_diff_grid_desc_m_g,
reduce_bias_diff_global_val_buf);
reduce_dbias_thread_buf,
dscale_dbias_grid_desc_m_g,
reduce_dbias_global_buf);
};
};
};
......
......@@ -204,10 +204,8 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>&
inv_var_thread_buf = var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
scale_diff_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
bias_diff_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> dscale_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> dbias_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
......@@ -289,7 +287,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
auto threadwise_scale_diff_store =
auto threadwise_dscale_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType,
decltype(thread_buffer_desc_m),
......@@ -307,7 +305,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_bias_diff_store =
auto threadwise_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType,
decltype(thread_buffer_desc_m),
......@@ -328,30 +326,30 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, -K_BlockTileSize);
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto x_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
const auto dy_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto dy_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dy, dy_grid_desc_m_k.GetElementSpaceSize());
auto dx_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
auto dx_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dx, dx_grid_desc_m_k.GetElementSpaceSize());
const auto scale_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto scale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_scale, scale_grid_desc_m.GetElementSpaceSize());
auto scale_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dscale, scale_grid_desc_m.GetElementSpaceSize());
auto bias_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbias, bias_grid_desc_m.GetElementSpaceSize());
if(haveSavedMeanInvVar)
{
const auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_savedMean, mean_var_grid_desc_m.GetElementSpaceSize());
const auto inv_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
const auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_savedInvVar, mean_var_grid_desc_m.GetElementSpaceSize());
auto threadwise_mean_inv_var_load =
......@@ -370,13 +368,13 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
thread_m_cluster_id * MThreadSliceSize));
threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
mean_global_val_buf,
mean_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf);
threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
inv_var_global_val_buf,
inv_var_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
inv_var_thread_buf);
......@@ -395,7 +393,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
x_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
......@@ -425,20 +423,20 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
scale_diff_thread_buf(I) = type_convert<AccDataType>(0);
bias_diff_thread_buf(I) = type_convert<AccDataType>(0);
dscale_thread_buf(I) = type_convert<AccDataType>(0);
dbias_thread_buf(I) = type_convert<AccDataType>(0);
});
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
x_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_dy_load.Run(dx_grid_desc_m_k,
dy_global_val_buf,
dy_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
dy_thread_buf);
......@@ -455,36 +453,36 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
});
});
ThreadwiseReduce::Reduce(tmp1_thread_buf, scale_diff_thread_buf);
ThreadwiseReduce::Reduce(dy_thread_buf, bias_diff_thread_buf);
ThreadwiseReduce::Reduce(tmp1_thread_buf, dscale_thread_buf);
ThreadwiseReduce::Reduce(dy_thread_buf, dbias_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k);
};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
BlockwiseReduce::Reduce(reduce_work_buf, scale_diff_thread_buf(I));
BlockwiseReduce::Reduce(reduce_work_buf, dscale_thread_buf(I));
block_sync_lds();
BlockwiseReduce::Reduce(reduce_work_buf, bias_diff_thread_buf(I));
BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I));
});
if(thread_k_cluster_id == 0)
{
threadwise_scale_diff_store.Run(thread_buffer_desc_m,
threadwise_dscale_store.Run(thread_buffer_desc_m,
make_tuple(I0),
scale_diff_thread_buf,
dscale_thread_buf,
scale_grid_desc_m,
scale_diff_global_val_buf);
dscale_global_buf);
threadwise_bias_diff_store.Run(thread_buffer_desc_m,
threadwise_dbias_store.Run(thread_buffer_desc_m,
make_tuple(I0),
bias_diff_thread_buf,
dbias_thread_buf,
bias_grid_desc_m,
bias_diff_global_val_buf);
dbias_global_buf);
};
threadwise_scale_load.Run(scale_grid_desc_m,
scale_global_val_buf,
scale_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
scale_thread_buf);
......@@ -498,13 +496,13 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
x_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_dy_load.Run(dy_grid_desc_m_k,
dy_global_val_buf,
dy_global_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
dy_thread_buf);
......@@ -517,13 +515,13 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
inv_var_thread_buf[iM];
AccDataType tmpVal = norm_x * scale_diff_thread_buf[iM];
AccDataType tmpVal = norm_x * dscale_thread_buf[iM];
dx_thread_buf(Number<offset>{}) =
type_convert<AccDataType>(1.0) / type_convert<AccDataType>(reduce_size) *
inv_var_thread_buf[iM] * scale_thread_buf[iM] *
(type_convert<AccDataType>(reduce_size) * dy_thread_buf[Number<offset>{}] -
bias_diff_thread_buf[iM] - tmpVal);
dbias_thread_buf[iM] - tmpVal);
});
});
......@@ -531,7 +529,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
make_tuple(I0, I0),
dx_thread_buf,
dx_grid_desc_m_k,
dx_global_val_buf);
dx_global_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment