target.hpp 3.76 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_TARGET_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_TARGET_HPP
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
#include <cassert>
Paul's avatar
Paul committed
5
6
7
8
9
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
Paul's avatar
Paul committed
10
#include <vector>
Paul's avatar
Paul committed
11
12
#include <migraphx/context.hpp>
#include <migraphx/pass.hpp>
Paul's avatar
Paul committed
13
#include <migraphx/config.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
14
15
#include <migraphx/argument.hpp>
#include <migraphx/rank.hpp>
Paul's avatar
Paul committed
16

Paul's avatar
Paul committed
17
namespace migraphx {
Paul's avatar
Paul committed
18
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
19

Paul's avatar
Paul committed
20
#ifdef DOXYGEN
Paul's avatar
Paul committed
21

Paul's avatar
Paul committed
22
/// An interface for a compilation target
Paul's avatar
Paul committed
23
24
25
26
27
28
struct target
{
    /// A unique name used to identify the target
    std::string name() const;
    /**
     * @brief The transformation pass to be run during compilation.
Paul's avatar
Paul committed
29
     *
Paul's avatar
Paul committed
30
31
32
33
34
35
36
37
38
     * @param ctx This is the target-dependent context that is created by `get_context`
     * @return The passes to be ran
     */
    std::vector<pass> get_passes(context& ctx) const;
    /**
     * @brief Construct a context for the target.
     * @return The context to be used during compilation and execution.
     */
    context get_context() const;
Shucai Xiao's avatar
Shucai Xiao committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    /**
     * @brief copy an argument to the current target.
     * 
     * @param arg Input argument to be copied to the target
     * @return Argument in the target.
     */
    argument copy_to(const argument& arg) const;
    /**
     * @brief copy an argument from the current target.
     * 
     * @param arg Input argument to be copied from the target
     * @return Argument in the host.
     */
    argument copy_from(const argument& arg) const;
    /**
     * @brief Allocate an argument based on the input shape
     * 
     * @param s Shape of the argument to be allocated in the target
     * @return Allocated argument in the target.
     */
    argument allocate(const shape& s) const;

Paul's avatar
Paul committed
61
62
63
};

#else
Paul's avatar
Paul committed
64

Shucai Xiao's avatar
Shucai Xiao committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
template <class T>
auto target_allocate(rank<1>, T& x, const shape& s)
    -> decltype(x.allocate(s))
{
    return x.allocate(s);
}

template <class T>
argument target_allocate(rank<0>, T& x, const shape&)
{
    std::string name = x.name();
    MIGRAPHX_THROW("Not computable: " + name);
    return argument{};
}

template <class T>
argument target_allocate(T& x, const shape& s)
{
    return target_allocate(rank<1>{}, x, s);
}

template <class T>
auto copy_to_target(rank<1>, T& x, const argument& arg)
    -> decltype(x.copy_to(arg))
{
    return x.copy_to(arg);
}

template <class T>
argument copy_to_target(rank<0>, T& x, const argument&)
{
    std::string name = x.name();
    MIGRAPHX_THROW("Not computable: " + name);

    return argument{};
}

template <class T>
argument copy_to_target(T& x, const argument& arg)
{
    return copy_to_target(rank<1>{}, x, arg);
}

template <class T>
auto copy_from_target(
    rank<1>, T& x, const argument& arg)
    -> decltype(x.copy_from(arg))
{
    return x.copy_from(arg);
}

template <class T>
argument copy_from_target(rank<0>, T& x, const argument&)
{
    std::string name = x.name();
    MIGRAPHX_THROW("Not computable: " + name);
    return argument{};
}

template <class T>
argument copy_from_target(T& x, const argument& arg)
{
    return copy_from_target(rank<1>{}, x, arg);
}

Paul's avatar
Paul committed
130
131
132
<%
interface('target',
    virtual('name', returns='std::string', const=True),
Paul's avatar
Paul committed
133
    virtual('get_passes', ctx='context&', returns='std::vector<pass>', const=True),
Shucai Xiao's avatar
Shucai Xiao committed
134
135
136
137
138
139
140
141
142
143
144
145
146
    virtual('get_context', returns='context', const=True),
     virtual('copy_to',
             returns = 'argument',
             input   = 'const argument&',
             const   = True,
             default = 'copy_to_target'),
     virtual('copy_from',
             returns = 'argument',
             input   = 'const argument&',
             const   = True,
             default = 'copy_from_target'),
    virtual('allocate', s='const shape&', returns='argument', const=True,
             default = 'target_allocate')
Paul's avatar
Paul committed
147
148
149
)
%>

Paul's avatar
Paul committed
150
151
#endif

Paul's avatar
Paul committed
152
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
153
} // namespace migraphx
Paul's avatar
Paul committed
154
155

#endif