Commit 5ebe74e6 authored by Jehandad Khan's avatar Jehandad Khan
Browse files

fix out_ptr bug

parent 2090160a
...@@ -411,7 +411,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -411,7 +411,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(), out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(),
arithmetic_sequence_gen<0, 8, 1>::type{}, arithmetic_sequence_gen<0, 8, 1>::type{},
Number<1>{}); Number<1>{});
#elif 0 #elif 1
p_out_global[0] = p_out_thread[0]; p_out_global[0] = p_out_thread[0];
#endif #endif
} }
......
...@@ -241,8 +241,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -241,8 +241,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
dim3(BlockSize), dim3(BlockSize),
0, 0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()), static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())); static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()), static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n", printf("Elapsed time : %f ms, %f TFlop/s\n",
time, time,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment