Commit d478a389 authored by myamlak's avatar myamlak
Browse files

Review remarks: binary ops templated

parent ac9ef30b
......@@ -17,7 +17,8 @@ using ABDataType = F16;
using CDataType = F16;
using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add;
using Add = ck::tensor_operation::binary_element_wise::
Add<EltwiseComputeDataType, EltwiseComputeDataType, EltwiseComputeDataType>;
using DeviceElementwiseAddInstance = ck::tensor_operation::device::
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 2, 8>;
......@@ -37,19 +38,19 @@ void host_broadcast2D(
{
for(int n = 0; n < N; ++n)
{
ComputeDataType Amn = static_cast<ComputeDataType>(A(m, n));
ComputeDataType Amn = ck::type_convert<ComputeDataType>(A(m, n));
ComputeDataType Cmn = 0;
if constexpr(broadcastDim == 0)
{
ComputeDataType Bn = static_cast<ComputeDataType>(B(n));
ComputeDataType Bn = ck::type_convert<ComputeDataType>(B(n));
functor(Cmn, Amn, Bn);
}
else
{
ComputeDataType Bm = static_cast<ComputeDataType>(B(m));
ComputeDataType Bm = ck::type_convert<ComputeDataType>(B(m));
functor(Cmn, Amn, Bm);
}
C(m, n) = static_cast<ctype>(Cmn);
C(m, n) = ck::type_convert<ctype>(Cmn);
}
}
}
......
......@@ -17,7 +17,8 @@ using ABDataType = F16;
using CDataType = F16;
using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add;
using Add = ck::tensor_operation::binary_element_wise::
Add<EltwiseComputeDataType, EltwiseComputeDataType, EltwiseComputeDataType>;
using DeviceElementwiseAddInstance = ck::tensor_operation::device::
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 1, 8>;
......@@ -34,11 +35,11 @@ void host_elementwise1D(
for(int m = 0; m < M; ++m)
{
ComputeDataType Am = static_cast<ComputeDataType>(A(m));
ComputeDataType Bm = static_cast<ComputeDataType>(B(m));
ComputeDataType Am = ck::type_convert<ComputeDataType>(A(m));
ComputeDataType Bm = ck::type_convert<ComputeDataType>(B(m));
ComputeDataType Cm = 0;
functor(Cm, Am, Bm);
C(m) = static_cast<ctype>(Cm);
C(m) = ck::type_convert<ctype>(Cm);
}
}
......
......@@ -17,7 +17,8 @@ using ABDataType = F16;
using CDataType = F16;
using EltwiseComputeDataType = F32;
using Add = ck::tensor_operation::binary_element_wise::Add;
using Add = ck::tensor_operation::binary_element_wise::
Add<EltwiseComputeDataType, EltwiseComputeDataType, EltwiseComputeDataType>;
using DeviceElementwiseAddInstance = ck::tensor_operation::device::
DeviceBinaryElementwise<ABDataType, ABDataType, CDataType, EltwiseComputeDataType, Add, 4, 8>;
......@@ -40,11 +41,11 @@ void host_elementwise4D(HostTensorC& C,
for(std::size_t h = 0; h < shape[2]; ++h)
for(std::size_t w = 0; w < shape[3]; ++w)
{
ComputeDataType a_val = static_cast<ComputeDataType>(A(n, c, h, w));
ComputeDataType b_val = static_cast<ComputeDataType>(B(n, c, h, w));
ComputeDataType a_val = ck::type_convert<ComputeDataType>(A(n, c, h, w));
ComputeDataType b_val = ck::type_convert<ComputeDataType>(B(n, c, h, w));
ComputeDataType c_val = 0;
functor(c_val, a_val, b_val);
C(n, c, h, w) = static_cast<ctype>(c_val);
C(n, c, h, w) = ck::type_convert<ctype>(c_val);
}
}
......
......@@ -523,8 +523,10 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
float ave_time = 0;
using Add = ck::tensor_operation::binary_element_wise::Add;
using Substract = ck::tensor_operation::binary_element_wise::Substract;
using Add =
ck::tensor_operation::binary_element_wise::Add<CDataType, CDataType, CDataType>;
using Substract = ck::tensor_operation::binary_element_wise::
Substract<CDataType, CDataType, CDataType>;
using GridwiseBinAdd = GridwiseBinaryElementwise_1D<CDataType,
CDataType,
CDataType,
......
......@@ -5,26 +5,42 @@ namespace ck {
namespace tensor_operation {
namespace binary_element_wise {
struct Add
template <typename Y, typename X1, typename X2>
struct Add;
template <>
struct Add<double, double, double>
{
__host__ __device__ constexpr void
operator()(double& dst, const double& src1, const double& src2) const
{
dst = src1 + src2;
}
};
template <>
struct Add<float, float, float>
{
__host__ __device__ constexpr void
operator()(float& dst, const float& src1, const float& src2) const
{
dst = src1 + src2;
}
};
template <>
struct Add<half_t, half_t, half_t>
{
__host__ __device__ constexpr void
operator()(half_t& dst, const half_t& src1, const half_t& src2) const
{
dst = src1 + src2;
}
};
template <>
struct Add<bhalf_t, bhalf_t, bhalf_t>
{
__host__ __device__ constexpr void
operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const
{
......@@ -35,26 +51,42 @@ struct Add
}
};
struct Substract
template <typename Y, typename X1, typename X2>
struct Substract;
template <>
struct Substract<double, double, double>
{
__host__ __device__ constexpr void
operator()(double& dst, const double& src1, const double& src2) const
{
dst = src1 - src2;
}
};
template <>
struct Substract<float, float, float>
{
__host__ __device__ constexpr void
operator()(float& dst, const float& src1, const float& src2) const
{
dst = src1 - src2;
}
};
template <>
struct Substract<half_t, half_t, half_t>
{
__host__ __device__ constexpr void
operator()(half_t& dst, const half_t& src1, const half_t& src2) const
{
dst = src1 - src2;
}
};
template <>
struct Substract<bhalf_t, bhalf_t, bhalf_t>
{
__host__ __device__ constexpr void
operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment