Unverified Commit 66aa4cc8 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add tuple type to shape (#800)



* Add definitions for all pointwise operators

* Formatting

* Add cpp generator class

* Formatting

* Move compilation to core

* Formatting

* Add clock to tmp name

* Add dynamic loader

* Formatting

* Add tests for code gen

* Formatting

* Add test for literals

* Formatting

* Use with_char

* Add missing header

* Fix mismerge

* Ignore tidy warning

* Fxx gcc 5 errors

* Apply fixits

* Skip signed bitwise of status

* Remove unused parameters

* Explicitly add c++14 flag

* Fix tidy warning

* Add tuple type to shape class

* Formatting

* Make data member private

* Formatting

* Add sub arguments

* Formatting

* Trun clang format off

* Disable clang-format

* Improve visiting tuples

* Formatting

* Add more argument tests

* Formatting

* Handle tuple in load

* Formatting

* Remove .o files

* Add tuple type to api

* Formatting

* Fix tidy warnings

* Fix tidy warnings

* Add a test for share method

* Formatting

* Add a test cpp_type

* Suppress tidy warning
Co-authored-by: default avatarShucai Xiao <Shucai.Xiao@amd.com>
parent 8fcb7409
...@@ -7,6 +7,7 @@ include(CheckCXXLinkerFlag) ...@@ -7,6 +7,7 @@ include(CheckCXXLinkerFlag)
add_library(migraphx add_library(migraphx
adjust_allocation.cpp adjust_allocation.cpp
analyze_streams.cpp analyze_streams.cpp
argument.cpp
auto_contiguous.cpp auto_contiguous.cpp
eliminate_common_subexpression.cpp eliminate_common_subexpression.cpp
decompose.cpp decompose.cpp
......
...@@ -49,6 +49,7 @@ shape::type_t to_shape_type(migraphx_shape_datatype_t t) ...@@ -49,6 +49,7 @@ shape::type_t to_shape_type(migraphx_shape_datatype_t t)
{ {
switch(t) switch(t)
{ {
case migraphx_shape_tuple_type: return shape::tuple_type;
#define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \ #define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \
case migraphx_shape_##x: return shape::x; case migraphx_shape_##x: return shape::x;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT)
...@@ -61,6 +62,7 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t) ...@@ -61,6 +62,7 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t)
{ {
switch(t) switch(t)
{ {
case shape::tuple_type: return migraphx_shape_tuple_type;
#define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \ #define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \
case shape::x: return migraphx_shape_##x; case shape::x: return migraphx_shape_##x;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT)
......
...@@ -36,6 +36,7 @@ typedef enum { ...@@ -36,6 +36,7 @@ typedef enum {
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x, #define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x,
/// An enum to represent the different data type inputs /// An enum to represent the different data type inputs
typedef enum { typedef enum {
migraphx_shape_tuple_type,
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES)
} migraphx_shape_datatype_t; } migraphx_shape_datatype_t;
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES #undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
......
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <unordered_map>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
argument::argument(const shape& s) : m_shape(s)
{
auto buffer = make_shared_array<char>(s.bytes());
m_data = {[=]() mutable { return buffer.get(); }};
}
argument::argument(shape s, std::nullptr_t)
: m_shape(std::move(s)), m_data({[] { return nullptr; }})
{
}
argument::argument(const shape& s, const argument::data_t& d) : m_shape(s), m_data(d) {}
argument argument::load(const shape& s, char* buffer)
{
if(s.type() != shape::tuple_type)
return argument{s, buffer};
// Collect all shapes
std::unordered_map<std::size_t, shape> shapes;
{
// cppcheck-suppress variableScope
std::size_t i = 0;
fix([&](auto self, auto ss) {
if(ss.sub_shapes().empty())
{
shapes[i] = ss;
i++;
}
else
{
for(auto&& child : ss.sub_shapes())
self(child);
}
})(s);
}
// Sort by type size
std::vector<std::size_t> order(shapes.size());
std::iota(order.begin(), order.end(), 0);
std::sort(order.begin(), order.end(), by(std::greater<>{}, [&](auto i) {
return shapes[i].type_size();
}));
// Compute offsets
std::unordered_map<std::size_t, std::size_t> offsets;
std::size_t offset = 0;
for(auto i : order)
{
offsets[i] = offset;
offset += shapes[i].bytes();
}
assert(offset == s.bytes());
// cppcheck-suppress variableScope
std::size_t i = 0;
return fix<argument>([&](auto self, auto ss) {
if(ss.sub_shapes().empty())
{
argument r{shapes[i], buffer + offsets[i]};
i++;
return r;
}
std::vector<argument> subs;
std::transform(ss.sub_shapes().begin(),
ss.sub_shapes().end(),
std::back_inserter(subs),
[&](auto child) { return self(child); });
return argument{subs};
})(s);
}
std::vector<shape> to_shapes(const std::vector<argument>& args)
{
std::vector<shape> shapes;
std::transform(args.begin(), args.end(), std::back_inserter(shapes), [](auto&& arg) {
return arg.get_shape();
});
return shapes;
}
argument::argument(const std::vector<argument>& args)
: m_shape(to_shapes(args)), m_data(data_t::from_args(args))
{
}
char* argument::data() const
{
assert(m_shape.type() != shape::tuple_type);
assert(not this->empty());
return m_data.get();
}
bool argument::empty() const { return not m_data.get and m_data.sub.empty(); }
const shape& argument::get_shape() const { return this->m_shape; }
argument argument::reshape(const shape& s) const { return {s, this->m_data}; }
argument::data_t argument::data_t::share() const
{
data_t result;
if(this->get)
{
auto self = std::make_shared<data_t>(*this);
result.get = [self]() mutable { return self->get(); };
}
std::transform(sub.begin(), sub.end(), std::back_inserter(result.sub), [](const auto& d) {
return d.share();
});
return result;
}
argument::data_t argument::data_t::from_args(const std::vector<argument>& args)
{
data_t result;
std::transform(args.begin(), args.end(), std::back_inserter(result.sub), [](auto&& arg) {
return arg.m_data;
});
return result;
}
argument argument::share() const { return {m_shape, m_data.share()}; }
std::vector<argument> argument::get_sub_objects() const
{
std::vector<argument> result;
assert(m_shape.sub_shapes().size() == m_data.sub.size());
std::transform(m_shape.sub_shapes().begin(),
m_shape.sub_shapes().end(),
m_data.sub.begin(),
std::back_inserter(result),
[](auto&& s, auto&& d) {
return argument{s, d};
});
return result;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <functional> #include <functional>
#include <utility> #include <utility>
// clang-format off
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -20,57 +21,61 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -20,57 +21,61 @@ inline namespace MIGRAPHX_INLINE_NS {
*/ */
struct argument : raw_data<argument> struct argument : raw_data<argument>
{ {
argument() {} argument() = default;
argument(const shape& s) : m_shape(s) argument(const shape& s);
{
auto buffer = make_shared_array<char>(s.bytes());
data = [=]() mutable { return buffer.get(); };
}
template <class F, MIGRAPHX_REQUIRES(std::is_pointer<decltype(std::declval<F>()())>{})> template <class F, MIGRAPHX_REQUIRES(std::is_pointer<decltype(std::declval<F>()())>{})>
argument(shape s, F d) argument(shape s, F d)
: data([f = std::move(d)]() mutable { return reinterpret_cast<char*>(f()); }), : m_shape(std::move(s)),
m_shape(std::move(s)) m_data({[f = std::move(d)]() mutable { return reinterpret_cast<char*>(f()); }})
{ {
} }
template <class T> template <class T>
argument(shape s, T* d) argument(shape s, T* d)
: data([d] { return reinterpret_cast<char*>(d); }), m_shape(std::move(s)) : m_shape(std::move(s)), m_data({[d] { return reinterpret_cast<char*>(d); }})
{ {
} }
template <class T> template <class T>
argument(shape s, std::shared_ptr<T> d) argument(shape s, std::shared_ptr<T> d)
: data([d] { return reinterpret_cast<char*>(d.get()); }), m_shape(std::move(s)) : m_shape(std::move(s)), m_data({[d] { return reinterpret_cast<char*>(d.get()); }})
{ {
} }
argument(shape s, std::nullptr_t) : data([] { return nullptr; }), m_shape(std::move(s)) {} argument(shape s, std::nullptr_t);
argument(const std::vector<argument>& args);
static argument load(const shape& s, char* buffer);
/// Provides a raw pointer to the data /// Provides a raw pointer to the data
std::function<char*()> data = nullptr; char* data() const;
/// Whether data is available /// Whether data is available
bool empty() const { return not data; } bool empty() const;
const shape& get_shape() const { return this->m_shape; } const shape& get_shape() const;
argument reshape(const shape& s) const argument reshape(const shape& s) const;
{
argument self = *this;
return {s, [=]() mutable { return self.data(); }};
}
/// Make copy of the argument that is always sharing the data /// Make copy of the argument that is always sharing the data
argument share() const argument share() const;
{
auto self = std::make_shared<argument>(*this); std::vector<argument> get_sub_objects() const;
return {m_shape, [self]() mutable { return self->data(); }};
}
private: private:
struct data_t
{
std::function<char*()> get = nullptr;
std::vector<data_t> sub = {};
data_t share() const;
static data_t from_args(const std::vector<argument>& args);
};
argument(const shape& s, const data_t& d);
shape m_shape; shape m_shape;
data_t m_data{};
}; };
void migraphx_to_value(value& v, const argument& a); void migraphx_to_value(value& v, const argument& a);
...@@ -78,5 +83,6 @@ void migraphx_from_value(const value& v, argument& a); ...@@ -78,5 +83,6 @@ void migraphx_from_value(const value& v, argument& a);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
// clang-format on
#endif #endif
...@@ -66,6 +66,8 @@ struct literal : raw_data<literal> ...@@ -66,6 +66,8 @@ struct literal : raw_data<literal>
const shape& get_shape() const { return this->m_shape; } const shape& get_shape() const { return this->m_shape; }
std::vector<literal> get_sub_objects() const { return {}; }
/// Convert the data to an argument /// Convert the data to an argument
argument get_argument() const argument get_argument() const
{ {
......
...@@ -33,7 +33,7 @@ struct as_shape ...@@ -33,7 +33,7 @@ struct as_shape
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return args.front().reshape(output_shape);
} }
bool is_borrowed() const { return true; } bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
...@@ -64,7 +64,7 @@ struct broadcast ...@@ -64,7 +64,7 @@ struct broadcast
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
...@@ -48,7 +48,7 @@ struct flatten ...@@ -48,7 +48,7 @@ struct flatten
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
...@@ -19,11 +19,8 @@ struct identity ...@@ -19,11 +19,8 @@ struct identity
{ {
std::string name() const { return "identity"; } std::string name() const { return "identity"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); } shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape, std::vector<argument> args) const { return args[0]; }
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -34,7 +34,7 @@ struct load ...@@ -34,7 +34,7 @@ struct load
{ {
if((offset + s.bytes()) > args[0].get_shape().bytes()) if((offset + s.bytes()) > args[0].get_shape().bytes())
MIGRAPHX_THROW("Load access is out of bounds"); MIGRAPHX_THROW("Load access is out of bounds");
return {s, args[0].data() + offset}; return argument::load(s, args[0].data() + offset);
} }
bool is_borrowed() const { return true; } bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
...@@ -66,7 +66,7 @@ struct multibroadcast ...@@ -66,7 +66,7 @@ struct multibroadcast
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
...@@ -68,7 +68,7 @@ struct reshape ...@@ -68,7 +68,7 @@ struct reshape
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } bool is_borrowed() const { return true; }
......
...@@ -37,7 +37,7 @@ struct scalar ...@@ -37,7 +37,7 @@ struct scalar
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
...@@ -75,7 +75,7 @@ struct squeeze ...@@ -75,7 +75,7 @@ struct squeeze
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
...@@ -61,7 +61,7 @@ struct transpose ...@@ -61,7 +61,7 @@ struct transpose
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
...@@ -68,7 +68,7 @@ struct unsqueeze ...@@ -68,7 +68,7 @@ struct unsqueeze
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
...@@ -29,7 +29,15 @@ struct raw_data : raw_data_base ...@@ -29,7 +29,15 @@ struct raw_data : raw_data_base
friend Stream& operator<<(Stream& os, const Derived& d) friend Stream& operator<<(Stream& os, const Derived& d)
{ {
if(not d.empty()) if(not d.empty())
d.visit([&](auto x) { os << x; }); d.visit([&](auto x) { os << x; },
[&](auto&& xs) {
for(auto&& x : xs)
{
os << "{ ";
os << x;
os << " }, ";
}
});
return os; return os;
} }
...@@ -45,9 +53,19 @@ struct raw_data : raw_data_base ...@@ -45,9 +53,19 @@ struct raw_data : raw_data_base
auto&& derived = static_cast<const Derived&>(*this); auto&& derived = static_cast<const Derived&>(*this);
if(derived.empty()) if(derived.empty())
MIGRAPHX_THROW("Visiting empty data!"); MIGRAPHX_THROW("Visiting empty data!");
auto&& s = derived.get_shape(); auto&& s = derived.get_shape();
auto&& buffer = derived.data(); s.visit_type([&](auto as) { v(*(as.from(derived.data()) + s.index(n))); });
s.visit_type([&](auto as) { v(*(as.from(buffer) + s.index(n))); }); }
template <class Visitor, class TupleVisitor>
void visit(Visitor v, TupleVisitor tv) const
{
auto&& derived = static_cast<const Derived&>(*this);
if(derived.empty())
MIGRAPHX_THROW("Visiting empty data!");
auto&& s = derived.get_shape();
s.visit_type([&](auto as) { v(make_view(s, as.from(derived.data()))); },
[&] { tv(derived.get_sub_objects()); });
} }
/** /**
...@@ -60,12 +78,7 @@ struct raw_data : raw_data_base ...@@ -60,12 +78,7 @@ struct raw_data : raw_data_base
template <class Visitor> template <class Visitor>
void visit(Visitor v) const void visit(Visitor v) const
{ {
auto&& derived = static_cast<const Derived&>(*this); visit(v, [&](const auto&) { MIGRAPHX_THROW("Invalid tuple type"); });
if(derived.empty())
MIGRAPHX_THROW("Visiting empty data!");
auto&& s = derived.get_shape();
auto&& buffer = derived.data();
s.visit_type([&](auto as) { v(make_view(s, as.from(buffer))); });
} }
/// Returns true if the raw data is only one element /// Returns true if the raw data is only one element
...@@ -156,43 +169,27 @@ struct raw_data : raw_data_base ...@@ -156,43 +169,27 @@ struct raw_data : raw_data_base
} }
}; };
template <class T, namespace detail {
class U, template <class V1, class V2, class... Ts>
MIGRAPHX_REQUIRES(std::is_base_of<raw_data_base, T>{} && void visit_all_flatten(const shape& s, V1&& v1, V2&& v2, Ts&&... xs)
std::is_base_of<raw_data_base, U>{})>
bool operator==(const T& x, const U& y)
{ {
auto&& xshape = x.get_shape(); s.visit_type([&](auto as) { v1(make_view(xs.get_shape(), as.from(xs.data()))...); },
auto&& yshape = y.get_shape(); [&] { v2(xs.get_sub_objects()...); });
bool result = x.empty() && y.empty();
if(not result && xshape == yshape)
{
auto&& xbuffer = x.data();
auto&& ybuffer = y.data();
// TODO: Dont use tensor view for single values
xshape.visit_type([&](auto as) {
auto xview = make_view(xshape, as.from(xbuffer));
auto yview = make_view(yshape, as.from(ybuffer));
result = xview == yview;
});
}
return result;
} }
template <class T, template <class V1, class V2, class... Ts>
class U, auto visit_all_pack(const shape& s, V1&& v1, V2&& v2)
MIGRAPHX_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator!=(const T& x, const U& y)
{ {
return !(x == y); return [&](auto&&... xs) {
// Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100
visit_all_flatten(s, v1, v2, xs...);
};
} }
namespace detail { template <class V1, class... Ts>
template <class V, class... Ts> auto visit_all_pack(const shape& s, V1&& v1)
void visit_all_impl(const shape& s, V&& v, Ts&&... xs)
{ {
s.visit_type([&](auto as) { v(make_view(xs.get_shape(), as.from(xs.data()))...); }); return visit_all_pack(s, v1, [](auto&&...) { MIGRAPHX_THROW("Invalid tuple type"); });
} }
} // namespace detail } // namespace detail
...@@ -215,10 +212,7 @@ auto visit_all(T&& x, Ts&&... xs) ...@@ -215,10 +212,7 @@ auto visit_all(T&& x, Ts&&... xs)
std::initializer_list<shape::type_t> types = {xs.get_shape().type()...}; std::initializer_list<shape::type_t> types = {xs.get_shape().type()...};
if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); })) if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same"); MIGRAPHX_THROW("Types must be the same");
return [&](auto v) { return [&](auto... vs) { detail::visit_all_pack(s, vs...)(x, xs...); };
// Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100
detail::visit_all_impl(s, v, x, xs...);
};
} }
template <class T> template <class T>
...@@ -240,6 +234,34 @@ auto visit_all(const std::vector<T>& x) ...@@ -240,6 +234,34 @@ auto visit_all(const std::vector<T>& x)
}; };
} }
template <class T,
class U,
MIGRAPHX_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator==(const T& x, const U& y)
{
auto&& xshape = x.get_shape();
auto&& yshape = y.get_shape();
bool result = x.empty() and y.empty();
if(not result and xshape == yshape)
{
visit_all(x, y)([&](auto xview, auto yview) { result = xview == yview; },
[&](auto&& xs, auto&& ys) {
result = std::equal(xs.begin(), xs.end(), ys.begin(), ys.end());
});
}
return result;
}
template <class T,
class U,
MIGRAPHX_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator!=(const T& x, const U& y)
{
return !(x == y);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -40,7 +40,7 @@ struct shape ...@@ -40,7 +40,7 @@ struct shape
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x, #define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
enum type_t enum type_t
{ {
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES) tuple_type
}; };
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES #undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
...@@ -82,6 +82,8 @@ struct shape ...@@ -82,6 +82,8 @@ struct shape
{ {
} }
shape(const std::vector<shape>& subs);
static shape static shape
from_permutation(type_t t, const std::vector<std::size_t>& l, const std::vector<int64_t>& perm); from_permutation(type_t t, const std::vector<std::size_t>& l, const std::vector<int64_t>& perm);
type_t type() const; type_t type() const;
...@@ -179,11 +181,16 @@ struct shape ...@@ -179,11 +181,16 @@ struct shape
type_t type_enum() const { return get_type<type>{}; } type_t type_enum() const { return get_type<type>{}; }
}; };
template <class Visitor> template <class Visitor, class TupleVisitor>
static void visit(type_t t, Visitor v) static void visit(type_t t, Visitor v, TupleVisitor tv)
{ {
switch(t) switch(t)
{ {
case tuple_type:
{
tv();
return;
}
#define MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE(x, t) \ #define MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return; case x: v(as<t>()); return;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE)
...@@ -193,9 +200,15 @@ struct shape ...@@ -193,9 +200,15 @@ struct shape
} }
template <class Visitor> template <class Visitor>
void visit_type(Visitor v) const static void visit(type_t t, Visitor v)
{ {
visit(this->type(), v); return visit(t, v, [] { MIGRAPHX_THROW("Tuple cannot be visited."); });
}
template <class... Visitors>
void visit_type(Visitors... vs) const
{
visit(this->type(), vs...);
} }
template <class Visitor> template <class Visitor>
...@@ -209,6 +222,8 @@ struct shape ...@@ -209,6 +222,8 @@ struct shape
std::string type_string() const; std::string type_string() const;
static type_t parse_type(const std::string& s); static type_t parse_type(const std::string& s);
const std::vector<shape>& sub_shapes() const;
private: private:
std::shared_ptr<const shape_impl> impl; std::shared_ptr<const shape_impl> impl;
......
...@@ -11,8 +11,11 @@ void raw_data_to_value(value& v, const RawData& rd) ...@@ -11,8 +11,11 @@ void raw_data_to_value(value& v, const RawData& rd)
{ {
value result; value result;
result["shape"] = migraphx::to_value(rd.get_shape()); result["shape"] = migraphx::to_value(rd.get_shape());
result["data"] = migraphx::value::binary(rd.data(), rd.get_shape().bytes()); if(rd.get_shape().type() == shape::tuple_type)
v = result; result["sub"] = migraphx::to_value(rd.get_sub_objects());
else
result["data"] = migraphx::value::binary(rd.data(), rd.get_shape().bytes());
v = result;
} }
void migraphx_to_value(value& v, const literal& l) { raw_data_to_value(v, l); } void migraphx_to_value(value& v, const literal& l) { raw_data_to_value(v, l); }
...@@ -25,8 +28,15 @@ void migraphx_from_value(const value& v, literal& l) ...@@ -25,8 +28,15 @@ void migraphx_from_value(const value& v, literal& l)
void migraphx_to_value(value& v, const argument& a) { raw_data_to_value(v, a); } void migraphx_to_value(value& v, const argument& a) { raw_data_to_value(v, a); }
void migraphx_from_value(const value& v, argument& a) void migraphx_from_value(const value& v, argument& a)
{ {
literal l = migraphx::from_value<literal>(v); if(v.contains("data"))
a = l.get_argument(); {
literal l = migraphx::from_value<literal>(v);
a = l.get_argument();
}
else
{
a = migraphx::from_value<std::vector<argument>>(v.at("sub"));
}
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment