Commit 9b55685c authored by Paul Fultz II's avatar Paul Fultz II Committed by mvermeulen
Browse files

Improve contiguous and concat performance (#368)

* Add env to trace nary device functions

* Formatting

* Improve contiguous and concat performance

* Formatting

* Remove unused variable

* Formatting

* Fix gpu tests

* Formatting

* Add more test for transposed concat

* Formatting

* Compute offset and not index

* Compute multi-index once

* Formatting

* Fix transposed inputs

* Formatting

* Use product order for comparisons of hip_array

* Formatting

* Add missing s parameter

* Formatting

* Dont invert permutation

* Fix tidy warnings

* Formatting

* Remove incorrect license

* Use a single integer for stride

* Formatting

* Fix tidy issue
parent 47b05b0c
......@@ -43,6 +43,12 @@ struct argument : raw_data<argument>
const shape& get_shape() const { return this->m_shape; }
argument reshape(const shape& s) const
{
argument self = *this;
return {s, [=]() mutable { return self.data(); }};
}
private:
shape m_shape;
};
......
#ifndef MIGRAPHX_GUARD_RTGLIB_PERMUTATION_HPP
#define MIGRAPHX_GUARD_RTGLIB_PERMUTATION_HPP
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Vector>
inline Vector reorder_dims(const Vector& dims, const std::vector<int64_t>& permutation)
{
Vector result(dims.size());
assert(dims.size() == permutation.size());
for(std::size_t i = 0; i < dims.size(); i++)
{
result[i] = dims[permutation[i]];
}
return result;
}
inline shape reorder_shape(const shape& s, const std::vector<int64_t>& permutation)
{
return {s.type(), reorder_dims(s.lens(), permutation), reorder_dims(s.strides(), permutation)};
}
template <class Vector, class Op>
inline std::vector<int64_t> sort_permutation(const Vector& data, Op op)
{
std::vector<std::int64_t> result(data.size());
std::iota(result.begin(), result.end(), 0);
std::sort(result.begin(), result.end(), [&](auto x, auto y) { return op(data[x], data[y]); });
return result;
}
inline std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
return sort_permutation(permutation, std::less<>{});
}
inline std::vector<int64_t> find_permutation(const shape& s)
{
return sort_permutation(s.strides(), std::greater<>{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -7,6 +7,7 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/permutation.hpp>
#include <unordered_set>
namespace migraphx {
......@@ -43,17 +44,6 @@ auto get_transpose_dims(instruction_ref ins)
return any_cast<const op::transpose&>(ins->get_operator()).dims;
}
std::vector<int64_t> reorder_dims(std::vector<int64_t> dims, std::vector<int64_t> permutation)
{
std::vector<int64_t> result(dims.size());
assert(dims.size() == permutation.size());
for(std::size_t i = 0; i < dims.size(); i++)
{
result[i] = dims[permutation[i]];
}
return result;
}
bool is_no_transpose(const std::vector<int64_t>& dims)
{
if(dims.empty())
......@@ -64,25 +54,6 @@ bool is_no_transpose(const std::vector<int64_t>& dims)
dims.begin(), dims.end(), [](auto x, auto y) { return (y - x) != 1; }) == dims.end();
}
template <class Vector, class Op>
std::vector<int64_t> sort_permutation(const Vector& data, Op op)
{
std::vector<std::int64_t> result(data.size());
std::iota(result.begin(), result.end(), 0);
std::sort(result.begin(), result.end(), [&](auto x, auto y) { return op(data[x], data[y]); });
return result;
}
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
return sort_permutation(permutation, std::less<>{});
}
std::vector<int64_t> find_permutation(const shape& s)
{
return sort_permutation(s.strides(), std::greater<>{});
}
struct find_reshaper
{
auto matcher() const
......
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/concat.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -18,16 +17,12 @@ argument concat(hipStream_t stream,
for(std::size_t j = 0; j < ninputs; j++)
{
auto&& arg = args[j];
std::size_t nelements = arg.get_shape().elements();
auto offset = offsets[j];
shape arg_shape{arg.get_shape().type(), arg.get_shape().lens()};
hip_visit_all(args.back(), arg, arg_shape)([&](auto output, auto input, auto input_shape) {
gs_launch(stream, nelements)([=](auto i) {
auto input_idx = input_shape.multi(i);
auto idx = output.get_shape().index(input_idx);
output.data()[idx + offset] = input[input_idx];
});
});
auto byte_offset = offset * arg.get_shape().type_size();
auto output_shape = shape{
arg.get_shape().type(), arg.get_shape().lens(), args.back().get_shape().strides()};
auto output = argument{output_shape, args.back().data() + byte_offset};
contiguous(stream, std::move(output), arg);
}
return args.back();
}
......
......@@ -9,7 +9,7 @@ namespace device {
void contiguous(hipStream_t stream, argument result, argument arg)
{
nary_nonstandard(stream, std::move(result), std::move(arg))([](auto x) { return x; });
nary(stream, std::move(result), std::move(arg))([](auto x) { return x; });
}
} // namespace device
......
......@@ -9,6 +9,33 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_ARRAY_OP(op, binary_op) \
MIGRAPHX_DEVICE_CONSTEXPR hip_array& operator op(const hip_array& x) \
{ \
for(std::size_t i = 0; i < N; i++) \
d[i] op x[i]; \
return *this; \
} \
MIGRAPHX_DEVICE_CONSTEXPR hip_array& operator op(const T& x) \
{ \
for(std::size_t i = 0; i < N; i++) \
d[i] op x; \
return *this; \
} \
friend MIGRAPHX_DEVICE_CONSTEXPR hip_array operator binary_op(hip_array x, const hip_array& y) \
{ \
return x op y; \
} \
friend MIGRAPHX_DEVICE_CONSTEXPR hip_array operator binary_op(hip_array x, const T& y) \
{ \
return x op y; \
} \
friend MIGRAPHX_DEVICE_CONSTEXPR hip_array operator binary_op(const T& y, hip_array x) \
{ \
return x op y; \
}
template <class T, std::size_t N>
struct hip_array
{
......@@ -49,19 +76,79 @@ struct hip_array
return result;
}
friend MIGRAPHX_DEVICE_CONSTEXPR hip_array operator*(const hip_array& x, const hip_array& y)
MIGRAPHX_DEVICE_CONSTEXPR T single(std::size_t width = 100) const
{
hip_array result;
T result = 0;
T a = 1;
for(std::size_t i = 0; i < N; i++)
result[i] = x[i] * y[i];
{
result += d[N - i - 1] * a;
a *= width;
}
return result;
}
friend MIGRAPHX_DEVICE_CONSTEXPR hip_array operator+(const hip_array& x, const hip_array& y)
MIGRAPHX_DEVICE_ARRAY_OP(+=, +)
MIGRAPHX_DEVICE_ARRAY_OP(*=, *)
MIGRAPHX_DEVICE_ARRAY_OP(/=, /)
MIGRAPHX_DEVICE_ARRAY_OP(%=, %)
MIGRAPHX_DEVICE_ARRAY_OP(&=, &)
MIGRAPHX_DEVICE_ARRAY_OP(|=, |)
MIGRAPHX_DEVICE_ARRAY_OP(^=, ^)
friend MIGRAPHX_DEVICE_CONSTEXPR bool operator==(const hip_array& x, const hip_array& y)
{
for(std::size_t i = 0; i < N; i++)
{
if(x[i] != y[i])
return false;
}
return true;
}
friend MIGRAPHX_DEVICE_CONSTEXPR bool operator!=(const hip_array& x, const hip_array& y)
{
return !(x == y);
}
// This uses the product order rather than lexical order
friend MIGRAPHX_DEVICE_CONSTEXPR bool operator<(const hip_array& x, const hip_array& y)
{
hip_array result{};
for(std::size_t i = 0; i < N; i++)
result[i] = x[i] + y[i];
{
if(not(x[i] < y[i]))
return false;
}
return true;
}
friend MIGRAPHX_DEVICE_CONSTEXPR bool operator>(const hip_array& x, const hip_array& y)
{
return y < x;
}
friend MIGRAPHX_DEVICE_CONSTEXPR bool operator<=(const hip_array& x, const hip_array& y)
{
return (x < y) or (x == y);
}
friend MIGRAPHX_DEVICE_CONSTEXPR bool operator>=(const hip_array& x, const hip_array& y)
{
return (y < x) or (x == y);
}
MIGRAPHX_DEVICE_CONSTEXPR hip_array carry(hip_array result) const
{
std::ptrdiff_t rem = 0;
for(std::ptrdiff_t i = result.size() - 1; i >= 0; i--)
{
auto z = result[i] + rem;
rem = z - std::ptrdiff_t(d[i]) + 1;
if(rem > 0)
z -= rem;
else
rem = 0;
result[i] = z;
}
// Add overflows to the back
if(rem > 0)
result.back() += rem;
return result;
}
};
......
#ifndef MIGRAPHX_GUARD_RTGLIB_MULTI_INDEX_HPP
#define MIGRAPHX_GUARD_RTGLIB_MULTI_INDEX_HPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/shape.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <std::size_t N>
struct multi_index
{
using hip_index = hip_array<std::size_t, N>;
hip_index id{};
std::size_t stride = 0;
MIGRAPHX_DEVICE_CONSTEXPR hip_index add_stride(hip_index i) const
{
i.back() += stride;
return i;
}
template <class F>
MIGRAPHX_DEVICE_CONSTEXPR void for_stride(hip_index n, F f) const
{
for(hip_index i = id; i < n; i = n.carry(add_stride(i)))
{
f(i);
}
}
};
template <std::size_t N>
MIGRAPHX_DEVICE_CONSTEXPR multi_index<N>
make_multi_index(const hip_shape<N>& s, std::size_t i, std::size_t n)
{
return {s.multi(i), n};
}
template <std::size_t N>
MIGRAPHX_DEVICE_CONSTEXPR multi_index<N>
make_multi_index(const hip_shape<N>& s, std::size_t i, const hip_array<std::size_t, N>& n)
{
return {s.multi(i), n};
}
template <std::size_t N>
inline auto mi_launch(hipStream_t stream, const hip_shape<N>& s, std::size_t local = 1024)
{
assert(s.standard);
std::size_t n = s.elements();
std::size_t groups = (n + local - 1) / local;
std::size_t nglobal = std::min<std::size_t>(128, groups) * local;
return [=](auto f) {
launch(stream, nglobal, local)([=](auto idx) {
auto midx = make_multi_index(s, idx.global, nglobal);
midx.for_stride(s.lens, [&](auto i) { f(i); });
});
};
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -2,17 +2,28 @@
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_NARY_HPP
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/multi_index.hpp>
#include <migraphx/gpu/device/visit.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/array.hpp>
#include <migraphx/env.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/config.hpp>
#include <iostream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_NARY);
// NOLINTNEXTLINE
#define MIGRAPHX_TRACE_NARY_FUNCTION \
if(enabled(MIGRAPHX_TRACE_NARY{})) \
std::cout << "nary device function: " << __PRETTY_FUNCTION__ << std::endl;
template <class... Ts>
auto pack(Ts... xs) __device__
{
......@@ -20,14 +31,28 @@ auto pack(Ts... xs) __device__
}
template <class F, class... Arguments>
auto nary_nonstandard_impl(hipStream_t stream, F f, argument result, Arguments... args)
auto nary_nonstandard_nonpacked_impl(hipStream_t stream, F f, argument result, Arguments... args)
{
std::size_t nelements = result.get_shape().elements();
hip_visit_all(result, args...)([&](auto output, auto... inputs) {
gs_launch(stream, nelements)([=](auto i) {
auto idx = output.get_shape().multi(i);
output[i] = f(inputs[idx]...);
MIGRAPHX_TRACE_NARY_FUNCTION
shape s{result.get_shape().type(), result.get_shape().lens()};
hip_visit_all(s, result, args...)([&](auto standard_shape, auto output, auto... inputs) {
mi_launch(stream, standard_shape)([=](auto idx) { output[idx] = f(inputs[idx]...); });
});
}
template <class F, class... Arguments>
auto nary_nonstandard_packed_impl(hipStream_t stream,
F f,
const argument& result,
Arguments... args)
{
MIGRAPHX_TRACE_NARY_FUNCTION
auto arg_shape = make_array(args...).front().get_shape();
auto perm = find_permutation(arg_shape);
auto s = reorder_shape(arg_shape, perm);
hip_visit_all(s, result.reshape(reorder_shape(result.get_shape(), perm)), args.reshape(s)...)(
[&](auto standard_shape, auto output, auto... inputs) {
mi_launch(stream, standard_shape)([=](auto idx) { output[idx] = f(inputs[idx]...); });
});
}
......@@ -35,6 +60,7 @@ template <class F, class... Arguments>
void nary_broadcast_vec_impl(
hipStream_t stream, F f, argument result, argument barg, Arguments... args)
{
MIGRAPHX_TRACE_NARY_FUNCTION
const auto& output_shape = result.get_shape();
const auto& b_shape = barg.get_shape();
auto bdim =
......@@ -83,6 +109,7 @@ void nary_broadcast_vec_impl(
template <class F, class... Arguments>
void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg, Arguments... args)
{
MIGRAPHX_TRACE_NARY_FUNCTION
const auto& output_shape = result.get_shape();
const auto& b_shape = barg.get_shape();
auto bdim =
......@@ -122,6 +149,7 @@ template <class F, class... Arguments>
void nary_double_broadcast_vec_impl(
hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args)
{
MIGRAPHX_TRACE_NARY_FUNCTION
assert(barg1.get_shape().broadcasted());
assert(barg2.get_shape().broadcasted());
assert(barg1.get_shape() == barg2.get_shape());
......@@ -179,6 +207,7 @@ template <class F, class... Arguments>
void nary_double_broadcast_impl(
hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args)
{
MIGRAPHX_TRACE_NARY_FUNCTION
assert(barg1.get_shape().broadcasted());
assert(barg2.get_shape().broadcasted());
assert(barg1.get_shape() == barg2.get_shape());
......@@ -226,6 +255,7 @@ void nary_double_broadcast_impl(
template <class F, class... Arguments>
void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments... args)
{
MIGRAPHX_TRACE_NARY_FUNCTION
const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) {
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
......@@ -250,6 +280,7 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments.
template <class F, class... Arguments>
void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... args)
{
MIGRAPHX_TRACE_NARY_FUNCTION
std::size_t nelements = result.get_shape().elements();
hip_pointer_visit_all(result, args...)([&](auto output, auto... inputs) {
gs_launch(stream, nelements)([=](auto i) { output[i] = f(inputs[i]...); });
......@@ -259,20 +290,25 @@ void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... a
template <class F, class... Arguments>
void nary_impl(hipStream_t stream, F f, argument result, Arguments... args)
{
bool standard = all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); });
bool packed = all_of({args.get_shape()...}, [](const shape& s) { return s.packed(); });
bool same_shapes =
all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
if(standard or (packed and same_shapes))
MIGRAPHX_TRACE_NARY_FUNCTION
const auto shapes = make_array(args.get_shape()...);
const bool standard = all_of(shapes, [](const shape& s) { return s.standard(); });
const bool packed = all_of(shapes, [](const shape& s) { return s.packed(); });
const bool same_shapes =
all_of(shapes, [&](const shape& s) { return s == result.get_shape(); });
const bool same_input_shapes = all_of(shapes, [&](const shape& s) { return s == shapes[0]; });
if((result.get_shape().standard() and standard) or (packed and same_shapes))
nary_standard_impl(stream, f, result, args...);
else if(packed and same_input_shapes)
nary_nonstandard_packed_impl(stream, f, result, args...);
else
nary_nonstandard_impl(stream, f, result, args...);
nary_nonstandard_nonpacked_impl(stream, f, result, args...);
}
template <class... Arguments>
auto nary_nonstandard(hipStream_t stream, argument result, Arguments... args)
{
return [=](auto f) { nary_nonstandard_impl(stream, f, result, args...); };
return [=](auto f) { nary_nonstandard_nonpacked_impl(stream, f, result, args...); };
}
template <class... Arguments>
......
......@@ -73,22 +73,6 @@ struct hip_shape
}
return result;
}
MIGRAPHX_DEVICE_CONSTEXPR hip_index carry(hip_index result) const
{
std::ptrdiff_t rem = 0;
for(std::ptrdiff_t i = result.size() - 1; i >= 0; i--)
{
auto z = result[i] + rem;
rem = z - std::ptrdiff_t(lens[i]) + 1;
if(rem > 0)
z -= rem;
else
rem = 0;
result[i] = z;
}
return result;
}
};
template <std::size_t N>
......
......@@ -1830,6 +1830,61 @@ struct test_concat2 : verify_program<test_concat2>
}
};
struct test_concat_transpose : verify_program<test_concat_transpose>
{
migraphx::program create_program() const
{
migraphx::program p;
std::size_t axis = 1;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {3, 2}};
migraphx::shape s2{migraphx::shape::int32_type, {2, 4}};
auto l0 = p.add_parameter("x", s0);
auto lp1 = p.add_parameter("y", s1);
auto l1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, lp1);
auto l2 = p.add_parameter("z", s2);
p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2);
return p;
}
};
struct test_concat_transpose2 : verify_program<test_concat_transpose2>
{
migraphx::program create_program() const
{
migraphx::program p;
std::size_t axis = 1;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {2, 3}};
migraphx::shape s2{migraphx::shape::int32_type, {5, 2}};
auto l0 = p.add_parameter("x", s0);
auto l1 = p.add_parameter("y", s1);
auto lp2 = p.add_parameter("z", s2);
auto l2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, lp2);
p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2);
return p;
}
};
struct test_concat_transpose3 : verify_program<test_concat_transpose3>
{
migraphx::program create_program() const
{
migraphx::program p;
std::size_t axis = 1;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {3, 2}};
migraphx::shape s2{migraphx::shape::int32_type, {5, 2}};
auto l0 = p.add_parameter("x", s0);
auto lp1 = p.add_parameter("y", s1);
auto l1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, lp1);
auto lp2 = p.add_parameter("z", s2);
auto l2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, lp2);
p.add_instruction(migraphx::op::concat{axis}, l0, l1, l2);
return p;
}
};
struct test_concat_relu : verify_program<test_concat_relu>
{
migraphx::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment