run_verify.cpp 6.78 KB
Newer Older
1
2
3
#include "run_verify.hpp"
#include "auto_print.hpp"
#include "verify_program.hpp"
4
#include "test.hpp"
5
6
#include <migraphx/env.hpp>
#include <migraphx/ref/target.hpp>
7
#include <migraphx/ranges.hpp>
8
9
10
11
12
13
14
15
16
#include <migraphx/generate.hpp>
#include <migraphx/verify_args.hpp>
#include <set>

#include <future>
#include <thread>
#include <utility>

MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_TEST_COMPILE)
17
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_TEST)
18
19
20
21
22
23
24
25
26
27
28
29

// An improved async, that doesn't block
template <class Function>
std::future<typename std::result_of<Function()>::type> detach_async(Function&& f,
                                                                    bool parallel = true)
{
    if(parallel)
    {
        using result_type = typename std::result_of<Function()>::type;
        std::packaged_task<result_type()> task(std::forward<Function>(f));
        auto fut = task.get_future();
        std::thread(std::move(task)).detach();
30
        return fut;
31
32
33
34
35
36
37
38
39
40
    }
    return std::async(std::launch::deferred, std::forward<Function>(f));
}

inline void compile_check(migraphx::program& p, const migraphx::target& t, bool show_trace = false)
{
    auto name   = t.name();
    auto shapes = p.get_output_shapes();
    std::stringstream ss;
    migraphx::compile_options options;
41
42
    if(show_trace)
        options.trace = migraphx::tracer{std::cout};
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    p.compile(t, options);
    if(shapes.size() != p.get_output_shapes().size())
    {
        std::cout << ss.str() << std::endl;
        throw std::runtime_error("Compiling program with " + name +
                                 " alters its number of outputs");
    }

    auto num = shapes.size();
    for(std::size_t i = 0; i < num; ++i)
    {
        if(p.get_output_shapes()[i].lens() != shapes[i].lens())
        {
            std::cout << ss.str() << std::endl;
            throw std::runtime_error("Compiling program with " + name + " alters its shape");
        }
    }
}

target_info run_verify::get_target_info(const std::string& name) const
{
    auto it = info.find(name);
    if(it != info.end())
        return it->second;
    else
        return {};
}

void run_verify::validate(const migraphx::target& t,
                          const migraphx::program& p,
73
                          const migraphx::parameter_map& m) const
74
75
76
77
78
79
80
{
    auto ti = get_target_info(t.name());
    if(ti.validate)
        ti.validate(p, m);
}

std::vector<migraphx::argument> run_verify::run_ref(migraphx::program p,
81
                                                    migraphx::parameter_map inputs) const
82
83
84
85
86
87
{
    migraphx::ref::target t{};
    auto_print pp{p, t.name()};
    compile_check(p, t);
    return p.eval(std::move(inputs));
}
88
89
std::pair<migraphx::program, std::vector<migraphx::argument>> run_verify::run_target(
    const migraphx::target& t, migraphx::program p, const migraphx::parameter_map& inputs) const
90
91
92
93
{
    auto_print pp{p, t.name()};
    auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{});
    compile_check(p, t, (trace_target == t.name()));
94
    migraphx::parameter_map m;
95
96
97
98
99
100
101
    for(auto&& input : inputs)
    {
        m[input.first] = t.copy_to(input.second);
    }
    for(auto&& x : p.get_parameter_shapes())
    {
        if(m.count(x.first) == 0)
Shucai Xiao's avatar
Shucai Xiao committed
102
        {
103
            m[x.first] = t.allocate(x.second);
Shucai Xiao's avatar
Shucai Xiao committed
104
        }
105
106
107
108
109
110
111
112
113
    }
    validate(t, p, m);
    p.eval(m);

    auto tres = p.eval(m);
    std::vector<migraphx::argument> res(tres.size());
    std::transform(
        tres.begin(), tres.end(), res.begin(), [&](auto& argu) { return t.copy_from(argu); });

114
    return std::make_pair(std::move(p), res);
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
}

template <class T>
auto get_hash(const T& x)
{
    return std::hash<T>{}(x);
}

void run_verify::verify(const std::string& name, const migraphx::program& p) const
{
    using result_future =
        std::future<std::pair<migraphx::program, std::vector<migraphx::argument>>>;
    auto_print::set_terminate_handler(name);
    std::vector<std::pair<std::string, result_future>> results;
    std::vector<std::string> target_names;
    for(const auto& tname : migraphx::get_targets())
    {
        if(tname == "ref")
            continue;
134
135
136
137
138
139

        // if tests disabled, skip running it
        target_info ti = get_target_info(tname);
        if(migraphx::contains(ti.disabled_tests, name))
            continue;

140
141
142
143
        target_names.push_back(tname);
    }
    if(not target_names.empty())
    {
144
        migraphx::parameter_map m;
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        for(auto&& x : p.get_parameter_shapes())
        {
            m[x.first] = migraphx::generate_argument(x.second, get_hash(x.first));
        }

        auto gold_f = detach_async([=] { return run_ref(p, m); });
        for(const auto& tname : target_names)
        {
            target_info ti = get_target_info(tname);
            auto t         = migraphx::make_target(tname);
            results.emplace_back(tname,
                                 detach_async([=] { return run_target(t, p, m); }, ti.parallel));
        }

159
        assert(gold_f.valid());
160
161
162
163
        auto gold = gold_f.get();

        for(auto&& pp : results)
        {
164
            assert(pp.second.valid());
165
166
167
168
169
170
171
172
173
174
175
176
177
            auto tname  = pp.first;
            auto x      = pp.second.get();
            auto cp     = x.first;
            auto result = x.second;

            bool passed = true;
            passed &= (gold.size() == result.size());
            std::size_t num = gold.size();
            for(std::size_t i = 0; ((i < num) and passed); ++i)
            {
                passed &= migraphx::verify_args(tname, gold[i], result[i]);
            }

178
            if(not passed or migraphx::enabled(MIGRAPHX_TRACE_TEST{}))
179
180
181
182
183
184
            {
                std::cout << p << std::endl;
                std::cout << "ref:\n" << p << std::endl;
                std::cout << tname << ":\n" << cp << std::endl;
                std::cout << std::endl;
            }
185
            EXPECT(passed);
186
187
188
189
190
191
192
        }
    }
    std::set_terminate(nullptr);
}

void run_verify::run(int argc, const char* argv[]) const
{
193
194
    std::unordered_map<std::string, std::vector<std::string>> labels;
    for(auto&& p : get_programs())
195
    {
196
197
        labels[p.section].push_back(p.name);
        test::add_test_case(p.name, [=] { verify(p.name, p.get_program()); });
198
    }
199
200
201
202
203
204
205
    test::driver d{};
    d.get_case_names = [&](const std::string& name) -> std::vector<std::string> {
        if(labels.count(name) > 0)
            return labels.at(name);
        return {name};
    };
    d.run(argc, argv);
206
207
208
209
210
211
212
}

void run_verify::disable_parallel_for(const std::string& name) { info[name].parallel = false; }
void run_verify::add_validation_for(const std::string& name, target_info::validation_function v)
{
    info[name].validate = std::move(v);
}
213
214
215
216
217
218

void run_verify::disable_test_for(const std::string& name, const std::vector<std::string>& tests)
{
    auto& disabled_tests = info[name].disabled_tests;
    disabled_tests.insert(disabled_tests.end(), tests.begin(), tests.end());
}