"...targets/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "5cef60b8578fd32ff505b8dfb115b602a2822692"
migraphx_py.cpp 4.63 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <migraphx/program.hpp>
Paul's avatar
Paul committed
5
6
#include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp>
Paul's avatar
Paul committed
7
#include <migraphx/onnx.hpp>
Paul's avatar
Paul committed
8
#include <migraphx/stringutils.hpp>
Paul's avatar
Paul committed
9
10
11
12
#ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#endif
Paul's avatar
Paul committed
13
14
15

namespace py = pybind11;

Paul's avatar
Paul committed
16
template <class F>
Paul's avatar
Paul committed
17
struct throw_half
Paul's avatar
Paul committed
18
19
20
{
    F f;

Paul's avatar
Paul committed
21
    template <class A>
Paul's avatar
Paul committed
22
23
24
25
26
27
28
29
    void operator()(A a) const
    {
        f(a);
    }

    void operator()(migraphx::shape::as<migraphx::half>) const
    {
        throw std::runtime_error("Half not supported in python yet.");
Paul's avatar
Paul committed
30
    }
Paul's avatar
Paul committed
31
32
};

Paul's avatar
Paul committed
33
34
35
36
37
38
39
40
41
42
43
template <class F>
struct skip_half
{
    F f;

    template <class A>
    void operator()(A a) const
    {
        f(a);
    }

Paul's avatar
Paul committed
44
    void operator()(migraphx::shape::as<migraphx::half>) const {}
Paul's avatar
Paul committed
45
46
};

Paul's avatar
Paul committed
47
template <class F>
Paul's avatar
Paul committed
48
49
void visit_type(const migraphx::shape& s, F f)
{
Paul's avatar
Paul committed
50
51
52
53
54
55
56
    s.visit_type(throw_half<F>{f});
}

template <class F>
void visit_types(F f)
{
    migraphx::shape::visit_types(skip_half<F>{f});
Paul's avatar
Paul committed
57
58
}

Paul's avatar
Paul committed
59
template <class T>
Paul's avatar
Paul committed
60
61
62
63
64
py::buffer_info to_buffer_info(T& x)
{
    migraphx::shape s = x.get_shape();
    py::buffer_info b;
    visit_type(s, [&](auto as) {
Paul's avatar
Paul committed
65
66
67
68
69
70
        b = py::buffer_info(x.data(),
                            as.size(),
                            py::format_descriptor<decltype(as())>::format(),
                            s.lens().size(),
                            s.lens(),
                            s.strides());
Paul's avatar
Paul committed
71
72
73
74
    });
    return b;
}

Paul's avatar
Paul committed
75
76
77
78
migraphx::shape to_shape(const py::buffer_info& info)
{
    migraphx::shape::type_t t;
    visit_types([&](auto as) {
Paul's avatar
Paul committed
79
        if(info.format == py::format_descriptor<decltype(as())>::format())
Paul's avatar
Paul committed
80
81
82
83
84
            t = as.type_enum();
    });
    return migraphx::shape{t, info.shape, info.strides};
}

Paul's avatar
Paul committed
85
86
PYBIND11_MODULE(migraphx, m)
{
Paul's avatar
Paul committed
87
88
89
90
91
92
93
94
95
96
97
98
    py::class_<migraphx::shape>(m, "shape")
        .def(py::init<>())
        .def("type", &migraphx::shape::type)
        .def("lens", &migraphx::shape::lens)
        .def("strides", &migraphx::shape::strides)
        .def("elements", &migraphx::shape::elements)
        .def("bytes", &migraphx::shape::bytes)
        .def("type_size", &migraphx::shape::type_size)
        .def("packed", &migraphx::shape::packed)
        .def("transposed", &migraphx::shape::transposed)
        .def("broadcasted", &migraphx::shape::broadcasted)
        .def("standard", &migraphx::shape::standard)
Paul's avatar
Paul committed
99
        .def("scalar", &migraphx::shape::scalar)
Paul's avatar
Paul committed
100
101
        .def("__eq__", std::equal_to<migraphx::shape>{})
        .def("__ne__", std::not_equal_to<migraphx::shape>{})
Paul's avatar
Paul committed
102
        .def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); });
Paul's avatar
Paul committed
103
104

    py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
Paul's avatar
Paul committed
105
106
107
        .def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
        .def("__init__", [](migraphx::argument& x, py::buffer b) {
            py::buffer_info info = b.request();
Paul's avatar
Paul committed
108
            new(&x) migraphx::argument(to_shape(info), info.ptr);
Paul's avatar
Paul committed
109
110
111
112
        })
        .def("__eq__", std::equal_to<migraphx::argument>{})
        .def("__ne__", std::not_equal_to<migraphx::argument>{})
        .def("__repr__", [](const migraphx::argument& x) { return migraphx::to_string(x); });
Paul's avatar
Paul committed
113

Paul's avatar
Paul committed
114
115
    py::class_<migraphx::target>(m, "target");

Paul's avatar
Paul committed
116
117
    py::class_<migraphx::program>(m, "program")
        .def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
Paul's avatar
Paul committed
118
        .def("get_shape", &migraphx::program::get_shape)
Paul's avatar
Paul committed
119
        .def("compile", [](migraphx::program& p, const migraphx::target& t) { p.compile(t); })
Paul's avatar
Paul committed
120
        .def("run", &migraphx::program::eval)
Paul's avatar
Paul committed
121
122
        .def("__eq__", std::equal_to<migraphx::program>{})
        .def("__ne__", std::not_equal_to<migraphx::program>{})
Paul's avatar
Paul committed
123
        .def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); });
Paul's avatar
Paul committed
124
125
126

    m.def("parse_onnx", &migraphx::parse_onnx);

Paul's avatar
Paul committed
127
    m.def("get_target", [](const std::string& name) -> migraphx::target {
Paul's avatar
Paul committed
128
        if(name == "cpu")
Paul's avatar
Paul committed
129
            return migraphx::cpu::target{};
Paul's avatar
Paul committed
130
131
132
133
#ifdef HAVE_GPU
        if(name == "gpu")
            return migraphx::gpu::target{};
#endif
Paul's avatar
Paul committed
134
135
136
137
138
        throw std::runtime_error("Target not found: " + name);
    });

    m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);

Paul's avatar
Paul committed
139
140
141
142
143
144
145
146
#ifdef HAVE_GPU
    m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
    m.def("to_gpu", &migraphx::gpu::to_gpu, py::arg("arg"), py::arg("host") = false);
    m.def("from_gpu", &migraphx::gpu::from_gpu);
    m.def("gpu_sync", &migraphx::gpu::gpu_sync);
    m.def("copy_to_gpu", &migraphx::gpu::copy_to_gpu);
#endif

Paul's avatar
Paul committed
147
148
149
150
151
152
#ifdef VERSION_INFO
    m.attr("__version__") = VERSION_INFO;
#else
    m.attr("__version__") = "dev";
#endif
}