migraphx_py.cpp 12.5 KB
Newer Older
Paul's avatar
Paul committed
1
2
3

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
Shucai Xiao's avatar
Shucai Xiao committed
4
#include <pybind11/numpy.h>
Paul's avatar
Paul committed
5
#include <migraphx/program.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
6
#include <migraphx/quantization.hpp>
Paul's avatar
Paul committed
7
8
#include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp>
Paul's avatar
Paul committed
9
#include <migraphx/stringutils.hpp>
10
11
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
12
#include <migraphx/type_name.hpp>
13
14
#include <migraphx/load_save.hpp>
#include <migraphx/register_target.hpp>
15
16
#include <migraphx/json.hpp>
#include <migraphx/make_op.hpp>
17

Paul's avatar
Paul committed
18
19
20
#ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp>
#endif
Paul's avatar
Paul committed
21

Shucai Xiao's avatar
Shucai Xiao committed
22
using half   = half_float::half;
Paul's avatar
Paul committed
23
24
namespace py = pybind11;

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
namespace migraphx {

migraphx::value to_value(py::kwargs kwargs);
migraphx::value to_value(py::list lst);

template <class T, class F>
void visit_py(T x, F f)
{
    if(py::isinstance<py::kwargs>(x))
    {
        f(to_value(x.template cast<py::kwargs>()));
    }
    else if(py::isinstance<py::list>(x))
    {
        f(to_value(x.template cast<py::list>()));
    }
    else if(py::isinstance<py::bool_>(x))
    {
        f(x.template cast<bool>());
    }
    else if(py::isinstance<py::int_>(x))
    {
        f(x.template cast<int>());
    }
    else if(py::isinstance<py::float_>(x))
    {
        f(x.template cast<float>());
    }
    else if(py::isinstance<py::str>(x))
    {
        f(x.template cast<std::string>());
    }
    else
    {
        MIGRAPHX_THROW("VISIT_PY: Unsupported data type!");
    }
}

migraphx::value to_value(py::list lst)
{
    migraphx::value v = migraphx::value::array{};
    for(auto val : lst)
    {
        visit_py(val, [&](auto py_val) { v.push_back(py_val); });
    }

    return v;
}

migraphx::value to_value(py::kwargs kwargs)
{
    migraphx::value v = migraphx::value::object{};

    for(auto arg : kwargs)
    {
        auto&& key = py::str(arg.first);
        auto&& val = arg.second;
        visit_py(val, [&](auto py_val) { v[key] = py_val; });
    }

    return v;
}
} // namespace migraphx

Shucai Xiao's avatar
Shucai Xiao committed
89
90
namespace pybind11 {
namespace detail {
Paul's avatar
Paul committed
91

Shucai Xiao's avatar
Shucai Xiao committed
92
93
template <>
struct npy_format_descriptor<half>
Paul's avatar
Paul committed
94
{
Shucai Xiao's avatar
Shucai Xiao committed
95
    static std::string format()
Paul's avatar
Paul committed
96
    {
Shucai Xiao's avatar
Shucai Xiao committed
97
98
        // following: https://docs.python.org/3/library/struct.html#format-characters
        return "e";
Paul's avatar
Paul committed
99
    }
Shucai Xiao's avatar
Shucai Xiao committed
100
    static constexpr auto name() { return _("half"); }
Paul's avatar
Paul committed
101
102
};

Shucai Xiao's avatar
Shucai Xiao committed
103
104
105
} // namespace detail
} // namespace pybind11

Paul's avatar
Paul committed
106
template <class F>
Paul's avatar
Paul committed
107
108
void visit_type(const migraphx::shape& s, F f)
{
Shucai Xiao's avatar
Shucai Xiao committed
109
    s.visit_type(f);
Paul's avatar
Paul committed
110
111
}

Paul's avatar
Paul committed
112
113
114
template <class T, class F>
void visit(const migraphx::raw_data<T>& x, F f)
{
Shucai Xiao's avatar
Shucai Xiao committed
115
    x.visit(f);
Paul's avatar
Paul committed
116
117
}

Paul's avatar
Paul committed
118
119
120
template <class F>
void visit_types(F f)
{
Shucai Xiao's avatar
Shucai Xiao committed
121
    migraphx::shape::visit_types(f);
Paul's avatar
Paul committed
122
123
}

Paul's avatar
Paul committed
124
template <class T>
Paul's avatar
Paul committed
125
126
127
py::buffer_info to_buffer_info(T& x)
{
    migraphx::shape s = x.get_shape();
Paul's avatar
Paul committed
128
129
130
    auto strides      = s.strides();
    std::transform(
        strides.begin(), strides.end(), strides.begin(), [&](auto i) { return i * s.type_size(); });
Paul's avatar
Paul committed
131
132
    py::buffer_info b;
    visit_type(s, [&](auto as) {
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        // migraphx use int8_t data to store bool type, we need to
        // explicitly specify the data type as bool for python
        if(s.type() == migraphx::shape::bool_type)
        {
            b = py::buffer_info(x.data(),
                                as.size(),
                                py::format_descriptor<bool>::format(),
                                s.lens().size(),
                                s.lens(),
                                strides);
        }
        else
        {
            b = py::buffer_info(x.data(),
                                as.size(),
                                py::format_descriptor<decltype(as())>::format(),
                                s.lens().size(),
                                s.lens(),
                                strides);
        }
Paul's avatar
Paul committed
153
154
155
156
    });
    return b;
}

Paul's avatar
Paul committed
157
158
159
migraphx::shape to_shape(const py::buffer_info& info)
{
    migraphx::shape::type_t t;
160
    std::size_t n = 0;
Paul's avatar
Paul committed
161
    visit_types([&](auto as) {
Shucai Xiao's avatar
Shucai Xiao committed
162
163
164
        if(info.format == py::format_descriptor<decltype(as())>::format() or
           (info.format == "l" and py::format_descriptor<decltype(as())>::format() == "q") or
           (info.format == "L" and py::format_descriptor<decltype(as())>::format() == "Q"))
Paul's avatar
Paul committed
165
        {
Paul's avatar
Paul committed
166
            t = as.type_enum();
167
168
            n = sizeof(as());
        }
Shucai Xiao's avatar
Shucai Xiao committed
169
170
171
172
173
        else if(info.format == "?" and py::format_descriptor<decltype(as())>::format() == "b")
        {
            t = migraphx::shape::bool_type;
            n = sizeof(bool);
        }
174
    });
175

Shucai Xiao's avatar
Shucai Xiao committed
176
    if(n == 0)
177
    {
Shucai Xiao's avatar
Shucai Xiao committed
178
        MIGRAPHX_THROW("MIGRAPHX PYTHON: Unsupported data type " + info.format);
179
180
    }

181
182
    auto strides = info.strides;
    std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t {
Paul's avatar
Paul committed
183
        return n > 0 ? i / n : 0;
Paul's avatar
Paul committed
184
    });
185
186
187
188
189
190
191
192
193
194

    // scalar support
    if(info.shape.empty())
    {
        return migraphx::shape{t};
    }
    else
    {
        return migraphx::shape{t, info.shape, strides};
    }
Paul's avatar
Paul committed
195
196
}

Paul's avatar
Paul committed
197
198
PYBIND11_MODULE(migraphx, m)
{
Paul's avatar
Paul committed
199
200
201
202
203
204
205
206
207
208
209
210
    py::class_<migraphx::shape>(m, "shape")
        .def(py::init<>())
        .def("type", &migraphx::shape::type)
        .def("lens", &migraphx::shape::lens)
        .def("strides", &migraphx::shape::strides)
        .def("elements", &migraphx::shape::elements)
        .def("bytes", &migraphx::shape::bytes)
        .def("type_size", &migraphx::shape::type_size)
        .def("packed", &migraphx::shape::packed)
        .def("transposed", &migraphx::shape::transposed)
        .def("broadcasted", &migraphx::shape::broadcasted)
        .def("standard", &migraphx::shape::standard)
Paul's avatar
Paul committed
211
        .def("scalar", &migraphx::shape::scalar)
Paul's avatar
Paul committed
212
213
        .def("__eq__", std::equal_to<migraphx::shape>{})
        .def("__ne__", std::not_equal_to<migraphx::shape>{})
Paul's avatar
Paul committed
214
        .def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); });
Paul's avatar
Paul committed
215
216

    py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
Paul's avatar
Paul committed
217
        .def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
Paul's avatar
Paul committed
218
219
220
221
222
        .def("__init__",
             [](migraphx::argument& x, py::buffer b) {
                 py::buffer_info info = b.request();
                 new(&x) migraphx::argument(to_shape(info), info.ptr);
             })
Paul's avatar
Paul committed
223
        .def("get_shape", &migraphx::argument::get_shape)
Paul's avatar
Paul committed
224
225
226
227
228
229
        .def("tolist",
             [](migraphx::argument& x) {
                 py::list l{x.get_shape().elements()};
                 visit(x, [&](auto data) { l = py::cast(data.to_vector()); });
                 return l;
             })
Paul's avatar
Paul committed
230
231
232
        .def("__eq__", std::equal_to<migraphx::argument>{})
        .def("__ne__", std::not_equal_to<migraphx::argument>{})
        .def("__repr__", [](const migraphx::argument& x) { return migraphx::to_string(x); });
Paul's avatar
Paul committed
233

Paul's avatar
Paul committed
234
235
    py::class_<migraphx::target>(m, "target");

Paul's avatar
Paul committed
236
    py::class_<migraphx::program>(m, "program")
237
        .def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); })
238
        .def("get_parameter_names", &migraphx::program::get_parameter_names)
Paul's avatar
Paul committed
239
        .def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
240
        .def("get_output_shapes", &migraphx::program::get_output_shapes)
kahmed10's avatar
kahmed10 committed
241
242
243
244
245
246
247
248
249
250
251
        .def(
            "compile",
            [](migraphx::program& p, const migraphx::target& t, bool offload_copy, bool fast_math) {
                migraphx::compile_options options;
                options.offload_copy = offload_copy;
                options.fast_math    = fast_math;
                p.compile(t, options);
            },
            py::arg("t"),
            py::arg("offload_copy") = true,
            py::arg("fast_math")    = true)
252
253
254
255
256
257
258
259
260
261
262
263
        .def("run",
             [](migraphx::program& p, py::dict params) {
                 migraphx::program::parameter_map pm;
                 for(auto x : params)
                 {
                     std::string key      = x.first.cast<std::string>();
                     py::buffer b         = x.second.cast<py::buffer>();
                     py::buffer_info info = b.request();
                     pm[key]              = migraphx::argument(to_shape(info), info.ptr);
                 }
                 return p.eval(pm);
             })
264
        .def("sort", &migraphx::program::sort)
Paul's avatar
Paul committed
265
266
        .def("__eq__", std::equal_to<migraphx::program>{})
        .def("__ne__", std::not_equal_to<migraphx::program>{})
Paul's avatar
Paul committed
267
        .def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); });
Paul's avatar
Paul committed
268

269
270
271
272
273
274
275
276
277
278
279
280
    py::class_<migraphx::operation>(m, "op")
        .def(py::init([](const std::string& name, py::kwargs kwargs) {
            migraphx::value v = migraphx::value::object{};
            if(kwargs)
            {
                v = migraphx::to_value(kwargs);
            }
            return migraphx::make_op(name, v);
        }))

        .def("name", &migraphx::operation::name);

Khalique's avatar
Khalique committed
281
    m.def("parse_tf",
282
283
284
          [](const std::string& filename, bool is_nhwc, unsigned int batch_size) {
              return migraphx::parse_tf(filename, migraphx::tf_options{is_nhwc, batch_size});
          },
Khalique's avatar
Khalique committed
285
286
          "Parse tf protobuf (default format is nhwc)",
          py::arg("filename"),
287
288
          py::arg("is_nhwc")    = true,
          py::arg("batch_size") = 1);
289

290
    m.def("parse_onnx",
291
          [](const std::string& filename,
292
             unsigned int default_dim_value,
293
             std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
294
295
             bool skip_unknown_operators,
             bool print_program_on_error) {
296
              migraphx::onnx_options options;
297
298
299
300
              options.default_dim_value      = default_dim_value;
              options.map_input_dims         = map_input_dims;
              options.skip_unknown_operators = skip_unknown_operators;
              options.print_program_on_error = print_program_on_error;
301
              return migraphx::parse_onnx(filename, options);
302
303
304
          },
          "Parse onnx file",
          py::arg("filename"),
305
306
307
308
          py::arg("default_dim_value") = 1,
          py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
          py::arg("skip_unknown_operators") = false,
          py::arg("print_program_on_error") = false);
Paul's avatar
Paul committed
309

310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
    m.def("parse_onnx_buffer",
          [](const std::string& onnx_buffer,
             unsigned int default_dim_value,
             std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
             bool skip_unknown_operators,
             bool print_program_on_error) {
              migraphx::onnx_options options;
              options.default_dim_value      = default_dim_value;
              options.map_input_dims         = map_input_dims;
              options.skip_unknown_operators = skip_unknown_operators;
              options.print_program_on_error = print_program_on_error;
              return migraphx::parse_onnx_buffer(onnx_buffer, options);
          },
          "Parse onnx file",
          py::arg("filename"),
          py::arg("default_dim_value") = 1,
          py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
          py::arg("skip_unknown_operators") = false,
          py::arg("print_program_on_error") = false);

330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
    m.def("load",
          [](const std::string& name, const std::string& format) {
              migraphx::file_options options;
              options.format = format;
              return migraphx::load(name, options);
          },
          "Load MIGraphX program",
          py::arg("filename"),
          py::arg("format") = "msgpack");

    m.def("save",
          [](const migraphx::program& p, const std::string& name, const std::string& format) {
              migraphx::file_options options;
              options.format = format;
              return migraphx::save(p, name, options);
          },
          "Save MIGraphX program",
          py::arg("p"),
          py::arg("filename"),
          py::arg("format") = "msgpack");
Paul's avatar
Paul committed
350

351
    m.def("get_target", &migraphx::make_target);
Paul's avatar
Paul committed
352
    m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
Shucai Xiao's avatar
Shucai Xiao committed
353
354
355
356
357
358
359
360
361
362
    m.def("quantize_fp16",
          &migraphx::quantize_fp16,
          py::arg("prog"),
          py::arg("ins_names") = std::vector<std::string>{"all"});
    m.def("quantize_int8",
          &migraphx::quantize_int8,
          py::arg("prog"),
          py::arg("t"),
          py::arg("calibration") = std::vector<migraphx::program::parameter_map>{},
          py::arg("ins_names")   = std::vector<std::string>{"dot", "convolution"});
Shucai Xiao's avatar
Shucai Xiao committed
363

Paul's avatar
Paul committed
364
365
366
367
368
369
370
#ifdef HAVE_GPU
    m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
    m.def("to_gpu", &migraphx::gpu::to_gpu, py::arg("arg"), py::arg("host") = false);
    m.def("from_gpu", &migraphx::gpu::from_gpu);
    m.def("gpu_sync", &migraphx::gpu::gpu_sync);
#endif

Paul's avatar
Paul committed
371
372
373
374
375
376
#ifdef VERSION_INFO
    m.attr("__version__") = VERSION_INFO;
#else
    m.attr("__version__") = "dev";
#endif
}