"fs/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "6ed88985903be474ecd59992f7191c2f0fa87e36"
Unverified Commit 1fc7013f authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge pull request #57 from ROCmSoftwarePlatform/squeeze_unsqeeze

Squeeze unsqeeze
parents b3131ba7 42f00caa
...@@ -282,6 +282,151 @@ struct contiguous ...@@ -282,6 +282,151 @@ struct contiguous
} }
}; };
struct slice
{
std::vector<int64_t> axes;
std::vector<int64_t> starts;
std::vector<int64_t> ends;
std::string name() const { return "slice"; }
auto fix_index(const std::vector<std::size_t>& lens, std::size_t axis, int64_t index) const
{
int64_t r = std::min(index, static_cast<int64_t>(lens[axis]));
if(r < 0)
r += lens[axis];
return std::size_t(r);
}
auto compute_offset(const shape& s) const
{
const std::vector<std::size_t>& lens = s.lens();
const std::vector<std::size_t>& strides = s.strides();
auto offset = 0;
if(!axes.empty())
{
for(std::size_t i = 0; i < axes.size(); i++)
{
auto axis = axes[i];
offset += fix_index(lens, axis, starts[i]) * strides[axis];
}
}
else
{
for(std::size_t axis = 0; axis < lens.size(); axis++)
{
offset += fix_index(lens, axis, starts[axis]) * strides[axis];
}
}
return offset;
}
shape compute_shape(std::vector<shape> inputs) const
{
auto input_shape = inputs[0];
auto t = input_shape.type();
const auto& old_lens = input_shape.lens();
const auto& old_strides = input_shape.strides();
// std::vector<int64_t> t_axes(old_lens.size());
// if(axes.size() == 0)
// {
// std::iota(t_axes.begin(), t_axes.end(), 0);
// }
// else
// {
// std::copy(axes.begin(), axes.end(), t_axes.begin());
// }
if(starts.size() != axes.size() || axes.size() != ends.size())
{
MIGRAPH_THROW("inconsistent sizes");
}
std::vector<std::size_t> new_lens = old_lens;
for(std::size_t i = 0; i < axes.size(); i++)
{
auto axis = axes[i];
new_lens[axis] =
fix_index(old_lens, axis, ends[i]) - fix_index(old_lens, axis, starts[i]);
}
return shape{t, new_lens, old_strides};
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
auto input = args[0];
auto offset = compute_offset(input.get_shape()) * output_shape.type_size();
return {std::move(output_shape), [=] { return input.data() + offset; }};
}
};
struct squeeze
{
std::vector<int64_t> axes;
std::string name() const { return "squeeze"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto input_shape = inputs[0];
auto type = input_shape.type();
auto old_lens = input_shape.lens();
if(std::any_of(
axes.begin(), axes.end(), [&](auto axis) { return input_shape.lens()[axis] != 1; }))
{
MIGRAPH_THROW("squeeze axis dimension should be equal to 1");
}
std::vector<std::size_t> new_lens;
if(axes.empty())
{
std::copy_if(old_lens.begin(),
old_lens.end(),
std::back_inserter(new_lens),
[](auto len) { return len != 1; });
}
else
{
for(std::size_t i = 0; i < old_lens.size(); i++)
{
if(std::find(axes.begin(), axes.end(), i) == axes.end())
{
new_lens.push_back(old_lens[i]);
}
}
}
return shape{type, new_lens};
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.front().data)};
}
};
struct unsqueeze
{
std::vector<int64_t> axes;
std::string name() const { return "unsqueeze"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto input_shape = inputs[0];
auto type = input_shape.type();
auto old_lens = input_shape.lens();
std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size);
std::size_t p = 0;
for(std::size_t i = 0; i < new_size; i++)
{
if(std::find(axes.begin(), axes.end(), i) != axes.end())
{
new_lens[i] = 1;
}
else
{
new_lens[i] = old_lens[p++];
}
}
return shape{type, new_lens};
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.front().data)};
}
};
struct reshape struct reshape
{ {
std::vector<int64_t> dims; std::vector<int64_t> dims;
......
...@@ -63,6 +63,7 @@ struct shape ...@@ -63,6 +63,7 @@ struct shape
const std::vector<std::size_t>& strides() const; const std::vector<std::size_t>& strides() const;
std::size_t elements() const; std::size_t elements() const;
std::size_t bytes() const; std::size_t bytes() const;
std::size_t type_size() const;
/// Map multiple indices to space index /// Map multiple indices to space index
std::size_t index(std::initializer_list<std::size_t> l) const; std::size_t index(std::initializer_list<std::size_t> l) const;
......
...@@ -98,6 +98,12 @@ std::size_t shape::bytes() const ...@@ -98,6 +98,12 @@ std::size_t shape::bytes() const
this->visit_type([&](auto as) { n = as.size(); }); this->visit_type([&](auto as) { n = as.size(); });
return n * this->element_space(); return n * this->element_space();
} }
std::size_t shape::type_size() const
{
std::size_t n = 0;
this->visit_type([&](auto as) { n = as.size(); });
return n;
}
std::size_t shape::index(std::initializer_list<std::size_t> l) const std::size_t shape::index(std::initializer_list<std::size_t> l) const
{ {
assert(l.size() <= this->lens().size()); assert(l.size() <= this->lens().size());
......
...@@ -6,6 +6,109 @@ ...@@ -6,6 +6,109 @@
#include <migraph/verify.hpp> #include <migraph/verify.hpp>
#include "test.hpp" #include "test.hpp"
void slice_test()
{
{
migraph::program p;
std::vector<int> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraph::shape s{migraph::shape::int32_type, {2, 2, 3}};
auto l0 = p.add_literal(migraph::literal{s, data});
p.add_instruction(migraph::slice{{2}, {1}, {3}}, l0);
migraph::shape s2{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}};
EXPECT(p.get_shape() == s2);
p.compile(migraph::cpu::cpu_target{});
migraph::shape sresult{migraph::shape::int32_type, {2, 2, 2}, {4, 2, 1}};
auto result = p.eval({});
std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, gold));
EXPECT(result.get_shape() == sresult);
}
{
migraph::program p;
std::vector<int> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraph::shape s{migraph::shape::int32_type, {2, 2, 3}};
auto l0 = p.add_literal(migraph::literal{s, data});
p.add_instruction(migraph::slice{{0, 1, 2}, {0, 0, 0}, {2, 2, 2}}, l0);
migraph::shape s2{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}};
EXPECT(p.get_shape() == s2);
p.compile(migraph::cpu::cpu_target{});
migraph::shape sresult{migraph::shape::int32_type, {2, 2, 2}, {4, 2, 1}};
auto result = p.eval({});
std::vector<int> gold = {0, 1, 3, 4, 6, 7, 9, 10};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraph::verify_range(results_vector, gold));
EXPECT(result.get_shape() == sresult);
}
}
void squeeze_test()
{
{
migraph::program p;
std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4, 1, 3, 1, 3}};
migraph::shape s2{migraph::shape::float_type, {4, 3, 1, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data});
p.add_instruction(migraph::squeeze{{1}}, l0);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
EXPECT(result.get_shape() == s2);
}
{
migraph::program p;
std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4, 1, 3, 1, 3}};
migraph::shape s2{migraph::shape::float_type, {4, 1, 3, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data});
p.add_instruction(migraph::squeeze{{3}}, l0);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
EXPECT(result.get_shape() == s2);
}
{
migraph::program p;
std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4, 1, 3, 1, 3}};
migraph::shape s2{migraph::shape::float_type, {4, 3, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data});
p.add_instruction(migraph::squeeze{}, l0);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
EXPECT(result.get_shape() == s2);
}
}
void unsqueeze_test()
{
{
migraph::program p;
std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4, 3, 3}};
migraph::shape s2{migraph::shape::float_type, {4, 1, 3, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data});
p.add_instruction(migraph::unsqueeze{{1}}, l0);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
EXPECT(result.get_shape() == s2);
}
{
migraph::program p;
std::vector<float> data(4 * 3 * 3);
migraph::shape s1{migraph::shape::float_type, {4, 3, 3}};
migraph::shape s2{migraph::shape::float_type, {4, 3, 1, 3}};
auto l0 = p.add_literal(migraph::literal{s1, data});
p.add_instruction(migraph::unsqueeze{{2}}, l0);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
EXPECT(result.get_shape() == s2);
}
}
void im2col_3x3_no_pad_identity_test() void im2col_3x3_no_pad_identity_test()
{ {
std::size_t f[2] = {3, 3}; std::size_t f[2] = {3, 3};
...@@ -801,6 +904,9 @@ void contiguous_test() ...@@ -801,6 +904,9 @@ void contiguous_test()
int main() int main()
{ {
slice_test();
squeeze_test();
unsqueeze_test();
exp_test(); exp_test();
sin_test(); sin_test();
cos_test(); cos_test();
...@@ -814,7 +920,7 @@ int main() ...@@ -814,7 +920,7 @@ int main()
gemm_test<double>(); gemm_test<double>();
reshape_test(); reshape_test();
transpose_test(); transpose_test();
contiguous_test(); // contiguous_test();
softmax_test(); softmax_test();
// maxpool_test(); // maxpool_test();
conv2d_test(); conv2d_test();
......
...@@ -130,6 +130,19 @@ void flatten_shape() ...@@ -130,6 +130,19 @@ void flatten_shape()
throws_shape(migraph::flatten{5}, input); throws_shape(migraph::flatten{5}, input);
} }
void slice_shape()
{
migraph::shape input{migraph::shape::int32_type, {2, 2, 3}};
expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraph::slice{{2}, {1}, {3}},
input);
expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
migraph::slice{{0, 1, 2}, {0, 0, 1}, {2, 2, 3}},
input);
expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 1}, {6, 3, 1}},
migraph::slice{{2}, {2}, {10}},
input);
}
int main() int main()
{ {
batch_norm_inference_shape(); batch_norm_inference_shape();
...@@ -138,4 +151,5 @@ int main() ...@@ -138,4 +151,5 @@ int main()
contiguous_shape(); contiguous_shape();
reshape_shape(); reshape_shape();
flatten_shape(); flatten_shape();
slice_shape();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment