Unverified Commit 1462ee22 authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into gridwise_2d

parents 2c4305b2 d1567094
......@@ -11,6 +11,7 @@
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
......@@ -163,17 +164,16 @@ bool run_grouped_conv_fwd(bool do_verification,
// do Conv
auto conv = DeviceConvNDFwdInstance{};
auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(
in_device_buf.GetDeviceBuffer(),
auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{bias_device_buf.GetDeviceBuffer()},
{bias_device_buf.GetDeviceBuffer()},
out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{{d0_g_n_k_wos_lengths}},
std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{{d0_g_n_k_wos_strides}},
{d0_g_n_k_wos_lengths},
{d0_g_n_k_wos_strides},
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
......@@ -235,8 +235,8 @@ bool run_grouped_conv_fwd(bool do_verification,
out_device_buf.FromDevice(out_device.mData.data());
pass &= ck::utils::check_err(
out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
pass &=
ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
}
return (pass ? 0 : 1);
......
......@@ -11,6 +11,7 @@
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
......@@ -150,14 +151,14 @@ bool run_grouped_conv_fwd(bool do_verification,
auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(),
std::array<const void*, 0>{},
{},
out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{{}},
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{{}},
{},
{},
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
......@@ -213,8 +214,8 @@ bool run_grouped_conv_fwd(bool do_verification,
out_device_buf.FromDevice(out_device.mData.data());
pass &= ck::utils::check_err(
out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
pass &=
ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
}
return (pass ? 0 : 1);
......
......@@ -25,7 +25,7 @@
// check GPU target
#ifdef __HIP_DEVICE_COMPILE__
#if !(defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx1030__))
defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__))
#error Not supported target
#endif
#endif
......@@ -38,6 +38,8 @@
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx1030__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx1100__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x10020000
#endif
// FMA instruction
......@@ -62,6 +64,13 @@
#define CK_USE_AMD_MFMA_BF16_1K_OP
#endif
// WMMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_WMMA
#elif defined(__gfx1100__) // for GPU code
#define CK_USE_AMD_WMMA
#endif
// buffer load
#define CK_USE_AMD_BUFFER_LOAD 1
......
......@@ -13,7 +13,16 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim, typename DyElementwiseOp>
template <typename XDataType,
typename DxDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct DeviceBatchNormBwd : public BaseOperator
{
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
......@@ -26,7 +35,7 @@ struct DeviceBatchNormBwd : public BaseOperator
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
const void* p_x,
const void* p_dy,
......@@ -42,9 +51,26 @@ struct DeviceBatchNormBwd : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t Rank, index_t NumBatchNormReduceDim, typename DyElementwiseOp>
using DeviceBatchNormBwdPtr =
std::unique_ptr<DeviceBatchNormBwd<Rank, NumBatchNormReduceDim, DyElementwiseOp>>;
template <typename XDataType,
typename DxDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
using DeviceBatchNormBwdPtr = std::unique_ptr<DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumBatchNormReduceDim>>;
} // namespace device
} // namespace tensor_operation
......
......@@ -27,7 +27,7 @@ template <typename XDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
......@@ -42,11 +42,19 @@ template <typename XDataType,
index_t XSrcVectorSize,
index_t DySrcVectorSize,
index_t DxDstVectorSize,
index_t ScaleSrcDstVectorSize,
index_t BiasDstVectorSize,
index_t ScaleSrcVectorSize,
index_t DscaleDbiasDstVectorSize,
index_t MeanVarSrcVectorSize>
struct DeviceBatchNormBwdImpl
: public DeviceBatchNormBwd<Rank, NumBatchNormReduceDim, DyElementwiseOp>
struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumBatchNormReduceDim>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
......@@ -194,7 +202,7 @@ struct DeviceBatchNormBwdImpl
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
const XDataType* p_x,
const DyDataType* p_dy,
......@@ -204,11 +212,11 @@ struct DeviceBatchNormBwdImpl
const DyElementwiseOp dy_elementwise_op,
double epsilon,
DxDataType* p_dx,
ScaleDataType* p_dscale,
BiasDataType* p_dbias)
DscaleDbiasDataType* p_dscale,
DscaleDbiasDataType* p_dbias)
: bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides),
bnBiasStrides_(bnBiasStrides),
bnDscaleDbiasStrides_(bnDscaleDbiasStrides),
bnMeanVarStrides_(bnMeanVarStrides),
p_x_(p_x),
p_dy_(p_dy),
......@@ -272,8 +280,8 @@ struct DeviceBatchNormBwdImpl
MakeXY2dDescriptor(xyLengths_, dxStrides_, blkGroupSize, numBlockTileIteration);
scale_grid_desc_m =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides);
bias_grid_desc_m =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides);
dscale_dbias_grid_desc_m =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnDscaleDbiasStrides);
mean_var_grid_desc_m =
MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides);
}
......@@ -289,7 +297,7 @@ struct DeviceBatchNormBwdImpl
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnDscaleDbiasStrides_;
std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_;
const XDataType* p_x_;
......@@ -299,8 +307,8 @@ struct DeviceBatchNormBwdImpl
const MeanVarDataType* p_savedInvVar_;
const DyElementwiseOp dy_elementwise_op_;
DxDataType* p_dx_;
ScaleDataType* p_dscale_;
BiasDataType* p_dbias_;
DscaleDbiasDataType* p_dscale_;
DscaleDbiasDataType* p_dbias_;
long_index_t invariant_length;
long_index_t reduce_length;
......@@ -313,7 +321,7 @@ struct DeviceBatchNormBwdImpl
XYGridDesc_M_K dy_grid_desc_m_k;
XYGridDesc_M_K dx_grid_desc_m_k;
ScaleBiasGridDesc_M scale_grid_desc_m;
ScaleBiasGridDesc_M bias_grid_desc_m;
ScaleBiasGridDesc_M dscale_dbias_grid_desc_m;
MeanVarGridDesc_M mean_var_grid_desc_m;
void* workspace_mean;
......@@ -337,11 +345,11 @@ struct DeviceBatchNormBwdImpl
{
// workspace for the partial reduced result for dscale
workspace_size +=
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(ScaleDataType) + 64;
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64;
// workspace for the partial reduced result for dbias
workspace_size +=
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(BiasDataType) + 64;
pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64;
if(!pArg_->haveSavedMeanInvVar_)
{
......@@ -379,7 +387,7 @@ struct DeviceBatchNormBwdImpl
// setup buffer for the partial reduced result for dscale
pArg_->workspace_reduce_dscale = pArg_->p_workspace_;
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(ScaleDataType);
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType);
space_sz = math::integer_least_multiple(space_sz, 64);
// setup buffer for the partial reduced result for dbias
......@@ -388,7 +396,7 @@ struct DeviceBatchNormBwdImpl
if(UseMultiblockInK && pArg_->blkGroupSize > 1)
{
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(BiasDataType);
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType);
space_sz = math::integer_least_multiple(space_sz, 64);
// setup buffer for welford intermediate mean
......@@ -454,7 +462,7 @@ struct DeviceBatchNormBwdImpl
DyDataType,
AccDataType,
ScaleDataType,
BiasDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
......@@ -477,7 +485,7 @@ struct DeviceBatchNormBwdImpl
DxDataType,
AccDataType,
ScaleDataType,
BiasDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
......@@ -493,8 +501,8 @@ struct DeviceBatchNormBwdImpl
XSrcVectorSize,
DySrcVectorSize,
DxDstVectorSize,
ScaleSrcDstVectorSize,
BiasDstVectorSize,
ScaleSrcVectorSize,
DscaleDbiasDstVectorSize,
MeanVarSrcVectorSize>;
if(UseMultiblockInK && arg.blkGroupSize > 1)
......@@ -553,7 +561,7 @@ struct DeviceBatchNormBwdImpl
DyDataType,
AccDataType,
ScaleDataType,
BiasDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
......@@ -568,7 +576,7 @@ struct DeviceBatchNormBwdImpl
DyDataType,
DxDataType,
ScaleDataType,
BiasDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
......@@ -614,8 +622,8 @@ struct DeviceBatchNormBwdImpl
: static_cast<MeanVarDataType*>(arg.workspace_savedInvVar),
arg.p_x_,
arg.p_dy_,
static_cast<ScaleDataType*>(arg.workspace_reduce_dscale),
static_cast<BiasDataType*>(arg.workspace_reduce_dbias));
static_cast<DscaleDbiasDataType*>(arg.workspace_reduce_dscale),
static_cast<DscaleDbiasDataType*>(arg.workspace_reduce_dbias));
avg_time += launch_and_time_kernel(
stream_config,
......@@ -629,13 +637,13 @@ struct DeviceBatchNormBwdImpl
dscale_dbias_grid_desc_m_k,
arg.mean_var_grid_desc_m,
arg.scale_grid_desc_m,
arg.bias_grid_desc_m,
arg.dscale_dbias_grid_desc_m,
arg.blkGroupSize,
arg.reduce_length,
arg.numBlockTileIteration,
numDscaleDbiasBlockTileIteration,
static_cast<const ScaleDataType*>(arg.workspace_reduce_dscale),
static_cast<const BiasDataType*>(arg.workspace_reduce_dbias),
static_cast<const DscaleDbiasDataType*>(arg.workspace_reduce_dscale),
static_cast<const DscaleDbiasDataType*>(arg.workspace_reduce_dbias),
arg.haveSavedMeanInvVar_
? arg.p_savedMean_
: static_cast<const MeanVarDataType*>(arg.workspace_savedMean),
......@@ -664,7 +672,7 @@ struct DeviceBatchNormBwdImpl
DxDataType,
AccDataType,
ScaleDataType,
BiasDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
......@@ -680,8 +688,8 @@ struct DeviceBatchNormBwdImpl
XSrcVectorSize,
DySrcVectorSize,
DxDstVectorSize,
ScaleSrcDstVectorSize,
BiasDstVectorSize,
ScaleSrcVectorSize,
DscaleDbiasDstVectorSize,
MeanVarSrcVectorSize>;
const auto kern_batchnorm_bwd = kernel_batchnorm_backward_with_blockwise_welford<
......@@ -691,7 +699,7 @@ struct DeviceBatchNormBwdImpl
DxDataType,
AccDataType,
ScaleDataType,
BiasDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
......@@ -708,7 +716,7 @@ struct DeviceBatchNormBwdImpl
arg.dy_grid_desc_m_k,
arg.dx_grid_desc_m_k,
arg.scale_grid_desc_m,
arg.bias_grid_desc_m,
arg.dscale_dbias_grid_desc_m,
arg.mean_var_grid_desc_m,
get_reduce_count_per_thread,
arg.reduce_length,
......@@ -764,16 +772,16 @@ struct DeviceBatchNormBwdImpl
return false;
};
if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcDstVectorSize != 1)
if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1)
return false;
if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasDstVectorSize != 1)
if(pArg_->bnDscaleDbiasStrides_[NumInvariantDim - 1] != 1 && DscaleDbiasDstVectorSize != 1)
return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcDstVectorSize != 0)
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0)
return false;
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasDstVectorSize != 0)
if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % DscaleDbiasDstVectorSize != 0)
return false;
if(pArg_->haveSavedMeanInvVar_)
......@@ -806,7 +814,7 @@ struct DeviceBatchNormBwdImpl
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
const std::array<ck::index_t, NumInvariantDim> bnBiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
const void* p_x,
const void* p_dy,
......@@ -826,7 +834,7 @@ struct DeviceBatchNormBwdImpl
reduceDims,
bnScaleBiasMeanVarLengths,
bnScaleStrides,
bnBiasStrides,
bnDscaleDbiasStrides,
bnMeanVarStrides,
static_cast<const XDataType*>(p_x),
static_cast<const DyDataType*>(p_dy),
......@@ -836,8 +844,8 @@ struct DeviceBatchNormBwdImpl
dy_elementwise_op,
epsilon,
static_cast<DxDataType*>(p_dx),
static_cast<ScaleDataType*>(p_dscale),
static_cast<BiasDataType*>(p_dbias));
static_cast<DscaleDbiasDataType*>(p_dscale),
static_cast<DscaleDbiasDataType*>(p_dbias));
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......@@ -854,7 +862,7 @@ struct DeviceBatchNormBwdImpl
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "XDyDxVectorDim_" << XDyDxVectorDim << ",";
str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcDstVectorSize << "_bias_" << BiasDstVectorSize << "_mean_var_" << MeanVarSrcVectorSize << "_Dx_" << DxDstVectorSize << ">";
str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << DscaleDbiasDstVectorSize << "_mean_var_" << MeanVarSrcVectorSize << "_Dx_" << DxDstVectorSize << ">";
// clang-format on
return str.str();
......
......@@ -187,6 +187,22 @@ struct AddRelu
const float a = x0 + type_convert<float>(x1);
y = a > 0.0f ? a : 0.0f;
};
template <>
__host__ __device__ constexpr void
operator()<int, int, int8_t>(int& y, const int& x0, const int8_t& x1) const
{
const int8_t a = x0 + x1;
y = a > 0 ? a : 0;
};
template <>
__host__ __device__ constexpr void
operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
{
const int8_t a = x0 + x1;
y = a > 0 ? a : 0;
};
};
struct AddHardswish
......
......@@ -10,8 +10,8 @@ namespace element_wise {
template <typename Activation>
struct Activation_Mul_Clamp
{
Activation_Mul_Clamp(float multiplier, Activation activationOp)
: multiplier_(multiplier), activationOp_(activationOp)
Activation_Mul_Clamp(float requantScale, Activation activationOp)
: requantScale_(requantScale), activationOp_(activationOp)
{
}
......@@ -19,7 +19,7 @@ struct Activation_Mul_Clamp
{
float x_fp32 = ck::type_convert<float>(x);
activationOp_(x_fp32, x_fp32);
float y_fp32 = math::clamp(multiplier_ * x_fp32, -128.f, 127.f);
float y_fp32 = math::clamp(requantScale_ * x_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
......@@ -28,10 +28,29 @@ struct Activation_Mul_Clamp
// We might type_convert to int8 after lambda in someplace
float x_fp32 = ck::type_convert<float>(x);
activationOp_(x_fp32, x_fp32);
y = math::clamp(multiplier_ * x_fp32, -128.f, 127.f);
y = math::clamp(requantScale_ * x_fp32, -128.f, 127.f);
}
float requantScale_;
Activation activationOp_;
};
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// relu, leaky relu ...etc
template <typename Activation>
struct Activation_Mul2_Clamp
{
Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x, const float& requantScale) const
{
float y_fp32 = ck::type_convert<float>(x);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
float multiplier_;
Activation activationOp_;
};
......@@ -39,21 +58,40 @@ struct Activation_Mul_Clamp
template <typename Activation>
struct Add_Activation_Mul_Clamp
{
Add_Activation_Mul_Clamp(float multiplier, Activation activationOp)
: multiplier_(multiplier), activationOp_(activationOp)
Add_Activation_Mul_Clamp(float requantScale, Activation activationOp)
: requantScale_(requantScale), activationOp_(activationOp)
{
}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x1, const int32_t& x2) const
operator()(int8_t& y, const int32_t& x, const int32_t& bias) const
{
float y_fp32 = ck::type_convert<float>(x + bias);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
float requantScale_;
Activation activationOp_;
};
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// relu, leaky relu ...etc
template <typename Activation>
struct Add_Activation_Mul2_Clamp
{
Add_Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x, const int32_t& bias, const float& requantScale) const
{
float y_fp32 = ck::type_convert<float>(x1 + x2);
float y_fp32 = ck::type_convert<float>(x + bias);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(multiplier_ * y_fp32, -128.f, 127.f);
y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
float multiplier_;
Activation activationOp_;
};
......@@ -61,23 +99,23 @@ struct Add_Activation_Mul_Clamp
template <typename Activation>
struct Add_Mul_Activation_Mul_Clamp
{
Add_Mul_Activation_Mul_Clamp(float multiplier1, float multiplier2, Activation activationOp)
: multiplier1_(multiplier1), multiplier2_(multiplier2), activationOp_(activationOp)
Add_Mul_Activation_Mul_Clamp(float requantScale1, float requantScale2, Activation activationOp)
: requantScale1_(requantScale1), requantScale2_(requantScale2), activationOp_(activationOp)
{
}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x1, const int32_t& x2) const
operator()(int8_t& y, const int32_t& x, const int32_t& bias) const
{
float y_fp32 = ck::type_convert<float>(x1 + x2);
y_fp32 = multiplier1_ * y_fp32;
float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = requantScale1_ * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(multiplier2_ * y_fp32, -128.f, 127.f);
y_fp32 = math::clamp(requantScale2_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
float multiplier1_;
float multiplier2_;
float requantScale1_;
float requantScale2_;
Activation activationOp_;
};
......
......@@ -16,7 +16,7 @@ template <typename GridwiseReduceSecondHalfBatchNormBackwardFinal_,
typename DyDataType,
typename DxDataType,
typename ScaleDataType,
typename BiasDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
......@@ -35,8 +35,8 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration,
index_t num_dscale_dbias_k_block_tile_iteration,
const ScaleDataType* const __restrict__ p_reduce_dscale,
const BiasDataType* const __restrict__ p_reduce_dbias,
const DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
const DscaleDbiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean,
const MeanVarDataType* const __restrict__ p_inv_var,
const XDataType* const __restrict__ p_x,
......@@ -44,8 +44,8 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
const ScaleDataType* const __restrict__ p_scale,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
DscaleDbiasDataType* const __restrict__ p_dscale,
DscaleDbiasDataType* const __restrict__ p_dbias)
{
GridwiseReduceSecondHalfBatchNormBackwardFinal_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k,
......@@ -76,7 +76,7 @@ template <typename XDataType,
typename DxDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
......@@ -92,8 +92,8 @@ template <typename XDataType,
index_t XSrcVectorSize,
index_t DySrcVectorSize,
index_t DxDstVectorSize,
index_t ScaleSrcDstVectorSize,
index_t BiasDstVectorSize,
index_t ScaleSrcVectorSize,
index_t DscaleDbiasDstVectorSize,
index_t MeanVarSrcVectorSize>
struct GridwiseReduceSecondHalfBatchNormBackwardFinal
{
......@@ -155,13 +155,13 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
const DscaleDbiasGridDesc_M_K& dscale_dbias_grid_desc_m_k,
const MeanVarGridDesc_M& mean_var_grid_desc_m,
const ScaleBiasGridDesc_M& scale_grid_desc_m,
const ScaleBiasGridDesc_M& bias_grid_desc_m,
const ScaleBiasGridDesc_M& dscale_dbias_grid_desc_m,
index_t blkgroup_size,
long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration,
index_t num_dscale_dbias_k_block_tile_iteration,
const ScaleDataType* const __restrict__ p_reduce_dscale,
const BiasDataType* const __restrict__ p_reduce_dbias,
const DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
const DscaleDbiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean,
const MeanVarDataType* const __restrict__ p_inv_var,
const XDataType* const __restrict__ p_x,
......@@ -169,8 +169,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
const ScaleDataType* const __restrict__ p_scale,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
DscaleDbiasDataType* const __restrict__ p_dscale,
DscaleDbiasDataType* const __restrict__ p_dbias)
{
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
......@@ -222,8 +222,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
// Step 1: do final reduction of dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// clang-format on
auto threadwise_dscale_load_m_k =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
auto threadwise_dscale_dbias_load_m_k =
ThreadwiseTensorSliceTransfer_v2<DscaleDbiasDataType,
AccDataType,
DscaleDbiasGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
......@@ -238,54 +238,20 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
auto threadwise_dbias_load_m_k =
ThreadwiseTensorSliceTransfer_v2<BiasDataType,
AccDataType,
DscaleDbiasGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
dscale_dbias_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
auto threadwise_dscale_store_m =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType,
decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
scale_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dbias_store_m =
auto threadwise_dscale_dbias_store_m =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType,
DscaleDbiasDataType,
decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
BiasDstVectorSize,
DscaleDbiasDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
bias_grid_desc_m,
dscale_dbias_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
......@@ -297,10 +263,10 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
p_reduce_dbias, dscale_dbias_grid_desc_m_k.GetElementSpaceSize());
auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dscale, scale_grid_desc_m.GetElementSpaceSize());
p_dscale, dscale_dbias_grid_desc_m.GetElementSpaceSize());
auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbias, bias_grid_desc_m.GetElementSpaceSize());
p_dbias, dscale_dbias_grid_desc_m.GetElementSpaceSize());
constexpr auto dscale_dbias_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize * 1);
......@@ -313,13 +279,13 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
for(index_t reducedTiles = 0; reducedTiles < num_dscale_dbias_k_block_tile_iteration;
++reducedTiles)
{
threadwise_dscale_load_m_k.Run(dscale_dbias_grid_desc_m_k,
threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_dscale_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dscale_thread_buf);
threadwise_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k,
reduce_dbias_global_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
......@@ -328,9 +294,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
ThreadwiseReduce::Reduce(reduce_dscale_thread_buf, dscale_thread_buf);
ThreadwiseReduce::Reduce(reduce_dbias_thread_buf, dbias_thread_buf);
threadwise_dscale_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
dscale_dbias_thread_copy_step_m_k);
threadwise_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
threadwise_dscale_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k,
dscale_dbias_thread_copy_step_m_k);
}
......@@ -343,16 +307,16 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I));
});
threadwise_dscale_store_m.Run(thread_buffer_desc_m,
threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m,
make_tuple(I0),
dscale_thread_buf,
scale_grid_desc_m,
dscale_dbias_grid_desc_m,
dscale_global_buf);
threadwise_dbias_store_m.Run(thread_buffer_desc_m,
threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m,
make_tuple(I0),
dbias_thread_buf,
bias_grid_desc_m,
dscale_dbias_grid_desc_m,
dbias_global_buf);
// clang-format off
......@@ -418,7 +382,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcDstVectorSize,
ScaleSrcVectorSize,
1,
true>(
scale_grid_desc_m,
......
......@@ -17,7 +17,7 @@ template <typename GridwiseWelfordSecondHalfReduceFirstHalf_,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
......@@ -45,8 +45,8 @@ __global__ void kernel_welford_second_half_reduce_first_half(
MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
ScaleDataType* const __restrict__ p_reduce_dscale,
BiasDataType* const __restrict__ p_reduce_dbias)
DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
DscaleDbiasDataType* const __restrict__ p_reduce_dbias)
{
GridwiseWelfordSecondHalfReduceFirstHalf_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k,
......@@ -76,7 +76,7 @@ template <typename XDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
......@@ -174,8 +174,8 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
ScaleDataType* const __restrict__ p_reduce_dscale,
BiasDataType* const __restrict__ p_reduce_dbias)
DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
DscaleDbiasDataType* const __restrict__ p_reduce_dbias)
{
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
......@@ -511,28 +511,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
BlockwiseReduce::Reduce(reduce_work_buf, reduce_dbias_thread_buf(I));
});
auto threadwise_dscale_store =
auto threadwise_dscale_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType,
decltype(thread_buffer_desc_m_1),
DscaleDbiasGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
dscale_dbias_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto threadwise_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType,
DscaleDbiasDataType,
decltype(thread_buffer_desc_m_1),
DscaleDbiasGridDesc_M_G,
PassThroughOp,
......@@ -557,13 +538,13 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
if(thread_k_cluster_id == 0)
{
threadwise_dscale_store.Run(thread_buffer_desc_m_1,
threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dscale_thread_buf,
dscale_dbias_grid_desc_m_g,
reduce_dscale_global_buf);
threadwise_dbias_store.Run(thread_buffer_desc_m_1,
threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
reduce_dbias_thread_buf,
dscale_dbias_grid_desc_m_g,
......
......@@ -796,6 +796,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
});
}
else
{
static_for<0, acc_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
......
......@@ -21,7 +21,7 @@ template <typename GridwiseBatchrNormBackwardWithBlockwiseWelford_,
typename DxDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
......@@ -33,7 +33,7 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
const XYGridDesc_M_K dy_grid_desc_m_k,
const XYGridDesc_M_K dx_grid_desc_m_k,
const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M bias_grid_desc_m,
const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m,
const MeanVarGridDesc_M mean_var_grid_desc_m,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
long_index_t reduce_size,
......@@ -47,14 +47,14 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
const MeanVarDataType* const __restrict__ p_savedInvVar,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
DscaleDbiasDataType* const __restrict__ p_dscale,
DscaleDbiasDataType* const __restrict__ p_dbias)
{
GridwiseBatchrNormBackwardWithBlockwiseWelford_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k,
dx_grid_desc_m_k,
scale_grid_desc_m,
bias_grid_desc_m,
dscale_dbias_grid_desc_m,
mean_var_grid_desc_m,
get_reduce_count_per_thread,
reduce_size,
......@@ -77,7 +77,7 @@ template <typename XDataType,
typename DxDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
......@@ -93,8 +93,8 @@ template <typename XDataType,
index_t XSrcVectorSize,
index_t DySrcVectorSize,
index_t DxDstVectorSize,
index_t ScaleSrcDstVectorSize,
index_t BiasDstVectorSize,
index_t ScaleSrcVectorSize,
index_t DscaleDbiasDstVectorSize,
index_t MeanVarSrcVectorSize>
struct GridwiseBatchNormBackwardWithBlockwiseWelford
{
......@@ -165,7 +165,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
const XYGridDesc_M_K dy_grid_desc_m_k,
const XYGridDesc_M_K dx_grid_desc_m_k,
const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M bias_grid_desc_m,
const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m,
const MeanVarGridDesc_M mean_var_grid_desc_m,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
long_index_t reduce_size,
......@@ -179,8 +179,8 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
const MeanVarDataType* const __restrict__ p_savedInvVar,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
DscaleDbiasDataType* const __restrict__ p_dscale,
DscaleDbiasDataType* const __restrict__ p_dbias)
{
using ck::math::sqrt;
......@@ -253,7 +253,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
dy_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
......@@ -271,7 +271,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
InMemoryDataOperationEnum::Set,
1,
true>(
dy_grid_desc_m_k,
dx_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize),
......@@ -285,45 +285,27 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcDstVectorSize,
ScaleSrcVectorSize,
1,
true>(
scale_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
auto threadwise_dscale_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType,
decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
scale_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dbias_store =
auto threadwise_dscale_dbias_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType,
DscaleDbiasDataType,
decltype(thread_buffer_desc_m),
ScaleBiasGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
BiasDstVectorSize,
DscaleDbiasDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
bias_grid_desc_m,
dscale_dbias_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
......@@ -344,10 +326,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
p_scale, scale_grid_desc_m.GetElementSpaceSize());
auto dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dscale, scale_grid_desc_m.GetElementSpaceSize());
p_dscale, dscale_dbias_grid_desc_m.GetElementSpaceSize());
auto dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dbias, bias_grid_desc_m.GetElementSpaceSize());
p_dbias, dscale_dbias_grid_desc_m.GetElementSpaceSize());
// clang-format off
// Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance)
......@@ -487,16 +469,16 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
if(thread_k_cluster_id == 0)
{
threadwise_dscale_store.Run(thread_buffer_desc_m,
threadwise_dscale_dbias_store.Run(thread_buffer_desc_m,
make_tuple(I0),
dscale_thread_buf,
scale_grid_desc_m,
dscale_dbias_grid_desc_m,
dscale_global_buf);
threadwise_dbias_store.Run(thread_buffer_desc_m,
threadwise_dscale_dbias_store.Run(thread_buffer_desc_m,
make_tuple(I0),
dbias_thread_buf,
bias_grid_desc_m,
dscale_dbias_grid_desc_m,
dbias_global_buf);
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseMultiblockWelfordFirstHalf_,
typename XDataType,
typename MeanVarDataType,
typename XGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename GetReduceCountPerThreadFunctor>
__global__ void kernel_multiblock_welford_first_half(
const XGridDesc_M_K x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const p_welford_mean,
MeanVarDataType* const p_welford_variance,
int32_t* const p_welford_count)
{
GridwiseMultiblockWelfordFirstHalf_::Run(x_grid_desc_m_k,
mean_var_count_grid_desc_m_g,
get_reduce_count_per_thread,
num_k_block_tile_iteration,
p_x,
p_welford_mean,
p_welford_variance,
p_welford_count);
};
template <typename XDataType,
typename AccDataType,
typename MeanVarDataType,
typename XGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename GetReduceCountPerThreadFunctor,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcCountSrcVectorDim,
index_t XSrcCountSrcVectorSize>
struct GridwiseMultiblockWelfordFirstHalf
{
static_assert((XSrcCountSrcVectorDim == 0 && MThreadSliceSize % XSrcCountSrcVectorSize == 0) ||
(XSrcCountSrcVectorDim == 1 &&
KThreadSliceSize % XSrcCountSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XSrcCountSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
false>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const p_welford_mean,
MeanVarDataType* const p_welford_variance,
int32_t* const p_welford_count)
{
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf;
const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1);
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / blkgroup_size;
const index_t block_local_id = block_global_id % blkgroup_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
XGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcCountSrcVectorDim,
XSrcCountSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_welford_mean_var_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto threadwise_welford_count_store =
ThreadwiseTensorSliceTransfer_v1r3<int32_t,
int32_t,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford();
threadwise_welford.max_count_ =
get_reduce_count_per_thread(block_local_id, thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
});
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_welford.Run(x_thread_buf, welford_mean_thread_buf, welford_var_thread_buf);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
welford_count_thread_buf(I) = threadwise_welford.cur_count_;
BlockwiseWelford::Run(
welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
});
if(thread_k_cluster_id == 0)
{
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_mean_thread_buf,
mean_var_count_grid_desc_m_g,
welford_mean_global_val_buf);
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_var_thread_buf,
mean_var_count_grid_desc_m_g,
welford_var_global_val_buf);
threadwise_welford_count_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_count_thread_buf,
mean_var_count_grid_desc_m_g,
welford_count_global_val_buf);
};
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_WMMA_HPP
#define CK_AMD_WMMA_HPP
#include "data_type.hpp"
// TODO: Add arch limitation
namespace ck {
// wave32 only
// src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w32;
template <>
struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
}
};
// src: bf16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_bf16_w32;
template <>
struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
{
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<float8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
}
};
// src: fp16, dst: fp16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w32;
template <index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
{
template <class FloatC>
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel);
}
};
// src: bf16, dst: bf16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w32;
template <index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
{
template <class FloatC>
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
}
};
// src: iu8, dst: i32
template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w32;
template <bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
{
template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<int32x8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
neg_a,
bit_cast<int32x4_t>(reg_a),
neg_b,
bit_cast<int32x4_t>(reg_b),
reg_c.template AsType<int32x8_t>()[Number<0>{}],
clamp);
}
};
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment