Commit 523d5c4b authored by rocking's avatar rocking
Browse files

Move instance declaration out of common

parent e2594ff7
...@@ -55,7 +55,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_, ...@@ -55,7 +55,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
} }
}; };
template <typename InDataType, template <typename DevicePoolFwdInstance,
typename InDataType,
typename OutDataType, typename OutDataType,
typename ComputeDataType, typename ComputeDataType,
typename IndexDataType, typename IndexDataType,
...@@ -84,20 +85,6 @@ bool pool3d_test(bool do_verification, ...@@ -84,20 +85,6 @@ bool pool3d_test(bool do_verification,
ck::index_t in_right_pad_h, ck::index_t in_right_pad_h,
ck::index_t in_right_pad_w) ck::index_t in_right_pad_w)
{ {
using DevicePoolFwdInstance =
ck::tensor_operation::device::DevicePool3dFwdImpl<InDataType, // InDataType
OutDataType, // OutDataType
IndexDataType, // IndexDataType
ComputeDataType, // ComputeDataType
ReduceOpId,
OutputIndex,
64, // BlockSize
64, // ReduceMThreadClusterSize
1, // ReduceKThreadClusterSize
1, // ReduceMThreadSliceSize
1, // ReduceKThreadSliceSize
1, // InSrcOutDstVectorSize
false>; // IsFastestDimReduced
const ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Z) / window_stride_d + 1; const ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Z) / window_stride_d + 1;
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1; const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1;
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1; const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1;
......
...@@ -15,9 +15,18 @@ using ComputeDataType = float; ...@@ -15,9 +15,18 @@ using ComputeDataType = float;
using IndexDataType = int32_t; using IndexDataType = int32_t;
#if 1
using InLayout = ck::tensor_layout::convolution::NDHWC; using InLayout = ck::tensor_layout::convolution::NDHWC;
using OutLayout = ck::tensor_layout::convolution::NDHWC; using OutLayout = ck::tensor_layout::convolution::NDHWC;
static constexpr bool IsFastestDimReduced = false;
#else
using InLayout = ck::tensor_layout::convolution::NCDHW;
using OutLayout = ck::tensor_layout::convolution::NCDHW;
static constexpr bool IsFastestDimReduced = true;
#endif
#if 1 #if 1
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
#else #else
...@@ -27,6 +36,21 @@ static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; ...@@ -27,6 +36,21 @@ static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
static constexpr bool OutputIndex = false; static constexpr bool OutputIndex = false;
static constexpr bool PropagateNan = false; static constexpr bool PropagateNan = false;
using DevicePoolFwdInstance =
ck::tensor_operation::device::DevicePool3dFwdImpl<InDataType, // InDataType
OutDataType, // OutDataType
IndexDataType, // IndexDataType
ComputeDataType, // ComputeDataType
ReduceOpId,
OutputIndex,
64, // BlockSize
64, // ReduceMThreadClusterSize
1, // ReduceKThreadClusterSize
1, // ReduceMThreadSliceSize
1, // ReduceKThreadSliceSize
1, // InSrcOutDstVectorSize
IsFastestDimReduced>;
int main() int main()
{ {
bool do_verification = true; bool do_verification = true;
...@@ -51,7 +75,8 @@ int main() ...@@ -51,7 +75,8 @@ int main()
ck::index_t in_right_pad_h = 1; ck::index_t in_right_pad_h = 1;
ck::index_t in_right_pad_w = 1; ck::index_t in_right_pad_w = 1;
bool pass = pool3d_test<InDataType, bool pass = pool3d_test<DevicePoolFwdInstance,
InDataType,
OutDataType, OutDataType,
ComputeDataType, ComputeDataType,
IndexDataType, IndexDataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment