"src/vscode:/vscode.git/clone" did not exist on "98754c70eceba7dc0a73b07c6c96963c89b5f8f7"
Commit c570fb57 authored by Paul's avatar Paul
Browse files

Formatting

parent 389f556d
...@@ -6,12 +6,12 @@ ...@@ -6,12 +6,12 @@
namespace py = pybind11; namespace py = pybind11;
template<class F> template <class F>
struct skip_half struct skip_half
{ {
F f; F f;
template<class A> template <class A>
void operator()(A a) const void operator()(A a) const
{ {
f(a); f(a);
...@@ -23,31 +23,30 @@ struct skip_half ...@@ -23,31 +23,30 @@ struct skip_half
} }
}; };
template<class F> template <class F>
void visit_type(const migraphx::shape& s, F f) void visit_type(const migraphx::shape& s, F f)
{ {
s.visit_type(skip_half<F>{f}); s.visit_type(skip_half<F>{f});
} }
template<class T> template <class T>
py::buffer_info to_buffer_info(T& x) py::buffer_info to_buffer_info(T& x)
{ {
migraphx::shape s = x.get_shape(); migraphx::shape s = x.get_shape();
py::buffer_info b; py::buffer_info b;
visit_type(s, [&](auto as) { visit_type(s, [&](auto as) {
b = py::buffer_info( b = py::buffer_info(x.data(),
x.data(),
as.size(), as.size(),
py::format_descriptor<decltype(as())>::format(), py::format_descriptor<decltype(as())>::format(),
s.lens().size(), s.lens().size(),
s.lens(), s.lens(),
s.strides() s.strides());
);
}); });
return b; return b;
} }
PYBIND11_MODULE(migraphx, m) { PYBIND11_MODULE(migraphx, m)
{
py::class_<migraphx::shape>(m, "shape") py::class_<migraphx::shape>(m, "shape")
.def(py::init<>()) .def(py::init<>())
.def("type", &migraphx::shape::type) .def("type", &migraphx::shape::type)
...@@ -63,15 +62,11 @@ PYBIND11_MODULE(migraphx, m) { ...@@ -63,15 +62,11 @@ PYBIND11_MODULE(migraphx, m) {
.def("scalar", &migraphx::shape::scalar); .def("scalar", &migraphx::shape::scalar);
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol()) py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument &x) -> py::buffer_info { .def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); });
return to_buffer_info(x);
});
py::class_<migraphx::program>(m, "program") py::class_<migraphx::program>(m, "program")
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes) .def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("compile", [](migraphx::program& p, const migraphx::target& t) { .def("compile", [](migraphx::program& p, const migraphx::target& t) { p.compile(t); })
p.compile(t);
})
.def("eval", &migraphx::program::eval); .def("eval", &migraphx::program::eval);
m.def("parse_onnx", &migraphx::parse_onnx); m.def("parse_onnx", &migraphx::parse_onnx);
...@@ -82,4 +77,3 @@ PYBIND11_MODULE(migraphx, m) { ...@@ -82,4 +77,3 @@ PYBIND11_MODULE(migraphx, m) {
m.attr("__version__") = "dev"; m.attr("__version__") = "dev";
#endif #endif
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment