Commit e83b22e0 authored by rocking's avatar rocking
Browse files

[What] Use half_float::half instead of ck::half_t for host reduction

[Why]  std::numeric_limits<_Float16>::lowest() will return zero
parent fe659502
...@@ -41,6 +41,10 @@ using BDataType = F16; ...@@ -41,6 +41,10 @@ using BDataType = F16;
using CDataType = F16; using CDataType = F16;
using AccDataType = F32; using AccDataType = F32;
// CAUSION - host reduce_max will call numeric_limits<ck::half_t>::lowest()
// However, numeric_limits<ck::half_t>::lowest() will return zero. So, used half_float::half instead
using HostReduceDataType = half_float::half;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
...@@ -174,18 +178,18 @@ using DeviceElementwiseDivInstance = ck::tensor_operation::device:: ...@@ -174,18 +178,18 @@ using DeviceElementwiseDivInstance = ck::tensor_operation::device::
using HostGemmInstance = ck::tensor_operation::host:: using HostGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>; ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
using HostReduceMaxInstance = ReductionHost<CDataType, using HostReduceMaxInstance = ReductionHost<HostReduceDataType,
CDataType, HostReduceDataType,
CDataType, HostReduceDataType,
ReduceMaxId, ReduceMaxId,
Rank, Rank,
NumReduceDim, NumReduceDim,
ReducePropagateNan, ReducePropagateNan,
false>; false>;
using HostReduceSumInstance = ReductionHost<CDataType, using HostReduceSumInstance = ReductionHost<HostReduceDataType,
CDataType, HostReduceDataType,
CDataType, HostReduceDataType,
ReduceSumId, ReduceSumId,
Rank, Rank,
NumReduceDim, NumReduceDim,
...@@ -474,18 +478,18 @@ int main(int argc, char* argv[]) ...@@ -474,18 +478,18 @@ int main(int argc, char* argv[])
host_gemm_invoker.Run(host_gemm_argument); host_gemm_invoker.Run(host_gemm_argument);
host_reduce_max.Run(1, // alpha host_reduce_max.Run(1, // alpha
reinterpret_cast<const CDataType*>(host_c_m_n.mData.data()), reinterpret_cast<const HostReduceDataType*>(host_c_m_n.mData.data()),
0, // beta 0, // beta
reinterpret_cast<CDataType*>(host_c_n_max.mData.data()), reinterpret_cast<HostReduceDataType*>(host_c_n_max.mData.data()),
host_indices.mData.data()); host_indices.mData.data());
host_broadcast2D<Tensor<CDataType>, Tensor<CDataType>, Tensor<CDataType>, Sub_Exp, 1>( host_broadcast2D<Tensor<CDataType>, Tensor<CDataType>, Tensor<CDataType>, Sub_Exp, 1>(
host_exp_m_n, host_c_m_n, host_c_n_max, M, N, Sub_Exp{}); host_exp_m_n, host_c_m_n, host_c_n_max, M, N, Sub_Exp{});
host_reduce_sum.Run(1, // alpha host_reduce_sum.Run(1, // alpha
reinterpret_cast<const CDataType*>(host_exp_m_n.mData.data()), reinterpret_cast<const HostReduceDataType*>(host_exp_m_n.mData.data()),
0, // beta 0, // beta
reinterpret_cast<CDataType*>(host_exp_n_sum.mData.data()), reinterpret_cast<HostReduceDataType*>(host_exp_n_sum.mData.data()),
host_indices.mData.data()); host_indices.mData.data());
host_broadcast2D<Tensor<CDataType>, Tensor<CDataType>, Tensor<CDataType>, Div, 1>( host_broadcast2D<Tensor<CDataType>, Tensor<CDataType>, Tensor<CDataType>, Div, 1>(
...@@ -498,15 +502,15 @@ int main(int argc, char* argv[]) ...@@ -498,15 +502,15 @@ int main(int argc, char* argv[])
softmax_m_n_device_buf.FromDevice(softmax_m_n.mData.data()); softmax_m_n_device_buf.FromDevice(softmax_m_n.mData.data());
bool result = true; bool result = true;
if (result &= ck::utils::check_err(c_m_n.mData, host_c_m_n.mData)) if(result &= ck::utils::check_err(c_m_n.mData, host_c_m_n.mData))
std::cout << "[PASS] - c_m_n" << std::endl; std::cout << "[PASS] - c_m_n" << std::endl;
if (result &= ck::utils::check_err(c_n_max.mData, host_c_n_max.mData)) if(result &= ck::utils::check_err(c_n_max.mData, host_c_n_max.mData))
std::cout << "[PASS] - c_n_max" << std::endl; std::cout << "[PASS] - c_n_max" << std::endl;
if (result &= ck::utils::check_err(exp_m_n.mData, host_exp_m_n.mData)) if(result &= ck::utils::check_err(exp_m_n.mData, host_exp_m_n.mData))
std::cout << "[PASS] - exp_m_n" << std::endl; std::cout << "[PASS] - exp_m_n" << std::endl;
if (result &= ck::utils::check_err(exp_n_sum.mData, host_exp_n_sum.mData)) if(result &= ck::utils::check_err(exp_n_sum.mData, host_exp_n_sum.mData))
std::cout << "[PASS] - exp_n_sum" << std::endl; std::cout << "[PASS] - exp_n_sum" << std::endl;
if (result &= ck::utils::check_err(softmax_m_n.mData, host_softmax_m_n.mData)) if(result &= ck::utils::check_err(softmax_m_n.mData, host_softmax_m_n.mData))
std::cout << "[PASS] - softmax_m_n" << std::endl; std::cout << "[PASS] - softmax_m_n" << std::endl;
} }
return 0; return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment