Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
91e37486
Commit
91e37486
authored
May 23, 2023
by
rocking
Browse files
Adjust the order of device_pool_fwd
parent
3fa7f1eb
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
35 additions
and
39 deletions
+35
-39
client_example/19_pool_fwd/avg_pool3d_fwd.cpp
client_example/19_pool_fwd/avg_pool3d_fwd.cpp
+6
-6
client_example/19_pool_fwd/max_pool2d_fwd.cpp
client_example/19_pool_fwd/max_pool2d_fwd.cpp
+6
-6
example/13_pool2d_fwd/pool2d_fwd_common.hpp
example/13_pool2d_fwd/pool2d_fwd_common.hpp
+3
-3
example/48_pool3d_fwd/pool3d_fwd_common.hpp
example/48_pool3d_fwd/pool3d_fwd_common.hpp
+3
-3
include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp
include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp
+3
-3
include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
...operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
+4
-6
include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp
...eration/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp
+4
-6
profiler/include/profiler/profile_pool2d_fwd_impl.hpp
profiler/include/profiler/profile_pool2d_fwd_impl.hpp
+3
-3
profiler/include/profiler/profile_pool3d_fwd_impl.hpp
profiler/include/profiler/profile_pool3d_fwd_impl.hpp
+3
-3
No files found.
client_example/19_pool_fwd/avg_pool3d_fwd.cpp
View file @
91e37486
...
@@ -115,12 +115,12 @@ int main(int argc, char* argv[])
...
@@ -115,12 +115,12 @@ int main(int argc, char* argv[])
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
out_indices_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
out_indices_device_buf
.
GetDeviceBuffer
()),
in_tensor_stride
,
out_tensor_stride
,
out_tensor_stride
,
in_length
,
in_length
,
window_spatial_lengths
,
window_spatial_lengths
,
out_length
,
out_length
,
in_tensor_stride
,
out_tensor_stride
,
out_tensor_stride
,
window_strides
,
window_strides
,
input_left_pads
,
input_left_pads
,
input_right_pads
,
input_right_pads
,
...
@@ -174,12 +174,12 @@ int main(int argc, char* argv[])
...
@@ -174,12 +174,12 @@ int main(int argc, char* argv[])
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
out_indices_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
out_indices_device_buf
.
GetDeviceBuffer
()),
in_tensor_stride
,
out_tensor_stride
,
out_tensor_stride
,
in_length
,
in_length
,
window_spatial_lengths
,
window_spatial_lengths
,
out_length
,
out_length
,
in_tensor_stride
,
out_tensor_stride
,
out_tensor_stride
,
window_strides
,
window_strides
,
input_left_pads
,
input_left_pads
,
input_right_pads
,
input_right_pads
,
...
...
client_example/19_pool_fwd/max_pool2d_fwd.cpp
View file @
91e37486
...
@@ -109,12 +109,12 @@ int main(int argc, char* argv[])
...
@@ -109,12 +109,12 @@ int main(int argc, char* argv[])
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
out_indices_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
out_indices_device_buf
.
GetDeviceBuffer
()),
in_tensor_stride
,
out_tensor_stride
,
out_tensor_stride
,
in_length
,
in_length
,
window_spatial_lengths
,
window_spatial_lengths
,
out_length
,
out_length
,
in_tensor_stride
,
out_tensor_stride
,
out_tensor_stride
,
window_strides
,
window_strides
,
input_left_pads
,
input_left_pads
,
input_right_pads
,
input_right_pads
,
...
@@ -168,12 +168,12 @@ int main(int argc, char* argv[])
...
@@ -168,12 +168,12 @@ int main(int argc, char* argv[])
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
out_indices_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
out_indices_device_buf
.
GetDeviceBuffer
()),
in_tensor_stride
,
out_tensor_stride
,
out_tensor_stride
,
in_length
,
in_length
,
window_spatial_lengths
,
window_spatial_lengths
,
out_length
,
out_length
,
in_tensor_stride
,
out_tensor_stride
,
out_tensor_stride
,
window_strides
,
window_strides
,
input_left_pads
,
input_left_pads
,
input_right_pads
,
input_right_pads
,
...
...
example/13_pool2d_fwd/pool2d_fwd_common.hpp
View file @
91e37486
...
@@ -116,12 +116,12 @@ bool pool_test(bool do_verification,
...
@@ -116,12 +116,12 @@ bool pool_test(bool do_verification,
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
out_indices_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
out_indices_device_buf
.
GetDeviceBuffer
()),
{
C
*
Hi
*
Wi
,
1
,
Wi
*
C
,
C
},
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
{
N
,
C
,
Hi
,
Wi
},
{
N
,
C
,
Hi
,
Wi
},
{
Y
,
X
},
{
Y
,
X
},
{
N
,
C
,
Ho
,
Wo
},
{
N
,
C
,
Ho
,
Wo
},
{
C
*
Hi
*
Wi
,
1
,
Wi
*
C
,
C
},
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
window_strides
,
window_strides
,
input_left_pads
,
input_left_pads
,
input_right_pads
,
input_right_pads
,
...
...
example/48_pool3d_fwd/pool3d_fwd_common.hpp
View file @
91e37486
...
@@ -123,12 +123,12 @@ bool pool3d_test(bool do_verification,
...
@@ -123,12 +123,12 @@ bool pool3d_test(bool do_verification,
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
out_indices_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
out_indices_device_buf
.
GetDeviceBuffer
()),
{
Di
*
C
*
Hi
*
Wi
,
1
,
C
*
Hi
*
Wi
,
Wi
*
C
,
C
},
{
Do
*
C
*
Ho
*
Wo
,
1
,
C
*
Ho
*
Wo
,
Wo
*
C
,
C
},
{
Do
*
C
*
Ho
*
Wo
,
1
,
C
*
Ho
*
Wo
,
Wo
*
C
,
C
},
{
N
,
C
,
Di
,
Hi
,
Wi
},
{
N
,
C
,
Di
,
Hi
,
Wi
},
{
Z
,
Y
,
X
},
{
Z
,
Y
,
X
},
{
N
,
C
,
Do
,
Ho
,
Wo
},
{
N
,
C
,
Do
,
Ho
,
Wo
},
{
Di
*
C
*
Hi
*
Wi
,
1
,
C
*
Hi
*
Wi
,
Wi
*
C
,
C
},
{
Do
*
C
*
Ho
*
Wo
,
1
,
C
*
Ho
*
Wo
,
Wo
*
C
,
C
},
{
Do
*
C
*
Ho
*
Wo
,
1
,
C
*
Ho
*
Wo
,
Wo
*
C
,
C
},
window_strides
,
window_strides
,
input_left_pads
,
input_left_pads
,
input_right_pads
,
input_right_pads
,
...
...
include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp
View file @
91e37486
...
@@ -25,12 +25,12 @@ struct DevicePoolFwd : public BaseOperator
...
@@ -25,12 +25,12 @@ struct DevicePoolFwd : public BaseOperator
MakeArgumentPointer
(
const
void
*
p_in_dev
,
MakeArgumentPointer
(
const
void
*
p_in_dev
,
void
*
p_out_dev
,
void
*
p_out_dev
,
void
*
p_out_indices_dev
,
void
*
p_out_indices_dev
,
std
::
vector
<
ck
::
index_t
>
input_stride
,
std
::
vector
<
ck
::
index_t
>
output_stride
,
std
::
vector
<
ck
::
index_t
>
indices_stride
,
std
::
vector
<
ck
::
index_t
>
input_lengths
,
std
::
vector
<
ck
::
index_t
>
input_lengths
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
output_lengths
,
std
::
vector
<
ck
::
index_t
>
output_lengths
,
std
::
vector
<
ck
::
index_t
>
input_stride
,
std
::
vector
<
ck
::
index_t
>
output_stride
,
std
::
vector
<
ck
::
index_t
>
indices_stride
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
View file @
91e37486
...
@@ -144,9 +144,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
...
@@ -144,9 +144,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
return
make_tuple
(
in_grid_desc_reducem_reducek
,
out_grid_desc_reducem
);
return
make_tuple
(
in_grid_desc_reducem_reducek
,
out_grid_desc_reducem
);
}
}
using
ABGridDescs
=
decltype
(
using
ABGridDescs
=
decltype
(
MakeABGridDescriptor_A_M_K_B_M
(
1
,
1
,
{},
{},
{},
{},
{},
{}));
MakeABGridDescriptor_A_M_K_B_M
(
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}));
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I0
])
>
;
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I0
])
>
;
using
BGridDesc_M
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I1
])
>
;
using
BGridDesc_M
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I1
])
>
;
...
@@ -285,12 +283,12 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
...
@@ -285,12 +283,12 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
MakeArgumentPointer
(
const
void
*
p_in_dev
,
MakeArgumentPointer
(
const
void
*
p_in_dev
,
void
*
p_out_dev
,
void
*
p_out_dev
,
void
*
p_out_indices_dev
,
void
*
p_out_indices_dev
,
std
::
vector
<
ck
::
index_t
>
,
// Suppose tensor layout = NHWC
std
::
vector
<
ck
::
index_t
>
,
// Suppose tensor layout = NHWC
std
::
vector
<
ck
::
index_t
>
,
// Suppose tensor layout = NHWC
std
::
vector
<
ck
::
index_t
>
input_lengths
,
std
::
vector
<
ck
::
index_t
>
input_lengths
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
output_lengths
,
std
::
vector
<
ck
::
index_t
>
output_lengths
,
std
::
vector
<
ck
::
index_t
>
,
// Suppose tensor layout = NHWC
std
::
vector
<
ck
::
index_t
>
,
// Suppose tensor layout = NHWC
std
::
vector
<
ck
::
index_t
>
,
// Suppose tensor layout = NHWC
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp
View file @
91e37486
...
@@ -151,9 +151,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
...
@@ -151,9 +151,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
return
make_tuple
(
in_grid_desc_reducem_reducek
,
out_grid_desc_reducem
);
return
make_tuple
(
in_grid_desc_reducem_reducek
,
out_grid_desc_reducem
);
}
}
using
ABGridDescs
=
decltype
(
using
ABGridDescs
=
decltype
(
MakeABGridDescriptor_A_M_K_B_M
(
1
,
1
,
{},
{},
{},
{},
{},
{}));
MakeABGridDescriptor_A_M_K_B_M
(
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}));
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I0
])
>
;
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I0
])
>
;
using
BGridDesc_M
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I1
])
>
;
using
BGridDesc_M
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I1
])
>
;
...
@@ -290,12 +288,12 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
...
@@ -290,12 +288,12 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
MakeArgumentPointer
(
const
void
*
p_in_dev
,
MakeArgumentPointer
(
const
void
*
p_in_dev
,
void
*
p_out_dev
,
void
*
p_out_dev
,
void
*
p_out_indices_dev
,
void
*
p_out_indices_dev
,
std
::
vector
<
ck
::
index_t
>
,
// Suppose tensor layout = NDHWC
std
::
vector
<
ck
::
index_t
>
,
// Suppose tensor layout = NDHWC
std
::
vector
<
ck
::
index_t
>
,
// Suppose tensor layout = NDHWC
std
::
vector
<
ck
::
index_t
>
input_lengths
,
std
::
vector
<
ck
::
index_t
>
input_lengths
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
output_lengths
,
std
::
vector
<
ck
::
index_t
>
output_lengths
,
std
::
vector
<
ck
::
index_t
>
,
// Suppose tensor layout = NDHWC
std
::
vector
<
ck
::
index_t
>
,
// Suppose tensor layout = NDHWC
std
::
vector
<
ck
::
index_t
>
,
// Suppose tensor layout = NDHWC
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
...
...
profiler/include/profiler/profile_pool2d_fwd_impl.hpp
View file @
91e37486
...
@@ -145,12 +145,12 @@ bool profile_pool2d_fwd_impl(int do_verification,
...
@@ -145,12 +145,12 @@ bool profile_pool2d_fwd_impl(int do_verification,
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
out_indices_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
out_indices_device_buf
.
GetDeviceBuffer
()),
{
C
*
Hi
*
Wi
,
1
,
Wi
*
C
,
C
},
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
in_length
,
in_length
,
window_spatial_lengths
,
window_spatial_lengths
,
out_length
,
out_length
,
{
C
*
Hi
*
Wi
,
1
,
Wi
*
C
,
C
},
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
window_strides
,
window_strides
,
input_left_pads
,
input_left_pads
,
input_right_pads
,
input_right_pads
,
...
...
profiler/include/profiler/profile_pool3d_fwd_impl.hpp
View file @
91e37486
...
@@ -150,12 +150,12 @@ bool profile_pool3d_fwd_impl(int do_verification,
...
@@ -150,12 +150,12 @@ bool profile_pool3d_fwd_impl(int do_verification,
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
out_indices_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
out_indices_device_buf
.
GetDeviceBuffer
()),
{
Di
*
C
*
Hi
*
Wi
,
1
,
C
*
Hi
*
Wi
,
Wi
*
C
,
C
},
{
Do
*
C
*
Ho
*
Wo
,
1
,
C
*
Ho
*
Wo
,
Wo
*
C
,
C
},
{
Do
*
C
*
Ho
*
Wo
,
1
,
C
*
Ho
*
Wo
,
Wo
*
C
,
C
},
in_length
,
in_length
,
window_spatial_lengths
,
window_spatial_lengths
,
out_length
,
out_length
,
{
Di
*
C
*
Hi
*
Wi
,
1
,
C
*
Hi
*
Wi
,
Wi
*
C
,
C
},
{
Do
*
C
*
Ho
*
Wo
,
1
,
C
*
Ho
*
Wo
,
Wo
*
C
,
C
},
{
Do
*
C
*
Ho
*
Wo
,
1
,
C
*
Ho
*
Wo
,
Wo
*
C
,
C
},
window_strides
,
window_strides
,
input_left_pads
,
input_left_pads
,
input_right_pads
,
input_right_pads
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment