Unverified Commit a3d9a2cd authored by rocking's avatar rocking Committed by GitHub
Browse files

Layernorm4d (#1022)



* Rename folder

* Add layernorm 4d fwd example

* Rename original layernorm example

* Add layernorm 4d f16  test

* Add layernorm4d_fwd client example

* Support layernorm4D in ckProfiler

* Rename groupnorm to groupnorm fwd in example

* Rename layernorm and group fwd in test

* Rename normalization to normalization_fwd (instances)

* Add fwd to DeviceNormalization

* Rename external api header

* Rename folder, because we can also add bwd in this folder

* Add fwd in layernorm and groupnorm (profiler

* Fix compile error

---------
Co-authored-by: default avatarPo Yen Chen <PoYen.Chen@amd.com>
parent ce526211
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#pragma once #pragma once
int run_groupnorm_example(int argc, char* argv[]) int run_groupnorm_fwd_example(int argc, char* argv[])
{ {
ck::index_t N = 32; ck::index_t N = 32;
ck::index_t H = 16; ck::index_t H = 16;
...@@ -65,9 +65,9 @@ int run_groupnorm_example(int argc, char* argv[]) ...@@ -65,9 +65,9 @@ int run_groupnorm_example(int argc, char* argv[])
{0, 0, 0, C, 1}, {0, 0, 0, C, 1},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(), std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()}, save_mean.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(), std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()}, save_mean.mDesc.GetStrides().end()},
{1, 2, 4}, // reduction dimension: [H, W, C] {1, 2, 4}, // reduction dimension: [H, W, C]
1e-6, 1e-6,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
......
add_example_executable(example_layernorm4d_fwd_fp16 layernorm4d_fwd_fp16.cpp)
add_example_executable(example_layernorm4d_fwd_splitk_fp16 layernorm4d_fwd_splitk_fp16.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using SaveMeanInvStdDataType = float;
using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#define SAVE_MEAN_INV_STD
constexpr int Rank = 4;
constexpr int NumReduceDim = 3;
using DeviceInstance =
ck::tensor_operation::device::DeviceNormalizationFwdImpl<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
SaveMeanInvStdDataType,
PassThrough,
Rank,
NumReduceDim,
256, // BlockSize
8, // ClusterM
32, // ClusterK
1, // SliceM
8, // SliceK
1, // XYVectorDim (0=M, 1=K)
8, // SrcScalarPerVector
1, // GammaVecDim (0=M, 1=K)
8, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K)
8, // BetaScalarPerVector
8, // YScalarPerVector
1>; // SaveMeanInvStdScalarPerVector
#include "run_layernorm4d_fwd_example.inc"
int main() { return run_layernorm4d_fwd_example<DeviceInstance>(); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using XDataType = ck::half_t;
using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using SaveMeanInvStdDataType = float;
using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
#define SAVE_MEAN_INV_STD
constexpr int Rank = 4;
constexpr int NumReduceDim = 3;
using DeviceInstance = ck::tensor_operation::device::DeviceNormalizationFwdSplitKImpl<
XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
SaveMeanInvStdDataType,
PassThrough,
Rank,
NumReduceDim,
256, // BlockSize
8, // ClusterM
32, // ClusterK
1, // SliceM
8, // SliceK
1, // XYVectorDim (0=M, 1=K)
8, // XScalarPerVector
1, // GammaVecDim (0=M, 1=K)
8, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K)
8, // BetaScalarPerVector
8, // YScalarPerVector
1>; // SaveMeanInvStdScalarPerVector
#include "run_layernorm4d_fwd_example.inc"
int main() { return run_layernorm4d_fwd_example<DeviceInstance>(); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <typename DeviceInstance>
int run_layernorm4d_fwd_example()
{
bool time_kernel = false;
ck::index_t N = 256;
ck::index_t H = 16;
ck::index_t W = 16;
ck::index_t C = 8;
Tensor<XDataType> x({N, H, W, C});
Tensor<GammaDataType> gamma({H, W, C});
Tensor<BetaDataType> beta({H, W, C});
Tensor<YDataType> y({N, H, W, C});
Tensor<SaveMeanInvStdDataType> save_mean({N});
Tensor<SaveMeanInvStdDataType> save_inv_std({N});
x.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0});
beta.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{0.0, 1.0});
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize());
DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize());
#ifdef SAVE_MEAN_INV_STD
DeviceMem save_mean_dev(sizeof(SaveMeanInvStdDataType) * save_mean.mDesc.GetElementSpaceSize());
DeviceMem save_inv_std_dev(sizeof(SaveMeanInvStdDataType) *
save_inv_std.mDesc.GetElementSpaceSize());
#endif
x_dev.ToDevice(x.mData.data());
gamma_dev.ToDevice(gamma.mData.data());
beta_dev.ToDevice(beta.mData.data());
auto device_instance = DeviceInstance{};
auto argument_ptr = device_instance.MakeArgumentPointer(
{N, H, W, C},
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()},
{0, W * C, C, 1},
{0, W * C, C, 1},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()},
{1, 2, 3},
1e-4,
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
save_mean_dev.GetDeviceBuffer(),
save_inv_std_dev.GetDeviceBuffer(),
#else
nullptr,
nullptr,
#endif
PassThrough{});
if(!device_instance.IsSupportedArgument(argument_ptr.get()))
{
std::cout << "The runtime parameters are not supported" << std::endl;
return 1;
};
size_t workspace_sz = device_instance.GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
device_instance.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
auto invoker_ptr = device_instance.MakeInvokerPointer();
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
bool pass = true;
{
Tensor<YDataType> host_y({N, H, W, C});
Tensor<SaveMeanInvStdDataType> host_save_mean({N});
Tensor<SaveMeanInvStdDataType> host_save_inv_std({N});
using ReferenceInstance =
ck::tensor_operation::host::ReferenceLayernorm<XDataType,
GammaDataType,
BetaDataType,
YDataType,
SaveMeanInvStdDataType,
ComputeDataType,
PassThrough,
Rank,
NumReduceDim>;
ReferenceInstance ref;
auto ref_argument = ref.MakeArgument(x,
gamma,
beta,
host_y,
host_save_mean,
host_save_inv_std,
PassThrough{},
{N, H, W, C},
{1, 2, 3},
1e-4);
auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument);
y_dev.FromDevice(y.mData.data());
pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results (y)", 1e-3, 1e-3);
#ifdef SAVE_MEAN_INV_STD
save_mean_dev.FromDevice(save_mean.mData.data());
save_inv_std_dev.FromDevice(save_inv_std.mData.data());
pass &= ck::utils::check_err(
save_mean, host_save_mean, "Error: Incorrect results (mean)", 1e-3, 1e-3);
pass &= ck::utils::check_err(
save_inv_std, host_save_inv_std, "Error: Incorrect results (inv_std)", 1e-3, 1e-3);
#endif
}
return (pass ? 0 : 1);
}
...@@ -19,7 +19,7 @@ template <typename XDataType, ...@@ -19,7 +19,7 @@ template <typename XDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
struct DeviceNormalization : public BaseOperator struct DeviceNormalizationFwd : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> lengths, MakeArgumentPointer(const std::vector<index_t> lengths,
...@@ -50,14 +50,14 @@ template <typename XDataType, ...@@ -50,14 +50,14 @@ template <typename XDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization<XDataType, using DeviceNormalizationFwdPtr = std::unique_ptr<DeviceNormalizationFwd<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType, SaveMeanInvStdDataType,
YElementwiseOperation, YElementwiseOperation,
Rank, Rank,
NumReduceDim>>; NumReduceDim>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <sstream> #include <sstream>
#include "ck/utility/reduction_operator.hpp" #include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp" #include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp" #include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp"
...@@ -46,14 +46,14 @@ template <typename XDataType, ...@@ -46,14 +46,14 @@ template <typename XDataType,
index_t YDstVectorSize, index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize, index_t SaveMeanInvStdDstVectorSize,
bool UseWelford = true> bool UseWelford = true>
struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType, SaveMeanInvStdDataType,
YElementwiseOperation, YElementwiseOperation,
Rank, Rank,
NumReduceDim> NumReduceDim>
{ {
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize);
static_assert( static_assert(
...@@ -461,7 +461,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -461,7 +461,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceNormalizationImpl<" << BlockSize << ","; str << "DeviceNormalizationFwdImpl<" << BlockSize << ",";
str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ","; str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ",";
str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ","; str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ",";
str << "XYSrcVectorDim_" << XYSrcVectorDim << ","; str << "XYSrcVectorDim_" << XYSrcVectorDim << ",";
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/utility/reduction_operator.hpp" #include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp" #include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp" #include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp"
...@@ -134,14 +134,14 @@ template <typename XDataType, ...@@ -134,14 +134,14 @@ template <typename XDataType,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorSize, index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize> index_t SaveMeanInvStdDstVectorSize>
struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType, SaveMeanInvStdDataType,
YElementwiseOperation, YElementwiseOperation,
Rank, Rank,
NumReduceDim> NumReduceDim>
{ {
using WorkspaceMeanVarDataType = SaveMeanInvStdDataType; using WorkspaceMeanVarDataType = SaveMeanInvStdDataType;
...@@ -732,7 +732,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType, ...@@ -732,7 +732,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceNormalizationSplitKImpl<" << BlockSize << ","; str << "DeviceNormalizationFwdSplitKImpl<" << BlockSize << ",";
str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ","; str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ",";
str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ","; str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ",";
str << "XYSrcVectorDim_" << XYVectorDim << ","; str << "XYSrcVectorDim_" << XYVectorDim << ",";
......
...@@ -28,7 +28,8 @@ template <typename XDataType, ...@@ -28,7 +28,8 @@ template <typename XDataType,
struct ReferenceLayernorm : public device::BaseOperator struct ReferenceLayernorm : public device::BaseOperator
{ {
// TODO - support generic layernorm // TODO - support generic layernorm
static_assert((Rank == 2 && NumReduceDim == 1), "Only support 2D version so far"); static_assert((Rank == 2 && NumReduceDim == 1) || (Rank == 4 && NumReduceDim == 3),
"Only support 2D & 4D version so far");
// Argument // Argument
struct Argument : public device::BaseArgument struct Argument : public device::BaseArgument
...@@ -71,7 +72,7 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -71,7 +72,7 @@ struct ReferenceLayernorm : public device::BaseOperator
// Invoker // Invoker
struct Invoker : public device::BaseInvoker struct Invoker : public device::BaseInvoker
{ {
float Run(const Argument& arg) float Run2D(const Argument& arg)
{ {
int M = arg.lengths_[0]; int M = arg.lengths_[0];
int N = arg.lengths_[1]; int N = arg.lengths_[1];
...@@ -117,6 +118,71 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -117,6 +118,71 @@ struct ReferenceLayernorm : public device::BaseOperator
return 0; return 0;
} }
float Run4D(const Argument& arg)
{
int N = arg.lengths_[0];
int H = arg.lengths_[1];
int W = arg.lengths_[2];
int C = arg.lengths_[3];
Tensor<ComputeDataType> mean({N});
Tensor<ComputeDataType> var({N});
int reduce_length = H * W * C;
for(int n = 0; n < N; ++n)
{
mean(n) = 0;
var(n) = 0;
for(int h = 0; h < H; ++h)
for(int w = 0; w < W; ++w)
for(int c = 0; c < C; ++c)
{
auto x_val = ck::type_convert<ComputeDataType>(arg.x_m_n_(n, h, w, c));
mean(n) += x_val;
var(n) += x_val * x_val;
}
mean(n) = mean(n) / reduce_length;
var(n) = (var(n) / reduce_length) - (mean(n) * mean(n));
}
for(int n = 0; n < N; ++n)
{
ComputeDataType divisor =
static_cast<ComputeDataType>(1) / ck::math::sqrt(var(n) + arg.epsilon_);
for(int h = 0; h < H; ++h)
for(int w = 0; w < W; ++w)
for(int c = 0; c < C; ++c)
{
auto x_val = ck::type_convert<ComputeDataType>(arg.x_m_n_(n, h, w, c));
auto gamma_val =
ck::type_convert<ComputeDataType>(arg.gamma_n_(h, w, c));
auto beta_val = ck::type_convert<ComputeDataType>(arg.beta_n_(h, w, c));
auto y_val = (x_val - mean(n)) * divisor;
y_val = (y_val * gamma_val) + beta_val;
arg.y_elementwise_op_(y_val, y_val);
arg.y_m_n_(n, h, w, c) = ck::type_convert<YDataType>(y_val);
}
arg.save_mean_m_(n) = ck::type_convert<SaveMeanInvStdDataType>(mean(n));
arg.save_inv_std_m_(n) = ck::type_convert<SaveMeanInvStdDataType>(divisor);
}
return 0;
}
float Run(const Argument& arg)
{
if(arg.lengths_.size() == 2)
return Run2D(arg);
else if(arg.lengths_.size() == 4)
return Run4D(arg);
return 0;
}
float Run(const device::BaseArgument* p_arg, float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override const StreamConfig& /* stream_config */ = StreamConfig{}) override
{ {
...@@ -134,17 +200,16 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -134,17 +200,16 @@ struct ReferenceLayernorm : public device::BaseOperator
{ {
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg); const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
// TODO - support generic layernorm if(p_arg_->lengths_.size() == 2 && p_arg_->reduceDims_.size() == 1 &&
if(p_arg_->lengths_.size() != 2) p_arg_->reduceDims_[0] == 1)
return false; return true;
if(p_arg_->reduceDims_.size() != 1)
return false;
if(p_arg_->reduceDims_[0] != 1) else if(p_arg_->lengths_.size() == 4 && p_arg_->reduceDims_.size() == 3 &&
return false; p_arg_->reduceDims_[0] == 1 && p_arg_->reduceDims_[1] == 2 &&
p_arg_->reduceDims_[2] == 3)
return true;
return true; return false;
} }
static auto MakeArgument(const Tensor<XDataType>& x_m_n, static auto MakeArgument(const Tensor<XDataType>& x_m_n,
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <memory> #include <memory>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp" #include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...@@ -18,25 +18,31 @@ namespace device { ...@@ -18,25 +18,31 @@ namespace device {
namespace instance { namespace instance {
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
// FP16 // FP16
void add_device_normalization_rank_2_1_f16_instances( void add_device_normalization_fwd_rank_2_1_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, PassThrough, 2, 1>>>&); std::vector<
std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, PassThrough, 2, 1>>>&);
void add_device_normalization_rank_4_3_f16_instances( void add_device_normalization_fwd_rank_4_3_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, PassThrough, 4, 3>>>&); std::vector<
std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, PassThrough, 4, 3>>>&);
void add_device_normalization_rank_5_3_f16_instances( void add_device_normalization_fwd_rank_5_3_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, PassThrough, 5, 3>>>&); std::vector<
std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, PassThrough, 5, 3>>>&);
#endif #endif
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
// FP32 // FP32
void add_device_normalization_rank_2_1_f32_instances( void add_device_normalization_fwd_rank_2_1_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, PassThrough, 2, 1>>>&); std::vector<
std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, PassThrough, 2, 1>>>&);
void add_device_normalization_rank_4_3_f32_instances( void add_device_normalization_fwd_rank_4_3_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&); std::vector<
std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
void add_device_normalization_rank_5_3_f32_instances( void add_device_normalization_fwd_rank_5_3_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, PassThrough, 5, 3>>>&); std::vector<
std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, PassThrough, 5, 3>>>&);
#endif #endif
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
...@@ -45,7 +51,7 @@ template <typename XDataType, ...@@ -45,7 +51,7 @@ template <typename XDataType,
typename SaveMeanInvStdDataType, typename SaveMeanInvStdDataType,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormalization< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormalizationFwd<
XDataType, XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
...@@ -55,14 +61,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal ...@@ -55,14 +61,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
Rank, Rank,
NumReduceDim>> NumReduceDim>>
{ {
using DeviceOp = DeviceNormalization<XDataType, using DeviceOp = DeviceNormalizationFwd<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType, SaveMeanInvStdDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Rank, Rank,
NumReduceDim>; NumReduceDim>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -74,15 +80,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal ...@@ -74,15 +80,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
{ {
if constexpr(Rank == 2 && NumReduceDim == 1) if constexpr(Rank == 2 && NumReduceDim == 1)
{ {
add_device_normalization_rank_2_1_f16_instances(op_ptrs); add_device_normalization_fwd_rank_2_1_f16_instances(op_ptrs);
} }
else if constexpr(Rank == 4 && NumReduceDim == 3) else if constexpr(Rank == 4 && NumReduceDim == 3)
{ {
add_device_normalization_rank_4_3_f16_instances(op_ptrs); add_device_normalization_fwd_rank_4_3_f16_instances(op_ptrs);
} }
else if constexpr(Rank == 5 && NumReduceDim == 3) else if constexpr(Rank == 5 && NumReduceDim == 3)
{ {
add_device_normalization_rank_5_3_f16_instances(op_ptrs); add_device_normalization_fwd_rank_5_3_f16_instances(op_ptrs);
} }
} }
#endif #endif
...@@ -93,15 +99,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal ...@@ -93,15 +99,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
{ {
if constexpr(Rank == 2 && NumReduceDim == 1) if constexpr(Rank == 2 && NumReduceDim == 1)
{ {
add_device_normalization_rank_2_1_f32_instances(op_ptrs); add_device_normalization_fwd_rank_2_1_f32_instances(op_ptrs);
} }
else if constexpr(Rank == 4 && NumReduceDim == 3) else if constexpr(Rank == 4 && NumReduceDim == 3)
{ {
add_device_normalization_rank_4_3_f32_instances(op_ptrs); add_device_normalization_fwd_rank_4_3_f32_instances(op_ptrs);
} }
else if constexpr(Rank == 5 && NumReduceDim == 3) else if constexpr(Rank == 5 && NumReduceDim == 3)
{ {
add_device_normalization_rank_5_3_f32_instances(op_ptrs); add_device_normalization_fwd_rank_5_3_f32_instances(op_ptrs);
} }
} }
#endif #endif
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp" #include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...@@ -18,16 +18,16 @@ namespace device { ...@@ -18,16 +18,16 @@ namespace device {
namespace instance { namespace instance {
// FP16 // FP16
void add_device_normalization_rank_5_3_swish_f16_instances( void add_device_normalization_fwd_rank_5_3_swish_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, Swish, 5, 3>>>&); std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Swish, 5, 3>>>&);
// FP32 // FP32
void add_device_normalization_rank_5_3_swish_f32_instances( void add_device_normalization_fwd_rank_5_3_swish_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Swish, 5, 3>>>&); std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Swish, 5, 3>>>&);
// [x, gamma, beta, y] = [f16, f32, f32, f16] // [x, gamma, beta, y] = [f16, f32, f32, f16]
void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances( void add_device_normalization_fwd_rank_5_3_swish_f16_f32_f32_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F16, F32, Swish, 5, 3>>>&); std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F32, F32, F16, F32, Swish, 5, 3>>>&);
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
...@@ -37,23 +37,23 @@ template <typename XDataType, ...@@ -37,23 +37,23 @@ template <typename XDataType,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
struct DeviceOperationInstanceFactory< struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceNormalization<XDataType, ck::tensor_operation::device::DeviceNormalizationFwd<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType, SaveMeanInvStdDataType,
ck::tensor_operation::element_wise::Swish, ck::tensor_operation::element_wise::Swish,
Rank, Rank,
NumReduceDim>> NumReduceDim>>
{ {
using DeviceOp = DeviceNormalization<XDataType, using DeviceOp = DeviceNormalizationFwd<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType, SaveMeanInvStdDataType,
ck::tensor_operation::element_wise::Swish, ck::tensor_operation::element_wise::Swish,
Rank, Rank,
NumReduceDim>; NumReduceDim>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -65,7 +65,7 @@ struct DeviceOperationInstanceFactory< ...@@ -65,7 +65,7 @@ struct DeviceOperationInstanceFactory<
{ {
if constexpr(Rank == 5 && NumReduceDim == 3) if constexpr(Rank == 5 && NumReduceDim == 3)
{ {
add_device_normalization_rank_5_3_swish_f16_instances(op_ptrs); add_device_normalization_fwd_rank_5_3_swish_f16_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, F32> && is_same_v<GammaDataType, F32> && else if constexpr(is_same_v<XDataType, F32> && is_same_v<GammaDataType, F32> &&
...@@ -74,7 +74,7 @@ struct DeviceOperationInstanceFactory< ...@@ -74,7 +74,7 @@ struct DeviceOperationInstanceFactory<
{ {
if constexpr(Rank == 5 && NumReduceDim == 3) if constexpr(Rank == 5 && NumReduceDim == 3)
{ {
add_device_normalization_rank_5_3_swish_f32_instances(op_ptrs); add_device_normalization_fwd_rank_5_3_swish_f32_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F32> && else if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F32> &&
...@@ -83,7 +83,7 @@ struct DeviceOperationInstanceFactory< ...@@ -83,7 +83,7 @@ struct DeviceOperationInstanceFactory<
{ {
if constexpr(Rank == 5 && NumReduceDim == 3) if constexpr(Rank == 5 && NumReduceDim == 3)
{ {
add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances(op_ptrs); add_device_normalization_fwd_rank_5_3_swish_f16_f32_f32_f16_instances(op_ptrs);
} }
} }
......
set(DEVICE_NORMALIZATION_INSTANCES)
list(APPEND DEVICE_NORMALIZATION_INSTANCES
device_layernorm2d_f16_instance.cpp
device_layernorm4d_f16_instance.cpp
device_groupnorm_f16_instance.cpp
device_groupnorm_swish_f16_instance.cpp
device_groupnorm_swish_f16_f32_f32_f16_instance.cpp
device_layernorm2d_f32_instance.cpp
device_layernorm4d_f32_instance.cpp
device_groupnorm_f32_instance.cpp
device_groupnorm_swish_f32_instance.cpp)
add_instance_library(device_normalization_instance ${DEVICE_NORMALIZATION_INSTANCES})
set(DEVICE_NORMALIZATION_FWD_INSTANCES)
list(APPEND DEVICE_NORMALIZATION_FWD_INSTANCES
device_layernorm2d_fwd_f16_instance.cpp
device_layernorm4d_fwd_f16_instance.cpp
device_groupnorm_fwd_f16_instance.cpp
device_groupnorm_fwd_swish_f16_instance.cpp
device_groupnorm_fwd_swish_f16_f32_f32_f16_instance.cpp
device_layernorm2d_fwd_f32_instance.cpp
device_layernorm4d_fwd_f32_instance.cpp
device_groupnorm_fwd_f32_instance.cpp
device_groupnorm_fwd_swish_f32_instance.cpp)
add_instance_library(device_normalization_fwd_instance ${DEVICE_NORMALIZATION_FWD_INSTANCES})
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_instance_common.hpp" #include "normalization_fwd_instance_common.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -10,8 +10,8 @@ namespace instance { ...@@ -10,8 +10,8 @@ namespace instance {
using Pass = ck::tensor_operation::element_wise::PassThrough; using Pass = ck::tensor_operation::element_wise::PassThrough;
void add_device_normalization_rank_5_3_f16_instances( void add_device_normalization_fwd_rank_5_3_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, Pass, 5, 3>>>& std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Pass, 5, 3>>>&
instances) instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_instance_common.hpp" #include "normalization_fwd_instance_common.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -10,8 +10,8 @@ namespace instance { ...@@ -10,8 +10,8 @@ namespace instance {
using Pass = ck::tensor_operation::element_wise::PassThrough; using Pass = ck::tensor_operation::element_wise::PassThrough;
void add_device_normalization_rank_5_3_f32_instances( void add_device_normalization_fwd_rank_5_3_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Pass, 5, 3>>>& std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Pass, 5, 3>>>&
instances) instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_instance_common.hpp" #include "normalization_fwd_instance_common.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -10,8 +10,8 @@ namespace instance { ...@@ -10,8 +10,8 @@ namespace instance {
using Swish = ck::tensor_operation::element_wise::Swish; using Swish = ck::tensor_operation::element_wise::Swish;
void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances( void add_device_normalization_fwd_rank_5_3_swish_f16_f32_f32_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F32, F32, F16, F32, Swish, 5, 3>>>& std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F32, F32, F16, F32, Swish, 5, 3>>>&
instances) instances)
{ {
add_device_operation_instances( add_device_operation_instances(
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_instance_common.hpp" #include "normalization_fwd_instance_common.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -10,8 +10,8 @@ namespace instance { ...@@ -10,8 +10,8 @@ namespace instance {
using Swish = ck::tensor_operation::element_wise::Swish; using Swish = ck::tensor_operation::element_wise::Swish;
void add_device_normalization_rank_5_3_swish_f16_instances( void add_device_normalization_fwd_rank_5_3_swish_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F16, F32, Swish, 5, 3>>>& std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F16, F16, F16, F32, Swish, 5, 3>>>&
instances) instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "normalization_instance_common.hpp" #include "normalization_fwd_instance_common.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -10,8 +10,8 @@ namespace instance { ...@@ -10,8 +10,8 @@ namespace instance {
using Swish = ck::tensor_operation::element_wise::Swish; using Swish = ck::tensor_operation::element_wise::Swish;
void add_device_normalization_rank_5_3_swish_f32_instances( void add_device_normalization_fwd_rank_5_3_swish_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Swish, 5, 3>>>& std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Swish, 5, 3>>>&
instances) instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment