"test/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "ec1ac8c0440202c501df405b1c8e4c5f16dfffbc"
Commit aa7b76b5 authored by Paul's avatar Paul
Browse files

Parse mean as reduce mean instead of pooling

parent 1fe84f2a
...@@ -132,7 +132,11 @@ struct tensor_view ...@@ -132,7 +132,11 @@ struct tensor_view
return m_data + this->size(); return m_data + this->size();
} }
std::vector<T> to_vector() const { return std::vector<T>(this->begin(), this->end()); } template <class U = T>
std::vector<U> to_vector() const
{
return std::vector<U>(this->begin(), this->end());
}
friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x) friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x)
{ {
......
...@@ -574,23 +574,18 @@ struct tf_parser ...@@ -574,23 +574,18 @@ struct tf_parser
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
bool keep_dims = attributes.at("keep_dims").b(); bool keep_dims = attributes.at("keep_dims").b();
std::vector<int32_t> hw_axes{2, 3};
// check if conditions for GlobalAvgPool are met
auto lens = args[0]->get_shape().lens(); auto lens = args[0]->get_shape().lens();
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector(), lens.size()); auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector<int64_t>(), lens.size());
if(axes == hw_axes and lens.size() == 4)
{
op::pooling op{"average"};
op.lengths[0] = lens[2];
op.lengths[1] = lens[3];
auto l0 = prog.add_instruction(op, args.front());
if(keep_dims) if(keep_dims)
return l0; {
return prog.add_instruction( return prog.add_instruction(op::reduce_mean{axes}, args[0]);
op::squeeze{std::vector<int64_t>(hw_axes.begin(), hw_axes.end())}, l0); }
else
{
auto ins = prog.add_instruction(op::reduce_mean{axes}, args[0]);
return prog.add_instruction(op::squeeze{axes}, ins);
} }
MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
} }
instruction_ref parse_pack(const std::string&, instruction_ref parse_pack(const std::string&,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment