Commit 93239a68 authored by Paul's avatar Paul
Browse files

Merge branch 'jit-concat' into jit-concat-pointwise

parents a839ade9 fc9b2a7d
...@@ -66,9 +66,11 @@ struct concat_compiler : compiler<concat_compiler> ...@@ -66,9 +66,11 @@ struct concat_compiler : compiler<concat_compiler>
static std::size_t get_concat_elements(const std::vector<shape>& inputs) static std::size_t get_concat_elements(const std::vector<shape>& inputs)
{ {
auto total = std::accumulate( auto total =
inputs.begin(), inputs.end(), 0, [](auto x, auto s) { return x + s.elements(); }); std::accumulate(inputs.begin(), std::prev(inputs.end()), 0, [](auto x, auto s) {
return total / inputs.size(); return x + s.elements();
});
return total / (inputs.size() - 1);
} }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment