Unverified Commit b90d69ae authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Enable constructing argument with tuple and buffer (#919)



* Improve handling of constructing a tuple from a buffer
* Add unit test
* Remove unused function
Co-authored-by: default avatarShucai Xiao <shucai@gmail.com>
parent 0859fe90
......@@ -8,7 +8,7 @@ inline namespace MIGRAPHX_INLINE_NS {
argument::argument(const shape& s) : m_shape(s)
{
auto buffer = make_shared_array<char>(s.bytes());
m_data = {[=]() mutable { return buffer.get(); }};
assign_buffer({[=]() mutable { return buffer.get(); }});
}
argument::argument(shape s, std::nullptr_t)
......@@ -18,10 +18,14 @@ argument::argument(shape s, std::nullptr_t)
argument::argument(const shape& s, const argument::data_t& d) : m_shape(s), m_data(d) {}
argument argument::load(const shape& s, char* buffer)
void argument::assign_buffer(std::function<char*()> d)
{
const shape& s = m_shape;
if(s.type() != shape::tuple_type)
return argument{s, buffer};
{
m_data = {std::move(d)};
return;
}
// Collect all shapes
std::unordered_map<std::size_t, shape> shapes;
{
......@@ -58,19 +62,22 @@ argument argument::load(const shape& s, char* buffer)
// cppcheck-suppress variableScope
std::size_t i = 0;
return fix<argument>([&](auto self, auto ss) {
m_data = fix<data_t>([&](auto self, auto ss) {
data_t result;
if(ss.sub_shapes().empty())
{
argument r{shapes[i], buffer + offsets[i]};
auto n = offsets[i];
result = {[d, n]() mutable { return d() + n; }};
i++;
return r;
return result;
}
std::vector<argument> subs;
std::vector<data_t> subs;
std::transform(ss.sub_shapes().begin(),
ss.sub_shapes().end(),
std::back_inserter(subs),
[&](auto child) { return self(child); });
return argument{subs};
result.sub = subs;
return result;
})(s);
}
......
......@@ -27,29 +27,29 @@ struct argument : raw_data<argument>
template <class F, MIGRAPHX_REQUIRES(std::is_pointer<decltype(std::declval<F>()())>{})>
argument(shape s, F d)
: m_shape(std::move(s)),
m_data({[f = std::move(d)]() mutable { return reinterpret_cast<char*>(f()); }})
: m_shape(std::move(s))
{
assign_buffer([f = std::move(d)]() mutable { return reinterpret_cast<char*>(f()); });
}
template <class T>
argument(shape s, T* d)
: m_shape(std::move(s)), m_data({[d] { return reinterpret_cast<char*>(d); }})
: m_shape(std::move(s))
{
assign_buffer([d] { return reinterpret_cast<char*>(d); });
}
template <class T>
argument(shape s, std::shared_ptr<T> d)
: m_shape(std::move(s)), m_data({[d] { return reinterpret_cast<char*>(d.get()); }})
: m_shape(std::move(s))
{
assign_buffer([d] { return reinterpret_cast<char*>(d.get()); });
}
argument(shape s, std::nullptr_t);
argument(const std::vector<argument>& args);
static argument load(const shape& s, char* buffer);
/// Provides a raw pointer to the data
char* data() const;
......@@ -68,6 +68,7 @@ struct argument : raw_data<argument>
std::vector<argument> get_sub_objects() const;
private:
void assign_buffer(std::function<char*()> d);
struct data_t
{
std::function<char*()> get = nullptr;
......
......@@ -35,7 +35,7 @@ struct load
{
if((offset + s.bytes()) > args[0].get_shape().bytes())
MIGRAPHX_THROW("Load access is out of bounds");
return argument::load(s, args[0].data() + offset);
return argument{s, args[0].data() + offset};
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
......@@ -98,6 +98,22 @@ TEST_CASE(nested_tuple)
EXPECT(a1.to_string() != a3.to_string());
}
TEST_CASE(tuple_construct)
{
migraphx::shape s{{migraphx::shape{migraphx::shape::float_type, {4}},
migraphx::shape{migraphx::shape::int8_type, {3}}}};
migraphx::argument a{s};
EXPECT(a.get_sub_objects().size() == 2);
EXPECT(a.get_shape() == s);
auto b = a; // NOLINT
EXPECT(a.get_shape() == b.get_shape());
EXPECT(a.get_sub_objects().size() == 2);
EXPECT(a.get_sub_objects()[0] == b.get_sub_objects()[0]);
EXPECT(a.get_sub_objects()[1] == b.get_sub_objects()[1]);
EXPECT(a == b);
}
TEST_CASE(tuple_visit)
{
auto a1 = make_tuple(3, 3.0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment