"examples/vscode:/vscode.git/clone" did not exist on "a6e2c1fe5c02cae8a9f077f5d4e11b73d5791723"
Commit 71a7ac8b authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Get back to template functions type_convert

parent 7c7bd091
...@@ -942,323 +942,125 @@ using int8x16_t = typename vector_type<int8_t, 16>::type; ...@@ -942,323 +942,125 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type; using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type; using int8x64_t = typename vector_type<int8_t, 64>::type;
template <typename Y, typename X, typename... config> // Convert X to Y
struct TypeConvert template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert(X x)
{ {
template<typename X> static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
__host__ __device__ Y operator()(X& x) const
{
return static_cast<Y>(x);
}
};
// convert bfp16 to fp32 return static_cast<Y>(x);
template <> }
struct TypeConvert<float, bhalf_t>
{
__host__ __device__ float operator()(bhalf_t& x) const
{
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(x) << 16};
return u.fp32;
}
};
// convert fp32 to bfp16 // convert bfp16 to fp32
template <> template <>
struct TypeConvert<bhalf_t, float, integral_constant<bool, true>> inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x)
{ {
__host__ __device__ bhalf_t operator()(float& x) const union
{ {
union uint32_t int32;
{ float fp32;
float fp32; } u = {uint32_t(x) << 16};
uint32_t int32;
} u = {x};
return uint16_t(u.int32 >> 16); return u.fp32;
} }
};
// convert fp32 to bfp16 // convert fp32 to bfp16
template <> template <>
struct TypeConvert<bhalf_t, float, integral_constant<bool, false>> inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
{ {
__host__ __device__ bhalf_t operator()(float& x) const union
{ {
union float fp32;
{ uint32_t int32;
float fp32; } u = {x};
uint32_t int32;
} u = {x}; // When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// When the exponent bits are not all 1s, then the value is zero, normal, // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus // This causes the bfloat16's mantissa to be incremented by 1 if the 16
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). // least significant bits of the float mantissa are greater than 0x8000,
// This causes the bfloat16's mantissa to be incremented by 1 if the 16 // or if they are equal to 0x8000 and the least significant bit of the
// least significant bits of the float mantissa are greater than 0x8000, // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// or if they are equal to 0x8000 and the least significant bit of the // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when // has the value 0x7f, then incrementing it causes it to become 0x00 and
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already // the exponent is incremented by one, which is the next higher FP value
// has the value 0x7f, then incrementing it causes it to become 0x00 and // to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// the exponent is incremented by one, which is the next higher FP value // with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal // to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// to a normal value with an exponent of 0x01 and a mantissa of 0x00. // incrementing it causes it to become an exponent of 0xFF and a mantissa
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, // of 0x00, which is Inf, the next higher value to the unrounded value.
// incrementing it causes it to become an exponent of 0xFF and a mantissa bool flag0 = ~u.int32 & 0x7f800000;
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool flag0 = ~u.int32 & 0x7f800000; // When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// When all of the exponent bits are 1, the value is Inf or NaN. // mantissa bit. Quiet NaN is indicated by the most significant mantissa
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero // bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit. Quiet NaN is indicated by the most significant mantissa // mantissa bit being 0 but some other bit(s) being 1. If any of the
// bit being 1. Signaling NaN is indicated by the most significant // lower 16 bits of the mantissa are 1, we set the least significant bit
// mantissa bit being 0 but some other bit(s) being 1. If any of the // of the bfloat16 mantissa, in order to preserve signaling NaN in case
// lower 16 bits of the mantissa are 1, we set the least significant bit // the bfloat16's mantissa bits are all 0.
// of the bfloat16 mantissa, in order to preserve signaling NaN in case bool flag1 = !flag0 && (u.int32 & 0xffff);
// the bfloat16's mantissa bits are all 0.
bool flag1 = !flag0 && (u.int32 & 0xffff); u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
u.int32 +=
flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even return uint16_t(u.int32 >> 16);
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN }
return uint16_t(u.int32 >> 16);
}
};
// convert bfp16 to fp16 via fp32 // convert bfp16 to fp16 via fp32
template <> template <>
struct TypeConvert<half_t, bhalf_t> inline __host__ __device__ constexpr half_t type_convert<half_t, bhalf_t>(bhalf_t x)
{ {
__host__ __device__ half_t operator()(bhalf_t& x) const float x_fp32 = type_convert<float>(x);
{
float x_fp32 = TypeConvert<float, bhalf_t>{}(x);
return static_cast<half_t>(x_fp32); return static_cast<half_t>(x_fp32);
} }
};
// convert fp16 to bfp16 via fp32 // convert fp16 to bfp16 via fp32
template <> template <>
struct TypeConvert<bhalf_t, half_t> inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_t x)
{ {
__host__ __device__ bhalf_t operator()(half_t& x) const float x_fp32 = static_cast<float>(x);
{
float x_fp32 = static_cast<float>(x);
return TypeConvert<bhalf_t, float>{}(x_fp32); return type_convert<bhalf_t>(x_fp32);
} }
};
// convert bfp16 to int32 via fp32 // convert bfp16 to int32 via fp32
template <> template <>
struct TypeConvert<int32_t, bhalf_t> inline __host__ __device__ constexpr int32_t type_convert<int32_t, bhalf_t>(bhalf_t x)
{ {
__host__ __device__ int32_t operator()(bhalf_t& x) const float x_fp32 = type_convert<float>(x);
{
float x_fp32 = TypeConvert<float, bhalf_t>{}(x);
return static_cast<int32_t>(x_fp32); return static_cast<int32_t>(x_fp32);
} }
};
// convert int32 to bfp16 via fp32 // convert int32 to bfp16 via fp32
template <> template <>
struct TypeConvert<bhalf_t, int32_t> inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int32_t>(int32_t x)
{ {
__host__ __device__ bhalf_t operator()(int32_t& x) const float x_fp32 = static_cast<float>(x);
{
float x_fp32 = static_cast<float>(x);
return TypeConvert<bhalf_t, float>{}(x_fp32); return type_convert<bhalf_t>(x_fp32);
} }
};
// convert bfp16 to int8 via fp32 // convert bfp16 to int8 via fp32
template <> template <>
struct TypeConvert<int8_t, bhalf_t> inline __host__ __device__ constexpr int8_t type_convert<int8_t, bhalf_t>(bhalf_t x)
{ {
__host__ __device__ int8_t operator()(bhalf_t& x) const float x_fp32 = type_convert<float>(x);
{
float x_fp32 = TypeConvert<float, bhalf_t>{}(x);
return static_cast<int8_t>(x_fp32); return static_cast<int8_t>(x_fp32);
} }
};
// convert int8 to bfp16 via fp32 // convert int8 to bfp16 via fp32
template <> template <>
struct TypeConvert<bhalf_t, int8_t> inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_t x)
{ {
__host__ __device__ bhalf_t operator()(int8_t& x) const float x_fp32 = static_cast<float>(x);
{
float x_fp32 = static_cast<float>(x);
return TypeConvert<bhalf_t, float>{}(x_fp32);
}
};
// class TypeConvert return type_convert<bhalf_t>(x_fp32);
// { }
// public:
// // constructor
// __host__ __device__ TypeConvert()
// {
// BF16ConvertRTN_ = false; // use round to zero by default
// }
// // switch bf16 conversion mode to rtn
// __host__ __device__ void SetBF16ConvertRTN() { BF16ConvertRTN_ = true; }
// // switch bf16 conversion mode to rtz
// __host__ __device__ void SetBF16ConvertRTZ() { BF16ConvertRTN_ = false; }
// // convert for all types except bf16
// template <typename Y, typename X>
// __host__ __device__ constexpr Y convert(X x)
// {
// static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
// return static_cast<Y>(x);
// }
// // convert bfp16 to fp32
// template <>
// inline __host__ __device__ constexpr float convert<float, bhalf_t>(bhalf_t x)
// {
// union
// {
// uint32_t int32;
// float fp32;
// } u = {uint32_t(x) << 16};
// return u.fp32;
// }
// // convert fp32 to bfp16
// template <>
// inline __host__ __device__ constexpr bhalf_t convert<bhalf_t, float>(float x)
// {
// // if using rtn
// if(BF16ConvertRTN_)
// {
// union
// {
// float fp32;
// uint32_t int32;
// } u = {x};
// // When the exponent bits are not all 1s, then the value is zero, normal,
// // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// // This causes the bfloat16's mantissa to be incremented by 1 if the 16
// // least significant bits of the float mantissa are greater than 0x8000,
// // or if they are equal to 0x8000 and the least significant bit of the
// // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// // has the value 0x7f, then incrementing it causes it to become 0x00 and
// // the exponent is incremented by one, which is the next higher FP value
// // to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// // with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// // to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// // incrementing it causes it to become an exponent of 0xFF and a mantissa
// // of 0x00, which is Inf, the next higher value to the unrounded value.
// bool flag0 = ~u.int32 & 0x7f800000;
// // When all of the exponent bits are 1, the value is Inf or NaN.
// // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// // mantissa bit. Quiet NaN is indicated by the most significant mantissa
// // bit being 1. Signaling NaN is indicated by the most significant
// // mantissa bit being 0 but some other bit(s) being 1. If any of the
// // lower 16 bits of the mantissa are 1, we set the least significant bit
// // of the bfloat16 mantissa, in order to preserve signaling NaN in case
// // the bfloat16's mantissa bits are all 0.
// bool flag1 = !flag0 && (u.int32 & 0xffff);
// u.int32 +=
// flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
// u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
// return uint16_t(u.int32 >> 16);
// }
// // if using rtz
// else
// {
// union
// {
// float fp32;
// uint32_t int32;
// } u = {x};
// return uint16_t(u.int32 >> 16);
// }
// }
// // convert bfp16 to fp16 via fp32
// template <>
// inline __host__ __device__ constexpr half_t convert<half_t, bhalf_t>(bhalf_t x)
// {
// float x_fp32 = convert<float>(x);
// return static_cast<half_t>(x_fp32);
// }
// // convert fp16 to bfp16 via fp32
// template <>
// inline __host__ __device__ constexpr bhalf_t convert<bhalf_t, half_t>(half_t x)
// {
// float x_fp32 = static_cast<float>(x);
// return convert<bhalf_t>(x_fp32);
// }
// // convert bfp16 to int32 via fp32
// template <>
// inline __host__ __device__ constexpr int32_t convert<int32_t, bhalf_t>(bhalf_t x)
// {
// float x_fp32 = convert<float>(x);
// return static_cast<int32_t>(x_fp32);
// }
// // convert int32 to bfp16 via fp32
// template <>
// inline __host__ __device__ constexpr bhalf_t convert<bhalf_t, int32_t>(int32_t x)
// {
// float x_fp32 = static_cast<float>(x);
// return convert<bhalf_t>(x_fp32);
// }
// // convert bfp16 to int8 via fp32
// template <>
// inline __host__ __device__ constexpr int8_t convert<int8_t, bhalf_t>(bhalf_t x)
// {
// float x_fp32 = convert<float>(x);
// return static_cast<int8_t>(x_fp32);
// }
// // convert int8 to bfp16 via fp32
// template <>
// inline __host__ __device__ constexpr bhalf_t convert<bhalf_t, int8_t>(int8_t x)
// {
// float x_fp32 = static_cast<float>(x);
// return convert<bhalf_t>(x_fp32);
// }
// private:
// bool BF16ConvertRTN_;
// };
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
......
...@@ -87,11 +87,10 @@ __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const h ...@@ -87,11 +87,10 @@ __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const h
#else #else
const vector_type<half_t, 2> a_vector{a}; const vector_type<half_t, 2> a_vector{a};
const vector_type<half_t, 2> b_vector{b}; const vector_type<half_t, 2> b_vector{b};
TypeConvert type_convert = TypeConvert();
static_for<0, 2, 1>{}([&](auto i) { static_for<0, 2, 1>{}([&](auto i) {
c += type_convert.convert<int32_t>(a_vector.AsType<half_t>()[i]) * c += type_convert<int32_t>(a_vector.AsType<half_t>()[i]) *
type_convert.convert<int32_t>(b_vector.AsType<half_t>()[i]); type_convert<int32_t>(b_vector.AsType<half_t>()[i]);
}); });
#endif #endif
} }
...@@ -139,8 +138,7 @@ __device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const h ...@@ -139,8 +138,7 @@ __device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const h
template <> template <>
__device__ void inner_product<int8_t, int8_t, int32_t>(const int8_t& a, const int8_t& b, int32_t& c) __device__ void inner_product<int8_t, int8_t, int32_t>(const int8_t& a, const int8_t& b, int32_t& c)
{ {
TypeConvert type_convert = TypeConvert(); c += type_convert<int32_t>(a) * type_convert<int32_t>(b);
c += type_convert.convert<int32_t>(a) * type_convert.convert<int32_t>(b);
} }
template <> template <>
...@@ -176,11 +174,10 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, ...@@ -176,11 +174,10 @@ inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b,
#else #else
const vector_type<int8_t, 4> a_vector{a}; const vector_type<int8_t, 4> a_vector{a};
const vector_type<int8_t, 4> b_vector{b}; const vector_type<int8_t, 4> b_vector{b};
TypeConvert type_convert = TypeConvert();
static_for<0, 4, 1>{}([&](auto i) { static_for<0, 4, 1>{}([&](auto i) {
c += type_convert.convert<int32_t>(a_vector.AsType<int8_t>()[i]) * c += type_convert<int32_t>(a_vector.AsType<int8_t>()[i]) *
type_convert.convert<int32_t>(b_vector.AsType<int8_t>()[i]); type_convert<int32_t>(b_vector.AsType<int8_t>()[i]);
}); });
#endif #endif
} }
......
...@@ -271,7 +271,7 @@ struct Tensor ...@@ -271,7 +271,7 @@ struct Tensor
Tensor<OutT> ret(mDesc); Tensor<OutT> ret(mDesc);
ck::ranges::transform(mData, ret.mData.begin(), [](auto value) { ck::ranges::transform(mData, ret.mData.begin(), [](auto value) {
return ck::TypeConvert<OutT, ck::remove_cvref_t<decltype(value)>>{}(value); return ck::type_convert<OutT>(value);
}); });
return ret; return ret;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment