Commit d478a389 authored by myamlak's avatar myamlak
Browse files

Review remarks: binary ops templated

parent ac9ef30b
...@@ -17,7 +17,8 @@ using ABDataType = F16; ...@@ -17,7 +17,8 @@ using ABDataType = F16;
using CDataType = F16; using CDataType = F16;
using EltwiseComputeDataType = F32; using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add; using Add = ck::tensor_operation::binary_element_wise::
Add<EltwiseComputeDataType, EltwiseComputeDataType, EltwiseComputeDataType>;
using DeviceElementwiseAddInstance = ck::tensor_operation::device:: using DeviceElementwiseAddInstance = ck::tensor_operation::device::
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 2, 8>; DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 2, 8>;
...@@ -37,19 +38,19 @@ void host_broadcast2D( ...@@ -37,19 +38,19 @@ void host_broadcast2D(
{ {
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
ComputeDataType Amn = static_cast<ComputeDataType>(A(m, n)); ComputeDataType Amn = ck::type_convert<ComputeDataType>(A(m, n));
ComputeDataType Cmn = 0; ComputeDataType Cmn = 0;
if constexpr(broadcastDim == 0) if constexpr(broadcastDim == 0)
{ {
ComputeDataType Bn = static_cast<ComputeDataType>(B(n)); ComputeDataType Bn = ck::type_convert<ComputeDataType>(B(n));
functor(Cmn, Amn, Bn); functor(Cmn, Amn, Bn);
} }
else else
{ {
ComputeDataType Bm = static_cast<ComputeDataType>(B(m)); ComputeDataType Bm = ck::type_convert<ComputeDataType>(B(m));
functor(Cmn, Amn, Bm); functor(Cmn, Amn, Bm);
} }
C(m, n) = static_cast<ctype>(Cmn); C(m, n) = ck::type_convert<ctype>(Cmn);
} }
} }
} }
......
...@@ -17,7 +17,8 @@ using ABDataType = F16; ...@@ -17,7 +17,8 @@ using ABDataType = F16;
using CDataType = F16; using CDataType = F16;
using EltwiseComputeDataType = F32; using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add; using Add = ck::tensor_operation::binary_element_wise::
Add<EltwiseComputeDataType, EltwiseComputeDataType, EltwiseComputeDataType>;
using DeviceElementwiseAddInstance = ck::tensor_operation::device:: using DeviceElementwiseAddInstance = ck::tensor_operation::device::
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 1, 8>; DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 1, 8>;
...@@ -34,11 +35,11 @@ void host_elementwise1D( ...@@ -34,11 +35,11 @@ void host_elementwise1D(
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
ComputeDataType Am = static_cast<ComputeDataType>(A(m)); ComputeDataType Am = ck::type_convert<ComputeDataType>(A(m));
ComputeDataType Bm = static_cast<ComputeDataType>(B(m)); ComputeDataType Bm = ck::type_convert<ComputeDataType>(B(m));
ComputeDataType Cm = 0; ComputeDataType Cm = 0;
functor(Cm, Am, Bm); functor(Cm, Am, Bm);
C(m) = static_cast<ctype>(Cm); C(m) = ck::type_convert<ctype>(Cm);
} }
} }
......
...@@ -17,7 +17,8 @@ using ABDataType = F16; ...@@ -17,7 +17,8 @@ using ABDataType = F16;
using CDataType = F16; using CDataType = F16;
using EltwiseComputeDataType = F32; using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add; using Add = ck::tensor_operation::binary_element_wise::
Add<EltwiseComputeDataType, EltwiseComputeDataType, EltwiseComputeDataType>;
using DeviceElementwiseAddInstance = ck::tensor_operation::device:: using DeviceElementwiseAddInstance = ck::tensor_operation::device::
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 4, 8>; DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 4, 8>;
...@@ -40,11 +41,11 @@ void host_elementwise4D(HostTensorC& C, ...@@ -40,11 +41,11 @@ void host_elementwise4D(HostTensorC& C,
for(std::size_t h = 0; h < shape[2]; ++h) for(std::size_t h = 0; h < shape[2]; ++h)
for(std::size_t w = 0; w < shape[3]; ++w) for(std::size_t w = 0; w < shape[3]; ++w)
{ {
ComputeDataType a_val = static_cast<ComputeDataType>(A(n, c, h, w)); ComputeDataType a_val = ck::type_convert<ComputeDataType>(A(n, c, h, w));
ComputeDataType b_val = static_cast<ComputeDataType>(B(n, c, h, w)); ComputeDataType b_val = ck::type_convert<ComputeDataType>(B(n, c, h, w));
ComputeDataType c_val = 0; ComputeDataType c_val = 0;
functor(c_val, a_val, b_val); functor(c_val, a_val, b_val);
C(n, c, h, w) = static_cast<ctype>(c_val); C(n, c, h, w) = ck::type_convert<ctype>(c_val);
} }
} }
......
...@@ -523,8 +523,10 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -523,8 +523,10 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
float ave_time = 0; float ave_time = 0;
using Add = ck::tensor_operation::binary_element_wise::Add; using Add =
using Substract = ck::tensor_operation::binary_element_wise::Substract; ck::tensor_operation::binary_element_wise::Add<CDataType, CDataType, CDataType>;
using Substract = ck::tensor_operation::binary_element_wise::
Substract<CDataType, CDataType, CDataType>;
using GridwiseBinAdd = GridwiseBinaryElementwise_1D<CDataType, using GridwiseBinAdd = GridwiseBinaryElementwise_1D<CDataType,
CDataType, CDataType,
CDataType, CDataType,
......
...@@ -5,26 +5,42 @@ namespace ck { ...@@ -5,26 +5,42 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace binary_element_wise { namespace binary_element_wise {
struct Add template <typename Y, typename X1, typename X2>
struct Add;
template <>
struct Add<double, double, double>
{ {
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(double& dst, const double& src1, const double& src2) const operator()(double& dst, const double& src1, const double& src2) const
{ {
dst = src1 + src2; dst = src1 + src2;
} }
};
template <>
struct Add<float, float, float>
{
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(float& dst, const float& src1, const float& src2) const operator()(float& dst, const float& src1, const float& src2) const
{ {
dst = src1 + src2; dst = src1 + src2;
} }
};
template <>
struct Add<half_t, half_t, half_t>
{
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(half_t& dst, const half_t& src1, const half_t& src2) const operator()(half_t& dst, const half_t& src1, const half_t& src2) const
{ {
dst = src1 + src2; dst = src1 + src2;
} }
};
template <>
struct Add<bhalf_t, bhalf_t, bhalf_t>
{
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const
{ {
...@@ -35,26 +51,42 @@ struct Add ...@@ -35,26 +51,42 @@ struct Add
} }
}; };
struct Substract template <typename Y, typename X1, typename X2>
struct Substract;
template <>
struct Substract<double, double, double>
{ {
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(double& dst, const double& src1, const double& src2) const operator()(double& dst, const double& src1, const double& src2) const
{ {
dst = src1 - src2; dst = src1 - src2;
} }
};
template <>
struct Substract<float, float, float>
{
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(float& dst, const float& src1, const float& src2) const operator()(float& dst, const float& src1, const float& src2) const
{ {
dst = src1 - src2; dst = src1 - src2;
} }
};
template <>
struct Substract<half_t, half_t, half_t>
{
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(half_t& dst, const half_t& src1, const half_t& src2) const operator()(half_t& dst, const half_t& src1, const half_t& src2) const
{ {
dst = src1 - src2; dst = src1 - src2;
} }
};
template <>
struct Substract<bhalf_t, bhalf_t, bhalf_t>
{
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment