Commit 11c7cee5 authored by Paul's avatar Paul
Browse files

Support buffer info constructor

parent 8a3d1d09
......@@ -61,6 +61,16 @@ struct shape
shape(type_t t);
shape(type_t t, std::vector<std::size_t> l);
shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s);
template<class Range>
shape(type_t t, const Range& l)
: shape(t, std::vector<std::size_t>(l.begin(), l.end()))
{}
template<class Range1, class Range2>
shape(type_t t, const Range1& l, const Range2& s)
: shape(t, std::vector<std::size_t>(l.begin(), l.end()), std::vector<std::size_t>(s.begin(), s.end()))
{}
type_t type() const;
const std::vector<std::size_t>& lens() const;
......@@ -141,6 +151,11 @@ struct shape
{
return reinterpret_cast<const T*>(buffer) + n;
}
type_t type_enum() const
{
return get_type<T>{};
}
};
template <class Visitor>
......@@ -156,6 +171,15 @@ struct shape
MIGRAPHX_THROW("Unknown type");
}
template <class Visitor>
static void visit_types(Visitor v)
{
#define MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL(x, t) \
v(as<t>());
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL)
#undef MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL
}
private:
std::shared_ptr<const shape_impl> impl;
......
......@@ -14,7 +14,7 @@
namespace py = pybind11;
template <class F>
struct skip_half
struct throw_half
{
F f;
......@@ -30,10 +30,31 @@ struct skip_half
}
};
template <class F>
struct skip_half
{
F f;
template <class A>
void operator()(A a) const
{
f(a);
}
void operator()(migraphx::shape::as<migraphx::half>) const
{}
};
template <class F>
void visit_type(const migraphx::shape& s, F f)
{
s.visit_type(skip_half<F>{f});
s.visit_type(throw_half<F>{f});
}
template <class F>
void visit_types(F f)
{
migraphx::shape::visit_types(skip_half<F>{f});
}
template <class T>
......@@ -52,6 +73,16 @@ py::buffer_info to_buffer_info(T& x)
return b;
}
migraphx::shape to_shape(const py::buffer_info& info)
{
migraphx::shape::type_t t;
visit_types([&](auto as) {
if (info.format == py::format_descriptor<decltype(as())>::format())
t = as.type_enum();
});
return migraphx::shape{t, info.shape, info.strides};
}
PYBIND11_MODULE(migraphx, m)
{
py::class_<migraphx::shape>(m, "shape")
......@@ -70,7 +101,11 @@ PYBIND11_MODULE(migraphx, m)
.def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); });
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); });
.def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
.def("__init__", [](migraphx::argument& x, py::buffer b) {
py::buffer_info info = b.request();
new (&x) migraphx::argument(to_shape(info), info.ptr);
});
py::class_<migraphx::target>(m, "target");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment