Unverified Commit b75c83d8 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Use jit for contiguous operator (#1217)

* Jit contiguous
parent 8c35fa94
......@@ -210,10 +210,10 @@ void replace(Range&& r, const T& old, const T& new_x)
std::replace(r.begin(), r.end(), old, new_x);
}
template <class R1, class R2>
bool equal(R1&& r1, R2&& r2)
template <class R1, class R2, class... Predicate>
bool equal(R1&& r1, R2&& r2, Predicate... pred)
{
return std::equal(r1.begin(), r1.end(), r2.begin(), r2.end());
return std::equal(r1.begin(), r1.end(), r2.begin(), r2.end(), pred...);
}
template <class R>
......
......@@ -61,9 +61,7 @@ struct shape_impl
{
assert(t != shape::tuple_type);
assert(m_lens.size() == m_strides.size());
// assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and
// "At least one stride must be non-zero");
m_standard = this->elements() == this->element_space() and
m_standard = this->elements() == this->element_space() and not skips() and
std::is_sorted(m_strides.rbegin(), m_strides.rend());
}
......@@ -110,6 +108,15 @@ struct shape_impl
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
}
// Does the shape skip over elements?
bool skips() const
{
assert(m_lens.size() == m_strides.size());
if(elements() == 1)
return false;
return std::none_of(m_strides.begin(), m_strides.end(), [](auto x) { return x == 1; });
}
std::shared_ptr<shape_impl> copy() const { return std::make_shared<shape_impl>(*this); }
};
......@@ -260,7 +267,8 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end
bool shape::packed() const
{
return this->sub_shapes().empty() and this->elements() == this->element_space();
return this->sub_shapes().empty() and not impl->skips() and
this->elements() == this->element_space();
}
bool shape::transposed() const
......@@ -285,10 +293,8 @@ bool shape::transposed() const
bool shape::broadcasted() const
{
assert(this->lens().size() == this->strides().size());
return std::accumulate(this->strides().begin(),
this->strides().end(),
std::size_t{1},
std::multiplies<std::size_t>()) == 0;
return std::any_of(
this->strides().begin(), this->strides().end(), [](auto x) { return x == 0; });
}
bool shape::scalar() const
......
......@@ -48,6 +48,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/array.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp>
#include <cmath>
#include <set>
......@@ -1012,9 +1013,43 @@ struct find_commutative_broadcast
}
};
struct find_contiguous
{
auto matcher() const { return match::name("gpu::contiguous"); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
m.replace_instruction(
ins,
make_op("gpu::precompile_op", {{"op", to_value(make_op("contiguous"))}}),
ins->inputs());
}
};
struct find_contiguous_pointwise
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(precompile_name("pointwise")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto pw = ins->inputs().front();
auto alloc = ins->inputs().back();
auto args = pw->inputs();
args.back() = alloc;
m.replace_instruction(ins, pw->get_operator(), args, pw->module_inputs());
}
};
void fuse_ops::apply(module& m) const
{
match::find_matches(m, find_gelu{}, find_gelu_new{fast_math});
match::find_matches(m, find_contiguous_pointwise{}, find_gelu{}, find_gelu_new{fast_math});
run_passes(m, {dead_code_elimination{}});
match::find_matches(m, find_triadd{});
match::find_matches(m,
......@@ -1036,6 +1071,7 @@ void fuse_ops::apply(module& m) const
find_gemm_add{},
find_gemm_pointwise{},
find_commutative_broadcast{});
match::find_matches(m, find_contiguous{});
}
} // namespace gpu
......
......@@ -79,7 +79,7 @@ static std::vector<std::string> get_op_names(const module& m)
struct pointwise_compiler : compiler<pointwise_compiler>
{
std::vector<std::string> names() const { return {"pointwise"}; }
std::vector<std::string> names() const { return {"pointwise", "contiguous"}; }
static std::size_t oversubscribe_if(bool b)
{
......@@ -114,34 +114,45 @@ struct pointwise_compiler : compiler<pointwise_compiler>
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation&) const
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
assert(not ins->module_inputs().empty());
auto* pm = ins->module_inputs().front();
run_passes(*pm, {eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g;
g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
g.add_point_op("sign",
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
// Add explict conversions
g.fresult(
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
auto name = g.create_function(
g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm));
std::string lambda = "MIGRAPHX_LIFT(" + name + ")";
auto op_names = get_op_names(*pm);
op_names.push_back("kernel");
auto op_name_string = join_strings(op_names, "_");
return replace(
compile_op(ctx,
to_shapes(ins->inputs()),
{{"lambda", lambda}, {"preamble", g.str()}, {"kernel", op_name_string}}));
if(op.name() == "contiguous")
{
return replace(compile_op(
ctx,
to_shapes(ins->inputs()),
{{"lambda", "[](auto x) { return x; }"}, {"kernel", "contiguous_kernel"}}));
}
else
{
assert(not ins->module_inputs().empty());
auto* pm = ins->module_inputs().front();
run_passes(*pm, {eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g;
g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
g.add_point_op("sign",
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
// Add explict conversions
g.fresult([](const shape& s) {
return "migraphx::convert<" + shape::cpp_type(s.type()) + ">";
});
auto name = g.create_function(
g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm));
std::string lambda = "MIGRAPHX_LIFT(" + name + ")";
auto op_names = get_op_names(*pm);
op_names.push_back("kernel");
auto op_name_string = join_strings(op_names, "_");
return replace(compile_op(
ctx,
to_shapes(ins->inputs()),
{{"lambda", lambda}, {"preamble", g.str()}, {"kernel", op_name_string}}));
}
}
};
} // namespace gpu
......
......@@ -49,7 +49,7 @@ constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op)
{
for(; first != last; ++first)
{
init = op(std::move(init), *first);
init = op(static_cast<T&&>(init), *first);
}
return init;
}
......@@ -64,6 +64,20 @@ constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first)
return d_first;
}
template <class InputIt, class OutputIt, class UnaryPredicate>
constexpr OutputIt copy_if(InputIt first, InputIt last, OutputIt d_first, UnaryPredicate pred)
{
for(; first != last; ++first)
{
if(pred(*first))
{
*d_first = *first;
++d_first;
}
}
return d_first;
}
template <class Iterator, class Compare>
constexpr Iterator is_sorted_until(Iterator first, Iterator last, Compare comp)
{
......@@ -115,6 +129,24 @@ constexpr Iterator find(Iterator first, Iterator last, const T& value)
return find_if(first, last, [&](const auto& x) { return x == value; });
}
template <class InputIt, class UnaryPredicate>
constexpr bool any_of(InputIt first, InputIt last, UnaryPredicate p)
{
return find_if(first, last, p) != last;
}
template <class InputIt, class UnaryPredicate>
constexpr bool none_of(InputIt first, InputIt last, UnaryPredicate p)
{
return find_if(first, last, p) == last;
}
template <class InputIt, class UnaryPredicate>
constexpr bool all_of(InputIt first, InputIt last, UnaryPredicate p)
{
return none_of(first, last, [=](auto&& x) { return not p(x); });
}
template <class Iterator1, class Iterator2>
constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, Iterator2 s_last)
{
......
......@@ -41,8 +41,15 @@ struct implicit_conversion_op
template <index_int N, class U>
constexpr operator vec<U, N>() const
{
static_assert(vec_size<T>() == N, "Vector mismatch size");
return __builtin_convertvector(x, vec<U, N>);
if constexpr(vec_size<T>() == 0)
{
return x;
}
else
{
static_assert(vec_size<T>() == N, "Vector mismatch size");
return __builtin_convertvector(x, vec<U, N>);
}
}
template <class U>
......
......@@ -44,7 +44,7 @@ struct shape
constexpr auto element_space() const { return _c<Strides{}.dot(Lens{} - 1) + 1>; }
constexpr auto packed() const { return elements() == element_space(); }
constexpr auto packed() const { return not skips() and elements() == element_space(); }
constexpr auto broadcasted() const { return _c<Strides{}.product() == 0>; }
constexpr auto transposed() const
{
......@@ -53,16 +53,9 @@ struct shape
if(shape{}.broadcasted())
{
index_array s{};
index_int j = 0;
for(index_int i = 0; i < s.size(); i++)
{
if(lstrides[i] != 0)
{
s[j] = lstrides[i];
j++;
}
}
return not is_sorted(s.begin(), s.begin() + j, greater{});
auto out = copy_if(
lstrides.begin(), lstrides.end(), s.begin(), [](auto x) { return x != 0; });
return not is_sorted(s.begin(), out, greater{});
}
else
{
......@@ -70,6 +63,13 @@ struct shape
}
});
}
constexpr auto skips() const
{
return return_c([] {
auto lstrides = Strides{};
return none_of(lstrides.begin(), lstrides.end(), [](auto x) { return x == 1; });
});
}
constexpr auto standard() const { return packed() and not transposed(); }
......@@ -86,26 +86,34 @@ struct shape
constexpr index_int index(index_int i) const
{
if(this->standard())
{
MIGRAPHX_ASSERT(i == compute_index(i));
return i;
}
else
{
const auto rank = this->lens.size();
index_int s = 1;
index_int result = 0;
for(index_int j = 0; j < rank; j++)
{
const index_int k = rank - j - 1;
const index_int stride = this->strides[k];
const index_int len = this->lens[k];
const index_int slen = s * len;
const index_int idx = (i % slen) / s;
result += stride * idx;
s = slen;
}
return result;
return compute_index(i);
}
}
constexpr index_int compute_index(index_int i) const
{
const auto rank = this->lens.size();
index_int s = 1;
index_int result = 0;
for(index_int j = 0; j < rank; j++)
{
const index_int k = rank - j - 1;
const index_int stride = this->strides[k];
const index_int len = this->lens[k];
const index_int slen = s * len;
const index_int idx = (i % slen) / s;
result += stride * idx;
s = slen;
}
return result;
}
/// Convert single index into a multi-index
constexpr index_array multi(index_int idx) const
{
......
......@@ -23,6 +23,7 @@
*/
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
migraphx::shape make_shape(std::vector<std::size_t> lens)
......@@ -35,6 +36,21 @@ migraphx::shape make_shape(std::vector<std::size_t> lens, std::vector<std::size_
return {migraphx::shape::float_type, std::move(lens), std::move(strides)};
}
bool verify_shape(const migraphx::shape& s1, const migraphx::shape& s2)
{
if(s1.elements() != s2.elements())
return false;
return migraphx::all_of(migraphx::range(s1.elements()),
[&](auto i) { return s1.index(i) == s2.index(i); });
}
template <class Range1, class Range2>
bool verify_shapes(const Range1& r1, const Range2& r2)
{
return migraphx::equal(
r1, r2, [](const auto& s1, const auto& s2) { return verify_shape(s1, s2); });
}
TEST_CASE(same_standard)
{
auto is = make_shape({64, 3, 7, 7});
......@@ -42,7 +58,7 @@ TEST_CASE(same_standard)
std::vector<migraphx::shape> ishapes = {is, is, is};
std::vector<migraphx::shape> eshapes = {os, os, os};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(verify_shapes(ishapes, rshapes));
EXPECT(eshapes == rshapes);
}
......@@ -53,7 +69,7 @@ TEST_CASE(same_broadcast1)
std::vector<migraphx::shape> ishapes = {is, make_shape({64, 3, 7, 7}, {0, 1, 0, 0}), is};
std::vector<migraphx::shape> eshapes = {os, make_shape({64, 3, 7 * 7}, {0, 1, 0}), os};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(verify_shapes(ishapes, rshapes));
EXPECT(eshapes == rshapes);
}
......@@ -64,7 +80,7 @@ TEST_CASE(same_broadcast2)
std::vector<migraphx::shape> ishapes = {is, make_shape({64, 3, 8, 7, 7}, {0, 8, 1, 0, 0}), is};
std::vector<migraphx::shape> eshapes = {os, make_shape({64, 8 * 3, 7 * 7}, {0, 1, 0}), os};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(verify_shapes(ishapes, rshapes));
EXPECT(eshapes == rshapes);
}
......@@ -75,7 +91,7 @@ TEST_CASE(same_transposed)
std::vector<migraphx::shape> ishapes = {is, migraphx::reorder_shape(is, {0, 1, 3, 2}), is};
std::vector<migraphx::shape> eshapes = {os, migraphx::reorder_shape(os, {0, 2, 1}), os};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(verify_shapes(ishapes, rshapes));
EXPECT(eshapes == rshapes);
}
......@@ -86,7 +102,7 @@ TEST_CASE(different_masked1)
std::vector<migraphx::shape> ishapes = {is, make_shape({1, 3, 1, 1}), is};
std::vector<migraphx::shape> eshapes = {os, make_shape({1, 3, 1}), os};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(verify_shapes(ishapes, rshapes));
EXPECT(eshapes == rshapes);
}
......@@ -98,7 +114,7 @@ TEST_CASE(different_masked2)
is, make_shape({1, 3, 1, 1}), make_shape({64, 1, 7, 7})};
std::vector<migraphx::shape> eshapes = {os, make_shape({1, 3, 1}), make_shape({64, 1, 7 * 7})};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(verify_shapes(ishapes, rshapes));
EXPECT(eshapes == rshapes);
}
......@@ -128,7 +144,7 @@ TEST_CASE(transposed1)
std::vector<migraphx::shape> eshapes = {
make_shape({8, 28, 4, 56 * 56}), make_shape({8, 28, 4, 56 * 56}, {351232, 3136, 87808, 1})};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(verify_shapes(ishapes, rshapes));
EXPECT(eshapes == rshapes);
}
......@@ -137,6 +153,7 @@ TEST_CASE(non_packed_empty1)
std::vector<migraphx::shape> ishapes = {make_shape({1, 12}, {589824, 64})};
std::vector<migraphx::shape> eshapes = {make_shape({12}, {64})};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(verify_shapes(ishapes, rshapes));
EXPECT(eshapes == rshapes);
}
......@@ -145,6 +162,7 @@ TEST_CASE(non_packed_empty2)
std::vector<migraphx::shape> ishapes = {make_shape({12, 1}, {64, 589824})};
std::vector<migraphx::shape> eshapes = {make_shape({12}, {64})};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(verify_shapes(ishapes, rshapes));
EXPECT(eshapes == rshapes);
}
......@@ -155,6 +173,16 @@ TEST_CASE(single_dim)
EXPECT(ishapes == rshapes);
}
TEST_CASE(step_broadcast_transpose)
{
std::vector<migraphx::shape> ishapes = {make_shape({1, 2, 2, 1}, {0, 0, 3, 6}),
make_shape({1, 2, 2, 1}, {4, 2, 1, 1})};
std::vector<migraphx::shape> eshapes = {make_shape({2, 2}, {0, 3}), make_shape({2, 2}, {2, 1})};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(verify_shapes(ishapes, rshapes));
EXPECT(eshapes == rshapes);
}
TEST_CASE(empty)
{
auto rshapes = migraphx::reduce_dims({});
......
......@@ -200,6 +200,15 @@ TEST_CASE(test_shape_broadcasted5)
EXPECT(s.broadcasted());
}
TEST_CASE(test_shape_step_broadcasted)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {0, 3}};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
EXPECT(s.broadcasted());
}
TEST_CASE(test_shape_default_copy)
{
migraphx::shape s1{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment