check_context.hpp 918 Bytes
Newer Older
Paul's avatar
Paul committed
1
2
3
4
#ifndef MIGRAPH_GUARD_RTGLIB_CHECK_CONTEXT_HPP
#define MIGRAPH_GUARD_RTGLIB_CHECK_CONTEXT_HPP

#include <migraph/program.hpp>
5
#include <migraph/config.hpp>
Paul's avatar
Paul committed
6

7
8
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
Paul's avatar
Paul committed
9

Paul's avatar
Paul committed
10
template <class T>
Paul's avatar
Paul committed
11
12
13
14
15
struct check_context
{
    struct op
    {
        std::string name() const { return "check_context"; }
Paul's avatar
Paul committed
16
17
        shape compute_shape(const std::vector<shape>&) const { return {}; }
        argument compute(context& ctx, const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
18
19
20
21
22
23
24
25
26
        {
            T* x = any_cast<T>(&ctx);
            if(x == nullptr)
                MIGRAPH_THROW(std::string("Unexpected context type: ") + ctx.type_id().name());
            return {};
        }
    };

    std::string name() const { return "check_context"; }
Paul's avatar
Paul committed
27
    void apply(program& p) const { p.insert_instruction(p.begin(), op{}); }
Paul's avatar
Paul committed
28
29
};

30
} // namespace MIGRAPH_INLINE_NS
Paul's avatar
Paul committed
31
32
33
} // namespace migraph

#endif