mxf8_utils.hpp 18.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"

#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__
#define CK_MX_FP8_CVT_FAST_PATH 1
#else
#define CK_MX_FP8_CVT_FAST_PATH 0
#endif

namespace ck {

namespace fp8_impl {
#if CK_MX_FP8_CVT_FAST_PATH
template <ck_fp8_interpretation_t interpret>
static __device__ float cast_to_f32_from_f8_scaled(float scale, fp8_storage_t v)
{
    union
    {
        unsigned int i32val;
        unsigned char i8val[4];
    } val;
    val.i8val[0] = v;

    static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
                      interpret == ck_fp8_interpretation_t::CK_E5M2_OCP,
                  "Only OCP interpretations are supported");

    if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
    {
        return __builtin_amdgcn_cvt_scalef32_f32_fp8(val.i32val, scale, 0);
    }
    else
    {
        return __builtin_amdgcn_cvt_scalef32_f32_bf8(val.i32val, scale, 0);
    }
}

template <ck_fp8_interpretation_t interpret>
static __device__ float2_t cast_to_f32x2_from_f8x2_scaled(float scale, fp8x2_storage_t v)
{
    const auto i16val = bit_cast<uint16_t>(v);

    static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
                      interpret == ck_fp8_interpretation_t::CK_E5M2_OCP,
                  "Only OCP interpretations are supported");

    if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
    {
        return __builtin_amdgcn_cvt_scalef32_pk_f32_fp8(i16val, scale, 0);
    }
    else
    {
        return __builtin_amdgcn_cvt_scalef32_pk_f32_bf8(i16val, scale, 0);
    }
}

template <ck_fp8_interpretation_t interpret, bool stochastic_rounding = false>
static __device__ fp8_storage_t cast_to_f8_from_f32_scaled(float v,
                                                           unsigned int rng = 0,
                                                           float scale      = 1.0f)
{
    fp8_storage_t i8data;
    union
    {
        float fval;
        unsigned int i32val;
    } val;

    union
    {
        uint32_t ival;
        vector_type<int16_t, 2>::type v2i16;
        fp8_storage_t v4i8[4];
    } ret{};

    // unsigned int ival = 0;
    val.fval = v;

    if constexpr(stochastic_rounding)
    {
        ret.ival =
            (interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
                ? __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, val.fval, rng, scale, 0)
                : __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, val.fval, rng, scale, 0);

        i8data = ret.v4i8[0];
    }
    else
    {
        // RNE CVT
        // llvm.amdgcn.cvt.scalef32.pk.fp8.f32
        // v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
        if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
        {
            // If fval / scale > max fp8, returns Nan
            ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32(/*old_vdst*/ ret.v2i16,
                                                                 val.fval,
                                                                 val.fval,
                                                                 scale,
                                                                 /*dst_lo_hi_sel*/ false);
        }
        else
        {
            // If fval / scale > max bf8, returns Inf
            ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32(/*old_vdst*/ ret.v2i16,
                                                                 val.fval,
                                                                 val.fval,
                                                                 scale,
                                                                 /*dst_lo_hi_sel*/ false);
        }

        i8data = ret.v4i8[0];
    }
    return i8data;
}

template <ck_fp8_interpretation_t interpret, bool stochastic_rounding = false>
static __device__ fp8x2_storage_t cast_to_f8_from_f32_scaled(float2_t v,
                                                             unsigned int rng = 0,
                                                             float scale      = 1.0f)
{

    union
    {
        uint32_t ival;
        vector_type<int16_t, 2>::type v2i16;
        StaticallyIndexedArray<fp8x2_storage_t, 2> v2f8x2;
    } ret{};

    if constexpr(stochastic_rounding)
    {
        fp8x2_storage_t f8x2;
        if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
        {
            ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[0], rng, scale, 0);
            f8x2[0]  = ret.v2f8x2(Number<0>{})[0];
            ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[1], rng, scale, 0);
            f8x2[1]  = ret.v2f8x2(Number<0>{})[0];
        }
        else
        {
            ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[0], rng, scale, 0);
            f8x2[0]  = ret.v2f8x2(Number<0>{})[0];
            ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[1], rng, scale, 0);
            f8x2[1]  = ret.v2f8x2(Number<0>{})[0];
        }
        return f8x2;
    }
    else
    {
        // RNE CVT
        // llvm.amdgcn.cvt.scalef32.pk.fp8.f32
        // v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
        if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
        {
            // If fval / scale > max fp8, returns Nan
            ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32(/*old_vdst*/ ret.v2i16,
                                                                 v[0],
                                                                 v[1],
                                                                 scale,
                                                                 /*dst_lo_hi_sel*/ false);
        }
        else
        {
            // If fval / scale > max bf8, returns Inf
            ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32(/*old_vdst*/ ret.v2i16,
                                                                 v[0],
                                                                 v[1],
                                                                 scale,
                                                                 /*dst_lo_hi_sel*/ false);
        }

        return ret.v2f8x2(Number<0>{});
    }
}

#endif // CK_MX_FP8_CVT_FAST_PATH

#if CK_MX_FP8_CVT_FAST_PATH
/**
 * \brief convert float to @p fp8_storage_t with scaling
 *
 * This version is used when the fast path (MX FP8 hardware) is available
 *
 * \tparam interp interpretation of fp8
 * \param f float number
 * \param scale scaling factor
 * \return fp8_storage_t
 */
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const float f, float scale)
{
    __is_interpret_supported(interp);
    uint32_t rng = 0;
    if constexpr(stochastic_rounding)
    {
        constexpr int seed = 1254739;
        rng                = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
    }
    return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
}

/**
 * \brief convert 2xfloat to @p 2xfp8_storage_t with scaling
 *
 * This version is used when the fast path (MX FP8 hardware) is available
 *
 * \tparam interp interpretation of fp8
 * \param f 2xfloat
 * \param scale scaling factor
 * \return 2xfp8_storage_t
 */
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const float2_t f,
                                                                          float scale)
{
    __is_interpret_supported(interp);
    uint32_t rng = 0;
    if constexpr(stochastic_rounding)
    {
        constexpr int seed = 1254739;
        rng                = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
    }
    return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
}

#else

/**
 * \brief convert float to @p fp8_storage_t with scaling
 *
 * This version is used when the fast path (MX FP8 hardware) is not available
 *
 * \tparam interp interpretation of fp8
 * \param f float number
 * \param scale scaling factor
 * \return fp8_storage_t
 */
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const float f, float scale)
{

    static_assert(interp == ck_fp8_interpretation_t::CK_E4M3_OCP ||
                      interp == ck_fp8_interpretation_t::CK_E5M2_OCP,
                  "Only OCP interpretations are supported");

    uint32_t rng = 0;
    if constexpr(stochastic_rounding)
    {
        constexpr int seed = 1254739;
        rng                = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
    }

    if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
    {
        return cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f / scale, rng);
    }
    else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
    {
        return cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f / scale, rng);
    }
    else
    {
        __hip_assert(false && "FP8 type is not supported by current target device");
        return 0;
    }
}

/**
 * \brief convert two float to @p 2xfp8_storage_t with scaling
 *
 * This version is used when the fast path (MX FP8 hardware) is not available
 *
 * \tparam interp interpretation of fp8
 * \param f 2xfloat
 * \param scale scaling factor
 * \return 2xfp8_storage_t
 */
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const float2_t f,
                                                                          float scale)
{

    static_assert(interp == ck_fp8_interpretation_t::CK_E4M3_OCP ||
                      interp == ck_fp8_interpretation_t::CK_E5M2_OCP,
                  "Only OCP interpretations are supported");

    uint32_t rng = 0;
    if constexpr(stochastic_rounding)
    {
        constexpr int seed = 1254739;
        rng                = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
    }

    if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
    {
        return {cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[0] / scale, rng),
                cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[1] / scale, rng)};
    }
    else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
    {
        return {cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[0] / scale, rng),
                cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[1] / scale, rng)};
    }
    else
    {
        __hip_assert(false && "FP8 type is not supported by current target device");
        return 0;
    }
}

#endif // CK_MX_FP8_CVT_FAST_PATH

} // namespace fp8_impl

// Declare a template function for fp8 conversion using SR
template <typename Y, typename X>
__host__ __device__ constexpr Y mxf8_convert_sr(X x, float scale);

// Declare a template function for fp8 conversion using RNE
template <typename Y, typename X>
__host__ __device__ constexpr Y mxf8_convert_rne(X x, float scale);

// convert fp32 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_ocp_t mxf8_convert_rne<f8_ocp_t, float>(float x, float scale)
{
    return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
}

// convert fp32 to bf8 with rounding to nearest even
template <>
inline __host__ __device__ bf8_ocp_t mxf8_convert_rne<bf8_ocp_t, float>(float x, float scale)
{
    return bf8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
}

// convert fp32x2 to fp8x2 with rounding to nearest even
template <>
inline __host__ __device__ f8x2_ocp_t mxf8_convert_rne<f8x2_ocp_t, float2_t>(float2_t x,
                                                                             float scale)
{
    return f8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
}

// convert fp32x2 to bf8x2 with rounding to nearest even
template <>
inline __host__ __device__ bf8x2_ocp_t mxf8_convert_rne<bf8x2_ocp_t, float2_t>(float2_t x,
                                                                               float scale)
{
    return bf8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
}

// convert fp32x16 to fp8x16 with rounding to nearest even
template <>
inline __host__ __device__ f8x16_ocp_t mxf8_convert_rne<f8x16_ocp_t, float16_t>(float16_t x,
                                                                                float scale)
{
    union
    {
        float16_t float_1x16;
        float2_t float_2x8[8];
    } in{x};

    union
    {
        f8x16_ocp_t fp8_1x16;
        f8x2_ocp_t fp8_2x8[8];
    } out{};

    ck::static_for<0, 8, 1>{}(
        [&](auto i) { out.fp8_2x8[i] = mxf8_convert_rne<f8x2_ocp_t>(in.float_2x8[i], scale); });

    return out.fp8_1x16;
}

// convert fp32x16 to bf8x16 with rounding to nearest even
template <>
inline __host__ __device__ bf8x16_ocp_t mxf8_convert_rne<bf8x16_ocp_t, float16_t>(float16_t x,
                                                                                  float scale)
{
    union
    {
        float16_t float_1x16;
        float2_t float_2x8[8];
    } in{x};

    union
    {
        bf8x16_ocp_t bf8_1x16;
        bf8x2_ocp_t bf8_2x8[8];
    } out{};

    ck::static_for<0, 8, 1>{}(
        [&](auto i) { out.bf8_2x8[i] = mxf8_convert_rne<bf8x2_ocp_t>(in.float_2x8[i], scale); });

    return out.bf8_1x16;
}

// convert fp32x32 to fp8x32 with rounding to nearest even
template <>
inline __host__ __device__ f8x32_ocp_t mxf8_convert_rne<f8x32_ocp_t, float32_t>(float32_t x,
                                                                                float scale)
{
    union
    {
        float32_t float_1x32;
        float16_t float_16x2[2];
    } in{x};

    union
    {
        f8x32_ocp_t fp8_1x32;
        f8x16_ocp_t fp8_16x2[2];
    } out{};

    ck::static_for<0, 2, 1>{}(
        [&](auto i) { out.fp8_16x2[i] = mxf8_convert_rne<f8x16_ocp_t>(in.float_16x2[i], scale); });

    return out.fp8_1x32;
}

// convert fp32x32 to bf8x32 with rounding to nearest even
template <>
inline __host__ __device__ bf8x32_ocp_t mxf8_convert_rne<bf8x32_ocp_t, float32_t>(float32_t x,
                                                                                  float scale)
{
    union
    {
        float32_t float_1x32;
        float16_t float_16x2[2];
    } in{x};

    union
    {
        bf8x32_ocp_t bf8_1x32;
        bf8x16_ocp_t bf8_16x2[2];
    } out{};

    ck::static_for<0, 2, 1>{}(
        [&](auto i) { out.bf8_16x2[i] = mxf8_convert_rne<bf8x16_ocp_t>(in.float_16x2[i], scale); });

    return out.bf8_1x32;
}

// convert fp32 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_ocp_t mxf8_convert_sr<f8_ocp_t, float>(float x, float scale)
{
    return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
}

// convert fp32 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_ocp_t mxf8_convert_sr<bf8_ocp_t, float>(float x, float scale)
{
    return bf8_ocp_t{
        fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
}

// convert fp32x2 to fp8x2 with stochastic rounding
template <>
inline __host__ __device__ f8x2_ocp_t mxf8_convert_sr<f8x2_ocp_t, float2_t>(float2_t x, float scale)
{
    return f8x2_ocp_t{
        fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
}

// convert fp32x2 to bf8x2 with stochastic rounding
template <>
inline __host__ __device__ bf8x2_ocp_t mxf8_convert_sr<bf8x2_ocp_t, float2_t>(float2_t x,
                                                                              float scale)
{
    return bf8x2_ocp_t{
        fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
}

// convert fp32x16 to fp8x16 with stochastic rounding
template <>
inline __host__ __device__ f8x16_ocp_t mxf8_convert_sr<f8x16_ocp_t, float16_t>(float16_t x,
                                                                               float scale)
{
    union
    {
        float16_t float_1x16;
        float2_t float_2x8[8];
    } in{x};

    union
    {
        f8x16_ocp_t fp8_1x16;
        f8x2_ocp_t fp8_2x8[8];
    } out{};

    ck::static_for<0, 8, 1>{}(
        [&](auto i) { out.fp8_2x8[i] = mxf8_convert_sr<f8x2_ocp_t>(in.float_2x8[i], scale); });

    return out.fp8_1x16;
}

// convert fp32x16 to bf8x16 with stochastic rounding
template <>
inline __host__ __device__ bf8x16_ocp_t mxf8_convert_sr<bf8x16_ocp_t, float16_t>(float16_t x,
                                                                                 float scale)
{
    union
    {
        float16_t float_1x16;
        float2_t float_2x8[8];
    } in{x};

    union
    {
        bf8x16_ocp_t bf8_1x16;
        bf8x2_ocp_t bf8_2x8[8];
    } out{};

    ck::static_for<0, 8, 1>{}(
        [&](auto i) { out.bf8_2x8[i] = mxf8_convert_sr<bf8x2_ocp_t>(in.float_2x8[i], scale); });

    return out.bf8_1x16;
}

// convert fp32x32 to fp8x32 with stochastic rounding
template <>
inline __host__ __device__ f8x32_ocp_t mxf8_convert_sr<f8x32_ocp_t, float32_t>(float32_t x,
                                                                               float scale)
{
    union
    {
        float32_t float_1x32;
        float16_t float_16x2[2];
    } in{x};

    union
    {
        f8x32_ocp_t fp8_1x32;
        f8x16_ocp_t fp8_16x2[2];
    } out{};

    ck::static_for<0, 2, 1>{}(
        [&](auto i) { out.fp8_16x2[i] = mxf8_convert_sr<f8x16_ocp_t>(in.float_16x2[i], scale); });

    return out.fp8_1x32;
}

// convert fp32x32 to bf8x32 with stochastic rounding
template <>
inline __host__ __device__ bf8x32_ocp_t mxf8_convert_sr<bf8x32_ocp_t, float32_t>(float32_t x,
                                                                                 float scale)
{
    union
    {
        float32_t float_1x32;
        float16_t float_16x2[2];
    } in{x};

    union
    {
        bf8x32_ocp_t bf8_1x32;
        bf8x16_ocp_t bf8_16x2[2];
    } out{};

    ck::static_for<0, 2, 1>{}(
        [&](auto i) { out.bf8_16x2[i] = mxf8_convert_sr<bf8x16_ocp_t>(in.float_16x2[i], scale); });

    return out.bf8_1x32;
}

} // namespace ck