Commit 5e206700 authored by Astha Rai's avatar Astha Rai
Browse files

temp fix for namespace error in MIOpen

parent 251ab612
...@@ -163,8 +163,8 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator ...@@ -163,8 +163,8 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
const CDEElementwiseOperation& cde_element_op) = 0; const CDEElementwiseOperation& cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
#endif #endif
};
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -17,7 +17,7 @@ struct PassThroughPack2 ...@@ -17,7 +17,7 @@ struct PassThroughPack2
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const; __host__ __device__ void operator()(Y& y, const X& x) const;
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::f8x2_t& x) const __host__ __device__ constexpr void operator()(half2_t& y, const f8x2_t& x) const
{ {
auto t = type_convert<float2_t>(x); auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t); y = type_convert<half2_t>(t);
...@@ -220,7 +220,7 @@ struct PassThrough ...@@ -220,7 +220,7 @@ struct PassThrough
template <> template <>
__host__ __device__ void operator()<bf8_t, half_t>(bf8_t& y, const half_t& x) const __host__ __device__ void operator()<bf8_t, half_t>(bf8_t& y, const half_t& x) const
{ {
y = ck::type_convert<bf8_t>(x); y = type_convert<bf8_t>(x);
} }
}; };
...@@ -293,21 +293,21 @@ struct Scale ...@@ -293,21 +293,21 @@ struct Scale
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const __host__ __device__ void operator()(Y& y, const X& x) const
{ {
y = ck::type_convert<Y>(ck::type_convert<float>(x) * scale_); y = type_convert<Y>(type_convert<float>(x) * scale_);
} }
template <> template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const __host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{ {
y = ck::type_convert<half_t>(scale_) * x; y = type_convert<half_t>(scale_) * x;
}; };
template <> template <>
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const __host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{ {
const float x_tmp = ck::type_convert<float>(x); const float x_tmp = type_convert<float>(x);
const float y_tmp = scale_ * x_tmp; const float y_tmp = scale_ * x_tmp;
y = ck::type_convert<bhalf_t>(y_tmp); y = type_convert<bhalf_t>(y_tmp);
}; };
template <> template <>
...@@ -325,7 +325,7 @@ struct Scale ...@@ -325,7 +325,7 @@ struct Scale
template <> template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const __host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{ {
y = ck::type_convert<int8_t>(scale_ * ck::type_convert<float>(x)); y = type_convert<int8_t>(scale_ * type_convert<float>(x));
}; };
float scale_; float scale_;
...@@ -341,7 +341,7 @@ struct ScaleAndResetNaNToMinusInfinity ...@@ -341,7 +341,7 @@ struct ScaleAndResetNaNToMinusInfinity
template <> template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const __host__ __device__ void operator()<float, float>(float& y, const float& x) const
{ {
y = ck::math::isnan(x) ? -ck::NumericLimits<float>::Infinity() : scale_ * x; y = math::isnan(x) ? -NumericLimits<float>::Infinity() : scale_ * x;
}; };
float scale_; float scale_;
...@@ -417,7 +417,7 @@ struct UnaryAbs ...@@ -417,7 +417,7 @@ struct UnaryAbs
is_same<T, int8_t>::value, is_same<T, int8_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::abs(x); y = math::abs(x);
}; };
}; };
...@@ -429,7 +429,7 @@ struct UnarySqrt ...@@ -429,7 +429,7 @@ struct UnarySqrt
static_assert(is_same<T, float>::value || is_same<T, double>::value, static_assert(is_same<T, float>::value || is_same<T, double>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::sqrt(x); y = math::sqrt(x);
}; };
}; };
...@@ -448,9 +448,9 @@ struct Relu ...@@ -448,9 +448,9 @@ struct Relu
template <> template <>
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const
{ {
float x_f32 = ck::type_convert<float>(x); float x_f32 = type_convert<float>(x);
float y_f32 = x_f32 > 0 ? x_f32 : 0; float y_f32 = x_f32 > 0 ? x_f32 : 0;
y = ck::type_convert<bhalf_t>(y_f32); y = type_convert<bhalf_t>(y_f32);
} }
}; };
...@@ -466,7 +466,7 @@ struct FastGelu ...@@ -466,7 +466,7 @@ struct FastGelu
template <typename Y, typename X> template <typename Y, typename X>
__device__ void operator()(Y& y, const X& x) const; __device__ void operator()(Y& y, const X& x) const;
#ifndef CK_CODE_GEN_RTC
template <> template <>
__host__ void operator()<float, float>(float& y, const float& x) const __host__ void operator()<float, float>(float& y, const float& x) const
{ {
...@@ -477,6 +477,7 @@ struct FastGelu ...@@ -477,6 +477,7 @@ struct FastGelu
const float emu = exp(u); const float emu = exp(u);
y = x / (1.f + emu); y = x / (1.f + emu);
} }
#endif
// device code, use lower precision "__ocml_exp_f32" and "rcp" // device code, use lower precision "__ocml_exp_f32" and "rcp"
template <> template <>
...@@ -488,7 +489,7 @@ struct FastGelu ...@@ -488,7 +489,7 @@ struct FastGelu
const float u = x * (c1 * x * x + c2); const float u = x * (c1 * x * x + c2);
const float emu = __ocml_exp_f32(u); const float emu = __ocml_exp_f32(u);
y = x * ck::math::rcp(1.f + emu); y = x * math::rcp(1.f + emu);
} }
template <> template <>
...@@ -586,10 +587,9 @@ struct Gelu ...@@ -586,10 +587,9 @@ struct Gelu
} }
template <> template <>
__host__ __device__ void operator()<ck::half_t, ck::half_t>(ck::half_t& y, __host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
const ck::half_t& x) const
{ {
y = ck::half_t(0.5) * x * (ck::half_t(1) + ck::half_t(erf(float(0.70710678118f * x)))); y = half_t(0.5) * x * (half_t(1) + half_t(erf(float(0.70710678118f * x))));
} }
}; };
...@@ -599,11 +599,11 @@ struct Sigmoid ...@@ -599,11 +599,11 @@ struct Sigmoid
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1); constexpr T one = type_convert<T>(1);
y = one / (one + ck::math::exp(-x)); y = one / (one + math::exp(-x));
}; };
}; };
...@@ -612,11 +612,11 @@ struct Silu ...@@ -612,11 +612,11 @@ struct Silu
template <typename T> template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same_v<T, float> || is_same_v<T, double> || is_same_v<T, ck::half_t> || static_assert(is_same_v<T, float> || is_same_v<T, double> || is_same_v<T, half_t> ||
is_same_v<T, int8_t> || is_same_v<T, int32_t>, is_same_v<T, int8_t> || is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1); constexpr T one = type_convert<T>(1);
y = x * (one / (one + ck::math::exp(-x))); y = x * (one / (one + math::exp(-x)));
}; };
}; };
...@@ -626,11 +626,11 @@ struct TanH ...@@ -626,11 +626,11 @@ struct TanH
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::tanh(x); y = math::tanh(x);
}; };
}; };
...@@ -640,11 +640,11 @@ struct ACos ...@@ -640,11 +640,11 @@ struct ACos
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::acos(x); y = math::acos(x);
}; };
}; };
...@@ -654,11 +654,11 @@ struct Neg ...@@ -654,11 +654,11 @@ struct Neg
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::neg(x); y = math::neg(x);
}; };
}; };
...@@ -668,11 +668,11 @@ struct ATan ...@@ -668,11 +668,11 @@ struct ATan
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::atan(x); y = math::atan(x);
}; };
}; };
...@@ -682,11 +682,11 @@ struct Sin ...@@ -682,11 +682,11 @@ struct Sin
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::sin(x); y = math::sin(x);
}; };
}; };
...@@ -696,11 +696,11 @@ struct ASinH ...@@ -696,11 +696,11 @@ struct ASinH
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::asinh(x); y = math::asinh(x);
}; };
}; };
...@@ -710,11 +710,11 @@ struct Cos ...@@ -710,11 +710,11 @@ struct Cos
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::cos(x); y = math::cos(x);
}; };
}; };
...@@ -724,11 +724,11 @@ struct ACosH ...@@ -724,11 +724,11 @@ struct ACosH
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::acosh(x); y = math::acosh(x);
}; };
}; };
...@@ -738,11 +738,11 @@ struct Tan ...@@ -738,11 +738,11 @@ struct Tan
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::tan(x); y = math::tan(x);
}; };
}; };
...@@ -752,11 +752,11 @@ struct ATanH ...@@ -752,11 +752,11 @@ struct ATanH
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::atanh(x); y = math::atanh(x);
}; };
}; };
...@@ -766,11 +766,11 @@ struct SinH ...@@ -766,11 +766,11 @@ struct SinH
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::sinh(x); y = math::sinh(x);
}; };
}; };
...@@ -780,11 +780,11 @@ struct Ceil ...@@ -780,11 +780,11 @@ struct Ceil
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::ceil(x); y = math::ceil(x);
}; };
}; };
...@@ -794,11 +794,11 @@ struct Exp ...@@ -794,11 +794,11 @@ struct Exp
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::exp(x); y = math::exp(x);
}; };
}; };
...@@ -808,11 +808,11 @@ struct CosH ...@@ -808,11 +808,11 @@ struct CosH
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::cosh(x); y = math::cosh(x);
}; };
}; };
...@@ -822,11 +822,11 @@ struct Floor ...@@ -822,11 +822,11 @@ struct Floor
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::floor(x); y = math::floor(x);
}; };
}; };
...@@ -836,11 +836,11 @@ struct Log ...@@ -836,11 +836,11 @@ struct Log
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::log(x); y = math::log(x);
}; };
}; };
...@@ -850,11 +850,11 @@ struct ASin ...@@ -850,11 +850,11 @@ struct ASin
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::asin(x); y = math::asin(x);
}; };
}; };
...@@ -864,11 +864,11 @@ struct Rcp ...@@ -864,11 +864,11 @@ struct Rcp
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::rcp(x); y = math::rcp(x);
}; };
}; };
...@@ -880,15 +880,15 @@ struct Swish ...@@ -880,15 +880,15 @@ struct Swish
__host__ __device__ void operator()(Y& y, const X& x) const __host__ __device__ void operator()(Y& y, const X& x) const
{ {
static_assert(is_same<X, float>::value || is_same<X, double>::value || static_assert(is_same<X, float>::value || is_same<X, double>::value ||
is_same<X, ck::half_t>::value, is_same<X, half_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
static_assert(is_same<Y, float>::value || is_same<Y, double>::value || static_assert(is_same<Y, float>::value || is_same<Y, double>::value ||
is_same<Y, ck::half_t>::value, is_same<Y, half_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
float bx = -beta_ * type_convert<float>(x); float bx = -beta_ * type_convert<float>(x);
y = type_convert<Y>(x / (1.f + ck::math::exp(bx))); y = type_convert<Y>(x / (1.f + math::exp(bx)));
}; };
const float beta_; const float beta_;
...@@ -907,7 +907,7 @@ struct SoftRelu ...@@ -907,7 +907,7 @@ struct SoftRelu
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_); T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1); constexpr T one = type_convert<T>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; y = math::log(one + math::exp(x * casted_alpha)) / casted_alpha;
} }
const float alpha_; const float alpha_;
}; };
...@@ -928,7 +928,7 @@ struct Power ...@@ -928,7 +928,7 @@ struct Power
T casted_beta = type_convert<T>(beta_); T casted_beta = type_convert<T>(beta_);
T casted_gamma = type_convert<T>(gamma_); T casted_gamma = type_convert<T>(gamma_);
T shifted_scaled_x = casted_alpha + casted_beta * x; T shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma); y = math::pow(shifted_scaled_x, casted_gamma);
} }
const float alpha_; const float alpha_;
const float beta_; const float beta_;
...@@ -948,7 +948,7 @@ struct ClippedRelu ...@@ -948,7 +948,7 @@ struct ClippedRelu
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_); T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_); T casted_beta = type_convert<T>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); y = math::min(casted_beta, math::max(casted_alpha, x));
} }
const float alpha_; const float alpha_;
const float beta_; const float beta_;
...@@ -983,7 +983,7 @@ struct Elu ...@@ -983,7 +983,7 @@ struct Elu
is_same<T, int8_t>::value, is_same<T, int8_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_); T casted_alpha = type_convert<T>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x); y = x > 0 ? x : casted_alpha * math::expm1(x);
} }
const float alpha_; const float alpha_;
}; };
...@@ -1085,10 +1085,10 @@ struct FastNumericArrayConverter ...@@ -1085,10 +1085,10 @@ struct FastNumericArrayConverter
}; };
template <> template <>
struct FastNumericArrayConverter<uint8_t, ck::half_t, 4> struct FastNumericArrayConverter<uint8_t, half_t, 4>
{ {
using InputArray = vector_type<uint8_t, 4>; using InputArray = vector_type<uint8_t, 4>;
using OutputArray = vector_type<ck::half_t, 4>; using OutputArray = vector_type<half_t, 4>;
__device__ static OutputArray convert(InputArray const& Input) __device__ static OutputArray convert(InputArray const& Input)
{ {
...@@ -1118,13 +1118,13 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, 4> ...@@ -1118,13 +1118,13 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, 4>
}; };
template <index_t N> template <index_t N>
struct FastNumericArrayConverter<uint8_t, ck::half_t, N> struct FastNumericArrayConverter<uint8_t, half_t, N>
{ {
static constexpr int VEC_WIDTH = 4; static constexpr int VEC_WIDTH = 4;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
using InputArray = vector_type<uint8_t, N>; using InputArray = vector_type<uint8_t, N>;
using OutputArray = vector_type<ck::half_t, N>; using OutputArray = vector_type<half_t, N>;
__device__ static OutputArray convert(InputArray const& Input) __device__ static OutputArray convert(InputArray const& Input)
{ {
...@@ -1133,7 +1133,7 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N> ...@@ -1133,7 +1133,7 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N>
OutputArray Output; OutputArray Output;
using Vec_InputArray = vector_type<uint8_t, 4>; using Vec_InputArray = vector_type<uint8_t, 4>;
using Vec_OutputArray = vector_type<ck::half_t, 4>; using Vec_OutputArray = vector_type<half_t, 4>;
Vec_OutputArray* half_4_ptr = reinterpret_cast<Vec_OutputArray*>(&Output); Vec_OutputArray* half_4_ptr = reinterpret_cast<Vec_OutputArray*>(&Output);
Vec_InputArray const* uint8_4_ptr = reinterpret_cast<Vec_InputArray const*>(&Input); Vec_InputArray const* uint8_4_ptr = reinterpret_cast<Vec_InputArray const*>(&Input);
......
...@@ -981,7 +981,7 @@ struct BlockToCTileMap_3DGrid_KSplit ...@@ -981,7 +981,7 @@ struct BlockToCTileMap_3DGrid_KSplit
// Create 3D grid // Create 3D grid
const auto M0 = math::integer_divide_ceil(M, MPerBlock); const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock); const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return ck::make_tuple(N0, M0, k_split); return make_tuple(N0, M0, k_split);
} }
template <typename TopIdx> template <typename TopIdx>
...@@ -1105,7 +1105,7 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1105,7 +1105,7 @@ struct BlockToCTileMap_GemmStreamK
uint32_t dp_for_sk_iters = k_iters_per_tile.get(); uint32_t dp_for_sk_iters = k_iters_per_tile.get();
uint32_t best_sk_score = uint32_t best_sk_score =
ck::NumericLimits<int32_t>::Max(); // we need to find the smallest sk iters NumericLimits<int32_t>::Max(); // we need to find the smallest sk iters
for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles; for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
tentative_sk_blocks++) tentative_sk_blocks++)
{ {
......
...@@ -1075,10 +1075,10 @@ using uint8x64_t = typename vector_type<uint8_t, 64>::type; ...@@ -1075,10 +1075,10 @@ using uint8x64_t = typename vector_type<uint8_t, 64>::type;
template <typename T> template <typename T>
struct NumericLimits; struct NumericLimits;
#ifndef CK_CODE_GEN_RTC
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
{ {
#ifndef CK_CODE_GEN_RTC
__host__ __device__ static constexpr T Min() { return std::numeric_limits<T>::min(); } __host__ __device__ static constexpr T Min() { return std::numeric_limits<T>::min(); }
__host__ __device__ static constexpr T Max() { return std::numeric_limits<T>::max(); } __host__ __device__ static constexpr T Max() { return std::numeric_limits<T>::max(); }
__host__ __device__ static constexpr T Lowest() { return std::numeric_limits<T>::lowest(); } __host__ __device__ static constexpr T Lowest() { return std::numeric_limits<T>::lowest(); }
...@@ -1087,8 +1087,8 @@ struct NumericLimits ...@@ -1087,8 +1087,8 @@ struct NumericLimits
return std::numeric_limits<T>::quiet_NaN(); return std::numeric_limits<T>::quiet_NaN();
} }
__host__ __device__ static constexpr T Infinity() { return std::numeric_limits<T>::infinity(); } __host__ __device__ static constexpr T Infinity() { return std::numeric_limits<T>::infinity(); }
#endif
}; };
#endif
template <> template <>
struct NumericLimits<int32_t> struct NumericLimits<int32_t>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment