run_verify.hpp 1.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
#ifndef MIGRAPHX_GUARD_TEST_RUN_VERIFY_HPP
#define MIGRAPHX_GUARD_TEST_RUN_VERIFY_HPP

#include <migraphx/program.hpp>
#include <functional>
#include <map>

struct target_info
{
    using validation_function =
11
        std::function<void(const migraphx::program& p, const migraphx::parameter_map& m)>;
12
13
    bool parallel = true;
    validation_function validate;
14
    std::vector<std::string> disabled_tests;
15
16
17
18
19
};

struct run_verify
{
    std::vector<migraphx::argument> run_ref(migraphx::program p,
20
                                            migraphx::parameter_map inputs) const;
21
22
23
    std::pair<migraphx::program, std::vector<migraphx::argument>>
    run_target(const migraphx::target& t,
               migraphx::program p,
24
               const migraphx::parameter_map& inputs) const;
25
26
    void validate(const migraphx::target& t,
                  const migraphx::program& p,
27
                  const migraphx::parameter_map& m) const;
28
29
30
31
32
33
    void verify(const std::string& name, const migraphx::program& p) const;
    void run(int argc, const char* argv[]) const;

    target_info get_target_info(const std::string& name) const;
    void disable_parallel_for(const std::string& name);
    void add_validation_for(const std::string& name, target_info::validation_function v);
34
    void disable_test_for(const std::string& name, const std::vector<std::string>& tests);
35
36
37
38
39
40

    private:
    std::map<std::string, target_info> info{};
};

#endif