Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c0bfcf91
Commit
c0bfcf91
authored
Jul 24, 2022
by
Chao Liu
Browse files
adding group
parent
19173ab7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
174 additions
and
6 deletions
+174
-6
include/ck/tensor_operation/gpu/device/device_conv_fwd_multiple_d_xdl_cshuffle.hpp
...on/gpu/device/device_conv_fwd_multiple_d_xdl_cshuffle.hpp
+174
-6
No files found.
include/ck/tensor_operation/gpu/device/device_conv_fwd_multiple_d_xdl_cshuffle.hpp
View file @
c0bfcf91
...
...
@@ -320,12 +320,12 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
GemmMRaw
=
N
*
std
::
accumulate
(
e_g_n_k_wos_lengths
.
begin
()
+
3
,
e_g_n_k_wos_lengths
.
begin
()
+
4
,
e_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
GemmKRaw
=
C
*
std
::
accumulate
(
b_g_k_c_xs_lengths
.
begin
()
+
3
,
b_g_k_c_xs_lengths
.
begin
()
+
4
,
b_g_k_c_xs_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
...
...
@@ -432,12 +432,12 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
GemmMRaw
=
N
*
std
::
accumulate
(
e_g_n_k_wos_lengths
.
begin
()
+
3
,
e_g_n_k_wos_lengths
.
begin
()
+
5
,
e_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
GemmKRaw
=
C
*
std
::
accumulate
(
b_g_k_c_xs_lengths
.
begin
()
+
3
,
b_g_k_c_xs_lengths
.
begin
()
+
5
,
b_g_k_c_xs_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
...
...
@@ -538,6 +538,146 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
}
template
<
typename
ALay
,
typename
std
::
enable_if
<
NDimSpatial
==
2
&&
is_same_v
<
ALay
,
tensor_layout
::
convolution
::
G_N_HW_C
>,
bool
>::
type
=
false
>
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
const
index_t
N
=
a_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
GemmMRaw
=
N
*
std
::
accumulate
(
e_g_n_k_wos_lengths
.
begin
()
+
3
,
e_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
GemmKRaw
=
C
*
std
::
accumulate
(
b_g_k_c_xs_lengths
.
begin
()
+
3
,
b_g_k_c_xs_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
Hi
=
a_g_n_c_wis_lengths
[
3
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
4
];
const
index_t
Ho
=
e_g_n_k_wos_lengths
[
3
];
const
index_t
Wo
=
e_g_n_k_wos_lengths
[
4
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
// This is different
const
index_t
CStride
=
a_g_n_c_wis_strides
[
2
];
const
index_t
WStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
const
auto
in_gemmmraw_gemmkraw_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
GemmMRaw
,
GemmKRaw
),
make_tuple
(
WStride
,
CStride
);
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_grid_desc
);
return
in_gemmm_gemmk_grid_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
// This is different
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
a_g_n_c_wis_srides
[
1
],
a_g_n_c_wis_srides
[
3
],
a_g_n_c_wis_srides
[
4
],
a_g_n_c_wis_srides
[
2
]));
const
auto
in_n_ho_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmmraw_gemmk_grid_desc
=
transform_tensor_descriptor
(
in_n_ho_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmk_grid_desc
);
return
in_gemmm_gemmk_grid_desc
;
}
else
{
const
index_t
Y
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
4
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
// This is different
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
a_g_n_c_wis_srides
[
1
],
a_g_n_c_wis_srides
[
3
],
a_g_n_c_wis_srides
[
4
],
a_g_n_c_wis_srides
[
2
]));
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmmraw_gemmk_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmk_grid_desc
);
return
in_gemmm_gemmk_grid_desc
;
}
}
template
<
typename
ALay
,
typename
std
::
enable_if
<
NDimSpatial
==
3
&&
is_same_v
<
ALay
,
tensor_layout
::
convolution
::
NDHWC
>,
...
...
@@ -558,12 +698,12 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
GemmMRaw
=
N
*
std
::
accumulate
(
e_g_n_k_wos_lengths
.
begin
()
+
3
,
e_g_n_k_wos_lengths
.
begin
()
+
6
,
e_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
GemmKRaw
=
C
*
std
::
accumulate
(
b_g_k_c_xs_lengths
.
begin
()
+
3
,
b_g_k_c_xs_lengths
.
begin
()
+
6
,
b_g_k_c_xs_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
...
...
@@ -711,6 +851,34 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return
wei_gemmn_gemmk_grid_desc
;
}
template
<
typename
BLay
,
typename
std
::
enable_if
<
is_same_v
<
BLay
,
tensor_layout
::
convolution
::
G_K_X_C
>
||
is_same_v
<
BLay
,
tensor_layout
::
convolution
::
G_K_YX_C
>
||
is_same_v
<
BLay
,
tensor_layout
::
convolution
::
G_K_ZYX_C
>
,
bool
>::
type
=
false
>
static
auto
MakeBGridDescriptor_N_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
)
{
const
index_t
K
=
b_g_k_c_xs_lengths
[
1
];
const
index_t
C
=
b_g_k_c_xs_lengths
[
2
];
const
index_t
GemmNRaw
=
K
;
const
index_t
GemmKRaw
=
C
*
std
::
accumulate
(
b_g_k_c_xs_lengths
.
begin
()
+
3
,
b_g_k_c_xs_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
auto
wei_k_yxc_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
GemmNRaw
,
GemmKRaw
));
const
auto
wei_gemmn_gemmk_grid_desc
=
matrix_padder
.
PadBDescriptor_N_K
(
wei_k_yxc_grid_desc
);
return
wei_gemmn_gemmk_grid_desc
;
}
template
<
typename
ELay
,
typename
std
::
enable_if
<
is_same_v
<
ELay
,
tensor_layout
::
convolution
::
NWK
>
||
is_same_v
<
ELay
,
tensor_layout
::
convolution
::
NHWK
>
||
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment