"docs/en/git@developer.sourcefind.cn:OpenDAS/opencompass.git" did not exist on "9e5746d3d8903165cc884637504de634bf4d04b6"
Commit 81caa447 authored by rocking's avatar rocking
Browse files

Fix bf16 error

parent 1a1059ab
...@@ -100,8 +100,8 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -100,8 +100,8 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi >= 0 && wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4])) wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4]))
{ {
ComputeDataType currVal = ComputeDataType currVal = ck::type_convert<ComputeDataType>(
static_cast<ComputeDataType>(arg.in_(n, c, di, hi, wi)); arg.in_(n, c, di, hi, wi));
in_elementwise_op(currVal, currVal); in_elementwise_op(currVal, currVal);
...@@ -112,7 +112,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -112,7 +112,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
} }
acc_elementwise_op(accuVal, accuVal); acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, do_, ho, wo) = accuVal; arg.out_(n, c, do_, ho, wo) = ck::type_convert<OutDataType>(accuVal);
}; };
make_ParallelTensorFunctor(f_ncdhw, make_ParallelTensorFunctor(f_ncdhw,
...@@ -151,8 +151,8 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -151,8 +151,8 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi >= 0 && wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4])) wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4]))
{ {
ComputeDataType currVal = ComputeDataType currVal = ck::type_convert<ComputeDataType>(
static_cast<ComputeDataType>(arg.in_(n, c, di, hi, wi)); arg.in_(n, c, di, hi, wi));
IndexDataType currIndex = IndexDataType currIndex =
arg.in_.GetOffsetFromMultiIndex(n, c, di, hi, wi); arg.in_.GetOffsetFromMultiIndex(n, c, di, hi, wi);
...@@ -166,7 +166,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -166,7 +166,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
acc_elementwise_op(accuVal, accuVal); acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, do_, ho, wo) = accuVal; arg.out_(n, c, do_, ho, wo) = ck::type_convert<OutDataType>(accuVal);
arg.out_indices_(n, c, do_, ho, wo) = accuIndex; arg.out_indices_(n, c, do_, ho, wo) = accuIndex;
}; };
...@@ -212,7 +212,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -212,7 +212,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3])) wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]))
{ {
ComputeDataType currVal = ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, hi, wi)); ck::type_convert<ComputeDataType>(arg.in_(n, c, hi, wi));
in_elementwise_op(currVal, currVal); in_elementwise_op(currVal, currVal);
...@@ -222,7 +222,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -222,7 +222,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
} }
acc_elementwise_op(accuVal, accuVal); acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, ho, wo) = accuVal; arg.out_(n, c, ho, wo) = ck::type_convert<OutDataType>(accuVal);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(f_nchw,
...@@ -255,7 +255,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -255,7 +255,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3])) wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]))
{ {
ComputeDataType currVal = ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, hi, wi)); ck::type_convert<ComputeDataType>(arg.in_(n, c, hi, wi));
IndexDataType currIndex = IndexDataType currIndex =
arg.in_.GetOffsetFromMultiIndex(n, c, hi, wi); arg.in_.GetOffsetFromMultiIndex(n, c, hi, wi);
...@@ -268,7 +268,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -268,7 +268,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
} }
acc_elementwise_op(accuVal, accuVal); acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, ho, wo) = accuVal; arg.out_(n, c, ho, wo) = ck::type_convert<OutDataType>(accuVal);
arg.out_indices_(n, c, ho, wo) = accuIndex; arg.out_indices_(n, c, ho, wo) = accuIndex;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment