Commit 88ed7f85 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from opt_log_softmax

parents d13dcab5 77af16d8
......@@ -111,6 +111,7 @@ struct tf_parser
add_generic_op("Identity", op::identity{});
add_generic_op("Relu", op::relu{});
add_generic_op("Relu6", op::clip{6.0, 0.0});
add_generic_op("Tanh", op::tanh{});
add_binary_op("Add", op::add{});
add_binary_op("Mul", op::mul{});
......
......@@ -1696,4 +1696,79 @@ TEST_CASE(clip_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(reduce_sum_test0)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{0}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{15, 18, 21, 24};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test1)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{1}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{4, 6, 12, 14, 20, 22};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test2)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{3, 7, 11, 15, 19, 23};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test02)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{0, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{33, 45};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test12)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_sum{{1, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{10, 26, 42};
EXPECT(results_vector == gold);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -586,7 +586,7 @@ struct test_softmax2 : verify_program<test_softmax2>
{
migraphx::program p;
auto x =
p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1000, 1, 1}});
p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 1028, 1, 25}});
p.add_instruction(migraphx::op::softmax{}, x);
return p;
}
......@@ -3471,4 +3471,40 @@ struct test_fp32_fp16_sub : verify_program<test_fp32_fp16_sub>
};
};
struct test_reduce_sum : verify_program<test_reduce_sum>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 1026, 4, 3}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_sum{{1}}, x);
return p;
};
};
struct test_reduce_sum_int : verify_program<test_reduce_sum_int>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::int32_type, {3, 4, 8, 8}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_sum{{1}}, x);
return p;
};
};
struct test_reduce_sum_half : verify_program<test_reduce_sum_half>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::half_type, {3, 4, 8, 8}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_sum{{1}}, x);
return p;
};
};
int main(int argc, const char* argv[]) { test::run(argc, argv); }
:
0 Placeholder*
dtype0*
shape:

tanhTanh0*
T0"
\ No newline at end of file
......@@ -370,4 +370,15 @@ TEST_CASE(sub_test)
EXPECT(p == prog);
}
TEST_CASE(tanh_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
p.add_instruction(migraphx::op::sub{}, l0, l1);
auto prog = migraphx::parse_tf("sub_test.pb", false);
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment