Commit cf86db72 authored by Paul's avatar Paul
Browse files

Merge branch 'master' into fp16

parents af454aeb 414e2fac
......@@ -14,6 +14,7 @@ struct contiguous_target
migraph::context get_context() const { return {}; }
};
// TODO: Add this test case
void literal_broadcast()
{
migraph::program p;
......@@ -25,7 +26,7 @@ void literal_broadcast()
EXPECT(not p.get_shape().broadcasted());
}
void literal_transpose()
TEST_CASE(literal_transpose)
{
migraph::program p;
p.add_literal(get_2x2_transposed());
......@@ -36,7 +37,7 @@ void literal_transpose()
EXPECT(not p.get_shape().transposed());
}
void after_literal_transpose()
TEST_CASE(after_literal_transpose)
{
migraph::program p;
auto l = p.add_literal(get_2x2());
......@@ -51,7 +52,7 @@ void after_literal_transpose()
EXPECT(not p.get_shape().transposed());
}
void after_literal_broadcast()
TEST_CASE(after_literal_broadcast)
{
migraph::program p;
auto l1 = p.add_literal(get_2x2());
......@@ -67,7 +68,7 @@ void after_literal_broadcast()
EXPECT(not p.get_shape().broadcasted());
}
void after_param_transpose()
TEST_CASE(after_param_transpose)
{
migraph::program p;
auto l = p.add_parameter("2x2", {migraph::shape::float_type, {2, 2}});
......@@ -82,7 +83,7 @@ void after_param_transpose()
EXPECT(not p.get_shape().transposed());
}
void after_param_broadcast()
TEST_CASE(after_param_broadcast)
{
migraph::program p;
auto l1 = p.add_parameter("2x2", {migraph::shape::float_type, {2, 2}});
......@@ -98,12 +99,4 @@ void after_param_broadcast()
EXPECT(not p.get_shape().broadcasted());
}
int main()
{
// literal_broadcast();
literal_transpose();
after_literal_transpose();
after_literal_broadcast();
after_param_transpose();
after_param_broadcast();
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -14,7 +14,7 @@ struct cse_target
migraph::context get_context() const { return {}; }
};
void cse_test1()
TEST_CASE(cse_test1)
{
migraph::program p1;
{
......@@ -38,7 +38,7 @@ void cse_test1()
EXPECT(p1 == p2);
}
void cse_test2()
TEST_CASE(cse_test2)
{
migraph::program p1;
{
......@@ -63,7 +63,7 @@ void cse_test2()
EXPECT(p1 == p2);
}
void cse_test3()
TEST_CASE(cse_test3)
{
migraph::program p1;
{
......@@ -86,7 +86,7 @@ void cse_test3()
EXPECT(p1 == p2);
}
void cse_test4()
TEST_CASE(cse_test4)
{
migraph::program p1;
{
......@@ -112,10 +112,4 @@ void cse_test4()
EXPECT(p1 == p2);
}
int main()
{
cse_test1();
cse_test2();
cse_test3();
cse_test4();
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -14,7 +14,7 @@ struct const_prop_target
migraph::context get_context() const { return {}; }
};
void const_add1()
TEST_CASE(const_add1)
{
migraph::program p1;
auto one = p1.add_literal(1);
......@@ -29,7 +29,7 @@ void const_add1()
EXPECT(p1 == p2);
}
void const_add2()
TEST_CASE(const_add2)
{
migraph::program p1;
auto one = p1.add_parameter("one", {migraph::shape::int32_type, {1}});
......@@ -44,7 +44,7 @@ void const_add2()
EXPECT(p1 != p2);
}
void const_add3()
TEST_CASE(const_add3)
{
migraph::program p1;
auto one = p1.add_literal(1);
......@@ -60,9 +60,4 @@ void const_add3()
EXPECT(p1 == p2);
}
int main()
{
const_add1();
const_add2();
const_add3();
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -7,7 +7,7 @@
#include <migraph/verify.hpp>
#include "test.hpp"
void slice_test()
TEST_CASE(slice_test)
{
{
migraph::program p;
......@@ -47,7 +47,7 @@ void slice_test()
}
}
void concat_test()
TEST_CASE(concat_test)
{
{
migraph::program p;
......@@ -97,7 +97,7 @@ void concat_test()
}
}
void squeeze_test()
TEST_CASE(squeeze_test)
{
{
migraph::program p;
......@@ -134,7 +134,7 @@ void squeeze_test()
}
}
void unsqueeze_test()
TEST_CASE(unsqueeze_test)
{
{
migraph::program p;
......@@ -160,7 +160,7 @@ void unsqueeze_test()
}
}
void globalavgpool_test()
TEST_CASE(globalavgpool_test)
{
migraph::program p;
auto s = migraph::shape{migraph::shape::float_type, {1, 3, 2, 2}};
......@@ -180,7 +180,7 @@ void globalavgpool_test()
EXPECT(migraph::verify_range(results_vector, gold));
}
void globalmaxpool_test()
TEST_CASE(globalmaxpool_test)
{
migraph::program p;
auto s = migraph::shape{migraph::shape::float_type, {1, 3, 2, 2}};
......@@ -200,7 +200,7 @@ void globalmaxpool_test()
EXPECT(migraph::verify_range(results_vector, gold));
}
void im2col_3x3_no_pad_identity_test()
TEST_CASE(im2col_3x3_no_pad_identity_test)
{
std::size_t f[2] = {3, 3};
std::size_t size[2] = {3, 3};
......@@ -229,7 +229,7 @@ void im2col_3x3_no_pad_identity_test()
EXPECT(migraph::verify_range(results_vector, input));
}
void im2col_3x3_no_pad_test()
TEST_CASE(im2col_3x3_no_pad_test)
{
std::size_t f[2] = {3, 3};
std::size_t size[2] = {4, 4};
......@@ -261,7 +261,7 @@ void im2col_3x3_no_pad_test()
EXPECT(migraph::verify_range(results_vector, correct));
}
void im2col_3x3_stride_2_no_pad_test()
TEST_CASE(im2col_3x3_stride_2_no_pad_test)
{
std::size_t f[2] = {3, 3};
std::size_t size[2] = {6, 6};
......@@ -294,7 +294,7 @@ void im2col_3x3_stride_2_no_pad_test()
EXPECT(migraph::verify_range(results_vector, correct));
}
void im2col_3x3_with_padding_test()
TEST_CASE(im2col_3x3_with_padding_test)
{
std::size_t f[2] = {3, 3};
std::size_t size[2] = {2, 2};
......@@ -326,7 +326,7 @@ void im2col_3x3_with_padding_test()
EXPECT(migraph::verify_range(results_vector, correct));
}
void batch_norm_inference_test()
TEST_CASE(batch_norm_inference_test)
{
migraph::program p;
const size_t width = 2, height = 2, channels = 4, batches = 2;
......@@ -366,7 +366,7 @@ void batch_norm_inference_test()
EXPECT(migraph::verify_range(result_vector, gold));
}
void im2col_3x3_with_channels_identity_test()
TEST_CASE(im2col_3x3_with_channels_identity_test)
{
std::size_t f[2] = {3, 3};
std::size_t size[2] = {3, 3};
......@@ -395,7 +395,7 @@ void im2col_3x3_with_channels_identity_test()
EXPECT(migraph::verify_range(results_vector, input));
}
void exp_test()
TEST_CASE(exp_test)
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}};
......@@ -409,7 +409,7 @@ void exp_test()
EXPECT(migraph::verify_range(results_vector, gold));
}
void sin_test()
TEST_CASE(sin_test)
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}};
......@@ -423,7 +423,7 @@ void sin_test()
EXPECT(migraph::verify_range(results_vector, gold));
}
void cos_test()
TEST_CASE(cos_test)
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}};
......@@ -437,7 +437,7 @@ void cos_test()
EXPECT(migraph::verify_range(results_vector, gold));
}
void tan_test()
TEST_CASE(tan_test)
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}};
......@@ -451,7 +451,7 @@ void tan_test()
EXPECT(migraph::verify_range(results_vector, gold));
}
void add_test()
TEST_CASE(add_test)
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}};
......@@ -466,7 +466,7 @@ void add_test()
EXPECT(migraph::verify_range(results_vector, gold));
}
void broadcast_test()
TEST_CASE(broadcast_test)
{
migraph::program p;
migraph::shape a_shape{migraph::shape::int32_type, {2, 2}};
......@@ -485,7 +485,7 @@ void broadcast_test()
EXPECT(output(1, 0) == -3);
EXPECT(output(1, 1) == -3);
}
void add_broadcast_test()
TEST_CASE(add_broadcast_test)
{
migraph::program p;
migraph::shape a_shape{migraph::shape::float_type, {2, 2, 3}};
......@@ -506,7 +506,7 @@ void add_broadcast_test()
EXPECT(migraph::verify_range(results_vector, gold));
}
void sub_test()
TEST_CASE(sub_test)
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}};
......@@ -521,7 +521,7 @@ void sub_test()
EXPECT(migraph::verify_range(results_vector, gold));
}
void mul_test()
TEST_CASE(mul_test)
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}};
......@@ -536,7 +536,7 @@ void mul_test()
EXPECT(migraph::verify_range(results_vector, gold));
}
void div_test()
TEST_CASE(div_test)
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}};
......@@ -551,12 +551,12 @@ void div_test()
EXPECT(migraph::verify_range(results_vector, gold));
}
void relu_test()
TEST_CASE(relu_test)
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}};
auto l = p.add_literal(migraph::literal{s, {-1.f, 0.f, 1.f}});
p.add_instruction(migraph::op::activation{"relu"}, l);
p.add_instruction(migraph::op::relu{}, l);
p.compile(migraph::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
......@@ -565,7 +565,7 @@ void relu_test()
EXPECT(migraph::verify_range(results_vector, gold));
}
void leaky_relu_test()
TEST_CASE(leaky_relu_test)
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}};
......@@ -579,7 +579,7 @@ void leaky_relu_test()
EXPECT(migraph::verify_range(results_vector, gold));
}
void imagescaler_test()
TEST_CASE(imagescaler_test)
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {1, 3, 2, 2}};
......@@ -626,7 +626,7 @@ void imagescaler_test()
EXPECT(migraph::verify_range(results_vector, gold));
}
void reshape_test()
TEST_CASE(reshape_test)
{
migraph::shape a_shape{migraph::shape::float_type, {24, 1, 1, 1}};
std::vector<float> data(24);
......@@ -716,8 +716,10 @@ void gemm_test()
EXPECT(std::abs(results_vector[i] - c[i]) < tol);
}
}
TEST_CASE_REGISTER(gemm_test<float>)
TEST_CASE_REGISTER(gemm_test<double>)
void maxpool_test()
TEST_CASE(maxpool_test)
{
migraph::program p;
std::vector<float> a = {
......@@ -763,7 +765,7 @@ void maxpool_test()
p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{3, 2}}}, al);
p.compile(migraph::cpu::target{});
auto result = p.eval({});
std::cout << result.get_shape() << std::endl;
// std::cout << result.get_shape() << std::endl;
std::vector<float> results_vector(36);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
float tol = 1e-6;
......@@ -774,7 +776,7 @@ void maxpool_test()
}
}
void softmax_test()
TEST_CASE(softmax_test)
{
migraph::program p;
std::vector<float> a = {
......@@ -833,7 +835,7 @@ void softmax_test()
EXPECT(migraph::verify_range(results_vector, s));
}
void conv2d_test()
TEST_CASE(conv2d_test)
{
migraph::program p;
std::vector<float> a = {
......@@ -896,7 +898,7 @@ void conv2d_test()
EXPECT(migraph::verify_range(results_vector, s));
}
void conv2d_padding_test()
TEST_CASE(conv2d_padding_test)
{
migraph::program p;
std::vector<float> a = {
......@@ -952,7 +954,7 @@ void conv2d_padding_test()
EXPECT(migraph::verify_range(results_vector, s));
}
void conv2d_padding_stride_test()
TEST_CASE(conv2d_padding_stride_test)
{
migraph::program p;
std::vector<float> a = {
......@@ -1013,7 +1015,7 @@ void conv2d_padding_stride_test()
EXPECT(migraph::verify_range(results_vector, s));
}
void transpose_test()
TEST_CASE(transpose_test)
{
migraph::shape a_shape{migraph::shape::float_type, {1, 2, 2, 3}};
std::vector<float> data(12);
......@@ -1048,7 +1050,7 @@ void transpose_test()
}
}
void contiguous_test()
TEST_CASE(contiguous_test)
{
migraph::shape a_shape{migraph::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}};
std::vector<float> data(12);
......@@ -1068,41 +1070,4 @@ void contiguous_test()
EXPECT(migraph::verify_range(results_vector, gold));
}
int main()
{
concat_test();
slice_test();
squeeze_test();
unsqueeze_test();
exp_test();
sin_test();
cos_test();
tan_test();
add_test();
broadcast_test();
add_broadcast_test();
imagescaler_test();
sub_test();
mul_test();
div_test();
relu_test();
leaky_relu_test();
gemm_test<float>();
gemm_test<double>();
reshape_test();
transpose_test();
// contiguous_test();
softmax_test();
// maxpool_test();
conv2d_test();
conv2d_padding_test();
conv2d_padding_stride_test();
batch_norm_inference_test();
globalavgpool_test();
globalmaxpool_test();
im2col_3x3_no_pad_identity_test();
im2col_3x3_no_pad_test();
im2col_3x3_stride_2_no_pad_test();
im2col_3x3_with_channels_identity_test();
im2col_3x3_with_padding_test();
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -12,7 +12,7 @@ struct dce_target
migraph::context get_context() const { return {}; }
};
void simple_test()
TEST_CASE(simple_test)
{
migraph::program p;
......@@ -27,7 +27,7 @@ void simple_test()
EXPECT(result != migraph::literal{4});
}
void simple_test_nop()
TEST_CASE(simple_test_nop)
{
migraph::program p;
......@@ -43,7 +43,7 @@ void simple_test_nop()
EXPECT(result != migraph::literal{4});
}
void simple_test_nop2()
TEST_CASE(simple_test_nop2)
{
migraph::program p;
......@@ -59,7 +59,7 @@ void simple_test_nop2()
EXPECT(result != migraph::literal{4});
}
void duplicate_test1()
TEST_CASE(duplicate_test1)
{
migraph::program p;
......@@ -75,7 +75,7 @@ void duplicate_test1()
EXPECT(result != migraph::literal{4});
}
void duplicate_test2()
TEST_CASE(duplicate_test2)
{
migraph::program p;
......@@ -92,7 +92,7 @@ void duplicate_test2()
EXPECT(result != migraph::literal{4});
}
void depth_test()
TEST_CASE(depth_test)
{
migraph::program p;
......@@ -111,12 +111,4 @@ void depth_test()
EXPECT(result != migraph::literal{4});
}
int main()
{
simple_test();
simple_test_nop();
simple_test_nop2();
duplicate_test1();
duplicate_test2();
depth_test();
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -32,7 +32,7 @@ struct allocate
}
};
void basic()
TEST_CASE(basic)
{
migraph::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {8}}});
......@@ -49,7 +49,7 @@ void basic()
EXPECT(p.get_parameter_shape("memory").bytes() == (8 * 4 + 40 * 4 + 200 * 4));
}
void aligned()
TEST_CASE(aligned)
{
migraph::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1}}});
......@@ -66,7 +66,7 @@ void aligned()
EXPECT(p.get_parameter_shape("memory").bytes() == (32 + 32 + 200 * 4));
}
void unaligned()
TEST_CASE(unaligned)
{
migraph::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1}}});
......@@ -83,7 +83,7 @@ void unaligned()
EXPECT(p.get_parameter_shape("memory").bytes() == (1 * 4 + 2 * 4 + 200 * 4));
}
void float_aligned()
TEST_CASE(float_aligned)
{
migraph::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1}}});
......@@ -100,11 +100,8 @@ void float_aligned()
EXPECT(p.get_parameter_shape("memory").bytes() == (1 * 4 + 2 * 4 + 200 * 4));
}
int main()
int main(int argc, const char* argv[])
{
setenv("MIGRAPH_DISABLE_MEMORY_COLORING", "1", 1);
basic();
aligned();
unaligned();
float_aligned();
test::run(argc, argv);
}
#include <migraph/eliminate_concat.hpp>
#include <migraph/dead_code_elimination.hpp>
#include <migraph/operators.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct concat
{
concat(std::size_t axis) { op.axis = axis; }
migraph::op::concat op;
std::string name() const { return "eliminate_concat::concat"; }
migraph::shape compute_shape(std::vector<migraph::shape> inputs) const
{
return op.compute_shape(std::move(inputs));
}
migraph::argument compute(migraph::context&,
const migraph::shape& output_shape,
const std::vector<migraph::argument>&) const
{
return {output_shape};
}
};
struct concat_test_optimization
{
/// A unique name used to identify the concat optimization
std::string name() const { return "eliminate_concat::concat"; }
/// A unique name used to identify the allocate operator
std::string allocate() const { return "allocate"; }
/// Return the lowered concat operator
migraph::op::concat get_concat(const migraph::operation& op) const
{
return migraph::any_cast<concat>(op).op;
}
};
struct eliminate_concat_target
{
std::size_t align = 32;
std::string name() const { return "eliminate_target"; }
std::vector<migraph::pass> get_passes(migraph::context&) const
{
return {migraph::eliminate_concat{concat_test_optimization{}},
migraph::dead_code_elimination{}};
}
migraph::context get_context() const { return {}; }
};
struct allocate
{
migraph::shape s{};
std::string name() const { return "allocate"; }
migraph::shape compute_shape(const std::vector<migraph::shape>& inputs) const
{
migraph::check_shapes{inputs}.has(0);
return s;
}
migraph::argument compute(migraph::context&,
const migraph::shape& output_shape,
const std::vector<migraph::argument>&) const
{
return {output_shape};
}
};
struct fred_op
{
std::string name() const { return "fred_op"; }
migraph::shape compute_shape(const std::vector<migraph::shape>& inputs) const
{
migraph::check_shapes{inputs}.has(1);
return inputs.at(0);
}
migraph::argument compute(migraph::context&,
const migraph::shape&,
const std::vector<migraph::argument>& args) const
{
return args.at(0);
}
};
TEST_CASE(basic)
{
auto create_test_program = []() {
migraph::program p;
auto a1 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 2, 8, 8}}});
auto p1 = p.add_instruction(fred_op{}, a1);
auto a2 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 3, 8, 8}}});
auto p2 = p.add_instruction(fred_op{}, a2);
auto a3 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 5, 8, 8}}});
auto p3 = p.add_instruction(fred_op{}, a3);
std::size_t axis = 1;
auto a4 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 10, 8, 8}}});
p.add_instruction(concat(axis), p1, p2, p3, a4);
return p;
};
auto create_control_program = []() {
migraph::program p;
auto a1 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1, 10, 8, 8}}});
auto l1 = p.add_instruction(
migraph::op::load{migraph::shape{migraph::shape::float_type, {1, 2, 8, 8}}, 0}, {a1});
auto p1 = p.add_instruction(fred_op{}, l1);
auto l2 = p.add_instruction(
migraph::op::load{migraph::shape{migraph::shape::float_type, {1, 3, 8, 8}}, 512}, {a1});
auto p2 = p.add_instruction(fred_op{}, l2);
auto l3 = p.add_instruction(
migraph::op::load{migraph::shape{migraph::shape::float_type, {1, 5, 8, 8}}, 1280},
{a1});
auto p3 = p.add_instruction(fred_op{}, l3);
p.add_instruction(migraph::op::identity{}, {a1, p1, p2, p3});
return p;
};
auto p1 = create_test_program();
auto p2 = create_control_program();
p1.compile(eliminate_concat_target{});
EXPECT(p1 == p2);
}
TEST_CASE(wont_work)
{
auto create_test_program = []() {
migraph::program p;
auto a1 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 2, 8, 8}}});
auto p1 = p.add_instruction(fred_op{}, a1);
auto a2 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 3, 8, 8}}});
auto p2 = p.add_instruction(fred_op{}, a2);
auto a3 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 5, 8, 8}}});
auto p3 = p.add_instruction(fred_op{}, a3);
std::size_t axis = 1;
auto a4 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 10, 8, 8}}});
p.add_instruction(concat(axis), p1, p2, p3, a4);
return p;
};
auto create_control_program = []() {
migraph::program p;
auto a1 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 2, 8, 8}}});
auto p1 = p.add_instruction(fred_op{}, a1);
auto a2 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 3, 8, 8}}});
auto p2 = p.add_instruction(fred_op{}, a2);
auto a3 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 5, 8, 8}}});
auto p3 = p.add_instruction(fred_op{}, a3);
std::size_t axis = 1;
auto a4 =
p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {2, 10, 8, 8}}});
p.add_instruction(concat(axis), p1, p2, p3, a4);
return p;
};
auto p1 = create_test_program();
auto p2 = create_control_program();
p1.compile(eliminate_concat_target{});
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -14,7 +14,7 @@ struct eliminate_contiguous_target
migraph::context get_context() const { return {}; }
};
void standard_op()
TEST_CASE(standard_op)
{
migraph::program p;
auto l = p.add_literal(get_2x2());
......@@ -26,7 +26,7 @@ void standard_op()
EXPECT(std::distance(p.begin(), p.end()) == count);
}
void non_standard_op()
TEST_CASE(non_standard_op)
{
migraph::program p;
auto l = p.add_literal(get_2x2());
......@@ -38,8 +38,4 @@ void non_standard_op()
EXPECT(std::distance(p.begin(), p.end()) == (count - 1));
}
int main()
{
standard_op();
non_standard_op();
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -50,7 +50,7 @@ struct double_reverse_target
migraph::context get_context() const { return {}; }
};
void literal_test1()
TEST_CASE(literal_test1)
{
migraph::program p;
......@@ -62,7 +62,7 @@ void literal_test1()
EXPECT(result != migraph::literal{4});
}
void literal_test2()
TEST_CASE(literal_test2)
{
migraph::program p;
......@@ -76,7 +76,7 @@ void literal_test2()
EXPECT(result != migraph::literal{3});
}
void print_test()
TEST_CASE(print_test)
{
migraph::program p;
......@@ -90,7 +90,7 @@ void print_test()
EXPECT(!s.empty());
}
void param_test()
TEST_CASE(param_test)
{
migraph::program p;
......@@ -104,7 +104,22 @@ void param_test()
EXPECT(result != migraph::literal{4});
}
void replace_test()
TEST_CASE(param_error_test)
{
migraph::program p;
auto x = p.add_parameter("x", {migraph::shape::int64_type});
auto y = p.add_parameter("y", {migraph::shape::int64_type});
p.add_instruction(sum_op{}, x, y);
EXPECT(test::throws<migraph::exception>(
[&] {
p.eval({{"x", migraph::literal{1}.get_argument()}});
},
"Parameter not found: y"));
}
TEST_CASE(replace_test)
{
migraph::program p;
......@@ -119,7 +134,7 @@ void replace_test()
EXPECT(result != migraph::literal{3});
}
void replace_ins_test()
TEST_CASE(replace_ins_test)
{
migraph::program p;
......@@ -135,7 +150,7 @@ void replace_ins_test()
EXPECT(result != migraph::literal{3});
}
void replace_ins_test2()
TEST_CASE(replace_ins_test2)
{
migraph::program p;
......@@ -152,7 +167,7 @@ void replace_ins_test2()
EXPECT(result != migraph::literal{3});
}
void insert_replace_test()
TEST_CASE(insert_replace_test)
{
migraph::program p;
......@@ -170,7 +185,7 @@ void insert_replace_test()
EXPECT(result != migraph::literal{5});
}
void target_test()
TEST_CASE(target_test)
{
migraph::program p;
......@@ -183,7 +198,7 @@ void target_test()
EXPECT(result != migraph::literal{4});
}
void reverse_target_test()
TEST_CASE(reverse_target_test)
{
migraph::program p;
......@@ -196,7 +211,7 @@ void reverse_target_test()
EXPECT(result != migraph::literal{4});
}
void double_reverse_target_test()
TEST_CASE(double_reverse_target_test)
{
migraph::program p;
......@@ -209,16 +224,4 @@ void double_reverse_target_test()
EXPECT(result != migraph::literal{4});
}
int main()
{
literal_test1();
literal_test2();
print_test();
param_test();
replace_test();
replace_ins_test();
replace_ins_test2();
insert_replace_test();
target_test();
reverse_target_test();
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -6,7 +6,7 @@
#include <test.hpp>
#include <migraph/verify.hpp>
void fwd_conv_batchnorm_rewrite_test()
TEST_CASE(fwd_conv_batchnorm_rewrite_test)
{
std::vector<float> xdata = {
0.26485917, 0.61703885, 0.32762103, 0.2503367, 0.6552712, 0.07947932, 0.95442678,
......@@ -64,8 +64,4 @@ void fwd_conv_batchnorm_rewrite_test()
EXPECT(migraph::verify_range(results_vector1, results_vector2));
}
int main()
{
fwd_conv_batchnorm_rewrite_test();
return 0;
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -158,7 +158,7 @@ struct test_literals
auto weights = p.add_literal(
generate_literal(migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}));
auto conv = p.add_instruction(migraph::op::convolution{}, input, weights);
p.add_instruction(migraph::op::activation{"relu"}, conv);
p.add_instruction(migraph::op::relu{}, conv);
return p;
}
};
......@@ -216,6 +216,21 @@ struct test_scale
}
};
struct test_slice
{
migraph::program create_program() const
{
migraph::program p;
migraph::shape s{migraph::shape::int32_type, {2, 2, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", {migraph::shape::int32_type, {2, 2, 2}});
auto slice0 = p.add_instruction(migraph::op::slice{{2}, {0}, {2}}, x);
p.add_instruction(migraph::op::add{}, y, slice0);
return p;
}
};
struct test_triadd
{
migraph::program create_program() const
......@@ -392,7 +407,7 @@ struct test_conv_relu
auto weights =
p.add_parameter("w", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}});
auto conv = p.add_instruction(migraph::op::convolution{}, input, weights);
p.add_instruction(migraph::op::activation{"relu"}, conv);
p.add_instruction(migraph::op::relu{}, conv);
return p;
}
};
......@@ -406,7 +421,7 @@ struct test_conv_relu_half
auto weights =
p.add_parameter("w", migraph::shape{migraph::shape::half_type, {4, 3, 3, 3}});
auto conv = p.add_instruction(migraph::op::convolution{}, input, weights);
p.add_instruction(migraph::op::activation{"relu"}, conv);
p.add_instruction(migraph::op::relu{}, conv);
return p;
}
};
......@@ -419,7 +434,7 @@ struct test_add_relu
auto x = p.add_parameter("x", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}});
auto y = p.add_parameter("y", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}});
auto add = p.add_instruction(migraph::op::add{}, x, y);
p.add_instruction(migraph::op::activation{"relu"}, add);
p.add_instruction(migraph::op::relu{}, add);
return p;
}
};
......@@ -446,7 +461,7 @@ struct test_conv_pooling
p.add_parameter("w", migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}});
auto conv = p.add_instruction(migraph::op::convolution{}, input, weights);
auto pooling = p.add_instruction(migraph::op::pooling{"max"}, conv);
p.add_instruction(migraph::op::activation{"relu"}, pooling);
p.add_instruction(migraph::op::relu{}, pooling);
return p;
}
};
......@@ -669,7 +684,7 @@ struct test_conv_bn_relu_pooling
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4)));
auto bn = p.add_instruction(
migraph::op::batch_norm_inference{}, conv, scale, bias, mean, variance);
auto relu = p.add_instruction(migraph::op::activation{"relu"}, bn);
auto relu = p.add_instruction(migraph::op::relu{}, bn);
p.add_instruction(migraph::op::pooling{"average", {1, 1}, {2, 2}, {3, 3}}, relu);
return p;
}
......@@ -709,6 +724,73 @@ struct test_concat2
}
};
struct test_concat_relu
{
migraph::program create_program() const
{
migraph::program p;
std::size_t axis = 0;
migraph::shape s0{migraph::shape::float_type, {2, 2}};
migraph::shape s1{migraph::shape::float_type, {3, 2}};
migraph::shape s2{migraph::shape::float_type, {1, 2}};
auto l0 = p.add_parameter("x", s0);
auto l1 = p.add_parameter("y", s1);
auto l2 = p.add_parameter("z", s2);
auto r0 = p.add_instruction(migraph::op::relu{}, l0);
auto r1 = p.add_instruction(migraph::op::relu{}, l1);
auto r2 = p.add_instruction(migraph::op::relu{}, l2);
auto c0 = p.add_instruction(migraph::op::concat{axis}, r0, r1, r2);
p.add_instruction(migraph::op::relu{}, c0);
return p;
}
};
void manual_identity()
{
migraph::program p;
std::vector<float> data0 = {0, 1, 2, 3};
migraph::shape s0{migraph::shape::float_type, {2, 2}};
auto l0 = p.add_literal(migraph::literal{s0, data0});
p.add_instruction(migraph::op::identity{}, l0);
p.compile(migraph::gpu::target{});
migraph::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second));
}
auto result = migraph::gpu::from_gpu(p.eval(m));
std::cout << result << std::endl;
}
void manual_test_concat_relu()
{
migraph::program p;
std::size_t axis = 0;
std::vector<float> data0 = {0, 1, 2, 3};
std::vector<float> data1 = {4, 5, 6, 7, 8, 9};
std::vector<float> data2 = {10, 11};
migraph::shape s0{migraph::shape::float_type, {2, 2}};
migraph::shape s1{migraph::shape::float_type, {3, 2}};
migraph::shape s2{migraph::shape::float_type, {1, 2}};
auto l0 = p.add_literal(migraph::literal{s0, data0});
auto l1 = p.add_literal(migraph::literal{s1, data1});
auto l2 = p.add_literal(migraph::literal{s2, data2});
auto r0 = p.add_instruction(migraph::op::relu{}, l0);
auto r1 = p.add_instruction(migraph::op::relu{}, l1);
auto r2 = p.add_instruction(migraph::op::relu{}, l2);
auto c0 = p.add_instruction(migraph::op::concat{axis}, r0, r1, r2);
p.add_instruction(migraph::op::relu{}, c0);
p.compile(migraph::gpu::target{});
migraph::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second));
}
auto result = migraph::gpu::from_gpu(p.eval(m));
std::cout << result << std::endl;
}
struct test_conv_bn_relu_pooling2
{
static migraph::instruction_ref
......@@ -739,7 +821,7 @@ struct test_conv_bn_relu_pooling2
auto conv2 = p.add_instruction(migraph::op::convolution{{0, 0}, {2, 2}, {1, 1}}, x2, w2);
auto bn2 = add_bn(p, conv2, 2048);
auto add = p.add_instruction(migraph::op::add{}, bn1, bn2);
auto relu = p.add_instruction(migraph::op::activation{"relu"}, add);
auto relu = p.add_instruction(migraph::op::relu{}, add);
p.add_instruction(migraph::op::pooling{"average", {1, 1}, {2, 2}, {3, 3}}, relu);
return p;
}
......@@ -749,6 +831,7 @@ int main()
{
verify_program<test_concat>();
verify_program<test_concat2>();
verify_program<test_concat_relu>();
verify_program<test_add>();
verify_program<test_add_half>();
verify_program<test_mul>();
......@@ -785,4 +868,5 @@ int main()
verify_program<test_conv_bn>();
verify_program<test_conv_bn_relu_pooling>();
verify_program<test_conv_bn_relu_pooling2>();
verify_program<test_slice>();
}
......@@ -79,6 +79,7 @@ struct pass_op
return {};
return inputs.front();
}
int output_alias(const std::vector<migraph::shape>&) const { return 0; }
};
struct pass_standard_op
......@@ -103,6 +104,7 @@ struct pass_standard_op
return {};
return inputs.front();
}
int output_alias(const std::vector<migraph::shape>&) const { return 0; }
};
struct nop
......
......@@ -2,7 +2,10 @@
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <functional>
#include <iostream>
#include <unordered_map>
#include <vector>
#ifndef MIGRAPH_GUARD_TEST_TEST_HPP
#define MIGRAPH_GUARD_TEST_TEST_HPP
......@@ -140,7 +143,7 @@ bool throws(F f)
}
}
template <class F, class Exception>
template <class Exception, class F>
bool throws(F f, const std::string& msg = "")
{
try
......@@ -154,11 +157,75 @@ bool throws(F f, const std::string& msg = "")
}
}
template <class T>
void run_test()
using string_map = std::unordered_map<std::string, std::vector<std::string>>;
template <class Keyword>
string_map parse(std::vector<std::string> as, Keyword keyword)
{
string_map result;
std::string flag;
for(auto&& x : as)
{
auto f = keyword(x);
if(f.empty())
{
result[flag].push_back(x);
}
else
{
flag = f.front();
result[flag]; // Ensure the flag exists
}
}
return result;
}
inline auto& get_test_cases()
{
static std::vector<std::pair<std::string, std::function<void()>>> cases;
return cases;
}
inline void add_test_case(std::string name, std::function<void()> f)
{
T t = {};
t.run();
get_test_cases().emplace_back(name, f);
}
struct auto_register
{
template <class F>
auto_register(const char* name, F f) noexcept
{
add_test_case(name, f);
}
};
inline void run_test_case(const std::string& name, const std::function<void()>& f)
{
std::cout << "[ RUN ] " << name << std::endl;
f();
std::cout << "[ COMPLETE ] " << name << std::endl;
}
inline void run(int argc, const char* argv[])
{
std::vector<std::string> as(argv + 1, argv + argc);
auto args = parse(as, [](auto &&) -> std::vector<std::string> { return {}; });
auto cases = args[""];
if(cases.empty())
{
for(auto&& tc : get_test_cases())
run_test_case(tc.first, tc.second);
}
else
{
std::unordered_map<std::string, std::function<void()>> m(get_test_cases().begin(),
get_test_cases().end());
for(auto&& name : cases)
run_test_case(name, m[name]);
}
}
} // namespace test
......@@ -179,4 +246,24 @@ void run_test()
// NOLINTNEXTLINE
#define STATUS(...) EXPECT((__VA_ARGS__) == 0)
// NOLINTNEXTLINE
#define TEST_CAT(x, ...) TEST_PRIMITIVE_CAT(x, __VA_ARGS__)
#define TEST_PRIMITIVE_CAT(x, ...) x##__VA_ARGS__
// NOLINTNEXTLINE
#define TEST_CASE_REGISTER(...) \
static test::auto_register TEST_CAT(register_test_case_, __LINE__) = \
test::auto_register(#__VA_ARGS__, &__VA_ARGS__);
// NOLINTNEXTLINE
#define TEST_CASE(...) \
void __VA_ARGS__(); \
TEST_CASE_REGISTER(__VA_ARGS__) \
void __VA_ARGS__()
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wglobal-constructors"
#endif
#endif
......@@ -4,7 +4,7 @@
#include <string>
#include "test.hpp"
void literal_test()
TEST_CASE(literal_test)
{
EXPECT(migraph::literal{1} == migraph::literal{1});
EXPECT(migraph::literal{1} != migraph::literal{2});
......@@ -25,7 +25,7 @@ void literal_test()
EXPECT(l4.empty());
}
void literal_os1()
TEST_CASE(literal_os1)
{
migraph::literal l{1};
std::stringstream ss;
......@@ -33,7 +33,7 @@ void literal_os1()
EXPECT(ss.str() == "1");
}
void literal_os2()
TEST_CASE(literal_os2)
{
migraph::literal l{};
std::stringstream ss;
......@@ -41,7 +41,7 @@ void literal_os2()
EXPECT(ss.str().empty());
}
void literal_os3()
TEST_CASE(literal_os3)
{
migraph::shape s{migraph::shape::int64_type, {3}};
migraph::literal l{s, {1, 2, 3}};
......@@ -50,9 +50,4 @@ void literal_os3()
EXPECT(ss.str() == "1, 2, 3");
}
int main()
{
literal_test();
literal_os1();
literal_os2();
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -27,7 +27,7 @@ void match1()
EXPECT(bool{r.result == l});
}
void match_name1()
TEST_CASE(match_name1)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -39,7 +39,7 @@ void match_name1()
EXPECT(bool{r.result == sum});
}
void match_name2()
TEST_CASE(match_name2)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -51,7 +51,7 @@ void match_name2()
EXPECT(bool{r.result == p.end()});
}
void match_name3()
TEST_CASE(match_name3)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -63,7 +63,7 @@ void match_name3()
EXPECT(bool{r.result == sum});
}
void match_arg1()
TEST_CASE(match_arg1)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -75,7 +75,7 @@ void match_arg1()
EXPECT(bool{r.result == sum});
}
void match_arg2()
TEST_CASE(match_arg2)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -87,7 +87,7 @@ void match_arg2()
EXPECT(bool{r.result == p.end()});
}
void match_arg3()
TEST_CASE(match_arg3)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -99,7 +99,7 @@ void match_arg3()
EXPECT(bool{r.result == sum});
}
void match_arg4()
TEST_CASE(match_arg4)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -111,7 +111,7 @@ void match_arg4()
EXPECT(bool{r.result == pass});
}
void match_arg5()
TEST_CASE(match_arg5)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -123,7 +123,7 @@ void match_arg5()
EXPECT(bool{r.result == p.end()});
}
void match_arg6()
TEST_CASE(match_arg6)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -135,7 +135,7 @@ void match_arg6()
EXPECT(bool{r.result == sum});
}
void match_arg7()
TEST_CASE(match_arg7)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -148,7 +148,7 @@ void match_arg7()
EXPECT(bool{r.result == sum});
}
void match_args1()
TEST_CASE(match_args1)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -161,7 +161,7 @@ void match_args1()
EXPECT(bool{r.result == sum});
}
void match_args2()
TEST_CASE(match_args2)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -174,7 +174,7 @@ void match_args2()
EXPECT(bool{r.result == p.end()});
}
void match_args3()
TEST_CASE(match_args3)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -186,7 +186,7 @@ void match_args3()
EXPECT(bool{r.result == p.end()});
}
void match_args4()
TEST_CASE(match_args4)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -200,7 +200,7 @@ void match_args4()
EXPECT(bool{r.result == sum2});
}
void match_args5()
TEST_CASE(match_args5)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -213,7 +213,7 @@ void match_args5()
EXPECT(bool{r.result == p.end()});
}
void match_args6()
TEST_CASE(match_args6)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -225,7 +225,7 @@ void match_args6()
EXPECT(bool{r.result == pass});
}
void match_args7()
TEST_CASE(match_args7)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -239,7 +239,7 @@ void match_args7()
EXPECT(bool{r.result == pass});
}
void match_either_args1()
TEST_CASE(match_either_args1)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -253,7 +253,7 @@ void match_either_args1()
EXPECT(bool{r.result == sum2});
}
void match_either_args2()
TEST_CASE(match_either_args2)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -267,7 +267,7 @@ void match_either_args2()
EXPECT(bool{r.result == sum2});
}
void match_either_args3()
TEST_CASE(match_either_args3)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -281,7 +281,7 @@ void match_either_args3()
EXPECT(bool{r.result == p.end()});
}
void match_all_of1()
TEST_CASE(match_all_of1)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -294,7 +294,7 @@ void match_all_of1()
EXPECT(bool{r.result == sum});
}
void match_all_of2()
TEST_CASE(match_all_of2)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -307,7 +307,7 @@ void match_all_of2()
EXPECT(bool{r.result == p.end()});
}
void match_any_of1()
TEST_CASE(match_any_of1)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -320,7 +320,7 @@ void match_any_of1()
EXPECT(bool{r.result == sum});
}
void match_any_of2()
TEST_CASE(match_any_of2)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -333,7 +333,7 @@ void match_any_of2()
EXPECT(bool{r.result == p.end()});
}
void match_none_of1()
TEST_CASE(match_none_of1)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -346,7 +346,7 @@ void match_none_of1()
EXPECT(bool{r.result == sum});
}
void match_none_of2()
TEST_CASE(match_none_of2)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -359,7 +359,7 @@ void match_none_of2()
EXPECT(bool{r.result == p.end()});
}
void match_bind1()
TEST_CASE(match_bind1)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -400,7 +400,7 @@ struct match_find_literal
}
};
void match_finder()
TEST_CASE(match_finder)
{
migraph::program p;
auto one = p.add_literal(1);
......@@ -410,43 +410,4 @@ void match_finder()
match::find_matches(p, match_find_sum{sum}, match_find_literal{sum});
}
int main()
{
match1();
match_name1();
match_name2();
match_name3();
match_arg1();
match_arg2();
match_arg3();
match_arg4();
match_arg5();
match_arg6();
match_arg7();
match_args1();
match_args2();
match_args3();
match_args4();
match_args5();
match_args6();
match_args7();
match_either_args1();
match_either_args2();
match_either_args3();
match_all_of1();
match_all_of2();
match_any_of1();
match_any_of2();
match_none_of1();
match_none_of2();
match_bind1();
match_finder();
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraph/memory_coloring.hpp>
#include <migraph/operators.hpp>
#include <migraph/generate.hpp>
#include <migraph/instruction.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
......@@ -9,7 +10,7 @@ struct memory_coloring_target
std::string name() const { return "memory_coloring"; }
std::vector<migraph::pass> get_passes(migraph::context&) const
{
return {migraph::memory_coloring{"allocate"}};
return {migraph::memory_coloring{"allocate", true}};
}
migraph::context get_context() const { return {}; }
};
......@@ -31,104 +32,580 @@ struct allocate
}
};
// A custom test operator that takes a single argument and an allocation
// This operator's output is an operand alias of argument 1
struct pass_memory
migraph::instruction_ref add_alloc(migraph::program& p, const migraph::shape& s)
{
std::string name() const { return "memory_coloring::pass_memory"; }
migraph::shape compute_shape(const std::vector<migraph::shape>& inputs) const
{
migraph::check_shapes{inputs, *this}.has(2);
return inputs.at(1);
}
migraph::argument compute(migraph::context&,
const migraph::shape&,
const std::vector<migraph::argument>& args) const
{
return args[1];
}
};
auto a0 = p.add_outline(s);
return p.add_instruction(allocate{}, a0);
}
bool no_allocate(const migraph::program& p)
{
return std::none_of(p.begin(), p.end(), [](auto&& ins) { return ins.name() == "allocate"; });
}
// The previous existing test
void test1()
TEST_CASE(test1)
{
migraph::program p;
auto a0 = p.add_outline(migraph::shape{migraph::shape::float_type, {8}});
auto a1 = p.add_instruction(allocate{}, a0);
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = p.add_outline(migraph::shape{migraph::shape::float_type, {40}});
auto p2 = p.add_instruction(allocate{}, a2);
p.add_instruction(pass_op{}, p2, p1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}});
p.add_instruction(pass_op{}, a2, p1);
p.compile(memory_coloring_target{});
EXPECT(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p));
}
// This test uses the pass_memory operator
void test2()
TEST_CASE(test2)
{
migraph::program p;
auto input = p.add_parameter("input", migraph::shape{migraph::shape::float_type, {16}});
auto a0 = p.add_outline(migraph::shape{migraph::shape::float_type, {128}});
auto a1 = p.add_instruction(allocate{}, a0);
auto p1 = p.add_instruction(pass_memory{}, input, a1);
auto a2 = p.add_outline(migraph::shape{migraph::shape::float_type, {40}});
auto p2 = p.add_instruction(allocate{}, a2);
p.add_instruction(pass_memory{}, p1, p2);
auto a1 = add_alloc(p, {migraph::shape::float_type, {128}});
auto p1 = p.add_instruction(pass_op{}, a1, input);
auto p2 = add_alloc(p, {migraph::shape::float_type, {40}});
p.add_instruction(pass_op{}, p2, p1);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 672);
CHECK(no_allocate(p));
}
TEST_CASE(test3)
{
migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
auto p2 = add_alloc(p, {migraph::shape::float_type, {128}});
auto p1 = p.add_instruction(pass_op{}, p2, a1);
auto p3 = add_alloc(p, {migraph::shape::float_type, {40}});
p.add_instruction(pass_op{}, p3, p1);
p.compile(memory_coloring_target{});
EXPECT(p.get_parameter_shape("scratch").bytes() == 672);
CHECK(p.get_parameter_shape("scratch").bytes() == 704); // The optimal solution is actually 672
CHECK(no_allocate(p));
}
// This test uses the pass_memory operator with two memory allocation passed together.
// This is similar to allocations done for workspaces, that is one allocation is aliased and the
// other is just used
void test3()
TEST_CASE(test4)
{
migraph::program p;
auto a0 = p.add_outline(migraph::shape{migraph::shape::float_type, {8}});
auto a1 = p.add_instruction(allocate{}, a0);
auto a2 = p.add_outline(migraph::shape{migraph::shape::float_type, {128}});
auto p2 = p.add_instruction(allocate{}, a2);
auto p1 = p.add_instruction(pass_memory{}, a1, p2);
auto a3 = p.add_outline(migraph::shape{migraph::shape::float_type, {40}});
auto p3 = p.add_instruction(allocate{}, a3);
p.add_instruction(pass_memory{}, p1, p3);
auto a1 = add_alloc(p, {migraph::shape::float_type, {0}});
auto p2 = add_alloc(p, {migraph::shape::float_type, {128}});
auto p1 = p.add_instruction(pass_op{}, p2, a1);
auto p3 = add_alloc(p, {migraph::shape::float_type, {40}});
p.add_instruction(pass_op{}, p3, p1);
p.compile(memory_coloring_target{});
EXPECT(p.get_parameter_shape("scratch").bytes() == 704);
CHECK(p.get_parameter_shape("scratch").bytes() == 672);
CHECK(no_allocate(p));
}
// Like the previous test, but this tests a zero workspace memory allocation
void test4()
TEST_CASE(test5)
{
migraph::program p;
auto a0 = p.add_outline(migraph::shape{migraph::shape::float_type, {0}});
auto a1 = p.add_instruction(allocate{}, a0);
auto a2 = p.add_outline(migraph::shape{migraph::shape::float_type, {128}});
auto p2 = p.add_instruction(allocate{}, a2);
auto p1 = p.add_instruction(pass_memory{}, a1, p2);
auto a3 = p.add_outline(migraph::shape{migraph::shape::float_type, {40}});
auto p3 = p.add_instruction(allocate{}, a3);
p.add_instruction(pass_memory{}, p1, p3);
auto a1 = add_alloc(p, {migraph::shape::float_type, {40}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto p2 = add_alloc(p, {migraph::shape::float_type, {8}});
p.add_instruction(pass_op{}, p2, p1);
p.compile(memory_coloring_target{});
EXPECT(p.get_parameter_shape("scratch").bytes() == 672);
CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p));
}
void literal_test()
TEST_CASE(test6)
{
migraph::program p;
auto lit = generate_literal(migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}});
p.add_literal(lit);
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto p2 = add_alloc(p, {migraph::shape::float_type, {40}});
auto p3 = add_alloc(p, {migraph::shape::float_type, {40}});
p.add_instruction(pass_op{}, p3, p2, p1);
p.compile(memory_coloring_target{});
auto result = p.eval({});
EXPECT(lit == result);
CHECK(p.get_parameter_shape("scratch").bytes() == 352);
CHECK(no_allocate(p));
}
TEST_CASE(test7)
{
migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto p2 = add_alloc(p, {migraph::shape::float_type, {40}});
auto p3 = add_alloc(p, {migraph::shape::float_type, {8}});
p.add_instruction(pass_op{}, p3, p2, p1);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 224);
CHECK(no_allocate(p));
}
TEST_CASE(test8)
{
migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto p2 = add_alloc(p, {migraph::shape::float_type, {40}});
auto p3 = add_alloc(p, {migraph::shape::float_type, {192}});
p.add_instruction(pass_op{}, p3, p2, p1);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 960);
CHECK(no_allocate(p));
}
TEST_CASE(test9)
{
migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto p2 = add_alloc(p, {migraph::shape::float_type, {8}});
auto p3 = add_alloc(p, {migraph::shape::float_type, {8}});
p.add_instruction(pass_op{}, p3, p2, p1);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 96);
CHECK(no_allocate(p));
}
TEST_CASE(test10)
{
migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
p.add_instruction(pass_op{}, a1);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 32);
CHECK(no_allocate(p));
}
TEST_CASE(test11)
{
migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {8}});
auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.add_instruction(pass_op{}, a3, p2);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 224);
CHECK(no_allocate(p));
}
TEST_CASE(test12)
{
migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {40}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {8}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.add_instruction(pass_op{}, a3, p2);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 352);
CHECK(no_allocate(p));
}
TEST_CASE(test13)
{
migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.add_instruction(pass_op{}, a3, p2);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 224);
CHECK(no_allocate(p));
}
TEST_CASE(test14)
{
migraph::program p;
auto a3 = add_alloc(p, {migraph::shape::float_type, {8}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}});
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.add_instruction(pass_op{}, a3, p2);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 224);
CHECK(no_allocate(p));
}
TEST_CASE(test15)
{
migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a2);
auto a3 = add_alloc(p, {migraph::shape::float_type, {40}});
p.add_instruction(pass_op{}, a3, p1, p2);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 352);
CHECK(no_allocate(p));
}
TEST_CASE(test16)
{
migraph::program p;
auto a1 = p.add_literal(migraph::generate_literal({migraph::shape::float_type, {8}}));
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = p.add_literal(migraph::generate_literal({migraph::shape::float_type, {40}}));
auto p2 = p.add_instruction(pass_op{}, a2);
auto a3 = add_alloc(p, {migraph::shape::float_type, {40}});
p.add_instruction(pass_op{}, a3, p1, p2);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 160);
CHECK(no_allocate(p));
}
TEST_CASE(test17)
{
migraph::program p;
auto a3 = add_alloc(p, {migraph::shape::float_type, {40}});
auto a1 = p.add_literal(migraph::generate_literal({migraph::shape::float_type, {8}}));
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = p.add_literal(migraph::generate_literal({migraph::shape::float_type, {40}}));
auto p2 = p.add_instruction(pass_op{}, a2);
p.add_instruction(pass_op{}, a3, p1, p2);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 160);
CHECK(no_allocate(p));
}
TEST_CASE(test18)
{
migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto p2 = p.add_instruction(pass_op{}, a1, p1);
auto p3 = p.add_instruction(pass_op{}, p2, p1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}});
p.add_instruction(pass_op{}, a2, p1, p2, p3);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p));
}
TEST_CASE(test19)
{
migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a2, p1);
auto a3 = add_alloc(p, {migraph::shape::float_type, {40}});
p.add_instruction(pass_op{}, a3, p2, p1);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 352);
CHECK(no_allocate(p));
}
TEST_CASE(test20)
{
migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {32}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {32}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {32}});
auto p1 = p.add_instruction(pass_op{}, a1, a2, a3);
auto a4 = add_alloc(p, {migraph::shape::float_type, {32}});
p.add_instruction(pass_op{}, a4, p1);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 384);
CHECK(no_allocate(p));
}
TEST_CASE(test21)
{
migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {32}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {8}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {32}});
auto p1 = p.add_instruction(pass_op{}, a1, a2, a3);
auto a4 = add_alloc(p, {migraph::shape::float_type, {8}});
p.add_instruction(pass_op{}, a4, p1);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 288);
CHECK(no_allocate(p));
}
TEST_CASE(test22)
{
migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {32}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {32}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1, a2, a3);
auto a4 = add_alloc(p, {migraph::shape::float_type, {8}});
p.add_instruction(pass_op{}, a4, p1);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 288);
CHECK(no_allocate(p));
}
TEST_CASE(test23)
{
migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {32}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {32}});
auto p1 = p.add_instruction(pass_op{}, a1, a2, a3);
auto a4 = add_alloc(p, {migraph::shape::float_type, {8}});
p.add_instruction(pass_op{}, a4, p1);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 288);
CHECK(no_allocate(p));
}
TEST_CASE(test24)
{
migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {32}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {32}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {32}});
auto p1 = p.add_instruction(pass_op{}, a1, a2, a3);
auto a4 = add_alloc(p, {migraph::shape::float_type, {8}});
p.add_instruction(pass_op{}, a4, p1);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 384);
CHECK(no_allocate(p));
}
TEST_CASE(test25)
{
migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
p.add_instruction(nop{});
auto p1 = p.add_instruction(pass_op{}, a1);
p.add_instruction(nop{});
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}});
p.add_instruction(pass_op{}, a2, p1);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p));
}
TEST_CASE(test26)
{
migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
p.add_instruction(nop{}, a1);
auto p1 = p.add_instruction(pass_op{}, a1);
p.add_instruction(nop{}, a1, p1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}});
p.add_instruction(pass_op{}, a2, p1);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p));
}
TEST_CASE(test27)
{
migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}});
p.add_instruction(nop{}, a2, p1);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p));
}
TEST_CASE(test28)
{
migraph::program p;
auto output = p.add_parameter("output", {migraph::shape::float_type, {8}});
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.add_instruction(pass_op{}, p2, output);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p));
}
TEST_CASE(test29)
{
migraph::program p;
auto output = p.add_parameter("output", {migraph::shape::float_type, {8}});
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.move_instruction(output, p2);
p.add_instruction(pass_op{}, p2, output);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p));
}
TEST_CASE(test30)
{
migraph::program p;
auto output = p.add_parameter("x", {migraph::shape::float_type, {8}});
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a2, p1);
p.move_instruction(output, p2);
p.add_instruction(pass_op{}, p2, output);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p));
}
TEST_CASE(test31)
{
migraph::program p;
auto output = p.add_parameter("output", {migraph::shape::float_type, {8}});
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a1);
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}});
p.move_instruction(output, a2);
p.add_instruction(pass_op{}, a2, p1);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 192);
CHECK(no_allocate(p));
}
TEST_CASE(test32)
{
migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {40}});
auto p1 = p.add_instruction(pass_op{}, a2, a1, a3);
auto a5 = add_alloc(p, {migraph::shape::float_type, {40}});
p.add_instruction(pass_op{}, a5, p1);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 352);
CHECK(no_allocate(p));
}
int main()
TEST_CASE(test33)
{
test1();
test2();
test3();
test4();
migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {8}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a2, a1, a3);
auto a5 = add_alloc(p, {migraph::shape::float_type, {40}});
p.add_instruction(pass_op{}, a5, p1);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 224);
CHECK(no_allocate(p));
}
literal_test();
TEST_CASE(test34)
{
migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {40}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {40}});
auto p1 = p.add_instruction(pass_op{}, a2, a1, a3);
auto a5 = add_alloc(p, {migraph::shape::float_type, {8}});
p.add_instruction(pass_op{}, a5, p1);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 480);
CHECK(no_allocate(p));
}
TEST_CASE(test35)
{
migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {40}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {8}});
auto a3 = add_alloc(p, {migraph::shape::float_type, {8}});
auto p1 = p.add_instruction(pass_op{}, a2, a1, a3);
auto a5 = add_alloc(p, {migraph::shape::float_type, {8}});
p.add_instruction(pass_op{}, a5, p1);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 224);
CHECK(no_allocate(p));
}
TEST_CASE(test36)
{
migraph::program p;
auto output = p.add_parameter("output", {migraph::shape::float_type, {20}});
auto a1 = add_alloc(p, {migraph::shape::float_type, {0}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}});
auto p1 = p.add_instruction(pass_op{}, a2, a1);
auto a3 = add_alloc(p, {migraph::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a3, p1);
auto a4 = add_alloc(p, {migraph::shape::float_type, {40}});
auto p3 = p.add_instruction(pass_op{}, a4, p2);
p.add_instruction(pass_op{}, output, p3);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 320);
CHECK(no_allocate(p));
}
TEST_CASE(test37)
{
migraph::program p;
auto output = p.add_parameter("output", {migraph::shape::float_type, {20}});
auto a1 = add_alloc(p, {migraph::shape::float_type, {4}});
auto a2 = add_alloc(p, {migraph::shape::float_type, {40}});
auto p1 = p.add_instruction(pass_op{}, a2, a1);
auto a3 = add_alloc(p, {migraph::shape::float_type, {40}});
auto p2 = p.add_instruction(pass_op{}, a3, p1);
auto a4 = add_alloc(p, {migraph::shape::float_type, {40}});
auto p3 = p.add_instruction(pass_op{}, a4, p2);
p.add_instruction(pass_op{}, output, p3);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 320);
CHECK(no_allocate(p));
}
TEST_CASE(test38)
{
migraph::program p;
auto output = p.add_parameter("output", {migraph::shape::float_type, {1, 64, 56, 56}});
auto p29 = add_alloc(p, {migraph::shape::float_type, {0}});
auto p30 = add_alloc(p, {migraph::shape::float_type, {1, 64, 112, 112}});
auto p31 = p.add_instruction(pass_op{}, p30, p29);
auto p32 = add_alloc(p, {migraph::shape::float_type, {1, 64, 112, 112}});
auto p37 = p.add_instruction(pass_op{}, p32, p31);
auto p38 = add_alloc(p, {migraph::shape::float_type, {1, 64, 112, 112}});
auto p39 = p.add_instruction(pass_op{}, p38, p37);
auto p40 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}});
auto p41 = p.add_instruction(pass_op{}, p40, p39);
auto p42 = add_alloc(p, {migraph::shape::float_type, {0}});
auto p43 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}});
auto p44 = p.add_instruction(pass_op{}, p43, p41, p42);
auto p45 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}});
auto p50 = p.add_instruction(pass_op{}, p45, p44);
auto p51 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}});
auto p52 = p.add_instruction(pass_op{}, p51, p50);
auto p53 = add_alloc(p, {migraph::shape::float_type, {0}});
auto p54 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}});
auto p55 = p.add_instruction(pass_op{}, p54, p52, p53);
auto p56 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}});
auto p61 = p.add_instruction(pass_op{}, p56, p55);
auto p62 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}});
auto p63 = p.add_instruction(pass_op{}, p62, p61, p41);
auto p64 = add_alloc(p, {migraph::shape::float_type, {0}});
auto p65 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}});
auto p66 = p.add_instruction(pass_op{}, p65, p63, p64);
auto p67 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}});
auto p72 = p.add_instruction(pass_op{}, p67, p66);
auto p73 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}});
auto p74 = p.add_instruction(pass_op{}, p73, p72);
auto p75 = add_alloc(p, {migraph::shape::float_type, {0}});
auto p76 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}});
auto p77 = p.add_instruction(pass_op{}, p76, p74, p75);
auto p78 = add_alloc(p, {migraph::shape::float_type, {1, 64, 56, 56}});
auto p83 = p.add_instruction(pass_op{}, p78, p77);
p.add_instruction(pass_op{}, output, p83, p63);
p.compile(memory_coloring_target{});
CHECK(p.get_parameter_shape("scratch").bytes() == 6422528);
CHECK(no_allocate(p));
}
TEST_CASE(literal_test)
{
migraph::program p;
auto lit = generate_literal(migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}});
p.add_literal(lit);
p.compile(memory_coloring_target{});
auto result = p.eval({});
CHECK(lit == result);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -32,7 +32,7 @@ void pytorch_conv_relu_maxpool()
auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2);
auto l5 = p.add_instruction(migraph::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::op::activation{"relu"}, l5);
auto l6 = p.add_instruction(migraph::op::relu{}, l5);
p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
auto prog = migraph::parse_onnx("conv_relu_maxpool.onnx");
......@@ -55,7 +55,7 @@ void pytorch_conv_bn_relu_maxpool()
auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2);
auto l5 = p.add_instruction(migraph::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::op::batch_norm_inference{1.0e-5f}, l5, p3, p4, p5, p6);
auto l7 = p.add_instruction(migraph::op::activation{"relu"}, l6);
auto l7 = p.add_instruction(migraph::op::relu{}, l6);
p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l7);
auto prog = migraph::parse_onnx("conv_bn_relu_maxpool.onnx");
......@@ -72,7 +72,7 @@ void pytorch_conv_relu_maxpool_x2()
auto l3 = p.add_instruction(migraph::op::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::op::broadcast{axis, l3->get_shape()}, l2);
auto l5 = p.add_instruction(migraph::op::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::op::activation{"relu"}, l5);
auto l6 = p.add_instruction(migraph::op::relu{}, l5);
auto l7 = p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
auto l8 = p.add_parameter("3", {migraph::shape::float_type, {1, 5, 5, 5}});
......@@ -80,7 +80,7 @@ void pytorch_conv_relu_maxpool_x2()
auto l10 = p.add_instruction(migraph::op::convolution{}, l7, l8);
auto l11 = p.add_instruction(migraph::op::broadcast{axis, l10->get_shape()}, l9);
auto l12 = p.add_instruction(migraph::op::add{}, l10, l11);
auto l13 = p.add_instruction(migraph::op::activation{"relu"}, l12);
auto l13 = p.add_instruction(migraph::op::relu{}, l12);
p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13);
auto prog = migraph::parse_onnx("conv_relu_maxpoolX2.onnx");
......
......@@ -52,7 +52,7 @@ void throws_shape(const migraph::shape&, Ts...)
"An expected shape should not be passed to throws_shape function");
}
void batch_norm_inference_shape()
TEST_CASE(batch_norm_inference_shape)
{
const size_t channels = 3;
migraph::shape s{migraph::shape::float_type, {4, channels, 3, 3}};
......@@ -62,7 +62,7 @@ void batch_norm_inference_shape()
throws_shape(migraph::op::batch_norm_inference{}, s, vars, vars, vars, vars, vars);
}
void convolution_shape()
TEST_CASE(convolution_shape)
{
migraph::shape output{migraph::shape::float_type, {4, 4, 1, 1}};
migraph::shape input{migraph::shape::float_type, {4, 3, 3, 3}};
......@@ -76,7 +76,7 @@ void convolution_shape()
throws_shape(migraph::op::convolution{}, input2, weights);
}
void transpose_shape()
TEST_CASE(transpose_shape)
{
migraph::shape input{migraph::shape::float_type, {2, 2}};
migraph::shape output{migraph::shape::float_type, {2, 2}, {1, 2}};
......@@ -85,7 +85,7 @@ void transpose_shape()
throws_shape(migraph::op::transpose{{1, 2}}, input);
}
void contiguous_shape()
TEST_CASE(contiguous_shape)
{
migraph::shape output{migraph::shape::float_type, {2, 2}};
migraph::shape input{migraph::shape::float_type, {2, 2}, {1, 2}};
......@@ -96,7 +96,7 @@ void contiguous_shape()
expect_shape(single, migraph::op::contiguous{}, single);
}
void reshape_shape()
TEST_CASE(reshape_shape)
{
migraph::shape input{migraph::shape::float_type, {24, 1, 1, 1}};
for(auto&& new_shape :
......@@ -114,7 +114,7 @@ void reshape_shape()
}
}
void flatten_shape()
TEST_CASE(flatten_shape)
{
migraph::shape input{migraph::shape::float_type, {2, 4, 6, 8}};
expect_shape(migraph::shape{migraph::shape::float_type, {1, 2 * 4 * 6 * 8}},
......@@ -132,7 +132,7 @@ void flatten_shape()
throws_shape(migraph::op::flatten{5}, input);
}
void slice_shape()
TEST_CASE(slice_shape)
{
migraph::shape input{migraph::shape::int32_type, {2, 2, 3}};
expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
......@@ -145,13 +145,4 @@ void slice_shape()
migraph::op::slice{{2}, {2}, {10}},
input);
}
int main()
{
batch_norm_inference_shape();
convolution_shape();
transpose_shape();
contiguous_shape();
reshape_shape();
flatten_shape();
slice_shape();
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -43,7 +43,7 @@ struct simple_operation_no_print
}
};
void operation_copy_test()
TEST_CASE(operation_copy_test)
{
simple_operation s{};
migraph::operation op1 = s; // NOLINT
......@@ -54,7 +54,7 @@ void operation_copy_test()
EXPECT(op2 == op1);
}
void operation_equal_test()
TEST_CASE(operation_equal_test)
{
simple_operation s{};
migraph::operation op1 = s;
......@@ -72,7 +72,7 @@ struct not_operation
{
};
void operation_any_cast()
TEST_CASE(operation_any_cast)
{
migraph::operation op1 = simple_operation{};
EXPECT(migraph::any_cast<simple_operation>(op1).data == 1);
......@@ -83,7 +83,7 @@ void operation_any_cast()
EXPECT(migraph::any_cast<not_operation*>(&op2) == nullptr);
}
void operation_print()
TEST_CASE(operation_print)
{
migraph::operation op = simple_operation{};
std::stringstream ss;
......@@ -92,7 +92,7 @@ void operation_print()
EXPECT(s == "simple[1]");
}
void operation_default_print()
TEST_CASE(operation_default_print)
{
migraph::operation op = simple_operation_no_print{};
std::stringstream ss;
......@@ -101,11 +101,4 @@ void operation_default_print()
EXPECT(s == "simple");
}
int main()
{
operation_copy_test();
operation_equal_test();
operation_any_cast();
operation_print();
operation_default_print();
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <test.hpp>
#include <basic_ops.hpp>
TEST_CASE(simple_alias)
{
migraph::program p;
auto l = p.add_literal(1);
auto p1 = p.add_instruction(pass_op{}, l);
EXPECT(bool{migraph::instruction::get_output_alias(l) == l});
EXPECT(bool{migraph::instruction::get_output_alias(p1) == l});
}
TEST_CASE(cascade_alias)
{
migraph::program p;
auto l = p.add_literal(1);
auto p1 = p.add_instruction(pass_op{}, l);
auto p2 = p.add_instruction(pass_op{}, p1);
auto p3 = p.add_instruction(pass_op{}, p2);
EXPECT(bool{migraph::instruction::get_output_alias(l) == l});
EXPECT(bool{migraph::instruction::get_output_alias(p1) == l});
EXPECT(bool{migraph::instruction::get_output_alias(p2) == l});
EXPECT(bool{migraph::instruction::get_output_alias(p3) == l});
}
TEST_CASE(no_alias)
{
migraph::program p;
auto x = p.add_literal(1);
auto y = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, x, y);
EXPECT(bool{migraph::instruction::get_output_alias(sum) == sum});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment