Commit 11c7cee5 authored by Paul's avatar Paul
Browse files

Support buffer info constructor

parent 8a3d1d09
...@@ -62,6 +62,16 @@ struct shape ...@@ -62,6 +62,16 @@ struct shape
shape(type_t t, std::vector<std::size_t> l); shape(type_t t, std::vector<std::size_t> l);
shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s); shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s);
template<class Range>
shape(type_t t, const Range& l)
: shape(t, std::vector<std::size_t>(l.begin(), l.end()))
{}
template<class Range1, class Range2>
shape(type_t t, const Range1& l, const Range2& s)
: shape(t, std::vector<std::size_t>(l.begin(), l.end()), std::vector<std::size_t>(s.begin(), s.end()))
{}
type_t type() const; type_t type() const;
const std::vector<std::size_t>& lens() const; const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t>& strides() const; const std::vector<std::size_t>& strides() const;
...@@ -141,6 +151,11 @@ struct shape ...@@ -141,6 +151,11 @@ struct shape
{ {
return reinterpret_cast<const T*>(buffer) + n; return reinterpret_cast<const T*>(buffer) + n;
} }
type_t type_enum() const
{
return get_type<T>{};
}
}; };
template <class Visitor> template <class Visitor>
...@@ -156,6 +171,15 @@ struct shape ...@@ -156,6 +171,15 @@ struct shape
MIGRAPHX_THROW("Unknown type"); MIGRAPHX_THROW("Unknown type");
} }
template <class Visitor>
static void visit_types(Visitor v)
{
#define MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL(x, t) \
v(as<t>());
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL)
#undef MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL
}
private: private:
std::shared_ptr<const shape_impl> impl; std::shared_ptr<const shape_impl> impl;
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
namespace py = pybind11; namespace py = pybind11;
template <class F> template <class F>
struct skip_half struct throw_half
{ {
F f; F f;
...@@ -30,10 +30,31 @@ struct skip_half ...@@ -30,10 +30,31 @@ struct skip_half
} }
}; };
template <class F>
struct skip_half
{
F f;
template <class A>
void operator()(A a) const
{
f(a);
}
void operator()(migraphx::shape::as<migraphx::half>) const
{}
};
template <class F> template <class F>
void visit_type(const migraphx::shape& s, F f) void visit_type(const migraphx::shape& s, F f)
{ {
s.visit_type(skip_half<F>{f}); s.visit_type(throw_half<F>{f});
}
template <class F>
void visit_types(F f)
{
migraphx::shape::visit_types(skip_half<F>{f});
} }
template <class T> template <class T>
...@@ -52,6 +73,16 @@ py::buffer_info to_buffer_info(T& x) ...@@ -52,6 +73,16 @@ py::buffer_info to_buffer_info(T& x)
return b; return b;
} }
migraphx::shape to_shape(const py::buffer_info& info)
{
migraphx::shape::type_t t;
visit_types([&](auto as) {
if (info.format == py::format_descriptor<decltype(as())>::format())
t = as.type_enum();
});
return migraphx::shape{t, info.shape, info.strides};
}
PYBIND11_MODULE(migraphx, m) PYBIND11_MODULE(migraphx, m)
{ {
py::class_<migraphx::shape>(m, "shape") py::class_<migraphx::shape>(m, "shape")
...@@ -70,7 +101,11 @@ PYBIND11_MODULE(migraphx, m) ...@@ -70,7 +101,11 @@ PYBIND11_MODULE(migraphx, m)
.def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); }); .def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); });
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol()) py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); }); .def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
.def("__init__", [](migraphx::argument& x, py::buffer b) {
py::buffer_info info = b.request();
new (&x) migraphx::argument(to_shape(info), info.ptr);
});
py::class_<migraphx::target>(m, "target"); py::class_<migraphx::target>(m, "target");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment