if_op.hpp 2.17 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
4
5
6
7
8
9
10
11
#ifndef MIGRAPHX_GUARD_OPERATORS_IF_OP_HPP
#define MIGRAPHX_GUARD_OPERATORS_IF_OP_HPP

#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/module.hpp>
#include <cmath>
#include <utility>
Shucai Xiao's avatar
Shucai Xiao committed
12
#include <set>
Shucai Xiao's avatar
Shucai Xiao committed
13
14
15
16
17
18
19
20
21
22
23

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {

struct if_op
{
    std::string name() const { return "if"; }

    shape compute_shape(const std::vector<shape>& inputs, std::vector<module_ref> mods) const
    {
Shucai Xiao's avatar
Shucai Xiao committed
24
        check_shapes{inputs, *this}.standard();
Shucai Xiao's avatar
Shucai Xiao committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
        if(mods.size() != 2)
        {
            MIGRAPHX_THROW("IF: operator should have two submodules.");
        }

        auto out_shapes0 = mods[0]->get_output_shapes();
        auto out_shapes1 = mods[1]->get_output_shapes();
        if(not std::equal(
               out_shapes1.begin(), out_shapes1.end(), out_shapes0.begin(), out_shapes0.end()))
        {
            MIGRAPHX_THROW("IF: output shapes of submodules must be the same.");
        }

        return out_shapes0.front();
    }
Shucai Xiao's avatar
Shucai Xiao committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

    argument compute(
        const std::vector<argument>& args,
        const std::vector<module_ref>& mods,
        const std::function<std::vector<argument>(
            module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)>& run) const
    {
        auto cond      = args.front().at<bool>();
        module_ref mod = cond ? mods[0] : mods[1];
        std::unordered_map<std::string, argument> params;

        std::set<std::string> pnames;
        for(const auto& smod : mods)
        {
            auto names = smod->get_parameter_names();
            pnames.insert(names.begin(), names.end());
        }

        assert(pnames.size() < args.size());
        std::transform(pnames.begin(),
                       pnames.end(),
                       args.begin() + 1,
                       std::inserter(params, params.end()),
                       [](auto&& name, auto&& arg) { return std::make_pair(name, arg); });

        auto results = run(mod, params);
        return results[0];
    }
Shucai Xiao's avatar
Shucai Xiao committed
68
69
70
71
72
73
74
};

} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif