Commit 82201cc0 authored by Khalique's avatar Khalique
Browse files

add onehot test and modified other tests

parent 8a08b4ca
...@@ -26,7 +26,6 @@ struct tf_parser ...@@ -26,7 +26,6 @@ struct tf_parser
{ {
using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>; using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>;
using node_map = std::map<std::string, tensorflow::NodeDef>; using node_map = std::map<std::string, tensorflow::NodeDef>;
// using input_node_map = std::unordered_map<std::string, std::unordered_set<std::string>>;
using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>; using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>;
node_map nodes; node_map nodes;
...@@ -577,33 +576,22 @@ struct tf_parser ...@@ -577,33 +576,22 @@ struct tf_parser
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
bool keep_dims = attributes.at("keep_dims").b(); bool keep_dims = attributes.at("keep_dims").b();
// std::vector<int32_t> hw_axes{2, 3};
// check if conditions for GlobalAvgPool are met
auto lens = args[0]->get_shape().lens(); auto lens = args[0]->get_shape().lens();
auto axes = args[1]->eval().get<int32_t>().to_vector(); auto axes = args[1]->eval().get<int32_t>().to_vector();
std::vector<int64_t> axes_int64 = std::vector<int64_t>(axes.begin(), axes.end()); std::vector<int64_t> axes_int64 = std::vector<int64_t>(axes.begin(), axes.end());
// if(axes == hw_axes and lens.size() == 4)
// {
// op::pooling op{"average"};
// op.lengths[0] = lens[2];
// op.lengths[1] = lens[3];
auto l0 = prog.add_instruction(op::reduce_mean{axes_int64}, args.front()); auto l0 = prog.add_instruction(op::reduce_mean{axes_int64}, args.front());
if(keep_dims) if(keep_dims)
return l0; return l0;
return prog.add_instruction(op::squeeze{axes_int64}, l0); return prog.add_instruction(op::squeeze{axes_int64}, l0);
// }
// MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
} }
instruction_ref instruction_ref
parse_onehot(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_onehot(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
// auto indices = args[0]->eval().get<int32_t>().to_vector();
size_t depth = static_cast<size_t>(args[1]->eval().at<int32_t>()); size_t depth = static_cast<size_t>(args[1]->eval().at<int32_t>());
int64_t axis = -1; int64_t axis = -1;
// size_t num_indices = indices.size();
float on_value = args[2]->eval().at<float>(); float on_value = args[2]->eval().at<float>();
float off_value = args[3]->eval().at<float>(); float off_value = args[3]->eval().at<float>();
......
...@@ -257,10 +257,8 @@ TEST_CASE(mean_test) ...@@ -257,10 +257,8 @@ TEST_CASE(mean_test)
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_literal(l); p.add_literal(l);
p.add_literal(l); p.add_literal(l);
migraphx::op::pooling op; p.add_instruction(migraphx::op::reduce_mean{{2, 3}}, l0);
op.lengths = {16, 16}; auto l3 = p.add_instruction(migraphx::op::reduce_mean{{2, 3}}, l0);
p.add_instruction(op, l0);
auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3); p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
auto prog = optimize_tf("mean_test.pb", false); auto prog = optimize_tf("mean_test.pb", false);
...@@ -272,10 +270,9 @@ TEST_CASE(mean_test_nhwc) ...@@ -272,10 +270,9 @@ TEST_CASE(mean_test_nhwc)
migraphx::program p; migraphx::program p;
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}}; migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
migraphx::op::pooling op; auto l1 = p.add_instruction(migraphx::op::transpose{{0,2,3,1}}, l0);
op.lengths = {16, 16}; auto l2 = p.add_instruction(migraphx::op::reduce_mean{{1,2}}, l1);
auto l3 = p.add_instruction(op, l0); p.add_instruction(migraphx::op::squeeze{{1, 2}}, l2);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
auto prog = optimize_tf("mean_test_nhwc.pb", true); auto prog = optimize_tf("mean_test_nhwc.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -293,6 +290,21 @@ TEST_CASE(mul_test) ...@@ -293,6 +290,21 @@ TEST_CASE(mul_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(onehot_test)
{
migraphx::program p;
auto l0 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {5}}, {1, 1, 1, 1, 1}});
p.add_literal(2);
p.add_literal(1.0f);
p.add_literal(0.0f);
auto l1 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2,2}}, {1, 0, 0, 1}});
int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, l1, l0);
auto prog = optimize_tf("onehot_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(pack_test) TEST_CASE(pack_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -477,15 +489,16 @@ TEST_CASE(stridedslice_test) ...@@ -477,15 +489,16 @@ TEST_CASE(stridedslice_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
auto l1 = p.add_instruction(migraphx::op::transpose{{0,2,3,1}}, l0);
std::size_t num_axes = 4; std::size_t num_axes = 4;
migraphx::op::slice op; migraphx::op::slice op;
op.starts = {0, 0, 0, 0}; op.starts = {0, 0, 0, 0};
op.ends = {1, 1, 1, 5}; op.ends = {1, 1, 1, 5};
op.axes = std::vector<int64_t>(num_axes); op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0); std::iota(op.axes.begin(), op.axes.end(), 0);
auto l1 = p.add_instruction(op, l0); auto l2 = p.add_instruction(op, l1);
auto shrink_axis = 1; auto shrink_axis = 1;
p.add_instruction(migraphx::op::squeeze{{shrink_axis}}, l1); p.add_instruction(migraphx::op::squeeze{{shrink_axis}}, l2);
auto prog = optimize_tf("stridedslice_test.pb", true); auto prog = optimize_tf("stridedslice_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -497,8 +510,8 @@ TEST_CASE(stridedslice_masks_test) ...@@ -497,8 +510,8 @@ TEST_CASE(stridedslice_masks_test)
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}});
std::size_t num_axes = 4; std::size_t num_axes = 4;
migraphx::op::slice op; migraphx::op::slice op;
op.starts = {0, 0, 1, 1}; op.starts = {0, 1, 1, 0};
op.ends = {1, 10, 3, 3}; op.ends = {1, 3, 3, 10};
op.axes = std::vector<int64_t>(num_axes); op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0); std::iota(op.axes.begin(), op.axes.end(), 0);
// add literals for starts, ends, and strides in tf (NHWC format) // add literals for starts, ends, and strides in tf (NHWC format)
...@@ -506,7 +519,9 @@ TEST_CASE(stridedslice_masks_test) ...@@ -506,7 +519,9 @@ TEST_CASE(stridedslice_masks_test)
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0, 0, 0, 0}); p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0, 0, 0, 0});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 1}); p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 1});
p.add_instruction(op, l0); auto l1 = p.add_instruction(migraphx::op::transpose{{0,2,3,1}}, l0);
auto l2 = p.add_instruction(op, l1);
p.add_instruction(migraphx::op::transpose{{0,3,1,2}}, l2);
auto prog = migraphx::parse_tf("stridedslice_masks_test.pb", true); auto prog = migraphx::parse_tf("stridedslice_masks_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment