"example/36_permute/run_permute_example.inc" did not exist on "665b73ff1dd5c2b5fbdf2ddd9c19da2a583fdd2d"
Commit d0f355a3 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 55a89c74 b305a29e
...@@ -149,7 +149,7 @@ function(clang_tidy_check TARGET) ...@@ -149,7 +149,7 @@ function(clang_tidy_check TARGET)
add_custom_target(${tidy_target} add_custom_target(${tidy_target}
# for some targets clang-tidy not able to get information from .clang-tidy # for some targets clang-tidy not able to get information from .clang-tidy
DEPENDS ${SOURCE} DEPENDS ${SOURCE}
COMMAND ${CLANG_TIDY_COMMAND} "-config=\{CheckOptions: \[\{key: bugprone-reserved-identifier.AllowedIdentifiers,value: __HIP_PLATFORM_HCC__\; __HIP_ROCclr__\}\]\}" ${SOURCE} "-export-fixes=${CLANG_TIDY_FIXIT_DIR}/${TARGET}-${tidy_file}.yaml" COMMAND ${CLANG_TIDY_COMMAND} "-config=\{CheckOptions: \[\{key: bugprone-reserved-identifier.AllowedIdentifiers,value: __HIP_PLATFORM_HCC__\; __HIP_PLATFORM_AMD__\; __HIP_ROCclr__\}\]\}" ${SOURCE} "-export-fixes=${CLANG_TIDY_FIXIT_DIR}/${TARGET}-${tidy_file}.yaml"
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
COMMENT "clang-tidy: Running clang-tidy on target ${SOURCE}..." COMMENT "clang-tidy: Running clang-tidy on target ${SOURCE}..."
) )
......
rocm-docs-core==0.30.1 rocm-docs-core==0.30.2
sphinxcontrib-bibtex==2.6.1 sphinxcontrib-bibtex==2.6.1
...@@ -113,7 +113,7 @@ requests==2.31.0 ...@@ -113,7 +113,7 @@ requests==2.31.0
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==0.30.1 rocm-docs-core==0.30.2
# via -r requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via # via
......
#include <iostream> #include <iostream>
#include <cstdlib> #include <cstdlib>
#include <random>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
...@@ -48,10 +49,8 @@ void host_elementwise4D(HostTensorB& B_nhwc, ...@@ -48,10 +49,8 @@ void host_elementwise4D(HostTensorB& B_nhwc,
for(std::size_t n = 0; n < N; ++n) for(std::size_t n = 0; n < N; ++n)
{ {
ADataType tmp_val; ADataType tmp_val;
// auto a_val = A_nchw(n, c, h, w);
auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)]; auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)];
functor_b(tmp_val, a_val); functor_b(tmp_val, a_val);
// functor_a(B_nhwc(n, h, w, c), scale * tmp_val);
functor_a(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], functor_a(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)],
scale * tmp_val); scale * tmp_val);
} }
...@@ -62,12 +61,14 @@ int main() ...@@ -62,12 +61,14 @@ int main()
bool do_verification = true; bool do_verification = true;
bool time_kernel = true; bool time_kernel = true;
std::vector<std::size_t> nchw = {4, 2, 1, 8}; std::vector<std::size_t> nchw = {16, 8, 32, 64};
std::vector<std::size_t> nhwc = {4, 1, 8, 2}; std::vector<std::size_t> nhwc = {16, 32, 64, 8};
Tensor<ADataType> a(nchw); Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc); Tensor<BDataType> b(nhwc);
float scale = 1.f; float scale = 1.f;
auto i = 0; auto i = 0;
std::mt19937 gen(11939);
std::uniform_int_distribution<int> dis(0, 1);
for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w) for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w)
for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h) for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h)
for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c)
...@@ -75,7 +76,7 @@ int main() ...@@ -75,7 +76,7 @@ int main()
{ {
a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) + a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) +
(h * nchw[3]) + w] = i; (h * nchw[3]) + w] = i;
i++; i = dis(gen);
} }
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
......
...@@ -67,6 +67,8 @@ int main() ...@@ -67,6 +67,8 @@ int main()
float scale = 1.f; float scale = 1.f;
auto i = 0; auto i = 0;
std::mt19937 gen(11939);
std::uniform_int_distribution<int> dis(0, 1);
for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w) for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w)
for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h) for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h)
for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c)
...@@ -74,7 +76,7 @@ int main() ...@@ -74,7 +76,7 @@ int main()
{ {
a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) + a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) +
(h * nchw[3]) + w] = i; (h * nchw[3]) + w] = i;
i++; i = dis(gen);
} }
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
......
add_example_executable(example_layernorm2d_bwd_fp32 layernorm2d_bwd_fp32.cpp)
...@@ -15,16 +15,17 @@ ...@@ -15,16 +15,17 @@
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp"
using DYDataType = ck::half_t; using DYDataType = float;
using XDataType = ck::half_t; using XDataType = float;
using GammaDataType = ck::half_t; using GammaDataType = float;
using MeanInvStdDataType = float; using MeanInvStdDataType = float;
using DGammaDataType = ck::half_t; using DGammaDataType = float;
using DBetaDataType = ck::half_t; using DBetaDataType = float;
using DXDataType = ck::half_t; using DXDataType = float;
using ComputeDataType = float; using ComputeDataType = float;
constexpr int Rank = 2; constexpr int Rank = 2;
...@@ -39,6 +40,7 @@ constexpr int NumReduceDim = 1; ...@@ -39,6 +40,7 @@ constexpr int NumReduceDim = 1;
// inv_std: [M, 1] // inv_std: [M, 1]
// Output shape // Output shape
// dx: [M, N]
// dgamma: [1, N] // dgamma: [1, N]
// dbeta: [1, N] // dbeta: [1, N]
...@@ -46,8 +48,34 @@ constexpr int NumReduceDim = 1; ...@@ -46,8 +48,34 @@ constexpr int NumReduceDim = 1;
// dbeta = reduce_sum(dy, axis=0) // dbeta = reduce_sum(dy, axis=0)
// [CAUSION] // [CAUSION]
// In DeviceNormalizationBwdGammaBetaImpl, M is invarient dimension, K is reduced dimension // In DeviceNormalizationBwdDataImpl & DeviceNormalizationBwdGammaBetaImpl, M is Invariant
// Hence, M in this example and DeviceNormalizationBwdGammaBetaImpl is different // dimension, K is reduced dimension Hence, M in this example and
// DeviceNormalizationBwdGammaBetaImpl is different
using XDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdDataImpl<
DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
ComputeDataType,
DXDataType,
Rank,
NumReduceDim,
256, // BlockSize
8, // MThreadClusterSize
32, // KThreadClusterSize
1, // MThreadSliceSize
4, // KThreadSliceSize
true, // IsDYFastestDimReduced
4, // DYSrcVectorSize
true, // IsXFastestDimReduced
4, // XSrcVectorSize
true, // IsGammaFastestDimReduced
4, // GammaSrcVectorSize
false, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize
true, // IsDXFastestDimReduced
4>; // DXDstVectorSize
using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdGammaBetaImpl< using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdGammaBetaImpl<
DYDataType, DYDataType,
XDataType, XDataType,
...@@ -58,18 +86,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio ...@@ -58,18 +86,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio
Rank, Rank,
NumReduceDim, NumReduceDim,
256, // BlockSize 256, // BlockSize
8, // ClusterInvarient 8, // MThreadClusterSize
32, // ClusterReduce 32, // KThreadClusterSize
8, // SliceInvarient 4, // MThreadSliceSize
1, // SliceReduce 1, // KThreadSliceSize
false, // IsDYFastestDimReduced false, // IsDYFastestDimReduced
8, // DYSrcVectorSize 4, // DYSrcVectorSize
false, // IsXFastestDimReduced false, // IsXFastestDimReduced
8, // XSrcVectorSize 4, // XSrcVectorSize
true, // IsMeanInvStdFastestDimReduced true, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize 1, // MeanInvStdSrcVectorSize
1, // DGammaDstVectorSize 4, // DGammaDstVectorSize
1>; // DBetaDstVectorSize 4>; // DBetaDstVectorSize
int main() int main()
{ {
...@@ -96,16 +124,48 @@ int main() ...@@ -96,16 +124,48 @@ int main()
DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize());
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize());
DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize());
DeviceMem dx_dev(sizeof(DXDataType) * dx.mDesc.GetElementSpaceSize());
DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize()); DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize());
DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize()); DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize());
dy_dev.ToDevice(dy.mData.data()); dy_dev.ToDevice(dy.mData.data());
x_dev.ToDevice(x.mData.data()); x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data());
mean_dev.ToDevice(mean.mData.data()); mean_dev.ToDevice(mean.mData.data());
inv_std_dev.ToDevice(inv_std.mData.data()); inv_std_dev.ToDevice(inv_std.mData.data());
// backward x
auto x_device_instance = XDeviceInstance{};
auto x_argument_ptr = x_device_instance.MakeArgumentPointer({M, N}, // lengths
{N, 1}, // dyStrides
{N, 1}, // xStrides
{0, 1}, // gammaStrides
{1, 0}, // meanStrides
{1, 0}, // invStdStrides
{N, 1}, // dxStrides
{1}, // reduceDims
dy_dev.GetDeviceBuffer(),
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
mean_dev.GetDeviceBuffer(),
inv_std_dev.GetDeviceBuffer(),
dx_dev.GetDeviceBuffer());
if(!x_device_instance.IsSupportedArgument(x_argument_ptr.get()))
{
std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__
<< std::endl;
return 1;
};
auto x_invoker_ptr = x_device_instance.MakeInvokerPointer();
x_invoker_ptr->Run(x_argument_ptr.get(), StreamConfig{nullptr, time_kernel});
// backward gamma & beta
auto gamma_beta_device_instance = GammaBetaDeviceInstance{}; auto gamma_beta_device_instance = GammaBetaDeviceInstance{};
auto gamma_beta_argument_ptr = auto gamma_beta_argument_ptr =
gamma_beta_device_instance.MakeArgumentPointer({M, N}, // inLengths gamma_beta_device_instance.MakeArgumentPointer({M, N}, // inLengths
...@@ -126,7 +186,8 @@ int main() ...@@ -126,7 +186,8 @@ int main()
if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get())) if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get()))
{ {
std::cout << "The runtime parameters are not supported" << std::endl; std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__
<< std::endl;
return 1; return 1;
}; };
...@@ -156,9 +217,11 @@ int main() ...@@ -156,9 +217,11 @@ int main()
dgamma_dev.FromDevice(dgamma.mData.data()); dgamma_dev.FromDevice(dgamma.mData.data());
dbeta_dev.FromDevice(dbeta.mData.data()); dbeta_dev.FromDevice(dbeta.mData.data());
dx_dev.FromDevice(dx.mData.data());
pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3); pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3);
pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3); pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3);
pass &= ck::utils::check_err(dx, host_dx, "Error: Incorrect dx", 1e-3, 1e-3);
} }
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
add_example_executable(example_layernorm2d_bwd_fp16 layernorm2d_bwd_fp16.cpp)
add_example_executable(example_groupnorm_bwd_fp16 groupnorm_bwd_fp16.cpp) add_example_executable(example_groupnorm_bwd_fp32 groupnorm_bwd_fp32.cpp)
...@@ -15,23 +15,58 @@ ...@@ -15,23 +15,58 @@
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp"
using DYDataType = ck::half_t; using DYDataType = float;
using XDataType = ck::half_t; using XDataType = float;
using GammaDataType = ck::half_t; using GammaDataType = float;
using MeanInvStdDataType = float; using MeanInvStdDataType = float;
using DGammaDataType = ck::half_t; using DGammaDataType = float;
using DBetaDataType = ck::half_t; using DBetaDataType = float;
using DXDataType = ck::half_t; using DXDataType = float;
using ComputeDataType = float; using ComputeDataType = float;
constexpr int Rank = 5; constexpr int Rank = 5;
constexpr int NumReduceDim = 3; constexpr int NumReduceDim = 3;
// Grouprnorm // Grouprnorm
// kernel: M , K // kernel 1: M , K
// dy: N, H, W, G, C -> N * G, H * W * C
// x: N, H, W, G, C -> N * G, H * W * C
// gamma: 1, 1, 1, G, C -> 1 * G, 1 * 1 * C
// mean: N, 1, 1, G, 1 -> N * G, 1 * 1 * 1
// rstd: N, 1, 1, G, 1 -> N * G, 1 * 1 * 1
// dx: N, H, W, G, C -> N * G, H * W * C
using XDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdDataImpl<
DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
ComputeDataType,
DXDataType,
Rank,
NumReduceDim,
256, // BlockSize
8, // MThreadClusterSize
32, // KThreadClusterSize
1, // MThreadSliceSize
4, // KThreadSliceSize
true, // IsDYFastestDimReduced
4, // DYSrcVectorSize
true, // IsXFastestDimReduced
4, // XSrcVectorSize
true, // IsGammaFastestDimReduced
4, // GammaSrcVectorSize
false, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize
true, // IsDXFastestDimReduced
4>; // DXDstVectorSize
// kernel 2: M , K
// dy: N, H, W, G, C -> G * C, N * H * W // dy: N, H, W, G, C -> G * C, N * H * W
// x: N, H, W, G, C -> G * C, N * H * W // x: N, H, W, G, C -> G * C, N * H * W
// mean: N, 1, 1, G, 1 -> G * 1, N * 1 * 1 // mean: N, 1, 1, G, 1 -> G * 1, N * 1 * 1
...@@ -52,18 +87,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio ...@@ -52,18 +87,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio
Rank, Rank,
NumReduceDim, NumReduceDim,
256, // BlockSize 256, // BlockSize
8, // ClusterInvarient 8, // ClusterInvariant
32, // ClusterReduce 32, // ClusterReduce
8, // SliceInvarient 4, // SliceInvariant
1, // SliceReduce 1, // SliceReduce
false, // IsDYFastestDimReduced false, // IsDYFastestDimReduced
8, // DYSrcVectorSize 4, // DYSrcVectorSize
false, // IsXFastestDimReduced false, // IsXFastestDimReduced
8, // XSrcVectorSize 4, // XSrcVectorSize
false, // IsMeanInvStdFastestDimReduced false, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize 1, // MeanInvStdSrcVectorSize
1, // DGammaDstVectorSize 4, // DGammaDstVectorSize
1>; // DBetaDstVectorSize 4>; // DBetaDstVectorSize
int main() int main()
{ {
...@@ -93,20 +128,55 @@ int main() ...@@ -93,20 +128,55 @@ int main()
DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize());
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize());
DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize());
DeviceMem dx_dev(sizeof(DXDataType) * dx.mDesc.GetElementSpaceSize());
DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize()); DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize());
DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize()); DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize());
dy_dev.ToDevice(dy.mData.data()); dy_dev.ToDevice(dy.mData.data());
x_dev.ToDevice(x.mData.data()); x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data());
mean_dev.ToDevice(mean.mData.data()); mean_dev.ToDevice(mean.mData.data());
inv_std_dev.ToDevice(inv_std.mData.data()); inv_std_dev.ToDevice(inv_std.mData.data());
std::vector<ck::index_t> dyStrides{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()}; std::vector<ck::index_t> dyStrides{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()};
std::vector<ck::index_t> xStrides{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}; std::vector<ck::index_t> xStrides{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()};
std::vector<ck::index_t> gammaStrides = {0, 0, 0, C, 1};
std::vector<ck::index_t> meanStrides = {G, 0, 0, 1, 0}; std::vector<ck::index_t> meanStrides = {G, 0, 0, 1, 0};
std::vector<ck::index_t> invStdStrides = {G, 0, 0, 1, 0}; std::vector<ck::index_t> invStdStrides = {G, 0, 0, 1, 0};
std::vector<ck::index_t> dxStrides{dx.mDesc.GetStrides().begin(), dx.mDesc.GetStrides().end()};
// backward x
auto x_device_instance = XDeviceInstance{};
auto x_argument_ptr = x_device_instance.MakeArgumentPointer({N, H, W, G, C}, // lengths
dyStrides, // dyStrides
xStrides, // xStrides
gammaStrides, // gammaStrides
meanStrides, // meanStrides
invStdStrides, // invStdStrides
dxStrides, // dxStrides
{1, 2, 4}, // reduceDims
dy_dev.GetDeviceBuffer(),
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
mean_dev.GetDeviceBuffer(),
inv_std_dev.GetDeviceBuffer(),
dx_dev.GetDeviceBuffer());
if(!x_device_instance.IsSupportedArgument(x_argument_ptr.get()))
{
std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__
<< std::endl;
return 1;
};
auto x_invoker_ptr = x_device_instance.MakeInvokerPointer();
x_invoker_ptr->Run(x_argument_ptr.get(), StreamConfig{nullptr, time_kernel});
// backward gamma & beta
auto gamma_beta_device_instance = GammaBetaDeviceInstance{}; auto gamma_beta_device_instance = GammaBetaDeviceInstance{};
auto gamma_beta_argument_ptr = auto gamma_beta_argument_ptr =
...@@ -128,7 +198,8 @@ int main() ...@@ -128,7 +198,8 @@ int main()
if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get())) if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get()))
{ {
std::cout << "The runtime parameters are not supported" << std::endl; std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__
<< std::endl;
return 1; return 1;
}; };
...@@ -158,9 +229,11 @@ int main() ...@@ -158,9 +229,11 @@ int main()
dgamma_dev.FromDevice(dgamma.mData.data()); dgamma_dev.FromDevice(dgamma.mData.data());
dbeta_dev.FromDevice(dbeta.mData.data()); dbeta_dev.FromDevice(dbeta.mData.data());
dx_dev.FromDevice(dx.mData.data());
pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3); pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3);
pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3); pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3);
pass &= ck::utils::check_err(dx, host_dx, "Error: Incorrect dx", 1e-3, 1e-3);
} }
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename DXDataType,
index_t Rank,
index_t NumReduceDim>
struct DeviceNormalizationBwdData : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> lengths,
const std::vector<index_t> dyStrides,
const std::vector<index_t> xStrides,
const std::vector<index_t> gammaStrides,
const std::vector<index_t> meanStrides,
const std::vector<index_t> invStdStrides,
const std::vector<index_t> dxStrides,
const std::vector<index_t> reduceDims,
const void* p_dy,
const void* p_x,
const void* p_gamma,
const void* p_mean,
const void* p_invStd,
void* p_dx) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename DXDataType,
index_t Rank,
index_t NumReduceDim>
using DeviceNormalizationBwdDataPtr = std::unique_ptr<DeviceNormalizationBwdData<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DXDataType,
Rank,
NumReduceDim>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
// M is Invariant dimension, K is reduced dimension
namespace ck {
namespace tensor_operation {
namespace device {
template <typename GridwiseNormalizationBwd,
typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename DXDataType,
typename GridDesc_M_K>
__global__ void
kernel_normalization_bwd_data(const GridDesc_M_K dy_grid_desc_m_k,
const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_M_K gamma_grid_desc_m_k,
const GridDesc_M_K mean_grid_desc_m_k,
const GridDesc_M_K inv_std_grid_desc_m_k,
const GridDesc_M_K dx_grid_desc_m_k,
index_t num_k_block_tile_iteration,
const DYDataType* const __restrict__ p_dy_global,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const MeanInvStdDataType* const __restrict__ p_mean_global,
const MeanInvStdDataType* const __restrict__ p_inv_std_global,
DXDataType* const __restrict__ p_dx_global)
{
GridwiseNormalizationBwd::Run(dy_grid_desc_m_k,
x_grid_desc_m_k,
gamma_grid_desc_m_k,
mean_grid_desc_m_k,
inv_std_grid_desc_m_k,
dx_grid_desc_m_k,
num_k_block_tile_iteration,
p_dy_global,
p_x_global,
p_gamma_global,
p_mean_global,
p_inv_std_global,
p_dx_global);
};
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename ComputeDataType,
typename DXDataType,
index_t Rank,
index_t NumReduceDim,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
bool IsDYFastestDimReduced,
index_t DYSrcVectorSize,
bool IsXFastestDimReduced,
index_t XSrcVectorSize,
bool IsGammaFastestDimReduced,
index_t GammaSrcVectorSize,
bool IsMeanInvStdFastestDimReduced,
index_t MeanInvStdSrcVectorSize,
bool IsDxFastestDimReduced,
index_t DXDstVectorSize>
struct DeviceNormalizationBwdDataImpl : public DeviceNormalizationBwdData<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DXDataType,
Rank,
NumReduceDim>
{
static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0;
static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0;
static constexpr index_t GammaSrcVectorDim = IsGammaFastestDimReduced ? 1 : 0;
static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0;
static constexpr index_t DXDstVectorDim = IsDxFastestDimReduced ? 1 : 0;
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize);
static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize % DYSrcVectorSize == 0) ||
(DYSrcVectorDim == 1 && KThreadSliceSize % DYSrcVectorSize == 0)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!");
static_assert(((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!");
static_assert(
((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) ||
(GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!");
static_assert(
(MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) ||
(MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0),
"Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please "
"check!");
static_assert(((DXDstVectorDim == 0 && MThreadSliceSize % DXDstVectorSize == 0) ||
(DXDstVectorDim == 1 && KThreadSliceSize % DXDstVectorSize == 0)),
"Invalid thread slice sizes and/or dx vector sizes configuration, please check!");
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static_assert(!reduceAllDim);
static auto Make2dDescriptor(const std::vector<index_t>& lengths,
const std::vector<index_t>& strides,
int numBlockTileIteration)
{
const auto tupleLengths = make_tuple_from_array(lengths, Number<Rank>{});
const auto tupleStrides = make_tuple_from_array(strides, Number<Rank>{});
const auto desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides);
const auto grid_desc_m_k = [&]() {
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(lengths, ReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(lengths, InvariantDims{});
return transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}();
const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{});
const auto pad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto pad_K = K_BlockTileSize * numBlockTileIteration - reduceLength;
auto grid_desc_m_k_padded =
transform_tensor_descriptor(grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, pad_M),
make_right_pad_transform(reduceLength, pad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return grid_desc_m_k_padded;
}
using GridDesc_M_K = decltype(Make2dDescriptor({1}, {1}, 1));
using GridwiseNormalizationBwdDataGeneric =
GridwiseNormalizationBwdData_mk_to_mk<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
ComputeDataType,
DXDataType,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
DYSrcVectorDim,
DYSrcVectorSize,
XSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
MeanInvStdSrcVectorDim,
MeanInvStdSrcVectorSize,
DXDstVectorDim,
DXDstVectorSize,
false>;
using GridwiseNormalizationBwdDataSweepOnce =
GridwiseNormalizationBwdData_mk_to_mk<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
ComputeDataType,
DXDataType,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
DYSrcVectorDim,
DYSrcVectorSize,
XSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
MeanInvStdSrcVectorDim,
MeanInvStdSrcVectorSize,
DXDstVectorDim,
DXDstVectorSize,
true>;
struct Argument : public BaseArgument
{
Argument(const std::vector<index_t> lengths,
const std::vector<index_t> dyStrides,
const std::vector<index_t> xStrides,
const std::vector<index_t> gammaStrides,
const std::vector<index_t> meanStrides,
const std::vector<index_t> invStdStrides,
const std::vector<index_t> dxStrides,
const std::vector<index_t> reduceDims,
const DYDataType* p_dy,
const XDataType* p_x,
const GammaDataType* p_gamma,
const MeanInvStdDataType* p_mean,
const MeanInvStdDataType* p_invStd,
DXDataType* p_dx)
: p_dy_(p_dy),
p_x_(p_x),
p_gamma_(p_gamma),
p_mean_(p_mean),
p_invStd_(p_invStd),
p_dx_(p_dx)
{
lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
dyStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(dyStrides, reduceDims);
xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims);
gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims);
meanStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(meanStrides, reduceDims);
invStdStrides_ =
shuffle_tensor_dimensions<Rank, NumReduceDim>(invStdStrides, reduceDims);
dxStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(dxStrides, reduceDims);
std::tie(MRaw_, KRaw_) = get_2d_lengths<Rank, NumReduceDim>(lengths_);
numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize);
gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize);
dy_grid_desc_m_k_ = Make2dDescriptor(lengths_, dyStrides_, numBlockTileIteration_);
x_grid_desc_m_k_ = Make2dDescriptor(lengths_, xStrides_, numBlockTileIteration_);
gamma_grid_desc_m_k_ =
Make2dDescriptor(lengths_, gammaStrides_, numBlockTileIteration_);
mean_grid_desc_m_k_ = Make2dDescriptor(lengths_, meanStrides_, numBlockTileIteration_);
inv_std_grid_desc_m_k_ =
Make2dDescriptor(lengths_, invStdStrides_, numBlockTileIteration_);
dx_grid_desc_m_k_ = Make2dDescriptor(lengths_, dxStrides_, numBlockTileIteration_);
isSweeponce_ = dy_grid_desc_m_k_.GetLength(Number<1>{}) <= K_BlockTileSize;
}
const DYDataType* p_dy_;
const XDataType* p_x_;
const GammaDataType* p_gamma_;
const MeanInvStdDataType* p_mean_;
const MeanInvStdDataType* p_invStd_;
DXDataType* p_dx_;
std::vector<index_t> lengths_;
std::vector<index_t> dyStrides_;
std::vector<index_t> xStrides_;
std::vector<index_t> gammaStrides_;
std::vector<index_t> meanStrides_;
std::vector<index_t> invStdStrides_;
std::vector<index_t> dxStrides_;
int numBlockTileIteration_;
size_t gridSize_;
// tensor descriptor
GridDesc_M_K dy_grid_desc_m_k_;
GridDesc_M_K x_grid_desc_m_k_;
GridDesc_M_K gamma_grid_desc_m_k_;
GridDesc_M_K mean_grid_desc_m_k_;
GridDesc_M_K inv_std_grid_desc_m_k_;
GridDesc_M_K dx_grid_desc_m_k_;
bool isSweeponce_;
index_t MRaw_; // Invariant length
index_t KRaw_; // reduce length
};
struct Invoker : public BaseInvoker
{
auto KernelSelector(bool isSweepOnce)
{
return isSweepOnce
? kernel_normalization_bwd_data<GridwiseNormalizationBwdDataSweepOnce,
DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DXDataType,
GridDesc_M_K>
: kernel_normalization_bwd_data<GridwiseNormalizationBwdDataGeneric,
DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DXDataType,
GridDesc_M_K>;
}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto kernel_main = KernelSelector(arg.isSweeponce_);
return launch_and_time_kernel(stream_config,
kernel_main,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.dy_grid_desc_m_k_,
arg.x_grid_desc_m_k_,
arg.gamma_grid_desc_m_k_,
arg.mean_grid_desc_m_k_,
arg.inv_std_grid_desc_m_k_,
arg.dx_grid_desc_m_k_,
arg.numBlockTileIteration_,
arg.p_dy_,
arg.p_x_,
arg.p_gamma_,
arg.p_mean_,
arg.p_invStd_,
arg.p_dx_);
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
template <index_t SrcVectorDim, index_t SrcVectorSize>
bool IsVectorDimSizeValid(const std::vector<index_t>& lengths,
const std::vector<index_t>& strides)
{
if constexpr(SrcVectorSize == 1)
return true;
// Fastest dimension is not reduced
if constexpr(SrcVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
return false;
if(strides[NumInvariantDim - 1] != 1)
return false;
if(lengths[NumInvariantDim - 1] % SrcVectorSize != 0)
return false;
}
else // Fastest dimension is reduced
{
if(strides[Rank - 1] != 1)
return false;
if(lengths[Rank - 1] % SrcVectorSize != 0)
return false;
};
return true;
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
bool pass = true;
pass &= IsVectorDimSizeValid<DYSrcVectorDim, DYSrcVectorSize>(p_arg_->lengths_,
p_arg_->dyStrides_);
pass &= IsVectorDimSizeValid<XSrcVectorDim, XSrcVectorSize>(p_arg_->lengths_,
p_arg_->xStrides_);
pass &= IsVectorDimSizeValid<GammaSrcVectorDim, GammaSrcVectorSize>(p_arg_->lengths_,
p_arg_->gammaStrides_);
pass &= IsVectorDimSizeValid<MeanInvStdSrcVectorDim, MeanInvStdSrcVectorSize>(
p_arg_->lengths_, p_arg_->meanStrides_);
pass &= IsVectorDimSizeValid<MeanInvStdSrcVectorDim, MeanInvStdSrcVectorSize>(
p_arg_->lengths_, p_arg_->invStdStrides_);
pass &= IsVectorDimSizeValid<DXDstVectorDim, DXDstVectorSize>(p_arg_->lengths_,
p_arg_->dxStrides_);
return pass;
}
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> lengths,
const std::vector<index_t> dyStrides,
const std::vector<index_t> xStrides,
const std::vector<index_t> gammaStrides,
const std::vector<index_t> meanStrides,
const std::vector<index_t> invStdStrides,
const std::vector<index_t> dxStrides,
const std::vector<index_t> reduceDims,
const void* p_dy,
const void* p_x,
const void* p_gamma,
const void* p_mean,
const void* p_invStd,
void* p_dx) override
{
if(lengths.size() != Rank || dyStrides.size() != Rank || xStrides.size() != Rank ||
gammaStrides.size() != Rank || meanStrides.size() != Rank ||
invStdStrides.size() != Rank || dxStrides.size() != Rank)
throw std::runtime_error("dimension is incorrect");
return std::make_unique<Argument>(lengths,
dyStrides,
xStrides,
gammaStrides,
meanStrides,
invStdStrides,
dxStrides,
reduceDims,
static_cast<const DYDataType*>(p_dy),
static_cast<const XDataType*>(p_x),
static_cast<const GammaDataType*>(p_gamma),
static_cast<const MeanInvStdDataType*>(p_mean),
static_cast<const MeanInvStdDataType*>(p_invStd),
static_cast<DXDataType*>(p_dx));
}
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceNormalizationBwdDataImpl<" << BlockSize << ",";
str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ",";
str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ",";
str << "DYSrcVectorSize" << DYSrcVectorSize << "_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_MeanRstd" << MeanInvStdSrcVectorSize << "_Dx" << DXDstVectorSize;
str << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
// M is invarient dimension, K is reduced dimension // M is Invariant dimension, K is reduced dimension
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -87,7 +87,6 @@ struct DeviceNormalizationBwdGammaBetaImpl ...@@ -87,7 +87,6 @@ struct DeviceNormalizationBwdGammaBetaImpl
Rank, Rank,
NumReduceDim> NumReduceDim>
{ {
static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0; static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0;
static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0; static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0;
static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0; static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0;
...@@ -102,18 +101,18 @@ struct DeviceNormalizationBwdGammaBetaImpl ...@@ -102,18 +101,18 @@ struct DeviceNormalizationBwdGammaBetaImpl
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)), (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"); "Invalid thread slice sizes and/or x vector sizes configuration, please check!");
static_assert(
((MThreadSliceSize % DGammaDstVectorSize == 0) ||
(MThreadSliceSize % DBetaDstVectorSize == 0)),
"Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please "
"check!");
static_assert( static_assert(
(MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) || (MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) ||
(MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0), (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0),
"Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please " "Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please "
"check!"); "check!");
static_assert(
((MThreadSliceSize % DGammaDstVectorSize == 0) ||
(MThreadSliceSize % DBetaDstVectorSize == 0)),
"Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please "
"check!");
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
...@@ -298,7 +297,7 @@ struct DeviceNormalizationBwdGammaBetaImpl ...@@ -298,7 +297,7 @@ struct DeviceNormalizationBwdGammaBetaImpl
GridDesc_M dgamma_grid_desc_m_; GridDesc_M dgamma_grid_desc_m_;
GridDesc_M dbeta_grid_desc_m_; GridDesc_M dbeta_grid_desc_m_;
index_t MRaw_; // invarient length index_t MRaw_; // Invariant length
index_t KRaw_; // reduce length index_t KRaw_; // reduce length
}; };
...@@ -457,6 +456,21 @@ struct DeviceNormalizationBwdGammaBetaImpl ...@@ -457,6 +456,21 @@ struct DeviceNormalizationBwdGammaBetaImpl
{ {
return std::make_unique<Invoker>(); return std::make_unique<Invoker>();
} }
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceNormalizationBwdGammaBetaImpl<" << BlockSize << ",";
str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ",";
str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ",";
str << "VectorSize_DY" << DYSrcVectorSize << "_X" << XSrcVectorSize ;
str << "_DGamma" << DGammaDstVectorSize << "_DBeta" << DBetaDstVectorSize << ">";
// clang-format on
return str.str();
}
}; };
} // namespace device } // namespace device
......
...@@ -19,7 +19,7 @@ namespace tensor_operation { ...@@ -19,7 +19,7 @@ namespace tensor_operation {
namespace device { namespace device {
// Y = Normalization(X, Beta, Gamma) // Y = Normalization(X, Beta, Gamma)
// M: Invarient length // M: Invariant length
// K: Reduce length (Calculate mean and variance along K dimension) // K: Reduce length (Calculate mean and variance along K dimension)
// eg. Length = [N, C, H, W], reduce dim = [C, H, W] // eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// Then, M = N, K = C * H * W // Then, M = N, K = C * H * W
...@@ -263,7 +263,7 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType, ...@@ -263,7 +263,7 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType,
GridDesc_M save_inv_std_grid_desc_m_; GridDesc_M save_inv_std_grid_desc_m_;
bool isSweeponce_; bool isSweeponce_;
index_t MRaw_; // invarient length index_t MRaw_; // Invariant length
index_t KRaw_; // reduce length index_t KRaw_; // reduce length
index_t invariant_lowest_length_; index_t invariant_lowest_length_;
...@@ -342,8 +342,6 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType, ...@@ -342,8 +342,6 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType,
} }
else else
{ {
printf("!!!! %d\n", p_arg_->invariant_lowest_length_);
if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) if(p_arg_->xStrides_[NumInvariantDim - 1] != 1)
return false; return false;
......
...@@ -108,7 +108,7 @@ namespace tensor_operation { ...@@ -108,7 +108,7 @@ namespace tensor_operation {
namespace device { namespace device {
// Y = Normalization(X, Beta, Gamma) // Y = Normalization(X, Beta, Gamma)
// M: Invarient length // M: Invariant length
// K: Reduce length (Calculate mean and variance along K dimension) // K: Reduce length (Calculate mean and variance along K dimension)
// eg. Length = [N, C, H, W], reduce dim = [C, H, W] // eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// Then, M = N, K = C * H * W // Then, M = N, K = C * H * W
...@@ -468,7 +468,7 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd<XDataTyp ...@@ -468,7 +468,7 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd<XDataTyp
Kernel2MeanVarGridDesc_M_KBlock kernel2_mean_var_grid_desc_m_kblock_; Kernel2MeanVarGridDesc_M_KBlock kernel2_mean_var_grid_desc_m_kblock_;
Kernel2CountGridDesc_M_KBlock kernel2_count_grid_desc_m_kblock_; Kernel2CountGridDesc_M_KBlock kernel2_count_grid_desc_m_kblock_;
index_t MRaw_; // invarient length index_t MRaw_; // Invariant length
index_t KRaw_; // reduce length index_t KRaw_; // reduce length
index_t invariant_lowest_length_; index_t invariant_lowest_length_;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
namespace ck {
// Tensor Shape
// dy, x = [M, K], gamma = [1, K], x_mean, inv_std = [M, 1]
// Flow:
// def normalization_backward_x(dy, x, gamma, x_mean, inv_std, reduce_axis, reduce_size):
// ds = np.sum(dy * gamma * x, axis=reduce_axis, keepdims=True)
// db = np.sum(dy * gamma, axis=reduce_axis, keepdims=True)
// b = (db * x_mean - ds) * inv_std ** (3) / reduce_size
// c = -b * x_mean - db * inv_std / reduce_size
// dx = inv_std * dy * gamma + b * x + c
// return dx
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename ComputeDataType,
typename DXDataType,
typename GridDesc_M_K,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t DYSrcVectorDim,
index_t DYSrcVectorSize,
index_t XSrcVectorDim,
index_t XSrcVectorSize,
index_t GammaSrcVectorDim,
index_t GammaSrcVectorSize,
index_t MeanInvStdSrcVectorDim,
index_t MeanInvStdSrcVectorSize,
index_t DXDstVectorDim,
index_t DXDstVectorSize,
bool SweepOnce>
struct GridwiseNormalizationBwdData_mk_to_mk
{
// if we just check ThreadSliceSize % VectorSize == 0, the performance may be poor (coalesce)
static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) ||
(DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!");
static_assert(((XSrcVectorDim == 0 && MThreadSliceSize == XSrcVectorSize) ||
(XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!");
static_assert(
((GammaSrcVectorDim == 0 && MThreadSliceSize == GammaSrcVectorSize) ||
(GammaSrcVectorDim == 1 && KThreadSliceSize == GammaSrcVectorSize)),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!");
static_assert(
((MeanInvStdSrcVectorDim == 0 && MThreadSliceSize == MeanInvStdSrcVectorSize) ||
(MeanInvStdSrcVectorDim == 1 && KThreadSliceSize == MeanInvStdSrcVectorSize)),
"Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!");
static_assert(((DXDstVectorDim == 0 && MThreadSliceSize == DXDstVectorSize) ||
(DXDstVectorDim == 1 && KThreadSliceSize == DXDstVectorSize)),
"Invalid thread slice sizes and/or dx vector sizes configuration, please check!");
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using DYThreadBufferDimAccessOrder =
typename conditional<DYSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using XThreadBufferDimAccessOrder =
typename conditional<XSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using GammaThreadBufferDimAccessOrder =
typename conditional<GammaSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using MeanInvStdThreadBufferDimAccessOrder =
typename conditional<MeanInvStdSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using DXThreadBufferDimAccessOrder =
typename conditional<DXDstVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder = DYThreadBufferDimAccessOrder;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
using BlockwiseSumReduce = PartitionedBlockwiseReduction<ComputeDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Add,
true>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const GridDesc_M_K& dy_grid_desc_m_k,
const GridDesc_M_K& x_grid_desc_m_k,
const GridDesc_M_K& gamma_grid_desc_m_k,
const GridDesc_M_K& mean_grid_desc_m_k,
const GridDesc_M_K& inv_std_grid_desc_m_k,
const GridDesc_M_K& dx_grid_desc_m_k,
index_t num_k_block_tile_iteration,
const DYDataType* const __restrict__ p_dy_global,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const MeanInvStdDataType* const __restrict__ p_mean_global,
const MeanInvStdDataType* const __restrict__ p_inv_std_global,
DXDataType* const __restrict__ p_dx_global)
{
// LDS
__shared__ ComputeDataType p_reduce_work_buffer[BlockSize];
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
// Global
const auto dy_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dy_global, dy_grid_desc_m_k.GetElementSpaceSize());
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
const auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_mean_global, mean_grid_desc_m_k.GetElementSpaceSize());
const auto inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_inv_std_global, inv_std_grid_desc_m_k.GetElementSpaceSize());
auto dx_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_dx_global, dx_grid_desc_m_k.GetElementSpaceSize());
// VGPR
auto dy_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto x_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto gamma_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto mean_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto inv_std_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto dx_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
ComputeDataType,
MThreadSliceSize * KThreadSliceSize,
true>{};
auto ds_thread_buf =
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>{};
auto db_thread_buf =
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>{};
// thread id
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
// IO
auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2<DYDataType,
ComputeDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
DYThreadBufferDimAccessOrder,
DYSrcVectorDim,
DYSrcVectorSize,
1,
false>(
dy_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
ComputeDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
XThreadBufferDimAccessOrder,
XSrcVectorDim,
XSrcVectorSize,
1,
false>(
x_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_gamma_load =
ThreadwiseTensorSliceTransfer_v2<GammaDataType,
ComputeDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
XThreadBufferDimAccessOrder,
GammaSrcVectorDim,
GammaSrcVectorSize,
1,
false>(
gamma_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_mean_load =
ThreadwiseTensorSliceTransfer_v2<MeanInvStdDataType,
ComputeDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
MeanInvStdThreadBufferDimAccessOrder,
MeanInvStdSrcVectorDim,
MeanInvStdSrcVectorSize,
1,
false>(
mean_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_inv_std_load =
ThreadwiseTensorSliceTransfer_v2<MeanInvStdDataType,
ComputeDataType,
GridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
MeanInvStdThreadBufferDimAccessOrder,
MeanInvStdSrcVectorDim,
MeanInvStdSrcVectorSize,
1,
false>(
inv_std_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_dx_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
DXDataType,
decltype(thread_buffer_desc_m_k),
GridDesc_M_K,
PassThroughOp,
ThreadBufferLengths_M_K,
DXThreadBufferDimAccessOrder,
DXDstVectorDim,
DXDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
false>(
dx_grid_desc_m_k,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize),
PassThroughOp{});
ComputeDataType reduce_size = type_convert<ComputeDataType>(
dy_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
ds_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
db_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
});
// Separate sweep once and sweep twice pipeline
// Sweep once: for small k, if KThreadClusterSize * KThreadSliceSize > K
// we don't need to use loop to read x, dy, gamma twice
if constexpr(SweepOnce)
{
threadwise_dy_load.Run(dy_grid_desc_m_k,
dy_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
dy_thread_buf);
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_gamma_load.Run(gamma_grid_desc_m_k,
gamma_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
gamma_thread_buf);
threadwise_mean_load.Run(mean_grid_desc_m_k,
mean_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
mean_thread_buf);
threadwise_inv_std_load.Run(inv_std_grid_desc_m_k,
inv_std_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
inv_std_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
constexpr auto offset_m =
Number<thread_buffer_desc_m.CalculateOffset(make_tuple(iM))>{};
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset_m_k =
Number<thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK))>{};
ds_thread_buf(offset_m) += dy_thread_buf[offset_m_k] *
gamma_thread_buf[offset_m_k] *
x_thread_buf[offset_m_k];
db_thread_buf(offset_m) +=
dy_thread_buf[offset_m_k] * gamma_thread_buf[offset_m_k];
});
});
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseSumReduce::Reduce(reduce_work_buf, ds_thread_buf(I));
block_sync_lds();
BlockwiseSumReduce::Reduce(reduce_work_buf, db_thread_buf(I));
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
constexpr auto offset_m =
Number<thread_buffer_desc_m.CalculateOffset(make_tuple(iM))>{};
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset_m_k =
Number<thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK))>{};
// b = (db * x_mean - ds) * rstd ** (3) / reduce_size
// c = -b * x_mean - db * rstd / reduce_size
// dx = rstd * dy * gamma + b * x + c
ComputeDataType b = db_thread_buf[offset_m] * mean_thread_buf[offset_m_k] -
ds_thread_buf[offset_m];
b *= inv_std_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] *
inv_std_thread_buf[offset_m_k] / reduce_size;
ComputeDataType c = -b * mean_thread_buf(offset_m_k);
c -= db_thread_buf[offset_m] * inv_std_thread_buf[offset_m_k] / reduce_size;
dx_thread_buf(offset_m_k) = dy_thread_buf[offset_m_k] *
gamma_thread_buf[offset_m_k] *
inv_std_thread_buf[offset_m_k] +
b * x_thread_buf[offset_m_k] + c;
});
});
threadwise_dx_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0),
dx_thread_buf,
dx_grid_desc_m_k,
dx_global_val_buf);
} // end of sweep once
else // Sweep Twice pipeline
{
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_dy_load.Run(dy_grid_desc_m_k,
dy_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
dy_thread_buf);
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_gamma_load.Run(gamma_grid_desc_m_k,
gamma_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
gamma_thread_buf);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
thread_copy_fwd_step_m_k);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
constexpr auto offset_m =
Number<thread_buffer_desc_m.CalculateOffset(make_tuple(iM))>{};
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset_m_k =
Number<thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK))>{};
ds_thread_buf(offset_m) += dy_thread_buf[offset_m_k] *
gamma_thread_buf[offset_m_k] *
x_thread_buf[offset_m_k];
db_thread_buf(offset_m) +=
dy_thread_buf[offset_m_k] * gamma_thread_buf[offset_m_k];
});
});
} // end of first sweep
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseSumReduce::Reduce(reduce_work_buf, ds_thread_buf(I));
block_sync_lds();
BlockwiseSumReduce::Reduce(reduce_work_buf, db_thread_buf(I));
});
// reverse read for using dy, gamma and x in the cache
constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, -K_BlockTileSize);
auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
// move to tail
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_bwd_step_m_k);
// move from start to tail
threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k, thread_copy_tail_m_k);
threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k, thread_copy_tail_m_k);
threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_tail_m_k);
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_dy_load.Run(dy_grid_desc_m_k,
dy_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
dy_thread_buf);
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_gamma_load.Run(gamma_grid_desc_m_k,
gamma_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
gamma_thread_buf);
threadwise_mean_load.Run(mean_grid_desc_m_k,
mean_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
mean_thread_buf);
threadwise_inv_std_load.Run(inv_std_grid_desc_m_k,
inv_std_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
inv_std_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
constexpr auto offset_m =
Number<thread_buffer_desc_m.CalculateOffset(make_tuple(iM))>{};
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset_m_k =
Number<thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK))>{};
// b = (db * x_mean - ds) * rstd ** (3) / reduce_size
// c = -b * x_mean - db * rstd / reduce_size
// dx = rstd * dy * gamma + b * x + c
ComputeDataType b = db_thread_buf[offset_m] * mean_thread_buf[offset_m_k] -
ds_thread_buf[offset_m];
b *= inv_std_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] *
inv_std_thread_buf[offset_m_k] / reduce_size;
ComputeDataType c = -b * mean_thread_buf(offset_m_k);
c -= db_thread_buf[offset_m] * inv_std_thread_buf[offset_m_k] / reduce_size;
dx_thread_buf(offset_m_k) = dy_thread_buf[offset_m_k] *
gamma_thread_buf[offset_m_k] *
inv_std_thread_buf[offset_m_k] +
b * x_thread_buf[offset_m_k] + c;
});
});
threadwise_dx_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0),
dx_thread_buf,
dx_grid_desc_m_k,
dx_global_val_buf);
threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
thread_copy_bwd_step_m_k);
threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k,
thread_copy_bwd_step_m_k);
threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k,
thread_copy_bwd_step_m_k);
threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_bwd_step_m_k);
}
}
}
};
} // namespace ck
...@@ -35,7 +35,7 @@ template <typename DYDataType, ...@@ -35,7 +35,7 @@ template <typename DYDataType,
index_t DBetaDstVectorSize> index_t DBetaDstVectorSize>
struct GridwiseNormalizationBwdGammaBeta_mk_to_k struct GridwiseNormalizationBwdGammaBeta_mk_to_k
{ {
// if we just check ThreadSliceSize & VectorSize == 0, the performance may be poor // if we just check ThreadSliceSize % VectorSize == 0, the performance may be poor (coalesce)
static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) || static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) ||
(DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)), (DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!"); "Invalid thread slice sizes and/or dy vector sizes configuration, please check!");
...@@ -44,6 +44,15 @@ struct GridwiseNormalizationBwdGammaBeta_mk_to_k ...@@ -44,6 +44,15 @@ struct GridwiseNormalizationBwdGammaBeta_mk_to_k
(XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)), (XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"); "Invalid thread slice sizes and/or x vector sizes configuration, please check!");
// do not force SliceSize == MeanInvStdSrcVectorSize for groupnorm
static_assert(
((MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) ||
(MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0)),
"Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!");
static_assert(MThreadSliceSize == DGammaDstVectorSize && MThreadSliceSize == DBetaDstVectorSize,
"Invalid thread slice sizes and/or dx vector sizes configuration, please check!");
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>; using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using DYThreadBufferDimAccessOrder = using DYThreadBufferDimAccessOrder =
......
...@@ -16,6 +16,31 @@ namespace ck { ...@@ -16,6 +16,31 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace host { namespace host {
// def normalization_backward_x(dy, x, gamma, x_mean, rstd, reduce_axis, reduce_size):
// ds = np.sum(dy * gamma * x, axis=reduce_axis, keepdims=True)
// db = np.sum(dy * gamma, axis=reduce_axis, keepdims=True)
// b = (db * x_mean - ds) * rstd ** (3) / reduce_size
// c = -b * x_mean - db * rstd / reduce_size
// dx = rstd * dy * gamma + b * x + c
// return dx
// def normalization_backward_gamma_beta(dy, x, x_mean, rstd, reduce_axis):
// # Assume shape of gamma and beta are the same
// dgamma = np.sum(dy * (x - x_mean) * rstd, axis=reduce_axis, keepdims=True)
// dbeta = np.sum(dy, axis=reduce_axis, keepdims=True)
// return dgamma, dbeta
// def groupnorm_backward(dy, x, gamma, x_mean, rstd):
// # dy, x = [N, H, W, G, C], gamma = [1, 1, 1, G, C], x_mean, rstd = [N, 1, 1, G, 1]
// N, H, W, G, C = x.shape
// dx = normalization_input_backward(
// dy, x, gamma, x_mean, rstd, (1, 2, 4), H * W * C)
// dgamma, dbeta = normalization_gamma_beta_backward(
// dy, x, x_mean, rstd, (0, 1, 2))
// return dx, dgamma, dbeta
// Reference (Layernorm and groupnorm):
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/group_norm_kernel.cpp#L655
template <typename DYDataType, template <typename DYDataType,
typename XDataType, typename XDataType,
typename GammaDataType, typename GammaDataType,
......
...@@ -16,6 +16,30 @@ namespace ck { ...@@ -16,6 +16,30 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace host { namespace host {
// def normalization_backward_x(dy, x, gamma, x_mean, rstd, reduce_axis, reduce_size):
// ds = np.sum(dy * gamma * x, axis=reduce_axis, keepdims=True)
// db = np.sum(dy * gamma, axis=reduce_axis, keepdims=True)
// b = (db * x_mean - ds) * rstd ** (3) / reduce_size
// c = -b * x_mean - db * rstd / reduce_size
// dx = rstd * dy * gamma + b * x + c
// return dx
// def normalization_beta_backward_gamma_beta(dy, x, x_mean, rstd, reduce_axis):
// # Assume shape of gamma and beta are the same
// dgamma = np.sum(dy * (x - x_mean) * rstd, axis=reduce_axis, keepdims=True)
// dbeta = np.sum(dy, axis=reduce_axis, keepdims=True)
// return dgamma, dbeta
// def layernorm_backward(dy, x, gamma, x_mean, rstd):
// # dy, x = [M, K], gamma = [1, K], x_mean, rstd = [M, 1]
// # dx = [M, K], dgamma, dbeta = [1, K]
// M, K = x.shape
// dx = normalization_input_backward(dy, x, gamma, x_mean, rstd, 1, K)
// dgamma, dbeta = normalization_gamma_beta_backward(dy, x, x_mean, rstd, 0)
// return dx, dgamma, dbeta
// Reference (Layernorm and groupnorm):
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/layer_norm_kernel.cpp#L196
template <typename DYDataType, template <typename DYDataType,
typename XDataType, typename XDataType,
typename GammaDataType, typename GammaDataType,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP32
// FP32
void add_device_groupnorm_bwd_data_f32_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdData<F32, F32, F32, F32, F32, 5, 3>>>&);
#endif
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename DXDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceNormalizationBwdData<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DXDataType,
5,
3>>
{
using DeviceOp = DeviceNormalizationBwdData<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DXDataType,
5,
3>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<DYDataType, F32> && is_same_v<XDataType, F32> &&
is_same_v<GammaDataType, F32> && is_same_v<MeanInvStdDataType, F32> &&
is_same_v<DXDataType, F32>)
{
add_device_groupnorm_bwd_data_f32_instances(op_ptrs);
}
#endif
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment