check_context.hpp 797 Bytes
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
#ifndef MIGRAPH_GUARD_RTGLIB_CHECK_CONTEXT_HPP
#define MIGRAPH_GUARD_RTGLIB_CHECK_CONTEXT_HPP

#include <migraph/program.hpp>

namespace migraph {

Paul's avatar
Paul committed
8
template <class T>
Paul's avatar
Paul committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
struct check_context
{
    struct op
    {
        std::string name() const { return "check_context"; }
        shape compute_shape(std::vector<shape>) const { return {}; }
        argument compute(context& ctx, shape, std::vector<argument>) const
        {
            T* x = any_cast<T>(&ctx);
            if(x == nullptr)
                MIGRAPH_THROW(std::string("Unexpected context type: ") + ctx.type_id().name());
            return {};
        }
    };

    std::string name() const { return "check_context"; }
Paul's avatar
Paul committed
25
    void apply(program& p) const { p.insert_instruction(p.begin(), op{}); }
Paul's avatar
Paul committed
26
27
28
29
30
};

} // namespace migraph

#endif