"references/vscode:/vscode.git/clone" did not exist on "97e21c1094b37ff356841e37fa9d17ab34527919"
Commit 70a993f3 authored by Khalique's avatar Khalique
Browse files

Merge branch 'master' of https://github.com/ROCmSoftwarePlatform/MIGraph into dropout

parents 0ca1089c 6feca68d
...@@ -76,6 +76,7 @@ struct onnx_parser ...@@ -76,6 +76,7 @@ struct onnx_parser
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze); add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
add_mem_op("Slice", &onnx_parser::parse_slice); add_mem_op("Slice", &onnx_parser::parse_slice);
add_mem_op("Concat", &onnx_parser::parse_concat); add_mem_op("Concat", &onnx_parser::parse_concat);
add_mem_op("Transpose", &onnx_parser::parse_transpose);
} }
template <class F> template <class F>
...@@ -363,6 +364,18 @@ struct onnx_parser ...@@ -363,6 +364,18 @@ struct onnx_parser
return prog.add_instruction(migraph::op::add{}, img_scaled, bias_bcast); return prog.add_instruction(migraph::op::add{}, img_scaled, bias_bcast);
} }
instruction_ref
parse_transpose(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
std::vector<int64_t> perm{};
if(contains(attributes, "perm"))
{
auto&& perm_vals = attributes["perm"].ints();
perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
}
return prog.add_instruction(migraph::op::transpose{perm}, args.front());
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
onnx::ModelProto model; onnx::ModelProto model;
......
...@@ -146,6 +146,18 @@ void globalmaxpool_test() ...@@ -146,6 +146,18 @@ void globalmaxpool_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void transpose_test()
{
migraph::program p;
auto input = p.add_parameter("0", migraph::shape{migraph::shape::float_type, {1, 2, 2, 3}});
std::vector<int64_t> perm{0, 3, 1, 2};
p.add_instruction(migraph::op::transpose{perm}, input);
auto prog = migraph::parse_onnx("transpose_test.onnx");
EXPECT(p == prog);
}
int main() int main()
{ {
pytorch_conv_bias_test(); pytorch_conv_bias_test();
...@@ -156,4 +168,5 @@ int main() ...@@ -156,4 +168,5 @@ int main()
imagescaler_test(); imagescaler_test();
globalavgpool_test(); globalavgpool_test();
globalmaxpool_test(); globalmaxpool_test();
transpose_test();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment