Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
56532f77
"test/vscode:/vscode.git/clone" did not exist on "44d600cd675f6899f4001dd9e3f8c2c7208d1863"
Commit
56532f77
authored
Mar 01, 2023
by
rocking
Browse files
Add second kernel of normalization splitK
parent
28ebcfe7
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
595 additions
and
45 deletions
+595
-45
include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp
...tion/gpu/device/impl/device_normalization_splitk_impl.hpp
+175
-44
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp
.../grid/normalization/gridwise_normalization_splitk_1st.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp
.../grid/normalization/gridwise_normalization_splitk_2nd.hpp
+419
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp
View file @
56532f77
...
...
@@ -12,12 +12,13 @@
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
template
<
typename
GridwiseWelford
1
,
template
<
typename
GridwiseWelford
,
typename
XDataType
,
typename
MeanVarDataType
,
typename
ComputeDataType
,
...
...
@@ -32,7 +33,7 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k,
MeanVarDataType
*
const
__restrict__
p_welford_variance
,
int32_t
*
const
__restrict__
p_welford_count
)
{
GridwiseWelford
1
::
Run
(
x_grid_desc_m_k
,
GridwiseWelford
::
Run
(
x_grid_desc_m_k
,
mean_var_grid_desc_m_kblock
,
num_k_block_tile_iteration
,
p_x_global
,
...
...
@@ -40,6 +41,57 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k,
p_welford_variance
,
p_welford_count
);
};
template
<
typename
GridwiseWelfordNormalization
,
typename
MeanVarDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
MeanVarGridDesc_M_KBlock
,
typename
CountVarGridDesc_M_KBlock
,
typename
XYGammaBetaGridDesc_M_K
>
__global__
void
kernel_normalizationSplitK2nd
(
const
MeanVarGridDesc_M_KBlock
mean_var_grid_desc_m_kblock
,
const
CountVarGridDesc_M_KBlock
count_grid_desc_m_kblock
,
const
XYGammaBetaGridDesc_M_K
x_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
gamma_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
beta_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
y_grid_desc_m_k
,
index_t
num_k_mean_var_count_iteration
,
index_t
num_k_block_tile_iteration
,
index_t
k_grid_size
,
ComputeDataType
epsilon
,
const
MeanVarDataType
*
const
p_mean_global
,
const
MeanVarDataType
*
const
p_variance_global
,
const
int32_t
*
const
p_welford_count_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
const
YElementwiseOperation
y_elementwise_op
)
{
GridwiseWelfordNormalization
::
Run
(
mean_var_grid_desc_m_kblock
,
count_grid_desc_m_kblock
,
x_grid_desc_m_k
,
gamma_grid_desc_m_k
,
beta_grid_desc_m_k
,
y_grid_desc_m_k
,
num_k_mean_var_count_iteration
,
num_k_block_tile_iteration
,
k_grid_size
,
epsilon
,
p_mean_global
,
p_variance_global
,
p_welford_count_global
,
p_x_global
,
p_gamma_global
,
p_beta_global
,
p_y_global
,
y_elementwise_op
);
};
}
// namespace ck
namespace
ck
{
...
...
@@ -64,7 +116,7 @@ template <typename XDataType,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XY
Src
VectorDim
,
index_t
XYVectorDim
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
...
...
@@ -184,22 +236,53 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
}
using
SrcGridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
Welford
1MeanVarGridDesc_M_KBlock
=
using
Kernel
1MeanVarGridDesc_M_KBlock
=
decltype
(
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
false
>
,
1
,
1
>
(
1
,
1
));
using
GridwiseWelford1
=
GridwiseNormalizationSplitK1st
<
XDataType
,
using
Kernel2MeanVarGridDesc_M_KBlock
=
decltype
(
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
true
>
,
1
,
1
>
(
1
,
1
));
using
Kernel2CountGridDesc_M_KBlock
=
decltype
(
MakeCountDescriptor_M_K
<
Sequence
<
true
,
true
>
,
1
,
1
>
(
1
,
1
));
using
GridwiseWelford
=
GridwiseNormalizationSplitK1st
<
XDataType
,
ComputeDataType
,
MeanVarDataType
,
SrcGridDesc_M_K
,
Welford
1MeanVarGridDesc_M_KBlock
,
Kernel
1MeanVarGridDesc_M_KBlock
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XY
Src
VectorDim
,
XYVectorDim
,
XSrcVectorSize
>
;
using
GridwiseWelfordNormalization
=
GridwiseNormalizationSplitK2nd
<
MeanVarDataType
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
ComputeDataType
,
YElementwiseOperation
,
Kernel2MeanVarGridDesc_M_KBlock
,
Kernel2CountGridDesc_M_KBlock
,
SrcGridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XYVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
XYVectorDim
,
YDstVectorSize
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
vector
<
index_t
>
lengths
,
...
...
@@ -236,19 +319,19 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
numBlockTileIteration_
=
1
;
while
(
true
)
{
int
testKGridSize
_
=
int
testKGridSize
=
math
::
integer_divide_ceil
(
KRaw_
,
K_BlockTileSize
*
numBlockTileIteration_
);
// we want the
testK
GridSize_ be not more than 128
if
(
testKGridSize
_
<=
128
)
// we want the
k
GridSize_ be not more than 128
if
(
testKGridSize
<=
128
)
break
;
++
numBlockTileIteration_
;
};
kGridSize_
=
math
::
integer_divide_ceil
(
KRaw_
,
K_BlockTileSize
*
numBlockTileIteration_
);
gridSize_
=
math
::
integer_divide_ceil
(
MRaw_
,
M_BlockTileSize
)
*
kGridSize_
;
numMeanVarCountIteration_
=
math
::
integer_divide_ceil
(
kGridSize_
,
KThreadClusterSize
);
x_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
xStrides_
,
kGridSize_
,
numBlockTileIteration_
);
...
...
@@ -260,9 +343,17 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
kGridSize_
,
numBlockTileIteration_
);
// We don't need to pad in K dimension for Welford1. Set KPerTile 1.
mean_var_grid_desc_m_kblock_
=
kernel1_
mean_var_grid_desc_m_kblock_
=
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
false
>
,
M_BlockTileSize
,
1
>
(
MRaw_
,
kGridSize_
);
kernel2_mean_var_grid_desc_m_kblock_
=
MakeMeanVarDescriptor_M_K
<
Sequence
<
true
,
true
>
,
M_BlockTileSize
,
K_BlockTileSize
>
(
MRaw_
,
kGridSize_
);
kernel2_count_grid_desc_m_kblock_
=
MakeCountDescriptor_M_K
<
Sequence
<
true
,
true
>
,
M_BlockTileSize
,
K_BlockTileSize
>
(
MRaw_
,
kGridSize_
);
}
ComputeDataType
epsilon_
;
...
...
@@ -284,6 +375,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
YElementwiseOperation
y_elementwise_op_
;
int
kGridSize_
;
int
numMeanVarCountIteration_
;
int
numBlockTileIteration_
;
size_t
gridSize_
;
...
...
@@ -292,7 +384,9 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
SrcGridDesc_M_K
beta_grid_desc_m_k_
;
SrcGridDesc_M_K
y_grid_desc_m_k_
;
Welford1MeanVarGridDesc_M_KBlock
mean_var_grid_desc_m_kblock_
;
Kernel1MeanVarGridDesc_M_KBlock
kernel1_mean_var_grid_desc_m_kblock_
;
Kernel2MeanVarGridDesc_M_KBlock
kernel2_mean_var_grid_desc_m_kblock_
;
Kernel2CountGridDesc_M_KBlock
kernel2_count_grid_desc_m_kblock_
;
index_t
MRaw_
;
// invarient length
index_t
KRaw_
;
// reduce length
...
...
@@ -306,12 +400,24 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
arg
.
p_workspace_count_
==
nullptr
)
throw
std
::
runtime_error
(
"wrong! WorkSpace pointer has not been set"
);
auto
kernel1
=
kernel_normalizationSplitK1st
<
GridwiseWelford
1
,
auto
kernel1
=
kernel_normalizationSplitK1st
<
GridwiseWelford
,
XDataType
,
MeanVarDataType
,
ComputeDataType
,
SrcGridDesc_M_K
,
Welford1MeanVarGridDesc_M_KBlock
>
;
Kernel1MeanVarGridDesc_M_KBlock
>
;
auto
kernel2
=
kernel_normalizationSplitK2nd
<
GridwiseWelfordNormalization
,
MeanVarDataType
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
ComputeDataType
,
YElementwiseOperation
,
Kernel2MeanVarGridDesc_M_KBlock
,
Kernel2CountGridDesc_M_KBlock
,
SrcGridDesc_M_K
>
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
...
...
@@ -320,16 +426,38 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
mean_var_grid_desc_m_kblock_
,
arg
.
kernel1_
mean_var_grid_desc_m_kblock_
,
arg
.
numBlockTileIteration_
,
arg
.
p_x_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_var_
),
static_cast
<
int32_t
*>
(
arg
.
p_workspace_count_
));
// TODO - welford2 + elementwise
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel2
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
kernel2_mean_var_grid_desc_m_kblock_
,
arg
.
kernel2_count_grid_desc_m_kblock_
,
arg
.
x_grid_desc_m_k_
,
arg
.
gamma_grid_desc_m_k_
,
arg
.
beta_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
arg
.
numMeanVarCountIteration_
,
arg
.
numBlockTileIteration_
,
arg
.
kGridSize_
,
arg
.
epsilon_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
p_workspace_var_
),
static_cast
<
int32_t
*>
(
arg
.
p_workspace_count_
),
arg
.
p_x_
,
arg
.
p_gamma_
,
arg
.
p_beta_
,
arg
.
p_y_
,
arg
.
y_elementwise_op_
);
return
(
avg_time
)
;
return
avg_time
;
};
float
Run
(
const
BaseArgument
*
p_arg
,
...
...
@@ -390,7 +518,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
if
constexpr
(
XY
Src
VectorDim
==
0
)
if
constexpr
(
XYVectorDim
==
0
)
{
if
constexpr
(
NumInvariantDim
==
0
)
{
...
...
@@ -423,38 +551,41 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
if
constexpr
(
GammaSrcVectorDim
==
0
)
{
if
(
p_arg_
->
gammaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
)
;
return
false
;
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
GammaSrcVectorSize
!=
0
)
return
(
false
)
;
return
false
;
}
else
// if fastest dim is reduced
{
if
(
p_arg_
->
gammaStrides_
[
Rank
-
1
]
!=
1
)
return
(
false
)
;
return
false
;
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
GammaSrcVectorSize
!=
0
)
return
(
false
)
;
return
false
;
}
// if fastest dim is not reduced
if
constexpr
(
BetaSrcVectorDim
==
0
)
{
if
(
p_arg_
->
betaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
)
;
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
BetaSrcVectorSize
!=
0
)
return
(
false
)
;
return
false
;
}
else
// if fastest dim is reduced
{
if
(
p_arg_
->
betaStrides_
[
Rank
-
1
]
!=
1
)
return
(
false
)
;
return
false
;
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
BetaSrcVectorSize
!=
0
)
return
(
false
)
;
return
false
;
}
if
(
p_arg_
->
kGridSize_
<=
1
)
return
false
;
return
true
;
};
...
...
@@ -507,7 +638,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
str
<<
"DeviceNormalizationImpl<"
<<
BlockSize
<<
","
;
str
<<
"Cluster_MK_"
<<
MThreadClusterSize
<<
"_"
<<
KThreadClusterSize
<<
","
;
str
<<
"Slice_MK_"
<<
MThreadSliceSize
<<
"_"
<<
KThreadSliceSize
<<
","
;
str
<<
"XYSrcVectorDim_"
<<
XY
Src
VectorDim
<<
","
;
str
<<
"XYSrcVectorDim_"
<<
XYVectorDim
<<
","
;
str
<<
"VectorSize_X"
<<
XSrcVectorSize
<<
"_Gamma"
<<
GammaSrcVectorSize
<<
"_Beta"
<<
BetaSrcVectorSize
<<
"_Y"
<<
YDstVectorSize
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp
View file @
56532f77
...
...
@@ -203,7 +203,7 @@ struct GridwiseNormalizationSplitK1st
var_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
});
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
for
(
index_t
k
=
0
;
k
<
num_k_block_tile_iteration
;
++
k
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp
0 → 100644
View file @
56532f77
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
MeanVarDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
MeanVarGridDesc_M_KBlock
,
typename
CountVarGridDesc_M_KBlock
,
typename
XYGammaBetaGridDesc_M_K
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
>
struct
GridwiseNormalizationSplitK2nd
{
static_assert
((
XSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
((
YDstVectorDim
==
0
&&
MThreadSliceSize
%
YDstVectorSize
==
0
)
||
(
YDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
(
XSrcVectorSize
==
YDstVectorSize
);
static_assert
(
XSrcVectorSize
==
GammaSrcVectorSize
);
static_assert
(
XSrcVectorSize
==
BetaSrcVectorSize
);
static
constexpr
bool
reorder_thread_cluster
=
(
XSrcVectorDim
==
0
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
XSrcVectorSize
>
;
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{}));
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
static
constexpr
auto
thread_buffer_desc_m_1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
I1
));
using
ThreadWelfordSrcDesc_M_1
=
decltype
(
thread_buffer_desc_m_1
);
using
ThreadWelfordDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
ThreadwiseWelford
=
ThreadwiseWelfordMerge
<
ComputeDataType
,
ThreadWelfordSrcDesc_M_1
,
ThreadWelfordDstDesc_M
>
;
using
BlockwiseWelford
=
BlockwiseWelford
<
ComputeDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileStepSize
=
KThreadClusterSize
*
XSrcVectorSize
;
static
constexpr
auto
ThreadBufferNumber
=
Number
<
KThreadSliceSize
/
XSrcVectorSize
>
{};
__device__
static
void
Run
(
const
MeanVarGridDesc_M_KBlock
&
mean_var_grid_desc_m_kblock
,
const
CountVarGridDesc_M_KBlock
&
count_grid_desc_m_kblock
,
const
XYGammaBetaGridDesc_M_K
&
x_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
&
gamma_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
&
beta_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
&
y_grid_desc_m_k
,
index_t
num_k_mean_var_count_iteration
,
index_t
num_k_block_tile_iteration
,
index_t
k_grid_size
,
ComputeDataType
epsilon
,
const
MeanVarDataType
*
const
p_mean_global
,
const
MeanVarDataType
*
const
p_variance_global
,
const
int32_t
*
const
p_welford_count_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
const
YElementwiseOperation
y_elementwise_op
)
{
// Thread/Block id
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
block_m_cluster_id
=
block_global_id
/
k_grid_size
;
const
index_t
block_k_cluster_id
=
block_global_id
%
k_grid_size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
// Global Memory
const
auto
mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_mean_global
,
mean_var_grid_desc_m_kblock
.
GetElementSpaceSize
());
const
auto
var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_variance_global
,
mean_var_grid_desc_m_kblock
.
GetElementSpaceSize
());
const
auto
welford_count_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_welford_count_global
,
count_grid_desc_m_kblock
.
GetElementSpaceSize
());
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_global
,
gamma_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_beta_global
,
beta_grid_desc_m_k
.
GetElementSpaceSize
());
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
// VGPR
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
in_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
in_var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
in_welford_count_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
welford_count_thread_buf
;
auto
x_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
XSrcVectorSize
,
true
>
{};
},
Number
<
ThreadBufferNumber
>
{});
auto
gamma_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
GammaSrcVectorSize
,
true
>
{};
},
Number
<
ThreadBufferNumber
>
{});
auto
&
beta_thread_buf
=
gamma_thread_buf
;
auto
&
y_thread_buf
=
x_thread_buf
;
// IO
auto
threadwise_mean_var_load_m_kblock
=
ThreadwiseTensorSliceTransfer_v2
<
MeanVarDataType
,
ComputeDataType
,
MeanVarGridDesc_M_KBlock
,
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
true
>
(
mean_var_grid_desc_m_kblock
,
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
));
auto
threadwise_count_load_m_kblock
=
ThreadwiseTensorSliceTransfer_v2
<
int32_t
,
int32_t
,
CountVarGridDesc_M_KBlock
,
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
true
>
(
count_grid_desc_m_kblock
,
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
ComputeDataType
,
XYGammaBetaGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcVectorDim
,
XSrcVectorSize
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_k_cluster_id
*
K_BlockTileSize
*
num_k_block_tile_iteration
+
thread_k_cluster_id
*
XSrcVectorSize
));
auto
threadwise_gamma_load
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
ComputeDataType
,
XYGammaBetaGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
1
,
true
>
(
gamma_grid_desc_m_k
,
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_k_cluster_id
*
K_BlockTileSize
*
num_k_block_tile_iteration
+
thread_k_cluster_id
*
GammaSrcVectorSize
));
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
ComputeDataType
,
XYGammaBetaGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
1
,
true
>
(
beta_grid_desc_m_k
,
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_k_cluster_id
*
K_BlockTileSize
*
num_k_block_tile_iteration
+
thread_k_cluster_id
*
BetaSrcVectorSize
));
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
YDataType
,
decltype
(
thread_buffer_desc_m_k
),
XYGammaBetaGridDesc_M_K
,
YElementwiseOperation
,
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
YDstVectorDim
,
YDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
y_grid_desc_m_k
,
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_k_cluster_id
*
K_BlockTileSize
*
num_k_block_tile_iteration
+
thread_k_cluster_id
*
YDstVectorSize
),
y_elementwise_op
);
// step1: Merge mean and variance
constexpr
auto
mean_var_count_thread_copy_step_I0_k
=
make_multi_index
(
I0
,
KThreadClusterSize
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
mean_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
var_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
welford_count_thread_buf
(
I
)
=
0
;
});
for
(
index_t
k
=
0
;
k
<
num_k_mean_var_count_iteration
;
++
k
)
{
threadwise_mean_var_load_m_kblock
.
Run
(
mean_var_grid_desc_m_kblock
,
mean_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
in_mean_thread_buf
);
threadwise_mean_var_load_m_kblock
.
Run
(
mean_var_grid_desc_m_kblock
,
var_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
in_var_thread_buf
);
threadwise_count_load_m_kblock
.
Run
(
count_grid_desc_m_kblock
,
welford_count_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
in_welford_count_thread_buf
);
ThreadwiseWelford
::
Run
(
in_mean_thread_buf
,
in_var_thread_buf
,
in_welford_count_thread_buf
,
mean_thread_buf
,
var_thread_buf
,
welford_count_thread_buf
);
threadwise_mean_var_load_m_kblock
.
MoveSrcSliceWindow
(
mean_var_grid_desc_m_kblock
,
mean_var_count_thread_copy_step_I0_k
);
threadwise_count_load_m_kblock
.
MoveSrcSliceWindow
(
count_grid_desc_m_kblock
,
mean_var_count_thread_copy_step_I0_k
);
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
welford_count_thread_buf
(
I
));
});
// step2: normalization
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileStepSize
);
for
(
index_t
k
=
0
;
k
<
num_k_mean_var_count_iteration
;
++
k
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
(
i
));
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
gamma_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
gamma_thread_buf
(
i
));
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK1
));
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
// gamma
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{});
});
});
});
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_beta_load
.
Run
(
beta_grid_desc_m_k
,
beta_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
beta_thread_buf
(
i
));
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK1
));
// beta
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{});
});
});
});
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_y_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
y_thread_buf
(
i
),
y_grid_desc_m_k
,
y_global_val_buf
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
}
// end for (normalization)
}
};
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment