Unverified Commit 8045f7c8 authored by shivadbhavsar's avatar shivadbhavsar Committed by GitHub
Browse files

pybind updates for torch_migraphx library (#1323)

Add function argument_from_pointer to allow directly passing a migraphx.shape object and a memory address. 
Expose the is_compiled() method from migraphx::program. 
Expose the enum types under migraphx::op. 
parent 7c8f2690
...@@ -40,6 +40,7 @@ ...@@ -40,6 +40,7 @@
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
#ifdef HAVE_GPU #ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
...@@ -82,7 +83,7 @@ void visit_py(T x, F f) ...@@ -82,7 +83,7 @@ void visit_py(T x, F f)
{ {
f(x.template cast<bool>()); f(x.template cast<bool>());
} }
else if(py::isinstance<py::int_>(x)) else if(py::isinstance<py::int_>(x) || py::hasattr(x, "__index__"))
{ {
f(x.template cast<int>()); f(x.template cast<int>());
} }
...@@ -324,6 +325,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -324,6 +325,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("get_parameter_names", &migraphx::program::get_parameter_names) .def("get_parameter_names", &migraphx::program::get_parameter_names)
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes) .def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("get_output_shapes", &migraphx::program::get_output_shapes) .def("get_output_shapes", &migraphx::program::get_output_shapes)
.def("is_compiled", &migraphx::program::is_compiled)
.def( .def(
"compile", "compile",
[](migraphx::program& p, const migraphx::target& t, bool offload_copy, bool fast_math) { [](migraphx::program& p, const migraphx::target& t, bool offload_copy, bool fast_math) {
...@@ -358,18 +360,35 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -358,18 +360,35 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("__ne__", std::not_equal_to<migraphx::program>{}) .def("__ne__", std::not_equal_to<migraphx::program>{})
.def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); }); .def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); });
py::class_<migraphx::operation>(m, "op") py::class_<migraphx::operation> op(m, "op");
.def(py::init([](const std::string& name, py::kwargs kwargs) { op.def(py::init([](const std::string& name, py::kwargs kwargs) {
migraphx::value v = migraphx::value::object{}; migraphx::value v = migraphx::value::object{};
if(kwargs) if(kwargs)
{ {
v = migraphx::to_value(kwargs); v = migraphx::to_value(kwargs);
} }
return migraphx::make_op(name, v); return migraphx::make_op(name, v);
})) }))
.def("name", &migraphx::operation::name); .def("name", &migraphx::operation::name);
py::enum_<migraphx::op::pooling_mode>(op, "pooling_mode")
.value("average", migraphx::op::pooling_mode::average)
.value("max", migraphx::op::pooling_mode::max)
.value("lpnorm", migraphx::op::pooling_mode::lpnorm);
py::enum_<migraphx::op::rnn_direction>(op, "rnn_direction")
.value("forward", migraphx::op::rnn_direction::forward)
.value("reverse", migraphx::op::rnn_direction::reverse)
.value("bidirectional", migraphx::op::rnn_direction::bidirectional);
m.def(
"argument_from_pointer",
[](const migraphx::shape shape, const int64_t address) {
return migraphx::argument(shape, reinterpret_cast<void*>(address));
},
py::arg("shape"),
py::arg("address"));
m.def( m.def(
"parse_tf", "parse_tf",
[](const std::string& filename, [](const std::string& filename,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment