Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
39dedce7
Commit
39dedce7
authored
Dec 13, 2022
by
rocking
Browse files
[What] Rename MakeMeanVarDescriptor_M_N
[Why] Prepare to add count version of make descriptor
parent
48c1b923
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
32 deletions
+24
-32
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
.../device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+24
-32
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
39dedce7
...
@@ -286,7 +286,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -286,7 +286,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
}
}
template
<
typename
LayOut
>
template
<
typename
LayOut
>
static
auto
MakeEGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
Stride
)
static
auto
MakeE
H
GridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
Stride
)
{
{
const
auto
grid_desc_mraw_nraw
=
[
&
]()
{
const
auto
grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
LayOut
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
LayOut
>::
value
)
...
@@ -310,26 +310,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -310,26 +310,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
MRaws
[
i
],
NRaws
[
i
],
DsStride
[
i
]);
return
DeviceOp
::
MakeE
H
GridDescriptor_M_N
<
DLayout
>
(
MRaws
[
i
],
NRaws
[
i
],
DsStride
[
i
]);
},
},
Number
<
NumDTensor
>
{});
Number
<
NumDTensor
>
{});
}
}
template
<
typename
LayOut
,
typename
DoPads
,
index_t
X
PerTile
,
index_t
Y
PerTile
>
template
<
typename
LayOut
,
typename
DoPads
,
index_t
M
PerTile
,
index_t
N
PerTile
>
static
auto
MakeDescriptor_
X_Y
(
index_t
X
,
index_t
Y
)
static
auto
Make
MeanVar
Descriptor_
M_N
(
index_t
M
,
index_t
N
)
{
{
const
auto
grid_desc_m_n
=
[
&
]()
{
const
auto
grid_desc_m_n
=
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
LayOut
>::
value
)
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
N
,
I1
));
{
return
PadTensorDescriptor
(
grid_desc_m_n
,
make_tuple
(
MPerTile
,
NPerTile
),
DoPads
{});
return
make_naive_tensor_descriptor
(
make_tuple
(
X
,
Y
),
make_tuple
(
Y
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
LayOut
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
X
,
Y
),
make_tuple
(
I1
,
X
));
}
}();
return
PadTensorDescriptor
(
grid_desc_m_n
,
make_tuple
(
XPerTile
,
YPerTile
),
DoPads
{});
}
}
template
<
index_t
XPerTile
>
template
<
index_t
XPerTile
>
...
@@ -344,17 +335,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -344,17 +335,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid
// layout(different padding)
// layout(different padding)
using
GemmMeanVarCountGridDesc_M_NBlock
=
using
GemmMeanVarCountGridDesc_M_NBlock
=
decltype
(
decltype
(
Make
Descriptor_
X_Y
<
HLayout
,
Sequence
<
true
,
false
>
,
MPerBlock
,
NPerBlock
>
(
1
,
1
));
MakeMeanVar
Descriptor_
M_N
<
HLayout
,
Sequence
<
true
,
false
>
,
MPerBlock
,
NPerBlock
>
(
1
,
1
));
using
LayernormMeanVarCountGridDesc_M_NBlock
=
using
LayernormMeanVarCountGridDesc_M_NBlock
=
decltype
(
MakeDescriptor_
X_Y
<
HLayout
,
decltype
(
Make
MeanVar
Descriptor_
M_N
<
HLayout
,
Sequence
<
true
,
true
>
,
Sequence
<
true
,
true
>
,
LayernormBlockTileSize_M_N
::
At
(
0
),
LayernormBlockTileSize_M_N
::
At
(
0
),
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
1
,
1
));
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
1
,
1
));
using
GammaBetaGridDesc_N
=
decltype
(
MakeDescriptor_X
<
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
1
));
using
GammaBetaGridDesc_N
=
decltype
(
MakeDescriptor_X
<
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
1
));
using
EHGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
<
HLayout
>
(
1
,
1
,
1
));
using
EHGridDesc_M_N
=
decltype
(
MakeE
H
GridDescriptor_M_N
<
HLayout
>
(
1
,
1
,
1
));
using
GridwiseGemmWelford
=
GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
<
using
GridwiseGemmWelford
=
GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
...
@@ -464,14 +455,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -464,14 +455,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
(
MRaw
,
KRaw
,
StrideA
)},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
KRaw
,
NRaw
,
StrideB
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
KRaw
,
NRaw
,
StrideB
)},
ds_grid_desc_m_n_
{},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
NRaw
,
StrideH
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeE
H
GridDescriptor_M_N
<
ELayout
>
(
MRaw
,
NRaw
,
StrideH
)},
gemm_mean_var_count_grid_desc_m_nblock_
{},
gemm_mean_var_count_grid_desc_m_nblock_
{},
layernorm_mean_var_count_grid_desc_m_nblock_
{},
layernorm_mean_var_count_grid_desc_m_nblock_
{},
gamma_grid_desc_n_
{
gamma_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_X
<
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
NRaw
)},
DeviceOp
::
MakeDescriptor_X
<
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
NRaw
)},
beta_grid_desc_n_
{
beta_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_X
<
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
NRaw
)},
DeviceOp
::
MakeDescriptor_X
<
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
NRaw
)},
h_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
HLayout
>
(
MRaw
,
NRaw
,
StrideH
)},
h_grid_desc_m_n_
{
DeviceOp
::
MakeE
H
GridDescriptor_M_N
<
HLayout
>
(
MRaw
,
NRaw
,
StrideH
)},
a_grid_desc_ak0_m_ak1_
{
a_grid_desc_ak0_m_ak1_
{
GridwiseGemmWelford
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
)},
GridwiseGemmWelford
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
)},
b_grid_desc_bk0_n_bk1_
{
b_grid_desc_bk0_n_bk1_
{
...
@@ -487,15 +478,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -487,15 +478,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
gemm_nblock_
{
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)},
gemm_nblock_
{
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)},
epsilon_
{
epsilon
}
epsilon_
{
epsilon
}
{
{
gemm_mean_var_count_grid_desc_m_nblock_
=
gemm_mean_var_count_grid_desc_m_nblock_
=
DeviceOp
::
DeviceOp
::
Make
Descriptor_
X_Y
<
HLayout
,
Sequence
<
true
,
false
>
,
MPerBlock
,
NPerBlock
>
(
MakeMeanVar
Descriptor_
M_N
<
HLayout
,
Sequence
<
true
,
false
>
,
MPerBlock
,
NPerBlock
>
(
MRaw
,
gemm_nblock_
);
MRaw
,
gemm_nblock_
);
layernorm_mean_var_count_grid_desc_m_nblock_
=
layernorm_mean_var_count_grid_desc_m_nblock_
=
DeviceOp
::
MakeDescriptor_X_Y
<
HLayout
,
DeviceOp
::
MakeMeanVarDescriptor_M_N
<
HLayout
,
Sequence
<
true
,
true
>
,
Sequence
<
true
,
true
>
,
LayernormBlockTileSize_M_N
::
At
(
0
),
LayernormBlockTileSize_M_N
::
At
(
0
),
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
MRaw
,
gemm_nblock_
);
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
MRaw
,
gemm_nblock_
);
// populate pointer, desc for Ds
// populate pointer, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
...
@@ -507,7 +499,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -507,7 +499,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
// D desc
// D desc
ds_grid_desc_m_n_
(
i
)
=
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
MRaw
,
NRaw
,
StrideDs
[
i
]);
DeviceOp
::
MakeE
H
GridDescriptor_M_N
<
DLayout
>
(
MRaw
,
NRaw
,
StrideDs
[
i
]);
});
});
// populate desc for Ds/E/F/G
// populate desc for Ds/E/F/G
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment