Commit 47ac3767 authored by rocking's avatar rocking
Browse files

Save din_length_raw

parent 283f9b62
......@@ -95,6 +95,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
: p_dout_{p_dout},
p_indices_{p_indices},
p_din_{p_din},
din_length_raw_{din_length},
blockSize_{256},
gridSize_{104}, // FIXME - Calculate the grid size by number of CU in the future
windowOverlap_{false}
......@@ -112,6 +113,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
const DOutDataType* p_dout_;
const IndexDataType* p_indices_;
DInDataType* p_din_;
index_t din_length_raw_;
index_t blockSize_;
index_t gridSize_;
bool windowOverlap_;
......@@ -123,12 +125,12 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
index_t din_length_raw = arg.din_grid_desc_.GetTransforms()[I0].GetUpperLengths()[I0];
if constexpr(is_same_v<DInDataType, float> || is_same_v<DInDataType, double>)
{
hip_check_error(hipMemsetAsync(
arg.p_din_, 0, din_length_raw * sizeof(DInDataType), stream_config.stream_id_));
hip_check_error(hipMemsetAsync(arg.p_din_,
0,
arg.din_length_raw_ * sizeof(DInDataType),
stream_config.stream_id_));
if(arg.windowOverlap_)
{
......@@ -181,7 +183,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
hip_check_error(
hipMemsetAsync(arg.p_workspace_,
0,
din_length_raw * sizeof(DInDataType_AutomicAddPreCast),
arg.din_length_raw_ * sizeof(DInDataType_AutomicAddPreCast),
stream_config.stream_id_));
const auto put_kernel = kernel_put_element_1d<GridwisePutElementAtomicAdd,
......@@ -236,7 +238,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
hip_check_error(hipMemsetAsync(arg.p_din_,
0,
din_length_raw * sizeof(DInDataType),
arg.din_length_raw_ * sizeof(DInDataType),
stream_config.stream_id_));
return launch_and_time_kernel(stream_config,
......@@ -270,10 +272,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
if(!needCast)
return 0;
else
{
index_t din_length = pArg_->din_grid_desc_.GetTransforms()[I0].GetUpperLengths()[I0];
return din_length * sizeof(DInDataType_AutomicAddPreCast);
}
return pArg_->din_length_raw_ * sizeof(DInDataType_AutomicAddPreCast);
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment