Commit 492bf901 authored by Scott Thornton's avatar Scott Thornton
Browse files

Formatting

parent 6a385c26
......@@ -324,23 +324,28 @@ struct slice
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
std::vector<int64_t> t_axes(old_lens.size());
if (axes.size() == 0) {
if(axes.size() == 0)
{
std::iota(t_axes.begin(), t_axes.end(), 0);
}
else {
else
{
std::copy(axes.begin(), axes.end(), t_axes.begin());
}
if (starts.size() || t_axes.size() != ends.size()) {
if(starts.size() || t_axes.size() != ends.size())
{
MIGRAPH_THROW("inconsistent sizes");
}
std::vector<std::size_t> new_lens;
std::copy(old_lens.begin(), old_lens.end(), new_lens.begin());
auto fix_index = [&] (std::size_t axis, int64_t index) {
auto r = std::min(index, static_cast<int64_t>(old_lens[axis]-1));
if (r < 0) r+= old_lens[axis];
auto fix_index = [&](std::size_t axis, int64_t index) {
auto r = std::min(index, static_cast<int64_t>(old_lens[axis] - 1));
if(r < 0)
r += old_lens[axis];
return r;
};
for (std::size_t i = 0; i < t_axes.size(); i++) {
for(std::size_t i = 0; i < t_axes.size(); i++)
{
auto axis = t_axes[i];
new_lens[axis] = fix_index(axis, ends[i]) - fix_index(axis, starts[i]);
}
......@@ -361,22 +366,28 @@ struct squeeze
auto input_shape = inputs[0];
auto type = input_shape.type();
auto old_lens = input_shape.lens();
for (auto axis : axes) {
if (input_shape.lens()[axis] != 1) {
for(auto axis : axes)
{
if(input_shape.lens()[axis] != 1)
{
MIGRAPH_THROW("squeeze axis dimension should be equal to 1");
}
}
std::vector<std::size_t> new_lens;
if (axes.size() == 0) {
for (std::size_t i = 0; i < old_lens.size(); i++) {
if (old_lens[i] != 1)
if(axes.size() == 0)
{
for(std::size_t i = 0; i < old_lens.size(); i++)
{
if(old_lens[i] != 1)
new_lens.push_back(old_lens[i]);
}
}
else {
for (std::size_t i = 0; i < old_lens.size(); i++) {
if (std::find(axes.begin(), axes.end(), i)
== axes.end()) {
else
{
for(std::size_t i = 0; i < old_lens.size(); i++)
{
if(std::find(axes.begin(), axes.end(), i) == axes.end())
{
new_lens.push_back(old_lens[i]);
}
}
......@@ -401,10 +412,14 @@ struct unsqueeze
std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size);
std::size_t p = 0;
for (std::size_t i = 0; i < new_size; i++) {
if (std::find(axes.begin(), axes.end(), i) != axes.end()) {
for(std::size_t i = 0; i < new_size; i++)
{
if(std::find(axes.begin(), axes.end(), i) != axes.end())
{
new_lens[i] = 1;
} else {
}
else
{
new_lens[i] = old_lens[p++];
}
}
......
......@@ -6,24 +6,26 @@
#include <migraph/verify.hpp>
#include "test.hpp"
void slice_test() {
void slice_test()
{
migraph::program p;
std::vector<float> data(4*3*2);
std::vector<float> data(4 * 3 * 2);
std::iota(data.begin(), data.end(), 0);
migraph::shape s{migraph::shape::float_type, {4,2,3}};
migraph::shape s{migraph::shape::float_type, {4, 2, 3}};
auto l0 = p.add_literal(migraph::literal{s, data});
p.add_instruction(migraph::squeeze{{0},{0},{2}}, l0);
p.add_instruction(migraph::squeeze{{0}, {0}, {2}}, l0);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
EXPECT(result.get_shape() == s2);
}
void squeeze_test() {
void squeeze_test()
{
{
migraph::program p;
std::vector<float> data(4*3*3);
migraph::shape s1{migraph::shape::float_type, {4,1,3,1,3}};
migraph::shape s2{migraph::shape::float_type, {4,3,1,3}};
std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4, 1, 3, 1, 3}};
migraph::shape s2{migraph::shape::float_type, {4, 3, 1, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data});
p.add_instruction(migraph::squeeze{{1}}, l0);
p.compile(migraph::cpu::cpu_target{});
......@@ -32,9 +34,9 @@ void squeeze_test() {
}
{
migraph::program p;
std::vector<float> data(4*3*3);
migraph::shape s1{migraph::shape::float_type, {4,1,3,1,3}};
migraph::shape s2{migraph::shape::float_type, {4,1,3,3}};
std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4, 1, 3, 1, 3}};
migraph::shape s2{migraph::shape::float_type, {4, 1, 3, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data});
p.add_instruction(migraph::squeeze{{3}}, l0);
p.compile(migraph::cpu::cpu_target{});
......@@ -43,9 +45,9 @@ void squeeze_test() {
}
{
migraph::program p;
std::vector<float> data(4*3*3);
migraph::shape s1{migraph::shape::float_type, {4,1,3,1,3}};
migraph::shape s2{migraph::shape::float_type, {4,3,3}};
std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4, 1, 3, 1, 3}};
migraph::shape s2{migraph::shape::float_type, {4, 3, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data});
p.add_instruction(migraph::squeeze{}, l0);
p.compile(migraph::cpu::cpu_target{});
......@@ -54,12 +56,13 @@ void squeeze_test() {
}
}
void unsqueeze_test() {
void unsqueeze_test()
{
{
migraph::program p;
std::vector<float> data(4*3*3);
migraph::shape s1{migraph::shape::float_type, {4,3,3}};
migraph::shape s2{migraph::shape::float_type, {4,1,3,3}};
std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4, 3, 3}};
migraph::shape s2{migraph::shape::float_type, {4, 1, 3, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data});
p.add_instruction(migraph::unsqueeze{{1}}, l0);
p.compile(migraph::cpu::cpu_target{});
......@@ -68,9 +71,9 @@ void unsqueeze_test() {
}
{
migraph::program p;
std::vector<float> data(4*3*3);
migraph::shape s1{migraph::shape::float_type, {4,3,3}};
migraph::shape s2{migraph::shape::float_type, {4,3,1,3}};
std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4, 3, 3}};
migraph::shape s2{migraph::shape::float_type, {4, 3, 1, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data});
p.add_instruction(migraph::unsqueeze{{2}}, l0);
p.compile(migraph::cpu::cpu_target{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment