Commit 4ddb62bd authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add fp8_convert_sr

parent 4089bc68
......@@ -4,6 +4,7 @@
#pragma once
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/get_id.hpp"
#include "ck/utility/statically_indexed_array.hpp"
namespace ck {
......@@ -1130,7 +1131,8 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
template <typename T, uint32_t seed, std::enable_if_t<std::is_same<float, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val)
{
uint32_t x = reinterpret_cast<uint32_t&>(val);
// uint32_t x = reinterpret_cast<uint32_t&>(val);
uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits ^= x >> 16;
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
......@@ -1146,7 +1148,8 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val)
template <typename T, uint32_t seed, std::enable_if_t<std::is_same<half_t, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val)
{
uint16_t x = reinterpret_cast<uint16_t&>(val);
// uint16_t x = reinterpret_cast<uint16_t&>(val);
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
uint32_t drop_bits = uint32_t(x) & 0xFFFFu;
drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5);
drop_bits *= 0x7000149;
......@@ -1164,12 +1167,34 @@ template <typename T,
__host__ __device__ uint32_t prand_generator(int id, T val)
{
std::ignore = id;
std::ignore = seed;
std::ignore = val;
return 0;
}
// Declare a template function for fp8 conversion using SR
template <typename Y, typename X>
__host__ __device__ constexpr Y fp8_convert_sr(X x);
// convert fp32 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_t fp8_convert_sr<f8_t, float>(float x)
{
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(get_thread_global_1d_id(), x);
return cast_to_f8<negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, rng);
}
// convert fp8 to fp32
template <>
inline __host__ __device__ float fp8_convert_sr<float, f8_t>(f8_t x)
{
return type_convert<float>(x);
}
template <typename T>
struct NumericLimits
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment