"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "2e337c7fc4eb42c76d548f74cfc7bb5a93740fde"
Unverified Commit abe4ec3e authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Propagate data layout in the operators (#777)



* Add method to compute shape with same layout

* Formatting

* Fix permutation with ambiguous layouts

* Formatting

* Propagate layout for pointwise operators

* Formatting

* Propagate layout for more operators

* Formatting

* Sort with lens

* Formatting

* Simplify permutation sorting

* Formatting

* Propagate layout for concat operator

* Formatting

* Use copy

* Formatting

* Remove header
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 41c0487b
...@@ -29,6 +29,7 @@ add_library(migraphx ...@@ -29,6 +29,7 @@ add_library(migraphx
make_op.cpp make_op.cpp
msgpack.cpp msgpack.cpp
operation.cpp operation.cpp
permutation.cpp
program.cpp program.cpp
module.cpp module.cpp
quantization.cpp quantization.cpp
......
File mode changed from 100644 to 100755
...@@ -25,6 +25,14 @@ struct binary : op_name<Derived> ...@@ -25,6 +25,14 @@ struct binary : op_name<Derived>
{ {
return s0; return s0;
} }
else if(s0.packed() != s1.packed())
{
return s0.packed() ? s0 : s1;
}
else if(s0.broadcasted() != s1.broadcasted())
{
return s0.broadcasted() ? s1.with_lens(s0.lens()) : s0.with_lens(s0.lens());
}
else else
{ {
return {s0.type(), s0.lens()}; return {s0.type(), s0.lens()};
...@@ -34,32 +42,13 @@ struct binary : op_name<Derived> ...@@ -34,32 +42,13 @@ struct binary : op_name<Derived>
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto s1 = args[0].get_shape(); visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
auto s2 = args[1].get_shape(); std::transform(input1.begin(),
if(s1 == s2 and s1.packed()) input1.end(),
{ input2.begin(),
shape std_shape{s1.type(), s1.lens()}; output.begin(),
argument std_result{std_shape, result.data()}; static_cast<const Derived&>(*this).apply());
argument std_arg0{std_shape, args[0].data()}; });
argument std_arg1{std_shape, args[1].data()};
visit_all(std_result, std_arg0, std_arg1)([&](auto output, auto input1, auto input2) {
std::transform(input1.begin(),
input1.end(),
input2.begin(),
output.begin(),
static_cast<const Derived&>(*this).apply());
});
}
else
{
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end()));
});
});
}
return result; return result;
} }
}; };
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -80,7 +81,7 @@ struct concat ...@@ -80,7 +81,7 @@ struct concat
std::vector<std::size_t> new_lens; std::vector<std::size_t> new_lens;
std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens)); std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens));
new_lens[axis] = new_dim_axis; new_lens[axis] = new_dim_axis;
return {type, new_lens}; return shape::from_permutation(type, new_lens, find_permutation(inputs));
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
...@@ -88,17 +89,12 @@ struct concat ...@@ -88,17 +89,12 @@ struct concat
std::vector<std::size_t> coffsets = compute_offsets(output_shape, args); std::vector<std::size_t> coffsets = compute_offsets(output_shape, args);
for(std::size_t l = 0; l < args.size(); l++) for(std::size_t l = 0; l < args.size(); l++)
{ {
auto argl = args[l]; auto argl = args[l];
std::size_t nelements = argl.get_shape().elements();
visit_all(result, argl)([&](auto output, auto input) { visit_all(result, argl)([&](auto output, auto input) {
auto slice_shape = auto slice_shape =
shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()}; shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()};
auto slice = make_view(slice_shape, output.data() + coffsets[l]); auto slice = make_view(slice_shape, output.data() + coffsets[l]);
// cppcheck-suppress useStlAlgorithm std::copy(input.begin(), input.end(), slice.begin());
for(std::size_t i = 0; i < nelements; i++)
{
slice[i] = input[i];
}
}); });
} }
return result; return result;
......
...@@ -57,7 +57,6 @@ struct convolution ...@@ -57,7 +57,6 @@ struct convolution
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
const shape& weights = inputs.at(1); const shape& weights = inputs.at(1);
auto t = input.type();
size_t kdims = input.lens().size() - 2; size_t kdims = input.lens().size() - 2;
if(kdims != this->kdims()) if(kdims != this->kdims())
{ {
...@@ -79,7 +78,7 @@ struct convolution ...@@ -79,7 +78,7 @@ struct convolution
1))); 1)));
} }
return {t, output_lens}; return inputs[0].with_lens(output_lens);
} }
size_t kdims() const size_t kdims() const
......
...@@ -51,7 +51,6 @@ struct deconvolution ...@@ -51,7 +51,6 @@ struct deconvolution
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
const shape& weights = inputs.at(1); const shape& weights = inputs.at(1);
auto t = input.type();
size_t kdims = input.lens().size() - 2; size_t kdims = input.lens().size() - 2;
if(kdims != this->kdims()) if(kdims != this->kdims())
{ {
...@@ -67,7 +66,7 @@ struct deconvolution ...@@ -67,7 +66,7 @@ struct deconvolution
stride[i] * (input.lens()[i + 2] - 1) + stride[i] * (input.lens()[i + 2] - 1) +
((weights.lens()[i + 2] - 1) * dilation[i] + 1) - 2 * padding[i]))); ((weights.lens()[i + 2] - 1) * dilation[i] + 1) - 2 * padding[i])));
} }
return {t, output_lens}; return inputs[0].with_lens(output_lens);
} }
size_t kdims() const size_t kdims() const
......
...@@ -51,10 +51,8 @@ struct pooling ...@@ -51,10 +51,8 @@ struct pooling
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
auto t = input.type(); auto input_lens = input.lens();
size_t kdims = input_lens.size() - 2;
auto input_lens = input.lens();
size_t kdims = input_lens.size() - 2;
if(kdims != this->kdims()) if(kdims != this->kdims())
{ {
MIGRAPHX_THROW("pooling: input k-dims does not match attribute size"); MIGRAPHX_THROW("pooling: input k-dims does not match attribute size");
...@@ -71,7 +69,7 @@ struct pooling ...@@ -71,7 +69,7 @@ struct pooling
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(1, len + 1))); output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(1, len + 1)));
} }
return {t, output_lens}; return inputs[0].with_lens(output_lens);
} }
size_t kdims() const size_t kdims() const
......
...@@ -81,7 +81,7 @@ struct quant_convolution ...@@ -81,7 +81,7 @@ struct quant_convolution
1))); 1)));
} }
return {t, output_lens}; return inputs[0].with_lens(t, output_lens);
} }
size_t kdims() const size_t kdims() const
......
...@@ -92,7 +92,7 @@ struct reduce_op : op_name<Derived> ...@@ -92,7 +92,7 @@ struct reduce_op : op_name<Derived>
lens[axis] = 1; lens[axis] = 1;
} }
return {s.type(), lens}; return inputs[0].with_lens(lens);
} }
template <class T> template <class T>
......
...@@ -20,49 +20,28 @@ struct unary : op_name<Derived> ...@@ -20,49 +20,28 @@ struct unary : op_name<Derived>
{ {
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(1); check_shapes{inputs, static_cast<const Derived&>(*this)}.has(1);
auto s = inputs.at(0); auto s = inputs.at(0);
if(s.packed()) if(s.broadcasted())
{ {
return s; return {s.type(), s.lens()};
} }
else else
{ {
return {s.type(), s.lens()}; return s.with_lens(s.lens());
} }
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto in_shape = args[0].get_shape(); result.visit([&](auto output) {
if(in_shape.packed()) args[0].visit([&](auto input) {
{ std::transform(input.begin(),
shape std_in_shape{in_shape.type(), in_shape.lens()}; input.end(),
shape std_out_shape{output_shape.type(), output_shape.lens()}; output.begin(),
argument arg_in{std_in_shape, args[0].data()}; static_cast<const Derived&>(*this).apply());
argument arg_out{std_out_shape, result.data()};
arg_out.visit([&](auto output) {
arg_in.visit([&](auto input) {
std::transform(input.begin(),
input.end(),
output.begin(),
static_cast<const Derived&>(*this).apply());
});
}); });
} });
else
{
result.visit([&](auto output) {
args[0].visit([&](auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input(idx.begin(), idx.end()));
});
});
});
}
return result; return result;
} }
}; };
......
...@@ -20,29 +20,22 @@ inline Vector reorder_dims(const Vector& dims, const std::vector<int64_t>& permu ...@@ -20,29 +20,22 @@ inline Vector reorder_dims(const Vector& dims, const std::vector<int64_t>& permu
return result; return result;
} }
inline shape reorder_shape(const shape& s, const std::vector<int64_t>& permutation) shape reorder_shape(const shape& s, const std::vector<int64_t>& permutation);
{
return {s.type(), reorder_dims(s.lens(), permutation), reorder_dims(s.strides(), permutation)};
}
template <class Vector, class Op> template <class Vector, class Op>
inline std::vector<int64_t> sort_permutation(const Vector& data, Op op) inline std::vector<int64_t> sort_permutation(const Vector& data, Op op)
{ {
std::vector<std::int64_t> result(data.size()); std::vector<std::int64_t> result(data.size());
std::iota(result.begin(), result.end(), 0); std::iota(result.begin(), result.end(), 0);
std::sort(result.begin(), result.end(), [&](auto x, auto y) { return op(data[x], data[y]); }); std::stable_sort(
result.begin(), result.end(), [&](auto x, auto y) { return op(data[x], data[y]); });
return result; return result;
} }
inline std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation) std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation);
{
return sort_permutation(permutation, std::less<>{});
}
inline std::vector<int64_t> find_permutation(const shape& s) std::vector<int64_t> find_permutation(const shape& s);
{ std::vector<int64_t> find_permutation(const std::vector<shape>& shapes);
return sort_permutation(s.strides(), std::greater<>{});
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -79,6 +79,8 @@ struct shape ...@@ -79,6 +79,8 @@ struct shape
{ {
} }
static shape
from_permutation(type_t t, const std::vector<std::size_t>& l, const std::vector<int64_t>& perm);
type_t type() const; type_t type() const;
const std::vector<std::size_t>& lens() const; const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t>& strides() const; const std::vector<std::size_t>& strides() const;
...@@ -121,6 +123,9 @@ struct shape ...@@ -121,6 +123,9 @@ struct shape
shape normalize_standard() const; shape normalize_standard() const;
shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
shape with_lens(const std::vector<std::size_t>& l) const;
friend bool operator==(const shape& x, const shape& y); friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y); friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x); friend std::ostream& operator<<(std::ostream& os, const shape& x);
......
#include <migraphx/permutation.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/algorithm.hpp>
#include <map>
#include <functional>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
shape reorder_shape(const shape& s, const std::vector<int64_t>& permutation)
{
return {s.type(), reorder_dims(s.lens(), permutation), reorder_dims(s.strides(), permutation)};
}
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
return sort_permutation(permutation, std::less<>{});
}
std::vector<int64_t> find_permutation(const shape& s)
{
std::vector<std::int64_t> result(s.lens().size());
std::iota(result.begin(), result.end(), 0);
std::stable_sort(result.begin(), result.end(), by(std::greater<>{}, [&](auto x) {
return std::make_tuple(s.strides()[x], s.lens()[x]);
}));
return result;
}
std::vector<int64_t> find_permutation(const std::vector<shape>& shapes)
{
if(shapes.empty())
return {};
std::map<std::vector<int64_t>, std::size_t> count;
for(auto&& s : shapes)
{
if(s.broadcasted())
continue;
count[find_permutation(s)]++;
}
if(count.empty())
{
std::vector<int64_t> r(shapes.front().lens().size());
std::iota(r.begin(), r.end(), 0);
return r;
}
auto it = std::max_element(
count.begin(), count.end(), by(std::less<>{}, [](auto&& p) { return p.second; }));
assert(it != count.end());
return it->first;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/permutation.hpp>
#include <numeric> #include <numeric>
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
...@@ -99,6 +100,16 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) ...@@ -99,6 +100,16 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
{ {
} }
shape shape::from_permutation(type_t t,
const std::vector<std::size_t>& l,
const std::vector<int64_t>& perm)
{
auto new_lens = reorder_dims(l, perm);
shape result = reorder_shape({t, new_lens}, invert_permutation(perm));
assert(result.lens() == l);
return result;
}
shape::type_t shape::type() const { return impl->m_type; } shape::type_t shape::type() const { return impl->m_type; }
const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; } const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; } const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; }
...@@ -221,6 +232,18 @@ shape shape::normalize_standard() const ...@@ -221,6 +232,18 @@ shape shape::normalize_standard() const
return *this; return *this;
} }
shape shape::with_lens(type_t t, const std::vector<std::size_t>& l) const
{
assert(l.size() == this->lens().size());
auto perm = find_permutation(*this);
return shape::from_permutation(t, l, perm);
}
shape shape::with_lens(const std::vector<std::size_t>& l) const
{
return this->with_lens(this->type(), l);
}
std::size_t shape::element_space() const { return impl->element_space(); } std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const std::string shape::type_string() const
......
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
#include <array> #include <array>
#include <algorithm> #include <algorithm>
#include <numeric> #include <numeric>
...@@ -386,4 +388,147 @@ TEST_CASE(test_serialize) ...@@ -386,4 +388,147 @@ TEST_CASE(test_serialize)
EXPECT(s3 != s4); EXPECT(s3 != s4);
} }
TEST_CASE(test_with_lens1)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 2}, {1, 2}};
auto s2 = s1.with_lens({4, 3});
EXPECT(s2.transposed());
migraphx::shape s3{migraphx::shape::float_type, {4, 3}, {1, 4}};
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens2)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 2}, {2, 1}};
auto s2 = s1.with_lens({3, 4});
EXPECT(s2.standard());
migraphx::shape s3{migraphx::shape::float_type, {3, 4}};
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous1)
{
migraphx::shape s1{migraphx::shape::float_type, {64, 1, 24, 24}};
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(not s2.transposed());
migraphx::shape s3{migraphx::shape::float_type, {64, 3, 24, 24}};
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous2)
{
auto s1 = migraphx::reorder_shape({migraphx::shape::float_type, {64, 24, 24, 1}}, {0, 3, 1, 2});
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(s2.transposed());
migraphx::shape s3 =
migraphx::reorder_shape({migraphx::shape::float_type, {64, 24, 24, 3}}, {0, 3, 1, 2});
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous3)
{
migraphx::shape s1{migraphx::shape::float_type, {64, 3, 1, 1}};
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(not s2.transposed());
migraphx::shape s3{migraphx::shape::float_type, {64, 3, 24, 24}};
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous4)
{
auto s1 = migraphx::reorder_shape({migraphx::shape::float_type, {64, 1, 1, 3}}, {0, 3, 1, 2});
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(s2.transposed());
migraphx::shape s3 =
migraphx::reorder_shape({migraphx::shape::float_type, {64, 24, 24, 3}}, {0, 3, 1, 2});
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous5)
{
migraphx::shape s1{migraphx::shape::float_type, {1, 5, 24, 24}};
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(not s2.transposed());
migraphx::shape s3{migraphx::shape::float_type, {64, 3, 24, 24}};
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous6)
{
auto s1 = migraphx::reorder_shape({migraphx::shape::float_type, {1, 24, 24, 5}}, {0, 3, 1, 2});
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(s2.transposed());
migraphx::shape s3 =
migraphx::reorder_shape({migraphx::shape::float_type, {64, 24, 24, 3}}, {0, 3, 1, 2});
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous7)
{
auto s1 = migraphx::reorder_shape({migraphx::shape::float_type, {1, 1, 1, 3}}, {0, 3, 1, 2});
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(s2.transposed());
migraphx::shape s3 =
migraphx::reorder_shape({migraphx::shape::float_type, {64, 24, 24, 3}}, {0, 3, 1, 2});
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous8)
{
migraphx::shape s1{migraphx::shape::float_type, {1, 1, 24, 24}};
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(not s2.transposed());
migraphx::shape s3{migraphx::shape::float_type, {64, 3, 24, 24}};
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous9)
{
auto s1 = migraphx::reorder_shape({migraphx::shape::float_type, {1, 24, 24, 1}}, {0, 3, 1, 2});
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(s2.transposed());
migraphx::shape s3 =
migraphx::reorder_shape({migraphx::shape::float_type, {64, 24, 24, 3}}, {0, 3, 1, 2});
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous10)
{
migraphx::shape s1{migraphx::shape::float_type, {3, 2, 4, 1}};
auto s2 = s1.with_lens({3, 2, 4, 1});
EXPECT(not s2.transposed());
migraphx::shape s3{migraphx::shape::float_type, {3, 2, 4, 1}};
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous11)
{
migraphx::shape s1{migraphx::shape::float_type, {64, 1, 1, 1}};
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(s1.standard());
EXPECT(s2.standard());
migraphx::shape s3{migraphx::shape::float_type, {64, 3, 24, 24}};
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous12)
{
migraphx::shape s1{migraphx::shape::float_type, {1, 64, 1, 1}};
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(s1.standard());
EXPECT(s2.standard());
migraphx::shape s3{migraphx::shape::float_type, {64, 3, 24, 24}};
EXPECT(s2 == s3);
}
TEST_CASE(test_with_lens_ambigous13)
{
auto s1 = migraphx::reorder_shape({migraphx::shape::float_type, {1, 1, 1, 3}}, {0, 3, 1, 2});
auto s2 = s1.with_lens({64, 3, 24, 24});
EXPECT(s2.transposed());
migraphx::shape s3 =
migraphx::reorder_shape({migraphx::shape::float_type, {64, 24, 24, 3}}, {0, 3, 1, 2});
EXPECT(s2 == s3);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment