Commit 47ac3767 authored by rocking's avatar rocking
Browse files

Save din_length_raw

parent 283f9b62
...@@ -95,6 +95,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -95,6 +95,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
: p_dout_{p_dout}, : p_dout_{p_dout},
p_indices_{p_indices}, p_indices_{p_indices},
p_din_{p_din}, p_din_{p_din},
din_length_raw_{din_length},
blockSize_{256}, blockSize_{256},
gridSize_{104}, // FIXME - Calculate the grid size by number of CU in the future gridSize_{104}, // FIXME - Calculate the grid size by number of CU in the future
windowOverlap_{false} windowOverlap_{false}
...@@ -112,6 +113,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -112,6 +113,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
const DOutDataType* p_dout_; const DOutDataType* p_dout_;
const IndexDataType* p_indices_; const IndexDataType* p_indices_;
DInDataType* p_din_; DInDataType* p_din_;
index_t din_length_raw_;
index_t blockSize_; index_t blockSize_;
index_t gridSize_; index_t gridSize_;
bool windowOverlap_; bool windowOverlap_;
...@@ -123,12 +125,12 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -123,12 +125,12 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
index_t din_length_raw = arg.din_grid_desc_.GetTransforms()[I0].GetUpperLengths()[I0];
if constexpr(is_same_v<DInDataType, float> || is_same_v<DInDataType, double>) if constexpr(is_same_v<DInDataType, float> || is_same_v<DInDataType, double>)
{ {
hip_check_error(hipMemsetAsync( hip_check_error(hipMemsetAsync(arg.p_din_,
arg.p_din_, 0, din_length_raw * sizeof(DInDataType), stream_config.stream_id_)); 0,
arg.din_length_raw_ * sizeof(DInDataType),
stream_config.stream_id_));
if(arg.windowOverlap_) if(arg.windowOverlap_)
{ {
...@@ -181,7 +183,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -181,7 +183,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
hip_check_error( hip_check_error(
hipMemsetAsync(arg.p_workspace_, hipMemsetAsync(arg.p_workspace_,
0, 0,
din_length_raw * sizeof(DInDataType_AutomicAddPreCast), arg.din_length_raw_ * sizeof(DInDataType_AutomicAddPreCast),
stream_config.stream_id_)); stream_config.stream_id_));
const auto put_kernel = kernel_put_element_1d<GridwisePutElementAtomicAdd, const auto put_kernel = kernel_put_element_1d<GridwisePutElementAtomicAdd,
...@@ -236,7 +238,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -236,7 +238,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
hip_check_error(hipMemsetAsync(arg.p_din_, hip_check_error(hipMemsetAsync(arg.p_din_,
0, 0,
din_length_raw * sizeof(DInDataType), arg.din_length_raw_ * sizeof(DInDataType),
stream_config.stream_id_)); stream_config.stream_id_));
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
...@@ -270,10 +272,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -270,10 +272,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
if(!needCast) if(!needCast)
return 0; return 0;
else else
{ return pArg_->din_length_raw_ * sizeof(DInDataType_AutomicAddPreCast);
index_t din_length = pArg_->din_grid_desc_.GetTransforms()[I0].GetUpperLengths()[I0];
return din_length * sizeof(DInDataType_AutomicAddPreCast);
}
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment