Commit c13780c2 authored by Paul's avatar Paul
Browse files

Format

parent 15fd8205
...@@ -101,7 +101,7 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -101,7 +101,7 @@ struct reduce_compiler : compiler<reduce_compiler>
auto faxis = find_fast_axis({inputs.front()}); auto faxis = find_fast_axis({inputs.front()});
vectorize vec{}; vectorize vec{};
// Vectorize if the axis is a reduction axis // Vectorize if the axis is a reduction axis
if (inputs.back().lens()[faxis] == 1) if(inputs.back().lens()[faxis] == 1)
{ {
vec = vectorize::elements(faxis, inputs); vec = vectorize::elements(faxis, inputs);
} }
...@@ -111,11 +111,14 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -111,11 +111,14 @@ struct reduce_compiler : compiler<reduce_compiler>
{ {
auto block_size = compute_block_size(reduce_elements, 256); auto block_size = compute_block_size(reduce_elements, 256);
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements() * block_size / vec.size, 256), block_size); v,
compute_global_for(ctx, inputs.back().elements() * block_size / vec.size, 256),
block_size);
} }
else if(algo == "lane") else if(algo == "lane")
{ {
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements() / vec.size, 256)); options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements() / vec.size, 256));
} }
else else
{ {
......
...@@ -163,9 +163,12 @@ struct block ...@@ -163,9 +163,12 @@ struct block
__device__ auto reduce(Op op, T init, Read read) const __device__ auto reduce(Op op, T init, Read read) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slicer, [=](auto x, auto... xs) {
return vec_reduce(block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) { return vec_reduce(block_reduce(idx,
return read(x[j], xs[j]...); op,
}), op); init,
x.get_shape().elements(),
[&](auto j) { return read(x[j], xs[j]...); }),
op);
}); });
} }
......
...@@ -146,7 +146,7 @@ constexpr auto vec_packed_transform(Ts... xs) ...@@ -146,7 +146,7 @@ constexpr auto vec_packed_transform(Ts... xs)
}; };
} }
template<class T, class Op> template <class T, class Op>
constexpr auto vec_reduce(T x, Op op) constexpr auto vec_reduce(T x, Op op)
{ {
if constexpr(vec_size<T>() < 2) if constexpr(vec_size<T>() < 2)
...@@ -155,7 +155,7 @@ constexpr auto vec_reduce(T x, Op op) ...@@ -155,7 +155,7 @@ constexpr auto vec_reduce(T x, Op op)
{ {
vec_type<T> result; vec_type<T> result;
for(int i = 1; i < vec_size<T>(); i++) for(int i = 1; i < vec_size<T>(); i++)
result = op(result[i-1], result[i]); result = op(result[i - 1], result[i]);
return result; return result;
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment