Commit bfc80764 authored by rocking's avatar rocking
Browse files

[What] Fix data type for host reduction

[Why] F16 issue for host reduction has been fix in c1ef7319
parent ea09fd32
...@@ -42,10 +42,6 @@ using CDataType = F16; ...@@ -42,10 +42,6 @@ using CDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using EltwiseComputeDataType = F32; using EltwiseComputeDataType = F32;
// CAUTION - host reduce_max will call numeric_limits<ck::half_t>::lowest()
// However, numeric_limits<ck::half_t>::lowest() will return zero. So, used half_float::half instead
using HostReduceDataType = half_float::half;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using CLayout = ck::tensor_layout::gemm::RowMajor;
...@@ -192,18 +188,18 @@ using DeviceElementwiseDivInstance = ...@@ -192,18 +188,18 @@ using DeviceElementwiseDivInstance =
using HostGemmInstance = ck::tensor_operation::host:: using HostGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>; ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
using HostReduceMaxInstance = ReductionHost<HostReduceDataType, using HostReduceMaxInstance = ReductionHost<CDataType,
HostReduceDataType, CDataType,
HostReduceDataType, CDataType,
ReduceMaxId, ReduceMaxId,
Rank, Rank,
NumReduceDim, NumReduceDim,
ReducePropagateNan, ReducePropagateNan,
false>; false>;
using HostReduceSumInstance = ReductionHost<HostReduceDataType, using HostReduceSumInstance = ReductionHost<CDataType,
AccDataType, AccDataType,
HostReduceDataType, CDataType,
ReduceSumId, ReduceSumId,
Rank, Rank,
NumReduceDim, NumReduceDim,
...@@ -510,9 +506,9 @@ int main(int argc, char* argv[]) ...@@ -510,9 +506,9 @@ int main(int argc, char* argv[])
host_gemm_invoker.Run(host_gemm_argument); host_gemm_invoker.Run(host_gemm_argument);
host_reduce_max.Run(1, // alpha host_reduce_max.Run(1, // alpha
reinterpret_cast<const HostReduceDataType*>(c_m_n.mData.data()), c_m_n.mData.data(),
0, // beta 0, // beta
reinterpret_cast<HostReduceDataType*>(host_c_n_max.mData.data()), host_c_n_max.mData.data(),
host_indices.mData.data()); host_indices.mData.data());
host_broadcast2D<Tensor<CDataType>, host_broadcast2D<Tensor<CDataType>,
...@@ -523,9 +519,9 @@ int main(int argc, char* argv[]) ...@@ -523,9 +519,9 @@ int main(int argc, char* argv[])
0>(host_exp_m_n, c_m_n, c_n_max, M, N, SubExp{}); 0>(host_exp_m_n, c_m_n, c_n_max, M, N, SubExp{});
host_reduce_sum.Run(1, // alpha host_reduce_sum.Run(1, // alpha
reinterpret_cast<const HostReduceDataType*>(exp_m_n.mData.data()), exp_m_n.mData.data(),
0, // beta 0, // beta
reinterpret_cast<HostReduceDataType*>(host_exp_n_sum.mData.data()), host_exp_n_sum.mData.data(),
host_indices.mData.data()); host_indices.mData.data());
host_broadcast2D<Tensor<CDataType>, host_broadcast2D<Tensor<CDataType>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment