f8_utils.hpp 4.86 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/utility/statically_indexed_array.hpp"

namespace ck {

using f8_t = uint8_t;

// fp8 rounding modes
enum class f8_rounding_mode
{
    standard,
    stochastic
};

// cast fp32 to fp8
template <bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t cast_to_f8(float x, uint32_t rng)
{
    // fp8 exponent/mantissa layout
    constexpr int we_f8 = 4;
    constexpr int wm_f8 = 3;

    // fp32 exponent/mantissa layout
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
28
    constexpr int we_f32 = 8;
29
30
31
32
33
34
35
36
37
38
39
40
    constexpr int wm_f32 = 23;

    uint32_t x_bitwise;
    x_bitwise = *(reinterpret_cast<uint32_t*>(&x));

    // unpack the input
    uint32_t head, mantissa;
    int exponent;
    uint32_t sign;

    head     = x_bitwise & 0xFF800000;
    mantissa = x_bitwise & 0x7FFFFF;
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
41
42
    exponent = (head >> wm_f32) & 0xFF;
    sign     = head >> (we_f32 + wm_f32);
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

    uint32_t signed_inf = (sign << (we_f8 + wm_f8)) + (((1 << we_f8) - 1) << wm_f8);

    if(negative_zero_nan)
    {
        if((x_bitwise & 0x7F800000) == 0x7F800000)
            return 0x80;
    }
    else
    {
        if((x_bitwise & 0x7F800000) == 0x7F800000)
            return signed_inf + (mantissa != 0 ? 1 : 0);
    }

    if(x_bitwise == 0)
        return 0;

    uint32_t drop_mask       = (1 << (wm_f32 - wm_f8)) - 1;
    const int max_exp        = (1 << we_f8) - (negative_zero_nan ? 1 : 2);
    const int exp_low_cutoff = 0x80 - (1 << (we_f8 - 1)) + 1 - (negative_zero_nan ? 1 : 0);

    exponent -= exp_low_cutoff - 1;
    if(exponent <= 0)
        drop_mask = (1 << (wm_f32 - wm_f8 + 1 - exponent)) - 1;
    mantissa += 1 << wm_f32;
    mantissa += (stoch ? rng : mantissa) & drop_mask;
    if(mantissa >= (2 << wm_f32))
    {
        mantissa >>= 1;
        exponent++;
    }
    mantissa >>= (wm_f32 - wm_f8);

    if(exponent <= 0)
    {
        if(x_bitwise == 0)
            return 0;
        else
        {
            // subnormal range; represented by a subnormal float8 (exponent 0)
            // and involves loss of accuracy
            mantissa >>= 1 - exponent;
            exponent = 0;
        }
    }
    // above range: quantize to maximum possible float of the same sign
    else if(exponent > max_exp)
    {
        if(clip)
        {
            mantissa = (1 << wm_f8) - 1;
            exponent = max_exp;
        }
        else
        {
            return signed_inf;
        }
    }
    if(exponent == 0 && mantissa == 0)
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
102
        return negative_zero_nan ? 0 : (sign << (we_f8 + wm_f8));
103
    mantissa &= (1 << wm_f8) - 1;
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
104
    return (sign << (we_f8 + wm_f8)) | (exponent << wm_f8) | mantissa;
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
}

// cast fp8 to fp32
template <bool negative_zero_nan>
__host__ __device__ float cast_from_f8(f8_t x)
{
    // fp8 exponent/mantissa layout
    constexpr int we_f8 = 4;
    constexpr int wm_f8 = 3;

    // fp32 exponent/mantissa layout
    constexpr int we_f32 = 8;
    constexpr int wm_f32 = 23;

    float fInf, fNegInf, fNaN, fNeg0;
    const uint32_t ifInf    = 0x7F800000;
    const uint32_t ifNegInf = 0xFF800000;
    const uint32_t ifNaN    = 0x7F800001;
    const uint32_t ifNeg0   = 0x80000000;
    fInf                    = *(reinterpret_cast<const float*>(&ifInf));
    fNegInf                 = *(reinterpret_cast<const float*>(&ifNegInf));
    fNaN                    = *(reinterpret_cast<const float*>(&ifNaN));
    fNeg0                   = *(reinterpret_cast<const float*>(&ifNeg0));

    if(x == 0)
        return static_cast<float>(0);

    // unpack the input
    uint32_t sign     = x >> (we_f8 + wm_f8);
    uint32_t mantissa = x & ((1 << wm_f8) - 1);
    int exponent      = (x & 0x7F) >> wm_f8;

    if(negative_zero_nan)
    {
        if(x == 0x80)
            return fNaN;
    }
    else
    {
        if(x == 0x80)
            return fNeg0;
        if(exponent == ((1 << we_f8) - 1))
            return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
    }

    uint32_t retval;
    const int exp_low_cutoff =
        (1 << (we_f32 - 1)) - (1 << (we_f8 - 1)) + 1 - (negative_zero_nan ? 1 : 0);

    // subnormal input
    if(exponent == 0)
    {
        // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
        int sh = 1 + __builtin_clz(mantissa) - ((1 + we_f32 + wm_f32) - wm_f8);
        mantissa <<= sh;
        exponent += 1 - sh;
        /*
        exponent++;
        while(mantissa<(1<<wm)) {
        mantissa <<= 1;
        exponent--;
        }
        */
        mantissa &= ((1 << wm_f8) - 1);
    }
    exponent += exp_low_cutoff - 1;
    mantissa <<= wm_f32 - wm_f8;

    // subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
    if(exponent <= 0)
    {
        mantissa |= 1 << wm_f32;
        mantissa >>= 1 - exponent;
        exponent = 0;
    }

    retval = (sign << (we_f32 + wm_f32)) | (exponent << wm_f32) | mantissa;
    return *(reinterpret_cast<const float*>(&retval));
}

} // namespace ck