verify_args.cpp 4.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
24
25
26
27
28
29
30

#include <migraphx/verify_args.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

bool verify_args(const std::string& name,
31
                 const argument& target_arg,
32
33
                 const verify::expected<argument>& ref_arg,
                 verify::tolerance tols)
34
35
{
    bool passed = true;
36
37
38
39
    visit_all(ref_arg.data(), target_arg)([&](auto ref, auto target) {
        double rms_error;
        passed =
            verify::verify_range_with_tolerance(target, verify::expected{ref}, tols, &rms_error);
40
41
42
43
        if(not passed)
        {
            // TODO: Check for nans
            std::cout << "FAILED: " << name << std::endl;
44
            std::cout << "RMS Error: " << rms_error << std::endl;
45
46
47
48
            if(ref.size() < 32)
                std::cout << "ref:" << ref << std::endl;
            if(target.size() < 32)
                std::cout << "target:" << target << std::endl;
Umang Yadav's avatar
Umang Yadav committed
49
            if(verify::range_zero(ref))
50
                std::cout << "Ref data is all zeros" << std::endl;
Umang Yadav's avatar
Umang Yadav committed
51
            if(verify::range_zero(target))
52
                std::cout << "Target data is all zeros" << std::endl;
53

Umang Yadav's avatar
Umang Yadav committed
54
            auto mxdiff = verify::max_diff(ref, target);
55
56
            std::cout << "Max diff: " << mxdiff << std::endl;

Umang Yadav's avatar
Umang Yadav committed
57
58
            auto idx = verify::mismatch_idx(ref, target, float_equal);
            if(idx < verify::range_distance(ref))
59
            {
60
                std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx]
61
62
63
                          << std::endl;
            }

Umang Yadav's avatar
Umang Yadav committed
64
            auto ref_nan_idx = find_idx(ref, verify::not_finite);
65
66
67
            if(ref_nan_idx >= 0)
                std::cout << "Non finite number found in ref at " << ref_nan_idx << ": "
                          << ref[ref_nan_idx] << std::endl;
68

Umang Yadav's avatar
Umang Yadav committed
69
            auto target_nan_idx = find_idx(target, verify::not_finite);
70
71
72
            if(target_nan_idx >= 0)
                std::cout << "Non finite number found in target at " << target_nan_idx << ": "
                          << target[target_nan_idx] << std::endl;
73
74
75
76
            std::cout << std::endl;
        }
        else
        {
Umang Yadav's avatar
Umang Yadav committed
77
            if(verify::range_zero(ref))
78
                std::cout << "Ref data is all zeros" << std::endl;
Umang Yadav's avatar
Umang Yadav committed
79
            if(verify::range_zero(target))
80
                std::cout << "Target data is all zeros" << std::endl;
81

Umang Yadav's avatar
Umang Yadav committed
82
            auto ref_nan_idx = find_idx(ref, verify::not_finite);
83
84
85
            if(ref_nan_idx >= 0)
                std::cout << "Non finite number found in ref at " << ref_nan_idx << ": "
                          << ref[ref_nan_idx] << std::endl;
86

Umang Yadav's avatar
Umang Yadav committed
87
            auto target_nan_idx = find_idx(target, verify::not_finite);
88
89
90
            if(target_nan_idx >= 0)
                std::cout << "Non finite number found in target at " << target_nan_idx << ": "
                          << target[target_nan_idx] << std::endl;
91
92
93
94
95
        }
    });
    return passed;
}

96
97
98
99
100
101
102
103
104
105
106
bool verify_args_with_tolerance(const std::string& name,
                                const argument& target_arg,
                                const verify::expected<argument>& ref_arg,
                                std::size_t tolerance)
{
    double rms_tol = 0.001;
    target_arg.visit([&](auto ta) { rms_tol = verify::get_rms_tol(ta, tolerance); });
    verify::tolerance tols{rms_tol};
    return verify_args(name, target_arg, ref_arg, tols);
}

107
108
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx