Commit aa7b76b5 authored by Paul's avatar Paul
Browse files

Parse mean as reduce mean instead of pooling

parent 1fe84f2a
......@@ -132,7 +132,11 @@ struct tensor_view
return m_data + this->size();
}
std::vector<T> to_vector() const { return std::vector<T>(this->begin(), this->end()); }
template <class U = T>
std::vector<U> to_vector() const
{
return std::vector<U>(this->begin(), this->end());
}
friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x)
{
......
......@@ -574,23 +574,18 @@ struct tf_parser
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
bool keep_dims = attributes.at("keep_dims").b();
std::vector<int32_t> hw_axes{2, 3};
// check if conditions for GlobalAvgPool are met
auto lens = args[0]->get_shape().lens();
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector(), lens.size());
auto lens = args[0]->get_shape().lens();
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector<int64_t>(), lens.size());
if(axes == hw_axes and lens.size() == 4)
if(keep_dims)
{
op::pooling op{"average"};
op.lengths[0] = lens[2];
op.lengths[1] = lens[3];
auto l0 = prog.add_instruction(op, args.front());
if(keep_dims)
return l0;
return prog.add_instruction(
op::squeeze{std::vector<int64_t>(hw_axes.begin(), hw_axes.end())}, l0);
return prog.add_instruction(op::reduce_mean{axes}, args[0]);
}
else
{
auto ins = prog.add_instruction(op::reduce_mean{axes}, args[0]);
return prog.add_instruction(op::squeeze{axes}, ins);
}
MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
}
instruction_ref parse_pack(const std::string&,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment