Commit cf86db72 authored by Paul's avatar Paul
Browse files

Merge branch 'master' into fp16

parents af454aeb 414e2fac
...@@ -20,11 +20,11 @@ migraph::program create_program() ...@@ -20,11 +20,11 @@ migraph::program create_program()
return p; return p;
} }
void program_equality() TEST_CASE(program_equality)
{ {
migraph::program x = create_program(); migraph::program x = create_program();
migraph::program y = create_program(); migraph::program y = create_program();
EXPECT(x == y); EXPECT(x == y);
} }
int main() { program_equality(); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -5,14 +5,14 @@ ...@@ -5,14 +5,14 @@
#include <numeric> #include <numeric>
#include "test.hpp" #include "test.hpp"
void test_shape_default() TEST_CASE(test_shape_default)
{ {
migraph::shape s{}; migraph::shape s{};
EXPECT(s.elements() == 0); EXPECT(s.elements() == 0);
EXPECT(s.bytes() == 0); EXPECT(s.bytes() == 0);
} }
void test_shape_assign() TEST_CASE(test_shape_assign)
{ {
migraph::shape s1{migraph::shape::float_type, {100, 32, 8, 8}}; migraph::shape s1{migraph::shape::float_type, {100, 32, 8, 8}};
migraph::shape s2 = s1; // NOLINT migraph::shape s2 = s1; // NOLINT
...@@ -20,7 +20,7 @@ void test_shape_assign() ...@@ -20,7 +20,7 @@ void test_shape_assign()
EXPECT(!(s1 != s2)); EXPECT(!(s1 != s2));
} }
void test_shape_packed_default() TEST_CASE(test_shape_packed_default)
{ {
migraph::shape s{migraph::shape::float_type, {2, 2}}; migraph::shape s{migraph::shape::float_type, {2, 2}};
EXPECT(s.standard()); EXPECT(s.standard());
...@@ -29,7 +29,7 @@ void test_shape_packed_default() ...@@ -29,7 +29,7 @@ void test_shape_packed_default()
EXPECT(not s.broadcasted()); EXPECT(not s.broadcasted());
} }
void test_shape_packed() TEST_CASE(test_shape_packed)
{ {
migraph::shape s{migraph::shape::float_type, {2, 2}, {2, 1}}; migraph::shape s{migraph::shape::float_type, {2, 2}, {2, 1}};
EXPECT(s.standard()); EXPECT(s.standard());
...@@ -38,7 +38,7 @@ void test_shape_packed() ...@@ -38,7 +38,7 @@ void test_shape_packed()
EXPECT(not s.broadcasted()); EXPECT(not s.broadcasted());
} }
void test_shape_transposed() TEST_CASE(test_shape_transposed)
{ {
migraph::shape s{migraph::shape::float_type, {2, 2}, {1, 2}}; migraph::shape s{migraph::shape::float_type, {2, 2}, {1, 2}};
EXPECT(not s.standard()); EXPECT(not s.standard());
...@@ -47,7 +47,7 @@ void test_shape_transposed() ...@@ -47,7 +47,7 @@ void test_shape_transposed()
EXPECT(not s.broadcasted()); EXPECT(not s.broadcasted());
} }
void test_shape_broadcasted() TEST_CASE(test_shape_broadcasted)
{ {
migraph::shape s{migraph::shape::float_type, {2, 2}, {1, 0}}; migraph::shape s{migraph::shape::float_type, {2, 2}, {1, 0}};
EXPECT(not s.standard()); EXPECT(not s.standard());
...@@ -56,7 +56,7 @@ void test_shape_broadcasted() ...@@ -56,7 +56,7 @@ void test_shape_broadcasted()
EXPECT(s.broadcasted()); EXPECT(s.broadcasted());
} }
void test_shape_default_copy() TEST_CASE(test_shape_default_copy)
{ {
migraph::shape s1{}; migraph::shape s1{};
migraph::shape s2{}; migraph::shape s2{};
...@@ -64,7 +64,7 @@ void test_shape_default_copy() ...@@ -64,7 +64,7 @@ void test_shape_default_copy()
EXPECT(!(s1 != s2)); EXPECT(!(s1 != s2));
} }
void test_shape4() TEST_CASE(test_shape4)
{ {
migraph::shape s{migraph::shape::float_type, {100, 32, 8, 8}}; migraph::shape s{migraph::shape::float_type, {100, 32, 8, 8}};
EXPECT(s.standard()); EXPECT(s.standard());
...@@ -97,7 +97,7 @@ void test_shape4() ...@@ -97,7 +97,7 @@ void test_shape4()
EXPECT(s.index(s.elements() - 1) == s.elements() - 1); EXPECT(s.index(s.elements() - 1) == s.elements() - 1);
} }
void test_shape42() TEST_CASE(test_shape42)
{ {
migraph::shape s{migraph::shape::float_type, {100, 32, 8, 8}, {2048, 64, 8, 1}}; migraph::shape s{migraph::shape::float_type, {100, 32, 8, 8}, {2048, 64, 8, 1}};
EXPECT(s.standard()); EXPECT(s.standard());
...@@ -130,7 +130,7 @@ void test_shape42() ...@@ -130,7 +130,7 @@ void test_shape42()
EXPECT(s.index(s.elements() - 1) == s.elements() - 1); EXPECT(s.index(s.elements() - 1) == s.elements() - 1);
} }
void test_shape4_transposed() TEST_CASE(test_shape4_transposed)
{ {
migraph::shape s{migraph::shape::float_type, {32, 100, 8, 8}, {64, 2048, 8, 1}}; migraph::shape s{migraph::shape::float_type, {32, 100, 8, 8}, {64, 2048, 8, 1}};
EXPECT(s.transposed()); EXPECT(s.transposed());
...@@ -163,7 +163,7 @@ void test_shape4_transposed() ...@@ -163,7 +163,7 @@ void test_shape4_transposed()
EXPECT(s.index(s.elements() - 1) == s.elements() - 1); EXPECT(s.index(s.elements() - 1) == s.elements() - 1);
} }
void test_shape4_nonpacked() TEST_CASE(test_shape4_nonpacked)
{ {
std::vector<std::size_t> lens = {100, 32, 8, 8}; std::vector<std::size_t> lens = {100, 32, 8, 8};
std::array<std::size_t, 4> offsets = {{5, 10, 0, 6}}; std::array<std::size_t, 4> offsets = {{5, 10, 0, 6}};
...@@ -206,17 +206,4 @@ void test_shape4_nonpacked() ...@@ -206,17 +206,4 @@ void test_shape4_nonpacked()
EXPECT(s.index(s.elements() - 1) == 469273); EXPECT(s.index(s.elements() - 1) == 469273);
} }
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
test_shape_default();
test_shape_assign();
test_shape_packed_default();
test_shape_packed();
test_shape_transposed();
test_shape_broadcasted();
test_shape_default_copy();
test_shape4();
test_shape42();
test_shape4_transposed();
test_shape4_nonpacked();
}
...@@ -14,7 +14,7 @@ struct simplify_algebra_target ...@@ -14,7 +14,7 @@ struct simplify_algebra_target
migraph::context get_context() const { return {}; } migraph::context get_context() const { return {}; }
}; };
void simplify_add1() TEST_CASE(simplify_add1)
{ {
migraph::program p1; migraph::program p1;
{ {
...@@ -43,7 +43,7 @@ void simplify_add1() ...@@ -43,7 +43,7 @@ void simplify_add1()
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
void simplify_add2() TEST_CASE(simplify_add2)
{ {
migraph::program p1; migraph::program p1;
{ {
...@@ -72,7 +72,7 @@ void simplify_add2() ...@@ -72,7 +72,7 @@ void simplify_add2()
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
void simplify_add3() TEST_CASE(simplify_add3)
{ {
migraph::program p1; migraph::program p1;
{ {
...@@ -99,6 +99,7 @@ void simplify_add3() ...@@ -99,6 +99,7 @@ void simplify_add3()
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
// TODO: Add test case
void simplify_add4() void simplify_add4()
{ {
migraph::program p1; migraph::program p1;
...@@ -128,10 +129,4 @@ void simplify_add4() ...@@ -128,10 +129,4 @@ void simplify_add4()
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
simplify_add1();
simplify_add2();
simplify_add3();
// simplify_add4();
}
...@@ -14,7 +14,7 @@ struct simplify_reshapes_target ...@@ -14,7 +14,7 @@ struct simplify_reshapes_target
migraph::context get_context() const { return {}; } migraph::context get_context() const { return {}; }
}; };
void double_contig() TEST_CASE(double_contig)
{ {
migraph::program p; migraph::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
...@@ -32,7 +32,7 @@ void double_contig() ...@@ -32,7 +32,7 @@ void double_contig()
EXPECT(result == get_2x2()); EXPECT(result == get_2x2());
} }
void double_transpose() TEST_CASE(double_transpose)
{ {
migraph::program p; migraph::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
...@@ -49,7 +49,7 @@ void double_transpose() ...@@ -49,7 +49,7 @@ void double_transpose()
EXPECT(result == get_2x2()); EXPECT(result == get_2x2());
} }
void double_transpose_contig() TEST_CASE(double_transpose_contig)
{ {
migraph::program p; migraph::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
...@@ -68,7 +68,7 @@ void double_transpose_contig() ...@@ -68,7 +68,7 @@ void double_transpose_contig()
EXPECT(result == get_2x2()); EXPECT(result == get_2x2());
} }
void single_transpose() TEST_CASE(single_transpose)
{ {
migraph::program p; migraph::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
...@@ -84,7 +84,7 @@ void single_transpose() ...@@ -84,7 +84,7 @@ void single_transpose()
EXPECT(result != get_2x2()); EXPECT(result != get_2x2());
} }
void double_transpose_sin_pass() TEST_CASE(double_transpose_sin_pass)
{ {
migraph::program p; migraph::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
...@@ -102,7 +102,7 @@ void double_transpose_sin_pass() ...@@ -102,7 +102,7 @@ void double_transpose_sin_pass()
EXPECT(result == get_2x2()); EXPECT(result == get_2x2());
} }
void single_transpose_sin_pass() TEST_CASE(single_transpose_sin_pass)
{ {
migraph::program p; migraph::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
...@@ -117,12 +117,4 @@ void single_transpose_sin_pass() ...@@ -117,12 +117,4 @@ void single_transpose_sin_pass()
EXPECT(result != get_2x2()); EXPECT(result != get_2x2());
} }
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
double_contig();
double_transpose();
double_transpose_contig();
single_transpose();
double_transpose_sin_pass();
single_transpose_sin_pass();
}
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include <test.hpp> #include <test.hpp>
#include <rob.hpp> #include <rob.hpp>
void simple_test() TEST_CASE(simple_test)
{ {
migraph::program p; migraph::program p;
...@@ -17,7 +17,7 @@ void simple_test() ...@@ -17,7 +17,7 @@ void simple_test()
EXPECT(result != migraph::literal{4}); EXPECT(result != migraph::literal{4});
} }
void out_of_order() TEST_CASE(out_of_order)
{ {
migraph::program p; migraph::program p;
...@@ -28,7 +28,7 @@ void out_of_order() ...@@ -28,7 +28,7 @@ void out_of_order()
EXPECT(bool{p.validate() == ins}); EXPECT(bool{p.validate() == ins});
} }
void incomplete_args() TEST_CASE(incomplete_args)
{ {
migraph::program p; migraph::program p;
...@@ -44,7 +44,7 @@ MIGRAPH_ROB(access_ins_arguments, ...@@ -44,7 +44,7 @@ MIGRAPH_ROB(access_ins_arguments,
migraph::instruction, migraph::instruction,
arguments) arguments)
void invalid_args() TEST_CASE(invalid_args)
{ {
migraph::program p; migraph::program p;
...@@ -55,10 +55,4 @@ void invalid_args() ...@@ -55,10 +55,4 @@ void invalid_args()
EXPECT(bool{p.validate() == p.begin()}); EXPECT(bool{p.validate() == p.begin()});
} }
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
simple_test();
out_of_order();
incomplete_args();
invalid_args();
}
#ifndef MIGRAPH_GUARD_CONCAT_OPT_HPP
#define MIGRAPH_GUARD_CONCAT_OPT_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <migraph/operation.hpp>
#include <migraph/operators.hpp>
namespace migraph {
struct program;
#ifdef DOXYGEN
/// An interface for target-dependent optimization for the concat instruction
struct concat_optimization
{
/// The name of the target-dependent concat operator
std::string name() const;
/// A name of the target-dependent allocate operator
std::string allocate() const;
/// Return the target-independent concat operator
op::concat get_concat(const operation& op) const;
};
#else
<%
interface('concat_optimization',
virtual('name', returns='std::string', const=True),
virtual('allocate', returns='std::string', const=True),
virtual('get_concat', returns='op::concat', op='const operation&', const=True)
)
%>
#endif
} // namespace migraph
#endif
...@@ -43,6 +43,9 @@ struct operation ...@@ -43,6 +43,9 @@ struct operation
* the same the `output` shape. * the same the `output` shape.
*/ */
argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const; argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
/// An optional method to return which argument the output will alias. If
/// there is no aliased output then -1 can be returned.
int output_alias(const std::vector<shape>& input) const;
/// An optional stream operator to print the operation. When this is not /// An optional stream operator to print the operation. When this is not
/// implemented, it will just print the operation's name. /// implemented, it will just print the operation's name.
friend std::ostream& operator<<(std::ostream& os, const operation& op); friend std::ostream& operator<<(std::ostream& os, const operation& op);
...@@ -108,10 +111,34 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto ...@@ -108,10 +111,34 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
return compute_op(rank<1>{}, x, ctx, output_shape, input); return compute_op(rank<1>{}, x, ctx, output_shape, input);
} }
template <class T>
int output_alias_op(rank<0>, const T&, const std::vector<shape>&)
{
return -1;
}
template <class T>
auto output_alias_op(rank<1>, const T& x, const std::vector<shape>& shapes)
-> decltype(x.output_alias(shapes))
{
return x.output_alias(shapes);
}
template <class T>
int output_alias_op(const T& x, const std::vector<shape>& shapes)
{
return output_alias_op(rank<1>{}, x, shapes);
}
<% <%
interface( interface(
'operation', 'operation',
virtual('name', returns = 'std::string', const = True), virtual('name', returns = 'std::string', const = True),
virtual('output_alias',
returns = 'int',
input = 'const std::vector<shape>&',
const = True,
default = 'output_alias_op'),
virtual('compute_shape', returns = 'shape', input = 'const std::vector<shape>&', const = True), virtual('compute_shape', returns = 'shape', input = 'const std::vector<shape>&', const = True),
virtual('compute', virtual('compute',
returns = 'argument', returns = 'argument',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment