Commit 492bf901 authored by Scott Thornton's avatar Scott Thornton
Browse files

Formatting

parent 6a385c26
...@@ -320,29 +320,34 @@ struct slice ...@@ -320,29 +320,34 @@ struct slice
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto t = input_shape.type(); auto t = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides(); auto old_strides = input_shape.strides();
std::vector<int64_t> t_axes(old_lens.size()); std::vector<int64_t> t_axes(old_lens.size());
if (axes.size() == 0) { if(axes.size() == 0)
{
std::iota(t_axes.begin(), t_axes.end(), 0); std::iota(t_axes.begin(), t_axes.end(), 0);
} }
else { else
{
std::copy(axes.begin(), axes.end(), t_axes.begin()); std::copy(axes.begin(), axes.end(), t_axes.begin());
} }
if (starts.size() || t_axes.size() != ends.size()) { if(starts.size() || t_axes.size() != ends.size())
{
MIGRAPH_THROW("inconsistent sizes"); MIGRAPH_THROW("inconsistent sizes");
} }
std::vector<std::size_t> new_lens; std::vector<std::size_t> new_lens;
std::copy(old_lens.begin(), old_lens.end(), new_lens.begin()); std::copy(old_lens.begin(), old_lens.end(), new_lens.begin());
auto fix_index = [&] (std::size_t axis, int64_t index) { auto fix_index = [&](std::size_t axis, int64_t index) {
auto r = std::min(index, static_cast<int64_t>(old_lens[axis]-1)); auto r = std::min(index, static_cast<int64_t>(old_lens[axis] - 1));
if (r < 0) r+= old_lens[axis]; if(r < 0)
r += old_lens[axis];
return r; return r;
}; };
for (std::size_t i = 0; i < t_axes.size(); i++) { for(std::size_t i = 0; i < t_axes.size(); i++)
auto axis = t_axes[i]; {
new_lens[axis] = fix_index(axis, ends[i]) - fix_index(axis, starts[i]); auto axis = t_axes[i];
new_lens[axis] = fix_index(axis, ends[i]) - fix_index(axis, starts[i]);
} }
return shape{t, new_lens, old_strides}; return shape{t, new_lens, old_strides};
} }
...@@ -359,24 +364,30 @@ struct squeeze ...@@ -359,24 +364,30 @@ struct squeeze
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
for (auto axis : axes) { for(auto axis : axes)
if (input_shape.lens()[axis] != 1) { {
if(input_shape.lens()[axis] != 1)
{
MIGRAPH_THROW("squeeze axis dimension should be equal to 1"); MIGRAPH_THROW("squeeze axis dimension should be equal to 1");
} }
} }
std::vector<std::size_t> new_lens; std::vector<std::size_t> new_lens;
if (axes.size() == 0) { if(axes.size() == 0)
for (std::size_t i = 0; i < old_lens.size(); i++) { {
if (old_lens[i] != 1) for(std::size_t i = 0; i < old_lens.size(); i++)
{
if(old_lens[i] != 1)
new_lens.push_back(old_lens[i]); new_lens.push_back(old_lens[i]);
} }
} }
else { else
for (std::size_t i = 0; i < old_lens.size(); i++) { {
if (std::find(axes.begin(), axes.end(), i) for(std::size_t i = 0; i < old_lens.size(); i++)
== axes.end()) { {
if(std::find(axes.begin(), axes.end(), i) == axes.end())
{
new_lens.push_back(old_lens[i]); new_lens.push_back(old_lens[i]);
} }
} }
...@@ -386,7 +397,7 @@ struct squeeze ...@@ -386,7 +397,7 @@ struct squeeze
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
}; };
struct unsqueeze struct unsqueeze
...@@ -395,16 +406,20 @@ struct unsqueeze ...@@ -395,16 +406,20 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; } std::string name() const { return "unsqueeze"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
std::size_t new_size = old_lens.size() + axes.size(); std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size); std::vector<std::size_t> new_lens(new_size);
std::size_t p = 0; std::size_t p = 0;
for (std::size_t i = 0; i < new_size; i++) { for(std::size_t i = 0; i < new_size; i++)
if (std::find(axes.begin(), axes.end(), i) != axes.end()) { {
if(std::find(axes.begin(), axes.end(), i) != axes.end())
{
new_lens[i] = 1; new_lens[i] = 1;
} else { }
else
{
new_lens[i] = old_lens[p++]; new_lens[i] = old_lens[p++];
} }
} }
......
...@@ -6,77 +6,80 @@ ...@@ -6,77 +6,80 @@
#include <migraph/verify.hpp> #include <migraph/verify.hpp>
#include "test.hpp" #include "test.hpp"
void slice_test() { void slice_test()
migraph::program p; {
std::vector<float> data(4*3*2); migraph::program p;
std::iota(data.begin(), data.end(), 0); std::vector<float> data(4 * 3 * 2);
migraph::shape s{migraph::shape::float_type, {4,2,3}}; std::iota(data.begin(), data.end(), 0);
auto l0 = p.add_literal(migraph::literal{s, data}); migraph::shape s{migraph::shape::float_type, {4, 2, 3}};
p.add_instruction(migraph::squeeze{{0},{0},{2}}, l0); auto l0 = p.add_literal(migraph::literal{s, data});
p.compile(migraph::cpu::cpu_target{}); p.add_instruction(migraph::squeeze{{0}, {0}, {2}}, l0);
auto result = p.eval({}); p.compile(migraph::cpu::cpu_target{});
EXPECT(result.get_shape() == s2); auto result = p.eval({});
EXPECT(result.get_shape() == s2);
} }
void squeeze_test() { void squeeze_test()
{
{ {
migraph::program p; migraph::program p;
std::vector<float> data(4*3*3); std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4,1,3,1,3}}; migraph::shape s1{migraph::shape::float_type, {4, 1, 3, 1, 3}};
migraph::shape s2{migraph::shape::float_type, {4,3,1,3}}; migraph::shape s2{migraph::shape::float_type, {4, 3, 1, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data}); auto l0 = p.add_literal(migraph::literal{s1, data});
p.add_instruction(migraph::squeeze{{1}}, l0); p.add_instruction(migraph::squeeze{{1}}, l0);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result.get_shape() == s2); EXPECT(result.get_shape() == s2);
} }
{ {
migraph::program p; migraph::program p;
std::vector<float> data(4*3*3); std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4,1,3,1,3}}; migraph::shape s1{migraph::shape::float_type, {4, 1, 3, 1, 3}};
migraph::shape s2{migraph::shape::float_type, {4,1,3,3}}; migraph::shape s2{migraph::shape::float_type, {4, 1, 3, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data}); auto l0 = p.add_literal(migraph::literal{s1, data});
p.add_instruction(migraph::squeeze{{3}}, l0); p.add_instruction(migraph::squeeze{{3}}, l0);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result.get_shape() == s2); EXPECT(result.get_shape() == s2);
} }
{ {
migraph::program p; migraph::program p;
std::vector<float> data(4*3*3); std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4,1,3,1,3}}; migraph::shape s1{migraph::shape::float_type, {4, 1, 3, 1, 3}};
migraph::shape s2{migraph::shape::float_type, {4,3,3}}; migraph::shape s2{migraph::shape::float_type, {4, 3, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data}); auto l0 = p.add_literal(migraph::literal{s1, data});
p.add_instruction(migraph::squeeze{}, l0); p.add_instruction(migraph::squeeze{}, l0);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result.get_shape() == s2); EXPECT(result.get_shape() == s2);
} }
} }
void unsqueeze_test() { void unsqueeze_test()
{
{ {
migraph::program p; migraph::program p;
std::vector<float> data(4*3*3); std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4,3,3}}; migraph::shape s1{migraph::shape::float_type, {4, 3, 3}};
migraph::shape s2{migraph::shape::float_type, {4,1,3,3}}; migraph::shape s2{migraph::shape::float_type, {4, 1, 3, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data}); auto l0 = p.add_literal(migraph::literal{s1, data});
p.add_instruction(migraph::unsqueeze{{1}}, l0); p.add_instruction(migraph::unsqueeze{{1}}, l0);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result.get_shape() == s2); EXPECT(result.get_shape() == s2);
} }
{ {
migraph::program p; migraph::program p;
std::vector<float> data(4*3*3); std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4,3,3}}; migraph::shape s1{migraph::shape::float_type, {4, 3, 3}};
migraph::shape s2{migraph::shape::float_type, {4,3,1,3}}; migraph::shape s2{migraph::shape::float_type, {4, 3, 1, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data}); auto l0 = p.add_literal(migraph::literal{s1, data});
p.add_instruction(migraph::unsqueeze{{2}}, l0); p.add_instruction(migraph::unsqueeze{{2}}, l0);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result.get_shape() == s2); EXPECT(result.get_shape() == s2);
} }
} }
void im2col_3x3_no_pad_identity_test() void im2col_3x3_no_pad_identity_test()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment