Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e67b958b
"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "579f84c6a004be53d5948d680948fb95bb0571cc"
Commit
e67b958b
authored
Jul 06, 2023
by
rocking
Browse files
rename channel to c from k
parent
529f2507
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
66 additions
and
66 deletions
+66
-66
include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_impl.hpp
...r_operation/gpu/device/impl/device_avgpool3d_bwd_impl.hpp
+66
-66
No files found.
include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_impl.hpp
View file @
e67b958b
...
@@ -37,10 +37,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
...
@@ -37,10 +37,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
static
constexpr
ck
::
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
ck
::
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
auto
static
auto
Make3DGridDescriptor_Out_M_K_In_M
(
const
std
::
vector
<
ck
::
index_t
>&
dout_n_
k
_wos_lengths
,
Make3DGridDescriptor_Out_M_K_In_M
(
const
std
::
vector
<
ck
::
index_t
>&
dout_n_
c
_wos_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
din_n_
k
_wos_length
,
const
std
::
vector
<
ck
::
index_t
>&
din_n_
c
_wos_length
,
const
std
::
vector
<
ck
::
index_t
>&
dout_n_
k
_wos_strides
,
const
std
::
vector
<
ck
::
index_t
>&
dout_n_
c
_wos_strides
,
const
std
::
vector
<
ck
::
index_t
>&
din_n_
k
_wos_strides
,
const
std
::
vector
<
ck
::
index_t
>&
din_n_
c
_wos_strides
,
const
std
::
vector
<
ck
::
index_t
>&
window_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
window_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
window_strides
,
const
std
::
vector
<
ck
::
index_t
>&
window_strides
,
const
std
::
vector
<
ck
::
index_t
>&
window_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
window_dilations
,
...
@@ -52,16 +52,16 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
...
@@ -52,16 +52,16 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
index_t
i_ytilde
=
tildes
[
1
];
index_t
i_ytilde
=
tildes
[
1
];
index_t
i_xtilde
=
tildes
[
2
];
index_t
i_xtilde
=
tildes
[
2
];
const
index_t
N
=
dout_n_
k
_wos_lengths
[
0
];
const
index_t
N
=
dout_n_
c
_wos_lengths
[
0
];
const
index_t
K
=
dout_n_
k
_wos_lengths
[
1
];
const
index_t
C
=
dout_n_
c
_wos_lengths
[
1
];
const
index_t
Di
=
din_n_
k
_wos_length
[
2
];
const
index_t
Di
=
din_n_
c
_wos_length
[
2
];
const
index_t
Hi
=
din_n_
k
_wos_length
[
3
];
const
index_t
Hi
=
din_n_
c
_wos_length
[
3
];
const
index_t
Wi
=
din_n_
k
_wos_length
[
4
];
const
index_t
Wi
=
din_n_
c
_wos_length
[
4
];
const
index_t
Do
=
dout_n_
k
_wos_lengths
[
2
];
const
index_t
Do
=
dout_n_
c
_wos_lengths
[
2
];
const
index_t
Ho
=
dout_n_
k
_wos_lengths
[
3
];
const
index_t
Ho
=
dout_n_
c
_wos_lengths
[
3
];
const
index_t
Wo
=
dout_n_
k
_wos_lengths
[
4
];
const
index_t
Wo
=
dout_n_
c
_wos_lengths
[
4
];
const
index_t
Z
=
window_lengths
[
0
];
const
index_t
Z
=
window_lengths
[
0
];
const
index_t
Y
=
window_lengths
[
1
];
const
index_t
Y
=
window_lengths
[
1
];
...
@@ -83,13 +83,13 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
...
@@ -83,13 +83,13 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
const
index_t
ConvDilationH
=
window_dilations
[
1
];
const
index_t
ConvDilationH
=
window_dilations
[
1
];
const
index_t
ConvDilationW
=
window_dilations
[
2
];
const
index_t
ConvDilationW
=
window_dilations
[
2
];
const
auto
out_n_do_ho_wo_
k
_grid_desc
=
const
auto
out_n_do_ho_wo_
c
_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
,
K
),
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
,
C
),
make_tuple
(
dout_n_
k
_wos_strides
[
0
],
make_tuple
(
dout_n_
c
_wos_strides
[
0
],
dout_n_
k
_wos_strides
[
2
],
dout_n_
c
_wos_strides
[
2
],
dout_n_
k
_wos_strides
[
3
],
dout_n_
c
_wos_strides
[
3
],
dout_n_
k
_wos_strides
[
4
],
dout_n_
c
_wos_strides
[
4
],
dout_n_
k
_wos_strides
[
1
]));
dout_n_
c
_wos_strides
[
1
]));
const
auto
GcdStrideDilationD
=
math
::
gcd
(
ConvStrideD
,
ConvDilationD
);
const
auto
GcdStrideDilationD
=
math
::
gcd
(
ConvStrideD
,
ConvDilationD
);
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
...
@@ -132,19 +132,19 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
...
@@ -132,19 +132,19 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
// Out[ReduceM, ReduceK]
// Out[ReduceM, ReduceK]
const
auto
out_n_dop_hop_wop_
k
_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_n_dop_hop_wop_
c
_grid_desc
=
transform_tensor_descriptor
(
out_n_do_ho_wo_
k
_grid_desc
,
out_n_do_ho_wo_
c
_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Do
,
I0
,
I0
),
make_pad_transform
(
Do
,
I0
,
I0
),
make_pad_transform
(
Ho
,
I0
,
I0
),
make_pad_transform
(
Ho
,
I0
,
I0
),
make_pad_transform
(
Wo
,
I0
,
I0
),
make_pad_transform
(
Wo
,
I0
,
I0
),
make_pass_through_transform
(
K
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_
k
_grid_desc
=
const
auto
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_
c
_grid_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
out_n_dop_hop_wop_
k
_grid_desc
,
out_n_dop_hop_wop_
c
_grid_desc
,
make_tuple
(
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
ZDot
,
DTilde
),
make_embed_transform
(
make_tuple
(
ZDot
,
DTilde
),
...
@@ -153,7 +153,7 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
...
@@ -153,7 +153,7 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
make_tuple
(
-
ConvDilationH
/
GcdStrideDilationH
,
I1
)),
make_tuple
(
-
ConvDilationH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
WTilde
),
make_embed_transform
(
make_tuple
(
XDot
,
WTilde
),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
K
)),
make_pass_through_transform
(
C
)),
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
...
@@ -163,9 +163,9 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
...
@@ -163,9 +163,9 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
Sequence
<
7
>
{}));
Sequence
<
7
>
{}));
const
auto
const
auto
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_
k
_grid_desc
=
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_
c
_grid_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_
k
_grid_desc
,
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_
c
_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_slice_transform
(
ZDot
,
I0
,
ZDotSlice
),
make_slice_transform
(
ZDot
,
I0
,
ZDotSlice
),
make_slice_transform
(
DTilde
,
IDTildeSliceBegin
,
DTildeSlice
),
make_slice_transform
(
DTilde
,
IDTildeSliceBegin
,
DTildeSlice
),
...
@@ -173,7 +173,7 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
...
@@ -173,7 +173,7 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_pass_through_transform
(
K
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
2
>
{},
...
@@ -192,14 +192,14 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
...
@@ -192,14 +192,14 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
Sequence
<
7
>
{}));
Sequence
<
7
>
{}));
const
auto
out_grid_desc_reducemraw_reducekraw
=
transform_tensor_descriptor
(
const
auto
out_grid_desc_reducemraw_reducekraw
=
transform_tensor_descriptor
(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_
k
_grid_desc
,
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_
c
_grid_desc
,
make_tuple
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
DTildeSlice
,
HTildeSlice
,
WTildeSlice
,
K
)),
make_merge_transform
(
make_tuple
(
N
,
DTildeSlice
,
HTildeSlice
,
WTildeSlice
,
C
)),
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
))),
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
,
7
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
,
7
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
index_t
MRaw
=
N
*
DTildeSlice
*
HTildeSlice
*
WTildeSlice
*
K
;
const
index_t
MRaw
=
N
*
DTildeSlice
*
HTildeSlice
*
WTildeSlice
*
C
;
const
index_t
MPad
=
math
::
integer_least_multiple
(
MRaw
,
M_BlockTileSize
)
-
MRaw
;
const
index_t
MPad
=
math
::
integer_least_multiple
(
MRaw
,
M_BlockTileSize
)
-
MRaw
;
const
index_t
KRaw
=
ZDotSlice
*
YDotSlice
*
XDotSlice
;
const
index_t
KRaw
=
ZDotSlice
*
YDotSlice
*
XDotSlice
;
...
@@ -212,27 +212,27 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
...
@@ -212,27 +212,27 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// In[ReduceM]
// In[ReduceM]
const
auto
in_n_di_hi_wi_
k
_grid_desc
=
const
auto
in_n_di_hi_wi_
c
_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
K
),
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
din_n_
k
_wos_strides
[
0
],
make_tuple
(
din_n_
c
_wos_strides
[
0
],
din_n_
k
_wos_strides
[
2
],
din_n_
c
_wos_strides
[
2
],
din_n_
k
_wos_strides
[
3
],
din_n_
c
_wos_strides
[
3
],
din_n_
k
_wos_strides
[
4
],
din_n_
c
_wos_strides
[
4
],
din_n_
k
_wos_strides
[
1
]));
din_n_
c
_wos_strides
[
1
]));
const
auto
in_n_dip_hip_wip_
k
_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_dip_hip_wip_
c
_grid_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_
k
_grid_desc
,
in_n_di_hi_wi_
c
_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
K
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_
k
_grid_desc
=
const
auto
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_
c
_grid_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
in_n_dip_hip_wip_
k
_grid_desc
,
in_n_dip_hip_wip_
c
_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
XTilde
,
DTilde
),
make_embed_transform
(
make_tuple
(
XTilde
,
DTilde
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
...
@@ -240,7 +240,7 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
...
@@ -240,7 +240,7 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
XTilde
,
WTilde
),
make_embed_transform
(
make_tuple
(
XTilde
,
WTilde
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
K
)),
make_pass_through_transform
(
C
)),
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
...
@@ -249,9 +249,9 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
...
@@ -249,9 +249,9 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
Sequence
<
5
,
6
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
Sequence
<
7
>
{}));
const
auto
in_n_dtildeslice_htildeslice_wtildeslice_
k
_grid_desc
=
const
auto
in_n_dtildeslice_htildeslice_wtildeslice_
c
_grid_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_
k
_grid_desc
,
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_
c
_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
i_ztilde
),
make_freeze_transform
(
i_ztilde
),
make_slice_transform
(
DTilde
,
IDTildeSliceBegin
,
DTildeSlice
),
make_slice_transform
(
DTilde
,
IDTildeSliceBegin
,
DTildeSlice
),
...
@@ -259,7 +259,7 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
...
@@ -259,7 +259,7 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_freeze_transform
(
i_xtilde
),
make_freeze_transform
(
i_xtilde
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_pass_through_transform
(
K
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
2
>
{},
...
@@ -278,9 +278,9 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
...
@@ -278,9 +278,9 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
Sequence
<
4
>
{}));
Sequence
<
4
>
{}));
const
auto
in_grid_desc_reducemraw
=
transform_tensor_descriptor
(
const
auto
in_grid_desc_reducemraw
=
transform_tensor_descriptor
(
in_n_dtildeslice_htildeslice_wtildeslice_
k
_grid_desc
,
in_n_dtildeslice_htildeslice_wtildeslice_
c
_grid_desc
,
make_tuple
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
DTildeSlice
,
HTildeSlice
,
WTildeSlice
,
K
))),
make_merge_transform
(
make_tuple
(
N
,
DTildeSlice
,
HTildeSlice
,
WTildeSlice
,
C
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
0
>
{}));
...
@@ -297,10 +297,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
...
@@ -297,10 +297,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
{
{
Argument
(
const
DOutDataType
*
p_dout
,
Argument
(
const
DOutDataType
*
p_dout
,
DInDataType
*
p_din
,
DInDataType
*
p_din
,
std
::
vector
<
ck
::
index_t
>
dout_n_
k
_wos_lengths
,
std
::
vector
<
ck
::
index_t
>
dout_n_
c
_wos_lengths
,
std
::
vector
<
ck
::
index_t
>
din_n_
k
_wos_length
,
std
::
vector
<
ck
::
index_t
>
din_n_
c
_wos_length
,
std
::
vector
<
ck
::
index_t
>
dout_n_
k
_wos_strides
,
std
::
vector
<
ck
::
index_t
>
dout_n_
c
_wos_strides
,
std
::
vector
<
ck
::
index_t
>
din_n_
k
_wos_strides
,
std
::
vector
<
ck
::
index_t
>
din_n_
c
_wos_strides
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_dilations
,
std
::
vector
<
ck
::
index_t
>
window_dilations
,
...
@@ -310,10 +310,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
...
@@ -310,10 +310,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
{
{
ignore
=
p_dout
;
ignore
=
p_dout
;
ignore
=
p_din
;
ignore
=
p_din
;
ignore
=
dout_n_
k
_wos_lengths
;
ignore
=
dout_n_
c
_wos_lengths
;
ignore
=
dout_n_
k
_wos_strides
;
ignore
=
dout_n_
c
_wos_strides
;
ignore
=
din_n_
k
_wos_length
;
ignore
=
din_n_
c
_wos_length
;
ignore
=
din_n_
k
_wos_strides
;
ignore
=
din_n_
c
_wos_strides
;
ignore
=
window_lengths
;
ignore
=
window_lengths
;
ignore
=
window_strides
;
ignore
=
window_strides
;
ignore
=
window_dilations
;
ignore
=
window_dilations
;
...
@@ -383,10 +383,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
...
@@ -383,10 +383,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_dout
,
MakeArgumentPointer
(
const
void
*
p_dout
,
void
*
p_din
,
void
*
p_din
,
std
::
vector
<
ck
::
index_t
>
dout_n_
k
_wos_lengths
,
std
::
vector
<
ck
::
index_t
>
dout_n_
c
_wos_lengths
,
std
::
vector
<
ck
::
index_t
>
din_n_
k
_wos_length
,
std
::
vector
<
ck
::
index_t
>
din_n_
c
_wos_length
,
std
::
vector
<
ck
::
index_t
>
dout_n_
k
_wos_strides
,
std
::
vector
<
ck
::
index_t
>
dout_n_
c
_wos_strides
,
std
::
vector
<
ck
::
index_t
>
din_n_
k
_wos_strides
,
std
::
vector
<
ck
::
index_t
>
din_n_
c
_wos_strides
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_dilations
,
std
::
vector
<
ck
::
index_t
>
window_dilations
,
...
@@ -395,10 +395,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
...
@@ -395,10 +395,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
DOutDataType
*>
(
p_dout
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
DOutDataType
*>
(
p_dout
),
static_cast
<
DInDataType
*>
(
p_din
),
static_cast
<
DInDataType
*>
(
p_din
),
dout_n_
k
_wos_lengths
,
dout_n_
c
_wos_lengths
,
din_n_
k
_wos_length
,
din_n_
c
_wos_length
,
dout_n_
k
_wos_strides
,
dout_n_
c
_wos_strides
,
din_n_
k
_wos_strides
,
din_n_
c
_wos_strides
,
window_lengths
,
window_lengths
,
window_strides
,
window_strides
,
window_dilations
,
window_dilations
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment