Unverified Commit b8deb54c authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into fuse-horiz-contiguous

parents fee84355 ca8a54fe
...@@ -49,7 +49,7 @@ constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op) ...@@ -49,7 +49,7 @@ constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op)
{ {
for(; first != last; ++first) for(; first != last; ++first)
{ {
init = op(std::move(init), *first); init = op(static_cast<T&&>(init), *first);
} }
return init; return init;
} }
...@@ -64,6 +64,20 @@ constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first) ...@@ -64,6 +64,20 @@ constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first)
return d_first; return d_first;
} }
template <class InputIt, class OutputIt, class UnaryPredicate>
constexpr OutputIt copy_if(InputIt first, InputIt last, OutputIt d_first, UnaryPredicate pred)
{
for(; first != last; ++first)
{
if(pred(*first))
{
*d_first = *first;
++d_first;
}
}
return d_first;
}
template <class Iterator, class Compare> template <class Iterator, class Compare>
constexpr Iterator is_sorted_until(Iterator first, Iterator last, Compare comp) constexpr Iterator is_sorted_until(Iterator first, Iterator last, Compare comp)
{ {
...@@ -115,6 +129,24 @@ constexpr Iterator find(Iterator first, Iterator last, const T& value) ...@@ -115,6 +129,24 @@ constexpr Iterator find(Iterator first, Iterator last, const T& value)
return find_if(first, last, [&](const auto& x) { return x == value; }); return find_if(first, last, [&](const auto& x) { return x == value; });
} }
template <class InputIt, class UnaryPredicate>
constexpr bool any_of(InputIt first, InputIt last, UnaryPredicate p)
{
return find_if(first, last, p) != last;
}
template <class InputIt, class UnaryPredicate>
constexpr bool none_of(InputIt first, InputIt last, UnaryPredicate p)
{
return find_if(first, last, p) == last;
}
template <class InputIt, class UnaryPredicate>
constexpr bool all_of(InputIt first, InputIt last, UnaryPredicate p)
{
return none_of(first, last, [=](auto&& x) { return not p(x); });
}
template <class Iterator1, class Iterator2> template <class Iterator1, class Iterator2>
constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, Iterator2 s_last) constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, Iterator2 s_last)
{ {
......
...@@ -41,8 +41,15 @@ struct implicit_conversion_op ...@@ -41,8 +41,15 @@ struct implicit_conversion_op
template <index_int N, class U> template <index_int N, class U>
constexpr operator vec<U, N>() const constexpr operator vec<U, N>() const
{ {
static_assert(vec_size<T>() == N, "Vector mismatch size"); if constexpr(vec_size<T>() == 0)
return __builtin_convertvector(x, vec<U, N>); {
return x;
}
else
{
static_assert(vec_size<T>() == N, "Vector mismatch size");
return __builtin_convertvector(x, vec<U, N>);
}
} }
template <class U> template <class U>
......
...@@ -44,7 +44,7 @@ struct shape ...@@ -44,7 +44,7 @@ struct shape
constexpr auto element_space() const { return _c<Strides{}.dot(Lens{} - 1) + 1>; } constexpr auto element_space() const { return _c<Strides{}.dot(Lens{} - 1) + 1>; }
constexpr auto packed() const { return elements() == element_space(); } constexpr auto packed() const { return not skips() and elements() == element_space(); }
constexpr auto broadcasted() const { return _c<Strides{}.product() == 0>; } constexpr auto broadcasted() const { return _c<Strides{}.product() == 0>; }
constexpr auto transposed() const constexpr auto transposed() const
{ {
...@@ -53,16 +53,9 @@ struct shape ...@@ -53,16 +53,9 @@ struct shape
if(shape{}.broadcasted()) if(shape{}.broadcasted())
{ {
index_array s{}; index_array s{};
index_int j = 0; auto out = copy_if(
for(index_int i = 0; i < s.size(); i++) lstrides.begin(), lstrides.end(), s.begin(), [](auto x) { return x != 0; });
{ return not is_sorted(s.begin(), out, greater{});
if(lstrides[i] != 0)
{
s[j] = lstrides[i];
j++;
}
}
return not is_sorted(s.begin(), s.begin() + j, greater{});
} }
else else
{ {
...@@ -70,6 +63,13 @@ struct shape ...@@ -70,6 +63,13 @@ struct shape
} }
}); });
} }
constexpr auto skips() const
{
return return_c([] {
auto lstrides = Strides{};
return none_of(lstrides.begin(), lstrides.end(), [](auto x) { return x == 1; });
});
}
constexpr auto standard() const { return packed() and not transposed(); } constexpr auto standard() const { return packed() and not transposed(); }
...@@ -86,26 +86,34 @@ struct shape ...@@ -86,26 +86,34 @@ struct shape
constexpr index_int index(index_int i) const constexpr index_int index(index_int i) const
{ {
if(this->standard()) if(this->standard())
{
MIGRAPHX_ASSERT(i == compute_index(i));
return i; return i;
}
else else
{ {
const auto rank = this->lens.size(); return compute_index(i);
index_int s = 1;
index_int result = 0;
for(index_int j = 0; j < rank; j++)
{
const index_int k = rank - j - 1;
const index_int stride = this->strides[k];
const index_int len = this->lens[k];
const index_int slen = s * len;
const index_int idx = (i % slen) / s;
result += stride * idx;
s = slen;
}
return result;
} }
} }
constexpr index_int compute_index(index_int i) const
{
const auto rank = this->lens.size();
index_int s = 1;
index_int result = 0;
for(index_int j = 0; j < rank; j++)
{
const index_int k = rank - j - 1;
const index_int stride = this->strides[k];
const index_int len = this->lens[k];
const index_int slen = s * len;
const index_int idx = (i % slen) / s;
result += stride * idx;
s = slen;
}
return result;
}
/// Convert single index into a multi-index /// Convert single index into a multi-index
constexpr index_array multi(index_int idx) const constexpr index_array multi(index_int idx) const
{ {
......
This diff is collapsed.
This diff is collapsed.
...@@ -53,10 +53,10 @@ ...@@ -53,10 +53,10 @@
#include <migraphx/gpu/compile_ops.hpp> #include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/fuse_ops.hpp> #include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp> #include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/lowering.hpp> #include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/mlir_conv.hpp>
#include <migraphx/gpu/pack_int8_args.hpp> #include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/schedule_model.hpp> #include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/gpu/sync_device.hpp> #include <migraphx/gpu/sync_device.hpp>
...@@ -128,7 +128,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -128,7 +128,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
enable_pass(not enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}), fuse_pointwise{}), enable_pass(not enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}), fuse_pointwise{}),
dead_code_elimination{}, dead_code_elimination{},
mlir_conv{&ctx}, fuse_mlir{&ctx},
dead_code_elimination{},
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"}, eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{}, dead_code_elimination{},
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -200,6 +200,15 @@ TEST_CASE(test_shape_broadcasted5) ...@@ -200,6 +200,15 @@ TEST_CASE(test_shape_broadcasted5)
EXPECT(s.broadcasted()); EXPECT(s.broadcasted());
} }
TEST_CASE(test_shape_step_broadcasted)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {0, 3}};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
EXPECT(s.broadcasted());
}
TEST_CASE(test_shape_default_copy) TEST_CASE(test_shape_default_copy)
{ {
migraphx::shape s1{}; migraphx::shape s1{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment