test_util.hpp 5.78 KB
Newer Older
1
2
3
#ifndef TEST_UTIL_HPP
#define TEST_UTIL_HPP

4
#include <algorithm>
5
6
7
8
#include <cmath>
#include <cstdlib>
#include <iostream>
#include <iomanip>
9
#include <iterator>
10
11
12
13
#include <limits>
#include <type_traits>
#include <vector>

14
15
16
#include "data_type.hpp"

namespace test {
17
18

template <typename T>
19
20
typename std::enable_if<std::is_floating_point<T>::value && !std::is_same<T, ck::half_t>::value,
                        bool>::type
21
22
23
check_err(const std::vector<T>& out,
          const std::vector<T>& ref,
          const std::string& msg,
24
25
          double rtol = 1e-5,
          double atol = 1e-8)
26
27
28
29
30
31
32
33
34
35
{
    if(out.size() != ref.size())
    {
        std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size()
                  << std::endl
                  << msg << std::endl;
        return false;
    }

    bool res{true};
36
37
38
    int err_count  = 0;
    double err     = 0;
    double max_err = std::numeric_limits<double>::min();
39
40
41
42
43
44
45
46
47
48
    for(std::size_t i = 0; i < ref.size(); ++i)
    {
        err = std::abs(out[i] - ref[i]);
        if(err > atol + rtol * std::abs(ref[i]) || !std::isfinite(out[i]) || !std::isfinite(ref[i]))
        {
            max_err = err > max_err ? err : max_err;
            err_count++;
            if(err_count < 5)
            {
                std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref["
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
                          << i << "]: " << out[i] << " != " << ref[i] << std::endl
                          << msg << std::endl;
            }
            res = false;
        }
    }
    if(!res)
    {
        std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
    }
    return res;
}

template <typename T>
typename std::enable_if<std::is_same<T, ck::bhalf_t>::value || std::is_same<T, ck::half_t>::value,
                        bool>::type
check_err(const std::vector<T>& out,
          const std::vector<T>& ref,
          const std::string& msg,
          double rtol = 1e-5,
          double atol = 1e-8)
{
    if(out.size() != ref.size())
    {
        std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size()
                  << std::endl
                  << msg << std::endl;
        return false;
    }

    bool res{true};
    int err_count  = 0;
    double err     = 0;
    double max_err = ck::type_convert<float>(ck::NumericLimits<T>::Min());
    for(std::size_t i = 0; i < ref.size(); ++i)
    {
        float o = ck::type_convert<float>(out[i]);
        float r = ck::type_convert<float>(ref[i]);
        err     = std::abs(o - r);
        if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
        {
            max_err = err > max_err ? err : max_err;
            err_count++;
            if(err_count < 5)
            {
                std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref["
                          << i << "]: " << o << " != " << r << std::endl
96
                          << msg << std::endl;
97
98
99
100
101
102
103
104
105
106
107
108
            }
            res = false;
        }
    }
    if(!res)
    {
        std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
    }
    return res;
}

bool check_err(const std::vector<_Float16>& out,
Chao Liu's avatar
Chao Liu committed
109
110
111
112
               const std::vector<_Float16>& ref,
               const std::string& msg,
               _Float16 rtol = static_cast<_Float16>(1e-3f),
               _Float16 atol = static_cast<_Float16>(1e-3f))
113
114
115
116
117
118
119
120
121
122
{
    if(out.size() != ref.size())
    {
        std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size()
                  << std::endl
                  << msg << std::endl;
        return false;
    }

    bool res{true};
Chao Liu's avatar
Chao Liu committed
123
124
125
    int err_count  = 0;
    double err     = 0;
    double max_err = std::numeric_limits<_Float16>::min();
126
127
128
129
    for(std::size_t i = 0; i < ref.size(); ++i)
    {
        double out_ = double(out[i]);
        double ref_ = double(ref[i]);
Chao Liu's avatar
Chao Liu committed
130
        err         = std::abs(out_ - ref_);
131
132
133
134
135
136
137
138
139
        if(err > atol + rtol * std::abs(ref_) || !std::isfinite(out_) || !std::isfinite(ref_))
        {
            max_err = err > max_err ? err : max_err;
            err_count++;
            if(err_count < 5)
            {
                std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref["
                          << i << "]: " << out_ << "!=" << ref_ << std::endl
                          << msg << std::endl;
140
141
142
143
144
145
146
147
148
149
150
151
            }
            res = false;
        }
    }
    if(!res)
    {
        std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
    }
    return res;
}

template <typename T>
152
153
154
155
156
157
158
typename std::enable_if<std::is_integral<T>::value && !std::is_same<T, ck::bhalf_t>::value,
                        bool>::type
check_err(const std::vector<T>& out,
          const std::vector<T>& ref,
          const std::string& msg,
          double = 0,
          double = 0)
159
160
161
162
163
164
165
166
167
168
169
170
171
{
    if(out.size() != ref.size())
    {
        std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size()
                  << std::endl
                  << msg << std::endl;
        return false;
    }

    for(std::size_t i = 0; i < ref.size(); ++i)
    {
        if(out[i] != ref[i])
        {
172
            std::cout << "out[" << i << "] != ref[" << i << "]: " << out[i] << " != " << ref[i]
173
174
175
176
177
178
179
180
                      << std::endl
                      << msg << std::endl;
            return false;
        }
    }
    return true;
}

181
182
183
184
185
186
187
188
} // namespace test

template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
    std::copy(std::begin(v), std::end(v), std::ostream_iterator<T>(os, " "));
    return os;
}
189
190

#endif