"ts/webui/vscode:/vscode.git/clone" did not exist on "64ea284f3b4d27a321c55755ff862c36d753e482"
Commit 9045e5ac authored by Khalique's avatar Khalique
Browse files

add gather and test

parent 03f5c679
...@@ -170,6 +170,7 @@ struct tf_parser ...@@ -170,6 +170,7 @@ struct tf_parser
add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv); add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv);
add_mem_op("ExpandDims", &tf_parser::parse_expanddims, false); add_mem_op("ExpandDims", &tf_parser::parse_expanddims, false);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm); add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("GatherV2", &tf_parser::parse_gather, false);
add_mem_op("MatMul", &tf_parser::parse_matmul, false); add_mem_op("MatMul", &tf_parser::parse_matmul, false);
add_mem_op("MaxPool", &tf_parser::parse_pooling); add_mem_op("MaxPool", &tf_parser::parse_pooling);
add_mem_op("Mean", &tf_parser::parse_mean); add_mem_op("Mean", &tf_parser::parse_mean);
...@@ -513,6 +514,14 @@ struct tf_parser ...@@ -513,6 +514,14 @@ struct tf_parser
return prog.add_instruction(op::reshape{new_dims}, args[0]); return prog.add_instruction(op::reshape{new_dims}, args[0]);
} }
instruction_ref
parse_gather(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
int axis = args[2]->eval().at<int32_t>();
op::gather op{axis};
return prog.add_instruction(op, {args[0], args[1]});
}
instruction_ref instruction_ref
parse_matmul(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_matmul(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
......
...@@ -184,6 +184,21 @@ TEST_CASE(expanddims_test_neg_dims) ...@@ -184,6 +184,21 @@ TEST_CASE(expanddims_test_neg_dims)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(gather_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4}});
auto l1 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 1}});
p.add_literal(1);
int axis = 1;
p.add_instruction(migraphx::op::gather{axis}, l0, l1);
auto prog = optimize_tf("gather_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(identity_test) TEST_CASE(identity_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment