Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
14d29856
Commit
14d29856
authored
Dec 23, 2022
by
rocking
Browse files
Pad different size for E and H in layernorm kernel according to different block tile
parent
d78877a7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
58 additions
and
48 deletions
+58
-48
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
...ce/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+58
-48
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
14d29856
...
@@ -202,9 +202,9 @@ template <typename ALayout,
...
@@ -202,9 +202,9 @@ template <typename ALayout,
GemmSpecialization
GemmSpec
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
Gemm
MPerBlock
,
index_t
NPerBlock
,
index_t
Gemm
NPerBlock
,
index_t
KPerBlock
,
index_t
Gemm
KPerBlock
,
index_t
AK1
,
index_t
AK1
,
index_t
BK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
MPerXDL
,
...
@@ -249,8 +249,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
...
@@ -249,8 +249,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
CDEElementwiseOperation
,
CDEElementwiseOperation
,
HElementwiseOperation
>
HElementwiseOperation
>
{
{
using
DeviceOp
=
DeviceGemmMultipleDLayernorm_Xdl_CShuffle
;
using
DeviceOp
=
DeviceGemmMultipleDLayernorm_Xdl_CShuffle
;
using
ELayout
=
HLayout
;
using
ELayout
=
HLayout
;
// EDataType, MeanDataType and VarDataType must be the same.
// EDataType, MeanDataType and VarDataType must be the same.
// eg. M, N, K = [1, 1, 1],
// eg. M, N, K = [1, 1, 1],
// in case of layernorm, divisor = 1 / sqrt(var + 1e-5) = 316.227783
// in case of layernorm, divisor = 1 / sqrt(var + 1e-5) = 316.227783
...
@@ -274,8 +274,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
...
@@ -274,8 +274,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
matrix_padder
=
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
Gemm
MPerBlock
,
Gemm
NPerBlock
,
Gemm
KPerBlock
};
static
auto
MakeAGridDescriptor_M_K
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
static
auto
MakeAGridDescriptor_M_K
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
{
...
@@ -313,21 +313,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
...
@@ -313,21 +313,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
}
template
<
typename
LayOut
>
template
<
typename
DoPads
,
index_t
MPerTile
,
index_t
NPerTile
>
static
auto
MakeEHGridDescriptor_M_N
(
index_t
M
Raw
,
index_t
N
Raw
,
index_t
Stride
)
static
auto
MakeEHGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
Stride
)
{
{
const
auto
grid_desc_mraw_nraw
=
[
&
]()
{
// Only support row major for E and H
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
LayOut
>::
value
)
const
auto
grid_desc_m_n
=
{
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
Stride
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
Stride
,
I1
));
return
PadTensorDescriptor
(
grid_desc_m_n
,
make_tuple
(
MPerTile
,
NPerTile
),
DoPads
{});
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
LayOut
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
I1
,
Stride
));
}
}();
return
matrix_padder
.
PadCDescriptor_M_N
(
grid_desc_mraw_nraw
);
}
}
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NumDTensor
>&
MRaws
,
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NumDTensor
>&
MRaws
,
...
@@ -337,8 +329,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
...
@@ -337,8 +329,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
return
generate_tuple
(
return
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
static_assert
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
DLayout
>::
value
);
return
DeviceOp
::
MakeEHGridDescriptor_M_N
<
DLayout
>
(
MRaws
[
i
],
NRaws
[
i
],
DsStride
[
i
]);
return
DeviceOp
::
MakeEHGridDescriptor_M_N
<
Sequence
<
true
,
true
>
,
GemmMPerBlock
,
GemmNPerBlock
>
(
MRaws
[
i
],
NRaws
[
i
],
DsStride
[
i
]);
},
},
Number
<
NumDTensor
>
{});
Number
<
NumDTensor
>
{});
}
}
...
@@ -373,11 +368,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
...
@@ -373,11 +368,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid
// layout(different padding)
// layout(different padding)
using
GemmMeanVarGridDesc_M_NBlock
=
using
GemmMeanVarGridDesc_M_NBlock
=
decltype
(
decltype
(
MakeMeanVarDescriptor_M_N
<
Sequence
<
true
,
false
>
,
MPerBlock
,
NPerBlock
>
(
1
,
1
));
MakeMeanVarDescriptor_M_N
<
Sequence
<
true
,
false
>
,
Gemm
MPerBlock
,
Gemm
NPerBlock
>
(
1
,
1
));
using
GemmCountGridDesc_M_NBlock
=
using
GemmCountGridDesc_M_NBlock
=
decltype
(
decltype
(
MakeCountDescriptor_M_N
<
Sequence
<
true
,
false
>
,
MPerBlock
,
NPerBlock
>
(
1
,
1
));
MakeCountDescriptor_M_N
<
Sequence
<
true
,
false
>
,
Gemm
MPerBlock
,
Gemm
NPerBlock
>
(
1
,
1
));
using
LayernormMeanVarGridDesc_M_NBlock
=
using
LayernormMeanVarGridDesc_M_NBlock
=
decltype
(
MakeMeanVarDescriptor_M_N
<
Sequence
<
true
,
true
>
,
decltype
(
MakeMeanVarDescriptor_M_N
<
Sequence
<
true
,
true
>
,
...
@@ -390,7 +385,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
...
@@ -390,7 +385,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
1
,
1
));
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
1
,
1
));
using
GammaBetaGridDesc_N
=
decltype
(
MakeDescriptor_X
<
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
1
));
using
GammaBetaGridDesc_N
=
decltype
(
MakeDescriptor_X
<
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
1
));
using
EHGridDesc_M_N
=
decltype
(
MakeEHGridDescriptor_M_N
<
HLayout
>
(
1
,
1
,
1
));
using
EHGridDesc_M_N
=
decltype
(
MakeEHGridDescriptor_M_N
<
Sequence
<
true
,
true
>
,
1
,
1
>
(
1
,
1
,
1
));
using
GridwiseGemmWelford
=
GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
<
using
GridwiseGemmWelford
=
GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
...
@@ -412,9 +407,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
...
@@ -412,9 +407,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
GemmCountGridDesc_M_NBlock
,
GemmCountGridDesc_M_NBlock
,
NumGemmKPrefetchStage
,
NumGemmKPrefetchStage
,
BlockSize
,
BlockSize
,
MPerBlock
,
Gemm
MPerBlock
,
NPerBlock
,
Gemm
NPerBlock
,
KPerBlock
,
Gemm
KPerBlock
,
AK1
,
AK1
,
BK1
,
BK1
,
MPerXDL
,
MPerXDL
,
...
@@ -503,7 +498,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
...
@@ -503,7 +498,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
(
MRaw
,
KRaw
,
StrideA
)},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
KRaw
,
NRaw
,
StrideB
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
KRaw
,
NRaw
,
StrideB
)},
ds_grid_desc_m_n_
{},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEHGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
NRaw
,
StrideH
)},
gemm_e_grid_desc_m_n_
{
DeviceOp
::
MakeEHGridDescriptor_M_N
<
Sequence
<
true
,
true
>
,
GemmMPerBlock
,
GemmNPerBlock
>
(
MRaw
,
NRaw
,
StrideH
)},
layernorm_e_grid_desc_m_n_
{
DeviceOp
::
MakeEHGridDescriptor_M_N
<
Sequence
<
true
,
true
>
,
LayernormBlockTileSize_M_N
::
At
(
0
),
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
MRaw
,
NRaw
,
StrideH
)},
gemm_mean_var_grid_desc_m_nblock_
{},
gemm_mean_var_grid_desc_m_nblock_
{},
gemm_count_grid_desc_m_nblock_
{},
gemm_count_grid_desc_m_nblock_
{},
layernorm_mean_var_grid_desc_m_nblock_
{},
layernorm_mean_var_grid_desc_m_nblock_
{},
...
@@ -512,12 +515,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
...
@@ -512,12 +515,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
DeviceOp
::
MakeDescriptor_X
<
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
NRaw
)},
DeviceOp
::
MakeDescriptor_X
<
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
NRaw
)},
beta_grid_desc_n_
{
beta_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_X
<
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
NRaw
)},
DeviceOp
::
MakeDescriptor_X
<
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
NRaw
)},
h_grid_desc_m_n_
{
DeviceOp
::
MakeEHGridDescriptor_M_N
<
HLayout
>
(
MRaw
,
NRaw
,
StrideH
)},
h_grid_desc_m_n_
{
DeviceOp
::
MakeEHGridDescriptor_M_N
<
Sequence
<
true
,
true
>
,
LayernormBlockTileSize_M_N
::
At
(
0
),
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
MRaw
,
NRaw
,
StrideH
)},
a_grid_desc_ak0_m_ak1_
{
a_grid_desc_ak0_m_ak1_
{
GridwiseGemmWelford
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
)},
GridwiseGemmWelford
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
)},
b_grid_desc_bk0_n_bk1_
{
b_grid_desc_bk0_n_bk1_
{
GridwiseGemmWelford
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k_
)},
GridwiseGemmWelford
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k_
)},
block_2_etile_map_
{
GridwiseGemmWelford
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
block_2_etile_map_
{
GridwiseGemmWelford
::
MakeDefaultBlock2ETileMap
(
gemm_e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
cde_element_op_
{
cde_element_op
},
...
@@ -525,16 +533,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
...
@@ -525,16 +533,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
MRaw_
{
MRaw
},
MRaw_
{
MRaw
},
NRaw_
{
NRaw
},
NRaw_
{
NRaw
},
KRaw_
{
KRaw
},
KRaw_
{
KRaw
},
gemm_nblock_
{
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)},
gemm_nblock_
{
math
::
integer_divide_ceil
(
NRaw
,
Gemm
NPerBlock
)},
epsilon_
{
static_cast
<
AccDataType
>
(
epsilon
)}
epsilon_
{
static_cast
<
AccDataType
>
(
epsilon
)}
{
{
// We don't need to pad in N dimension in gemm for mean/var/count. Set NPerTile 1.
// We don't need to pad in N dimension in gemm for mean/var/count. Set NPerTile 1.
gemm_mean_var_grid_desc_m_nblock_
=
gemm_mean_var_grid_desc_m_nblock_
=
DeviceOp
::
MakeMeanVarDescriptor_M_N
<
Sequence
<
true
,
false
>
,
MPerBlock
,
1
>
(
DeviceOp
::
MakeMeanVarDescriptor_M_N
<
Sequence
<
true
,
false
>
,
Gemm
MPerBlock
,
1
>
(
MRaw
,
gemm_nblock_
);
MRaw
,
gemm_nblock_
);
gemm_count_grid_desc_m_nblock_
=
gemm_count_grid_desc_m_nblock_
=
DeviceOp
::
MakeCountDescriptor_M_N
<
Sequence
<
true
,
false
>
,
MPerBlock
,
1
>
(
DeviceOp
::
MakeCountDescriptor_M_N
<
Sequence
<
true
,
false
>
,
Gemm
MPerBlock
,
1
>
(
MRaw
,
gemm_nblock_
);
MRaw
,
gemm_nblock_
);
layernorm_mean_var_grid_desc_m_nblock_
=
layernorm_mean_var_grid_desc_m_nblock_
=
...
@@ -551,7 +559,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
...
@@ -551,7 +559,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
// populate pointer, desc for Ds
// populate pointer, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
// D pointer
...
@@ -559,14 +566,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
...
@@ -559,14 +566,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
// D desc
// D desc
ds_grid_desc_m_n_
(
i
)
=
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEHGridDescriptor_M_N
<
DLayout
>
(
MRaw
,
NRaw
,
StrideDs
[
i
]);
DeviceOp
::
MakeEHGridDescriptor_M_N
<
Sequence
<
true
,
true
>
,
GemmMPerBlock
,
GemmNPerBlock
>
(
MRaw
,
NRaw
,
StrideDs
[
i
]);
});
});
// populate desc for Ds/E/mean/var/count
// populate desc for Ds/E/mean/var/count
if
(
GridwiseGemmWelford
::
CheckValidity
(
a_grid_desc_m_k_
,
if
(
GridwiseGemmWelford
::
CheckValidity
(
a_grid_desc_m_k_
,
b_grid_desc_n_k_
,
b_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
gemm_
e_grid_desc_m_n_
,
block_2_etile_map_
))
block_2_etile_map_
))
{
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
...
@@ -575,7 +584,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
...
@@ -575,7 +584,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemmWelford
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemmWelford
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
gemm_
e_grid_desc_m_n_
);
gemm_mean_var_grid_desc_mblock_mperblock_nblock_
=
gemm_mean_var_grid_desc_mblock_mperblock_nblock_
=
GridwiseGemmWelford
::
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
GridwiseGemmWelford
::
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
...
@@ -593,7 +602,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
...
@@ -593,7 +602,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
std
::
cout
<<
"B[N, K]: "
<<
b_grid_desc_n_k_
<<
std
::
endl
;
std
::
cout
<<
"B[N, K]: "
<<
b_grid_desc_n_k_
<<
std
::
endl
;
static_for
<
0
,
NumDTensor
,
1
>
{}(
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
std
::
cout
<<
"Ds[M, N]: "
<<
ds_grid_desc_m_n_
[
i
]
<<
std
::
endl
;
});
[
&
](
auto
i
)
{
std
::
cout
<<
"Ds[M, N]: "
<<
ds_grid_desc_m_n_
[
i
]
<<
std
::
endl
;
});
std
::
cout
<<
"E[M, N]: "
<<
e_grid_desc_m_n_
<<
std
::
endl
;
std
::
cout
<<
"E[M, N]: "
<<
gemm_
e_grid_desc_m_n_
<<
std
::
endl
;
std
::
cout
<<
"H[M, N]: "
<<
h_grid_desc_m_n_
<<
std
::
endl
;
std
::
cout
<<
"H[M, N]: "
<<
h_grid_desc_m_n_
<<
std
::
endl
;
}
}
...
@@ -614,7 +623,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
...
@@ -614,7 +623,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
AGridDesc_M_K
a_grid_desc_m_k_
;
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EHGridDesc_M_N
e_grid_desc_m_n_
;
EHGridDesc_M_N
gemm_e_grid_desc_m_n_
;
EHGridDesc_M_N
layernorm_e_grid_desc_m_n_
;
GemmMeanVarGridDesc_M_NBlock
gemm_mean_var_grid_desc_m_nblock_
;
GemmMeanVarGridDesc_M_NBlock
gemm_mean_var_grid_desc_m_nblock_
;
GemmCountGridDesc_M_NBlock
gemm_count_grid_desc_m_nblock_
;
GemmCountGridDesc_M_NBlock
gemm_count_grid_desc_m_nblock_
;
LayernormMeanVarGridDesc_M_NBlock
layernorm_mean_var_grid_desc_m_nblock_
;
LayernormMeanVarGridDesc_M_NBlock
layernorm_mean_var_grid_desc_m_nblock_
;
...
@@ -663,13 +673,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
...
@@ -663,13 +673,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
if
(
!
GridwiseGemmWelford
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
if
(
!
GridwiseGemmWelford
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
gemm_
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
arg
.
block_2_etile_map_
))
{
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmWelford has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemmWelford has invalid setting"
);
}
}
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
gemm_
e_grid_desc_m_n_
);
const
auto
M
=
arg
.
h_grid_desc_m_n_
.
GetLength
(
I0
);
const
auto
M
=
arg
.
h_grid_desc_m_n_
.
GetLength
(
I0
);
const
auto
N
=
arg
.
h_grid_desc_m_n_
.
GetLength
(
I1
);
const
auto
N
=
arg
.
h_grid_desc_m_n_
.
GetLength
(
I1
);
...
@@ -763,7 +773,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
...
@@ -763,7 +773,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
arg
.
p_gamma_grid_
,
arg
.
p_gamma_grid_
,
arg
.
p_beta_grid_
,
arg
.
p_beta_grid_
,
arg
.
p_h_grid_
,
arg
.
p_h_grid_
,
arg
.
e_grid_desc_m_n_
,
arg
.
layernorm_
e_grid_desc_m_n_
,
arg
.
h_grid_desc_m_n_
,
arg
.
h_grid_desc_m_n_
,
arg
.
layernorm_mean_var_grid_desc_m_nblock_
,
arg
.
layernorm_mean_var_grid_desc_m_nblock_
,
arg
.
layernorm_count_grid_desc_m_nblock_
,
arg
.
layernorm_count_grid_desc_m_nblock_
,
...
@@ -1043,9 +1053,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
...
@@ -1043,9 +1053,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
str
<<
"DeviceGemmMultipleDLayernorm_Xdl_CShuffle"
str
<<
"DeviceGemmMultipleDLayernorm_Xdl_CShuffle"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
Gemm
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
Gemm
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
Gemm
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
BK1
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
getGemmSpecializationString
(
GemmSpec
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment