Commit 3b5d27f7 authored by Paul's avatar Paul
Browse files

Use average

parent a0921a37
......@@ -61,11 +61,15 @@ struct concat_compiler : compiler<concat_compiler>
{
std::vector<std::string> names() const { return {"concat"}; }
static std::size_t get_min_elements(const std::vector<shape>& inputs)
static std::size_t get_concat_elements(const std::vector<shape>& inputs)
{
auto it = std::min_element(
inputs.begin(), inputs.end(), by(std::less<>{}, [](auto s) { return s.elements(); }));
return it->elements();
auto total = std::accumulate(inputs.begin(), inputs.end(), 0, [](auto x, auto s) {
return x + s.elements();
});
return total / inputs.size();
// auto it = std::min_element(
// inputs.begin(), inputs.end(), by(std::less<>{}, [](auto s) { return s.elements(); }));
// return it->elements();
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
......@@ -79,7 +83,7 @@ struct concat_compiler : compiler<concat_compiler>
auto vec = vectorize::elements(axis, options.virtual_inputs);
options.kernel_name = v.get("kernel", "concat_kernel");
options.set_launch_params(
v, compute_global_for(ctx, get_min_elements(options.inputs) / vec.size, 256));
v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256));
auto src = interpolate_string(concat_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment