mxfp_utils.hpp 12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

namespace ck::utils {

union cvt
{
    float value_float;
    uint32_t value_bitwise;
};

template <typename DTYPE>
inline bool getDataHasInf()
{
    return DTYPE::dataInfo.hasInf;
}

template <typename T>
__host__ __device__ inline bool is_zero(e8m0_scale_t const scale, T const data);

template <typename T>
__host__ __device__ inline bool is_nan(e8m0_scale_t const scale, T const data);

template <typename T>
__host__ __device__ inline bool is_inf(e8m0_scale_t const scale, T const data);

template <typename T>
__host__ __device__ inline int get_exponent_value(T x)
{
    x >>= NumericUtils<T>::mant;

    x &= ((1 << NumericUtils<T>::exp) - 1);

    return static_cast<int>(x);
}

template <typename T>
__host__ __device__ inline bool is_subnormal(T x)
{
    return get_exponent_value<T>(x) == 0;
}

template <typename T>
__host__ __device__ inline double get_mantissa_value(T x)
{
    double mantissa = is_subnormal<T>(x) ? 0.0f : 1.0f;

    for(uint i = 0; i < NumericUtils<T>::mant; i++)
    {

        mantissa += std::pow(2, -int32_t((NumericUtils<T>::mant - i))) * (x & 0b1);

        x >>= 1;
    }

    return mantissa;
}

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
61
62
63
64
65
66
template <typename T>
__host__ __device__ inline bool get_data_has_inf()
{
    return NumericUtils<T>::has_inf;
}

67
68
69
70
template <typename T>
__host__ __device__ float convert_to_float(T data, int scale_exp)
{
    float d_sign =
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
71
        std::pow(-1, static_cast<float>(data >> (NumericUtils<T>::exp + NumericUtils<T>::mant)));
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

    float d_exp;
    if(is_subnormal<T>(data))
        d_exp = std::pow(2, 1 - static_cast<int>(NumericUtils<T>::bias));
    else
        d_exp = std::pow(2, get_exponent_value<T>(data) - static_cast<int>(NumericUtils<T>::bias));
    float d_mant = get_mantissa_value<T>(data);

    float data_value  = d_sign * d_exp * d_mant;
    float scale_value = std::pow(
        2, static_cast<float>((scale_exp - static_cast<int>(NumericUtils<e8m0_scale_t>::bias))));

    return data_value * scale_value;
}

template <typename T>
__host__ __device__ inline float to_float(e8m0_scale_t const scale, T const data);

template <typename T>
__host__ __device__ T sat_convert_to_type(float value);

template <typename T>
__host__ __device__ T sat_convert_to_type_sr(float value, uint32_t seed);

template <typename T>
inline T convert_to_type(float value)
{
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
99
    using bitwise_type = typename NumericUtils<T>::bitwise_type;
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205

    if(std::abs(value) > NumericLimits<T>::Max())
    {
        float max_value = NumericLimits<T>::Max();

        cvt t;

        // cppcheck-suppress redundantAssignment
        t.value_float        = max_value;
        uint32_t max_bitwise = t.value_bitwise;

        // cppcheck-suppress redundantAssignment
        t.value_float = value;
        bitwise_type sign =
            t.value_bitwise >> (NumericUtils<float>::exp + NumericUtils<float>::mant);
        bitwise_type exp =
            ((max_bitwise >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask) -
            (NumericUtils<float>::bias - NumericUtils<T>::bias);
        bitwise_type mantissa = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);

        uint32_t mant_prev = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
        mant_prev &= ((1 << NumericUtils<T>::mant) - 1);
        mant_prev--;

        mant_prev <<= (NumericUtils<float>::mant - NumericUtils<T>::mant);
        uint32_t prev_bit =
            ((max_bitwise >> NumericUtils<float>::mant) << NumericUtils<float>::mant) | mant_prev;

        t.value_bitwise = prev_bit;
        float prev_val  = t.value_float;
        float diff      = max_value - prev_val;

        float actual_max = max_value + (diff / 2);

        if(std::abs(value) < actual_max)
        {
            return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) |
                   (exp << NumericUtils<T>::mant) | mantissa;
        }
        else
        {
            if(!get_data_has_inf<T>())
            {

                return (1 << (NumericUtils<T>::mant + NumericUtils<T>::exp)) - 1;
            }
            else
            {
                exp++;
                return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) |
                       (exp << NumericUtils<T>::mant);
            }
        }
    }
    const int mfmt = NumericUtils<float>::mant;
    uint32_t x;
    // x = reinterpret_cast<uint32_t&>(value);
    x = bit_cast<uint32_t>(value);

    uint32_t head, mantissa;
    int32_t exponent, bias;
    uint32_t sign;

    head     = x & NumericUtils<float>::head_mask;
    mantissa = x & NumericUtils<float>::mant_mask;
    exponent = (head >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask;
    sign     = head >> (NumericUtils<float>::mant + NumericUtils<float>::exp);
    bias     = NumericUtils<float>::bias;

    if(x == 0)
    {
        return 0b0;
    }

    const int mini_bias                  = NumericUtils<T>::bias;
    const int mini_denormal_act_exponent = 1 - mini_bias;

    int act_exponent, out_exponent, exponent_diff;

    bool is_subnorm = false;

    if(exponent == 0)
    {
        act_exponent  = exponent - bias + 1;
        exponent_diff = mini_denormal_act_exponent - act_exponent;
        is_subnorm    = true;
    }
    else
    {
        act_exponent = exponent - bias;
        if(act_exponent <= mini_denormal_act_exponent)
        {
            exponent_diff = mini_denormal_act_exponent - act_exponent;
            is_subnorm    = true;
        }
        else
        {
            exponent_diff = 0;
        }
        mantissa += (1UL << mfmt);
    }

    auto shift_amount = (mfmt - NumericUtils<T>::mant + exponent_diff);
    shift_amount      = (shift_amount >= 64) ? 63 : shift_amount;
    bool midpoint     = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
206
    float min_subnorm = NumericLimits<T>::DataMinSubnorm() * (sign ? -1 : 1);
207

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
208
    if(is_subnorm && std::abs(value) < std::abs(min_subnorm))
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    {
        // closer to 0
        if(std::abs(value) <= std::abs(min_subnorm - value))
            return 0;
        else
            return 1 | (sign << (NumericUtils<T>::exp + NumericUtils<T>::mant));
    }

    if(exponent_diff > 0)
        mantissa >>= exponent_diff;
    else if(exponent_diff == -1)
        mantissa <<= -exponent_diff;
    bool implicit_one = mantissa & (1 << mfmt);
    out_exponent      = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1);

    uint32_t drop_mask = (1UL << (mfmt - NumericUtils<T>::mant)) - 1;
    bool odd           = mantissa & (1UL << (mfmt - NumericUtils<T>::mant));
    mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask;

    if(out_exponent == 0)
    {
        if((1UL << mfmt) & mantissa)
        {
            out_exponent = 1;
        }
    }
    else
    {
        if((1UL << (mfmt + 1)) & mantissa)
        {
            mantissa >>= 1;
            out_exponent++;
        }
    }

    mantissa >>= (mfmt - NumericUtils<T>::mant);

    if(out_exponent == 0 && mantissa == 0)
    {
        return 0;
    }

    mantissa &= (1UL << NumericUtils<T>::mant) - 1;
    return (sign << (NumericUtils<T>::exp + NumericUtils<T>::mant)) |
           (out_exponent << NumericUtils<T>::mant) | mantissa;
}

template <typename T>
inline T convert_to_type_sr(float value, uint32_t seed)
{
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
259
    // using bitwise_type = typename NumericUtils<T>::bitwise_type;
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303

    if(std::abs(value) > NumericLimits<T>::Max())
    {
        float max_value = NumericLimits<T>::Max();

        cvt t;

        // cppcheck-suppress redundantAssignment
        t.value_float    = max_value;
        uint max_bitwise = t.value_bitwise;

        // cppcheck-suppress redundantAssignment
        t.value_float = value;
        T sign        = t.value_bitwise >> (NumericUtils<float>::exp + NumericUtils<float>::mant);
        T exp = ((max_bitwise >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask) -
                (NumericUtils<float>::bias - NumericUtils<T>::bias);

        uint32_t mant_prev = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
        mant_prev &= ((1UL << NumericUtils<T>::mant) - 1);
        mant_prev--;

        mant_prev <<= (NumericUtils<float>::mant - NumericUtils<T>::mant);
        uint32_t prev_bit =
            ((max_bitwise >> NumericUtils<float>::mant) << NumericUtils<float>::mant) | mant_prev;

        t.value_bitwise = prev_bit;
        float prev_val  = t.value_float;
        float diff      = max_value - prev_val;

        float actual_max = max_value + (diff / 2);

        if(std::abs(value) < actual_max)
        {
            double d_max_value  = static_cast<double>(max_value);
            double d_actual_max = static_cast<double>(actual_max);
            double d_value      = static_cast<double>(value);
            double d_is         = std::abs(d_max_value - d_actual_max);
            double d_seed       = static_cast<double>(seed);
            double d_prob = 1.0f - (std::abs(d_value - d_max_value) / d_is); // prob to round down

            double thresh = UINT_MAX * d_prob;

            if(!get_data_has_inf<T>() || d_seed <= thresh)
                // return static_cast<T>(satConvertToType(getDataMax<DTYPE>())); //round down time
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
304
305
                return sign == 0 ? NumericUtils<f4_t>::data_max_positive_normal_mask
                                 : NumericUtils<f4_t>::data_max_negative_normal_mask;
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
            else
            {
                exp++;
                return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) // inf
                       | (exp << NumericUtils<T>::mant);
            }
        }
        else
        {
            if(!get_data_has_inf<T>())
                return (1 << (NumericUtils<T>::mant + NumericUtils<T>::exp)) - 1;
            else
            {
                exp++;
                return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) // inf
                       | (exp << NumericUtils<T>::mant);
            }
        }
    }

    // uint32_t f32 = reinterpret_cast<uint32_t&>(value);
    uint32_t f32 = bit_cast<uint32_t>(value);

    auto f32_mant = f32 & NumericUtils<float>::mant_mask;
    auto head     = f32 & NumericUtils<float>::head_mask;
    auto f32_exp  = (head >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask;

    auto sign_bit = head >> (NumericUtils<float>::mant + NumericUtils<float>::exp);
    auto sign     = sign_bit << (NumericUtils<T>::exp + NumericUtils<T>::mant);

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
336
    f32_exp      = static_cast<int32_t>(f32_exp) - NumericUtils<float>::bias;
337
338
339
340
    int32_t exp  = f32_exp;
    auto mant    = f32_mant;
    bool subnorm = false;

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
341
    if(f32 == 0)
342
343
344
345
346
347
348
349
350
351
352
        return 0b0;

    if(exp >= NumericUtils<T>::unbiased_exp_min)
    {
        mant = f32_mant;
    }
    // if the exponent bit is 8, then the subnormal is exactly the same as f32
    else if(exp < NumericUtils<T>::unbiased_exp_min &&
            NumericUtils<T>::exp < NumericUtils<float>::exp)
    {
        subnorm   = true;
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
353
        auto diff = static_cast<uint32_t>(NumericUtils<T>::unbiased_exp_min - exp);
354
355
356
357
358
359
360
        if(diff >= 32)
        {
            mant     = 0;
            f32_mant = 0;
        }
        else
        {
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
361
            f32_mant |= static_cast<uint32_t>(1) << NumericUtils<float>::mant;
362
363
364
365
366
367
368
369
370
371
372
373
374
            f32_mant >>= diff;
        }
        exp  = 0;
        mant = f32_mant;
    }

    uint32_t sr_shift = NumericUtils<T>::sr_shift;

    // For stochastic-rounding we add the aligned random value to the
    // mantissa and then truncate (RTZ).
    mant += seed >> sr_shift;

    // Increment exponent when mantissa overflows due to rounding
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
375
    if(mant >= static_cast<uint32_t>(1) << NumericUtils<float>::mant)
376
377
378
379
        ++exp;
    mant >>= (NumericUtils<float>::mant - NumericUtils<T>::mant);
    mant &= ((1 << NumericUtils<T>::mant) - 1);

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
380
    auto biased_exp = static_cast<uint32_t>(exp);
381
    if(!subnorm)
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
382
        biased_exp = static_cast<uint32_t>(exp + NumericUtils<T>::bias);
383
384
385
386
387
388
    biased_exp &= ((1 << NumericUtils<T>::exp) - 1);
    auto val = sign | biased_exp << NumericUtils<T>::mant | mant;
    return val;
}

} // namespace ck::utils