Commit 85b2fee8 authored by rocking's avatar rocking
Browse files

Add compute datatype for reference code.

Prevent error in bf16
parent 7f09b8a0
...@@ -213,7 +213,7 @@ bool maxpool_bwd_test(bool do_verification, ...@@ -213,7 +213,7 @@ bool maxpool_bwd_test(bool do_verification,
ref_pooling_fwd_invoker.Run(ref_pooling_fwd_argument); ref_pooling_fwd_invoker.Run(ref_pooling_fwd_argument);
using ReferencePoolingBwdInstance = ck::tensor_operation::host:: using ReferencePoolingBwdInstance = ck::tensor_operation::host::
ReferenceMaxPoolBwd<DOutDataType, IndexDataType, DInDataType, PassThrough>; ReferenceMaxPoolBwd<DOutDataType, IndexDataType, float, DInDataType, PassThrough>;
auto ref_pooling_bwd = ReferencePoolingBwdInstance{}; auto ref_pooling_bwd = ReferencePoolingBwdInstance{};
auto ref_pooling_bwd_invoker = ref_pooling_bwd.MakeInvoker(); auto ref_pooling_bwd_invoker = ref_pooling_bwd.MakeInvoker();
......
...@@ -21,6 +21,7 @@ using namespace std; ...@@ -21,6 +21,7 @@ using namespace std;
template <typename DOutDataType, template <typename DOutDataType,
typename IndexDataType, typename IndexDataType,
typename ConputeDataType,
typename DInDataType, typename DInDataType,
typename ElementwiseOperation> typename ElementwiseOperation>
struct ReferenceMaxPoolBwd : public device::BaseOperator struct ReferenceMaxPoolBwd : public device::BaseOperator
...@@ -49,13 +50,17 @@ struct ReferenceMaxPoolBwd : public device::BaseOperator ...@@ -49,13 +50,17 @@ struct ReferenceMaxPoolBwd : public device::BaseOperator
{ {
int din_length = arg.din_.GetElementSpaceSize(); int din_length = arg.din_.GetElementSpaceSize();
int dout_length = arg.dout_.GetElementSpaceSize(); int dout_length = arg.dout_.GetElementSpaceSize();
std::vector<ConputeDataType> buf(din_length);
for(int i = 0; i < dout_length; ++i) for(int i = 0; i < dout_length; ++i)
{ {
int index = arg.indices_.mData[i]; int index = arg.indices_.mData[i];
if(index >= 0 && index < din_length) if(index >= 0 && index < din_length)
arg.din_.mData[index] += arg.dout_.mData[i]; buf[index] += ck::type_convert<ConputeDataType>(arg.dout_.mData[i]);
} }
for(int i = 0; i < din_length; ++i)
arg.din_.mData[i] = ck::type_convert<DInDataType>(buf[i]);
return 0; return 0;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment