Commit 81caa447 authored by rocking's avatar rocking
Browse files

Fix bf16 error

parent 1a1059ab
......@@ -100,8 +100,8 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4]))
{
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, di, hi, wi));
ComputeDataType currVal = ck::type_convert<ComputeDataType>(
arg.in_(n, c, di, hi, wi));
in_elementwise_op(currVal, currVal);
......@@ -112,7 +112,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
}
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, do_, ho, wo) = accuVal;
arg.out_(n, c, do_, ho, wo) = ck::type_convert<OutDataType>(accuVal);
};
make_ParallelTensorFunctor(f_ncdhw,
......@@ -151,8 +151,8 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4]))
{
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, di, hi, wi));
ComputeDataType currVal = ck::type_convert<ComputeDataType>(
arg.in_(n, c, di, hi, wi));
IndexDataType currIndex =
arg.in_.GetOffsetFromMultiIndex(n, c, di, hi, wi);
......@@ -166,7 +166,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, do_, ho, wo) = accuVal;
arg.out_(n, c, do_, ho, wo) = ck::type_convert<OutDataType>(accuVal);
arg.out_indices_(n, c, do_, ho, wo) = accuIndex;
};
......@@ -212,7 +212,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]))
{
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, hi, wi));
ck::type_convert<ComputeDataType>(arg.in_(n, c, hi, wi));
in_elementwise_op(currVal, currVal);
......@@ -222,7 +222,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
}
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, ho, wo) = accuVal;
arg.out_(n, c, ho, wo) = ck::type_convert<OutDataType>(accuVal);
};
make_ParallelTensorFunctor(f_nchw,
......@@ -255,7 +255,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]))
{
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, hi, wi));
ck::type_convert<ComputeDataType>(arg.in_(n, c, hi, wi));
IndexDataType currIndex =
arg.in_.GetOffsetFromMultiIndex(n, c, hi, wi);
......@@ -268,7 +268,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
}
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, ho, wo) = accuVal;
arg.out_(n, c, ho, wo) = ck::type_convert<OutDataType>(accuVal);
arg.out_indices_(n, c, ho, wo) = accuIndex;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment