Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1d7290fb
Commit
1d7290fb
authored
Nov 29, 2022
by
rocking
Browse files
Update interface
parent
003ec407
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
175 additions
and
124 deletions
+175
-124
example/21_gemm_layernorm/gemm_add_add_layernorm_xdl_fp16.cpp
...ple/21_gemm_layernorm/gemm_add_add_layernorm_xdl_fp16.cpp
+1
-1
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
.../device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+139
-70
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
...mm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
+35
-53
No files found.
example/21_gemm_layernorm/gemm_add_add_layernorm_xdl_fp16.cpp
View file @
1d7290fb
...
...
@@ -63,7 +63,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayern
//######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle|
//######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
LayernormThreadClusterSize_M_N, LayernormThreadSliceSize_M_N
<
ALayout
,
BLayout
,
DsLayout
,
HLayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
GammaDataType
,
BetaDataType
,
HDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
HElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
64
,
4
>
,
4
,
S
<
8
,
32
>
,
S
<
1
,
8
>
,
1
,
8
,
8
,
8
,
8
,
1
>
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
1d7290fb
...
...
@@ -13,8 +13,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp"
// #include
// "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "device_base.hpp"
...
...
@@ -103,23 +102,56 @@ __global__ void
#endif
}
// template <typename GridwiseWelfordLayernorm,
// typename EDataType,
// typename HDataType,
// typename MeanDataType,
// typename VarDataType>
// __global__ void
// #if CK_USE_LAUNCH_BOUNDS
// __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
// #endif
// kernel_welford_layernorm2d_second_half(const EDataType* __restrict__ p_x_grid,
// const MeanDataType* __restrict__ p_mean_grid,
// const VarDataType* __restrict__ p_var_grid,
// HDataType* __restrict__ p_y_grid,
// index_t blkgroup_size)
// {
// // GridwiseWelfordLayernorm::Run(p_x_grid, p_mean_grid, p_var_grid, p_y_grid, blkgroup_size);
// }
template
<
typename
GridwiseWelfordLayernorm
,
typename
EDataType
,
typename
HDataType
,
typename
MeanDataType
,
typename
VarDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
EHGridDesc_M_N
,
typename
MeanVarCountGridDesc_M_N
,
typename
GammaBetaGridDesc_N
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_welford_layernorm2d_second_half
(
const
EDataType
*
__restrict__
p_e_grid
,
const
MeanDataType
*
__restrict__
p_in_welford_mean_grid
,
const
VarDataType
*
__restrict__
p_in_welford_var_grid
,
const
int32_t
*
__restrict__
p_in_welford_count_grid
,
const
GammaDataType
*
__restrict__
p_gamma_grid
,
const
BetaDataType
*
__restrict__
p_beta_grid
,
HDataType
*
__restrict__
p_h_grid
,
const
EHGridDesc_M_N
&
e_grid_desc_m_n
,
const
EHGridDesc_M_N
&
h_grid_desc_m_n
,
const
MeanVarCountGridDesc_M_N
&
mean_var_count_grid_desc_m_n
,
const
GammaBetaGridDesc_N
&
gamma_grid_desc_n
,
const
GammaBetaGridDesc_N
&
beta_grid_desc_n
,
index_t
blkgroup_size
,
index_t
num_mean_var_count_k_block_tile_iteration
,
index_t
num_xy_k_block_tile_iteration
,
ComputeDataType
epsilon
)
{
GridwiseWelfordLayernorm
::
Run
(
p_e_grid
,
p_in_welford_mean_grid
,
p_in_welford_var_grid
,
p_in_welford_count_grid
,
p_gamma_grid
,
p_beta_grid
,
p_h_grid
,
e_grid_desc_m_n
,
h_grid_desc_m_n
,
mean_var_count_grid_desc_m_n
,
gamma_grid_desc_n
,
beta_grid_desc_n
,
blkgroup_size
,
num_mean_var_count_k_block_tile_iteration
,
num_xy_k_block_tile_iteration
,
epsilon
);
}
}
// namespace ck
...
...
@@ -204,6 +236,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
using
LayernormBlockTileSize_M_N
=
Sequence
<
LayernormThreadClusterSize_M_N
::
At
(
0
)
*
LayernormThreadSliceSize_M_N
::
At
(
0
),
LayernormThreadClusterSize_M_N
::
At
(
1
)
*
LayernormThreadSliceSize_M_N
::
At
(
1
)
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -330,11 +366,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
(
1
,
1
,
1
));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
(
1
,
1
,
1
));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
using
MeanVarCountGridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
using
GammaBetaGridDesc_N
=
decltype
(
MakeDescriptor_N
(
1
));
using
MeanVarGridDesc_M
=
decltype
(
MakeDescriptor_M
(
1
));
using
HGridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
HLayout
>
(
1
,
1
,
1
));
using
EHGridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
HLayout
>
(
1
,
1
,
1
));
using
GridwiseGemmWelford
=
GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
...
...
@@ -351,7 +385,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
AGridDesc_M_K
,
BGridDesc_N_K
,
DsGridDesc_M_N
,
EGridDesc_M_N
,
E
H
GridDesc_M_N
,
MeanVarCountGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
...
...
@@ -388,27 +422,28 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using
Block2ETileMap
=
typename
GridwiseGemmWelford
::
DefaultBlock2ETileMap
;
// using GridwiseWelfordLayernorm =
// GridwiseWelfordSecondHalfLayernorm2d<EDataType,
// HDataType,
// MeanDataType,
// VarDataType,
// AccDataType,
// HGridDesc_M_N,
// MeanVarGridDesc_M_N,
// GammaBetaGridDesc_N,
// MeanVarGridDesc_M,
// BlockSize,
// LayernormThreadClusterSize_M_N::At(I0),
// LayernormThreadClusterSize_M_N::At(I1),
// LayernormThreadSliceSize_M_N::At(I0),
// LayernormThreadSliceSize_M_N::At(I1),
// LayernormESrcHDstVectorDim,
// LayernormESrcVectorSize,
// LayernormHDstVectorSize,
// LayernormGammaSrcVectorSize,
// LayernormBetaSrcVectorSize,
// LayernormMeanVarSrcDstVectorSize>;
using
GridwiseWelfordLayernorm
=
GridwiseWelfordSecondHalfLayernorm2d
<
EDataType
,
HDataType
,
MeanDataType
,
VarDataType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
EHGridDesc_M_N
,
MeanVarCountGridDesc_M_N
,
GammaBetaGridDesc_N
,
BlockSize
,
LayernormThreadClusterSize_M_N
::
At
(
I0
),
LayernormThreadClusterSize_M_N
::
At
(
I1
),
LayernormThreadSliceSize_M_N
::
At
(
I0
),
LayernormThreadSliceSize_M_N
::
At
(
I1
),
LayernormESrcHDstVectorDim
,
LayernormESrcVectorSize
,
LayernormHDstVectorSize
,
LayernormGammaSrcVectorSize
,
LayernormBetaSrcVectorSize
,
LayernormMeanVarSrcDstVectorSize
>
;
// Argument
struct
Argument
:
public
BaseArgument
...
...
@@ -449,20 +484,24 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
gamma_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
beta_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
h_grid_desc_m_n_
{
DeviceOp
::
MakeGridDescriptor_M_N
<
HLayout
>
(
MRaw
,
NRaw
,
StrideH
)},
a_grid_desc_ak0_m_ak1_
{
GridwiseGemmWelford
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
)},
b_grid_desc_bk0_n_bk1_
{
GridwiseGemmWelford
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k_
)},
block_2_etile_map_
{
GridwiseGemmWelford
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
h_element_op_
{
h_element_op
},
blkGroupSize
_
{
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)},
gemm_nblock
_
{
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)},
epsilon_
{
epsilon
}
{
mean_var_count_grid_desc_m_n_
=
DeviceOp
::
MakeGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
blkGroupSize_
,
blkGroupSize
_
);
DeviceOp
::
MakeGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
gemm_nblock_
,
gemm_nblock
_
);
hip_check_error
(
hipMalloc
(
&
p_e_grid_
,
sizeof
(
EDataType
)
*
MRaw
*
NRaw
));
int
gemm_welford_size
=
MRaw
*
blkGroupSize
_
;
int
gemm_welford_size
=
MRaw
*
gemm_nblock
_
;
hip_check_error
(
hipMalloc
(
&
p_welford_mean_grid_
,
sizeof
(
MeanDataType
)
*
gemm_welford_size
));
hip_check_error
(
...
...
@@ -502,8 +541,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
GridwiseGemmWelford
::
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
mean_var_count_grid_desc_m_n_
);
}
// TODO - H
}
void
Print
()
const
...
...
@@ -533,11 +570,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
E
H
GridDesc_M_N
e_grid_desc_m_n_
;
MeanVarCountGridDesc_M_N
mean_var_count_grid_desc_m_n_
;
GammaBetaGridDesc_N
gamma_grid_desc_n_
;
GammaBetaGridDesc_N
beta_grid_desc_n_
;
HGridDesc_M_N
h_grid_desc_m_n_
;
E
HGridDesc_M_N
h_grid_desc_m_n_
;
// tensor descriptors for block/thread-wise copy
typename
GridwiseGemmWelford
::
DefaultAGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
...
...
@@ -558,7 +595,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
CDEElementwiseOperation
cde_element_op_
;
HElementwiseOperation
h_element_op_
;
int
blkGroupSize
_
;
int
gemm_nblock
_
;
AccDataType
epsilon_
;
};
...
...
@@ -581,9 +618,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
throw
std
::
runtime_error
(
"wrong! GridwiseGemmWelford has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
const
auto
M
=
arg
.
h_grid_desc_m_n_
.
GetLength
(
I0
);
const
auto
N
=
arg
.
h_grid_desc_m_n_
.
GetLength
(
I1
);
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
...
...
@@ -612,12 +650,18 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
typename
GridwiseGemmWelford
::
DefaultBlock2ETileMap
,
has_main_loop
>
;
// const auto kernel_welford_layernorm =
// kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
// EDataType,
// HDataType,
// MeanDataType,
// VarDataType>;
const
auto
kernel_welford_layernorm
=
kernel_welford_layernorm2d_second_half
<
GridwiseWelfordLayernorm
,
EDataType
,
HDataType
,
MeanDataType
,
VarDataType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
EHGridDesc_M_N
,
MeanVarCountGridDesc_M_N
,
GammaBetaGridDesc_N
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
...
...
@@ -642,16 +686,39 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg
.
mean_var_count_grid_desc_mblock_mperblock_nblock_
,
arg
.
block_2_etile_map_
);
// avg_time += launch_and_time_kernel(stream_config,
// kernel_welford_layernorm,
// dim3(grid_size),
// dim3(BlockSize),
// 0,
// arg.p_e_grid_,
// arg.p_welford_mean_grid_,
// arg.p_welford_var_grid_,
// arg.p_h_grid_,
// arg.blkGroupSize_);
grid_size
=
math
::
integer_least_multiple
(
M
,
LayernormBlockTileSize_M_N
::
At
(
0
))
/
LayernormBlockTileSize_M_N
::
At
(
0
);
index_t
numMeanVarCountBlockTileIteration_N
=
math
::
integer_least_multiple
(
arg
.
gemm_nblock_
,
LayernormThreadClusterSize_M_N
::
At
(
I1
))
/
LayernormThreadClusterSize_M_N
::
At
(
I1
);
index_t
numXBlockTileIteration_N
=
math
::
integer_least_multiple
(
N
,
LayernormBlockTileSize_M_N
::
At
(
I1
))
/
LayernormBlockTileSize_M_N
::
At
(
I1
);
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_welford_layernorm
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_e_grid_
,
arg
.
p_welford_mean_grid_
,
arg
.
p_welford_var_grid_
,
arg
.
p_welford_count_grid_
,
arg
.
p_gamma_grid_
,
arg
.
p_beta_grid_
,
arg
.
p_h_grid_
,
arg
.
e_grid_desc_m_n_
,
arg
.
h_grid_desc_m_n_
,
arg
.
mean_var_count_grid_desc_m_n_
,
arg
.
gamma_grid_desc_n_
,
arg
.
beta_grid_desc_n_
,
arg
.
gemm_nblock_
,
numMeanVarCountBlockTileIteration_N
,
numXBlockTileIteration_N
,
arg
.
epsilon_
);
return
avg_time
;
};
...
...
@@ -681,6 +748,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
return
false
;
}
// TODO
return
true
;
}
...
...
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
View file @
1d7290fb
...
...
@@ -23,25 +23,26 @@ template <typename EDataType,
typename
HDataType
,
typename
MeanDataType
,
typename
VarDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
XY
GridDesc_M_N
,
typename
MeanVarGridDesc_M_N
,
typename
EH
GridDesc_M_N
,
typename
MeanVar
Count
GridDesc_M_N
,
typename
GammaBetaGridDesc_N
,
typename
MeanVarGridDesc_M
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
NThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
NThreadSliceSize
,
index_t
X
SrcYDstVectorDim
,
index_t
X
SrcVectorSize
,
index_t
E
SrcYDstVectorDim
,
index_t
E
SrcVectorSize
,
index_t
YDstVectorSize
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
MeanVarSrcDstVectorSize
>
struct
GridwiseWelfordSecondHalfLayernorm2d
{
static
constexpr
bool
reorder_thread_cluster
=
(
X
SrcYDstVectorDim
==
0
);
static
constexpr
bool
reorder_thread_cluster
=
(
E
SrcYDstVectorDim
==
0
);
using
ThreadClusterLengths_M_N
=
Sequence
<
MThreadClusterSize
,
NThreadClusterSize
>
;
...
...
@@ -76,57 +77,38 @@ struct GridwiseWelfordSecondHalfLayernorm2d
static
constexpr
index_t
N_BlockTileSize
=
NThreadClusterSize
*
NThreadSliceSize
;
__device__
static
void
Run
(
const
EDataType
*
__restrict__
p_e_grid
,
const
MeanDataType
*
__restrict__
p_mean_grid
,
const
VarDataType
*
__restrict__
p_var_grid
,
const
MeanDataType
*
__restrict__
p_in_welford_mean_grid
,
const
VarDataType
*
__restrict__
p_in_welford_var_grid
,
const
int32_t
*
__restrict__
p_in_welford_count_grid
,
const
GammaDataType
*
__restrict__
p_gamma_grid
,
const
BetaDataType
*
__restrict__
p_beta_grid
,
HDataType
*
__restrict__
p_h_grid
,
/*const MeanVarGridDesc_M_N& mean_grid_desc_m_k,
const MeanVarGridDesc_M_N& var_grid_desc_m_k,
const GammaBetaGridDesc_N& gamma_grid_desc_m,
const GammaBetaGridDesc_N& beta_grid_desc_m,
const MeanVarGridDesc_M& mean_var_grid_desc_m,*/
index_t
blkgroup_size
)
const
EHGridDesc_M_N
&
e_grid_desc_m_n
,
const
EHGridDesc_M_N
&
h_grid_desc_m_n
,
const
MeanVarCountGridDesc_M_N
&
mean_var_count_grid_desc_m_n
,
const
GammaBetaGridDesc_N
&
gamma_grid_desc_n
,
const
GammaBetaGridDesc_N
&
beta_grid_desc_n
,
index_t
gemm_nblock_
,
index_t
num_mean_var_count_k_block_tile_iteration
,
index_t
num_xy_k_block_tile_iteration
,
ComputeDataType
epsilon
)
{
ignore
=
p_e_grid
;
ignore
=
p_mean_grid
;
ignore
=
p_var_grid
;
ignore
=
p_in_welford_mean_grid
;
ignore
=
p_in_welford_var_grid
;
ignore
=
p_in_welford_count_grid
;
ignore
=
p_gamma_grid
;
ignore
=
p_beta_grid
;
ignore
=
p_h_grid
;
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
blkgroup_size
;
const
index_t
block_local_id
=
block_global_id
%
blkgroup_size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_n_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_N
=
Sequence
<
MThreadSliceSize
,
NThreadSliceSize
>
;
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
constexpr
auto
thread_buffer_desc_m_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
NThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m_1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{}));
/*
auto threadwise_mean_load_m_n =
ThreadwiseTensorSliceTransfer_v2<MeanDataType,
ComputeDataType,
MeanVarGridDesc_M_N,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
mean_grid_desc_m_n,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id * 1));*/
ignore
=
e_grid_desc_m_n
;
ignore
=
h_grid_desc_m_n
;
ignore
=
mean_var_count_grid_desc_m_n
;
ignore
=
gamma_grid_desc_n
;
ignore
=
beta_grid_desc_n
;
ignore
=
gemm_nblock_
;
ignore
=
num_mean_var_count_k_block_tile_iteration
;
ignore
=
num_xy_k_block_tile_iteration
;
ignore
=
epsilon
;
}
// run
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment