Commit ffcd5b35 authored by Paul's avatar Paul
Browse files

Add standard attribute to shape

parent bc70ef12
......@@ -11,7 +11,7 @@ void auto_contigous::apply(program& p) const
for(auto ins : iterator_for(p))
{
shape s = ins->result;
if(not s.packed() or s.broadcasted())
if(not s.standard())
{
auto prev = p.insert_instruction(ins, ins->op, ins->arguments);
p.replace_instruction(ins, contiguous{}, prev);
......
......@@ -68,7 +68,7 @@ struct literal : raw_data<literal>
template <class Iterator>
void fill(Iterator start, Iterator end)
{
if(m_shape.packed())
if(m_shape.standard())
{
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.data())); });
}
......@@ -82,12 +82,6 @@ struct literal : raw_data<literal>
output(idx.begin(), idx.end()) = *it;
});
});
// visit_all(*this)([&](auto output) {
// shape_for_each(output.get_shape(), [&](const auto& idx) {
// it++;
// output(idx.begin(), idx.end()) = *it;
// });
// });
}
}
};
......
......@@ -76,7 +76,9 @@ struct shape
std::size_t index(std::size_t i) const;
bool packed() const;
bool transposed() const;
bool broadcasted() const;
bool standard() const;
friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y);
......@@ -139,7 +141,7 @@ struct shape
type_t m_type;
std::vector<std::size_t> m_lens;
std::vector<std::size_t> m_strides;
bool m_packed;
bool m_standard;
void calculate_strides();
std::size_t element_space() const;
......
......@@ -88,16 +88,16 @@ struct tensor_view
return m_data[m_shape.index(this->size() - 1)];
}
// TODO: Add iterators so it can handle nonpacked tensors
// TODO: Add iterators so it can handle nonstandard tensors
T* begin()
{
assert(this->m_shape.packed());
assert(this->m_shape.standard());
return m_data;
}
T* end()
{
assert(this->m_shape.packed());
assert(this->m_shape.standard());
if(this->empty())
return m_data;
else
......@@ -106,13 +106,13 @@ struct tensor_view
const T* begin() const
{
assert(this->m_shape.packed());
assert(this->m_shape.standard());
return m_data;
}
const T* end() const
{
assert(this->m_shape.packed());
assert(this->m_shape.standard());
if(this->empty())
return m_data;
else
......
......@@ -8,10 +8,10 @@
namespace migraph {
shape::shape() : m_type(float_type), m_packed(false) {}
shape::shape() : m_type(float_type), m_standard(false) {}
shape::shape(type_t t) : m_type(t), m_lens({1}), m_strides({1}), m_packed(true) {}
shape::shape(type_t t, std::vector<std::size_t> l) : m_type(t), m_lens(std::move(l)), m_packed(true)
shape::shape(type_t t) : m_type(t), m_lens({1}), m_strides({1}), m_standard(true) {}
shape::shape(type_t t, std::vector<std::size_t> l) : m_type(t), m_lens(std::move(l)), m_standard(true)
{
this->calculate_strides();
assert(m_lens.size() == m_strides.size());
......@@ -22,8 +22,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
assert(m_lens.size() == m_strides.size());
assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and
"At least one stride must be non-zero");
m_packed = this->elements() == this->element_space() and
std::is_sorted(m_strides.rbegin(), m_strides.rend());
m_standard = this->packed() and not this->transposed();
}
void shape::calculate_strides()
......@@ -67,7 +66,7 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const
std::size_t shape::index(std::size_t i) const
{
assert(this->lens().size() == this->strides().size());
if(this->packed())
if(this->standard())
return i;
else
return std::inner_product(this->lens().begin(),
......@@ -80,7 +79,9 @@ std::size_t shape::index(std::size_t i) const
return ((i / stride) % len) * stride;
});
}
bool shape::packed() const { return this->m_packed; }
bool shape::packed() const { return this->elements() == this->element_space(); }
bool shape::transposed() const { return not std::is_sorted(this->strides().rbegin(), this->strides().rend()); }
bool shape::broadcasted() const
{
......@@ -91,6 +92,8 @@ bool shape::broadcasted() const
std::multiplies<std::size_t>()) == 0;
}
bool shape::standard() const { return this->m_standard; }
std::size_t shape::element_space() const
{
assert(this->lens().size() == this->strides().size());
......
......@@ -22,11 +22,14 @@ void after_literal_transpose()
{
migraph::program p;
auto l = p.add_literal(get_2x2());
EXPECT(p.get_shape().packed());
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
p.add_instruction(migraph::transpose{{1, 0}}, l);
EXPECT(not p.get_shape().packed());
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed());
p.compile(contigous_target{});
EXPECT(p.get_shape().packed());
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
}
int main() { after_literal_transpose(); }
......@@ -16,19 +16,37 @@ void test_shape_assign()
void test_shape_packed_default()
{
migraph::shape s{migraph::shape::float_type, {2, 2}};
EXPECT(s.standard());
EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
}
void test_shape_packed()
{
migraph::shape s{migraph::shape::float_type, {2, 2}, {2, 1}};
EXPECT(s.standard());
EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
}
void test_shape_transposed()
{
migraph::shape s{migraph::shape::float_type, {2, 2}, {1, 2}};
EXPECT(not s.standard());
EXPECT(s.packed());
EXPECT(s.transposed());
EXPECT(not s.broadcasted());
}
void test_shape_broadcasted()
{
migraph::shape s{migraph::shape::float_type, {2, 2}, {1, 0}};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
EXPECT(s.broadcasted());
}
void test_shape_default()
......@@ -42,7 +60,10 @@ void test_shape_default()
void test_shape4()
{
migraph::shape s{migraph::shape::float_type, {100, 32, 8, 8}};
EXPECT(s.standard());
EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
EXPECT(s.type() == migraph::shape::float_type);
EXPECT(s.lens()[0] == 100);
EXPECT(s.lens()[1] == 32);
......@@ -86,7 +107,10 @@ void test_shape4_nonpacked()
std::multiplies<std::size_t>());
migraph::shape s{migraph::shape::float_type, lens, strides};
EXPECT(!s.packed());
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
EXPECT(s.type() == migraph::shape::float_type);
EXPECT(s.lens()[0] == 100);
EXPECT(s.lens()[1] == 32);
......@@ -116,6 +140,7 @@ int main()
test_shape_packed_default();
test_shape_packed();
test_shape_transposed();
test_shape_broadcasted();
test_shape_default();
test_shape4();
test_shape4_nonpacked();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment