mxf4_utils.hpp 3.36 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"

namespace ck::utils {

template <>
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
12
__host__ __device__ inline bool is_nan<f4_t>(e8m0_bexp_t const scale,
13
14
15
                                             f4_t const dataBytes [[maybe_unused]])
{
    // no need to check for data as it does not have NaN representation
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
16
    return scale == NumericLimits<e8m0_bexp_t>::QuietNaN();
17
18
19
20
}

// no infinity representation in ocp_e2m1_mxfp4 will always return false
template <>
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
21
__host__ __device__ inline bool is_inf<f4_t>(e8m0_bexp_t const scale [[maybe_unused]],
22
23
24
25
26
27
28
                                             f4_t const data [[maybe_unused]])
{
    // no inf representation for ocp_e2m1_mxfp4
    return false;
}

template <>
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
29
__host__ __device__ inline bool is_zero<f4_t>(e8m0_bexp_t const scale, f4_t const data)
30
31
32
33
34
{
    if(is_nan<f4_t>(scale, data))
        return false;

    // no need to check for scale as it does not have a 0 representation
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
35
    f4_t result = (data & 0b00001111) & NumericUtils<f4_t>::set_sign_mask;
36

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
37
    return result == 0b0;
38
39
40
}

template <>
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
41
__host__ __device__ inline float to_float<f4_t>(e8m0_bexp_t const scale, f4_t const data)
42
43
44
45
46
47
48
{
    if(is_nan<f4_t>(scale, data))
        return std::numeric_limits<float>::quiet_NaN();

    if(is_zero<f4_t>(scale, data))
        return 0.0f;

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
49
    f4_t prepared_data = data & 0b00001111;
50

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
51
    int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
52

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
53
    return convert_to_float<f4_t>(prepared_data, scale_exp);
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
}

template <>
__host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value)
{
    cvt t;
    t.value_float = value;
    uint32_t sign = t.value_bitwise >> 31;

    if(std::isnan(value))
    {

        return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
                    : NumericUtils<f4_t>::data_max_positive_normal_mask;
    }

    if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
        return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
                    : NumericUtils<f4_t>::data_max_positive_normal_mask;

    f4_t res = convert_to_type<f4_t>(value);

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
76
    if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
77
       NumericLimits<f4_t>::DataMinSubnorm())
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
                         : NumericUtils<f4_t>::positive_zero_mask;

    return res;
}

template <>
__host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32_t seed)
{
    cvt t;
    t.value_float = value;
    uint32_t sign = t.value_bitwise >> 31;

    if(std::isnan(value))
        return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
                    : NumericUtils<f4_t>::data_max_positive_normal_mask;

    if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
        return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
                    : NumericUtils<f4_t>::data_max_positive_normal_mask;

    f4_t res = convert_to_type_sr<f4_t>(value, seed);

Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
101
    if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
102
       NumericLimits<f4_t>::DataMinSubnorm())
103
104
105
106
107
108
109
        return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
                         : NumericUtils<f4_t>::positive_zero_mask;

    return res;
}

} // namespace ck::utils