Unverified Commit 3696fe1c authored by rocking's avatar rocking Committed by GitHub
Browse files

Layernorm and groupnorm support to save mean and inverse std in forward (#929)

* save mean and inverse std in normalization

* Save mean and inverse std in splitK

* Vector save mean and inv std

* Modify instance for save mean and std

* simplify the layernorm example

* Save mean and std in groupnorm example

* Save mean and inv std in ckProfiler and test

* Remove compute data type from base class

* Save mean and inv std in client example

* Add changelog

* clang format

* Fix compile error

* Refine naming

* Avoid error in bf16

* revert changelog
parent 58338bb2
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
#include "ck/library/tensor_operation_instance/gpu/normalization.hpp" #include "ck/library/tensor_operation_instance/gpu/normalization.hpp"
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using ComputeDataType = float; using SaveMeanInvStdDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#define SAVE_MEAN_INV_STD
constexpr int Rank = 2; constexpr int Rank = 2;
constexpr int NumReduceDim = 1; constexpr int NumReduceDim = 1;
...@@ -50,12 +52,16 @@ int main(int argc, char* argv[]) ...@@ -50,12 +52,16 @@ int main(int argc, char* argv[])
SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N); SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N);
SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N); SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N);
SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size); SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size);
#ifdef SAVE_MEAN_INV_STD
SimpleDeviceMem save_mean_device_buf(sizeof(SaveMeanInvStdDataType) * M);
SimpleDeviceMem save_inv_std_device_buf(sizeof(SaveMeanInvStdDataType) * M);
#endif
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType, using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
PassThrough, PassThrough,
Rank, Rank,
NumReduceDim>; NumReduceDim>;
...@@ -84,14 +90,21 @@ int main(int argc, char* argv[]) ...@@ -84,14 +90,21 @@ int main(int argc, char* argv[])
{0, 1}, // gammaStrides {0, 1}, // gammaStrides
{0, 1}, // betaStrides {0, 1}, // betaStrides
{Stride, 1}, // yStrides {Stride, 1}, // yStrides
{1}, // save_mean Strides
{1}, // save_inv_std Strides
{1}, // reduceDims {1}, // reduceDims
1e-4, 1e-4,
x_device_buf.GetDeviceBuffer(), x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(), beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(), y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr, nullptr,
nullptr, nullptr,
#endif
PassThrough{}); PassThrough{});
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
...@@ -109,6 +122,10 @@ int main(int argc, char* argv[]) ...@@ -109,6 +122,10 @@ int main(int argc, char* argv[])
std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N + std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N +
sizeof(BetaDataType) * N + sizeof(YDataType) * M * N; sizeof(BetaDataType) * N + sizeof(YDataType) * M * N;
#ifdef SAVE_MEAN_INV_STD
num_byte += sizeof(SaveMeanInvStdDataType) * M * 2;
#endif
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
...@@ -140,17 +157,24 @@ int main(int argc, char* argv[]) ...@@ -140,17 +157,24 @@ int main(int argc, char* argv[])
auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths
{Stride, 1}, // xStrides {Stride, 1}, // xStrides
{1}, // gammaStrides {0, 1}, // gammaStrides
{1}, // betaStrides {0, 1}, // betaStrides
{Stride, 1}, // yStrides {Stride, 1}, // yStrides
{1}, // save_mean Strides
{1}, // save_inv_std Strides
{1}, // reduceDims {1}, // reduceDims
1e-4, 1e-4,
x_device_buf.GetDeviceBuffer(), x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(), beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(), y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr, nullptr,
nullptr, nullptr,
#endif
PassThrough{}); PassThrough{});
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
#include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp" #include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp"
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = float; using GammaDataType = float;
using BetaDataType = float; using BetaDataType = float;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using ComputeDataType = float; using SaveMeanInvStdDataType = float;
using Swish = ck::tensor_operation::element_wise::Swish; using Swish = ck::tensor_operation::element_wise::Swish;
#define SAVE_MEAN_INV_STD
constexpr int Rank = 5; constexpr int Rank = 5;
constexpr int NumReduceDim = 3; constexpr int NumReduceDim = 3;
...@@ -49,19 +51,24 @@ int main(int argc, char* argv[]) ...@@ -49,19 +51,24 @@ int main(int argc, char* argv[])
std::size_t xy_size = N * H * W * G * C; std::size_t xy_size = N * H * W * G * C;
std::size_t gamma_beta_size = G * C; std::size_t gamma_beta_size = G * C;
std::vector<ck::index_t> xy_strides = {H * W * G * C, W * G * C, G * C, C, 1}; std::vector<ck::index_t> xy_strides = {H * W * G * C, W * G * C, G * C, C, 1};
std::vector<ck::index_t> gamma_beta_strides = {0, 0, 0, C, 1}; std::vector<ck::index_t> gamma_beta_strides = {0, 0, 0, C, 1};
std::vector<ck::index_t> save_mean_inv_std_strides = {G, 1};
SimpleDeviceMem x_device_buf(sizeof(XDataType) * xy_size); SimpleDeviceMem x_device_buf(sizeof(XDataType) * xy_size);
SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_beta_size); SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_beta_size);
SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * gamma_beta_size); SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * gamma_beta_size);
SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size); SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size);
#ifdef SAVE_MEAN_INV_STD
SimpleDeviceMem save_mean_device_buf(sizeof(SaveMeanInvStdDataType) * N * G);
SimpleDeviceMem save_inv_std_device_buf(sizeof(SaveMeanInvStdDataType) * N * G);
#endif
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType, using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
Swish, Swish,
Rank, Rank,
NumReduceDim>; NumReduceDim>;
...@@ -75,19 +82,26 @@ int main(int argc, char* argv[]) ...@@ -75,19 +82,26 @@ int main(int argc, char* argv[])
const auto& generic_op_ptr = op_ptrs[0]; const auto& generic_op_ptr = op_ptrs[0];
auto generic_argument_ptr = auto generic_argument_ptr =
generic_op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths generic_op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
xy_strides, // xStrides xy_strides, // xStrides
gamma_beta_strides, // gammaStrides gamma_beta_strides, // gammaStrides
gamma_beta_strides, // betaStrides gamma_beta_strides, // betaStrides
xy_strides, // yStrides xy_strides, // yStrides
{1, 2, 4}, // reduceDims save_mean_inv_std_strides, // save_mean Strides
save_mean_inv_std_strides, // save_inv_std Strides
{1, 2, 4}, // reduceDims
1e-6, 1e-6,
x_device_buf.GetDeviceBuffer(), x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(), beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(), y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr, nullptr,
nullptr, nullptr,
#endif
Swish{}); Swish{});
if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get())) if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get()))
...@@ -107,21 +121,29 @@ int main(int argc, char* argv[]) ...@@ -107,21 +121,29 @@ int main(int argc, char* argv[])
for(int i = 0; i < op_ptrs.size(); ++i) for(int i = 0; i < op_ptrs.size(); ++i)
{ {
auto& op_ptr = op_ptrs[i]; auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths auto argument_ptr =
xy_strides, // xStrides op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
gamma_beta_strides, // gammaStrides xy_strides, // xStrides
gamma_beta_strides, // betaStrides gamma_beta_strides, // gammaStrides
xy_strides, // yStrides gamma_beta_strides, // betaStrides
{1, 2, 4}, // reduceDims xy_strides, // yStrides
1e-6, save_mean_inv_std_strides, // save_mean Strides
x_device_buf.GetDeviceBuffer(), save_mean_inv_std_strides, // save_inv_std Strides
gamma_device_buf.GetDeviceBuffer(), {1, 2, 4}, // reduceDims
beta_device_buf.GetDeviceBuffer(), 1e-6,
y_device_buf.GetDeviceBuffer(), x_device_buf.GetDeviceBuffer(),
nullptr, gamma_device_buf.GetDeviceBuffer(),
nullptr, beta_device_buf.GetDeviceBuffer(),
Swish{}); y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr,
nullptr,
#endif
Swish{});
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
...@@ -139,6 +161,10 @@ int main(int argc, char* argv[]) ...@@ -139,6 +161,10 @@ int main(int argc, char* argv[])
sizeof(XDataType) * xy_size + sizeof(GammaDataType) * gamma_beta_size + sizeof(XDataType) * xy_size + sizeof(GammaDataType) * gamma_beta_size +
sizeof(BetaDataType) * gamma_beta_size + sizeof(YDataType) * xy_size; sizeof(BetaDataType) * gamma_beta_size + sizeof(YDataType) * xy_size;
#ifdef SAVE_MEAN_INV_STD
num_byte += sizeof(SaveMeanInvStdDataType) * N * G * 2;
#endif
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
...@@ -169,20 +195,28 @@ int main(int argc, char* argv[]) ...@@ -169,20 +195,28 @@ int main(int argc, char* argv[])
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl; << std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths auto argument_ptr =
xy_strides, // xStrides op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
gamma_beta_strides, // gammaStrides xy_strides, // xStrides
gamma_beta_strides, // betaStrides gamma_beta_strides, // gammaStrides
xy_strides, // yStrides gamma_beta_strides, // betaStrides
{1, 2, 4}, // reduceDims xy_strides, // yStrides
1e-6, save_mean_inv_std_strides, // save_mean Strides
x_device_buf.GetDeviceBuffer(), save_mean_inv_std_strides, // save_inv_std Strides
gamma_device_buf.GetDeviceBuffer(), {1, 2, 4}, // reduceDims
beta_device_buf.GetDeviceBuffer(), 1e-6,
y_device_buf.GetDeviceBuffer(), x_device_buf.GetDeviceBuffer(),
nullptr, gamma_device_buf.GetDeviceBuffer(),
nullptr, beta_device_buf.GetDeviceBuffer(),
Swish{}); y_device_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf.GetDeviceBuffer(),
save_inv_std_device_buf.GetDeviceBuffer(),
#else
nullptr,
nullptr,
#endif
Swish{});
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
......
...@@ -114,12 +114,15 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n, ...@@ -114,12 +114,15 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
BetaDataType, BetaDataType,
HDataType, HDataType,
AccDataType, AccDataType,
AccDataType,
HElementOp, HElementOp,
2, 2,
1>; 1>;
Tensor<EMeanVarDataType> e_m_n(HostTensorDescriptor{M, N}); Tensor<EMeanVarDataType> e_m_n(HostTensorDescriptor{M, N});
Tensor<AccDataType> c_m_n(HostTensorDescriptor{M, N}); Tensor<AccDataType> c_m_n(HostTensorDescriptor{M, N});
Tensor<AccDataType> save_mean({M});
Tensor<AccDataType> save_inv_std({M});
auto ref_gemm = ReferenceGemm{}; auto ref_gemm = ReferenceGemm{};
auto ref_gemm_invoker = ref_gemm.MakeInvoker(); auto ref_gemm_invoker = ref_gemm.MakeInvoker();
...@@ -145,7 +148,7 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n, ...@@ -145,7 +148,7 @@ void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
auto ref_layernorm_invoker = ref_layernorm.MakeInvoker(); auto ref_layernorm_invoker = ref_layernorm.MakeInvoker();
auto ref_layernorm_argument = ref_layernorm.MakeArgument( auto ref_layernorm_argument = ref_layernorm.MakeArgument(
e_m_n, gamma_n, beta_n, h_m_n, h_element_op, {M, N}, {1}, epsilon); e_m_n, gamma_n, beta_n, h_m_n, save_mean, save_inv_std, h_element_op, {M, N}, {1}, epsilon);
ref_layernorm_invoker.Run(ref_layernorm_argument); ref_layernorm_invoker.Run(ref_layernorm_argument);
} }
......
...@@ -3,12 +3,15 @@ ...@@ -3,12 +3,15 @@
#include "common.hpp" #include "common.hpp"
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using ComputeDataType = float; using SaveMeanInvStdDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#define SAVE_MEAN_INV_STD
constexpr int Rank = 2; constexpr int Rank = 2;
constexpr int NumReduceDim = 1; constexpr int NumReduceDim = 1;
...@@ -19,6 +22,7 @@ using DeviceInstance = ...@@ -19,6 +22,7 @@ using DeviceInstance =
BetaDataType, BetaDataType,
ComputeDataType, ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
PassThrough, PassThrough,
Rank, Rank,
NumReduceDim, NumReduceDim,
...@@ -33,7 +37,8 @@ using DeviceInstance = ...@@ -33,7 +37,8 @@ using DeviceInstance =
8, // GammaScalarPerVector 8, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K) 1, // BetaVecDim (0=M, 1=K)
8, // BetaScalarPerVector 8, // BetaScalarPerVector
8>; // OutScalarPerVector 8, // YScalarPerVector
1>; // SaveMeanInvStdScalarPerVector
#include "run_layernorm_example.inc" #include "run_layernorm_example.inc"
int main() { return run_groupnorm_example<DeviceInstance>(); } int main() { return run_groupnorm_example<DeviceInstance>(); }
...@@ -3,12 +3,15 @@ ...@@ -3,12 +3,15 @@
#include "common.hpp" #include "common.hpp"
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using ComputeDataType = float; using SaveMeanInvStdDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#define SAVE_MEAN_INV_STD
constexpr int Rank = 2; constexpr int Rank = 2;
constexpr int NumReduceDim = 1; constexpr int NumReduceDim = 1;
...@@ -19,6 +22,7 @@ using DeviceInstance = ...@@ -19,6 +22,7 @@ using DeviceInstance =
BetaDataType, BetaDataType,
ComputeDataType, ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
PassThrough, PassThrough,
Rank, Rank,
NumReduceDim, NumReduceDim,
...@@ -33,7 +37,8 @@ using DeviceInstance = ...@@ -33,7 +37,8 @@ using DeviceInstance =
8, // GammaScalarPerVector 8, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K) 1, // BetaVecDim (0=M, 1=K)
8, // BetaScalarPerVector 8, // BetaScalarPerVector
8>; // YScalarPerVector 8, // YScalarPerVector
1>; // SaveMeanInvStdScalarPerVector
#include "run_layernorm_example.inc" #include "run_layernorm_example.inc"
......
...@@ -10,22 +10,13 @@ int run_groupnorm_example() ...@@ -10,22 +10,13 @@ int run_groupnorm_example()
ck::index_t M = 1024; ck::index_t M = 1024;
ck::index_t N = 1024; ck::index_t N = 1024;
ck::index_t Stride = N;
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { Tensor<XDataType> x({M, N});
return HostTensorDescriptor({len}, {stride}); Tensor<GammaDataType> gamma({N});
}; Tensor<BetaDataType> beta({N});
Tensor<YDataType> y({M, N});
auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) { Tensor<SaveMeanInvStdDataType> save_mean({M});
using namespace ck::literals; Tensor<SaveMeanInvStdDataType> save_inv_std({M});
return HostTensorDescriptor({row, col}, {stride, 1_uz});
};
Tensor<XDataType> x(f_host_tensor_descriptor2d(M, N, Stride));
Tensor<GammaDataType> gamma(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta(f_host_tensor_descriptor1d(N, 1));
Tensor<YDataType> y(f_host_tensor_descriptor2d(M, N, Stride));
x.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0}); x.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0}); gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0});
...@@ -35,6 +26,11 @@ int run_groupnorm_example() ...@@ -35,6 +26,11 @@ int run_groupnorm_example()
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize());
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize());
#ifdef SAVE_MEAN_INV_STD
DeviceMem save_mean_dev(sizeof(SaveMeanInvStdDataType) * save_mean.mDesc.GetElementSpaceSize());
DeviceMem save_inv_std_dev(sizeof(SaveMeanInvStdDataType) *
save_inv_std.mDesc.GetElementSpaceSize());
#endif
x_dev.ToDevice(x.mData.data()); x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data()); gamma_dev.ToDevice(gamma.mData.data());
...@@ -47,14 +43,23 @@ int run_groupnorm_example() ...@@ -47,14 +43,23 @@ int run_groupnorm_example()
{0, 1}, {0, 1},
{0, 1}, {0, 1},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
{1}, {1},
1e-4, 1e-4,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(), gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(), y_dev.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_dev.GetDeviceBuffer(),
save_inv_std_dev.GetDeviceBuffer(),
#else
nullptr, nullptr,
nullptr, nullptr,
#endif
PassThrough{}); PassThrough{});
if(!device_instance.IsSupportedArgument(argument_ptr.get())) if(!device_instance.IsSupportedArgument(argument_ptr.get()))
...@@ -72,24 +77,45 @@ int run_groupnorm_example() ...@@ -72,24 +77,45 @@ int run_groupnorm_example()
bool pass = true; bool pass = true;
{ {
Tensor<YDataType> host_y(f_host_tensor_descriptor2d(M, N, Stride)); Tensor<YDataType> host_y({M, N});
using ReferenceInstance = ck::tensor_operation::host::ReferenceLayernorm<XDataType, Tensor<SaveMeanInvStdDataType> host_save_mean({M});
GammaDataType, Tensor<SaveMeanInvStdDataType> host_save_inv_std({M});
BetaDataType,
YDataType, using ReferenceInstance =
ComputeDataType, ck::tensor_operation::host::ReferenceLayernorm<XDataType,
PassThrough, GammaDataType,
Rank, BetaDataType,
NumReduceDim>; YDataType,
SaveMeanInvStdDataType,
ComputeDataType,
PassThrough,
Rank,
NumReduceDim>;
ReferenceInstance ref; ReferenceInstance ref;
auto ref_argument = auto ref_argument = ref.MakeArgument(x,
ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, {M, N}, {1}, 1e-4); gamma,
auto ref_invoker = ref.MakeInvoker(); beta,
host_y,
host_save_mean,
host_save_inv_std,
PassThrough{},
{M, N},
{1},
1e-4);
auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3); pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results (y)", 1e-3, 1e-3);
#ifdef SAVE_MEAN_INV_STD
save_mean_dev.FromDevice(save_mean.mData.data());
save_inv_std_dev.FromDevice(save_inv_std.mData.data());
pass &= ck::utils::check_err(
save_mean, host_save_mean, "Error: Incorrect results (mean)", 1e-3, 1e-3);
pass &= ck::utils::check_err(
save_inv_std, host_save_inv_std, "Error: Incorrect results (inv_std)", 1e-3, 1e-3);
#endif
} }
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
...@@ -6,11 +6,14 @@ ...@@ -6,11 +6,14 @@
constexpr int Rank = 5; constexpr int Rank = 5;
constexpr int NumReduceDim = 3; constexpr int NumReduceDim = 3;
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using ComputeDataType = float; using SaveMeanInvStdDataType = float;
using ComputeDataType = float;
#define SAVE_MEAN_INV_STD
struct YElementOp struct YElementOp
{ {
...@@ -39,6 +42,7 @@ using DeviceInstance = ...@@ -39,6 +42,7 @@ using DeviceInstance =
BetaDataType, BetaDataType,
ComputeDataType, ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
YElementOp, YElementOp,
Rank, Rank,
NumReduceDim, NumReduceDim,
...@@ -53,7 +57,8 @@ using DeviceInstance = ...@@ -53,7 +57,8 @@ using DeviceInstance =
2, // GammaScalarPerVector 2, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K) 1, // BetaVecDim (0=M, 1=K)
2, // BetaScalarPerVector 2, // BetaScalarPerVector
2>; // OutScalarPerVector 2, // YScalarPerVector
1>; // SaveMeanInvStdScalarPerVector
#include "run_groupnorm_example.inc" #include "run_groupnorm_example.inc"
......
...@@ -6,12 +6,15 @@ ...@@ -6,12 +6,15 @@
constexpr int Rank = 5; constexpr int Rank = 5;
constexpr int NumReduceDim = 3; constexpr int NumReduceDim = 3;
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using ComputeDataType = float; using SaveMeanInvStdDataType = float;
using YElementOp = ck::tensor_operation::element_wise::Swish; using ComputeDataType = float;
using YElementOp = ck::tensor_operation::element_wise::Swish;
#define SAVE_MEAN_INV_STD
using DeviceInstance = using DeviceInstance =
ck::tensor_operation::device::DeviceNormalizationSplitKImpl<XDataType, ck::tensor_operation::device::DeviceNormalizationSplitKImpl<XDataType,
...@@ -19,6 +22,7 @@ using DeviceInstance = ...@@ -19,6 +22,7 @@ using DeviceInstance =
BetaDataType, BetaDataType,
ComputeDataType, ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
YElementOp, YElementOp,
Rank, Rank,
NumReduceDim, NumReduceDim,
...@@ -33,7 +37,8 @@ using DeviceInstance = ...@@ -33,7 +37,8 @@ using DeviceInstance =
2, // GammaScalarPerVector 2, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K) 1, // BetaVecDim (0=M, 1=K)
2, // BetaScalarPerVector 2, // BetaScalarPerVector
2>; // OutScalarPerVector 2, // YScalarPerVector
1>; // SaveMeanInvStdScalarPerVector
#include "run_groupnorm_example.inc" #include "run_groupnorm_example.inc"
......
...@@ -6,12 +6,15 @@ ...@@ -6,12 +6,15 @@
constexpr int Rank = 5; constexpr int Rank = 5;
constexpr int NumReduceDim = 3; constexpr int NumReduceDim = 3;
using XDataType = ck::half_t; using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using ComputeDataType = float; using SaveMeanInvStdDataType = float;
using YElementOp = ck::tensor_operation::element_wise::Swish; using ComputeDataType = float;
using YElementOp = ck::tensor_operation::element_wise::Swish;
#define SAVE_MEAN_INV_STD
using DeviceInstance = using DeviceInstance =
ck::tensor_operation::device::DeviceNormalizationImpl<XDataType, ck::tensor_operation::device::DeviceNormalizationImpl<XDataType,
...@@ -19,6 +22,7 @@ using DeviceInstance = ...@@ -19,6 +22,7 @@ using DeviceInstance =
BetaDataType, BetaDataType,
ComputeDataType, ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
YElementOp, YElementOp,
Rank, Rank,
NumReduceDim, NumReduceDim,
...@@ -33,7 +37,8 @@ using DeviceInstance = ...@@ -33,7 +37,8 @@ using DeviceInstance =
2, // GammaScalarPerVector 2, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K) 1, // BetaVecDim (0=M, 1=K)
2, // BetaScalarPerVector 2, // BetaScalarPerVector
2>; // OutScalarPerVector 2, // YScalarPerVector
1>; // SaveMeanInvStdScalarPerVector
#include "run_groupnorm_example.inc" #include "run_groupnorm_example.inc"
......
...@@ -34,6 +34,8 @@ int run_groupnorm_example(int argc, char* argv[]) ...@@ -34,6 +34,8 @@ int run_groupnorm_example(int argc, char* argv[])
Tensor<YDataType> y({N, H, W, G, C}); Tensor<YDataType> y({N, H, W, G, C});
Tensor<GammaDataType> gamma({G, C}); Tensor<GammaDataType> gamma({G, C});
Tensor<BetaDataType> beta({G, C}); Tensor<BetaDataType> beta({G, C});
Tensor<SaveMeanInvStdDataType> save_mean({N, G});
Tensor<SaveMeanInvStdDataType> save_inv_std({N, G});
ck::utils::FillUniformDistribution<XDataType>{0.f, 1.f}(x); ck::utils::FillUniformDistribution<XDataType>{0.f, 1.f}(x);
ck::utils::FillUniformDistribution<GammaDataType>{0.f, 1.f}(gamma); ck::utils::FillUniformDistribution<GammaDataType>{0.f, 1.f}(gamma);
...@@ -43,6 +45,11 @@ int run_groupnorm_example(int argc, char* argv[]) ...@@ -43,6 +45,11 @@ int run_groupnorm_example(int argc, char* argv[])
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize());
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize());
#ifdef SAVE_MEAN_INV_STD
DeviceMem save_mean_dev(sizeof(SaveMeanInvStdDataType) * save_mean.mDesc.GetElementSpaceSize());
DeviceMem save_inv_std_dev(sizeof(SaveMeanInvStdDataType) *
save_inv_std.mDesc.GetElementSpaceSize());
#endif
x_dev.ToDevice(x.mData.data()); x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data()); gamma_dev.ToDevice(gamma.mData.data());
...@@ -57,14 +64,23 @@ int run_groupnorm_example(int argc, char* argv[]) ...@@ -57,14 +64,23 @@ int run_groupnorm_example(int argc, char* argv[])
{0, 0, 0, C, 1}, {0, 0, 0, C, 1},
{0, 0, 0, C, 1}, {0, 0, 0, C, 1},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
{1, 2, 4}, // reduction dimension: [H, W, C] {1, 2, 4}, // reduction dimension: [H, W, C]
1e-6, 1e-6,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(), gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(), y_dev.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_dev.GetDeviceBuffer(),
save_inv_std_dev.GetDeviceBuffer(),
#else
nullptr, nullptr,
nullptr, nullptr,
#endif
y_element_op); y_element_op);
if(!device_instance.IsSupportedArgument(argument_ptr.get())) if(!device_instance.IsSupportedArgument(argument_ptr.get()))
...@@ -92,21 +108,40 @@ int run_groupnorm_example(int argc, char* argv[]) ...@@ -92,21 +108,40 @@ int run_groupnorm_example(int argc, char* argv[])
bool pass = true; bool pass = true;
{ {
Tensor<YDataType> host_y({N, H, W, G, C}); Tensor<YDataType> host_y({N, H, W, G, C});
using ReferenceInstance = ck::tensor_operation::host::ReferenceGroupnorm<XDataType, Tensor<SaveMeanInvStdDataType> host_save_mean(HostTensorDescriptor{N, G});
GammaDataType, Tensor<SaveMeanInvStdDataType> host_save_inv_std(HostTensorDescriptor{N, G});
BetaDataType, using ReferenceInstance =
YDataType, ck::tensor_operation::host::ReferenceGroupnorm<XDataType,
ComputeDataType, GammaDataType,
YElementOp>; BetaDataType,
YDataType,
SaveMeanInvStdDataType,
ComputeDataType,
YElementOp>;
ReferenceInstance ref; ReferenceInstance ref;
auto ref_argument = auto ref_argument = ref.MakeArgument(x,
ref.MakeArgument(x, gamma, beta, host_y, y_element_op, {N, H, W, G, C}, 1e-6); gamma,
auto ref_invoker = ref.MakeInvoker(); beta,
host_y,
host_save_mean,
host_save_inv_std,
y_element_op,
{N, H, W, G, C},
1e-6);
auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3); pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3);
#ifdef SAVE_MEAN_INV_STD
save_mean_dev.FromDevice(save_mean.mData.data());
save_inv_std_dev.FromDevice(save_inv_std.mData.data());
pass &= ck::utils::check_err(
save_mean, host_save_mean, "Error: Incorrect results (mean)", 1e-3, 1e-3);
pass &= ck::utils::check_err(
save_inv_std, host_save_inv_std, "Error: Incorrect results (inv_std)", 1e-3, 1e-3);
#endif
} }
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
...@@ -167,20 +167,31 @@ int main() ...@@ -167,20 +167,31 @@ int main()
XElementwiseOperation>(x, a, b, mn, XElementwiseOperation{}); XElementwiseOperation>(x, a, b, mn, XElementwiseOperation{});
Tensor<YDataType> host_y(f_host_tensor_descriptor2d(M, N, Stride)); Tensor<YDataType> host_y(f_host_tensor_descriptor2d(M, N, Stride));
Tensor<AccDataType> host_save_mean({M});
Tensor<AccDataType> host_save_inv_std({M});
using ReferenceInstance = using ReferenceInstance =
ck::tensor_operation::host::ReferenceLayernorm<XDataType, ck::tensor_operation::host::ReferenceLayernorm<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, AccDataType,
AccDataType,
YElementwiseOperation, YElementwiseOperation,
Rank, Rank,
NumReduceDim>; NumReduceDim>;
ReferenceInstance ref; ReferenceInstance ref;
auto ref_argument = auto ref_argument = ref.MakeArgument(x,
ref.MakeArgument(x, gamma, beta, host_y, YElementwiseOperation{}, {M, N}, {1}, 1e-4); gamma,
auto ref_invoker = ref.MakeInvoker(); beta,
host_y,
host_save_mean,
host_save_inv_std,
YElementwiseOperation{},
{M, N},
{1},
1e-4);
auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
......
...@@ -14,8 +14,8 @@ namespace device { ...@@ -14,8 +14,8 @@ namespace device {
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
...@@ -27,6 +27,8 @@ struct DeviceNormalization : public BaseOperator ...@@ -27,6 +27,8 @@ struct DeviceNormalization : public BaseOperator
const std::vector<index_t> gammaStrides, const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> saveMeanStrides,
const std::vector<index_t> saveInvStdStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
double epsilon, double epsilon,
const void* p_x, const void* p_x,
...@@ -43,16 +45,16 @@ struct DeviceNormalization : public BaseOperator ...@@ -43,16 +45,16 @@ struct DeviceNormalization : public BaseOperator
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization<XDataType, using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
YElementwiseOperation, YElementwiseOperation,
Rank, Rank,
NumReduceDim>>; NumReduceDim>>;
......
...@@ -28,6 +28,7 @@ template <typename XDataType, ...@@ -28,6 +28,7 @@ template <typename XDataType,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType, typename ComputeDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim, index_t NumReduceDim,
...@@ -43,12 +44,13 @@ template <typename XDataType, ...@@ -43,12 +44,13 @@ template <typename XDataType,
index_t BetaSrcVectorDim, index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorSize, index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize,
bool UseWelford = true> bool UseWelford = true>
struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
YElementwiseOperation, YElementwiseOperation,
Rank, Rank,
NumReduceDim> NumReduceDim>
...@@ -64,18 +66,24 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -64,18 +66,24 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
(BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)), (BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"); "Invalid thread slice sizes and/or beta vector sizes configuration, please check!");
static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!");
using PassThrough = tensor_operation::element_wise::PassThrough; using PassThrough = tensor_operation::element_wise::PassThrough;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static_assert(!reduceAllDim); // TODO
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths, static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<index_t>& inStrides, const std::vector<index_t>& inStrides,
int numBlockTileIteration) int numBlockTileIteration)
{ {
constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank; static constexpr index_t numSrcDim = Rank;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{}); const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
...@@ -133,7 +141,37 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -133,7 +141,37 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
static auto MakeSaveMeanInvStdDescriptor_M(const std::vector<index_t>& lengths,
const std::vector<index_t>& strides)
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
const auto tupleSrcLengths = make_tuple_from_array_and_index_seq(lengths, InvariantDims{});
const auto tupleSrcStrides = make_tuple_from_array_and_index_seq(strides, InvariantDims{});
const auto desc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto grid_desc_m =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(InvariantDims{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
const auto pad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto grid_desc_m_padded = transform_tensor_descriptor(
grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, pad_M)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return grid_desc_m_padded;
}
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1)); using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1));
using GridDesc_M = decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1}));
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -142,17 +180,23 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -142,17 +180,23 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> gammaStrides, const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> saveMeanStrides,
const std::vector<index_t> saveInvStdStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
YElementwiseOperation y_elementwise_op, YElementwiseOperation y_elementwise_op,
double epsilon, double epsilon,
const XDataType* p_x, const XDataType* p_x,
const GammaDataType* p_gamma, const GammaDataType* p_gamma,
const BetaDataType* p_beta, const BetaDataType* p_beta,
YDataType* p_y) YDataType* p_y,
SaveMeanInvStdDataType* p_saveMean,
SaveMeanInvStdDataType* p_saveInvStd)
: p_x_(p_x), : p_x_(p_x),
p_gamma_(p_gamma), p_gamma_(p_gamma),
p_beta_(p_beta), p_beta_(p_beta),
p_y_(p_y), p_y_(p_y),
p_saveMean_(p_saveMean),
p_saveInvStd_(p_saveInvStd),
y_elementwise_op_(y_elementwise_op) y_elementwise_op_(y_elementwise_op)
{ {
epsilon_ = static_cast<ComputeDataType>(epsilon); epsilon_ = static_cast<ComputeDataType>(epsilon);
...@@ -162,16 +206,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -162,16 +206,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims); yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims);
gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims); gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims);
betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims); betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims);
saveMeanStrides_ = saveMeanStrides;
saveInvStdStrides_ = saveInvStdStrides;
long_index_t invariant_length; std::tie(MRaw_, KRaw_) = get_2d_lengths<Rank, NumReduceDim>(Lengths_);
long_index_t reduce_length;
std::tie(invariant_length, reduce_length) =
get_2d_lengths<Rank, NumReduceDim>(Lengths_);
numBlockTileIteration_ = math::integer_divide_ceil(reduce_length, K_BlockTileSize); numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize);
gridSize_ = math::integer_divide_ceil(invariant_length, M_BlockTileSize); gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize);
x_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, xStrides_, numBlockTileIteration_); x_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, xStrides_, numBlockTileIteration_);
gamma_grid_desc_m_k_ = gamma_grid_desc_m_k_ =
...@@ -179,9 +221,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -179,9 +221,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
beta_grid_desc_m_k_ = beta_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, betaStrides_, numBlockTileIteration_); MakeSrc2dDescriptor(Lengths_, betaStrides_, numBlockTileIteration_);
y_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, yStrides_, numBlockTileIteration_); y_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, yStrides_, numBlockTileIteration_);
save_mean_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveMeanStrides);
save_inv_std_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveInvStdStrides);
isSweeponce_ = isSweeponce_ =
x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
if constexpr(NumInvariantDim == 0)
invariant_lowest_length_ = 1;
else
invariant_lowest_length_ = Lengths_[NumInvariantDim - 1];
} }
ComputeDataType epsilon_; ComputeDataType epsilon_;
...@@ -190,12 +239,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -190,12 +239,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const GammaDataType* p_gamma_; const GammaDataType* p_gamma_;
const BetaDataType* p_beta_; const BetaDataType* p_beta_;
YDataType* p_y_; YDataType* p_y_;
SaveMeanInvStdDataType* p_saveMean_;
SaveMeanInvStdDataType* p_saveInvStd_;
std::vector<index_t> Lengths_; std::vector<index_t> Lengths_;
std::vector<index_t> xStrides_; std::vector<index_t> xStrides_;
std::vector<index_t> gammaStrides_; std::vector<index_t> gammaStrides_;
std::vector<index_t> betaStrides_; std::vector<index_t> betaStrides_;
std::vector<index_t> yStrides_; std::vector<index_t> yStrides_;
std::vector<index_t> saveMeanStrides_;
std::vector<index_t> saveInvStdStrides_;
YElementwiseOperation y_elementwise_op_; YElementwiseOperation y_elementwise_op_;
...@@ -206,7 +259,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -206,7 +259,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GridDesc_M_K gamma_grid_desc_m_k_; GridDesc_M_K gamma_grid_desc_m_k_;
GridDesc_M_K beta_grid_desc_m_k_; GridDesc_M_K beta_grid_desc_m_k_;
GridDesc_M_K y_grid_desc_m_k_; GridDesc_M_K y_grid_desc_m_k_;
GridDesc_M save_mean_grid_desc_m_;
GridDesc_M save_inv_std_grid_desc_m_;
bool isSweeponce_; bool isSweeponce_;
index_t MRaw_; // invarient length
index_t KRaw_; // reduce length
index_t invariant_lowest_length_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
...@@ -217,9 +277,11 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -217,9 +277,11 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
GridDesc_M,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
...@@ -233,6 +295,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -233,6 +295,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
BetaSrcVectorSize, BetaSrcVectorSize,
XYSrcVectorDim, XYSrcVectorDim,
YDstVectorSize, YDstVectorSize,
SaveMeanInvStdDstVectorSize,
UseWelford>(arg.isSweeponce_); UseWelford>(arg.isSweeponce_);
float avg_time = 0; float avg_time = 0;
...@@ -245,12 +308,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -245,12 +308,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
arg.gamma_grid_desc_m_k_, arg.gamma_grid_desc_m_k_,
arg.beta_grid_desc_m_k_, arg.beta_grid_desc_m_k_,
arg.y_grid_desc_m_k_, arg.y_grid_desc_m_k_,
arg.save_mean_grid_desc_m_,
arg.save_inv_std_grid_desc_m_,
arg.numBlockTileIteration_, arg.numBlockTileIteration_,
arg.epsilon_, arg.epsilon_,
arg.p_x_, arg.p_x_,
arg.p_gamma_, arg.p_gamma_,
arg.p_beta_, arg.p_beta_,
arg.p_y_, arg.p_y_,
arg.p_saveMean_,
arg.p_saveInvStd_,
arg.y_elementwise_op_); arg.y_elementwise_op_);
return (avg_time); return (avg_time);
...@@ -267,8 +334,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -267,8 +334,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
{ {
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg); const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
constexpr index_t NumInvariantDim = Rank - NumReduceDim;
if constexpr(XYSrcVectorDim == 0) if constexpr(XYSrcVectorDim == 0)
{ {
if constexpr(NumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
...@@ -277,13 +342,15 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -277,13 +342,15 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
} }
else else
{ {
printf("!!!! %d\n", p_arg_->invariant_lowest_length_);
if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) if(p_arg_->xStrides_[NumInvariantDim - 1] != 1)
return false; return false;
if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0) if(p_arg_->invariant_lowest_length_ % XSrcVectorSize != 0)
return false; return false;
if(p_arg_->invariant_lowest_length % YDstVectorSize != 0) if(p_arg_->invariant_lowest_length_ % YDstVectorSize != 0)
return false; return false;
}; };
} }
...@@ -325,7 +392,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -325,7 +392,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1) if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1)
return (false); return (false);
if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0) if(p_arg_->invariant_lowest_length_ % BetaSrcVectorSize != 0)
return (false); return (false);
} }
else // if fastest dim is reduced else // if fastest dim is reduced
...@@ -337,6 +404,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -337,6 +404,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
return (false); return (false);
} }
if(p_arg_->invariant_lowest_length_ % SaveMeanInvStdDstVectorSize != 0)
return false;
return true; return true;
}; };
...@@ -346,6 +416,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -346,6 +416,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> gammaStrides, const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> saveMeanStrides,
const std::vector<index_t> saveInvStdStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
double epsilon, double epsilon,
const void* p_x, const void* p_x,
...@@ -353,27 +425,30 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -353,27 +425,30 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const void* p_beta, const void* p_beta,
void* p_y, void* p_y,
void* p_saveMean, void* p_saveMean,
void* p_saveInvVar, void* p_saveInvStd,
YElementwiseOperation y_elementwise_op) override YElementwiseOperation y_elementwise_op) override
{ {
// TODO if(lengths.size() != Rank || xStrides.size() != Rank || gammaStrides.size() != Rank ||
// Optional cache of the intermediate results (mean and InvVariance) during the betaStrides.size() != Rank || yStrides.size() != Rank ||
// forward pass could speedup in the backward saveMeanStrides.size() != NumInvariantDim || saveInvStdStrides.size() != NumInvariantDim)
ignore = p_saveMean; throw std::runtime_error("dimension is incorrect");
ignore = p_saveInvVar;
return std::make_unique<Argument>(lengths, return std::make_unique<Argument>(lengths,
xStrides, xStrides,
gammaStrides, gammaStrides,
betaStrides, betaStrides,
yStrides, yStrides,
saveMeanStrides,
saveInvStdStrides,
reduceDims, reduceDims,
y_elementwise_op, y_elementwise_op,
epsilon, epsilon,
static_cast<const XDataType*>(p_x), static_cast<const XDataType*>(p_x),
static_cast<const GammaDataType*>(p_gamma), static_cast<const GammaDataType*>(p_gamma),
static_cast<const BetaDataType*>(p_beta), static_cast<const BetaDataType*>(p_beta),
static_cast<YDataType*>(p_y)); static_cast<YDataType*>(p_y),
static_cast<SaveMeanInvStdDataType*>(p_saveMean),
static_cast<SaveMeanInvStdDataType*>(p_saveInvStd));
}; };
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace ck { namespace ck {
template <typename GridwiseWelford, template <typename GridwiseWelford,
typename XDataType, typename XDataType,
typename MeanVarDataType, typename WorkspaceMeanVarDataType,
typename ComputeDataType, typename ComputeDataType,
typename XGridDesc_M_K, typename XGridDesc_M_K,
typename MeanVarGridDesc_M_KBlock> typename MeanVarGridDesc_M_KBlock>
...@@ -28,8 +28,8 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k, ...@@ -28,8 +28,8 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k,
const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x_global, const XDataType* const __restrict__ p_x_global,
MeanVarDataType* const __restrict__ p_welford_mean, WorkspaceMeanVarDataType* const __restrict__ p_welford_mean,
MeanVarDataType* const __restrict__ p_welford_variance, WorkspaceMeanVarDataType* const __restrict__ p_welford_variance,
int32_t* const __restrict__ p_welford_count) int32_t* const __restrict__ p_welford_count)
{ {
GridwiseWelford::Run(x_grid_desc_m_k, GridwiseWelford::Run(x_grid_desc_m_k,
...@@ -42,16 +42,18 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k, ...@@ -42,16 +42,18 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k,
}; };
template <typename GridwiseWelfordNormalization, template <typename GridwiseWelfordNormalization,
typename MeanVarDataType, typename WorkspaceMeanVarDataType,
typename XDataType, typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType, typename ComputeDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
typename MeanVarGridDesc_M_KBlock, typename MeanVarGridDesc_M_KBlock,
typename CountGridDesc_M_KBlock, typename CountGridDesc_M_KBlock,
typename XYGammaBetaGridDesc_M_K> typename XYGammaBetaGridDesc_M_K,
typename SaveMeanInvStdGridDesc_M>
__global__ void __global__ void
kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock,
const CountGridDesc_M_KBlock count_grid_desc_m_kblock, const CountGridDesc_M_KBlock count_grid_desc_m_kblock,
...@@ -59,17 +61,21 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_ ...@@ -59,17 +61,21 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
const XYGammaBetaGridDesc_M_K gamma_grid_desc_m_k, const XYGammaBetaGridDesc_M_K gamma_grid_desc_m_k,
const XYGammaBetaGridDesc_M_K beta_grid_desc_m_k, const XYGammaBetaGridDesc_M_K beta_grid_desc_m_k,
const XYGammaBetaGridDesc_M_K y_grid_desc_m_k, const XYGammaBetaGridDesc_M_K y_grid_desc_m_k,
const SaveMeanInvStdGridDesc_M save_mean_grid_desc_m,
const SaveMeanInvStdGridDesc_M save_inv_std_grid_desc_m,
index_t num_k_mean_var_count_iteration, index_t num_k_mean_var_count_iteration,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
index_t k_grid_size, index_t k_grid_size,
ComputeDataType epsilon, ComputeDataType epsilon,
const MeanVarDataType* const p_mean_global, const WorkspaceMeanVarDataType* const p_mean_global,
const MeanVarDataType* const p_variance_global, const WorkspaceMeanVarDataType* const p_variance_global,
const int32_t* const p_welford_count_global, const int32_t* const p_welford_count_global,
const XDataType* const __restrict__ p_x_global, const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global, const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global, const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global, YDataType* const __restrict__ p_y_global,
SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
const YElementwiseOperation y_elementwise_op) const YElementwiseOperation y_elementwise_op)
{ {
GridwiseWelfordNormalization::Run(mean_var_grid_desc_m_kblock, GridwiseWelfordNormalization::Run(mean_var_grid_desc_m_kblock,
...@@ -78,6 +84,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_ ...@@ -78,6 +84,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
gamma_grid_desc_m_k, gamma_grid_desc_m_k,
beta_grid_desc_m_k, beta_grid_desc_m_k,
y_grid_desc_m_k, y_grid_desc_m_k,
save_mean_grid_desc_m,
save_inv_std_grid_desc_m,
num_k_mean_var_count_iteration, num_k_mean_var_count_iteration,
num_k_block_tile_iteration, num_k_block_tile_iteration,
k_grid_size, k_grid_size,
...@@ -89,6 +97,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_ ...@@ -89,6 +97,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_
p_gamma_global, p_gamma_global,
p_beta_global, p_beta_global,
p_y_global, p_y_global,
p_save_mean_global,
p_save_inv_std_global,
y_elementwise_op); y_elementwise_op);
}; };
} // namespace ck } // namespace ck
...@@ -107,6 +117,7 @@ template <typename XDataType, ...@@ -107,6 +117,7 @@ template <typename XDataType,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType, typename ComputeDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim, index_t NumReduceDim,
...@@ -121,17 +132,18 @@ template <typename XDataType, ...@@ -121,17 +132,18 @@ template <typename XDataType,
index_t GammaSrcVectorSize, index_t GammaSrcVectorSize,
index_t BetaSrcVectorDim, index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorSize> index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize>
struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
YElementwiseOperation, YElementwiseOperation,
Rank, Rank,
NumReduceDim> NumReduceDim>
{ {
using MeanVarDataType = ComputeDataType; using WorkspaceMeanVarDataType = SaveMeanInvStdDataType;
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize);
static_assert( static_assert(
...@@ -144,22 +156,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -144,22 +156,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
(BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)), (BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"); "Invalid thread slice sizes and/or beta vector sizes configuration, please check!");
static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!");
using PassThrough = tensor_operation::element_wise::PassThrough; using PassThrough = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static_assert(!reduceAllDim); // TODO
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths, static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<index_t>& inStrides, const std::vector<index_t>& inStrides,
int kBlockSize, int kBlockSize,
int numBlockTileIteration) int numBlockTileIteration)
{ {
constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank; static constexpr index_t numSrcDim = Rank;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{}); const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
...@@ -219,7 +237,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -219,7 +237,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
}; };
template <typename DoPads, index_t MPerTile, index_t KPerTile> template <typename DoPads, index_t MPerTile, index_t KPerTile>
static auto MakeMeanVarDescriptor_M_K(index_t M, index_t K) static auto MakeWorkspaceMeanVarDescriptor_M_K(index_t M, index_t K)
{ {
const auto grid_desc_m_k = const auto grid_desc_m_k =
make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1)); make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1));
...@@ -227,26 +245,57 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -227,26 +245,57 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
} }
template <typename DoPads, index_t MPerTile, index_t KPerTile> template <typename DoPads, index_t MPerTile, index_t KPerTile>
static auto MakeCountDescriptor_M_K(index_t M, index_t K) static auto MakeWorkspaceCountDescriptor_M_K(index_t M, index_t K)
{ {
const auto grid_desc_m_k = const auto grid_desc_m_k =
make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I0, I1)); make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I0, I1));
return PadTensorDescriptor(grid_desc_m_k, make_tuple(MPerTile, KPerTile), DoPads{}); return PadTensorDescriptor(grid_desc_m_k, make_tuple(MPerTile, KPerTile), DoPads{});
} }
static auto MakeSaveMeanInvStdDescriptor_M(const std::vector<index_t>& lengths,
const std::vector<index_t>& strides)
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
const auto tupleSrcLengths = make_tuple_from_array_and_index_seq(lengths, InvariantDims{});
const auto tupleSrcStrides = make_tuple_from_array_and_index_seq(strides, InvariantDims{});
const auto desc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto grid_desc_m =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(InvariantDims{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
const auto pad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto grid_desc_m_padded = transform_tensor_descriptor(
grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, pad_M)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return grid_desc_m_padded;
}
using SrcGridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)); using SrcGridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
using Kernel1MeanVarGridDesc_M_KBlock = using Kernel1MeanVarGridDesc_M_KBlock =
decltype(MakeMeanVarDescriptor_M_K<Sequence<true, false>, 1, 1>(1, 1)); decltype(MakeWorkspaceMeanVarDescriptor_M_K<Sequence<true, false>, 1, 1>(1, 1));
using Kernel2MeanVarGridDesc_M_KBlock = using Kernel2MeanVarGridDesc_M_KBlock =
decltype(MakeMeanVarDescriptor_M_K<Sequence<true, true>, 1, 1>(1, 1)); decltype(MakeWorkspaceMeanVarDescriptor_M_K<Sequence<true, true>, 1, 1>(1, 1));
using Kernel2CountGridDesc_M_KBlock = using Kernel2CountGridDesc_M_KBlock =
decltype(MakeCountDescriptor_M_K<Sequence<true, true>, 1, 1>(1, 1)); decltype(MakeWorkspaceCountDescriptor_M_K<Sequence<true, true>, 1, 1>(1, 1));
using SaveMeanInvStdGridDesc_M = decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1}));
using GridwiseWelford = GridwiseNormalizationSplitK1st<XDataType, using GridwiseWelford = GridwiseNormalizationSplitK1st<XDataType,
ComputeDataType, ComputeDataType,
MeanVarDataType, WorkspaceMeanVarDataType,
SrcGridDesc_M_K, SrcGridDesc_M_K,
Kernel1MeanVarGridDesc_M_KBlock, Kernel1MeanVarGridDesc_M_KBlock,
BlockSize, BlockSize,
...@@ -258,16 +307,18 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -258,16 +307,18 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
XSrcVectorSize>; XSrcVectorSize>;
using GridwiseWelfordNormalization = using GridwiseWelfordNormalization =
GridwiseNormalizationSplitK2nd<MeanVarDataType, GridwiseNormalizationSplitK2nd<WorkspaceMeanVarDataType,
XDataType, XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
Kernel2MeanVarGridDesc_M_KBlock, Kernel2MeanVarGridDesc_M_KBlock,
Kernel2CountGridDesc_M_KBlock, Kernel2CountGridDesc_M_KBlock,
SrcGridDesc_M_K, SrcGridDesc_M_K,
SaveMeanInvStdGridDesc_M,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
...@@ -280,7 +331,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -280,7 +331,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
BetaSrcVectorDim, BetaSrcVectorDim,
BetaSrcVectorSize, BetaSrcVectorSize,
XYVectorDim, XYVectorDim,
YDstVectorSize>; YDstVectorSize,
SaveMeanInvStdDstVectorSize>;
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -289,17 +341,23 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -289,17 +341,23 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> gammaStrides, const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> saveMeanStrides,
const std::vector<index_t> saveInvStdStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
YElementwiseOperation y_elementwise_op, YElementwiseOperation y_elementwise_op,
double epsilon, double epsilon,
const XDataType* p_x, const XDataType* p_x,
const GammaDataType* p_gamma, const GammaDataType* p_gamma,
const BetaDataType* p_beta, const BetaDataType* p_beta,
YDataType* p_y) YDataType* p_y,
SaveMeanInvStdDataType* p_saveMean,
SaveMeanInvStdDataType* p_saveInvStd)
: p_x_(p_x), : p_x_(p_x),
p_gamma_(p_gamma), p_gamma_(p_gamma),
p_beta_(p_beta), p_beta_(p_beta),
p_y_(p_y), p_y_(p_y),
p_saveMean_(p_saveMean),
p_saveInvStd_(p_saveInvStd),
p_workspace_mean_{nullptr}, p_workspace_mean_{nullptr},
p_workspace_var_{nullptr}, p_workspace_var_{nullptr},
p_workspace_count_{nullptr}, p_workspace_count_{nullptr},
...@@ -312,6 +370,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -312,6 +370,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims); yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims);
gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims); gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims);
betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims); betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims);
saveMeanStrides_ = saveMeanStrides;
saveInvStdStrides_ = saveInvStdStrides;
std::tie(MRaw_, KRaw_) = get_2d_lengths<Rank, NumReduceDim>(Lengths_); std::tie(MRaw_, KRaw_) = get_2d_lengths<Rank, NumReduceDim>(Lengths_);
...@@ -346,20 +406,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -346,20 +406,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
y_grid_desc_m_k_ = y_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, yStrides_, kGridSize_, numBlockTileIteration_); MakeSrc2dDescriptor(Lengths_, yStrides_, kGridSize_, numBlockTileIteration_);
save_mean_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveMeanStrides);
save_inv_std_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveInvStdStrides);
// We don't need to pad in K dimension for Welford1. Set KPerTile 1. // We don't need to pad in K dimension for Welford1. Set KPerTile 1.
kernel1_mean_var_grid_desc_m_kblock_ = kernel1_mean_var_grid_desc_m_kblock_ =
MakeMeanVarDescriptor_M_K<Sequence<true, false>, M_BlockTileSize, 1>(MRaw_, MakeWorkspaceMeanVarDescriptor_M_K<Sequence<true, false>, M_BlockTileSize, 1>(
kGridSize_); MRaw_, kGridSize_);
kernel2_mean_var_grid_desc_m_kblock_ = kernel2_mean_var_grid_desc_m_kblock_ =
MakeMeanVarDescriptor_M_K<Sequence<true, true>, MakeWorkspaceMeanVarDescriptor_M_K<Sequence<true, true>,
M_BlockTileSize, M_BlockTileSize,
K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_);
kernel2_count_grid_desc_m_kblock_ = kernel2_count_grid_desc_m_kblock_ =
MakeCountDescriptor_M_K<Sequence<true, true>, MakeWorkspaceCountDescriptor_M_K<Sequence<true, true>,
M_BlockTileSize, M_BlockTileSize,
K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_);
if constexpr(NumInvariantDim == 0)
invariant_lowest_length_ = 1;
else
invariant_lowest_length_ = Lengths_[NumInvariantDim - 1];
} }
ComputeDataType epsilon_; ComputeDataType epsilon_;
...@@ -368,6 +436,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -368,6 +436,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const GammaDataType* p_gamma_; const GammaDataType* p_gamma_;
const BetaDataType* p_beta_; const BetaDataType* p_beta_;
YDataType* p_y_; YDataType* p_y_;
SaveMeanInvStdDataType* p_saveMean_;
SaveMeanInvStdDataType* p_saveInvStd_;
void* p_workspace_mean_; void* p_workspace_mean_;
void* p_workspace_var_; void* p_workspace_var_;
void* p_workspace_count_; void* p_workspace_count_;
...@@ -377,6 +447,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -377,6 +447,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
std::vector<index_t> gammaStrides_; std::vector<index_t> gammaStrides_;
std::vector<index_t> betaStrides_; std::vector<index_t> betaStrides_;
std::vector<index_t> yStrides_; std::vector<index_t> yStrides_;
std::vector<index_t> saveMeanStrides_;
std::vector<index_t> saveInvStdStrides_;
YElementwiseOperation y_elementwise_op_; YElementwiseOperation y_elementwise_op_;
...@@ -389,6 +461,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -389,6 +461,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
SrcGridDesc_M_K gamma_grid_desc_m_k_; SrcGridDesc_M_K gamma_grid_desc_m_k_;
SrcGridDesc_M_K beta_grid_desc_m_k_; SrcGridDesc_M_K beta_grid_desc_m_k_;
SrcGridDesc_M_K y_grid_desc_m_k_; SrcGridDesc_M_K y_grid_desc_m_k_;
SaveMeanInvStdGridDesc_M save_mean_grid_desc_m_;
SaveMeanInvStdGridDesc_M save_inv_std_grid_desc_m_;
Kernel1MeanVarGridDesc_M_KBlock kernel1_mean_var_grid_desc_m_kblock_; Kernel1MeanVarGridDesc_M_KBlock kernel1_mean_var_grid_desc_m_kblock_;
Kernel2MeanVarGridDesc_M_KBlock kernel2_mean_var_grid_desc_m_kblock_; Kernel2MeanVarGridDesc_M_KBlock kernel2_mean_var_grid_desc_m_kblock_;
...@@ -396,6 +470,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -396,6 +470,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
index_t MRaw_; // invarient length index_t MRaw_; // invarient length
index_t KRaw_; // reduce length index_t KRaw_; // reduce length
index_t invariant_lowest_length_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
...@@ -408,60 +484,68 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -408,60 +484,68 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
auto kernel1 = kernel_normalizationSplitK1st<GridwiseWelford, auto kernel1 = kernel_normalizationSplitK1st<GridwiseWelford,
XDataType, XDataType,
MeanVarDataType, WorkspaceMeanVarDataType,
ComputeDataType, ComputeDataType,
SrcGridDesc_M_K, SrcGridDesc_M_K,
Kernel1MeanVarGridDesc_M_KBlock>; Kernel1MeanVarGridDesc_M_KBlock>;
auto kernel2 = kernel_normalizationSplitK2nd<GridwiseWelfordNormalization, auto kernel2 = kernel_normalizationSplitK2nd<GridwiseWelfordNormalization,
MeanVarDataType, WorkspaceMeanVarDataType,
XDataType, XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
Kernel2MeanVarGridDesc_M_KBlock, Kernel2MeanVarGridDesc_M_KBlock,
Kernel2CountGridDesc_M_KBlock, Kernel2CountGridDesc_M_KBlock,
SrcGridDesc_M_K>; SrcGridDesc_M_K,
SaveMeanInvStdGridDesc_M>;
float avg_time = 0; float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config, avg_time += launch_and_time_kernel(
kernel1, stream_config,
dim3(arg.gridSize_), kernel1,
dim3(BlockSize), dim3(arg.gridSize_),
0, dim3(BlockSize),
arg.x_grid_desc_m_k_, 0,
arg.kernel1_mean_var_grid_desc_m_kblock_, arg.x_grid_desc_m_k_,
arg.numBlockTileIteration_, arg.kernel1_mean_var_grid_desc_m_kblock_,
arg.p_x_, arg.numBlockTileIteration_,
static_cast<MeanVarDataType*>(arg.p_workspace_mean_), arg.p_x_,
static_cast<MeanVarDataType*>(arg.p_workspace_var_), static_cast<WorkspaceMeanVarDataType*>(arg.p_workspace_mean_),
static_cast<int32_t*>(arg.p_workspace_count_)); static_cast<WorkspaceMeanVarDataType*>(arg.p_workspace_var_),
static_cast<int32_t*>(arg.p_workspace_count_));
avg_time += launch_and_time_kernel(stream_config,
kernel2, avg_time += launch_and_time_kernel(
dim3(arg.gridSize_), stream_config,
dim3(BlockSize), kernel2,
0, dim3(arg.gridSize_),
arg.kernel2_mean_var_grid_desc_m_kblock_, dim3(BlockSize),
arg.kernel2_count_grid_desc_m_kblock_, 0,
arg.x_grid_desc_m_k_, arg.kernel2_mean_var_grid_desc_m_kblock_,
arg.gamma_grid_desc_m_k_, arg.kernel2_count_grid_desc_m_kblock_,
arg.beta_grid_desc_m_k_, arg.x_grid_desc_m_k_,
arg.y_grid_desc_m_k_, arg.gamma_grid_desc_m_k_,
arg.numMeanVarCountIteration_, arg.beta_grid_desc_m_k_,
arg.numBlockTileIteration_, arg.y_grid_desc_m_k_,
arg.kGridSize_, arg.save_mean_grid_desc_m_,
arg.epsilon_, arg.save_inv_std_grid_desc_m_,
static_cast<MeanVarDataType*>(arg.p_workspace_mean_), arg.numMeanVarCountIteration_,
static_cast<MeanVarDataType*>(arg.p_workspace_var_), arg.numBlockTileIteration_,
static_cast<int32_t*>(arg.p_workspace_count_), arg.kGridSize_,
arg.p_x_, arg.epsilon_,
arg.p_gamma_, static_cast<const WorkspaceMeanVarDataType*>(arg.p_workspace_mean_),
arg.p_beta_, static_cast<const WorkspaceMeanVarDataType*>(arg.p_workspace_var_),
arg.p_y_, static_cast<const int32_t*>(arg.p_workspace_count_),
arg.y_elementwise_op_); arg.p_x_,
arg.p_gamma_,
arg.p_beta_,
arg.p_y_,
arg.p_saveMean_,
arg.p_saveInvStd_,
arg.y_elementwise_op_);
return avg_time; return avg_time;
}; };
...@@ -482,10 +566,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -482,10 +566,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
int welford_size = pArg_->MRaw_ * pArg_->kGridSize_; int welford_size = pArg_->MRaw_ * pArg_->kGridSize_;
// workspace for welford intermediate mean // workspace for welford intermediate mean
workspace_size += welford_size * sizeof(MeanVarDataType) + 64; workspace_size += welford_size * sizeof(WorkspaceMeanVarDataType) + 64;
// workspace for welford intermediate variance // workspace for welford intermediate variance
workspace_size += welford_size * sizeof(MeanVarDataType) + 64; workspace_size += welford_size * sizeof(WorkspaceMeanVarDataType) + 64;
// workspace for welford intermediate count // workspace for welford intermediate count
workspace_size += pArg_->kGridSize_ * sizeof(int32_t) + 64; workspace_size += pArg_->kGridSize_ * sizeof(int32_t) + 64;
...@@ -504,13 +588,13 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -504,13 +588,13 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
// setup buffer used for intermediate welford mean // setup buffer used for intermediate welford mean
pArg_->p_workspace_mean_ = static_cast<char*>(pArg_->p_workspace_); pArg_->p_workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
index_t mean_space_sz = welford_size * sizeof(MeanVarDataType); index_t mean_space_sz = welford_size * sizeof(WorkspaceMeanVarDataType);
mean_space_sz = math::integer_least_multiple(mean_space_sz, 64); mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
// setup buffer used for intermediate welford varirance // setup buffer used for intermediate welford varirance
pArg_->p_workspace_var_ = reinterpret_cast<char*>(pArg_->p_workspace_mean_) + mean_space_sz; pArg_->p_workspace_var_ = reinterpret_cast<char*>(pArg_->p_workspace_mean_) + mean_space_sz;
index_t variance_space_sz = welford_size * sizeof(MeanVarDataType); index_t variance_space_sz = welford_size * sizeof(WorkspaceMeanVarDataType);
variance_space_sz = math::integer_least_multiple(variance_space_sz, 64); variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
// setup buffer used for intermediate welford count // setup buffer used for intermediate welford count
...@@ -522,8 +606,6 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -522,8 +606,6 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
{ {
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg); const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
constexpr index_t NumInvariantDim = Rank - NumReduceDim;
if constexpr(XYVectorDim == 0) if constexpr(XYVectorDim == 0)
{ {
if constexpr(NumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
...@@ -535,10 +617,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -535,10 +617,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) if(p_arg_->xStrides_[NumInvariantDim - 1] != 1)
return false; return false;
if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0) if(p_arg_->invariant_lowest_length_ % XSrcVectorSize != 0)
return false; return false;
if(p_arg_->invariant_lowest_length % YDstVectorSize != 0) if(p_arg_->invariant_lowest_length_ % YDstVectorSize != 0)
return false; return false;
}; };
} }
...@@ -578,7 +660,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -578,7 +660,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1) if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1)
return false; return false;
if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0) if(p_arg_->invariant_lowest_length_ % BetaSrcVectorSize != 0)
return false; return false;
} }
else // if fastest dim is reduced else // if fastest dim is reduced
...@@ -593,6 +675,9 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -593,6 +675,9 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
if(p_arg_->kGridSize_ <= 1) if(p_arg_->kGridSize_ <= 1)
return false; return false;
if(p_arg_->invariant_lowest_length_ % SaveMeanInvStdDstVectorSize != 0)
return false;
return true; return true;
}; };
...@@ -602,6 +687,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -602,6 +687,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> gammaStrides, const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> saveMeanStrides,
const std::vector<index_t> saveInvStdStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
double epsilon, double epsilon,
const void* p_x, const void* p_x,
...@@ -609,27 +696,30 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -609,27 +696,30 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
const void* p_beta, const void* p_beta,
void* p_y, void* p_y,
void* p_saveMean, void* p_saveMean,
void* p_saveInvVar, void* p_saveInvStd,
YElementwiseOperation y_elementwise_op) override YElementwiseOperation y_elementwise_op) override
{ {
// TODO if(lengths.size() != Rank || xStrides.size() != Rank || gammaStrides.size() != Rank ||
// Optional cache of the intermediate results (mean and InvVariance) during the betaStrides.size() != Rank || yStrides.size() != Rank ||
// forward pass could speedup in the backward saveMeanStrides.size() != NumInvariantDim || saveInvStdStrides.size() != NumInvariantDim)
ignore = p_saveMean; throw std::runtime_error("dimension is incorrect");
ignore = p_saveInvVar;
return std::make_unique<Argument>(lengths, return std::make_unique<Argument>(lengths,
xStrides, xStrides,
gammaStrides, gammaStrides,
betaStrides, betaStrides,
yStrides, yStrides,
saveMeanStrides,
saveInvStdStrides,
reduceDims, reduceDims,
y_elementwise_op, y_elementwise_op,
epsilon, epsilon,
static_cast<const XDataType*>(p_x), static_cast<const XDataType*>(p_x),
static_cast<const GammaDataType*>(p_gamma), static_cast<const GammaDataType*>(p_gamma),
static_cast<const BetaDataType*>(p_beta), static_cast<const BetaDataType*>(p_beta),
static_cast<YDataType*>(p_y)); static_cast<YDataType*>(p_y),
static_cast<SaveMeanInvStdDataType*>(p_saveMean),
static_cast<SaveMeanInvStdDataType*>(p_saveInvStd));
}; };
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
...@@ -18,9 +18,11 @@ template <typename XDataType, ...@@ -18,9 +18,11 @@ template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType, typename ComputeDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
typename GridDesc_M_K, typename GridDesc_M_K,
typename GridDesc_M,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
index_t KThreadClusterSize, index_t KThreadClusterSize,
...@@ -34,6 +36,7 @@ template <typename XDataType, ...@@ -34,6 +36,7 @@ template <typename XDataType,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorDim, index_t YDstVectorDim,
index_t YDstVectorSize, index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize,
bool SweepOnce> bool SweepOnce>
struct GridwiseNormalizationNaiveVariance_mk_to_mk struct GridwiseNormalizationNaiveVariance_mk_to_mk
{ {
...@@ -45,6 +48,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -45,6 +48,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"); "Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!");
static_assert(XSrcVectorSize == YDstVectorSize); static_assert(XSrcVectorSize == YDstVectorSize);
static_assert(XSrcVectorSize == GammaSrcVectorSize); static_assert(XSrcVectorSize == GammaSrcVectorSize);
static_assert(XSrcVectorSize == BetaSrcVectorSize); static_assert(XSrcVectorSize == BetaSrcVectorSize);
...@@ -66,6 +73,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -66,6 +73,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}))); make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})));
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
...@@ -84,6 +95,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -84,6 +95,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
reduce::Add, reduce::Add,
true>; true>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -98,12 +111,16 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -98,12 +111,16 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
const GridDesc_M_K& gamma_grid_desc_m_k, const GridDesc_M_K& gamma_grid_desc_m_k,
const GridDesc_M_K& beta_grid_desc_m_k, const GridDesc_M_K& beta_grid_desc_m_k,
const GridDesc_M_K& y_grid_desc_m_k, const GridDesc_M_K& y_grid_desc_m_k,
const GridDesc_M& save_mean_grid_desc_m,
const GridDesc_M& save_inv_std_grid_desc_m,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
ComputeDataType epsilon, ComputeDataType epsilon,
const XDataType* const __restrict__ p_x_global, const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global, const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global, const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global, YDataType* const __restrict__ p_y_global,
SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
const YElementwiseOperation y_elementwise_op) const YElementwiseOperation y_elementwise_op)
{ {
// LDS // LDS
...@@ -115,6 +132,12 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -115,6 +132,12 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto save_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize());
auto save_inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize());
auto x_thread_buf = generate_tuple( auto x_thread_buf = generate_tuple(
[&](auto) { [&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr, return StaticBuffer<AddressSpaceEnum::Vgpr,
...@@ -152,6 +175,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -152,6 +175,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
mean_square_thread_buf; mean_square_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>& StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>&
var_thread_buf = mean_square_thread_buf; var_thread_buf = mean_square_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>&
inv_std_thread_buf = mean_square_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
...@@ -228,6 +253,42 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -228,6 +253,42 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
thread_k_cluster_id * YDstVectorSize), thread_k_cluster_id * YDstVectorSize),
y_elementwise_op); y_elementwise_op);
auto threadwise_mean_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_mean_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_inv_std_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_inv_std_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
constexpr auto thread_copy_bwd_step_m_k = constexpr auto thread_copy_bwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
...@@ -243,7 +304,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -243,7 +304,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// E(x), E[x^2], var(x) // E(x), E[x^2], var(x)
// FIXME: Should not hack the transform from deviceOP // FIXME: Should not hack the transform from deviceOP
int reduce_length = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; ComputeDataType reduce_length = type_convert<ComputeDataType>(
x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = reduce::Add::template GetIdentityValue<ComputeDataType>(); mean_thread_buf(I) = reduce::Add::template GetIdentityValue<ComputeDataType>();
...@@ -302,10 +364,34 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -302,10 +364,34 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// var(x) = E[x^2] - E[x]^2 // var(x) = E[x^2] - E[x]^2
var_thread_buf(I) = var_thread_buf(I) =
mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
inv_std_thread_buf(I) = type_convert<ComputeDataType>(1.0f) /
ck::math::sqrt(var_thread_buf(I) + epsilon);
}); });
// save mean and inverse std for backward (optional)
if(thread_k_cluster_id == 0)
{
if(p_save_mean_global != nullptr)
{
threadwise_mean_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
save_mean_grid_desc_m,
save_mean_global_val_buf);
}
if(p_save_inv_std_global != nullptr)
{
threadwise_inv_std_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_std_thread_buf,
save_inv_std_grid_desc_m,
save_inv_std_global_val_buf);
}
}
// normalization
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
...@@ -314,7 +400,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -314,7 +400,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// normalize // normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) * (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor; inv_std_thread_buf(iM);
// gamma & beta // gamma & beta
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
...@@ -404,8 +490,30 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -404,8 +490,30 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// var(x) = E[x^2] - E[x]^2 // var(x) = E[x^2] - E[x]^2
var_thread_buf(I) = var_thread_buf(I) =
mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon);
}); });
if(thread_k_cluster_id == 0)
{
if(p_save_mean_global != nullptr)
{
threadwise_mean_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
save_mean_grid_desc_m,
save_mean_global_val_buf);
}
if(p_save_inv_std_global != nullptr)
{
threadwise_inv_std_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_std_thread_buf,
save_inv_std_grid_desc_m,
save_inv_std_global_val_buf);
}
}
auto thread_copy_tail_m_k = auto thread_copy_tail_m_k =
(num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k; (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k;
...@@ -437,7 +545,6 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -437,7 +545,6 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
}); });
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
...@@ -446,7 +553,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -446,7 +553,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// normalize // normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) * (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor; inv_std_thread_buf(iM);
// gamma // gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
......
...@@ -12,31 +12,42 @@ template <typename GridwiseReduction, ...@@ -12,31 +12,42 @@ template <typename GridwiseReduction,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType, typename ComputeDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
typename GridDesc_M_K> typename GridDesc_M_K,
__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, typename GridDesc_M>
const GridDesc_M_K gamma_grid_desc_m_k, __global__ void
const GridDesc_M_K beta_grid_desc_m_k, kernel_normalization(const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_M_K y_grid_desc_m_k, const GridDesc_M_K gamma_grid_desc_m_k,
index_t num_k_block_tile_iteration, const GridDesc_M_K beta_grid_desc_m_k,
ComputeDataType epsilon, const GridDesc_M_K y_grid_desc_m_k,
const XDataType* const __restrict__ p_x_global, const GridDesc_M save_mean_grid_desc_m,
const GammaDataType* const __restrict__ p_gamma_global, const GridDesc_M save_inv_std_grid_desc_m,
const BetaDataType* const __restrict__ p_beta_global, index_t num_k_block_tile_iteration,
YDataType* const __restrict__ p_y_global, ComputeDataType epsilon,
const YElementwiseOperation y_elementwise_op) const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
const YElementwiseOperation y_elementwise_op)
{ {
GridwiseReduction::Run(x_grid_desc_m_k, GridwiseReduction::Run(x_grid_desc_m_k,
gamma_grid_desc_m_k, gamma_grid_desc_m_k,
beta_grid_desc_m_k, beta_grid_desc_m_k,
y_grid_desc_m_k, y_grid_desc_m_k,
save_mean_grid_desc_m,
save_inv_std_grid_desc_m,
num_k_block_tile_iteration, num_k_block_tile_iteration,
epsilon, epsilon,
p_x_global, p_x_global,
p_gamma_global, p_gamma_global,
p_beta_global, p_beta_global,
p_y_global, p_y_global,
p_save_mean_global,
p_save_inv_std_global,
y_elementwise_op); y_elementwise_op);
}; };
...@@ -44,9 +55,11 @@ template <typename XDataType, ...@@ -44,9 +55,11 @@ template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType, typename ComputeDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
typename GridDesc_M_K, typename GridDesc_M_K,
typename GridDesc_M,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
index_t KThreadClusterSize, index_t KThreadClusterSize,
...@@ -60,6 +73,7 @@ template <typename XDataType, ...@@ -60,6 +73,7 @@ template <typename XDataType,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorDim, index_t YDstVectorDim,
index_t YDstVectorSize, index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize,
bool UseWelford> bool UseWelford>
auto NormalizationKernelSelector(bool isSweepOnce) auto NormalizationKernelSelector(bool isSweepOnce)
{ {
...@@ -68,9 +82,11 @@ auto NormalizationKernelSelector(bool isSweepOnce) ...@@ -68,9 +82,11 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
GridDesc_M,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
...@@ -84,15 +100,18 @@ auto NormalizationKernelSelector(bool isSweepOnce) ...@@ -84,15 +100,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize, BetaSrcVectorSize,
YDstVectorDim, YDstVectorDim,
YDstVectorSize, YDstVectorSize,
SaveMeanInvStdDstVectorSize,
false>; false>;
using GridwiseNormalizationSweepOnceNaive = using GridwiseNormalizationSweepOnceNaive =
GridwiseNormalizationNaiveVariance_mk_to_mk<XDataType, GridwiseNormalizationNaiveVariance_mk_to_mk<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
GridDesc_M,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
...@@ -106,15 +125,18 @@ auto NormalizationKernelSelector(bool isSweepOnce) ...@@ -106,15 +125,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize, BetaSrcVectorSize,
YDstVectorDim, YDstVectorDim,
YDstVectorSize, YDstVectorSize,
SaveMeanInvStdDstVectorSize,
true>; true>;
using GridwiseNormalizationGenericWelford = using GridwiseNormalizationGenericWelford =
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType, GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
GridDesc_M,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
...@@ -128,15 +150,18 @@ auto NormalizationKernelSelector(bool isSweepOnce) ...@@ -128,15 +150,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize, BetaSrcVectorSize,
YDstVectorDim, YDstVectorDim,
YDstVectorSize, YDstVectorSize,
SaveMeanInvStdDstVectorSize,
false>; false>;
using GridwiseNormalizationSweepOnceWelford = using GridwiseNormalizationSweepOnceWelford =
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType, GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
GridDesc_M,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
...@@ -150,6 +175,7 @@ auto NormalizationKernelSelector(bool isSweepOnce) ...@@ -150,6 +175,7 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize, BetaSrcVectorSize,
YDstVectorDim, YDstVectorDim,
YDstVectorSize, YDstVectorSize,
SaveMeanInvStdDstVectorSize,
true>; true>;
if constexpr(UseWelford) if constexpr(UseWelford)
...@@ -159,17 +185,21 @@ auto NormalizationKernelSelector(bool isSweepOnce) ...@@ -159,17 +185,21 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K> GridDesc_M_K,
GridDesc_M>
: kernel_normalization<GridwiseNormalizationGenericWelford, : kernel_normalization<GridwiseNormalizationGenericWelford,
XDataType, XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K>; GridDesc_M_K,
GridDesc_M>;
} }
else else
{ {
...@@ -178,17 +208,21 @@ auto NormalizationKernelSelector(bool isSweepOnce) ...@@ -178,17 +208,21 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K> GridDesc_M_K,
GridDesc_M>
: kernel_normalization<GridwiseNormalizationGenericNaive, : kernel_normalization<GridwiseNormalizationGenericNaive,
XDataType, XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K>; GridDesc_M_K,
GridDesc_M>;
} }
} }
......
...@@ -17,11 +17,13 @@ template <typename MeanVarDataType, ...@@ -17,11 +17,13 @@ template <typename MeanVarDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType, typename ComputeDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
typename MeanVarGridDesc_M_KBlock, typename MeanVarGridDesc_M_KBlock,
typename CountGridDesc_M_KBlock, typename CountGridDesc_M_KBlock,
typename XYGammaBetaGridDesc_M_K, typename XYGammaBetaGridDesc_M_K,
typename SaveMeanInvStdGridDesc_M,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
index_t KThreadClusterSize, index_t KThreadClusterSize,
...@@ -34,7 +36,8 @@ template <typename MeanVarDataType, ...@@ -34,7 +36,8 @@ template <typename MeanVarDataType,
index_t BetaSrcVectorDim, index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorDim, index_t YDstVectorDim,
index_t YDstVectorSize> index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize>
struct GridwiseNormalizationSplitK2nd struct GridwiseNormalizationSplitK2nd
{ {
static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
...@@ -45,6 +48,10 @@ struct GridwiseNormalizationSplitK2nd ...@@ -45,6 +48,10 @@ struct GridwiseNormalizationSplitK2nd
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"); "Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!");
static_assert(XSrcVectorSize == YDstVectorSize); static_assert(XSrcVectorSize == YDstVectorSize);
static_assert(XSrcVectorSize == GammaSrcVectorSize); static_assert(XSrcVectorSize == GammaSrcVectorSize);
static_assert(XSrcVectorSize == BetaSrcVectorSize); static_assert(XSrcVectorSize == BetaSrcVectorSize);
...@@ -69,6 +76,10 @@ struct GridwiseNormalizationSplitK2nd ...@@ -69,6 +76,10 @@ struct GridwiseNormalizationSplitK2nd
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>; using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
static constexpr auto thread_buffer_desc_m_1 = static constexpr auto thread_buffer_desc_m_1 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, I1)); make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, I1));
...@@ -99,6 +110,8 @@ struct GridwiseNormalizationSplitK2nd ...@@ -99,6 +110,8 @@ struct GridwiseNormalizationSplitK2nd
const XYGammaBetaGridDesc_M_K& gamma_grid_desc_m_k, const XYGammaBetaGridDesc_M_K& gamma_grid_desc_m_k,
const XYGammaBetaGridDesc_M_K& beta_grid_desc_m_k, const XYGammaBetaGridDesc_M_K& beta_grid_desc_m_k,
const XYGammaBetaGridDesc_M_K& y_grid_desc_m_k, const XYGammaBetaGridDesc_M_K& y_grid_desc_m_k,
const SaveMeanInvStdGridDesc_M& save_mean_grid_desc_m,
const SaveMeanInvStdGridDesc_M& save_inv_std_grid_desc_m,
index_t num_k_mean_var_count_iteration, index_t num_k_mean_var_count_iteration,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
index_t k_grid_size, index_t k_grid_size,
...@@ -110,6 +123,8 @@ struct GridwiseNormalizationSplitK2nd ...@@ -110,6 +123,8 @@ struct GridwiseNormalizationSplitK2nd
const GammaDataType* const __restrict__ p_gamma_global, const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global, const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global, YDataType* const __restrict__ p_y_global,
SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
const YElementwiseOperation y_elementwise_op) const YElementwiseOperation y_elementwise_op)
{ {
// Thread/Block id // Thread/Block id
...@@ -145,6 +160,12 @@ struct GridwiseNormalizationSplitK2nd ...@@ -145,6 +160,12 @@ struct GridwiseNormalizationSplitK2nd
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto save_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize());
auto save_inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize());
// VGPR // VGPR
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
in_mean_thread_buf; in_mean_thread_buf;
...@@ -158,6 +179,7 @@ struct GridwiseNormalizationSplitK2nd ...@@ -158,6 +179,7 @@ struct GridwiseNormalizationSplitK2nd
var_thread_buf; var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf; welford_count_thread_buf;
auto& inv_std_thread_buf = var_thread_buf;
auto x_thread_buf = generate_tuple( auto x_thread_buf = generate_tuple(
[&](auto) { [&](auto) {
...@@ -283,6 +305,42 @@ struct GridwiseNormalizationSplitK2nd ...@@ -283,6 +305,42 @@ struct GridwiseNormalizationSplitK2nd
thread_k_cluster_id * YDstVectorSize), thread_k_cluster_id * YDstVectorSize),
y_elementwise_op); y_elementwise_op);
auto threadwise_mean_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
SaveMeanInvStdGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_mean_grid_desc_m,
make_multi_index(block_m_cluster_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_inv_std_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
SaveMeanInvStdGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_inv_std_grid_desc_m,
make_multi_index(block_m_cluster_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
// step1: Merge mean and variance // step1: Merge mean and variance
constexpr auto mean_var_count_thread_copy_step_I0_k = constexpr auto mean_var_count_thread_copy_step_I0_k =
make_multi_index(I0, KThreadClusterSize); make_multi_index(I0, KThreadClusterSize);
...@@ -332,9 +390,33 @@ struct GridwiseNormalizationSplitK2nd ...@@ -332,9 +390,33 @@ struct GridwiseNormalizationSplitK2nd
BlockwiseWelford::Run( BlockwiseWelford::Run(
mean_thread_buf(I), var_thread_buf(I), welford_count_thread_buf(I)); mean_thread_buf(I), var_thread_buf(I), welford_count_thread_buf(I));
inv_std_thread_buf(I) =
type_convert<ComputeDataType>(1.0f) / ck::math::sqrt(var_thread_buf(I) + epsilon);
}); });
// step2: normalization // step2: save mean and inverse std for backward (optional)
if(block_k_cluster_id == 0 && thread_k_cluster_id == 0)
{
if(p_save_mean_global != nullptr)
{
threadwise_mean_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
save_mean_grid_desc_m,
save_mean_global_val_buf);
}
if(p_save_inv_std_global != nullptr)
{
threadwise_inv_std_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_std_thread_buf,
save_inv_std_grid_desc_m,
save_inv_std_global_val_buf);
}
}
// step3: normalization
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
for(index_t k = 0; k < num_k_block_tile_iteration; ++k) for(index_t k = 0; k < num_k_block_tile_iteration; ++k)
...@@ -360,7 +442,6 @@ struct GridwiseNormalizationSplitK2nd ...@@ -360,7 +442,6 @@ struct GridwiseNormalizationSplitK2nd
}); });
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
...@@ -369,7 +450,7 @@ struct GridwiseNormalizationSplitK2nd ...@@ -369,7 +450,7 @@ struct GridwiseNormalizationSplitK2nd
// normalize // normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) * (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor; inv_std_thread_buf(iM);
// gamma // gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
......
...@@ -16,9 +16,11 @@ template <typename XDataType, ...@@ -16,9 +16,11 @@ template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType, typename ComputeDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
typename GridDesc_M_K, typename GridDesc_M_K,
typename GridDesc_M,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
index_t KThreadClusterSize, index_t KThreadClusterSize,
...@@ -32,6 +34,7 @@ template <typename XDataType, ...@@ -32,6 +34,7 @@ template <typename XDataType,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorDim, index_t YDstVectorDim,
index_t YDstVectorSize, index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize,
bool SweepOnce> bool SweepOnce>
struct GridwiseNormalizationWelfordVariance_mk_to_mk struct GridwiseNormalizationWelfordVariance_mk_to_mk
{ {
...@@ -43,6 +46,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -43,6 +46,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"); "Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!");
static_assert(XSrcVectorSize == YDstVectorSize); static_assert(XSrcVectorSize == YDstVectorSize);
static_assert(XSrcVectorSize == GammaSrcVectorSize); static_assert(XSrcVectorSize == GammaSrcVectorSize);
static_assert(XSrcVectorSize == BetaSrcVectorSize); static_assert(XSrcVectorSize == BetaSrcVectorSize);
...@@ -64,6 +71,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -64,6 +71,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}))); make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})));
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
...@@ -77,6 +88,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -77,6 +88,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
ThreadClusterLengths_M_K, ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder>; ThreadClusterArrangeOrder>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -114,17 +127,18 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -114,17 +127,18 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
const GridDesc_M_K& gamma_grid_desc_m_k, const GridDesc_M_K& gamma_grid_desc_m_k,
const GridDesc_M_K& beta_grid_desc_m_k, const GridDesc_M_K& beta_grid_desc_m_k,
const GridDesc_M_K& y_grid_desc_m_k, const GridDesc_M_K& y_grid_desc_m_k,
const GridDesc_M& save_mean_grid_desc_m,
const GridDesc_M& save_inv_std_grid_desc_m,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
ComputeDataType epsilon, ComputeDataType epsilon,
const XDataType* const __restrict__ p_x_global, const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global, const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global, const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global, YDataType* const __restrict__ p_y_global,
SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
const YElementwiseOperation y_elementwise_op) const YElementwiseOperation y_elementwise_op)
{ {
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto x_thread_buf = generate_tuple( auto x_thread_buf = generate_tuple(
[&](auto) { [&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr, return StaticBuffer<AddressSpaceEnum::Vgpr,
...@@ -150,6 +164,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -150,6 +164,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
mean_thread_buf; mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
var_thread_buf; var_thread_buf;
auto& inv_std_thread_buf = var_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
...@@ -226,6 +241,42 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -226,6 +241,42 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
thread_k_cluster_id * YDstVectorSize), thread_k_cluster_id * YDstVectorSize),
y_elementwise_op); y_elementwise_op);
auto threadwise_mean_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_mean_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_inv_std_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_inv_std_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
constexpr auto thread_copy_bwd_step_m_k = constexpr auto thread_copy_bwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
...@@ -239,6 +290,15 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -239,6 +290,15 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize()); p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto save_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize());
auto save_inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford(); auto threadwise_welford = ThreadwiseWelford();
threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id); threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id);
...@@ -279,10 +339,33 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -279,10 +339,33 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
int count = threadwise_welford.cur_count_; int count = threadwise_welford.cur_count_;
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
inv_std_thread_buf(I) = type_convert<ComputeDataType>(1.0f) /
ck::math::sqrt(var_thread_buf(I) + epsilon);
}); });
// save mean and inverse std for backward (optional)
if(thread_k_cluster_id == 0)
{
if(p_save_mean_global != nullptr)
{
threadwise_mean_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
save_mean_grid_desc_m,
save_mean_global_val_buf);
}
if(p_save_inv_std_global != nullptr)
{
threadwise_inv_std_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_std_thread_buf,
save_inv_std_grid_desc_m,
save_inv_std_global_val_buf);
}
}
// normalization
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
...@@ -291,7 +374,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -291,7 +374,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
// normalize // normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) * (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor; inv_std_thread_buf(iM);
// gamma & beta // gamma & beta
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
...@@ -360,8 +443,29 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -360,8 +443,29 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
int count = threadwise_welford.cur_count_; int count = threadwise_welford.cur_count_;
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon);
}); });
if(thread_k_cluster_id == 0)
{
if(p_save_mean_global != nullptr)
{
threadwise_mean_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
save_mean_grid_desc_m,
save_mean_global_val_buf);
}
if(p_save_inv_std_global != nullptr)
{
threadwise_inv_std_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_std_thread_buf,
save_inv_std_grid_desc_m,
save_inv_std_global_val_buf);
}
}
auto thread_copy_tail_m_k = auto thread_copy_tail_m_k =
(num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k; (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k;
...@@ -393,7 +497,6 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -393,7 +497,6 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
}); });
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
...@@ -402,7 +505,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -402,7 +505,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
// normalize // normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) * (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor; inv_std_thread_buf(iM);
// gamma // gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
......
...@@ -20,8 +20,9 @@ template <typename XDataType, ...@@ -20,8 +20,9 @@ template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename AccDataType, typename SaveMeanInvStdDataType,
typename AccElementwiseOperation> typename ComputeDataType,
typename YElementwiseOperation>
struct ReferenceGroupnorm : public device::BaseOperator struct ReferenceGroupnorm : public device::BaseOperator
{ {
// x = [N, H, W, G, C] // x = [N, H, W, G, C]
...@@ -35,14 +36,18 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -35,14 +36,18 @@ struct ReferenceGroupnorm : public device::BaseOperator
const Tensor<GammaDataType>& gamma, const Tensor<GammaDataType>& gamma,
const Tensor<BetaDataType>& beta, const Tensor<BetaDataType>& beta,
Tensor<YDataType>& y, Tensor<YDataType>& y,
AccElementwiseOperation acc_elementwise_op, Tensor<SaveMeanInvStdDataType>& save_mean,
Tensor<SaveMeanInvStdDataType>& save_inv_std,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths, const std::vector<index_t> lengths,
AccDataType epsilon) ComputeDataType epsilon)
: x_(x), : x_(x),
gamma_(gamma), gamma_(gamma),
beta_(beta), beta_(beta),
y_(y), y_(y),
acc_elementwise_op_(acc_elementwise_op), save_mean_(save_mean),
save_inv_std_(save_inv_std),
y_elementwise_op_(y_elementwise_op),
lengths_(lengths), lengths_(lengths),
epsilon_(epsilon) epsilon_(epsilon)
{ {
...@@ -52,9 +57,11 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -52,9 +57,11 @@ struct ReferenceGroupnorm : public device::BaseOperator
const Tensor<XDataType> gamma_; const Tensor<XDataType> gamma_;
const Tensor<XDataType> beta_; const Tensor<XDataType> beta_;
Tensor<YDataType>& y_; Tensor<YDataType>& y_;
AccElementwiseOperation acc_elementwise_op_; Tensor<SaveMeanInvStdDataType>& save_mean_;
Tensor<SaveMeanInvStdDataType>& save_inv_std_;
YElementwiseOperation y_elementwise_op_;
std::vector<index_t> lengths_; std::vector<index_t> lengths_;
AccDataType epsilon_; ComputeDataType epsilon_;
}; };
// Invoker // Invoker
...@@ -68,8 +75,8 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -68,8 +75,8 @@ struct ReferenceGroupnorm : public device::BaseOperator
int G = arg.lengths_[3]; int G = arg.lengths_[3];
int C = arg.lengths_[4]; int C = arg.lengths_[4];
Tensor<AccDataType> mean({N, G}); Tensor<ComputeDataType> mean({N, G});
Tensor<AccDataType> var({N, G}); Tensor<ComputeDataType> var({N, G});
// Compute mean & var in [H, W, C] by Welford Algorithm // Compute mean & var in [H, W, C] by Welford Algorithm
// TODO - parallel for each HWC // TODO - parallel for each HWC
...@@ -78,9 +85,9 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -78,9 +85,9 @@ struct ReferenceGroupnorm : public device::BaseOperator
{ {
for(int g = 0; g < G; ++g) for(int g = 0; g < G; ++g)
{ {
AccDataType mean_val = type_convert<AccDataType>(0.0f); ComputeDataType mean_val = type_convert<ComputeDataType>(0.0f);
AccDataType var_val = type_convert<AccDataType>(0.0f); ComputeDataType var_val = type_convert<ComputeDataType>(0.0f);
int32_t curr_count = 0; int32_t curr_count = 0;
for(int h = 0; h < H; ++h) for(int h = 0; h < H; ++h)
{ {
...@@ -89,10 +96,11 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -89,10 +96,11 @@ struct ReferenceGroupnorm : public device::BaseOperator
for(int c = 0; c < C; ++c) for(int c = 0; c < C; ++c)
{ {
curr_count++; curr_count++;
AccDataType x = type_convert<AccDataType>(arg.x_(n, h, w, g, c)); ComputeDataType x =
AccDataType delta = x - mean_val; type_convert<ComputeDataType>(arg.x_(n, h, w, g, c));
ComputeDataType delta = x - mean_val;
mean_val += delta / curr_count; mean_val += delta / curr_count;
AccDataType delta2 = x - mean_val; ComputeDataType delta2 = x - mean_val;
var_val += delta * delta2; var_val += delta * delta2;
} }
} }
...@@ -100,6 +108,12 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -100,6 +108,12 @@ struct ReferenceGroupnorm : public device::BaseOperator
mean(n, g) = mean_val; mean(n, g) = mean_val;
var(n, g) = var_val / curr_count; var(n, g) = var_val / curr_count;
arg.save_mean_(n, g) = ck::type_convert<SaveMeanInvStdDataType>(mean(n, g));
ComputeDataType divisor =
static_cast<ComputeDataType>(1) / ck::math::sqrt(var(n, g) + arg.epsilon_);
arg.save_inv_std_(n, g) = ck::type_convert<SaveMeanInvStdDataType>(divisor);
} }
} }
...@@ -114,15 +128,19 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -114,15 +128,19 @@ struct ReferenceGroupnorm : public device::BaseOperator
{ {
for(int c = 0; c < C; ++c) for(int c = 0; c < C; ++c)
{ {
AccDataType x = type_convert<AccDataType>(arg.x_(n, h, w, g, c)); ComputeDataType x =
AccDataType gamma = type_convert<AccDataType>(arg.gamma_(g, c)); type_convert<ComputeDataType>(arg.x_(n, h, w, g, c));
AccDataType beta = type_convert<AccDataType>(arg.beta_(g, c)); ComputeDataType gamma =
AccDataType mean_val = type_convert<AccDataType>(mean(n, g)); type_convert<ComputeDataType>(arg.gamma_(g, c));
AccDataType var_val = type_convert<AccDataType>(var(n, g)); ComputeDataType beta =
AccDataType y = gamma * (x - mean_val) / type_convert<ComputeDataType>(arg.beta_(g, c));
ck::math::sqrt(arg.epsilon_ + var_val) + ComputeDataType mean_val =
beta; type_convert<ComputeDataType>(mean(n, g));
arg.acc_elementwise_op_(y, y); ComputeDataType var_val = type_convert<ComputeDataType>(var(n, g));
ComputeDataType y = gamma * (x - mean_val) /
ck::math::sqrt(arg.epsilon_ + var_val) +
beta;
arg.y_elementwise_op_(y, y);
arg.y_(n, h, w, g, c) = type_convert<YDataType>(y); arg.y_(n, h, w, g, c) = type_convert<YDataType>(y);
} }
} }
...@@ -159,11 +177,14 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -159,11 +177,14 @@ struct ReferenceGroupnorm : public device::BaseOperator
const Tensor<GammaDataType>& gamma, const Tensor<GammaDataType>& gamma,
const Tensor<BetaDataType>& beta, const Tensor<BetaDataType>& beta,
Tensor<YDataType>& y, Tensor<YDataType>& y,
AccElementwiseOperation acc_elementwise_op, Tensor<SaveMeanInvStdDataType>& save_mean,
Tensor<SaveMeanInvStdDataType>& save_inv_std,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths, const std::vector<index_t> lengths,
AccDataType epsilon) ComputeDataType epsilon)
{ {
return Argument{x, gamma, beta, y, acc_elementwise_op, lengths, epsilon}; return Argument{
x, gamma, beta, y, save_mean, save_inv_std, y_elementwise_op, lengths, epsilon};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
...@@ -20,8 +20,9 @@ template <typename XDataType, ...@@ -20,8 +20,9 @@ template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename AccDataType, typename SaveMeanInvStdDataType,
typename AccElementwiseOperation, typename ComputeDataType,
typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
struct ReferenceLayernorm : public device::BaseOperator struct ReferenceLayernorm : public device::BaseOperator
...@@ -36,15 +37,19 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -36,15 +37,19 @@ struct ReferenceLayernorm : public device::BaseOperator
const Tensor<GammaDataType>& gamma_n, const Tensor<GammaDataType>& gamma_n,
const Tensor<BetaDataType>& beta_n, const Tensor<BetaDataType>& beta_n,
Tensor<YDataType>& y_m_n, Tensor<YDataType>& y_m_n,
AccElementwiseOperation acc_elementwise_op, Tensor<SaveMeanInvStdDataType>& save_mean_m,
Tensor<SaveMeanInvStdDataType>& save_inv_std_m,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths, const std::vector<index_t> lengths,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccDataType epsilon) ComputeDataType epsilon)
: x_m_n_(x_m_n), : x_m_n_(x_m_n),
gamma_n_(gamma_n), gamma_n_(gamma_n),
beta_n_(beta_n), beta_n_(beta_n),
y_m_n_(y_m_n), y_m_n_(y_m_n),
acc_elementwise_op_(acc_elementwise_op), save_mean_m_(save_mean_m),
save_inv_std_m_(save_inv_std_m),
y_elementwise_op_(y_elementwise_op),
lengths_(lengths), lengths_(lengths),
reduceDims_(reduceDims), reduceDims_(reduceDims),
epsilon_(epsilon) epsilon_(epsilon)
...@@ -55,10 +60,12 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -55,10 +60,12 @@ struct ReferenceLayernorm : public device::BaseOperator
const Tensor<XDataType> gamma_n_; const Tensor<XDataType> gamma_n_;
const Tensor<XDataType> beta_n_; const Tensor<XDataType> beta_n_;
Tensor<YDataType>& y_m_n_; Tensor<YDataType>& y_m_n_;
AccElementwiseOperation acc_elementwise_op_; Tensor<SaveMeanInvStdDataType>& save_mean_m_;
Tensor<SaveMeanInvStdDataType>& save_inv_std_m_;
YElementwiseOperation y_elementwise_op_;
std::vector<index_t> lengths_; std::vector<index_t> lengths_;
std::vector<index_t> reduceDims_; std::vector<index_t> reduceDims_;
AccDataType epsilon_; ComputeDataType epsilon_;
}; };
// Invoker // Invoker
...@@ -69,8 +76,8 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -69,8 +76,8 @@ struct ReferenceLayernorm : public device::BaseOperator
int M = arg.lengths_[0]; int M = arg.lengths_[0];
int N = arg.lengths_[1]; int N = arg.lengths_[1];
Tensor<AccDataType> mean({M}); Tensor<ComputeDataType> mean({M});
Tensor<AccDataType> var({M}); Tensor<ComputeDataType> var({M});
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
...@@ -79,7 +86,7 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -79,7 +86,7 @@ struct ReferenceLayernorm : public device::BaseOperator
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n)); auto x_val = ck::type_convert<ComputeDataType>(arg.x_m_n_(m, n));
mean(m) += x_val; mean(m) += x_val;
var(m) += x_val * x_val; var(m) += x_val * x_val;
} }
...@@ -90,17 +97,21 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -90,17 +97,21 @@ struct ReferenceLayernorm : public device::BaseOperator
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
AccDataType divisor = ComputeDataType divisor =
static_cast<AccDataType>(1) / ck::math::sqrt(var(m) + arg.epsilon_); static_cast<ComputeDataType>(1) / ck::math::sqrt(var(m) + arg.epsilon_);
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n)); auto x_val = ck::type_convert<ComputeDataType>(arg.x_m_n_(m, n));
auto y_val = (x_val - mean(m)) * divisor; auto gamma_val = ck::type_convert<ComputeDataType>(arg.gamma_n_(n));
y_val = (y_val * arg.gamma_n_(n)) + arg.beta_n_(n); auto beta_val = ck::type_convert<ComputeDataType>(arg.beta_n_(n));
arg.acc_elementwise_op_(y_val, y_val); auto y_val = (x_val - mean(m)) * divisor;
y_val = (y_val * gamma_val) + beta_val;
arg.y_elementwise_op_(y_val, y_val);
arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val); arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val);
} }
arg.save_mean_m_(m) = ck::type_convert<SaveMeanInvStdDataType>(mean(m));
arg.save_inv_std_m_(m) = ck::type_convert<SaveMeanInvStdDataType>(divisor);
} }
return 0; return 0;
...@@ -140,13 +151,23 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -140,13 +151,23 @@ struct ReferenceLayernorm : public device::BaseOperator
const Tensor<GammaDataType>& gamma_n, const Tensor<GammaDataType>& gamma_n,
const Tensor<BetaDataType>& beta_n, const Tensor<BetaDataType>& beta_n,
Tensor<YDataType>& y_m_n, Tensor<YDataType>& y_m_n,
AccElementwiseOperation acc_elementwise_op, Tensor<SaveMeanInvStdDataType>& save_mean_m,
Tensor<SaveMeanInvStdDataType>& save_inv_std_m,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths, const std::vector<index_t> lengths,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccDataType epsilon) ComputeDataType epsilon)
{ {
return Argument{ return Argument{x_m_n,
x_m_n, gamma_n, beta_n, y_m_n, acc_elementwise_op, lengths, reduceDims, epsilon}; gamma_n,
beta_n,
y_m_n,
save_mean_m,
save_inv_std_m,
y_elementwise_op,
lengths,
reduceDims,
epsilon};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment