run_verify.cpp 9.31 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
24
25
26
#include "run_verify.hpp"
#include "auto_print.hpp"
#include "verify_program.hpp"
27
#include "test.hpp"
28
#include <migraphx/env.hpp>
29
#include <migraphx/register_target.hpp>
30
#include <migraphx/ranges.hpp>
31
#include <migraphx/generate.hpp>
32
#include <migraphx/load_save.hpp>
Paul Fultz II's avatar
Paul Fultz II committed
33
#include <migraphx/tmp_dir.hpp>
34
35
36
37
38
39
40
41
#include <migraphx/verify_args.hpp>
#include <set>

#include <future>
#include <thread>
#include <utility>

MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_TEST_COMPILE)
42
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_TEST)
43
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DUMP_TEST)
44
45
46
47
48
49
50
51

// An improved async, that doesn't block
template <class Function>
std::future<typename std::result_of<Function()>::type> detach_async(Function&& f,
                                                                    bool parallel = true)
{
    if(parallel)
    {
52
        using result_type = typename std::invoke_result<Function>::type;
53
54
55
        std::packaged_task<result_type()> task(std::forward<Function>(f));
        auto fut = task.get_future();
        std::thread(std::move(task)).detach();
56
        return fut;
57
58
59
60
    }
    return std::async(std::launch::deferred, std::forward<Function>(f));
}

Paul Fultz II's avatar
Paul Fultz II committed
61
62
63
64
65
66
67
68
69
inline void verify_load_save(const migraphx::program& p)
{
    migraphx::tmp_dir td{"migraphx_test"};
    auto path = td.path / "test.mxr";
    migraphx::save(p, path.string());
    auto loaded = migraphx::load(path.string());
    EXPECT(p == loaded);
}

70
71
72
73
inline void compile_check(migraphx::program& p,
                          const migraphx::target& t,
                          migraphx::compile_options c_opts,
                          bool show_trace = false)
74
75
76
77
{
    auto name   = t.name();
    auto shapes = p.get_output_shapes();
    std::stringstream ss;
78
    if(show_trace)
79
80
        c_opts.trace = migraphx::tracer{std::cout};
    p.compile(t, c_opts);
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    if(shapes.size() != p.get_output_shapes().size())
    {
        std::cout << ss.str() << std::endl;
        throw std::runtime_error("Compiling program with " + name +
                                 " alters its number of outputs");
    }

    auto num = shapes.size();
    for(std::size_t i = 0; i < num; ++i)
    {
        if(p.get_output_shapes()[i].lens() != shapes[i].lens())
        {
            std::cout << ss.str() << std::endl;
            throw std::runtime_error("Compiling program with " + name + " alters its shape");
        }
    }
Paul Fultz II's avatar
Paul Fultz II committed
97
98
    if(t.name() != "ref")
        verify_load_save(p);
99
100
101
102
103
104
105
106
107
108
109
110
111
}

target_info run_verify::get_target_info(const std::string& name) const
{
    auto it = info.find(name);
    if(it != info.end())
        return it->second;
    else
        return {};
}

void run_verify::validate(const migraphx::target& t,
                          const migraphx::program& p,
112
                          const migraphx::parameter_map& m) const
113
114
115
116
117
118
119
{
    auto ti = get_target_info(t.name());
    if(ti.validate)
        ti.validate(p, m);
}

std::vector<migraphx::argument> run_verify::run_ref(migraphx::program p,
120
121
                                                    migraphx::parameter_map inputs,
                                                    const migraphx::compile_options& c_opts) const
122
{
123
    migraphx::target t = migraphx::make_target("ref");
124
    auto_print pp{p, t.name()};
125
    compile_check(p, t, c_opts);
126
127
    return p.eval(std::move(inputs));
}
128
129
130
131
132
std::pair<migraphx::program, std::vector<migraphx::argument>>
run_verify::run_target(const migraphx::target& t,
                       migraphx::program p,
                       const migraphx::parameter_map& inputs,
                       const migraphx::compile_options& c_opts) const
133
134
135
{
    auto_print pp{p, t.name()};
    auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{});
136
    compile_check(p, t, c_opts, (trace_target == t.name()));
137
    migraphx::parameter_map m;
138
139
140
141
142
143
144
    for(auto&& input : inputs)
    {
        m[input.first] = t.copy_to(input.second);
    }
    for(auto&& x : p.get_parameter_shapes())
    {
        if(m.count(x.first) == 0)
Shucai Xiao's avatar
Shucai Xiao committed
145
        {
146
            m[x.first] = t.allocate(x.second);
Shucai Xiao's avatar
Shucai Xiao committed
147
        }
148
149
150
151
152
153
154
155
156
    }
    validate(t, p, m);
    p.eval(m);

    auto tres = p.eval(m);
    std::vector<migraphx::argument> res(tres.size());
    std::transform(
        tres.begin(), tres.end(), res.begin(), [&](auto& argu) { return t.copy_from(argu); });

157
    return std::make_pair(std::move(p), res);
158
159
160
161
162
163
164
165
}

template <class T>
auto get_hash(const T& x)
{
    return std::hash<T>{}(x);
}

166
167
168
void run_verify::verify(const std::string& name,
                        const migraphx::program& p,
                        const migraphx::compile_options& c_opts) const
169
170
171
172
{
    using result_future =
        std::future<std::pair<migraphx::program, std::vector<migraphx::argument>>>;
    auto_print::set_terminate_handler(name);
173
    if(migraphx::enabled(MIGRAPHX_DUMP_TEST{}))
Charlie Lin's avatar
Charlie Lin committed
174
        migraphx::save(p, name + ".mxr");
Paul Fultz II's avatar
Paul Fultz II committed
175
    verify_load_save(p);
176
177
178
    std::vector<std::string> target_names;
    for(const auto& tname : migraphx::get_targets())
    {
varunsh's avatar
varunsh committed
179
        // TODO(varunsh): once verify tests can run, remove fpga
180
        if(tname == "ref" or tname == "fpga")
181
            continue;
182
183
184
185
186
187

        // if tests disabled, skip running it
        target_info ti = get_target_info(tname);
        if(migraphx::contains(ti.disabled_tests, name))
            continue;

188
189
190
191
        target_names.push_back(tname);
    }
    if(not target_names.empty())
    {
192
        std::vector<std::pair<std::string, result_future>> results;
193
        migraphx::parameter_map m;
194
195
        for(auto&& x : p.get_parameter_shapes())
        {
Charlie Lin's avatar
Charlie Lin committed
196
197
198
199
200
201
202
203
204
205
            if(x.second.dynamic())
            {
                // create static shape using maximum dimensions
                migraphx::shape static_shape{x.second.type(), x.second.max_lens()};
                m[x.first] = migraphx::generate_argument(static_shape, get_hash(x.first));
            }
            else
            {
                m[x.first] = migraphx::generate_argument(x.second, get_hash(x.first));
            }
206
207
        }

208
        auto gold_f = detach_async([=] { return run_ref(p, m, c_opts); });
209
210
211
212
        for(const auto& tname : target_names)
        {
            target_info ti = get_target_info(tname);
            auto t         = migraphx::make_target(tname);
213
214
            results.emplace_back(
                tname, detach_async([=] { return run_target(t, p, m, c_opts); }, ti.parallel));
215
216
        }

217
        assert(gold_f.valid());
218
219
220
221
        auto gold = gold_f.get();

        for(auto&& pp : results)
        {
222
            assert(pp.second.valid());
223
224
225
226
227
228
229
230
231
232
233
234
235
            auto tname  = pp.first;
            auto x      = pp.second.get();
            auto cp     = x.first;
            auto result = x.second;

            bool passed = true;
            passed &= (gold.size() == result.size());
            std::size_t num = gold.size();
            for(std::size_t i = 0; ((i < num) and passed); ++i)
            {
                passed &= migraphx::verify_args(tname, gold[i], result[i]);
            }

236
            if(not passed or migraphx::enabled(MIGRAPHX_TRACE_TEST{}))
237
238
239
240
241
242
            {
                std::cout << p << std::endl;
                std::cout << "ref:\n" << p << std::endl;
                std::cout << tname << ":\n" << cp << std::endl;
                std::cout << std::endl;
            }
243
            EXPECT(passed);
244
245
246
247
248
249
250
        }
    }
    std::set_terminate(nullptr);
}

void run_verify::run(int argc, const char* argv[]) const
{
251
252
    std::unordered_map<std::string, std::vector<std::string>> labels;
    for(auto&& p : get_programs())
253
    {
254
        labels[p.section].push_back(p.name);
255
        test::add_test_case(p.name, [=] { verify(p.name, p.get_program(), p.compile_options); });
256
    }
257
258
259
260
261
262
263
    test::driver d{};
    d.get_case_names = [&](const std::string& name) -> std::vector<std::string> {
        if(labels.count(name) > 0)
            return labels.at(name);
        return {name};
    };
    d.run(argc, argv);
264
265
266
267
268
269
270
}

void run_verify::disable_parallel_for(const std::string& name) { info[name].parallel = false; }
void run_verify::add_validation_for(const std::string& name, target_info::validation_function v)
{
    info[name].validate = std::move(v);
}
271
272
273
274
275
276

void run_verify::disable_test_for(const std::string& name, const std::vector<std::string>& tests)
{
    auto& disabled_tests = info[name].disabled_tests;
    disabled_tests.insert(disabled_tests.end(), tests.begin(), tests.end());
}