Commit a5f62830 authored by Khalique's avatar Khalique
Browse files

formatting

parent ab855464
......@@ -354,16 +354,20 @@ struct tf_parser
MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
}
instruction_ref
parse_pack(const std::string&, const attribute_map& attributes, std::vector<instruction_ref> args)
instruction_ref parse_pack(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
// reinterpret as unsqueeze with concat
std::vector<instruction_ref> unsqueezed_args;
int64_t axis = 0;
if(contains(attributes, "axis"))
axis = attributes.at("axis").i();
std::transform(args.begin(), args.end(), std::back_inserter(unsqueezed_args),
[&](instruction_ref arg){ return prog.add_instruction(op::unsqueeze{{axis}}, arg); });
std::transform(
args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); });
return prog.add_instruction(op::concat{static_cast<size_t>(axis)}, unsqueezed_args);
}
......
......@@ -139,8 +139,12 @@ TEST_CASE(pack_test)
std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t axis = 1;
std::transform(args.begin(), args.end(), std::back_inserter(unsqueezed_args),
[&](migraphx::instruction_ref arg) { return p.add_instruction(migraphx::op::unsqueeze{{axis}}, arg); });
std::transform(args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](migraphx::instruction_ref arg) {
return p.add_instruction(migraphx::op::unsqueeze{{axis}}, arg);
});
p.add_instruction(migraphx::op::concat{static_cast<size_t>(axis)}, unsqueezed_args);
auto prog = migraphx::parse_tf("pack_test.pb", false);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment