migraphx_py.cpp 7.62 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <migraphx/program.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
5
#include <migraphx/quantization.hpp>
Paul's avatar
Paul committed
6
7
#include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp>
Paul's avatar
Paul committed
8
#include <migraphx/stringutils.hpp>
9
10
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
11
#include <migraphx/type_name.hpp>
12

Paul's avatar
Paul committed
13
14
15
16
#ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#endif
Paul's avatar
Paul committed
17
18
19

namespace py = pybind11;

Paul's avatar
Paul committed
20
template <class F>
Paul's avatar
Paul committed
21
struct throw_half
Paul's avatar
Paul committed
22
23
24
{
    F f;

Paul's avatar
Paul committed
25
    template <class A>
Paul's avatar
Paul committed
26
27
28
29
30
31
32
33
    void operator()(A a) const
    {
        f(a);
    }

    void operator()(migraphx::shape::as<migraphx::half>) const
    {
        throw std::runtime_error("Half not supported in python yet.");
Paul's avatar
Paul committed
34
    }
Paul's avatar
Paul committed
35
36
37
38
39

    void operator()(migraphx::tensor_view<migraphx::half>) const
    {
        throw std::runtime_error("Half not supported in python yet.");
    }
Paul's avatar
Paul committed
40
41
};

Paul's avatar
Paul committed
42
43
44
45
46
47
48
49
50
51
52
template <class F>
struct skip_half
{
    F f;

    template <class A>
    void operator()(A a) const
    {
        f(a);
    }

Paul's avatar
Paul committed
53
    void operator()(migraphx::shape::as<migraphx::half>) const {}
Paul's avatar
Paul committed
54

Paul's avatar
Paul committed
55
    void operator()(migraphx::tensor_view<migraphx::half>) const {}
Paul's avatar
Paul committed
56
57
};

Paul's avatar
Paul committed
58
template <class F>
Paul's avatar
Paul committed
59
60
void visit_type(const migraphx::shape& s, F f)
{
Paul's avatar
Paul committed
61
62
63
    s.visit_type(throw_half<F>{f});
}

Paul's avatar
Paul committed
64
65
66
67
68
69
template <class T, class F>
void visit(const migraphx::raw_data<T>& x, F f)
{
    x.visit(throw_half<F>{f});
}

Paul's avatar
Paul committed
70
71
72
73
template <class F>
void visit_types(F f)
{
    migraphx::shape::visit_types(skip_half<F>{f});
Paul's avatar
Paul committed
74
75
}

Paul's avatar
Paul committed
76
template <class T>
Paul's avatar
Paul committed
77
78
79
py::buffer_info to_buffer_info(T& x)
{
    migraphx::shape s = x.get_shape();
Paul's avatar
Paul committed
80
81
82
    auto strides      = s.strides();
    std::transform(
        strides.begin(), strides.end(), strides.begin(), [&](auto i) { return i * s.type_size(); });
Paul's avatar
Paul committed
83
84
    py::buffer_info b;
    visit_type(s, [&](auto as) {
Paul's avatar
Paul committed
85
86
87
88
89
        b = py::buffer_info(x.data(),
                            as.size(),
                            py::format_descriptor<decltype(as())>::format(),
                            s.lens().size(),
                            s.lens(),
90
                            strides);
Paul's avatar
Paul committed
91
92
93
94
    });
    return b;
}

Paul's avatar
Paul committed
95
96
97
migraphx::shape to_shape(const py::buffer_info& info)
{
    migraphx::shape::type_t t;
98
    std::size_t n = 0;
Paul's avatar
Paul committed
99
    visit_types([&](auto as) {
Paul's avatar
Paul committed
100
101
        if(info.format == py::format_descriptor<decltype(as())>::format())
        {
Paul's avatar
Paul committed
102
            t = as.type_enum();
103
104
105
            n = sizeof(as());
        }
    });
106

Shucai Xiao's avatar
Shucai Xiao committed
107
    if(n == 0)
108
109
110
111
    {
        MIGRAPHX_THROW("MIGRAPHX PYTHON: Unsupported data type" + info.format);
    }

112
113
    auto strides = info.strides;
    std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t {
Paul's avatar
Paul committed
114
        return n > 0 ? i / n : 0;
Paul's avatar
Paul committed
115
    });
116
117
118
119
120
121
122
123
124
125

    // scalar support
    if(info.shape.empty())
    {
        return migraphx::shape{t};
    }
    else
    {
        return migraphx::shape{t, info.shape, strides};
    }
Paul's avatar
Paul committed
126
127
}

Paul's avatar
Paul committed
128
129
PYBIND11_MODULE(migraphx, m)
{
Paul's avatar
Paul committed
130
131
132
133
134
135
136
137
138
139
140
141
    py::class_<migraphx::shape>(m, "shape")
        .def(py::init<>())
        .def("type", &migraphx::shape::type)
        .def("lens", &migraphx::shape::lens)
        .def("strides", &migraphx::shape::strides)
        .def("elements", &migraphx::shape::elements)
        .def("bytes", &migraphx::shape::bytes)
        .def("type_size", &migraphx::shape::type_size)
        .def("packed", &migraphx::shape::packed)
        .def("transposed", &migraphx::shape::transposed)
        .def("broadcasted", &migraphx::shape::broadcasted)
        .def("standard", &migraphx::shape::standard)
Paul's avatar
Paul committed
142
        .def("scalar", &migraphx::shape::scalar)
Paul's avatar
Paul committed
143
144
        .def("__eq__", std::equal_to<migraphx::shape>{})
        .def("__ne__", std::not_equal_to<migraphx::shape>{})
Paul's avatar
Paul committed
145
        .def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); });
Paul's avatar
Paul committed
146
147

    py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
Paul's avatar
Paul committed
148
        .def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
Paul's avatar
Paul committed
149
150
151
152
153
        .def("__init__",
             [](migraphx::argument& x, py::buffer b) {
                 py::buffer_info info = b.request();
                 new(&x) migraphx::argument(to_shape(info), info.ptr);
             })
Paul's avatar
Paul committed
154
        .def("get_shape", &migraphx::argument::get_shape)
Paul's avatar
Paul committed
155
156
157
158
159
160
        .def("tolist",
             [](migraphx::argument& x) {
                 py::list l{x.get_shape().elements()};
                 visit(x, [&](auto data) { l = py::cast(data.to_vector()); });
                 return l;
             })
Paul's avatar
Paul committed
161
162
163
        .def("__eq__", std::equal_to<migraphx::argument>{})
        .def("__ne__", std::not_equal_to<migraphx::argument>{})
        .def("__repr__", [](const migraphx::argument& x) { return migraphx::to_string(x); });
Paul's avatar
Paul committed
164

Paul's avatar
Paul committed
165
166
    py::class_<migraphx::target>(m, "target");

Paul's avatar
Paul committed
167
    py::class_<migraphx::program>(m, "program")
168
        .def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); })
Paul's avatar
Paul committed
169
        .def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
170
        .def("get_output_shapes", &migraphx::program::get_output_shapes)
171
172
173
174
175
176
177
178
        .def("compile",
             [](migraphx::program& p, const migraphx::target& t, bool offload_copy) {
                 migraphx::compile_options options;
                 options.offload_copy = offload_copy;
                 p.compile(t, options);
             },
             py::arg("t"),
             py::arg("offload_copy") = true)
Paul's avatar
Paul committed
179
        .def("run", &migraphx::program::eval)
Paul's avatar
Paul committed
180
181
        .def("__eq__", std::equal_to<migraphx::program>{})
        .def("__ne__", std::not_equal_to<migraphx::program>{})
Paul's avatar
Paul committed
182
        .def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); });
Paul's avatar
Paul committed
183

Khalique's avatar
Khalique committed
184
    m.def("parse_tf",
185
186
187
          [](const std::string& filename, bool is_nhwc, unsigned int batch_size) {
              return migraphx::parse_tf(filename, migraphx::tf_options{is_nhwc, batch_size});
          },
Khalique's avatar
Khalique committed
188
189
          "Parse tf protobuf (default format is nhwc)",
          py::arg("filename"),
190
191
          py::arg("is_nhwc")    = true,
          py::arg("batch_size") = 1);
192

193
    m.def("parse_onnx",
194
195
196
197
198
199
200
          [](const std::string& filename,
             std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
             std::size_t value) {
              migraphx::onnx_options options;
              options.map_input_dims    = map_input_dims;
              options.default_dim_value = value;
              return migraphx::parse_onnx(filename, options);
201
202
203
          },
          "Parse onnx file",
          py::arg("filename"),
204
205
          py::arg("map_input_dims") = std::map<std::string, std::vector<std::size_t>>(),
          py::arg("value")          = 1);
Paul's avatar
Paul committed
206

Paul's avatar
Paul committed
207
    m.def("get_target", [](const std::string& name) -> migraphx::target {
Paul's avatar
Paul committed
208
        if(name == "cpu")
Paul's avatar
Paul committed
209
            return migraphx::cpu::target{};
Paul's avatar
Paul committed
210
211
212
213
#ifdef HAVE_GPU
        if(name == "gpu")
            return migraphx::gpu::target{};
#endif
Paul's avatar
Paul committed
214
215
216
217
        throw std::runtime_error("Target not found: " + name);
    });

    m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
Shucai Xiao's avatar
Shucai Xiao committed
218
219
220
221
222
223
224
225
226
227
    m.def("quantize_fp16",
          &migraphx::quantize_fp16,
          py::arg("prog"),
          py::arg("ins_names") = std::vector<std::string>{"all"});
    m.def("quantize_int8",
          &migraphx::quantize_int8,
          py::arg("prog"),
          py::arg("t"),
          py::arg("calibration") = std::vector<migraphx::program::parameter_map>{},
          py::arg("ins_names")   = std::vector<std::string>{"dot", "convolution"});
Shucai Xiao's avatar
Shucai Xiao committed
228

Paul's avatar
Paul committed
229
230
231
232
233
234
235
#ifdef HAVE_GPU
    m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
    m.def("to_gpu", &migraphx::gpu::to_gpu, py::arg("arg"), py::arg("host") = false);
    m.def("from_gpu", &migraphx::gpu::from_gpu);
    m.def("gpu_sync", &migraphx::gpu::gpu_sync);
#endif

Paul's avatar
Paul committed
236
237
238
239
240
241
#ifdef VERSION_INFO
    m.attr("__version__") = VERSION_INFO;
#else
    m.attr("__version__") = "dev";
#endif
}