check_err.hpp 16.2 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
// SPDX-License-Identifier: MIT
2
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
3
4
5
6
7
8
9
10
11
12
13
14
15

#pragma once

#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <iostream>
#include <iomanip>
#include <iterator>
#include <limits>
#include <type_traits>
#include <vector>

16
#include "ck/ck.hpp"
Chao Liu's avatar
Chao Liu committed
17
#include "ck/utility/data_type.hpp"
18
#include "ck/utility/type.hpp"
19
#include "ck/host_utility/io.hpp"
Chao Liu's avatar
Chao Liu committed
20

21
22
#include "ck/library/utility/ranges.hpp"

Chao Liu's avatar
Chao Liu committed
23
24
25
namespace ck {
namespace utils {

26
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
27
double get_relative_threshold(const int number_of_accumulations = 1)
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
{
    using F8   = ck::f8_t;
    using F16  = ck::half_t;
    using BF16 = ck::bhalf_t;
    using F32  = float;
    using I8   = int8_t;
    using I32  = int32_t;

    static_assert(is_same_v<ComputeDataType, F8> || is_same_v<ComputeDataType, F16> ||
                      is_same_v<ComputeDataType, BF16> || is_same_v<ComputeDataType, F32> ||
                      is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
                      is_same_v<ComputeDataType, int>,
                  "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
    double compute_error = 0;
    if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
                 is_same_v<ComputeDataType, int>)
    {
        return 0;
    }
    else
    {
        compute_error = std::pow(2, -NumericUtils<ComputeDataType>::mant) * 0.5;
    }

    static_assert(is_same_v<OutDataType, F8> || is_same_v<OutDataType, F16> ||
                      is_same_v<OutDataType, BF16> || is_same_v<OutDataType, F32> ||
                      is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
                      is_same_v<OutDataType, int>,
                  "Warning: Unhandled OutDataType for setting up the relative threshold!");
    double output_error = 0;
    if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
                 is_same_v<OutDataType, int>)
    {
        return 0;
    }
    else
    {
        output_error = std::pow(2, -NumericUtils<OutDataType>::mant) * 0.5;
    }
    double midway_error = std::max(compute_error, output_error);

    static_assert(is_same_v<AccDataType, F8> || is_same_v<AccDataType, F16> ||
                      is_same_v<AccDataType, BF16> || is_same_v<AccDataType, F32> ||
                      is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
                      is_same_v<AccDataType, int>,
                  "Warning: Unhandled AccDataType for setting up the relative threshold!");
    double acc_error = 0;
    if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
                 is_same_v<AccDataType, int>)
    {
        return 0;
    }
    else
    {
82
        acc_error = std::pow(2, -NumericUtils<AccDataType>::mant) * 0.5 * number_of_accumulations;
83
84
85
86
87
    }
    return std::max(acc_error, midway_error);
}

template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
88
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
{
    using F8   = ck::f8_t;
    using F16  = ck::half_t;
    using BF16 = ck::bhalf_t;
    using F32  = float;
    using I8   = int8_t;
    using I32  = int32_t;

    static_assert(is_same_v<ComputeDataType, F8> || is_same_v<ComputeDataType, F16> ||
                      is_same_v<ComputeDataType, BF16> || is_same_v<ComputeDataType, F32> ||
                      is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
                      is_same_v<ComputeDataType, int>,
                  "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
    auto expo            = std::log2(std::abs(max_possible_num));
    double compute_error = 0;
    if constexpr(is_same_v<ComputeDataType, I8> || is_same_v<ComputeDataType, I32> ||
                 is_same_v<ComputeDataType, int>)
    {
        return 0;
    }
    else
    {
        compute_error = std::pow(2, expo - NumericUtils<ComputeDataType>::mant) * 0.5;
    }

    static_assert(is_same_v<OutDataType, F8> || is_same_v<OutDataType, F16> ||
                      is_same_v<OutDataType, BF16> || is_same_v<OutDataType, F32> ||
                      is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
                      is_same_v<OutDataType, int>,
                  "Warning: Unhandled OutDataType for setting up the absolute threshold!");
    double output_error = 0;
    if constexpr(is_same_v<OutDataType, I8> || is_same_v<OutDataType, I32> ||
                 is_same_v<OutDataType, int>)
    {
        return 0;
    }
    else
    {
        output_error = std::pow(2, expo - NumericUtils<OutDataType>::mant) * 0.5;
    }
    double midway_error = std::max(compute_error, output_error);

    static_assert(is_same_v<AccDataType, F8> || is_same_v<AccDataType, F16> ||
                      is_same_v<AccDataType, BF16> || is_same_v<AccDataType, F32> ||
                      is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
                      is_same_v<AccDataType, int>,
                  "Warning: Unhandled AccDataType for setting up the absolute threshold!");
    double acc_error = 0;
    if constexpr(is_same_v<AccDataType, I8> || is_same_v<AccDataType, I32> ||
                 is_same_v<AccDataType, int>)
    {
        return 0;
    }
    else
    {
        acc_error =
145
            std::pow(2, expo - NumericUtils<AccDataType>::mant) * 0.5 * number_of_accumulations;
146
147
148
149
    }
    return std::max(acc_error, midway_error);
}

150
151
152
153
154
155
156
157
template <typename Range, typename RefRange>
typename std::enable_if<
    std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
        std::is_floating_point_v<ranges::range_value_t<Range>> &&
        !std::is_same_v<ranges::range_value_t<Range>, half_t>,
    bool>::type
check_err(const Range& out,
          const RefRange& ref,
Chao Liu's avatar
Chao Liu committed
158
159
160
161
162
163
          const std::string& msg = "Error: Incorrect results!",
          double rtol            = 1e-5,
          double atol            = 3e-6)
{
    if(out.size() != ref.size())
    {
164
        std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
165
                  << std::endl;
Chao Liu's avatar
Chao Liu committed
166
167
168
169
170
171
172
173
174
        return false;
    }

    bool res{true};
    int err_count  = 0;
    double err     = 0;
    double max_err = std::numeric_limits<double>::min();
    for(std::size_t i = 0; i < ref.size(); ++i)
    {
175
176
177
178
        const double o = *std::next(std::begin(out), i);
        const double r = *std::next(std::begin(ref), i);
        err            = std::abs(o - r);
        if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
Chao Liu's avatar
Chao Liu committed
179
180
181
182
183
        {
            max_err = err > max_err ? err : max_err;
            err_count++;
            if(err_count < 5)
            {
184
                std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
185
                          << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
Chao Liu's avatar
Chao Liu committed
186
187
188
189
190
191
            }
            res = false;
        }
    }
    if(!res)
    {
192
193
194
195
196
        const float error_percent =
            static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
        std::cerr << "max err: " << max_err;
        std::cerr << ", number of errors: " << err_count;
        std::cerr << ", " << error_percent << "% wrong values" << std::endl;
Chao Liu's avatar
Chao Liu committed
197
198
199
200
    }
    return res;
}

201
202
203
204
205
206
207
template <typename Range, typename RefRange>
typename std::enable_if<
    std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
        std::is_same_v<ranges::range_value_t<Range>, bhalf_t>,
    bool>::type
check_err(const Range& out,
          const RefRange& ref,
Chao Liu's avatar
Chao Liu committed
208
          const std::string& msg = "Error: Incorrect results!",
209
          double rtol            = 1e-1,
Chao Liu's avatar
Chao Liu committed
210
211
212
213
          double atol            = 1e-3)
{
    if(out.size() != ref.size())
    {
214
        std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
215
                  << std::endl;
Chao Liu's avatar
Chao Liu committed
216
217
218
219
220
221
222
223
224
225
        return false;
    }

    bool res{true};
    int err_count = 0;
    double err    = 0;
    // TODO: This is a hack. We should have proper specialization for bhalf_t data type.
    double max_err = std::numeric_limits<float>::min();
    for(std::size_t i = 0; i < ref.size(); ++i)
    {
226
227
228
        const double o = type_convert<float>(*std::next(std::begin(out), i));
        const double r = type_convert<float>(*std::next(std::begin(ref), i));
        err            = std::abs(o - r);
Chao Liu's avatar
Chao Liu committed
229
230
231
232
233
234
        if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
        {
            max_err = err > max_err ? err : max_err;
            err_count++;
            if(err_count < 5)
            {
235
                std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
236
                          << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
Chao Liu's avatar
Chao Liu committed
237
238
239
240
241
242
            }
            res = false;
        }
    }
    if(!res)
    {
243
244
245
246
247
        const float error_percent =
            static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
        std::cerr << "max err: " << max_err;
        std::cerr << ", number of errors: " << err_count;
        std::cerr << ", " << error_percent << "% wrong values" << std::endl;
Chao Liu's avatar
Chao Liu committed
248
249
250
251
    }
    return res;
}

252
253
254
255
256
257
258
template <typename Range, typename RefRange>
typename std::enable_if<
    std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
        std::is_same_v<ranges::range_value_t<Range>, half_t>,
    bool>::type
check_err(const Range& out,
          const RefRange& ref,
Chao Liu's avatar
Chao Liu committed
259
260
261
262
263
264
          const std::string& msg = "Error: Incorrect results!",
          double rtol            = 1e-3,
          double atol            = 1e-3)
{
    if(out.size() != ref.size())
    {
265
        std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
266
                  << std::endl;
Chao Liu's avatar
Chao Liu committed
267
268
269
270
271
272
        return false;
    }

    bool res{true};
    int err_count  = 0;
    double err     = 0;
273
    double max_err = NumericLimits<ranges::range_value_t<Range>>::Min();
Chao Liu's avatar
Chao Liu committed
274
275
    for(std::size_t i = 0; i < ref.size(); ++i)
    {
276
277
278
        const double o = type_convert<float>(*std::next(std::begin(out), i));
        const double r = type_convert<float>(*std::next(std::begin(ref), i));
        err            = std::abs(o - r);
Chao Liu's avatar
Chao Liu committed
279
280
281
282
283
284
        if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
        {
            max_err = err > max_err ? err : max_err;
            err_count++;
            if(err_count < 5)
            {
285
                std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
286
                          << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
Chao Liu's avatar
Chao Liu committed
287
288
289
290
291
292
            }
            res = false;
        }
    }
    if(!res)
    {
293
294
295
296
297
        const float error_percent =
            static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
        std::cerr << "max err: " << max_err;
        std::cerr << ", number of errors: " << err_count;
        std::cerr << ", " << error_percent << "% wrong values" << std::endl;
Chao Liu's avatar
Chao Liu committed
298
299
300
301
    }
    return res;
}

302
303
304
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
                  std::is_integral_v<ranges::range_value_t<Range>> &&
305
306
307
                  !std::is_same_v<ranges::range_value_t<Range>, bhalf_t> &&
                  !std::is_same_v<ranges::range_value_t<Range>, f8_t> &&
                  !std::is_same_v<ranges::range_value_t<Range>, bf8_t>)
308
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
309
                     || std::is_same_v<ranges::range_value_t<Range>, int4_t>
310
311
312
#endif
                 ,
                 bool>
313
314
check_err(const Range& out,
          const RefRange& ref,
Chao Liu's avatar
Chao Liu committed
315
316
          const std::string& msg = "Error: Incorrect results!",
          double                 = 0,
317
          double atol            = 0)
Chao Liu's avatar
Chao Liu committed
318
319
320
{
    if(out.size() != ref.size())
    {
321
        std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
322
                  << std::endl;
Chao Liu's avatar
Chao Liu committed
323
324
325
326
327
328
329
330
331
        return false;
    }

    bool res{true};
    int err_count   = 0;
    int64_t err     = 0;
    int64_t max_err = std::numeric_limits<int64_t>::min();
    for(std::size_t i = 0; i < ref.size(); ++i)
    {
332
333
334
        const int64_t o = *std::next(std::begin(out), i);
        const int64_t r = *std::next(std::begin(ref), i);
        err             = std::abs(o - r);
Chao Liu's avatar
Chao Liu committed
335

336
        if(err > atol)
Chao Liu's avatar
Chao Liu committed
337
338
339
340
341
        {
            max_err = err > max_err ? err : max_err;
            err_count++;
            if(err_count < 5)
            {
342
                std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
343
                          << std::endl;
Chao Liu's avatar
Chao Liu committed
344
345
346
347
348
349
            }
            res = false;
        }
    }
    if(!res)
    {
350
351
352
353
354
        const float error_percent =
            static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
        std::cerr << "max err: " << max_err;
        std::cerr << ", number of errors: " << err_count;
        std::cerr << ", " << error_percent << "% wrong values" << std::endl;
Chao Liu's avatar
Chao Liu committed
355
356
357
358
    }
    return res;
}

359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
                  std::is_same_v<ranges::range_value_t<Range>, f8_t>),
                 bool>
check_err(const Range& out,
          const RefRange& ref,
          const std::string& msg = "Error: Incorrect results!",
          double rtol            = 1e-3,
          double atol            = 1e-3)
{
    if(out.size() != ref.size())
    {
        std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
                  << std::endl;
        return false;
    }

    bool res{true};
    int err_count  = 0;
    double err     = 0;
    double max_err = std::numeric_limits<float>::min();
380

381
382
383
384
385
    for(std::size_t i = 0; i < ref.size(); ++i)
    {
        const double o = type_convert<float>(*std::next(std::begin(out), i));
        const double r = type_convert<float>(*std::next(std::begin(ref), i));
        err            = std::abs(o - r);
386

387
388
389
390
391
392
393
394
395
396
397
398
        if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
        {
            max_err = err > max_err ? err : max_err;
            err_count++;
            if(err_count < 5)
            {
                std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
                          << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
            }
            res = false;
        }
    }
399

400
401
    if(!res)
    {
402
403
        std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err
                  << " number of errors: " << err_count << std::endl;
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
    }
    return res;
}

template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
                  std::is_same_v<ranges::range_value_t<Range>, bf8_t>),
                 bool>
check_err(const Range& out,
          const RefRange& ref,
          const std::string& msg = "Error: Incorrect results!",
          double rtol            = 1e-3,
          double atol            = 1e-3)
{
    if(out.size() != ref.size())
    {
        std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
                  << std::endl;
        return false;
    }

    bool res{true};
    int err_count  = 0;
    double err     = 0;
    double max_err = std::numeric_limits<float>::min();
    for(std::size_t i = 0; i < ref.size(); ++i)
    {
        const double o = type_convert<float>(*std::next(std::begin(out), i));
        const double r = type_convert<float>(*std::next(std::begin(ref), i));
        err            = std::abs(o - r);
        if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
        {
            max_err = err > max_err ? err : max_err;
            err_count++;
            if(err_count < 5)
            {
                std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
                          << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
            }
            res = false;
        }
    }
    if(!res)
    {
        std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
    }
    return res;
}

Chao Liu's avatar
Chao Liu committed
453
454
} // namespace utils
} // namespace ck