"...composable_kernel_rocm.git" did not exist on "fab2f10a554974998e8a979d7992c02784bfc848"
Commit a2598b8a authored by rocking's avatar rocking
Browse files

Move initialize of workspace to the run

parent 4f1dbdf5
...@@ -163,9 +163,6 @@ bool maxpool_bwd_test(bool do_verification, ...@@ -163,9 +163,6 @@ bool maxpool_bwd_test(bool do_verification,
size_t pool_bwd_workspace_sz = pool_bwd.GetWorkSpaceSize(pool_bwd_argument_ptr.get()); size_t pool_bwd_workspace_sz = pool_bwd.GetWorkSpaceSize(pool_bwd_argument_ptr.get());
DeviceMem pool_bwd_workspace_device_buf(pool_bwd_workspace_sz); DeviceMem pool_bwd_workspace_device_buf(pool_bwd_workspace_sz);
// similar to din_device_buf.SetZero()
// we need to set workspace to be zero
pool_bwd_workspace_device_buf.SetZero();
pool_bwd.SetWorkSpacePointer(pool_bwd_argument_ptr.get(), pool_bwd.SetWorkSpacePointer(pool_bwd_argument_ptr.get(),
pool_bwd_workspace_device_buf.GetDeviceBuffer()); pool_bwd_workspace_device_buf.GetDeviceBuffer());
......
...@@ -173,6 +173,14 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -173,6 +173,14 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
if(arg.p_workspace_ == nullptr) if(arg.p_workspace_ == nullptr)
throw std::runtime_error("wrong! WorkSpace pointer has not been set"); throw std::runtime_error("wrong! WorkSpace pointer has not been set");
index_t din_length_raw =
arg.din_grid_desc_.GetTransforms()[I0].GetUpperLengths()[I0];
hip_check_error(
hipMemset(arg.p_workspace_,
0,
din_length_raw * sizeof(DInDataType_AutomicAddPreCast)));
const auto put_kernel = kernel_put_element_1d<GridwisePutElementAtomicAdd, const auto put_kernel = kernel_put_element_1d<GridwisePutElementAtomicAdd,
InOutGrid1dDesc, InOutGrid1dDesc,
DOutDataType, DOutDataType,
...@@ -244,7 +252,6 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -244,7 +252,6 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
} }
}; };
// User need to set the value of workspace to zero
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
{ {
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg); const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment