Commit 15385fb1 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from develop branch

parents f7f02979 b606ed4f
......@@ -63,7 +63,7 @@ TEST_CASE(biasadd_test)
uint64_t axis = 1;
auto l0 = p.add_parameter("0", s0);
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
auto l2 = p.add_instruction(migraphx::op::broadcast{axis, l0->get_shape()}, l1);
auto l2 = p.add_instruction(migraphx::op::broadcast{axis, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::add{}, l0, l2);
auto prog = migraphx::parse_tf("biasadd_test.pb", true);
......@@ -80,7 +80,7 @@ TEST_CASE(concat_test)
int axis = 1;
// tf uses axis as the third input, and it is in int32 format
// add the literal using a vector in order to set stride to 1 (like in tf parser)
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {1}}, std::vector<int>{axis});
p.add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis});
p.add_instruction(migraphx::op::concat{static_cast<std::size_t>(axis)}, l0, l1);
auto prog = migraphx::parse_tf("concat_test.pb", false);
......@@ -91,7 +91,7 @@ TEST_CASE(concat_test)
TEST_CASE(const_test)
{
migraphx::program p;
p.add_literal(migraphx::shape{migraphx::shape::float_type, {1}}, std::vector<float>{1.0f});
p.add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector<float>{1.0f});
auto prog = migraphx::parse_tf("constant_test.pb", false);
EXPECT(p == prog);
......@@ -129,6 +129,77 @@ TEST_CASE(identity_test)
EXPECT(p == prog);
}
TEST_CASE(matmul_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {8, 4}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 8}});
auto trans_l0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0);
auto trans_l1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
p.add_instruction(migraphx::op::dot{}, trans_l0, trans_l1);
auto prog = migraphx::parse_tf("matmul_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(mul_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});
p.add_instruction(migraphx::op::mul{}, l0, l1);
auto prog = migraphx::parse_tf("mul_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(pack_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2}});
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {2}});
std::vector<migraphx::instruction_ref> args{l0, l1, l2};
std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t axis = 1;
std::transform(args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](migraphx::instruction_ref arg) {
return p.add_instruction(migraphx::op::unsqueeze{{axis}}, arg);
});
p.add_instruction(migraphx::op::concat{static_cast<size_t>(axis)}, unsqueezed_args);
auto prog = migraphx::parse_tf("pack_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(pack_test_nhwc)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
std::vector<migraphx::instruction_ref> args{l0, l1, l2};
std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t nchw_axis = 1;
std::transform(args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](migraphx::instruction_ref arg) {
return p.add_instruction(migraphx::op::unsqueeze{{nchw_axis}}, arg);
});
p.add_instruction(migraphx::op::concat{static_cast<size_t>(nchw_axis)}, unsqueezed_args);
auto prog = migraphx::parse_tf("pack_test_nhwc.pb", true);
EXPECT(p == prog);
}
TEST_CASE(pooling_test)
{
migraphx::program p;
......@@ -193,4 +264,28 @@ TEST_CASE(squeeze_test)
EXPECT(p == prog);
}
TEST_CASE(stridedslice_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
std::size_t num_axes = 4;
migraphx::op::slice op;
op.starts = {0, 0, 0, 0};
op.ends = {1, 5, 1, 1};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
// add literals for starts, ends, and strides in tf (NHWC format)
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0, 0, 0, 0});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 5});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 1});
auto l1 = p.add_instruction(op, l0);
auto shrink_axis = 2;
p.add_instruction(migraphx::op::squeeze{{shrink_axis}}, l1);
auto prog = migraphx::parse_tf("stridedslice_test.pb", true);
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -49,7 +49,7 @@ struct operation
argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
/// An optional method to return which argument the output will alias. If
/// there is no aliased output then -1 can be returned.
int output_alias(const std::vector<shape>& input) const;
std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
/// An optional stream operator to print the operation. When this is not
/// implemented, it will just print the operation's name.
friend std::ostream& operator<<(std::ostream& os, const operation& op);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment