"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "2741f329959ab7603a6bed844137f990a1a87808"
Commit 754dbbd4 authored by Paul's avatar Paul
Browse files

Add equality operators

parent c9cd5e1f
...@@ -97,6 +97,8 @@ PYBIND11_MODULE(migraphx, m) ...@@ -97,6 +97,8 @@ PYBIND11_MODULE(migraphx, m)
.def("broadcasted", &migraphx::shape::broadcasted) .def("broadcasted", &migraphx::shape::broadcasted)
.def("standard", &migraphx::shape::standard) .def("standard", &migraphx::shape::standard)
.def("scalar", &migraphx::shape::scalar) .def("scalar", &migraphx::shape::scalar)
.def("__eq__", std::equal_to<migraphx::shape>{})
.def("__ne__", std::not_equal_to<migraphx::shape>{})
.def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); }); .def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); });
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol()) py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
...@@ -104,14 +106,20 @@ PYBIND11_MODULE(migraphx, m) ...@@ -104,14 +106,20 @@ PYBIND11_MODULE(migraphx, m)
.def("__init__", [](migraphx::argument& x, py::buffer b) { .def("__init__", [](migraphx::argument& x, py::buffer b) {
py::buffer_info info = b.request(); py::buffer_info info = b.request();
new(&x) migraphx::argument(to_shape(info), info.ptr); new(&x) migraphx::argument(to_shape(info), info.ptr);
}); })
.def("__eq__", std::equal_to<migraphx::argument>{})
.def("__ne__", std::not_equal_to<migraphx::argument>{})
.def("__repr__", [](const migraphx::argument& x) { return migraphx::to_string(x); });
py::class_<migraphx::target>(m, "target"); py::class_<migraphx::target>(m, "target");
py::class_<migraphx::program>(m, "program") py::class_<migraphx::program>(m, "program")
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes) .def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("get_shape", &migraphx::program::get_shape)
.def("compile", [](migraphx::program& p, const migraphx::target& t) { p.compile(t); }) .def("compile", [](migraphx::program& p, const migraphx::target& t) { p.compile(t); })
.def("run", &migraphx::program::eval) .def("run", &migraphx::program::eval)
.def("__eq__", std::equal_to<migraphx::program>{})
.def("__ne__", std::not_equal_to<migraphx::program>{})
.def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); }); .def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); });
m.def("parse_onnx", &migraphx::parse_onnx); m.def("parse_onnx", &migraphx::parse_onnx);
......
...@@ -2,9 +2,12 @@ import migraphx ...@@ -2,9 +2,12 @@ import migraphx
p = migraphx.parse_onnx("conv_relu_maxpool.onnx") p = migraphx.parse_onnx("conv_relu_maxpool.onnx")
print(p) print(p)
s1 = p.get_shape()
print("Compiling ...") print("Compiling ...")
p.compile(migraphx.get_target("cpu")) p.compile(migraphx.get_target("cpu"))
print(p) print(p)
s2 = p.get_shape()
assert s1 == s2
params = {} params = {}
for key, value in p.get_parameter_shapes().items(): for key, value in p.get_parameter_shapes().items():
print("Parameter {} -> {}".format(key, value)) print("Parameter {} -> {}".format(key, value))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment