migraphx_float8.hpp 21.4 KB
Newer Older
Umang Yadav's avatar
Umang Yadav committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
/* ************************************************************************
 * Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
 * ies of the Software, and to permit persons to whom the Software is furnished
 * to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
 * PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
 * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
 * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
 * CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 *
 * ************************************************************************ */

Umang Yadav's avatar
Umang Yadav committed
23
24
#ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
#define MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
Umang Yadav's avatar
Umang Yadav committed
25
#if defined(__clang__)
Umang Yadav's avatar
Umang Yadav committed
26
27
28
29
30
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wfloat-equal"
#pragma clang diagnostic ignored "-Wmacro-redefined"
#pragma clang diagnostic ignored "-Wc++20-extensions"
Umang Yadav's avatar
Umang Yadav committed
31
#endif
Umang Yadav's avatar
Umang Yadav committed
32
33
34
35
36
37
38
39

#if(defined(__HIP_PLATFORM_HCC__) || defined(__HIP_PLATFORM_AMD__))
// need to include hip_runtime.h otherwise it complains about __host__ and __device__
#ifndef __HIPCC_RTC__
#include <hip/hip_runtime.h>
#else
#include <migraphx/kernels/hip.hpp>
#endif
Umang Yadav's avatar
Umang Yadav committed
40
#define MIGRAPHX_HIP_HOST_DEVICE __host__ __device__
Umang Yadav's avatar
Umang Yadav committed
41
#define MIGRAPHX_HIP_HOST __host__
Umang Yadav's avatar
Umang Yadav committed
42
43
#else
#define MIGRAPHX_HIP_HOST_DEVICE
Umang Yadav's avatar
Umang Yadav committed
44
#define MIGRAPHX_HIP_HOST
Umang Yadav's avatar
Umang Yadav committed
45
#endif
Umang Yadav's avatar
Umang Yadav committed
46

Umang Yadav's avatar
Umang Yadav committed
47
48
#define MIGRAPHX_HIP_DEVICE __device__

Umang Yadav's avatar
Umang Yadav committed
49
50
51
52
#ifndef MIGRAPHX_FP8_FNUZ
#define MIGRAPHX_FP8_FNUZ true
#endif

Umang Yadav's avatar
Umang Yadav committed
53
54
// We are clipping in down conversion by default
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1
Umang Yadav's avatar
Umang Yadav committed
55

Umang Yadav's avatar
Umang Yadav committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#ifndef __HIPCC_RTC__
#include <cmath>
#include <cstdint>
#include <climits>
#include <cstring>
#include <iosfwd>
#include <limits>
#include <sstream>
#include <iostream>
#include <string>
#include <utility>
#endif

namespace migraphx_hip_f8_impl {

template <int wm, int we, typename T, bool negative_zero_nan, bool clip>
Umang Yadav's avatar
Umang Yadav committed
72
MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0);
Umang Yadav's avatar
Umang Yadav committed
73
74

template <int wm, int we, typename T, bool negative_zero_nan>
Umang Yadav's avatar
Umang Yadav committed
75
MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x);
Umang Yadav's avatar
Umang Yadav committed
76
77
78

} // namespace migraphx_hip_f8_impl

Umang Yadav's avatar
Umang Yadav committed
79
#include <migraphx/migraphx_hip_f8_impl.hpp>
Umang Yadav's avatar
Umang Yadav committed
80
81
82
83
84

namespace migraphx_fp8 {

enum class migraphx_hip_f8_rounding_mode
{
Umang Yadav's avatar
Umang Yadav committed
85
    standard, // standard rounding is doing RNE -- round to nearest even
Umang Yadav's avatar
Umang Yadav committed
86
87
88
89
90
91
92
93
94
95
    stochastic
};

enum class hip_f8_type
{
    bf8 = 0, // s1e5m2
    fp8 = 1  // s1e4m3
};

template <migraphx_fp8::hip_f8_type T = migraphx_fp8::hip_f8_type::fp8>
Umang Yadav's avatar
Umang Yadav committed
96
struct hip_f8
Umang Yadav's avatar
Umang Yadav committed
97
98
99
{
    uint8_t data;
    // default constructor
Umang Yadav's avatar
Umang Yadav committed
100
101
102
103
104
105
106
107
108
    MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8() = default;
    // default copy constructor
    MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8(const hip_f8& y) = default;
    struct from_bits_t
    {
    };
    static constexpr MIGRAPHX_HIP_HOST_DEVICE from_bits_t from_bits() { return from_bits_t(); }

    MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8(uint8_t bits, from_bits_t) : data(bits) {}
Umang Yadav's avatar
Umang Yadav committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
    // device specific optimized F8 down-conversion code

    template <bool stochastic_rounding = false>
    static MIGRAPHX_HIP_DEVICE uint8_t cast_to_f8_from_f32(float v, uint32_t rng = 0)
    {
        uint8_t i8data;
        union
        {
            float fval;
            uint32_t i32val;
            uint8_t i8val[4]; // NOTE: not endian independent
        } val;

        uint32_t ival = 0;
        val.fval      = v;

#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
        if constexpr(T == migraphx_fp8::hip_f8_type::fp8)
        {
            if((val.i32val & 0x7F800000) != 0x7F800000) /// propagate NAN/INF, no clipping
                val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
        }
        else
        {
            if((val.i32val & 0x7F800000) != 0x7F800000) // propagate NAN/INF, no clipping
                val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
        }
#endif
        if(stochastic_rounding)
        {
            if constexpr(T == migraphx_fp8::hip_f8_type::fp8)
            {
                ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
            }
            else
            {
                ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
            }
        }
Umang Yadav's avatar
Umang Yadav committed
150
        else // RNE CVT
Umang Yadav's avatar
Umang Yadav committed
151
152
153
154
155
156
157
158
159
160
161
162
        {
            if constexpr(T == migraphx_fp8::hip_f8_type::fp8)
            {
                ival = __builtin_amdgcn_cvt_pk_fp8_f32(
                    val.fval, val.fval, ival, false); // false -> WORD0
            }
            else
            {
                ival = __builtin_amdgcn_cvt_pk_bf8_f32(
                    val.fval, val.fval, ival, false); // false -> WORD0}
            }
        }
Umang Yadav's avatar
Umang Yadav committed
163
164
165
166
        val.i32val = ival;
        i8data     = val.i8val[0]; // little endian

        return i8data;
Umang Yadav's avatar
Umang Yadav committed
167
168
169
170
171
172
173
    }
#endif // __gfx940__

       // constructor from float
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)

    // NOTE: ON-DEVICE... always optimal bias
Umang Yadav's avatar
Umang Yadav committed
174
175
176
177
    explicit MIGRAPHX_HIP_DEVICE hip_f8(float v,
                                        migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
                                            migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
                                        uint32_t rng = 0)
Umang Yadav's avatar
Umang Yadav committed
178
179
180
181
182
183
184
185
186
187
188
189
    {
        // runtime branch, use cast_to_f8_from_f32 if want to avoid it
        if(rm == migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic)
            data = cast_to_f8_from_f32<true>(v, rng);
        else
            data = cast_to_f8_from_f32<false>(v);
    }

    // Host only implementation using s/w simulation
    explicit MIGRAPHX_HIP_HOST
#else
    // both Host and DEVICE for non-gfx940 using s/w simulation
Umang Yadav's avatar
Umang Yadav committed
190
    explicit constexpr MIGRAPHX_HIP_HOST_DEVICE
Umang Yadav's avatar
Umang Yadav committed
191
#endif
Umang Yadav's avatar
Umang Yadav committed
192
193
194
195
    hip_f8(float v,
           migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
               migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
           uint32_t rng = 0)
Umang Yadav's avatar
Umang Yadav committed
196
197
198
199
200
    {
        if constexpr(T == migraphx_fp8::hip_f8_type::fp8)
        {
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
            data = migraphx_hip_f8_impl::
Umang Yadav's avatar
Umang Yadav committed
201
                cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>(
Umang Yadav's avatar
Umang Yadav committed
202
203
204
                    v, (rm == migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic), rng);
#else  // MIGRAPHX_F8_DOWNCAST_CLIPPING
            data = migraphx_hip_f8_impl::
Umang Yadav's avatar
Umang Yadav committed
205
                cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>(
Umang Yadav's avatar
Umang Yadav committed
206
207
208
209
210
211
212
                    v, (rm == migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic), rng);
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
        }
        else
        {
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
            data = migraphx_hip_f8_impl::
Umang Yadav's avatar
Umang Yadav committed
213
                cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>(
Umang Yadav's avatar
Umang Yadav committed
214
215
216
                    v, (rm == migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic), rng);
#else  // MIGRAPHX_F8_DOWNCAST_CLIPPING
            data = migraphx_hip_f8_impl::
Umang Yadav's avatar
Umang Yadav committed
217
                cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>(
Umang Yadav's avatar
Umang Yadav committed
218
219
220
221
222
                    v, (rm == migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic), rng);
#endif // rocblas_F8_downcast_clipping}
        }
    }

Umang Yadav's avatar
Umang Yadav committed
223
224
225
226
227
228
229
230
231
232
    /*
        // Constructor from half
        explicit constexpr MIGRAPHX_HIP_HOST_DEVICE
        hip_f8(migraphx::half v,
               migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
                   migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
               uint32_t rng = 0)
            : hip_f8((float)v, rm, rng)
        {
        }
Umang Yadav's avatar
Umang Yadav committed
233
234

    // constructor from int
Umang Yadav's avatar
Umang Yadav committed
235
236
237
238
239
240
    explicit constexpr MIGRAPHX_HIP_HOST_DEVICE
    hip_f8(int v,
           migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
               migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
           uint32_t rng = 0)
        : hip_f8((float)v, rm, rng)
Umang Yadav's avatar
Umang Yadav committed
241
242
    {
    }
Umang Yadav's avatar
Umang Yadav committed
243

Umang Yadav's avatar
Umang Yadav committed
244
    // constructor from double
Umang Yadav's avatar
Umang Yadav committed
245
246
247
248
249
250
    explicit constexpr MIGRAPHX_HIP_HOST_DEVICE
    hip_f8(double v,
           migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
               migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
           uint32_t rng = 0)
        : hip_f8((float)v, rm, rng)
Umang Yadav's avatar
Umang Yadav committed
251
252
    {
    }
Umang Yadav's avatar
Umang Yadav committed
253
254
    */
    /**/
Umang Yadav's avatar
Umang Yadav committed
255
    // convert to float
Umang Yadav's avatar
Umang Yadav committed
256
257
// #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if 0 // need constexpr operator(). This version can't be constexpr
Umang Yadav's avatar
Umang Yadav committed
258
    // upcast using device specific intrinsic
Umang Yadav's avatar
Umang Yadav committed
259
    inline MIGRAPHX_HIP_DEVICE operator float() const
Umang Yadav's avatar
Umang Yadav committed
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
    {
        float fval;
        uint32_t i32val = static_cast<uint32_t>(data);

        // upcast
        if constexpr(T == migraphx_fp8::hip_f8_type::fp8)
        {
            asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
        }
        else
        {
            asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
        }

        return fval;
    }

Umang Yadav's avatar
Umang Yadav committed
277
    inline constexpr MIGRAPHX_HIP_HOST operator float() const
Umang Yadav's avatar
Umang Yadav committed
278
#else // non gfx940
Umang Yadav's avatar
Umang Yadav committed
279
    inline constexpr MIGRAPHX_HIP_HOST_DEVICE operator float() const
Umang Yadav's avatar
Umang Yadav committed
280
281
282
283
#endif
    {
        if constexpr(T == migraphx_fp8::hip_f8_type::fp8)
        {
Umang Yadav's avatar
Umang Yadav committed
284
285
            return migraphx_hip_f8_impl::
                cast_from_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>(data);
Umang Yadav's avatar
Umang Yadav committed
286
        } // else
Umang Yadav's avatar
Umang Yadav committed
287
288
        return migraphx_hip_f8_impl::
            cast_from_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>(data);
Umang Yadav's avatar
Umang Yadav committed
289
290
    }

Umang Yadav's avatar
Umang Yadav committed
291
292
293
294
295
296
297
    /*
        // convert to half
        explicit inline MIGRAPHX_HIP_HOST_DEVICE operator migraphx::half() const
        {
            return migraphx::half(float(*this)); // convert to float, then convert to f16
        }
    */
Umang Yadav's avatar
Umang Yadav committed
298
299

    // check for zero
Umang Yadav's avatar
Umang Yadav committed
300
301
302
303
304
305
306
307
308
309
310
    inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_zero() const
    {
        if constexpr(MIGRAPHX_FP8_FNUZ)
        {
            return data == 0x00;
        }
        else
        {
            return (data == 0x00) || (data == 0x80);
        }
    }
Umang Yadav's avatar
Umang Yadav committed
311
312

    // check for nan
Umang Yadav's avatar
Umang Yadav committed
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
    inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_nan() const
    {
        if constexpr(MIGRAPHX_FP8_FNUZ)
        {
            return data == 0x80;
        }
        else
        {
            if(T == migraphx_fp8::hip_f8_type::bf8)
            {
                return (data == 0x7d) || (data == 0x7e) || (data == 0x7f) || (data == 0xfd) ||
                       (data == 0xfe) || (data == 0xff);
            }
            else
            {
                return (data == 0x79) || (data == 0x7a) || (data == 0x7b) || (data == 0x7c) ||
                       (data == 0x7d) || (data == 0x7e) || (data == 0x7f) || (data == 0xf9) ||
                       (data == 0xfa) || (data == 0xfb) || (data == 0xfc) || (data == 0xfd) ||
                       (data == 0xfe) || (data == 0xff);
            }
        }
    }
Umang Yadav's avatar
Umang Yadav committed
335
336

    // check for inf
Umang Yadav's avatar
Umang Yadav committed
337
    inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_inf() const
Umang Yadav's avatar
Umang Yadav committed
338
    {
Umang Yadav's avatar
Umang Yadav committed
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
        if constexpr(MIGRAPHX_FP8_FNUZ)
        {
            return data == 0x80;
        }
        else
        {
            if(T == migraphx_fp8::hip_f8_type::bf8)
            {
                return (data == 0x7c) || (data == 0xfc);
            }
            else
            {
                return (data == 0x78) || (data == 0xf8);
            }
        }
Umang Yadav's avatar
Umang Yadav committed
354
355
    }

Umang Yadav's avatar
Umang Yadav committed
356
357
358
359
360
361
362
363
364
365
366
367
368
#define MIGRAPHX_FP8_UNARY_OP(unary_op, binary_op)                                    \
    constexpr hip_f8& MIGRAPHX_HIP_HOST_DEVICE operator unary_op(const hip_f8& rhs)   \
    {                                                                                 \
        const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
        *this          = static_cast<hip_f8>(tmp);                                    \
        return *this;                                                                 \
    }                                                                                 \
    constexpr hip_f8& MIGRAPHX_HIP_HOST_DEVICE operator unary_op(const float& rhs)    \
    {                                                                                 \
        const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
        *this          = static_cast<hip_f8>(tmp);                                    \
        return *this;                                                                 \
    }
Umang Yadav's avatar
Umang Yadav committed
369

Umang Yadav's avatar
Umang Yadav committed
370
371
372
373
    MIGRAPHX_FP8_UNARY_OP(*=, *)
    MIGRAPHX_FP8_UNARY_OP(-=, -)
    MIGRAPHX_FP8_UNARY_OP(+=, +)
    MIGRAPHX_FP8_UNARY_OP(/=, /)
Umang Yadav's avatar
Umang Yadav committed
374

Umang Yadav's avatar
Umang Yadav committed
375
376
    inline MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8& operator=(const hip_f8& rhs) = default;
    inline MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8& operator=(hip_f8&& rhs)      = default;
Umang Yadav's avatar
Umang Yadav committed
377

Umang Yadav's avatar
Umang Yadav committed
378
379
380
381
382
383
384
385
386
#if !defined(__HIP_NO_F8_CONVERSIONS__)
    // for the device kernels, this needs to be disabled since implicit_conversion op can type cast
    // any type to any other type and that results in conflicts in candidate overload resolutions.
    inline constexpr hip_f8& MIGRAPHX_HIP_HOST_DEVICE operator=(float rhs)
    {
        *this = static_cast<hip_f8>(rhs);
        return *this;
    }
#endif
Umang Yadav's avatar
Umang Yadav committed
387

Umang Yadav's avatar
Umang Yadav committed
388
389
390
391
392
393
394
    inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator==(const hip_f8& rhs) const
    {
        if((rhs.is_zero() && this->is_zero()) ||
           (fabs(rhs - *this) < std::numeric_limits<hip_f8<T>>::epsilon()))
            return true;
        else if(rhs.is_nan() || rhs.is_inf() || this->is_nan() || this->is_inf())
            return false;
Umang Yadav's avatar
Umang Yadav committed
395

Umang Yadav's avatar
Umang Yadav committed
396
397
        return false;
    }
Umang Yadav's avatar
Umang Yadav committed
398

Umang Yadav's avatar
Umang Yadav committed
399
400
401
402
403
404
    inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator<(const hip_f8& rhs) const
    {
        const auto we   = static_cast<float>(*this);
        const auto them = static_cast<float>(rhs);
        return we < them;
    }
Umang Yadav's avatar
Umang Yadav committed
405

Umang Yadav's avatar
Umang Yadav committed
406
407
408
409
410
411
412
    inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator>(const hip_f8& rhs) const
    {
        const auto we   = static_cast<float>(*this);
        const auto them = static_cast<float>(rhs);
        return we > them;
    }
};
Umang Yadav's avatar
Umang Yadav committed
413

Umang Yadav's avatar
Umang Yadav committed
414
415
416
417
#ifndef __HIPCC_RTC__
// Special operator overloading
template <migraphx_fp8::hip_f8_type T>
inline std::ostream& operator<<(std::ostream& os, const migraphx_fp8::hip_f8<T>& rhs)
Umang Yadav's avatar
Umang Yadav committed
418
{
Umang Yadav's avatar
Umang Yadav committed
419
    return os << static_cast<float>(rhs);
Umang Yadav's avatar
Umang Yadav committed
420
}
Umang Yadav's avatar
Umang Yadav committed
421
#endif
Umang Yadav's avatar
Umang Yadav committed
422

Umang Yadav's avatar
Umang Yadav committed
423
424
425
426
427
428
429
430
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, U)                                    \
    template <migraphx_fp8::hip_f8_type T>                                      \
    inline constexpr U MIGRAPHX_HIP_HOST_DEVICE operator binary_op(             \
        const migraphx_fp8::hip_f8<T>& lhs, const migraphx_fp8::hip_f8<T>& rhs) \
    {                                                                           \
        return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs));    \
    }
Umang Yadav's avatar
Umang Yadav committed
431

Umang Yadav's avatar
Umang Yadav committed
432
// TODO: these should return floats
Umang Yadav's avatar
Umang Yadav committed
433
434
435
436
437
438
439
440
441
442
443
MIGRAPHX_FP8_BINARY_OP(*, migraphx_fp8::hip_f8<T>)
MIGRAPHX_FP8_BINARY_OP(-, migraphx_fp8::hip_f8<T>)
MIGRAPHX_FP8_BINARY_OP(/, migraphx_fp8::hip_f8<T>)
MIGRAPHX_FP8_BINARY_OP(+, migraphx_fp8::hip_f8<T>)
// TODO: Comparison ops shouldn't convert to float, maybe need to take care of rounding effects.
MIGRAPHX_FP8_BINARY_OP(==, bool)
MIGRAPHX_FP8_BINARY_OP(>=, bool)
MIGRAPHX_FP8_BINARY_OP(<=, bool)
MIGRAPHX_FP8_BINARY_OP(>, bool)
MIGRAPHX_FP8_BINARY_OP(<, bool)
MIGRAPHX_FP8_BINARY_OP(!=, bool)
Umang Yadav's avatar
Umang Yadav committed
444

Umang Yadav's avatar
Umang Yadav committed
445
446
template <migraphx_fp8::hip_f8_type T>
inline MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<T> fabs(migraphx_fp8::hip_f8<T> v)
Umang Yadav's avatar
Umang Yadav committed
447
{
Umang Yadav's avatar
Umang Yadav committed
448
449
    v.data = v.data & 0x7f;
    return v;
Umang Yadav's avatar
Umang Yadav committed
450
451
}

Umang Yadav's avatar
Umang Yadav committed
452
453
template <class T>
MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Max()
Umang Yadav's avatar
Umang Yadav committed
454
{
Umang Yadav's avatar
Umang Yadav committed
455
    return T{0x7F, T::from_bits()};
Umang Yadav's avatar
Umang Yadav committed
456
457
}

Umang Yadav's avatar
Umang Yadav committed
458
459
template <class T>
MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Lowest()
Umang Yadav's avatar
Umang Yadav committed
460
{
Umang Yadav's avatar
Umang Yadav committed
461
    return T{0xFF, T::from_bits()};
Umang Yadav's avatar
Umang Yadav committed
462
463
}

Umang Yadav's avatar
Umang Yadav committed
464
using fp8e4m3fnuz = hip_f8<migraphx_fp8::hip_f8_type::fp8>;
Umang Yadav's avatar
Umang Yadav committed
465

Umang Yadav's avatar
Umang Yadav committed
466
/*
Umang Yadav's avatar
Umang Yadav committed
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
// Use h/w intrinsic and optimized version when __gfx940__
template <typename T,
          typename Ta,
          bool stochastic_rounding,
          typename std::enable_if<(!(migraphx::is_same<T, Ta>{}) &&
                                   (migraphx::is_same<T, migraphx_f8>{} ||
                                    migraphx::is_same<T, migraphx_bf8>{})),
                                  int>::type = 0>
inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
    // NOTE: we are directly calling cast_to_f8_from_f32 instead of constructor to optimize
    // away one runtime branch
    T val;
    if(migraphx::is_same<T, migraphx_f8>::value)
        val.data = migraphx_f8::cast_to_f8_from_f32<stochastic_rounding>(float(a), rng);
    else
        val.data = migraphx_bf8::cast_to_bf8_from_f32<stochastic_rounding>(float(a), rng);
    return val;
#else  // non gfx940
    return T(float(a),
             stochastic_rounding ? migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic
                                 : migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
             rng);
#endif // __gfx940__
}

// NOTE NOTE: The above code is good if we don't consider HIP-GEMM code and only consider
// the quantization However, if we need HIP-GEMM for fall-back, we would need explicit_cast
// handles Tacc=f32 to To=f16/bf16 conversion
template <typename T,
          typename Ta,
          bool stochastic_rounding,
          typename std::enable_if<(!(migraphx::is_same<T, Ta>{}) &&
                                   !(migraphx::is_same<T, migraphx_f8>{} ||
                                     migraphx::is_same<T, migraphx_bf8>{})),
                                  int>::type = 0>
inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng)
{
    // the return type is not a F8 types, no SR for those types
    // not sure if we have direct conversion, so converting to float first
    // no effect if the input type is float
    return T(float(a));
}
*/
} // namespace migraphx_fp8
Umang Yadav's avatar
Umang Yadav committed
513
// define numeric limits for the new data type
Umang Yadav's avatar
Umang Yadav committed
514
namespace std {
Umang Yadav's avatar
Umang Yadav committed
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
inline bool isfinite(migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> x) // NOLINT
{
    return x.is_inf();
}

inline bool isfinite(migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> x) // NOLINT
{
    return x.is_inf();
}

template <>
class numeric_limits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>
{
    public:
    static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> epsilon()
    {
        return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(float(0.0625));
    }

    static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> quiet_NaN()
    {
        return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(
            static_cast<uint8_t>(MIGRAPHX_FP8_FNUZ ? 0X80 : 0x79));
    }

    static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> max()
    {
        return migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>();
    }

    static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> min()
    {
        return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(-1.0f) *
               migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>();
    }

    static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> lowest()
    {
        return migraphx_fp8::F8_Lowest<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>();
    }
};

template <>
class numeric_limits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>
{
    public:
    static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> epsilon()
    {
        return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(float(0.125));
    }

    static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> quiet_NaN()
    {
        return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(
            static_cast<uint8_t>(MIGRAPHX_FP8_FNUZ ? 0X80 : 0x7d));
    }

    static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> max()
    {
        return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(
            migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>());
    }
    static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> min()
    {
        return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(float(-1.0f)) *
               migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>();
    }

    static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> lowest()
    {
        return migraphx_fp8::F8_Lowest<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>();
    }
};

template <class T>
struct common_type<migraphx_fp8::fp8e4m3fnuz, T> : std::common_type<float, T> // NOLINT
{
};

template <class T>
struct common_type<T, migraphx_fp8::fp8e4m3fnuz> : std::common_type<float, T> // NOLINT
{
};

template <>
struct common_type<migraphx_fp8::fp8e4m3fnuz, migraphx_fp8::fp8e4m3fnuz>
{
    using type = float;
};

Umang Yadav's avatar
Umang Yadav committed
605
606
} // namespace std
// =================================================================================================
Umang Yadav's avatar
Umang Yadav committed
607
#if defined(__clang__)
Umang Yadav's avatar
Umang Yadav committed
608
#pragma clang diagnostic pop
Umang Yadav's avatar
Umang Yadav committed
609
#endif
Umang Yadav's avatar
Umang Yadav committed
610
#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP