Commit 4828226b authored by Paul's avatar Paul
Browse files

Format

parent 276184b2
...@@ -53,13 +53,13 @@ struct index ...@@ -53,13 +53,13 @@ struct index
return blockDim.x; // NOLINT return blockDim.x; // NOLINT
} }
#endif #endif
template<class N> template <class N>
constexpr auto max_global_stride_iterations(N n) const constexpr auto max_global_stride_iterations(N n) const
{ {
return _c<1> + n / nglobal(); return _c<1> + n / nglobal();
} }
template<class N> template <class N>
constexpr auto max_local_stride_iterations(N n) const constexpr auto max_local_stride_iterations(N n) const
{ {
return _c<1> + n / nlocal(); return _c<1> + n / nlocal();
......
...@@ -22,13 +22,10 @@ __device__ void generic_binary_layernorm( ...@@ -22,13 +22,10 @@ __device__ void generic_binary_layernorm(
MIGRAPHX_ASSERT(relements > 0); MIGRAPHX_ASSERT(relements > 0);
reduce::block::run<reduce_output>([&](auto, auto r) { reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type; using value_type = typename Input1::type;
auto input = r.inner([&](auto x1, auto x2) { auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2);
return op(x1, x2);
})(input1, input2);
auto mean = [&](auto f) { auto mean = [&](auto f) {
return r.reduce(op::sum{}, 0, [&](auto x) { return r.reduce(op::sum{}, 0, [&](auto x) { return f(x) / value_type{relements}; })(
return f(x) / value_type{relements}; input);
})(input);
}; };
// mean(x) // mean(x)
auto mean_x = mean(op::id{}); auto mean_x = mean(op::id{});
......
...@@ -147,10 +147,11 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) ...@@ -147,10 +147,11 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
} }
#endif #endif
namespace reduce { namespace reduce {
struct inner_array_base {}; struct inner_array_base
{
};
template <class Output, class Input, class T> template <class Output, class Input, class T>
constexpr auto reduce_slice(Input input, T i) constexpr auto reduce_slice(Input input, T i)
...@@ -227,20 +228,20 @@ struct block ...@@ -227,20 +228,20 @@ struct block
f(); f();
} }
template<class T, index_int N, index_int Stride, class Shape> template <class T, index_int N, index_int Stride, class Shape>
struct inner_array : inner_array_base struct inner_array : inner_array_base
{ {
array<T, N> arr; array<T, N> arr;
constexpr Shape get_shape() const {return Shape{};} constexpr Shape get_shape() const { return Shape{}; }
template<class U> template <class U>
constexpr auto& operator[](U i) const constexpr auto& operator[](U i) const
{ {
return arr[i/Stride]; return arr[i / Stride];
} }
template<class U> template <class U>
constexpr auto& operator[](U i) constexpr auto& operator[](U i)
{ {
return arr[i/Stride]; return arr[i / Stride];
} }
}; };
...@@ -255,8 +256,14 @@ struct block ...@@ -255,8 +256,14 @@ struct block
} }
else else
{ {
inner_array<result_type, decltype(idx.max_local_stride_iterations(x.get_shape().elements())){}, decltype(idx.nlocal()){}, decltype(x.get_shape())> y; inner_array<result_type,
idx.local_stride(x.get_shape().elements(), [&](auto j) { y[j] = f(x[j], xs[j]...); }); decltype(
idx.max_local_stride_iterations(x.get_shape().elements())){},
decltype(idx.nlocal()){},
decltype(x.get_shape())>
y;
idx.local_stride(x.get_shape().elements(),
[&](auto j) { y[j] = f(x[j], xs[j]...); });
return y; return y;
} }
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment