Commit 82201cc0 authored by Khalique's avatar Khalique
Browse files

add onehot test and modified other tests

parent 8a08b4ca
......@@ -26,7 +26,6 @@ struct tf_parser
{
using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>;
using node_map = std::map<std::string, tensorflow::NodeDef>;
// using input_node_map = std::unordered_map<std::string, std::unordered_set<std::string>>;
using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>;
node_map nodes;
......@@ -577,33 +576,22 @@ struct tf_parser
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
bool keep_dims = attributes.at("keep_dims").b();
// std::vector<int32_t> hw_axes{2, 3};
// check if conditions for GlobalAvgPool are met
auto lens = args[0]->get_shape().lens();
auto axes = args[1]->eval().get<int32_t>().to_vector();
std::vector<int64_t> axes_int64 = std::vector<int64_t>(axes.begin(), axes.end());
// if(axes == hw_axes and lens.size() == 4)
// {
// op::pooling op{"average"};
// op.lengths[0] = lens[2];
// op.lengths[1] = lens[3];
auto l0 = prog.add_instruction(op::reduce_mean{axes_int64}, args.front());
if(keep_dims)
return l0;
return prog.add_instruction(op::squeeze{axes_int64}, l0);
// }
// MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
}
instruction_ref
parse_onehot(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
// auto indices = args[0]->eval().get<int32_t>().to_vector();
size_t depth = static_cast<size_t>(args[1]->eval().at<int32_t>());
int64_t axis = -1;
// size_t num_indices = indices.size();
float on_value = args[2]->eval().at<float>();
float off_value = args[3]->eval().at<float>();
......
......@@ -257,10 +257,8 @@ TEST_CASE(mean_test)
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_literal(l);
p.add_literal(l);
migraphx::op::pooling op;
op.lengths = {16, 16};
p.add_instruction(op, l0);
auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::reduce_mean{{2, 3}}, l0);
auto l3 = p.add_instruction(migraphx::op::reduce_mean{{2, 3}}, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
auto prog = optimize_tf("mean_test.pb", false);
......@@ -272,10 +270,9 @@ TEST_CASE(mean_test_nhwc)
migraphx::program p;
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
migraphx::op::pooling op;
op.lengths = {16, 16};
auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
auto l1 = p.add_instruction(migraphx::op::transpose{{0,2,3,1}}, l0);
auto l2 = p.add_instruction(migraphx::op::reduce_mean{{1,2}}, l1);
p.add_instruction(migraphx::op::squeeze{{1, 2}}, l2);
auto prog = optimize_tf("mean_test_nhwc.pb", true);
EXPECT(p == prog);
......@@ -293,6 +290,21 @@ TEST_CASE(mul_test)
EXPECT(p == prog);
}
TEST_CASE(onehot_test)
{
migraphx::program p;
auto l0 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {5}}, {1, 1, 1, 1, 1}});
p.add_literal(2);
p.add_literal(1.0f);
p.add_literal(0.0f);
auto l1 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2,2}}, {1, 0, 0, 1}});
int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, l1, l0);
auto prog = optimize_tf("onehot_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(pack_test)
{
migraphx::program p;
......@@ -477,15 +489,16 @@ TEST_CASE(stridedslice_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
auto l1 = p.add_instruction(migraphx::op::transpose{{0,2,3,1}}, l0);
std::size_t num_axes = 4;
migraphx::op::slice op;
op.starts = {0, 0, 0, 0};
op.ends = {1, 1, 1, 5};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
auto l1 = p.add_instruction(op, l0);
auto l2 = p.add_instruction(op, l1);
auto shrink_axis = 1;
p.add_instruction(migraphx::op::squeeze{{shrink_axis}}, l1);
p.add_instruction(migraphx::op::squeeze{{shrink_axis}}, l2);
auto prog = optimize_tf("stridedslice_test.pb", true);
EXPECT(p == prog);
......@@ -497,8 +510,8 @@ TEST_CASE(stridedslice_masks_test)
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}});
std::size_t num_axes = 4;
migraphx::op::slice op;
op.starts = {0, 0, 1, 1};
op.ends = {1, 10, 3, 3};
op.starts = {0, 1, 1, 0};
op.ends = {1, 3, 3, 10};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
// add literals for starts, ends, and strides in tf (NHWC format)
......@@ -506,7 +519,9 @@ TEST_CASE(stridedslice_masks_test)
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0, 0, 0, 0});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 1});
p.add_instruction(op, l0);
auto l1 = p.add_instruction(migraphx::op::transpose{{0,2,3,1}}, l0);
auto l2 = p.add_instruction(op, l1);
p.add_instruction(migraphx::op::transpose{{0,3,1,2}}, l2);
auto prog = migraphx::parse_tf("stridedslice_masks_test.pb", true);
EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment