Unverified Commit abe4ec3e authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Propagate data layout in the operators (#777)



* Add method to compute shape with same layout

* Formatting

* Fix permutation with ambiguous layouts

* Formatting

* Propagate layout for pointwise operators

* Formatting

* Propagate layout for more operators

* Formatting

* Sort with lens

* Formatting

* Simplify permutation sorting

* Formatting

* Propagate layout for concat operator

* Formatting

* Use copy

* Formatting

* Remove header
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 41c0487b
......@@ -29,6 +29,7 @@ add_library(migraphx
make_op.cpp
msgpack.cpp
operation.cpp
permutation.cpp
program.cpp
module.cpp
quantization.cpp
......
File mode changed from 100644 to 100755
......@@ -25,6 +25,14 @@ struct binary : op_name<Derived>
{
return s0;
}
else if(s0.packed() != s1.packed())
{
return s0.packed() ? s0 : s1;
}
else if(s0.broadcasted() != s1.broadcasted())
{
return s0.broadcasted() ? s1.with_lens(s0.lens()) : s0.with_lens(s0.lens());
}
else
{
return {s0.type(), s0.lens()};
......@@ -34,32 +42,13 @@ struct binary : op_name<Derived>
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto s1 = args[0].get_shape();
auto s2 = args[1].get_shape();
if(s1 == s2 and s1.packed())
{
shape std_shape{s1.type(), s1.lens()};
argument std_result{std_shape, result.data()};
argument std_arg0{std_shape, args[0].data()};
argument std_arg1{std_shape, args[1].data()};
visit_all(std_result, std_arg0, std_arg1)([&](auto output, auto input1, auto input2) {
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
std::transform(input1.begin(),
input1.end(),
input2.begin(),
output.begin(),
static_cast<const Derived&>(*this).apply());
});
}
else
{
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end()));
});
});
}
return result;
}
};
......
......@@ -9,6 +9,7 @@
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
......@@ -80,7 +81,7 @@ struct concat
std::vector<std::size_t> new_lens;
std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens));
new_lens[axis] = new_dim_axis;
return {type, new_lens};
return shape::from_permutation(type, new_lens, find_permutation(inputs));
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
......@@ -89,16 +90,11 @@ struct concat
for(std::size_t l = 0; l < args.size(); l++)
{
auto argl = args[l];
std::size_t nelements = argl.get_shape().elements();
visit_all(result, argl)([&](auto output, auto input) {
auto slice_shape =
shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()};
auto slice = make_view(slice_shape, output.data() + coffsets[l]);
// cppcheck-suppress useStlAlgorithm
for(std::size_t i = 0; i < nelements; i++)
{
slice[i] = input[i];
}
std::copy(input.begin(), input.end(), slice.begin());
});
}
return result;
......
......@@ -57,7 +57,6 @@ struct convolution
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
auto t = input.type();
size_t kdims = input.lens().size() - 2;
if(kdims != this->kdims())
{
......@@ -79,7 +78,7 @@ struct convolution
1)));
}
return {t, output_lens};
return inputs[0].with_lens(output_lens);
}
size_t kdims() const
......
......@@ -51,7 +51,6 @@ struct deconvolution
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
auto t = input.type();
size_t kdims = input.lens().size() - 2;
if(kdims != this->kdims())
{
......@@ -67,7 +66,7 @@ struct deconvolution
stride[i] * (input.lens()[i + 2] - 1) +
((weights.lens()[i + 2] - 1) * dilation[i] + 1) - 2 * padding[i])));
}
return {t, output_lens};
return inputs[0].with_lens(output_lens);
}
size_t kdims() const
......
......@@ -51,8 +51,6 @@ struct pooling
check_shapes{inputs, *this}.has(1);
const shape& input = inputs.at(0);
auto t = input.type();
auto input_lens = input.lens();
size_t kdims = input_lens.size() - 2;
if(kdims != this->kdims())
......@@ -71,7 +69,7 @@ struct pooling
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(1, len + 1)));
}
return {t, output_lens};
return inputs[0].with_lens(output_lens);
}
size_t kdims() const
......
......@@ -81,7 +81,7 @@ struct quant_convolution
1)));
}
return {t, output_lens};
return inputs[0].with_lens(t, output_lens);
}
size_t kdims() const
......
......@@ -92,7 +92,7 @@ struct reduce_op : op_name<Derived>
lens[axis] = 1;
}
return {s.type(), lens};
return inputs[0].with_lens(lens);
}
template <class T>
......
......@@ -20,28 +20,21 @@ struct unary : op_name<Derived>
{
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(1);
auto s = inputs.at(0);
if(s.packed())
if(s.broadcasted())
{
return s;
return {s.type(), s.lens()};
}
else
{
return {s.type(), s.lens()};
return s.with_lens(s.lens());
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto in_shape = args[0].get_shape();
if(in_shape.packed())
{
shape std_in_shape{in_shape.type(), in_shape.lens()};
shape std_out_shape{output_shape.type(), output_shape.lens()};
argument arg_in{std_in_shape, args[0].data()};
argument arg_out{std_out_shape, result.data()};
arg_out.visit([&](auto output) {
arg_in.visit([&](auto input) {
result.visit([&](auto output) {
args[0].visit([&](auto input) {
std::transform(input.begin(),
input.end(),
output.begin(),
......@@ -49,20 +42,6 @@ struct unary : op_name<Derived>
});
});
}
else
{
result.visit([&](auto output) {
args[0].visit([&](auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input(idx.begin(), idx.end()));
});
});
});
}
return result;
}
};
......
......@@ -20,29 +20,22 @@ inline Vector reorder_dims(const Vector& dims, const std::vector<int64_t>& permu
return result;
}
inline shape reorder_shape(const shape& s, const std::vector<int64_t>& permutation)
{
return {s.type(), reorder_dims(s.lens(), permutation), reorder_dims(s.strides(), permutation)};
}
shape reorder_shape(const shape& s, const std::vector<int64_t>& permutation);
template <class Vector, class Op>
inline std::vector<int64_t> sort_permutation(const Vector& data, Op op)
{
std::vector<std::int64_t> result(data.size());
std::iota(result.begin(), result.end(), 0);
std::sort(result.begin(), result.end(), [&](auto x, auto y) { return op(data[x], data[y]); });
std::stable_sort(
result.begin(), result.end(), [&](auto x, auto y) { return op(data[x], data[y]); });
return result;
}
inline std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
return sort_permutation(permutation, std::less<>{});
}
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation);
inline std::vector<int64_t> find_permutation(const shape& s)
{
return sort_permutation(s.strides(), std::greater<>{});
}
std::vector<int64_t> find_permutation(const shape& s);
std::vector<int64_t> find_permutation(const std::vector<shape>& shapes);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -79,6 +79,8 @@ struct shape
{
}
static shape
from_permutation(type_t t, const std::vector<std::size_t>& l, const std::vector<int64_t>& perm);
type_t type() const;
const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t>& strides() const;
......@@ -121,6 +123,9 @@ struct shape
shape normalize_standard() const;
shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
shape with_lens(const std::vector<std::size_t>& l) const;
friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x);
......
#include <migraphx/permutation.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/algorithm.hpp>
#include <map>
#include <functional>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
shape reorder_shape(const shape& s, const std::vector<int64_t>& permutation)
{
return {s.type(), reorder_dims(s.lens(), permutation), reorder_dims(s.strides(), permutation)};
}
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
return sort_permutation(permutation, std::less<>{});
}
std::vector<int64_t> find_permutation(const shape& s)
{
std::vector<std::int64_t> result(s.lens().size());
std::iota(result.begin(), result.end(), 0);
std::stable_sort(result.begin(), result.end(), by(std::greater<>{}, [&](auto x) {
return std::make_tuple(s.strides()[x], s.lens()[x]);
}));
return result;
}
std::vector<int64_t> find_permutation(const std::vector<shape>& shapes)
{
if(shapes.empty())
return {};
std::map<std::vector<int64_t>, std::size_t> count;
for(auto&& s : shapes)
{
if(s.broadcasted())
continue;
count[find_permutation(s)]++;
}
if(count.empty())
{
std::vector<int64_t> r(shapes.front().lens().size());
std::iota(r.begin(), r.end(), 0);
return r;
}
auto it = std::max_element(
count.begin(), count.end(), by(std::less<>{}, [](auto&& p) { return p.second; }));
assert(it != count.end());
return it->first;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -2,6 +2,7 @@
#include <migraphx/shape.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/permutation.hpp>
#include <numeric>
#include <algorithm>
#include <functional>
......@@ -99,6 +100,16 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
{
}
shape shape::from_permutation(type_t t,
const std::vector<std::size_t>& l,
const std::vector<int64_t>& perm)
{
auto new_lens = reorder_dims(l, perm);
shape result = reorder_shape({t, new_lens}, invert_permutation(perm));
assert(result.lens() == l);
return result;
}
shape::type_t shape::type() const { return impl->m_type; }
const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; }
......@@ -221,6 +232,18 @@ shape shape::normalize_standard() const
return *this;
}
shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const
{
assert(l.size() == this->lens().size());
auto perm = find_permutation(*this);
return shape::from_permutation(t, l, perm);
}
shape shape::with_lens(const std::vector<std::size_t>& l) const
{
return this->with_lens(this->type(), l);
}
std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const
......
#include <migraphx/shape.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
#include <array>
#include <algorithm>
#include <numeric>
......@@ -386,4 +388,147 @@ TEST_CASE(test_serialize)
EXPECT(s3 != s4);
}
TEST_CASE(test_with_lens1)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 2}, {1, 2}};
auto s2 = s1.with_lens({4, 3});
EXPECT(s2.transposed());
migraphx::shape s3{migraphx::shape::float_type, {4, 3}, {1, 4}};
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens2)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 2}, {2, 1}};
auto s2 = s1.with_lens({3, 4});
EXPECT(s2.standard());
migraphx::shape s3{migraphx::shape::float_type, {3, 4}};
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous1)
{
migraphx::shape s1{migraphx::shape::float_type, {64, 1, 24, 24}};
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(not s2.transposed());
migraphx::shape s3{migraphx::shape::float_type, {64, 3, 24, 24}};
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous2)
{
auto s1 = migraphx::reorder_shape({migraphx::shape::float_type, {64, 24, 24, 1}}, {0, 3, 1, 2});
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(s2.transposed());
migraphx::shape s3 =
migraphx::reorder_shape({migraphx::shape::float_type, {64, 24, 24, 3}}, {0, 3, 1, 2});
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous3)
{
migraphx::shape s1{migraphx::shape::float_type, {64, 3, 1, 1}};
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(not s2.transposed());
migraphx::shape s3{migraphx::shape::float_type, {64, 3, 24, 24}};
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous4)
{
auto s1 = migraphx::reorder_shape({migraphx::shape::float_type, {64, 1, 1, 3}}, {0, 3, 1, 2});
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(s2.transposed());
migraphx::shape s3 =
migraphx::reorder_shape({migraphx::shape::float_type, {64, 24, 24, 3}}, {0, 3, 1, 2});
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous5)
{
migraphx::shape s1{migraphx::shape::float_type, {1, 5, 24, 24}};
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(not s2.transposed());
migraphx::shape s3{migraphx::shape::float_type, {64, 3, 24, 24}};
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous6)
{
auto s1 = migraphx::reorder_shape({migraphx::shape::float_type, {1, 24, 24, 5}}, {0, 3, 1, 2});
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(s2.transposed());
migraphx::shape s3 =
migraphx::reorder_shape({migraphx::shape::float_type, {64, 24, 24, 3}}, {0, 3, 1, 2});
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous7)
{
auto s1 = migraphx::reorder_shape({migraphx::shape::float_type, {1, 1, 1, 3}}, {0, 3, 1, 2});
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(s2.transposed());
migraphx::shape s3 =
migraphx::reorder_shape({migraphx::shape::float_type, {64, 24, 24, 3}}, {0, 3, 1, 2});
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous8)
{
migraphx::shape s1{migraphx::shape::float_type, {1, 1, 24, 24}};
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(not s2.transposed());
migraphx::shape s3{migraphx::shape::float_type, {64, 3, 24, 24}};
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous9)
{
auto s1 = migraphx::reorder_shape({migraphx::shape::float_type, {1, 24, 24, 1}}, {0, 3, 1, 2});
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(s2.transposed());
migraphx::shape s3 =
migraphx::reorder_shape({migraphx::shape::float_type, {64, 24, 24, 3}}, {0, 3, 1, 2});
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous10)
{
migraphx::shape s1{migraphx::shape::float_type, {3, 2, 4, 1}};
auto s2 = s1.with_lens({3, 2, 4, 1});
EXPECT(not s2.transposed());
migraphx::shape s3{migraphx::shape::float_type, {3, 2, 4, 1}};
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous11)
{
migraphx::shape s1{migraphx::shape::float_type, {64, 1, 1, 1}};
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(s1.standard());
EXPECT(s2.standard());
migraphx::shape s3{migraphx::shape::float_type, {64, 3, 24, 24}};
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous12)
{
migraphx::shape s1{migraphx::shape::float_type, {1, 64, 1, 1}};
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(s1.standard());
EXPECT(s2.standard());
migraphx::shape s3{migraphx::shape::float_type, {64, 3, 24, 24}};
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous13)
{
auto s1 = migraphx::reorder_shape({migraphx::shape::float_type, {1, 1, 1, 3}}, {0, 3, 1, 2});
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(s2.transposed());
migraphx::shape s3 =
migraphx::reorder_shape({migraphx::shape::float_type, {64, 24, 24, 3}}, {0, 3, 1, 2});
EXPECT(s2 == s3);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment