Commit d0f355a3 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 55a89c74 b305a29e
...@@ -149,7 +149,7 @@ function(clang_tidy_check TARGET) ...@@ -149,7 +149,7 @@ function(clang_tidy_check TARGET)
add_custom_target(${tidy_target} add_custom_target(${tidy_target}
# for some targets clang-tidy not able to get information from .clang-tidy # for some targets clang-tidy not able to get information from .clang-tidy
DEPENDS ${SOURCE} DEPENDS ${SOURCE}
COMMAND ${CLANG_TIDY_COMMAND} "-config=\{CheckOptions: \[\{key: bugprone-reserved-identifier.AllowedIdentifiers,value: __HIP_PLATFORM_HCC__\; __HIP_ROCclr__\}\]\}" ${SOURCE} "-export-fixes=${CLANG_TIDY_FIXIT_DIR}/${TARGET}-${tidy_file}.yaml" COMMAND ${CLANG_TIDY_COMMAND} "-config=\{CheckOptions: \[\{key: bugprone-reserved-identifier.AllowedIdentifiers,value: __HIP_PLATFORM_HCC__\; __HIP_PLATFORM_AMD__\; __HIP_ROCclr__\}\]\}" ${SOURCE} "-export-fixes=${CLANG_TIDY_FIXIT_DIR}/${TARGET}-${tidy_file}.yaml"
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
COMMENT "clang-tidy: Running clang-tidy on target ${SOURCE}..." COMMENT "clang-tidy: Running clang-tidy on target ${SOURCE}..."
) )
......
rocm-docs-core==0.30.1 rocm-docs-core==0.30.2
sphinxcontrib-bibtex==2.6.1 sphinxcontrib-bibtex==2.6.1
...@@ -113,7 +113,7 @@ requests==2.31.0 ...@@ -113,7 +113,7 @@ requests==2.31.0
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==0.30.1 rocm-docs-core==0.30.2
# via -r requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via # via
......
#include <iostream> #include <iostream>
#include <cstdlib> #include <cstdlib>
#include <random>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
...@@ -48,10 +49,8 @@ void host_elementwise4D(HostTensorB& B_nhwc, ...@@ -48,10 +49,8 @@ void host_elementwise4D(HostTensorB& B_nhwc,
for(std::size_t n = 0; n < N; ++n) for(std::size_t n = 0; n < N; ++n)
{ {
ADataType tmp_val; ADataType tmp_val;
// auto a_val = A_nchw(n, c, h, w);
auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)]; auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)];
functor_b(tmp_val, a_val); functor_b(tmp_val, a_val);
// functor_a(B_nhwc(n, h, w, c), scale * tmp_val);
functor_a(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], functor_a(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)],
scale * tmp_val); scale * tmp_val);
} }
...@@ -62,12 +61,14 @@ int main() ...@@ -62,12 +61,14 @@ int main()
bool do_verification = true; bool do_verification = true;
bool time_kernel = true; bool time_kernel = true;
std::vector<std::size_t> nchw = {4, 2, 1, 8}; std::vector<std::size_t> nchw = {16, 8, 32, 64};
std::vector<std::size_t> nhwc = {4, 1, 8, 2}; std::vector<std::size_t> nhwc = {16, 32, 64, 8};
Tensor<ADataType> a(nchw); Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc); Tensor<BDataType> b(nhwc);
float scale = 1.f; float scale = 1.f;
auto i = 0; auto i = 0;
std::mt19937 gen(11939);
std::uniform_int_distribution<int> dis(0, 1);
for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w) for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w)
for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h) for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h)
for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c)
...@@ -75,7 +76,7 @@ int main() ...@@ -75,7 +76,7 @@ int main()
{ {
a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) + a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) +
(h * nchw[3]) + w] = i; (h * nchw[3]) + w] = i;
i++; i = dis(gen);
} }
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
......
...@@ -67,6 +67,8 @@ int main() ...@@ -67,6 +67,8 @@ int main()
float scale = 1.f; float scale = 1.f;
auto i = 0; auto i = 0;
std::mt19937 gen(11939);
std::uniform_int_distribution<int> dis(0, 1);
for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w) for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w)
for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h) for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h)
for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c)
...@@ -74,7 +76,7 @@ int main() ...@@ -74,7 +76,7 @@ int main()
{ {
a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) + a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) +
(h * nchw[3]) + w] = i; (h * nchw[3]) + w] = i;
i++; i = dis(gen);
} }
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
......
add_example_executable(example_layernorm2d_bwd_fp32 layernorm2d_bwd_fp32.cpp)
...@@ -15,16 +15,17 @@ ...@@ -15,16 +15,17 @@
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp"
using DYDataType = ck::half_t; using DYDataType = float;
using XDataType = ck::half_t; using XDataType = float;
using GammaDataType = ck::half_t; using GammaDataType = float;
using MeanInvStdDataType = float; using MeanInvStdDataType = float;
using DGammaDataType = ck::half_t; using DGammaDataType = float;
using DBetaDataType = ck::half_t; using DBetaDataType = float;
using DXDataType = ck::half_t; using DXDataType = float;
using ComputeDataType = float; using ComputeDataType = float;
constexpr int Rank = 2; constexpr int Rank = 2;
...@@ -39,6 +40,7 @@ constexpr int NumReduceDim = 1; ...@@ -39,6 +40,7 @@ constexpr int NumReduceDim = 1;
// inv_std: [M, 1] // inv_std: [M, 1]
// Output shape // Output shape
// dx: [M, N]
// dgamma: [1, N] // dgamma: [1, N]
// dbeta: [1, N] // dbeta: [1, N]
...@@ -46,8 +48,34 @@ constexpr int NumReduceDim = 1; ...@@ -46,8 +48,34 @@ constexpr int NumReduceDim = 1;
// dbeta = reduce_sum(dy, axis=0) // dbeta = reduce_sum(dy, axis=0)
// [CAUSION] // [CAUSION]
// In DeviceNormalizationBwdGammaBetaImpl, M is invarient dimension, K is reduced dimension // In DeviceNormalizationBwdDataImpl & DeviceNormalizationBwdGammaBetaImpl, M is Invariant
// Hence, M in this example and DeviceNormalizationBwdGammaBetaImpl is different // dimension, K is reduced dimension Hence, M in this example and
// DeviceNormalizationBwdGammaBetaImpl is different
using XDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdDataImpl<
DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
ComputeDataType,
DXDataType,
Rank,
NumReduceDim,
256, // BlockSize
8, // MThreadClusterSize
32, // KThreadClusterSize
1, // MThreadSliceSize
4, // KThreadSliceSize
true, // IsDYFastestDimReduced
4, // DYSrcVectorSize
true, // IsXFastestDimReduced
4, // XSrcVectorSize
true, // IsGammaFastestDimReduced
4, // GammaSrcVectorSize
false, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize
true, // IsDXFastestDimReduced
4>; // DXDstVectorSize
using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdGammaBetaImpl< using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdGammaBetaImpl<
DYDataType, DYDataType,
XDataType, XDataType,
...@@ -58,18 +86,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio ...@@ -58,18 +86,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio
Rank, Rank,
NumReduceDim, NumReduceDim,
256, // BlockSize 256, // BlockSize
8, // ClusterInvarient 8, // MThreadClusterSize
32, // ClusterReduce 32, // KThreadClusterSize
8, // SliceInvarient 4, // MThreadSliceSize
1, // SliceReduce 1, // KThreadSliceSize
false, // IsDYFastestDimReduced false, // IsDYFastestDimReduced
8, // DYSrcVectorSize 4, // DYSrcVectorSize
false, // IsXFastestDimReduced false, // IsXFastestDimReduced
8, // XSrcVectorSize 4, // XSrcVectorSize
true, // IsMeanInvStdFastestDimReduced true, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize 1, // MeanInvStdSrcVectorSize
1, // DGammaDstVectorSize 4, // DGammaDstVectorSize
1>; // DBetaDstVectorSize 4>; // DBetaDstVectorSize
int main() int main()
{ {
...@@ -96,16 +124,48 @@ int main() ...@@ -96,16 +124,48 @@ int main()
DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize());
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize());
DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize());
DeviceMem dx_dev(sizeof(DXDataType) * dx.mDesc.GetElementSpaceSize());
DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize()); DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize());
DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize()); DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize());
dy_dev.ToDevice(dy.mData.data()); dy_dev.ToDevice(dy.mData.data());
x_dev.ToDevice(x.mData.data()); x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data());
mean_dev.ToDevice(mean.mData.data()); mean_dev.ToDevice(mean.mData.data());
inv_std_dev.ToDevice(inv_std.mData.data()); inv_std_dev.ToDevice(inv_std.mData.data());
// backward x
auto x_device_instance = XDeviceInstance{};
auto x_argument_ptr = x_device_instance.MakeArgumentPointer({M, N}, // lengths
{N, 1}, // dyStrides
{N, 1}, // xStrides
{0, 1}, // gammaStrides
{1, 0}, // meanStrides
{1, 0}, // invStdStrides
{N, 1}, // dxStrides
{1}, // reduceDims
dy_dev.GetDeviceBuffer(),
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
mean_dev.GetDeviceBuffer(),
inv_std_dev.GetDeviceBuffer(),
dx_dev.GetDeviceBuffer());
if(!x_device_instance.IsSupportedArgument(x_argument_ptr.get()))
{
std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__
<< std::endl;
return 1;
};
auto x_invoker_ptr = x_device_instance.MakeInvokerPointer();
x_invoker_ptr->Run(x_argument_ptr.get(), StreamConfig{nullptr, time_kernel});
// backward gamma & beta
auto gamma_beta_device_instance = GammaBetaDeviceInstance{}; auto gamma_beta_device_instance = GammaBetaDeviceInstance{};
auto gamma_beta_argument_ptr = auto gamma_beta_argument_ptr =
gamma_beta_device_instance.MakeArgumentPointer({M, N}, // inLengths gamma_beta_device_instance.MakeArgumentPointer({M, N}, // inLengths
...@@ -126,7 +186,8 @@ int main() ...@@ -126,7 +186,8 @@ int main()
if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get())) if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get()))
{ {
std::cout << "The runtime parameters are not supported" << std::endl; std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__
<< std::endl;
return 1; return 1;
}; };
...@@ -156,9 +217,11 @@ int main() ...@@ -156,9 +217,11 @@ int main()
dgamma_dev.FromDevice(dgamma.mData.data()); dgamma_dev.FromDevice(dgamma.mData.data());
dbeta_dev.FromDevice(dbeta.mData.data()); dbeta_dev.FromDevice(dbeta.mData.data());
dx_dev.FromDevice(dx.mData.data());
pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3); pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3);
pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3); pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3);
pass &= ck::utils::check_err(dx, host_dx, "Error: Incorrect dx", 1e-3, 1e-3);
} }
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
add_example_executable(example_layernorm2d_bwd_fp16 layernorm2d_bwd_fp16.cpp)
add_example_executable(example_groupnorm_bwd_fp16 groupnorm_bwd_fp16.cpp) add_example_executable(example_groupnorm_bwd_fp32 groupnorm_bwd_fp32.cpp)
...@@ -15,23 +15,58 @@ ...@@ -15,23 +15,58 @@
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp"
using DYDataType = ck::half_t; using DYDataType = float;
using XDataType = ck::half_t; using XDataType = float;
using GammaDataType = ck::half_t; using GammaDataType = float;
using MeanInvStdDataType = float; using MeanInvStdDataType = float;
using DGammaDataType = ck::half_t; using DGammaDataType = float;
using DBetaDataType = ck::half_t; using DBetaDataType = float;
using DXDataType = ck::half_t; using DXDataType = float;
using ComputeDataType = float; using ComputeDataType = float;
constexpr int Rank = 5; constexpr int Rank = 5;
constexpr int NumReduceDim = 3; constexpr int NumReduceDim = 3;
// Grouprnorm // Grouprnorm
// kernel: M , K // kernel 1: M , K
// dy: N, H, W, G, C -> N * G, H * W * C
// x: N, H, W, G, C -> N * G, H * W * C
// gamma: 1, 1, 1, G, C -> 1 * G, 1 * 1 * C
// mean: N, 1, 1, G, 1 -> N * G, 1 * 1 * 1
// rstd: N, 1, 1, G, 1 -> N * G, 1 * 1 * 1
// dx: N, H, W, G, C -> N * G, H * W * C
using XDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdDataImpl<
DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
ComputeDataType,
DXDataType,
Rank,
NumReduceDim,
256, // BlockSize
8, // MThreadClusterSize
32, // KThreadClusterSize
1, // MThreadSliceSize
4, // KThreadSliceSize
true, // IsDYFastestDimReduced
4, // DYSrcVectorSize
true, // IsXFastestDimReduced
4, // XSrcVectorSize
true, // IsGammaFastestDimReduced
4, // GammaSrcVectorSize
false, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize
true, // IsDXFastestDimReduced
4>; // DXDstVectorSize
// kernel 2: M , K
// dy: N, H, W, G, C -> G * C, N * H * W // dy: N, H, W, G, C -> G * C, N * H * W
// x: N, H, W, G, C -> G * C, N * H * W // x: N, H, W, G, C -> G * C, N * H * W
// mean: N, 1, 1, G, 1 -> G * 1, N * 1 * 1 // mean: N, 1, 1, G, 1 -> G * 1, N * 1 * 1
...@@ -52,18 +87,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio ...@@ -52,18 +87,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio
Rank, Rank,
NumReduceDim, NumReduceDim,
256, // BlockSize 256, // BlockSize
8, // ClusterInvarient 8, // ClusterInvariant
32, // ClusterReduce 32, // ClusterReduce
8, // SliceInvarient 4, // SliceInvariant
1, // SliceReduce 1, // SliceReduce
false, // IsDYFastestDimReduced false, // IsDYFastestDimReduced
8, // DYSrcVectorSize 4, // DYSrcVectorSize
false, // IsXFastestDimReduced false, // IsXFastestDimReduced
8, // XSrcVectorSize 4, // XSrcVectorSize
false, // IsMeanInvStdFastestDimReduced false, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize 1, // MeanInvStdSrcVectorSize
1, // DGammaDstVectorSize 4, // DGammaDstVectorSize
1>; // DBetaDstVectorSize 4>; // DBetaDstVectorSize
int main() int main()
{ {
...@@ -93,20 +128,55 @@ int main() ...@@ -93,20 +128,55 @@ int main()
DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize());
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize());
DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize());
DeviceMem dx_dev(sizeof(DXDataType) * dx.mDesc.GetElementSpaceSize());
DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize()); DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize());
DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize()); DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize());
dy_dev.ToDevice(dy.mData.data()); dy_dev.ToDevice(dy.mData.data());
x_dev.ToDevice(x.mData.data()); x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data());
mean_dev.ToDevice(mean.mData.data()); mean_dev.ToDevice(mean.mData.data());
inv_std_dev.ToDevice(inv_std.mData.data()); inv_std_dev.ToDevice(inv_std.mData.data());
std::vector<ck::index_t> dyStrides{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()}; std::vector<ck::index_t> dyStrides{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()};
std::vector<ck::index_t> xStrides{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}; std::vector<ck::index_t> xStrides{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()};
std::vector<ck::index_t> gammaStrides = {0, 0, 0, C, 1};
std::vector<ck::index_t> meanStrides = {G, 0, 0, 1, 0}; std::vector<ck::index_t> meanStrides = {G, 0, 0, 1, 0};
std::vector<ck::index_t> invStdStrides = {G, 0, 0, 1, 0}; std::vector<ck::index_t> invStdStrides = {G, 0, 0, 1, 0};
std::vector<ck::index_t> dxStrides{dx.mDesc.GetStrides().begin(), dx.mDesc.GetStrides().end()};
// backward x
auto x_device_instance = XDeviceInstance{};
auto x_argument_ptr = x_device_instance.MakeArgumentPointer({N, H, W, G, C}, // lengths
dyStrides, // dyStrides
xStrides, // xStrides
gammaStrides, // gammaStrides
meanStrides, // meanStrides
invStdStrides, // invStdStrides
dxStrides, // dxStrides
{1, 2, 4}, // reduceDims
dy_dev.GetDeviceBuffer(),
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
mean_dev.GetDeviceBuffer(),
inv_std_dev.GetDeviceBuffer(),
dx_dev.GetDeviceBuffer());
if(!x_device_instance.IsSupportedArgument(x_argument_ptr.get()))
{
std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__
<< std::endl;
return 1;
};
auto x_invoker_ptr = x_device_instance.MakeInvokerPointer();
x_invoker_ptr->Run(x_argument_ptr.get(), StreamConfig{nullptr, time_kernel});
// backward gamma & beta
auto gamma_beta_device_instance = GammaBetaDeviceInstance{}; auto gamma_beta_device_instance = GammaBetaDeviceInstance{};
auto gamma_beta_argument_ptr = auto gamma_beta_argument_ptr =
...@@ -128,7 +198,8 @@ int main() ...@@ -128,7 +198,8 @@ int main()
if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get())) if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get()))
{ {
std::cout << "The runtime parameters are not supported" << std::endl; std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__
<< std::endl;
return 1; return 1;
}; };
...@@ -158,9 +229,11 @@ int main() ...@@ -158,9 +229,11 @@ int main()
dgamma_dev.FromDevice(dgamma.mData.data()); dgamma_dev.FromDevice(dgamma.mData.data());
dbeta_dev.FromDevice(dbeta.mData.data()); dbeta_dev.FromDevice(dbeta.mData.data());
dx_dev.FromDevice(dx.mData.data());
pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3); pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3);
pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3); pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3);
pass &= ck::utils::check_err(dx, host_dx, "Error: Incorrect dx", 1e-3, 1e-3);
} }
return (pass ? 0 : 1); return (pass ? 0 : 1);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename DXDataType,
index_t Rank,
index_t NumReduceDim>
struct DeviceNormalizationBwdData : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> lengths,
const std::vector<index_t> dyStrides,
const std::vector<index_t> xStrides,
const std::vector<index_t> gammaStrides,
const std::vector<index_t> meanStrides,
const std::vector<index_t> invStdStrides,
const std::vector<index_t> dxStrides,
const std::vector<index_t> reduceDims,
const void* p_dy,
const void* p_x,
const void* p_gamma,
const void* p_mean,
const void* p_invStd,
void* p_dx) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename DXDataType,
index_t Rank,
index_t NumReduceDim>
using DeviceNormalizationBwdDataPtr = std::unique_ptr<DeviceNormalizationBwdData<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DXDataType,
Rank,
NumReduceDim>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
// M is invarient dimension, K is reduced dimension // M is Invariant dimension, K is reduced dimension
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -87,7 +87,6 @@ struct DeviceNormalizationBwdGammaBetaImpl ...@@ -87,7 +87,6 @@ struct DeviceNormalizationBwdGammaBetaImpl
Rank, Rank,
NumReduceDim> NumReduceDim>
{ {
static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0; static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0;
static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0; static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0;
static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0; static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0;
...@@ -102,18 +101,18 @@ struct DeviceNormalizationBwdGammaBetaImpl ...@@ -102,18 +101,18 @@ struct DeviceNormalizationBwdGammaBetaImpl
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)), (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"); "Invalid thread slice sizes and/or x vector sizes configuration, please check!");
static_assert(
((MThreadSliceSize % DGammaDstVectorSize == 0) ||
(MThreadSliceSize % DBetaDstVectorSize == 0)),
"Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please "
"check!");
static_assert( static_assert(
(MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) || (MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) ||
(MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0), (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0),
"Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please " "Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please "
"check!"); "check!");
static_assert(
((MThreadSliceSize % DGammaDstVectorSize == 0) ||
(MThreadSliceSize % DBetaDstVectorSize == 0)),
"Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please "
"check!");
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
...@@ -298,7 +297,7 @@ struct DeviceNormalizationBwdGammaBetaImpl ...@@ -298,7 +297,7 @@ struct DeviceNormalizationBwdGammaBetaImpl
GridDesc_M dgamma_grid_desc_m_; GridDesc_M dgamma_grid_desc_m_;
GridDesc_M dbeta_grid_desc_m_; GridDesc_M dbeta_grid_desc_m_;
index_t MRaw_; // invarient length index_t MRaw_; // Invariant length
index_t KRaw_; // reduce length index_t KRaw_; // reduce length
}; };
...@@ -457,6 +456,21 @@ struct DeviceNormalizationBwdGammaBetaImpl ...@@ -457,6 +456,21 @@ struct DeviceNormalizationBwdGammaBetaImpl
{ {
return std::make_unique<Invoker>(); return std::make_unique<Invoker>();
} }
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceNormalizationBwdGammaBetaImpl<" << BlockSize << ",";
str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ",";
str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ",";
str << "VectorSize_DY" << DYSrcVectorSize << "_X" << XSrcVectorSize ;
str << "_DGamma" << DGammaDstVectorSize << "_DBeta" << DBetaDstVectorSize << ">";
// clang-format on
return str.str();
}
}; };
} // namespace device } // namespace device
......
...@@ -19,7 +19,7 @@ namespace tensor_operation { ...@@ -19,7 +19,7 @@ namespace tensor_operation {
namespace device { namespace device {
// Y = Normalization(X, Beta, Gamma) // Y = Normalization(X, Beta, Gamma)
// M: Invarient length // M: Invariant length
// K: Reduce length (Calculate mean and variance along K dimension) // K: Reduce length (Calculate mean and variance along K dimension)
// eg. Length = [N, C, H, W], reduce dim = [C, H, W] // eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// Then, M = N, K = C * H * W // Then, M = N, K = C * H * W
...@@ -263,7 +263,7 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType, ...@@ -263,7 +263,7 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType,
GridDesc_M save_inv_std_grid_desc_m_; GridDesc_M save_inv_std_grid_desc_m_;
bool isSweeponce_; bool isSweeponce_;
index_t MRaw_; // invarient length index_t MRaw_; // Invariant length
index_t KRaw_; // reduce length index_t KRaw_; // reduce length
index_t invariant_lowest_length_; index_t invariant_lowest_length_;
...@@ -342,8 +342,6 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType, ...@@ -342,8 +342,6 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType,
} }
else else
{ {
printf("!!!! %d\n", p_arg_->invariant_lowest_length_);
if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) if(p_arg_->xStrides_[NumInvariantDim - 1] != 1)
return false; return false;
......
...@@ -108,7 +108,7 @@ namespace tensor_operation { ...@@ -108,7 +108,7 @@ namespace tensor_operation {
namespace device { namespace device {
// Y = Normalization(X, Beta, Gamma) // Y = Normalization(X, Beta, Gamma)
// M: Invarient length // M: Invariant length
// K: Reduce length (Calculate mean and variance along K dimension) // K: Reduce length (Calculate mean and variance along K dimension)
// eg. Length = [N, C, H, W], reduce dim = [C, H, W] // eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// Then, M = N, K = C * H * W // Then, M = N, K = C * H * W
...@@ -468,7 +468,7 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd<XDataTyp ...@@ -468,7 +468,7 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd<XDataTyp
Kernel2MeanVarGridDesc_M_KBlock kernel2_mean_var_grid_desc_m_kblock_; Kernel2MeanVarGridDesc_M_KBlock kernel2_mean_var_grid_desc_m_kblock_;
Kernel2CountGridDesc_M_KBlock kernel2_count_grid_desc_m_kblock_; Kernel2CountGridDesc_M_KBlock kernel2_count_grid_desc_m_kblock_;
index_t MRaw_; // invarient length index_t MRaw_; // Invariant length
index_t KRaw_; // reduce length index_t KRaw_; // reduce length
index_t invariant_lowest_length_; index_t invariant_lowest_length_;
......
...@@ -35,7 +35,7 @@ template <typename DYDataType, ...@@ -35,7 +35,7 @@ template <typename DYDataType,
index_t DBetaDstVectorSize> index_t DBetaDstVectorSize>
struct GridwiseNormalizationBwdGammaBeta_mk_to_k struct GridwiseNormalizationBwdGammaBeta_mk_to_k
{ {
// if we just check ThreadSliceSize & VectorSize == 0, the performance may be poor // if we just check ThreadSliceSize % VectorSize == 0, the performance may be poor (coalesce)
static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) || static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) ||
(DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)), (DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!"); "Invalid thread slice sizes and/or dy vector sizes configuration, please check!");
...@@ -44,6 +44,15 @@ struct GridwiseNormalizationBwdGammaBeta_mk_to_k ...@@ -44,6 +44,15 @@ struct GridwiseNormalizationBwdGammaBeta_mk_to_k
(XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)), (XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"); "Invalid thread slice sizes and/or x vector sizes configuration, please check!");
// do not force SliceSize == MeanInvStdSrcVectorSize for groupnorm
static_assert(
((MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) ||
(MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0)),
"Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!");
static_assert(MThreadSliceSize == DGammaDstVectorSize && MThreadSliceSize == DBetaDstVectorSize,
"Invalid thread slice sizes and/or dx vector sizes configuration, please check!");
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>; using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using DYThreadBufferDimAccessOrder = using DYThreadBufferDimAccessOrder =
......
...@@ -16,6 +16,31 @@ namespace ck { ...@@ -16,6 +16,31 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace host { namespace host {
// def normalization_backward_x(dy, x, gamma, x_mean, rstd, reduce_axis, reduce_size):
// ds = np.sum(dy * gamma * x, axis=reduce_axis, keepdims=True)
// db = np.sum(dy * gamma, axis=reduce_axis, keepdims=True)
// b = (db * x_mean - ds) * rstd ** (3) / reduce_size
// c = -b * x_mean - db * rstd / reduce_size
// dx = rstd * dy * gamma + b * x + c
// return dx
// def normalization_backward_gamma_beta(dy, x, x_mean, rstd, reduce_axis):
// # Assume shape of gamma and beta are the same
// dgamma = np.sum(dy * (x - x_mean) * rstd, axis=reduce_axis, keepdims=True)
// dbeta = np.sum(dy, axis=reduce_axis, keepdims=True)
// return dgamma, dbeta
// def groupnorm_backward(dy, x, gamma, x_mean, rstd):
// # dy, x = [N, H, W, G, C], gamma = [1, 1, 1, G, C], x_mean, rstd = [N, 1, 1, G, 1]
// N, H, W, G, C = x.shape
// dx = normalization_input_backward(
// dy, x, gamma, x_mean, rstd, (1, 2, 4), H * W * C)
// dgamma, dbeta = normalization_gamma_beta_backward(
// dy, x, x_mean, rstd, (0, 1, 2))
// return dx, dgamma, dbeta
// Reference (Layernorm and groupnorm):
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/group_norm_kernel.cpp#L655
template <typename DYDataType, template <typename DYDataType,
typename XDataType, typename XDataType,
typename GammaDataType, typename GammaDataType,
......
...@@ -16,6 +16,30 @@ namespace ck { ...@@ -16,6 +16,30 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace host { namespace host {
// def normalization_backward_x(dy, x, gamma, x_mean, rstd, reduce_axis, reduce_size):
// ds = np.sum(dy * gamma * x, axis=reduce_axis, keepdims=True)
// db = np.sum(dy * gamma, axis=reduce_axis, keepdims=True)
// b = (db * x_mean - ds) * rstd ** (3) / reduce_size
// c = -b * x_mean - db * rstd / reduce_size
// dx = rstd * dy * gamma + b * x + c
// return dx
// def normalization_beta_backward_gamma_beta(dy, x, x_mean, rstd, reduce_axis):
// # Assume shape of gamma and beta are the same
// dgamma = np.sum(dy * (x - x_mean) * rstd, axis=reduce_axis, keepdims=True)
// dbeta = np.sum(dy, axis=reduce_axis, keepdims=True)
// return dgamma, dbeta
// def layernorm_backward(dy, x, gamma, x_mean, rstd):
// # dy, x = [M, K], gamma = [1, K], x_mean, rstd = [M, 1]
// # dx = [M, K], dgamma, dbeta = [1, K]
// M, K = x.shape
// dx = normalization_input_backward(dy, x, gamma, x_mean, rstd, 1, K)
// dgamma, dbeta = normalization_gamma_beta_backward(dy, x, x_mean, rstd, 0)
// return dx, dgamma, dbeta
// Reference (Layernorm and groupnorm):
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/layer_norm_kernel.cpp#L196
template <typename DYDataType, template <typename DYDataType,
typename XDataType, typename XDataType,
typename GammaDataType, typename GammaDataType,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP32
// FP32
void add_device_groupnorm_bwd_data_f32_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdData<F32, F32, F32, F32, F32, 5, 3>>>&);
#endif
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename DXDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceNormalizationBwdData<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DXDataType,
5,
3>>
{
using DeviceOp = DeviceNormalizationBwdData<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DXDataType,
5,
3>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<DYDataType, F32> && is_same_v<XDataType, F32> &&
is_same_v<GammaDataType, F32> && is_same_v<MeanInvStdDataType, F32> &&
is_same_v<DXDataType, F32>)
{
add_device_groupnorm_bwd_data_f32_instances(op_ptrs);
}
#endif
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment