"docs/vscode:/vscode.git/clone" did not exist on "9bfe03c6d9d383344ff54749f4c10fe6b0372fb0"
Commit a2598b8a authored by rocking's avatar rocking
Browse files

Move initialize of workspace to the run

parent 4f1dbdf5
......@@ -163,9 +163,6 @@ bool maxpool_bwd_test(bool do_verification,
size_t pool_bwd_workspace_sz = pool_bwd.GetWorkSpaceSize(pool_bwd_argument_ptr.get());
DeviceMem pool_bwd_workspace_device_buf(pool_bwd_workspace_sz);
// similar to din_device_buf.SetZero()
// we need to set workspace to be zero
pool_bwd_workspace_device_buf.SetZero();
pool_bwd.SetWorkSpacePointer(pool_bwd_argument_ptr.get(),
pool_bwd_workspace_device_buf.GetDeviceBuffer());
......
......@@ -173,6 +173,14 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
if(arg.p_workspace_ == nullptr)
throw std::runtime_error("wrong! WorkSpace pointer has not been set");
index_t din_length_raw =
arg.din_grid_desc_.GetTransforms()[I0].GetUpperLengths()[I0];
hip_check_error(
hipMemset(arg.p_workspace_,
0,
din_length_raw * sizeof(DInDataType_AutomicAddPreCast)));
const auto put_kernel = kernel_put_element_1d<GridwisePutElementAtomicAdd,
InOutGrid1dDesc,
DOutDataType,
......@@ -244,7 +252,6 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
}
};
// User need to set the value of workspace to zero
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
{
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment