migraphx_float8.hpp 23.2 KB
Newer Older
Umang Yadav's avatar
Umang Yadav committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
/* ************************************************************************
 * Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
 * ies of the Software, and to permit persons to whom the Software is furnished
 * to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
 * PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
 * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
 * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
 * CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 *
 * ************************************************************************ */

Umang Yadav's avatar
Umang Yadav committed
23
24
#ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
#define MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP
Umang Yadav's avatar
Umang Yadav committed
25
#if defined(__clang__)
Umang Yadav's avatar
Umang Yadav committed
26
27
28
29
30
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wfloat-equal"
#pragma clang diagnostic ignored "-Wmacro-redefined"
#pragma clang diagnostic ignored "-Wc++20-extensions"
31
#endif // __clang__
Umang Yadav's avatar
Umang Yadav committed
32
33
34

#if(defined(__HIP_PLATFORM_HCC__) || defined(__HIP_PLATFORM_AMD__))
// need to include hip_runtime.h otherwise it complains about __host__ and __device__
35
#if defined(MIGRAPHX_JIT_USE_HIPRTC)
Umang Yadav's avatar
Umang Yadav committed
36
#include <migraphx/kernels/hip.hpp>
37
38
#else
#include <hip/hip_runtime.h>
Umang Yadav's avatar
Umang Yadav committed
39
#endif
Umang Yadav's avatar
Umang Yadav committed
40
#define MIGRAPHX_HIP_HOST_DEVICE __host__ __device__
Umang Yadav's avatar
Umang Yadav committed
41
#define MIGRAPHX_HIP_HOST __host__
Umang Yadav's avatar
Umang Yadav committed
42
43
#else
#define MIGRAPHX_HIP_HOST_DEVICE
Umang Yadav's avatar
Umang Yadav committed
44
#define MIGRAPHX_HIP_HOST
45
#endif // HIP_PLATFORM_AMD
Umang Yadav's avatar
Umang Yadav committed
46

Umang Yadav's avatar
Umang Yadav committed
47
48
#define MIGRAPHX_HIP_DEVICE __device__

Umang Yadav's avatar
Umang Yadav committed
49
50
#ifndef MIGRAPHX_FP8_FNUZ
#define MIGRAPHX_FP8_FNUZ true
51
#endif // MIGRAPHX_FP8_FNUZ
Umang Yadav's avatar
Umang Yadav committed
52

Umang Yadav's avatar
Umang Yadav committed
53
54
// We are clipping in down conversion by default
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1
55
56
57
58
59
60
#if defined(MIGRAPHX_JIT_USE_HIPRTC)
#include <migraphx/kernels/types.hpp>
using uint8_t  = migraphx::uint8_t;
using uint16_t = migraphx::uint16_t;
using uint32_t = migraphx::uint32_t;
#else
Umang Yadav's avatar
Umang Yadav committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include <cmath>
#include <cstdint>
#include <climits>
#include <cstring>
#include <iosfwd>
#include <limits>
#include <sstream>
#include <iostream>
#include <string>
#include <utility>
#endif

namespace migraphx_hip_f8_impl {

template <int wm, int we, typename T, bool negative_zero_nan, bool clip>
Umang Yadav's avatar
Umang Yadav committed
76
MIGRAPHX_HIP_HOST_DEVICE constexpr uint8_t cast_to_f8(T _x, bool stoch = false, uint32_t rng = 0);
Umang Yadav's avatar
Umang Yadav committed
77
78

template <int wm, int we, typename T, bool negative_zero_nan>
Umang Yadav's avatar
Umang Yadav committed
79
MIGRAPHX_HIP_HOST_DEVICE constexpr T cast_from_f8(uint8_t x);
Umang Yadav's avatar
Umang Yadav committed
80
81
82

} // namespace migraphx_hip_f8_impl

Umang Yadav's avatar
Umang Yadav committed
83
#include <migraphx/migraphx_hip_f8_impl.hpp>
Umang Yadav's avatar
Umang Yadav committed
84
85
86
87
88

namespace migraphx_fp8 {

enum class migraphx_hip_f8_rounding_mode
{
Umang Yadav's avatar
Umang Yadav committed
89
    standard, // standard rounding is doing RNE -- round to nearest even
Umang Yadav's avatar
Umang Yadav committed
90
91
92
93
94
95
96
97
98
    stochastic
};

enum class hip_f8_type
{
    bf8 = 0, // s1e5m2
    fp8 = 1  // s1e4m3
};

99
100
101
template <typename T>
class NumericLimits;

Umang Yadav's avatar
Umang Yadav committed
102
template <migraphx_fp8::hip_f8_type T = migraphx_fp8::hip_f8_type::fp8>
Umang Yadav's avatar
Umang Yadav committed
103
struct hip_f8
Umang Yadav's avatar
Umang Yadav committed
104
105
106
{
    uint8_t data;
    // default constructor
Umang Yadav's avatar
Umang Yadav committed
107
108
109
110
111
112
113
114
115
    MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8() = default;
    // default copy constructor
    MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8(const hip_f8& y) = default;
    struct from_bits_t
    {
    };
    static constexpr MIGRAPHX_HIP_HOST_DEVICE from_bits_t from_bits() { return from_bits_t(); }

    MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8(uint8_t bits, from_bits_t) : data(bits) {}
Umang Yadav's avatar
Umang Yadav committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
    // device specific optimized F8 down-conversion code

    template <bool stochastic_rounding = false>
    static MIGRAPHX_HIP_DEVICE uint8_t cast_to_f8_from_f32(float v, uint32_t rng = 0)
    {
        uint8_t i8data;
        union
        {
            float fval;
            uint32_t i32val;
            uint8_t i8val[4]; // NOTE: not endian independent
        } val;

        uint32_t ival = 0;
        val.fval      = v;

#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
        if constexpr(T == migraphx_fp8::hip_f8_type::fp8)
        {
            if((val.i32val & 0x7F800000) != 0x7F800000) /// propagate NAN/INF, no clipping
                val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
        }
        else
        {
            if((val.i32val & 0x7F800000) != 0x7F800000) // propagate NAN/INF, no clipping
                val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
        }
#endif
        if(stochastic_rounding)
        {
            if constexpr(T == migraphx_fp8::hip_f8_type::fp8)
            {
                ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
            }
            else
            {
                ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
            }
        }
Umang Yadav's avatar
Umang Yadav committed
157
        else // RNE CVT
Umang Yadav's avatar
Umang Yadav committed
158
159
160
161
162
163
164
165
166
167
168
169
        {
            if constexpr(T == migraphx_fp8::hip_f8_type::fp8)
            {
                ival = __builtin_amdgcn_cvt_pk_fp8_f32(
                    val.fval, val.fval, ival, false); // false -> WORD0
            }
            else
            {
                ival = __builtin_amdgcn_cvt_pk_bf8_f32(
                    val.fval, val.fval, ival, false); // false -> WORD0}
            }
        }
Umang Yadav's avatar
Umang Yadav committed
170
171
172
173
        val.i32val = ival;
        i8data     = val.i8val[0]; // little endian

        return i8data;
Umang Yadav's avatar
Umang Yadav committed
174
175
176
177
178
179
180
    }
#endif // __gfx940__

       // constructor from float
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)

    // NOTE: ON-DEVICE... always optimal bias
Umang Yadav's avatar
Umang Yadav committed
181
182
183
184
    explicit MIGRAPHX_HIP_DEVICE hip_f8(float v,
                                        migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
                                            migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
                                        uint32_t rng = 0)
Umang Yadav's avatar
Umang Yadav committed
185
186
187
188
189
190
191
192
193
194
195
196
    {
        // runtime branch, use cast_to_f8_from_f32 if want to avoid it
        if(rm == migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic)
            data = cast_to_f8_from_f32<true>(v, rng);
        else
            data = cast_to_f8_from_f32<false>(v);
    }

    // Host only implementation using s/w simulation
    explicit MIGRAPHX_HIP_HOST
#else
    // both Host and DEVICE for non-gfx940 using s/w simulation
Umang Yadav's avatar
Umang Yadav committed
197
    explicit constexpr MIGRAPHX_HIP_HOST_DEVICE
Umang Yadav's avatar
Umang Yadav committed
198
#endif
Umang Yadav's avatar
Umang Yadav committed
199
200
201
202
    hip_f8(float v,
           migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
               migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
           uint32_t rng = 0)
Umang Yadav's avatar
Umang Yadav committed
203
204
205
206
207
    {
        if constexpr(T == migraphx_fp8::hip_f8_type::fp8)
        {
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
            data = migraphx_hip_f8_impl::
Umang Yadav's avatar
Umang Yadav committed
208
                cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>(
Umang Yadav's avatar
Umang Yadav committed
209
210
211
                    v, (rm == migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic), rng);
#else  // MIGRAPHX_F8_DOWNCAST_CLIPPING
            data = migraphx_hip_f8_impl::
Umang Yadav's avatar
Umang Yadav committed
212
                cast_to_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>(
Umang Yadav's avatar
Umang Yadav committed
213
214
215
216
217
218
219
                    v, (rm == migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic), rng);
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
        }
        else
        {
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
            data = migraphx_hip_f8_impl::
Umang Yadav's avatar
Umang Yadav committed
220
                cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, true /*clip*/>(
Umang Yadav's avatar
Umang Yadav committed
221
222
223
                    v, (rm == migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic), rng);
#else  // MIGRAPHX_F8_DOWNCAST_CLIPPING
            data = migraphx_hip_f8_impl::
Umang Yadav's avatar
Umang Yadav committed
224
                cast_to_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/, false /*clip*/>(
Umang Yadav's avatar
Umang Yadav committed
225
226
227
228
229
                    v, (rm == migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic), rng);
#endif // rocblas_F8_downcast_clipping}
        }
    }

Umang Yadav's avatar
Umang Yadav committed
230
231
232
233
234
235
236
237
238
239
    /*
        // Constructor from half
        explicit constexpr MIGRAPHX_HIP_HOST_DEVICE
        hip_f8(migraphx::half v,
               migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
                   migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
               uint32_t rng = 0)
            : hip_f8((float)v, rm, rng)
        {
        }
Umang Yadav's avatar
Umang Yadav committed
240
241

    // constructor from int
Umang Yadav's avatar
Umang Yadav committed
242
243
244
245
246
247
    explicit constexpr MIGRAPHX_HIP_HOST_DEVICE
    hip_f8(int v,
           migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
               migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
           uint32_t rng = 0)
        : hip_f8((float)v, rm, rng)
Umang Yadav's avatar
Umang Yadav committed
248
249
    {
    }
Umang Yadav's avatar
Umang Yadav committed
250

Umang Yadav's avatar
Umang Yadav committed
251
    // constructor from double
Umang Yadav's avatar
Umang Yadav committed
252
253
254
255
256
257
    explicit constexpr MIGRAPHX_HIP_HOST_DEVICE
    hip_f8(double v,
           migraphx_fp8::migraphx_hip_f8_rounding_mode rm =
               migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
           uint32_t rng = 0)
        : hip_f8((float)v, rm, rng)
Umang Yadav's avatar
Umang Yadav committed
258
259
    {
    }
Umang Yadav's avatar
Umang Yadav committed
260
261
    */
    /**/
Umang Yadav's avatar
Umang Yadav committed
262
    // convert to float
Umang Yadav's avatar
Umang Yadav committed
263
264
// #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if 0 // need constexpr operator(). This version can't be constexpr
Umang Yadav's avatar
Umang Yadav committed
265
    // upcast using device specific intrinsic
Umang Yadav's avatar
Umang Yadav committed
266
    inline MIGRAPHX_HIP_DEVICE operator float() const
Umang Yadav's avatar
Umang Yadav committed
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    {
        float fval;
        uint32_t i32val = static_cast<uint32_t>(data);

        // upcast
        if constexpr(T == migraphx_fp8::hip_f8_type::fp8)
        {
            asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
        }
        else
        {
            asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
        }

        return fval;
    }

Umang Yadav's avatar
Umang Yadav committed
284
    inline constexpr MIGRAPHX_HIP_HOST operator float() const
Umang Yadav's avatar
Umang Yadav committed
285
#else // non gfx940
Umang Yadav's avatar
Umang Yadav committed
286
    inline constexpr MIGRAPHX_HIP_HOST_DEVICE operator float() const
Umang Yadav's avatar
Umang Yadav committed
287
288
289
290
#endif
    {
        if constexpr(T == migraphx_fp8::hip_f8_type::fp8)
        {
Umang Yadav's avatar
Umang Yadav committed
291
292
            return migraphx_hip_f8_impl::
                cast_from_f8<3, 4, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>(data);
Umang Yadav's avatar
Umang Yadav committed
293
        } // else
Umang Yadav's avatar
Umang Yadav committed
294
295
        return migraphx_hip_f8_impl::
            cast_from_f8<2, 5, float, MIGRAPHX_FP8_FNUZ /*negative_zero_nan*/>(data);
Umang Yadav's avatar
Umang Yadav committed
296
297
    }

Umang Yadav's avatar
Umang Yadav committed
298
299
300
301
302
303
304
    /*
        // convert to half
        explicit inline MIGRAPHX_HIP_HOST_DEVICE operator migraphx::half() const
        {
            return migraphx::half(float(*this)); // convert to float, then convert to f16
        }
    */
Umang Yadav's avatar
Umang Yadav committed
305
306

    // check for zero
Umang Yadav's avatar
Umang Yadav committed
307
308
309
310
311
312
313
314
315
316
317
    inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_zero() const
    {
        if constexpr(MIGRAPHX_FP8_FNUZ)
        {
            return data == 0x00;
        }
        else
        {
            return (data == 0x00) || (data == 0x80);
        }
    }
Umang Yadav's avatar
Umang Yadav committed
318
319

    // check for nan
Umang Yadav's avatar
Umang Yadav committed
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
    inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_nan() const
    {
        if constexpr(MIGRAPHX_FP8_FNUZ)
        {
            return data == 0x80;
        }
        else
        {
            if(T == migraphx_fp8::hip_f8_type::bf8)
            {
                return (data == 0x7d) || (data == 0x7e) || (data == 0x7f) || (data == 0xfd) ||
                       (data == 0xfe) || (data == 0xff);
            }
            else
            {
                return (data == 0x79) || (data == 0x7a) || (data == 0x7b) || (data == 0x7c) ||
                       (data == 0x7d) || (data == 0x7e) || (data == 0x7f) || (data == 0xf9) ||
                       (data == 0xfa) || (data == 0xfb) || (data == 0xfc) || (data == 0xfd) ||
                       (data == 0xfe) || (data == 0xff);
            }
        }
    }
Umang Yadav's avatar
Umang Yadav committed
342
343

    // check for inf
Umang Yadav's avatar
Umang Yadav committed
344
    inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool is_inf() const
Umang Yadav's avatar
Umang Yadav committed
345
    {
Umang Yadav's avatar
Umang Yadav committed
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
        if constexpr(MIGRAPHX_FP8_FNUZ)
        {
            return data == 0x80;
        }
        else
        {
            if(T == migraphx_fp8::hip_f8_type::bf8)
            {
                return (data == 0x7c) || (data == 0xfc);
            }
            else
            {
                return (data == 0x78) || (data == 0xf8);
            }
        }
Umang Yadav's avatar
Umang Yadav committed
361
362
    }

Umang Yadav's avatar
Umang Yadav committed
363
364
365
366
367
368
369
370
371
372
373
374
375
#define MIGRAPHX_FP8_UNARY_OP(unary_op, binary_op)                                    \
    constexpr hip_f8& MIGRAPHX_HIP_HOST_DEVICE operator unary_op(const hip_f8& rhs)   \
    {                                                                                 \
        const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
        *this          = static_cast<hip_f8>(tmp);                                    \
        return *this;                                                                 \
    }                                                                                 \
    constexpr hip_f8& MIGRAPHX_HIP_HOST_DEVICE operator unary_op(const float& rhs)    \
    {                                                                                 \
        const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
        *this          = static_cast<hip_f8>(tmp);                                    \
        return *this;                                                                 \
    }
Umang Yadav's avatar
Umang Yadav committed
376

Umang Yadav's avatar
Umang Yadav committed
377
378
379
380
    MIGRAPHX_FP8_UNARY_OP(*=, *)
    MIGRAPHX_FP8_UNARY_OP(-=, -)
    MIGRAPHX_FP8_UNARY_OP(+=, +)
    MIGRAPHX_FP8_UNARY_OP(/=, /)
Umang Yadav's avatar
Umang Yadav committed
381

Umang Yadav's avatar
Umang Yadav committed
382
383
    inline MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8& operator=(const hip_f8& rhs) = default;
    inline MIGRAPHX_HIP_HOST_DEVICE constexpr hip_f8& operator=(hip_f8&& rhs)      = default;
Umang Yadav's avatar
Umang Yadav committed
384

Umang Yadav's avatar
Umang Yadav committed
385
386
387
388
389
390
391
392
393
#if !defined(__HIP_NO_F8_CONVERSIONS__)
    // for the device kernels, this needs to be disabled since implicit_conversion op can type cast
    // any type to any other type and that results in conflicts in candidate overload resolutions.
    inline constexpr hip_f8& MIGRAPHX_HIP_HOST_DEVICE operator=(float rhs)
    {
        *this = static_cast<hip_f8>(rhs);
        return *this;
    }
#endif
Umang Yadav's avatar
Umang Yadav committed
394

Umang Yadav's avatar
Umang Yadav committed
395
396
397
    inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator==(const hip_f8& rhs) const
    {
        if((rhs.is_zero() && this->is_zero()) ||
398
           (fabs(rhs - *this) < migraphx_fp8::NumericLimits<hip_f8<T>>::epsilon()))
Umang Yadav's avatar
Umang Yadav committed
399
400
401
            return true;
        else if(rhs.is_nan() || rhs.is_inf() || this->is_nan() || this->is_inf())
            return false;
Umang Yadav's avatar
Umang Yadav committed
402

Umang Yadav's avatar
Umang Yadav committed
403
404
        return false;
    }
Umang Yadav's avatar
Umang Yadav committed
405

Umang Yadav's avatar
Umang Yadav committed
406
407
408
409
410
411
    inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator<(const hip_f8& rhs) const
    {
        const auto we   = static_cast<float>(*this);
        const auto them = static_cast<float>(rhs);
        return we < them;
    }
Umang Yadav's avatar
Umang Yadav committed
412

Umang Yadav's avatar
Umang Yadav committed
413
414
415
416
417
418
419
    inline MIGRAPHX_HIP_HOST_DEVICE constexpr bool operator>(const hip_f8& rhs) const
    {
        const auto we   = static_cast<float>(*this);
        const auto them = static_cast<float>(rhs);
        return we > them;
    }
};
Umang Yadav's avatar
Umang Yadav committed
420

421
#ifndef MIGRAPHX_JIT_USE_HIPRTC
Umang Yadav's avatar
Umang Yadav committed
422
423
424
// Special operator overloading
template <migraphx_fp8::hip_f8_type T>
inline std::ostream& operator<<(std::ostream& os, const migraphx_fp8::hip_f8<T>& rhs)
Umang Yadav's avatar
Umang Yadav committed
425
{
Umang Yadav's avatar
Umang Yadav committed
426
    return os << static_cast<float>(rhs);
Umang Yadav's avatar
Umang Yadav committed
427
}
Umang Yadav's avatar
Umang Yadav committed
428
#endif
Umang Yadav's avatar
Umang Yadav committed
429

Umang Yadav's avatar
Umang Yadav committed
430
431
432
433
434
435
436
437
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, U)                                    \
    template <migraphx_fp8::hip_f8_type T>                                      \
    inline constexpr U MIGRAPHX_HIP_HOST_DEVICE operator binary_op(             \
        const migraphx_fp8::hip_f8<T>& lhs, const migraphx_fp8::hip_f8<T>& rhs) \
    {                                                                           \
        return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs));    \
    }
Umang Yadav's avatar
Umang Yadav committed
438

Umang Yadav's avatar
Umang Yadav committed
439
// TODO: these should return floats
Umang Yadav's avatar
Umang Yadav committed
440
441
442
443
444
445
446
447
448
449
450
MIGRAPHX_FP8_BINARY_OP(*, migraphx_fp8::hip_f8<T>)
MIGRAPHX_FP8_BINARY_OP(-, migraphx_fp8::hip_f8<T>)
MIGRAPHX_FP8_BINARY_OP(/, migraphx_fp8::hip_f8<T>)
MIGRAPHX_FP8_BINARY_OP(+, migraphx_fp8::hip_f8<T>)
// TODO: Comparison ops shouldn't convert to float, maybe need to take care of rounding effects.
MIGRAPHX_FP8_BINARY_OP(==, bool)
MIGRAPHX_FP8_BINARY_OP(>=, bool)
MIGRAPHX_FP8_BINARY_OP(<=, bool)
MIGRAPHX_FP8_BINARY_OP(>, bool)
MIGRAPHX_FP8_BINARY_OP(<, bool)
MIGRAPHX_FP8_BINARY_OP(!=, bool)
Umang Yadav's avatar
Umang Yadav committed
451

Umang Yadav's avatar
Umang Yadav committed
452
453
template <migraphx_fp8::hip_f8_type T>
inline MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<T> fabs(migraphx_fp8::hip_f8<T> v)
Umang Yadav's avatar
Umang Yadav committed
454
{
Umang Yadav's avatar
Umang Yadav committed
455
456
    v.data = v.data & 0x7f;
    return v;
Umang Yadav's avatar
Umang Yadav committed
457
458
}

Umang Yadav's avatar
Umang Yadav committed
459
460
template <class T>
MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Max()
Umang Yadav's avatar
Umang Yadav committed
461
{
Umang Yadav's avatar
Umang Yadav committed
462
    return T{0x7F, T::from_bits()};
Umang Yadav's avatar
Umang Yadav committed
463
464
}

Umang Yadav's avatar
Umang Yadav committed
465
466
template <class T>
MIGRAPHX_HIP_HOST_DEVICE constexpr T F8_Lowest()
Umang Yadav's avatar
Umang Yadav committed
467
{
Umang Yadav's avatar
Umang Yadav committed
468
    return T{0xFF, T::from_bits()};
Umang Yadav's avatar
Umang Yadav committed
469
470
}

Umang Yadav's avatar
Umang Yadav committed
471
using fp8e4m3fnuz = hip_f8<migraphx_fp8::hip_f8_type::fp8>;
Umang Yadav's avatar
Umang Yadav committed
472

473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
template <>
class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>
{
    public:
    static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> epsilon()
    {
        return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(float(0.0625));
    }

    static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> quiet_NaN()
    {
        return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(
            static_cast<uint8_t>(MIGRAPHX_FP8_FNUZ ? 0X80 : 0x79));
    }

    static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> max()
    {
        return migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>();
    }

    static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> min()
    {
        return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(-1.0f) *
               migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>();
    }

    static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> lowest()
    {
        return migraphx_fp8::F8_Lowest<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>();
    }
Umang Yadav's avatar
Umang Yadav committed
503
504
505
506
507
508
509
510
511
512
513

    static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> infinity()
    {
        if constexpr(MIGRAPHX_FP8_FNUZ)
        {
            return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(
                static_cast<uint8_t>(0x80));
        }
        return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>(
            static_cast<uint8_t>(0x78));
    }
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
};

template <>
class NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>
{
    public:
    static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> epsilon()
    {
        return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(float(0.125));
    }

    static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> quiet_NaN()
    {
        return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(
            static_cast<uint8_t>(MIGRAPHX_FP8_FNUZ ? 0X80 : 0x7d));
    }

    static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> max()
    {
        return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(
            migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>());
    }
    static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> min()
    {
        return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(float(-1.0f)) *
               migraphx_fp8::F8_Max<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>();
    }

    static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> lowest()
    {
        return migraphx_fp8::F8_Lowest<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>();
    }
Umang Yadav's avatar
Umang Yadav committed
546
547
548
549
550
551
552
553
554
555
556

    static MIGRAPHX_HIP_HOST_DEVICE migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> infinity()
    {
        if constexpr(MIGRAPHX_FP8_FNUZ)
        {
            return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(
                static_cast<uint8_t>(0x80));
        }
        return static_cast<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>(
            static_cast<uint8_t>(0x7c));
    }
557
};
Umang Yadav's avatar
Umang Yadav committed
558
/*
Umang Yadav's avatar
Umang Yadav committed
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
// Use h/w intrinsic and optimized version when __gfx940__
template <typename T,
          typename Ta,
          bool stochastic_rounding,
          typename std::enable_if<(!(migraphx::is_same<T, Ta>{}) &&
                                   (migraphx::is_same<T, migraphx_f8>{} ||
                                    migraphx::is_same<T, migraphx_bf8>{})),
                                  int>::type = 0>
inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
    // NOTE: we are directly calling cast_to_f8_from_f32 instead of constructor to optimize
    // away one runtime branch
    T val;
    if(migraphx::is_same<T, migraphx_f8>::value)
        val.data = migraphx_f8::cast_to_f8_from_f32<stochastic_rounding>(float(a), rng);
    else
        val.data = migraphx_bf8::cast_to_bf8_from_f32<stochastic_rounding>(float(a), rng);
    return val;
#else  // non gfx940
    return T(float(a),
             stochastic_rounding ? migraphx_fp8::migraphx_hip_f8_rounding_mode::stochastic
                                 : migraphx_fp8::migraphx_hip_f8_rounding_mode::standard,
             rng);
#endif // __gfx940__
}

// NOTE NOTE: The above code is good if we don't consider HIP-GEMM code and only consider
// the quantization However, if we need HIP-GEMM for fall-back, we would need explicit_cast
// handles Tacc=f32 to To=f16/bf16 conversion
template <typename T,
          typename Ta,
          bool stochastic_rounding,
          typename std::enable_if<(!(migraphx::is_same<T, Ta>{}) &&
                                   !(migraphx::is_same<T, migraphx_f8>{} ||
                                     migraphx::is_same<T, migraphx_bf8>{})),
                                  int>::type = 0>
inline __host__ __device__ T explicit_downcast(Ta a, uint32_t rng)
{
    // the return type is not a F8 types, no SR for those types
    // not sure if we have direct conversion, so converting to float first
    // no effect if the input type is float
    return T(float(a));
}
*/
} // namespace migraphx_fp8
Umang Yadav's avatar
Umang Yadav committed
605
// define numeric limits for the new data type
606
#ifndef MIGRAPHX_JIT_USE_HIPRTC
Umang Yadav's avatar
Umang Yadav committed
607
namespace std {
Umang Yadav's avatar
Umang Yadav committed
608
609
610
611
612
613
614
615
616
617
inline bool isfinite(migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> x) // NOLINT
{
    return x.is_inf();
}

inline bool isfinite(migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> x) // NOLINT
{
    return x.is_inf();
}

Umang Yadav's avatar
Umang Yadav committed
618
619
620
621
622
623
624
625
626
627
inline bool isnan(migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8> x) // NOLINT
{
    return x.is_nan();
}

inline bool isnan(migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8> x) // NOLINT
{
    return x.is_nan();
}

Umang Yadav's avatar
Umang Yadav committed
628
629
template <>
class numeric_limits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>
630
    : public migraphx_fp8::NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::fp8>>
Umang Yadav's avatar
Umang Yadav committed
631
632
633
634
635
{
};

template <>
class numeric_limits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>
636
    : public migraphx_fp8::NumericLimits<migraphx_fp8::hip_f8<migraphx_fp8::hip_f8_type::bf8>>
Umang Yadav's avatar
Umang Yadav committed
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
{
};

template <class T>
struct common_type<migraphx_fp8::fp8e4m3fnuz, T> : std::common_type<float, T> // NOLINT
{
};

template <class T>
struct common_type<T, migraphx_fp8::fp8e4m3fnuz> : std::common_type<float, T> // NOLINT
{
};

template <>
struct common_type<migraphx_fp8::fp8e4m3fnuz, migraphx_fp8::fp8e4m3fnuz>
{
    using type = float;
};

Umang Yadav's avatar
Umang Yadav committed
656
} // namespace std
657
#endif
Umang Yadav's avatar
Umang Yadav committed
658
// =================================================================================================
Umang Yadav's avatar
Umang Yadav committed
659
#if defined(__clang__)
Umang Yadav's avatar
Umang Yadav committed
660
#pragma clang diagnostic pop
Umang Yadav's avatar
Umang Yadav committed
661
#endif
Umang Yadav's avatar
Umang Yadav committed
662
#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP