verify_args.cpp 3.14 KB
Newer Older
1
2
3
4
5
6
7

#include <migraphx/verify_args.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

bool verify_args(const std::string& name,
8
9
                 const argument& ref_arg,
                 const argument& target_arg,
10
11
12
                 double tolerance)
{
    bool passed = true;
13
    visit_all(ref_arg, target_arg)([&](auto ref, auto target) {
14
        double error;
15
        passed = verify_range(ref, target, tolerance, &error);
16
17
18
19
20
        if(not passed)
        {
            // TODO: Check for nans
            std::cout << "FAILED: " << name << std::endl;
            std::cout << "error: " << error << std::endl;
21
22
23
24
25
26
27
28
            if(ref.size() < 32)
                std::cout << "ref:" << ref << std::endl;
            if(target.size() < 32)
                std::cout << "target:" << target << std::endl;
            if(range_zero(ref))
                std::cout << "Ref data is all zeros" << std::endl;
            if(range_zero(target))
                std::cout << "Target data is all zeros" << std::endl;
29

30
            auto mxdiff = max_diff(ref, target);
31
32
            std::cout << "Max diff: " << mxdiff << std::endl;

33
34
            auto idx = mismatch_idx(ref, target, float_equal);
            if(idx < range_distance(ref))
35
            {
36
                std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx]
37
38
39
                          << std::endl;
            }

40
41
42
43
            auto ref_nan_idx = find_idx(ref, not_finite);
            if(ref_nan_idx >= 0)
                std::cout << "Non finite number found in ref at " << ref_nan_idx << ": "
                          << ref[ref_nan_idx] << std::endl;
44

45
46
47
48
            auto target_nan_idx = find_idx(target, not_finite);
            if(target_nan_idx >= 0)
                std::cout << "Non finite number found in target at " << target_nan_idx << ": "
                          << target[target_nan_idx] << std::endl;
49
50
51
52
            std::cout << std::endl;
        }
        else
        {
53
54
55
56
            if(range_zero(ref))
                std::cout << "Ref data is all zeros" << std::endl;
            if(range_zero(target))
                std::cout << "Target data is all zeros" << std::endl;
57

58
            // auto mxdiff = max_diff(ref, target);
59
60
            // std::cout << "Max diff: " << mxdiff << std::endl;

61
62
            // auto idx = mismatch_idx(ref, target, float_equal);
            // if(idx < range_distance(ref))
63
            // {
64
            //     std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx]
65
66
67
            //               << std::endl;
            // }

68
69
70
71
            auto ref_nan_idx = find_idx(ref, not_finite);
            if(ref_nan_idx >= 0)
                std::cout << "Non finite number found in ref at " << ref_nan_idx << ": "
                          << ref[ref_nan_idx] << std::endl;
72

73
74
75
76
            auto target_nan_idx = find_idx(target, not_finite);
            if(target_nan_idx >= 0)
                std::cout << "Non finite number found in target at " << target_nan_idx << ": "
                          << target[target_nan_idx] << std::endl;
77
78
79
80
81
82
83
84
            // std::cout << std::endl;
        }
    });
    return passed;
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx