capture.hpp 1.21 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
#ifndef MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP
Shucai Xiao's avatar
Shucai Xiao committed
2
#define MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP
Shucai Xiao's avatar
Shucai Xiao committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {

struct capture
{
Shucai Xiao's avatar
Shucai Xiao committed
21
22
    std::size_t ins_index;
    std::function<void(std::size_t ins_index, std::vector<argument>)> f;
Shucai Xiao's avatar
Shucai Xiao committed
23
    template <class Self, class F>
Shucai Xiao's avatar
Shucai Xiao committed
24
25
    static auto reflect(Self& self, F f)
    {
Shucai Xiao's avatar
Shucai Xiao committed
26
        return pack(f(self.ins_index, "ins_index"));
Shucai Xiao's avatar
Shucai Xiao committed
27
28
    }

Shucai Xiao's avatar
Shucai Xiao committed
29
    std::string name() const { return "capture"; }
Shucai Xiao's avatar
Shucai Xiao committed
30

Shucai Xiao's avatar
Shucai Xiao committed
31
    shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
Shucai Xiao's avatar
Shucai Xiao committed
32

Shucai Xiao's avatar
Shucai Xiao committed
33
    argument compute(const shape&, std::vector<argument> args) const
Shucai Xiao's avatar
Shucai Xiao committed
34
    {
35
36
37
38
39
40
41
42
43
        if (f)
        {
            f(ins_index, args);
        }
        else
        {
            MIGRAPHX_THROW("CAPTURE: callback function is not callable!");
        }
        
Shucai Xiao's avatar
Shucai Xiao committed
44
        return args.front();
Shucai Xiao's avatar
Shucai Xiao committed
45
46
47
48
49
50
51
52
    }
};

} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif