run_verify.hpp 1.27 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
#ifndef MIGRAPHX_GUARD_TEST_RUN_VERIFY_HPP
#define MIGRAPHX_GUARD_TEST_RUN_VERIFY_HPP

#include <migraphx/program.hpp>
#include <functional>
#include <map>

struct target_info
{
    using validation_function =
11
        std::function<void(const migraphx::program& p, const migraphx::parameter_map& m)>;
12
13
14
15
16
17
18
    bool parallel = true;
    validation_function validate;
};

struct run_verify
{
    std::vector<migraphx::argument> run_ref(migraphx::program p,
19
                                            migraphx::parameter_map inputs) const;
20
21
22
    std::pair<migraphx::program, std::vector<migraphx::argument>>
    run_target(const migraphx::target& t,
               migraphx::program p,
23
               const migraphx::parameter_map& inputs) const;
24
25
    void validate(const migraphx::target& t,
                  const migraphx::program& p,
26
                  const migraphx::parameter_map& m) const;
27
28
29
30
31
32
33
34
35
36
37
38
    void verify(const std::string& name, const migraphx::program& p) const;
    void run(int argc, const char* argv[]) const;

    target_info get_target_info(const std::string& name) const;
    void disable_parallel_for(const std::string& name);
    void add_validation_for(const std::string& name, target_info::validation_function v);

    private:
    std::map<std::string, target_info> info{};
};

#endif