Commit b1ad4b4f authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Update e8m0 casting

parent bdc1dd6f
...@@ -16,6 +16,7 @@ using bf8_t = unsigned _BitInt(8); ...@@ -16,6 +16,7 @@ using bf8_t = unsigned _BitInt(8);
struct e8m0_scale_t struct e8m0_scale_t
{ {
// E8M0 scale is biased
using type = uint8_t; using type = uint8_t;
type data; type data;
constexpr e8m0_scale_t() : data{type{}} {} constexpr e8m0_scale_t() : data{type{}} {}
......
...@@ -10,12 +10,14 @@ namespace ck::utils { ...@@ -10,12 +10,14 @@ namespace ck::utils {
__host__ __device__ inline float cast_to_float(e8m0_scale_t const scale) __host__ __device__ inline float cast_to_float(e8m0_scale_t const scale)
{ {
return std::pow(2, bit_cast<uint8_t>(scale) - NumericUtils<e8m0_scale_t>::bias); // TODO: check performance and try bit shift impl
return std::powf(2, bit_cast<uint8_t>(scale) - NumericUtils<e8m0_scale_t>::bias);
} }
__host__ __device__ inline e8m0_scale_t cast_from_float(float const scale) __host__ __device__ inline e8m0_scale_t cast_from_float(float const scale)
{ {
return static_cast<uint8_t>(std::log2(scale) + NumericUtils<e8m0_scale_t>::bias); uint32_t e = bit_cast<uint32_t>(scale) & NumericUtils<float>::nan_mask;
return static_cast<uint8_t>(e >> 23);
} }
template <> template <>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment