Unverified Commit b8deb54c authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into fuse-horiz-contiguous

parents fee84355 ca8a54fe
......@@ -49,7 +49,7 @@ constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op)
{
for(; first != last; ++first)
{
init = op(std::move(init), *first);
init = op(static_cast<T&&>(init), *first);
}
return init;
}
......@@ -64,6 +64,20 @@ constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first)
return d_first;
}
template <class InputIt, class OutputIt, class UnaryPredicate>
constexpr OutputIt copy_if(InputIt first, InputIt last, OutputIt d_first, UnaryPredicate pred)
{
for(; first != last; ++first)
{
if(pred(*first))
{
*d_first = *first;
++d_first;
}
}
return d_first;
}
template <class Iterator, class Compare>
constexpr Iterator is_sorted_until(Iterator first, Iterator last, Compare comp)
{
......@@ -115,6 +129,24 @@ constexpr Iterator find(Iterator first, Iterator last, const T& value)
return find_if(first, last, [&](const auto& x) { return x == value; });
}
template <class InputIt, class UnaryPredicate>
constexpr bool any_of(InputIt first, InputIt last, UnaryPredicate p)
{
return find_if(first, last, p) != last;
}
template <class InputIt, class UnaryPredicate>
constexpr bool none_of(InputIt first, InputIt last, UnaryPredicate p)
{
return find_if(first, last, p) == last;
}
template <class InputIt, class UnaryPredicate>
constexpr bool all_of(InputIt first, InputIt last, UnaryPredicate p)
{
return none_of(first, last, [=](auto&& x) { return not p(x); });
}
template <class Iterator1, class Iterator2>
constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, Iterator2 s_last)
{
......
......@@ -41,8 +41,15 @@ struct implicit_conversion_op
template <index_int N, class U>
constexpr operator vec<U, N>() const
{
static_assert(vec_size<T>() == N, "Vector mismatch size");
return __builtin_convertvector(x, vec<U, N>);
if constexpr(vec_size<T>() == 0)
{
return x;
}
else
{
static_assert(vec_size<T>() == N, "Vector mismatch size");
return __builtin_convertvector(x, vec<U, N>);
}
}
template <class U>
......
......@@ -44,7 +44,7 @@ struct shape
constexpr auto element_space() const { return _c<Strides{}.dot(Lens{} - 1) + 1>; }
constexpr auto packed() const { return elements() == element_space(); }
constexpr auto packed() const { return not skips() and elements() == element_space(); }
constexpr auto broadcasted() const { return _c<Strides{}.product() == 0>; }
constexpr auto transposed() const
{
......@@ -53,16 +53,9 @@ struct shape
if(shape{}.broadcasted())
{
index_array s{};
index_int j = 0;
for(index_int i = 0; i < s.size(); i++)
{
if(lstrides[i] != 0)
{
s[j] = lstrides[i];
j++;
}
}
return not is_sorted(s.begin(), s.begin() + j, greater{});
auto out = copy_if(
lstrides.begin(), lstrides.end(), s.begin(), [](auto x) { return x != 0; });
return not is_sorted(s.begin(), out, greater{});
}
else
{
......@@ -70,6 +63,13 @@ struct shape
}
});
}
constexpr auto skips() const
{
return return_c([] {
auto lstrides = Strides{};
return none_of(lstrides.begin(), lstrides.end(), [](auto x) { return x == 1; });
});
}
constexpr auto standard() const { return packed() and not transposed(); }
......@@ -86,26 +86,34 @@ struct shape
constexpr index_int index(index_int i) const
{
if(this->standard())
{
MIGRAPHX_ASSERT(i == compute_index(i));
return i;
}
else
{
const auto rank = this->lens.size();
index_int s = 1;
index_int result = 0;
for(index_int j = 0; j < rank; j++)
{
const index_int k = rank - j - 1;
const index_int stride = this->strides[k];
const index_int len = this->lens[k];
const index_int slen = s * len;
const index_int idx = (i % slen) / s;
result += stride * idx;
s = slen;
}
return result;
return compute_index(i);
}
}
constexpr index_int compute_index(index_int i) const
{
const auto rank = this->lens.size();
index_int s = 1;
index_int result = 0;
for(index_int j = 0; j < rank; j++)
{
const index_int k = rank - j - 1;
const index_int stride = this->strides[k];
const index_int len = this->lens[k];
const index_int slen = s * len;
const index_int idx = (i % slen) / s;
result += stride * idx;
s = slen;
}
return result;
}
/// Convert single index into a multi-index
constexpr index_array multi(index_int idx) const
{
......
This diff is collapsed.
This diff is collapsed.
......@@ -53,10 +53,10 @@
#include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/mlir_conv.hpp>
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/gpu/sync_device.hpp>
......@@ -128,7 +128,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
enable_pass(not enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}), fuse_pointwise{}),
dead_code_elimination{},
mlir_conv{&ctx},
fuse_mlir{&ctx},
dead_code_elimination{},
lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{},
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -200,6 +200,15 @@ TEST_CASE(test_shape_broadcasted5)
EXPECT(s.broadcasted());
}
TEST_CASE(test_shape_step_broadcasted)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {0, 3}};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
EXPECT(s.broadcasted());
}
TEST_CASE(test_shape_default_copy)
{
migraphx::shape s1{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment