f8_utils.hpp 7.43 KB
Newer Older
1
2
3
4
5
6
7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/utility/data_type.hpp"

8
9
// these conversions are disabled if native conversions available
#if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
10
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
namespace ck {

// fp8 rounding modes
// use standard for rounding to nearest, the faster one
// use stochastic for stochastic rounding, helps to avoid error accumulation
enum class f8_rounding_mode
{
    standard,
    stochastic
};

} // namespace ck

namespace ck::utils {

namespace {

28
29
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
30
{
31
32
33
    // fp8/bf8 exponent/mantissa layout
    constexpr int out_exp  = NumericUtils<Y>::exp;
    constexpr int out_mant = NumericUtils<Y>::mant;
34

35
36
37
    // original type exponent/mantissa layout
    constexpr int in_exp  = NumericUtils<X>::exp;
    constexpr int in_mant = NumericUtils<X>::mant;
38
39
40
41

    int exponent;
    uint32_t head, mantissa, sign;
    // nan code is same for float and half
42
43
    constexpr Y nan_code        = 0x80;
    constexpr uint32_t nan_mask = NumericUtils<X>::nan_mask;
44
45

    // convert to bitwise
46
    using T_bitwise     = typename NumericUtils<X>::bitwise_type;
47
48
49
    T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));

    // unpack the input, depends on datatype
50
51
52
53
54
55
56
57
    head     = x_bitwise & NumericUtils<X>::head_mask;
    mantissa = x_bitwise & NumericUtils<X>::mant_mask;
    exponent = (head >> in_mant) & NumericUtils<X>::exp_mask;
    sign     = head >> (in_exp + in_mant);

    uint32_t signed_inf   = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant);
    uint32_t drop_mask    = (1 << (in_mant - out_mant)) - 1;
    constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2);
58
    constexpr int exp_low_cutoff =
59
        (1 << (in_exp - 1)) - (1 << (out_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
60
61
62
63
64
65
66
67
68
69
70
71

    if constexpr(negative_zero_nan)
    {
        if((x_bitwise & nan_mask) == nan_mask)
            return nan_code;
    }
    else
    {
        if((x_bitwise & nan_mask) == nan_mask)
            return signed_inf + (mantissa != 0 ? 1 : 0);
    }

72
73
74
75
76
77
78
79
80
81
82
83
84
    // if input is half and output is bf8
    if((NumericUtils<X>::mant == 10) && (NumericUtils<Y>::mant == 2) && negative_zero_nan &&
       exponent == 0)
    {
        exponent += 1;
        while(mantissa < (1 << in_mant))
        {
            mantissa <<= 1;
            exponent -= 1;
        }
        mantissa &= ~(1 << in_mant);
    }

85
86
87
88
89
90
    // check if x is 0.0
    if(x_bitwise == 0)
        return 0;

    exponent -= exp_low_cutoff - 1;
    if(exponent <= 0)
91
92
        drop_mask = (1 << (in_mant - out_mant + 1 - exponent)) - 1;
    mantissa += 1 << in_mant;
93
94
    // apply random number if needed
    mantissa += (stoch ? rng : mantissa) & drop_mask;
95
    if(mantissa >= (2 << in_mant))
96
97
98
99
    {
        mantissa >>= 1;
        exponent++;
    }
100
    mantissa >>= (in_mant - out_mant);
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

    // check negative exponent
    if(exponent <= 0)
    {
        if(x_bitwise == 0)
            return 0;
        else
        {
            // subnormal range; represented by a subnormal float8 (exponent 0)
            // and involves loss of accuracy
            mantissa >>= 1 - exponent;
            exponent = 0;
        }
    }
    // above range: quantize to maximum possible float of the same sign
    else if(exponent > max_exp)
    {
        if(clip)
        {
120
            mantissa = (1 << out_mant) - 1;
121
122
123
124
125
126
127
128
129
130
            exponent = max_exp;
        }
        else
        {
            return signed_inf;
        }
    }

    // check if x is 0.0 or -0.0
    if(exponent == 0 && mantissa == 0)
131
132
133
        return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
    mantissa &= (1 << out_mant) - 1;
    return (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
134
135
}

136
137
template <typename X, typename Y, bool negative_zero_nan>
__host__ __device__ Y run_cast_from_f8(X x)
138
{
139
140
141
    // fp8/bf8 exponent/mantissa layout
    constexpr int in_exp  = NumericUtils<X>::exp;
    constexpr int in_mant = NumericUtils<X>::mant;
142
143

    // resulting type exponent/mantissa layout
144
145
    constexpr int out_exp  = NumericUtils<Y>::exp;
    constexpr int out_mant = NumericUtils<Y>::mant;
146
147

    // prepare the codes
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    constexpr X nan_code = 0x80;
    Y Inf, NegInf, NaN, Neg0;
    using T_bitwise = typename NumericUtils<Y>::bitwise_type;

    constexpr T_bitwise Inf_bitwise    = NumericUtils<Y>::Inf;
    constexpr T_bitwise NegInf_bitwise = NumericUtils<Y>::NegInf;
    constexpr T_bitwise NaN_bitwise    = NumericUtils<Y>::NaN;
    constexpr T_bitwise Neg0_bitwise   = NumericUtils<Y>::Neg0;

    Inf    = *(reinterpret_cast<const Y*>(&Inf_bitwise));
    NegInf = *(reinterpret_cast<const Y*>(&NegInf_bitwise));
    NaN    = *(reinterpret_cast<const Y*>(&NaN_bitwise));
    Neg0   = *(reinterpret_cast<const Y*>(&Neg0_bitwise));

    // check if x is 0.0
    if(x == 0)
        return static_cast<Y>(0);
165
166

    // unpack the input
167
168
169
    uint32_t sign     = x >> (in_exp + in_mant);
    uint32_t mantissa = x & ((1 << in_mant) - 1);
    int exponent      = (x & 0x7F) >> in_mant;
170
171

    constexpr int exp_low_cutoff =
172
173
        (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
    T_bitwise retval;
174
175
176
177

    if constexpr(negative_zero_nan)
    {
        if(x == nan_code)
178
            return NaN;
179
180
181
182
    }
    else
    {
        if(x == nan_code)
183
184
185
186
187
188
189
190
191
192
            return Neg0;
        if(exponent == ((1 << in_exp) - 1))
            return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
    }

    if((NumericUtils<Y>::mant == 10) && (NumericUtils<X>::mant == 2) && !negative_zero_nan)
    {
        retval = x;
        retval <<= 8;
        return *(reinterpret_cast<const Y*>(&retval));
193
194
195
196
197
198
    }

    // subnormal input
    if(exponent == 0)
    {
        // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
199
200
201
202
203
204
205
        exponent++;
        while(mantissa < (1 << in_mant))
        {
            mantissa <<= 1;
            exponent--;
        }
        mantissa &= ((1 << in_mant) - 1);
206
207
    }
    exponent += exp_low_cutoff - 1;
208
    mantissa <<= out_mant - in_mant;
209
210
211
212

    // subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
    if(exponent <= 0)
    {
213
        mantissa |= 1 << out_mant;
214
215
216
217
        mantissa >>= 1 - exponent;
        exponent = 0;
    }

218
219
    retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
    return *(reinterpret_cast<const Y*>(&retval));
220
221
222
223
}

} // namespace

224
225
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ Y cast_to_f8(X x, uint32_t rng)
226
{
227
228
229
230
    // check datatypes
    constexpr bool is_half  = std::is_same<X, half_t>::value;
    constexpr bool is_float = std::is_same<X, float>::value;
    static_assert(is_half || is_float, "Only half and float can be casted.");
231

232
    return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng);
233
234
}

235
236
template <typename X, typename Y, bool negative_zero_nan>
__host__ __device__ Y cast_from_f8(X x)
237
238
{
    // check datatype
239
240
    constexpr bool is_half  = std::is_same<Y, half_t>::value;
    constexpr bool is_float = std::is_same<Y, float>::value;
241
242
    static_assert(is_half || is_float, "only half and float are supported.");

243
    return run_cast_from_f8<X, Y, negative_zero_nan>(x);
244
245
246
}

} // namespace ck::utils
247
248
#endif // #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
#endif // #if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)