Commit c5d87f8f authored by Paul's avatar Paul
Browse files

Format

parent 9466f4c0
......@@ -49,7 +49,9 @@ static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs)
return {4, 2};
}
vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs, const std::vector<std::size_t>& sizes)
vectorize vectorize::elements(std::size_t axis,
const std::vector<shape>& inputs,
const std::vector<std::size_t>& sizes)
{
if(std::all_of(
inputs.begin(), inputs.end(), [&](const auto& s) { return s.lens()[axis] == 1; }))
......@@ -83,16 +85,19 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
vectorize vectorize::elements(context& ctx, std::size_t axis, const std::vector<shape>& inputs)
{
if (inputs.empty())
if(inputs.empty())
return {1, axis};
std::size_t n = std::max_element(inputs.begin(), inputs.end(), by(std::less<>{}, [](const auto& s) { return s.elements(); }))->elements();
std::size_t n = std::max_element(inputs.begin(),
inputs.end(),
by(std::less<>{}, [](const auto& s) { return s.elements(); }))
->elements();
std::size_t max_global = ctx.get_current_device().get_cu_count() *
ctx.get_current_device().get_max_workitems_per_cu();
std::size_t over = n / max_global;
std::vector<std::size_t> sizes;
if (over > 8)
if(over > 8)
sizes.push_back(8);
if (over > 4)
if(over > 4)
sizes.push_back(4);
sizes.push_back(2);
return elements(axis, inputs, sizes);
......
......@@ -47,7 +47,9 @@ struct vectorize
std::size_t axis = 0;
static vectorize elements(std::size_t axis, const std::vector<shape>& inputs);
static vectorize elements(context& ctx, std::size_t axis, const std::vector<shape>& inputs);
static vectorize elements(std::size_t axis, const std::vector<shape>& inputs, const std::vector<std::size_t>& sizes);
static vectorize elements(std::size_t axis,
const std::vector<shape>& inputs,
const std::vector<std::size_t>& sizes);
std::string str() const;
};
struct preload
......
......@@ -78,10 +78,7 @@ struct pointwise_compiler : compiler<pointwise_compiler>
auto vec = vectorize::elements(ctx, axis, options.virtual_inputs);
options.kernel_name = v.get("kernel", "kernel");
options.set_launch_params(
v,
compute_global_for(ctx,
options.output.elements() / vec.size,
256));
v, compute_global_for(ctx, options.output.elements() / vec.size, 256));
auto src = interpolate_string(pointwise_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment