Commit 514b2d1c authored by Jing Zhang's avatar Jing Zhang
Browse files

improved verify

parent 6a4db056
...@@ -10,6 +10,7 @@ template <class TIn, ...@@ -10,6 +10,7 @@ template <class TIn,
class UpperPads> class UpperPads>
void host_direct_convolution(const Tensor<TIn>& in_nchw, void host_direct_convolution(const Tensor<TIn>& in_nchw,
const Tensor<TWei>& wei_kcyx, const Tensor<TWei>& wei_kcyx,
const Tensor<TOut>& add_nkhw,
Tensor<TOut>& out_nkhw, Tensor<TOut>& out_nkhw,
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
...@@ -40,7 +41,7 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw, ...@@ -40,7 +41,7 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
} }
} }
} }
out_nkhw(n, k, ho, wo) += v; out_nkhw(n, k, ho, wo) = v + add_nkhw(n, k, ho, wo);
}; };
auto f_par = make_ParallelTensorFunctor(f, auto f_par = make_ParallelTensorFunctor(f,
......
...@@ -656,7 +656,7 @@ int main(int argc, char* argv[]) ...@@ -656,7 +656,7 @@ int main(int argc, char* argv[])
Tensor<in_data_t> wei_kcyx(make_HostTensorDescriptor(wei_kcyx_desc)); Tensor<in_data_t> wei_kcyx(make_HostTensorDescriptor(wei_kcyx_desc));
Tensor<out_data_t> out_nkhw_host(make_HostTensorDescriptor(out_nkhw_desc)); Tensor<out_data_t> out_nkhw_host(make_HostTensorDescriptor(out_nkhw_desc));
Tensor<out_data_t> add_nkhw_device(make_HostTensorDescriptor(out_nkhw_desc)); Tensor<out_data_t> add_nkhw(make_HostTensorDescriptor(out_nkhw_desc));
Tensor<out_data_t> out_nkhw_device(make_HostTensorDescriptor(out_nkhw_desc)); Tensor<out_data_t> out_nkhw_device(make_HostTensorDescriptor(out_nkhw_desc));
std::size_t num_thread = std::thread::hardware_concurrency(); std::size_t num_thread = std::thread::hardware_concurrency();
...@@ -693,9 +693,7 @@ int main(int argc, char* argv[]) ...@@ -693,9 +693,7 @@ int main(int argc, char* argv[])
}; };
wei_kcyx.GenerateTensorValue(gen_wei, num_thread); wei_kcyx.GenerateTensorValue(gen_wei, num_thread);
#endif #endif
add_nkhw.GenerateTensorValue(GeneratorTensor_2{-1, 1}, num_thread);
out_nkhw_host.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
add_nkhw_device.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
} }
#if 0 #if 0
...@@ -778,7 +776,7 @@ int main(int argc, char* argv[]) ...@@ -778,7 +776,7 @@ int main(int argc, char* argv[])
wei_kcyx_desc, wei_kcyx_desc,
wei_kcyx, wei_kcyx,
out_nkhw_desc, out_nkhw_desc,
add_nkhw_device, add_nkhw,
out_nkhw_device, out_nkhw_device,
ConvStrides{}, ConvStrides{},
ConvDilations{}, ConvDilations{},
...@@ -791,6 +789,7 @@ int main(int argc, char* argv[]) ...@@ -791,6 +789,7 @@ int main(int argc, char* argv[])
{ {
host_direct_convolution(in_nchw, host_direct_convolution(in_nchw,
wei_kcyx, wei_kcyx,
add_nkhw,
out_nkhw_host, out_nkhw_host,
ConvStrides{}, ConvStrides{},
ConvDilations{}, ConvDilations{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment