Commit 0f421d6f authored by rocking's avatar rocking
Browse files

[What] Add ComputeDataType to the eltwise kernel

[Why] Similar to acc datatype, it increase precision
parent cf326690
...@@ -40,6 +40,7 @@ using ADataType = F16; ...@@ -40,6 +40,7 @@ using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using CDataType = F16; using CDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using EltwiseComputeDataType = F32;
// CAUSION - host reduce_max will call numeric_limits<ck::half_t>::lowest() // CAUSION - host reduce_max will call numeric_limits<ck::half_t>::lowest()
// However, numeric_limits<ck::half_t>::lowest() will return zero. So, used half_float::half instead // However, numeric_limits<ck::half_t>::lowest() will return zero. So, used half_float::half instead
...@@ -103,10 +104,10 @@ using ReduceMaxInElementwiseOperation = ...@@ -103,10 +104,10 @@ using ReduceMaxInElementwiseOperation =
typename ck::reduce_unary_operator<CDataType, ReduceMaxId, true, true>::InElementwiseOperation; typename ck::reduce_unary_operator<CDataType, ReduceMaxId, true, true>::InElementwiseOperation;
using ReduceMaxAccElementwiseOperation = using ReduceMaxAccElementwiseOperation =
typename ck::reduce_unary_operator<CDataType, ReduceMaxId, true, true>::AccElementwiseOperation; typename ck::reduce_unary_operator<CDataType, ReduceMaxId, true, true>::AccElementwiseOperation;
using ReduceSumInElementwiseOperation = using ReduceSumInElementwiseOperation = typename ck::
typename ck::reduce_unary_operator<AccDataType, ReduceSumId, true, true>::InElementwiseOperation; reduce_unary_operator<AccDataType, ReduceSumId, true, true>::InElementwiseOperation;
using ReduceSumAccElementwiseOperation = using ReduceSumAccElementwiseOperation = typename ck::
typename ck::reduce_unary_operator<AccDataType, ReduceSumId, true, true>::AccElementwiseOperation; reduce_unary_operator<AccDataType, ReduceSumId, true, true>::AccElementwiseOperation;
using DeviceReduceMaxInstance = using DeviceReduceMaxInstance =
ck::tensor_operation::device::DeviceReduceBlockWise<CDataType, ck::tensor_operation::device::DeviceReduceBlockWise<CDataType,
...@@ -150,30 +151,36 @@ using DeviceReduceSumInstance = ...@@ -150,30 +151,36 @@ using DeviceReduceSumInstance =
struct Sub_Exp struct Sub_Exp
{ {
__host__ __device__ constexpr void __host__ __device__ constexpr void operator()(EltwiseComputeDataType& dst,
operator()(CDataType& dst, const CDataType& src1, const CDataType& src2) const const EltwiseComputeDataType& src1,
const EltwiseComputeDataType& src2) const
{ {
dst = src1 - src2; dst = exp(src1 - src2);
// FIXME - use float16 exponential
float dst_f32 = static_cast<float>(dst);
dst = static_cast<CDataType>(exp(dst_f32));
} }
}; };
struct Div struct Div
{ {
__host__ __device__ constexpr void __host__ __device__ constexpr void operator()(EltwiseComputeDataType& dst,
operator()(CDataType& dst, const CDataType& src1, const CDataType& src2) const const EltwiseComputeDataType& src1,
const EltwiseComputeDataType& src2) const
{ {
dst = src1 / src2; dst = src1 / src2;
} }
}; };
using DeviceElementwiseSubExpInstance = ck::tensor_operation::device:: using DeviceElementwiseSubExpInstance =
DeviceElementwise_2D<CDataType, CDataType, CDataType, Sub_Exp, 256, 32, 8>; ck::tensor_operation::device::DeviceElementwise_2D<CDataType,
CDataType,
CDataType,
EltwiseComputeDataType,
Sub_Exp,
256,
32,
8>;
using DeviceElementwiseDivInstance = ck::tensor_operation::device:: using DeviceElementwiseDivInstance = ck::tensor_operation::device::
DeviceElementwise_2D<CDataType, CDataType, CDataType, Div, 256, 32, 8>; DeviceElementwise_2D<CDataType, CDataType, CDataType, EltwiseComputeDataType, Div, 256, 32, 8>;
using HostGemmInstance = ck::tensor_operation::host:: using HostGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>; ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
...@@ -199,6 +206,7 @@ using HostReduceSumInstance = ReductionHost<HostReduceDataType, ...@@ -199,6 +206,7 @@ using HostReduceSumInstance = ReductionHost<HostReduceDataType,
template <typename HostTensorA, template <typename HostTensorA,
typename HostTensorB, typename HostTensorB,
typename HostTensorC, typename HostTensorC,
typename ComputeDataType,
typename Functor, typename Functor,
int broadcastDim> int broadcastDim>
void host_broadcast2D( void host_broadcast2D(
...@@ -208,10 +216,19 @@ void host_broadcast2D( ...@@ -208,10 +216,19 @@ void host_broadcast2D(
{ {
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
ComputeDataType Amn = static_cast<ComputeDataType>(A(m, n));
ComputeDataType Cmn = 0;
if constexpr(broadcastDim == 1) if constexpr(broadcastDim == 1)
functor(C(m, n), A(m, n), B(n)); {
ComputeDataType Bn = static_cast<ComputeDataType>(B(n));
functor(Cmn, Amn, Bn);
}
else else
functor(C(m, n), A(m, n), B(m)); {
ComputeDataType Bm = static_cast<ComputeDataType>(B(m));
functor(Cmn, Amn, Bm);
}
C(m, n) = static_cast<ComputeDataType>(Cmn);
} }
} }
} }
...@@ -490,8 +507,12 @@ int main(int argc, char* argv[]) ...@@ -490,8 +507,12 @@ int main(int argc, char* argv[])
reinterpret_cast<HostReduceDataType*>(host_c_n_max.mData.data()), reinterpret_cast<HostReduceDataType*>(host_c_n_max.mData.data()),
host_indices.mData.data()); host_indices.mData.data());
host_broadcast2D<Tensor<CDataType>, Tensor<CDataType>, Tensor<CDataType>, Sub_Exp, 1>( host_broadcast2D<Tensor<CDataType>,
host_exp_m_n, c_m_n, c_n_max, M, N, Sub_Exp{}); Tensor<CDataType>,
Tensor<CDataType>,
EltwiseComputeDataType,
Sub_Exp,
1>(host_exp_m_n, c_m_n, c_n_max, M, N, Sub_Exp{});
host_reduce_sum.Run(1, // alpha host_reduce_sum.Run(1, // alpha
reinterpret_cast<const HostReduceDataType*>(exp_m_n.mData.data()), reinterpret_cast<const HostReduceDataType*>(exp_m_n.mData.data()),
...@@ -499,8 +520,12 @@ int main(int argc, char* argv[]) ...@@ -499,8 +520,12 @@ int main(int argc, char* argv[])
reinterpret_cast<HostReduceDataType*>(host_exp_n_sum.mData.data()), reinterpret_cast<HostReduceDataType*>(host_exp_n_sum.mData.data()),
host_indices.mData.data()); host_indices.mData.data());
host_broadcast2D<Tensor<CDataType>, Tensor<CDataType>, Tensor<CDataType>, Div, 1>( host_broadcast2D<Tensor<CDataType>,
host_softmax_m_n, exp_m_n, exp_n_sum, M, N, Div{}); Tensor<CDataType>,
Tensor<CDataType>,
EltwiseComputeDataType,
Div,
1>(host_softmax_m_n, exp_m_n, exp_n_sum, M, N, Div{});
bool result = true; bool result = true;
if(result &= ck::utils::check_err(c_m_n.mData, host_c_m_n.mData)) if(result &= ck::utils::check_err(c_m_n.mData, host_c_m_n.mData))
......
...@@ -13,6 +13,7 @@ namespace device { ...@@ -13,6 +13,7 @@ namespace device {
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename ComputeDataType,
typename ElementwiseFunctor, typename ElementwiseFunctor,
index_t ThreadPerBlock, index_t ThreadPerBlock,
index_t ThreadTileSize, index_t ThreadTileSize,
...@@ -43,6 +44,7 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor> ...@@ -43,6 +44,7 @@ struct DeviceElementwise_2D : public DeviceElementwise<ElementwiseFunctor>
using GridwiseEltwise = GridwiseElementwise_1D<ADataType, using GridwiseEltwise = GridwiseElementwise_1D<ADataType,
BDataType, BDataType,
CDataType, CDataType,
ComputeDataType,
GridDesc_M0, GridDesc_M0,
ElementwiseFunctor, ElementwiseFunctor,
ThreadPerBlock, ThreadPerBlock,
......
...@@ -33,6 +33,7 @@ __global__ void kernel_elementwise_1d(const ADataType* __restrict__ p_a_global, ...@@ -33,6 +33,7 @@ __global__ void kernel_elementwise_1d(const ADataType* __restrict__ p_a_global,
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename ComputeDataType,
typename GridDesc_M0, typename GridDesc_M0,
typename ElementwiseFunctor, typename ElementwiseFunctor,
index_t ThreadPerBlock, index_t ThreadPerBlock,
...@@ -70,15 +71,15 @@ struct GridwiseElementwise_1D ...@@ -70,15 +71,15 @@ struct GridwiseElementwise_1D
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_grid_desc_m0.GetElementSpaceSize()); p_c_global, c_grid_desc_m0.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, ADataType, ScalarPerVector, true> a_thread_buf; StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, BDataType, ScalarPerVector, true> b_thread_buf; StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> b_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, CDataType, ScalarPerVector, true> c_thread_buf; StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> c_thread_buf;
const auto thread_to_global_offset = CalculateElementwiseIndex(); const auto thread_to_global_offset = CalculateElementwiseIndex();
auto a_global_load = auto a_global_load =
ThreadwiseTensorSliceTransfer_v2<ADataType, ThreadwiseTensorSliceTransfer_v2<ADataType,
ADataType, ComputeDataType,
GridDesc_M0, GridDesc_M0,
decltype(thread_desc_M0), decltype(thread_desc_M0),
Sequence<ScalarPerVector>, // SliceLengths Sequence<ScalarPerVector>, // SliceLengths
...@@ -90,7 +91,7 @@ struct GridwiseElementwise_1D ...@@ -90,7 +91,7 @@ struct GridwiseElementwise_1D
auto b_global_load = auto b_global_load =
ThreadwiseTensorSliceTransfer_v2<BDataType, ThreadwiseTensorSliceTransfer_v2<BDataType,
BDataType, ComputeDataType,
GridDesc_M0, GridDesc_M0,
decltype(thread_desc_M0), decltype(thread_desc_M0),
Sequence<ScalarPerVector>, // SliceLengths Sequence<ScalarPerVector>, // SliceLengths
...@@ -101,7 +102,7 @@ struct GridwiseElementwise_1D ...@@ -101,7 +102,7 @@ struct GridwiseElementwise_1D
false>{b_grid_desc_m0, thread_to_global_offset}; false>{b_grid_desc_m0, thread_to_global_offset};
auto c_global_write = auto c_global_write =
ThreadwiseTensorSliceTransfer_v1r3<CDataType, ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
CDataType, CDataType,
decltype(thread_desc_M0), decltype(thread_desc_M0),
GridDesc_M0, GridDesc_M0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment