Commit 7a7d50ec authored by rocking's avatar rocking
Browse files

Support naive variance for device_normalization

parent f174fb09
...@@ -10,46 +10,11 @@ ...@@ -10,46 +10,11 @@
#include "ck/tensor_operation/gpu/device/device_normalization.hpp" #include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
namespace ck {
template <typename GridwiseReduction,
typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename ConputeDataType,
typename YElementwiseOperation,
typename GridDesc_M_K>
__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_M_K gamma_grid_desc_m_k,
const GridDesc_M_K beta_grid_desc_m_k,
const GridDesc_M_K y_grid_desc_m_k,
index_t num_k_block_tile_iteration,
ConputeDataType epsilon,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
const YElementwiseOperation y_elementwise_op)
{
GridwiseReduction::Run(x_grid_desc_m_k,
gamma_grid_desc_m_k,
beta_grid_desc_m_k,
y_grid_desc_m_k,
num_k_block_tile_iteration,
epsilon,
p_x_global,
p_gamma_global,
p_beta_global,
p_y_global,
y_elementwise_op);
};
} // namespace ck
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -58,7 +23,7 @@ namespace device { ...@@ -58,7 +23,7 @@ namespace device {
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename ConputeDataType, typename ComputeDataType,
typename YDataType, typename YDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
index_t Rank, index_t Rank,
...@@ -74,11 +39,12 @@ template <typename XDataType, ...@@ -74,11 +39,12 @@ template <typename XDataType,
index_t GammaSrcVectorSize, index_t GammaSrcVectorSize,
index_t BetaSrcVectorDim, index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorSize> index_t YDstVectorSize,
bool UseWelford = true>
struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
ConputeDataType, ComputeDataType,
YDataType, YDataType,
YElementwiseOperation, YElementwiseOperation,
Rank, Rank,
...@@ -167,51 +133,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -167,51 +133,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)); using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridwiseReduceLayernormGeneric =
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType,
BetaDataType,
YDataType,
ConputeDataType,
YElementwiseOperation,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XYSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
XYSrcVectorDim,
YDstVectorSize,
false>;
using GridwiseNormalizationSweepOnce =
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType,
BetaDataType,
YDataType,
ConputeDataType,
YElementwiseOperation,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XYSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
XYSrcVectorDim,
YDstVectorSize,
true>;
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::vector<index_t> lengths, Argument(const std::vector<index_t> lengths,
...@@ -232,7 +153,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -232,7 +153,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
p_y_(p_y), p_y_(p_y),
y_elementwise_op_(y_elementwise_op) y_elementwise_op_(y_elementwise_op)
{ {
epsilon_ = static_cast<ConputeDataType>(epsilon); epsilon_ = static_cast<ComputeDataType>(epsilon);
Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims); Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims); xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims);
...@@ -265,7 +186,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -265,7 +186,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
} }
ConputeDataType epsilon_; ComputeDataType epsilon_;
const XDataType* p_x_; const XDataType* p_x_;
const GammaDataType* p_gamma_; const GammaDataType* p_gamma_;
...@@ -295,23 +216,27 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -295,23 +216,27 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const auto kernel_main = arg.isSweeponce_ auto kernel_main = NormalizationKernelSelector<XDataType,
? kernel_normalization<GridwiseNormalizationSweepOnce, GammaDataType,
XDataType, BetaDataType,
GammaDataType, YDataType,
BetaDataType, ComputeDataType,
YDataType, YElementwiseOperation,
ConputeDataType, GridDesc_M_K,
YElementwiseOperation, BlockSize,
GridDesc_M_K> MThreadClusterSize,
: kernel_normalization<GridwiseReduceLayernormGeneric, KThreadClusterSize,
XDataType, MThreadSliceSize,
GammaDataType, KThreadSliceSize,
BetaDataType, XYSrcVectorDim,
YDataType, XSrcVectorSize,
ConputeDataType, GammaSrcVectorDim,
YElementwiseOperation, GammaSrcVectorSize,
GridDesc_M_K>; BetaSrcVectorDim,
BetaSrcVectorSize,
XYSrcVectorDim,
YDstVectorSize,
UseWelford>(arg.isSweeponce_);
float avg_time = 0; float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config, avg_time += launch_and_time_kernel(stream_config,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp"
namespace ck {
template <typename GridwiseReduction,
typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename ComputeDataType,
typename YElementwiseOperation,
typename GridDesc_M_K>
__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_M_K gamma_grid_desc_m_k,
const GridDesc_M_K beta_grid_desc_m_k,
const GridDesc_M_K y_grid_desc_m_k,
index_t num_k_block_tile_iteration,
ComputeDataType epsilon,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
const YElementwiseOperation y_elementwise_op)
{
GridwiseReduction::Run(x_grid_desc_m_k,
gamma_grid_desc_m_k,
beta_grid_desc_m_k,
y_grid_desc_m_k,
num_k_block_tile_iteration,
epsilon,
p_x_global,
p_gamma_global,
p_beta_global,
p_y_global,
y_elementwise_op);
};
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename YDataType,
typename ComputeDataType,
typename YElementwiseOperation,
typename GridDesc_M_K,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcVectorDim,
index_t XSrcVectorSize,
index_t GammaSrcVectorDim,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize,
index_t YDstVectorDim,
index_t YDstVectorSize,
bool UseWelford>
auto NormalizationKernelSelector(bool isSweepOnce)
{
using GridwiseNormalizationGenericNaive =
GridwiseNormalizationNaiveVariance_mk_to_mk<XDataType,
GammaDataType,
BetaDataType,
YDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
YDstVectorDim,
YDstVectorSize,
false>;
using GridwiseNormalizationSweepOnceNaive =
GridwiseNormalizationNaiveVariance_mk_to_mk<XDataType,
GammaDataType,
BetaDataType,
YDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
YDstVectorDim,
YDstVectorSize,
true>;
using GridwiseNormalizationGenericWelford =
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType,
BetaDataType,
YDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
YDstVectorDim,
YDstVectorSize,
false>;
using GridwiseNormalizationSweepOnceWelford =
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType,
BetaDataType,
YDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
YDstVectorDim,
YDstVectorSize,
true>;
if constexpr(UseWelford)
{
return isSweepOnce ? kernel_normalization<GridwiseNormalizationSweepOnceWelford,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K>
: kernel_normalization<GridwiseNormalizationGenericWelford,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K>;
}
else
{
return isSweepOnce ? kernel_normalization<GridwiseNormalizationSweepOnceNaive,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K>
: kernel_normalization<GridwiseNormalizationGenericNaive,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
ComputeDataType,
YElementwiseOperation,
GridDesc_M_K>;
}
}
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment