"trace/process.cc" did not exist on "218099b35d4c1d3ab5f500075e7d5b66f0faa13f"
Commit 85b2fee8 authored by rocking's avatar rocking
Browse files

Add compute datatype for reference code.

Prevent error in bf16
parent 7f09b8a0
......@@ -213,7 +213,7 @@ bool maxpool_bwd_test(bool do_verification,
ref_pooling_fwd_invoker.Run(ref_pooling_fwd_argument);
using ReferencePoolingBwdInstance = ck::tensor_operation::host::
ReferenceMaxPoolBwd<DOutDataType, IndexDataType, DInDataType, PassThrough>;
ReferenceMaxPoolBwd<DOutDataType, IndexDataType, float, DInDataType, PassThrough>;
auto ref_pooling_bwd = ReferencePoolingBwdInstance{};
auto ref_pooling_bwd_invoker = ref_pooling_bwd.MakeInvoker();
......
......@@ -21,6 +21,7 @@ using namespace std;
template <typename DOutDataType,
typename IndexDataType,
typename ConputeDataType,
typename DInDataType,
typename ElementwiseOperation>
struct ReferenceMaxPoolBwd : public device::BaseOperator
......@@ -49,13 +50,17 @@ struct ReferenceMaxPoolBwd : public device::BaseOperator
{
int din_length = arg.din_.GetElementSpaceSize();
int dout_length = arg.dout_.GetElementSpaceSize();
std::vector<ConputeDataType> buf(din_length);
for(int i = 0; i < dout_length; ++i)
{
int index = arg.indices_.mData[i];
if(index >= 0 && index < din_length)
arg.din_.mData[index] += arg.dout_.mData[i];
buf[index] += ck::type_convert<ConputeDataType>(arg.dout_.mData[i]);
}
for(int i = 0; i < din_length; ++i)
arg.din_.mData[i] = ck::type_convert<DInDataType>(buf[i]);
return 0;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment