Commit bfc80764 authored by rocking's avatar rocking
Browse files

[What] Fix data type for host reduction

[Why] F16 issue for host reduction has been fix in c1ef7319
parent ea09fd32
......@@ -42,10 +42,6 @@ using CDataType = F16;
using AccDataType = F32;
using EltwiseComputeDataType = F32;
// CAUTION - host reduce_max will call numeric_limits<ck::half_t>::lowest()
// However, numeric_limits<ck::half_t>::lowest() will return zero. So, used half_float::half instead
using HostReduceDataType = half_float::half;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
......@@ -192,18 +188,18 @@ using DeviceElementwiseDivInstance =
using HostGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
using HostReduceMaxInstance = ReductionHost<HostReduceDataType,
HostReduceDataType,
HostReduceDataType,
using HostReduceMaxInstance = ReductionHost<CDataType,
CDataType,
CDataType,
ReduceMaxId,
Rank,
NumReduceDim,
ReducePropagateNan,
false>;
using HostReduceSumInstance = ReductionHost<HostReduceDataType,
using HostReduceSumInstance = ReductionHost<CDataType,
AccDataType,
HostReduceDataType,
CDataType,
ReduceSumId,
Rank,
NumReduceDim,
......@@ -510,9 +506,9 @@ int main(int argc, char* argv[])
host_gemm_invoker.Run(host_gemm_argument);
host_reduce_max.Run(1, // alpha
reinterpret_cast<const HostReduceDataType*>(c_m_n.mData.data()),
c_m_n.mData.data(),
0, // beta
reinterpret_cast<HostReduceDataType*>(host_c_n_max.mData.data()),
host_c_n_max.mData.data(),
host_indices.mData.data());
host_broadcast2D<Tensor<CDataType>,
......@@ -523,9 +519,9 @@ int main(int argc, char* argv[])
0>(host_exp_m_n, c_m_n, c_n_max, M, N, SubExp{});
host_reduce_sum.Run(1, // alpha
reinterpret_cast<const HostReduceDataType*>(exp_m_n.mData.data()),
exp_m_n.mData.data(),
0, // beta
reinterpret_cast<HostReduceDataType*>(host_exp_n_sum.mData.data()),
host_exp_n_sum.mData.data(),
host_indices.mData.data());
host_broadcast2D<Tensor<CDataType>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment