Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5e215d49
Commit
5e215d49
authored
Dec 12, 2022
by
rocking
Browse files
Refine the MakeDescriptor
parent
328cc7f4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
81 deletions
+27
-81
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
.../device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+27
-81
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
5e215d49
...
...
@@ -315,104 +315,45 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
Number
<
NumDTensor
>
{});
}
template
<
typename
LayOut
>
static
auto
Make
GemmMeanVarCountGrid
Descriptor_
M_NBlock
(
index_t
M
,
index_t
NBlock
)
template
<
typename
LayOut
,
typename
DoPads
,
index_t
XPerTile
,
index_t
YPerTile
>
static
auto
MakeDescriptor_
X_Y
(
index_t
X
,
index_t
Y
)
{
const
auto
grid_desc_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
LayOut
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
NBlock
),
make_tuple
(
NBlock
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
X
,
Y
),
make_tuple
(
Y
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
LayOut
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
NBlock
),
make_tuple
(
I1
,
M
));
return
make_naive_tensor_descriptor
(
make_tuple
(
X
,
Y
),
make_tuple
(
I1
,
X
));
}
}();
return
PadTensorDescriptor
(
grid_desc_m_n
,
make_tuple
(
MPerBlock
,
NBlock
),
Sequence
<
true
,
false
>
{});
return
PadTensorDescriptor
(
grid_desc_m_n
,
make_tuple
(
XPerTile
,
YPerTile
),
DoPads
{});
}
template
<
typename
LayOut
>
static
auto
Make
LayernormMeanVarCountGrid
Descriptor_
M_NBlock
(
index_t
M
,
index_t
NBlock
)
template
<
index_t
XPerTile
>
static
auto
MakeDescriptor_
X
(
index_t
X
)
{
const
auto
grid_desc_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
LayOut
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
NBlock
),
make_tuple
(
NBlock
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
LayOut
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
NBlock
),
make_tuple
(
I1
,
M
));
}
}();
return
PadTensorDescriptor
(
grid_desc_m_n
,
make_tuple
(
LayernormBlockTileSize_M_N
::
At
(
0
),
LayernormBlockTileSize_M_N
::
At
(
1
)),
Sequence
<
true
,
true
>
{});
const
auto
grid_desc_x
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
X
));
return
PadTensorDescriptor
(
grid_desc_x
,
make_tuple
(
XPerTile
),
Sequence
<
true
>
{});
}
static
auto
MakeDescriptor_M
(
index_t
MRaw
)
{
const
auto
grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M
return
transform_tensor_descriptor
(
grid_desc_mraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
}
else
{
// not pad N
return
grid_desc_mraw
;
}
};
static
auto
MakeDescriptor_N
(
index_t
NRaw
)
{
const
auto
grid_desc_nraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
NRaw
));
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad N
return
transform_tensor_descriptor
(
grid_desc_nraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
}
else
{
// not pad N
return
grid_desc_nraw
;
}
};
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
(
1
,
1
,
1
));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
(
1
,
1
,
1
));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid
// layout(different padding)
using
GemmMeanVarCountGridDesc_M_NBlock
=
decltype
(
MakeGemmMeanVarCountGridDescriptor_M_NBlock
<
HLayout
>
(
1
,
1
));
decltype
(
MakeDescriptor_X_Y
<
HLayout
,
Sequence
<
true
,
false
>
,
MPerBlock
,
NPerBlock
>
(
1
,
1
));
using
LayernormMeanVarCountGridDesc_M_NBlock
=
decltype
(
MakeLayernormMeanVarCountGridDescriptor_M_NBlock
<
HLayout
>
(
1
,
1
));
using
GammaBetaGridDesc_N
=
decltype
(
MakeDescriptor_N
(
1
));
decltype
(
MakeDescriptor_X_Y
<
HLayout
,
Sequence
<
true
,
true
>
,
LayernormBlockTileSize_M_N
::
At
(
0
),
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
1
,
1
));
using
GammaBetaGridDesc_N
=
decltype
(
MakeDescriptor_X
<
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
1
));
using
EHGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
<
HLayout
>
(
1
,
1
,
1
));
using
GridwiseGemmWelford
=
GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
<
...
...
@@ -526,8 +467,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
NRaw
,
StrideH
)},
gemm_mean_var_count_grid_desc_m_nblock_
{},
layernorm_mean_var_count_grid_desc_m_nblock_
{},
gamma_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
beta_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
gamma_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_X
<
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
NRaw
)},
beta_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_X
<
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
NRaw
)},
h_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
HLayout
>
(
MRaw
,
NRaw
,
StrideH
)},
a_grid_desc_ak0_m_ak1_
{
GridwiseGemmWelford
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
)},
...
...
@@ -545,11 +488,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
epsilon_
{
epsilon
}
{
gemm_mean_var_count_grid_desc_m_nblock_
=
DeviceOp
::
MakeGemmMeanVarCountGridDescriptor_M_NBlock
<
HLayout
>
(
MRaw
,
gemm_nblock_
);
DeviceOp
::
MakeDescriptor_X_Y
<
HLayout
,
Sequence
<
true
,
false
>
,
MPerBlock
,
NPerBlock
>
(
MRaw
,
gemm_nblock_
);
layernorm_mean_var_count_grid_desc_m_nblock_
=
DeviceOp
::
MakeLayernormMeanVarCountGridDescriptor_M_NBlock
<
HLayout
>
(
MRaw
,
gemm_nblock_
);
DeviceOp
::
MakeDescriptor_X_Y
<
HLayout
,
Sequence
<
true
,
true
>
,
LayernormBlockTileSize_M_N
::
At
(
0
),
LayernormBlockTileSize_M_N
::
At
(
1
)
>
(
MRaw
,
gemm_nblock_
);
// populate pointer, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment