Commit c570fb57 authored by Paul's avatar Paul
Browse files

Formatting

parent 389f556d
...@@ -6,12 +6,12 @@ ...@@ -6,12 +6,12 @@
namespace py = pybind11; namespace py = pybind11;
template<class F> template <class F>
struct skip_half struct skip_half
{ {
F f; F f;
template<class A> template <class A>
void operator()(A a) const void operator()(A a) const
{ {
f(a); f(a);
...@@ -20,34 +20,33 @@ struct skip_half ...@@ -20,34 +20,33 @@ struct skip_half
void operator()(migraphx::shape::as<migraphx::half>) const void operator()(migraphx::shape::as<migraphx::half>) const
{ {
throw std::runtime_error("Half not supported in python yet."); throw std::runtime_error("Half not supported in python yet.");
} }
}; };
template<class F> template <class F>
void visit_type(const migraphx::shape& s, F f) void visit_type(const migraphx::shape& s, F f)
{ {
s.visit_type(skip_half<F>{f}); s.visit_type(skip_half<F>{f});
} }
template<class T> template <class T>
py::buffer_info to_buffer_info(T& x) py::buffer_info to_buffer_info(T& x)
{ {
migraphx::shape s = x.get_shape(); migraphx::shape s = x.get_shape();
py::buffer_info b; py::buffer_info b;
visit_type(s, [&](auto as) { visit_type(s, [&](auto as) {
b = py::buffer_info( b = py::buffer_info(x.data(),
x.data(), as.size(),
as.size(), py::format_descriptor<decltype(as())>::format(),
py::format_descriptor<decltype(as())>::format(), s.lens().size(),
s.lens().size(), s.lens(),
s.lens(), s.strides());
s.strides()
);
}); });
return b; return b;
} }
PYBIND11_MODULE(migraphx, m) { PYBIND11_MODULE(migraphx, m)
{
py::class_<migraphx::shape>(m, "shape") py::class_<migraphx::shape>(m, "shape")
.def(py::init<>()) .def(py::init<>())
.def("type", &migraphx::shape::type) .def("type", &migraphx::shape::type)
...@@ -63,15 +62,11 @@ PYBIND11_MODULE(migraphx, m) { ...@@ -63,15 +62,11 @@ PYBIND11_MODULE(migraphx, m) {
.def("scalar", &migraphx::shape::scalar); .def("scalar", &migraphx::shape::scalar);
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol()) py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument &x) -> py::buffer_info { .def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); });
return to_buffer_info(x);
});
py::class_<migraphx::program>(m, "program") py::class_<migraphx::program>(m, "program")
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes) .def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("compile", [](migraphx::program& p, const migraphx::target& t) { .def("compile", [](migraphx::program& p, const migraphx::target& t) { p.compile(t); })
p.compile(t);
})
.def("eval", &migraphx::program::eval); .def("eval", &migraphx::program::eval);
m.def("parse_onnx", &migraphx::parse_onnx); m.def("parse_onnx", &migraphx::parse_onnx);
...@@ -82,4 +77,3 @@ PYBIND11_MODULE(migraphx, m) { ...@@ -82,4 +77,3 @@ PYBIND11_MODULE(migraphx, m) {
m.attr("__version__") = "dev"; m.attr("__version__") = "dev";
#endif #endif
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment