Commit 523a78c7 authored by Scott Thornton's avatar Scott Thornton
Browse files

Added slice w/ tests

parent 492bf901
...@@ -317,43 +317,71 @@ struct slice ...@@ -317,43 +317,71 @@ struct slice
std::vector<int64_t> starts; std::vector<int64_t> starts;
std::vector<int64_t> ends; std::vector<int64_t> ends;
std::string name() const { return "slice"; } std::string name() const { return "slice"; }
shape compute_shape(std::vector<shape> inputs) const
auto fix_index(const std::vector<std::size_t>& lens, std::size_t axis, int64_t index) const
{ {
auto input_shape = inputs[0]; std::size_t r = std::min(index, static_cast<int64_t>(lens[axis]));
auto t = input_shape.type(); if(r < 0)
auto old_lens = input_shape.lens(); r += lens[axis];
auto old_strides = input_shape.strides(); return r;
std::vector<int64_t> t_axes(old_lens.size()); }
if(axes.size() == 0)
auto compute_offset(const shape& s) const
{
const std::vector<std::size_t>& lens = s.lens();
const std::vector<std::size_t>& strides = s.strides();
auto offset = 0;
if(axes.size() > 0)
{ {
std::iota(t_axes.begin(), t_axes.end(), 0); for(std::size_t i = 0; i < axes.size(); i++)
{
auto axis = axes[i];
offset += fix_index(lens, axis, starts[i]) * strides[axis];
}
} }
else else
{ {
std::copy(axes.begin(), axes.end(), t_axes.begin()); for(std::size_t axis = 0; axis < lens.size(); axis++)
{
offset += fix_index(lens, axis, starts[axis]) * strides[axis];
}
} }
if(starts.size() || t_axes.size() != ends.size()) return offset;
}
shape compute_shape(std::vector<shape> inputs) const
{
auto input_shape = inputs[0];
auto t = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
// std::vector<int64_t> t_axes(old_lens.size());
// if(axes.size() == 0)
// {
// std::iota(t_axes.begin(), t_axes.end(), 0);
// }
// else
// {
// std::copy(axes.begin(), axes.end(), t_axes.begin());
// }
if(starts.size() != axes.size() || axes.size() != ends.size())
{ {
MIGRAPH_THROW("inconsistent sizes"); MIGRAPH_THROW("inconsistent sizes");
} }
std::vector<std::size_t> new_lens; std::vector<std::size_t> new_lens = old_lens;
std::copy(old_lens.begin(), old_lens.end(), new_lens.begin()); for(std::size_t i = 0; i < axes.size(); i++)
auto fix_index = [&](std::size_t axis, int64_t index) {
auto r = std::min(index, static_cast<int64_t>(old_lens[axis] - 1));
if(r < 0)
r += old_lens[axis];
return r;
};
for(std::size_t i = 0; i < t_axes.size(); i++)
{ {
auto axis = t_axes[i]; auto axis = axes[i];
new_lens[axis] = fix_index(axis, ends[i]) - fix_index(axis, starts[i]); new_lens[axis] =
fix_index(old_lens, axis, ends[i]) - fix_index(old_lens, axis, starts[i]);
} }
return shape{t, new_lens, old_strides}; return shape{t, new_lens, old_strides};
} }
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; auto input = args[0];
auto offset = compute_offset(input.get_shape()) * output_shape.type_size();
return {std::move(output_shape), [=] { return input.data() + offset; }};
} }
}; };
......
...@@ -63,6 +63,7 @@ struct shape ...@@ -63,6 +63,7 @@ struct shape
const std::vector<std::size_t>& strides() const; const std::vector<std::size_t>& strides() const;
std::size_t elements() const; std::size_t elements() const;
std::size_t bytes() const; std::size_t bytes() const;
std::size_t type_size() const;
/// Map multiple indices to space index /// Map multiple indices to space index
std::size_t index(std::initializer_list<std::size_t> l) const; std::size_t index(std::initializer_list<std::size_t> l) const;
......
...@@ -98,6 +98,12 @@ std::size_t shape::bytes() const ...@@ -98,6 +98,12 @@ std::size_t shape::bytes() const
this->visit_type([&](auto as) { n = as.size(); }); this->visit_type([&](auto as) { n = as.size(); });
return n * this->element_space(); return n * this->element_space();
} }
std::size_t shape::type_size() const
{
std::size_t n = 0;
this->visit_type([&](auto as) { n = as.size(); });
return n;
}
std::size_t shape::index(std::initializer_list<std::size_t> l) const std::size_t shape::index(std::initializer_list<std::size_t> l) const
{ {
assert(l.size() <= this->lens().size()); assert(l.size() <= this->lens().size());
......
...@@ -8,15 +8,42 @@ ...@@ -8,15 +8,42 @@
void slice_test() void slice_test()
{ {
migraph::program p; {
std::vector<float> data(4 * 3 * 2); migraph::program p;
std::iota(data.begin(), data.end(), 0); std::vector<int> data(2 * 2 * 3);
migraph::shape s{migraph::shape::float_type, {4, 2, 3}}; std::iota(data.begin(), data.end(), 0);
auto l0 = p.add_literal(migraph::literal{s, data}); migraph::shape s{migraph::shape::int32_type, {2, 2, 3}};
p.add_instruction(migraph::squeeze{{0}, {0}, {2}}, l0); auto l0 = p.add_literal(migraph::literal{s, data});
p.compile(migraph::cpu::cpu_target{}); p.add_instruction(migraph::slice{{2}, {1}, {3}}, l0);
auto result = p.eval({}); migraph::shape s2{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}};
EXPECT(result.get_shape() == s2); EXPECT(p.get_shape() == s2);
p.compile(migraph::cpu::cpu_target{});
migraph::shape sresult{migraph::shape::int32_type, {2, 2, 2}, {4, 2, 1}};
auto result = p.eval({});
std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, gold));
EXPECT(result.get_shape() == sresult);
}
{
migraph::program p;
std::vector<int> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraph::shape s{migraph::shape::int32_type, {2, 2, 3}};
auto l0 = p.add_literal(migraph::literal{s, data});
p.add_instruction(migraph::slice{{0, 1, 2}, {0, 0, 0}, {2, 2, 2}}, l0);
migraph::shape s2{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}};
EXPECT(p.get_shape() == s2);
p.compile(migraph::cpu::cpu_target{});
migraph::shape sresult{migraph::shape::int32_type, {2, 2, 2}, {4, 2, 1}};
auto result = p.eval({});
std::vector<int> gold = {0, 1, 3, 4, 6, 7, 9, 10};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, gold));
EXPECT(result.get_shape() == sresult);
}
} }
void squeeze_test() void squeeze_test()
...@@ -877,6 +904,7 @@ void contiguous_test() ...@@ -877,6 +904,7 @@ void contiguous_test()
int main() int main()
{ {
slice_test();
squeeze_test(); squeeze_test();
unsqueeze_test(); unsqueeze_test();
exp_test(); exp_test();
......
...@@ -130,6 +130,19 @@ void flatten_shape() ...@@ -130,6 +130,19 @@ void flatten_shape()
throws_shape(migraph::flatten{5}, input); throws_shape(migraph::flatten{5}, input);
} }
void slice_shape()
{
migraph::shape input{migraph::shape::int32_type, {2, 2, 3}};
expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraph::slice{{2}, {1}, {3}},
input);
expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraph::slice{{0, 1, 2}, {0, 0, 1}, {2, 2, 3}},
input);
expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 1}, {6, 3, 1}},
migraph::slice{{2}, {2}, {10}},
input);
}
int main() int main()
{ {
batch_norm_inference_shape(); batch_norm_inference_shape();
...@@ -138,4 +151,5 @@ int main() ...@@ -138,4 +151,5 @@ int main()
contiguous_shape(); contiguous_shape();
reshape_shape(); reshape_shape();
flatten_shape(); flatten_shape();
slice_shape();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment