Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
156423d6
Commit
156423d6
authored
Jul 24, 2023
by
rocking
Browse files
Do not hardcode stride
parent
74284dd2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
132 additions
and
98 deletions
+132
-98
example/48_pool3d_fwd/pool3d_fwd_common.hpp
example/48_pool3d_fwd/pool3d_fwd_common.hpp
+54
-40
include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_impl.hpp
...nsor_operation/gpu/device/impl/device_pool3d_fwd_impl.hpp
+70
-50
library/src/tensor_operation_instance/gpu/pool_fwd/pool_fwd_instance_common.hpp
...ration_instance/gpu/pool_fwd/pool_fwd_instance_common.hpp
+8
-8
No files found.
example/48_pool3d_fwd/pool3d_fwd_common.hpp
View file @
156423d6
...
...
@@ -8,7 +8,7 @@
#include "ck/utility/reduction_enums.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_
ndhwc_ndhwc
.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_
impl
.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
...
...
@@ -18,6 +18,43 @@
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp"
template
<
typename
TensorLayout
>
std
::
vector
<
ck
::
index_t
>
f_tensor_strides_ncdhw
(
ck
::
index_t
N_
,
ck
::
index_t
C_
,
ck
::
index_t
D
,
ck
::
index_t
H
,
ck
::
index_t
W
,
TensorLayout
layout
)
{
using
namespace
ck
::
literals
;
(
void
)
N_
;
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NCDHW
>::
value
)
return
{
C_
*
D
*
H
*
W
,
D
*
H
*
W
,
H
*
W
,
W
,
1
_uz
};
else
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NDHWC
>::
value
)
return
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
};
};
template
<
typename
TensorLayout
>
HostTensorDescriptor
f_host_tensor_descriptor
(
std
::
size_t
N_
,
std
::
size_t
C_
,
std
::
size_t
D
,
std
::
size_t
H
,
std
::
size_t
W
,
TensorLayout
layout
)
{
using
namespace
ck
::
literals
;
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NCDHW
>::
value
)
{
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
{
C_
*
D
*
H
*
W
,
D
*
H
*
W
,
H
*
W
,
W
,
1
_uz
});
}
else
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NDHWC
>::
value
)
{
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
});
}
};
template
<
typename
InDataType
,
typename
OutDataType
,
typename
ComputeDataType
,
...
...
@@ -48,20 +85,19 @@ bool pool3d_test(bool do_verification,
ck
::
index_t
in_right_pad_w
)
{
using
DevicePoolFwdInstance
=
ck
::
tensor_operation
::
device
::
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
<
InDataType
,
// InDataType
OutDataType
,
// OutDataType
IndexDataType
,
// IndexDataType
ComputeDataType
,
// ComputeDataType
ReduceOpId
,
OutputIndex
,
64
,
// BlockSize
64
,
// ReduceMThreadClusterSize
1
,
// ReduceKThreadClusterSize
4
,
// ReduceMThreadSliceSize
1
,
// ReduceKThreadSliceSize
4
>
;
// InSrcOutDstVectorSize
ck
::
tensor_operation
::
device
::
DevicePool3dFwdImpl
<
InDataType
,
// InDataType
OutDataType
,
// OutDataType
IndexDataType
,
// IndexDataType
ComputeDataType
,
// ComputeDataType
ReduceOpId
,
OutputIndex
,
64
,
// BlockSize
64
,
// ReduceMThreadClusterSize
1
,
// ReduceKThreadClusterSize
1
,
// ReduceMThreadSliceSize
1
,
// ReduceKThreadSliceSize
1
,
// InSrcOutDstVectorSize
false
>
;
// IsFastestDimReduced
const
ck
::
index_t
Do
=
(
Di
+
in_left_pad_d
+
in_right_pad_d
-
Z
)
/
window_stride_d
+
1
;
const
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
Y
)
/
window_stride_h
+
1
;
const
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
X
)
/
window_stride_w
+
1
;
...
...
@@ -72,28 +108,6 @@ bool pool3d_test(bool do_verification,
const
std
::
vector
<
ck
::
index_t
>
input_left_pads
{
in_left_pad_d
,
in_left_pad_h
,
in_left_pad_w
};
const
std
::
vector
<
ck
::
index_t
>
input_right_pads
{
in_right_pad_d
,
in_right_pad_h
,
in_right_pad_w
};
// tensor layout
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
N_
,
std
::
size_t
C_
,
std
::
size_t
D
,
std
::
size_t
H
,
std
::
size_t
W
,
auto
layout
)
{
using
namespace
ck
::
literals
;
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NCDHW
>::
value
)
{
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
{
C_
*
D
*
H
*
W
,
D
*
H
*
W
,
H
*
W
,
W
,
1
_uz
});
}
else
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NDHWC
>::
value
)
{
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
});
}
};
Tensor
<
InDataType
>
in_n_c_di_hi_wi
(
f_host_tensor_descriptor
(
N
,
C
,
Di
,
Hi
,
Wi
,
InLayout
{}));
Tensor
<
OutDataType
>
out_n_c_do_ho_wo_host
(
f_host_tensor_descriptor
(
N
,
C
,
Do
,
Ho
,
Wo
,
OutLayout
{}));
...
...
@@ -126,9 +140,9 @@ bool pool3d_test(bool do_verification,
{
N
,
C
,
Di
,
Hi
,
Wi
},
{
Z
,
Y
,
X
},
{
N
,
C
,
Do
,
Ho
,
Wo
},
{
Di
*
C
*
Hi
*
Wi
,
1
,
C
*
Hi
*
Wi
,
Wi
*
C
,
C
}
,
{
Do
*
C
*
Ho
*
Wo
,
1
,
C
*
Ho
*
Wo
,
Wo
*
C
,
C
}
,
{
Do
*
C
*
Ho
*
Wo
,
1
,
C
*
Ho
*
Wo
,
Wo
*
C
,
C
}
,
f_tensor_strides_ncdhw
(
N
,
C
,
Di
,
Hi
,
Wi
,
InLayout
{})
,
f_tensor_strides_ncdhw
(
N
,
C
,
Do
,
Ho
,
Wo
,
OutLayout
{})
,
f_tensor_strides_ncdhw
(
N
,
C
,
Do
,
Ho
,
Wo
,
OutLayout
{})
,
window_strides
,
input_left_pads
,
input_right_pads
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_
ndhwc_ndhwc
.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_
impl
.hpp
View file @
156423d6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -10,6 +10,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -29,8 +30,9 @@ template <typename InDataType,
ck
::
index_t
KThreadClusterSize
,
ck
::
index_t
MThreadSliceSize
,
ck
::
index_t
KThreadSliceSize
,
ck
::
index_t
InSrcOutDstVectorSize
>
struct
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
ck
::
index_t
InSrcOutDstVectorSize
,
bool
IsFastestDimReduced
>
struct
DevicePool3dFwdImpl
:
public
DevicePoolFwd
<
5
,
3
,
InDataType
,
OutDataType
,
IndexDataType
,
ReduceOpId
,
OutputIndex
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -51,29 +53,27 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
using
AccElementwiseOperation
=
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
// for NDHWC, the dim C is the vector Dim for both input and output in memory, which is not
// reduced.
static
constexpr
index_t
InSrcOutDstVectorDim
=
0
;
static
constexpr
ck
::
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
ck
::
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
auto
MakeABGridDescriptor_A_M_K_B_M
(
ck
::
index_t
N
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
static
auto
MakeABGridDescriptor_A_M_K_B_M
(
std
::
vector
<
ck
::
index_t
>
input_lengths
,
std
::
vector
<
ck
::
index_t
>
output_lengths
,
std
::
vector
<
ck
::
index_t
>
input_stride
,
std
::
vector
<
ck
::
index_t
>
output_stride
,
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
{
const
index_t
Di
=
input_spatial_lengths
[
0
];
const
index_t
Hi
=
input_spatial_lengths
[
1
];
const
index_t
Wi
=
input_spatial_lengths
[
2
];
const
index_t
N
=
input_lengths
[
0
];
const
index_t
C
=
input_lengths
[
1
];
const
index_t
Di
=
input_lengths
[
2
];
const
index_t
Hi
=
input_lengths
[
3
];
const
index_t
Wi
=
input_lengths
[
4
];
const
index_t
Do
=
output_
spatial_
lengths
[
0
];
const
index_t
Ho
=
output_
spatial_
lengths
[
1
];
const
index_t
Wo
=
output_
spatial_
lengths
[
2
];
const
index_t
Do
=
output_lengths
[
2
];
const
index_t
Ho
=
output_lengths
[
3
];
const
index_t
Wo
=
output_lengths
[
4
];
const
index_t
Z
=
window_spatial_lengths
[
0
];
const
index_t
Y
=
window_spatial_lengths
[
1
];
...
...
@@ -98,8 +98,15 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
const
index_t
KPad
=
math
::
integer_least_multiple
(
KRaw
,
K_BlockTileSize
)
-
KRaw
;
// A[ReduceM, ReduceK]
const
auto
in_grid_desc_n_di_hi_wi_c
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
const
index_t
Ni_stride
=
input_stride
[
0
];
const
index_t
Ci_stride
=
input_stride
[
1
];
const
index_t
Di_stride
=
input_stride
[
2
];
const
index_t
Hi_stride
=
input_stride
[
3
];
const
index_t
Wi_stride
=
input_stride
[
4
];
const
auto
in_grid_desc_n_di_hi_wi_c
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
Ni_stride
,
Di_stride
,
Hi_stride
,
Wi_stride
,
Ci_stride
));
const
auto
in_grid_desc_n_dip_hip_wip_c
=
transform_tensor_descriptor
(
in_grid_desc_n_di_hi_wi_c
,
...
...
@@ -139,8 +146,21 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// B[ReduceM]
const
auto
out_grid_desc_reducemraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
*
C
));
const
index_t
No_stride
=
output_stride
[
0
];
const
index_t
Co_stride
=
output_stride
[
1
];
const
index_t
Do_stride
=
output_stride
[
2
];
const
index_t
Ho_stride
=
output_stride
[
3
];
const
index_t
Wo_stride
=
output_stride
[
4
];
const
auto
out_grid_desc_n_do_ho_wo_c
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
No_stride
,
Do_stride
,
Ho_stride
,
Wo_stride
,
Co_stride
));
const
auto
out_grid_desc_reducemraw
=
transform_tensor_descriptor
(
out_grid_desc_n_do_ho_wo_c
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
,
C
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
out_grid_desc_reducem
=
transform_tensor_descriptor
(
out_grid_desc_reducemraw
,
...
...
@@ -151,7 +171,8 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
return
make_tuple
(
in_grid_desc_reducem_reducek
,
out_grid_desc_reducem
);
}
using
ABGridDescs
=
decltype
(
MakeABGridDescriptor_A_M_K_B_M
(
1
,
1
,
{},
{},
{},
{},
{},
{}));
using
ABGridDescs
=
decltype
(
MakeABGridDescriptor_A_M_K_B_M
({},
{},
{},
{},
{},
{},
{},
{}));
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I0
])
>
;
using
BGridDesc_M
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I1
])
>
;
...
...
@@ -160,11 +181,12 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
Argument
(
const
InDataType
*
p_in_dev
,
OutDataType
*
p_out_dev
,
IndexDataType
*
p_out_indices_dev
,
ck
::
index_t
N
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>&
input_lengths
,
std
::
vector
<
ck
::
index_t
>&
output_lengths
,
std
::
vector
<
ck
::
index_t
>&
input_stride
,
std
::
vector
<
ck
::
index_t
>&
output_stride
,
std
::
vector
<
ck
::
index_t
>&
,
// indices_stride
std
::
vector
<
ck
::
index_t
>&
window_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>&
window_strides
,
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
...
...
@@ -174,11 +196,11 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
a_grid_desc_m_k_
{},
b_grid_desc_m_
{}
{
const
auto
descs
=
MakeABGridDescriptor_A_M_K_B_M
(
N
,
C
,
input_spatial_lengths
,
const
auto
descs
=
MakeABGridDescriptor_A_M_K_B_M
(
input_lengths
,
output_lengths
,
input_stride
,
output_stride
,
window_spatial_lengths
,
output_spatial_lengths
,
window_strides
,
input_left_pads
,
input_right_pads
);
...
...
@@ -186,7 +208,8 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
a_grid_desc_m_k_
=
descs
[
I0
];
b_grid_desc_m_
=
descs
[
I1
];
invariant_lowest_length_
=
C
;
// C
invariant_lowest_length_
=
input_lengths
[
1
];
int32_t
reduceLength
=
window_spatial_lengths
[
0
]
*
window_spatial_lengths
[
1
]
*
window_spatial_lengths
[
2
];
...
...
@@ -211,6 +234,11 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
// for NDHWC, the dim C is the vector Dim for both input and output in memory, which is
// not reduced.
// reduce [D, H, W]
static
constexpr
index_t
InSrcOutDstVectorDim
=
IsFastestDimReduced
?
1
:
0
;
using
gridwise_reduce
=
GridwiseReduction_mk_to_m_threadwise
<
InDataType
,
OutDataType
,
...
...
@@ -291,9 +319,9 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
std
::
vector
<
ck
::
index_t
>
input_lengths
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
output_lengths
,
std
::
vector
<
ck
::
index_t
>
,
// Suppose tensor layout = NDHWC
std
::
vector
<
ck
::
index_t
>
,
// Suppose tensor layout = NDHWC
std
::
vector
<
ck
::
index_t
>
,
// Suppose tensor layout = NDHWC
std
::
vector
<
ck
::
index_t
>
input_stride
,
std
::
vector
<
ck
::
index_t
>
output_stride
,
std
::
vector
<
ck
::
index_t
>
indices_stride
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
...
...
@@ -307,26 +335,18 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
if
(
pooling_dims
!=
std
::
vector
<
ck
::
index_t
>
{
2
,
3
,
4
})
throw
std
::
runtime_error
(
"pooling_dims only support {2, 3, 4} in pool3d so far"
);
index_t
N
=
input_lengths
[
0
];
index_t
C
=
input_lengths
[
1
];
index_t
Di
=
input_lengths
[
2
];
index_t
Hi
=
input_lengths
[
3
];
index_t
Wi
=
input_lengths
[
4
];
index_t
Do
=
output_lengths
[
2
];
index_t
Ho
=
output_lengths
[
3
];
index_t
Wo
=
output_lengths
[
4
];
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
=
{
Di
,
Hi
,
Wi
};
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
=
{
Do
,
Ho
,
Wo
};
if
(
output_stride
!=
indices_stride
)
throw
std
::
runtime_error
(
"output_stride need to be equal to indices_stride for now"
);
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_dev
),
static_cast
<
OutDataType
*>
(
p_out_dev
),
static_cast
<
IndexDataType
*>
(
p_out_indices_dev
),
N
,
C
,
input_spatial_lengths
,
input_lengths
,
output_lengths
,
input_stride
,
output_stride
,
indices_stride
,
window_lengths
,
output_spatial_lengths
,
window_strides
,
input_left_pads
,
input_right_pads
);
...
...
@@ -342,7 +362,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DevicePool3dFwd
_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
<"
<<
BlockSize
<<
","
;
str
<<
"DevicePool3dFwd
Impl
<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"InSrcOutDstVectorSize_"
<<
InSrcOutDstVectorSize
<<
">"
;
...
...
library/src/tensor_operation_instance/gpu/pool_fwd/pool_fwd_instance_common.hpp
View file @
156423d6
...
...
@@ -5,7 +5,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_
ndhwc_ndhwc
.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_
impl
.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
...
...
@@ -31,8 +31,8 @@ using device_pool2d_fwd_nhwc_instances =
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
256
,
256
,
1
,
1
,
1
,
1
>
,
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
256
,
256
,
1
,
2
,
1
,
2
>
,
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
256
,
256
,
1
,
4
,
1
,
4
>
// clang-format on
>
;
// clang-format on
>
;
template
<
typename
InDataType
,
typename
OutDataType
,
...
...
@@ -43,11 +43,11 @@ template <typename InDataType,
using
device_pool3d_fwd_ndhwc_instances
=
// clang-format off
std
::
tuple
<
DevicePool3dFwd
_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
256
,
256
,
1
,
1
,
1
,
1
>
,
DevicePool3dFwd
_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
256
,
256
,
1
,
2
,
1
,
2
>
,
DevicePool3dFwd
_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
256
,
256
,
1
,
4
,
1
,
4
>
// clang-format on
>
;
DevicePool3dFwd
Impl
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
256
,
256
,
1
,
1
,
1
,
1
>
,
DevicePool3dFwd
Impl
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
256
,
256
,
1
,
2
,
1
,
2
>
,
DevicePool3dFwd
Impl
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
256
,
256
,
1
,
4
,
1
,
4
>
// clang-format on
>
;
}
// namespace instance
}
// namespace device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment