Commit 1a1059ab authored by rocking's avatar rocking
Browse files

Remove layout

parent 1737b087
...@@ -27,8 +27,6 @@ template <typename InDataType, ...@@ -27,8 +27,6 @@ template <typename InDataType,
typename ComputeDataType, typename ComputeDataType,
typename DInDataType, typename DInDataType,
typename DOutDataType, typename DOutDataType,
typename InLayout,
typename OutLayout,
bool PropagateNan> bool PropagateNan>
bool maxpool_bwd_test(bool do_verification, bool maxpool_bwd_test(bool do_verification,
bool time_kernel, bool time_kernel,
...@@ -73,41 +71,30 @@ bool maxpool_bwd_test(bool do_verification, ...@@ -73,41 +71,30 @@ bool maxpool_bwd_test(bool do_verification,
const std::vector<ck::index_t> input_left_pads{in_left_pad_h, in_left_pad_w}; const std::vector<ck::index_t> input_left_pads{in_left_pad_h, in_left_pad_w};
const std::vector<ck::index_t> input_right_pads{in_right_pad_h, in_right_pad_w}; const std::vector<ck::index_t> input_right_pads{in_right_pad_h, in_right_pad_w};
// tensor layout
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) { [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W) {
using namespace ck::literals; using namespace ck::literals;
// reference need Tensor with NCHW order
if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NCHW>::value) return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_});
{
return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz});
}
else if constexpr(ck::is_same<decltype(layout),
ck::tensor_layout::convolution::NHWC>::value)
{
return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_});
}
}; };
// in // in
Tensor<InDataType> in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); Tensor<InDataType> in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi));
// out // out
Tensor<OutDataType> out_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); Tensor<OutDataType> out_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo));
Tensor<OutDataType> out_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); Tensor<OutDataType> out_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo));
// indices // indices
Tensor<IndexDataType> indices_n_c_ho_wo_device( Tensor<IndexDataType> indices_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo));
f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); Tensor<IndexDataType> indices_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo));
Tensor<IndexDataType> indices_n_c_ho_wo_host(
f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{}));
// dout // dout
Tensor<DOutDataType> dout_n_c_ho_wo(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); Tensor<DOutDataType> dout_n_c_ho_wo(f_host_tensor_descriptor(N, C, Ho, Wo));
// din // din
Tensor<DInDataType> din_n_c_hi_wi_host(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); Tensor<DInDataType> din_n_c_hi_wi_host(f_host_tensor_descriptor(N, C, Hi, Wi));
Tensor<DInDataType> din_n_c_hi_wi_device(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); Tensor<DInDataType> din_n_c_hi_wi_device(f_host_tensor_descriptor(N, C, Hi, Wi));
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl;
std::cout << "out_n_c_ho_wo: " << out_n_c_ho_wo_host.mDesc << std::endl; std::cout << "out_n_c_ho_wo: " << out_n_c_ho_wo_host.mDesc << std::endl;
...@@ -212,8 +199,12 @@ bool maxpool_bwd_test(bool do_verification, ...@@ -212,8 +199,12 @@ bool maxpool_bwd_test(bool do_verification,
input_right_pads); input_right_pads);
ref_pooling_fwd_invoker.Run(ref_pooling_fwd_argument); ref_pooling_fwd_invoker.Run(ref_pooling_fwd_argument);
using ReferencePoolingBwdInstance = ck::tensor_operation::host:: using ReferencePoolingBwdInstance =
ReferenceMaxPoolBwd<DOutDataType, IndexDataType, float, DInDataType, PassThrough>; ck::tensor_operation::host::ReferenceMaxPoolBwd<DOutDataType,
IndexDataType,
ComputeDataType,
DInDataType,
PassThrough>;
auto ref_pooling_bwd = ReferencePoolingBwdInstance{}; auto ref_pooling_bwd = ReferencePoolingBwdInstance{};
auto ref_pooling_bwd_invoker = ref_pooling_bwd.MakeInvoker(); auto ref_pooling_bwd_invoker = ref_pooling_bwd.MakeInvoker();
......
...@@ -16,9 +16,6 @@ using ComputeDataType = float; ...@@ -16,9 +16,6 @@ using ComputeDataType = float;
using DInDataType = float; using DInDataType = float;
using DOutDataType = float; using DOutDataType = float;
using InLayout = ck::tensor_layout::convolution::NHWC;
using OutLayout = ck::tensor_layout::convolution::NHWC;
static constexpr bool PropagateNan = false; static constexpr bool PropagateNan = false;
int main() int main()
...@@ -29,16 +26,16 @@ int main() ...@@ -29,16 +26,16 @@ int main()
// Pool shape // Pool shape
ck::index_t N = 1; ck::index_t N = 1;
ck::index_t C = 1; ck::index_t C = 1;
ck::index_t Y = 3; ck::index_t Y = 2;
ck::index_t X = 3; ck::index_t X = 2;
ck::index_t Hi = 31; ck::index_t Hi = 32;
ck::index_t Wi = 31; ck::index_t Wi = 32;
ck::index_t window_stride_h = 1; ck::index_t window_stride_h = 2;
ck::index_t window_stride_w = 1; ck::index_t window_stride_w = 2;
ck::index_t in_left_pad_h = 0; ck::index_t in_left_pad_h = 0;
ck::index_t in_left_pad_w = 0; ck::index_t in_left_pad_w = 0;
ck::index_t in_right_pad_h = 1; ck::index_t in_right_pad_h = 0;
ck::index_t in_right_pad_w = 1; ck::index_t in_right_pad_w = 0;
bool pass = maxpool_bwd_test<InDataType, bool pass = maxpool_bwd_test<InDataType,
OutDataType, OutDataType,
...@@ -46,8 +43,6 @@ int main() ...@@ -46,8 +43,6 @@ int main()
ComputeDataType, ComputeDataType,
DInDataType, DInDataType,
DOutDataType, DOutDataType,
InLayout,
OutLayout,
PropagateNan>(do_verification, PropagateNan>(do_verification,
time_kernel, time_kernel,
N, N,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment