Commit 9421576c authored by Paul's avatar Paul
Browse files

Make program and shape printable

parent 1b6387fa
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp>
namespace py = pybind11; namespace py = pybind11;
...@@ -61,7 +62,12 @@ PYBIND11_MODULE(migraphx, m) ...@@ -61,7 +62,12 @@ PYBIND11_MODULE(migraphx, m)
.def("transposed", &migraphx::shape::transposed) .def("transposed", &migraphx::shape::transposed)
.def("broadcasted", &migraphx::shape::broadcasted) .def("broadcasted", &migraphx::shape::broadcasted)
.def("standard", &migraphx::shape::standard) .def("standard", &migraphx::shape::standard)
.def("scalar", &migraphx::shape::scalar); .def("scalar", &migraphx::shape::scalar)
.def("__repr__",
[](const migraphx::shape &s) {
return migraphx::to_string(s);
}
);
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol()) py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); }); .def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); });
...@@ -71,7 +77,12 @@ PYBIND11_MODULE(migraphx, m) ...@@ -71,7 +77,12 @@ PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::program>(m, "program") py::class_<migraphx::program>(m, "program")
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes) .def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("compile", [](migraphx::program& p, const migraphx::target& t) { p.compile(t); }) .def("compile", [](migraphx::program& p, const migraphx::target& t) { p.compile(t); })
.def("run", &migraphx::program::eval); .def("run", &migraphx::program::eval)
.def("__repr__",
[](const migraphx::program &p) {
return migraphx::to_string(p);
}
);
m.def("parse_onnx", &migraphx::parse_onnx); m.def("parse_onnx", &migraphx::parse_onnx);
......
import migraphx import migraphx
p = migraphx.parse_onnx("conv_relu_maxpool.onnx") p = migraphx.parse_onnx("conv_relu_maxpool.onnx")
print(p)
print("Compiling ...")
p.compile(migraphx.get_target("cpu")) p.compile(migraphx.get_target("cpu"))
print(p)
params = {} params = {}
for key, value in p.get_parameter_shapes().items(): for key, value in p.get_parameter_shapes().items():
print("Parameter {} -> {}".format(key, value))
params[key] = migraphx.generate_argument(value) params[key] = migraphx.generate_argument(value)
r = p.run(params) r = p.run(params)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment