Commit cf326690 authored by rocking's avatar rocking
Browse files

[What] Use F32 as the acc of reduce sum

[Why] Prevent loss of precision
parent c16f789d
......@@ -98,15 +98,15 @@ constexpr ck::ReduceTensorOp ReduceMaxId = ck::ReduceTensorOp::MAX;
constexpr ck::ReduceTensorOp ReduceSumId = ck::ReduceTensorOp::ADD;
constexpr bool ReducePropagateNan = false;
using ReduceMaxOp = typename ck::reduce_binary_operator<CDataType, ReduceMaxId>::opType;
using ReduceSumOp = typename ck::reduce_binary_operator<CDataType, ReduceSumId>::opType;
using ReduceSumOp = typename ck::reduce_binary_operator<AccDataType, ReduceSumId>::opType;
using ReduceMaxInElementwiseOperation =
typename ck::reduce_unary_operator<CDataType, ReduceMaxId, true, true>::InElementwiseOperation;
using ReduceMaxAccElementwiseOperation =
typename ck::reduce_unary_operator<CDataType, ReduceMaxId, true, true>::AccElementwiseOperation;
using ReduceSumInElementwiseOperation =
typename ck::reduce_unary_operator<CDataType, ReduceSumId, true, true>::InElementwiseOperation;
typename ck::reduce_unary_operator<AccDataType, ReduceSumId, true, true>::InElementwiseOperation;
using ReduceSumAccElementwiseOperation =
typename ck::reduce_unary_operator<CDataType, ReduceSumId, true, true>::AccElementwiseOperation;
typename ck::reduce_unary_operator<AccDataType, ReduceSumId, true, true>::AccElementwiseOperation;
using DeviceReduceMaxInstance =
ck::tensor_operation::device::DeviceReduceBlockWise<CDataType,
......@@ -130,7 +130,7 @@ using DeviceReduceMaxInstance =
using DeviceReduceSumInstance =
ck::tensor_operation::device::DeviceReduceBlockWise<CDataType,
CDataType,
AccDataType,
CDataType,
Rank,
NumReduceDim,
......@@ -188,7 +188,7 @@ using HostReduceMaxInstance = ReductionHost<HostReduceDataType,
false>;
using HostReduceSumInstance = ReductionHost<HostReduceDataType,
HostReduceDataType,
AccDataType,
HostReduceDataType,
ReduceSumId,
Rank,
......@@ -504,15 +504,15 @@ int main(int argc, char* argv[])
bool result = true;
if(result &= ck::utils::check_err(c_m_n.mData, host_c_m_n.mData))
std::cout << "[PASS] - c_m_n" << std::endl;
std::cout << "[PASS] - gemm" << std::endl;
if(result &= ck::utils::check_err(c_n_max.mData, host_c_n_max.mData))
std::cout << "[PASS] - c_n_max" << std::endl;
std::cout << "[PASS] - reduce max" << std::endl;
if(result &= ck::utils::check_err(exp_m_n.mData, host_exp_m_n.mData))
std::cout << "[PASS] - exp_m_n" << std::endl;
std::cout << "[PASS] - broadcast sub + exp" << std::endl;
if(result &= ck::utils::check_err(exp_n_sum.mData, host_exp_n_sum.mData))
std::cout << "[PASS] - exp_n_sum" << std::endl;
std::cout << "[PASS] - reduce sum" << std::endl;
if(result &= ck::utils::check_err(softmax_m_n.mData, host_softmax_m_n.mData))
std::cout << "[PASS] - softmax_m_n" << std::endl;
std::cout << "[PASS] - broadcast div" << std::endl;
}
return 0;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment