Unverified Commit 1462ee22 authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into gridwise_2d

parents 2c4305b2 d1567094
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
...@@ -163,26 +164,25 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -163,26 +164,25 @@ bool run_grouped_conv_fwd(bool do_verification,
// do Conv // do Conv
auto conv = DeviceConvNDFwdInstance{}; auto conv = DeviceConvNDFwdInstance{};
auto invoker = conv.MakeInvoker(); auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument( auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
in_device_buf.GetDeviceBuffer(), wei_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(), {bias_device_buf.GetDeviceBuffer()},
std::array<const void*, 1>{bias_device_buf.GetDeviceBuffer()}, out_device_buf.GetDeviceBuffer(),
out_device_buf.GetDeviceBuffer(), a_g_n_c_wis_lengths,
a_g_n_c_wis_lengths, a_g_n_c_wis_strides,
a_g_n_c_wis_strides, b_g_k_c_xs_lengths,
b_g_k_c_xs_lengths, b_g_k_c_xs_strides,
b_g_k_c_xs_strides, {d0_g_n_k_wos_lengths},
std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{{d0_g_n_k_wos_lengths}}, {d0_g_n_k_wos_strides},
std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{{d0_g_n_k_wos_strides}}, e_g_n_k_wos_lengths,
e_g_n_k_wos_lengths, e_g_n_k_wos_strides,
e_g_n_k_wos_strides, conv_filter_strides,
conv_filter_strides, conv_filter_dilations,
conv_filter_dilations, input_left_pads,
input_left_pads, input_right_pads,
input_right_pads, in_element_op,
in_element_op, wei_element_op,
wei_element_op, out_element_op);
out_element_op);
if(!conv.IsSupportedArgument(argument)) if(!conv.IsSupportedArgument(argument))
{ {
...@@ -235,8 +235,8 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -235,8 +235,8 @@ bool run_grouped_conv_fwd(bool do_verification,
out_device_buf.FromDevice(out_device.mData.data()); out_device_buf.FromDevice(out_device.mData.data());
pass &= ck::utils::check_err( pass &=
out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f); ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
} }
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
...@@ -150,14 +151,14 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -150,14 +151,14 @@ bool run_grouped_conv_fwd(bool do_verification,
auto invoker = conv.MakeInvoker(); auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(), wei_device_buf.GetDeviceBuffer(),
std::array<const void*, 0>{}, {},
out_device_buf.GetDeviceBuffer(), out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths, a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{{}}, {},
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{{}}, {},
e_g_n_k_wos_lengths, e_g_n_k_wos_lengths,
e_g_n_k_wos_strides, e_g_n_k_wos_strides,
conv_filter_strides, conv_filter_strides,
...@@ -213,8 +214,8 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -213,8 +214,8 @@ bool run_grouped_conv_fwd(bool do_verification,
out_device_buf.FromDevice(out_device.mData.data()); out_device_buf.FromDevice(out_device.mData.data());
pass &= ck::utils::check_err( pass &=
out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f); ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
} }
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
// check GPU target // check GPU target
#ifdef __HIP_DEVICE_COMPILE__ #ifdef __HIP_DEVICE_COMPILE__
#if !(defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ #if !(defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx1030__)) defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__))
#error Not supported target #error Not supported target
#endif #endif
#endif #endif
...@@ -38,6 +38,8 @@ ...@@ -38,6 +38,8 @@
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx1030__) // for GPU code #elif defined(__gfx1030__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx1100__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x10020000
#endif #endif
// FMA instruction // FMA instruction
...@@ -62,6 +64,13 @@ ...@@ -62,6 +64,13 @@
#define CK_USE_AMD_MFMA_BF16_1K_OP #define CK_USE_AMD_MFMA_BF16_1K_OP
#endif #endif
// WMMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_WMMA
#elif defined(__gfx1100__) // for GPU code
#define CK_USE_AMD_WMMA
#endif
// buffer load // buffer load
#define CK_USE_AMD_BUFFER_LOAD 1 #define CK_USE_AMD_BUFFER_LOAD 1
......
...@@ -13,7 +13,16 @@ namespace ck { ...@@ -13,7 +13,16 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim, typename DyElementwiseOp> template <typename XDataType,
typename DxDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct DeviceBatchNormBwd : public BaseOperator struct DeviceBatchNormBwd : public BaseOperator
{ {
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim; static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
...@@ -26,7 +35,7 @@ struct DeviceBatchNormBwd : public BaseOperator ...@@ -26,7 +35,7 @@ struct DeviceBatchNormBwd : public BaseOperator
const std::array<int, NumBatchNormReduceDim> reduceDims, const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths, const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides, const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides, const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides, const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* p_dy, const void* p_dy,
...@@ -42,9 +51,26 @@ struct DeviceBatchNormBwd : public BaseOperator ...@@ -42,9 +51,26 @@ struct DeviceBatchNormBwd : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <index_t Rank, index_t NumBatchNormReduceDim, typename DyElementwiseOp> template <typename XDataType,
using DeviceBatchNormBwdPtr = typename DxDataType,
std::unique_ptr<DeviceBatchNormBwd<Rank, NumBatchNormReduceDim, DyElementwiseOp>>; typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
using DeviceBatchNormBwdPtr = std::unique_ptr<DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumBatchNormReduceDim>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -27,7 +27,7 @@ template <typename XDataType, ...@@ -27,7 +27,7 @@ template <typename XDataType,
typename DyDataType, typename DyDataType,
typename AccDataType, typename AccDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename DscaleDbiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp, typename DyElementwiseOp,
index_t Rank, index_t Rank,
...@@ -42,11 +42,19 @@ template <typename XDataType, ...@@ -42,11 +42,19 @@ template <typename XDataType,
index_t XSrcVectorSize, index_t XSrcVectorSize,
index_t DySrcVectorSize, index_t DySrcVectorSize,
index_t DxDstVectorSize, index_t DxDstVectorSize,
index_t ScaleSrcDstVectorSize, index_t ScaleSrcVectorSize,
index_t BiasDstVectorSize, index_t DscaleDbiasDstVectorSize,
index_t MeanVarSrcVectorSize> index_t MeanVarSrcVectorSize>
struct DeviceBatchNormBwdImpl struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<XDataType,
: public DeviceBatchNormBwd<Rank, NumBatchNormReduceDim, DyElementwiseOp> DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumBatchNormReduceDim>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
...@@ -194,7 +202,7 @@ struct DeviceBatchNormBwdImpl ...@@ -194,7 +202,7 @@ struct DeviceBatchNormBwdImpl
const std::array<int, NumBatchNormReduceDim> reduceDims, const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths, const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides, const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides, const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides, const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
const XDataType* p_x, const XDataType* p_x,
const DyDataType* p_dy, const DyDataType* p_dy,
...@@ -204,11 +212,11 @@ struct DeviceBatchNormBwdImpl ...@@ -204,11 +212,11 @@ struct DeviceBatchNormBwdImpl
const DyElementwiseOp dy_elementwise_op, const DyElementwiseOp dy_elementwise_op,
double epsilon, double epsilon,
DxDataType* p_dx, DxDataType* p_dx,
ScaleDataType* p_dscale, DscaleDbiasDataType* p_dscale,
BiasDataType* p_dbias) DscaleDbiasDataType* p_dbias)
: bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths), : bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides), bnScaleStrides_(bnScaleStrides),
bnBiasStrides_(bnBiasStrides), bnDscaleDbiasStrides_(bnDscaleDbiasStrides),
bnMeanVarStrides_(bnMeanVarStrides), bnMeanVarStrides_(bnMeanVarStrides),
p_x_(p_x), p_x_(p_x),
p_dy_(p_dy), p_dy_(p_dy),
...@@ -272,8 +280,8 @@ struct DeviceBatchNormBwdImpl ...@@ -272,8 +280,8 @@ struct DeviceBatchNormBwdImpl
MakeXY2dDescriptor(xyLengths_, dxStrides_, blkGroupSize, numBlockTileIteration); MakeXY2dDescriptor(xyLengths_, dxStrides_, blkGroupSize, numBlockTileIteration);
scale_grid_desc_m = scale_grid_desc_m =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides); MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides);
bias_grid_desc_m = dscale_dbias_grid_desc_m =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides); MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnDscaleDbiasStrides);
mean_var_grid_desc_m = mean_var_grid_desc_m =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides); MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides);
} }
...@@ -289,7 +297,7 @@ struct DeviceBatchNormBwdImpl ...@@ -289,7 +297,7 @@ struct DeviceBatchNormBwdImpl
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_; std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_; std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides_; std::array<index_t, Rank - NumBatchNormReduceDim> bnDscaleDbiasStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_; std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_;
const XDataType* p_x_; const XDataType* p_x_;
...@@ -299,8 +307,8 @@ struct DeviceBatchNormBwdImpl ...@@ -299,8 +307,8 @@ struct DeviceBatchNormBwdImpl
const MeanVarDataType* p_savedInvVar_; const MeanVarDataType* p_savedInvVar_;
const DyElementwiseOp dy_elementwise_op_; const DyElementwiseOp dy_elementwise_op_;
DxDataType* p_dx_; DxDataType* p_dx_;
ScaleDataType* p_dscale_; DscaleDbiasDataType* p_dscale_;
BiasDataType* p_dbias_; DscaleDbiasDataType* p_dbias_;
long_index_t invariant_length; long_index_t invariant_length;
long_index_t reduce_length; long_index_t reduce_length;
...@@ -313,7 +321,7 @@ struct DeviceBatchNormBwdImpl ...@@ -313,7 +321,7 @@ struct DeviceBatchNormBwdImpl
XYGridDesc_M_K dy_grid_desc_m_k; XYGridDesc_M_K dy_grid_desc_m_k;
XYGridDesc_M_K dx_grid_desc_m_k; XYGridDesc_M_K dx_grid_desc_m_k;
ScaleBiasGridDesc_M scale_grid_desc_m; ScaleBiasGridDesc_M scale_grid_desc_m;
ScaleBiasGridDesc_M bias_grid_desc_m; ScaleBiasGridDesc_M dscale_dbias_grid_desc_m;
MeanVarGridDesc_M mean_var_grid_desc_m; MeanVarGridDesc_M mean_var_grid_desc_m;
void* workspace_mean; void* workspace_mean;
...@@ -337,11 +345,11 @@ struct DeviceBatchNormBwdImpl ...@@ -337,11 +345,11 @@ struct DeviceBatchNormBwdImpl
{ {
// workspace for the partial reduced result for dscale // workspace for the partial reduced result for dscale
workspace_size += workspace_size +=
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(ScaleDataType) + 64; pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64;
// workspace for the partial reduced result for dbias // workspace for the partial reduced result for dbias
workspace_size += workspace_size +=
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(BiasDataType) + 64; pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64;
if(!pArg_->haveSavedMeanInvVar_) if(!pArg_->haveSavedMeanInvVar_)
{ {
...@@ -379,7 +387,7 @@ struct DeviceBatchNormBwdImpl ...@@ -379,7 +387,7 @@ struct DeviceBatchNormBwdImpl
// setup buffer for the partial reduced result for dscale // setup buffer for the partial reduced result for dscale
pArg_->workspace_reduce_dscale = pArg_->p_workspace_; pArg_->workspace_reduce_dscale = pArg_->p_workspace_;
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(ScaleDataType); space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType);
space_sz = math::integer_least_multiple(space_sz, 64); space_sz = math::integer_least_multiple(space_sz, 64);
// setup buffer for the partial reduced result for dbias // setup buffer for the partial reduced result for dbias
...@@ -388,7 +396,7 @@ struct DeviceBatchNormBwdImpl ...@@ -388,7 +396,7 @@ struct DeviceBatchNormBwdImpl
if(UseMultiblockInK && pArg_->blkGroupSize > 1) if(UseMultiblockInK && pArg_->blkGroupSize > 1)
{ {
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(BiasDataType); space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType);
space_sz = math::integer_least_multiple(space_sz, 64); space_sz = math::integer_least_multiple(space_sz, 64);
// setup buffer for welford intermediate mean // setup buffer for welford intermediate mean
...@@ -454,7 +462,7 @@ struct DeviceBatchNormBwdImpl ...@@ -454,7 +462,7 @@ struct DeviceBatchNormBwdImpl
DyDataType, DyDataType,
AccDataType, AccDataType,
ScaleDataType, ScaleDataType,
BiasDataType, DscaleDbiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp, DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
...@@ -477,7 +485,7 @@ struct DeviceBatchNormBwdImpl ...@@ -477,7 +485,7 @@ struct DeviceBatchNormBwdImpl
DxDataType, DxDataType,
AccDataType, AccDataType,
ScaleDataType, ScaleDataType,
BiasDataType, DscaleDbiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp, DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
...@@ -493,8 +501,8 @@ struct DeviceBatchNormBwdImpl ...@@ -493,8 +501,8 @@ struct DeviceBatchNormBwdImpl
XSrcVectorSize, XSrcVectorSize,
DySrcVectorSize, DySrcVectorSize,
DxDstVectorSize, DxDstVectorSize,
ScaleSrcDstVectorSize, ScaleSrcVectorSize,
BiasDstVectorSize, DscaleDbiasDstVectorSize,
MeanVarSrcVectorSize>; MeanVarSrcVectorSize>;
if(UseMultiblockInK && arg.blkGroupSize > 1) if(UseMultiblockInK && arg.blkGroupSize > 1)
...@@ -553,7 +561,7 @@ struct DeviceBatchNormBwdImpl ...@@ -553,7 +561,7 @@ struct DeviceBatchNormBwdImpl
DyDataType, DyDataType,
AccDataType, AccDataType,
ScaleDataType, ScaleDataType,
BiasDataType, DscaleDbiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp, DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
...@@ -568,7 +576,7 @@ struct DeviceBatchNormBwdImpl ...@@ -568,7 +576,7 @@ struct DeviceBatchNormBwdImpl
DyDataType, DyDataType,
DxDataType, DxDataType,
ScaleDataType, ScaleDataType,
BiasDataType, DscaleDbiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp, DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
...@@ -614,8 +622,8 @@ struct DeviceBatchNormBwdImpl ...@@ -614,8 +622,8 @@ struct DeviceBatchNormBwdImpl
: static_cast<MeanVarDataType*>(arg.workspace_savedInvVar), : static_cast<MeanVarDataType*>(arg.workspace_savedInvVar),
arg.p_x_, arg.p_x_,
arg.p_dy_, arg.p_dy_,
static_cast<ScaleDataType*>(arg.workspace_reduce_dscale), static_cast<DscaleDbiasDataType*>(arg.workspace_reduce_dscale),
static_cast<BiasDataType*>(arg.workspace_reduce_dbias)); static_cast<DscaleDbiasDataType*>(arg.workspace_reduce_dbias));
avg_time += launch_and_time_kernel( avg_time += launch_and_time_kernel(
stream_config, stream_config,
...@@ -629,13 +637,13 @@ struct DeviceBatchNormBwdImpl ...@@ -629,13 +637,13 @@ struct DeviceBatchNormBwdImpl
dscale_dbias_grid_desc_m_k, dscale_dbias_grid_desc_m_k,
arg.mean_var_grid_desc_m, arg.mean_var_grid_desc_m,
arg.scale_grid_desc_m, arg.scale_grid_desc_m,
arg.bias_grid_desc_m, arg.dscale_dbias_grid_desc_m,
arg.blkGroupSize, arg.blkGroupSize,
arg.reduce_length, arg.reduce_length,
arg.numBlockTileIteration, arg.numBlockTileIteration,
numDscaleDbiasBlockTileIteration, numDscaleDbiasBlockTileIteration,
static_cast<const ScaleDataType*>(arg.workspace_reduce_dscale), static_cast<const DscaleDbiasDataType*>(arg.workspace_reduce_dscale),
static_cast<const BiasDataType*>(arg.workspace_reduce_dbias), static_cast<const DscaleDbiasDataType*>(arg.workspace_reduce_dbias),
arg.haveSavedMeanInvVar_ arg.haveSavedMeanInvVar_
? arg.p_savedMean_ ? arg.p_savedMean_
: static_cast<const MeanVarDataType*>(arg.workspace_savedMean), : static_cast<const MeanVarDataType*>(arg.workspace_savedMean),
...@@ -664,7 +672,7 @@ struct DeviceBatchNormBwdImpl ...@@ -664,7 +672,7 @@ struct DeviceBatchNormBwdImpl
DxDataType, DxDataType,
AccDataType, AccDataType,
ScaleDataType, ScaleDataType,
BiasDataType, DscaleDbiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp, DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
...@@ -680,8 +688,8 @@ struct DeviceBatchNormBwdImpl ...@@ -680,8 +688,8 @@ struct DeviceBatchNormBwdImpl
XSrcVectorSize, XSrcVectorSize,
DySrcVectorSize, DySrcVectorSize,
DxDstVectorSize, DxDstVectorSize,
ScaleSrcDstVectorSize, ScaleSrcVectorSize,
BiasDstVectorSize, DscaleDbiasDstVectorSize,
MeanVarSrcVectorSize>; MeanVarSrcVectorSize>;
const auto kern_batchnorm_bwd = kernel_batchnorm_backward_with_blockwise_welford< const auto kern_batchnorm_bwd = kernel_batchnorm_backward_with_blockwise_welford<
...@@ -691,7 +699,7 @@ struct DeviceBatchNormBwdImpl ...@@ -691,7 +699,7 @@ struct DeviceBatchNormBwdImpl
DxDataType, DxDataType,
AccDataType, AccDataType,
ScaleDataType, ScaleDataType,
BiasDataType, DscaleDbiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp, DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
...@@ -708,7 +716,7 @@ struct DeviceBatchNormBwdImpl ...@@ -708,7 +716,7 @@ struct DeviceBatchNormBwdImpl
arg.dy_grid_desc_m_k, arg.dy_grid_desc_m_k,
arg.dx_grid_desc_m_k, arg.dx_grid_desc_m_k,
arg.scale_grid_desc_m, arg.scale_grid_desc_m,
arg.bias_grid_desc_m, arg.dscale_dbias_grid_desc_m,
arg.mean_var_grid_desc_m, arg.mean_var_grid_desc_m,
get_reduce_count_per_thread, get_reduce_count_per_thread,
arg.reduce_length, arg.reduce_length,
...@@ -764,16 +772,16 @@ struct DeviceBatchNormBwdImpl ...@@ -764,16 +772,16 @@ struct DeviceBatchNormBwdImpl
return false; return false;
}; };
if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcDstVectorSize != 1) if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1)
return false; return false;
if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasDstVectorSize != 1) if(pArg_->bnDscaleDbiasStrides_[NumInvariantDim - 1] != 1 && DscaleDbiasDstVectorSize != 1)
return false; return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcDstVectorSize != 0) if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0)
return false; return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasDstVectorSize != 0) if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % DscaleDbiasDstVectorSize != 0)
return false; return false;
if(pArg_->haveSavedMeanInvVar_) if(pArg_->haveSavedMeanInvVar_)
...@@ -806,7 +814,7 @@ struct DeviceBatchNormBwdImpl ...@@ -806,7 +814,7 @@ struct DeviceBatchNormBwdImpl
const std::array<int, NumBatchNormReduceDim> reduceDims, const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths, const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides, const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides, const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides, const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
const void* p_x, const void* p_x,
const void* p_dy, const void* p_dy,
...@@ -826,7 +834,7 @@ struct DeviceBatchNormBwdImpl ...@@ -826,7 +834,7 @@ struct DeviceBatchNormBwdImpl
reduceDims, reduceDims,
bnScaleBiasMeanVarLengths, bnScaleBiasMeanVarLengths,
bnScaleStrides, bnScaleStrides,
bnBiasStrides, bnDscaleDbiasStrides,
bnMeanVarStrides, bnMeanVarStrides,
static_cast<const XDataType*>(p_x), static_cast<const XDataType*>(p_x),
static_cast<const DyDataType*>(p_dy), static_cast<const DyDataType*>(p_dy),
...@@ -836,8 +844,8 @@ struct DeviceBatchNormBwdImpl ...@@ -836,8 +844,8 @@ struct DeviceBatchNormBwdImpl
dy_elementwise_op, dy_elementwise_op,
epsilon, epsilon,
static_cast<DxDataType*>(p_dx), static_cast<DxDataType*>(p_dx),
static_cast<ScaleDataType*>(p_dscale), static_cast<DscaleDbiasDataType*>(p_dscale),
static_cast<BiasDataType*>(p_dbias)); static_cast<DscaleDbiasDataType*>(p_dbias));
}; };
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
...@@ -854,7 +862,7 @@ struct DeviceBatchNormBwdImpl ...@@ -854,7 +862,7 @@ struct DeviceBatchNormBwdImpl
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "XDyDxVectorDim_" << XDyDxVectorDim << ","; str << "XDyDxVectorDim_" << XDyDxVectorDim << ",";
str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcDstVectorSize << "_bias_" << BiasDstVectorSize << "_mean_var_" << MeanVarSrcVectorSize << "_Dx_" << DxDstVectorSize << ">"; str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << DscaleDbiasDstVectorSize << "_mean_var_" << MeanVarSrcVectorSize << "_Dx_" << DxDstVectorSize << ">";
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -187,6 +187,22 @@ struct AddRelu ...@@ -187,6 +187,22 @@ struct AddRelu
const float a = x0 + type_convert<float>(x1); const float a = x0 + type_convert<float>(x1);
y = a > 0.0f ? a : 0.0f; y = a > 0.0f ? a : 0.0f;
}; };
template <>
__host__ __device__ constexpr void
operator()<int, int, int8_t>(int& y, const int& x0, const int8_t& x1) const
{
const int8_t a = x0 + x1;
y = a > 0 ? a : 0;
};
template <>
__host__ __device__ constexpr void
operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
{
const int8_t a = x0 + x1;
y = a > 0 ? a : 0;
};
}; };
struct AddHardswish struct AddHardswish
......
...@@ -10,8 +10,8 @@ namespace element_wise { ...@@ -10,8 +10,8 @@ namespace element_wise {
template <typename Activation> template <typename Activation>
struct Activation_Mul_Clamp struct Activation_Mul_Clamp
{ {
Activation_Mul_Clamp(float multiplier, Activation activationOp) Activation_Mul_Clamp(float requantScale, Activation activationOp)
: multiplier_(multiplier), activationOp_(activationOp) : requantScale_(requantScale), activationOp_(activationOp)
{ {
} }
...@@ -19,7 +19,7 @@ struct Activation_Mul_Clamp ...@@ -19,7 +19,7 @@ struct Activation_Mul_Clamp
{ {
float x_fp32 = ck::type_convert<float>(x); float x_fp32 = ck::type_convert<float>(x);
activationOp_(x_fp32, x_fp32); activationOp_(x_fp32, x_fp32);
float y_fp32 = math::clamp(multiplier_ * x_fp32, -128.f, 127.f); float y_fp32 = math::clamp(requantScale_ * x_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32); y = ck::type_convert<int8_t>(y_fp32);
} }
...@@ -28,10 +28,29 @@ struct Activation_Mul_Clamp ...@@ -28,10 +28,29 @@ struct Activation_Mul_Clamp
// We might type_convert to int8 after lambda in someplace // We might type_convert to int8 after lambda in someplace
float x_fp32 = ck::type_convert<float>(x); float x_fp32 = ck::type_convert<float>(x);
activationOp_(x_fp32, x_fp32); activationOp_(x_fp32, x_fp32);
y = math::clamp(multiplier_ * x_fp32, -128.f, 127.f); y = math::clamp(requantScale_ * x_fp32, -128.f, 127.f);
}
float requantScale_;
Activation activationOp_;
};
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// relu, leaky relu ...etc
template <typename Activation>
struct Activation_Mul2_Clamp
{
Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x, const float& requantScale) const
{
float y_fp32 = ck::type_convert<float>(x);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
} }
float multiplier_;
Activation activationOp_; Activation activationOp_;
}; };
...@@ -39,21 +58,40 @@ struct Activation_Mul_Clamp ...@@ -39,21 +58,40 @@ struct Activation_Mul_Clamp
template <typename Activation> template <typename Activation>
struct Add_Activation_Mul_Clamp struct Add_Activation_Mul_Clamp
{ {
Add_Activation_Mul_Clamp(float multiplier, Activation activationOp) Add_Activation_Mul_Clamp(float requantScale, Activation activationOp)
: multiplier_(multiplier), activationOp_(activationOp) : requantScale_(requantScale), activationOp_(activationOp)
{ {
} }
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x1, const int32_t& x2) const operator()(int8_t& y, const int32_t& x, const int32_t& bias) const
{
float y_fp32 = ck::type_convert<float>(x + bias);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
float requantScale_;
Activation activationOp_;
};
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// relu, leaky relu ...etc
template <typename Activation>
struct Add_Activation_Mul2_Clamp
{
Add_Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x, const int32_t& bias, const float& requantScale) const
{ {
float y_fp32 = ck::type_convert<float>(x1 + x2); float y_fp32 = ck::type_convert<float>(x + bias);
activationOp_(y_fp32, y_fp32); activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(multiplier_ * y_fp32, -128.f, 127.f); y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32); y = ck::type_convert<int8_t>(y_fp32);
} }
float multiplier_;
Activation activationOp_; Activation activationOp_;
}; };
...@@ -61,23 +99,23 @@ struct Add_Activation_Mul_Clamp ...@@ -61,23 +99,23 @@ struct Add_Activation_Mul_Clamp
template <typename Activation> template <typename Activation>
struct Add_Mul_Activation_Mul_Clamp struct Add_Mul_Activation_Mul_Clamp
{ {
Add_Mul_Activation_Mul_Clamp(float multiplier1, float multiplier2, Activation activationOp) Add_Mul_Activation_Mul_Clamp(float requantScale1, float requantScale2, Activation activationOp)
: multiplier1_(multiplier1), multiplier2_(multiplier2), activationOp_(activationOp) : requantScale1_(requantScale1), requantScale2_(requantScale2), activationOp_(activationOp)
{ {
} }
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x1, const int32_t& x2) const operator()(int8_t& y, const int32_t& x, const int32_t& bias) const
{ {
float y_fp32 = ck::type_convert<float>(x1 + x2); float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = multiplier1_ * y_fp32; y_fp32 = requantScale1_ * y_fp32;
activationOp_(y_fp32, y_fp32); activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(multiplier2_ * y_fp32, -128.f, 127.f); y_fp32 = math::clamp(requantScale2_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32); y = ck::type_convert<int8_t>(y_fp32);
} }
float multiplier1_; float requantScale1_;
float multiplier2_; float requantScale2_;
Activation activationOp_; Activation activationOp_;
}; };
......
...@@ -16,7 +16,7 @@ template <typename GridwiseReduceSecondHalfBatchNormBackwardFinal_, ...@@ -16,7 +16,7 @@ template <typename GridwiseReduceSecondHalfBatchNormBackwardFinal_,
typename DyDataType, typename DyDataType,
typename DxDataType, typename DxDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename DscaleDbiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp, typename DyElementwiseOp,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
...@@ -35,8 +35,8 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final( ...@@ -35,8 +35,8 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
long_index_t reduce_size, long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_dscale_dbias_k_block_tile_iteration, index_t num_dscale_dbias_k_block_tile_iteration,
const ScaleDataType* const __restrict__ p_reduce_dscale, const DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
const BiasDataType* const __restrict__ p_reduce_dbias, const DscaleDbiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean, const MeanVarDataType* const __restrict__ p_mean,
const MeanVarDataType* const __restrict__ p_inv_var, const MeanVarDataType* const __restrict__ p_inv_var,
const XDataType* const __restrict__ p_x, const XDataType* const __restrict__ p_x,
...@@ -44,8 +44,8 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final( ...@@ -44,8 +44,8 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
const ScaleDataType* const __restrict__ p_scale, const ScaleDataType* const __restrict__ p_scale,
const DyElementwiseOp dy_elementwise_op, const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx, DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale, DscaleDbiasDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias) DscaleDbiasDataType* const __restrict__ p_dbias)
{ {
GridwiseReduceSecondHalfBatchNormBackwardFinal_::Run(x_grid_desc_m_k, GridwiseReduceSecondHalfBatchNormBackwardFinal_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k, dy_grid_desc_m_k,
...@@ -76,7 +76,7 @@ template <typename XDataType, ...@@ -76,7 +76,7 @@ template <typename XDataType,
typename DxDataType, typename DxDataType,
typename AccDataType, typename AccDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename DscaleDbiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp, typename DyElementwiseOp,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
...@@ -92,8 +92,8 @@ template <typename XDataType, ...@@ -92,8 +92,8 @@ template <typename XDataType,
index_t XSrcVectorSize, index_t XSrcVectorSize,
index_t DySrcVectorSize, index_t DySrcVectorSize,
index_t DxDstVectorSize, index_t DxDstVectorSize,
index_t ScaleSrcDstVectorSize, index_t ScaleSrcVectorSize,
index_t BiasDstVectorSize, index_t DscaleDbiasDstVectorSize,
index_t MeanVarSrcVectorSize> index_t MeanVarSrcVectorSize>
struct GridwiseReduceSecondHalfBatchNormBackwardFinal struct GridwiseReduceSecondHalfBatchNormBackwardFinal
{ {
...@@ -155,13 +155,13 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -155,13 +155,13 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
const DscaleDbiasGridDesc_M_K& dscale_dbias_grid_desc_m_k, const DscaleDbiasGridDesc_M_K& dscale_dbias_grid_desc_m_k,
const MeanVarGridDesc_M& mean_var_grid_desc_m, const MeanVarGridDesc_M& mean_var_grid_desc_m,
const ScaleBiasGridDesc_M& scale_grid_desc_m, const ScaleBiasGridDesc_M& scale_grid_desc_m,
const ScaleBiasGridDesc_M& bias_grid_desc_m, const ScaleBiasGridDesc_M& dscale_dbias_grid_desc_m,
index_t blkgroup_size, index_t blkgroup_size,
long_index_t reduce_size, long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_dscale_dbias_k_block_tile_iteration, index_t num_dscale_dbias_k_block_tile_iteration,
const ScaleDataType* const __restrict__ p_reduce_dscale, const DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
const BiasDataType* const __restrict__ p_reduce_dbias, const DscaleDbiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean, const MeanVarDataType* const __restrict__ p_mean,
const MeanVarDataType* const __restrict__ p_inv_var, const MeanVarDataType* const __restrict__ p_inv_var,
const XDataType* const __restrict__ p_x, const XDataType* const __restrict__ p_x,
...@@ -169,8 +169,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -169,8 +169,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
const ScaleDataType* const __restrict__ p_scale, const ScaleDataType* const __restrict__ p_scale,
const DyElementwiseOp dy_elementwise_op, const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx, DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale, DscaleDbiasDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias) DscaleDbiasDataType* const __restrict__ p_dbias)
{ {
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
...@@ -222,24 +222,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -222,24 +222,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
// Step 1: do final reduction of dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance) // Step 1: do final reduction of dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// clang-format on // clang-format on
auto threadwise_dscale_load_m_k = auto threadwise_dscale_dbias_load_m_k =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType, ThreadwiseTensorSliceTransfer_v2<DscaleDbiasDataType,
AccDataType,
DscaleDbiasGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
dscale_dbias_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
auto threadwise_dbias_load_m_k =
ThreadwiseTensorSliceTransfer_v2<BiasDataType,
AccDataType, AccDataType,
DscaleDbiasGridDesc_M_K, DscaleDbiasGridDesc_M_K,
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
...@@ -254,38 +238,20 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -254,38 +238,20 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1)); thread_k_cluster_id * 1));
auto threadwise_dscale_store_m = auto threadwise_dscale_dbias_store_m =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType, DscaleDbiasDataType,
decltype(thread_buffer_desc_m), decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M, ScaleBiasGridDesc_M,
PassThroughOp, PassThroughOp,
ThreadBufferLengths_M, ThreadBufferLengths_M,
Sequence<0>, Sequence<0>,
0, 0,
ScaleSrcDstVectorSize, DscaleDbiasDstVectorSize,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
scale_grid_desc_m, dscale_dbias_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dbias_store_m =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType,
decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
BiasDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
bias_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize), thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{}); PassThroughOp{});
...@@ -297,10 +263,10 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -297,10 +263,10 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
p_reduce_dbias, dscale_dbias_grid_desc_m_k.GetElementSpaceSize()); p_reduce_dbias, dscale_dbias_grid_desc_m_k.GetElementSpaceSize());
auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dscale, scale_grid_desc_m.GetElementSpaceSize()); p_dscale, dscale_dbias_grid_desc_m.GetElementSpaceSize());
auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbias, bias_grid_desc_m.GetElementSpaceSize()); p_dbias, dscale_dbias_grid_desc_m.GetElementSpaceSize());
constexpr auto dscale_dbias_thread_copy_step_m_k = constexpr auto dscale_dbias_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize * 1); make_multi_index(0, KThreadClusterSize * 1);
...@@ -313,25 +279,23 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -313,25 +279,23 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
for(index_t reducedTiles = 0; reducedTiles < num_dscale_dbias_k_block_tile_iteration; for(index_t reducedTiles = 0; reducedTiles < num_dscale_dbias_k_block_tile_iteration;
++reducedTiles) ++reducedTiles)
{ {
threadwise_dscale_load_m_k.Run(dscale_dbias_grid_desc_m_k, threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_dscale_global_buf, reduce_dscale_global_buf,
thread_buffer_desc_m_1, thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
reduce_dscale_thread_buf); reduce_dscale_thread_buf);
threadwise_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k, threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_dbias_global_buf, reduce_dbias_global_buf,
thread_buffer_desc_m_1, thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
reduce_dbias_thread_buf); reduce_dbias_thread_buf);
ThreadwiseReduce::Reduce(reduce_dscale_thread_buf, dscale_thread_buf); ThreadwiseReduce::Reduce(reduce_dscale_thread_buf, dscale_thread_buf);
ThreadwiseReduce::Reduce(reduce_dbias_thread_buf, dbias_thread_buf); ThreadwiseReduce::Reduce(reduce_dbias_thread_buf, dbias_thread_buf);
threadwise_dscale_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k, threadwise_dscale_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
dscale_dbias_thread_copy_step_m_k); dscale_dbias_thread_copy_step_m_k);
threadwise_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
dscale_dbias_thread_copy_step_m_k);
} }
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
...@@ -343,17 +307,17 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -343,17 +307,17 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I)); BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I));
}); });
threadwise_dscale_store_m.Run(thread_buffer_desc_m, threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
dscale_thread_buf, dscale_thread_buf,
scale_grid_desc_m, dscale_dbias_grid_desc_m,
dscale_global_buf); dscale_global_buf);
threadwise_dbias_store_m.Run(thread_buffer_desc_m, threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
dbias_thread_buf, dbias_thread_buf,
bias_grid_desc_m, dscale_dbias_grid_desc_m,
dbias_global_buf); dbias_global_buf);
// clang-format off // clang-format off
// Step 2: calculate dx = 1/N * inv-variance * scale * (N * dy - dbias - dscale * (x - mean) * inv-variance) // Step 2: calculate dx = 1/N * inv-variance * scale * (N * dy - dbias - dscale * (x - mean) * inv-variance)
...@@ -418,7 +382,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -418,7 +382,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
ThreadBufferLengths_M, ThreadBufferLengths_M,
Sequence<0>, Sequence<0>,
0, 0,
ScaleSrcDstVectorSize, ScaleSrcVectorSize,
1, 1,
true>( true>(
scale_grid_desc_m, scale_grid_desc_m,
......
...@@ -17,7 +17,7 @@ template <typename GridwiseWelfordSecondHalfReduceFirstHalf_, ...@@ -17,7 +17,7 @@ template <typename GridwiseWelfordSecondHalfReduceFirstHalf_,
typename DyDataType, typename DyDataType,
typename AccDataType, typename AccDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename DscaleDbiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp, typename DyElementwiseOp,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
...@@ -45,8 +45,8 @@ __global__ void kernel_welford_second_half_reduce_first_half( ...@@ -45,8 +45,8 @@ __global__ void kernel_welford_second_half_reduce_first_half(
MeanVarDataType* const __restrict__ p_out_welford_inv_variance, MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x, const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy, const DyDataType* const __restrict__ p_dy,
ScaleDataType* const __restrict__ p_reduce_dscale, DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
BiasDataType* const __restrict__ p_reduce_dbias) DscaleDbiasDataType* const __restrict__ p_reduce_dbias)
{ {
GridwiseWelfordSecondHalfReduceFirstHalf_::Run(x_grid_desc_m_k, GridwiseWelfordSecondHalfReduceFirstHalf_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k, dy_grid_desc_m_k,
...@@ -76,7 +76,7 @@ template <typename XDataType, ...@@ -76,7 +76,7 @@ template <typename XDataType,
typename DyDataType, typename DyDataType,
typename AccDataType, typename AccDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename DscaleDbiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp, typename DyElementwiseOp,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
...@@ -174,8 +174,8 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -174,8 +174,8 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
MeanVarDataType* const __restrict__ p_out_welford_inv_variance, MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x, const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy, const DyDataType* const __restrict__ p_dy,
ScaleDataType* const __restrict__ p_reduce_dscale, DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
BiasDataType* const __restrict__ p_reduce_dbias) DscaleDbiasDataType* const __restrict__ p_reduce_dbias)
{ {
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
...@@ -511,28 +511,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -511,28 +511,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
BlockwiseReduce::Reduce(reduce_work_buf, reduce_dbias_thread_buf(I)); BlockwiseReduce::Reduce(reduce_work_buf, reduce_dbias_thread_buf(I));
}); });
auto threadwise_dscale_store = auto threadwise_dscale_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType, DscaleDbiasDataType,
decltype(thread_buffer_desc_m_1),
DscaleDbiasGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
dscale_dbias_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto threadwise_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType,
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
DscaleDbiasGridDesc_M_G, DscaleDbiasGridDesc_M_G,
PassThroughOp, PassThroughOp,
...@@ -557,17 +538,17 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -557,17 +538,17 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
{ {
threadwise_dscale_store.Run(thread_buffer_desc_m_1, threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
reduce_dscale_thread_buf, reduce_dscale_thread_buf,
dscale_dbias_grid_desc_m_g, dscale_dbias_grid_desc_m_g,
reduce_dscale_global_buf); reduce_dscale_global_buf);
threadwise_dbias_store.Run(thread_buffer_desc_m_1, threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0), make_tuple(I0, I0),
reduce_dbias_thread_buf, reduce_dbias_thread_buf,
dscale_dbias_grid_desc_m_g, dscale_dbias_grid_desc_m_g,
reduce_dbias_global_buf); reduce_dbias_global_buf);
}; };
}; };
}; };
......
...@@ -796,6 +796,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -796,6 +796,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
}); });
} }
else
{
static_for<0, acc_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
......
...@@ -21,7 +21,7 @@ template <typename GridwiseBatchrNormBackwardWithBlockwiseWelford_, ...@@ -21,7 +21,7 @@ template <typename GridwiseBatchrNormBackwardWithBlockwiseWelford_,
typename DxDataType, typename DxDataType,
typename AccDataType, typename AccDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename DscaleDbiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp, typename DyElementwiseOp,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
...@@ -33,7 +33,7 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford( ...@@ -33,7 +33,7 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
const XYGridDesc_M_K dy_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k,
const XYGridDesc_M_K dx_grid_desc_m_k, const XYGridDesc_M_K dx_grid_desc_m_k,
const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M bias_grid_desc_m, const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m,
const MeanVarGridDesc_M mean_var_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
long_index_t reduce_size, long_index_t reduce_size,
...@@ -47,14 +47,14 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford( ...@@ -47,14 +47,14 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
const MeanVarDataType* const __restrict__ p_savedInvVar, const MeanVarDataType* const __restrict__ p_savedInvVar,
const DyElementwiseOp dy_elementwise_op, const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx, DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale, DscaleDbiasDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias) DscaleDbiasDataType* const __restrict__ p_dbias)
{ {
GridwiseBatchrNormBackwardWithBlockwiseWelford_::Run(x_grid_desc_m_k, GridwiseBatchrNormBackwardWithBlockwiseWelford_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k, dy_grid_desc_m_k,
dx_grid_desc_m_k, dx_grid_desc_m_k,
scale_grid_desc_m, scale_grid_desc_m,
bias_grid_desc_m, dscale_dbias_grid_desc_m,
mean_var_grid_desc_m, mean_var_grid_desc_m,
get_reduce_count_per_thread, get_reduce_count_per_thread,
reduce_size, reduce_size,
...@@ -77,7 +77,7 @@ template <typename XDataType, ...@@ -77,7 +77,7 @@ template <typename XDataType,
typename DxDataType, typename DxDataType,
typename AccDataType, typename AccDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename DscaleDbiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp, typename DyElementwiseOp,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
...@@ -93,8 +93,8 @@ template <typename XDataType, ...@@ -93,8 +93,8 @@ template <typename XDataType,
index_t XSrcVectorSize, index_t XSrcVectorSize,
index_t DySrcVectorSize, index_t DySrcVectorSize,
index_t DxDstVectorSize, index_t DxDstVectorSize,
index_t ScaleSrcDstVectorSize, index_t ScaleSrcVectorSize,
index_t BiasDstVectorSize, index_t DscaleDbiasDstVectorSize,
index_t MeanVarSrcVectorSize> index_t MeanVarSrcVectorSize>
struct GridwiseBatchNormBackwardWithBlockwiseWelford struct GridwiseBatchNormBackwardWithBlockwiseWelford
{ {
...@@ -165,7 +165,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -165,7 +165,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
const XYGridDesc_M_K dy_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k,
const XYGridDesc_M_K dx_grid_desc_m_k, const XYGridDesc_M_K dx_grid_desc_m_k,
const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M bias_grid_desc_m, const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m,
const MeanVarGridDesc_M mean_var_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
long_index_t reduce_size, long_index_t reduce_size,
...@@ -179,8 +179,8 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -179,8 +179,8 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
const MeanVarDataType* const __restrict__ p_savedInvVar, const MeanVarDataType* const __restrict__ p_savedInvVar,
const DyElementwiseOp dy_elementwise_op, const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx, DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale, DscaleDbiasDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias) DscaleDbiasDataType* const __restrict__ p_dbias)
{ {
using ck::math::sqrt; using ck::math::sqrt;
...@@ -253,7 +253,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -253,7 +253,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
XSrcVectorSize, XSrcVectorSize,
1, 1,
true>( true>(
x_grid_desc_m_k, dy_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * KThreadSliceSize));
...@@ -271,7 +271,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -271,7 +271,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
dy_grid_desc_m_k, dx_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize), thread_k_cluster_id * KThreadSliceSize),
...@@ -285,45 +285,27 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -285,45 +285,27 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
ThreadBufferLengths_M, ThreadBufferLengths_M,
Sequence<0>, Sequence<0>,
0, 0,
ScaleSrcDstVectorSize, ScaleSrcVectorSize,
1, 1,
true>( true>(
scale_grid_desc_m, scale_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize)); thread_m_cluster_id * MThreadSliceSize));
auto threadwise_dscale_store = auto threadwise_dscale_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType,
decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
scale_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType, DscaleDbiasDataType,
decltype(thread_buffer_desc_m), decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M, ScaleBiasGridDesc_M,
PassThroughOp, PassThroughOp,
ThreadBufferLengths_M, ThreadBufferLengths_M,
Sequence<0>, Sequence<0>,
0, 0,
BiasDstVectorSize, DscaleDbiasDstVectorSize,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
bias_grid_desc_m, dscale_dbias_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize), thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{}); PassThroughOp{});
...@@ -344,10 +326,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -344,10 +326,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
p_scale, scale_grid_desc_m.GetElementSpaceSize()); p_scale, scale_grid_desc_m.GetElementSpaceSize());
auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dscale, scale_grid_desc_m.GetElementSpaceSize()); p_dscale, dscale_dbias_grid_desc_m.GetElementSpaceSize());
auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbias, bias_grid_desc_m.GetElementSpaceSize()); p_dbias, dscale_dbias_grid_desc_m.GetElementSpaceSize());
// clang-format off // clang-format off
// Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance) // Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance)
...@@ -487,17 +469,17 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -487,17 +469,17 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
{ {
threadwise_dscale_store.Run(thread_buffer_desc_m, threadwise_dscale_dbias_store.Run(thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
dscale_thread_buf, dscale_thread_buf,
scale_grid_desc_m, dscale_dbias_grid_desc_m,
dscale_global_buf); dscale_global_buf);
threadwise_dbias_store.Run(thread_buffer_desc_m, threadwise_dscale_dbias_store.Run(thread_buffer_desc_m,
make_tuple(I0), make_tuple(I0),
dbias_thread_buf, dbias_thread_buf,
bias_grid_desc_m, dscale_dbias_grid_desc_m,
dbias_global_buf); dbias_global_buf);
}; };
// clang-format off // clang-format off
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseMultiblockWelfordFirstHalf_,
typename XDataType,
typename MeanVarDataType,
typename XGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename GetReduceCountPerThreadFunctor>
__global__ void kernel_multiblock_welford_first_half(
const XGridDesc_M_K x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const p_welford_mean,
MeanVarDataType* const p_welford_variance,
int32_t* const p_welford_count)
{
GridwiseMultiblockWelfordFirstHalf_::Run(x_grid_desc_m_k,
mean_var_count_grid_desc_m_g,
get_reduce_count_per_thread,
num_k_block_tile_iteration,
p_x,
p_welford_mean,
p_welford_variance,
p_welford_count);
};
template <typename XDataType,
typename AccDataType,
typename MeanVarDataType,
typename XGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename GetReduceCountPerThreadFunctor,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcCountSrcVectorDim,
index_t XSrcCountSrcVectorSize>
struct GridwiseMultiblockWelfordFirstHalf
{
static_assert((XSrcCountSrcVectorDim == 0 && MThreadSliceSize % XSrcCountSrcVectorSize == 0) ||
(XSrcCountSrcVectorDim == 1 &&
KThreadSliceSize % XSrcCountSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XSrcCountSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
false>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const p_welford_mean,
MeanVarDataType* const p_welford_variance,
int32_t* const p_welford_count)
{
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf;
const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1);
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / blkgroup_size;
const index_t block_local_id = block_global_id % blkgroup_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
XGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcCountSrcVectorDim,
XSrcCountSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_welford_mean_var_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto threadwise_welford_count_store =
ThreadwiseTensorSliceTransfer_v1r3<int32_t,
int32_t,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford();
threadwise_welford.max_count_ =
get_reduce_count_per_thread(block_local_id, thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
});
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_welford.Run(x_thread_buf, welford_mean_thread_buf, welford_var_thread_buf);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
welford_count_thread_buf(I) = threadwise_welford.cur_count_;
BlockwiseWelford::Run(
welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
});
if(thread_k_cluster_id == 0)
{
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_mean_thread_buf,
mean_var_count_grid_desc_m_g,
welford_mean_global_val_buf);
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_var_thread_buf,
mean_var_count_grid_desc_m_g,
welford_var_global_val_buf);
threadwise_welford_count_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_count_thread_buf,
mean_var_count_grid_desc_m_g,
welford_count_global_val_buf);
};
}
};
} // namespace ck
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment