Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
78ff5f81
Commit
78ff5f81
authored
Dec 02, 2022
by
rocking
Browse files
Implement layernorm
parent
a4e34d88
Changes
3
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
273 additions
and
144 deletions
+273
-144
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
.../device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+38
-39
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
...dwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
...mm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
+233
-103
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
78ff5f81
...
...
@@ -111,8 +111,9 @@ template <typename GridwiseWelfordLayernorm,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
EHGridDesc_M_N
,
typename
MeanVarCountGridDesc_M_N
,
typename
GammaBetaGridDesc_N
>
typename
MeanVarCountGridDesc_M_NBlock
,
typename
GammaBetaGridDesc_N
,
typename
HElementwiseOperation
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
...
@@ -127,13 +128,13 @@ __global__ void
HDataType
*
__restrict__
p_h_grid
,
const
EHGridDesc_M_N
e_grid_desc_m_n
,
const
EHGridDesc_M_N
h_grid_desc_m_n
,
const
MeanVarCountGridDesc_M_N
mean_var_count_grid_desc_m_n
,
const
MeanVarCountGridDesc_M_N
Block
mean_var_count_grid_desc_m_n
block
,
const
GammaBetaGridDesc_N
gamma_grid_desc_n
,
const
GammaBetaGridDesc_N
beta_grid_desc_n
,
index_t
blkgroup_size
,
index_t
num
_mean_var_count_k_b
lock
_t
ile
_i
teration
,
index_t
num_xy_k_block_tile_iterati
on
,
ComputeDataType
epsilon
)
index_t
numMeanVarCountBlockTileIteration_N
,
index_t
num
NormB
lock
T
ile
I
teration
_N
,
ComputeDataType
epsil
on
,
HElementwiseOperation
h_element_op
)
{
GridwiseWelfordLayernorm
::
Run
(
p_e_grid
,
p_in_welford_mean_grid
,
...
...
@@ -144,13 +145,13 @@ __global__ void
p_h_grid
,
e_grid_desc_m_n
,
h_grid_desc_m_n
,
mean_var_count_grid_desc_m_n
,
mean_var_count_grid_desc_m_n
block
,
gamma_grid_desc_n
,
beta_grid_desc_n
,
blkgroup_size
,
num
_mean_var_count_k_b
lock
_t
ile
_i
teration
,
num_xy_k_block_tile_iterati
on
,
epsilon
);
numMeanVarCountBlockTileIteration_N
,
num
NormB
lock
T
ile
I
teration
_N
,
epsil
on
,
h_element_op
);
}
}
// namespace ck
...
...
@@ -374,7 +375,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
(
1
,
1
,
1
));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
(
1
,
1
,
1
));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
MeanVarCountGridDesc_M_N
=
decltype
(
MakeMeanVarCountGridDescriptor_M_NBlock
(
1
,
1
));
using
MeanVarCountGridDesc_M_N
Block
=
decltype
(
MakeMeanVarCountGridDescriptor_M_NBlock
(
1
,
1
));
using
GammaBetaGridDesc_N
=
decltype
(
MakeDescriptor_N
(
1
));
using
EHGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
<
HLayout
>
(
1
,
1
,
1
));
...
...
@@ -394,7 +395,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K
,
DsGridDesc_M_N
,
EHGridDesc_M_N
,
MeanVarCountGridDesc_M_N
,
MeanVarCountGridDesc_M_N
Block
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
...
...
@@ -439,8 +440,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType
,
AccDataType
,
EHGridDesc_M_N
,
MeanVarCountGridDesc_M_N
,
MeanVarCountGridDesc_M_N
Block
,
GammaBetaGridDesc_N
,
HElementwiseOperation
,
BlockSize
,
LayernormThreadClusterSize_M_N
::
At
(
I0
),
LayernormThreadClusterSize_M_N
::
At
(
I1
),
...
...
@@ -488,7 +490,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
KRaw
,
NRaw
,
StrideB
)},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
NRaw
,
StrideH
)},
mean_var_count_grid_desc_m_n_
{},
mean_var_count_grid_desc_m_n
block
_
{},
gamma_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
beta_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
h_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
HLayout
>
(
MRaw
,
NRaw
,
StrideH
)},
...
...
@@ -504,7 +506,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
gemm_nblock_
{
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)},
epsilon_
{
epsilon
}
{
mean_var_count_grid_desc_m_n_
=
mean_var_count_grid_desc_m_n
block
_
=
DeviceOp
::
MakeMeanVarCountGridDescriptor_M_NBlock
(
MRaw
,
gemm_nblock_
);
hip_check_error
(
hipMalloc
(
&
p_e_grid_
,
sizeof
(
EDataType
)
*
MRaw
*
NRaw
));
...
...
@@ -546,7 +548,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
mean_var_count_grid_desc_mblock_mperblock_nblock_
=
GridwiseGemmWelford
::
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
mean_var_count_grid_desc_m_n_
);
mean_var_count_grid_desc_m_n
block
_
);
}
}
...
...
@@ -578,7 +580,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K
b_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EHGridDesc_M_N
e_grid_desc_m_n_
;
MeanVarCountGridDesc_M_N
mean_var_count_grid_desc_m_n_
;
MeanVarCountGridDesc_M_N
Block
mean_var_count_grid_desc_m_n
block
_
;
GammaBetaGridDesc_N
gamma_grid_desc_n_
;
GammaBetaGridDesc_N
beta_grid_desc_n_
;
EHGridDesc_M_N
h_grid_desc_m_n_
;
...
...
@@ -666,8 +668,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType
,
AccDataType
,
EHGridDesc_M_N
,
MeanVarCountGridDesc_M_N
,
GammaBetaGridDesc_N
>
;
MeanVarCountGridDesc_M_NBlock
,
GammaBetaGridDesc_N
,
HElementwiseOperation
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
...
...
@@ -692,17 +695,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg
.
mean_var_count_grid_desc_mblock_mperblock_nblock_
,
arg
.
block_2_etile_map_
);
grid_size
=
math
::
integer_least_multiple
(
M
,
LayernormBlockTileSize_M_N
::
At
(
0
))
/
LayernormBlockTileSize_M_N
::
At
(
0
);
grid_size
=
math
::
integer_divide_ceil
(
M
,
LayernormBlockTileSize_M_N
::
At
(
0
));
index_t
numMeanVarCountBlockTileIteration_N
=
math
::
integer_least_multiple
(
arg
.
gemm_nblock_
,
LayernormThreadClusterSize_M_N
::
At
(
I1
))
/
LayernormThreadClusterSize_M_N
::
At
(
I1
);
index_t
numMeanVarCountBlockTileIteration_N
=
math
::
integer_divide_ceil
(
arg
.
gemm_nblock_
,
LayernormThreadClusterSize_M_N
::
At
(
I1
));
index_t
numEBlockTileIteration_N
=
math
::
integer_least_multiple
(
N
,
LayernormBlockTileSize_M_N
::
At
(
I1
))
/
LayernormBlockTileSize_M_N
::
At
(
I1
);
index_t
numNormBlockTileIteration_N
=
math
::
integer_divide_ceil
(
N
,
LayernormBlockTileSize_M_N
::
At
(
I1
));
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_welford_layernorm
,
...
...
@@ -718,13 +717,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg
.
p_h_grid_
,
arg
.
e_grid_desc_m_n_
,
arg
.
h_grid_desc_m_n_
,
arg
.
mean_var_count_grid_desc_m_n_
,
arg
.
mean_var_count_grid_desc_m_n
block
_
,
arg
.
gamma_grid_desc_n_
,
arg
.
beta_grid_desc_n_
,
arg
.
gemm_nblock_
,
numMeanVarCountBlockTileIteration_N
,
numEBlockTileIteration_N
,
arg
.
epsilon_
);
numNormBlockTileIteration_N
,
arg
.
epsilon_
,
arg
.
h_element_op_
);
return
avg_time
;
};
...
...
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
View file @
78ff5f81
...
...
@@ -47,7 +47,7 @@ template <typename ABDataType,
typename
BGridDesc_N_K
,
typename
DsGridDesc_M_N
,
typename
EGridDesc_M_N
,
typename
MeanVarCountGridDesc_M_N
,
typename
MeanVarCountGridDesc_M_N
Block
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
...
...
@@ -349,7 +349,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
using
MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
=
remove_cvref_t
<
decltype
(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
MeanVarCountGridDesc_M_N
{}))
>
;
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
MeanVarCountGridDesc_M_N
Block
{}))
>
;
using
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
...
...
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
View file @
78ff5f81
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment