Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
003ec407
Commit
003ec407
authored
Nov 28, 2022
by
rocking
Browse files
Add welford count
parent
b7f500f0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
197 additions
and
169 deletions
+197
-169
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
.../device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+106
-114
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
...dwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
+91
-55
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
003ec407
...
...
@@ -13,7 +13,8 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
// #include
// "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "device_base.hpp"
...
...
@@ -33,8 +34,7 @@ template <typename GridwiseGemmWelford,
typename
BGridDesc_BK0_N_BK1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
MeanGridDescriptor_MBlock_MPerBlock_NBlock
,
typename
VarGridDescriptor_MBlock_MPerBlock_NBlock
,
typename
MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
,
typename
Block2ETileMap
,
bool
HasMainKBlockLoop
>
__global__
void
...
...
@@ -46,8 +46,9 @@ __global__ void
const
ABDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
MeanDataType
*
__restrict__
p_mean_grid
,
VarDataType
*
__restrict__
p_var_grid
,
MeanDataType
*
__restrict__
p_welford_mean_grid
,
VarDataType
*
__restrict__
p_welford_var_grid
,
int32_t
*
__restrict__
p_welford_count_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
...
...
@@ -57,8 +58,8 @@ __global__ void
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
MeanGridDescriptor_MBlock_MPerBlock_NBlock
mean_grid_desc_mblock_mperblock_nblock
,
const
VarGridDescriptor_MBlock_MPerBlock_NBlock
var
_grid_desc_mblock_mperblock_nblock
,
const
Mean
VarCount
GridDescriptor_MBlock_MPerBlock_NBlock
mean_var_count
_grid_desc_mblock_mperblock_nblock
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...
...
@@ -69,8 +70,9 @@ __global__ void
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_mean_grid
,
p_var_grid
,
p_welford_mean_grid
,
p_welford_var_grid
,
p_welford_count_grid
,
p_shared
,
a_element_op
,
b_element_op
,
...
...
@@ -79,16 +81,16 @@ __global__ void
b_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
mean_grid_desc_mblock_mperblock_nblock
,
var_grid_desc_mblock_mperblock_nblock
,
mean_var_count_grid_desc_mblock_mperblock_nblock
,
block_2_etile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
p_mean_grid
;
ignore
=
p_var_grid
;
ignore
=
p_welford_mean_grid
;
ignore
=
p_welford_var_grid
;
ignore
=
p_welford_count_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
...
...
@@ -96,29 +98,28 @@ __global__ void
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
mean_grid_desc_mblock_mperblock_nblock
;
ignore
=
var_grid_desc_mblock_mperblock_nblock
;
ignore
=
mean_var_count_grid_desc_mblock_mperblock_nblock
;
ignore
=
block_2_etile_map
;
#endif
}
template
<
typename
GridwiseWelfordLayernorm
,
typename
EDataType
,
typename
HDataType
,
typename
MeanDataType
,
typename
VarDataType
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_welford_layernorm2d_second_half
(
const
EDataType
*
__restrict__
p_x_grid
,
const
MeanDataType
*
__restrict__
p_mean_grid
,
const
VarDataType
*
__restrict__
p_var_grid
,
HDataType
*
__restrict__
p_y_grid
,
index_t
blkgroup_size
)
{
GridwiseWelfordLayernorm
::
Run
(
p_x_grid
,
p_mean_grid
,
p_var_grid
,
p_y_grid
,
blkgroup_size
);
}
//
template <typename GridwiseWelfordLayernorm,
//
typename EDataType,
//
typename HDataType,
//
typename MeanDataType,
//
typename VarDataType>
//
__global__ void
//
#if CK_USE_LAUNCH_BOUNDS
//
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
//
#endif
//
kernel_welford_layernorm2d_second_half(const EDataType* __restrict__ p_x_grid,
//
const MeanDataType* __restrict__ p_mean_grid,
//
const VarDataType* __restrict__ p_var_grid,
//
HDataType* __restrict__ p_y_grid,
//
index_t blkgroup_size)
//
{
//
//
GridwiseWelfordLayernorm::Run(p_x_grid, p_mean_grid, p_var_grid, p_y_grid, blkgroup_size);
//
}
}
// namespace ck
...
...
@@ -326,14 +327,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
}
};
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
(
1
,
1
,
1
));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
(
1
,
1
,
1
));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
using
MeanVarGridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
using
GammaBetaGridDesc_N
=
decltype
(
MakeDescriptor_N
(
1
));
using
MeanVarGridDesc_M
=
decltype
(
MakeDescriptor_M
(
1
));
using
HGridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
HLayout
>
(
1
,
1
,
1
));
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
(
1
,
1
,
1
));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
(
1
,
1
,
1
));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
using
MeanVar
Count
GridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
using
GammaBetaGridDesc_N
=
decltype
(
MakeDescriptor_N
(
1
));
using
MeanVarGridDesc_M
=
decltype
(
MakeDescriptor_M
(
1
));
using
HGridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
HLayout
>
(
1
,
1
,
1
));
using
GridwiseGemmWelford
=
GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
...
...
@@ -351,8 +352,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K
,
DsGridDesc_M_N
,
EGridDesc_M_N
,
MeanVarGridDesc_M_N
,
MeanVarGridDesc_M_N
,
MeanVarCountGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
...
...
@@ -384,32 +384,31 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
CShuffleNXdlPerWavePerShuffle
,
PostShuffleThreadClusterSize_M_N
,
PostShuffleScalarPerVector
,
1
,
LoopSched
>
;
using
Block2ETileMap
=
typename
GridwiseGemmWelford
::
DefaultBlock2ETileMap
;
using
GridwiseWelfordLayernorm
=
GridwiseWelfordSecondHalfLayernorm2d
<
EDataType
,
HDataType
,
MeanDataType
,
VarDataType
,
AccDataType
,
HGridDesc_M_N
,
MeanVarGridDesc_M_N
,
GammaBetaGridDesc_N
,
MeanVarGridDesc_M
,
BlockSize
,
LayernormThreadClusterSize_M_N
::
At
(
I0
),
LayernormThreadClusterSize_M_N
::
At
(
I1
),
LayernormThreadSliceSize_M_N
::
At
(
I0
),
LayernormThreadSliceSize_M_N
::
At
(
I1
),
LayernormESrcHDstVectorDim
,
LayernormESrcVectorSize
,
LayernormHDstVectorSize
,
LayernormGammaSrcVectorSize
,
LayernormBetaSrcVectorSize
,
LayernormMeanVarSrcDstVectorSize
>
;
//
using GridwiseWelfordLayernorm =
//
GridwiseWelfordSecondHalfLayernorm2d<EDataType,
//
HDataType,
//
MeanDataType,
//
VarDataType,
//
AccDataType,
//
HGridDesc_M_N,
//
MeanVarGridDesc_M_N,
//
GammaBetaGridDesc_N,
//
MeanVarGridDesc_M,
//
BlockSize,
//
LayernormThreadClusterSize_M_N::At(I0),
//
LayernormThreadClusterSize_M_N::At(I1),
//
LayernormThreadSliceSize_M_N::At(I0),
//
LayernormThreadSliceSize_M_N::At(I1),
//
LayernormESrcHDstVectorDim,
//
LayernormESrcVectorSize,
//
LayernormHDstVectorSize,
//
LayernormGammaSrcVectorSize,
//
LayernormBetaSrcVectorSize,
//
LayernormMeanVarSrcDstVectorSize>;
// Argument
struct
Argument
:
public
BaseArgument
...
...
@@ -436,8 +435,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
p_e_grid_
{
nullptr
},
p_mean_grid_
{
nullptr
},
p_var_grid_
{
nullptr
},
p_welford_mean_grid_
{
nullptr
},
p_welford_var_grid_
{
nullptr
},
p_welford_count_grid_
{
nullptr
},
p_gamma_grid_
{
static_cast
<
const
GammaDataType
*>
(
p_gamma_grid
)},
p_beta_grid_
{
static_cast
<
const
BetaDataType
*>
(
p_beta_grid
)},
p_h_grid_
{
static_cast
<
HDataType
*>
(
p_h_grid
)},
...
...
@@ -445,8 +445,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
KRaw
,
NRaw
,
StrideB
)},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
NRaw
,
StrideH
)},
mean_grid_desc_m_n_
{},
var_grid_desc_m_n_
{},
mean_var_count_grid_desc_m_n_
{},
gamma_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
beta_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
h_grid_desc_m_n_
{
DeviceOp
::
MakeGridDescriptor_M_N
<
HLayout
>
(
MRaw
,
NRaw
,
StrideH
)},
...
...
@@ -458,16 +457,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
blkGroupSize_
{
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)},
epsilon_
{
epsilon
}
{
mean_grid_desc_m_n_
=
DeviceOp
::
MakeGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
blkGroupSize_
,
blkGroupSize_
);
var_grid_desc_m_n_
=
mean_var_count_grid_desc_m_n_
=
DeviceOp
::
MakeGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
blkGroupSize_
,
blkGroupSize_
);
hip_check_error
(
hipMalloc
(
&
p_e_grid_
,
sizeof
(
EDataType
)
*
MRaw
*
NRaw
));
int
gemm_welford_size
=
MRaw
*
blkGroupSize_
;
hip_check_error
(
hipMalloc
(
&
p_mean_grid_
,
sizeof
(
MeanDataType
)
*
gemm_welford_size
));
hip_check_error
(
hipMalloc
(
&
p_var_grid_
,
sizeof
(
VarDataType
)
*
gemm_welford_size
));
hip_check_error
(
hipMalloc
(
&
p_welford_mean_grid_
,
sizeof
(
MeanDataType
)
*
gemm_welford_size
));
hip_check_error
(
hipMalloc
(
&
p_welford_var_grid_
,
sizeof
(
VarDataType
)
*
gemm_welford_size
));
hip_check_error
(
hipMalloc
(
&
p_welford_count_grid_
,
sizeof
(
int32_t
)
*
gemm_welford_size
));
// populate pointer, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
...
...
@@ -487,8 +487,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
mean_grid_desc_m_n_
,
var_grid_desc_m_n_
,
mean_var_count_grid_desc_m_n_
,
block_2_etile_map_
))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
...
...
@@ -499,13 +498,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
GridwiseGemmWelford
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
mean_grid_desc_mblock_mperblock_nblock_
=
GridwiseGemmWelford
::
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock
(
mean_grid_desc_m_n_
);
var_grid_desc_mblock_mperblock_nblock_
=
GridwiseGemmWelford
::
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock
(
var_grid_desc_m_n_
);
mean_var_count_grid_desc_mblock_mperblock_nblock_
=
GridwiseGemmWelford
::
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
mean_var_count_grid_desc_m_n_
);
}
// TODO - H
...
...
@@ -527,8 +522,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
const
BDataType
*
p_b_grid_
;
typename
GridwiseGemmWelford
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
MeanDataType
*
p_mean_grid_
;
// mean
VarDataType
*
p_var_grid_
;
// variance * count
MeanDataType
*
p_welford_mean_grid_
;
VarDataType
*
p_welford_var_grid_
;
int32_t
*
p_welford_count_grid_
;
const
GammaDataType
*
p_gamma_grid_
;
const
BetaDataType
*
p_beta_grid_
;
HDataType
*
p_h_grid_
;
...
...
@@ -538,8 +534,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K
b_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
MeanVarGridDesc_M_N
mean_grid_desc_m_n_
;
MeanVarGridDesc_M_N
var_grid_desc_m_n_
;
MeanVarCountGridDesc_M_N
mean_var_count_grid_desc_m_n_
;
GammaBetaGridDesc_N
gamma_grid_desc_n_
;
GammaBetaGridDesc_N
beta_grid_desc_n_
;
HGridDesc_M_N
h_grid_desc_m_n_
;
...
...
@@ -551,10 +546,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemmWelford
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemmWelford
::
MeanGridDescriptor_MBlock_MPerBlock_NBlock
mean_grid_desc_mblock_mperblock_nblock_
;
typename
GridwiseGemmWelford
::
VarGridDescriptor_MBlock_MPerBlock_NBlock
var_grid_desc_mblock_mperblock_nblock_
;
typename
GridwiseGemmWelford
::
MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
mean_var_count_grid_desc_mblock_mperblock_nblock_
;
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
...
...
@@ -582,8 +575,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
mean_grid_desc_m_n_
,
arg
.
var_grid_desc_m_n_
,
arg
.
mean_var_count_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmWelford has invalid setting"
);
...
...
@@ -615,17 +607,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemmWelford
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemmWelford
::
MeanGridDescriptor_MBlock_MPerBlock_NBlock
,
typename
GridwiseGemmWelford
::
Var
GridDescriptor_MBlock_MPerBlock_NBlock
,
typename
GridwiseGemmWelford
::
MeanVarCount
GridDescriptor_MBlock_MPerBlock_NBlock
,
typename
GridwiseGemmWelford
::
DefaultBlock2ETileMap
,
has_main_loop
>
;
const
auto
kernel_welford_layernorm
=
kernel_welford_layernorm2d_second_half
<
GridwiseWelfordLayernorm
,
EDataType
,
HDataType
,
MeanDataType
,
VarDataType
>
;
//
const auto kernel_welford_layernorm =
//
kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
//
EDataType,
//
HDataType,
//
MeanDataType,
//
VarDataType>;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
...
...
@@ -637,8 +629,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
p_mean_grid_
,
arg
.
p_var_grid_
,
arg
.
p_welford_mean_grid_
,
arg
.
p_welford_var_grid_
,
arg
.
p_welford_count_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
...
...
@@ -646,20 +639,19 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
mean_grid_desc_mblock_mperblock_nblock_
,
arg
.
var_grid_desc_mblock_mperblock_nblock_
,
arg
.
mean_var_count_grid_desc_mblock_mperblock_nblock_
,
arg
.
block_2_etile_map_
);
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_welford_layernorm
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_e_grid_
,
arg
.
p_mean_grid_
,
arg
.
p_var_grid_
,
arg
.
p_h_grid_
,
arg
.
blkGroupSize_
);
//
avg_time += launch_and_time_kernel(stream_config,
//
kernel_welford_layernorm,
//
dim3(grid_size),
//
dim3(BlockSize),
//
0,
//
arg.p_e_grid_,
//
arg.p_
welford_
mean_grid_,
//
arg.p_
welford_
var_grid_,
//
arg.p_h_grid_,
//
arg.blkGroupSize_);
return
avg_time
;
};
...
...
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
View file @
003ec407
...
...
@@ -47,8 +47,7 @@ template <typename ABDataType,
typename
BGridDesc_N_K
,
typename
DsGridDesc_M_N
,
typename
EGridDesc_M_N
,
typename
MeanGridDesc_M_N
,
typename
VarGridDesc_M_N
,
typename
MeanVarCountGridDesc_M_N
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
...
...
@@ -80,7 +79,6 @@ template <typename ABDataType,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
PostShuffleThreadClusterSize_M_N
,
index_t
PostShuffleScalarPerVector
,
index_t
MeanVarTransferScalarPerVector
,
LoopScheduler
LoopSched
>
struct
GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
{
...
...
@@ -242,10 +240,10 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
Number
<
NumDTensor
>
{});
}
// TODO - MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
// TODO - MakeMeanVar
Count
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
template
<
typename
GridDescriptor_M_N
>
__host__
__device__
static
constexpr
auto
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock
(
const
GridDescriptor_M_N
&
grid_desc_m_n
)
MakeMeanVar
Count
GridDescriptor_MBlock_MPerBlock_NBlock
(
const
GridDescriptor_M_N
&
grid_desc_m_n
)
{
const
auto
M
=
grid_desc_m_n
.
GetLength
(
I0
);
const
auto
NBlock
=
grid_desc_m_n
.
GetLength
(
I1
);
...
...
@@ -276,8 +274,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
MeanGridDesc_M_N
&
mean_grid_desc_m_n
,
const
VarGridDesc_M_N
&
var_grid_desc_m_n
,
const
MeanVarCountGridDesc_M_N
&
mean_var_count_grid_desc_m_n
,
const
Block2ETileMap
&
block_2_etile_map
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
...
...
@@ -290,9 +287,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// check consistency of desc
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)
&&
M
==
mean_grid_desc_m_n
.
GetLength
(
I0
)
&&
M
==
var_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
/
NPerBlock
==
mean_grid_desc_m_n
.
GetLength
(
I1
)
&&
N
/
NPerBlock
==
var_grid_desc_m_n
.
GetLength
(
I1
)))
M
==
mean_var_count_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
/
NPerBlock
==
mean_var_count_grid_desc_m_n
.
GetLength
(
I1
)))
{
return
false
;
}
...
...
@@ -356,10 +352,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
remove_cvref_t
<
decltype
(
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
using
MeanGridDescriptor_MBlock_MPerBlock_NBlock
=
remove_cvref_t
<
decltype
(
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock
(
MeanGridDesc_M_N
{}))
>
;
using
VarGridDescriptor_MBlock_MPerBlock_NBlock
=
remove_cvref_t
<
decltype
(
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock
(
VarGridDesc_M_N
{}))
>
;
using
MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
=
remove_cvref_t
<
decltype
(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
MeanVarCountGridDesc_M_N
{}))
>
;
using
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
...
...
@@ -372,26 +366,26 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
Block2ETileMap
>
__device__
static
void
Run
(
const
ABDataType
*
__restrict__
p_
a
_grid
,
const
ABDataType
*
__restrict__
p_
b
_grid
,
DsGridPointer
p_
ds
_grid
,
E
DataType
*
__restrict__
p_
e
_grid
,
Mean
DataType
*
__restrict__
p_
mean
_grid
,
VarDataType
*
__restrict__
p_
var_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
MeanGridDescriptor_MBlock_MPerBlock_NBlock
&
mean_grid_desc_mblock_mperblock_nblock
,
const
VarGridDescriptor_MBlock_MPerBlock_NBlock
&
var
_grid_desc_mblock_mperblock_nblock
,
const
Block2ETileMap
&
block_2_etile_map
)
__device__
static
void
Run
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_
b
_grid
,
DsGridPointer
p_
ds
_grid
,
EDataType
*
__restrict__
p_
e
_grid
,
Mean
DataType
*
__restrict__
p_
welford_mean
_grid
,
Var
DataType
*
__restrict__
p_
welford_var
_grid
,
int32_t
*
__restrict__
p_
welford_count
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Mean
VarCount
GridDescriptor_MBlock_MPerBlock_NBlock
&
mean_var_count
_grid_desc_mblock_mperblock_nblock
,
const
Block2ETileMap
&
block_2_etile_map
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
...
...
@@ -411,10 +405,16 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
p_e_grid
,
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
mean_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_mean_grid
,
mean_grid_desc_mblock_mperblock_nblock
.
GetElementSpaceSize
());
p_welford_mean_grid
,
mean_var_count_grid_desc_mblock_mperblock_nblock
.
GetElementSpaceSize
());
auto
var_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_var_grid
,
var_grid_desc_mblock_mperblock_nblock
.
GetElementSpaceSize
());
p_welford_var_grid
,
mean_var_count_grid_desc_mblock_mperblock_nblock
.
GetElementSpaceSize
());
auto
welford_count_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_welford_count
,
mean_var_count_grid_desc_mblock_mperblock_nblock
.
GetElementSpaceSize
());
// divide block work by [M, N]
const
auto
block_work_idx
=
...
...
@@ -871,9 +871,14 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
decltype
(
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
>
(
thread_welford_dst_desc_m
.
GetElementSpaceSize
()));
using
welford_count_vgpr_type
=
decltype
(
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
>
(
thread_welford_dst_desc_m
.
GetElementSpaceSize
()));
Array
<
ThreadwiseWelford
,
num_shuffleM
>
threadwise_welfords
;
Array
<
mean_var_vgpr_type
,
num_shuffleM
>
mean_thread_bufs
;
Array
<
mean_var_vgpr_type
,
num_shuffleM
>
var_thread_bufs
;
Array
<
welford_count_vgpr_type
,
num_shuffleM
>
welford_count_thread_bufs
;
static_for
<
0
,
num_shuffleM
,
1
>
{}([
&
](
auto
i
)
{
// TODO - padding
...
...
@@ -884,9 +889,13 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
var_thread_bufs
(
i
)
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
>
(
thread_welford_dst_desc_m
.
GetElementSpaceSize
());
welford_count_thread_bufs
(
i
)
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
>
(
thread_welford_dst_desc_m
.
GetElementSpaceSize
());
static_for
<
0
,
PostShuffleThreadSliceSize_M
,
1
>
{}([
&
](
auto
j
)
{
mean_thread_bufs
(
i
)(
j
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
var_thread_bufs
(
i
)(
j
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
mean_thread_bufs
(
i
)(
j
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
var_thread_bufs
(
i
)(
j
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
welford_count_thread_bufs
(
i
)(
j
)
=
0
;
});
});
...
...
@@ -982,13 +991,14 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// Blockwise welford and write out
static_for
<
0
,
num_shuffleM
,
1
>
{}([
&
](
auto
i
)
{
auto
&
mean_thread_buf
=
mean_thread_bufs
(
i
);
auto
&
var_thread_buf
=
var_thread_bufs
(
i
);
int
count
=
threadwise_welfords
(
i
).
cur_count_
;
auto
&
mean_thread_buf
=
mean_thread_bufs
(
i
);
auto
&
var_thread_buf
=
var_thread_bufs
(
i
);
auto
&
count
_thread_buf
=
welford_count_thread_bufs
(
i
)
;
static_for
<
0
,
PostShuffleThreadSliceSize_M
,
1
>
{}([
&
](
auto
j
)
{
block_sync_lds
();
BlockwiseWelford
::
Run
(
mean_thread_buf
(
j
),
var_thread_buf
(
j
),
count
);
BlockwiseWelford
::
Run
(
mean_thread_buf
(
j
),
var_thread_buf
(
j
),
count_thread_buf
(
j
));
});
constexpr
auto
thread_welford_desc_I_m_I
=
make_naive_tensor_descriptor_packed
(
...
...
@@ -997,20 +1007,19 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
constexpr
int
shuffleMPerBlock
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I1
);
static_assert
(
PostShuffleThreadSliceSize_M
%
MeanVarTransferScalarPerVector
==
0
);
auto
mean_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
MeanDataType
,
decltype
(
thread_welford_desc_I_m_I
),
decltype
(
mean_grid_desc_mblock_mperblock_nblock
),
decltype
(
mean_
var_count_
grid_desc_mblock_mperblock_nblock
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
PostShuffleThreadSliceSize_M
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
1
,
MeanVarTransferScalarPerVector
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
false
>
{
mean_grid_desc_mblock_mperblock_nblock
,
false
>
{
mean_
var_count_
grid_desc_mblock_mperblock_nblock
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
shuffleMPerBlock
*
i
+
post_shuffle_thread_data_idx_begin
[
I0
],
// mperblock
...
...
@@ -1021,32 +1030,59 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
AccDataType
,
VarDataType
,
decltype
(
thread_welford_desc_I_m_I
),
decltype
(
var_grid_desc_mblock_mperblock_nblock
),
decltype
(
mean_var_count_grid_desc_mblock_mperblock_nblock
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
PostShuffleThreadSliceSize_M
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
1
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
false
>
{
mean_var_count_grid_desc_mblock_mperblock_nblock
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
shuffleMPerBlock
*
i
+
post_shuffle_thread_data_idx_begin
[
I0
],
// mperblock
block_work_idx
[
I1
]),
// nblock
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
count_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
int32_t
,
int32_t
,
decltype
(
thread_welford_desc_I_m_I
),
decltype
(
mean_var_count_grid_desc_mblock_mperblock_nblock
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
PostShuffleThreadSliceSize_M
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
1
,
MeanVarTransferScalarPerVector
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
false
>
{
var
_grid_desc_mblock_mperblock_nblock
,
false
>
{
mean_var_count
_grid_desc_mblock_mperblock_nblock
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
shuffleMPerBlock
*
i
+
post_shuffle_thread_data_idx_begin
[
I0
],
// mperblock
block_work_idx
[
I1
]),
// nblock
tensor_operation
::
element_wise
::
PassThrough
{}};
mean_thread_copy_vgpr_to_global
.
Run
(
thread_welford_desc_I_m_I
,
make_tuple
(
I0
,
I0
,
I0
),
mean_thread_buf
,
mean_grid_desc_mblock_mperblock_nblock
,
mean_grid_buf
);
mean_thread_copy_vgpr_to_global
.
Run
(
thread_welford_desc_I_m_I
,
make_tuple
(
I0
,
I0
,
I0
),
mean_thread_buf
,
mean_var_count_grid_desc_mblock_mperblock_nblock
,
mean_grid_buf
);
var_thread_copy_vgpr_to_global
.
Run
(
thread_welford_desc_I_m_I
,
make_tuple
(
I0
,
I0
,
I0
),
var_thread_buf
,
var
_grid_desc_mblock_mperblock_nblock
,
mean_var_count
_grid_desc_mblock_mperblock_nblock
,
var_grid_buf
);
count_thread_copy_vgpr_to_global
.
Run
(
thread_welford_desc_I_m_I
,
make_tuple
(
I0
,
I0
,
I0
),
count_thread_buf
,
mean_var_count_grid_desc_mblock_mperblock_nblock
,
welford_count_grid_buf
);
});
}
// shuffle C + Ds + welford + write out
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment