Commit a69937d3 authored by Jing Zhang's avatar Jing Zhang
Browse files

add maxpool host for validation

parent ec381569
...@@ -976,7 +976,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -976,7 +976,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
} }
#endif #endif
#if 0 #if 1
// Resize_Add // Resize_Add
if constexpr(add_type == 0) if constexpr(add_type == 0)
{ {
...@@ -1137,12 +1137,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -1137,12 +1137,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
make_tuple(ki, 0, hi * 2 + 1, wi * 2 + 1)); make_tuple(ki, 0, hi * 2 + 1, wi * 2 + 1));
d_thread_buf(Number<d_offset>{}) = c_thread_buf[Number<c_offset_0>{}]; d_thread_buf(Number<d_offset>{}) = c_thread_buf[Number<c_offset_0>{}];
d_thread_buf(Number<d_offset>{}) = max(c_thread_buf[Number<c_offset_1>{}], d_thread_buf(Number<d_offset>{}) = fmaxf(c_thread_buf[Number<c_offset_1>{}],
d_thread_buf(Number<d_offset>{})); d_thread_buf(Number<d_offset>{}));
d_thread_buf(Number<d_offset>{}) = max(c_thread_buf[Number<c_offset_2>{}], d_thread_buf(Number<d_offset>{}) = fmaxf(c_thread_buf[Number<c_offset_2>{}],
d_thread_buf(Number<d_offset>{})); d_thread_buf(Number<d_offset>{}));
d_thread_buf(Number<d_offset>{}) = max(c_thread_buf[Number<c_offset_3>{}], d_thread_buf(Number<d_offset>{}) = fmax(c_thread_buf[Number<c_offset_3>{}],
d_thread_buf(Number<d_offset>{})); d_thread_buf(Number<d_offset>{}));
}); });
}); });
}); });
......
...@@ -284,6 +284,25 @@ void host_direct_convolution_maxpool_nchwc(const Tensor<TIn>& in, ...@@ -284,6 +284,25 @@ void host_direct_convolution_maxpool_nchwc(const Tensor<TIn>& in,
out_host.mDesc.GetLengths()[2], out_host.mDesc.GetLengths()[2],
out_host.mDesc.GetLengths()[3], out_host.mDesc.GetLengths()[3],
out_host.mDesc.GetLengths()[4])(std::thread::hardware_concurrency()); out_host.mDesc.GetLengths()[4])(std::thread::hardware_concurrency());
auto maxpool_nchw = [&](auto n, auto k0, auto ho, auto wo, auto k1) {
auto hx = ho * 2;
auto wx = wo * 2;
auto v0 = out_host(n, k0, hx, wx, k1);
auto v1 = out_host(n, k0, hx, wx + 1, k1);
auto v2 = out_host(n, k0, hx + 1, wx, k1);
auto v3 = out_host(n, k0, hx + 1, wx + 1, k1);
max_host(n, k0, ho, wo, k1) = std::max({v0, v1, v2, v3});
};
make_ParallelTensorFunctor(maxpool_nchw,
max_host.mDesc.GetLengths()[0],
max_host.mDesc.GetLengths()[1],
max_host.mDesc.GetLengths()[2],
max_host.mDesc.GetLengths()[3],
max_host.mDesc.GetLengths()[4])(std::thread::hardware_concurrency());
} }
template <typename TIn, typename TWei, typename TOut, typename InLeftPads, typename InRightPads> template <typename TIn, typename TWei, typename TOut, typename InLeftPads, typename InRightPads>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment