Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
eaeef340
Commit
eaeef340
authored
Dec 08, 2022
by
rocking
Browse files
Add layout parameter
parent
27b19e34
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
6 deletions
+27
-6
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
.../device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+27
-6
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
eaeef340
...
@@ -314,17 +314,37 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -314,17 +314,37 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
Number
<
NumDTensor
>
{});
Number
<
NumDTensor
>
{});
}
}
template
<
typename
LayOut
>
static
auto
MakeGemmMeanVarCountGridDescriptor_M_NBlock
(
index_t
M
,
index_t
NBlock
)
static
auto
MakeGemmMeanVarCountGridDescriptor_M_NBlock
(
index_t
M
,
index_t
NBlock
)
{
{
const
auto
grid_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
NBlock
));
const
auto
grid_desc_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
LayOut
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
NBlock
),
make_tuple
(
NBlock
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
LayOut
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
NBlock
),
make_tuple
(
I1
,
M
));
}
}();
return
PadTensorDescriptor
(
return
PadTensorDescriptor
(
grid_desc_m_n
,
make_tuple
(
MPerBlock
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
grid_desc_m_n
,
make_tuple
(
MPerBlock
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
}
}
template
<
typename
LayOut
>
static
auto
MakeLayernormMeanVarCountGridDescriptor_M_NBlock
(
index_t
M
,
index_t
NBlock
)
static
auto
MakeLayernormMeanVarCountGridDescriptor_M_NBlock
(
index_t
M
,
index_t
NBlock
)
{
{
const
auto
grid_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
NBlock
));
const
auto
grid_desc_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
LayOut
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
NBlock
),
make_tuple
(
NBlock
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
LayOut
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
NBlock
),
make_tuple
(
I1
,
M
));
}
}();
return
PadTensorDescriptor
(
return
PadTensorDescriptor
(
grid_desc_m_n
,
grid_desc_m_n
,
...
@@ -388,9 +408,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -388,9 +408,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid
// layout(different padding)
// layout(different padding)
using
GemmMeanVarCountGridDesc_M_NBlock
=
using
GemmMeanVarCountGridDesc_M_NBlock
=
decltype
(
MakeGemmMeanVarCountGridDescriptor_M_NBlock
(
1
,
1
));
decltype
(
MakeGemmMeanVarCountGridDescriptor_M_NBlock
<
HLayout
>
(
1
,
1
));
using
LayernormMeanVarCountGridDesc_M_NBlock
=
using
LayernormMeanVarCountGridDesc_M_NBlock
=
decltype
(
MakeLayernormMeanVarCountGridDescriptor_M_NBlock
(
1
,
1
));
decltype
(
MakeLayernormMeanVarCountGridDescriptor_M_NBlock
<
HLayout
>
(
1
,
1
));
using
GammaBetaGridDesc_N
=
decltype
(
MakeDescriptor_N
(
1
));
using
GammaBetaGridDesc_N
=
decltype
(
MakeDescriptor_N
(
1
));
using
EHGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
<
HLayout
>
(
1
,
1
,
1
));
using
EHGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
<
HLayout
>
(
1
,
1
,
1
));
...
@@ -525,10 +545,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -525,10 +545,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
epsilon_
{
epsilon
}
epsilon_
{
epsilon
}
{
{
gemm_mean_var_count_grid_desc_m_nblock_
=
gemm_mean_var_count_grid_desc_m_nblock_
=
DeviceOp
::
MakeGemmMeanVarCountGridDescriptor_M_NBlock
(
MRaw
,
gemm_nblock_
);
DeviceOp
::
MakeGemmMeanVarCountGridDescriptor_M_NBlock
<
HLayout
>
(
MRaw
,
gemm_nblock_
);
layernorm_mean_var_count_grid_desc_m_nblock_
=
layernorm_mean_var_count_grid_desc_m_nblock_
=
DeviceOp
::
MakeLayernormMeanVarCountGridDescriptor_M_NBlock
(
MRaw
,
gemm_nblock_
);
DeviceOp
::
MakeLayernormMeanVarCountGridDescriptor_M_NBlock
<
HLayout
>
(
MRaw
,
gemm_nblock_
);
// populate pointer, desc for Ds
// populate pointer, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment