mxf4_utils.hpp 3.35 KB
Newer Older
Illia Silin's avatar
Illia Silin committed
1
// SPDX-License-Identifier: MIT
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
2
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
Illia Silin's avatar
Illia Silin committed
3
4
5
6
7
8
9
10
11
12
13
14
15

#pragma once

#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"

namespace ck::utils {

template <>
__host__ __device__ inline bool is_nan<f4_t>(e8m0_bexp_t const scale,
                                             f4_t const dataBytes [[maybe_unused]])
{
    // no need to check for data as it does not have NaN representation
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
16
    return scale.is_nan();
Illia Silin's avatar
Illia Silin committed
17
18
19
20
21
22
23
24
25
26
27
28
}

// no infinity representation in ocp_e2m1_mxfp4 will always return false
template <>
__host__ __device__ inline bool is_inf<f4_t>(e8m0_bexp_t const scale [[maybe_unused]],
                                             f4_t const data [[maybe_unused]])
{
    // no inf representation for ocp_e2m1_mxfp4
    return false;
}

template <>
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
29
30
__host__ __device__ inline bool is_zero<f4_t>(e8m0_bexp_t const scale [[maybe_unused]],
                                              f4_t const data)
Illia Silin's avatar
Illia Silin committed
31
32
33
34
35
36
37
38
39
40
41
{
    // no need to check for scale as it does not have a 0 representation
    f4_t result = (data & 0b00001111) & NumericUtils<f4_t>::set_sign_mask;

    return result == 0b0;
}

template <>
__host__ __device__ inline float to_float<f4_t>(e8m0_bexp_t const scale, f4_t const data)
{
    if(is_nan<f4_t>(scale, data))
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
42
    {
Illia Silin's avatar
Illia Silin committed
43
        return std::numeric_limits<float>::quiet_NaN();
Rostyslav Geyyer's avatar
Rostyslav Geyyer committed
44
    }
Illia Silin's avatar
Illia Silin committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

    if(is_zero<f4_t>(scale, data))
        return 0.0f;

    f4_t prepared_data = data & 0b00001111;

    int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);

    return convert_to_float<f4_t>(prepared_data, scale_exp);
}

template <>
__host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value)
{
    cvt t;
    t.value_float = value;
    uint32_t sign = t.value_bitwise >> 31;

    if(std::isnan(value))
    {

        return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
                    : NumericUtils<f4_t>::data_max_positive_normal_mask;
    }

    if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
        return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
                    : NumericUtils<f4_t>::data_max_positive_normal_mask;

    f4_t res = convert_to_type<f4_t>(value);

    if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
       NumericLimits<f4_t>::DataMinSubnorm())
        return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
                         : NumericUtils<f4_t>::positive_zero_mask;

    return res;
}

template <>
__host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32_t seed)
{
    cvt t;
    t.value_float = value;
    uint32_t sign = t.value_bitwise >> 31;

    if(std::isnan(value))
        return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
                    : NumericUtils<f4_t>::data_max_positive_normal_mask;

    if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
        return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
                    : NumericUtils<f4_t>::data_max_positive_normal_mask;

    f4_t res = convert_to_type_sr<f4_t>(value, seed);

    if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
       NumericLimits<f4_t>::DataMinSubnorm())
        return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
                         : NumericUtils<f4_t>::positive_zero_mask;

    return res;
}

} // namespace ck::utils