Commit 4c6c750a authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Add TypeConvert class and start refactoring

parent dbd8f94b
...@@ -942,19 +942,34 @@ using int8x16_t = typename vector_type<int8_t, 16>::type; ...@@ -942,19 +942,34 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type; using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type; using int8x64_t = typename vector_type<int8_t, 64>::type;
// Convert X to Y class TypeConvert
template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert(X x)
{ {
public:
// constructor
__host__ __device__ TypeConvert()
{
BF16ConvertRTN_ = false; // use round to zero by default
}
// switch bf16 conversion mode to rtn
__host__ __device__ void SetBF16ConvertRTN() { BF16ConvertRTN_ = true; }
// switch bf16 conversion mode to rtz
__host__ __device__ void SetBF16ConvertRTZ() { BF16ConvertRTN_ = false; }
// convert for all types except bf16
template <typename Y, typename X>
__host__ __device__ constexpr Y convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>); static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
return static_cast<Y>(x); return static_cast<Y>(x);
} }
// convert bfp16 to fp32 // convert bfp16 to fp32
template <> template <>
inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x) inline __host__ __device__ constexpr float convert<float, bhalf_t>(bhalf_t x)
{ {
union union
{ {
uint32_t int32; uint32_t int32;
...@@ -962,12 +977,15 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t ...@@ -962,12 +977,15 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t
} u = {uint32_t(x) << 16}; } u = {uint32_t(x) << 16};
return u.fp32; return u.fp32;
} }
// convert fp32 to bfp16 // convert fp32 to bfp16
template <> template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x) inline __host__ __device__ constexpr bhalf_t convert<bhalf_t, float>(float x)
{ {
// if using rtn
if(BF16ConvertRTN_)
{
union union
{ {
float fp32; float fp32;
...@@ -1002,65 +1020,82 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float ...@@ -1002,65 +1020,82 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
// the bfloat16's mantissa bits are all 0. // the bfloat16's mantissa bits are all 0.
bool flag1 = !flag0 && (u.int32 & 0xffff); bool flag1 = !flag0 && (u.int32 & 0xffff);
u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even u.int32 +=
flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
return uint16_t(u.int32 >> 16); return uint16_t(u.int32 >> 16);
} }
// if using rtz
else
{
union
{
float fp32;
uint32_t int32;
} u = {x};
// convert bfp16 to fp16 via fp32 return uint16_t(u.int32 >> 16);
template <> }
inline __host__ __device__ constexpr half_t type_convert<half_t, bhalf_t>(bhalf_t x) }
{
float x_fp32 = type_convert<float>(x); // convert bfp16 to fp16 via fp32
template <>
inline __host__ __device__ constexpr half_t convert<half_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = convert<float>(x);
return static_cast<half_t>(x_fp32); return static_cast<half_t>(x_fp32);
} }
// convert fp16 to bfp16 via fp32 // convert fp16 to bfp16 via fp32
template <> template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_t x) inline __host__ __device__ constexpr bhalf_t convert<bhalf_t, half_t>(half_t x)
{ {
float x_fp32 = static_cast<float>(x); float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32); return convert<bhalf_t>(x_fp32);
} }
// convert bfp16 to int32 via fp32 // convert bfp16 to int32 via fp32
template <> template <>
inline __host__ __device__ constexpr int32_t type_convert<int32_t, bhalf_t>(bhalf_t x) inline __host__ __device__ constexpr int32_t convert<int32_t, bhalf_t>(bhalf_t x)
{ {
float x_fp32 = type_convert<float>(x); float x_fp32 = convert<float>(x);
return static_cast<int32_t>(x_fp32); return static_cast<int32_t>(x_fp32);
} }
// convert int32 to bfp16 via fp32 // convert int32 to bfp16 via fp32
template <> template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int32_t>(int32_t x) inline __host__ __device__ constexpr bhalf_t convert<bhalf_t, int32_t>(int32_t x)
{ {
float x_fp32 = static_cast<float>(x); float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32); return convert<bhalf_t>(x_fp32);
} }
// convert bfp16 to int8 via fp32 // convert bfp16 to int8 via fp32
template <> template <>
inline __host__ __device__ constexpr int8_t type_convert<int8_t, bhalf_t>(bhalf_t x) inline __host__ __device__ constexpr int8_t convert<int8_t, bhalf_t>(bhalf_t x)
{ {
float x_fp32 = type_convert<float>(x); float x_fp32 = convert<float>(x);
return static_cast<int8_t>(x_fp32); return static_cast<int8_t>(x_fp32);
} }
// convert int8 to bfp16 via fp32 // convert int8 to bfp16 via fp32
template <> template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_t x) inline __host__ __device__ constexpr bhalf_t convert<bhalf_t, int8_t>(int8_t x)
{ {
float x_fp32 = static_cast<float>(x); float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32); return convert<bhalf_t>(x_fp32);
} }
private:
bool BF16ConvertRTN_;
};
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
......
...@@ -87,10 +87,11 @@ __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const h ...@@ -87,10 +87,11 @@ __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const h
#else #else
const vector_type<half_t, 2> a_vector{a}; const vector_type<half_t, 2> a_vector{a};
const vector_type<half_t, 2> b_vector{b}; const vector_type<half_t, 2> b_vector{b};
TypeConvert type_convert = TypeConvert();
static_for<0, 2, 1>{}([&](auto i) { static_for<0, 2, 1>{}([&](auto i) {
c += type_convert<int32_t>(a_vector.AsType<half_t>()[i]) * c += type_convert.convert<int32_t>(a_vector.AsType<half_t>()[i]) *
type_convert<int32_t>(b_vector.AsType<half_t>()[i]); type_convert.convert<int32_t>(b_vector.AsType<half_t>()[i]);
}); });
#endif #endif
} }
...@@ -138,7 +139,8 @@ __device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const h ...@@ -138,7 +139,8 @@ __device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const h
template <> template <>
__device__ void inner_product<int8_t, int8_t, int32_t>(const int8_t& a, const int8_t& b, int32_t& c) __device__ void inner_product<int8_t, int8_t, int32_t>(const int8_t& a, const int8_t& b, int32_t& c)
{ {
c += type_convert<int32_t>(a) * type_convert<int32_t>(b); TypeConvert type_convert = TypeConvert();
c += type_convert.convert<int32_t>(a) * type_convert.convert<int32_t>(b);
} }
template <> template <>
...@@ -174,10 +176,11 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, ...@@ -174,10 +176,11 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
#else #else
const vector_type<int8_t, 4> a_vector{a}; const vector_type<int8_t, 4> a_vector{a};
const vector_type<int8_t, 4> b_vector{b}; const vector_type<int8_t, 4> b_vector{b};
TypeConvert type_convert = TypeConvert();
static_for<0, 4, 1>{}([&](auto i) { static_for<0, 4, 1>{}([&](auto i) {
c += type_convert<int32_t>(a_vector.AsType<int8_t>()[i]) * c += type_convert.convert<int32_t>(a_vector.AsType<int8_t>()[i]) *
type_convert<int32_t>(b_vector.AsType<int8_t>()[i]); type_convert.convert<int32_t>(b_vector.AsType<int8_t>()[i]);
}); });
#endif #endif
} }
......
...@@ -270,8 +270,10 @@ struct Tensor ...@@ -270,8 +270,10 @@ struct Tensor
{ {
Tensor<OutT> ret(mDesc); Tensor<OutT> ret(mDesc);
ck::ranges::transform( ck::ranges::transform(mData, ret.mData.begin(), [](auto value) {
mData, ret.mData.begin(), [](auto value) { return ck::type_convert<OutT>(value); }); ck::TypeConvert type_convert = ck::TypeConvert();
return type_convert.convert<OutT>(value);
});
return ret; return ret;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment