Commit c13780c2 authored by Paul's avatar Paul
Browse files

Format

parent 15fd8205
......@@ -101,7 +101,7 @@ struct reduce_compiler : compiler<reduce_compiler>
auto faxis = find_fast_axis({inputs.front()});
vectorize vec{};
// Vectorize if the axis is a reduction axis
if (inputs.back().lens()[faxis] == 1)
if(inputs.back().lens()[faxis] == 1)
{
vec = vectorize::elements(faxis, inputs);
}
......@@ -111,11 +111,14 @@ struct reduce_compiler : compiler<reduce_compiler>
{
auto block_size = compute_block_size(reduce_elements, 256);
options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements() * block_size / vec.size, 256), block_size);
v,
compute_global_for(ctx, inputs.back().elements() * block_size / vec.size, 256),
block_size);
}
else if(algo == "lane")
{
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements() / vec.size, 256));
options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements() / vec.size, 256));
}
else
{
......@@ -124,7 +127,7 @@ struct reduce_compiler : compiler<reduce_compiler>
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs);
options.kernel_name = "reduce_kernel";
options.kernel_name = "reduce_kernel";
std::string identity = "[](auto x) { return x; }";
auto src = interpolate_string(simple_reduce_kernel,
{{"reduction", v.at("reduction").to<std::string>()},
......
......@@ -163,9 +163,12 @@ struct block
__device__ auto reduce(Op op, T init, Read read) const
{
return sliced(slicer, [=](auto x, auto... xs) {
return vec_reduce(block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) {
return read(x[j], xs[j]...);
}), op);
return vec_reduce(block_reduce(idx,
op,
init,
x.get_shape().elements(),
[&](auto j) { return read(x[j], xs[j]...); }),
op);
});
}
......
......@@ -146,7 +146,7 @@ constexpr auto vec_packed_transform(Ts... xs)
};
}
template<class T, class Op>
template <class T, class Op>
constexpr auto vec_reduce(T x, Op op)
{
if constexpr(vec_size<T>() < 2)
......@@ -155,7 +155,7 @@ constexpr auto vec_reduce(T x, Op op)
{
vec_type<T> result;
for(int i = 1; i < vec_size<T>(); i++)
result = op(result[i-1], result[i]);
result = op(result[i - 1], result[i]);
return result;
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment