"include/vscode:/vscode.git/clone" did not exist on "f4ea00fc631e60b0b7abb1d0c454c51d0c6a2ecf"
Commit 514b2d1c authored by Jing Zhang's avatar Jing Zhang
Browse files

improved verify

parent 6a4db056
......@@ -10,6 +10,7 @@ template <class TIn,
class UpperPads>
void host_direct_convolution(const Tensor<TIn>& in_nchw,
const Tensor<TWei>& wei_kcyx,
const Tensor<TOut>& add_nkhw,
Tensor<TOut>& out_nkhw,
ConvStrides,
ConvDilations,
......@@ -40,7 +41,7 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
}
}
}
out_nkhw(n, k, ho, wo) += v;
out_nkhw(n, k, ho, wo) = v + add_nkhw(n, k, ho, wo);
};
auto f_par = make_ParallelTensorFunctor(f,
......
......@@ -656,7 +656,7 @@ int main(int argc, char* argv[])
Tensor<in_data_t> wei_kcyx(make_HostTensorDescriptor(wei_kcyx_desc));
Tensor<out_data_t> out_nkhw_host(make_HostTensorDescriptor(out_nkhw_desc));
Tensor<out_data_t> add_nkhw_device(make_HostTensorDescriptor(out_nkhw_desc));
Tensor<out_data_t> add_nkhw(make_HostTensorDescriptor(out_nkhw_desc));
Tensor<out_data_t> out_nkhw_device(make_HostTensorDescriptor(out_nkhw_desc));
std::size_t num_thread = std::thread::hardware_concurrency();
......@@ -693,9 +693,7 @@ int main(int argc, char* argv[])
};
wei_kcyx.GenerateTensorValue(gen_wei, num_thread);
#endif
out_nkhw_host.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
add_nkhw_device.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
add_nkhw.GenerateTensorValue(GeneratorTensor_2{-1, 1}, num_thread);
}
#if 0
......@@ -778,7 +776,7 @@ int main(int argc, char* argv[])
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
add_nkhw_device,
add_nkhw,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
......@@ -791,6 +789,7 @@ int main(int argc, char* argv[])
{
host_direct_convolution(in_nchw,
wei_kcyx,
add_nkhw,
out_nkhw_host,
ConvStrides{},
ConvDilations{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment