"tests/pipelines/vscode:/vscode.git/clone" did not exist on "c059cc0992899383d1079fbea52b71a49aa3f88a"
Commit de6aad06 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Parameters renaming again in batchnorm backward kernels

parent 7d114e80
...@@ -26,7 +26,7 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final( ...@@ -26,7 +26,7 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K x_grid_desc_m_k,
const XYGridDesc_M_K dy_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k,
const XYGridDesc_M_K dx_grid_desc_m_k, const XYGridDesc_M_K dx_grid_desc_m_k,
const ScaleBiasDiffGridDesc_M_K scale_bias_diff_grid_desc_m_k, const ScaleBiasDiffGridDesc_M_K dscale_dbias_grid_desc_m_k,
const MeanVarGridDesc_M mean_var_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m,
const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M bias_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m,
...@@ -48,7 +48,7 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final( ...@@ -48,7 +48,7 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
GridwiseReduceSecondHalfBatchNormBackwardFinal_::Run(x_grid_desc_m_k, GridwiseReduceSecondHalfBatchNormBackwardFinal_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k, dy_grid_desc_m_k,
dx_grid_desc_m_k, dx_grid_desc_m_k,
scale_bias_diff_grid_desc_m_k, dscale_dbias_grid_desc_m_k,
mean_var_grid_desc_m, mean_var_grid_desc_m,
scale_grid_desc_m, scale_grid_desc_m,
bias_grid_desc_m, bias_grid_desc_m,
...@@ -143,7 +143,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -143,7 +143,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
__device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k, __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
const XYGridDesc_M_K& dy_grid_desc_m_k, const XYGridDesc_M_K& dy_grid_desc_m_k,
const XYGridDesc_M_K& dx_grid_desc_m_k, const XYGridDesc_M_K& dx_grid_desc_m_k,
const ScaleBiasDiffGridDesc_M_K& scale_bias_diff_grid_desc_m_k, const ScaleBiasDiffGridDesc_M_K& dscale_dbias_grid_desc_m_k,
const MeanVarGridDesc_M& mean_var_grid_desc_m, const MeanVarGridDesc_M& mean_var_grid_desc_m,
const ScaleBiasGridDesc_M& scale_grid_desc_m, const ScaleBiasGridDesc_M& scale_grid_desc_m,
const ScaleBiasGridDesc_M& bias_grid_desc_m, const ScaleBiasGridDesc_M& bias_grid_desc_m,
...@@ -168,14 +168,12 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -168,14 +168,12 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
reduce_scale_diff_thread_buf; reduce_dscale_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
reduce_bias_diff_thread_buf; reduce_dbias_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> dscale_thread_buf;
scale_diff_thread_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> dbias_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
bias_diff_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf; x_thread_buf;
...@@ -212,7 +210,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -212,7 +210,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
// Step 1: do final reduction for scale_diff and bias_diff and output // Step 1: do final reduction for scale_diff and bias_diff and output
auto threadwise_scale_diff_load_m_k = auto threadwise_dscale_load_m_k =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType, ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
AccDataType, AccDataType,
ScaleBiasDiffGridDesc_M_K, ScaleBiasDiffGridDesc_M_K,
...@@ -223,12 +221,12 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -223,12 +221,12 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
1, 1,
1, 1,
true>( true>(
scale_bias_diff_grid_desc_m_k, dscale_dbias_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1)); thread_k_cluster_id * 1));
auto threadwise_bias_diff_load_m_k = auto threadwise_dbias_load_m_k =
ThreadwiseTensorSliceTransfer_v2<BiasDataType, ThreadwiseTensorSliceTransfer_v2<BiasDataType,
AccDataType, AccDataType,
ScaleBiasDiffGridDesc_M_K, ScaleBiasDiffGridDesc_M_K,
...@@ -239,12 +237,12 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -239,12 +237,12 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
1, 1,
1, 1,
true>( true>(
scale_bias_diff_grid_desc_m_k, dscale_dbias_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1)); thread_k_cluster_id * 1));
auto threadwise_scale_diff_store_m = auto threadwise_dscale_store_m =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType, ScaleDataType,
decltype(thread_buffer_desc_m), decltype(thread_buffer_desc_m),
...@@ -262,7 +260,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -262,7 +260,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
thread_m_cluster_id * MThreadSliceSize), thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{}); PassThroughOp{});
auto threadwise_bias_diff_store_m = auto threadwise_dbias_store_m =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType, BiasDataType,
decltype(thread_buffer_desc_m), decltype(thread_buffer_desc_m),
...@@ -280,67 +278,67 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -280,67 +278,67 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
thread_m_cluster_id * MThreadSliceSize), thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{}); PassThroughOp{});
const auto reduce_scale_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto reduce_dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_dscale, scale_bias_diff_grid_desc_m_k.GetElementSpaceSize()); p_reduce_dscale, dscale_dbias_grid_desc_m_k.GetElementSpaceSize());
const auto reduce_bias_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto reduce_dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_dbias, scale_bias_diff_grid_desc_m_k.GetElementSpaceSize()); p_reduce_dbias, dscale_dbias_grid_desc_m_k.GetElementSpaceSize());
auto scale_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dscale, scale_grid_desc_m.GetElementSpaceSize()); p_dscale, scale_grid_desc_m.GetElementSpaceSize());
auto bias_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbias, bias_grid_desc_m.GetElementSpaceSize()); p_dbias, bias_grid_desc_m.GetElementSpaceSize());
constexpr auto scale_bias_diff_thread_copy_step_m_k = constexpr auto dscale_dbias_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize * 1); make_multi_index(0, KThreadClusterSize * 1);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
scale_diff_thread_buf(I) = type_convert<AccDataType>(0.0f); dscale_thread_buf(I) = type_convert<AccDataType>(0.0f);
bias_diff_thread_buf(I) = type_convert<AccDataType>(0.0f); dbias_thread_buf(I) = type_convert<AccDataType>(0.0f);
}); });
for(index_t reducedTiles = 0; reducedTiles < num_scale_bias_diff_k_block_tile_iteration; for(index_t reducedTiles = 0; reducedTiles < num_scale_bias_diff_k_block_tile_iteration;
++reducedTiles) ++reducedTiles)
{ {
threadwise_scale_diff_load_m_k.Run(scale_bias_diff_grid_desc_m_k, threadwise_dscale_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_scale_diff_global_val_buf, reduce_dscale_global_buf,
thread_buffer_desc_m_1, thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
reduce_scale_diff_thread_buf); reduce_dscale_thread_buf);
threadwise_bias_diff_load_m_k.Run(scale_bias_diff_grid_desc_m_k, threadwise_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_bias_diff_global_val_buf, reduce_dbias_global_buf,
thread_buffer_desc_m_1, thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
reduce_bias_diff_thread_buf); reduce_dbias_thread_buf);
ThreadwiseReduce::Reduce(reduce_scale_diff_thread_buf, scale_diff_thread_buf); ThreadwiseReduce::Reduce(reduce_dscale_thread_buf, dscale_thread_buf);
ThreadwiseReduce::Reduce(reduce_bias_diff_thread_buf, bias_diff_thread_buf); ThreadwiseReduce::Reduce(reduce_dbias_thread_buf, dbias_thread_buf);
threadwise_scale_diff_load_m_k.MoveSrcSliceWindow(scale_bias_diff_grid_desc_m_k, threadwise_dscale_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
scale_bias_diff_thread_copy_step_m_k); dscale_dbias_thread_copy_step_m_k);
threadwise_bias_diff_load_m_k.MoveSrcSliceWindow(scale_bias_diff_grid_desc_m_k, threadwise_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
scale_bias_diff_thread_copy_step_m_k); dscale_dbias_thread_copy_step_m_k);
} }
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
BlockwiseReduce::Reduce(reduce_work_buf, scale_diff_thread_buf(I)); BlockwiseReduce::Reduce(reduce_work_buf, dscale_thread_buf(I));
block_sync_lds(); block_sync_lds();
BlockwiseReduce::Reduce(reduce_work_buf, bias_diff_thread_buf(I)); BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I));
}); });
threadwise_scale_diff_store_m.Run(thread_buffer_desc_m, threadwise_dscale_store_m.Run(thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
scale_diff_thread_buf, dscale_thread_buf,
scale_grid_desc_m, scale_grid_desc_m,
scale_diff_global_val_buf); dscale_global_buf);
threadwise_bias_diff_store_m.Run(thread_buffer_desc_m, threadwise_dbias_store_m.Run(thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
bias_diff_thread_buf, dbias_thread_buf,
bias_grid_desc_m, bias_grid_desc_m,
bias_diff_global_val_buf); dbias_global_buf);
// Step 2: calculate dx = 1/N * invVar * scale * (N * dy - biasDiff - scaleDiff * (x - mean) // Step 2: calculate dx = 1/N * invVar * scale * (N * dy - biasDiff - scaleDiff * (x - mean)
// * invVar) and output // * invVar) and output
...@@ -426,38 +424,38 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -426,38 +424,38 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize)); thread_m_cluster_id * MThreadSliceSize));
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto x_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize()); p_x, x_grid_desc_m_k.GetElementSpaceSize());
const auto dy_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto dy_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dy, dy_grid_desc_m_k.GetElementSpaceSize()); p_dy, dy_grid_desc_m_k.GetElementSpaceSize());
auto dx_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dx_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dx, dx_grid_desc_m_k.GetElementSpaceSize()); p_dx, dx_grid_desc_m_k.GetElementSpaceSize());
const auto scale_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto scale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_scale, scale_grid_desc_m.GetElementSpaceSize()); p_scale, scale_grid_desc_m.GetElementSpaceSize());
const auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_mean, mean_var_grid_desc_m.GetElementSpaceSize()); p_mean, mean_var_grid_desc_m.GetElementSpaceSize());
const auto inv_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_inv_var, mean_var_grid_desc_m.GetElementSpaceSize()); p_inv_var, mean_var_grid_desc_m.GetElementSpaceSize());
threadwise_scale_load.Run(scale_grid_desc_m, threadwise_scale_load.Run(scale_grid_desc_m,
scale_global_val_buf, scale_global_buf,
thread_buffer_desc_m, thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
scale_thread_buf); scale_thread_buf);
threadwise_mean_var_load.Run(mean_var_grid_desc_m, threadwise_mean_var_load.Run(mean_var_grid_desc_m,
mean_global_val_buf, mean_global_buf,
thread_buffer_desc_m, thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
mean_thread_buf); mean_thread_buf);
threadwise_mean_var_load.Run(mean_var_grid_desc_m, threadwise_mean_var_load.Run(mean_var_grid_desc_m,
inv_var_global_val_buf, inv_var_global_buf,
thread_buffer_desc_m, thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
inv_var_thread_buf); inv_var_thread_buf);
...@@ -467,13 +465,13 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -467,13 +465,13 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
for(index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles) for(index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles)
{ {
threadwise_x_load.Run(x_grid_desc_m_k, threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf, x_global_buf,
thread_buffer_desc_m_k, thread_buffer_desc_m_k,
make_tuple(I0, I0), make_tuple(I0, I0),
x_thread_buf); x_thread_buf);
threadwise_dy_load.Run(dy_grid_desc_m_k, threadwise_dy_load.Run(dy_grid_desc_m_k,
dy_global_val_buf, dy_global_buf,
thread_buffer_desc_m_k, thread_buffer_desc_m_k,
make_tuple(I0, I0), make_tuple(I0, I0),
dy_thread_buf); dy_thread_buf);
...@@ -490,12 +488,12 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -490,12 +488,12 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) * AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
inv_var_thread_buf[iM]; inv_var_thread_buf[iM];
AccDataType tmpVal = norm_x * scale_diff_thread_buf[iM]; AccDataType tmpVal = norm_x * dscale_thread_buf[iM];
dx_thread_buf(Number<offset>{}) = dx_thread_buf(Number<offset>{}) =
multiplier * multiplier *
(type_convert<AccDataType>(reduce_size) * dy_thread_buf[Number<offset>{}] - (type_convert<AccDataType>(reduce_size) * dy_thread_buf[Number<offset>{}] -
bias_diff_thread_buf[iM] - tmpVal); dbias_thread_buf[iM] - tmpVal);
}); });
}); });
...@@ -503,7 +501,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -503,7 +501,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
make_tuple(I0, I0), make_tuple(I0, I0),
dx_thread_buf, dx_thread_buf,
dx_grid_desc_m_k, dx_grid_desc_m_k,
dx_global_val_buf); dx_global_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k); threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, xy_thread_copy_step_m_k); threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, xy_thread_copy_step_m_k);
......
...@@ -28,7 +28,7 @@ __global__ void kernel_welford_second_half_reduce_first_half( ...@@ -28,7 +28,7 @@ __global__ void kernel_welford_second_half_reduce_first_half(
const XYGridDesc_M_K dy_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k,
const MeanVarGridDesc_M mean_var_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m,
const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k,
const ScaleBiasDiffGridDesc_M_G scale_bias_grid_desc_m_g, const ScaleBiasDiffGridDesc_M_G dscale_dbias_grid_desc_m_g,
index_t blkgroup_size, index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration, index_t num_mean_var_count_k_block_tile_iteration,
...@@ -50,7 +50,7 @@ __global__ void kernel_welford_second_half_reduce_first_half( ...@@ -50,7 +50,7 @@ __global__ void kernel_welford_second_half_reduce_first_half(
dy_grid_desc_m_k, dy_grid_desc_m_k,
mean_var_grid_desc_m, mean_var_grid_desc_m,
mean_var_count_grid_desc_m_k, mean_var_count_grid_desc_m_k,
scale_bias_grid_desc_m_g, dscale_dbias_grid_desc_m_g,
blkgroup_size, blkgroup_size,
num_xy_k_block_tile_iteration, num_xy_k_block_tile_iteration,
num_mean_var_count_k_block_tile_iteration, num_mean_var_count_k_block_tile_iteration,
...@@ -149,7 +149,7 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -149,7 +149,7 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
const XYGridDesc_M_K& dy_grid_desc_m_k, const XYGridDesc_M_K& dy_grid_desc_m_k,
const MeanVarGridDesc_M& mean_var_grid_desc_m, const MeanVarGridDesc_M& mean_var_grid_desc_m,
const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k, const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k,
const ScaleBiasDiffGridDesc_M_G& scale_bias_diff_grid_desc_m_g, const ScaleBiasDiffGridDesc_M_G& dscale_dbias_grid_desc_m_g,
index_t blkgroup_size, index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration, index_t num_mean_var_count_k_block_tile_iteration,
...@@ -201,9 +201,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -201,9 +201,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
tmp1_thread_buf; tmp1_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
reduce_scale_diff_thread_buf; reduce_dscale_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
reduce_bias_diff_thread_buf; reduce_dbias_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
...@@ -231,10 +231,10 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -231,10 +231,10 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
if(haveSavedMeanInvVar) if(haveSavedMeanInvVar)
{ {
const auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_savedMean, mean_var_grid_desc_m.GetElementSpaceSize()); p_savedMean, mean_var_grid_desc_m.GetElementSpaceSize());
const auto inv_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_savedInvVar, mean_var_grid_desc_m.GetElementSpaceSize()); p_savedInvVar, mean_var_grid_desc_m.GetElementSpaceSize());
auto threadwise_mean_inv_var_load = auto threadwise_mean_inv_var_load =
...@@ -253,26 +253,26 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -253,26 +253,26 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
thread_m_cluster_id * MThreadSliceSize)); thread_m_cluster_id * MThreadSliceSize));
threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m, threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
mean_global_val_buf, mean_global_buf,
thread_buffer_desc_m, thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
mean_thread_buf); mean_thread_buf);
threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m, threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
inv_var_global_val_buf, inv_var_global_buf,
thread_buffer_desc_m, thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
inv_var_thread_buf); inv_var_thread_buf);
} }
else else
{ {
const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto welford_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_mean, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); p_in_welford_mean, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto welford_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_variance, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); p_in_welford_variance, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto welford_count_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
auto threadwise_mean_var_load_m_k = auto threadwise_mean_var_load_m_k =
...@@ -320,19 +320,19 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -320,19 +320,19 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
++reducedTiles) ++reducedTiles)
{ {
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_mean_global_val_buf, welford_mean_global_buf,
thread_buffer_desc_m_1, thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
in_welford_mean_thread_buf); in_welford_mean_thread_buf);
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_var_global_val_buf, welford_var_global_buf,
thread_buffer_desc_m_1, thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
in_welford_var_thread_buf); in_welford_var_thread_buf);
threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k, threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_count_global_val_buf, welford_count_global_buf,
thread_buffer_desc_m_1, thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
in_welford_count_thread_buf); in_welford_count_thread_buf);
...@@ -386,23 +386,23 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -386,23 +386,23 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
thread_m_cluster_id * MThreadSliceSize), thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{}); PassThroughOp{});
auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_welford_mean, mean_var_grid_desc_m.GetElementSpaceSize()); p_out_welford_mean, mean_var_grid_desc_m.GetElementSpaceSize());
auto inv_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_welford_inv_variance, mean_var_grid_desc_m.GetElementSpaceSize()); p_out_welford_inv_variance, mean_var_grid_desc_m.GetElementSpaceSize());
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
mean_thread_buf, mean_thread_buf,
mean_var_grid_desc_m, mean_var_grid_desc_m,
mean_global_val_buf); mean_global_buf);
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
inv_var_thread_buf, inv_var_thread_buf,
mean_var_grid_desc_m, mean_var_grid_desc_m,
inv_var_global_val_buf); inv_var_global_buf);
}; };
}; };
...@@ -438,17 +438,17 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -438,17 +438,17 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
workSizePerBlock * block_local_id + workSizePerBlock * block_local_id +
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * KThreadSliceSize));
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto x_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize()); p_x, x_grid_desc_m_k.GetElementSpaceSize());
const auto dy_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto dy_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dy, dy_grid_desc_m_k.GetElementSpaceSize()); p_dy, dy_grid_desc_m_k.GetElementSpaceSize());
constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize); constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
reduce_scale_diff_thread_buf(I) = type_convert<AccDataType>(0); reduce_dscale_thread_buf(I) = type_convert<AccDataType>(0);
reduce_bias_diff_thread_buf(I) = type_convert<AccDataType>(0); reduce_dbias_thread_buf(I) = type_convert<AccDataType>(0);
}); });
// Step 2: do first-half reduction on dy and dy * (x-mean) * inv-variance // Step 2: do first-half reduction on dy and dy * (x-mean) * inv-variance
...@@ -456,13 +456,13 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -456,13 +456,13 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
for(index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles) for(index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles)
{ {
threadwise_x_load.Run(x_grid_desc_m_k, threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf, x_global_buf,
thread_buffer_desc_m_k, thread_buffer_desc_m_k,
make_tuple(I0, I0), make_tuple(I0, I0),
x_thread_buf); x_thread_buf);
threadwise_dy_load.Run(dy_grid_desc_m_k, threadwise_dy_load.Run(dy_grid_desc_m_k,
dy_global_val_buf, dy_global_buf,
thread_buffer_desc_m_k, thread_buffer_desc_m_k,
make_tuple(I0, I0), make_tuple(I0, I0),
dy_thread_buf); dy_thread_buf);
...@@ -479,20 +479,20 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -479,20 +479,20 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
}); });
}); });
ThreadwiseReduce::Reduce(tmp1_thread_buf, reduce_scale_diff_thread_buf); ThreadwiseReduce::Reduce(tmp1_thread_buf, reduce_dscale_thread_buf);
ThreadwiseReduce::Reduce(dy_thread_buf, reduce_bias_diff_thread_buf); ThreadwiseReduce::Reduce(dy_thread_buf, reduce_dbias_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k); threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, xy_thread_copy_step_m_k); threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, xy_thread_copy_step_m_k);
}; };
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
BlockwiseReduce::Reduce(reduce_work_buf, reduce_scale_diff_thread_buf(I)); BlockwiseReduce::Reduce(reduce_work_buf, reduce_dscale_thread_buf(I));
block_sync_lds(); block_sync_lds();
BlockwiseReduce::Reduce(reduce_work_buf, reduce_bias_diff_thread_buf(I)); BlockwiseReduce::Reduce(reduce_work_buf, reduce_dbias_thread_buf(I));
}); });
auto threadwise_scale_diff_store = auto threadwise_dscale_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType, ScaleDataType,
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
...@@ -505,13 +505,13 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -505,13 +505,13 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
scale_bias_diff_grid_desc_m_g, dscale_dbias_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
block_local_id), block_local_id),
PassThroughOp{}); PassThroughOp{});
auto threadwise_bias_diff_store = auto threadwise_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType, BiasDataType,
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
...@@ -524,31 +524,31 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -524,31 +524,31 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
scale_bias_diff_grid_desc_m_g, dscale_dbias_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
block_local_id), block_local_id),
PassThroughOp{}); PassThroughOp{});
auto reduce_scale_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto reduce_dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_dscale, scale_bias_diff_grid_desc_m_g.GetElementSpaceSize()); p_reduce_dscale, dscale_dbias_grid_desc_m_g.GetElementSpaceSize());
auto reduce_bias_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto reduce_dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_dbias, scale_bias_diff_grid_desc_m_g.GetElementSpaceSize()); p_reduce_dbias, dscale_dbias_grid_desc_m_g.GetElementSpaceSize());
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
{ {
threadwise_scale_diff_store.Run(thread_buffer_desc_m_1, threadwise_dscale_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
reduce_scale_diff_thread_buf, reduce_dscale_thread_buf,
scale_bias_diff_grid_desc_m_g, dscale_dbias_grid_desc_m_g,
reduce_scale_diff_global_val_buf); reduce_dscale_global_buf);
threadwise_bias_diff_store.Run(thread_buffer_desc_m_1, threadwise_dbias_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
reduce_bias_diff_thread_buf, reduce_dbias_thread_buf,
scale_bias_diff_grid_desc_m_g, dscale_dbias_grid_desc_m_g,
reduce_bias_diff_global_val_buf); reduce_dbias_global_buf);
}; };
}; };
}; };
......
...@@ -204,10 +204,8 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -204,10 +204,8 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>& StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>&
inv_var_thread_buf = var_thread_buf; inv_var_thread_buf = var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> dscale_thread_buf;
scale_diff_thread_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> dbias_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
bias_diff_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
...@@ -289,7 +287,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -289,7 +287,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize)); thread_m_cluster_id * MThreadSliceSize));
auto threadwise_scale_diff_store = auto threadwise_dscale_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType, ScaleDataType,
decltype(thread_buffer_desc_m), decltype(thread_buffer_desc_m),
...@@ -307,7 +305,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -307,7 +305,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
thread_m_cluster_id * MThreadSliceSize), thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{}); PassThroughOp{});
auto threadwise_bias_diff_store = auto threadwise_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType, BiasDataType,
decltype(thread_buffer_desc_m), decltype(thread_buffer_desc_m),
...@@ -328,30 +326,30 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -328,30 +326,30 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, -K_BlockTileSize); constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, -K_BlockTileSize);
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto x_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize()); p_x, x_grid_desc_m_k.GetElementSpaceSize());
const auto dy_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto dy_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dy, dy_grid_desc_m_k.GetElementSpaceSize()); p_dy, dy_grid_desc_m_k.GetElementSpaceSize());
auto dx_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dx_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dx, dx_grid_desc_m_k.GetElementSpaceSize()); p_dx, dx_grid_desc_m_k.GetElementSpaceSize());
const auto scale_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto scale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_scale, scale_grid_desc_m.GetElementSpaceSize()); p_scale, scale_grid_desc_m.GetElementSpaceSize());
auto scale_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dscale, scale_grid_desc_m.GetElementSpaceSize()); p_dscale, scale_grid_desc_m.GetElementSpaceSize());
auto bias_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbias, bias_grid_desc_m.GetElementSpaceSize()); p_dbias, bias_grid_desc_m.GetElementSpaceSize());
if(haveSavedMeanInvVar) if(haveSavedMeanInvVar)
{ {
const auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_savedMean, mean_var_grid_desc_m.GetElementSpaceSize()); p_savedMean, mean_var_grid_desc_m.GetElementSpaceSize());
const auto inv_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_savedInvVar, mean_var_grid_desc_m.GetElementSpaceSize()); p_savedInvVar, mean_var_grid_desc_m.GetElementSpaceSize());
auto threadwise_mean_inv_var_load = auto threadwise_mean_inv_var_load =
...@@ -370,13 +368,13 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -370,13 +368,13 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
thread_m_cluster_id * MThreadSliceSize)); thread_m_cluster_id * MThreadSliceSize));
threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m, threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
mean_global_val_buf, mean_global_buf,
thread_buffer_desc_m, thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
mean_thread_buf); mean_thread_buf);
threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m, threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
inv_var_global_val_buf, inv_var_global_buf,
thread_buffer_desc_m, thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
inv_var_thread_buf); inv_var_thread_buf);
...@@ -395,7 +393,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -395,7 +393,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
{ {
threadwise_x_load.Run(x_grid_desc_m_k, threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf, x_global_buf,
thread_buffer_desc_m_k, thread_buffer_desc_m_k,
make_tuple(I0, I0), make_tuple(I0, I0),
x_thread_buf); x_thread_buf);
...@@ -425,20 +423,20 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -425,20 +423,20 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
}; };
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
scale_diff_thread_buf(I) = type_convert<AccDataType>(0); dscale_thread_buf(I) = type_convert<AccDataType>(0);
bias_diff_thread_buf(I) = type_convert<AccDataType>(0); dbias_thread_buf(I) = type_convert<AccDataType>(0);
}); });
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{ {
threadwise_x_load.Run(x_grid_desc_m_k, threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf, x_global_buf,
thread_buffer_desc_m_k, thread_buffer_desc_m_k,
make_tuple(I0, I0), make_tuple(I0, I0),
x_thread_buf); x_thread_buf);
threadwise_dy_load.Run(dx_grid_desc_m_k, threadwise_dy_load.Run(dx_grid_desc_m_k,
dy_global_val_buf, dy_global_buf,
thread_buffer_desc_m_k, thread_buffer_desc_m_k,
make_tuple(I0, I0), make_tuple(I0, I0),
dy_thread_buf); dy_thread_buf);
...@@ -455,36 +453,36 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -455,36 +453,36 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
}); });
}); });
ThreadwiseReduce::Reduce(tmp1_thread_buf, scale_diff_thread_buf); ThreadwiseReduce::Reduce(tmp1_thread_buf, dscale_thread_buf);
ThreadwiseReduce::Reduce(dy_thread_buf, bias_diff_thread_buf); ThreadwiseReduce::Reduce(dy_thread_buf, dbias_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k); threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k);
}; };
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
BlockwiseReduce::Reduce(reduce_work_buf, scale_diff_thread_buf(I)); BlockwiseReduce::Reduce(reduce_work_buf, dscale_thread_buf(I));
block_sync_lds(); block_sync_lds();
BlockwiseReduce::Reduce(reduce_work_buf, bias_diff_thread_buf(I)); BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I));
}); });
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
{ {
threadwise_scale_diff_store.Run(thread_buffer_desc_m, threadwise_dscale_store.Run(thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
scale_diff_thread_buf, dscale_thread_buf,
scale_grid_desc_m, scale_grid_desc_m,
scale_diff_global_val_buf); dscale_global_buf);
threadwise_bias_diff_store.Run(thread_buffer_desc_m, threadwise_dbias_store.Run(thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
bias_diff_thread_buf, dbias_thread_buf,
bias_grid_desc_m, bias_grid_desc_m,
bias_diff_global_val_buf); dbias_global_buf);
}; };
threadwise_scale_load.Run(scale_grid_desc_m, threadwise_scale_load.Run(scale_grid_desc_m,
scale_global_val_buf, scale_global_buf,
thread_buffer_desc_m, thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
scale_thread_buf); scale_thread_buf);
...@@ -498,13 +496,13 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -498,13 +496,13 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{ {
threadwise_x_load.Run(x_grid_desc_m_k, threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf, x_global_buf,
thread_buffer_desc_m_k, thread_buffer_desc_m_k,
make_tuple(I0, I0), make_tuple(I0, I0),
x_thread_buf); x_thread_buf);
threadwise_dy_load.Run(dy_grid_desc_m_k, threadwise_dy_load.Run(dy_grid_desc_m_k,
dy_global_val_buf, dy_global_buf,
thread_buffer_desc_m_k, thread_buffer_desc_m_k,
make_tuple(I0, I0), make_tuple(I0, I0),
dy_thread_buf); dy_thread_buf);
...@@ -517,13 +515,13 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -517,13 +515,13 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) * AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
inv_var_thread_buf[iM]; inv_var_thread_buf[iM];
AccDataType tmpVal = norm_x * scale_diff_thread_buf[iM]; AccDataType tmpVal = norm_x * dscale_thread_buf[iM];
dx_thread_buf(Number<offset>{}) = dx_thread_buf(Number<offset>{}) =
type_convert<AccDataType>(1.0) / type_convert<AccDataType>(reduce_size) * type_convert<AccDataType>(1.0) / type_convert<AccDataType>(reduce_size) *
inv_var_thread_buf[iM] * scale_thread_buf[iM] * inv_var_thread_buf[iM] * scale_thread_buf[iM] *
(type_convert<AccDataType>(reduce_size) * dy_thread_buf[Number<offset>{}] - (type_convert<AccDataType>(reduce_size) * dy_thread_buf[Number<offset>{}] -
bias_diff_thread_buf[iM] - tmpVal); dbias_thread_buf[iM] - tmpVal);
}); });
}); });
...@@ -531,7 +529,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -531,7 +529,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
make_tuple(I0, I0), make_tuple(I0, I0),
dx_thread_buf, dx_thread_buf,
dx_grid_desc_m_k, dx_grid_desc_m_k,
dx_global_val_buf); dx_global_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k); threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment