check_context.hpp 1.2 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPHX_GUARD_RTGLIB_CHECK_CONTEXT_HPP
#define MIGRAPHX_GUARD_RTGLIB_CHECK_CONTEXT_HPP
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
5
#include <migraphx/program.hpp>
#include <migraphx/config.hpp>
6
#include <migraphx/register_op.hpp>
Paul's avatar
Paul committed
7

Paul's avatar
Paul committed
8
namespace migraphx {
Paul's avatar
Paul committed
9
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
10

Paul's avatar
Paul committed
11
template <class T>
Paul's avatar
Paul committed
12
13
struct check_context
{
14
    struct op : auto_register_op<op>
Paul's avatar
Paul committed
15
    {
16
        std::string name() const { return "check_context::" + get_type_name<T>(); }
Paul's avatar
Paul committed
17
18
        shape compute_shape(const std::vector<shape>&) const { return {}; }
        argument compute(context& ctx, const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
19
20
21
22
23
24
25
26
27
        {
            this->check(ctx);
            return {};
        }
        void finalize(context& ctx, const shape&, const std::vector<shape>&) const
        {
            this->check(ctx);
        }
        void check(context& ctx) const
Paul's avatar
Paul committed
28
29
30
        {
            T* x = any_cast<T>(&ctx);
            if(x == nullptr)
Paul's avatar
Paul committed
31
                MIGRAPHX_THROW(std::string("Unexpected context type: ") + ctx.type_id().name());
Paul's avatar
Paul committed
32
33
34
35
        }
    };

    std::string name() const { return "check_context"; }
36
    void apply(module& m) const { m.insert_instruction(m.begin(), op{}); }
Paul's avatar
Paul committed
37
38
};

Paul's avatar
Paul committed
39
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
40
} // namespace migraphx
Paul's avatar
Paul committed
41
42

#endif