Commit 4828226b authored by Paul's avatar Paul
Browse files

Format

parent 276184b2
......@@ -53,13 +53,13 @@ struct index
return blockDim.x; // NOLINT
}
#endif
template<class N>
template <class N>
constexpr auto max_global_stride_iterations(N n) const
{
return _c<1> + n / nglobal();
}
template<class N>
template <class N>
constexpr auto max_local_stride_iterations(N n) const
{
return _c<1> + n / nlocal();
......
......@@ -22,13 +22,10 @@ __device__ void generic_binary_layernorm(
MIGRAPHX_ASSERT(relements > 0);
reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type;
auto input = r.inner([&](auto x1, auto x2) {
return op(x1, x2);
})(input1, input2);
auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2);
auto mean = [&](auto f) {
return r.reduce(op::sum{}, 0, [&](auto x) {
return f(x) / value_type{relements};
})(input);
return r.reduce(op::sum{}, 0, [&](auto x) { return f(x) / value_type{relements}; })(
input);
};
// mean(x)
auto mean_x = mean(op::id{});
......
......@@ -147,10 +147,11 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
}
#endif
namespace reduce {
struct inner_array_base {};
struct inner_array_base
{
};
template <class Output, class Input, class T>
constexpr auto reduce_slice(Input input, T i)
......@@ -227,20 +228,20 @@ struct block
f();
}
template<class T, index_int N, index_int Stride, class Shape>
template <class T, index_int N, index_int Stride, class Shape>
struct inner_array : inner_array_base
{
array<T, N> arr;
constexpr Shape get_shape() const {return Shape{};}
template<class U>
constexpr Shape get_shape() const { return Shape{}; }
template <class U>
constexpr auto& operator[](U i) const
{
return arr[i/Stride];
return arr[i / Stride];
}
template<class U>
template <class U>
constexpr auto& operator[](U i)
{
return arr[i/Stride];
return arr[i / Stride];
}
};
......@@ -255,8 +256,14 @@ struct block
}
else
{
inner_array<result_type, decltype(idx.max_local_stride_iterations(x.get_shape().elements())){}, decltype(idx.nlocal()){}, decltype(x.get_shape())> y;
idx.local_stride(x.get_shape().elements(), [&](auto j) { y[j] = f(x[j], xs[j]...); });
inner_array<result_type,
decltype(
idx.max_local_stride_iterations(x.get_shape().elements())){},
decltype(idx.nlocal()){},
decltype(x.get_shape())>
y;
idx.local_stride(x.get_shape().elements(),
[&](auto j) { y[j] = f(x[j], xs[j]...); });
return y;
}
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment