Unverified Commit 3696fe1c authored by rocking's avatar rocking Committed by GitHub
Browse files

Layernorm and groupnorm support to save mean and inverse std in forward (#929)

* save mean and inverse std in normalization

* Save mean and inverse std in splitK

* Vector save mean and inv std

* Modify instance for save mean and std

* simplify the layernorm example

* Save mean and std in groupnorm example

* Save mean and inv std in ckProfiler and test

* Remove compute data type from base class

* Save mean and inv std in client example

* Add changelog

* clang format

* Fix compile error

* Refine naming

* Avoid error in bf16

* revert changelog
parent 58338bb2
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
#include "ck/library/tensor_operation_instance/gpu/normalization.hpp" #include "ck/library/tensor_operation_instance/gpu/normalization.hpp"
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using ComputeDataType = float; using SaveMeanInvStdDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#define SAVE_MEAN_INV_STD
constexpr int Rank = 2; constexpr int Rank = 2;
constexpr int NumReduceDim = 1; constexpr int NumReduceDim = 1;
...@@ -50,12 +52,16 @@ int main(int argc, char* argv[]) ...@@ -50,12 +52,16 @@ int main(int argc, char* argv[])
SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N); SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N);
SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N); SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N);
SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size); SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size);
#ifdef SAVE_MEAN_INV_STD
SimpleDeviceMem save_mean_device_buf(sizeof(SaveMeanInvStdDataType) * M);
SimpleDeviceMem save_inv_std_device_buf(sizeof(SaveMeanInvStdDataType) * M);
#endif
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType, using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
PassThrough, PassThrough,
Rank, Rank,
NumReduceDim>; NumReduceDim>;
...@@ -84,14 +90,21 @@ int main(int argc, char* argv[]) ...@@ -84,14 +90,21 @@ int main(int argc, char* argv[])
{0, 1}, // gammaStrides {0, 1}, // gammaStrides
{0, 1}, // betaStrides {0, 1}, // betaStrides
{Stride, 1}, // yStrides {Stride, 1}, // yStrides
{1}, // save_mean Strides
{1}, // save_inv_std Strides
{1}, // reduceDims {1}, // reduceDims
1e-4, 1e-4,
x_device_buf.GetDeviceBuffer(), x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(), beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(), y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr, nullptr,
nullptr, nullptr,
#endif
PassThrough{}); PassThrough{});
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
...@@ -109,6 +122,10 @@ int main(int argc, char* argv[]) ...@@ -109,6 +122,10 @@ int main(int argc, char* argv[])
std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N + std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N +
sizeof(BetaDataType) * N + sizeof(YDataType) * M * N; sizeof(BetaDataType) * N + sizeof(YDataType) * M * N;
#ifdef SAVE_MEAN_INV_STD
num_byte += sizeof(SaveMeanInvStdDataType) * M * 2;
#endif
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
...@@ -140,17 +157,24 @@ int main(int argc, char* argv[]) ...@@ -140,17 +157,24 @@ int main(int argc, char* argv[])
auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths
{Stride, 1}, // xStrides {Stride, 1}, // xStrides
{1}, // gammaStrides {0, 1}, // gammaStrides
{1}, // betaStrides {0, 1}, // betaStrides
{Stride, 1}, // yStrides {Stride, 1}, // yStrides
{1}, // save_mean Strides
{1}, // save_inv_std Strides
{1}, // reduceDims {1}, // reduceDims
1e-4, 1e-4,
x_device_buf.GetDeviceBuffer(), x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(), beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(), y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr, nullptr,
nullptr, nullptr,
#endif
PassThrough{}); PassThrough{});
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
#include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp" #include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp"
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = float; using GammaDataType = float;
using BetaDataType = float; using BetaDataType = float;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using ComputeDataType = float; using SaveMeanInvStdDataType = float;
using Swish = ck::tensor_operation::element_wise::Swish; using Swish = ck::tensor_operation::element_wise::Swish;
#define SAVE_MEAN_INV_STD
constexpr int Rank = 5; constexpr int Rank = 5;
constexpr int NumReduceDim = 3; constexpr int NumReduceDim = 3;
...@@ -49,19 +51,24 @@ int main(int argc, char* argv[]) ...@@ -49,19 +51,24 @@ int main(int argc, char* argv[])
std::size_t xy_size = N * H * W * G * C; std::size_t xy_size = N * H * W * G * C;
std::size_t gamma_beta_size = G * C; std::size_t gamma_beta_size = G * C;
std::vector<ck::index_t> xy_strides = {H * W * G * C, W * G * C, G * C, C, 1}; std::vector<ck::index_t> xy_strides = {H * W * G * C, W * G * C, G * C, C, 1};
std::vector<ck::index_t> gamma_beta_strides = {0, 0, 0, C, 1}; std::vector<ck::index_t> gamma_beta_strides = {0, 0, 0, C, 1};
std::vector<ck::index_t> save_mean_inv_std_strides = {G, 1};
SimpleDeviceMem x_device_buf(sizeof(XDataType) * xy_size); SimpleDeviceMem x_device_buf(sizeof(XDataType) * xy_size);
SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_beta_size); SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_beta_size);
SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * gamma_beta_size); SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * gamma_beta_size);
SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size); SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size);
#ifdef SAVE_MEAN_INV_STD
SimpleDeviceMem save_mean_device_buf(sizeof(SaveMeanInvStdDataType) * N * G);
SimpleDeviceMem save_inv_std_device_buf(sizeof(SaveMeanInvStdDataType) * N * G);
#endif
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType, using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
Swish, Swish,
Rank, Rank,
NumReduceDim>; NumReduceDim>;
...@@ -75,19 +82,26 @@ int main(int argc, char* argv[]) ...@@ -75,19 +82,26 @@ int main(int argc, char* argv[])
const auto& generic_op_ptr = op_ptrs[0]; const auto& generic_op_ptr = op_ptrs[0];
auto generic_argument_ptr = auto generic_argument_ptr =
generic_op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths generic_op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
xy_strides, // xStrides xy_strides, // xStrides
gamma_beta_strides, // gammaStrides gamma_beta_strides, // gammaStrides
gamma_beta_strides, // betaStrides gamma_beta_strides, // betaStrides
xy_strides, // yStrides xy_strides, // yStrides
{1, 2, 4}, // reduceDims save_mean_inv_std_strides, // save_mean Strides
save_mean_inv_std_strides, // save_inv_std Strides
{1, 2, 4}, // reduceDims
1e-6, 1e-6,
x_device_buf.GetDeviceBuffer(), x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(), beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(), y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr, nullptr,
nullptr, nullptr,
#endif
Swish{}); Swish{});
if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get())) if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get()))
...@@ -107,21 +121,29 @@ int main(int argc, char* argv[]) ...@@ -107,21 +121,29 @@ int main(int argc, char* argv[])
for(int i = 0; i < op_ptrs.size(); ++i) for(int i = 0; i < op_ptrs.size(); ++i)
{ {
auto& op_ptr = op_ptrs[i]; auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths auto argument_ptr =
xy_strides, // xStrides op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
gamma_beta_strides, // gammaStrides xy_strides, // xStrides
gamma_beta_strides, // betaStrides gamma_beta_strides, // gammaStrides
xy_strides, // yStrides gamma_beta_strides, // betaStrides
{1, 2, 4}, // reduceDims xy_strides, // yStrides
1e-6, save_mean_inv_std_strides, // save_mean Strides
x_device_buf.GetDeviceBuffer(), save_mean_inv_std_strides, // save_inv_std Strides
gamma_device_buf.GetDeviceBuffer(), {1, 2, 4}, // reduceDims
beta_device_buf.GetDeviceBuffer(), 1e-6,
y_device_buf.GetDeviceBuffer(), x_device_buf.GetDeviceBuffer(),
nullptr, gamma_device_buf.GetDeviceBuffer(),
nullptr, beta_device_buf.GetDeviceBuffer(),
Swish{}); y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr,
nullptr,
#endif
Swish{});
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
...@@ -139,6 +161,10 @@ int main(int argc, char* argv[]) ...@@ -139,6 +161,10 @@ int main(int argc, char* argv[])
sizeof(XDataType) * xy_size + sizeof(GammaDataType) * gamma_beta_size + sizeof(XDataType) * xy_size + sizeof(GammaDataType) * gamma_beta_size +
sizeof(BetaDataType) * gamma_beta_size + sizeof(YDataType) * xy_size; sizeof(BetaDataType) * gamma_beta_size + sizeof(YDataType) * xy_size;
#ifdef SAVE_MEAN_INV_STD
num_byte += sizeof(SaveMeanInvStdDataType) * N * G * 2;
#endif
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
...@@ -169,20 +195,28 @@ int main(int argc, char* argv[]) ...@@ -169,20 +195,28 @@ int main(int argc, char* argv[])
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl; << std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths auto argument_ptr =
xy_strides, // xStrides op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
gamma_beta_strides, // gammaStrides xy_strides, // xStrides
gamma_beta_strides, // betaStrides gamma_beta_strides, // gammaStrides
xy_strides, // yStrides gamma_beta_strides, // betaStrides
{1, 2, 4}, // reduceDims xy_strides, // yStrides
1e-6, save_mean_inv_std_strides, // save_mean Strides
x_device_buf.GetDeviceBuffer(), save_mean_inv_std_strides, // save_inv_std Strides
gamma_device_buf.GetDeviceBuffer(), {1, 2, 4}, // reduceDims
beta_device_buf.GetDeviceBuffer(), 1e-6,
y_device_buf.GetDeviceBuffer(), x_device_buf.GetDeviceBuffer(),
nullptr, gamma_device_buf.GetDeviceBuffer(),
nullptr, beta_device_buf.GetDeviceBuffer(),
Swish{}); y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr,
nullptr,
#endif
Swish{});
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
......
...@@ -114,12 +114,15 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n, ...@@ -114,12 +114,15 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
BetaDataType, BetaDataType,
HDataType, HDataType,
AccDataType, AccDataType,
AccDataType,
HElementOp, HElementOp,
2, 2,
1>; 1>;
Tensor<EMeanVarDataType> e_m_n(HostTensorDescriptor{M, N}); Tensor<EMeanVarDataType> e_m_n(HostTensorDescriptor{M, N});
Tensor<AccDataType> c_m_n(HostTensorDescriptor{M, N}); Tensor<AccDataType> c_m_n(HostTensorDescriptor{M, N});
Tensor<AccDataType> save_mean({M});
Tensor<AccDataType> save_inv_std({M});
auto ref_gemm = ReferenceGemm{}; auto ref_gemm = ReferenceGemm{};
auto ref_gemm_invoker = ref_gemm.MakeInvoker(); auto ref_gemm_invoker = ref_gemm.MakeInvoker();
...@@ -145,7 +148,7 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n, ...@@ -145,7 +148,7 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
auto ref_layernorm_invoker = ref_layernorm.MakeInvoker(); auto ref_layernorm_invoker = ref_layernorm.MakeInvoker();
auto ref_layernorm_argument = ref_layernorm.MakeArgument( auto ref_layernorm_argument = ref_layernorm.MakeArgument(
e_m_n, gamma_n, beta_n, h_m_n, h_element_op, {M, N}, {1}, epsilon); e_m_n, gamma_n, beta_n, h_m_n, save_mean, save_inv_std, h_element_op, {M, N}, {1}, epsilon);
ref_layernorm_invoker.Run(ref_layernorm_argument); ref_layernorm_invoker.Run(ref_layernorm_argument);
} }
......
...@@ -3,12 +3,15 @@ ...@@ -3,12 +3,15 @@
#include "common.hpp" #include "common.hpp"
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using ComputeDataType = float; using SaveMeanInvStdDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#define SAVE_MEAN_INV_STD
constexpr int Rank = 2; constexpr int Rank = 2;
constexpr int NumReduceDim = 1; constexpr int NumReduceDim = 1;
...@@ -19,6 +22,7 @@ using DeviceInstance = ...@@ -19,6 +22,7 @@ using DeviceInstance =
BetaDataType, BetaDataType,
ComputeDataType, ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
PassThrough, PassThrough,
Rank, Rank,
NumReduceDim, NumReduceDim,
...@@ -33,7 +37,8 @@ using DeviceInstance = ...@@ -33,7 +37,8 @@ using DeviceInstance =
8, // GammaScalarPerVector 8, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K) 1, // BetaVecDim (0=M, 1=K)
8, // BetaScalarPerVector 8, // BetaScalarPerVector
8>; // OutScalarPerVector 8, // YScalarPerVector
1>; // SaveMeanInvStdScalarPerVector
#include "run_layernorm_example.inc" #include "run_layernorm_example.inc"
int main() { return run_groupnorm_example<DeviceInstance>(); } int main() { return run_groupnorm_example<DeviceInstance>(); }
...@@ -3,12 +3,15 @@ ...@@ -3,12 +3,15 @@
#include "common.hpp" #include "common.hpp"
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using ComputeDataType = float; using SaveMeanInvStdDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#define SAVE_MEAN_INV_STD
constexpr int Rank = 2; constexpr int Rank = 2;
constexpr int NumReduceDim = 1; constexpr int NumReduceDim = 1;
...@@ -19,6 +22,7 @@ using DeviceInstance = ...@@ -19,6 +22,7 @@ using DeviceInstance =
BetaDataType, BetaDataType,
ComputeDataType, ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
PassThrough, PassThrough,
Rank, Rank,
NumReduceDim, NumReduceDim,
...@@ -33,7 +37,8 @@ using DeviceInstance = ...@@ -33,7 +37,8 @@ using DeviceInstance =
8, // GammaScalarPerVector 8, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K) 1, // BetaVecDim (0=M, 1=K)
8, // BetaScalarPerVector 8, // BetaScalarPerVector
8>; // YScalarPerVector 8, // YScalarPerVector
1>; // SaveMeanInvStdScalarPerVector
#include "run_layernorm_example.inc" #include "run_layernorm_example.inc"
......
...@@ -10,22 +10,13 @@ int run_groupnorm_example() ...@@ -10,22 +10,13 @@ int run_groupnorm_example()
ck::index_t M = 1024; ck::index_t M = 1024;
ck::index_t N = 1024; ck::index_t N = 1024;
ck::index_t Stride = N;
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { Tensor<XDataType> x({M, N});
return HostTensorDescriptor({len}, {stride}); Tensor<GammaDataType> gamma({N});
}; Tensor<BetaDataType> beta({N});
Tensor<YDataType> y({M, N});
auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) { Tensor<SaveMeanInvStdDataType> save_mean({M});
using namespace ck::literals; Tensor<SaveMeanInvStdDataType> save_inv_std({M});
return HostTensorDescriptor({row, col}, {stride, 1_uz});
};
Tensor<XDataType> x(f_host_tensor_descriptor2d(M, N, Stride));
Tensor<GammaDataType> gamma(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta(f_host_tensor_descriptor1d(N, 1));
Tensor<YDataType> y(f_host_tensor_descriptor2d(M, N, Stride));
x.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0}); x.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0}); gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0});
...@@ -35,6 +26,11 @@ int run_groupnorm_example() ...@@ -35,6 +26,11 @@ int run_groupnorm_example()
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize());
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize());
#ifdef SAVE_MEAN_INV_STD
DeviceMem save_mean_dev(sizeof(SaveMeanInvStdDataType) * save_mean.mDesc.GetElementSpaceSize());
DeviceMem save_inv_std_dev(sizeof(SaveMeanInvStdDataType) *
save_inv_std.mDesc.GetElementSpaceSize());
#endif
x_dev.ToDevice(x.mData.data()); x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data()); gamma_dev.ToDevice(gamma.mData.data());
...@@ -47,14 +43,23 @@ int run_groupnorm_example() ...@@ -47,14 +43,23 @@ int run_groupnorm_example()
{0, 1}, {0, 1},
{0, 1}, {0, 1},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
{1}, {1},
1e-4, 1e-4,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(), gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(), y_dev.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_dev.GetDeviceBuffer(),
save_inv_std_dev.GetDeviceBuffer(),
#else
nullptr, nullptr,
nullptr, nullptr,
#endif
PassThrough{}); PassThrough{});
if(!device_instance.IsSupportedArgument(argument_ptr.get())) if(!device_instance.IsSupportedArgument(argument_ptr.get()))
...@@ -72,24 +77,45 @@ int run_groupnorm_example() ...@@ -72,24 +77,45 @@ int run_groupnorm_example()
bool pass = true; bool pass = true;
{ {
Tensor<YDataType> host_y(f_host_tensor_descriptor2d(M, N, Stride)); Tensor<YDataType> host_y({M, N});
using ReferenceInstance = ck::tensor_operation::host::ReferenceLayernorm<XDataType, Tensor<SaveMeanInvStdDataType> host_save_mean({M});
GammaDataType, Tensor<SaveMeanInvStdDataType> host_save_inv_std({M});
BetaDataType,
YDataType, using ReferenceInstance =
ComputeDataType, ck::tensor_operation::host::ReferenceLayernorm<XDataType,
PassThrough, GammaDataType,
Rank, BetaDataType,
NumReduceDim>; YDataType,
SaveMeanInvStdDataType,
ComputeDataType,
PassThrough,
Rank,
NumReduceDim>;
ReferenceInstance ref; ReferenceInstance ref;
auto ref_argument = auto ref_argument = ref.MakeArgument(x,
ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, {M, N}, {1}, 1e-4); gamma,
auto ref_invoker = ref.MakeInvoker(); beta,
host_y,
host_save_mean,
host_save_inv_std,
PassThrough{},
{M, N},
{1},
1e-4);
auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3); pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results (y)", 1e-3, 1e-3);
#ifdef SAVE_MEAN_INV_STD
save_mean_dev.FromDevice(save_mean.mData.data());
save_inv_std_dev.FromDevice(save_inv_std.mData.data());
pass &= ck::utils::check_err(
save_mean, host_save_mean, "Error: Incorrect results (mean)", 1e-3, 1e-3);
pass &= ck::utils::check_err(
save_inv_std, host_save_inv_std, "Error: Incorrect results (inv_std)", 1e-3, 1e-3);
#endif
} }
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
...@@ -6,11 +6,14 @@ ...@@ -6,11 +6,14 @@
constexpr int Rank = 5; constexpr int Rank = 5;
constexpr int NumReduceDim = 3; constexpr int NumReduceDim = 3;
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using ComputeDataType = float; using SaveMeanInvStdDataType = float;
using ComputeDataType = float;
#define SAVE_MEAN_INV_STD
struct YElementOp struct YElementOp
{ {
...@@ -39,6 +42,7 @@ using DeviceInstance = ...@@ -39,6 +42,7 @@ using DeviceInstance =
BetaDataType, BetaDataType,
ComputeDataType, ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
YElementOp, YElementOp,
Rank, Rank,
NumReduceDim, NumReduceDim,
...@@ -53,7 +57,8 @@ using DeviceInstance = ...@@ -53,7 +57,8 @@ using DeviceInstance =
2, // GammaScalarPerVector 2, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K) 1, // BetaVecDim (0=M, 1=K)
2, // BetaScalarPerVector 2, // BetaScalarPerVector
2>; // OutScalarPerVector 2, // YScalarPerVector
1>; // SaveMeanInvStdScalarPerVector
#include "run_groupnorm_example.inc" #include "run_groupnorm_example.inc"
......
...@@ -6,12 +6,15 @@ ...@@ -6,12 +6,15 @@
constexpr int Rank = 5; constexpr int Rank = 5;
constexpr int NumReduceDim = 3; constexpr int NumReduceDim = 3;
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using ComputeDataType = float; using SaveMeanInvStdDataType = float;
using YElementOp = ck::tensor_operation::element_wise::Swish; using ComputeDataType = float;
using YElementOp = ck::tensor_operation::element_wise::Swish;
#define SAVE_MEAN_INV_STD
using DeviceInstance = using DeviceInstance =
ck::tensor_operation::device::DeviceNormalizationSplitKImpl<XDataType, ck::tensor_operation::device::DeviceNormalizationSplitKImpl<XDataType,
...@@ -19,6 +22,7 @@ using DeviceInstance = ...@@ -19,6 +22,7 @@ using DeviceInstance =
BetaDataType, BetaDataType,
ComputeDataType, ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
YElementOp, YElementOp,
Rank, Rank,
NumReduceDim, NumReduceDim,
...@@ -33,7 +37,8 @@ using DeviceInstance = ...@@ -33,7 +37,8 @@ using DeviceInstance =
2, // GammaScalarPerVector 2, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K) 1, // BetaVecDim (0=M, 1=K)
2, // BetaScalarPerVector 2, // BetaScalarPerVector
2>; // OutScalarPerVector 2, // YScalarPerVector
1>; // SaveMeanInvStdScalarPerVector
#include "run_groupnorm_example.inc" #include "run_groupnorm_example.inc"
......
...@@ -6,12 +6,15 @@ ...@@ -6,12 +6,15 @@
constexpr int Rank = 5; constexpr int Rank = 5;
constexpr int NumReduceDim = 3; constexpr int NumReduceDim = 3;
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using ComputeDataType = float; using SaveMeanInvStdDataType = float;
using YElementOp = ck::tensor_operation::element_wise::Swish; using ComputeDataType = float;
using YElementOp = ck::tensor_operation::element_wise::Swish;
#define SAVE_MEAN_INV_STD
using DeviceInstance = using DeviceInstance =
ck::tensor_operation::device::DeviceNormalizationImpl<XDataType, ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
...@@ -19,6 +22,7 @@ using DeviceInstance = ...@@ -19,6 +22,7 @@ using DeviceInstance =
BetaDataType, BetaDataType,
ComputeDataType, ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
YElementOp, YElementOp,
Rank, Rank,
NumReduceDim, NumReduceDim,
...@@ -33,7 +37,8 @@ using DeviceInstance = ...@@ -33,7 +37,8 @@ using DeviceInstance =
2, // GammaScalarPerVector 2, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K) 1, // BetaVecDim (0=M, 1=K)
2, // BetaScalarPerVector 2, // BetaScalarPerVector
2>; // OutScalarPerVector 2, // YScalarPerVector
1>; // SaveMeanInvStdScalarPerVector
#include "run_groupnorm_example.inc" #include "run_groupnorm_example.inc"
......
...@@ -34,6 +34,8 @@ int run_groupnorm_example(int argc, char* argv[]) ...@@ -34,6 +34,8 @@ int run_groupnorm_example(int argc, char* argv[])
Tensor<YDataType> y({N, H, W, G, C}); Tensor<YDataType> y({N, H, W, G, C});
Tensor<GammaDataType> gamma({G, C}); Tensor<GammaDataType> gamma({G, C});
Tensor<BetaDataType> beta({G, C}); Tensor<BetaDataType> beta({G, C});
Tensor<SaveMeanInvStdDataType> save_mean({N, G});
Tensor<SaveMeanInvStdDataType> save_inv_std({N, G});
ck::utils::FillUniformDistribution<XDataType>{0.f, 1.f}(x); ck::utils::FillUniformDistribution<XDataType>{0.f, 1.f}(x);
ck::utils::FillUniformDistribution<GammaDataType>{0.f, 1.f}(gamma); ck::utils::FillUniformDistribution<GammaDataType>{0.f, 1.f}(gamma);
...@@ -43,6 +45,11 @@ int run_groupnorm_example(int argc, char* argv[]) ...@@ -43,6 +45,11 @@ int run_groupnorm_example(int argc, char* argv[])
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize());
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize());
#ifdef SAVE_MEAN_INV_STD
DeviceMem save_mean_dev(sizeof(SaveMeanInvStdDataType) * save_mean.mDesc.GetElementSpaceSize());
DeviceMem save_inv_std_dev(sizeof(SaveMeanInvStdDataType) *
save_inv_std.mDesc.GetElementSpaceSize());
#endif
x_dev.ToDevice(x.mData.data()); x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data()); gamma_dev.ToDevice(gamma.mData.data());
...@@ -57,14 +64,23 @@ int run_groupnorm_example(int argc, char* argv[]) ...@@ -57,14 +64,23 @@ int run_groupnorm_example(int argc, char* argv[])
{0, 0, 0, C, 1}, {0, 0, 0, C, 1},
{0, 0, 0, C, 1}, {0, 0, 0, C, 1},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
{1, 2, 4}, // reduction dimension: [H, W, C] {1, 2, 4}, // reduction dimension: [H, W, C]
1e-6, 1e-6,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(), gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(), y_dev.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_dev.GetDeviceBuffer(),
save_inv_std_dev.GetDeviceBuffer(),
#else
nullptr, nullptr,
nullptr, nullptr,
#endif
y_element_op); y_element_op);
if(!device_instance.IsSupportedArgument(argument_ptr.get())) if(!device_instance.IsSupportedArgument(argument_ptr.get()))
...@@ -92,21 +108,40 @@ int run_groupnorm_example(int argc, char* argv[]) ...@@ -92,21 +108,40 @@ int run_groupnorm_example(int argc, char* argv[])
bool pass = true; bool pass = true;
{ {
Tensor<YDataType> host_y({N, H, W, G, C}); Tensor<YDataType> host_y({N, H, W, G, C});
using ReferenceInstance = ck::tensor_operation::host::ReferenceGroupnorm<XDataType, Tensor<SaveMeanInvStdDataType> host_save_mean(HostTensorDescriptor{N, G});
GammaDataType, Tensor<SaveMeanInvStdDataType> host_save_inv_std(HostTensorDescriptor{N, G});
BetaDataType, using ReferenceInstance =
YDataType, ck::tensor_operation::host::ReferenceGroupnorm<XDataType,
ComputeDataType, GammaDataType,
YElementOp>; BetaDataType,
YDataType,
SaveMeanInvStdDataType,
ComputeDataType,
YElementOp>;
ReferenceInstance ref; ReferenceInstance ref;
auto ref_argument = auto ref_argument = ref.MakeArgument(x,
ref.MakeArgument(x, gamma, beta, host_y, y_element_op, {N, H, W, G, C}, 1e-6); gamma,
auto ref_invoker = ref.MakeInvoker(); beta,
host_y,
host_save_mean,
host_save_inv_std,
y_element_op,
{N, H, W, G, C},
1e-6);
auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3); pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3);
#ifdef SAVE_MEAN_INV_STD
save_mean_dev.FromDevice(save_mean.mData.data());
save_inv_std_dev.FromDevice(save_inv_std.mData.data());
pass &= ck::utils::check_err(
save_mean, host_save_mean, "Error: Incorrect results (mean)", 1e-3, 1e-3);
pass &= ck::utils::check_err(
save_inv_std, host_save_inv_std, "Error: Incorrect results (inv_std)", 1e-3, 1e-3);
#endif
} }
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
...@@ -167,20 +167,31 @@ int main() ...@@ -167,20 +167,31 @@ int main()
XElementwiseOperation>(x, a, b, mn, XElementwiseOperation{}); XElementwiseOperation>(x, a, b, mn, XElementwiseOperation{});
Tensor<YDataType> host_y(f_host_tensor_descriptor2d(M, N, Stride)); Tensor<YDataType> host_y(f_host_tensor_descriptor2d(M, N, Stride));
Tensor<AccDataType> host_save_mean({M});
Tensor<AccDataType> host_save_inv_std({M});
using ReferenceInstance = using ReferenceInstance =
ck::tensor_operation::host::ReferenceLayernorm<XDataType, ck::tensor_operation::host::ReferenceLayernorm<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, AccDataType,
AccDataType,
YElementwiseOperation, YElementwiseOperation,
Rank, Rank,
NumReduceDim>; NumReduceDim>;
ReferenceInstance ref; ReferenceInstance ref;
auto ref_argument = auto ref_argument = ref.MakeArgument(x,
ref.MakeArgument(x, gamma, beta, host_y, YElementwiseOperation{}, {M, N}, {1}, 1e-4); gamma,
auto ref_invoker = ref.MakeInvoker(); beta,
host_y,
host_save_mean,
host_save_inv_std,
YElementwiseOperation{},
{M, N},
{1},
1e-4);
auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
......
...@@ -14,8 +14,8 @@ namespace device { ...@@ -14,8 +14,8 @@ namespace device {
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
...@@ -27,6 +27,8 @@ struct DeviceNormalization : public BaseOperator ...@@ -27,6 +27,8 @@ struct DeviceNormalization : public BaseOperator
const std::vector<index_t> gammaStrides, const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> saveMeanStrides,
const std::vector<index_t> saveInvStdStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
double epsilon, double epsilon,
const void* p_x, const void* p_x,
...@@ -43,16 +45,16 @@ struct DeviceNormalization : public BaseOperator ...@@ -43,16 +45,16 @@ struct DeviceNormalization : public BaseOperator
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization<XDataType, using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
YElementwiseOperation, YElementwiseOperation,
Rank, Rank,
NumReduceDim>>; NumReduceDim>>;
......
...@@ -28,6 +28,7 @@ template <typename XDataType, ...@@ -28,6 +28,7 @@ template <typename XDataType,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType, typename ComputeDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim, index_t NumReduceDim,
...@@ -43,12 +44,13 @@ template <typename XDataType, ...@@ -43,12 +44,13 @@ template <typename XDataType,
index_t BetaSrcVectorDim, index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorSize, index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize,
bool UseWelford = true> bool UseWelford = true>
struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
YElementwiseOperation, YElementwiseOperation,
Rank, Rank,
NumReduceDim> NumReduceDim>
...@@ -64,18 +66,24 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -64,18 +66,24 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
(BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)), (BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"); "Invalid thread slice sizes and/or beta vector sizes configuration, please check!");
static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!");
using PassThrough = tensor_operation::element_wise::PassThrough; using PassThrough = tensor_operation::element_wise::PassThrough;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static_assert(!reduceAllDim); // TODO
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths, static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<index_t>& inStrides, const std::vector<index_t>& inStrides,
int numBlockTileIteration) int numBlockTileIteration)
{ {
constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank; static constexpr index_t numSrcDim = Rank;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{}); const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
...@@ -133,7 +141,37 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -133,7 +141,37 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
static auto MakeSaveMeanInvStdDescriptor_M(const std::vector<index_t>& lengths,
const std::vector<index_t>& strides)
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
const auto tupleSrcLengths = make_tuple_from_array_and_index_seq(lengths, InvariantDims{});
const auto tupleSrcStrides = make_tuple_from_array_and_index_seq(strides, InvariantDims{});
const auto desc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto grid_desc_m =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(InvariantDims{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
const auto pad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto grid_desc_m_padded = transform_tensor_descriptor(
grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, pad_M)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return grid_desc_m_padded;
}
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1)); using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1));
using GridDesc_M = decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1}));
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -142,17 +180,23 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -142,17 +180,23 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> gammaStrides, const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> saveMeanStrides,
const std::vector<index_t> saveInvStdStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
YElementwiseOperation y_elementwise_op, YElementwiseOperation y_elementwise_op,
double epsilon, double epsilon,
const XDataType* p_x, const XDataType* p_x,
const GammaDataType* p_gamma, const GammaDataType* p_gamma,
const BetaDataType* p_beta, const BetaDataType* p_beta,
YDataType* p_y) YDataType* p_y,
SaveMeanInvStdDataType* p_saveMean,
SaveMeanInvStdDataType* p_saveInvStd)
: p_x_(p_x), : p_x_(p_x),
p_gamma_(p_gamma), p_gamma_(p_gamma),
p_beta_(p_beta), p_beta_(p_beta),
p_y_(p_y), p_y_(p_y),
p_saveMean_(p_saveMean),
p_saveInvStd_(p_saveInvStd),
y_elementwise_op_(y_elementwise_op) y_elementwise_op_(y_elementwise_op)
{ {
epsilon_ = static_cast<ComputeDataType>(epsilon); epsilon_ = static_cast<ComputeDataType>(epsilon);
...@@ -162,16 +206,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -162,16 +206,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims); yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims);
gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims); gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims);
betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims); betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims);
saveMeanStrides_ = saveMeanStrides;
saveInvStdStrides_ = saveInvStdStrides;
long_index_t invariant_length; std::tie(MRaw_, KRaw_) = get_2d_lengths<Rank, NumReduceDim>(Lengths_);
long_index_t reduce_length;
std::tie(invariant_length, reduce_length) =
get_2d_lengths<Rank, NumReduceDim>(Lengths_);
numBlockTileIteration_ = math::integer_divide_ceil(reduce_length, K_BlockTileSize); numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize);
gridSize_ = math::integer_divide_ceil(invariant_length, M_BlockTileSize); gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize);
x_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, xStrides_, numBlockTileIteration_); x_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, xStrides_, numBlockTileIteration_);
gamma_grid_desc_m_k_ = gamma_grid_desc_m_k_ =
...@@ -179,9 +221,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -179,9 +221,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
beta_grid_desc_m_k_ = beta_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, betaStrides_, numBlockTileIteration_); MakeSrc2dDescriptor(Lengths_, betaStrides_, numBlockTileIteration_);
y_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, yStrides_, numBlockTileIteration_); y_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, yStrides_, numBlockTileIteration_);
save_mean_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveMeanStrides);
save_inv_std_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveInvStdStrides);
isSweeponce_ = isSweeponce_ =
x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
if constexpr(NumInvariantDim == 0)
invariant_lowest_length_ = 1;
else
invariant_lowest_length_ = Lengths_[NumInvariantDim - 1];
} }
ComputeDataType epsilon_; ComputeDataType epsilon_;
...@@ -190,12 +239,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -190,12 +239,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const GammaDataType* p_gamma_; const GammaDataType* p_gamma_;
const BetaDataType* p_beta_; const BetaDataType* p_beta_;
YDataType* p_y_; YDataType* p_y_;
SaveMeanInvStdDataType* p_saveMean_;
SaveMeanInvStdDataType* p_saveInvStd_;
std::vector<index_t> Lengths_; std::vector<index_t> Lengths_;
std::vector<index_t> xStrides_; std::vector<index_t> xStrides_;
std::vector<index_t> gammaStrides_; std::vector<index_t> gammaStrides_;
std::vector<index_t> betaStrides_; std::vector<index_t> betaStrides_;
std::vector<index_t> yStrides_; std::vector<index_t> yStrides_;
std::vector<index_t> saveMeanStrides_;
std::vector<index_t> saveInvStdStrides_;
YElementwiseOperation y_elementwise_op_; YElementwiseOperation y_elementwise_op_;
...@@ -206,7 +259,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -206,7 +259,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GridDesc_M_K gamma_grid_desc_m_k_; GridDesc_M_K gamma_grid_desc_m_k_;
GridDesc_M_K beta_grid_desc_m_k_; GridDesc_M_K beta_grid_desc_m_k_;
GridDesc_M_K y_grid_desc_m_k_; GridDesc_M_K y_grid_desc_m_k_;
GridDesc_M save_mean_grid_desc_m_;
GridDesc_M save_inv_std_grid_desc_m_;
bool isSweeponce_; bool isSweeponce_;
index_t MRaw_; // invarient length
index_t KRaw_; // reduce length
index_t invariant_lowest_length_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
...@@ -217,9 +277,11 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -217,9 +277,11 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
GridDesc_M,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
...@@ -233,6 +295,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -233,6 +295,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
BetaSrcVectorSize, BetaSrcVectorSize,
XYSrcVectorDim, XYSrcVectorDim,
YDstVectorSize, YDstVectorSize,
SaveMeanInvStdDstVectorSize,
UseWelford>(arg.isSweeponce_); UseWelford>(arg.isSweeponce_);
float avg_time = 0; float avg_time = 0;
...@@ -245,12 +308,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -245,12 +308,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
arg.gamma_grid_desc_m_k_, arg.gamma_grid_desc_m_k_,
arg.beta_grid_desc_m_k_, arg.beta_grid_desc_m_k_,
arg.y_grid_desc_m_k_, arg.y_grid_desc_m_k_,
arg.save_mean_grid_desc_m_,
arg.save_inv_std_grid_desc_m_,
arg.numBlockTileIteration_, arg.numBlockTileIteration_,
arg.epsilon_, arg.epsilon_,
arg.p_x_, arg.p_x_,
arg.p_gamma_, arg.p_gamma_,
arg.p_beta_, arg.p_beta_,
arg.p_y_, arg.p_y_,
arg.p_saveMean_,
arg.p_saveInvStd_,
arg.y_elementwise_op_); arg.y_elementwise_op_);
return (avg_time); return (avg_time);
...@@ -267,8 +334,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -267,8 +334,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
{ {
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg); const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
constexpr index_t NumInvariantDim = Rank - NumReduceDim;
if constexpr(XYSrcVectorDim == 0) if constexpr(XYSrcVectorDim == 0)
{ {
if constexpr(NumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
...@@ -277,13 +342,15 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -277,13 +342,15 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
} }
else else
{ {
printf("!!!! %d\n", p_arg_->invariant_lowest_length_);
if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) if(p_arg_->xStrides_[NumInvariantDim - 1] != 1)
return false; return false;
if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0) if(p_arg_->invariant_lowest_length_ % XSrcVectorSize != 0)
return false; return false;
if(p_arg_->invariant_lowest_length % YDstVectorSize != 0) if(p_arg_->invariant_lowest_length_ % YDstVectorSize != 0)
return false; return false;
}; };
} }
...@@ -325,7 +392,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -325,7 +392,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1) if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1)
return (false); return (false);
if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0) if(p_arg_->invariant_lowest_length_ % BetaSrcVectorSize != 0)
return (false); return (false);
} }
else // if fastest dim is reduced else // if fastest dim is reduced
...@@ -337,6 +404,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -337,6 +404,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
return (false); return (false);
} }
if(p_arg_->invariant_lowest_length_ % SaveMeanInvStdDstVectorSize != 0)
return false;
return true; return true;
}; };
...@@ -346,6 +416,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -346,6 +416,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> gammaStrides, const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> saveMeanStrides,
const std::vector<index_t> saveInvStdStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
double epsilon, double epsilon,
const void* p_x, const void* p_x,
...@@ -353,27 +425,30 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -353,27 +425,30 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const void* p_beta, const void* p_beta,
void* p_y, void* p_y,
void* p_saveMean, void* p_saveMean,
void* p_saveInvVar, void* p_saveInvStd,
YElementwiseOperation y_elementwise_op) override YElementwiseOperation y_elementwise_op) override
{ {
// TODO if(lengths.size() != Rank || xStrides.size() != Rank || gammaStrides.size() != Rank ||
// Optional cache of the intermediate results (mean and InvVariance) during the betaStrides.size() != Rank || yStrides.size() != Rank ||
// forward pass could speedup in the backward saveMeanStrides.size() != NumInvariantDim || saveInvStdStrides.size() != NumInvariantDim)
ignore = p_saveMean; throw std::runtime_error("dimension is incorrect");
ignore = p_saveInvVar;
return std::make_unique<Argument>(lengths, return std::make_unique<Argument>(lengths,
xStrides, xStrides,
gammaStrides, gammaStrides,
betaStrides, betaStrides,
yStrides, yStrides,
saveMeanStrides,
saveInvStdStrides,
reduceDims, reduceDims,
y_elementwise_op, y_elementwise_op,
epsilon, epsilon,
static_cast<const XDataType*>(p_x), static_cast<const XDataType*>(p_x),
static_cast<const GammaDataType*>(p_gamma), static_cast<const GammaDataType*>(p_gamma),
static_cast<const BetaDataType*>(p_beta), static_cast<const BetaDataType*>(p_beta),
static_cast<YDataType*>(p_y)); static_cast<YDataType*>(p_y),
static_cast<SaveMeanInvStdDataType*>(p_saveMean),
static_cast<SaveMeanInvStdDataType*>(p_saveInvStd));
}; };
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
...@@ -18,9 +18,11 @@ template <typename XDataType, ...@@ -18,9 +18,11 @@ template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType, typename ComputeDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
typename GridDesc_M_K, typename GridDesc_M_K,
typename GridDesc_M,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
index_t KThreadClusterSize, index_t KThreadClusterSize,
...@@ -34,6 +36,7 @@ template <typename XDataType, ...@@ -34,6 +36,7 @@ template <typename XDataType,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorDim, index_t YDstVectorDim,
index_t YDstVectorSize, index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize,
bool SweepOnce> bool SweepOnce>
struct GridwiseNormalizationNaiveVariance_mk_to_mk struct GridwiseNormalizationNaiveVariance_mk_to_mk
{ {
...@@ -45,6 +48,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -45,6 +48,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"); "Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!");
static_assert(XSrcVectorSize == YDstVectorSize); static_assert(XSrcVectorSize == YDstVectorSize);
static_assert(XSrcVectorSize == GammaSrcVectorSize); static_assert(XSrcVectorSize == GammaSrcVectorSize);
static_assert(XSrcVectorSize == BetaSrcVectorSize); static_assert(XSrcVectorSize == BetaSrcVectorSize);
...@@ -66,6 +73,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -66,6 +73,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}))); make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})));
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
...@@ -84,6 +95,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -84,6 +95,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
reduce::Add, reduce::Add,
true>; true>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -98,12 +111,16 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -98,12 +111,16 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
const GridDesc_M_K& gamma_grid_desc_m_k, const GridDesc_M_K& gamma_grid_desc_m_k,
const GridDesc_M_K& beta_grid_desc_m_k, const GridDesc_M_K& beta_grid_desc_m_k,
const GridDesc_M_K& y_grid_desc_m_k, const GridDesc_M_K& y_grid_desc_m_k,
const GridDesc_M& save_mean_grid_desc_m,
const GridDesc_M& save_inv_std_grid_desc_m,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
ComputeDataType epsilon, ComputeDataType epsilon,
const XDataType* const __restrict__ p_x_global, const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global, const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global, const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global, YDataType* const __restrict__ p_y_global,
SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
const YElementwiseOperation y_elementwise_op) const YElementwiseOperation y_elementwise_op)
{ {
// LDS // LDS
...@@ -115,6 +132,12 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -115,6 +132,12 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto save_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize());
auto save_inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize());
auto x_thread_buf = generate_tuple( auto x_thread_buf = generate_tuple(
[&](auto) { [&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr, return StaticBuffer<AddressSpaceEnum::Vgpr,
...@@ -152,6 +175,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -152,6 +175,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
mean_square_thread_buf; mean_square_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>& StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>&
var_thread_buf = mean_square_thread_buf; var_thread_buf = mean_square_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>&
inv_std_thread_buf = mean_square_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
...@@ -228,6 +253,42 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -228,6 +253,42 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
thread_k_cluster_id * YDstVectorSize), thread_k_cluster_id * YDstVectorSize),
y_elementwise_op); y_elementwise_op);
auto threadwise_mean_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_mean_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_inv_std_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_inv_std_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
constexpr auto thread_copy_bwd_step_m_k = constexpr auto thread_copy_bwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
...@@ -243,7 +304,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -243,7 +304,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// E(x), E[x^2], var(x) // E(x), E[x^2], var(x)
// FIXME: Should not hack the transform from deviceOP // FIXME: Should not hack the transform from deviceOP
int reduce_length = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; ComputeDataType reduce_length = type_convert<ComputeDataType>(
x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = reduce::Add::template GetIdentityValue<ComputeDataType>(); mean_thread_buf(I) = reduce::Add::template GetIdentityValue<ComputeDataType>();
...@@ -302,10 +364,34 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -302,10 +364,34 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// var(x) = E[x^2] - E[x]^2 // var(x) = E[x^2] - E[x]^2
var_thread_buf(I) = var_thread_buf(I) =
mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
inv_std_thread_buf(I) = type_convert<ComputeDataType>(1.0f) /
ck::math::sqrt(var_thread_buf(I) + epsilon);
}); });
// save mean and inverse std for backward (optional)
if(thread_k_cluster_id == 0)
{
if(p_save_mean_global != nullptr)
{
threadwise_mean_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
save_mean_grid_desc_m,
save_mean_global_val_buf);
}
if(p_save_inv_std_global != nullptr)
{
threadwise_inv_std_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_std_thread_buf,
save_inv_std_grid_desc_m,
save_inv_std_global_val_buf);
}
}
// normalization
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
...@@ -314,7 +400,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -314,7 +400,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// normalize // normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) * (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor; inv_std_thread_buf(iM);
// gamma & beta // gamma & beta
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
...@@ -404,8 +490,30 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -404,8 +490,30 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// var(x) = E[x^2] - E[x]^2 // var(x) = E[x^2] - E[x]^2
var_thread_buf(I) = var_thread_buf(I) =
mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon);
}); });
if(thread_k_cluster_id == 0)
{
if(p_save_mean_global != nullptr)
{
threadwise_mean_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
save_mean_grid_desc_m,
save_mean_global_val_buf);
}
if(p_save_inv_std_global != nullptr)
{
threadwise_inv_std_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_std_thread_buf,
save_inv_std_grid_desc_m,
save_inv_std_global_val_buf);
}
}
auto thread_copy_tail_m_k = auto thread_copy_tail_m_k =
(num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k; (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k;
...@@ -437,7 +545,6 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -437,7 +545,6 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
}); });
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
...@@ -446,7 +553,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -446,7 +553,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// normalize // normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) * (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor; inv_std_thread_buf(iM);
// gamma // gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
......
...@@ -12,31 +12,42 @@ template <typename GridwiseReduction, ...@@ -12,31 +12,42 @@ template <typename GridwiseReduction,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType, typename ComputeDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
typename GridDesc_M_K> typename GridDesc_M_K,
__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, typename GridDesc_M>
const GridDesc_M_K gamma_grid_desc_m_k, __global__ void
const GridDesc_M_K beta_grid_desc_m_k, kernel_normalization(const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_M_K y_grid_desc_m_k, const GridDesc_M_K gamma_grid_desc_m_k,
index_t num_k_block_tile_iteration, const GridDesc_M_K beta_grid_desc_m_k,
ComputeDataType epsilon, const GridDesc_M_K y_grid_desc_m_k,
const XDataType* const __restrict__ p_x_global, const GridDesc_M save_mean_grid_desc_m,
const GammaDataType* const __restrict__ p_gamma_global, const GridDesc_M save_inv_std_grid_desc_m,
const BetaDataType* const __restrict__ p_beta_global, index_t num_k_block_tile_iteration,
YDataType* const __restrict__ p_y_global, ComputeDataType epsilon,
const YElementwiseOperation y_elementwise_op) const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
const YElementwiseOperation y_elementwise_op)
{ {
GridwiseReduction::Run(x_grid_desc_m_k, GridwiseReduction::Run(x_grid_desc_m_k,
gamma_grid_desc_m_k, gamma_grid_desc_m_k,
beta_grid_desc_m_k, beta_grid_desc_m_k,
y_grid_desc_m_k, y_grid_desc_m_k,
save_mean_grid_desc_m,
save_inv_std_grid_desc_m,
num_k_block_tile_iteration, num_k_block_tile_iteration,
epsilon, epsilon,
p_x_global, p_x_global,
p_gamma_global, p_gamma_global,
p_beta_global, p_beta_global,
p_y_global, p_y_global,
p_save_mean_global,
p_save_inv_std_global,
y_elementwise_op); y_elementwise_op);
}; };
...@@ -44,9 +55,11 @@ template <typename XDataType, ...@@ -44,9 +55,11 @@ template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType, typename ComputeDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
typename GridDesc_M_K, typename GridDesc_M_K,
typename GridDesc_M,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
index_t KThreadClusterSize, index_t KThreadClusterSize,
...@@ -60,6 +73,7 @@ template <typename XDataType, ...@@ -60,6 +73,7 @@ template <typename XDataType,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorDim, index_t YDstVectorDim,
index_t YDstVectorSize, index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize,
bool UseWelford> bool UseWelford>
auto NormalizationKernelSelector(bool isSweepOnce) auto NormalizationKernelSelector(bool isSweepOnce)
{ {
...@@ -68,9 +82,11 @@ auto NormalizationKernelSelector(bool isSweepOnce) ...@@ -68,9 +82,11 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
GridDesc_M,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
...@@ -84,15 +100,18 @@ auto NormalizationKernelSelector(bool isSweepOnce) ...@@ -84,15 +100,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize, BetaSrcVectorSize,
YDstVectorDim, YDstVectorDim,
YDstVectorSize, YDstVectorSize,
SaveMeanInvStdDstVectorSize,
false>; false>;
using GridwiseNormalizationSweepOnceNaive = using GridwiseNormalizationSweepOnceNaive =
GridwiseNormalizationNaiveVariance_mk_to_mk<XDataType, GridwiseNormalizationNaiveVariance_mk_to_mk<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
GridDesc_M,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
...@@ -106,15 +125,18 @@ auto NormalizationKernelSelector(bool isSweepOnce) ...@@ -106,15 +125,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize, BetaSrcVectorSize,
YDstVectorDim, YDstVectorDim,
YDstVectorSize, YDstVectorSize,
SaveMeanInvStdDstVectorSize,
true>; true>;
using GridwiseNormalizationGenericWelford = using GridwiseNormalizationGenericWelford =
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType, GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
GridDesc_M,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
...@@ -128,15 +150,18 @@ auto NormalizationKernelSelector(bool isSweepOnce) ...@@ -128,15 +150,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize, BetaSrcVectorSize,
YDstVectorDim, YDstVectorDim,
YDstVectorSize, YDstVectorSize,
SaveMeanInvStdDstVectorSize,
false>; false>;
using GridwiseNormalizationSweepOnceWelford = using GridwiseNormalizationSweepOnceWelford =
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType, GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
GridDesc_M,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
...@@ -150,6 +175,7 @@ auto NormalizationKernelSelector(bool isSweepOnce) ...@@ -150,6 +175,7 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize, BetaSrcVectorSize,
YDstVectorDim, YDstVectorDim,
YDstVectorSize, YDstVectorSize,
SaveMeanInvStdDstVectorSize,
true>; true>;
if constexpr(UseWelford) if constexpr(UseWelford)
...@@ -159,17 +185,21 @@ auto NormalizationKernelSelector(bool isSweepOnce) ...@@ -159,17 +185,21 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K> GridDesc_M_K,
GridDesc_M>
: kernel_normalization<GridwiseNormalizationGenericWelford, : kernel_normalization<GridwiseNormalizationGenericWelford,
XDataType, XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K>; GridDesc_M_K,
GridDesc_M>;
} }
else else
{ {
...@@ -178,17 +208,21 @@ auto NormalizationKernelSelector(bool isSweepOnce) ...@@ -178,17 +208,21 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K> GridDesc_M_K,
GridDesc_M>
: kernel_normalization<GridwiseNormalizationGenericNaive, : kernel_normalization<GridwiseNormalizationGenericNaive,
XDataType, XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K>; GridDesc_M_K,
GridDesc_M>;
} }
} }
......
...@@ -17,11 +17,13 @@ template <typename MeanVarDataType, ...@@ -17,11 +17,13 @@ template <typename MeanVarDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType, typename ComputeDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
typename MeanVarGridDesc_M_KBlock, typename MeanVarGridDesc_M_KBlock,
typename CountGridDesc_M_KBlock, typename CountGridDesc_M_KBlock,
typename XYGammaBetaGridDesc_M_K, typename XYGammaBetaGridDesc_M_K,
typename SaveMeanInvStdGridDesc_M,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
index_t KThreadClusterSize, index_t KThreadClusterSize,
...@@ -34,7 +36,8 @@ template <typename MeanVarDataType, ...@@ -34,7 +36,8 @@ template <typename MeanVarDataType,
index_t BetaSrcVectorDim, index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorDim, index_t YDstVectorDim,
index_t YDstVectorSize> index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize>
struct GridwiseNormalizationSplitK2nd struct GridwiseNormalizationSplitK2nd
{ {
static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
...@@ -45,6 +48,10 @@ struct GridwiseNormalizationSplitK2nd ...@@ -45,6 +48,10 @@ struct GridwiseNormalizationSplitK2nd
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"); "Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!");
static_assert(XSrcVectorSize == YDstVectorSize); static_assert(XSrcVectorSize == YDstVectorSize);
static_assert(XSrcVectorSize == GammaSrcVectorSize); static_assert(XSrcVectorSize == GammaSrcVectorSize);
static_assert(XSrcVectorSize == BetaSrcVectorSize); static_assert(XSrcVectorSize == BetaSrcVectorSize);
...@@ -69,6 +76,10 @@ struct GridwiseNormalizationSplitK2nd ...@@ -69,6 +76,10 @@ struct GridwiseNormalizationSplitK2nd
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>; using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
static constexpr auto thread_buffer_desc_m_1 = static constexpr auto thread_buffer_desc_m_1 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, I1)); make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, I1));
...@@ -99,6 +110,8 @@ struct GridwiseNormalizationSplitK2nd ...@@ -99,6 +110,8 @@ struct GridwiseNormalizationSplitK2nd
const XYGammaBetaGridDesc_M_K& gamma_grid_desc_m_k, const XYGammaBetaGridDesc_M_K& gamma_grid_desc_m_k,
const XYGammaBetaGridDesc_M_K& beta_grid_desc_m_k, const XYGammaBetaGridDesc_M_K& beta_grid_desc_m_k,
const XYGammaBetaGridDesc_M_K& y_grid_desc_m_k, const XYGammaBetaGridDesc_M_K& y_grid_desc_m_k,
const SaveMeanInvStdGridDesc_M& save_mean_grid_desc_m,
const SaveMeanInvStdGridDesc_M& save_inv_std_grid_desc_m,
index_t num_k_mean_var_count_iteration, index_t num_k_mean_var_count_iteration,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
index_t k_grid_size, index_t k_grid_size,
...@@ -110,6 +123,8 @@ struct GridwiseNormalizationSplitK2nd ...@@ -110,6 +123,8 @@ struct GridwiseNormalizationSplitK2nd
const GammaDataType* const __restrict__ p_gamma_global, const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global, const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global, YDataType* const __restrict__ p_y_global,
SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
const YElementwiseOperation y_elementwise_op) const YElementwiseOperation y_elementwise_op)
{ {
// Thread/Block id // Thread/Block id
...@@ -145,6 +160,12 @@ struct GridwiseNormalizationSplitK2nd ...@@ -145,6 +160,12 @@ struct GridwiseNormalizationSplitK2nd
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto save_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize());
auto save_inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize());
// VGPR // VGPR
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
in_mean_thread_buf; in_mean_thread_buf;
...@@ -158,6 +179,7 @@ struct GridwiseNormalizationSplitK2nd ...@@ -158,6 +179,7 @@ struct GridwiseNormalizationSplitK2nd
var_thread_buf; var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf; welford_count_thread_buf;
auto& inv_std_thread_buf = var_thread_buf;
auto x_thread_buf = generate_tuple( auto x_thread_buf = generate_tuple(
[&](auto) { [&](auto) {
...@@ -283,6 +305,42 @@ struct GridwiseNormalizationSplitK2nd ...@@ -283,6 +305,42 @@ struct GridwiseNormalizationSplitK2nd
thread_k_cluster_id * YDstVectorSize), thread_k_cluster_id * YDstVectorSize),
y_elementwise_op); y_elementwise_op);
auto threadwise_mean_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
SaveMeanInvStdGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_mean_grid_desc_m,
make_multi_index(block_m_cluster_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_inv_std_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
SaveMeanInvStdGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_inv_std_grid_desc_m,
make_multi_index(block_m_cluster_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
// step1: Merge mean and variance // step1: Merge mean and variance
constexpr auto mean_var_count_thread_copy_step_I0_k = constexpr auto mean_var_count_thread_copy_step_I0_k =
make_multi_index(I0, KThreadClusterSize); make_multi_index(I0, KThreadClusterSize);
...@@ -332,9 +390,33 @@ struct GridwiseNormalizationSplitK2nd ...@@ -332,9 +390,33 @@ struct GridwiseNormalizationSplitK2nd
BlockwiseWelford::Run( BlockwiseWelford::Run(
mean_thread_buf(I), var_thread_buf(I), welford_count_thread_buf(I)); mean_thread_buf(I), var_thread_buf(I), welford_count_thread_buf(I));
inv_std_thread_buf(I) =
type_convert<ComputeDataType>(1.0f) / ck::math::sqrt(var_thread_buf(I) + epsilon);
}); });
// step2: normalization // step2: save mean and inverse std for backward (optional)
if(block_k_cluster_id == 0 && thread_k_cluster_id == 0)
{
if(p_save_mean_global != nullptr)
{
threadwise_mean_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
save_mean_grid_desc_m,
save_mean_global_val_buf);
}
if(p_save_inv_std_global != nullptr)
{
threadwise_inv_std_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_std_thread_buf,
save_inv_std_grid_desc_m,
save_inv_std_global_val_buf);
}
}
// step3: normalization
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
for(index_t k = 0; k < num_k_block_tile_iteration; ++k) for(index_t k = 0; k < num_k_block_tile_iteration; ++k)
...@@ -360,7 +442,6 @@ struct GridwiseNormalizationSplitK2nd ...@@ -360,7 +442,6 @@ struct GridwiseNormalizationSplitK2nd
}); });
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
...@@ -369,7 +450,7 @@ struct GridwiseNormalizationSplitK2nd ...@@ -369,7 +450,7 @@ struct GridwiseNormalizationSplitK2nd
// normalize // normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) * (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor; inv_std_thread_buf(iM);
// gamma // gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
......
...@@ -16,9 +16,11 @@ template <typename XDataType, ...@@ -16,9 +16,11 @@ template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType, typename ComputeDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
typename GridDesc_M_K, typename GridDesc_M_K,
typename GridDesc_M,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
index_t KThreadClusterSize, index_t KThreadClusterSize,
...@@ -32,6 +34,7 @@ template <typename XDataType, ...@@ -32,6 +34,7 @@ template <typename XDataType,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorDim, index_t YDstVectorDim,
index_t YDstVectorSize, index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize,
bool SweepOnce> bool SweepOnce>
struct GridwiseNormalizationWelfordVariance_mk_to_mk struct GridwiseNormalizationWelfordVariance_mk_to_mk
{ {
...@@ -43,6 +46,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -43,6 +46,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"); "Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!");
static_assert(XSrcVectorSize == YDstVectorSize); static_assert(XSrcVectorSize == YDstVectorSize);
static_assert(XSrcVectorSize == GammaSrcVectorSize); static_assert(XSrcVectorSize == GammaSrcVectorSize);
static_assert(XSrcVectorSize == BetaSrcVectorSize); static_assert(XSrcVectorSize == BetaSrcVectorSize);
...@@ -64,6 +71,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -64,6 +71,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}))); make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})));
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
...@@ -77,6 +88,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -77,6 +88,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
ThreadClusterLengths_M_K, ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder>; ThreadClusterArrangeOrder>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -114,17 +127,18 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -114,17 +127,18 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
const GridDesc_M_K& gamma_grid_desc_m_k, const GridDesc_M_K& gamma_grid_desc_m_k,
const GridDesc_M_K& beta_grid_desc_m_k, const GridDesc_M_K& beta_grid_desc_m_k,
const GridDesc_M_K& y_grid_desc_m_k, const GridDesc_M_K& y_grid_desc_m_k,
const GridDesc_M& save_mean_grid_desc_m,
const GridDesc_M& save_inv_std_grid_desc_m,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
ComputeDataType epsilon, ComputeDataType epsilon,
const XDataType* const __restrict__ p_x_global, const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global, const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global, const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global, YDataType* const __restrict__ p_y_global,
SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
const YElementwiseOperation y_elementwise_op) const YElementwiseOperation y_elementwise_op)
{ {
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto x_thread_buf = generate_tuple( auto x_thread_buf = generate_tuple(
[&](auto) { [&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr, return StaticBuffer<AddressSpaceEnum::Vgpr,
...@@ -150,6 +164,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -150,6 +164,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
mean_thread_buf; mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
var_thread_buf; var_thread_buf;
auto& inv_std_thread_buf = var_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
...@@ -226,6 +241,42 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -226,6 +241,42 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
thread_k_cluster_id * YDstVectorSize), thread_k_cluster_id * YDstVectorSize),
y_elementwise_op); y_elementwise_op);
auto threadwise_mean_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_mean_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_inv_std_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_inv_std_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
constexpr auto thread_copy_bwd_step_m_k = constexpr auto thread_copy_bwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
...@@ -239,6 +290,15 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -239,6 +290,15 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize()); p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto save_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize());
auto save_inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford(); auto threadwise_welford = ThreadwiseWelford();
threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id); threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id);
...@@ -279,10 +339,33 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -279,10 +339,33 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
int count = threadwise_welford.cur_count_; int count = threadwise_welford.cur_count_;
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
inv_std_thread_buf(I) = type_convert<ComputeDataType>(1.0f) /
ck::math::sqrt(var_thread_buf(I) + epsilon);
}); });
// save mean and inverse std for backward (optional)
if(thread_k_cluster_id == 0)
{
if(p_save_mean_global != nullptr)
{
threadwise_mean_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
save_mean_grid_desc_m,
save_mean_global_val_buf);
}
if(p_save_inv_std_global != nullptr)
{
threadwise_inv_std_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_std_thread_buf,
save_inv_std_grid_desc_m,
save_inv_std_global_val_buf);
}
}
// normalization
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
...@@ -291,7 +374,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -291,7 +374,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
// normalize // normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) * (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor; inv_std_thread_buf(iM);
// gamma & beta // gamma & beta
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
...@@ -360,8 +443,29 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -360,8 +443,29 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
int count = threadwise_welford.cur_count_; int count = threadwise_welford.cur_count_;
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon);
}); });
if(thread_k_cluster_id == 0)
{
if(p_save_mean_global != nullptr)
{
threadwise_mean_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
save_mean_grid_desc_m,
save_mean_global_val_buf);
}
if(p_save_inv_std_global != nullptr)
{
threadwise_inv_std_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_std_thread_buf,
save_inv_std_grid_desc_m,
save_inv_std_global_val_buf);
}
}
auto thread_copy_tail_m_k = auto thread_copy_tail_m_k =
(num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k; (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k;
...@@ -393,7 +497,6 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -393,7 +497,6 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
}); });
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
...@@ -402,7 +505,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -402,7 +505,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
// normalize // normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) * (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor; inv_std_thread_buf(iM);
// gamma // gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
......
...@@ -20,8 +20,9 @@ template <typename XDataType, ...@@ -20,8 +20,9 @@ template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename AccDataType, typename SaveMeanInvStdDataType,
typename AccElementwiseOperation> typename ComputeDataType,
typename YElementwiseOperation>
struct ReferenceGroupnorm : public device::BaseOperator struct ReferenceGroupnorm : public device::BaseOperator
{ {
// x = [N, H, W, G, C] // x = [N, H, W, G, C]
...@@ -35,14 +36,18 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -35,14 +36,18 @@ struct ReferenceGroupnorm : public device::BaseOperator
const Tensor<GammaDataType>& gamma, const Tensor<GammaDataType>& gamma,
const Tensor<BetaDataType>& beta, const Tensor<BetaDataType>& beta,
Tensor<YDataType>& y, Tensor<YDataType>& y,
AccElementwiseOperation acc_elementwise_op, Tensor<SaveMeanInvStdDataType>& save_mean,
Tensor<SaveMeanInvStdDataType>& save_inv_std,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths, const std::vector<index_t> lengths,
AccDataType epsilon) ComputeDataType epsilon)
: x_(x), : x_(x),
gamma_(gamma), gamma_(gamma),
beta_(beta), beta_(beta),
y_(y), y_(y),
acc_elementwise_op_(acc_elementwise_op), save_mean_(save_mean),
save_inv_std_(save_inv_std),
y_elementwise_op_(y_elementwise_op),
lengths_(lengths), lengths_(lengths),
epsilon_(epsilon) epsilon_(epsilon)
{ {
...@@ -52,9 +57,11 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -52,9 +57,11 @@ struct ReferenceGroupnorm : public device::BaseOperator
const Tensor<XDataType> gamma_; const Tensor<XDataType> gamma_;
const Tensor<XDataType> beta_; const Tensor<XDataType> beta_;
Tensor<YDataType>& y_; Tensor<YDataType>& y_;
AccElementwiseOperation acc_elementwise_op_; Tensor<SaveMeanInvStdDataType>& save_mean_;
Tensor<SaveMeanInvStdDataType>& save_inv_std_;
YElementwiseOperation y_elementwise_op_;
std::vector<index_t> lengths_; std::vector<index_t> lengths_;
AccDataType epsilon_; ComputeDataType epsilon_;
}; };
// Invoker // Invoker
...@@ -68,8 +75,8 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -68,8 +75,8 @@ struct ReferenceGroupnorm : public device::BaseOperator
int G = arg.lengths_[3]; int G = arg.lengths_[3];
int C = arg.lengths_[4]; int C = arg.lengths_[4];
Tensor<AccDataType> mean({N, G}); Tensor<ComputeDataType> mean({N, G});
Tensor<AccDataType> var({N, G}); Tensor<ComputeDataType> var({N, G});
// Compute mean & var in [H, W, C] by Welford Algorithm // Compute mean & var in [H, W, C] by Welford Algorithm
// TODO - parallel for each HWC // TODO - parallel for each HWC
...@@ -78,9 +85,9 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -78,9 +85,9 @@ struct ReferenceGroupnorm : public device::BaseOperator
{ {
for(int g = 0; g < G; ++g) for(int g = 0; g < G; ++g)
{ {
AccDataType mean_val = type_convert<AccDataType>(0.0f); ComputeDataType mean_val = type_convert<ComputeDataType>(0.0f);
AccDataType var_val = type_convert<AccDataType>(0.0f); ComputeDataType var_val = type_convert<ComputeDataType>(0.0f);
int32_t curr_count = 0; int32_t curr_count = 0;
for(int h = 0; h < H; ++h) for(int h = 0; h < H; ++h)
{ {
...@@ -89,10 +96,11 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -89,10 +96,11 @@ struct ReferenceGroupnorm : public device::BaseOperator
for(int c = 0; c < C; ++c) for(int c = 0; c < C; ++c)
{ {
curr_count++; curr_count++;
AccDataType x = type_convert<AccDataType>(arg.x_(n, h, w, g, c)); ComputeDataType x =
AccDataType delta = x - mean_val; type_convert<ComputeDataType>(arg.x_(n, h, w, g, c));
ComputeDataType delta = x - mean_val;
mean_val += delta / curr_count; mean_val += delta / curr_count;
AccDataType delta2 = x - mean_val; ComputeDataType delta2 = x - mean_val;
var_val += delta * delta2; var_val += delta * delta2;
} }
} }
...@@ -100,6 +108,12 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -100,6 +108,12 @@ struct ReferenceGroupnorm : public device::BaseOperator
mean(n, g) = mean_val; mean(n, g) = mean_val;
var(n, g) = var_val / curr_count; var(n, g) = var_val / curr_count;
arg.save_mean_(n, g) = ck::type_convert<SaveMeanInvStdDataType>(mean(n, g));
ComputeDataType divisor =
static_cast<ComputeDataType>(1) / ck::math::sqrt(var(n, g) + arg.epsilon_);
arg.save_inv_std_(n, g) = ck::type_convert<SaveMeanInvStdDataType>(divisor);
} }
} }
...@@ -114,15 +128,19 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -114,15 +128,19 @@ struct ReferenceGroupnorm : public device::BaseOperator
{ {
for(int c = 0; c < C; ++c) for(int c = 0; c < C; ++c)
{ {
AccDataType x = type_convert<AccDataType>(arg.x_(n, h, w, g, c)); ComputeDataType x =
AccDataType gamma = type_convert<AccDataType>(arg.gamma_(g, c)); type_convert<ComputeDataType>(arg.x_(n, h, w, g, c));
AccDataType beta = type_convert<AccDataType>(arg.beta_(g, c)); ComputeDataType gamma =
AccDataType mean_val = type_convert<AccDataType>(mean(n, g)); type_convert<ComputeDataType>(arg.gamma_(g, c));
AccDataType var_val = type_convert<AccDataType>(var(n, g)); ComputeDataType beta =
AccDataType y = gamma * (x - mean_val) / type_convert<ComputeDataType>(arg.beta_(g, c));
ck::math::sqrt(arg.epsilon_ + var_val) + ComputeDataType mean_val =
beta; type_convert<ComputeDataType>(mean(n, g));
arg.acc_elementwise_op_(y, y); ComputeDataType var_val = type_convert<ComputeDataType>(var(n, g));
ComputeDataType y = gamma * (x - mean_val) /
ck::math::sqrt(arg.epsilon_ + var_val) +
beta;
arg.y_elementwise_op_(y, y);
arg.y_(n, h, w, g, c) = type_convert<YDataType>(y); arg.y_(n, h, w, g, c) = type_convert<YDataType>(y);
} }
} }
...@@ -159,11 +177,14 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -159,11 +177,14 @@ struct ReferenceGroupnorm : public device::BaseOperator
const Tensor<GammaDataType>& gamma, const Tensor<GammaDataType>& gamma,
const Tensor<BetaDataType>& beta, const Tensor<BetaDataType>& beta,
Tensor<YDataType>& y, Tensor<YDataType>& y,
AccElementwiseOperation acc_elementwise_op, Tensor<SaveMeanInvStdDataType>& save_mean,
Tensor<SaveMeanInvStdDataType>& save_inv_std,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths, const std::vector<index_t> lengths,
AccDataType epsilon) ComputeDataType epsilon)
{ {
return Argument{x, gamma, beta, y, acc_elementwise_op, lengths, epsilon}; return Argument{
x, gamma, beta, y, save_mean, save_inv_std, y_elementwise_op, lengths, epsilon};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
...@@ -20,8 +20,9 @@ template <typename XDataType, ...@@ -20,8 +20,9 @@ template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename AccDataType, typename SaveMeanInvStdDataType,
typename AccElementwiseOperation, typename ComputeDataType,
typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
struct ReferenceLayernorm : public device::BaseOperator struct ReferenceLayernorm : public device::BaseOperator
...@@ -36,15 +37,19 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -36,15 +37,19 @@ struct ReferenceLayernorm : public device::BaseOperator
const Tensor<GammaDataType>& gamma_n, const Tensor<GammaDataType>& gamma_n,
const Tensor<BetaDataType>& beta_n, const Tensor<BetaDataType>& beta_n,
Tensor<YDataType>& y_m_n, Tensor<YDataType>& y_m_n,
AccElementwiseOperation acc_elementwise_op, Tensor<SaveMeanInvStdDataType>& save_mean_m,
Tensor<SaveMeanInvStdDataType>& save_inv_std_m,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths, const std::vector<index_t> lengths,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccDataType epsilon) ComputeDataType epsilon)
: x_m_n_(x_m_n), : x_m_n_(x_m_n),
gamma_n_(gamma_n), gamma_n_(gamma_n),
beta_n_(beta_n), beta_n_(beta_n),
y_m_n_(y_m_n), y_m_n_(y_m_n),
acc_elementwise_op_(acc_elementwise_op), save_mean_m_(save_mean_m),
save_inv_std_m_(save_inv_std_m),
y_elementwise_op_(y_elementwise_op),
lengths_(lengths), lengths_(lengths),
reduceDims_(reduceDims), reduceDims_(reduceDims),
epsilon_(epsilon) epsilon_(epsilon)
...@@ -55,10 +60,12 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -55,10 +60,12 @@ struct ReferenceLayernorm : public device::BaseOperator
const Tensor<XDataType> gamma_n_; const Tensor<XDataType> gamma_n_;
const Tensor<XDataType> beta_n_; const Tensor<XDataType> beta_n_;
Tensor<YDataType>& y_m_n_; Tensor<YDataType>& y_m_n_;
AccElementwiseOperation acc_elementwise_op_; Tensor<SaveMeanInvStdDataType>& save_mean_m_;
Tensor<SaveMeanInvStdDataType>& save_inv_std_m_;
YElementwiseOperation y_elementwise_op_;
std::vector<index_t> lengths_; std::vector<index_t> lengths_;
std::vector<index_t> reduceDims_; std::vector<index_t> reduceDims_;
AccDataType epsilon_; ComputeDataType epsilon_;
}; };
// Invoker // Invoker
...@@ -69,8 +76,8 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -69,8 +76,8 @@ struct ReferenceLayernorm : public device::BaseOperator
int M = arg.lengths_[0]; int M = arg.lengths_[0];
int N = arg.lengths_[1]; int N = arg.lengths_[1];
Tensor<AccDataType> mean({M}); Tensor<ComputeDataType> mean({M});
Tensor<AccDataType> var({M}); Tensor<ComputeDataType> var({M});
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
...@@ -79,7 +86,7 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -79,7 +86,7 @@ struct ReferenceLayernorm : public device::BaseOperator
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n)); auto x_val = ck::type_convert<ComputeDataType>(arg.x_m_n_(m, n));
mean(m) += x_val; mean(m) += x_val;
var(m) += x_val * x_val; var(m) += x_val * x_val;
} }
...@@ -90,17 +97,21 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -90,17 +97,21 @@ struct ReferenceLayernorm : public device::BaseOperator
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
AccDataType divisor = ComputeDataType divisor =
static_cast<AccDataType>(1) / ck::math::sqrt(var(m) + arg.epsilon_); static_cast<ComputeDataType>(1) / ck::math::sqrt(var(m) + arg.epsilon_);
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n)); auto x_val = ck::type_convert<ComputeDataType>(arg.x_m_n_(m, n));
auto y_val = (x_val - mean(m)) * divisor; auto gamma_val = ck::type_convert<ComputeDataType>(arg.gamma_n_(n));
y_val = (y_val * arg.gamma_n_(n)) + arg.beta_n_(n); auto beta_val = ck::type_convert<ComputeDataType>(arg.beta_n_(n));
arg.acc_elementwise_op_(y_val, y_val); auto y_val = (x_val - mean(m)) * divisor;
y_val = (y_val * gamma_val) + beta_val;
arg.y_elementwise_op_(y_val, y_val);
arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val); arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val);
} }
arg.save_mean_m_(m) = ck::type_convert<SaveMeanInvStdDataType>(mean(m));
arg.save_inv_std_m_(m) = ck::type_convert<SaveMeanInvStdDataType>(divisor);
} }
return 0; return 0;
...@@ -140,13 +151,23 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -140,13 +151,23 @@ struct ReferenceLayernorm : public device::BaseOperator
const Tensor<GammaDataType>& gamma_n, const Tensor<GammaDataType>& gamma_n,
const Tensor<BetaDataType>& beta_n, const Tensor<BetaDataType>& beta_n,
Tensor<YDataType>& y_m_n, Tensor<YDataType>& y_m_n,
AccElementwiseOperation acc_elementwise_op, Tensor<SaveMeanInvStdDataType>& save_mean_m,
Tensor<SaveMeanInvStdDataType>& save_inv_std_m,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths, const std::vector<index_t> lengths,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccDataType epsilon) ComputeDataType epsilon)
{ {
return Argument{ return Argument{x_m_n,
x_m_n, gamma_n, beta_n, y_m_n, acc_elementwise_op, lengths, reduceDims, epsilon}; gamma_n,
beta_n,
y_m_n,
save_mean_m,
save_inv_std_m,
y_elementwise_op,
lengths,
reduceDims,
epsilon};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment