hip_float8.h 3.81 KB
Newer Older
1
2
3
#pragma once

#ifdef __HIPCC__
4
  #include <hip/hip_runtime.h>
5
#else
6
7
8
9
  #include <type_traits>
  #include <stdint.h>
  #include <math.h>
  #include <iostream>
10
11
12
13
#endif

#include "hip_float8_impl.h"

14
15
16
17
18
19
20
21
22
23
24
25
struct alignas(1) hip_fp8 {
  struct from_bits_t {};
  HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() {
    return from_bits_t();
  }
  uint8_t data;

  hip_fp8() = default;
  HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
  HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
  explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
      : data(v) {}
26
27

#ifdef __HIP__MI300__
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
  // NOTE: ON-DEVICE... always optimal bias
  explicit HIP_FP8_DEVICE hip_fp8(float v)
      : data(hip_fp8_impl::to_fp8_from_fp32(v)) {}

  explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
      : hip_fp8(static_cast<float>(v)) {}

  // Host only implementation using s/w simulation
  explicit HIP_FP8_HOST
#else   // __HIP__MI300__
  // both Host and DEVICE for non-MI300 using s/w simulation
  explicit HIP_FP8_HOST_DEVICE
#endif  // __HIP__MI300__
  hip_fp8(float v) {
    data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/,
                                   true /*clip*/>(v);
  }

  explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
      : hip_fp8(static_cast<float>(v)) {}
48
49

#ifdef __HIP__MI300__
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
  // upcast using device specific intrinsic
  explicit inline HIP_FP8_DEVICE operator float() const {
    float fval;
    uint32_t i32val = static_cast<uint32_t>(data);

    // upcast
    asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
                 : "=v"(fval)
                 : "v"(i32val));

    return fval;
  }

  explicit inline HIP_FP8_HOST operator float() const
#else   // __HIP__MI300__
  explicit inline HIP_FP8_HOST_DEVICE operator float() const
#endif  // __HIP__MI300__
  {
    return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(
        data);
  }
71
72
};

73
74
75
76
77
namespace std {
inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); }
inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); }
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; }
}  // namespace std
78
79

// Special operator overloading
80
81
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) {
  return os << float(f8);
82
83
84
}

// all + operator overloading with mixed types
85
86
87
88
// mixed types, always converts to f32, does computation in f32, and returns
// float
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) {
  return (fa + float(b));
89
90
}

91
92
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) {
  return (float(a) + fb);
93
94
}

95
96
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) {
  return hip_fp8(float(a) + float(b));
97
98
}

99
100
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) {
  return a = hip_fp8(float(a) + float(b));
101
102
103
}

// overloading multiplication, always returns float,
104
105
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) {
  return float(a) * float(b);
106
107
}

108
109
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) {
  return (a * float(b));
110
111
}

112
113
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) {
  return (float(a) * b);
114
115
}

116
117
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) {
  return ((float)a * float(b));
118
119
}

120
121
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) {
  return ((float)a * float(b));
122
123
124
}

// overloading for compare
125
126
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) {
  return (a.data == b.data);
127
}
128
129
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) {
  return (a.data != b.data);
130
131
}

132
133
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) {
  return static_cast<float>(a) >= static_cast<float>(b);
134
}
135
136
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) {
  return static_cast<float>(a) > static_cast<float>(b);
137
}