#include #include #include #include #include #include namespace py = pybind11; template struct skip_half { F f; template void operator()(A a) const { f(a); } void operator()(migraphx::shape::as) const { throw std::runtime_error("Half not supported in python yet."); } }; template void visit_type(const migraphx::shape& s, F f) { s.visit_type(skip_half{f}); } template py::buffer_info to_buffer_info(T& x) { migraphx::shape s = x.get_shape(); py::buffer_info b; visit_type(s, [&](auto as) { b = py::buffer_info(x.data(), as.size(), py::format_descriptor::format(), s.lens().size(), s.lens(), s.strides()); }); return b; } PYBIND11_MODULE(migraphx, m) { py::class_(m, "shape") .def(py::init<>()) .def("type", &migraphx::shape::type) .def("lens", &migraphx::shape::lens) .def("strides", &migraphx::shape::strides) .def("elements", &migraphx::shape::elements) .def("bytes", &migraphx::shape::bytes) .def("type_size", &migraphx::shape::type_size) .def("packed", &migraphx::shape::packed) .def("transposed", &migraphx::shape::transposed) .def("broadcasted", &migraphx::shape::broadcasted) .def("standard", &migraphx::shape::standard) .def("scalar", &migraphx::shape::scalar); py::class_(m, "argument", py::buffer_protocol()) .def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); }); py::class_(m, "target"); py::class_(m, "program") .def("get_parameter_shapes", &migraphx::program::get_parameter_shapes) .def("compile", [](migraphx::program& p, const migraphx::target& t) { p.compile(t); }) .def("run", &migraphx::program::eval); m.def("parse_onnx", &migraphx::parse_onnx); m.def("get_target", [](const std::string& name) -> migraphx::target { if(name == "cpu") return migraphx::cpu::target{}; throw std::runtime_error("Target not found: " + name); }); m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0); #ifdef VERSION_INFO m.attr("__version__") = VERSION_INFO; #else m.attr("__version__") = "dev"; #endif }