Commit 19b3a3e8 authored by ltqin's avatar ltqin
Browse files

reset input to zero

parent 591e9987
...@@ -180,6 +180,10 @@ int main(int argc, char* argv[]) ...@@ -180,6 +180,10 @@ int main(int argc, char* argv[])
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
// reset input to zero
in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0});
in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data());
// do GEMM // do GEMM
auto conv = DeviceConvBwdDataInstance{}; auto conv = DeviceConvBwdDataInstance{};
auto invoker = conv.MakeInvoker(); auto invoker = conv.MakeInvoker();
......
...@@ -182,8 +182,8 @@ int main(int argc, char* argv[]) ...@@ -182,8 +182,8 @@ int main(int argc, char* argv[])
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
// reset input to zero
in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{5}); in_n_c_hi_wi_device_result.GenerateTensorValue(GeneratorTensor_1<InDataType>{0});
in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data()); in_device_buf.ToDevice(in_n_c_hi_wi_device_result.mData.data());
// get host result // get host result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment