Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7a7d50ec
Commit
7a7d50ec
authored
Feb 10, 2023
by
rocking
Browse files
Support naive variance for device_normalization
parent
f174fb09
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
223 additions
and
103 deletions
+223
-103
include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp
...r_operation/gpu/device/impl/device_normalization_impl.hpp
+28
-103
include/ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp
...or_operation/gpu/grid/gridwise_normalization_selector.hpp
+195
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp
View file @
7a7d50ec
...
@@ -10,46 +10,11 @@
...
@@ -10,46 +10,11 @@
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_
w
el
ford_variance
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_
s
el
ector
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
template
<
typename
GridwiseReduction
,
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
ConputeDataType
,
typename
YElementwiseOperation
,
typename
GridDesc_M_K
>
__global__
void
kernel_normalization
(
const
GridDesc_M_K
x_grid_desc_m_k
,
const
GridDesc_M_K
gamma_grid_desc_m_k
,
const
GridDesc_M_K
beta_grid_desc_m_k
,
const
GridDesc_M_K
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
ConputeDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
const
YElementwiseOperation
y_elementwise_op
)
{
GridwiseReduction
::
Run
(
x_grid_desc_m_k
,
gamma_grid_desc_m_k
,
beta_grid_desc_m_k
,
y_grid_desc_m_k
,
num_k_block_tile_iteration
,
epsilon
,
p_x_global
,
p_gamma_global
,
p_beta_global
,
p_y_global
,
y_elementwise_op
);
};
}
// namespace ck
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
@@ -58,7 +23,7 @@ namespace device {
...
@@ -58,7 +23,7 @@ namespace device {
template
<
typename
XDataType
,
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
typename
Co
n
puteDataType
,
typename
Co
m
puteDataType
,
typename
YDataType
,
typename
YDataType
,
typename
YElementwiseOperation
,
typename
YElementwiseOperation
,
index_t
Rank
,
index_t
Rank
,
...
@@ -74,11 +39,12 @@ template <typename XDataType,
...
@@ -74,11 +39,12 @@ template <typename XDataType,
index_t
GammaSrcVectorSize
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorSize
>
index_t
YDstVectorSize
,
bool
UseWelford
=
true
>
struct
DeviceNormalizationImpl
:
public
DeviceNormalization
<
XDataType
,
struct
DeviceNormalizationImpl
:
public
DeviceNormalization
<
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
Co
n
puteDataType
,
Co
m
puteDataType
,
YDataType
,
YDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
Rank
,
Rank
,
...
@@ -167,51 +133,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -167,51 +133,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridwiseReduceLayernormGeneric
=
GridwiseNormalizationWelfordVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
ConputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XYSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
XYSrcVectorDim
,
YDstVectorSize
,
false
>
;
using
GridwiseNormalizationSweepOnce
=
GridwiseNormalizationWelfordVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
ConputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XYSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
XYSrcVectorDim
,
YDstVectorSize
,
true
>
;
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
const
std
::
vector
<
index_t
>
lengths
,
Argument
(
const
std
::
vector
<
index_t
>
lengths
,
...
@@ -232,7 +153,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -232,7 +153,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
p_y_
(
p_y
),
p_y_
(
p_y
),
y_elementwise_op_
(
y_elementwise_op
)
y_elementwise_op_
(
y_elementwise_op
)
{
{
epsilon_
=
static_cast
<
Co
n
puteDataType
>
(
epsilon
);
epsilon_
=
static_cast
<
Co
m
puteDataType
>
(
epsilon
);
Lengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
lengths
,
reduceDims
);
Lengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
lengths
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
xStrides
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
xStrides
,
reduceDims
);
...
@@ -265,7 +186,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -265,7 +186,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
x_grid_desc_m_k_
.
GetLength
(
Number
<
1
>
{})
<=
KThreadClusterSize
*
KThreadSliceSize
;
x_grid_desc_m_k_
.
GetLength
(
Number
<
1
>
{})
<=
KThreadClusterSize
*
KThreadSliceSize
;
}
}
Co
n
puteDataType
epsilon_
;
Co
m
puteDataType
epsilon_
;
const
XDataType
*
p_x_
;
const
XDataType
*
p_x_
;
const
GammaDataType
*
p_gamma_
;
const
GammaDataType
*
p_gamma_
;
...
@@ -295,23 +216,27 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
...
@@ -295,23 +216,27 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
{
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
const
auto
kernel_main
=
arg
.
isSweeponce_
auto
kernel_main
=
NormalizationKernelSelector
<
XDataType
,
?
kernel_normalization
<
GridwiseNormalizationSweepOnce
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
ConputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
>
:
kernel_normalization
<
GridwiseReduceLayernormGeneric
,
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
Co
n
puteDataType
,
Co
m
puteDataType
,
YElementwiseOperation
,
YElementwiseOperation
,
GridDesc_M_K
>
;
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XYSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
XYSrcVectorDim
,
YDstVectorSize
,
UseWelford
>
(
arg
.
isSweeponce_
);
float
avg_time
=
0
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
avg_time
+=
launch_and_time_kernel
(
stream_config
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp
0 → 100644
View file @
7a7d50ec
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp"
namespace
ck
{
template
<
typename
GridwiseReduction
,
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
GridDesc_M_K
>
__global__
void
kernel_normalization
(
const
GridDesc_M_K
x_grid_desc_m_k
,
const
GridDesc_M_K
gamma_grid_desc_m_k
,
const
GridDesc_M_K
beta_grid_desc_m_k
,
const
GridDesc_M_K
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
ComputeDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
const
YElementwiseOperation
y_elementwise_op
)
{
GridwiseReduction
::
Run
(
x_grid_desc_m_k
,
gamma_grid_desc_m_k
,
beta_grid_desc_m_k
,
y_grid_desc_m_k
,
num_k_block_tile_iteration
,
epsilon
,
p_x_global
,
p_gamma_global
,
p_beta_global
,
p_y_global
,
y_elementwise_op
);
};
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
GridDesc_M_K
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
bool
UseWelford
>
auto
NormalizationKernelSelector
(
bool
isSweepOnce
)
{
using
GridwiseNormalizationGenericNaive
=
GridwiseNormalizationNaiveVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
YDstVectorDim
,
YDstVectorSize
,
false
>
;
using
GridwiseNormalizationSweepOnceNaive
=
GridwiseNormalizationNaiveVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
YDstVectorDim
,
YDstVectorSize
,
true
>
;
using
GridwiseNormalizationGenericWelford
=
GridwiseNormalizationWelfordVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
YDstVectorDim
,
YDstVectorSize
,
false
>
;
using
GridwiseNormalizationSweepOnceWelford
=
GridwiseNormalizationWelfordVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
YDstVectorDim
,
YDstVectorSize
,
true
>
;
if
constexpr
(
UseWelford
)
{
return
isSweepOnce
?
kernel_normalization
<
GridwiseNormalizationSweepOnceWelford
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
>
:
kernel_normalization
<
GridwiseNormalizationGenericWelford
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
>
;
}
else
{
return
isSweepOnce
?
kernel_normalization
<
GridwiseNormalizationSweepOnceNaive
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
>
:
kernel_normalization
<
GridwiseNormalizationGenericNaive
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
ComputeDataType
,
YElementwiseOperation
,
GridDesc_M_K
>
;
}
}
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment