Commit eb0d8fee authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into driver

parents 65ef35cd 0d796941
 sum-example:e
 sum-example:a

0
1
23"Sum test-dropoutZ
23"Sumtest-sumZ
0

......@@ -15,7 +15,7 @@

b
2
3

B
\ No newline at end of file
unknown-example:
unknown-example:

0
12"Unknown
2"Unknown test-unknownZ

23"Unknown test-unknownZ
0


......@@ -14,7 +14,7 @@


b
2
3



......
......@@ -229,13 +229,43 @@ TEST_CASE(multibroadcast)
}
}
TEST_CASE(broadcast)
{
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
migraphx::op::broadcast{0, lens},
input);
}
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 2, 4, 3}, {0, 0, 3, 1}},
migraphx::op::broadcast{2, lens},
input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::op::broadcast{2, lens}, input);
}
}
TEST_CASE(gather)
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 6, 4, 5}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 4, 5}},
migraphx::op::gather{axis},
input,
indices);
......@@ -245,7 +275,57 @@ TEST_CASE(gather)
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {6, 3, 4, 5}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 3, 4, 5}},
migraphx::op::gather{axis},
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {1}};
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}},
migraphx::op::gather{axis},
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type};
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 5}},
migraphx::op::gather{axis},
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type};
int axis = 3;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4}},
migraphx::op::gather{axis},
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {3}};
migraphx::shape indices{migraphx::shape::int32_type};
int axis = 0;
expect_shape(migraphx::shape{migraphx::shape::float_type},
migraphx::op::gather{axis},
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {3}};
migraphx::shape indices{migraphx::shape::int32_type, {1}};
int axis = 0;
expect_shape(migraphx::shape{migraphx::shape::float_type, {1}},
migraphx::op::gather{axis},
input,
indices);
......@@ -266,6 +346,244 @@ TEST_CASE(gather)
}
}
TEST_CASE(logsoftmax)
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 0;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 2;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 3;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 5;
throws_shape(migraphx::op::logsoftmax{axis}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = -1;
throws_shape(migraphx::op::logsoftmax{axis}, input);
}
}
// 2 inputs arguments
TEST_CASE(matmul)
{
{
migraphx::shape s_m1{migraphx::shape::float_type, {5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 2}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 4}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {1, 4}}, migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {6, 1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {6, 1, 4}},
migraphx::op::dot{},
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 6, 1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 6, 5, 4}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 6, 1, 4}},
migraphx::op::dot{},
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {4, 8}}, migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {1, 1}}, migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 7}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 7}},
migraphx::op::dot{},
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1, 5, 7}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 2, 5, 7}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2);
}
}
// 3 input arguments
TEST_CASE(gemm)
{
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 1}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {8}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 1}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
migraphx::op::dot{},
s_m1,
s_m2,
s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 8}},
migraphx::op::dot{},
s_m1,
s_m2,
s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {1, 4, 8}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type, {4, 8}};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
migraphx::shape s_m3{migraphx::shape::float_type};
throws_shape(migraphx::op::dot{}, s_m1, s_m2, s_m3);
}
}
TEST_CASE(rnn)
{
{
......@@ -590,4 +908,168 @@ TEST_CASE(gru)
}
}
TEST_CASE(lstm)
{
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::lstm{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape,
w_shape,
r_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::lstm{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::lstm{hidden_size,
{migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(migraphx::op::lstm{hidden_size + 1,
{migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(migraphx::op::lstm{hidden_size,
{migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(
migraphx::op::lstm{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/program.hpp>
#include <migraphx/ranges.hpp>
#include <sstream>
#include "test.hpp"
#include <basic_ops.hpp>
migraphx::program create_program()
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int64_type});
auto y = p.add_parameter("y", {migraphx::shape::int64_type});
auto sum = p.add_instruction(sum_op{}, x, y);
auto one = p.add_literal(1);
p.add_instruction(sum_op{}, sum, one);
return p;
}
TEST_CASE(basic_graph_test)
{
migraphx::program p = create_program();
std::stringstream ss;
p.print_graph(ss);
std::string test = ss.str();
EXPECT(migraphx::contains(test, "digraph"));
EXPECT(migraphx::contains(test, "rankdir=LR"));
EXPECT(migraphx::contains(test, "\"@0\"[label=\"@literal\"]"));
EXPECT(migraphx::contains(test, "\"y\"[label=\"@param:y\"]"));
EXPECT(migraphx::contains(test, "\"x\"[label=\"@param:x\"]"));
EXPECT(migraphx::contains(test, "\"@1\"[label=\"sum\"]"));
EXPECT(migraphx::contains(test, "\"@2\"[label=\"sum\"]"));
EXPECT(migraphx::contains(test, "\"x\" -> \"@1\""));
EXPECT(migraphx::contains(test, "\"y\" -> \"@1\""));
EXPECT(migraphx::contains(test, "\"@1\" -> \"@2\""));
EXPECT(migraphx::contains(test, "\"@0\" -> \"@2\""));
EXPECT(migraphx::contains(test, "[label=\"int64_type, {1}, {0}\"]"));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -2,6 +2,10 @@
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/cpu/target.hpp>
#include <sstream>
#include "test.hpp"
#include <basic_ops.hpp>
......@@ -27,4 +31,78 @@ TEST_CASE(program_equality)
EXPECT(x == y);
}
TEST_CASE(program_copy)
{
auto create_program_1 = [] {
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5}};
std::vector<float> data(3 * 4 * 5);
std::iota(data.begin(), data.end(), 1.0f);
auto l2 = p.add_literal(migraphx::literal(s, data));
auto p1 = p.add_parameter("x", s);
auto po = p.add_outline(s);
auto sum = p.add_instruction(migraphx::op::add{}, l2, po);
p.add_instruction(migraphx::op::mul{}, sum, p1);
return p;
};
{
auto p1 = create_program_1();
migraphx::program p2{};
p2 = p1;
p2.compile(migraphx::cpu::target{});
EXPECT(p1 != p2);
p1.compile(migraphx::cpu::target{});
EXPECT(p1 == p2);
}
{
auto p1 = create_program_1();
auto p2(p1);
EXPECT(p1 == p2);
p1.compile(migraphx::cpu::target{});
EXPECT(p1 != p2);
p2 = p1;
EXPECT(p1 == p2);
}
{
auto p1 = create_program_1();
auto p2 = create_program();
EXPECT(p1 != p2);
p2 = p1;
EXPECT(p1 == p2);
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
EXPECT(p1 == p2);
}
{
migraphx::program p1;
migraphx::shape s1{migraphx::shape::float_type, {2, 3}};
migraphx::shape s2{migraphx::shape::float_type, {3, 6}};
migraphx::shape s3{migraphx::shape::float_type, {2, 6}};
auto para1 = p1.add_parameter("m1", s1);
auto para2 = p1.add_parameter("m2", s2);
auto para3 = p1.add_parameter("m3", s3);
p1.add_instruction(migraphx::op::dot{0.31f, 0.28f}, para1, para2, para3);
migraphx::program p2{};
p2 = p1;
EXPECT(p2 == p1);
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
EXPECT(p2 == p1);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/constant_propagate.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/mul.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
......@@ -9,12 +11,12 @@ struct const_prop_target
std::string name() const { return "const_prop"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
{
return {migraphx::constant_propagate{}, migraphx::dead_code_elimination{}};
return {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}};
}
migraphx::context get_context() const { return {}; }
};
TEST_CASE(const_add1)
TEST_CASE(const_add)
{
migraphx::program p1;
auto one = p1.add_literal(1);
......@@ -29,7 +31,7 @@ TEST_CASE(const_add1)
EXPECT(p1 == p2);
}
TEST_CASE(const_add2)
TEST_CASE(const_add_parameter)
{
migraphx::program p1;
auto one = p1.add_parameter("one", {migraphx::shape::int32_type, {1}});
......@@ -44,7 +46,7 @@ TEST_CASE(const_add2)
EXPECT(p1 != p2);
}
TEST_CASE(const_add3)
TEST_CASE(const_multiadd)
{
migraphx::program p1;
auto one = p1.add_literal(1);
......@@ -60,4 +62,54 @@ TEST_CASE(const_add3)
EXPECT(p1 == p2);
}
TEST_CASE(const_add_mul)
{
migraphx::program p1;
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto mul = p1.add_instruction(migraphx::op::mul{}, two, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, one, mul);
auto sum2 = p1.add_instruction(migraphx::op::add{}, sum1, two);
p1.add_instruction(pass_op{}, sum2);
p1.compile(const_prop_target{});
migraphx::program p2;
auto total = p2.add_literal(7);
p2.add_instruction(pass_op{}, total);
EXPECT(p1 == p2);
}
TEST_CASE(const_add_scalar)
{
migraphx::program p1;
auto one = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(1));
auto two = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(2));
auto sum = p1.add_instruction(migraphx::op::add{}, one, two);
p1.add_instruction(pass_op{}, sum);
p1.compile(const_prop_target{});
migraphx::program p2;
auto total =
p2.add_literal(migraphx::literal{{migraphx::shape::int32_type, {2, 2}}, {3, 3, 3, 3}});
p2.add_instruction(pass_op{}, total);
EXPECT(p1 == p2);
}
TEST_CASE(const_scalar)
{
migraphx::program p1;
{
auto one = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(1));
p1.add_instruction(pass_op{}, one);
}
p1.compile(const_prop_target{});
migraphx::program p2;
{
auto one = p2.add_instruction(migraphx::op::scalar{{2, 2}}, p2.add_literal(1));
p2.add_instruction(pass_op{}, one);
}
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -4,6 +4,8 @@ find_package(PythonInterp)
function(add_py_test NAME SCRIPT)
set (ENV_COMMAND ${CMAKE_COMMAND} -E env
"PYTHONPATH=$<TARGET_FILE_DIR:migraphx_py>"
"PYTHONMALLOC=debug"
"MALLOC_CHECK_=3"
)
add_test(
NAME test_py_${NAME}
......@@ -15,7 +17,8 @@ endfunction()
add_dependencies(tests migraphx_py)
add_dependencies(check migraphx_py)
add_py_test(cpu cpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(cpu test_cpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
if(MIGRAPHX_ENABLE_GPU)
add_py_test(gpu gpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(gpu test_gpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(array test_array.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
endif()
import migraphx, struct, array, sys
try:
from functools import reduce
except:
pass
def assert_eq(x, y):
if x == y:
pass
else:
raise Exception(str(x) + " != " + str(y))
def read_float(b, index):
return struct.unpack_from('f', b, index*4)[0]
def write_float(b, index):
struct.pack_into('f', b, index*4)
def nelements(lens):
return reduce(lambda x,y: x*y,lens, 1)
def create_buffer(t, data, shape):
a = array.array(t, data)
if sys.version_info >= (3, 0):
m = memoryview(a.tobytes())
return m.cast(t, shape)
else:
m = memoryview(a.tostring())
return m
def check_argument(a):
l = a.tolist()
for i in range(len(l)):
assert_eq(l[i], read_float(a, i))
def check_shapes(r, m):
lens = list(m.shape)
strides = [int(s/m.itemsize) for s in m.strides]
elements = nelements(lens)
assert_eq(r.get_shape().elements(), elements)
assert_eq(r.get_shape().lens(), lens)
assert_eq(r.get_shape().strides(), strides)
def run(p):
params = {}
for key, value in p.get_parameter_shapes().items():
params[key] = migraphx.to_gpu(migraphx.generate_argument(value))
return migraphx.from_gpu(p.run(params))
def test_shape(shape):
data = list(range(nelements(shape)))
m = create_buffer('f', data, shape)
a = migraphx.argument(m)
check_shapes(a, m)
assert_eq(a.tolist(), data)
def test_input():
if sys.version_info >= (3, 0):
test_shape([4])
test_shape([2, 3])
else:
data = list(range(4))
m = create_buffer('f', data, [4])
a1 = migraphx.argument(m)
a2 = migraphx.argument(bytearray(a1))
check_shapes(a2, m)
assert_eq(a1.tolist(), m.tolist())
def test_output():
p = migraphx.parse_onnx("conv_relu_maxpool.onnx")
p.compile(migraphx.get_target("gpu"))
r1 = run(p)
r2 = run(p)
assert_eq(r1, r2)
assert_eq(r1.tolist(), r2.tolist())
check_argument(r1)
check_argument(r2)
m1 = memoryview(r1)
m2 = memoryview(r2)
check_shapes(r1, m1)
check_shapes(r2, m2)
test_input()
test_output()
#include <migraphx/schedule.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dfor.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct unary_op
{
std::string name() const { return "unary"; }
migraphx::argument
compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
{
if(args.empty())
return {};
return args.front();
}
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
if(inputs.empty())
return {};
return inputs.front();
}
int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
};
struct nary_op
{
std::string comment = "";
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::pack(f(self.comment, "comment"));
}
std::string name() const { return "nary"; }
migraphx::argument
compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
{
if(args.empty())
return {};
return args.front();
}
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
if(inputs.empty())
return {};
return inputs.front();
}
};
struct stream_free_op
{
std::string comment = "";
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::pack(f(self.comment, "comment"));
}
std::string name() const { return "stream_free"; }
migraphx::argument
compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
{
if(args.empty())
return {};
return args.front();
}
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
if(inputs.empty())
return {};
return inputs.front();
}
};
struct wait_event
{
std::shared_ptr<std::vector<std::size_t>> wait_for =
std::make_shared<std::vector<std::size_t>>();
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::pack(f(*self.wait_for, "wait_for"));
}
std::string name() const { return "wait_event"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; }
migraphx::argument compute(migraphx::context&,
const migraphx::shape&,
const std::vector<migraphx::argument>&) const
{
assert(wait_for != nullptr);
assert(not wait_for->empty());
return {};
}
};
using instruction_map = std::unordered_map<migraphx::instruction_ref, std::size_t>;
using int_map = std::unordered_map<std::size_t, std::size_t>;
using wait_map =
std::unordered_map<migraphx::instruction_ref, std::shared_ptr<std::vector<std::size_t>>>;
struct schedule_model_test
{
std::shared_ptr<instruction_map> ins2stream = std::make_shared<instruction_map>();
std::shared_ptr<int_map> wait2stream = std::make_shared<int_map>();
std::shared_ptr<wait_map> ins2wait_for = std::make_shared<wait_map>();
std::size_t concurrency() const { return 4; }
void sched(migraphx::program&, migraphx::instruction_ref ins, std::size_t n) const
{
(*ins2stream)[ins] = n;
}
void wait(migraphx::program& p, migraphx::instruction_ref ins, std::size_t wait_id) const
{
if(ins2wait_for->count(ins) == 0)
{
auto event = wait_event{};
p.insert_instruction(ins, event);
(*ins2wait_for)[ins] = event.wait_for;
}
(*ins2wait_for)[ins]->push_back(wait2stream->at(wait_id));
}
void record(migraphx::program&, migraphx::instruction_ref ins, std::size_t wait_id) const
{
(*wait2stream)[wait_id] = ins2stream->at(ins);
}
std::size_t weight(const migraphx::operation& op) const
{
if(op.name() == "stream_free")
return 0;
else if(op.name() == "binary" or op.name() == "unary")
return 4;
else
return 1;
}
};
bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx::instruction_ref y)
{
for(auto ins : migraphx::iterator_for(p))
{
if(ins->name() != "identity")
continue;
if(not migraphx::contains(ins->inputs(), x))
continue;
if(not migraphx::contains(ins->inputs(), y))
continue;
return true;
}
return false;
}
struct schedule_target
{
schedule_model_test model{};
std::string name() const { return "schedule"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
{
return {migraphx::schedule{model}};
}
migraphx::context get_context() const { return {}; }
std::size_t get_stream(migraphx::instruction_ref ins) { return model.ins2stream->at(ins); }
std::vector<std::size_t> get_streams(std::vector<migraphx::instruction_ref> inss)
{
std::vector<std::size_t> result;
std::transform(inss.begin(), inss.end(), std::back_inserter(result), [&](auto ins) {
return this->get_stream(ins);
});
return result;
}
bool has_stream(migraphx::instruction_ref ins) { return model.ins2stream->count(ins) > 0; }
void check_conflicts(migraphx::program& p,
std::vector<std::vector<migraphx::instruction_ref>> conflicts,
bool result = true)
{
migraphx::dfor(conflicts.size(), conflicts.size())([&](auto i, auto j) {
if(i == j)
return;
for(auto ins1 : conflicts[i])
{
for(auto ins2 : conflicts[j])
{
// If both instructions are on the same stream then dont check for a conflict
if(this->has_stream(ins1) and this->has_stream(ins2) and
this->get_stream(ins1) == this->get_stream(ins2))
continue;
CHECK(::check_conflicts(p, ins1, ins2) == result);
}
}
});
}
};
template <class T>
std::vector<T> sorted(std::vector<T> x)
{
std::sort(x.begin(), x.end());
return x;
}
template <class T>
std::vector<T> unique(std::vector<T> x)
{
std::sort(x.begin(), x.end());
x.erase(std::unique(x.begin(), x.end()), x.end());
return x;
}
std::vector<std::size_t> get_wait_for(std::vector<std::size_t> wait_for)
{
return unique(std::move(wait_for));
}
std::vector<std::size_t> get_wait_for(std::size_t wait_on, std::vector<std::size_t> wait_for)
{
wait_for.erase(std::find(wait_for.begin(), wait_for.end(), wait_on));
return unique(wait_for);
}
std::vector<std::size_t> get_wait_for(migraphx::instruction_ref ins)
{
auto wait_ins = std::prev(ins);
// Skip identity operators
while(wait_ins->name() == "identity")
wait_ins = std::prev(wait_ins);
if(wait_ins->name() != "wait_event")
return {};
auto wf = *migraphx::any_cast<wait_event>(wait_ins->get_operator()).wait_for;
std::sort(wf.begin(), wf.end());
return wf;
}
template <class T>
std::vector<migraphx::instruction_ref>
chain(migraphx::program& p, std::size_t n, T x, migraphx::instruction_ref input)
{
std::vector<migraphx::instruction_ref> result;
for(std::size_t i = 0; i < n; i++)
{
result.push_back(p.add_instruction(x, input));
input = result.back();
}
return result;
}
TEST_CASE(single_entry)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, onep1, onep2);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) ==
get_wait_for(t.get_stream(binary), {t.get_stream(onep1), t.get_stream(onep2)}));
EXPECT(check_conflicts(p, onep1, onep2));
}
TEST_CASE(stream_free)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(stream_free_op{}, one);
auto onep2 = p.add_instruction(stream_free_op{}, one);
auto binary = p.add_instruction(nary_op{}, onep1, onep2);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(not t.has_stream(onep1));
EXPECT(not t.has_stream(onep2));
EXPECT(t.get_stream(binary) == 0);
}
TEST_CASE(zero_record)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one);
auto onei1 = p.add_instruction(migraphx::op::identity{}, onep1);
auto onei2 = p.add_instruction(migraphx::op::identity{}, onep2);
auto binary = p.add_instruction(nary_op{}, onei1, onei2);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
EXPECT(t.has_stream(binary));
EXPECT(get_wait_for(binary) ==
get_wait_for(t.get_stream(binary), {t.get_stream(onep1), t.get_stream(onep2)}));
EXPECT(check_conflicts(p, onep1, onep2));
t.check_conflicts(p, {{onep1, onei1}, {onep2, onei2}});
}
TEST_CASE(zero_merge1)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(migraphx::op::identity{}, onep1, onep2);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
// No stream assignment
EXPECT(not t.has_stream(binary));
// There is no wait
EXPECT(get_wait_for(binary).empty());
EXPECT(check_conflicts(p, onep1, onep2));
}
TEST_CASE(zero_merge2)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(migraphx::op::identity{},
p.add_instruction(migraphx::op::identity{}, onep1),
p.add_instruction(migraphx::op::identity{}, onep2));
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
// No stream assignment
EXPECT(not t.has_stream(binary));
// There is no wait
EXPECT(get_wait_for(binary).empty());
EXPECT(check_conflicts(p, onep1, onep2));
}
TEST_CASE(zero_merge3)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one);
auto id = p.add_instruction(migraphx::op::identity{}, onep1, onep2);
auto final = p.add_instruction(unary_op{}, id);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
// No stream assignment
EXPECT(not t.has_stream(id));
// There is no wait
EXPECT(get_wait_for(id).empty());
// Stream assignment for final op
EXPECT(t.get_stream(final) == 0);
EXPECT(get_wait_for(final) ==
get_wait_for(t.get_stream(final), {t.get_stream(onep1), t.get_stream(onep2)}));
EXPECT(check_conflicts(p, onep1, onep2));
}
TEST_CASE(zero_merge4)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one);
auto id = p.add_instruction(migraphx::op::identity{},
p.add_instruction(migraphx::op::identity{}, onep1),
p.add_instruction(migraphx::op::identity{}, onep2));
auto final = p.add_instruction(unary_op{}, id);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
// No stream assignment
EXPECT(not t.has_stream(id));
// There is no wait
EXPECT(get_wait_for(id).empty());
// Stream assignment for final op
EXPECT(t.get_stream(final) == 0);
EXPECT(get_wait_for(final) ==
get_wait_for(t.get_stream(final), {t.get_stream(onep1), t.get_stream(onep2)}));
EXPECT(check_conflicts(p, onep1, onep2));
}
TEST_CASE(double_entry)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_instruction(stream_free_op{}, p.add_literal(1));
auto two = p.add_instruction(stream_free_op{}, p.add_literal(2));
auto onep = p.add_instruction(unary_op{}, one);
auto twop = p.add_instruction(unary_op{}, two);
auto binary = p.add_instruction(nary_op{}, onep, twop);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(not t.has_stream(two));
EXPECT(t.get_stream(onep) != t.get_stream(twop));
EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) ==
get_wait_for(t.get_stream(binary), {t.get_stream(onep), t.get_stream(twop)}));
t.check_conflicts(p, {{onep, one}, {twop, two}});
}
TEST_CASE(two_branches)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, i1, c1.back());
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 1);
for(auto ins : c1)
EXPECT(t.get_stream(ins) == 0);
EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) ==
get_wait_for(t.get_stream(binary), {t.get_stream(c1.back()), t.get_stream(i1)}));
t.check_conflicts(p, {c1, {i1}});
}
TEST_CASE(four_branches)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 4, unary_op{}, one);
auto c2 = chain(p, 3, unary_op{}, one);
auto c3 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back());
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 3);
for(auto ins : c1)
EXPECT(t.get_stream(ins) == 0);
for(auto ins : c2)
EXPECT(t.get_stream(ins) == 1);
for(auto ins : c3)
EXPECT(t.get_stream(ins) == 2);
EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) == get_wait_for(t.get_stream(binary),
{t.get_stream(c1.back()),
t.get_stream(c2.back()),
t.get_stream(c3.back()),
t.get_stream(i1)}));
t.check_conflicts(p, {c1, c2, c3, {i1}});
}
TEST_CASE(five_branches)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 5, unary_op{}, one);
auto c2 = chain(p, 4, unary_op{}, one);
auto c3 = chain(p, 3, unary_op{}, one);
auto c4 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back(), c4.back());
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 3);
for(auto ins : c1)
EXPECT(t.get_stream(ins) == 0);
for(auto ins : c2)
EXPECT(t.get_stream(ins) == 1);
for(auto ins : c3)
EXPECT(t.get_stream(ins) == 2);
for(auto ins : c4)
EXPECT(t.get_stream(ins) == 3);
EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) == get_wait_for(t.get_stream(binary),
{t.get_stream(c1.back()),
t.get_stream(c2.back()),
t.get_stream(c3.back()),
t.get_stream(i1)}));
t.check_conflicts(p, {c1, c2, c3, c4});
t.check_conflicts(p, {c1, c2, c3, {i1}});
}
TEST_CASE(four_branches_eq)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one);
auto onep3 = p.add_instruction(unary_op{}, one);
auto onep4 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, onep1, onep2, onep3, onep4);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(
sorted<std::size_t>(
{t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}) ==
unique<std::size_t>(
{t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}));
EXPECT(t.get_stream(binary) == 0);
EXPECT(
get_wait_for(binary) ==
get_wait_for(
t.get_stream(binary),
{t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}));
t.check_conflicts(p, {{onep1}, {onep2}, {onep3}, {onep4}});
}
TEST_CASE(seq_merge)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one);
auto binary1 = p.add_instruction(nary_op{}, i1, c1.back());
auto c2 = chain(p, 2, unary_op{}, binary1);
auto i2 = p.add_instruction(unary_op{}, binary1);
auto binary2 = p.add_instruction(nary_op{}, i2, c2.back());
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) != t.get_stream(c1.back()));
for(auto ins : c1)
EXPECT(t.get_stream(ins) == t.get_stream(c1.back()));
EXPECT(t.get_stream(binary1) == t.get_stream(c1.back()));
EXPECT(get_wait_for(binary1) ==
get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
t.check_conflicts(p, {c1, {i1}});
EXPECT(t.get_stream(i2) != t.get_stream(c2.back()));
for(auto ins : c2)
EXPECT(t.get_stream(ins) == t.get_stream(c2.back()));
EXPECT(t.get_stream(binary2) == 0);
EXPECT(get_wait_for(binary2) ==
get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
t.check_conflicts(p, {c2, {i2}});
}
TEST_CASE(par_merge)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto start1 = p.add_instruction(unary_op{}, one);
auto c1 = chain(p, 3, unary_op{}, start1);
auto i1 = p.add_instruction(unary_op{}, start1);
auto binary1 = p.add_instruction(nary_op{}, i1, c1.back());
auto start2 = p.add_instruction(unary_op{}, one);
auto c2 = chain(p, 2, unary_op{}, start2);
auto i2 = p.add_instruction(unary_op{}, start2);
auto binary2 = p.add_instruction(nary_op{}, i2, c2.back());
auto binary3 = p.add_instruction(nary_op{}, binary1, binary2);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(binary3) == 0);
EXPECT(t.get_stream(i1) != t.get_stream(i2));
for(auto ins : c1)
EXPECT(t.get_stream(ins) == 0);
EXPECT(t.get_stream(binary1) == 0);
EXPECT(get_wait_for(binary1) ==
get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
t.check_conflicts(p, {c1, {i1}});
for(auto ins : c2)
EXPECT(t.get_stream(ins) == t.get_stream(binary2));
EXPECT(t.get_stream(binary2) != t.get_stream(i1));
EXPECT(t.get_stream(binary2) != t.get_stream(i2));
EXPECT(get_wait_for(binary2) ==
get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
t.check_conflicts(p, {c2, {i2}});
EXPECT(check_conflicts(p, binary1, binary2));
t.check_conflicts(p, {c1, {i1}, c2, {i2}});
}
TEST_CASE(inner_par_merge)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto start1 = p.add_instruction(unary_op{}, one);
auto c1 = chain(p, 3, unary_op{}, start1);
auto i1 = p.add_instruction(unary_op{}, start1);
auto binary1 = p.add_instruction(nary_op{}, i1, c1.back());
auto start2 = p.add_instruction(unary_op{}, one);
auto c2 = chain(p, 2, unary_op{}, start2);
auto i2 = p.add_instruction(unary_op{}, start2);
auto binary2 = p.add_instruction(nary_op{}, i2, c2.back());
auto outer1 = p.add_instruction(unary_op{}, one);
auto outer2 = p.add_instruction(unary_op{}, one);
auto output = p.add_instruction(nary_op{}, binary1, binary2, outer1, outer2);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(output) == 0);
EXPECT(get_wait_for(output) == get_wait_for(t.get_stream(output),
{t.get_stream(binary1),
t.get_stream(binary2),
t.get_stream(outer1),
t.get_stream(outer2)}));
EXPECT(t.get_stream(outer1) == 1);
EXPECT(t.get_stream(outer2) == 2);
EXPECT(t.get_stream(i1) != t.get_stream(i2));
for(auto ins : c1)
EXPECT(t.get_stream(ins) == 0);
EXPECT(t.get_stream(binary1) == 0);
EXPECT(get_wait_for(binary1) ==
get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
t.check_conflicts(p, {c1, {i1}});
for(auto ins : c2)
EXPECT(t.get_stream(ins) == t.get_stream(binary2));
EXPECT(t.get_stream(binary2) != t.get_stream(i1));
EXPECT(t.get_stream(binary2) != t.get_stream(i2));
EXPECT(get_wait_for(binary2) ==
get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
t.check_conflicts(p, {c2, {i2}});
EXPECT(check_conflicts(p, binary1, binary2));
t.check_conflicts(p, {c1, {i1}, c2, {i2}, {outer1}, {outer2}});
}
TEST_CASE(par_merge_multi_entry)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto start1 = p.add_instruction(unary_op{}, one);
auto c1 = chain(p, 3, unary_op{}, start1);
auto i1 = p.add_instruction(unary_op{}, start1);
auto binary1 = p.add_instruction(nary_op{}, i1, c1.back());
auto two = p.add_literal(1);
auto start2 = p.add_instruction(unary_op{}, two);
auto c2 = chain(p, 2, unary_op{}, start2);
auto i2 = p.add_instruction(unary_op{}, start2);
auto binary2 = p.add_instruction(nary_op{}, i2, c2.back());
auto binary3 = p.add_instruction(nary_op{}, binary1, binary2);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(not t.has_stream(two));
EXPECT(t.get_stream(binary3) == 0);
EXPECT(t.get_stream(i1) != t.get_stream(i2));
for(auto ins : c1)
EXPECT(t.get_stream(ins) == 0);
EXPECT(t.get_stream(binary1) == 0);
EXPECT(get_wait_for(binary1) ==
get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
t.check_conflicts(p, {c1, {i1}});
for(auto ins : c2)
EXPECT(t.get_stream(ins) == t.get_stream(binary2));
EXPECT(t.get_stream(binary2) != t.get_stream(i1));
EXPECT(t.get_stream(binary2) != t.get_stream(i2));
EXPECT(get_wait_for(binary2) ==
get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
t.check_conflicts(p, {c2, {i2}});
EXPECT(check_conflicts(p, binary1, binary2));
t.check_conflicts(p, {c1, {i1}, c2, {i2}});
}
TEST_CASE(inner_split1)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one);
auto s1 = p.add_instruction(unary_op{}, c1);
auto s2 = p.add_instruction(unary_op{}, c1);
auto output = p.add_instruction(nary_op{}, i1, s1, s2);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) != t.get_stream(s1));
EXPECT(t.get_stream(i1) != t.get_stream(s2));
for(auto ins : c1)
EXPECT(t.get_stream(ins) != t.get_stream(i1));
EXPECT(t.get_stream(s1) != t.get_stream(s2));
EXPECT(t.get_stream(output) == 0);
EXPECT(
get_wait_for(output) ==
get_wait_for(t.get_stream(output), {t.get_stream(i1), t.get_stream(s1), t.get_stream(s2)}));
EXPECT(get_wait_for(s1).empty());
// TODO: Remove the extra wait here
// EXPECT(get_wait_for(s2).empty());
t.check_conflicts(p, {c1, {i1}, {s1}, {s2}});
}
TEST_CASE(inner_split2)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one);
auto s1 = chain(p, 3, unary_op{}, c1.back());
auto s2 = chain(p, 4, unary_op{}, c1.back());
auto output = p.add_instruction(nary_op{}, i1, s1.back(), s2.back());
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) != t.get_stream(s1.back()));
EXPECT(t.get_stream(i1) != t.get_stream(s2.back()));
for(auto ins : c1)
EXPECT(t.get_stream(ins) != t.get_stream(i1));
EXPECT(t.get_stream(s1.back()) != t.get_stream(s2.back()));
EXPECT(t.get_stream(output) == 0);
EXPECT(get_wait_for(output) ==
get_wait_for(t.get_stream(output),
{t.get_stream(i1), t.get_stream(s1.back()), t.get_stream(s2.back())}));
EXPECT(get_wait_for(s1.front()) == get_wait_for({t.get_stream(c1.back())}));
t.check_conflicts(p, {c1, {i1}, s1, s2});
}
TEST_CASE(inception_resnet)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto input = p.add_instruction(unary_op{}, one);
auto c1 = chain(p, 2, unary_op{}, input);
auto i1 = p.add_instruction(unary_op{}, input);
auto binary = p.add_instruction(nary_op{}, i1, c1.back());
auto output = p.add_instruction(nary_op{}, binary, input);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) != 0);
for(auto ins : c1)
EXPECT(t.get_stream(ins) == 0);
EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) ==
get_wait_for(t.get_stream(binary), {t.get_stream(c1.back()), t.get_stream(i1)}));
EXPECT(t.get_stream(output) == 0);
EXPECT(get_wait_for(output).empty());
t.check_conflicts(p, {c1, {i1}});
}
TEST_CASE(inception1)
{
schedule_target t{};
migraphx::program p;
auto i1 = p.add_literal(0);
auto i2 = p.add_literal(1);
auto i3 = p.add_literal(1);
auto i4 = p.add_literal(2);
auto i7 = p.add_instruction(nary_op{"i7"}, i1, i4, i3, i2);
auto i8 = p.add_literal(2);
auto i9 = p.add_instruction(migraphx::op::identity{}, i8);
auto i10 = p.add_literal(1);
auto i11 = p.add_instruction(nary_op{"i11"}, i7, i9, i10);
auto i12 = p.add_literal(2);
auto i13 = p.add_instruction(migraphx::op::identity{}, i12);
auto i14 = p.add_literal(1);
auto i15 = p.add_literal(1);
auto i16 = p.add_literal(2);
auto i17 = p.add_instruction(nary_op{"i17"}, i11, i16, i15, i13, i14);
auto i18 = p.add_literal(2);
auto i19 = p.add_instruction(migraphx::op::identity{}, i18);
auto i20 = p.add_literal(1);
auto i21 = p.add_literal(1);
auto i22 = p.add_literal(2);
auto i23 = p.add_instruction(nary_op{"i23"}, i17, i22, i21, i19, i20);
auto i24 = p.add_literal(1);
auto i25 = p.add_instruction(nary_op{"i25"}, i23, i24);
auto i26 = p.add_literal(2);
auto i27 = p.add_instruction(migraphx::op::identity{}, i26);
auto i28 = p.add_literal(1);
auto i29 = p.add_literal(1);
auto i30 = p.add_literal(2);
auto i31 = p.add_instruction(nary_op{"i31"}, i25, i30, i29, i27, i28);
auto i32 = p.add_literal(2);
auto i33 = p.add_instruction(migraphx::op::identity{}, i32);
auto i34 = p.add_literal(1);
auto i35 = p.add_literal(1);
auto i36 = p.add_literal(2);
auto i37 = p.add_instruction(nary_op{"i37"}, i31, i36, i35, i33, i34);
auto i38 = p.add_literal(1);
auto i39 = p.add_instruction(nary_op{"i39"}, i37, i38);
auto i41 = p.add_literal(2);
auto i42 = p.add_instruction(migraphx::op::identity{}, i41);
auto i43 = p.add_literal(1);
auto i44 = p.add_literal(1);
auto i45 = p.add_literal(2);
auto i48 = p.add_instruction(nary_op{"i48"}, i39, i45, i44, i42, i43);
auto i49 = p.add_literal(2);
auto i50 = p.add_instruction(migraphx::op::identity{}, i49);
auto i51 = p.add_literal(1);
auto i52 = p.add_literal(1);
auto i53 = p.add_literal(2);
auto i54 = p.add_instruction(nary_op{"i54"}, i48, i53, i52, i50, i51);
auto i55 = p.add_literal(1);
auto i56 = p.add_instruction(migraphx::op::identity{}, i55);
auto i57 = p.add_literal(2);
auto i58 = p.add_instruction(migraphx::op::identity{}, i57);
auto i59 = p.add_literal(1);
auto i60 = p.add_literal(2);
auto i61 = p.add_instruction(nary_op{"i61"}, i54, i60, i59, i58, i56);
auto i62 = p.add_literal(2);
auto i63 = p.add_instruction(migraphx::op::identity{}, i62);
auto i64 = p.add_literal(1);
auto i65 = p.add_literal(1);
auto i66 = p.add_literal(2);
auto i69 = p.add_instruction(nary_op{"i69"}, i39, i66, i65, i63, i64);
auto i70 = p.add_instruction(migraphx::op::identity{}, i55);
auto i71 = p.add_literal(2);
auto i72 = p.add_instruction(migraphx::op::identity{}, i71);
auto i73 = p.add_literal(1);
auto i74 = p.add_literal(2);
auto i75 = p.add_instruction(nary_op{"i75"}, i69, i74, i73, i72, i70);
auto i77 = p.add_literal(1);
auto i80 = p.add_instruction(nary_op{"i80"}, i39, i77);
auto i81 = p.add_instruction(migraphx::op::identity{}, i55);
auto i82 = p.add_literal(2);
auto i83 = p.add_instruction(migraphx::op::identity{}, i82);
auto i84 = p.add_literal(1);
auto i85 = p.add_literal(2);
auto i86 = p.add_instruction(nary_op{"i86"}, i80, i85, i84, i83, i81);
auto i88 = p.add_instruction(migraphx::op::identity{}, i55);
auto i89 = p.add_literal(2);
auto i90 = p.add_instruction(migraphx::op::identity{}, i89);
auto i91 = p.add_literal(1);
auto i92 = p.add_literal(2);
auto i94 = p.add_instruction(nary_op{"i94"}, i39, i92, i91, i90, i88);
auto i96 = p.add_instruction(migraphx::op::identity{}, i55, i94, i75, i61, i86);
auto i97 = p.add_literal(2);
auto i98 = p.add_instruction(migraphx::op::identity{}, i97);
auto i99 = p.add_literal(3);
auto i100 = p.add_literal(1);
auto i101 = p.add_literal(2);
auto output = p.add_instruction(nary_op{"output"}, i96, i101, i100, i98, i99);
p.compile(t);
EXPECT(t.get_streams({i7, i11, i17, i23, i25, i31, i37, i39}) ==
t.get_streams({i7, i7, i7, i7, i7, i7, i7, i7}));
EXPECT(t.get_streams({i48, i54, i61, output}) ==
t.get_streams({output, output, output, output}));
EXPECT(t.get_streams({i80, i86}) == t.get_streams({i80, i80}));
EXPECT(t.get_streams({i69, i75}) == t.get_streams({i69, i69}));
EXPECT(t.get_stream(i7) != t.get_stream(i80));
EXPECT(t.get_stream(i69) != t.get_stream(i80));
EXPECT(t.get_stream(i69) != t.get_stream(i7));
EXPECT(t.get_stream(output) != t.get_stream(i69));
EXPECT(t.get_stream(output) != t.get_stream(i80));
EXPECT(get_wait_for(i80) == get_wait_for({t.get_stream(i39)}));
EXPECT(get_wait_for(i69) == get_wait_for({t.get_stream(i39)}));
EXPECT(get_wait_for(i94) == get_wait_for({t.get_stream(i39)}));
EXPECT(
get_wait_for(output) ==
get_wait_for(t.get_stream(output),
{t.get_stream(i94), t.get_stream(i75), t.get_stream(i61), t.get_stream(i86)}));
t.check_conflicts(p, {{i80, i86}, {i69, i75}, {i48, i54, i61}, {i94}});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
2
0 Placeholder*
shape
:*
dtype0
2
1 Placeholder*
dtype0*
shape
:
add_bcast1Add01*
T0"
\ No newline at end of file
:
0 Placeholder*
shape:*
dtype0
:
1 Placeholder*
dtype0*
shape:

add1Add01*
T0"
\ No newline at end of file
;
0 Placeholder*
shape:*
dtype0
/
1 Placeholder*
dtype0*
shape:
:
bias_add1BiasAdd01*
T0*
data_formatNHWC"
\ No newline at end of file
:
0 Placeholder*
shape:*
dtype0
identityIdentity0*
T0"
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment