Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8332414c
"examples/vscode:/vscode.git/clone" did not exist on "e62dd5cfa8cf2d6c538366b0b9b85ce3fe613500"
Commit
8332414c
authored
Apr 28, 2023
by
rocking
Browse files
Expand the base class of pool2d, prepare to share base class with pool3d
parent
54c90aae
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
15 deletions
+12
-15
include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp
include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp
+8
-11
include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
...operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
+4
-4
No files found.
include/ck/tensor_operation/gpu/device/device_pool
2d
_fwd.hpp
→
include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp
View file @
8332414c
...
@@ -13,8 +13,8 @@ namespace ck {
...
@@ -13,8 +13,8 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
ck
::
ReduceTensorOp
ReduceOpId
>
template
<
index_t
NDimSpatial
,
ReduceTensorOp
ReduceOpId
>
struct
DevicePool
2d
Fwd
:
public
BaseOperator
struct
DevicePoolFwd
:
public
BaseOperator
{
{
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
in_dev
,
MakeArgumentPointer
(
const
void
*
in_dev
,
...
@@ -22,19 +22,16 @@ struct DevicePool2dFwd : public BaseOperator
...
@@ -22,19 +22,16 @@ struct DevicePool2dFwd : public BaseOperator
void
*
out_indices_dev
,
void
*
out_indices_dev
,
ck
::
index_t
N
,
ck
::
index_t
N
,
ck
::
index_t
C
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
2
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>
window_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
window_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
2
>
window_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
window_strides
,
std
::
array
<
ck
::
index_t
,
2
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
2
>
input_right_pads
)
=
0
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
template
<
ck
::
ReduceTensorOp
ReduceOpId
>
using
DevicePool2dFwdPtr
=
std
::
unique_ptr
<
DevicePool2dFwd
<
ReduceOpId
>>
;
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
View file @
8332414c
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_pool
2d
_fwd.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -29,7 +29,7 @@ template <typename InDataType,
...
@@ -29,7 +29,7 @@ template <typename InDataType,
ck
::
index_t
ReduceMThreadSliceSize
,
ck
::
index_t
ReduceMThreadSliceSize
,
ck
::
index_t
ReduceKThreadSliceSize
,
ck
::
index_t
ReduceKThreadSliceSize
,
ck
::
index_t
InSrcOutDstVectorSize
>
ck
::
index_t
InSrcOutDstVectorSize
>
struct
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
:
public
DevicePool
2d
Fwd
<
ReduceOpId
>
struct
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
:
public
DevicePoolFwd
<
2
,
ReduceOpId
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -141,8 +141,8 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
...
@@ -141,8 +141,8 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
return
make_tuple
(
in_grid_desc_reducem_reducek
,
out_grid_desc_reducem
);
return
make_tuple
(
in_grid_desc_reducem_reducek
,
out_grid_desc_reducem
);
}
}
using
ABGridDescs
=
decltype
(
using
ABGridDescs
=
decltype
(
MakeABGridDescriptor_A_M_K_B_M
(
MakeABGridDescriptor_A_M_K_B_M
(
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}));
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}));
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I0
])
>
;
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I0
])
>
;
using
BGridDesc_M
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I1
])
>
;
using
BGridDesc_M
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I1
])
>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment