Commit b28bd72d authored by Paul's avatar Paul
Browse files

Formatting

parent 42a952cb
...@@ -429,14 +429,14 @@ struct flatten ...@@ -429,14 +429,14 @@ struct flatten
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
auto&& lens = inputs.front().lens(); auto&& lens = inputs.front().lens();
if(axis > lens.size()) if(axis > lens.size())
{ {
MIGRAPH_THROW("axis for flatten must be less than tensor rank"); MIGRAPH_THROW("axis for flatten must be less than tensor rank");
} }
auto x = std::accumulate( auto x =
lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{}); std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
auto y = std::accumulate( auto y =
lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{}); std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{});
return {inputs.at(0).type(), {x, y}}; return {inputs.at(0).type(), {x, y}};
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
......
...@@ -604,7 +604,7 @@ void transpose_test() ...@@ -604,7 +604,7 @@ void transpose_test()
auto result = p.eval({}); auto result = p.eval({});
result.visit([&](auto output) { result.visit([&](auto output) {
std::vector<size_t> new_lens = {1, 3, 2, 2}; std::vector<size_t> new_lens = {1, 3, 2, 2};
EXPECT(bool{output.get_shape().lens() == new_lens}); EXPECT(bool{output.get_shape().lens() == new_lens});
}); });
} }
......
...@@ -5,51 +5,54 @@ ...@@ -5,51 +5,54 @@
#include <sstream> #include <sstream>
#include "test.hpp" #include "test.hpp"
template<class... Ts> template <class... Ts>
void expect_shape(migraph::shape expected, migraph::operation op, Ts... xs) void expect_shape(migraph::shape expected, migraph::operation op, Ts... xs)
{ {
migraph::program p; migraph::program p;
std::vector<migraph::shape> shapes{xs...}; std::vector<migraph::shape> shapes{xs...};
std::vector<migraph::instruction_ref> args; std::vector<migraph::instruction_ref> args;
for(auto&& s:shapes) for(auto&& s : shapes)
args.push_back(p.add_outline(s)); args.push_back(p.add_outline(s));
p.add_instruction(op, args); p.add_instruction(op, args);
if(p.get_shape() != expected) { if(p.get_shape() != expected)
{
std::cout << "FAILED: Incorrect shape for " << op.name() << ": "; std::cout << "FAILED: Incorrect shape for " << op.name() << ": ";
std::cout << expected << " != " << p.get_shape() << std::endl; std::cout << expected << " != " << p.get_shape() << std::endl;
for(auto&& s:shapes) for(auto&& s : shapes)
std::cout << " " << s << std::endl; std::cout << " " << s << std::endl;
} }
} }
template<class... Ts> template <class... Ts>
void throws_shape(migraph::operation op, Ts... xs) void throws_shape(migraph::operation op, Ts... xs)
{ {
migraph::program p; migraph::program p;
std::vector<migraph::shape> shapes{xs...}; std::vector<migraph::shape> shapes{xs...};
std::vector<migraph::instruction_ref> args; std::vector<migraph::instruction_ref> args;
for(auto&& s:shapes) for(auto&& s : shapes)
args.push_back(p.add_outline(s)); args.push_back(p.add_outline(s));
bool thrown = test::throws([&] { p.add_instruction(op, args); }); bool thrown = test::throws([&] { p.add_instruction(op, args); });
if(not thrown) { if(not thrown)
{
std::cout << "FAILED: No error found for " << op.name() << ": "; std::cout << "FAILED: No error found for " << op.name() << ": ";
for(auto&& s:shapes) for(auto&& s : shapes)
std::cout << " " << s << std::endl; std::cout << " " << s << std::endl;
} }
} }
template<class...> template <class...>
struct always_false struct always_false : std::false_type
: std::false_type {
{}; };
template<class... Ts> template <class... Ts>
void throws_shape(migraph::shape, Ts...) void throws_shape(migraph::shape, Ts...)
{ {
static_assert(always_false<Ts...>{}, "An expected shape should not be passed to throws_shape function"); static_assert(always_false<Ts...>{},
"An expected shape should not be passed to throws_shape function");
} }
void batch_norm_inference_shape() void batch_norm_inference_shape()
{ {
const size_t channels = 3; const size_t channels = 3;
migraph::shape s{migraph::shape::float_type, {4, channels, 3, 3}}; migraph::shape s{migraph::shape::float_type, {4, channels, 3, 3}};
...@@ -59,7 +62,7 @@ void batch_norm_inference_shape() ...@@ -59,7 +62,7 @@ void batch_norm_inference_shape()
throws_shape(migraph::batch_norm_inference{}, s, vars, vars, vars, vars, vars); throws_shape(migraph::batch_norm_inference{}, s, vars, vars, vars, vars, vars);
} }
void convolution_shape() void convolution_shape()
{ {
migraph::shape output{migraph::shape::float_type, {4, 4, 1, 1}}; migraph::shape output{migraph::shape::float_type, {4, 4, 1, 1}};
migraph::shape input{migraph::shape::float_type, {4, 3, 3, 3}}; migraph::shape input{migraph::shape::float_type, {4, 3, 3, 3}};
...@@ -88,7 +91,7 @@ void contiguous_shape() ...@@ -88,7 +91,7 @@ void contiguous_shape()
migraph::shape input{migraph::shape::float_type, {2, 2}, {1, 2}}; migraph::shape input{migraph::shape::float_type, {2, 2}, {1, 2}};
expect_shape(output, migraph::contiguous{}, input); expect_shape(output, migraph::contiguous{}, input);
throws_shape(migraph::contiguous{}, input, input); throws_shape(migraph::contiguous{}, input, input);
migraph::shape single{migraph::shape::float_type, {2}}; migraph::shape single{migraph::shape::float_type, {2}};
throws_shape(migraph::contiguous{}, single); throws_shape(migraph::contiguous{}, single);
} }
...@@ -96,11 +99,8 @@ void contiguous_shape() ...@@ -96,11 +99,8 @@ void contiguous_shape()
void reshape_shape() void reshape_shape()
{ {
migraph::shape input{migraph::shape::float_type, {24, 1, 1, 1}}; migraph::shape input{migraph::shape::float_type, {24, 1, 1, 1}};
for(auto&& new_shape:std::vector<std::vector<int64_t>>{ for(auto&& new_shape :
{8, 3, 1, 1}, std::vector<std::vector<int64_t>>{{8, 3, 1, 1}, {1, 3, 4, 2}, {1, 3, 4, 2}})
{1, 3, 4, 2},
{1, 3, 4, 2}
})
{ {
std::vector<std::size_t> lens(new_shape.size()); std::vector<std::size_t> lens(new_shape.size());
std::copy(new_shape.begin(), new_shape.end(), lens.begin()); std::copy(new_shape.begin(), new_shape.end(), lens.begin());
...@@ -108,10 +108,7 @@ void reshape_shape() ...@@ -108,10 +108,7 @@ void reshape_shape()
expect_shape(output, migraph::reshape{new_shape}, input); expect_shape(output, migraph::reshape{new_shape}, input);
} }
for(auto&& new_shape:std::vector<std::vector<int64_t>>{ for(auto&& new_shape : std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}})
{8, 3, 2, 2},
{1, 3, -1, -1}
})
{ {
throws_shape(migraph::reshape{new_shape}, input); throws_shape(migraph::reshape{new_shape}, input);
} }
...@@ -120,15 +117,20 @@ void reshape_shape() ...@@ -120,15 +117,20 @@ void reshape_shape()
void flatten_shape() void flatten_shape()
{ {
migraph::shape input{migraph::shape::float_type, {2, 4, 6, 8}}; migraph::shape input{migraph::shape::float_type, {2, 4, 6, 8}};
expect_shape(migraph::shape{migraph::shape::float_type, {1, 2*4*6*8}}, migraph::flatten{0}, input); expect_shape(
expect_shape(migraph::shape{migraph::shape::float_type, {2, 4*6*8}}, migraph::flatten{1}, input); migraph::shape{migraph::shape::float_type, {1, 2 * 4 * 6 * 8}}, migraph::flatten{0}, input);
expect_shape(migraph::shape{migraph::shape::float_type, {2*4, 6*8}}, migraph::flatten{2}, input); expect_shape(
expect_shape(migraph::shape{migraph::shape::float_type, {2*4*6, 8}}, migraph::flatten{3}, input); migraph::shape{migraph::shape::float_type, {2, 4 * 6 * 8}}, migraph::flatten{1}, input);
expect_shape(migraph::shape{migraph::shape::float_type, {2*4*6*8, 1}}, migraph::flatten{4}, input); expect_shape(
migraph::shape{migraph::shape::float_type, {2 * 4, 6 * 8}}, migraph::flatten{2}, input);
expect_shape(
migraph::shape{migraph::shape::float_type, {2 * 4 * 6, 8}}, migraph::flatten{3}, input);
expect_shape(
migraph::shape{migraph::shape::float_type, {2 * 4 * 6 * 8, 1}}, migraph::flatten{4}, input);
throws_shape(migraph::flatten{5}, input); throws_shape(migraph::flatten{5}, input);
} }
int main() int main()
{ {
batch_norm_inference_shape(); batch_norm_inference_shape();
convolution_shape(); convolution_shape();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment