Commit d2549384 authored by Khalique's avatar Khalique
Browse files

manual merge

parents 67048d04 ab6cd9d3
softmax-example:I

01"Softmax test-softmaxZ
0


b
1


B
\ No newline at end of file
 sum-example:e

0
1
23"Sum test-dropoutZ
0

Z
1

Z
2

b
2

B
\ No newline at end of file
 tan-example:9
xy"Tantest_tanZ
x


b
y


B
\ No newline at end of file
 tanh-example:;
xy"Tanh test_tanhZ
x

b
y

B
\ No newline at end of file
unknown-example:

0
12"Unknown
2"Unknown test-unknownZ
0




Z
1


b
2




B
\ No newline at end of file
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <sstream> #include <sstream>
#include "test.hpp" #include "test.hpp"
template <class... Ts> template <class... Ts>
void expect_shape(const migraph::shape& expected, const migraph::operation& op, Ts... xs) void expect_shape(const migraphx::shape& expected, const migraphx::operation& op, Ts... xs)
{ {
migraph::program p; migraphx::program p;
std::vector<migraph::shape> shapes{xs...}; std::vector<migraphx::shape> shapes{xs...};
std::vector<migraph::instruction_ref> args(shapes.size()); std::vector<migraphx::instruction_ref> args(shapes.size());
std::transform( std::transform(
shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return p.add_outline(s); }); shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return p.add_outline(s); });
p.add_instruction(op, args); p.add_instruction(op, args);
...@@ -24,11 +24,11 @@ void expect_shape(const migraph::shape& expected, const migraph::operation& op, ...@@ -24,11 +24,11 @@ void expect_shape(const migraph::shape& expected, const migraph::operation& op,
} }
template <class... Ts> template <class... Ts>
void throws_shape(const migraph::operation& op, Ts... xs) void throws_shape(const migraphx::operation& op, Ts... xs)
{ {
migraph::program p; migraphx::program p;
std::vector<migraph::shape> shapes{xs...}; std::vector<migraphx::shape> shapes{xs...};
std::vector<migraph::instruction_ref> args(shapes.size()); std::vector<migraphx::instruction_ref> args(shapes.size());
std::transform( std::transform(
shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return p.add_outline(s); }); shapes.begin(), shapes.end(), args.begin(), [&](auto&& s) { return p.add_outline(s); });
bool thrown = test::throws([&] { p.add_instruction(op, args); }); bool thrown = test::throws([&] { p.add_instruction(op, args); });
...@@ -46,7 +46,7 @@ struct always_false : std::false_type ...@@ -46,7 +46,7 @@ struct always_false : std::false_type
}; };
template <class... Ts> template <class... Ts>
void throws_shape(const migraph::shape&, Ts...) void throws_shape(const migraphx::shape&, Ts...)
{ {
static_assert(always_false<Ts...>{}, static_assert(always_false<Ts...>{},
"An expected shape should not be passed to throws_shape function"); "An expected shape should not be passed to throws_shape function");
...@@ -55,94 +55,97 @@ void throws_shape(const migraph::shape&, Ts...) ...@@ -55,94 +55,97 @@ void throws_shape(const migraph::shape&, Ts...)
TEST_CASE(batch_norm_inference_shape) TEST_CASE(batch_norm_inference_shape)
{ {
const size_t channels = 3; const size_t channels = 3;
migraph::shape s{migraph::shape::float_type, {4, channels, 3, 3}}; migraphx::shape s{migraphx::shape::float_type, {4, channels, 3, 3}};
migraph::shape vars{migraph::shape::float_type, {channels}}; migraphx::shape vars{migraphx::shape::float_type, {channels}};
expect_shape(s, migraph::op::batch_norm_inference{}, s, vars, vars, vars, vars); expect_shape(s, migraphx::op::batch_norm_inference{}, s, vars, vars, vars, vars);
throws_shape(migraph::op::batch_norm_inference{}, s); throws_shape(migraphx::op::batch_norm_inference{}, s);
throws_shape(migraph::op::batch_norm_inference{}, s, vars, vars, vars, vars, vars); throws_shape(migraphx::op::batch_norm_inference{}, s, vars, vars, vars, vars, vars);
} }
TEST_CASE(convolution_shape) TEST_CASE(convolution_shape)
{ {
migraph::shape output{migraph::shape::float_type, {4, 4, 1, 1}}; migraphx::shape output{migraphx::shape::float_type, {4, 4, 1, 1}};
migraph::shape input{migraph::shape::float_type, {4, 3, 3, 3}}; migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
migraph::shape weights{migraph::shape::float_type, {4, 3, 3, 3}}; migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
expect_shape(output, migraph::op::convolution{}, input, weights); expect_shape(output, migraphx::op::convolution{}, input, weights);
throws_shape(migraph::op::convolution{}, input); throws_shape(migraphx::op::convolution{}, input);
migraph::shape input2{migraph::shape::float_type, {3, 3}}; migraphx::shape input2{migraphx::shape::float_type, {3, 3}};
migraph::shape weights2{migraph::shape::float_type, {3, 3}}; migraphx::shape weights2{migraphx::shape::float_type, {3, 3}};
throws_shape(migraph::op::convolution{}, input2, weights2); throws_shape(migraphx::op::convolution{}, input2, weights2);
throws_shape(migraph::op::convolution{}, input2, weights); throws_shape(migraphx::op::convolution{}, input2, weights);
} }
TEST_CASE(transpose_shape) TEST_CASE(transpose_shape)
{ {
migraph::shape input{migraph::shape::float_type, {2, 2}}; migraphx::shape input{migraphx::shape::float_type, {2, 2}};
migraph::shape output{migraph::shape::float_type, {2, 2}, {1, 2}}; migraphx::shape output{migraphx::shape::float_type, {2, 2}, {1, 2}};
expect_shape(input, migraph::op::transpose{{0, 1}}, input); expect_shape(input, migraphx::op::transpose{{0, 1}}, input);
expect_shape(output, migraph::op::transpose{{1, 0}}, input); expect_shape(output, migraphx::op::transpose{{1, 0}}, input);
throws_shape(migraph::op::transpose{{1, 2}}, input); throws_shape(migraphx::op::transpose{{1, 2}}, input);
} }
TEST_CASE(contiguous_shape) TEST_CASE(contiguous_shape)
{ {
migraph::shape output{migraph::shape::float_type, {2, 2}}; migraphx::shape output{migraphx::shape::float_type, {2, 2}};
migraph::shape input{migraph::shape::float_type, {2, 2}, {1, 2}}; migraphx::shape input{migraphx::shape::float_type, {2, 2}, {1, 2}};
expect_shape(output, migraph::op::contiguous{}, input); expect_shape(output, migraphx::op::contiguous{}, input);
throws_shape(migraph::op::contiguous{}, input, input); throws_shape(migraphx::op::contiguous{}, input, input);
migraph::shape single{migraph::shape::float_type, {2}}; migraphx::shape single{migraphx::shape::float_type, {2}};
expect_shape(single, migraph::op::contiguous{}, single); expect_shape(single, migraphx::op::contiguous{}, single);
} }
TEST_CASE(reshape_shape) TEST_CASE(reshape_shape)
{ {
migraph::shape input{migraph::shape::float_type, {24, 1, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}};
for(auto&& new_shape : for(auto&& new_shape :
std::vector<std::vector<int64_t>>{{8, 3, 1, 1}, {1, 3, 4, 2}, {1, 3, 4, 2}}) std::vector<std::vector<int64_t>>{{8, 3, 1, 1}, {1, 3, 4, 2}, {1, 3, 4, 2}})
{ {
std::vector<std::size_t> lens(new_shape.size()); std::vector<std::size_t> lens(new_shape.size());
std::copy(new_shape.begin(), new_shape.end(), lens.begin()); std::copy(new_shape.begin(), new_shape.end(), lens.begin());
migraph::shape output{migraph::shape::float_type, lens}; migraphx::shape output{migraphx::shape::float_type, lens};
expect_shape(output, migraph::op::reshape{new_shape}, input); expect_shape(output, migraphx::op::reshape{new_shape}, input);
} }
for(auto&& new_shape : std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}}) for(auto&& new_shape : std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}})
{ {
throws_shape(migraph::op::reshape{new_shape}, input); throws_shape(migraphx::op::reshape{new_shape}, input);
} }
} }
TEST_CASE(flatten_shape) TEST_CASE(flatten_shape)
{ {
migraph::shape input{migraph::shape::float_type, {2, 4, 6, 8}}; migraphx::shape input{migraphx::shape::float_type, {2, 4, 6, 8}};
expect_shape(migraph::shape{migraph::shape::float_type, {1, 2 * 4 * 6 * 8}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 2 * 4 * 6 * 8}},
migraph::op::flatten{0}, migraphx::op::flatten{0},
input); input);
expect_shape( expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 4 * 6 * 8}},
migraph::shape{migraph::shape::float_type, {2, 4 * 6 * 8}}, migraph::op::flatten{1}, input); migraphx::op::flatten{1},
expect_shape(
migraph::shape{migraph::shape::float_type, {2 * 4, 6 * 8}}, migraph::op::flatten{2}, input);
expect_shape(
migraph::shape{migraph::shape::float_type, {2 * 4 * 6, 8}}, migraph::op::flatten{3}, input);
expect_shape(migraph::shape{migraph::shape::float_type, {2 * 4 * 6 * 8, 1}},
migraph::op::flatten{4},
input); input);
throws_shape(migraph::op::flatten{5}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {2 * 4, 6 * 8}},
migraphx::op::flatten{2},
input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2 * 4 * 6, 8}},
migraphx::op::flatten{3},
input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2 * 4 * 6 * 8, 1}},
migraphx::op::flatten{4},
input);
throws_shape(migraphx::op::flatten{5}, input);
} }
TEST_CASE(slice_shape) TEST_CASE(slice_shape)
{ {
migraph::shape input{migraph::shape::int32_type, {2, 2, 3}}; migraphx::shape input{migraphx::shape::int32_type, {2, 2, 3}};
expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}}, expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraph::op::slice{{2}, {1}, {3}}, migraphx::op::slice{{2}, {1}, {3}},
input); input);
expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}}, expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraph::op::slice{{0, 1, 2}, {0, 0, 1}, {2, 2, 3}}, migraphx::op::slice{{0, 1, 2}, {0, 0, 1}, {2, 2, 3}},
input); input);
expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 1}, {6, 3, 1}}, expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 1}, {6, 3, 1}},
migraph::op::slice{{2}, {2}, {10}}, migraphx::op::slice{{2}, {2}, {10}},
input); input);
} }
...@@ -150,62 +153,99 @@ TEST_CASE(multibroadcast) ...@@ -150,62 +153,99 @@ TEST_CASE(multibroadcast)
{ {
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {2, 1, 3}}; migraphx::shape input{migraphx::shape::float_type, {2, 1, 3}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 3, 0, 1}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 0, 1}},
migraph::op::multibroadcast{lens}, migraphx::op::multibroadcast{lens},
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {2, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {2, 1, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 1, 0, 0}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 1, 0, 0}},
migraph::op::multibroadcast{lens}, migraphx::op::multibroadcast{lens},
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {5, 1}}; migraphx::shape input{migraphx::shape::float_type, {5, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 0, 1, 0}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 1, 0}},
migraph::op::multibroadcast{lens}, migraphx::op::multibroadcast{lens},
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {1, 0, 0, 0}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 0, 0, 0}},
migraph::op::multibroadcast{lens}, migraphx::op::multibroadcast{lens},
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 2, 5, 3}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {3}}; migraphx::shape input{migraphx::shape::float_type, {3}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 0, 0, 1}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 0, 0, 1}},
migraph::op::multibroadcast{lens}, migraphx::op::multibroadcast{lens},
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 4, 1, 3}; std::vector<std::size_t> lens{4, 4, 1, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 3}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 3, 3, 1}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {0, 3, 3, 1}},
migraph::op::multibroadcast{lens}, migraphx::op::multibroadcast{lens},
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 1, 1, 3}; std::vector<std::size_t> lens{4, 1, 1, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {1, 1, 1, 0}}, expect_shape(migraphx::shape{migraphx::shape::float_type, lens, {1, 1, 1, 0}},
migraph::op::multibroadcast{lens}, migraphx::op::multibroadcast{lens},
input); input);
} }
{ {
std::vector<std::size_t> lens{4, 1, 3}; std::vector<std::size_t> lens{4, 1, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {4, 1, 1, 1}};
throws_shape(migraph::op::multibroadcast{lens}, input); throws_shape(migraphx::op::multibroadcast{lens}, input);
} }
{ {
std::vector<std::size_t> lens{4, 1, 3}; std::vector<std::size_t> lens{4, 1, 3};
migraph::shape input{migraph::shape::float_type, {}}; migraphx::shape input{migraphx::shape::float_type, {}};
throws_shape(migraph::op::multibroadcast{lens}, input); throws_shape(migraphx::op::multibroadcast{lens}, input);
}
}
TEST_CASE(gather)
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 6, 4, 5}},
migraphx::op::gather{axis},
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {6, 3, 4, 5}},
migraphx::op::gather{axis},
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = 4;
throws_shape(migraphx::op::gather{axis}, input, indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = -5;
throws_shape(migraphx::op::gather{axis}, input, indices);
} }
} }
......
#include <migraph/operation.hpp> #include <migraphx/operation.hpp>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include "test.hpp" #include "test.hpp"
...@@ -9,18 +9,19 @@ struct simple_operation ...@@ -9,18 +9,19 @@ struct simple_operation
template <class T, class F> template <class T, class F>
static auto reflect(T& x, F f) static auto reflect(T& x, F f)
{ {
return migraph::pack(f(x.data, "data")); return migraphx::pack(f(x.data, "data"));
} }
int data = 1; int data = 1;
std::string name() const { return "simple"; } std::string name() const { return "simple"; }
migraph::shape compute_shape(const std::vector<migraph::shape>&) const migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const
{ {
MIGRAPH_THROW("not computable"); MIGRAPHX_THROW("not computable");
} }
migraph::argument migraphx::argument compute(migraphx::context&,
compute(migraph::context&, const migraph::shape&, const std::vector<migraph::argument>&) const const migraphx::shape&,
const std::vector<migraphx::argument>&) const
{ {
MIGRAPH_THROW("not computable"); MIGRAPHX_THROW("not computable");
} }
friend std::ostream& operator<<(std::ostream& os, const simple_operation& op) friend std::ostream& operator<<(std::ostream& os, const simple_operation& op)
{ {
...@@ -32,22 +33,23 @@ struct simple_operation ...@@ -32,22 +33,23 @@ struct simple_operation
struct simple_operation_no_print struct simple_operation_no_print
{ {
std::string name() const { return "simple"; } std::string name() const { return "simple"; }
migraph::shape compute_shape(const std::vector<migraph::shape>&) const migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const
{ {
MIGRAPH_THROW("not computable"); MIGRAPHX_THROW("not computable");
} }
migraph::argument migraphx::argument compute(migraphx::context&,
compute(migraph::context&, const migraph::shape&, const std::vector<migraph::argument>&) const const migraphx::shape&,
const std::vector<migraphx::argument>&) const
{ {
MIGRAPH_THROW("not computable"); MIGRAPHX_THROW("not computable");
} }
}; };
TEST_CASE(operation_copy_test) TEST_CASE(operation_copy_test)
{ {
simple_operation s{}; simple_operation s{};
migraph::operation op1 = s; // NOLINT migraphx::operation op1 = s; // NOLINT
migraph::operation op2 = op1; // NOLINT migraphx::operation op2 = op1; // NOLINT
// cppcheck-suppress duplicateExpression // cppcheck-suppress duplicateExpression
EXPECT(s == op1); EXPECT(s == op1);
// cppcheck-suppress duplicateExpression // cppcheck-suppress duplicateExpression
...@@ -57,10 +59,10 @@ TEST_CASE(operation_copy_test) ...@@ -57,10 +59,10 @@ TEST_CASE(operation_copy_test)
TEST_CASE(operation_equal_test) TEST_CASE(operation_equal_test)
{ {
simple_operation s{}; simple_operation s{};
migraph::operation op1 = s; migraphx::operation op1 = s;
s.data = 2; s.data = 2;
migraph::operation op2 = op1; // NOLINT migraphx::operation op2 = op1; // NOLINT
migraph::operation op3 = s; // NOLINT migraphx::operation op3 = s; // NOLINT
EXPECT(s != op1); EXPECT(s != op1);
EXPECT(op2 == op1); EXPECT(op2 == op1);
...@@ -74,18 +76,18 @@ struct not_operation ...@@ -74,18 +76,18 @@ struct not_operation
TEST_CASE(operation_any_cast) TEST_CASE(operation_any_cast)
{ {
migraph::operation op1 = simple_operation{}; migraphx::operation op1 = simple_operation{};
EXPECT(migraph::any_cast<simple_operation>(op1).data == 1); EXPECT(migraphx::any_cast<simple_operation>(op1).data == 1);
EXPECT(migraph::any_cast<not_operation*>(&op1) == nullptr); EXPECT(migraphx::any_cast<not_operation*>(&op1) == nullptr);
EXPECT(test::throws([&] { migraph::any_cast<not_operation&>(op1); })); EXPECT(test::throws([&] { migraphx::any_cast<not_operation&>(op1); }));
migraph::operation op2 = simple_operation{2}; migraphx::operation op2 = simple_operation{2};
EXPECT(migraph::any_cast<simple_operation>(op2).data == 2); EXPECT(migraphx::any_cast<simple_operation>(op2).data == 2);
EXPECT(migraph::any_cast<not_operation*>(&op2) == nullptr); EXPECT(migraphx::any_cast<not_operation*>(&op2) == nullptr);
} }
TEST_CASE(operation_print) TEST_CASE(operation_print)
{ {
migraph::operation op = simple_operation{}; migraphx::operation op = simple_operation{};
std::stringstream ss; std::stringstream ss;
ss << op; ss << op;
std::string s = ss.str(); std::string s = ss.str();
...@@ -94,11 +96,71 @@ TEST_CASE(operation_print) ...@@ -94,11 +96,71 @@ TEST_CASE(operation_print)
TEST_CASE(operation_default_print) TEST_CASE(operation_default_print)
{ {
migraph::operation op = simple_operation_no_print{}; migraphx::operation op = simple_operation_no_print{};
std::stringstream ss; std::stringstream ss;
ss << op; ss << op;
std::string s = ss.str(); std::string s = ss.str();
EXPECT(s == "simple"); EXPECT(s == "simple");
} }
struct final_operation
{
std::string name() const { return "final"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const
{
MIGRAPHX_THROW("not computable");
}
void
finalize(migraphx::context&, const migraphx::shape&, const std::vector<migraphx::shape>&) const
{
}
};
struct final_operation_throw
{
std::string name() const { return "final"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const
{
MIGRAPHX_THROW("not computable");
}
[[gnu::noreturn]] void
finalize(migraphx::context&, const migraphx::shape&, const std::vector<migraphx::shape>&) const
{
MIGRAPHX_THROW("finalize");
}
};
TEST_CASE(check_has_finalize_simple)
{
migraphx::operation op = simple_operation{};
EXPECT(not migraphx::has_finalize(op));
}
TEST_CASE(check_has_finalize)
{
migraphx::operation op = final_operation{};
EXPECT(migraphx::has_finalize(op));
}
TEST_CASE(check_run_finalize)
{
migraphx::operation op = final_operation{};
migraphx::context ctx{};
op.finalize(ctx, {}, {});
}
TEST_CASE(check_run_finalize_simple)
{
migraphx::operation op = simple_operation{};
migraphx::context ctx{};
op.finalize(ctx, {}, {});
}
TEST_CASE(check_run_finalize_throw)
{
migraphx::operation op = final_operation_throw{};
migraphx::context ctx{};
EXPECT(test::throws([&] { op.finalize(ctx, {}, {}); }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <test.hpp> #include <test.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
TEST_CASE(simple_alias) TEST_CASE(simple_alias)
{ {
migraph::program p; migraphx::program p;
auto l = p.add_literal(1); auto l = p.add_literal(1);
auto p1 = p.add_instruction(pass_op{}, l); auto p1 = p.add_instruction(pass_op{}, l);
EXPECT(bool{migraph::instruction::get_output_alias(l) == l}); EXPECT(bool{migraphx::instruction::get_output_alias(l) == l});
EXPECT(bool{migraph::instruction::get_output_alias(p1) == l}); EXPECT(bool{migraphx::instruction::get_output_alias(p1) == l});
} }
TEST_CASE(cascade_alias) TEST_CASE(cascade_alias)
{ {
migraph::program p; migraphx::program p;
auto l = p.add_literal(1); auto l = p.add_literal(1);
auto p1 = p.add_instruction(pass_op{}, l); auto p1 = p.add_instruction(pass_op{}, l);
auto p2 = p.add_instruction(pass_op{}, p1); auto p2 = p.add_instruction(pass_op{}, p1);
auto p3 = p.add_instruction(pass_op{}, p2); auto p3 = p.add_instruction(pass_op{}, p2);
EXPECT(bool{migraph::instruction::get_output_alias(l) == l}); EXPECT(bool{migraphx::instruction::get_output_alias(l) == l});
EXPECT(bool{migraph::instruction::get_output_alias(p1) == l}); EXPECT(bool{migraphx::instruction::get_output_alias(p1) == l});
EXPECT(bool{migraph::instruction::get_output_alias(p2) == l}); EXPECT(bool{migraphx::instruction::get_output_alias(p2) == l});
EXPECT(bool{migraph::instruction::get_output_alias(p3) == l}); EXPECT(bool{migraphx::instruction::get_output_alias(p3) == l});
} }
TEST_CASE(no_alias) TEST_CASE(no_alias)
{ {
migraph::program p; migraphx::program p;
auto x = p.add_literal(1); auto x = p.add_literal(1);
auto y = p.add_literal(2); auto y = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, x, y); auto sum = p.add_instruction(sum_op{}, x, y);
EXPECT(bool{migraph::instruction::get_output_alias(sum) == sum}); EXPECT(bool{migraphx::instruction::get_output_alias(sum) == sum});
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <sstream> #include <sstream>
#include "test.hpp" #include "test.hpp"
#include <basic_ops.hpp> #include <basic_ops.hpp>
migraph::program create_program() migraphx::program create_program()
{ {
migraph::program p; migraphx::program p;
auto x = p.add_parameter("x", {migraph::shape::int64_type}); auto x = p.add_parameter("x", {migraphx::shape::int64_type});
auto y = p.add_parameter("y", {migraph::shape::int64_type}); auto y = p.add_parameter("y", {migraphx::shape::int64_type});
auto sum = p.add_instruction(sum_op{}, x, y); auto sum = p.add_instruction(sum_op{}, x, y);
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -22,8 +22,8 @@ migraph::program create_program() ...@@ -22,8 +22,8 @@ migraph::program create_program()
TEST_CASE(program_equality) TEST_CASE(program_equality)
{ {
migraph::program x = create_program(); migraphx::program x = create_program();
migraph::program y = create_program(); migraphx::program y = create_program();
EXPECT(x == y); EXPECT(x == y);
} }
......
#include <migraph/shape.hpp> #include <migraphx/shape.hpp>
#include <array> #include <array>
#include <algorithm> #include <algorithm>
#include <numeric> #include <numeric>
...@@ -7,22 +7,22 @@ ...@@ -7,22 +7,22 @@
TEST_CASE(test_shape_default) TEST_CASE(test_shape_default)
{ {
migraph::shape s{}; migraphx::shape s{};
EXPECT(s.elements() == 0); EXPECT(s.elements() == 0);
EXPECT(s.bytes() == 0); EXPECT(s.bytes() == 0);
} }
TEST_CASE(test_shape_assign) TEST_CASE(test_shape_assign)
{ {
migraph::shape s1{migraph::shape::float_type, {100, 32, 8, 8}}; migraphx::shape s1{migraphx::shape::float_type, {100, 32, 8, 8}};
migraph::shape s2 = s1; // NOLINT migraphx::shape s2 = s1; // NOLINT
EXPECT(s1 == s2); EXPECT(s1 == s2);
EXPECT(!(s1 != s2)); EXPECT(!(s1 != s2));
} }
TEST_CASE(test_shape_packed_default) TEST_CASE(test_shape_packed_default)
{ {
migraph::shape s{migraph::shape::float_type, {2, 2}}; migraphx::shape s{migraphx::shape::float_type, {2, 2}};
EXPECT(s.standard()); EXPECT(s.standard());
EXPECT(s.packed()); EXPECT(s.packed());
EXPECT(not s.transposed()); EXPECT(not s.transposed());
...@@ -31,7 +31,7 @@ TEST_CASE(test_shape_packed_default) ...@@ -31,7 +31,7 @@ TEST_CASE(test_shape_packed_default)
TEST_CASE(test_shape_packed) TEST_CASE(test_shape_packed)
{ {
migraph::shape s{migraph::shape::float_type, {2, 2}, {2, 1}}; migraphx::shape s{migraphx::shape::float_type, {2, 2}, {2, 1}};
EXPECT(s.standard()); EXPECT(s.standard());
EXPECT(s.packed()); EXPECT(s.packed());
EXPECT(not s.transposed()); EXPECT(not s.transposed());
...@@ -40,7 +40,7 @@ TEST_CASE(test_shape_packed) ...@@ -40,7 +40,7 @@ TEST_CASE(test_shape_packed)
TEST_CASE(test_shape_transposed) TEST_CASE(test_shape_transposed)
{ {
migraph::shape s{migraph::shape::float_type, {2, 2}, {1, 2}}; migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 2}};
EXPECT(not s.standard()); EXPECT(not s.standard());
EXPECT(s.packed()); EXPECT(s.packed());
EXPECT(s.transposed()); EXPECT(s.transposed());
...@@ -49,7 +49,7 @@ TEST_CASE(test_shape_transposed) ...@@ -49,7 +49,7 @@ TEST_CASE(test_shape_transposed)
TEST_CASE(test_shape_broadcasted) TEST_CASE(test_shape_broadcasted)
{ {
migraph::shape s{migraph::shape::float_type, {2, 2}, {1, 0}}; migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 0}};
EXPECT(not s.standard()); EXPECT(not s.standard());
EXPECT(not s.packed()); EXPECT(not s.packed());
EXPECT(not s.transposed()); EXPECT(not s.transposed());
...@@ -58,20 +58,20 @@ TEST_CASE(test_shape_broadcasted) ...@@ -58,20 +58,20 @@ TEST_CASE(test_shape_broadcasted)
TEST_CASE(test_shape_default_copy) TEST_CASE(test_shape_default_copy)
{ {
migraph::shape s1{}; migraphx::shape s1{};
migraph::shape s2{}; migraphx::shape s2{};
EXPECT(s1 == s2); EXPECT(s1 == s2);
EXPECT(!(s1 != s2)); EXPECT(!(s1 != s2));
} }
TEST_CASE(test_shape4) TEST_CASE(test_shape4)
{ {
migraph::shape s{migraph::shape::float_type, {100, 32, 8, 8}}; migraphx::shape s{migraphx::shape::float_type, {100, 32, 8, 8}};
EXPECT(s.standard()); EXPECT(s.standard());
EXPECT(s.packed()); EXPECT(s.packed());
EXPECT(not s.transposed()); EXPECT(not s.transposed());
EXPECT(not s.broadcasted()); EXPECT(not s.broadcasted());
EXPECT(s.type() == migraph::shape::float_type); EXPECT(s.type() == migraphx::shape::float_type);
EXPECT(s.lens()[0] == 100); EXPECT(s.lens()[0] == 100);
EXPECT(s.lens()[1] == 32); EXPECT(s.lens()[1] == 32);
EXPECT(s.lens()[2] == 8); EXPECT(s.lens()[2] == 8);
...@@ -99,12 +99,12 @@ TEST_CASE(test_shape4) ...@@ -99,12 +99,12 @@ TEST_CASE(test_shape4)
TEST_CASE(test_shape42) TEST_CASE(test_shape42)
{ {
migraph::shape s{migraph::shape::float_type, {100, 32, 8, 8}, {2048, 64, 8, 1}}; migraphx::shape s{migraphx::shape::float_type, {100, 32, 8, 8}, {2048, 64, 8, 1}};
EXPECT(s.standard()); EXPECT(s.standard());
EXPECT(s.packed()); EXPECT(s.packed());
EXPECT(not s.transposed()); EXPECT(not s.transposed());
EXPECT(not s.broadcasted()); EXPECT(not s.broadcasted());
EXPECT(s.type() == migraph::shape::float_type); EXPECT(s.type() == migraphx::shape::float_type);
EXPECT(s.lens()[0] == 100); EXPECT(s.lens()[0] == 100);
EXPECT(s.lens()[1] == 32); EXPECT(s.lens()[1] == 32);
EXPECT(s.lens()[2] == 8); EXPECT(s.lens()[2] == 8);
...@@ -132,12 +132,12 @@ TEST_CASE(test_shape42) ...@@ -132,12 +132,12 @@ TEST_CASE(test_shape42)
TEST_CASE(test_shape4_transposed) TEST_CASE(test_shape4_transposed)
{ {
migraph::shape s{migraph::shape::float_type, {32, 100, 8, 8}, {64, 2048, 8, 1}}; migraphx::shape s{migraphx::shape::float_type, {32, 100, 8, 8}, {64, 2048, 8, 1}};
EXPECT(s.transposed()); EXPECT(s.transposed());
EXPECT(s.packed()); EXPECT(s.packed());
EXPECT(not s.standard()); EXPECT(not s.standard());
EXPECT(not s.broadcasted()); EXPECT(not s.broadcasted());
EXPECT(s.type() == migraph::shape::float_type); EXPECT(s.type() == migraphx::shape::float_type);
EXPECT(s.lens()[0] == 32); EXPECT(s.lens()[0] == 32);
EXPECT(s.lens()[1] == 100); EXPECT(s.lens()[1] == 100);
EXPECT(s.lens()[2] == 8); EXPECT(s.lens()[2] == 8);
...@@ -179,12 +179,12 @@ TEST_CASE(test_shape4_nonpacked) ...@@ -179,12 +179,12 @@ TEST_CASE(test_shape4_nonpacked)
strides.rbegin() + 1, strides.rbegin() + 1,
std::multiplies<std::size_t>()); std::multiplies<std::size_t>());
migraph::shape s{migraph::shape::float_type, lens, strides}; migraphx::shape s{migraphx::shape::float_type, lens, strides};
EXPECT(not s.standard()); EXPECT(not s.standard());
EXPECT(not s.packed()); EXPECT(not s.packed());
EXPECT(not s.transposed()); EXPECT(not s.transposed());
EXPECT(not s.broadcasted()); EXPECT(not s.broadcasted());
EXPECT(s.type() == migraph::shape::float_type); EXPECT(s.type() == migraphx::shape::float_type);
EXPECT(s.lens()[0] == 100); EXPECT(s.lens()[0] == 100);
EXPECT(s.lens()[1] == 32); EXPECT(s.lens()[1] == 32);
EXPECT(s.lens()[2] == 8); EXPECT(s.lens()[2] == 8);
......
#include <migraph/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
#include <migraph/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
struct simplify_algebra_target struct simplify_algebra_target
{ {
std::string name() const { return "simplify_algebra"; } std::string name() const { return "simplify_algebra"; }
std::vector<migraph::pass> get_passes(migraph::context&) const std::vector<migraphx::pass> get_passes(migraphx::context&) const
{ {
return {migraph::simplify_algebra{}, migraph::dead_code_elimination{}}; return {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}};
} }
migraph::context get_context() const { return {}; } migraphx::context get_context() const { return {}; }
}; };
TEST_CASE(simplify_add1) TEST_CASE(simplify_add1)
{ {
migraph::program p1; migraphx::program p1;
{ {
auto x = p1.add_parameter("x", {migraph::shape::int32_type, {1}}); auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = p1.add_parameter("y", {migraph::shape::int32_type, {1}}); auto y = p1.add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = p1.add_literal(1); auto one = p1.add_literal(1);
auto two = p1.add_literal(2); auto two = p1.add_literal(2);
auto sum1 = p1.add_instruction(migraph::op::add{}, x, one); auto sum1 = p1.add_instruction(migraphx::op::add{}, x, one);
auto sum2 = p1.add_instruction(migraph::op::add{}, y, two); auto sum2 = p1.add_instruction(migraphx::op::add{}, y, two);
auto sum3 = p1.add_instruction(migraph::op::add{}, sum1, sum2); auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3); p1.add_instruction(pass_op{}, sum3);
} }
p1.compile(simplify_algebra_target{}); p1.compile(simplify_algebra_target{});
migraph::program p2; migraphx::program p2;
{ {
auto x = p2.add_parameter("x", {migraph::shape::int32_type, {1}}); auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = p2.add_parameter("y", {migraph::shape::int32_type, {1}}); auto y = p2.add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = p2.add_literal(1); auto one = p2.add_literal(1);
auto two = p2.add_literal(2); auto two = p2.add_literal(2);
auto sum1 = p2.add_instruction(migraph::op::add{}, one, two); auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two);
auto sum2 = p2.add_instruction(migraph::op::add{}, x, y); auto sum2 = p2.add_instruction(migraphx::op::add{}, x, y);
auto sum3 = p2.add_instruction(migraph::op::add{}, sum2, sum1); auto sum3 = p2.add_instruction(migraphx::op::add{}, sum2, sum1);
p2.add_instruction(pass_op{}, sum3); p2.add_instruction(pass_op{}, sum3);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
...@@ -45,28 +45,28 @@ TEST_CASE(simplify_add1) ...@@ -45,28 +45,28 @@ TEST_CASE(simplify_add1)
TEST_CASE(simplify_add2) TEST_CASE(simplify_add2)
{ {
migraph::program p1; migraphx::program p1;
{ {
auto x = p1.add_parameter("x", {migraph::shape::int32_type, {1}}); auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = p1.add_parameter("y", {migraph::shape::int32_type, {1}}); auto y = p1.add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = p1.add_literal(1); auto one = p1.add_literal(1);
auto two = p1.add_literal(2); auto two = p1.add_literal(2);
auto sum1 = p1.add_instruction(migraph::op::add{}, one, x); auto sum1 = p1.add_instruction(migraphx::op::add{}, one, x);
auto sum2 = p1.add_instruction(migraph::op::add{}, two, y); auto sum2 = p1.add_instruction(migraphx::op::add{}, two, y);
auto sum3 = p1.add_instruction(migraph::op::add{}, sum1, sum2); auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3); p1.add_instruction(pass_op{}, sum3);
} }
p1.compile(simplify_algebra_target{}); p1.compile(simplify_algebra_target{});
migraph::program p2; migraphx::program p2;
{ {
auto x = p2.add_parameter("x", {migraph::shape::int32_type, {1}}); auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = p2.add_parameter("y", {migraph::shape::int32_type, {1}}); auto y = p2.add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = p2.add_literal(1); auto one = p2.add_literal(1);
auto two = p2.add_literal(2); auto two = p2.add_literal(2);
auto sum1 = p2.add_instruction(migraph::op::add{}, one, two); auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two);
auto sum2 = p2.add_instruction(migraph::op::add{}, x, y); auto sum2 = p2.add_instruction(migraphx::op::add{}, x, y);
auto sum3 = p2.add_instruction(migraph::op::add{}, sum2, sum1); auto sum3 = p2.add_instruction(migraphx::op::add{}, sum2, sum1);
p2.add_instruction(pass_op{}, sum3); p2.add_instruction(pass_op{}, sum3);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
...@@ -74,26 +74,26 @@ TEST_CASE(simplify_add2) ...@@ -74,26 +74,26 @@ TEST_CASE(simplify_add2)
TEST_CASE(simplify_add3) TEST_CASE(simplify_add3)
{ {
migraph::program p1; migraphx::program p1;
{ {
auto x = p1.add_parameter("x", {migraph::shape::int32_type, {1}}); auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = p1.add_literal(1); auto one = p1.add_literal(1);
auto two = p1.add_literal(2); auto two = p1.add_literal(2);
auto sum1 = p1.add_instruction(migraph::op::add{}, one, x); auto sum1 = p1.add_instruction(migraphx::op::add{}, one, x);
auto sum2 = p1.add_instruction(migraph::op::add{}, one, two); auto sum2 = p1.add_instruction(migraphx::op::add{}, one, two);
auto sum3 = p1.add_instruction(migraph::op::add{}, sum1, sum2); auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3); p1.add_instruction(pass_op{}, sum3);
} }
p1.compile(simplify_algebra_target{}); p1.compile(simplify_algebra_target{});
migraph::program p2; migraphx::program p2;
{ {
auto x = p2.add_parameter("x", {migraph::shape::int32_type, {1}}); auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = p2.add_literal(1); auto one = p2.add_literal(1);
auto two = p2.add_literal(2); auto two = p2.add_literal(2);
auto sum1 = p2.add_instruction(migraph::op::add{}, one, x); auto sum1 = p2.add_instruction(migraphx::op::add{}, one, x);
auto sum2 = p2.add_instruction(migraph::op::add{}, one, two); auto sum2 = p2.add_instruction(migraphx::op::add{}, one, two);
auto sum3 = p2.add_instruction(migraph::op::add{}, sum1, sum2); auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, sum2);
p2.add_instruction(pass_op{}, sum3); p2.add_instruction(pass_op{}, sum3);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
...@@ -102,28 +102,28 @@ TEST_CASE(simplify_add3) ...@@ -102,28 +102,28 @@ TEST_CASE(simplify_add3)
// TODO: Add test case // TODO: Add test case
void simplify_add4() void simplify_add4()
{ {
migraph::program p1; migraphx::program p1;
{ {
auto x = p1.add_parameter("x", {migraph::shape::int32_type, {1}}); auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = p1.add_parameter("y", {migraph::shape::int32_type, {1}}); auto y = p1.add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = p1.add_literal(1); auto one = p1.add_literal(1);
auto two = p1.add_literal(2); auto two = p1.add_literal(2);
auto sum1 = p1.add_instruction(migraph::op::add{}, one, x); auto sum1 = p1.add_instruction(migraphx::op::add{}, one, x);
auto sum2 = p1.add_instruction(migraph::op::add{}, sum1, y); auto sum2 = p1.add_instruction(migraphx::op::add{}, sum1, y);
auto sum3 = p1.add_instruction(migraph::op::add{}, sum2, two); auto sum3 = p1.add_instruction(migraphx::op::add{}, sum2, two);
p1.add_instruction(pass_op{}, sum3); p1.add_instruction(pass_op{}, sum3);
} }
p1.compile(simplify_algebra_target{}); p1.compile(simplify_algebra_target{});
migraph::program p2; migraphx::program p2;
{ {
auto x = p2.add_parameter("x", {migraph::shape::int32_type, {1}}); auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = p2.add_parameter("y", {migraph::shape::int32_type, {1}}); auto y = p2.add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = p2.add_literal(1); auto one = p2.add_literal(1);
auto two = p2.add_literal(2); auto two = p2.add_literal(2);
auto sum1 = p2.add_instruction(migraph::op::add{}, one, two); auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two);
auto sum2 = p2.add_instruction(migraph::op::add{}, x, y); auto sum2 = p2.add_instruction(migraphx::op::add{}, x, y);
auto sum3 = p2.add_instruction(migraph::op::add{}, sum2, sum1); auto sum3 = p2.add_instruction(migraphx::op::add{}, sum2, sum1);
p2.add_instruction(pass_op{}, sum3); p2.add_instruction(pass_op{}, sum3);
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
......
#include <migraph/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraph/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
struct simplify_reshapes_target struct simplify_reshapes_target
{ {
std::string name() const { return "simplify_reshapes"; } std::string name() const { return "simplify_reshapes"; }
std::vector<migraph::pass> get_passes(migraph::context&) const std::vector<migraphx::pass> get_passes(migraphx::context&) const
{ {
return {migraph::simplify_reshapes{}, migraph::dead_code_elimination{}}; return {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}};
} }
migraph::context get_context() const { return {}; } migraphx::context get_context() const { return {}; }
}; };
TEST_CASE(double_contig) TEST_CASE(double_contig)
{ {
migraph::program p; migraphx::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraph::op::transpose{{1, 0}}, l); auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c1 = p.add_instruction(migraph::op::contiguous{}, t1); auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
auto c2 = p.add_instruction(migraph::op::contiguous{}, c1); auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
p.add_instruction(pass_op{}, c2); p.add_instruction(pass_op{}, c2);
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
...@@ -34,10 +34,10 @@ TEST_CASE(double_contig) ...@@ -34,10 +34,10 @@ TEST_CASE(double_contig)
TEST_CASE(double_transpose) TEST_CASE(double_transpose)
{ {
migraph::program p; migraphx::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraph::op::transpose{{1, 0}}, l); auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto t2 = p.add_instruction(migraph::op::transpose{{1, 0}}, t1); auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
p.add_instruction(pass_op{}, t2); p.add_instruction(pass_op{}, t2);
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
...@@ -51,12 +51,12 @@ TEST_CASE(double_transpose) ...@@ -51,12 +51,12 @@ TEST_CASE(double_transpose)
TEST_CASE(double_transpose_contig) TEST_CASE(double_transpose_contig)
{ {
migraph::program p; migraphx::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraph::op::transpose{{1, 0}}, l); auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c1 = p.add_instruction(migraph::op::contiguous{}, t1); auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
auto t2 = p.add_instruction(migraph::op::transpose{{1, 0}}, c1); auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, c1);
auto c2 = p.add_instruction(migraph::op::contiguous{}, t2); auto c2 = p.add_instruction(migraphx::op::contiguous{}, t2);
p.add_instruction(pass_op{}, c2); p.add_instruction(pass_op{}, c2);
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
...@@ -70,9 +70,9 @@ TEST_CASE(double_transpose_contig) ...@@ -70,9 +70,9 @@ TEST_CASE(double_transpose_contig)
TEST_CASE(single_transpose) TEST_CASE(single_transpose)
{ {
migraph::program p; migraphx::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraph::op::transpose{{1, 0}}, l); auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
p.add_instruction(pass_op{}, t1); p.add_instruction(pass_op{}, t1);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed()); EXPECT(p.get_shape().transposed());
...@@ -86,10 +86,10 @@ TEST_CASE(single_transpose) ...@@ -86,10 +86,10 @@ TEST_CASE(single_transpose)
TEST_CASE(double_transpose_sin_pass) TEST_CASE(double_transpose_sin_pass)
{ {
migraph::program p; migraphx::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraph::op::transpose{{1, 0}}, l); auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
p.add_instruction(migraph::op::transpose{{1, 0}}, t1); p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
EXPECT(p.get_shape().standard()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
p.compile(simplify_reshapes_target{}); p.compile(simplify_reshapes_target{});
...@@ -104,9 +104,9 @@ TEST_CASE(double_transpose_sin_pass) ...@@ -104,9 +104,9 @@ TEST_CASE(double_transpose_sin_pass)
TEST_CASE(single_transpose_sin_pass) TEST_CASE(single_transpose_sin_pass)
{ {
migraph::program p; migraphx::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
p.add_instruction(migraph::op::transpose{{1, 0}}, l); p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
EXPECT(not p.get_shape().standard()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed()); EXPECT(p.get_shape().transposed());
p.compile(simplify_reshapes_target{}); p.compile(simplify_reshapes_target{});
...@@ -117,4 +117,21 @@ TEST_CASE(single_transpose_sin_pass) ...@@ -117,4 +117,21 @@ TEST_CASE(single_transpose_sin_pass)
EXPECT(result != get_2x2()); EXPECT(result != get_2x2());
} }
TEST_CASE(reshape_transpose)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 112, 56, 56}};
auto x = p.add_parameter("x", s);
auto r1 = p.add_instruction(migraphx::op::reshape{{1, 4, 28, 56, 56}}, x);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3, 4}}, r1);
auto ct = p.add_instruction(migraphx::op::contiguous{}, t);
auto r2 = p.add_instruction(migraphx::op::reshape{{1, 112, 56, 56}}, ct);
p.add_instruction(pass_op{}, r2);
EXPECT(p.get_shape() == s);
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == s);
EXPECT(std::distance(p.begin(), p.end()) == n);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraph/type_name.hpp> #include <migraphx/type_name.hpp>
#include "test.hpp" #include "test.hpp"
struct global_class struct global_class
...@@ -21,8 +21,8 @@ struct ns_class ...@@ -21,8 +21,8 @@ struct ns_class
int main() int main()
{ {
EXPECT(migraph::get_type_name<global_class>() == "global_class"); EXPECT(migraphx::get_type_name<global_class>() == "global_class");
EXPECT(migraph::get_type_name<global_class::inner_class>() == "global_class::inner_class"); EXPECT(migraphx::get_type_name<global_class::inner_class>() == "global_class::inner_class");
EXPECT(migraph::get_type_name<foo::ns_class>() == "foo::ns_class"); EXPECT(migraphx::get_type_name<foo::ns_class>() == "foo::ns_class");
EXPECT(migraph::get_type_name<foo::ns_class::inner_class>() == "foo::ns_class::inner_class"); EXPECT(migraphx::get_type_name<foo::ns_class::inner_class>() == "foo::ns_class::inner_class");
} }
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
#include <rob.hpp> #include <rob.hpp>
TEST_CASE(simple_test) TEST_CASE(simple_test)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
p.add_instruction(sum_op{}, one, two); p.add_instruction(sum_op{}, one, two);
EXPECT(bool{p.validate() == p.end()}); EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result == migraph::literal{3}); EXPECT(result == migraphx::literal{3});
EXPECT(result != migraph::literal{4}); EXPECT(result != migraphx::literal{4});
} }
TEST_CASE(out_of_order) TEST_CASE(out_of_order)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
...@@ -30,7 +30,7 @@ TEST_CASE(out_of_order) ...@@ -30,7 +30,7 @@ TEST_CASE(out_of_order)
TEST_CASE(incomplete_args) TEST_CASE(incomplete_args)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
...@@ -39,14 +39,14 @@ TEST_CASE(incomplete_args) ...@@ -39,14 +39,14 @@ TEST_CASE(incomplete_args)
EXPECT(bool{p.validate() == ins}); EXPECT(bool{p.validate() == ins});
} }
MIGRAPH_ROB(access_ins_arguments, MIGRAPHX_ROB(access_ins_arguments,
std::vector<migraph::instruction_ref>, std::vector<migraphx::instruction_ref>,
migraph::instruction, migraphx::instruction,
arguments) arguments)
TEST_CASE(invalid_args) TEST_CASE(invalid_args)
{ {
migraph::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
......
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
ls -1 $DIR/include/ | xargs -n 1 -P $(nproc) -I{} -t bash -c "python3.6 $DIR/te.py $DIR/include/{} | clang-format-5.0 -style=file > $DIR/../src/include/migraph/{}" ls -1 $DIR/include/ | xargs -n 1 -P $(nproc) -I{} -t bash -c "python3.6 $DIR/te.py $DIR/include/{} | clang-format-5.0 -style=file > $DIR/../src/include/migraphx/{}"
#ifndef MIGRAPH_GUARD_CONCAT_OPT_HPP #ifndef MIGRAPHX_GUARD_CONCAT_OPT_HPP
#define MIGRAPH_GUARD_CONCAT_OPT_HPP #define MIGRAPHX_GUARD_CONCAT_OPT_HPP
#include <cassert> #include <cassert>
#include <string> #include <string>
...@@ -8,10 +8,12 @@ ...@@ -8,10 +8,12 @@
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <migraph/operation.hpp> #include <migraphx/operation.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
...@@ -40,6 +42,7 @@ interface('concat_optimization', ...@@ -40,6 +42,7 @@ interface('concat_optimization',
#endif #endif
} // namespace migraph } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_CONTEXT_HPP #ifndef MIGRAPHX_GUARD_CONTEXT_HPP
#define MIGRAPH_GUARD_CONTEXT_HPP #define MIGRAPHX_GUARD_CONTEXT_HPP
#include <cassert> #include <cassert>
#include <string> #include <string>
...@@ -7,8 +7,10 @@ ...@@ -7,8 +7,10 @@
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#ifdef DOXYGEN #ifdef DOXYGEN
...@@ -31,6 +33,7 @@ interface('context', ...@@ -31,6 +33,7 @@ interface('context',
#endif #endif
} // namespace migraph } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_OPERAND_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_OPERAND_HPP
#include <cassert> #include <cassert>
#include <string> #include <string>
...@@ -7,14 +7,16 @@ ...@@ -7,14 +7,16 @@
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <migraph/shape.hpp> #include <migraphx/shape.hpp>
#include <migraph/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraph/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraph/argument.hpp> #include <migraphx/argument.hpp>
#include <migraph/context.hpp> #include <migraphx/context.hpp>
#include <migraph/auto_any_cast.hpp> #include <migraphx/auto_any_cast.hpp>
#include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#ifdef DOXYGEN #ifdef DOXYGEN
...@@ -24,6 +26,8 @@ struct operation ...@@ -24,6 +26,8 @@ struct operation
{ {
/// A unique name identifying the operation /// A unique name identifying the operation
std::string name() const; std::string name() const;
/// An optional method that can be used to finalize the operator before running
void finalize(context& ctx);
/// This is used to compute the resulting shape from an operation. If an /// This is used to compute the resulting shape from an operation. If an
/// operation cannot be run with input shapes, then it should throw an /// operation cannot be run with input shapes, then it should throw an
/// exception. /// exception.
...@@ -51,6 +55,11 @@ struct operation ...@@ -51,6 +55,11 @@ struct operation
friend std::ostream& operator<<(std::ostream& os, const operation& op); friend std::ostream& operator<<(std::ostream& os, const operation& op);
}; };
/// Returns true if operation does not require a context to run compute
bool is_context_free(const operation& x);
/// Returns true if the operation has a finalize method
bool has_finalize(const operation& x);
#else #else
namespace operation_stream { namespace operation_stream {
...@@ -87,7 +96,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name()) ...@@ -87,7 +96,7 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
} // namespace operation_equal } // namespace operation_equal
template <class T> template <class T>
auto compute_op(rank<1>, auto compute_op(rank<2>,
const T& x, const T& x,
context& ctx, context& ctx,
const shape& output_shape, const shape& output_shape,
...@@ -97,18 +106,72 @@ auto compute_op(rank<1>, ...@@ -97,18 +106,72 @@ auto compute_op(rank<1>,
return x.compute(auto_any_cast(ctx), output_shape, input); return x.compute(auto_any_cast(ctx), output_shape, input);
} }
template <class T>
auto compute_op(
rank<1>, const T& x, context&, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T> template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&) argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{ {
std::string name = x.name(); std::string name = x.name();
MIGRAPH_THROW("Not computable: " + name); MIGRAPHX_THROW("Not computable: " + name);
} }
template <class T> template <class T>
argument argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input) compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{ {
return compute_op(rank<1>{}, x, ctx, output_shape, input); return compute_op(rank<2>{}, x, ctx, output_shape, input);
}
template <class T>
auto compute_op(rank<2>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(std::declval<context&>()), output_shape, input))
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable without a context: " + name);
}
template <class T>
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
}
template <class T>
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<2>{}, x, output_shape, input);
}
template <class T>
auto is_context_free_op(rank<1>,
const T& x,
const shape& output_shape,
const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input), std::true_type{});
template <class T>
auto is_context_free_op(rank<0>, const T&, const shape&, const std::vector<argument>&)
-> std::false_type;
template <class T>
auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
rank<1>{}, x, std::declval<const shape&>(), std::declval<std::vector<argument>>()))
{
return {};
} }
template <class T> template <class T>
...@@ -130,15 +193,60 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes) ...@@ -130,15 +193,60 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
return output_alias_op(rank<1>{}, x, shapes); return output_alias_op(rank<1>{}, x, shapes);
} }
template <class T>
auto finalize_op(
rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
-> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), void())
{
x.finalize(auto_any_cast(ctx), output_shape, input);
}
template <class T>
void finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
{
}
template <class T>
void finalize_op(T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
finalize_op(rank<1>{}, x, ctx, output_shape, input);
}
template <class T>
auto has_finalize_op(
rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
-> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), std::true_type{});
template <class T>
auto has_finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
-> std::false_type;
template <class T>
auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
std::declval<T&>(),
std::declval<context&>(),
std::declval<const shape&>(),
std::declval<std::vector<shape>>()))
{
return {};
}
<% <%
interface( interface(
'operation', 'operation',
virtual('name', returns = 'std::string', const = True), virtual('name', returns = 'std::string', const = True),
virtual('is_context_free', returns = 'bool', const = True, default = 'is_context_free_op'),
virtual('has_finalize', returns = 'bool', const = True, default = 'has_finalize_op'),
virtual('output_alias', virtual('output_alias',
returns = 'int', returns = 'int',
input = 'const std::vector<shape>&', input = 'const std::vector<shape>&',
const = True, const = True,
default = 'output_alias_op'), default = 'output_alias_op'),
virtual('finalize',
ctx = 'context&',
output = 'const shape&',
input = 'const std::vector<shape>&',
default = 'finalize_op'),
virtual('compute_shape', returns = 'shape', input = 'const std::vector<shape>&', const = True), virtual('compute_shape', returns = 'shape', input = 'const std::vector<shape>&', const = True),
virtual('compute', virtual('compute',
returns = 'argument', returns = 'argument',
...@@ -147,24 +255,47 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes) ...@@ -147,24 +255,47 @@ int output_alias_op(const T& x, const std::vector<shape>& shapes)
input = 'const std::vector<argument>&', input = 'const std::vector<argument>&',
const = True, const = True,
default = 'compute_op'), default = 'compute_op'),
virtual('compute',
returns = 'argument',
output = 'const shape&',
input = 'const std::vector<argument>&',
const = True,
default = 'compute_op'),
friend('operator<<', friend('operator<<',
returns = 'std::ostream &', returns = 'std::ostream &',
os = 'std::ostream &', os = 'std::ostream &',
op = 'const operation &', op = 'const operation &',
using = 'migraph::operation_stream::operator<<'), using = 'migraphx::operation_stream::operator<<'),
friend('operator==', friend('operator==',
returns = 'bool', returns = 'bool',
x = 'const operation &', x = 'const operation &',
y = 'const operation &', y = 'const operation &',
using = 'migraph::operation_equal::operator==')) %> using = 'migraphx::operation_equal::operator==')) %>
inline bool operator!=(const operation& x, const operation& y) inline bool operator!=(const operation& x, const operation& y)
{ {
return !(x == y); return !(x == y);
} }
inline bool is_context_free(const operation& op) { return op.is_context_free(); }
template <class T>
bool is_context_free(const T& x)
{
return is_context_free_op(x);
}
inline bool has_finalize(const operation& op) { return op.has_finalize(); }
template <class T>
bool has_finalize(const T& x)
{
return has_finalize_op(x);
}
#endif #endif
} // namespace migraph } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment