Commit ffcd5b35 authored by Paul's avatar Paul
Browse files

Add standard attribute to shape

parent bc70ef12
...@@ -11,7 +11,7 @@ void auto_contigous::apply(program& p) const ...@@ -11,7 +11,7 @@ void auto_contigous::apply(program& p) const
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
shape s = ins->result; shape s = ins->result;
if(not s.packed() or s.broadcasted()) if(not s.standard())
{ {
auto prev = p.insert_instruction(ins, ins->op, ins->arguments); auto prev = p.insert_instruction(ins, ins->op, ins->arguments);
p.replace_instruction(ins, contiguous{}, prev); p.replace_instruction(ins, contiguous{}, prev);
......
...@@ -68,7 +68,7 @@ struct literal : raw_data<literal> ...@@ -68,7 +68,7 @@ struct literal : raw_data<literal>
template <class Iterator> template <class Iterator>
void fill(Iterator start, Iterator end) void fill(Iterator start, Iterator end)
{ {
if(m_shape.packed()) if(m_shape.standard())
{ {
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.data())); }); m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.data())); });
} }
...@@ -82,12 +82,6 @@ struct literal : raw_data<literal> ...@@ -82,12 +82,6 @@ struct literal : raw_data<literal>
output(idx.begin(), idx.end()) = *it; output(idx.begin(), idx.end()) = *it;
}); });
}); });
// visit_all(*this)([&](auto output) {
// shape_for_each(output.get_shape(), [&](const auto& idx) {
// it++;
// output(idx.begin(), idx.end()) = *it;
// });
// });
} }
} }
}; };
......
...@@ -76,7 +76,9 @@ struct shape ...@@ -76,7 +76,9 @@ struct shape
std::size_t index(std::size_t i) const; std::size_t index(std::size_t i) const;
bool packed() const; bool packed() const;
bool transposed() const;
bool broadcasted() const; bool broadcasted() const;
bool standard() const;
friend bool operator==(const shape& x, const shape& y); friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y); friend bool operator!=(const shape& x, const shape& y);
...@@ -139,7 +141,7 @@ struct shape ...@@ -139,7 +141,7 @@ struct shape
type_t m_type; type_t m_type;
std::vector<std::size_t> m_lens; std::vector<std::size_t> m_lens;
std::vector<std::size_t> m_strides; std::vector<std::size_t> m_strides;
bool m_packed; bool m_standard;
void calculate_strides(); void calculate_strides();
std::size_t element_space() const; std::size_t element_space() const;
......
...@@ -88,16 +88,16 @@ struct tensor_view ...@@ -88,16 +88,16 @@ struct tensor_view
return m_data[m_shape.index(this->size() - 1)]; return m_data[m_shape.index(this->size() - 1)];
} }
// TODO: Add iterators so it can handle nonpacked tensors // TODO: Add iterators so it can handle nonstandard tensors
T* begin() T* begin()
{ {
assert(this->m_shape.packed()); assert(this->m_shape.standard());
return m_data; return m_data;
} }
T* end() T* end()
{ {
assert(this->m_shape.packed()); assert(this->m_shape.standard());
if(this->empty()) if(this->empty())
return m_data; return m_data;
else else
...@@ -106,13 +106,13 @@ struct tensor_view ...@@ -106,13 +106,13 @@ struct tensor_view
const T* begin() const const T* begin() const
{ {
assert(this->m_shape.packed()); assert(this->m_shape.standard());
return m_data; return m_data;
} }
const T* end() const const T* end() const
{ {
assert(this->m_shape.packed()); assert(this->m_shape.standard());
if(this->empty()) if(this->empty())
return m_data; return m_data;
else else
......
...@@ -8,10 +8,10 @@ ...@@ -8,10 +8,10 @@
namespace migraph { namespace migraph {
shape::shape() : m_type(float_type), m_packed(false) {} shape::shape() : m_type(float_type), m_standard(false) {}
shape::shape(type_t t) : m_type(t), m_lens({1}), m_strides({1}), m_packed(true) {} shape::shape(type_t t) : m_type(t), m_lens({1}), m_strides({1}), m_standard(true) {}
shape::shape(type_t t, std::vector<std::size_t> l) : m_type(t), m_lens(std::move(l)), m_packed(true) shape::shape(type_t t, std::vector<std::size_t> l) : m_type(t), m_lens(std::move(l)), m_standard(true)
{ {
this->calculate_strides(); this->calculate_strides();
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
...@@ -22,8 +22,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) ...@@ -22,8 +22,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and
"At least one stride must be non-zero"); "At least one stride must be non-zero");
m_packed = this->elements() == this->element_space() and m_standard = this->packed() and not this->transposed();
std::is_sorted(m_strides.rbegin(), m_strides.rend());
} }
void shape::calculate_strides() void shape::calculate_strides()
...@@ -67,7 +66,7 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const ...@@ -67,7 +66,7 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const
std::size_t shape::index(std::size_t i) const std::size_t shape::index(std::size_t i) const
{ {
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
if(this->packed()) if(this->standard())
return i; return i;
else else
return std::inner_product(this->lens().begin(), return std::inner_product(this->lens().begin(),
...@@ -80,7 +79,9 @@ std::size_t shape::index(std::size_t i) const ...@@ -80,7 +79,9 @@ std::size_t shape::index(std::size_t i) const
return ((i / stride) % len) * stride; return ((i / stride) % len) * stride;
}); });
} }
bool shape::packed() const { return this->m_packed; } bool shape::packed() const { return this->elements() == this->element_space(); }
bool shape::transposed() const { return not std::is_sorted(this->strides().rbegin(), this->strides().rend()); }
bool shape::broadcasted() const bool shape::broadcasted() const
{ {
...@@ -91,6 +92,8 @@ bool shape::broadcasted() const ...@@ -91,6 +92,8 @@ bool shape::broadcasted() const
std::multiplies<std::size_t>()) == 0; std::multiplies<std::size_t>()) == 0;
} }
bool shape::standard() const { return this->m_standard; }
std::size_t shape::element_space() const std::size_t shape::element_space() const
{ {
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
......
...@@ -22,11 +22,14 @@ void after_literal_transpose() ...@@ -22,11 +22,14 @@ void after_literal_transpose()
{ {
migraph::program p; migraph::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
EXPECT(p.get_shape().packed()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
p.add_instruction(migraph::transpose{{1, 0}}, l); p.add_instruction(migraph::transpose{{1, 0}}, l);
EXPECT(not p.get_shape().packed()); EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed());
p.compile(contigous_target{}); p.compile(contigous_target{});
EXPECT(p.get_shape().packed()); EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
} }
int main() { after_literal_transpose(); } int main() { after_literal_transpose(); }
...@@ -16,19 +16,37 @@ void test_shape_assign() ...@@ -16,19 +16,37 @@ void test_shape_assign()
void test_shape_packed_default() void test_shape_packed_default()
{ {
migraph::shape s{migraph::shape::float_type, {2, 2}}; migraph::shape s{migraph::shape::float_type, {2, 2}};
EXPECT(s.standard());
EXPECT(s.packed()); EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
} }
void test_shape_packed() void test_shape_packed()
{ {
migraph::shape s{migraph::shape::float_type, {2, 2}, {2, 1}}; migraph::shape s{migraph::shape::float_type, {2, 2}, {2, 1}};
EXPECT(s.standard());
EXPECT(s.packed()); EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
} }
void test_shape_transposed() void test_shape_transposed()
{ {
migraph::shape s{migraph::shape::float_type, {2, 2}, {1, 2}}; migraph::shape s{migraph::shape::float_type, {2, 2}, {1, 2}};
EXPECT(not s.standard());
EXPECT(s.packed());
EXPECT(s.transposed());
EXPECT(not s.broadcasted());
}
void test_shape_broadcasted()
{
migraph::shape s{migraph::shape::float_type, {2, 2}, {1, 0}};
EXPECT(not s.standard());
EXPECT(not s.packed()); EXPECT(not s.packed());
EXPECT(not s.transposed());
EXPECT(s.broadcasted());
} }
void test_shape_default() void test_shape_default()
...@@ -42,7 +60,10 @@ void test_shape_default() ...@@ -42,7 +60,10 @@ void test_shape_default()
void test_shape4() void test_shape4()
{ {
migraph::shape s{migraph::shape::float_type, {100, 32, 8, 8}}; migraph::shape s{migraph::shape::float_type, {100, 32, 8, 8}};
EXPECT(s.standard());
EXPECT(s.packed()); EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
EXPECT(s.type() == migraph::shape::float_type); EXPECT(s.type() == migraph::shape::float_type);
EXPECT(s.lens()[0] == 100); EXPECT(s.lens()[0] == 100);
EXPECT(s.lens()[1] == 32); EXPECT(s.lens()[1] == 32);
...@@ -86,7 +107,10 @@ void test_shape4_nonpacked() ...@@ -86,7 +107,10 @@ void test_shape4_nonpacked()
std::multiplies<std::size_t>()); std::multiplies<std::size_t>());
migraph::shape s{migraph::shape::float_type, lens, strides}; migraph::shape s{migraph::shape::float_type, lens, strides};
EXPECT(!s.packed()); EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
EXPECT(s.type() == migraph::shape::float_type); EXPECT(s.type() == migraph::shape::float_type);
EXPECT(s.lens()[0] == 100); EXPECT(s.lens()[0] == 100);
EXPECT(s.lens()[1] == 32); EXPECT(s.lens()[1] == 32);
...@@ -116,6 +140,7 @@ int main() ...@@ -116,6 +140,7 @@ int main()
test_shape_packed_default(); test_shape_packed_default();
test_shape_packed(); test_shape_packed();
test_shape_transposed(); test_shape_transposed();
test_shape_broadcasted();
test_shape_default(); test_shape_default();
test_shape4(); test_shape4();
test_shape4_nonpacked(); test_shape4_nonpacked();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment