Commit 2aa25de2 authored by Paul's avatar Paul
Browse files

Try to calculate local

parent f1e1d443
......@@ -53,6 +53,19 @@ struct pointwise_compiler : compiler<pointwise_compiler>
else
return 1;
}
static std::size_t compute_local(gen::vectorize v, const std::vector<shape>& inputs)
{
const std::size_t max_local = 1024;
if (std::none_of(inputs.begin(), inputs.end(), [&](auto s) {
return s.transposed();
}))
return max_local;
if (std::any_of(inputs.begin(), inputs.end(), [&](auto s) {
return s.broadcasted() or s.strides()[v.axis] != 1;
}))
return max_local;
return inputs.front().lens()[v.axis] / v.size;
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
......@@ -67,7 +80,7 @@ struct pointwise_compiler : compiler<pointwise_compiler>
v,
compute_global_for(ctx,
options.output.elements() / vec.size,
oversubscribe_if(not preloads.is_preloading())));
oversubscribe_if(not preloads.is_preloading())), compute_local(vec, options.virtual_inputs));
auto src = interpolate_string(pointwise_kernel,
{{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment