Unverified Commit 42685803 authored by shivadbhavsar's avatar shivadbhavsar Committed by GitHub
Browse files

expose enum datatypes to python api (#1655)

Expose the shape::type_t values to be used by the python api and is required by torch_migraphx to support torchbench models.
parent 7b2a5ccf
...@@ -62,6 +62,7 @@ namespace py = pybind11; ...@@ -62,6 +62,7 @@ namespace py = pybind11;
PYBIND11_MODULE(__VA_ARGS__) \ PYBIND11_MODULE(__VA_ARGS__) \
MIGRAPHX_POP_WARNING MIGRAPHX_POP_WARNING
#define MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM(x, t) .value(#x, migraphx::shape::type_t::x)
namespace migraphx { namespace migraphx {
migraphx::value to_value(py::kwargs kwargs); migraphx::value to_value(py::kwargs kwargs);
...@@ -235,7 +236,8 @@ migraphx::shape to_shape(const py::buffer_info& info) ...@@ -235,7 +236,8 @@ migraphx::shape to_shape(const py::buffer_info& info)
MIGRAPHX_PYBIND11_MODULE(migraphx, m) MIGRAPHX_PYBIND11_MODULE(migraphx, m)
{ {
py::class_<migraphx::shape>(m, "shape") py::class_<migraphx::shape> shape_cls(m, "shape");
shape_cls
.def(py::init([](py::kwargs kwargs) { .def(py::init([](py::kwargs kwargs) {
auto v = migraphx::to_value(kwargs); auto v = migraphx::to_value(kwargs);
auto t = migraphx::shape::parse_type(v.get("type", "float")); auto t = migraphx::shape::parse_type(v.get("type", "float"));
...@@ -261,6 +263,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -261,6 +263,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("__ne__", std::not_equal_to<migraphx::shape>{}) .def("__ne__", std::not_equal_to<migraphx::shape>{})
.def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); }); .def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); });
py::enum_<migraphx::shape::type_t>(shape_cls, "type_t")
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM);
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol()) py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); }) .def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
.def(py::init([](py::buffer b) { .def(py::init([](py::buffer b) {
......
...@@ -49,6 +49,16 @@ def test_create_shape_type(): ...@@ -49,6 +49,16 @@ def test_create_shape_type():
assert s.type_size() == 4 assert s.type_size() == 4
def test_type_enum():
mgx_types = [
'bool_type', 'double_type', 'float_type', 'half_type', 'int16_type',
'int32_type', 'int64_type', 'int8_type', 'uint16_type', 'uint32_type',
'uint64_type', 'uint8_type'
]
for t in mgx_types:
assert hasattr(migraphx.shape.type_t, t)
if __name__ == "__main__": if __name__ == "__main__":
test_create_shape() test_create_shape()
test_create_shape_broadcast() test_create_shape_broadcast()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment