Unverified Commit 7bcaf2a7 authored by Adam Osewski's avatar Adam Osewski Committed by GitHub
Browse files

Merge branch 'develop' into wavelet_model

parents e59daa22 0345963e
...@@ -10,8 +10,8 @@ namespace element_wise { ...@@ -10,8 +10,8 @@ namespace element_wise {
template <typename Activation> template <typename Activation>
struct Activation_Mul_Clamp struct Activation_Mul_Clamp
{ {
Activation_Mul_Clamp(float multiplier, Activation activationOp) Activation_Mul_Clamp(float requantScale, Activation activationOp)
: multiplier_(multiplier), activationOp_(activationOp) : requantScale_(requantScale), activationOp_(activationOp)
{ {
} }
...@@ -19,7 +19,7 @@ struct Activation_Mul_Clamp ...@@ -19,7 +19,7 @@ struct Activation_Mul_Clamp
{ {
float x_fp32 = ck::type_convert<float>(x); float x_fp32 = ck::type_convert<float>(x);
activationOp_(x_fp32, x_fp32); activationOp_(x_fp32, x_fp32);
float y_fp32 = math::clamp(multiplier_ * x_fp32, -128.f, 127.f); float y_fp32 = math::clamp(requantScale_ * x_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32); y = ck::type_convert<int8_t>(y_fp32);
} }
...@@ -28,10 +28,29 @@ struct Activation_Mul_Clamp ...@@ -28,10 +28,29 @@ struct Activation_Mul_Clamp
// We might type_convert to int8 after lambda in someplace // We might type_convert to int8 after lambda in someplace
float x_fp32 = ck::type_convert<float>(x); float x_fp32 = ck::type_convert<float>(x);
activationOp_(x_fp32, x_fp32); activationOp_(x_fp32, x_fp32);
y = math::clamp(multiplier_ * x_fp32, -128.f, 127.f); y = math::clamp(requantScale_ * x_fp32, -128.f, 127.f);
}
float requantScale_;
Activation activationOp_;
};
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// relu, leaky relu ...etc
template <typename Activation>
struct Activation_Mul2_Clamp
{
Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x, const float& requantScale) const
{
float y_fp32 = ck::type_convert<float>(x);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
} }
float multiplier_;
Activation activationOp_; Activation activationOp_;
}; };
...@@ -39,21 +58,40 @@ struct Activation_Mul_Clamp ...@@ -39,21 +58,40 @@ struct Activation_Mul_Clamp
template <typename Activation> template <typename Activation>
struct Add_Activation_Mul_Clamp struct Add_Activation_Mul_Clamp
{ {
Add_Activation_Mul_Clamp(float multiplier, Activation activationOp) Add_Activation_Mul_Clamp(float requantScale, Activation activationOp)
: multiplier_(multiplier), activationOp_(activationOp) : requantScale_(requantScale), activationOp_(activationOp)
{ {
} }
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x1, const int32_t& x2) const operator()(int8_t& y, const int32_t& x, const int32_t& bias) const
{
float y_fp32 = ck::type_convert<float>(x + bias);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
float requantScale_;
Activation activationOp_;
};
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// relu, leaky relu ...etc
template <typename Activation>
struct Add_Activation_Mul2_Clamp
{
Add_Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x, const int32_t& bias, const float& requantScale) const
{ {
float y_fp32 = ck::type_convert<float>(x1 + x2); float y_fp32 = ck::type_convert<float>(x + bias);
activationOp_(y_fp32, y_fp32); activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(multiplier_ * y_fp32, -128.f, 127.f); y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32); y = ck::type_convert<int8_t>(y_fp32);
} }
float multiplier_;
Activation activationOp_; Activation activationOp_;
}; };
...@@ -61,23 +99,23 @@ struct Add_Activation_Mul_Clamp ...@@ -61,23 +99,23 @@ struct Add_Activation_Mul_Clamp
template <typename Activation> template <typename Activation>
struct Add_Mul_Activation_Mul_Clamp struct Add_Mul_Activation_Mul_Clamp
{ {
Add_Mul_Activation_Mul_Clamp(float multiplier1, float multiplier2, Activation activationOp) Add_Mul_Activation_Mul_Clamp(float requantScale1, float requantScale2, Activation activationOp)
: multiplier1_(multiplier1), multiplier2_(multiplier2), activationOp_(activationOp) : requantScale1_(requantScale1), requantScale2_(requantScale2), activationOp_(activationOp)
{ {
} }
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x1, const int32_t& x2) const operator()(int8_t& y, const int32_t& x, const int32_t& bias) const
{ {
float y_fp32 = ck::type_convert<float>(x1 + x2); float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = multiplier1_ * y_fp32; y_fp32 = requantScale1_ * y_fp32;
activationOp_(y_fp32, y_fp32); activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(multiplier2_ * y_fp32, -128.f, 127.f); y_fp32 = math::clamp(requantScale2_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32); y = ck::type_convert<int8_t>(y_fp32);
} }
float multiplier1_; float requantScale1_;
float multiplier2_; float requantScale2_;
Activation activationOp_; Activation activationOp_;
}; };
......
...@@ -16,7 +16,7 @@ template <typename GridwiseReduceSecondHalfBatchNormBackwardFinal_, ...@@ -16,7 +16,7 @@ template <typename GridwiseReduceSecondHalfBatchNormBackwardFinal_,
typename DyDataType, typename DyDataType,
typename DxDataType, typename DxDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename DscaleDbiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp, typename DyElementwiseOp,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
...@@ -35,8 +35,8 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final( ...@@ -35,8 +35,8 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
long_index_t reduce_size, long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_dscale_dbias_k_block_tile_iteration, index_t num_dscale_dbias_k_block_tile_iteration,
const ScaleDataType* const __restrict__ p_reduce_dscale, const DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
const BiasDataType* const __restrict__ p_reduce_dbias, const DscaleDbiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean, const MeanVarDataType* const __restrict__ p_mean,
const MeanVarDataType* const __restrict__ p_inv_var, const MeanVarDataType* const __restrict__ p_inv_var,
const XDataType* const __restrict__ p_x, const XDataType* const __restrict__ p_x,
...@@ -44,8 +44,8 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final( ...@@ -44,8 +44,8 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
const ScaleDataType* const __restrict__ p_scale, const ScaleDataType* const __restrict__ p_scale,
const DyElementwiseOp dy_elementwise_op, const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx, DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale, DscaleDbiasDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias) DscaleDbiasDataType* const __restrict__ p_dbias)
{ {
GridwiseReduceSecondHalfBatchNormBackwardFinal_::Run(x_grid_desc_m_k, GridwiseReduceSecondHalfBatchNormBackwardFinal_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k, dy_grid_desc_m_k,
...@@ -76,7 +76,7 @@ template <typename XDataType, ...@@ -76,7 +76,7 @@ template <typename XDataType,
typename DxDataType, typename DxDataType,
typename AccDataType, typename AccDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename DscaleDbiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp, typename DyElementwiseOp,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
...@@ -92,8 +92,8 @@ template <typename XDataType, ...@@ -92,8 +92,8 @@ template <typename XDataType,
index_t XSrcVectorSize, index_t XSrcVectorSize,
index_t DySrcVectorSize, index_t DySrcVectorSize,
index_t DxDstVectorSize, index_t DxDstVectorSize,
index_t ScaleSrcDstVectorSize, index_t ScaleSrcVectorSize,
index_t BiasDstVectorSize, index_t DscaleDbiasDstVectorSize,
index_t MeanVarSrcVectorSize> index_t MeanVarSrcVectorSize>
struct GridwiseReduceSecondHalfBatchNormBackwardFinal struct GridwiseReduceSecondHalfBatchNormBackwardFinal
{ {
...@@ -155,13 +155,13 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -155,13 +155,13 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
const DscaleDbiasGridDesc_M_K& dscale_dbias_grid_desc_m_k, const DscaleDbiasGridDesc_M_K& dscale_dbias_grid_desc_m_k,
const MeanVarGridDesc_M& mean_var_grid_desc_m, const MeanVarGridDesc_M& mean_var_grid_desc_m,
const ScaleBiasGridDesc_M& scale_grid_desc_m, const ScaleBiasGridDesc_M& scale_grid_desc_m,
const ScaleBiasGridDesc_M& bias_grid_desc_m, const ScaleBiasGridDesc_M& dscale_dbias_grid_desc_m,
index_t blkgroup_size, index_t blkgroup_size,
long_index_t reduce_size, long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_dscale_dbias_k_block_tile_iteration, index_t num_dscale_dbias_k_block_tile_iteration,
const ScaleDataType* const __restrict__ p_reduce_dscale, const DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
const BiasDataType* const __restrict__ p_reduce_dbias, const DscaleDbiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean, const MeanVarDataType* const __restrict__ p_mean,
const MeanVarDataType* const __restrict__ p_inv_var, const MeanVarDataType* const __restrict__ p_inv_var,
const XDataType* const __restrict__ p_x, const XDataType* const __restrict__ p_x,
...@@ -169,8 +169,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -169,8 +169,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
const ScaleDataType* const __restrict__ p_scale, const ScaleDataType* const __restrict__ p_scale,
const DyElementwiseOp dy_elementwise_op, const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx, DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale, DscaleDbiasDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias) DscaleDbiasDataType* const __restrict__ p_dbias)
{ {
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
...@@ -222,24 +222,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -222,24 +222,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
// Step 1: do final reduction of dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance) // Step 1: do final reduction of dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// clang-format on // clang-format on
auto threadwise_dscale_load_m_k = auto threadwise_dscale_dbias_load_m_k =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType, ThreadwiseTensorSliceTransfer_v2<DscaleDbiasDataType,
AccDataType,
DscaleDbiasGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
dscale_dbias_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
auto threadwise_dbias_load_m_k =
ThreadwiseTensorSliceTransfer_v2<BiasDataType,
AccDataType, AccDataType,
DscaleDbiasGridDesc_M_K, DscaleDbiasGridDesc_M_K,
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
...@@ -254,38 +238,20 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -254,38 +238,20 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1)); thread_k_cluster_id * 1));
auto threadwise_dscale_store_m = auto threadwise_dscale_dbias_store_m =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType, DscaleDbiasDataType,
decltype(thread_buffer_desc_m), decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M, ScaleBiasGridDesc_M,
PassThroughOp, PassThroughOp,
ThreadBufferLengths_M, ThreadBufferLengths_M,
Sequence<0>, Sequence<0>,
0, 0,
ScaleSrcDstVectorSize, DscaleDbiasDstVectorSize,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
scale_grid_desc_m, dscale_dbias_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dbias_store_m =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType,
decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
BiasDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
bias_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize), thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{}); PassThroughOp{});
...@@ -297,10 +263,10 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -297,10 +263,10 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
p_reduce_dbias, dscale_dbias_grid_desc_m_k.GetElementSpaceSize()); p_reduce_dbias, dscale_dbias_grid_desc_m_k.GetElementSpaceSize());
auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dscale, scale_grid_desc_m.GetElementSpaceSize()); p_dscale, dscale_dbias_grid_desc_m.GetElementSpaceSize());
auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbias, bias_grid_desc_m.GetElementSpaceSize()); p_dbias, dscale_dbias_grid_desc_m.GetElementSpaceSize());
constexpr auto dscale_dbias_thread_copy_step_m_k = constexpr auto dscale_dbias_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize * 1); make_multi_index(0, KThreadClusterSize * 1);
...@@ -313,25 +279,23 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -313,25 +279,23 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
for(index_t reducedTiles = 0; reducedTiles < num_dscale_dbias_k_block_tile_iteration; for(index_t reducedTiles = 0; reducedTiles < num_dscale_dbias_k_block_tile_iteration;
++reducedTiles) ++reducedTiles)
{ {
threadwise_dscale_load_m_k.Run(dscale_dbias_grid_desc_m_k, threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_dscale_global_buf, reduce_dscale_global_buf,
thread_buffer_desc_m_1, thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
reduce_dscale_thread_buf); reduce_dscale_thread_buf);
threadwise_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k, threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_dbias_global_buf, reduce_dbias_global_buf,
thread_buffer_desc_m_1, thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
reduce_dbias_thread_buf); reduce_dbias_thread_buf);
ThreadwiseReduce::Reduce(reduce_dscale_thread_buf, dscale_thread_buf); ThreadwiseReduce::Reduce(reduce_dscale_thread_buf, dscale_thread_buf);
ThreadwiseReduce::Reduce(reduce_dbias_thread_buf, dbias_thread_buf); ThreadwiseReduce::Reduce(reduce_dbias_thread_buf, dbias_thread_buf);
threadwise_dscale_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k, threadwise_dscale_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
dscale_dbias_thread_copy_step_m_k); dscale_dbias_thread_copy_step_m_k);
threadwise_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
dscale_dbias_thread_copy_step_m_k);
} }
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
...@@ -343,17 +307,17 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -343,17 +307,17 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I)); BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I));
}); });
threadwise_dscale_store_m.Run(thread_buffer_desc_m, threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
dscale_thread_buf, dscale_thread_buf,
scale_grid_desc_m, dscale_dbias_grid_desc_m,
dscale_global_buf); dscale_global_buf);
threadwise_dbias_store_m.Run(thread_buffer_desc_m, threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
dbias_thread_buf, dbias_thread_buf,
bias_grid_desc_m, dscale_dbias_grid_desc_m,
dbias_global_buf); dbias_global_buf);
// clang-format off // clang-format off
// Step 2: calculate dx = 1/N * inv-variance * scale * (N * dy - dbias - dscale * (x - mean) * inv-variance) // Step 2: calculate dx = 1/N * inv-variance * scale * (N * dy - dbias - dscale * (x - mean) * inv-variance)
...@@ -418,7 +382,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -418,7 +382,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
ThreadBufferLengths_M, ThreadBufferLengths_M,
Sequence<0>, Sequence<0>,
0, 0,
ScaleSrcDstVectorSize, ScaleSrcVectorSize,
1, 1,
true>( true>(
scale_grid_desc_m, scale_grid_desc_m,
......
...@@ -17,7 +17,7 @@ template <typename GridwiseWelfordSecondHalfReduceFirstHalf_, ...@@ -17,7 +17,7 @@ template <typename GridwiseWelfordSecondHalfReduceFirstHalf_,
typename DyDataType, typename DyDataType,
typename AccDataType, typename AccDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename DscaleDbiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp, typename DyElementwiseOp,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
...@@ -45,8 +45,8 @@ __global__ void kernel_welford_second_half_reduce_first_half( ...@@ -45,8 +45,8 @@ __global__ void kernel_welford_second_half_reduce_first_half(
MeanVarDataType* const __restrict__ p_out_welford_inv_variance, MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x, const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy, const DyDataType* const __restrict__ p_dy,
ScaleDataType* const __restrict__ p_reduce_dscale, DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
BiasDataType* const __restrict__ p_reduce_dbias) DscaleDbiasDataType* const __restrict__ p_reduce_dbias)
{ {
GridwiseWelfordSecondHalfReduceFirstHalf_::Run(x_grid_desc_m_k, GridwiseWelfordSecondHalfReduceFirstHalf_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k, dy_grid_desc_m_k,
...@@ -76,7 +76,7 @@ template <typename XDataType, ...@@ -76,7 +76,7 @@ template <typename XDataType,
typename DyDataType, typename DyDataType,
typename AccDataType, typename AccDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename DscaleDbiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp, typename DyElementwiseOp,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
...@@ -174,8 +174,8 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -174,8 +174,8 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
MeanVarDataType* const __restrict__ p_out_welford_inv_variance, MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x, const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy, const DyDataType* const __restrict__ p_dy,
ScaleDataType* const __restrict__ p_reduce_dscale, DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
BiasDataType* const __restrict__ p_reduce_dbias) DscaleDbiasDataType* const __restrict__ p_reduce_dbias)
{ {
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
...@@ -511,28 +511,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -511,28 +511,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
BlockwiseReduce::Reduce(reduce_work_buf, reduce_dbias_thread_buf(I)); BlockwiseReduce::Reduce(reduce_work_buf, reduce_dbias_thread_buf(I));
}); });
auto threadwise_dscale_store = auto threadwise_dscale_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType, DscaleDbiasDataType,
decltype(thread_buffer_desc_m_1),
DscaleDbiasGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
dscale_dbias_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto threadwise_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType,
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
DscaleDbiasGridDesc_M_G, DscaleDbiasGridDesc_M_G,
PassThroughOp, PassThroughOp,
...@@ -557,17 +538,17 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -557,17 +538,17 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
{ {
threadwise_dscale_store.Run(thread_buffer_desc_m_1, threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
reduce_dscale_thread_buf, reduce_dscale_thread_buf,
dscale_dbias_grid_desc_m_g, dscale_dbias_grid_desc_m_g,
reduce_dscale_global_buf); reduce_dscale_global_buf);
threadwise_dbias_store.Run(thread_buffer_desc_m_1, threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
reduce_dbias_thread_buf, reduce_dbias_thread_buf,
dscale_dbias_grid_desc_m_g, dscale_dbias_grid_desc_m_g,
reduce_dbias_global_buf); reduce_dbias_global_buf);
}; };
}; };
}; };
......
...@@ -796,6 +796,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -796,6 +796,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
}); });
} }
else
{
static_for<0, acc_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
......
...@@ -21,7 +21,7 @@ template <typename GridwiseBatchrNormBackwardWithBlockwiseWelford_, ...@@ -21,7 +21,7 @@ template <typename GridwiseBatchrNormBackwardWithBlockwiseWelford_,
typename DxDataType, typename DxDataType,
typename AccDataType, typename AccDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename DscaleDbiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp, typename DyElementwiseOp,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
...@@ -33,7 +33,7 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford( ...@@ -33,7 +33,7 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
const XYGridDesc_M_K dy_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k,
const XYGridDesc_M_K dx_grid_desc_m_k, const XYGridDesc_M_K dx_grid_desc_m_k,
const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M bias_grid_desc_m, const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m,
const MeanVarGridDesc_M mean_var_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
long_index_t reduce_size, long_index_t reduce_size,
...@@ -47,14 +47,14 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford( ...@@ -47,14 +47,14 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
const MeanVarDataType* const __restrict__ p_savedInvVar, const MeanVarDataType* const __restrict__ p_savedInvVar,
const DyElementwiseOp dy_elementwise_op, const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx, DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale, DscaleDbiasDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias) DscaleDbiasDataType* const __restrict__ p_dbias)
{ {
GridwiseBatchrNormBackwardWithBlockwiseWelford_::Run(x_grid_desc_m_k, GridwiseBatchrNormBackwardWithBlockwiseWelford_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k, dy_grid_desc_m_k,
dx_grid_desc_m_k, dx_grid_desc_m_k,
scale_grid_desc_m, scale_grid_desc_m,
bias_grid_desc_m, dscale_dbias_grid_desc_m,
mean_var_grid_desc_m, mean_var_grid_desc_m,
get_reduce_count_per_thread, get_reduce_count_per_thread,
reduce_size, reduce_size,
...@@ -77,7 +77,7 @@ template <typename XDataType, ...@@ -77,7 +77,7 @@ template <typename XDataType,
typename DxDataType, typename DxDataType,
typename AccDataType, typename AccDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename DscaleDbiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp, typename DyElementwiseOp,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
...@@ -93,8 +93,8 @@ template <typename XDataType, ...@@ -93,8 +93,8 @@ template <typename XDataType,
index_t XSrcVectorSize, index_t XSrcVectorSize,
index_t DySrcVectorSize, index_t DySrcVectorSize,
index_t DxDstVectorSize, index_t DxDstVectorSize,
index_t ScaleSrcDstVectorSize, index_t ScaleSrcVectorSize,
index_t BiasDstVectorSize, index_t DscaleDbiasDstVectorSize,
index_t MeanVarSrcVectorSize> index_t MeanVarSrcVectorSize>
struct GridwiseBatchNormBackwardWithBlockwiseWelford struct GridwiseBatchNormBackwardWithBlockwiseWelford
{ {
...@@ -165,7 +165,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -165,7 +165,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
const XYGridDesc_M_K dy_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k,
const XYGridDesc_M_K dx_grid_desc_m_k, const XYGridDesc_M_K dx_grid_desc_m_k,
const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M bias_grid_desc_m, const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m,
const MeanVarGridDesc_M mean_var_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
long_index_t reduce_size, long_index_t reduce_size,
...@@ -179,8 +179,8 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -179,8 +179,8 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
const MeanVarDataType* const __restrict__ p_savedInvVar, const MeanVarDataType* const __restrict__ p_savedInvVar,
const DyElementwiseOp dy_elementwise_op, const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx, DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale, DscaleDbiasDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias) DscaleDbiasDataType* const __restrict__ p_dbias)
{ {
using ck::math::sqrt; using ck::math::sqrt;
...@@ -253,7 +253,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -253,7 +253,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
XSrcVectorSize, XSrcVectorSize,
1, 1,
true>( true>(
x_grid_desc_m_k, dy_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * KThreadSliceSize));
...@@ -271,7 +271,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -271,7 +271,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
dy_grid_desc_m_k, dx_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize), thread_k_cluster_id * KThreadSliceSize),
...@@ -285,45 +285,27 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -285,45 +285,27 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
ThreadBufferLengths_M, ThreadBufferLengths_M,
Sequence<0>, Sequence<0>,
0, 0,
ScaleSrcDstVectorSize, ScaleSrcVectorSize,
1, 1,
true>( true>(
scale_grid_desc_m, scale_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize)); thread_m_cluster_id * MThreadSliceSize));
auto threadwise_dscale_store = auto threadwise_dscale_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType,
decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
scale_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType, DscaleDbiasDataType,
decltype(thread_buffer_desc_m), decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M, ScaleBiasGridDesc_M,
PassThroughOp, PassThroughOp,
ThreadBufferLengths_M, ThreadBufferLengths_M,
Sequence<0>, Sequence<0>,
0, 0,
BiasDstVectorSize, DscaleDbiasDstVectorSize,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
bias_grid_desc_m, dscale_dbias_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize), thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{}); PassThroughOp{});
...@@ -344,10 +326,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -344,10 +326,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
p_scale, scale_grid_desc_m.GetElementSpaceSize()); p_scale, scale_grid_desc_m.GetElementSpaceSize());
auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dscale, scale_grid_desc_m.GetElementSpaceSize()); p_dscale, dscale_dbias_grid_desc_m.GetElementSpaceSize());
auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbias, bias_grid_desc_m.GetElementSpaceSize()); p_dbias, dscale_dbias_grid_desc_m.GetElementSpaceSize());
// clang-format off // clang-format off
// Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance) // Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance)
...@@ -487,17 +469,17 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -487,17 +469,17 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
{ {
threadwise_dscale_store.Run(thread_buffer_desc_m, threadwise_dscale_dbias_store.Run(thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
dscale_thread_buf, dscale_thread_buf,
scale_grid_desc_m, dscale_dbias_grid_desc_m,
dscale_global_buf); dscale_global_buf);
threadwise_dbias_store.Run(thread_buffer_desc_m, threadwise_dscale_dbias_store.Run(thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
dbias_thread_buf, dbias_thread_buf,
bias_grid_desc_m, dscale_dbias_grid_desc_m,
dbias_global_buf); dbias_global_buf);
}; };
// clang-format off // clang-format off
......
// SPDX-License-Identifier: MIT
// // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
//
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseElementwise2dFunctor,
typename InGrid2dDescTuple,
typename OutGrid2dDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation>
__global__ void kernel_elementwise_2d(const InGrid2dDescTuple in_grid_2d_desc_tuple,
const OutGrid2dDescTuple out_grid_2d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const index_t num_threads_m,
const index_t num_threads_n)
{
GridwiseElementwise2dFunctor::Run(in_grid_2d_desc_tuple,
out_grid_2d_desc_tuple,
p_in_global_tuple,
p_out_global_tuple,
elementwise_op,
num_threads_m,
num_threads_n);
}
template <typename InGrid2dDescTuple,
typename OutGrid2dDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation,
index_t MPerThread,
index_t NPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct GridwiseElementwise_2D
{
static constexpr index_t NumInput = InDataTypePointerTuple::Size();
static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size() &&
NumInput == InGrid2dDescTuple::Size() &&
NumOutput == OutGrid2dDescTuple::Size(),
"Tuple size is inconsistent with the number of in/out!");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto thread_buffer_desc_mn =
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}, Number<NPerThread>{}));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
__device__ static void Run(const InGrid2dDescTuple in_grid_2d_desc_tuple,
const OutGrid2dDescTuple out_grid_2d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const index_t num_threads_m,
const index_t num_threads_n)
{
auto in_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return StaticBuffer<AddressSpaceEnum::Vgpr,
DataType,
MPerThread * NPerThread,
true>{};
},
Number<NumInput>{});
auto out_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return StaticBuffer<AddressSpaceEnum::Vgpr,
DataType,
MPerThread * NPerThread,
true>{};
},
Number<NumOutput>{});
auto in_global_buf_tuple = generate_tuple(
[&](auto I) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global_tuple[I], in_grid_2d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumInput>{});
auto out_global_buf_tuple = generate_tuple(
[&](auto I) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global_tuple[I], out_grid_2d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumOutput>{});
const auto M = in_grid_2d_desc_tuple[I0].GetLength(I0);
const auto N = in_grid_2d_desc_tuple[I0].GetLength(I1);
const index_t loop_step_m = num_threads_m * MPerThread;
const index_t loop_step_n = num_threads_n * NPerThread;
const index_t thread_1d_id = get_thread_global_1d_id();
index_t tid_m = thread_1d_id / num_threads_n;
index_t tid_n = thread_1d_id % num_threads_n;
const auto thread_global_offset = make_multi_index(tid_m * MPerThread, tid_n * NPerThread);
auto in_global_load_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return ThreadwiseTensorSliceTransfer_v2<
DataType,
DataType,
decltype(in_grid_2d_desc_tuple[I]),
decltype(thread_buffer_desc_mn),
Sequence<MPerThread, NPerThread>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
0, // SrcVectorDim
InScalarPerVectorSeq::At(I), // ScalarPerVector
1, // SrcScalarStrideInVector
true>{in_grid_2d_desc_tuple[I], thread_global_offset};
},
Number<NumInput>{});
auto out_global_store_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return ThreadwiseTensorSliceTransfer_v1r3<
DataType,
DataType,
decltype(thread_buffer_desc_mn),
decltype(out_grid_2d_desc_tuple[I]),
PassThroughOp,
Sequence<MPerThread, NPerThread>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
1, // SrcVectorDim
1, // OutScalarPerVectorSeq::At(I),
InMemoryDataOperationEnum::Set,
1,
true>(out_grid_2d_desc_tuple[I], thread_global_offset, PassThroughOp{});
},
Number<NumOutput>{});
index_t num_iter_m = M / (loop_step_m);
do
{
index_t num_iter_n = N / (loop_step_n);
do
{
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).Run(in_grid_2d_desc_tuple[I],
in_global_buf_tuple[I],
thread_buffer_desc_mn,
make_tuple(I0, I0),
in_thread_buf_tuple(I));
in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_2d_desc_tuple[I],
make_multi_index(0, loop_step_n));
});
static_for<0, MPerThread, 1>{}([&](auto iM) {
static_for<0, NPerThread, 1>{}([&](auto iN) {
constexpr auto offset =
thread_buffer_desc_mn.CalculateOffset(make_tuple(iM, iN));
// get reference to in data
const auto in_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> const auto& {
return in_thread_buf_tuple(I)(Number<offset>{});
},
Number<NumInput>{});
// get referenec to dst data
auto out_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& {
return out_thread_buf_tuple(I)(Number<offset>{});
},
Number<NumOutput>{});
unpack2(elementwise_op, out_data_refs, in_data_refs);
});
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).Run(thread_buffer_desc_mn,
make_tuple(I0, I0),
out_thread_buf_tuple[I],
out_grid_2d_desc_tuple[I],
out_global_buf_tuple(I));
out_global_store_tuple(I).MoveDstSliceWindow(out_grid_2d_desc_tuple[I],
make_multi_index(0, loop_step_n));
});
} while(--num_iter_n);
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).MoveSrcSliceWindow(
in_grid_2d_desc_tuple[I],
make_multi_index(loop_step_m, -(N / loop_step_n) * loop_step_n));
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).MoveDstSliceWindow(
out_grid_2d_desc_tuple[I],
make_multi_index(loop_step_m, -(N / loop_step_n) * loop_step_n));
});
} while(--num_iter_m);
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename DsDataType,
typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1Value,
index_t M1PerThreadM111,
index_t N1PerThreadN111,
index_t KPerThread,
typename M11N11ThreadClusterM110Xs,
typename M11N11ThreadClusterN110Xs,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct GridwiseGemmDlMultipleD_km_kn_mn
{
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
// ck::Tuple<const D0DataType*, const D1DataType*, ...>
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) &&
K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_grid_desc_k0_n_k1.GetLength(I2)) &&
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
}
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
{
const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
return has_main_k_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
{
const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static constexpr auto
MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1)
{
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto M1 = Number<MPerBlock>{};
const auto M0 = M / M1;
const auto a_grid_desc_k0_m0_m1_k1 =
transform_tensor_descriptor(a_grid_desc_k0_m_k1,
make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(M0, M1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return a_grid_desc_k0_m0_m1_k1;
}
__host__ __device__ static constexpr auto
MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1)
{
const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto N1 = Number<NPerBlock>{};
const auto N0 = N / N1;
const auto b_grid_desc_k0_n0_n1_k1 =
transform_tensor_descriptor(b_grid_desc_k0_n_k1,
make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(N0, N1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return b_grid_desc_k0_n0_n1_k1;
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
constexpr auto M11 =
Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies{}, I1) *
M1PerThreadM111>{};
constexpr auto N11 =
Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies{}, I1) *
N1PerThreadN111>{};
constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11;
const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return c_grid_desc_m0_m10_m11_n0_n10_n11;
}
// Ds desc for source in blockwise copy
template <typename DsGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(const DsGridDesc_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) { return MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(ds_grid_desc_m_n[i]); },
Number<NumDTensor>{});
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
using DsGridPointer = decltype(MakeDsGridPointer());
template <typename DsGridDesc_M0_M10_M11_N0_N10_N11,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AElementwiseOperation&,
const BElementwiseOperation&,
const CDEElementwiseOperation& cde_element_op,
const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1,
const DsGridDesc_M0_M10_M11_N0_N10_N11& ds_grid_desc_m0_m10_m11_n0_n10_n11,
const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap& block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
// divide block work by [M, N]
const auto c_m0_n0_block_cluster_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
if(!block_2_ctile_map.ValidCTileIndex(
make_tuple(im0, in0),
make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
{
return;
}
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, I1, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
a_k0_m_k1_block_desc.GetElementSpaceSize() &&
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
b_k0_n_k1_block_desc.GetElementSpaceSize() &&
"wrong!");
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, 1, MPerBlock, K1.value>,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
remove_reference_t<decltype(a_grid_desc_k0_m0_m1_k1)>,
decltype(a_block_desc_k0_m0_m1_k1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
false,
true>(a_grid_desc_k0_m0_m1_k1,
make_multi_index(0, im0, 0, 0),
a_block_desc_k0_m0_m1_k1,
make_multi_index(0, 0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, 1, NPerBlock, K1.value>,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>,
decltype(b_block_desc_k0_n0_n1_k1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
false,
true>(b_grid_desc_k0_n0_n1_k1,
make_multi_index(0, in0, 0, 0),
b_block_desc_k0_n0_n1_k1,
make_multi_index(0, 0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
const auto blockwise_gemm =
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN110Xs,
M1PerThreadM111,
N1PerThreadN111>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed(
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
// Initialize C
c_thread_buf.Clear();
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size,
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0);
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_block_slice_copy_step);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11,
a_block_even_buf,
b_block_even_buf,
c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_block_slice_copy_step);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
k_block_data_begin += 2 * K0PerBlock;
} while(k_block_data_begin < K0 - 2 * K0PerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step);
block_sync_lds();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
block_sync_lds();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id());
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i], ds_grid_desc_m0_m10_m11_n0_n10_n11[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
auto ds_thread_buf = generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return StaticBuffer<AddressSpaceEnum::Vgpr,
DDataType,
c_m10_m11_n10_n11_thread_tensor_lengths[I3],
true>{};
},
Number<NumDTensor>{});
auto ds_threadwise_copy = generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return ThreadwiseTensorSliceTransfer_v2<
DDataType,
DDataType,
decltype(ds_grid_desc_m0_m10_m11_n0_n10_n11[i]),
decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
Sequence<I1,
I1,
I1,
I1,
I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
1,
false>(ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
make_multi_index(im0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
in0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]));
},
Number<NumDTensor>{});
static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I0], 1>{}([&](auto m10) {
static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I1], 1>{}([&](auto m11) {
static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I2], 1>{}([&](auto n10) {
// load d matrix data
static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_threadwise_copy(i).Run(ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
ds_grid_buf[i],
c_thread_desc_m0_m10_m11_n0_n10_n11,
make_tuple(I0, I0, I0, I0, I0, I0),
ds_thread_buf(i));
});
// cal element op
static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I3], 1>{}(
[&](auto i) {
// get reference to src data
const auto src_data_refs = generate_tie(
// return type should be lvalue
[&](auto iSrc) -> const auto& {
return ds_thread_buf[iSrc][i];
},
Number<NumDTensor>{});
// get reference to dst data
constexpr index_t c_offset =
c_thread_desc_m0_m10_m11_n0_n10_n11.CalculateOffset(
make_tuple(0, m10, m11, 0, n10, i));
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto) -> auto& { return c_thread_buf(Number<c_offset>{}); },
Number<2>{});
unpack2(cde_element_op, dst_data_refs, src_data_refs);
});
static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_threadwise_copy(i).MoveSrcSliceWindow(
ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
make_multi_index(0, 0, 0, 0, 1, 0));
});
});
static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_threadwise_copy(i).MoveSrcSliceWindow(
ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
make_multi_index(
0, 0, 1, 0, -c_m10_m11_n10_n11_thread_tensor_lengths[I2], 0));
});
});
static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_threadwise_copy(i).MoveSrcSliceWindow(
ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
make_multi_index(
0, 1, -c_m10_m11_n10_n11_thread_tensor_lengths[I1], 0, 0, 0));
});
});
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1,
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
1,
c_m10_m11_n10_n11_thread_tensor_lengths[I2],
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
make_multi_index(im0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
in0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]),
ck::tensor_operation::element_wise::PassThrough{}}
.Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_m0_m10_m11_n0_n10_n11,
c_grid_buf);
}
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseMultiblockWelfordFirstHalf_,
typename XDataType,
typename MeanVarDataType,
typename XGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename GetReduceCountPerThreadFunctor>
__global__ void kernel_multiblock_welford_first_half(
const XGridDesc_M_K x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const p_welford_mean,
MeanVarDataType* const p_welford_variance,
int32_t* const p_welford_count)
{
GridwiseMultiblockWelfordFirstHalf_::Run(x_grid_desc_m_k,
mean_var_count_grid_desc_m_g,
get_reduce_count_per_thread,
num_k_block_tile_iteration,
p_x,
p_welford_mean,
p_welford_variance,
p_welford_count);
};
template <typename XDataType,
typename AccDataType,
typename MeanVarDataType,
typename XGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename GetReduceCountPerThreadFunctor,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcCountSrcVectorDim,
index_t XSrcCountSrcVectorSize>
struct GridwiseMultiblockWelfordFirstHalf
{
static_assert((XSrcCountSrcVectorDim == 0 && MThreadSliceSize % XSrcCountSrcVectorSize == 0) ||
(XSrcCountSrcVectorDim == 1 &&
KThreadSliceSize % XSrcCountSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XSrcCountSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
false>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const p_welford_mean,
MeanVarDataType* const p_welford_variance,
int32_t* const p_welford_count)
{
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf;
const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1);
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / blkgroup_size;
const index_t block_local_id = block_global_id % blkgroup_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
XGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcCountSrcVectorDim,
XSrcCountSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_welford_mean_var_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto threadwise_welford_count_store =
ThreadwiseTensorSliceTransfer_v1r3<int32_t,
int32_t,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford();
threadwise_welford.max_count_ =
get_reduce_count_per_thread(block_local_id, thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
});
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_welford.Run(x_thread_buf, welford_mean_thread_buf, welford_var_thread_buf);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
welford_count_thread_buf(I) = threadwise_welford.cur_count_;
BlockwiseWelford::Run(
welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
});
if(thread_k_cluster_id == 0)
{
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_mean_thread_buf,
mean_var_count_grid_desc_m_g,
welford_mean_global_val_buf);
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_var_thread_buf,
mean_var_count_grid_desc_m_g,
welford_var_global_val_buf);
threadwise_welford_count_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_count_thread_buf,
mean_var_count_grid_desc_m_g,
welford_count_global_val_buf);
};
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_WMMA_HPP
#define CK_AMD_WMMA_HPP
#include "data_type.hpp"
// TODO: Add arch limitation
namespace ck {
// wave32 only
// src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w32;
template <>
struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
}
};
// src: bf16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_bf16_w32;
template <>
struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
{
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
}
};
// src: fp16, dst: fp16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w32;
template <index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel);
}
};
// src: bf16, dst: bf16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w32;
template <index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
{
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
}
};
// src: iu8, dst: i32
template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w32;
template <bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
{
template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<int32x8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
neg_a,
bit_cast<int32x4_t>(reg_a),
neg_b,
bit_cast<int32x4_t>(reg_b),
reg_c.template AsType<int32x8_t>()[Number<0>{}],
clamp);
}
};
} // namespace ck
#endif
...@@ -114,7 +114,16 @@ static inline __device__ int4_t abs(int4_t x) ...@@ -114,7 +114,16 @@ static inline __device__ int4_t abs(int4_t x)
}; };
#endif #endif
static inline __device__ half_t abs(half_t x) { return ::__habs(x); }; static inline __device__ half_t abs(half_t x)
{
uint16_t xx = ck::bit_cast<uint16_t>(x);
uint16_t abs_xx = xx & 0x7fff;
half_t abs_x = ck::bit_cast<half_t>(abs_xx);
return abs_x;
};
static inline __device__ bool isnan(float x) { return ::isnan(x); }; static inline __device__ bool isnan(float x) { return ::isnan(x); };
...@@ -140,7 +149,12 @@ static inline __device__ bool isnan(int4_t x) ...@@ -140,7 +149,12 @@ static inline __device__ bool isnan(int4_t x)
}; };
#endif #endif
static inline __device__ bool isnan(half_t x) { return ::__hisnan(x); }; static inline __device__ bool isnan(half_t x)
{
uint16_t xx = ck::bit_cast<uint16_t>(x);
return (xx & 0x7FFF) > 0x7C00;
};
static inline __device__ float sqrt(float x) { return ::sqrtf(x); }; static inline __device__ float sqrt(float x) { return ::sqrtf(x); };
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <array>
#include <algorithm>
#include <thread>
#include "ck/utility/math_v2.hpp"
#include "ck/utility/ignore.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename XDataType,
typename DxDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct ReferenceBatchNormBwd : public device::DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumBatchNormReduceDim>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
struct Argument : public device::BaseArgument
{
Argument(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> dxStrides,
const std::array<index_t, Rank> dyStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, NumInvariantDim> bnScaleStrides,
const std::array<index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<index_t, NumInvariantDim> bnMeanVarStrides,
const XDataType* p_x,
const DyDataType* p_dy,
const ScaleDataType* p_scale,
const MeanVarDataType* p_savedMean,
const MeanVarDataType* p_savedInvVar,
double epsilon,
const DyElementwiseOp dy_elementwise_op,
DxDataType* p_dx,
DscaleDbiasDataType* p_dscale,
DscaleDbiasDataType* p_dbias)
: reduceDims_(reduceDims),
bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides),
bnDscaleDbiasStrides_(bnDscaleDbiasStrides),
bnMeanVarStrides_(bnMeanVarStrides),
p_x_(p_x),
p_dy_(p_dy),
p_scale_(p_scale),
p_savedMean_(p_savedMean),
p_savedInvVar_(p_savedInvVar),
dy_elementwise_op_(dy_elementwise_op),
p_dx_(p_dx),
p_dscale_(p_dscale),
p_dbias_(p_dbias)
{
using ck::host_common::get_index_set;
if(std::any_of(
reduceDims.begin(), reduceDims.end(), [](int d) { return d < 0 || d >= Rank; }))
throw std::runtime_error("Invalid reduce dimensions!");
// get invariant_dims[] and invariant_lengths[]
for(int dim = 0, i = 0; dim < Rank; dim++)
if(std::none_of(
reduceDims.begin(), reduceDims.end(), [&](int d) { return d == dim; }))
{
invariantDims_[i] = dim;
invariant_lengths_[i] = xyLengths[dim];
i++;
};
// get reduce_lengths_[]
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims[j];
reduce_lengths_[i++] = xyLengths[dim];
};
for(int i = 0; i < NumInvariantDim; i++)
if(invariant_lengths_[i] != bnScaleBiasMeanVarLengths_[i])
throw std::runtime_error("Invalid lengths parameters!");
for(int j = 0, i = 0; j < NumInvariantDim; j++)
{
int dim = invariantDims_[j];
x_invariant_strides_[i] = xStrides[dim];
dy_invariant_strides_[i] = dyStrides[dim];
dx_invariant_strides_[i] = dxStrides[dim];
i++;
};
for(int j = 0, i = 0; j < NumBatchNormReduceDim; j++)
{
int dim = reduceDims_[j];
x_reduce_strides_[i] = xStrides[dim];
dy_reduce_strides_[i] = dyStrides[dim];
dx_reduce_strides_[i] = dxStrides[dim];
i++;
};
reduceSize_ = std::accumulate(
reduce_lengths_.begin(), reduce_lengths_.end(), 1, std::multiplies<size_t>{});
invariant_index_set_ = get_index_set<NumInvariantDim>(invariant_lengths_);
reduce_index_set_ = get_index_set<NumBatchNormReduceDim>(reduce_lengths_);
epsilon_ = type_convert<AccDataType>(epsilon);
haveSavedMeanInvVar_ = (p_savedMean != nullptr && p_savedInvVar != nullptr);
}
std::array<int, NumBatchNormReduceDim> reduceDims_;
std::array<int, NumInvariantDim> invariantDims_;
std::array<index_t, NumInvariantDim> invariant_lengths_;
std::array<index_t, NumBatchNormReduceDim> reduce_lengths_;
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths_;
const std::array<index_t, NumInvariantDim> bnScaleStrides_;
const std::array<index_t, NumInvariantDim> bnDscaleDbiasStrides_;
const std::array<index_t, NumInvariantDim> bnMeanVarStrides_;
std::array<index_t, NumInvariantDim> x_invariant_strides_;
std::array<index_t, NumInvariantDim> dy_invariant_strides_;
std::array<index_t, NumInvariantDim> dx_invariant_strides_;
std::array<index_t, NumBatchNormReduceDim> x_reduce_strides_;
std::array<index_t, NumBatchNormReduceDim> dy_reduce_strides_;
std::array<index_t, NumBatchNormReduceDim> dx_reduce_strides_;
const XDataType* p_x_;
const DyDataType* p_dy_;
const ScaleDataType* p_scale_;
const MeanVarDataType* p_savedMean_;
const MeanVarDataType* p_savedInvVar_;
const DyElementwiseOp dy_elementwise_op_;
DxDataType* p_dx_;
DscaleDbiasDataType* p_dscale_;
DscaleDbiasDataType* p_dbias_;
bool haveSavedMeanInvVar_;
std::vector<std::array<index_t, NumInvariantDim>> invariant_index_set_;
std::vector<std::array<index_t, NumBatchNormReduceDim>> reduce_index_set_;
AccDataType epsilon_;
size_t reduceSize_;
};
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg)
{
using ck::host_common::get_offset_from_index;
auto thread_reduce_func = [&](auto invariant_index) {
size_t x_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.x_invariant_strides_, invariant_index);
size_t dy_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.dy_invariant_strides_, invariant_index);
size_t dx_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.dx_invariant_strides_, invariant_index);
AccDataType mean = type_convert<AccDataType>(0.0f);
AccDataType variance = type_convert<AccDataType>(0.0f);
AccDataType invVar;
int32_t curr_count = 0;
if(arg.haveSavedMeanInvVar_)
{
size_t mean_invVar_invariant_offset = get_offset_from_index<NumInvariantDim>(
arg.bnMeanVarStrides_, invariant_index);
mean =
type_convert<AccDataType>(arg.p_savedMean_[mean_invVar_invariant_offset]);
invVar =
type_convert<AccDataType>(arg.p_savedInvVar_[mean_invVar_invariant_offset]);
}
else
{
// compute mean, variance using welford method
for(const auto& reduce_index : arg.reduce_index_set_)
{
size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.x_reduce_strides_, reduce_index);
auto x_offset = x_invariant_offset + x_reduce_offset;
curr_count++;
AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType delta = x - mean;
mean += delta / curr_count;
AccDataType delta2 = x - mean;
variance += delta * delta2;
};
// actual variance
variance = variance / curr_count;
// inv-variance defined as 1/sqrt(epsilon+variance)
invVar =
type_convert<AccDataType>(1.0f) / ck::math::sqrt(arg.epsilon_ + variance);
};
AccDataType dbias =
type_convert<AccDataType>(0.0f); // Sum on reduced dimensions of dy
AccDataType dscale =
type_convert<AccDataType>(0.0f); // Sum on reduced dimensions of dy * norm_x
// 1) calculate dy * (x - mean) * inv-variance
// 2) calculate sum(dy) on reduced dimensions
// 3) calculate sum(dy * norm_x) on reduced dimensions
for(const auto& reduce_index : arg.reduce_index_set_)
{
size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.x_reduce_strides_, reduce_index);
size_t dy_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.dy_reduce_strides_, reduce_index);
auto x_offset = x_invariant_offset + x_reduce_offset;
auto dy_offset = dy_invariant_offset + dy_reduce_offset;
AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType norm_x = (x - mean) * invVar;
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[dy_offset]);
arg.dy_elementwise_op_(dy, dy);
dbias += dy;
dscale += norm_x * dy;
};
size_t dscale_offset = get_offset_from_index<NumInvariantDim>(
arg.bnDscaleDbiasStrides_, invariant_index);
size_t dbias_offset = get_offset_from_index<NumInvariantDim>(
arg.bnDscaleDbiasStrides_, invariant_index);
arg.p_dscale_[dscale_offset] = type_convert<DscaleDbiasDataType>(dscale);
arg.p_dbias_[dbias_offset] = type_convert<DscaleDbiasDataType>(dbias);
size_t scale_offset =
get_offset_from_index<NumInvariantDim>(arg.bnScaleStrides_, invariant_index);
AccDataType scale = type_convert<AccDataType>(arg.p_scale_[scale_offset]);
AccDataType multiplier = type_convert<AccDataType>(1.0f) /
type_convert<AccDataType>(arg.reduceSize_) * invVar *
scale;
// 1) calculate tmp = dscale * (x - mean) * inv-variance
// 2) calculate dx = 1/reduceSize * inv-variance * scale * (reduceSize * dy - dbias
// - tmp)
for(const auto& reduce_index : arg.reduce_index_set_)
{
size_t x_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.x_reduce_strides_, reduce_index);
size_t dy_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.dy_reduce_strides_, reduce_index);
size_t dx_reduce_offset = get_offset_from_index<NumBatchNormReduceDim>(
arg.dx_reduce_strides_, reduce_index);
auto x_offset = x_invariant_offset + x_reduce_offset;
auto dy_offset = dy_invariant_offset + dy_reduce_offset;
auto dx_offset = dx_invariant_offset + dx_reduce_offset;
AccDataType x = type_convert<AccDataType>(arg.p_x_[x_offset]);
AccDataType norm_x = (x - mean) * invVar;
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[dy_offset]);
arg.dy_elementwise_op_(dy, dy);
AccDataType tmpVal = norm_x * dscale;
AccDataType dx = multiplier * (type_convert<AccDataType>(arg.reduceSize_) * dy -
dbias - tmpVal);
arg.p_dx_[dx_offset] = type_convert<DxDataType>(dx);
};
};
std::size_t num_thread = std::thread::hardware_concurrency();
std::size_t work_per_thread =
(arg.invariant_index_set_.size() + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t i_begin = it * work_per_thread;
std::size_t i_end = std::min(static_cast<size_t>((it + 1) * work_per_thread),
arg.invariant_index_set_.size());
auto f = [=] {
for(std::size_t i = i_begin; i < i_end; ++i)
{
thread_reduce_func(arg.invariant_index_set_[i]);
}
};
threads[it] = joinable_thread(f);
}
return (0.0f);
};
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
};
};
bool IsSupportedArgument(const device::BaseArgument* p_arg) override
{
(void)p_arg;
return (true);
};
std::unique_ptr<device::BaseArgument>
MakeArgumentPointer(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> dxStrides,
const std::array<index_t, Rank> dyStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<index_t, NumInvariantDim> bnScaleStrides,
const std::array<index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<index_t, NumInvariantDim> bnMeanVarStrides,
const void* p_x,
const void* p_dy,
const void* p_scale,
const void* p_savedMean,
const void* p_savedInvVar,
double epsilon,
const DyElementwiseOp dy_elementwise_op,
void* p_dx,
void* p_dscale,
void* p_dbias) override
{
return std::make_unique<Argument>(xyLengths,
xStrides,
dxStrides,
dyStrides,
reduceDims,
bnScaleBiasMeanVarLengths,
bnScaleStrides,
bnDscaleDbiasStrides,
bnMeanVarStrides,
static_cast<const XDataType*>(p_x),
static_cast<const DyDataType*>(p_dy),
static_cast<const ScaleDataType*>(p_scale),
static_cast<const MeanVarDataType*>(p_savedMean),
static_cast<const MeanVarDataType*>(p_savedInvVar),
epsilon,
dy_elementwise_op,
static_cast<DxDataType*>(p_dx),
static_cast<DscaleDbiasDataType*>(p_dscale),
static_cast<DscaleDbiasDataType*>(p_dbias));
};
std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "Reference_BatchNorm_Backward" << std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename XDataType,
typename DyDataType,
typename DxDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp>
struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C
: public device::DeviceBatchNormBwd<4, 3, DyElementwiseOp>
{
struct Argument : public device::BaseArgument
{
Argument(const std::array<index_t, 4> xyLengths,
const std::array<index_t, 4> xStrides,
const std::array<index_t, 4> dyStrides,
const std::array<index_t, 4> dxStrides,
const std::array<int, 3> reduceDims,
const std::array<ck::index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, 1> bnScaleStrides,
const std::array<ck::index_t, 1> bnBiasStrides,
const std::array<ck::index_t, 1> bnMeanVarStrides,
const XDataType* p_x,
const DyDataType* p_dy,
const ScaleDataType* p_scale,
const MeanVarDataType* p_savedMean,
const MeanVarDataType* p_savedInvVar,
double epsilon,
const DyElementwiseOp dy_elementwise_op,
DxDataType* p_dx,
ScaleDataType* p_dscale,
BiasDataType* p_dbias)
: p_x_(p_x),
p_dy_(p_dy),
p_scale_(p_scale),
p_savedMean_(p_savedMean),
p_savedInvVar_(p_savedInvVar),
epsilon_(epsilon),
dy_elementwise_op_(dy_elementwise_op),
p_dx_(p_dx),
p_dscale_(p_dscale),
p_dbias_(p_dbias)
{
ignore = xStrides;
ignore = dyStrides;
ignore = dxStrides;
ignore = bnScaleStrides;
ignore = bnBiasStrides;
ignore = bnMeanVarStrides;
if(xyLengths.size() != 4 || bnScaleBiasMeanVarLengths.size() != 1 ||
bnScaleBiasMeanVarLengths[0] != xyLengths[3])
throw std::runtime_error("Invalid tensor dimensions!");
if(reduceDims[0] != 0 || reduceDims[1] != 1 || reduceDims[2] != 2)
throw std::runtime_error("Invalid reduce dimensions!");
n_ = xyLengths[0];
h_ = xyLengths[1];
w_ = xyLengths[2];
c_ = xyLengths[3];
haveSavedMeanInvVar_ = (p_savedMean != nullptr && p_savedInvVar != nullptr);
}
const XDataType* p_x_;
const DyDataType* p_dy_;
const ScaleDataType* p_scale_;
const MeanVarDataType* p_savedMean_;
const MeanVarDataType* p_savedInvVar_;
double epsilon_;
const DyElementwiseOp dy_elementwise_op_;
DxDataType* p_dx_;
ScaleDataType* p_dscale_;
BiasDataType* p_dbias_;
bool haveSavedMeanInvVar_;
index_t n_, h_, w_, c_;
};
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg)
{
auto thread_reduce_func = [&](auto iC) {
AccDataType reduceSize = type_convert<AccDataType>(arg.n_) *
type_convert<AccDataType>(arg.h_) *
type_convert<AccDataType>(arg.w_);
index_t offset_C = iC;
AccDataType mean;
AccDataType invVar;
if(arg.haveSavedMeanInvVar_)
{
mean = arg.p_savedMean_[offset_C];
invVar = arg.p_savedInvVar_[offset_C];
}
else
{
AccDataType meansquare;
meansquare = type_convert<AccDataType>(0.0f);
mean = type_convert<AccDataType>(0.0f);
// compute mean, meanquare, variance, inv-variance
for(index_t iN = 0; iN < arg.n_; iN++)
{
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
for(index_t iH = 0; iH < arg.h_; iH++)
{
index_t offset_H = iH * arg.w_ * arg.c_;
for(index_t iW = 0; iW < arg.w_; iW++)
{
index_t offset_W = iW * arg.c_;
auto offset = offset_N + offset_H + offset_W + offset_C;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
mean += x;
meansquare += x * x;
};
}
};
mean = mean / reduceSize;
meansquare = meansquare / reduceSize;
AccDataType variance = meansquare - mean * mean;
invVar = type_convert<AccDataType>(1.0f) /
std::sqrt(type_convert<AccDataType>(arg.epsilon_) + variance);
};
AccDataType dbias = type_convert<AccDataType>(0.0f); // Sum on NHW of dy
AccDataType dscale = type_convert<AccDataType>(0.0f); // Sum on NHW of dy * norm_x
// 1) calculate dy * (x - mean) * inv-variance
// 2) calculate sum(dy) on NHW dimensions
// 3) calculate sum(dy * norm_x) on NHW dimensions
for(index_t iN = 0; iN < arg.n_; iN++)
{
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
for(index_t iH = 0; iH < arg.h_; iH++)
{
index_t offset_H = iH * arg.w_ * arg.c_;
for(index_t iW = 0; iW < arg.w_; iW++)
{
index_t offset_W = iW * arg.c_;
auto offset = offset_N + offset_H + offset_W + offset_C;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
AccDataType norm_x = (x - mean) * invVar;
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[offset]);
arg.dy_elementwise_op_(dy, dy);
dbias += dy;
dscale += norm_x * dy;
};
}
};
arg.p_dscale_[offset_C] = type_convert<ScaleDataType>(dscale);
arg.p_dbias_[offset_C] = type_convert<BiasDataType>(dbias);
AccDataType scale = type_convert<AccDataType>(arg.p_scale_[offset_C]);
AccDataType multiplier =
type_convert<AccDataType>(1.0f) / reduceSize * invVar * scale;
// 1) calculate tmp = dscale * (x - mean) * inv-variance
// 2) calculate dx = 1/nhw * inv-variance * scale * (nhw * dy - dbias - tmp)
for(index_t iN = 0; iN < arg.n_; iN++)
{
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
for(index_t iH = 0; iH < arg.h_; iH++)
{
index_t offset_H = iH * arg.w_ * arg.c_;
for(index_t iW = 0; iW < arg.w_; iW++)
{
index_t offset_W = iW * arg.c_;
auto offset = offset_N + offset_H + offset_W + offset_C;
AccDataType x = type_convert<AccDataType>(arg.p_x_[offset]);
AccDataType norm_x = (x - mean) * invVar;
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[offset]);
arg.dy_elementwise_op_(dy, dy);
AccDataType tmpVal = norm_x * dscale;
AccDataType dx = multiplier * (reduceSize * dy - dbias - tmpVal);
arg.p_dx_[offset] = type_convert<XDataType>(dx);
};
}
};
};
std::size_t num_thread = std::thread::hardware_concurrency();
std::size_t work_per_thread = (arg.c_ + num_thread - 1) / num_thread;
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t ic_begin = it * work_per_thread;
std::size_t ic_end = std::min(static_cast<int>((it + 1) * work_per_thread), arg.c_);
auto f = [=] {
for(std::size_t ic = ic_begin; ic < ic_end; ++ic)
{
thread_reduce_func(ic);
}
};
threads[it] = joinable_thread(f);
}
return (0.0f);
};
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
};
};
bool IsSupportedArgument(const device::BaseArgument* p_arg) override
{
(void)p_arg;
return (true);
};
std::unique_ptr<device::BaseArgument>
MakeArgumentPointer(const std::array<index_t, 4> xyLengths,
const std::array<index_t, 4> xStrides,
const std::array<index_t, 4> dyStrides,
const std::array<index_t, 4> dxStrides,
const std::array<int, 3> reduceDims,
const std::array<ck::index_t, 1> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, 1> bnScaleStrides,
const std::array<ck::index_t, 1> bnBiasStrides,
const std::array<ck::index_t, 1> bnMeanVarStrides,
const void* p_x,
const void* p_dy,
const void* p_scale,
const void* p_savedMean,
const void* p_savedInvVar,
double epsilon,
const DyElementwiseOp dy_elementwise_op,
void* p_dx,
void* p_dscale,
void* p_dbias) override
{
return std::make_unique<Argument>(xyLengths,
xStrides,
dyStrides,
dxStrides,
reduceDims,
bnScaleBiasMeanVarLengths,
bnScaleStrides,
bnBiasStrides,
bnMeanVarStrides,
static_cast<const XDataType*>(p_x),
static_cast<const DyDataType*>(p_dy),
static_cast<const ScaleDataType*>(p_scale),
static_cast<const MeanVarDataType*>(p_savedMean),
static_cast<const MeanVarDataType*>(p_savedInvVar),
epsilon,
dy_elementwise_op,
static_cast<DxDataType*>(p_dx),
static_cast<ScaleDataType*>(p_dscale),
static_cast<BiasDataType*>(p_dbias));
};
std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "Reference_BatchNorm_Backward_NHWC_C<" << std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
...@@ -26,9 +26,9 @@ using Empty_Tuple = ck::Tuple<>; ...@@ -26,9 +26,9 @@ using Empty_Tuple = ck::Tuple<>;
using F16_Tuple = ck::Tuple<F16>; using F16_Tuple = ck::Tuple<F16>;
using F16_F16_Tuple = ck::Tuple<F16, F16>; using F16_F16_Tuple = ck::Tuple<F16, F16>;
using F32_Tuple = ck::Tuple<F32>; using F32_Tuple = ck::Tuple<F32>;
using I32_Tuple = ck::Tuple<I32>;
using I32_Tuple = ck::Tuple<I32>; using I32_F32_Tuple = ck::Tuple<I32, F32>;
// GEMM layout // GEMM layout
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
...@@ -78,8 +78,9 @@ using NHWGK = ck::tensor_layout::convolution::NHWGK; ...@@ -78,8 +78,9 @@ using NHWGK = ck::tensor_layout::convolution::NHWGK;
using NDHWGK = ck::tensor_layout::convolution::NDHWGK; using NDHWGK = ck::tensor_layout::convolution::NDHWGK;
// //
using GK = ck::tensor_layout::convolution::G_K; using GK = ck::tensor_layout::convolution::G_K;
using GK_TUPLE = ck::Tuple<GK>; using GK_Tuple = ck::Tuple<GK>;
using GK_GK_Tuple = ck::Tuple<GK, GK>;
// pointwise functor // pointwise functor
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...@@ -97,6 +98,13 @@ template <typename Activation> ...@@ -97,6 +98,13 @@ template <typename Activation>
using Add_Activation_Mul_Clamp = using Add_Activation_Mul_Clamp =
ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Activation>; ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Activation>;
template <typename Activation>
using Activation_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp<Activation>;
template <typename Activation>
using Add_Activation_Mul2_Clamp =
ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp<Activation>;
template <typename DeviceOp, typename Tag = void> template <typename DeviceOp, typename Tag = void>
struct DeviceOperationInstanceFactory; struct DeviceOperationInstanceFactory;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// FP16
void add_device_batchnorm_backward_rank_4_3_f16_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<F16, F32, F32, F32, F16, F32, F32, PassThrough, 4, 3>>>&);
// FP32
void add_device_batchnorm_backward_rank_4_3_f32_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<F32, F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
// BF16
void add_device_batchnorm_backward_rank_4_3_bf16_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<BF16, F32, F32, F32, BF16, F32, F32, PassThrough, 4, 3>>>&);
// FP64
void add_device_batchnorm_backward_rank_4_3_f64_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<F64, F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&);
template <typename XDataType,
typename DxDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
index_t NumReduceDim>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumReduceDim>>
{
using DeviceOp = DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumReduceDim>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<XDataType, F16> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, F16> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F32> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, F32> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_f32_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, BF16> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_bf16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F64> && is_same_v<DxDataType, F64> &&
is_same_v<DyDataType, F64> && is_same_v<AccDataType, F64> &&
is_same_v<ScaleDataType, F64> && is_same_v<DscaleDbiasDataType, F64> &&
is_same_v<MeanVarDataType, F64>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_f64_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -131,6 +131,47 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances( ...@@ -131,6 +131,47 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// grouped conv2d forward, NHWGC/GKYXC/NHWGK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
...@@ -273,11 +314,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -273,11 +314,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, float>) is_same_v<OutDataType, float>)
{ {
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>) is_same_v<OutDataType, half_t>)
{ {
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<WeiDataType, ck::bhalf_t> &&
...@@ -289,6 +332,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -289,6 +332,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs);
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs);
} }
} }
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> && else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwd<2,
GNHWC,
GKYXC,
GNHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwd<2,
GNHWC,
GKYXC,
GNHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwd<2,
GNHWC,
GKYXC,
GNHWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwd<
NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceGroupedConvFwd<NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void add_device_conv2d_bias_perchannel_quantization_int8_instances(
std::vector<
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Activation_Mul2_Clamp<PassThrough>>>>&
instances);
void add_device_conv2d_bias_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_GK_Tuple,
GNHWK,
int8_t,
int8_t,
I32_F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Add_Activation_Mul2_Clamp<Relu>>>>&
instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename DsDataType,
typename OutDataType,
typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul2_Clamp<Activation>>>
{
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Add_Activation_Mul2_Clamp<Activation>>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_GK_Tuple> &&
is_same_v<OutLayout, GNHWK>)
{
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<DsDataType, I32_F32_Tuple> && is_same_v<OutDataType, int8_t>)
{
if constexpr(is_same_v<Activation, PassThrough>)
add_device_conv2d_bias_perchannel_quantization_int8_instances(op_ptrs);
else if constexpr(is_same_v<Activation, Relu>)
add_device_conv2d_bias_relu_perchannel_quantization_int8_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -23,7 +23,7 @@ void add_device_conv2d_bias_perlayer_quantization_int8_instances( ...@@ -23,7 +23,7 @@ void add_device_conv2d_bias_perlayer_quantization_int8_instances(
std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, GNHWC,
GKYXC, GKYXC,
GK_TUPLE, GK_Tuple,
GNHWK, GNHWK,
int8_t, int8_t,
int8_t, int8_t,
...@@ -38,7 +38,7 @@ void add_device_conv2d_bias_relu_perlayer_quantization_int8_instances( ...@@ -38,7 +38,7 @@ void add_device_conv2d_bias_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC, GNHWC,
GKYXC, GKYXC,
GK_TUPLE, GK_Tuple,
GNHWK, GNHWK,
int8_t, int8_t,
int8_t, int8_t,
...@@ -91,7 +91,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -91,7 +91,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> && if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_TUPLE> && is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> &&
is_same_v<OutLayout, GNHWK>) is_same_v<OutLayout, GNHWK>)
{ {
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void add_device_conv2d_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_Tuple,
GNHWK,
int8_t,
int8_t,
F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Activation_Mul2_Clamp<PassThrough>>>>&
instances);
void add_device_conv2d_relu_perchannel_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
GK_Tuple,
GNHWK,
int8_t,
int8_t,
F32_Tuple,
int8_t,
PassThrough,
PassThrough,
Activation_Mul2_Clamp<Relu>>>>&
instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename DsDataType,
typename OutDataType,
typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
NumDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
DsDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Activation_Mul2_Clamp<Activation>>>
{
using DeviceOp = DeviceGroupedConvFwdMultipleD<NumDimSpatial,
InLayout,
WeiLayout,
GK_Tuple,
OutLayout,
InDataType,
WeiDataType,
F32_Tuple,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Activation_Mul2_Clamp<Activation>>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<DsLayout, GK_Tuple> &&
is_same_v<OutLayout, GNHWK>)
{
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
if constexpr(is_same_v<Activation, PassThrough>)
add_device_conv2d_perchannel_quantization_int8_instances(op_ptrs);
else if constexpr(is_same_v<Activation, Relu>)
add_device_conv2d_relu_perchannel_quantization_int8_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment