Commit 523a78c7 authored by Scott Thornton's avatar Scott Thornton
Browse files

Added slice w/ tests

parent 492bf901
......@@ -317,43 +317,71 @@ struct slice
std::vector<int64_t> starts;
std::vector<int64_t> ends;
std::string name() const { return "slice"; }
shape compute_shape(std::vector<shape> inputs) const
auto fix_index(const std::vector<std::size_t>& lens, std::size_t axis, int64_t index) const
{
auto input_shape = inputs[0];
auto t = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
std::vector<int64_t> t_axes(old_lens.size());
if(axes.size() == 0)
std::size_t r = std::min(index, static_cast<int64_t>(lens[axis]));
if(r < 0)
r += lens[axis];
return r;
}
auto compute_offset(const shape& s) const
{
const std::vector<std::size_t>& lens = s.lens();
const std::vector<std::size_t>& strides = s.strides();
auto offset = 0;
if(axes.size() > 0)
{
std::iota(t_axes.begin(), t_axes.end(), 0);
for(std::size_t i = 0; i < axes.size(); i++)
{
auto axis = axes[i];
offset += fix_index(lens, axis, starts[i]) * strides[axis];
}
}
else
{
std::copy(axes.begin(), axes.end(), t_axes.begin());
for(std::size_t axis = 0; axis < lens.size(); axis++)
{
offset += fix_index(lens, axis, starts[axis]) * strides[axis];
}
}
if(starts.size() || t_axes.size() != ends.size())
return offset;
}
shape compute_shape(std::vector<shape> inputs) const
{
auto input_shape = inputs[0];
auto t = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
// std::vector<int64_t> t_axes(old_lens.size());
// if(axes.size() == 0)
// {
// std::iota(t_axes.begin(), t_axes.end(), 0);
// }
// else
// {
// std::copy(axes.begin(), axes.end(), t_axes.begin());
// }
if(starts.size() != axes.size() || axes.size() != ends.size())
{
MIGRAPH_THROW("inconsistent sizes");
}
std::vector<std::size_t> new_lens;
std::copy(old_lens.begin(), old_lens.end(), new_lens.begin());
auto fix_index = [&](std::size_t axis, int64_t index) {
auto r = std::min(index, static_cast<int64_t>(old_lens[axis] - 1));
if(r < 0)
r += old_lens[axis];
return r;
};
for(std::size_t i = 0; i < t_axes.size(); i++)
std::vector<std::size_t> new_lens = old_lens;
for(std::size_t i = 0; i < axes.size(); i++)
{
auto axis = t_axes[i];
new_lens[axis] = fix_index(axis, ends[i]) - fix_index(axis, starts[i]);
auto axis = axes[i];
new_lens[axis] =
fix_index(old_lens, axis, ends[i]) - fix_index(old_lens, axis, starts[i]);
}
return shape{t, new_lens, old_strides};
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.front().data)};
auto input = args[0];
auto offset = compute_offset(input.get_shape()) * output_shape.type_size();
return {std::move(output_shape), [=] { return input.data() + offset; }};
}
};
......
......@@ -63,6 +63,7 @@ struct shape
const std::vector<std::size_t>& strides() const;
std::size_t elements() const;
std::size_t bytes() const;
std::size_t type_size() const;
/// Map multiple indices to space index
std::size_t index(std::initializer_list<std::size_t> l) const;
......
......@@ -98,6 +98,12 @@ std::size_t shape::bytes() const
this->visit_type([&](auto as) { n = as.size(); });
return n * this->element_space();
}
std::size_t shape::type_size() const
{
std::size_t n = 0;
this->visit_type([&](auto as) { n = as.size(); });
return n;
}
std::size_t shape::index(std::initializer_list<std::size_t> l) const
{
assert(l.size() <= this->lens().size());
......
......@@ -8,15 +8,42 @@
void slice_test()
{
migraph::program p;
std::vector<float> data(4 * 3 * 2);
std::iota(data.begin(), data.end(), 0);
migraph::shape s{migraph::shape::float_type, {4, 2, 3}};
auto l0 = p.add_literal(migraph::literal{s, data});
p.add_instruction(migraph::squeeze{{0}, {0}, {2}}, l0);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
EXPECT(result.get_shape() == s2);
{
migraph::program p;
std::vector<int> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraph::shape s{migraph::shape::int32_type, {2, 2, 3}};
auto l0 = p.add_literal(migraph::literal{s, data});
p.add_instruction(migraph::slice{{2}, {1}, {3}}, l0);
migraph::shape s2{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}};
EXPECT(p.get_shape() == s2);
p.compile(migraph::cpu::cpu_target{});
migraph::shape sresult{migraph::shape::int32_type, {2, 2, 2}, {4, 2, 1}};
auto result = p.eval({});
std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, gold));
EXPECT(result.get_shape() == sresult);
}
{
migraph::program p;
std::vector<int> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraph::shape s{migraph::shape::int32_type, {2, 2, 3}};
auto l0 = p.add_literal(migraph::literal{s, data});
p.add_instruction(migraph::slice{{0, 1, 2}, {0, 0, 0}, {2, 2, 2}}, l0);
migraph::shape s2{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}};
EXPECT(p.get_shape() == s2);
p.compile(migraph::cpu::cpu_target{});
migraph::shape sresult{migraph::shape::int32_type, {2, 2, 2}, {4, 2, 1}};
auto result = p.eval({});
std::vector<int> gold = {0, 1, 3, 4, 6, 7, 9, 10};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, gold));
EXPECT(result.get_shape() == sresult);
}
}
void squeeze_test()
......@@ -877,6 +904,7 @@ void contiguous_test()
int main()
{
slice_test();
squeeze_test();
unsqueeze_test();
exp_test();
......
......@@ -130,6 +130,19 @@ void flatten_shape()
throws_shape(migraph::flatten{5}, input);
}
void slice_shape()
{
migraph::shape input{migraph::shape::int32_type, {2, 2, 3}};
expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraph::slice{{2}, {1}, {3}},
input);
expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraph::slice{{0, 1, 2}, {0, 0, 1}, {2, 2, 3}},
input);
expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 1}, {6, 3, 1}},
migraph::slice{{2}, {2}, {10}},
input);
}
int main()
{
batch_norm_inference_shape();
......@@ -138,4 +151,5 @@ int main()
contiguous_shape();
reshape_shape();
flatten_shape();
slice_shape();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment