Commit 0f421d6f authored by rocking's avatar rocking
Browse files

[What] Add ComputeDataType to the eltwise kernel

[Why] Similar to acc datatype, it increase precision
parent cf326690
......@@ -36,10 +36,11 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using CDataType = F16;
using AccDataType = F32;
using ADataType = F16;
using BDataType = F16;
using CDataType = F16;
using AccDataType = F32;
using EltwiseComputeDataType = F32;
// CAUSION - host reduce_max will call numeric_limits<ck::half_t>::lowest()
// However, numeric_limits<ck::half_t>::lowest() will return zero. So, used half_float::half instead
......@@ -103,10 +104,10 @@ using ReduceMaxInElementwiseOperation =
typename ck::reduce_unary_operator<CDataType, ReduceMaxId, true, true>::InElementwiseOperation;
using ReduceMaxAccElementwiseOperation =
typename ck::reduce_unary_operator<CDataType, ReduceMaxId, true, true>::AccElementwiseOperation;
using ReduceSumInElementwiseOperation =
typename ck::reduce_unary_operator<AccDataType, ReduceSumId, true, true>::InElementwiseOperation;
using ReduceSumAccElementwiseOperation =
typename ck::reduce_unary_operator<AccDataType, ReduceSumId, true, true>::AccElementwiseOperation;
using ReduceSumInElementwiseOperation = typename ck::
reduce_unary_operator<AccDataType, ReduceSumId, true, true>::InElementwiseOperation;
using ReduceSumAccElementwiseOperation = typename ck::
reduce_unary_operator<AccDataType, ReduceSumId, true, true>::AccElementwiseOperation;
using DeviceReduceMaxInstance =
ck::tensor_operation::device::DeviceReduceBlockWise<CDataType,
......@@ -150,30 +151,36 @@ using DeviceReduceSumInstance =
struct Sub_Exp
{
__host__ __device__ constexpr void
operator()(CDataType& dst, const CDataType& src1, const CDataType& src2) const
__host__ __device__ constexpr void operator()(EltwiseComputeDataType& dst,
const EltwiseComputeDataType& src1,
const EltwiseComputeDataType& src2) const
{
dst = src1 - src2;
// FIXME - use float16 exponential
float dst_f32 = static_cast<float>(dst);
dst = static_cast<CDataType>(exp(dst_f32));
dst = exp(src1 - src2);
}
};
struct Div
{
__host__ __device__ constexpr void
operator()(CDataType& dst, const CDataType& src1, const CDataType& src2) const
__host__ __device__ constexpr void operator()(EltwiseComputeDataType& dst,
const EltwiseComputeDataType& src1,
const EltwiseComputeDataType& src2) const
{
dst = src1 / src2;
}
};
using DeviceElementwiseSubExpInstance = ck::tensor_operation::device::
DeviceElementwise_2D<CDataType, CDataType, CDataType, Sub_Exp, 256, 32, 8>;
using DeviceElementwiseSubExpInstance =
ck::tensor_operation::device::DeviceElementwise_2D<CDataType,
CDataType,
CDataType,
EltwiseComputeDataType,
Sub_Exp,
256,
32,
8>;
using DeviceElementwiseDivInstance = ck::tensor_operation::device::
DeviceElementwise_2D<CDataType, CDataType, CDataType, Div, 256, 32, 8>;
DeviceElementwise_2D<CDataType, CDataType, CDataType, EltwiseComputeDataType, Div, 256, 32, 8>;
using HostGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
......@@ -199,6 +206,7 @@ using HostReduceSumInstance = ReductionHost<HostReduceDataType,
template <typename HostTensorA,
typename HostTensorB,
typename HostTensorC,
typename ComputeDataType,
typename Functor,
int broadcastDim>
void host_broadcast2D(
......@@ -208,10 +216,19 @@ void host_broadcast2D(
{
for(int n = 0; n < N; ++n)
{
ComputeDataType Amn = static_cast<ComputeDataType>(A(m, n));
ComputeDataType Cmn = 0;
if constexpr(broadcastDim == 1)
functor(C(m, n), A(m, n), B(n));
{
ComputeDataType Bn = static_cast<ComputeDataType>(B(n));
functor(Cmn, Amn, Bn);
}
else
functor(C(m, n), A(m, n), B(m));
{
ComputeDataType Bm = static_cast<ComputeDataType>(B(m));
functor(Cmn, Amn, Bm);
}
C(m, n) = static_cast<ComputeDataType>(Cmn);
}
}
}
......@@ -490,8 +507,12 @@ int main(int argc, char* argv[])
reinterpret_cast<HostReduceDataType*>(host_c_n_max.mData.data()),
host_indices.mData.data());
host_broadcast2D<Tensor<CDataType>, Tensor<CDataType>, Tensor<CDataType>, Sub_Exp, 1>(
host_exp_m_n, c_m_n, c_n_max, M, N, Sub_Exp{});
host_broadcast2D<Tensor<CDataType>,
Tensor<CDataType>,
Tensor<CDataType>,
EltwiseComputeDataType,
Sub_Exp,
1>(host_exp_m_n, c_m_n, c_n_max, M, N, Sub_Exp{});
host_reduce_sum.Run(1, // alpha
reinterpret_cast<const HostReduceDataType*>(exp_m_n.mData.data()),
......@@ -499,8 +520,12 @@ int main(int argc, char* argv[])
reinterpret_cast<HostReduceDataType*>(host_exp_n_sum.mData.data()),
host_indices.mData.data());
host_broadcast2D<Tensor<CDataType>, Tensor<CDataType>, Tensor<CDataType>, Div, 1>(
host_softmax_m_n, exp_m_n, exp_n_sum, M, N, Div{});
host_broadcast2D<Tensor<CDataType>,
Tensor<CDataType>,
Tensor<CDataType>,
EltwiseComputeDataType,
Div,
1>(host_softmax_m_n, exp_m_n, exp_n_sum, M, N, Div{});
bool result = true;
if(result &= ck::utils::check_err(c_m_n.mData, host_c_m_n.mData))
......
......@@ -13,6 +13,7 @@ namespace device {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ComputeDataType,
typename ElementwiseFunctor,
index_t ThreadPerBlock,
index_t ThreadTileSize,
......@@ -43,6 +44,7 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
using GridwiseEltwise = GridwiseElementwise_1D<ADataType,
BDataType,
CDataType,
ComputeDataType,
GridDesc_M0,
ElementwiseFunctor,
ThreadPerBlock,
......
......@@ -33,6 +33,7 @@ __global__ void kernel_elementwise_1d(const ADataType* __restrict__ p_a_global,
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ComputeDataType,
typename GridDesc_M0,
typename ElementwiseFunctor,
index_t ThreadPerBlock,
......@@ -70,15 +71,15 @@ struct GridwiseElementwise_1D
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_grid_desc_m0.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, ADataType, ScalarPerVector, true> a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, BDataType, ScalarPerVector, true> b_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, CDataType, ScalarPerVector, true> c_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> b_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> c_thread_buf;
const auto thread_to_global_offset = CalculateElementwiseIndex();
auto a_global_load =
ThreadwiseTensorSliceTransfer_v2<ADataType,
ADataType,
ComputeDataType,
GridDesc_M0,
decltype(thread_desc_M0),
Sequence<ScalarPerVector>, // SliceLengths
......@@ -90,7 +91,7 @@ struct GridwiseElementwise_1D
auto b_global_load =
ThreadwiseTensorSliceTransfer_v2<BDataType,
BDataType,
ComputeDataType,
GridDesc_M0,
decltype(thread_desc_M0),
Sequence<ScalarPerVector>, // SliceLengths
......@@ -101,7 +102,7 @@ struct GridwiseElementwise_1D
false>{b_grid_desc_m0, thread_to_global_offset};
auto c_global_write =
ThreadwiseTensorSliceTransfer_v1r3<CDataType,
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
CDataType,
decltype(thread_desc_M0),
GridDesc_M0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment