"vscode:/vscode.git/clone" did not exist on "4ddee80b285368d0efe6227af989ca0d11aef379"
Commit 523d5c4b authored by rocking's avatar rocking
Browse files

Move instance declaration out of common

parent e2594ff7
......@@ -55,7 +55,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
}
};
template <typename InDataType,
template <typename DevicePoolFwdInstance,
typename InDataType,
typename OutDataType,
typename ComputeDataType,
typename IndexDataType,
......@@ -84,20 +85,6 @@ bool pool3d_test(bool do_verification,
ck::index_t in_right_pad_h,
ck::index_t in_right_pad_w)
{
using DevicePoolFwdInstance =
ck::tensor_operation::device::DevicePool3dFwdImpl<InDataType, // InDataType
OutDataType, // OutDataType
IndexDataType, // IndexDataType
ComputeDataType, // ComputeDataType
ReduceOpId,
OutputIndex,
64, // BlockSize
64, // ReduceMThreadClusterSize
1, // ReduceKThreadClusterSize
1, // ReduceMThreadSliceSize
1, // ReduceKThreadSliceSize
1, // InSrcOutDstVectorSize
false>; // IsFastestDimReduced
const ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Z) / window_stride_d + 1;
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1;
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1;
......
......@@ -15,9 +15,18 @@ using ComputeDataType = float;
using IndexDataType = int32_t;
#if 1
using InLayout = ck::tensor_layout::convolution::NDHWC;
using OutLayout = ck::tensor_layout::convolution::NDHWC;
static constexpr bool IsFastestDimReduced = false;
#else
using InLayout = ck::tensor_layout::convolution::NCDHW;
using OutLayout = ck::tensor_layout::convolution::NCDHW;
static constexpr bool IsFastestDimReduced = true;
#endif
#if 1
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
#else
......@@ -27,6 +36,21 @@ static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
static constexpr bool OutputIndex = false;
static constexpr bool PropagateNan = false;
using DevicePoolFwdInstance =
ck::tensor_operation::device::DevicePool3dFwdImpl<InDataType, // InDataType
OutDataType, // OutDataType
IndexDataType, // IndexDataType
ComputeDataType, // ComputeDataType
ReduceOpId,
OutputIndex,
64, // BlockSize
64, // ReduceMThreadClusterSize
1, // ReduceKThreadClusterSize
1, // ReduceMThreadSliceSize
1, // ReduceKThreadSliceSize
1, // InSrcOutDstVectorSize
IsFastestDimReduced>;
int main()
{
bool do_verification = true;
......@@ -51,7 +75,8 @@ int main()
ck::index_t in_right_pad_h = 1;
ck::index_t in_right_pad_w = 1;
bool pass = pool3d_test<InDataType,
bool pass = pool3d_test<DevicePoolFwdInstance,
InDataType,
OutDataType,
ComputeDataType,
IndexDataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment