Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
58d84615
Commit
58d84615
authored
Jul 06, 2023
by
rocking
Browse files
Implement invoker
parent
055acace
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
79 additions
and
17 deletions
+79
-17
include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_impl.hpp
...r_operation/gpu/device/impl/device_avgpool3d_bwd_impl.hpp
+79
-17
No files found.
include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_impl.hpp
View file @
58d84615
...
...
@@ -10,6 +10,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -308,6 +309,32 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
using
DoutGridDesc_M_K
=
remove_cvref_t
<
tuple_element_t
<
0
,
DoutDinGridDesc
>>
;
using
DinGridDesc_M
=
remove_cvref_t
<
tuple_element_t
<
1
,
DoutDinGridDesc
>>
;
// FIXME
// for NDHWC, the dim C is the vector Dim for both input and output in memory, which is not
// reduced. Assume C is the fastest dimension
static
constexpr
index_t
InSrcOutDstVectorDim
=
0
;
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
using
Div
=
tensor_operation
::
element_wise
::
UnaryDivide
;
using
gridwise_reduce
=
GridwiseReduction_mk_to_m_threadwise
<
DOutDataType
,
DInDataType
,
ComputeDataType
,
int
,
DoutGridDesc_M_K
,
DinGridDesc_M
,
reduce
::
Add
,
PassThrough
,
Div
,
InMemoryDataOperationEnum
::
Set
,
false
,
// propagate_nan
BlockSize
,
MThreadSliceSize
,
KThreadSliceSize
,
InSrcOutDstVectorDim
,
InSrcOutDstVectorSize
,
InSrcOutDstVectorSize
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
DOutDataType
*
p_dout
,
...
...
@@ -321,7 +348,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
std
::
vector
<
ck
::
index_t
>
window_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
:
p_dout_grid_
{
p_dout
},
p_din_grid_
{
p_din
},
num_reduce_
{
1
}
:
p_dout_grid_
{
p_dout
},
p_din_grid_
{
p_din
},
num_reduce_
{
1
},
div_element_op_
{
window_lengths
[
0
]
*
window_lengths
[
1
]
*
window_lengths
[
2
]}
{
std
::
vector
<
ck
::
index_t
>
Tildes
(
NDimSpatial
);
for
(
int
i
=
0
;
i
<
NDimSpatial
;
++
i
)
...
...
@@ -369,35 +399,67 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
}
}
void
Print
()
const
{
for
(
index_t
i
=
0
;
i
<
num_reduce_
;
i
++
)
{
std
::
cout
<<
"dout_grid_desc_m_k_container_"
<<
dout_grid_desc_m_k_container_
[
i
]
<<
std
::
endl
;
std
::
cout
<<
"din_grid_desc_m_container_"
<<
din_grid_desc_m_container_
[
i
]
<<
std
::
endl
;
}
}
// pointer
const
DOutDataType
*
p_dout_grid_
;
DInDataType
*
p_din_grid_
;
int
num_reduce_
;
std
::
vector
<
DoutGridDesc_M_K
>
dout_grid_desc_m_k_container_
;
std
::
vector
<
DinGridDesc_M
>
din_grid_desc_m_container_
;
Div
div_element_op_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
ignore
=
arg
;
ignore
=
stream_config
;
float
ave_time
=
0
;
for
(
index_t
i
=
0
;
i
<
arg
.
num_reduce_
;
i
++
)
{
const
auto
kernel
=
kernel_reduce_threadwise
<
gridwise_reduce
,
false
,
false
,
false
,
// don't have index input
DOutDataType
,
DInDataType
,
ComputeDataType
,
int
,
DoutGridDesc_M_K
,
DinGridDesc_M
,
PassThrough
,
Div
>
;
ck
::
index_t
M
=
arg
.
dout_grid_desc_m_k_container_
[
i
].
GetLength
(
I0
);
const
index_t
grid_size
=
(
M
/
M_BlockTileSize
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
dout_grid_desc_m_k_container_
[
i
],
arg
.
din_grid_desc_m_container_
[
i
],
PassThrough
{},
arg
.
div_element_op_
,
float
(
1
),
arg
.
p_dout_grid_
,
nullptr
,
float
(
0
),
arg
.
p_din_grid_
,
nullptr
);
}
return
ave_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
ignore
=
p_arg
;
ignore
=
stream_config
;
return
0
;
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment