Commit 1b6387fa authored by Paul's avatar Paul
Browse files

Add more complete test

parent 06ff76b3
......@@ -66,14 +66,16 @@ PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); });
py::class_<migraphx::target>(m, "target");
py::class_<migraphx::program>(m, "program")
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("compile", [](migraphx::program& p, const migraphx::target& t) { p.compile(t); })
.def("eval", &migraphx::program::eval);
.def("run", &migraphx::program::eval);
m.def("parse_onnx", &migraphx::parse_onnx);
m.def("target", [](const std::string& name) -> migraphx::target {
m.def("get_target", [](const std::string& name) -> migraphx::target {
if(name == "cpu")
return migraphx::cpu::target{};
throw std::runtime_error("Target not found: " + name);
......
import migraphx
p = migraphx.parse_onnx("conv_relu_maxpool.onnx")
p.compile(migraphx.get_target("cpu"))
params = {}
for key, value in p.get_parameter_shapes().items():
params[key] = migraphx.generate_argument(value)
r = p.run(params)
print(r)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment