run_verify.cpp 10.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
24
25
26
#include "run_verify.hpp"
#include "auto_print.hpp"
#include "verify_program.hpp"
27
#include "test.hpp"
28
#include <migraphx/env.hpp>
29
#include <migraphx/register_target.hpp>
30
#include <migraphx/ranges.hpp>
31
#include <migraphx/generate.hpp>
32
#include <migraphx/load_save.hpp>
Paul Fultz II's avatar
Paul Fultz II committed
33
#include <migraphx/tmp_dir.hpp>
34
35
36
37
38
39
40
41
#include <migraphx/verify_args.hpp>
#include <set>

#include <future>
#include <thread>
#include <utility>

MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_TEST_COMPILE)
42
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_TEST)
43
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DUMP_TEST)
44
45
46

// An improved async, that doesn't block
template <class Function>
47
48
std::future<typename std::invoke_result_t<Function>> detach_async(Function&& f,
                                                                  bool parallel = true)
49
50
51
{
    if(parallel)
    {
52
        using result_type = typename std::invoke_result<Function>::type;
53
54
55
        std::packaged_task<result_type()> task(std::forward<Function>(f));
        auto fut = task.get_future();
        std::thread(std::move(task)).detach();
56
        return fut;
57
58
59
60
    }
    return std::async(std::launch::deferred, std::forward<Function>(f));
}

Paul Fultz II's avatar
Paul Fultz II committed
61
62
63
64
65
66
67
68
69
inline void verify_load_save(const migraphx::program& p)
{
    migraphx::tmp_dir td{"migraphx_test"};
    auto path = td.path / "test.mxr";
    migraphx::save(p, path.string());
    auto loaded = migraphx::load(path.string());
    EXPECT(p == loaded);
}

70
71
72
73
inline void compile_check(migraphx::program& p,
                          const migraphx::target& t,
                          migraphx::compile_options c_opts,
                          bool show_trace = false)
74
75
76
77
{
    auto name   = t.name();
    auto shapes = p.get_output_shapes();
    std::stringstream ss;
78
    if(show_trace)
79
80
        c_opts.trace = migraphx::tracer{std::cout};
    p.compile(t, c_opts);
81
82
83
84
85
86
87
88
89
90
    if(shapes.size() != p.get_output_shapes().size())
    {
        std::cout << ss.str() << std::endl;
        throw std::runtime_error("Compiling program with " + name +
                                 " alters its number of outputs");
    }

    auto num = shapes.size();
    for(std::size_t i = 0; i < num; ++i)
    {
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        auto output_shape = p.get_output_shapes()[i];
        if(output_shape.dynamic() and shapes[i].dynamic())
        {
            if(output_shape.dyn_dims() != shapes[i].dyn_dims())
            {
                std::cout << ss.str() << std::endl;
                throw std::runtime_error("Compiling program with " + name +
                                         " alters its dynamic output dimensions");
            }
        }
        else if(not(output_shape.dynamic() or shapes[i].dynamic()))
        {
            if(output_shape.lens() != shapes[i].lens())
            {
                std::cout << ss.str() << std::endl;
                throw std::runtime_error("Compiling program with " + name +
                                         " alters its static output dimensions");
            }
        }
        else
111
112
        {
            std::cout << ss.str() << std::endl;
113
114
115
            throw std::runtime_error(
                "Compiling program with " + name +
                " alters its output dimensions (static shape vs dynamic shape)");
116
117
        }
    }
Paul Fultz II's avatar
Paul Fultz II committed
118
119
    if(t.name() != "ref")
        verify_load_save(p);
120
121
122
123
124
125
126
127
128
129
130
131
132
}

target_info run_verify::get_target_info(const std::string& name) const
{
    auto it = info.find(name);
    if(it != info.end())
        return it->second;
    else
        return {};
}

void run_verify::validate(const migraphx::target& t,
                          const migraphx::program& p,
133
                          const migraphx::parameter_map& m) const
134
135
136
137
138
139
140
{
    auto ti = get_target_info(t.name());
    if(ti.validate)
        ti.validate(p, m);
}

std::vector<migraphx::argument> run_verify::run_ref(migraphx::program p,
141
142
                                                    migraphx::parameter_map inputs,
                                                    const migraphx::compile_options& c_opts) const
143
{
144
    migraphx::target t = migraphx::make_target("ref");
145
    auto_print pp{p, t.name()};
146
    compile_check(p, t, c_opts);
147
148
    return p.eval(std::move(inputs));
}
149
150
151
152
153
std::pair<migraphx::program, std::vector<migraphx::argument>>
run_verify::run_target(const migraphx::target& t,
                       migraphx::program p,
                       const migraphx::parameter_map& inputs,
                       const migraphx::compile_options& c_opts) const
154
155
156
{
    auto_print pp{p, t.name()};
    auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{});
157
    compile_check(p, t, c_opts, (trace_target == t.name()));
158
    migraphx::parameter_map m;
159
160
161
162
163
164
165
    for(auto&& input : inputs)
    {
        m[input.first] = t.copy_to(input.second);
    }
    for(auto&& x : p.get_parameter_shapes())
    {
        if(m.count(x.first) == 0)
Shucai Xiao's avatar
Shucai Xiao committed
166
        {
167
            m[x.first] = t.allocate(x.second);
Shucai Xiao's avatar
Shucai Xiao committed
168
        }
169
170
171
172
173
174
175
176
177
    }
    validate(t, p, m);
    p.eval(m);

    auto tres = p.eval(m);
    std::vector<migraphx::argument> res(tres.size());
    std::transform(
        tres.begin(), tres.end(), res.begin(), [&](auto& argu) { return t.copy_from(argu); });

178
    return std::make_pair(std::move(p), res);
179
180
181
182
183
184
185
186
}

template <class T>
auto get_hash(const T& x)
{
    return std::hash<T>{}(x);
}

187
188
189
void run_verify::verify(const std::string& name,
                        const migraphx::program& p,
                        const migraphx::compile_options& c_opts) const
190
191
192
193
{
    using result_future =
        std::future<std::pair<migraphx::program, std::vector<migraphx::argument>>>;
    auto_print::set_terminate_handler(name);
194
    if(migraphx::enabled(MIGRAPHX_DUMP_TEST{}))
Charlie Lin's avatar
Charlie Lin committed
195
        migraphx::save(p, name + ".mxr");
Paul Fultz II's avatar
Paul Fultz II committed
196
    verify_load_save(p);
197
198
199
    std::vector<std::string> target_names;
    for(const auto& tname : migraphx::get_targets())
    {
varunsh's avatar
varunsh committed
200
        // TODO(varunsh): once verify tests can run, remove fpga
201
        if(tname == "ref" or tname == "fpga")
202
            continue;
203
204
205
206
207
208

        // if tests disabled, skip running it
        target_info ti = get_target_info(tname);
        if(migraphx::contains(ti.disabled_tests, name))
            continue;

209
210
211
212
        target_names.push_back(tname);
    }
    if(not target_names.empty())
    {
213
        std::vector<std::pair<std::string, result_future>> results;
214
        migraphx::parameter_map m;
215
216
        for(auto&& x : p.get_parameter_shapes())
        {
Charlie Lin's avatar
Charlie Lin committed
217
218
219
220
221
222
223
224
225
226
            if(x.second.dynamic())
            {
                // create static shape using maximum dimensions
                migraphx::shape static_shape{x.second.type(), x.second.max_lens()};
                m[x.first] = migraphx::generate_argument(static_shape, get_hash(x.first));
            }
            else
            {
                m[x.first] = migraphx::generate_argument(x.second, get_hash(x.first));
            }
227
228
        }

229
        auto gold_f = detach_async([=] { return run_ref(p, m, c_opts); });
230
231
232
233
        for(const auto& tname : target_names)
        {
            target_info ti = get_target_info(tname);
            auto t         = migraphx::make_target(tname);
234
235
            results.emplace_back(
                tname, detach_async([=] { return run_target(t, p, m, c_opts); }, ti.parallel));
236
237
        }

238
        assert(gold_f.valid());
239
240
241
242
        auto gold = gold_f.get();

        for(auto&& pp : results)
        {
243
            assert(pp.second.valid());
244
245
246
247
248
249
250
251
252
253
            auto tname  = pp.first;
            auto x      = pp.second.get();
            auto cp     = x.first;
            auto result = x.second;

            bool passed = true;
            passed &= (gold.size() == result.size());
            std::size_t num = gold.size();
            for(std::size_t i = 0; ((i < num) and passed); ++i)
            {
254
255
                passed &= migraphx::verify_args_with_tolerance(
                    tname, result[i], migraphx::verify::expected{gold[i]});
256
257
            }

258
            if(not passed or migraphx::enabled(MIGRAPHX_TRACE_TEST{}))
259
260
261
262
263
264
            {
                std::cout << p << std::endl;
                std::cout << "ref:\n" << p << std::endl;
                std::cout << tname << ":\n" << cp << std::endl;
                std::cout << std::endl;
            }
265
            EXPECT(passed);
266
267
268
269
270
271
272
        }
    }
    std::set_terminate(nullptr);
}

void run_verify::run(int argc, const char* argv[]) const
{
273
274
    std::unordered_map<std::string, std::vector<std::string>> labels;
    for(auto&& p : get_programs())
275
    {
276
        labels[p.section].push_back(p.name);
277
        test::add_test_case(p.name, [=] { verify(p.name, p.get_program(), p.compile_options); });
278
    }
279
280
281
282
283
284
285
    test::driver d{};
    d.get_case_names = [&](const std::string& name) -> std::vector<std::string> {
        if(labels.count(name) > 0)
            return labels.at(name);
        return {name};
    };
    d.run(argc, argv);
286
287
288
289
290
291
292
}

void run_verify::disable_parallel_for(const std::string& name) { info[name].parallel = false; }
void run_verify::add_validation_for(const std::string& name, target_info::validation_function v)
{
    info[name].validate = std::move(v);
}
293
294
295
296
297
298

void run_verify::disable_test_for(const std::string& name, const std::vector<std::string>& tests)
{
    auto& disabled_tests = info[name].disabled_tests;
    disabled_tests.insert(disabled_tests.end(), tests.begin(), tests.end());
}