Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b58d5a7d
Commit
b58d5a7d
authored
Dec 07, 2022
by
rocking
Browse files
Fix bug of mean var padding for layernorm
parent
db0a27ad
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
19 deletions
+44
-19
example/21_gemm_layernorm/gemm_add_add_layernorm_xdl_fp16.cpp
...ple/21_gemm_layernorm/gemm_add_add_layernorm_xdl_fp16.cpp
+2
-1
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
.../device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+42
-18
No files found.
example/21_gemm_layernorm/gemm_add_add_layernorm_xdl_fp16.cpp
View file @
b58d5a7d
...
@@ -254,7 +254,8 @@ int main()
...
@@ -254,7 +254,8 @@ int main()
h_device_buf
.
FromDevice
(
h_m_n
.
mData
.
data
());
h_device_buf
.
FromDevice
(
h_m_n
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
e_m_n
,
e_m_n_host
);
pass
&=
ck
::
utils
::
check_err
(
e_m_n
,
e_m_n_host
);
pass
&=
ck
::
utils
::
check_err
(
h_m_n
,
h_m_n_host
);
pass
&=
ck
::
utils
::
check_err
(
h_m_n
,
h_m_n_host
,
"Error: Incorrect results h_m_n"
,
1e-2
,
1e-2
);
}
}
return
pass
?
0
:
1
;
return
pass
?
0
:
1
;
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
b58d5a7d
...
@@ -111,7 +111,7 @@ template <typename GridwiseWelfordLayernorm,
...
@@ -111,7 +111,7 @@ template <typename GridwiseWelfordLayernorm,
typename
BetaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
EHGridDesc_M_N
,
typename
EHGridDesc_M_N
,
typename
MeanVarCountGridDesc_M_NBlock
,
typename
Layernorm
MeanVarCountGridDesc_M_NBlock
,
typename
GammaBetaGridDesc_N
,
typename
GammaBetaGridDesc_N
,
typename
HElementwiseOperation
>
typename
HElementwiseOperation
>
__global__
void
__global__
void
...
@@ -128,7 +128,7 @@ __global__ void
...
@@ -128,7 +128,7 @@ __global__ void
HDataType
*
__restrict__
p_h_grid
,
HDataType
*
__restrict__
p_h_grid
,
const
EHGridDesc_M_N
e_grid_desc_m_n
,
const
EHGridDesc_M_N
e_grid_desc_m_n
,
const
EHGridDesc_M_N
h_grid_desc_m_n
,
const
EHGridDesc_M_N
h_grid_desc_m_n
,
const
MeanVarCountGridDesc_M_NBlock
mean_var_count_grid_desc_m_nblock
,
const
Layernorm
MeanVarCountGridDesc_M_NBlock
mean_var_count_grid_desc_m_nblock
,
const
GammaBetaGridDesc_N
gamma_grid_desc_n
,
const
GammaBetaGridDesc_N
gamma_grid_desc_n
,
const
GammaBetaGridDesc_N
beta_grid_desc_n
,
const
GammaBetaGridDesc_N
beta_grid_desc_n
,
index_t
numMeanVarCountBlockTileIteration_N
,
index_t
numMeanVarCountBlockTileIteration_N
,
...
@@ -314,14 +314,25 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -314,14 +314,25 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
Number
<
NumDTensor
>
{});
Number
<
NumDTensor
>
{});
}
}
static
auto
MakeMeanVarCountGridDescriptor_M_NBlock
(
index_t
M
,
index_t
NBlock
)
static
auto
Make
Gemm
MeanVarCountGridDescriptor_M_NBlock
(
index_t
M
,
index_t
NBlock
)
{
{
const
auto
grid_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
NBlock
));
const
auto
grid_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
NBlock
));
// TODO - padding according to MNperBlock of Gemm and Layernorm
// TODO - padding according to MNperBlock of Gemm
// CAUSION - GetWorkSpaceSize
return
grid_desc_m_n
;
return
grid_desc_m_n
;
}
}
static
auto
MakeLayernormMeanVarCountGridDescriptor_M_NBlock
(
index_t
M
,
index_t
NBlock
)
{
const
auto
grid_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
NBlock
));
return
PadTensorDescriptor
(
grid_desc_m_n
,
make_tuple
(
LayernormBlockTileSize_M_N
::
At
(
0
),
LayernormBlockTileSize_M_N
::
At
(
1
)),
Sequence
<
true
,
true
>
{});
}
static
auto
MakeDescriptor_M
(
index_t
MRaw
)
static
auto
MakeDescriptor_M
(
index_t
MRaw
)
{
{
const
auto
grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
const
auto
grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
...
@@ -375,9 +386,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -375,9 +386,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
(
1
,
1
,
1
));
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
(
1
,
1
,
1
));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
(
1
,
1
,
1
));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
(
1
,
1
,
1
));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
MeanVarCountGridDesc_M_NBlock
=
decltype
(
MakeMeanVarCountGridDescriptor_M_NBlock
(
1
,
1
));
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid
using
GammaBetaGridDesc_N
=
decltype
(
MakeDescriptor_N
(
1
));
// layout(different padding)
using
EHGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
<
HLayout
>
(
1
,
1
,
1
));
using
GemmMeanVarCountGridDesc_M_NBlock
=
decltype
(
MakeGemmMeanVarCountGridDescriptor_M_NBlock
(
1
,
1
));
using
LayernormMeanVarCountGridDesc_M_NBlock
=
decltype
(
MakeLayernormMeanVarCountGridDescriptor_M_NBlock
(
1
,
1
));
using
GammaBetaGridDesc_N
=
decltype
(
MakeDescriptor_N
(
1
));
using
EHGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
<
HLayout
>
(
1
,
1
,
1
));
using
GridwiseGemmWelford
=
GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
<
using
GridwiseGemmWelford
=
GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
...
@@ -395,7 +411,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -395,7 +411,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K
,
BGridDesc_N_K
,
DsGridDesc_M_N
,
DsGridDesc_M_N
,
EHGridDesc_M_N
,
EHGridDesc_M_N
,
MeanVarCountGridDesc_M_NBlock
,
Gemm
MeanVarCountGridDesc_M_NBlock
,
NumGemmKPrefetchStage
,
NumGemmKPrefetchStage
,
BlockSize
,
BlockSize
,
MPerBlock
,
MPerBlock
,
...
@@ -440,7 +456,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -440,7 +456,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType
,
BetaDataType
,
AccDataType
,
AccDataType
,
EHGridDesc_M_N
,
EHGridDesc_M_N
,
MeanVarCountGridDesc_M_NBlock
,
Layernorm
MeanVarCountGridDesc_M_NBlock
,
GammaBetaGridDesc_N
,
GammaBetaGridDesc_N
,
HElementwiseOperation
,
HElementwiseOperation
,
BlockSize
,
BlockSize
,
...
@@ -491,7 +507,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -491,7 +507,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
KRaw
,
NRaw
,
StrideB
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
KRaw
,
NRaw
,
StrideB
)},
ds_grid_desc_m_n_
{},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
NRaw
,
StrideH
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
NRaw
,
StrideH
)},
mean_var_count_grid_desc_m_nblock_
{},
gemm_mean_var_count_grid_desc_m_nblock_
{},
layernorm_mean_var_count_grid_desc_m_nblock_
{},
gamma_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
gamma_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
beta_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
beta_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
h_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
HLayout
>
(
MRaw
,
NRaw
,
StrideH
)},
h_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
HLayout
>
(
MRaw
,
NRaw
,
StrideH
)},
...
@@ -507,8 +524,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -507,8 +524,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
gemm_nblock_
{
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)},
gemm_nblock_
{
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)},
epsilon_
{
epsilon
}
epsilon_
{
epsilon
}
{
{
mean_var_count_grid_desc_m_nblock_
=
gemm_mean_var_count_grid_desc_m_nblock_
=
DeviceOp
::
MakeMeanVarCountGridDescriptor_M_NBlock
(
MRaw
,
gemm_nblock_
);
DeviceOp
::
MakeGemmMeanVarCountGridDescriptor_M_NBlock
(
MRaw
,
gemm_nblock_
);
layernorm_mean_var_count_grid_desc_m_nblock_
=
DeviceOp
::
MakeLayernormMeanVarCountGridDescriptor_M_NBlock
(
MRaw
,
gemm_nblock_
);
// populate pointer, desc for Ds
// populate pointer, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
...
@@ -540,7 +560,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -540,7 +560,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
mean_var_count_grid_desc_mblock_mperblock_nblock_
=
mean_var_count_grid_desc_mblock_mperblock_nblock_
=
GridwiseGemmWelford
::
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
GridwiseGemmWelford
::
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
mean_var_count_grid_desc_m_nblock_
);
gemm_
mean_var_count_grid_desc_m_nblock_
);
}
}
}
}
...
@@ -572,7 +592,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -572,7 +592,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K
b_grid_desc_n_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EHGridDesc_M_N
e_grid_desc_m_n_
;
EHGridDesc_M_N
e_grid_desc_m_n_
;
MeanVarCountGridDesc_M_NBlock
mean_var_count_grid_desc_m_nblock_
;
GemmMeanVarCountGridDesc_M_NBlock
gemm_mean_var_count_grid_desc_m_nblock_
;
LayernormMeanVarCountGridDesc_M_NBlock
layernorm_mean_var_count_grid_desc_m_nblock_
;
GammaBetaGridDesc_N
gamma_grid_desc_n_
;
GammaBetaGridDesc_N
gamma_grid_desc_n_
;
GammaBetaGridDesc_N
beta_grid_desc_n_
;
GammaBetaGridDesc_N
beta_grid_desc_n_
;
EHGridDesc_M_N
h_grid_desc_m_n_
;
EHGridDesc_M_N
h_grid_desc_m_n_
;
...
@@ -660,7 +681,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -660,7 +681,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType
,
BetaDataType
,
AccDataType
,
AccDataType
,
EHGridDesc_M_N
,
EHGridDesc_M_N
,
MeanVarCountGridDesc_M_NBlock
,
Layernorm
MeanVarCountGridDesc_M_NBlock
,
GammaBetaGridDesc_N
,
GammaBetaGridDesc_N
,
HElementwiseOperation
>
;
HElementwiseOperation
>
;
...
@@ -710,7 +731,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -710,7 +731,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg
.
p_h_grid_
,
arg
.
p_h_grid_
,
arg
.
e_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
h_grid_desc_m_n_
,
arg
.
h_grid_desc_m_n_
,
arg
.
mean_var_count_grid_desc_m_nblock_
,
arg
.
layernorm_
mean_var_count_grid_desc_m_nblock_
,
arg
.
gamma_grid_desc_n_
,
arg
.
gamma_grid_desc_n_
,
arg
.
beta_grid_desc_n_
,
arg
.
beta_grid_desc_n_
,
numMeanVarCountBlockTileIteration_N
,
numMeanVarCountBlockTileIteration_N
,
...
@@ -745,7 +766,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -745,7 +766,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
size_t
workspace_size
=
0
;
size_t
workspace_size
=
0
;
int
gemm_welford_size
=
pArg_
->
mean_var_count_grid_desc_m_nblock_
.
GetElementSpaceSize
();
// FIXME - padding
int
gemm_welford_size
=
pArg_
->
gemm_mean_var_count_grid_desc_m_nblock_
.
GetElementSpaceSize
();
// workspace for welford intermediate mean
// workspace for welford intermediate mean
workspace_size
+=
gemm_welford_size
*
sizeof
(
MeanDataType
)
+
64
;
workspace_size
+=
gemm_welford_size
*
sizeof
(
MeanDataType
)
+
64
;
...
@@ -765,7 +788,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -765,7 +788,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
pArg_
->
p_workspace_
=
p_workspace
;
pArg_
->
p_workspace_
=
p_workspace
;
int
gemm_welford_size
=
pArg_
->
mean_var_count_grid_desc_m_nblock_
.
GetElementSpaceSize
();
int
gemm_welford_size
=
pArg_
->
gemm_mean_var_count_grid_desc_m_nblock_
.
GetElementSpaceSize
();
// int gemm_welford_size = MRaw * pArg->gemm_nblock_;
// int gemm_welford_size = MRaw * pArg->gemm_nblock_;
// setup buffer used for intermediate welford mean
// setup buffer used for intermediate welford mean
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment