f8_utils.hpp 4.99 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/utility/statically_indexed_array.hpp"

namespace ck {

using f8_t = uint8_t;

// fp8 rounding modes
enum class f8_rounding_mode
{
    standard,
    stochastic
};

// cast fp32 to fp8
template <bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t cast_to_f8(float x, uint32_t rng)
{
    // fp8 exponent/mantissa layout
    constexpr int we_f8 = 4;
    constexpr int wm_f8 = 3;

    // fp32 exponent/mantissa layout
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
28
    constexpr int we_f32 = 8;
29
30
31
32
33
34
35
36
37
38
39
40
    constexpr int wm_f32 = 23;

    uint32_t x_bitwise;
    x_bitwise = *(reinterpret_cast<uint32_t*>(&x));

    // unpack the input
    uint32_t head, mantissa;
    int exponent;
    uint32_t sign;

    head     = x_bitwise & 0xFF800000;
    mantissa = x_bitwise & 0x7FFFFF;
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
41
42
    exponent = (head >> wm_f32) & 0xFF;
    sign     = head >> (we_f32 + wm_f32);
43
44

    uint32_t signed_inf = (sign << (we_f8 + wm_f8)) + (((1 << we_f8) - 1) << wm_f8);
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
45
46
47
    uint32_t drop_mask  = (1 << (wm_f32 - wm_f8)) - 1;
    int max_exp;
    int exp_low_cutoff;
48

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
49
    if constexpr(negative_zero_nan)
50
51
52
    {
        if((x_bitwise & 0x7F800000) == 0x7F800000)
            return 0x80;
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
53
54
        max_exp        = (1 << we_f8) - 1;
        exp_low_cutoff = 0x80 - (1 << (we_f8 - 1));
55
56
57
58
59
    }
    else
    {
        if((x_bitwise & 0x7F800000) == 0x7F800000)
            return signed_inf + (mantissa != 0 ? 1 : 0);
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
60
61
        max_exp        = (1 << we_f8) - 2;
        exp_low_cutoff = 0x80 - (1 << (we_f8 - 1)) + 1;
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    }

    if(x_bitwise == 0)
        return 0;

    exponent -= exp_low_cutoff - 1;
    if(exponent <= 0)
        drop_mask = (1 << (wm_f32 - wm_f8 + 1 - exponent)) - 1;
    mantissa += 1 << wm_f32;
    mantissa += (stoch ? rng : mantissa) & drop_mask;
    if(mantissa >= (2 << wm_f32))
    {
        mantissa >>= 1;
        exponent++;
    }
    mantissa >>= (wm_f32 - wm_f8);

    if(exponent <= 0)
    {
        if(x_bitwise == 0)
            return 0;
        else
        {
            // subnormal range; represented by a subnormal float8 (exponent 0)
            // and involves loss of accuracy
            mantissa >>= 1 - exponent;
            exponent = 0;
        }
    }
    // above range: quantize to maximum possible float of the same sign
    else if(exponent > max_exp)
    {
        if(clip)
        {
            mantissa = (1 << wm_f8) - 1;
            exponent = max_exp;
        }
        else
        {
            return signed_inf;
        }
    }
    if(exponent == 0 && mantissa == 0)
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
105
        return negative_zero_nan ? 0 : (sign << (we_f8 + wm_f8));
106
    mantissa &= (1 << wm_f8) - 1;
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
107
    return (sign << (we_f8 + wm_f8)) | (exponent << wm_f8) | mantissa;
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
}

// cast fp8 to fp32
template <bool negative_zero_nan>
__host__ __device__ float cast_from_f8(f8_t x)
{
    // fp8 exponent/mantissa layout
    constexpr int we_f8 = 4;
    constexpr int wm_f8 = 3;

    // fp32 exponent/mantissa layout
    constexpr int we_f32 = 8;
    constexpr int wm_f32 = 23;

    float fInf, fNegInf, fNaN, fNeg0;
    const uint32_t ifInf    = 0x7F800000;
    const uint32_t ifNegInf = 0xFF800000;
    const uint32_t ifNaN    = 0x7F800001;
    const uint32_t ifNeg0   = 0x80000000;
    fInf                    = *(reinterpret_cast<const float*>(&ifInf));
    fNegInf                 = *(reinterpret_cast<const float*>(&ifNegInf));
    fNaN                    = *(reinterpret_cast<const float*>(&ifNaN));
    fNeg0                   = *(reinterpret_cast<const float*>(&ifNeg0));

    if(x == 0)
        return static_cast<float>(0);

    // unpack the input
    uint32_t sign     = x >> (we_f8 + wm_f8);
    uint32_t mantissa = x & ((1 << wm_f8) - 1);
    int exponent      = (x & 0x7F) >> wm_f8;

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
140
141
142
143
    int exp_low_cutoff;
    uint32_t retval;

    if constexpr(negative_zero_nan)
144
145
146
    {
        if(x == 0x80)
            return fNaN;
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
147
        exp_low_cutoff = (1 << (we_f32 - 1)) - (1 << (we_f8 - 1));
148
149
150
151
152
153
154
    }
    else
    {
        if(x == 0x80)
            return fNeg0;
        if(exponent == ((1 << we_f8) - 1))
            return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
155
        exp_low_cutoff = (1 << (we_f32 - 1)) - (1 << (we_f8 - 1)) + 1;
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    }

    // subnormal input
    if(exponent == 0)
    {
        // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
        int sh = 1 + __builtin_clz(mantissa) - ((1 + we_f32 + wm_f32) - wm_f8);
        mantissa <<= sh;
        exponent += 1 - sh;
        /*
        exponent++;
        while(mantissa<(1<<wm)) {
        mantissa <<= 1;
        exponent--;
        }
        */
        mantissa &= ((1 << wm_f8) - 1);
    }
    exponent += exp_low_cutoff - 1;
    mantissa <<= wm_f32 - wm_f8;

    // subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
    if(exponent <= 0)
    {
        mantissa |= 1 << wm_f32;
        mantissa >>= 1 - exponent;
        exponent = 0;
    }

    retval = (sign << (we_f32 + wm_f32)) | (exponent << wm_f32) | mantissa;
    return *(reinterpret_cast<const float*>(&retval));
}

} // namespace ck