Commit cf326690 authored by rocking's avatar rocking
Browse files

[What] Use F32 as the acc of reduce sum

[Why] Prevent loss of precision
parent c16f789d
...@@ -98,15 +98,15 @@ constexpr ck::ReduceTensorOp ReduceMaxId = ck::ReduceTensorOp::MAX; ...@@ -98,15 +98,15 @@ constexpr ck::ReduceTensorOp ReduceMaxId = ck::ReduceTensorOp::MAX;
constexpr ck::ReduceTensorOp ReduceSumId = ck::ReduceTensorOp::ADD; constexpr ck::ReduceTensorOp ReduceSumId = ck::ReduceTensorOp::ADD;
constexpr bool ReducePropagateNan = false; constexpr bool ReducePropagateNan = false;
using ReduceMaxOp = typename ck::reduce_binary_operator<CDataType, ReduceMaxId>::opType; using ReduceMaxOp = typename ck::reduce_binary_operator<CDataType, ReduceMaxId>::opType;
using ReduceSumOp = typename ck::reduce_binary_operator<CDataType, ReduceSumId>::opType; using ReduceSumOp = typename ck::reduce_binary_operator<AccDataType, ReduceSumId>::opType;
using ReduceMaxInElementwiseOperation = using ReduceMaxInElementwiseOperation =
typename ck::reduce_unary_operator<CDataType, ReduceMaxId, true, true>::InElementwiseOperation; typename ck::reduce_unary_operator<CDataType, ReduceMaxId, true, true>::InElementwiseOperation;
using ReduceMaxAccElementwiseOperation = using ReduceMaxAccElementwiseOperation =
typename ck::reduce_unary_operator<CDataType, ReduceMaxId, true, true>::AccElementwiseOperation; typename ck::reduce_unary_operator<CDataType, ReduceMaxId, true, true>::AccElementwiseOperation;
using ReduceSumInElementwiseOperation = using ReduceSumInElementwiseOperation =
typename ck::reduce_unary_operator<CDataType, ReduceSumId, true, true>::InElementwiseOperation; typename ck::reduce_unary_operator<AccDataType, ReduceSumId, true, true>::InElementwiseOperation;
using ReduceSumAccElementwiseOperation = using ReduceSumAccElementwiseOperation =
typename ck::reduce_unary_operator<CDataType, ReduceSumId, true, true>::AccElementwiseOperation; typename ck::reduce_unary_operator<AccDataType, ReduceSumId, true, true>::AccElementwiseOperation;
using DeviceReduceMaxInstance = using DeviceReduceMaxInstance =
ck::tensor_operation::device::DeviceReduceBlockWise<CDataType, ck::tensor_operation::device::DeviceReduceBlockWise<CDataType,
...@@ -130,7 +130,7 @@ using DeviceReduceMaxInstance = ...@@ -130,7 +130,7 @@ using DeviceReduceMaxInstance =
using DeviceReduceSumInstance = using DeviceReduceSumInstance =
ck::tensor_operation::device::DeviceReduceBlockWise<CDataType, ck::tensor_operation::device::DeviceReduceBlockWise<CDataType,
CDataType, AccDataType,
CDataType, CDataType,
Rank, Rank,
NumReduceDim, NumReduceDim,
...@@ -188,7 +188,7 @@ using HostReduceMaxInstance = ReductionHost<HostReduceDataType, ...@@ -188,7 +188,7 @@ using HostReduceMaxInstance = ReductionHost<HostReduceDataType,
false>; false>;
using HostReduceSumInstance = ReductionHost<HostReduceDataType, using HostReduceSumInstance = ReductionHost<HostReduceDataType,
HostReduceDataType, AccDataType,
HostReduceDataType, HostReduceDataType,
ReduceSumId, ReduceSumId,
Rank, Rank,
...@@ -504,15 +504,15 @@ int main(int argc, char* argv[]) ...@@ -504,15 +504,15 @@ int main(int argc, char* argv[])
bool result = true; bool result = true;
if(result &= ck::utils::check_err(c_m_n.mData, host_c_m_n.mData)) if(result &= ck::utils::check_err(c_m_n.mData, host_c_m_n.mData))
std::cout << "[PASS] - c_m_n" << std::endl; std::cout << "[PASS] - gemm" << std::endl;
if(result &= ck::utils::check_err(c_n_max.mData, host_c_n_max.mData)) if(result &= ck::utils::check_err(c_n_max.mData, host_c_n_max.mData))
std::cout << "[PASS] - c_n_max" << std::endl; std::cout << "[PASS] - reduce max" << std::endl;
if(result &= ck::utils::check_err(exp_m_n.mData, host_exp_m_n.mData)) if(result &= ck::utils::check_err(exp_m_n.mData, host_exp_m_n.mData))
std::cout << "[PASS] - exp_m_n" << std::endl; std::cout << "[PASS] - broadcast sub + exp" << std::endl;
if(result &= ck::utils::check_err(exp_n_sum.mData, host_exp_n_sum.mData)) if(result &= ck::utils::check_err(exp_n_sum.mData, host_exp_n_sum.mData))
std::cout << "[PASS] - exp_n_sum" << std::endl; std::cout << "[PASS] - reduce sum" << std::endl;
if(result &= ck::utils::check_err(softmax_m_n.mData, host_softmax_m_n.mData)) if(result &= ck::utils::check_err(softmax_m_n.mData, host_softmax_m_n.mData))
std::cout << "[PASS] - softmax_m_n" << std::endl; std::cout << "[PASS] - broadcast div" << std::endl;
} }
return 0; return 0;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment