Commit 4001eef7 authored by Paul's avatar Paul
Browse files

Format

parent 2aa25de2
...@@ -56,13 +56,11 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -56,13 +56,11 @@ struct pointwise_compiler : compiler<pointwise_compiler>
static std::size_t compute_local(gen::vectorize v, const std::vector<shape>& inputs) static std::size_t compute_local(gen::vectorize v, const std::vector<shape>& inputs)
{ {
const std::size_t max_local = 1024; const std::size_t max_local = 1024;
if (std::none_of(inputs.begin(), inputs.end(), [&](auto s) { if(std::none_of(inputs.begin(), inputs.end(), [&](auto s) { return s.transposed(); }))
return s.transposed();
}))
return max_local; return max_local;
if (std::any_of(inputs.begin(), inputs.end(), [&](auto s) { if(std::any_of(inputs.begin(), inputs.end(), [&](auto s) {
return s.broadcasted() or s.strides()[v.axis] != 1; return s.broadcasted() or s.strides()[v.axis] != 1;
})) }))
return max_local; return max_local;
return inputs.front().lens()[v.axis] / v.size; return inputs.front().lens()[v.axis] / v.size;
} }
...@@ -80,7 +78,8 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -80,7 +78,8 @@ struct pointwise_compiler : compiler<pointwise_compiler>
v, v,
compute_global_for(ctx, compute_global_for(ctx,
options.output.elements() / vec.size, options.output.elements() / vec.size,
oversubscribe_if(not preloads.is_preloading())), compute_local(vec, options.virtual_inputs)); oversubscribe_if(not preloads.is_preloading())),
compute_local(vec, options.virtual_inputs));
auto src = interpolate_string(pointwise_kernel, auto src = interpolate_string(pointwise_kernel,
{{"params", enum_params(inputs.size(), "void * private_p")}, {{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment