check_err.hpp 6.98 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
2
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
3
4
5
6
7
8
9
10
11
12
13
14
15

#pragma once

#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <iostream>
#include <iomanip>
#include <iterator>
#include <limits>
#include <type_traits>
#include <vector>

16
#include "ck/ck.hpp"
Chao Liu's avatar
Chao Liu committed
17
#include "ck/utility/data_type.hpp"
18
#include "ck/utility/type.hpp"
19
#include "ck/host_utility/io.hpp"
Chao Liu's avatar
Chao Liu committed
20

21
22
#include "ck/library/utility/ranges.hpp"

Chao Liu's avatar
Chao Liu committed
23
24
25
namespace ck {
namespace utils {

26
27
28
29
30
31
32
33
template <typename Range, typename RefRange>
typename std::enable_if<
    std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
        std::is_floating_point_v<ranges::range_value_t<Range>> &&
        !std::is_same_v<ranges::range_value_t<Range>, half_t>,
    bool>::type
check_err(const Range& out,
          const RefRange& ref,
Chao Liu's avatar
Chao Liu committed
34
35
36
37
38
39
          const std::string& msg = "Error: Incorrect results!",
          double rtol            = 1e-5,
          double atol            = 3e-6)
{
    if(out.size() != ref.size())
    {
40
        std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
41
                  << std::endl;
Chao Liu's avatar
Chao Liu committed
42
43
44
45
46
47
48
49
50
        return false;
    }

    bool res{true};
    int err_count  = 0;
    double err     = 0;
    double max_err = std::numeric_limits<double>::min();
    for(std::size_t i = 0; i < ref.size(); ++i)
    {
51
52
53
54
        const double o = *std::next(std::begin(out), i);
        const double r = *std::next(std::begin(ref), i);
        err            = std::abs(o - r);
        if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
Chao Liu's avatar
Chao Liu committed
55
56
57
58
59
        {
            max_err = err > max_err ? err : max_err;
            err_count++;
            if(err_count < 5)
            {
60
                std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
61
                          << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
Chao Liu's avatar
Chao Liu committed
62
63
64
65
66
67
            }
            res = false;
        }
    }
    if(!res)
    {
68
        std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
Chao Liu's avatar
Chao Liu committed
69
70
71
72
    }
    return res;
}

73
74
75
76
77
78
79
template <typename Range, typename RefRange>
typename std::enable_if<
    std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
        std::is_same_v<ranges::range_value_t<Range>, bhalf_t>,
    bool>::type
check_err(const Range& out,
          const RefRange& ref,
Chao Liu's avatar
Chao Liu committed
80
81
82
83
84
85
          const std::string& msg = "Error: Incorrect results!",
          double rtol            = 1e-3,
          double atol            = 1e-3)
{
    if(out.size() != ref.size())
    {
86
        std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
87
                  << std::endl;
Chao Liu's avatar
Chao Liu committed
88
89
90
91
92
93
94
95
96
97
        return false;
    }

    bool res{true};
    int err_count = 0;
    double err    = 0;
    // TODO: This is a hack. We should have proper specialization for bhalf_t data type.
    double max_err = std::numeric_limits<float>::min();
    for(std::size_t i = 0; i < ref.size(); ++i)
    {
98
99
100
        const double o = type_convert<float>(*std::next(std::begin(out), i));
        const double r = type_convert<float>(*std::next(std::begin(ref), i));
        err            = std::abs(o - r);
Chao Liu's avatar
Chao Liu committed
101
102
103
104
105
106
        if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
        {
            max_err = err > max_err ? err : max_err;
            err_count++;
            if(err_count < 5)
            {
107
                std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
108
                          << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
Chao Liu's avatar
Chao Liu committed
109
110
111
112
113
114
            }
            res = false;
        }
    }
    if(!res)
    {
115
        std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
Chao Liu's avatar
Chao Liu committed
116
117
118
119
    }
    return res;
}

120
121
122
123
124
125
126
template <typename Range, typename RefRange>
typename std::enable_if<
    std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
        std::is_same_v<ranges::range_value_t<Range>, half_t>,
    bool>::type
check_err(const Range& out,
          const RefRange& ref,
Chao Liu's avatar
Chao Liu committed
127
128
129
130
131
132
          const std::string& msg = "Error: Incorrect results!",
          double rtol            = 1e-3,
          double atol            = 1e-3)
{
    if(out.size() != ref.size())
    {
133
        std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
134
                  << std::endl;
Chao Liu's avatar
Chao Liu committed
135
136
137
138
139
140
        return false;
    }

    bool res{true};
    int err_count  = 0;
    double err     = 0;
141
    double max_err = std::numeric_limits<ranges::range_value_t<Range>>::min();
Chao Liu's avatar
Chao Liu committed
142
143
    for(std::size_t i = 0; i < ref.size(); ++i)
    {
144
145
146
        const double o = type_convert<float>(*std::next(std::begin(out), i));
        const double r = type_convert<float>(*std::next(std::begin(ref), i));
        err            = std::abs(o - r);
Chao Liu's avatar
Chao Liu committed
147
148
149
150
151
152
        if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
        {
            max_err = err > max_err ? err : max_err;
            err_count++;
            if(err_count < 5)
            {
153
                std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
154
                          << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
Chao Liu's avatar
Chao Liu committed
155
156
157
158
159
160
            }
            res = false;
        }
    }
    if(!res)
    {
161
        std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
Chao Liu's avatar
Chao Liu committed
162
163
164
165
    }
    return res;
}

166
167
168
169
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
                  std::is_integral_v<ranges::range_value_t<Range>> &&
                  !std::is_same_v<ranges::range_value_t<Range>, bhalf_t>)
170
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
171
                     || std::is_same_v<ranges::range_value_t<Range>, int4_t>
172
173
174
#endif
                 ,
                 bool>
175
176
check_err(const Range& out,
          const RefRange& ref,
Chao Liu's avatar
Chao Liu committed
177
178
          const std::string& msg = "Error: Incorrect results!",
          double                 = 0,
179
          double atol            = 0)
Chao Liu's avatar
Chao Liu committed
180
181
182
{
    if(out.size() != ref.size())
    {
183
        std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
184
                  << std::endl;
Chao Liu's avatar
Chao Liu committed
185
186
187
188
189
190
191
192
193
        return false;
    }

    bool res{true};
    int err_count   = 0;
    int64_t err     = 0;
    int64_t max_err = std::numeric_limits<int64_t>::min();
    for(std::size_t i = 0; i < ref.size(); ++i)
    {
194
195
196
        const int64_t o = *std::next(std::begin(out), i);
        const int64_t r = *std::next(std::begin(ref), i);
        err             = std::abs(o - r);
Chao Liu's avatar
Chao Liu committed
197

198
        if(err > atol)
Chao Liu's avatar
Chao Liu committed
199
200
201
202
203
        {
            max_err = err > max_err ? err : max_err;
            err_count++;
            if(err_count < 5)
            {
204
                std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
205
                          << std::endl;
Chao Liu's avatar
Chao Liu committed
206
207
208
209
210
211
            }
            res = false;
        }
    }
    if(!res)
    {
212
        std::cerr << "max err: " << max_err << std::endl;
Chao Liu's avatar
Chao Liu committed
213
214
215
216
217
218
    }
    return res;
}

} // namespace utils
} // namespace ck