Commit 6fe85d43 authored by Paul's avatar Paul
Browse files

Add autocontigous pass

parent 682b524e
add_library(migraph
auto_contiguous.cpp
dead_code_elimination.cpp
generate.cpp
program.cpp
......
#include <migraph/auto_contiguous.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
namespace migraph {
void auto_contigous::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
shape s = ins->result;
if(not s.packed() or s.broadcasted())
{
auto prev = p.insert_instruction(ins, ins->op, ins->arguments);
p.replace_instruction(ins, contiguous{}, prev);
}
}
}
} // namespace migraph
#ifndef MIGRAPH_GUARD_RTGLIB_AUTO_CONTIGOUS_HPP
#define MIGRAPH_GUARD_RTGLIB_AUTO_CONTIGOUS_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
namespace migraph {
struct program;
struct auto_contigous
{
std::string name() const { return "auto_contigous"; }
void apply(program& p) const;
};
} // namespace migraph
#endif
......@@ -2,6 +2,7 @@
#define MIGRAPH_GUARD_MIGRAPHLIB_LITERAL_HPP
#include <migraph/shape.hpp>
#include <migraph/shape_for_each.hpp>
#include <migraph/argument.hpp>
#include <migraph/tensor_view.hpp>
#include <migraph/raw_data.hpp>
......@@ -26,24 +27,21 @@ struct literal : raw_data<literal>
template <class T>
literal(shape s, const std::vector<T>& x) : buffer(s.bytes(), 0), m_shape(s)
{
assert(s.packed());
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
s.visit_type([&](auto as) { std::copy(x.begin(), x.end(), as.from(buffer.data())); });
fill(x.begin(), x.end());
}
template <class T>
literal(shape s, const std::initializer_list<T>& x) : buffer(s.bytes(), 0), m_shape(s)
{
assert(s.packed());
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
s.visit_type([&](auto as) { std::copy(x.begin(), x.end(), as.from(buffer.data())); });
fill(x.begin(), x.end());
}
template <class Iterator>
literal(shape s, Iterator start, Iterator end) : buffer(s.bytes(), 0), m_shape(s)
{
assert(s.packed());
s.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.data())); });
fill(start, end);
}
literal(shape s, const char* x) : buffer(x, x + s.bytes()), m_shape(s) {}
......@@ -66,6 +64,32 @@ struct literal : raw_data<literal>
private:
std::vector<char> buffer;
shape m_shape;
template <class Iterator>
void fill(Iterator start, Iterator end)
{
if(m_shape.packed())
{
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.data())); });
}
else
{
auto it = start;
m_shape.visit_type([&](auto as) {
auto output = make_view(m_shape, as.from(buffer.data()));
shape_for_each(output.get_shape(), [&](const auto& idx) {
it++;
output(idx.begin(), idx.end()) = *it;
});
});
// visit_all(*this)([&](auto output) {
// shape_for_each(output.get_shape(), [&](const auto& idx) {
// it++;
// output(idx.begin(), idx.end()) = *it;
// });
// });
}
}
};
} // namespace migraph
......
......@@ -78,6 +78,8 @@ struct program
instruction_ref begin();
instruction_ref end();
shape get_shape() const;
instruction_ref validate() const;
void compile(const target& t);
......
......@@ -126,6 +126,11 @@ bool program::has_instruction(instruction_ref ins) const
instruction_ref program::begin() { return impl->instructions.begin(); }
instruction_ref program::end() { return impl->instructions.end(); }
shape program::get_shape() const
{
return impl->instructions.back().result;
}
instruction_ref program::validate() const
{
return std::find_if(impl->instructions.begin(),
......
......@@ -22,7 +22,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
assert(m_lens.size() == m_strides.size());
assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and
"At least one stride must be non-zero");
m_packed = this->elements() == this->element_space();
m_packed = this->elements() == this->element_space() and std::is_sorted(m_strides.rbegin(), m_strides.rend());
}
void shape::calculate_strides()
......
#include <migraph/auto_contiguous.hpp>
#include <migraph/operators.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct contigous_target
{
std::string name() const { return "contigous"; }
std::vector<migraph::pass> get_passes(migraph::context&) const
{
return {migraph::auto_contigous{}};
}
migraph::context get_context() const { return {}; }
};
migraph::literal get_2x2()
{
return migraph::literal{{migraph::shape::float_type, {2, 2}}, {1, 2, 3, 4}};
}
void after_literal_transpose()
{
migraph::program p;
auto l = p.add_literal(get_2x2());
EXPECT(p.get_shape().packed());
p.add_instruction(migraph::transpose{{1, 0}}, l);
EXPECT(not p.get_shape().packed());
p.compile(contigous_target{});
EXPECT(p.get_shape().packed());
}
int main() {
after_literal_transpose();
}
......@@ -13,6 +13,24 @@ void test_shape_assign()
EXPECT(!(s1 != s2));
}
void test_shape_packed_default()
{
migraph::shape s{migraph::shape::float_type, {2, 2}};
EXPECT(s.packed());
}
void test_shape_packed()
{
migraph::shape s{migraph::shape::float_type, {2, 2}, {2, 1}};
EXPECT(s.packed());
}
void test_shape_transposed()
{
migraph::shape s{migraph::shape::float_type, {2, 2}, {1, 2}};
EXPECT(not s.packed());
}
void test_shape_default()
{
migraph::shape s1{};
......@@ -95,6 +113,9 @@ void test_shape4_nonpacked()
int main()
{
test_shape_assign();
test_shape_packed_default();
test_shape_packed();
test_shape_transposed();
test_shape_default();
test_shape4();
test_shape4_nonpacked();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment