Commit 283f9b62 authored by rocking's avatar rocking
Browse files

Move set din zero to the device operator

parent a2598b8a
......@@ -116,7 +116,6 @@ bool maxpool_bwd_test(bool do_verification,
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
dout_device_buf.ToDevice(dout_n_c_ho_wo.mData.data());
din_device_buf.SetZero();
auto pool_fwd = DevicePoolFwdInstance{};
auto pool_fwd_invoker_ptr = pool_fwd.MakeInvokerPointer();
......
......@@ -123,8 +123,13 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
index_t din_length_raw = arg.din_grid_desc_.GetTransforms()[I0].GetUpperLengths()[I0];
if constexpr(is_same_v<DInDataType, float> || is_same_v<DInDataType, double>)
{
hip_check_error(hipMemsetAsync(
arg.p_din_, 0, din_length_raw * sizeof(DInDataType), stream_config.stream_id_));
if(arg.windowOverlap_)
{
const auto put_kernel = kernel_put_element_1d<GridwisePutElementAtomicAdd,
......@@ -173,13 +178,11 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
if(arg.p_workspace_ == nullptr)
throw std::runtime_error("wrong! WorkSpace pointer has not been set");
index_t din_length_raw =
arg.din_grid_desc_.GetTransforms()[I0].GetUpperLengths()[I0];
hip_check_error(
hipMemset(arg.p_workspace_,
hipMemsetAsync(arg.p_workspace_,
0,
din_length_raw * sizeof(DInDataType_AutomicAddPreCast)));
din_length_raw * sizeof(DInDataType_AutomicAddPreCast),
stream_config.stream_id_));
const auto put_kernel = kernel_put_element_1d<GridwisePutElementAtomicAdd,
InOutGrid1dDesc,
......@@ -231,6 +234,11 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
DInDataType,
PassThrough>;
hip_check_error(hipMemsetAsync(arg.p_din_,
0,
din_length_raw * sizeof(DInDataType),
stream_config.stream_id_));
return launch_and_time_kernel(stream_config,
put_kernel,
dim3(arg.gridSize_),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment