f8_utils.hpp 7.58 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/utility/statically_indexed_array.hpp"

namespace ck {

using f8_t = uint8_t;
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
11
using half_t  = _Float16;
12
13
14
15
16
17
18
19

// fp8 rounding modes
enum class f8_rounding_mode
{
    standard,
    stochastic
};

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
20
21
template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
22
{
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
23
24
25
    // check data type
    constexpr bool is_half  = std::is_same<T, half_t>::value;
    constexpr bool is_float = std::is_same<T, float>::value;
26

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
27
28
29
    // fp8 exponent/mantissa layout
    constexpr int f8_exp = 4;
    constexpr int f8_mant = 3;
30

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
31
32
33
    // resulting type exponent/mantissa layout
    constexpr int type_exp = is_half ? 5 : 8;
    constexpr int type_mant = is_half ? 10 : 23;
34
35

    int exponent;
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
36
37
38
39
    uint32_t head, mantissa, sign;
    // nan code is same for float and half
    constexpr uint8_t nan_code = 0x80;
    constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000;
40

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
41
42
43
    // convert to bitwise
    typedef typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type T_bitwise;
    T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
44

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    // unpack the input, depends on datatype
    if constexpr(is_float)
    {
        head     = x_bitwise & 0xFF800000;
        mantissa = x_bitwise & 0x7FFFFF;
        exponent = (head >> type_mant) & 0xFF;
        sign     = head >> (type_exp + type_mant);
    }
    else if constexpr(is_half)
    {
        head     = x_bitwise & 0xFC00;
        mantissa = x_bitwise & 0x3FF;
        exponent = (head >> type_mant) & 0x1F;
        sign     = head >> (type_exp + type_mant);
    }

    uint32_t signed_inf = (sign << (type_exp + type_mant)) + (((1 << type_exp) - 1) << type_mant);
    uint32_t drop_mask  = (1 << (type_mant - f8_mant)) - 1;
    constexpr int max_exp  = (1 << f8_exp) - (negative_zero_nan ? 1 : 2);
    constexpr int exp_low_cutoff = (1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
65

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
66
    if constexpr(negative_zero_nan)
67
    {
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
68
69
        if((x_bitwise & nan_mask) == nan_mask)
            return nan_code;
70
71
72
    }
    else
    {
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
73
        if((x_bitwise & nan_mask) == nan_mask)
74
75
76
            return signed_inf + (mantissa != 0 ? 1 : 0);
    }

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
77
    // check if x is 0.0
78
79
80
81
82
    if(x_bitwise == 0)
        return 0;

    exponent -= exp_low_cutoff - 1;
    if(exponent <= 0)
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
83
84
85
        drop_mask = (1 << (type_mant - f8_mant + 1 - exponent)) - 1;
    mantissa += 1 << type_mant;
    // apply random number if needed
86
    mantissa += (stoch ? rng : mantissa) & drop_mask;
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
87
    if(mantissa >= (2 << type_mant))
88
89
90
91
    {
        mantissa >>= 1;
        exponent++;
    }
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
92
    mantissa >>= (type_mant - f8_mant);
93

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
94
    // check negative exponent
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    if(exponent <= 0)
    {
        if(x_bitwise == 0)
            return 0;
        else
        {
            // subnormal range; represented by a subnormal float8 (exponent 0)
            // and involves loss of accuracy
            mantissa >>= 1 - exponent;
            exponent = 0;
        }
    }
    // above range: quantize to maximum possible float of the same sign
    else if(exponent > max_exp)
    {
        if(clip)
        {
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
112
            mantissa = (1 << f8_mant) - 1;
113
114
115
116
117
118
119
            exponent = max_exp;
        }
        else
        {
            return signed_inf;
        }
    }
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
120
121

    // check if x is 0.0 or -0.0
122
    if(exponent == 0 && mantissa == 0)
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        return negative_zero_nan ? 0 : (sign << (f8_exp + f8_mant));
    mantissa &= (1 << f8_mant) - 1;
    return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa;
}

template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng)
{
    // check datatype
    constexpr bool is_half  = std::is_same<T, half_t>::value;
    constexpr bool is_float = std::is_same<T, float>::value;
    static_assert(is_half || is_float, "Only half and float can be casted to f8.");

    return run_cast_to_f8<T, negative_zero_nan, clip, stoch>(x, rng);
137
138
}

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
139
140
template <typename T, bool negative_zero_nan>
__host__ __device__ T run_cast_from_f8(f8_t x)
141
{
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
142
143
144
145
    // check data type
    constexpr bool is_half  = std::is_same<T, half_t>::value;
    constexpr bool is_float = std::is_same<T, float>::value;

146
    // fp8 exponent/mantissa layout
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
147
148
    constexpr int f8_exp = 4;
    constexpr int f8_mant = 3;
149

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    // resulting type exponent/mantissa layout
    constexpr int type_exp = is_half ? 5 : 8;
    constexpr int type_mant = is_half ? 10 : 23;

    // prepare the codes
    constexpr uint8_t nan_code = 0x80;
    T fInf, fNegInf, fNaN, fNeg0;
    if constexpr(is_half)
    {
        constexpr uint16_t ihInf    = 0x7C00;
        constexpr uint16_t ihNegInf = 0xFC00;
        constexpr uint16_t ihNaN    = 0x7C01;
        constexpr uint16_t ihNeg0   = 0x8000;
        fInf                    = *(reinterpret_cast<const half_t*>(&ihInf));
        fNegInf                 = *(reinterpret_cast<const half_t*>(&ihNegInf));
        fNaN                    = *(reinterpret_cast<const half_t*>(&ihNaN));
        fNeg0                   = *(reinterpret_cast<const half_t*>(&ihNeg0));
    }
    else if constexpr(is_float)
    {
        constexpr uint32_t ifInf    = 0x7F800000;
        constexpr uint32_t ifNegInf = 0xFF800000;
        constexpr uint32_t ifNaN    = 0x7F800001;
        constexpr uint32_t ifNeg0   = 0x80000000;
        fInf                    = *(reinterpret_cast<const float*>(&ifInf));
        fNegInf                 = *(reinterpret_cast<const float*>(&ifNegInf));
        fNaN                    = *(reinterpret_cast<const float*>(&ifNaN));
        fNeg0                   = *(reinterpret_cast<const float*>(&ifNeg0));
    }
179
180

    // unpack the input
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
181
182
183
    uint32_t sign     = x >> (f8_exp + f8_mant);
    uint32_t mantissa = x & ((1 << f8_mant) - 1);
    int exponent      = (x & 0x7F) >> f8_mant;
184

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
185
186
    constexpr int exp_low_cutoff = (1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
    typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type retval;
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
187
188

    if constexpr(negative_zero_nan)
189
    {
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
190
        if(x == nan_code)
191
192
193
194
            return fNaN;
    }
    else
    {
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
195
        if(x == nan_code)
196
            return fNeg0;
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
197
        if(exponent == ((1 << f8_exp) - 1))
198
199
200
201
202
203
204
            return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
    }

    // subnormal input
    if(exponent == 0)
    {
        // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
205
        int sh = 1 + __builtin_clz(mantissa) - ((1 + type_exp + type_mant) - f8_mant);
206
207
208
209
210
211
212
213
214
        mantissa <<= sh;
        exponent += 1 - sh;
        /*
        exponent++;
        while(mantissa<(1<<wm)) {
        mantissa <<= 1;
        exponent--;
        }
        */
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
215
        mantissa &= ((1 << f8_mant) - 1);
216
217
    }
    exponent += exp_low_cutoff - 1;
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
218
    mantissa <<= type_mant - f8_mant;
219
220
221
222

    // subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
    if(exponent <= 0)
    {
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
223
        mantissa |= 1 << type_mant;
224
225
226
227
        mantissa >>= 1 - exponent;
        exponent = 0;
    }

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    retval = (sign << (type_exp + type_mant)) | (exponent << type_mant) | mantissa;
    return *(reinterpret_cast<const T*>(&retval));
}

template <typename T, bool negative_zero_nan>
__host__ __device__ T cast_from_f8(f8_t x)
{
    // check datatype
    constexpr bool is_half  = std::is_same<T, half_t>::value;
    constexpr bool is_float = std::is_same<T, float>::value;
    static_assert(is_half || is_float, "only half and float are supported.");

    // check if x is 0.0
    if(x == 0)
        return static_cast<T>(0);

    return run_cast_from_f8<T, negative_zero_nan>(x);
245
246
247
}

} // namespace ck