Commit 389f556d authored by Paul's avatar Paul
Browse files

Add initial python module

parent ab6cd9d3
......@@ -4,3 +4,4 @@ ROCmSoftwarePlatform/rocBLAS@30a992ae02fda568688bcd190edd5e277d6674d9
ROCmSoftwarePlatform/MIOpen@1.7.0
blaze,https://bitbucket.org/blaze-lib/blaze/get/f0755dea0e03.tar.gz -X header -DHEADER_DIR=blaze
half,https://github.com/pfultz2/half/archive/1.12.0.tar.gz -X header -H sha256:0a08660b68abb176ebc2a0cdf8de46e3182a7f46c66443bb80dbfaaec98cf969
pybind/pybind11@v2.2.4 -DPYBIND11_TEST=Off --build
......@@ -35,6 +35,7 @@ target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLU
set(PACKAGE_DEPENDS)
add_subdirectory(onnx)
add_subdirectory(py)
add_subdirectory(targets/cpu)
if(MIGRAPHX_ENABLE_GPU)
list(APPEND PACKAGE_DEPENDS MIOpen rocblas)
......
find_package(pybind11 REQUIRED)
pybind11_add_module(migraphx_py migraphx_py.cpp)
target_link_libraries(migraphx_py migraphx migraphx_onnx migraphx_gpu migraphx_cpu)
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <migraphx/program.hpp>
#include <migraphx/onnx.hpp>
namespace py = pybind11;
template<class F>
struct skip_half
{
F f;
template<class A>
void operator()(A a) const
{
f(a);
}
void operator()(migraphx::shape::as<migraphx::half>) const
{
throw std::runtime_error("Half not supported in python yet.");
}
};
template<class F>
void visit_type(const migraphx::shape& s, F f)
{
s.visit_type(skip_half<F>{f});
}
template<class T>
py::buffer_info to_buffer_info(T& x)
{
migraphx::shape s = x.get_shape();
py::buffer_info b;
visit_type(s, [&](auto as) {
b = py::buffer_info(
x.data(),
as.size(),
py::format_descriptor<decltype(as())>::format(),
s.lens().size(),
s.lens(),
s.strides()
);
});
return b;
}
PYBIND11_MODULE(migraphx, m) {
py::class_<migraphx::shape>(m, "shape")
.def(py::init<>())
.def("type", &migraphx::shape::type)
.def("lens", &migraphx::shape::lens)
.def("strides", &migraphx::shape::strides)
.def("elements", &migraphx::shape::elements)
.def("bytes", &migraphx::shape::bytes)
.def("type_size", &migraphx::shape::type_size)
.def("packed", &migraphx::shape::packed)
.def("transposed", &migraphx::shape::transposed)
.def("broadcasted", &migraphx::shape::broadcasted)
.def("standard", &migraphx::shape::standard)
.def("scalar", &migraphx::shape::scalar);
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument &x) -> py::buffer_info {
return to_buffer_info(x);
});
py::class_<migraphx::program>(m, "program")
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("compile", [](migraphx::program& p, const migraphx::target& t) {
p.compile(t);
})
.def("eval", &migraphx::program::eval);
m.def("parse_onnx", &migraphx::parse_onnx);
#ifdef VERSION_INFO
m.attr("__version__") = VERSION_INFO;
#else
m.attr("__version__") = "dev";
#endif
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment