Unverified Commit 66aa4cc8 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add tuple type to shape (#800)



* Add definitions for all pointwise operators

* Formatting

* Add cpp generator class

* Formatting

* Move compilation to core

* Formatting

* Add clock to tmp name

* Add dynamic loader

* Formatting

* Add tests for code gen

* Formatting

* Add test for literals

* Formatting

* Use with_char

* Add missing header

* Fix mismerge

* Ignore tidy warning

* Fxx gcc 5 errors

* Apply fixits

* Skip signed bitwise of status

* Remove unused parameters

* Explicitly add c++14 flag

* Fix tidy warning

* Add tuple type to shape class

* Formatting

* Make data member private

* Formatting

* Add sub arguments

* Formatting

* Trun clang format off

* Disable clang-format

* Improve visiting tuples

* Formatting

* Add more argument tests

* Formatting

* Handle tuple in load

* Formatting

* Remove .o files

* Add tuple type to api

* Formatting

* Fix tidy warnings

* Fix tidy warnings

* Add a test for share method

* Formatting

* Add a test cpp_type

* Suppress tidy warning
Co-authored-by: default avatarShucai Xiao <Shucai.Xiao@amd.com>
parent 8fcb7409
......@@ -20,28 +20,36 @@ struct shape_impl
return result;
}
shape_impl() : m_type(shape::float_type), m_standard(false) {}
shape_impl() : m_type(shape::float_type) {}
shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({0}), m_standard(true) {}
shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({0}), m_standard(true)
{
assert(t != shape::tuple_type);
}
shape_impl(shape::type_t t, std::vector<std::size_t> l)
: m_type(t), m_lens(std::move(l)), m_standard(true)
{
assert(t != shape::tuple_type);
this->calculate_strides();
assert(m_lens.size() == m_strides.size());
}
shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
{
assert(t != shape::tuple_type);
assert(m_lens.size() == m_strides.size());
// assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and
// "At least one stride must be non-zero");
m_standard = this->elements() == this->element_space() and
std::is_sorted(m_strides.rbegin(), m_strides.rend());
}
shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}
shape::type_t m_type;
std::vector<std::size_t> m_lens;
std::vector<std::size_t> m_strides;
bool m_standard;
std::vector<std::size_t> m_lens = {};
std::vector<std::size_t> m_strides = {};
std::vector<shape> m_shapes = {};
bool m_standard = false;
void calculate_strides()
{
......@@ -84,7 +92,7 @@ const std::vector<shape::type_t>& shape::types()
{
static const std::vector<shape::type_t> result = {
#define MIGRAPHX_GENERATE_TYPE_VECTOR(x, t) x,
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_GENERATE_TYPE_VECTOR)};
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_GENERATE_TYPE_VECTOR) tuple_type};
return result;
}
......@@ -92,6 +100,7 @@ std::string shape::name(shape::type_t t)
{
switch(t)
{
case tuple_type: return "tuple_type";
#define MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE(x, t) \
case x: return #x;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE)
......@@ -103,6 +112,7 @@ std::string shape::cpp_type(shape::type_t t)
{
switch(t)
{
case tuple_type: MIGRAPHX_THROW("No C++ type for tuple");
#define MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE(x, t) \
case x: return #t;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE)
......@@ -123,6 +133,8 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
{
}
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape shape::from_permutation(type_t t,
const std::vector<std::size_t>& l,
const std::vector<int64_t>& perm)
......@@ -139,14 +151,25 @@ const std::vector<std::size_t>& shape::strides() const { return impl->m_strides;
std::size_t shape::elements() const { return impl->elements(); }
std::size_t shape::bytes() const
{
std::size_t n = 0;
this->visit_type([&](auto as) { n = as.size(); });
return n * this->element_space();
if(this->sub_shapes().empty())
{
std::size_t n = 0;
this->visit_type([&](auto as) { n = as.size(); });
return n * this->element_space();
}
else
{
return std::accumulate(this->sub_shapes().begin(),
this->sub_shapes().end(),
std::size_t{0},
[&](auto x, auto y) { return x + y.bytes(); });
}
}
std::size_t shape::type_size() const
{
std::size_t n = 0;
this->visit_type([&](auto as) { n = as.size(); });
if(this->sub_shapes().empty())
this->visit_type([&](auto as) { n = as.size(); });
return n;
}
std::size_t shape::index(std::initializer_list<std::size_t> l) const
......@@ -208,7 +231,10 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end
});
}
bool shape::packed() const { return this->elements() == this->element_space(); }
bool shape::packed() const
{
return this->sub_shapes().empty() and this->elements() == this->element_space();
}
bool shape::transposed() const
{
......@@ -242,7 +268,8 @@ bool shape::scalar() const
{
assert(this->lens().size() == this->strides().size());
// if any stride > 0, then accumulate will return false
return std::accumulate(this->strides().begin(), this->strides().end(), std::size_t(0)) == 0;
return this->sub_shapes().empty() and
std::accumulate(this->strides().begin(), this->strides().end(), std::size_t(0)) == 0;
}
bool shape::standard() const { return impl->m_standard; }
......@@ -273,15 +300,23 @@ std::string shape::type_string() const { return name(this->type()); }
bool operator==(const shape& x, const shape& y)
{
return x.type() == y.type() && x.lens() == y.lens() && x.strides() == y.strides();
return x.impl == y.impl or (x.type() == y.type() and x.lens() == y.lens() and
x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes());
}
bool operator!=(const shape& x, const shape& y) { return !(x == y); }
std::ostream& operator<<(std::ostream& os, const shape& x)
{
os << x.type_string() << ", ";
os << "{" << to_string_range(x.lens()) << "}, ";
os << "{" << to_string_range(x.strides()) << "}";
if(x.sub_shapes().empty())
{
os << x.type_string() << ", ";
os << "{" << to_string_range(x.lens()) << "}, ";
os << "{" << to_string_range(x.strides()) << "}";
}
else
{
os << "[" << to_string_range(x.sub_shapes()) << "]";
}
return os;
}
......@@ -289,23 +324,36 @@ shape::type_t shape::parse_type(const std::string& s)
{
static const std::unordered_map<std::string, shape::type_t> m = {
#define MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP(x, t) {#x, x}, {#t, x},
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP)};
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP){"tuple_type",
tuple_type},
{"tuple", tuple_type}};
return m.at(s);
}
const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; }
void migraphx_to_value(value& v, const shape& s)
{
value result;
result["type"] = migraphx::to_value(s.type_string());
result["lens"] = migraphx::to_value(s.lens());
result["strides"] = migraphx::to_value(s.strides());
v = result;
result["type"] = migraphx::to_value(s.type_string());
result["lens"] = migraphx::to_value(s.lens());
result["strides"] = migraphx::to_value(s.strides());
result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
v = result;
}
void migraphx_from_value(const value& v, shape& s)
{
s = shape{shape::parse_type(v.at("type").get_string()),
v.at("lens").to_vector<std::size_t>(),
v.at("strides").to_vector<std::size_t>()};
auto t = v.at("type").get_string();
if(t == "tuple_type")
{
s = shape{migraphx::from_value<std::vector<migraphx::shape>>(v.at("sub_shapes"))};
}
else
{
s = shape{shape::parse_type(t),
v.at("lens").to_vector<std::size_t>(),
v.at("strides").to_vector<std::size_t>()};
}
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -48,17 +48,6 @@ std::string generate_index_ints(const std::vector<T>& v)
return "index_ints<" + to_string_range(v) + ">{}";
}
std::string generate_cpp_type(shape::type_t t)
{
switch(t)
{
#define MIGRAPHX_GPU_GENERATE_TYPE_STRING(x, t) \
case shape::x: return #t;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_GPU_GENERATE_TYPE_STRING)
}
MIGRAPHX_THROW("Invalid type");
}
std::string generate_make_shape(const shape& s)
{
return "make_shape(" + generate_index_ints(s.lens()) + ", " + generate_index_ints(s.strides()) +
......@@ -80,7 +69,7 @@ std::string generate_make_tensor(std::size_t n, const shape& s)
{
return interpolate_string(make_tensor_template,
{{"n", std::to_string(n)},
{"type", generate_cpp_type(s.type())},
{"type", shape::cpp_type(s.type())},
{"lens", generate_index_ints(s.lens())},
{"strides", generate_index_ints(s.strides())}});
}
......
......@@ -16,6 +16,7 @@ rocblas_datatype get_type(shape::type_t type)
case shape::uint8_type: return rocblas_datatype_u8_r;
case shape::int32_type: return rocblas_datatype_i32_r;
case shape::uint32_type: return rocblas_datatype_u32_r;
case shape::tuple_type:
case shape::bool_type:
case shape::uint16_type:
case shape::int16_type:
......
#include <migraphx/argument.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/serialize.hpp>
#include <sstream>
#include <string>
#include "test.hpp"
migraphx::argument as_argument(migraphx::argument a) { return a; }
template <class T>
migraphx::argument as_argument(T x)
{
return migraphx::literal{x}.get_argument();
}
template <class... Ts>
migraphx::argument make_tuple(Ts... xs)
{
return migraphx::argument{{as_argument(xs)...}};
}
TEST_CASE(copy_eq)
{
auto a1 = as_argument(3);
auto a2 = as_argument(3);
auto a3 = as_argument(1);
auto a4 = a1; // NOLINT
EXPECT(a1 == a2);
EXPECT(a2 != a3);
EXPECT(a1 == a4);
EXPECT(a4 != a3);
EXPECT(a1.get_sub_objects().empty());
EXPECT(a2.get_sub_objects().empty());
EXPECT(a3.get_sub_objects().empty());
EXPECT(a4.get_sub_objects().empty());
}
TEST_CASE(default_construct)
{
migraphx::argument a1{};
migraphx::argument a2{};
EXPECT(a1.empty());
EXPECT(a2.empty());
EXPECT(a1 == a2);
EXPECT(a1.to_string().empty());
EXPECT(a2.to_string().empty());
EXPECT(a1.get_sub_objects().empty());
EXPECT(a2.get_sub_objects().empty());
}
TEST_CASE(string_elems)
{
migraphx::shape s{migraphx::shape::int64_type, {3}};
migraphx::literal l{s, {1, 2, 3}};
auto a = l.get_argument();
EXPECT(a.to_string() == "1, 2, 3");
}
TEST_CASE(tuple)
{
auto a1 = make_tuple(3, 3.0);
EXPECT(a1.get_shape().type() == migraphx::shape::tuple_type);
EXPECT(a1.get_sub_objects().size() == 2);
EXPECT(a1.get_sub_objects()[0] == as_argument(3));
EXPECT(a1.get_sub_objects()[1] == as_argument(3.0));
auto a2 = make_tuple(3, 3.0);
EXPECT(a1 == a2);
EXPECT(a1.to_string() == a2.to_string());
auto a3 = make_tuple(3, 4.0);
EXPECT(a1 != a3);
EXPECT(a1.to_string() != a3.to_string());
}
TEST_CASE(nested_tuple)
{
auto a1 = make_tuple(3, make_tuple(5, 4));
EXPECT(a1.get_shape().type() == migraphx::shape::tuple_type);
EXPECT(a1.get_sub_objects().size() == 2);
EXPECT(a1.get_sub_objects()[0] == as_argument(3));
EXPECT(a1.get_sub_objects()[1] == make_tuple(5, 4));
auto a2 = make_tuple(3, make_tuple(5, 4));
EXPECT(a1 == a2);
EXPECT(a1.to_string() == a2.to_string());
auto a3 = make_tuple(3, make_tuple(5, 6));
EXPECT(a1 != a3);
EXPECT(a1.to_string() != a3.to_string());
}
TEST_CASE(tuple_visit)
{
auto a1 = make_tuple(3, 3.0);
EXPECT(test::throws([&] { a1.visit([](auto&&) {}); }));
EXPECT(test::throws([&] { a1.at<float>(); }));
bool reaches = false;
a1.visit([&](auto&&) { EXPECT(false); },
[&](auto&& xs) {
reaches = true;
EXPECT(xs.size() == 2);
EXPECT(xs[0] == as_argument(3));
EXPECT(xs[1] == as_argument(3.0));
});
EXPECT(reaches);
}
TEST_CASE(tuple_visit_all)
{
auto a1 = make_tuple(3, 3.0);
auto a2 = make_tuple(1, 2, 3);
EXPECT(test::throws([&] { visit_all(a1, a2)([](auto&&, auto&&) {}); }));
bool reaches = false;
visit_all(a1, a2)([&](auto&&, auto&&) { EXPECT(false); },
[&](auto&& xs, auto&& ys) {
reaches = true;
EXPECT(xs.size() == 2);
EXPECT(xs[0] == as_argument(3));
EXPECT(xs[1] == as_argument(3.0));
EXPECT(ys.size() == 3);
EXPECT(ys[0] == as_argument(1));
EXPECT(ys[1] == as_argument(2));
EXPECT(ys[2] == as_argument(3));
});
EXPECT(reaches);
}
TEST_CASE(value_argument)
{
migraphx::shape s{migraphx::shape::int64_type, {3}};
migraphx::literal l1{s, {1, 2, 3}};
auto a1 = l1.get_argument();
auto v1 = migraphx::to_value(a1);
migraphx::literal l2{1};
auto a2 = l2.get_argument();
auto v2 = migraphx::to_value(a2);
EXPECT(v1 != v2);
auto a3 = migraphx::from_value<migraphx::argument>(v1);
EXPECT(a3 == a1);
auto a4 = migraphx::from_value<migraphx::argument>(v2);
EXPECT(a4 == a2);
}
TEST_CASE(value_tuple)
{
auto a1 = make_tuple(3, 3.0, make_tuple(3, 4));
auto a2 = make_tuple(1, 2, 3);
auto v1 = migraphx::to_value(a1);
auto v2 = migraphx::to_value(a2);
EXPECT(v1 != v2);
auto a3 = migraphx::from_value<migraphx::argument>(v1);
EXPECT(a3 == a1);
auto a4 = migraphx::from_value<migraphx::argument>(v2);
EXPECT(a4 == a2);
}
TEST_CASE(argument_share)
{
migraphx::shape s{migraphx::shape::int64_type, {3}};
std::vector<char> buffer(s.bytes());
migraphx::argument a1(s, [=]() mutable { return buffer.data(); });
auto a2 = a1; // NOLINT
EXPECT(a1.data() != a2.data());
auto a3 = a1.share();
EXPECT(a1.data() != a3.data());
auto a4 = a3; // NOLINT
EXPECT(a4.data() == a3.data());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -135,21 +135,4 @@ TEST_CASE(value_literal)
EXPECT(l4 == l2);
}
TEST_CASE(value_argument)
{
migraphx::shape s{migraphx::shape::int64_type, {3}};
migraphx::literal l1{s, {1, 2, 3}};
auto a1 = l1.get_argument();
auto v1 = migraphx::to_value(a1);
migraphx::literal l2{1};
auto a2 = l2.get_argument();
auto v2 = migraphx::to_value(a2);
EXPECT(v1 != v2);
auto a3 = migraphx::from_value<migraphx::argument>(v1);
EXPECT(a3 == a1);
auto a4 = migraphx::from_value<migraphx::argument>(v2);
EXPECT(a4 == a2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -72,6 +72,43 @@ TEST_CASE(make_op_invalid_key)
EXPECT(test::throws([] { migraphx::make_op("convolution", {{"paddings", {1, 1}}}); }));
}
TEST_CASE(load_offset)
{
migraphx::shape s{migraphx::shape::float_type, {4}};
migraphx::shape bs{migraphx::shape::int8_type, {32}};
auto op = migraphx::make_op("load", {{"offset", 4}, {"shape", migraphx::to_value(s)}});
EXPECT(op.compute_shape({bs}) == s);
migraphx::argument a{bs};
EXPECT(op.compute(bs, {a}).data() == a.data() + 4);
}
TEST_CASE(load_out_of_bounds)
{
migraphx::shape s{migraphx::shape::float_type, {4}};
migraphx::shape bs{migraphx::shape::int8_type, {16}};
auto op = migraphx::make_op("load", {{"offset", 4}, {"shape", migraphx::to_value(s)}});
migraphx::argument a{bs};
EXPECT(test::throws([&] { op.compute(bs, {a}); }));
}
TEST_CASE(load_tuple)
{
migraphx::shape s{{migraphx::shape{migraphx::shape::int8_type, {3}},
migraphx::shape{migraphx::shape::float_type, {4}}}};
migraphx::shape bs{migraphx::shape::int8_type, {32}};
auto op = migraphx::make_op("load", {{"offset", 4}, {"shape", migraphx::to_value(s)}});
EXPECT(op.compute_shape({bs}) == s);
migraphx::argument a{bs};
auto r = op.compute(bs, {a});
EXPECT(r.get_sub_objects().size() == 2);
auto* start = a.data() + 4;
EXPECT(r.get_sub_objects()[0].data() == start + 16);
EXPECT(r.get_sub_objects()[1].data() == start);
}
TEST_CASE(ops)
{
auto names = migraphx::get_operators();
......
#include <migraphx/shape.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
#include <array>
......@@ -388,6 +389,74 @@ TEST_CASE(test_serialize)
EXPECT(s3 != s4);
}
TEST_CASE(tuple)
{
migraphx::shape s{{migraphx::shape{migraphx::shape::float_type},
migraphx::shape{migraphx::shape::int8_type}}};
EXPECT(s.type() == migraphx::shape::tuple_type);
EXPECT(s.bytes() == 4 + 1);
EXPECT(s.type_size() == 0);
EXPECT(s.type_string() == "tuple_type");
EXPECT(s.lens().empty());
EXPECT(s.strides().empty());
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.broadcasted());
EXPECT(not s.transposed());
EXPECT(not s.scalar());
EXPECT(s.sub_shapes().size() == 2);
EXPECT(s.sub_shapes()[0].type() == migraphx::shape::float_type);
EXPECT(s.sub_shapes()[0].elements() == 1);
EXPECT(s.sub_shapes()[1].type() == migraphx::shape::int8_type);
EXPECT(s.sub_shapes()[1].elements() == 1);
EXPECT(test::throws([&] { s.visit_type([](auto) {}); }));
}
TEST_CASE(tuple_copy)
{
migraphx::shape s1{{migraphx::shape{migraphx::shape::float_type},
migraphx::shape{migraphx::shape::int8_type}}};
migraphx::shape s2{{migraphx::shape{migraphx::shape::float_type},
migraphx::shape{migraphx::shape::int8_type}}};
EXPECT(s1 == s2);
auto s3 = s1;
EXPECT(s3 == s1);
EXPECT(s3 == s2);
migraphx::shape s4{{migraphx::shape{migraphx::shape::int8_type},
migraphx::shape{migraphx::shape::float_type}}};
EXPECT(s4 != s1);
EXPECT(s4 != s2);
EXPECT(s4 != s3);
}
TEST_CASE(tuple_print)
{
migraphx::shape s{{migraphx::shape{migraphx::shape::float_type},
migraphx::shape{migraphx::shape::int8_type}}};
std::string x = migraphx::to_string(s);
EXPECT(x.front() == '[');
EXPECT(x.back() == ']');
EXPECT(migraphx::contains(x, "float"));
EXPECT(migraphx::contains(x, "int8"));
}
TEST_CASE(tuple_serialize)
{
migraphx::shape s1{{migraphx::shape{migraphx::shape::float_type},
migraphx::shape{migraphx::shape::int8_type}}};
migraphx::shape s2{{migraphx::shape{migraphx::shape::int8_type},
migraphx::shape{migraphx::shape::float_type}}};
auto v1 = migraphx::to_value(s1);
auto v2 = migraphx::to_value(s2);
EXPECT(v1 != v2);
auto s3 = migraphx::from_value<migraphx::shape>(v1);
EXPECT(s3 == s1);
auto s4 = migraphx::from_value<migraphx::shape>(v2);
EXPECT(s4 == s2);
EXPECT(s3 != s4);
}
TEST_CASE(test_with_lens1)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 2}, {1, 2}};
......@@ -531,4 +600,12 @@ TEST_CASE(test_with_lens_ambigous13)
EXPECT(s2 == s3);
}
TEST_CASE(cpp_type_name)
{
EXPECT(migraphx::shape::cpp_type(migraphx::shape::int8_type) == "int8_t");
EXPECT(migraphx::shape::cpp_type(migraphx::shape::float_type) == "float");
EXPECT(migraphx::shape::cpp_type(migraphx::shape::half_type) == "half");
EXPECT(test::throws([&] { migraphx::shape::cpp_type(migraphx::shape::tuple_type); }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -49,6 +49,7 @@ shape::type_t to_shape_type(migraphx_shape_datatype_t t)
{
switch(t)
{
case migraphx_shape_tuple_type: return shape::tuple_type;
#define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \
case migraphx_shape_##x: return shape::x;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT)
......@@ -61,6 +62,7 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t)
{
switch(t)
{
case shape::tuple_type: return migraphx_shape_tuple_type;
#define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \
case shape::x: return migraphx_shape_##x;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT)
......
......@@ -36,6 +36,7 @@ typedef enum {
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x,
/// An enum to represent the different data type inputs
typedef enum {
migraphx_shape_tuple_type,
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES)
} migraphx_shape_datatype_t;
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment