Commit 283f9b62 authored by rocking's avatar rocking
Browse files

Move set din zero to the device operator

parent a2598b8a
...@@ -116,7 +116,6 @@ bool maxpool_bwd_test(bool do_verification, ...@@ -116,7 +116,6 @@ bool maxpool_bwd_test(bool do_verification,
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
dout_device_buf.ToDevice(dout_n_c_ho_wo.mData.data()); dout_device_buf.ToDevice(dout_n_c_ho_wo.mData.data());
din_device_buf.SetZero();
auto pool_fwd = DevicePoolFwdInstance{}; auto pool_fwd = DevicePoolFwdInstance{};
auto pool_fwd_invoker_ptr = pool_fwd.MakeInvokerPointer(); auto pool_fwd_invoker_ptr = pool_fwd.MakeInvokerPointer();
......
...@@ -123,8 +123,13 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -123,8 +123,13 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
index_t din_length_raw = arg.din_grid_desc_.GetTransforms()[I0].GetUpperLengths()[I0];
if constexpr(is_same_v<DInDataType, float> || is_same_v<DInDataType, double>) if constexpr(is_same_v<DInDataType, float> || is_same_v<DInDataType, double>)
{ {
hip_check_error(hipMemsetAsync(
arg.p_din_, 0, din_length_raw * sizeof(DInDataType), stream_config.stream_id_));
if(arg.windowOverlap_) if(arg.windowOverlap_)
{ {
const auto put_kernel = kernel_put_element_1d<GridwisePutElementAtomicAdd, const auto put_kernel = kernel_put_element_1d<GridwisePutElementAtomicAdd,
...@@ -173,13 +178,11 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -173,13 +178,11 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
if(arg.p_workspace_ == nullptr) if(arg.p_workspace_ == nullptr)
throw std::runtime_error("wrong! WorkSpace pointer has not been set"); throw std::runtime_error("wrong! WorkSpace pointer has not been set");
index_t din_length_raw =
arg.din_grid_desc_.GetTransforms()[I0].GetUpperLengths()[I0];
hip_check_error( hip_check_error(
hipMemset(arg.p_workspace_, hipMemsetAsync(arg.p_workspace_,
0, 0,
din_length_raw * sizeof(DInDataType_AutomicAddPreCast))); din_length_raw * sizeof(DInDataType_AutomicAddPreCast),
stream_config.stream_id_));
const auto put_kernel = kernel_put_element_1d<GridwisePutElementAtomicAdd, const auto put_kernel = kernel_put_element_1d<GridwisePutElementAtomicAdd,
InOutGrid1dDesc, InOutGrid1dDesc,
...@@ -231,6 +234,11 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -231,6 +234,11 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
DInDataType, DInDataType,
PassThrough>; PassThrough>;
hip_check_error(hipMemsetAsync(arg.p_din_,
0,
din_length_raw * sizeof(DInDataType),
stream_config.stream_id_));
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
put_kernel, put_kernel,
dim3(arg.gridSize_), dim3(arg.gridSize_),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment