Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
055acace
Commit
055acace
authored
Jul 06, 2023
by
rocking
Browse files
Imitate the argument from conv bwd
parent
55e420ec
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
12 deletions
+44
-12
include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_impl.hpp
...r_operation/gpu/device/impl/device_avgpool3d_bwd_impl.hpp
+44
-12
No files found.
include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_impl.hpp
View file @
055acace
...
...
@@ -131,6 +131,7 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
// Problem size of reduction kernel
const
index_t
MRaw
=
N
*
DTildeSlice
*
HTildeSlice
*
WTildeSlice
*
C
;
const
index_t
MPad
=
math
::
integer_least_multiple
(
MRaw
,
M_BlockTileSize
)
-
MRaw
;
...
...
@@ -293,6 +294,20 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
return
make_tuple
(
out_grid_desc_reducem_reducek
,
in_grid_desc_reducem
);
}
using
DoutDinGridDesc
=
decltype
(
Make3DGridDescriptor_Out_M_K_In_M
({
0
,
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
},
{
0
,
0
,
0
},
{
0
,
0
,
0
},
{
0
,
0
,
0
},
{
0
,
0
,
0
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}));
using
DoutGridDesc_M_K
=
remove_cvref_t
<
tuple_element_t
<
0
,
DoutDinGridDesc
>>
;
using
DinGridDesc_M
=
remove_cvref_t
<
tuple_element_t
<
1
,
DoutDinGridDesc
>>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
DOutDataType
*
p_dout
,
...
...
@@ -308,18 +323,6 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
:
p_dout_grid_
{
p_dout
},
p_din_grid_
{
p_din
},
num_reduce_
{
1
}
{
ignore
=
p_dout
;
ignore
=
p_din
;
ignore
=
dout_n_c_wos_lengths
;
ignore
=
dout_n_c_wos_strides
;
ignore
=
din_n_c_wos_length
;
ignore
=
din_n_c_wos_strides
;
ignore
=
window_lengths
;
ignore
=
window_strides
;
ignore
=
window_dilations
;
ignore
=
input_left_pads
;
ignore
=
input_right_pads
;
std
::
vector
<
ck
::
index_t
>
Tildes
(
NDimSpatial
);
for
(
int
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
...
...
@@ -346,16 +349,45 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
{
continue
;
}
const
auto
dout_din_grid_desc
=
Make3DGridDescriptor_Out_M_K_In_M
(
dout_n_c_wos_lengths
,
din_n_c_wos_length
,
dout_n_c_wos_strides
,
din_n_c_wos_strides
,
window_lengths
,
window_strides
,
window_dilations
,
input_left_pads
,
input_right_pads
,
{
i_ztilde
,
i_ytilde
,
i_xtilde
});
dout_grid_desc_m_k_container_
.
push_back
(
dout_din_grid_desc
[
I0
]);
din_grid_desc_m_container_
.
push_back
(
dout_din_grid_desc
[
I1
]);
}
}
}
}
void
Print
()
const
{
for
(
index_t
i
=
0
;
i
<
num_reduce_
;
i
++
)
{
std
::
cout
<<
"dout_grid_desc_m_k_container_"
<<
dout_grid_desc_m_k_container_
[
i
]
<<
std
::
endl
;
std
::
cout
<<
"din_grid_desc_m_container_"
<<
din_grid_desc_m_container_
[
i
]
<<
std
::
endl
;
}
}
// pointer
const
DOutDataType
*
p_dout_grid_
;
DInDataType
*
p_din_grid_
;
int
num_reduce_
;
std
::
vector
<
DoutGridDesc_M_K
>
dout_grid_desc_m_k_container_
;
std
::
vector
<
DinGridDesc_M
>
din_grid_desc_m_container_
;
};
struct
Invoker
:
public
BaseInvoker
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment