Commit 333cd06e authored by charlie's avatar charlie
Browse files

Python API partial update

Adds dynamic_dimension object
parent 863bdfbf
...@@ -236,7 +236,10 @@ migraphx::shape to_shape(const py::buffer_info& info) ...@@ -236,7 +236,10 @@ migraphx::shape to_shape(const py::buffer_info& info)
MIGRAPHX_PYBIND11_MODULE(migraphx, m) MIGRAPHX_PYBIND11_MODULE(migraphx, m)
{ {
py::class_<migraphx::shape>(m, "shape") py::class_<migraphx::shape> py_shape(m, "shape");
// TODO: update this def to also create dynamic shapes
py_shape
.def(py::init([](py::kwargs kwargs) { .def(py::init([](py::kwargs kwargs) {
auto v = migraphx::to_value(kwargs); auto v = migraphx::to_value(kwargs);
auto t = migraphx::shape::parse_type(v.get("type", "float")); auto t = migraphx::shape::parse_type(v.get("type", "float"));
...@@ -249,19 +252,30 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -249,19 +252,30 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("type", &migraphx::shape::type) .def("type", &migraphx::shape::type)
.def("lens", &migraphx::shape::lens) .def("lens", &migraphx::shape::lens)
.def("strides", &migraphx::shape::strides) .def("strides", &migraphx::shape::strides)
.def("ndim", &migraphx::shape::ndim)
.def("elements", &migraphx::shape::elements) .def("elements", &migraphx::shape::elements)
.def("bytes", &migraphx::shape::bytes) .def("bytes", &migraphx::shape::bytes)
.def("type_string", &migraphx::shape::type_string) .def("type_string", &migraphx::shape::type_string)
.def("type_size", &migraphx::shape::type_size) .def("type_size", &migraphx::shape::type_size)
.def("dyn_dims", &migraphx::shape::dyn_dims)
.def("packed", &migraphx::shape::packed) .def("packed", &migraphx::shape::packed)
.def("transposed", &migraphx::shape::transposed) .def("transposed", &migraphx::shape::transposed)
.def("broadcasted", &migraphx::shape::broadcasted) .def("broadcasted", &migraphx::shape::broadcasted)
.def("standard", &migraphx::shape::standard) .def("standard", &migraphx::shape::standard)
.def("scalar", &migraphx::shape::scalar) .def("scalar", &migraphx::shape::scalar)
.def("dynamic", &migraphx::shape::dynamic)
.def("__eq__", std::equal_to<migraphx::shape>{}) .def("__eq__", std::equal_to<migraphx::shape>{})
.def("__ne__", std::not_equal_to<migraphx::shape>{}) .def("__ne__", std::not_equal_to<migraphx::shape>{})
.def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); }); .def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); });
py::class_<migraphx::shape::dynamic_dimension>(py_shape, "dynamic_dimension")
.def(py::init<std::size_t, std::size_t, std::size_t>())
.def_readwrite("min", &migraphx::shape::dynamic_dimension::min)
.def_readwrite("max", &migraphx::shape::dynamic_dimension::max)
.def_readwrite("opt", &migraphx::shape::dynamic_dimension::opt)
.def("is_fixed", &migraphx::shape::dynamic_dimension::is_fixed)
.def("has_optimal", &migraphx::shape::dynamic_dimension::has_optimal);
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol()) py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); }) .def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
.def(py::init([](py::buffer b) { .def(py::init([](py::buffer b) {
...@@ -428,26 +442,38 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -428,26 +442,38 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
"parse_onnx", "parse_onnx",
[](const std::string& filename, [](const std::string& filename,
unsigned int default_dim_value, unsigned int default_dim_value,
migraphx::shape::dynamic_dimension default_dyn_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims, std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>
map_dyn_input_dims,
bool skip_unknown_operators, bool skip_unknown_operators,
bool print_program_on_error, bool print_program_on_error,
int64_t max_loop_iterations) { int64_t max_loop_iterations,
bool use_dyn_output) {
migraphx::onnx_options options; migraphx::onnx_options options;
options.default_dim_value = default_dim_value; options.default_dim_value = default_dim_value;
options.default_dyn_dim_value = default_dyn_dim_value;
options.map_input_dims = map_input_dims; options.map_input_dims = map_input_dims;
options.map_dyn_input_dims = map_dyn_input_dims;
options.skip_unknown_operators = skip_unknown_operators; options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error; options.print_program_on_error = print_program_on_error;
options.max_loop_iterations = max_loop_iterations; options.max_loop_iterations = max_loop_iterations;
options.use_dyn_output = use_dyn_output;
return migraphx::parse_onnx(filename, options); return migraphx::parse_onnx(filename, options);
}, },
"Parse onnx file", "Parse onnx file",
py::arg("filename"), py::arg("filename"),
py::arg("default_dim_value") = 1, py::arg("default_dim_value") = 0,
py::arg("default_dyn_dim_value") = migraphx::shape::dynamic_dimension{1, 1},
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(), py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("map_dyn_input_dims") =
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(),
py::arg("skip_unknown_operators") = false, py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false, py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10); py::arg("max_loop_iterations") = 10,
py::arg("use_dyn_output") = false);
// TODO: also update reading from ONNX buffer
m.def( m.def(
"parse_onnx_buffer", "parse_onnx_buffer",
[](const std::string& onnx_buffer, [](const std::string& onnx_buffer,
......
...@@ -49,6 +49,15 @@ def test_create_shape_type(): ...@@ -49,6 +49,15 @@ def test_create_shape_type():
assert s.type_size() == 4 assert s.type_size() == 4
def test_create_dynamic_dimension():
dd = migraphx.shape.dynamic_dimension(1, 4)
assert dd.min == 1
assert dd.max == 4
assert dd.opt == 0
assert dd.is_fixed == False
assert dd.has_opt == False
if __name__ == "__main__": if __name__ == "__main__":
test_create_shape() test_create_shape()
test_create_shape_broadcast() test_create_shape_broadcast()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment