Unverified Commit a4e34d88 authored by rocking5566's avatar rocking5566 Committed by GitHub
Browse files

Merge branch 'develop' into gemm_layernorm_welford

parents 6916e3e4 ad541ad6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename XDataType,
typename DxDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
struct DeviceBatchNormBwd : public BaseOperator
{
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, Rank> xyLengths,
const std::array<index_t, Rank> xStrides,
const std::array<index_t, Rank> dyStrides,
const std::array<index_t, Rank> dxStrides,
const std::array<int, NumBatchNormReduceDim> reduceDims,
const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
const void* p_x,
const void* p_dy,
const void* p_scale,
const void* p_savedMean,
const void* p_savedInvVar,
double epsilon,
const DyElementwiseOp dy_elementwise_op,
void* p_dx,
void* p_dscale,
void* p_dbias) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename XDataType,
typename DxDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim>
using DeviceBatchNormBwdPtr = std::unique_ptr<DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumBatchNormReduceDim>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -700,7 +700,7 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -700,7 +700,7 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << ", " << BK1 << ", "
<< getGemmSpecializationString(GemmSpec) << getGemmSpecializationString(GemmSpec)
......
...@@ -10,8 +10,8 @@ namespace element_wise { ...@@ -10,8 +10,8 @@ namespace element_wise {
template <typename Activation> template <typename Activation>
struct Activation_Mul_Clamp struct Activation_Mul_Clamp
{ {
Activation_Mul_Clamp(float multiplier, Activation activationOp) Activation_Mul_Clamp(float requantScale, Activation activationOp)
: multiplier_(multiplier), activationOp_(activationOp) : requantScale_(requantScale), activationOp_(activationOp)
{ {
} }
...@@ -19,7 +19,7 @@ struct Activation_Mul_Clamp ...@@ -19,7 +19,7 @@ struct Activation_Mul_Clamp
{ {
float x_fp32 = ck::type_convert<float>(x); float x_fp32 = ck::type_convert<float>(x);
activationOp_(x_fp32, x_fp32); activationOp_(x_fp32, x_fp32);
float y_fp32 = math::clamp(multiplier_ * x_fp32, -128.f, 127.f); float y_fp32 = math::clamp(requantScale_ * x_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32); y = ck::type_convert<int8_t>(y_fp32);
} }
...@@ -28,10 +28,29 @@ struct Activation_Mul_Clamp ...@@ -28,10 +28,29 @@ struct Activation_Mul_Clamp
// We might type_convert to int8 after lambda in someplace // We might type_convert to int8 after lambda in someplace
float x_fp32 = ck::type_convert<float>(x); float x_fp32 = ck::type_convert<float>(x);
activationOp_(x_fp32, x_fp32); activationOp_(x_fp32, x_fp32);
y = math::clamp(multiplier_ * x_fp32, -128.f, 127.f); y = math::clamp(requantScale_ * x_fp32, -128.f, 127.f);
}
float requantScale_;
Activation activationOp_;
};
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// relu, leaky relu ...etc
template <typename Activation>
struct Activation_Mul2_Clamp
{
Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x, const float& requantScale) const
{
float y_fp32 = ck::type_convert<float>(x);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
} }
float multiplier_;
Activation activationOp_; Activation activationOp_;
}; };
...@@ -39,21 +58,40 @@ struct Activation_Mul_Clamp ...@@ -39,21 +58,40 @@ struct Activation_Mul_Clamp
template <typename Activation> template <typename Activation>
struct Add_Activation_Mul_Clamp struct Add_Activation_Mul_Clamp
{ {
Add_Activation_Mul_Clamp(float multiplier, Activation activationOp) Add_Activation_Mul_Clamp(float requantScale, Activation activationOp)
: multiplier_(multiplier), activationOp_(activationOp) : requantScale_(requantScale), activationOp_(activationOp)
{ {
} }
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x1, const int32_t& x2) const operator()(int8_t& y, const int32_t& x, const int32_t& bias) const
{
float y_fp32 = ck::type_convert<float>(x + bias);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
float requantScale_;
Activation activationOp_;
};
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// relu, leaky relu ...etc
template <typename Activation>
struct Add_Activation_Mul2_Clamp
{
Add_Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x, const int32_t& bias, const float& requantScale) const
{ {
float y_fp32 = ck::type_convert<float>(x1 + x2); float y_fp32 = ck::type_convert<float>(x + bias);
activationOp_(y_fp32, y_fp32); activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(multiplier_ * y_fp32, -128.f, 127.f); y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32); y = ck::type_convert<int8_t>(y_fp32);
} }
float multiplier_;
Activation activationOp_; Activation activationOp_;
}; };
...@@ -61,23 +99,23 @@ struct Add_Activation_Mul_Clamp ...@@ -61,23 +99,23 @@ struct Add_Activation_Mul_Clamp
template <typename Activation> template <typename Activation>
struct Add_Mul_Activation_Mul_Clamp struct Add_Mul_Activation_Mul_Clamp
{ {
Add_Mul_Activation_Mul_Clamp(float multiplier1, float multiplier2, Activation activationOp) Add_Mul_Activation_Mul_Clamp(float requantScale1, float requantScale2, Activation activationOp)
: multiplier1_(multiplier1), multiplier2_(multiplier2), activationOp_(activationOp) : requantScale1_(requantScale1), requantScale2_(requantScale2), activationOp_(activationOp)
{ {
} }
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x1, const int32_t& x2) const operator()(int8_t& y, const int32_t& x, const int32_t& bias) const
{ {
float y_fp32 = ck::type_convert<float>(x1 + x2); float y_fp32 = ck::type_convert<float>(x + bias);
y_fp32 = multiplier1_ * y_fp32; y_fp32 = requantScale1_ * y_fp32;
activationOp_(y_fp32, y_fp32); activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(multiplier2_ * y_fp32, -128.f, 127.f); y_fp32 = math::clamp(requantScale2_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32); y = ck::type_convert<int8_t>(y_fp32);
} }
float multiplier1_; float requantScale1_;
float multiplier2_; float requantScale2_;
Activation activationOp_; Activation activationOp_;
}; };
......
...@@ -93,6 +93,9 @@ struct GridwiseMultiblockWelfordFirstHalf ...@@ -93,6 +93,9 @@ struct GridwiseMultiblockWelfordFirstHalf
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
// clang-format off
// First half of the Multiblock Welford method to calculate mean and variance, used by both batchnorm-forward and batchnorm-backward.
// clang-format on
__device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k, __device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g, const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread, const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
......
...@@ -529,6 +529,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal ...@@ -529,6 +529,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize()); resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
// calculate inv-variance as 1/sqrt(epsilon+variance)
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_var_thread_buf(I) = welford_var_thread_buf(I) =
type_convert<AccDataType>(1.0f) / sqrt(epsilon + welford_var_thread_buf[I]); type_convert<AccDataType>(1.0f) / sqrt(epsilon + welford_var_thread_buf[I]);
......
...@@ -441,6 +441,7 @@ struct GridwiseBatchNormForwardWithBlockwiseWelford ...@@ -441,6 +441,7 @@ struct GridwiseBatchNormForwardWithBlockwiseWelford
auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize()); resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
// calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
var_thread_buf(I) = var_thread_buf(I) =
type_convert<AccDataType>(1.0f) / sqrt(epsilon + var_thread_buf[I]); type_convert<AccDataType>(1.0f) / sqrt(epsilon + var_thread_buf[I]);
......
...@@ -27,8 +27,8 @@ using F16_Tuple = ck::Tuple<F16>; ...@@ -27,8 +27,8 @@ using F16_Tuple = ck::Tuple<F16>;
using F16_F16_Tuple = ck::Tuple<F16, F16>; using F16_F16_Tuple = ck::Tuple<F16, F16>;
using F32_Tuple = ck::Tuple<F32>; using F32_Tuple = ck::Tuple<F32>;
using I32_Tuple = ck::Tuple<I32>; using I32_Tuple = ck::Tuple<I32>;
using I32_F32_Tuple = ck::Tuple<I32, F32>;
// GEMM layout // GEMM layout
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
...@@ -79,7 +79,8 @@ using NDHWGK = ck::tensor_layout::convolution::NDHWGK; ...@@ -79,7 +79,8 @@ using NDHWGK = ck::tensor_layout::convolution::NDHWGK;
// //
using GK = ck::tensor_layout::convolution::G_K; using GK = ck::tensor_layout::convolution::G_K;
using GK_TUPLE = ck::Tuple<GK>; using GK_Tuple = ck::Tuple<GK>;
using GK_GK_Tuple = ck::Tuple<GK, GK>;
// pointwise functor // pointwise functor
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...@@ -97,6 +98,13 @@ template <typename Activation> ...@@ -97,6 +98,13 @@ template <typename Activation>
using Add_Activation_Mul_Clamp = using Add_Activation_Mul_Clamp =
ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Activation>; ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Activation>;
template <typename Activation>
using Activation_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp<Activation>;
template <typename Activation>
using Add_Activation_Mul2_Clamp =
ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp<Activation>;
template <typename DeviceOp, typename Tag = void> template <typename DeviceOp, typename Tag = void>
struct DeviceOperationInstanceFactory; struct DeviceOperationInstanceFactory;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// FP16
void add_device_batchnorm_backward_rank_4_3_f16_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<F16, F32, F32, F32, F16, F32, F32, PassThrough, 4, 3>>>&);
// FP32
void add_device_batchnorm_backward_rank_4_3_f32_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<F32, F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
// BF16
void add_device_batchnorm_backward_rank_4_3_bf16_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<BF16, F32, F32, F32, BF16, F32, F32, PassThrough, 4, 3>>>&);
// FP64
void add_device_batchnorm_backward_rank_4_3_f64_instances(
std::vector<std::unique_ptr<
DeviceBatchNormBwd<F64, F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&);
template <typename XDataType,
typename DxDataType,
typename DyDataType,
typename AccDataType,
typename ScaleDataType,
typename DscaleDbiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
index_t NumReduceDim>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumReduceDim>>
{
using DeviceOp = DeviceBatchNormBwd<XDataType,
DxDataType,
DyDataType,
AccDataType,
ScaleDataType,
DscaleDbiasDataType,
MeanVarDataType,
DyElementwiseOp,
Rank,
NumReduceDim>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<XDataType, F16> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, F16> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F32> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, F32> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_f32_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, BF16> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_bf16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F64> && is_same_v<DxDataType, F64> &&
is_same_v<DyDataType, F64> && is_same_v<AccDataType, F64> &&
is_same_v<ScaleDataType, F64> && is_same_v<DscaleDbiasDataType, F64> &&
is_same_v<MeanVarDataType, F64>)
{
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{
add_device_batchnorm_backward_rank_4_3_f64_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -2,6 +2,9 @@ add_instance_library(device_batchnorm_instance ...@@ -2,6 +2,9 @@ add_instance_library(device_batchnorm_instance
device_batchnorm_forward_f16_instance.cpp device_batchnorm_forward_f16_instance.cpp
device_batchnorm_forward_f32_instance.cpp device_batchnorm_forward_f32_instance.cpp
device_batchnorm_forward_bf16_instance.cpp device_batchnorm_forward_bf16_instance.cpp
device_batchnorm_forward_i8_instance.cpp
device_batchnorm_forward_f64_instance.cpp device_batchnorm_forward_f64_instance.cpp
device_batchnorm_backward_f16_instance.cpp
device_batchnorm_backward_f32_instance.cpp
device_batchnorm_backward_bf16_instance.cpp
device_batchnorm_backward_f64_instance.cpp
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment