Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3df07c27
Commit
3df07c27
authored
Dec 14, 2022
by
rocking
Browse files
Use 1D global memory for count
parent
39dedce7
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
151 additions
and
102 deletions
+151
-102
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
.../device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+82
-35
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
...dwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
+50
-50
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
...mm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
+19
-17
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
3df07c27
...
@@ -33,7 +33,8 @@ template <typename GridwiseGemmWelford,
...
@@ -33,7 +33,8 @@ template <typename GridwiseGemmWelford,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
,
typename
MeanVarGridDescriptor_MBlock_MPerBlock_NBlock
,
typename
CountGridDescriptor_MBlock_MPerBlock_NBlock
,
typename
Block2ETileMap
,
typename
Block2ETileMap
,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
...
@@ -57,8 +58,10 @@ __global__ void
...
@@ -57,8 +58,10 @@ __global__ void
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
const
MeanVarGridDescriptor_MBlock_MPerBlock_NBlock
mean_var_count_grid_desc_mblock_mperblock_nblock
,
mean_var_grid_desc_mblock_mperblock_nblock
,
const
CountGridDescriptor_MBlock_MPerBlock_NBlock
count_grid_desc_mblock_mperblock_nblock
,
const
Block2ETileMap
block_2_etile_map
,
const
Block2ETileMap
block_2_etile_map
,
index_t
NRaw
)
index_t
NRaw
)
{
{
...
@@ -81,7 +84,8 @@ __global__ void
...
@@ -81,7 +84,8 @@ __global__ void
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
mean_var_count_grid_desc_mblock_mperblock_nblock
,
mean_var_grid_desc_mblock_mperblock_nblock
,
count_grid_desc_mblock_mperblock_nblock
,
block_2_etile_map
,
block_2_etile_map
,
NRaw
);
NRaw
);
#else
#else
...
@@ -99,7 +103,8 @@ __global__ void
...
@@ -99,7 +103,8 @@ __global__ void
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
mean_var_count_grid_desc_mblock_mperblock_nblock
;
ignore
=
mean_var_grid_desc_mblock_mperblock_nblock
;
ignore
=
count_grid_desc_mblock_mperblock_nblock
;
ignore
=
block_2_etile_map
;
ignore
=
block_2_etile_map
;
ignore
=
NRaw
;
ignore
=
NRaw
;
#endif
#endif
...
@@ -114,7 +119,8 @@ template <typename GridwiseWelfordLayernorm,
...
@@ -114,7 +119,8 @@ template <typename GridwiseWelfordLayernorm,
typename
BetaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
EHGridDesc_M_N
,
typename
EHGridDesc_M_N
,
typename
LayernormMeanVarCountGridDesc_M_NBlock
,
typename
LayernormMeanVarGridDesc_M_NBlock
,
typename
LayernormCountGridDesc_M_NBlock
,
typename
GammaBetaGridDesc_N
,
typename
GammaBetaGridDesc_N
,
typename
HElementwiseOperation
>
typename
HElementwiseOperation
>
__global__
void
__global__
void
...
@@ -131,7 +137,8 @@ __global__ void
...
@@ -131,7 +137,8 @@ __global__ void
HDataType
*
__restrict__
p_h_grid
,
HDataType
*
__restrict__
p_h_grid
,
const
EHGridDesc_M_N
e_grid_desc_m_n
,
const
EHGridDesc_M_N
e_grid_desc_m_n
,
const
EHGridDesc_M_N
h_grid_desc_m_n
,
const
EHGridDesc_M_N
h_grid_desc_m_n
,
const
LayernormMeanVarCountGridDesc_M_NBlock
mean_var_count_grid_desc_m_nblock
,
const
LayernormMeanVarGridDesc_M_NBlock
mean_var_grid_desc_m_nblock
,
const
LayernormCountGridDesc_M_NBlock
count_grid_desc_m_nblock
,
const
GammaBetaGridDesc_N
gamma_grid_desc_n
,
const
GammaBetaGridDesc_N
gamma_grid_desc_n
,
const
GammaBetaGridDesc_N
beta_grid_desc_n
,
const
GammaBetaGridDesc_N
beta_grid_desc_n
,
index_t
numMeanVarCountBlockTileIteration_N
,
index_t
numMeanVarCountBlockTileIteration_N
,
...
@@ -148,7 +155,8 @@ __global__ void
...
@@ -148,7 +155,8 @@ __global__ void
p_h_grid
,
p_h_grid
,
e_grid_desc_m_n
,
e_grid_desc_m_n
,
h_grid_desc_m_n
,
h_grid_desc_m_n
,
mean_var_count_grid_desc_m_nblock
,
mean_var_grid_desc_m_nblock
,
count_grid_desc_m_nblock
,
gamma_grid_desc_n
,
gamma_grid_desc_n
,
beta_grid_desc_n
,
beta_grid_desc_n
,
numMeanVarCountBlockTileIteration_N
,
numMeanVarCountBlockTileIteration_N
,
...
@@ -315,7 +323,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -315,7 +323,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
Number
<
NumDTensor
>
{});
Number
<
NumDTensor
>
{});
}
}
template
<
typename
LayOut
,
typename
DoPads
,
index_t
MPerTile
,
index_t
NPerTile
>
template
<
typename
DoPads
,
index_t
MPerTile
,
index_t
NPerTile
>
static
auto
MakeMeanVarDescriptor_M_N
(
index_t
M
,
index_t
N
)
static
auto
MakeMeanVarDescriptor_M_N
(
index_t
M
,
index_t
N
)
{
{
const
auto
grid_desc_m_n
=
const
auto
grid_desc_m_n
=
...
@@ -323,6 +331,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -323,6 +331,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
return
PadTensorDescriptor
(
grid_desc_m_n
,
make_tuple
(
MPerTile
,
NPerTile
),
DoPads
{});
return
PadTensorDescriptor
(
grid_desc_m_n
,
make_tuple
(
MPerTile
,
NPerTile
),
DoPads
{});
}
}
template
<
typename
DoPads
,
index_t
MPerTile
,
index_t
NPerTile
>
static
auto
MakeCountDescriptor_M_N
(
index_t
M
,
index_t
N
)
{
const
auto
grid_desc_m_n
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I0
,
I1
));
return
PadTensorDescriptor
(
grid_desc_m_n
,
make_tuple
(
MPerTile
,
NPerTile
),
DoPads
{});
}
template
<
index_t
XPerTile
>
template
<
index_t
XPerTile
>
static
auto
MakeDescriptor_X
(
index_t
X
)
static
auto
MakeDescriptor_X
(
index_t
X
)
{
{
...
@@ -335,12 +351,19 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -335,12 +351,19 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid
// layout(different padding)
// layout(different padding)
using
GemmMeanVarCountGridDesc_M_NBlock
=
decltype
(
using
GemmMeanVarGridDesc_M_NBlock
=
MakeMeanVarDescriptor_M_N
<
HLayout
,
Sequence
<
true
,
false
>
,
MPerBlock
,
NPerBlock
>
(
1
,
1
));
decltype
(
MakeMeanVarDescriptor_M_N
<
Sequence
<
true
,
false
>
,
MPerBlock
,
NPerBlock
>
(
1
,
1
));
using
GemmCountGridDesc_M_NBlock
=
decltype
(
MakeCountDescriptor_M_N
<
Sequence
<
true
,
false
>
,
MPerBlock
,
NPerBlock
>
(
1
,
1
));
using
LayernormMeanVarGridDesc_M_NBlock
=
decltype
(
MakeMeanVarDescriptor_M_N
<
Sequence
<
true
,
true
>
,
LayernormBlockTileSize_M_N
::
At
(
0
),
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
1
,
1
));
using
LayernormMeanVarCountGridDesc_M_NBlock
=
using
LayernormCountGridDesc_M_NBlock
=
decltype
(
MakeMeanVarDescriptor_M_N
<
HLayout
,
decltype
(
MakeCountDescriptor_M_N
<
Sequence
<
true
,
true
>
,
Sequence
<
true
,
true
>
,
LayernormBlockTileSize_M_N
::
At
(
0
),
LayernormBlockTileSize_M_N
::
At
(
0
),
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
1
,
1
));
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
1
,
1
));
...
@@ -363,7 +386,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -363,7 +386,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K
,
BGridDesc_N_K
,
DsGridDesc_M_N
,
DsGridDesc_M_N
,
EHGridDesc_M_N
,
EHGridDesc_M_N
,
GemmMeanVarCountGridDesc_M_NBlock
,
GemmMeanVarGridDesc_M_NBlock
,
GemmCountGridDesc_M_NBlock
,
NumGemmKPrefetchStage
,
NumGemmKPrefetchStage
,
BlockSize
,
BlockSize
,
MPerBlock
,
MPerBlock
,
...
@@ -408,7 +432,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -408,7 +432,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType
,
BetaDataType
,
AccDataType
,
AccDataType
,
EHGridDesc_M_N
,
EHGridDesc_M_N
,
LayernormMeanVarCountGridDesc_M_NBlock
,
LayernormMeanVarGridDesc_M_NBlock
,
LayernormCountGridDesc_M_NBlock
,
GammaBetaGridDesc_N
,
GammaBetaGridDesc_N
,
HElementwiseOperation
,
HElementwiseOperation
,
BlockSize
,
BlockSize
,
...
@@ -456,8 +481,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -456,8 +481,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
KRaw
,
NRaw
,
StrideB
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
KRaw
,
NRaw
,
StrideB
)},
ds_grid_desc_m_n_
{},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEHGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
NRaw
,
StrideH
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEHGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
NRaw
,
StrideH
)},
gemm_mean_var_count_grid_desc_m_nblock_
{},
gemm_mean_var_grid_desc_m_nblock_
{},
layernorm_mean_var_count_grid_desc_m_nblock_
{},
gemm_count_grid_desc_m_nblock_
{},
layernorm_mean_var_grid_desc_m_nblock_
{},
layernorm_count_grid_desc_m_nblock_
{},
gamma_grid_desc_n_
{
gamma_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_X
<
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
NRaw
)},
DeviceOp
::
MakeDescriptor_X
<
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
NRaw
)},
beta_grid_desc_n_
{
beta_grid_desc_n_
{
...
@@ -478,17 +505,26 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -478,17 +505,26 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
gemm_nblock_
{
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)},
gemm_nblock_
{
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)},
epsilon_
{
epsilon
}
epsilon_
{
epsilon
}
{
{
gemm_mean_var_count_grid_desc_m_nblock_
=
DeviceOp
::
gemm_mean_var_grid_desc_m_nblock_
=
MakeMeanVarDescriptor_M_N
<
HLayout
,
Sequence
<
true
,
false
>
,
MPerBlock
,
NPerBlock
>
(
DeviceOp
::
MakeMeanVarDescriptor_M_N
<
Sequence
<
true
,
false
>
,
MPerBlock
,
NPerBlock
>
(
MRaw
,
gemm_nblock_
);
gemm_count_grid_desc_m_nblock_
=
DeviceOp
::
MakeCountDescriptor_M_N
<
Sequence
<
true
,
false
>
,
MPerBlock
,
NPerBlock
>
(
MRaw
,
gemm_nblock_
);
MRaw
,
gemm_nblock_
);
layernorm_mean_var_count_grid_desc_m_nblock_
=
layernorm_mean_var_grid_desc_m_nblock_
=
DeviceOp
::
MakeMeanVarDescriptor_M_N
<
HLayout
,
DeviceOp
::
MakeMeanVarDescriptor_M_N
<
Sequence
<
true
,
true
>
,
Sequence
<
true
,
true
>
,
LayernormBlockTileSize_M_N
::
At
(
0
),
LayernormBlockTileSize_M_N
::
At
(
0
),
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
MRaw
,
gemm_nblock_
);
MRaw
,
gemm_nblock_
);
layernorm_count_grid_desc_m_nblock_
=
DeviceOp
::
MakeCountDescriptor_M_N
<
Sequence
<
true
,
true
>
,
LayernormBlockTileSize_M_N
::
At
(
0
),
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
MRaw
,
gemm_nblock_
);
// populate pointer, desc for Ds
// populate pointer, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
...
@@ -517,9 +553,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -517,9 +553,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
GridwiseGemmWelford
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemmWelford
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
e_grid_desc_m_n_
);
mean_var_count_grid_desc_mblock_mperblock_nblock_
=
gemm_mean_var_grid_desc_mblock_mperblock_nblock_
=
GridwiseGemmWelford
::
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
gemm_mean_var_grid_desc_m_nblock_
);
gemm_count_grid_desc_mblock_mperblock_nblock_
=
GridwiseGemmWelford
::
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
GridwiseGemmWelford
::
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
gemm_
mean_var_
count_grid_desc_m_nblock_
);
gemm_count_grid_desc_m_nblock_
);
}
}
}
}
...
@@ -551,8 +591,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -551,8 +591,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K
b_grid_desc_n_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EHGridDesc_M_N
e_grid_desc_m_n_
;
EHGridDesc_M_N
e_grid_desc_m_n_
;
GemmMeanVarCountGridDesc_M_NBlock
gemm_mean_var_count_grid_desc_m_nblock_
;
GemmMeanVarGridDesc_M_NBlock
gemm_mean_var_grid_desc_m_nblock_
;
LayernormMeanVarCountGridDesc_M_NBlock
layernorm_mean_var_count_grid_desc_m_nblock_
;
GemmCountGridDesc_M_NBlock
gemm_count_grid_desc_m_nblock_
;
LayernormMeanVarGridDesc_M_NBlock
layernorm_mean_var_grid_desc_m_nblock_
;
LayernormCountGridDesc_M_NBlock
layernorm_count_grid_desc_m_nblock_
;
GammaBetaGridDesc_N
gamma_grid_desc_n_
;
GammaBetaGridDesc_N
gamma_grid_desc_n_
;
GammaBetaGridDesc_N
beta_grid_desc_n_
;
GammaBetaGridDesc_N
beta_grid_desc_n_
;
EHGridDesc_M_N
h_grid_desc_m_n_
;
EHGridDesc_M_N
h_grid_desc_m_n_
;
...
@@ -564,8 +606,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -564,8 +606,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemmWelford
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemmWelford
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemmWelford
::
MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
typename
GridwiseGemmWelford
::
MeanVarGridDescriptor_MBlock_MPerBlock_NBlock
mean_var_count_grid_desc_mblock_mperblock_nblock_
;
gemm_mean_var_grid_desc_mblock_mperblock_nblock_
;
typename
GridwiseGemmWelford
::
CountGridDescriptor_MBlock_MPerBlock_NBlock
gemm_count_grid_desc_mblock_mperblock_nblock_
;
// block-to-e-tile map
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
Block2ETileMap
block_2_etile_map_
;
...
@@ -628,8 +672,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -628,8 +672,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemmWelford
::
typename
GridwiseGemmWelford
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemmWelford
::
typename
GridwiseGemmWelford
::
MeanVarGridDescriptor_MBlock_MPerBlock_NBlock
,
MeanVar
CountGridDescriptor_MBlock_MPerBlock_NBlock
,
typename
GridwiseGemmWelford
::
CountGridDescriptor_MBlock_MPerBlock_NBlock
,
typename
GridwiseGemmWelford
::
DefaultBlock2ETileMap
,
typename
GridwiseGemmWelford
::
DefaultBlock2ETileMap
,
has_main_loop
>
;
has_main_loop
>
;
...
@@ -643,7 +687,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -643,7 +687,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType
,
BetaDataType
,
AccDataType
,
AccDataType
,
EHGridDesc_M_N
,
EHGridDesc_M_N
,
LayernormMeanVarCountGridDesc_M_NBlock
,
LayernormMeanVarGridDesc_M_NBlock
,
LayernormCountGridDesc_M_NBlock
,
GammaBetaGridDesc_N
,
GammaBetaGridDesc_N
,
HElementwiseOperation
>
;
HElementwiseOperation
>
;
...
@@ -667,7 +712,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -667,7 +712,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
mean_var_count_grid_desc_mblock_mperblock_nblock_
,
arg
.
gemm_mean_var_grid_desc_mblock_mperblock_nblock_
,
arg
.
gemm_count_grid_desc_mblock_mperblock_nblock_
,
arg
.
block_2_etile_map_
,
arg
.
block_2_etile_map_
,
arg
.
NRaw_
);
arg
.
NRaw_
);
...
@@ -694,7 +740,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -694,7 +740,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg
.
p_h_grid_
,
arg
.
p_h_grid_
,
arg
.
e_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
h_grid_desc_m_n_
,
arg
.
h_grid_desc_m_n_
,
arg
.
layernorm_mean_var_count_grid_desc_m_nblock_
,
arg
.
layernorm_mean_var_grid_desc_m_nblock_
,
arg
.
layernorm_count_grid_desc_m_nblock_
,
arg
.
gamma_grid_desc_n_
,
arg
.
gamma_grid_desc_n_
,
arg
.
beta_grid_desc_n_
,
arg
.
beta_grid_desc_n_
,
numMeanVarCountBlockTileIteration_N
,
numMeanVarCountBlockTileIteration_N
,
...
@@ -738,7 +785,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -738,7 +785,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
workspace_size
+=
gemm_welford_size
*
sizeof
(
VarDataType
)
+
64
;
workspace_size
+=
gemm_welford_size
*
sizeof
(
VarDataType
)
+
64
;
// workspace for welford intermediate count
// workspace for welford intermediate count
workspace_size
+=
gemm_welford_size
*
sizeof
(
int32_t
)
+
64
;
workspace_size
+=
pArg_
->
gemm_nblock_
*
sizeof
(
int32_t
)
+
64
;
return
(
workspace_size
);
return
(
workspace_size
);
};
};
...
...
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
View file @
3df07c27
...
@@ -47,7 +47,8 @@ template <typename ABDataType,
...
@@ -47,7 +47,8 @@ template <typename ABDataType,
typename
BGridDesc_N_K
,
typename
BGridDesc_N_K
,
typename
DsGridDesc_M_N
,
typename
DsGridDesc_M_N
,
typename
EGridDesc_M_N
,
typename
EGridDesc_M_N
,
typename
MeanVarCountGridDesc_M_NBlock
,
typename
MeanVarGridDesc_M_NBlock
,
typename
CountGridDesc_M_NBlock
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
MPerBlock
,
...
@@ -347,8 +348,10 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -347,8 +348,10 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
remove_cvref_t
<
decltype
(
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
using
MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
=
remove_cvref_t
<
decltype
(
using
MeanVarGridDescriptor_MBlock_MPerBlock_NBlock
=
remove_cvref_t
<
decltype
(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
MeanVarCountGridDesc_M_NBlock
{}))
>
;
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
MeanVarGridDesc_M_NBlock
{}))
>
;
using
CountGridDescriptor_MBlock_MPerBlock_NBlock
=
remove_cvref_t
<
decltype
(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
CountGridDesc_M_NBlock
{}))
>
;
using
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
using
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
...
@@ -361,7 +364,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -361,7 +364,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
Block2ETileMap
>
typename
Block2ETileMap
>
__device__
static
void
Run
(
const
ABDataType
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
DsGridPointer
p_ds_grid
,
DsGridPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
EDataType
*
__restrict__
p_e_grid
,
...
@@ -378,8 +382,9 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -378,8 +382,9 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
&
const
MeanVarGridDescriptor_MBlock_MPerBlock_NBlock
&
mean_var_count_grid_desc_mblock_mperblock_nblock
,
mean_var_grid_desc_mblock_mperblock_nblock
,
const
CountGridDescriptor_MBlock_MPerBlock_NBlock
&
count_grid_desc_mblock_mperblock_nblock
,
const
Block2ETileMap
&
block_2_etile_map
,
const
Block2ETileMap
&
block_2_etile_map
,
index_t
NRaw
)
index_t
NRaw
)
{
{
...
@@ -401,16 +406,13 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -401,16 +406,13 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
p_e_grid
,
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_e_grid
,
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
mean_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
mean_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_welford_mean_grid
,
p_welford_mean_grid
,
mean_var_grid_desc_mblock_mperblock_nblock
.
GetElementSpaceSize
());
mean_var_count_grid_desc_mblock_mperblock_nblock
.
GetElementSpaceSize
());
auto
var_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
var_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_welford_var_grid
,
p_welford_var_grid
,
mean_var_grid_desc_mblock_mperblock_nblock
.
GetElementSpaceSize
());
mean_var_count_grid_desc_mblock_mperblock_nblock
.
GetElementSpaceSize
());
auto
welford_count_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
welford_count_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_welford_count
,
p_welford_count
,
count_grid_desc_mblock_mperblock_nblock
.
GetElementSpaceSize
());
mean_var_count_grid_desc_mblock_mperblock_nblock
.
GetElementSpaceSize
());
// divide block work by [M, N]
// divide block work by [M, N]
const
auto
block_work_idx
=
const
auto
block_work_idx
=
...
@@ -880,7 +882,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -880,7 +882,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
Array
<
welford_count_vgpr_type
,
num_shuffleM
>
welford_count_thread_bufs
;
Array
<
welford_count_vgpr_type
,
num_shuffleM
>
welford_count_thread_bufs
;
int
max_count
=
PostShuffleThreadSliceSize_N
*
num_shuffleN
;
int
max_count
=
PostShuffleThreadSliceSize_N
*
num_shuffleN
;
const
auto
nblock
=
mean_var_
count_
grid_desc_mblock_mperblock_nblock
.
GetLength
(
I2
);
const
auto
nblock
=
mean_var_grid_desc_mblock_mperblock_nblock
.
GetLength
(
I2
);
// tail block
// tail block
if
(
block_work_idx
[
I1
]
%
nblock
==
nblock
-
1
)
if
(
block_work_idx
[
I1
]
%
nblock
==
nblock
-
1
)
...
@@ -1038,7 +1040,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -1038,7 +1040,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
AccDataType
,
AccDataType
,
MeanDataType
,
MeanDataType
,
decltype
(
thread_welford_desc_I_m_I
),
decltype
(
thread_welford_desc_I_m_I
),
decltype
(
mean_var_
count_
grid_desc_mblock_mperblock_nblock
),
decltype
(
mean_var_grid_desc_mblock_mperblock_nblock
),
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
PostShuffleThreadSliceSize_M
,
1
>
,
Sequence
<
1
,
PostShuffleThreadSliceSize_M
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
...
@@ -1046,7 +1048,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -1046,7 +1048,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
1
,
1
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
1
,
false
>
{
mean_var_
count_
grid_desc_mblock_mperblock_nblock
,
false
>
{
mean_var_grid_desc_mblock_mperblock_nblock
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
make_multi_index
(
block_work_idx
[
I0
],
// mblock
shuffleMPerBlock
*
i
+
shuffleMPerBlock
*
i
+
post_shuffle_thread_data_idx_begin
[
I0
],
// mperblock
post_shuffle_thread_data_idx_begin
[
I0
],
// mperblock
...
@@ -1057,7 +1059,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -1057,7 +1059,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
AccDataType
,
AccDataType
,
VarDataType
,
VarDataType
,
decltype
(
thread_welford_desc_I_m_I
),
decltype
(
thread_welford_desc_I_m_I
),
decltype
(
mean_var_
count_
grid_desc_mblock_mperblock_nblock
),
decltype
(
mean_var_grid_desc_mblock_mperblock_nblock
),
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
PostShuffleThreadSliceSize_M
,
1
>
,
Sequence
<
1
,
PostShuffleThreadSliceSize_M
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
...
@@ -1065,7 +1067,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -1065,7 +1067,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
1
,
1
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
1
,
false
>
{
mean_var_
count_
grid_desc_mblock_mperblock_nblock
,
false
>
{
mean_var_grid_desc_mblock_mperblock_nblock
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
make_multi_index
(
block_work_idx
[
I0
],
// mblock
shuffleMPerBlock
*
i
+
shuffleMPerBlock
*
i
+
post_shuffle_thread_data_idx_begin
[
I0
],
// mperblock
post_shuffle_thread_data_idx_begin
[
I0
],
// mperblock
...
@@ -1076,7 +1078,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -1076,7 +1078,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
int32_t
,
int32_t
,
int32_t
,
int32_t
,
decltype
(
thread_welford_desc_I_m_I
),
decltype
(
thread_welford_desc_I_m_I
),
decltype
(
mean_var_
count_grid_desc_mblock_mperblock_nblock
),
decltype
(
count_grid_desc_mblock_mperblock_nblock
),
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
PostShuffleThreadSliceSize_M
,
1
>
,
Sequence
<
1
,
PostShuffleThreadSliceSize_M
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
...
@@ -1084,31 +1086,29 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -1084,31 +1086,29 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
1
,
1
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
1
,
false
>
{
mean_var_
count_grid_desc_mblock_mperblock_nblock
,
false
>
{
count_grid_desc_mblock_mperblock_nblock
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
make_multi_index
(
block_work_idx
[
I0
],
// mblock
shuffleMPerBlock
*
i
+
shuffleMPerBlock
*
i
+
post_shuffle_thread_data_idx_begin
[
I0
],
// mperblock
post_shuffle_thread_data_idx_begin
[
I0
],
// mperblock
block_work_idx
[
I1
]),
// nblock
block_work_idx
[
I1
]),
// nblock
tensor_operation
::
element_wise
::
PassThrough
{}};
tensor_operation
::
element_wise
::
PassThrough
{}};
mean_thread_copy_vgpr_to_global
.
Run
(
mean_thread_copy_vgpr_to_global
.
Run
(
thread_welford_desc_I_m_I
,
thread_welford_desc_I_m_I
,
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
),
mean_thread_buf
,
mean_thread_buf
,
mean_var_
count_
grid_desc_mblock_mperblock_nblock
,
mean_var_grid_desc_mblock_mperblock_nblock
,
mean_grid_buf
);
mean_grid_buf
);
var_thread_copy_vgpr_to_global
.
Run
(
thread_welford_desc_I_m_I
,
var_thread_copy_vgpr_to_global
.
Run
(
thread_welford_desc_I_m_I
,
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
),
var_thread_buf
,
var_thread_buf
,
mean_var_
count_
grid_desc_mblock_mperblock_nblock
,
mean_var_grid_desc_mblock_mperblock_nblock
,
var_grid_buf
);
var_grid_buf
);
count_thread_copy_vgpr_to_global
.
Run
(
count_thread_copy_vgpr_to_global
.
Run
(
thread_welford_desc_I_m_I
,
thread_welford_desc_I_m_I
,
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
),
count_thread_buf
,
count_thread_buf
,
mean_var_
count_grid_desc_mblock_mperblock_nblock
,
count_grid_desc_mblock_mperblock_nblock
,
welford_count_grid_buf
);
welford_count_grid_buf
);
});
});
...
...
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
View file @
3df07c27
...
@@ -27,7 +27,8 @@ template <typename EDataType,
...
@@ -27,7 +27,8 @@ template <typename EDataType,
typename
BetaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
EHGridDesc_M_N
,
typename
EHGridDesc_M_N
,
typename
MeanVarCountGridDesc_M_NBlock
,
typename
MeanVarGridDesc_M_NBlock
,
typename
CountGridDesc_M_NBlock
,
typename
GammaBetaGridDesc_N
,
typename
GammaBetaGridDesc_N
,
typename
HElementwiseOperation
,
typename
HElementwiseOperation
,
index_t
BlockSize
,
index_t
BlockSize
,
...
@@ -95,7 +96,8 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -95,7 +96,8 @@ struct GridwiseWelfordSecondHalfLayernorm2d
HDataType
*
__restrict__
p_h_grid
,
HDataType
*
__restrict__
p_h_grid
,
const
EHGridDesc_M_N
&
e_grid_desc_m_n
,
const
EHGridDesc_M_N
&
e_grid_desc_m_n
,
const
EHGridDesc_M_N
&
h_grid_desc_m_n
,
const
EHGridDesc_M_N
&
h_grid_desc_m_n
,
const
MeanVarCountGridDesc_M_NBlock
&
mean_var_count_grid_desc_m_n
,
const
MeanVarGridDesc_M_NBlock
&
mean_var_grid_desc_m_n
,
const
CountGridDesc_M_NBlock
&
count_grid_desc_m_n
,
const
GammaBetaGridDesc_N
&
gamma_grid_desc_n
,
const
GammaBetaGridDesc_N
&
gamma_grid_desc_n
,
const
GammaBetaGridDesc_N
&
beta_grid_desc_n
,
const
GammaBetaGridDesc_N
&
beta_grid_desc_n
,
index_t
numMeanVarCountBlockTileIteration_N
,
index_t
numMeanVarCountBlockTileIteration_N
,
...
@@ -116,13 +118,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -116,13 +118,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d
p_e_grid
,
e_grid_desc_m_n
.
GetElementSpaceSize
());
p_e_grid
,
e_grid_desc_m_n
.
GetElementSpaceSize
());
const
auto
welford_mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
welford_mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_mean_grid
,
mean_var_
count_
grid_desc_m_n
.
GetElementSpaceSize
());
p_in_welford_mean_grid
,
mean_var_grid_desc_m_n
.
GetElementSpaceSize
());
const
auto
welford_var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
welford_var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_var_grid
,
mean_var_
count_
grid_desc_m_n
.
GetElementSpaceSize
());
p_in_welford_var_grid
,
mean_var_grid_desc_m_n
.
GetElementSpaceSize
());
const
auto
welford_count_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
welford_count_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_count_grid
,
mean_var_
count_grid_desc_m_n
.
GetElementSpaceSize
());
p_in_welford_count_grid
,
count_grid_desc_m_n
.
GetElementSpaceSize
());
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_grid
,
gamma_grid_desc_n
.
GetElementSpaceSize
());
p_gamma_grid
,
gamma_grid_desc_n
.
GetElementSpaceSize
());
...
@@ -173,7 +175,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -173,7 +175,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
auto
threadwise_mean_load_m_nblock
=
auto
threadwise_mean_load_m_nblock
=
ThreadwiseTensorSliceTransfer_v2
<
MeanDataType
,
ThreadwiseTensorSliceTransfer_v2
<
MeanDataType
,
ComputeDataType
,
ComputeDataType
,
MeanVar
Count
GridDesc_M_NBlock
,
MeanVarGridDesc_M_NBlock
,
decltype
(
thread_buffer_desc_m_1
),
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
ThreadBufferLengths_M_1
,
ThreadBufferDimAccessOrder
,
ThreadBufferDimAccessOrder
,
...
@@ -181,7 +183,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -181,7 +183,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1
,
1
,
1
,
1
,
true
>
(
true
>
(
mean_var_
count_
grid_desc_m_n
,
mean_var_grid_desc_m_n
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_m_cluster_id
*
MThreadSliceSize
,
thread_n_cluster_id
));
thread_n_cluster_id
));
...
@@ -189,7 +191,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -189,7 +191,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
auto
threadwise_var_load_m_nblock
=
auto
threadwise_var_load_m_nblock
=
ThreadwiseTensorSliceTransfer_v2
<
VarDataType
,
ThreadwiseTensorSliceTransfer_v2
<
VarDataType
,
ComputeDataType
,
ComputeDataType
,
MeanVar
Count
GridDesc_M_NBlock
,
MeanVarGridDesc_M_NBlock
,
decltype
(
thread_buffer_desc_m_1
),
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
ThreadBufferLengths_M_1
,
ThreadBufferDimAccessOrder
,
ThreadBufferDimAccessOrder
,
...
@@ -197,7 +199,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -197,7 +199,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1
,
1
,
1
,
1
,
true
>
(
true
>
(
mean_var_
count_
grid_desc_m_n
,
mean_var_grid_desc_m_n
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_m_cluster_id
*
MThreadSliceSize
,
thread_n_cluster_id
));
thread_n_cluster_id
));
...
@@ -205,7 +207,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -205,7 +207,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
auto
threadwise_count_load_m_nblock
=
auto
threadwise_count_load_m_nblock
=
ThreadwiseTensorSliceTransfer_v2
<
int32_t
,
ThreadwiseTensorSliceTransfer_v2
<
int32_t
,
int32_t
,
int32_t
,
MeanVar
CountGridDesc_M_NBlock
,
CountGridDesc_M_NBlock
,
decltype
(
thread_buffer_desc_m_1
),
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
ThreadBufferLengths_M_1
,
ThreadBufferDimAccessOrder
,
ThreadBufferDimAccessOrder
,
...
@@ -213,7 +215,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -213,7 +215,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1
,
1
,
1
,
1
,
true
>
(
true
>
(
mean_var_
count_grid_desc_m_n
,
count_grid_desc_m_n
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_m_cluster_id
*
MThreadSliceSize
,
thread_n_cluster_id
));
thread_n_cluster_id
));
...
@@ -292,19 +294,19 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -292,19 +294,19 @@ struct GridwiseWelfordSecondHalfLayernorm2d
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
numMeanVarCountBlockTileIteration_N
;
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
numMeanVarCountBlockTileIteration_N
;
++
reducedTiles
)
++
reducedTiles
)
{
{
threadwise_mean_load_m_nblock
.
Run
(
mean_var_
count_
grid_desc_m_n
,
threadwise_mean_load_m_nblock
.
Run
(
mean_var_grid_desc_m_n
,
welford_mean_global_val_buf
,
welford_mean_global_val_buf
,
thread_buffer_desc_m_1
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
in_welford_mean_thread_buf
);
in_welford_mean_thread_buf
);
threadwise_var_load_m_nblock
.
Run
(
mean_var_
count_
grid_desc_m_n
,
threadwise_var_load_m_nblock
.
Run
(
mean_var_grid_desc_m_n
,
welford_var_global_val_buf
,
welford_var_global_val_buf
,
thread_buffer_desc_m_1
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
in_welford_var_thread_buf
);
in_welford_var_thread_buf
);
threadwise_count_load_m_nblock
.
Run
(
mean_var_
count_grid_desc_m_n
,
threadwise_count_load_m_nblock
.
Run
(
count_grid_desc_m_n
,
welford_count_global_val_buf
,
welford_count_global_val_buf
,
thread_buffer_desc_m_1
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
...
@@ -317,11 +319,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -317,11 +319,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d
welford_var_thread_buf
,
welford_var_thread_buf
,
welford_count_thread_buf
);
welford_count_thread_buf
);
threadwise_mean_load_m_nblock
.
MoveSrcSliceWindow
(
mean_var_
count_
grid_desc_m_n
,
threadwise_mean_load_m_nblock
.
MoveSrcSliceWindow
(
mean_var_grid_desc_m_n
,
mean_var_count_thread_copy_step_m_n
);
mean_var_count_thread_copy_step_m_n
);
threadwise_var_load_m_nblock
.
MoveSrcSliceWindow
(
mean_var_
count_
grid_desc_m_n
,
threadwise_var_load_m_nblock
.
MoveSrcSliceWindow
(
mean_var_grid_desc_m_n
,
mean_var_count_thread_copy_step_m_n
);
mean_var_count_thread_copy_step_m_n
);
threadwise_count_load_m_nblock
.
MoveSrcSliceWindow
(
mean_var_
count_grid_desc_m_n
,
threadwise_count_load_m_nblock
.
MoveSrcSliceWindow
(
count_grid_desc_m_n
,
mean_var_count_thread_copy_step_m_n
);
mean_var_count_thread_copy_step_m_n
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment