Commit c5d87f8f authored by Paul's avatar Paul
Browse files

Format

parent 9466f4c0
...@@ -49,7 +49,9 @@ static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs) ...@@ -49,7 +49,9 @@ static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs)
return {4, 2}; return {4, 2};
} }
vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs, const std::vector<std::size_t>& sizes) vectorize vectorize::elements(std::size_t axis,
const std::vector<shape>& inputs,
const std::vector<std::size_t>& sizes)
{ {
if(std::all_of( if(std::all_of(
inputs.begin(), inputs.end(), [&](const auto& s) { return s.lens()[axis] == 1; })) inputs.begin(), inputs.end(), [&](const auto& s) { return s.lens()[axis] == 1; }))
...@@ -83,16 +85,19 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs ...@@ -83,16 +85,19 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
vectorize vectorize::elements(context& ctx, std::size_t axis, const std::vector<shape>& inputs) vectorize vectorize::elements(context& ctx, std::size_t axis, const std::vector<shape>& inputs)
{ {
if (inputs.empty()) if(inputs.empty())
return {1, axis}; return {1, axis};
std::size_t n = std::max_element(inputs.begin(), inputs.end(), by(std::less<>{}, [](const auto& s) { return s.elements(); }))->elements(); std::size_t n = std::max_element(inputs.begin(),
inputs.end(),
by(std::less<>{}, [](const auto& s) { return s.elements(); }))
->elements();
std::size_t max_global = ctx.get_current_device().get_cu_count() * std::size_t max_global = ctx.get_current_device().get_cu_count() *
ctx.get_current_device().get_max_workitems_per_cu(); ctx.get_current_device().get_max_workitems_per_cu();
std::size_t over = n / max_global; std::size_t over = n / max_global;
std::vector<std::size_t> sizes; std::vector<std::size_t> sizes;
if (over > 8) if(over > 8)
sizes.push_back(8); sizes.push_back(8);
if (over > 4) if(over > 4)
sizes.push_back(4); sizes.push_back(4);
sizes.push_back(2); sizes.push_back(2);
return elements(axis, inputs, sizes); return elements(axis, inputs, sizes);
......
...@@ -47,7 +47,9 @@ struct vectorize ...@@ -47,7 +47,9 @@ struct vectorize
std::size_t axis = 0; std::size_t axis = 0;
static vectorize elements(std::size_t axis, const std::vector<shape>& inputs); static vectorize elements(std::size_t axis, const std::vector<shape>& inputs);
static vectorize elements(context& ctx, std::size_t axis, const std::vector<shape>& inputs); static vectorize elements(context& ctx, std::size_t axis, const std::vector<shape>& inputs);
static vectorize elements(std::size_t axis, const std::vector<shape>& inputs, const std::vector<std::size_t>& sizes); static vectorize elements(std::size_t axis,
const std::vector<shape>& inputs,
const std::vector<std::size_t>& sizes);
std::string str() const; std::string str() const;
}; };
struct preload struct preload
......
...@@ -78,10 +78,7 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -78,10 +78,7 @@ struct pointwise_compiler : compiler<pointwise_compiler>
auto vec = vectorize::elements(ctx, axis, options.virtual_inputs); auto vec = vectorize::elements(ctx, axis, options.virtual_inputs);
options.kernel_name = v.get("kernel", "kernel"); options.kernel_name = v.get("kernel", "kernel");
options.set_launch_params( options.set_launch_params(
v, v, compute_global_for(ctx, options.output.elements() / vec.size, 256));
compute_global_for(ctx,
options.output.elements() / vec.size,
256));
auto src = interpolate_string(pointwise_kernel, auto src = interpolate_string(pointwise_kernel,
{{"kernel", options.kernel_name}, {{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment