Commit e7a1d5fa authored by Paul's avatar Paul
Browse files

Merge

parents eb98f95b d36f72c5
...@@ -33,49 +33,92 @@ ...@@ -33,49 +33,92 @@
namespace migraphx { namespace migraphx {
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_ARRAY_OP(op, binary_op) \ #define MIGRAPHX_DEVICE_ARRAY_OP(op, binary_op) \
template <class U> \ template <class U> \
constexpr array& operator op(const array<U, N>& x) \ constexpr array& operator op(const array<U, N>& x) \
{ \ { \
for(index_int i = 0; i < N; i++) \ array_for_each(*this, x)([](auto& sy, auto sx) { sy op sx; }); \
d[i] op x[i]; \ return *this; \
return *this; \ } \
} \ template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \ constexpr array& operator op(const U& x) \
constexpr array& operator op(const U& x) \ { \
{ \ array_for_each (*this)([&](auto& sy) { sy op x; }); \
for(index_int i = 0; i < N; i++) \ return *this; \
d[i] op x; \ } \
return *this; \ template <class U> \
} \ friend constexpr auto operator binary_op(const array& x, const array<U, N>& y) \
template <class U> \ { \
friend constexpr auto operator binary_op(const array& x, const array<U, N>& y) \ array<decltype(T {} binary_op U{}), N> z{}; \
{ \ array_for_each(z, x, y)([&](auto& sz, auto sx, auto sy) { sz = sx binary_op sy; }); \
array<decltype(T {} binary_op U{}), N> z{}; \ return z; \
for(index_int i = 0; i < N; i++) \ } \
z[i] = x[i] binary_op y[i]; \ template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
return z; \ friend constexpr auto operator binary_op(const array& x, const U& y) \
} \ { \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \ array<decltype(T {} binary_op U{}), N> z{}; \
friend constexpr auto operator binary_op(const array& x, const U& y) \ array_for_each(z, x)([&](auto& sz, auto sx) { sz = sx binary_op y; }); \
{ \ return z; \
array<decltype(T {} binary_op U{}), N> z{}; \ } \
for(index_int i = 0; i < N; i++) \ template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
z[i] = x[i] binary_op y; \ friend constexpr auto operator binary_op(const U& x, const array& y) \
return z; \ { \
} \ array<decltype(T {} binary_op U{}), N> z{}; \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \ array_for_each(z, y)([&](auto& sz, auto sy) { sz = x binary_op sy; }); \
friend constexpr auto operator binary_op(const U& x, const array& y) \ return z; \
{ \
array<decltype(T {} binary_op U{}), N> z{}; \
for(index_int i = 0; i < N; i++) \
z[i] = x binary_op y[i]; \
return z; \
} }
template <class T>
constexpr auto is_vectorizable()
{
return not is_same<T, bool>{} and (is_fundamental<T>{} or is_same<T, half>{});
}
template <class T>
constexpr auto array2vec(T x)
{
using value_type = typename T::value_type;
constexpr auto size = decltype(x.size()){};
using type = vec<value_type, size>;
static_assert(size != 3, "Wrong size");
return __builtin_bit_cast(type, x);
}
template <class T, class U, index_int N>
constexpr void vec2array(T& x, vec<U, N> v)
{
if constexpr(not is_const<T>{})
x = __builtin_bit_cast(T, v);
}
template <class T, class... Ts>
constexpr auto array_for_each(T& x, Ts&... xs)
{
MIGRAPHX_ASSERT((x.size() == xs.size() and ...));
return [&](auto f) {
constexpr auto size = decltype(x.size()){};
if constexpr((is_vectorizable<typename T::value_type>() or
(is_vectorizable<typename Ts::value_type>() or ...)) and
size <= 8 and size > 1 and (size % 2 == 0))
{
[&](auto v, auto... vs) {
f(v, vs...);
vec2array(x, v);
swallow{(vec2array(xs, vs), 0)...};
}(array2vec(x), array2vec(xs)...);
}
else
{
for(index_int i = 0; i < size; i++)
f(x[i], xs[i]...);
}
};
}
template <class T, index_int N> template <class T, index_int N>
struct array struct array
{ {
using value_type = T;
T d[N]; T d[N];
constexpr T& operator[](index_int i) constexpr T& operator[](index_int i)
{ {
......
...@@ -48,11 +48,10 @@ __device__ void generic_binary_layernorm( ...@@ -48,11 +48,10 @@ __device__ void generic_binary_layernorm(
using reduce_output = reduce::with_axis<Input1, Axis>; using reduce_output = reduce::with_axis<Input1, Axis>;
reduce::block::run<reduce_output>([&](auto, auto r) { reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type; using value_type = typename Input1::type;
using reduce_type = vec_type<value_type>;
constexpr auto relements = r.template elements<Input1>(); constexpr auto relements = r.template elements<Input1>();
auto means = r.reduce(op::sum{}, make_array<reduce_type>(0, 0), [&](auto x1, auto x2) { auto means = r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x1, auto x2) {
auto x = op(x1, x2); auto x = op(x1, x2);
return make_array(vec_reduce(x, op::sum{}), vec_reduce(x * x, op::sum{})) / reduce_type{relements}; return make_array(x, x * x) / vec_type<value_type>{relements};
})(input1, input2); })(input1, input2);
auto mean_x = means[0]; auto mean_x = means[0];
......
...@@ -201,12 +201,9 @@ struct block ...@@ -201,12 +201,9 @@ struct block
__device__ auto reduce(Op op, T init, Read read) const __device__ auto reduce(Op op, T init, Read read) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slicer, [=](auto x, auto... xs) {
return vec_reduce(block_reduce(idx, return block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) {
op, return vec_reduce(read(x[j], xs[j]...), op);
init, });
x.get_shape().elements(),
[&](auto j) { return read(x[j], xs[j]...); }),
op);
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment