Commit d1481b13 authored by Paul's avatar Paul
Browse files

Merge branch 'contigous-pass'

parents b4d2a740 0df528ee
add_library(migraph add_library(migraph
auto_contiguous.cpp
dead_code_elimination.cpp dead_code_elimination.cpp
generate.cpp generate.cpp
program.cpp program.cpp
......
#include <migraph/auto_contiguous.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
namespace migraph {
void auto_contiguous::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
shape s = ins->result;
if(not s.standard())
{
auto prev = p.insert_instruction(ins, ins->op, ins->arguments);
p.replace_instruction(ins, contiguous{}, prev);
}
}
}
} // namespace migraph
#ifndef MIGRAPH_GUARD_RTGLIB_AUTO_CONTIGOUS_HPP
#define MIGRAPH_GUARD_RTGLIB_AUTO_CONTIGOUS_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
namespace migraph {
struct program;
struct auto_contiguous
{
std::string name() const { return "auto_contiguous"; }
void apply(program& p) const;
};
} // namespace migraph
#endif
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define MIGRAPH_GUARD_MIGRAPHLIB_LITERAL_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_LITERAL_HPP
#include <migraph/shape.hpp> #include <migraph/shape.hpp>
#include <migraph/shape_for_each.hpp>
#include <migraph/argument.hpp> #include <migraph/argument.hpp>
#include <migraph/tensor_view.hpp> #include <migraph/tensor_view.hpp>
#include <migraph/raw_data.hpp> #include <migraph/raw_data.hpp>
...@@ -26,24 +27,21 @@ struct literal : raw_data<literal> ...@@ -26,24 +27,21 @@ struct literal : raw_data<literal>
template <class T> template <class T>
literal(shape s, const std::vector<T>& x) : buffer(s.bytes(), 0), m_shape(s) literal(shape s, const std::vector<T>& x) : buffer(s.bytes(), 0), m_shape(s)
{ {
assert(s.packed());
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types"); static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
s.visit_type([&](auto as) { std::copy(x.begin(), x.end(), as.from(buffer.data())); }); fill(x.begin(), x.end());
} }
template <class T> template <class T>
literal(shape s, const std::initializer_list<T>& x) : buffer(s.bytes(), 0), m_shape(s) literal(shape s, const std::initializer_list<T>& x) : buffer(s.bytes(), 0), m_shape(s)
{ {
assert(s.packed());
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types"); static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
s.visit_type([&](auto as) { std::copy(x.begin(), x.end(), as.from(buffer.data())); }); fill(x.begin(), x.end());
} }
template <class Iterator> template <class Iterator>
literal(shape s, Iterator start, Iterator end) : buffer(s.bytes(), 0), m_shape(s) literal(shape s, Iterator start, Iterator end) : buffer(s.bytes(), 0), m_shape(s)
{ {
assert(s.packed()); fill(start, end);
s.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.data())); });
} }
literal(shape s, const char* x) : buffer(x, x + s.bytes()), m_shape(s) {} literal(shape s, const char* x) : buffer(x, x + s.bytes()), m_shape(s) {}
...@@ -66,6 +64,26 @@ struct literal : raw_data<literal> ...@@ -66,6 +64,26 @@ struct literal : raw_data<literal>
private: private:
std::vector<char> buffer; std::vector<char> buffer;
shape m_shape; shape m_shape;
template <class Iterator>
void fill(Iterator start, Iterator end)
{
if(m_shape.standard())
{
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.data())); });
}
else
{
auto it = start;
m_shape.visit_type([&](auto as) {
auto output = make_view(m_shape, as.from(buffer.data()));
shape_for_each(output.get_shape(), [&](const auto& idx) {
it++;
output(idx.begin(), idx.end()) = *it;
});
});
}
}
}; };
} // namespace migraph } // namespace migraph
......
...@@ -232,9 +232,9 @@ struct transpose ...@@ -232,9 +232,9 @@ struct transpose
} }
return {t, output_lens, output_strides}; return {t, output_lens, output_strides};
} }
argument compute(context&, shape, std::vector<argument>) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
MIGRAPH_THROW("not computable"); return {output_shape, std::move(args.front().data)};
} }
}; };
...@@ -297,9 +297,9 @@ struct reshape ...@@ -297,9 +297,9 @@ struct reshape
return s; return s;
} }
argument compute(context&, shape, std::vector<argument>) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
MIGRAPH_THROW("not computable"); return {output_shape, std::move(args.front().data)};
} }
friend std::ostream& operator<<(std::ostream& os, const reshape& op) friend std::ostream& operator<<(std::ostream& os, const reshape& op)
......
...@@ -78,6 +78,8 @@ struct program ...@@ -78,6 +78,8 @@ struct program
instruction_ref begin(); instruction_ref begin();
instruction_ref end(); instruction_ref end();
shape get_shape() const;
instruction_ref validate() const; instruction_ref validate() const;
void compile(const target& t); void compile(const target& t);
......
...@@ -76,7 +76,9 @@ struct shape ...@@ -76,7 +76,9 @@ struct shape
std::size_t index(std::size_t i) const; std::size_t index(std::size_t i) const;
bool packed() const; bool packed() const;
bool transposed() const;
bool broadcasted() const; bool broadcasted() const;
bool standard() const;
friend bool operator==(const shape& x, const shape& y); friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y); friend bool operator!=(const shape& x, const shape& y);
...@@ -139,7 +141,7 @@ struct shape ...@@ -139,7 +141,7 @@ struct shape
type_t m_type; type_t m_type;
std::vector<std::size_t> m_lens; std::vector<std::size_t> m_lens;
std::vector<std::size_t> m_strides; std::vector<std::size_t> m_strides;
bool m_packed; bool m_standard;
void calculate_strides(); void calculate_strides();
std::size_t element_space() const; std::size_t element_space() const;
......
...@@ -88,16 +88,16 @@ struct tensor_view ...@@ -88,16 +88,16 @@ struct tensor_view
return m_data[m_shape.index(this->size() - 1)]; return m_data[m_shape.index(this->size() - 1)];
} }
// TODO: Add iterators so it can handle nonpacked tensors // TODO: Add iterators so it can handle nonstandard tensors
T* begin() T* begin()
{ {
assert(this->m_shape.packed()); assert(this->m_shape.standard());
return m_data; return m_data;
} }
T* end() T* end()
{ {
assert(this->m_shape.packed()); assert(this->m_shape.standard());
if(this->empty()) if(this->empty())
return m_data; return m_data;
else else
...@@ -106,13 +106,13 @@ struct tensor_view ...@@ -106,13 +106,13 @@ struct tensor_view
const T* begin() const const T* begin() const
{ {
assert(this->m_shape.packed()); assert(this->m_shape.standard());
return m_data; return m_data;
} }
const T* end() const const T* end() const
{ {
assert(this->m_shape.packed()); assert(this->m_shape.standard());
if(this->empty()) if(this->empty())
return m_data; return m_data;
else else
......
...@@ -126,6 +126,8 @@ bool program::has_instruction(instruction_ref ins) const ...@@ -126,6 +126,8 @@ bool program::has_instruction(instruction_ref ins) const
instruction_ref program::begin() { return impl->instructions.begin(); } instruction_ref program::begin() { return impl->instructions.begin(); }
instruction_ref program::end() { return impl->instructions.end(); } instruction_ref program::end() { return impl->instructions.end(); }
shape program::get_shape() const { return impl->instructions.back().result; }
instruction_ref program::validate() const instruction_ref program::validate() const
{ {
return std::find_if(impl->instructions.begin(), return std::find_if(impl->instructions.begin(),
......
...@@ -8,10 +8,11 @@ ...@@ -8,10 +8,11 @@
namespace migraph { namespace migraph {
shape::shape() : m_type(float_type), m_packed(false) {} shape::shape() : m_type(float_type), m_standard(false) {}
shape::shape(type_t t) : m_type(t), m_lens({1}), m_strides({1}), m_packed(true) {} shape::shape(type_t t) : m_type(t), m_lens({1}), m_strides({1}), m_standard(true) {}
shape::shape(type_t t, std::vector<std::size_t> l) : m_type(t), m_lens(std::move(l)), m_packed(true) shape::shape(type_t t, std::vector<std::size_t> l)
: m_type(t), m_lens(std::move(l)), m_standard(true)
{ {
this->calculate_strides(); this->calculate_strides();
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
...@@ -22,7 +23,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) ...@@ -22,7 +23,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and
"At least one stride must be non-zero"); "At least one stride must be non-zero");
m_packed = this->elements() == this->element_space(); m_standard = this->packed() and not this->transposed();
} }
void shape::calculate_strides() void shape::calculate_strides()
...@@ -66,7 +67,7 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const ...@@ -66,7 +67,7 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const
std::size_t shape::index(std::size_t i) const std::size_t shape::index(std::size_t i) const
{ {
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
if(this->packed()) if(this->standard())
return i; return i;
else else
return std::inner_product(this->lens().begin(), return std::inner_product(this->lens().begin(),
...@@ -79,7 +80,12 @@ std::size_t shape::index(std::size_t i) const ...@@ -79,7 +80,12 @@ std::size_t shape::index(std::size_t i) const
return ((i / stride) % len) * stride; return ((i / stride) % len) * stride;
}); });
} }
bool shape::packed() const { return this->m_packed; } bool shape::packed() const { return this->elements() == this->element_space(); }
bool shape::transposed() const
{
return not std::is_sorted(this->strides().rbegin(), this->strides().rend());
}
bool shape::broadcasted() const bool shape::broadcasted() const
{ {
...@@ -90,18 +96,17 @@ bool shape::broadcasted() const ...@@ -90,18 +96,17 @@ bool shape::broadcasted() const
std::multiplies<std::size_t>()) == 0; std::multiplies<std::size_t>()) == 0;
} }
bool shape::standard() const { return this->m_standard; }
std::size_t shape::element_space() const std::size_t shape::element_space() const
{ {
// TODO: Get rid of intermediate vector
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
std::vector<std::size_t> max_indices(this->lens().size()); return std::inner_product(this->lens().begin(),
std::transform(this->lens().begin(), this->lens().end(),
this->lens().end(), this->strides().begin(),
std::vector<std::size_t>(this->lens().size(), 1).begin(), std::size_t{0},
max_indices.begin(), std::plus<std::size_t>{},
std::minus<std::size_t>()); [](std::size_t l, std::size_t s) { return (l - 1) * s; }) +
return std::inner_product(
max_indices.begin(), max_indices.end(), this->strides().begin(), std::size_t{0}) +
1; 1;
} }
......
...@@ -203,18 +203,6 @@ struct cpu_pooling ...@@ -203,18 +203,6 @@ struct cpu_pooling
} }
}; };
struct cpu_transpose
{
transpose op;
std::string name() const { return "cpu::transpose"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.front().data)};
}
};
struct cpu_contiguous struct cpu_contiguous
{ {
contiguous op; contiguous op;
...@@ -232,18 +220,6 @@ struct cpu_contiguous ...@@ -232,18 +220,6 @@ struct cpu_contiguous
} }
}; };
struct cpu_reshape
{
reshape op;
std::string name() const { return "cpu::reshape"; }
shape compute_shape(std::vector<shape> inputs) const { return op.compute_shape(inputs); }
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.front().data)};
}
};
struct cpu_gemm struct cpu_gemm
{ {
gemm op; gemm op;
...@@ -545,9 +521,7 @@ struct cpu_apply ...@@ -545,9 +521,7 @@ struct cpu_apply
apply_map["gemm"] = extend_op<cpu_gemm, gemm>(); apply_map["gemm"] = extend_op<cpu_gemm, gemm>();
apply_map["batch_norm_inference"] = apply_map["batch_norm_inference"] =
extend_op<cpu_batch_norm_inference, batch_norm_inference>(); extend_op<cpu_batch_norm_inference, batch_norm_inference>();
apply_map["reshape"] = extend_op<cpu_reshape, reshape>();
apply_map["contiguous"] = extend_op<cpu_contiguous, contiguous>(); apply_map["contiguous"] = extend_op<cpu_contiguous, contiguous>();
apply_map["transpose"] = extend_op<cpu_transpose, transpose>();
apply_map["identity"] = simple_op<cpu_unary<identity_op>>(); apply_map["identity"] = simple_op<cpu_unary<identity_op>>();
apply_map["tanh"] = simple_op<cpu_unary<tanh_op>>(); apply_map["tanh"] = simple_op<cpu_unary<tanh_op>>();
......
...@@ -183,22 +183,6 @@ struct miopen_gemm ...@@ -183,22 +183,6 @@ struct miopen_gemm
} }
}; };
struct miopen_transpose
{
transpose op;
std::string name() const { return "gpu::transpose"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
return op.compute_shape({inputs.at(0)});
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
return {output_shape, std::move(args.front().data)};
}
};
struct miopen_contiguous struct miopen_contiguous
{ {
contiguous op; contiguous op;
...@@ -271,10 +255,6 @@ struct miopen_apply ...@@ -271,10 +255,6 @@ struct miopen_apply
{ {
apply_gemm(it); apply_gemm(it);
} }
else if(it->op.name() == "transpose")
{
apply_transpose(it);
}
else if(it->op.name() == "contiguous") else if(it->op.name() == "contiguous")
{ {
apply_contiguous(it); apply_contiguous(it);
...@@ -346,13 +326,6 @@ struct miopen_apply ...@@ -346,13 +326,6 @@ struct miopen_apply
ins, miopen_gemm{op}, ins->arguments.at(0), ins->arguments.at(1), output); ins, miopen_gemm{op}, ins->arguments.at(0), ins->arguments.at(1), output);
} }
void apply_transpose(instruction_ref ins)
{
auto&& op = any_cast<transpose>(ins->op);
auto output = insert_allocation(ins, ins->result);
prog->replace_instruction(ins, miopen_transpose{op}, ins->arguments.at(0), output);
}
void apply_contiguous(instruction_ref ins) void apply_contiguous(instruction_ref ins)
{ {
auto&& op = any_cast<contiguous>(ins->op); auto&& op = any_cast<contiguous>(ins->op);
......
#include <migraph/auto_contiguous.hpp>
#include <migraph/operators.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct contiguous_target
{
std::string name() const { return "contiguous"; }
std::vector<migraph::pass> get_passes(migraph::context&) const
{
return {migraph::auto_contiguous{}};
}
migraph::context get_context() const { return {}; }
};
migraph::literal get_2x2()
{
return migraph::literal{{migraph::shape::float_type, {2, 2}}, {1, 2, 3, 4}};
}
migraph::literal get_2() { return migraph::literal{{migraph::shape::float_type, {2}}, {1, 2}}; }
void after_literal_transpose()
{
migraph::program p;
auto l = p.add_literal(get_2x2());
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
auto t = p.add_instruction(migraph::transpose{{1, 0}}, l);
p.add_instruction(pass_op{}, t);
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().transposed());
p.compile(contiguous_target{});
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().transposed());
}
void after_literal_broadcast()
{
migraph::program p;
auto l1 = p.add_literal(get_2x2());
auto l2 = p.add_literal(get_2());
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted());
auto b = p.add_instruction(migraph::broadcast{}, l1, l2);
p.add_instruction(pass_op{}, b);
EXPECT(not p.get_shape().standard());
EXPECT(p.get_shape().broadcasted());
p.compile(contiguous_target{});
EXPECT(p.get_shape().standard());
EXPECT(not p.get_shape().broadcasted());
}
int main()
{
after_literal_transpose();
after_literal_broadcast();
}
...@@ -641,7 +641,7 @@ void contiguous_test() ...@@ -641,7 +641,7 @@ void contiguous_test()
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<size_t> new_lens = {1, 3, 2, 2}; std::vector<size_t> new_lens = {1, 3, 2, 2};
std::vector<size_t> new_strides = {12, 1, 6, 3}; std::vector<size_t> new_strides = {12, 1, 6, 3};
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11}; std::vector<float> gold = {1, 4, 7, 10, 2, 5, 8, 11, 3, 6, 9, 0};
EXPECT(test::verify_range(results_vector, gold)); EXPECT(test::verify_range(results_vector, gold));
} }
......
...@@ -61,3 +61,22 @@ struct minus_op ...@@ -61,3 +61,22 @@ struct minus_op
return inputs.front(); return inputs.front();
} }
}; };
struct pass_op
{
std::string name() const { return "pass"; }
migraph::argument
compute(migraph::context&, migraph::shape, std::vector<migraph::argument> args) const
{
if(args.empty())
return {};
return args.front();
}
migraph::shape compute_shape(std::vector<migraph::shape> inputs) const
{
if(inputs.empty())
return {};
return inputs.front();
}
};
...@@ -13,6 +13,42 @@ void test_shape_assign() ...@@ -13,6 +13,42 @@ void test_shape_assign()
EXPECT(!(s1 != s2)); EXPECT(!(s1 != s2));
} }
void test_shape_packed_default()
{
migraph::shape s{migraph::shape::float_type, {2, 2}};
EXPECT(s.standard());
EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
}
void test_shape_packed()
{
migraph::shape s{migraph::shape::float_type, {2, 2}, {2, 1}};
EXPECT(s.standard());
EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
}
void test_shape_transposed()
{
migraph::shape s{migraph::shape::float_type, {2, 2}, {1, 2}};
EXPECT(not s.standard());
EXPECT(s.packed());
EXPECT(s.transposed());
EXPECT(not s.broadcasted());
}
void test_shape_broadcasted()
{
migraph::shape s{migraph::shape::float_type, {2, 2}, {1, 0}};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
EXPECT(s.broadcasted());
}
void test_shape_default() void test_shape_default()
{ {
migraph::shape s1{}; migraph::shape s1{};
...@@ -24,7 +60,10 @@ void test_shape_default() ...@@ -24,7 +60,10 @@ void test_shape_default()
void test_shape4() void test_shape4()
{ {
migraph::shape s{migraph::shape::float_type, {100, 32, 8, 8}}; migraph::shape s{migraph::shape::float_type, {100, 32, 8, 8}};
EXPECT(s.standard());
EXPECT(s.packed()); EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
EXPECT(s.type() == migraph::shape::float_type); EXPECT(s.type() == migraph::shape::float_type);
EXPECT(s.lens()[0] == 100); EXPECT(s.lens()[0] == 100);
EXPECT(s.lens()[1] == 32); EXPECT(s.lens()[1] == 32);
...@@ -68,7 +107,10 @@ void test_shape4_nonpacked() ...@@ -68,7 +107,10 @@ void test_shape4_nonpacked()
std::multiplies<std::size_t>()); std::multiplies<std::size_t>());
migraph::shape s{migraph::shape::float_type, lens, strides}; migraph::shape s{migraph::shape::float_type, lens, strides};
EXPECT(!s.packed()); EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
EXPECT(s.type() == migraph::shape::float_type); EXPECT(s.type() == migraph::shape::float_type);
EXPECT(s.lens()[0] == 100); EXPECT(s.lens()[0] == 100);
EXPECT(s.lens()[1] == 32); EXPECT(s.lens()[1] == 32);
...@@ -95,6 +137,10 @@ void test_shape4_nonpacked() ...@@ -95,6 +137,10 @@ void test_shape4_nonpacked()
int main() int main()
{ {
test_shape_assign(); test_shape_assign();
test_shape_packed_default();
test_shape_packed();
test_shape_transposed();
test_shape_broadcasted();
test_shape_default(); test_shape_default();
test_shape4(); test_shape4();
test_shape4_nonpacked(); test_shape4_nonpacked();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment