Commit c310d928 authored by Paul's avatar Paul
Browse files

Fix reduced shape calculations

parent d4613133
......@@ -61,6 +61,16 @@ __global__ void ${kernel}(${params})
template <class T>
static shape get_reduced_shape(const shape& s, const std::vector<T>& axes)
{
auto lens = s.lens();
std::fill(lens.begin(), lens.end(), 1);
for(const auto& axis : axes)
lens[axis] = s.lens()[axis];
return shape{s.type(), lens};
}
template <class T>
static shape get_output_shape(const shape& s, const std::vector<T>& axes)
{
auto lens = s.lens();
for(const auto& axis : axes)
......@@ -93,10 +103,13 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
auto axes = v.at("axes").to_vector<std::size_t>();
auto virtual_inputs = inputs;
virtual_inputs.push_back(
get_reduced_shape(inputs.front(), v.at("axes").to_vector<std::size_t>()));
virtual_inputs.push_back(get_reduced_shape(inputs.front(), axes));
virtual_inputs.push_back(get_output_shape(inputs.front(), axes));
virtual_inputs = reduce_dims(virtual_inputs);
auto output_shape = virtual_inputs.back();
virtual_inputs.pop_back();
auto reduced_shape = virtual_inputs.back();
virtual_inputs.pop_back();
......@@ -136,7 +149,7 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"algo", algo},
{"reduced", "decltype(" + generate_make_shape(reduced_shape) + ")"},
{"reduced", "decltype(" + generate_make_shape(output_shape) + ")"},
{"lambda", v.at("lambda").to<std::string>()},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment