Unverified Commit a0b570b2 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add more supported operators and optimizations for the cpu backend (#746)



* Add eliminate_data_type pass

* Formatting

* Auto convert quant ops

* Formatting

* Flip the order of decompose

* Compute max size differently

* Formatting

* Clamp values in convert

* Formatting

* Fix loss of precision in reduce

* Formatting

* Fix bugs in reduction

* Fix accumulator type in reference softmax implementation

* Formatting

* Update convert test

* Remove unused variables

* Remove unnecessary quant_dot check

* Formatting

* Add tests

* Formatting

* Remove unused code

* Remove duplicate ops

* Remove blaze dependency

* Use set since shape::type_t is no hashable on gcc 5

* Formatting

* Add dnnl binary op

* Formatting

* Add binary and eltwise

* Formatting

* Add softmax

* Formatting

* Remove unused operators

* Add missing files

* Formatting

* Add lrn

* Formatting

* Add deconvolution

* Formatting

* Change allocate default

* Add reorder

* Formatting

* Add reductions

* Formatting

* Sort lines

* Change literals in another loop

* Add pow operator

* Formatting

* Add pow operator

* Formatting

* Make sure shapes are packed

* Allow broadcasted inputs

* Remove unused operators

* Simplify functions

* Remove softmax

* Add sub and erf functions

* Formatting

* Fix bug

* Formatting

* Improve parallism

* Formatting

* Allow multiple batch dimensions

* Formatting

* Move literal transforms out of lowering

* Formatting

* Add gather operator

* Sort lines

* Add early exit for carry

* Formatting

* Add missing concat

* Rename macro

* Fix deep nesting

* Formatting

* Fix cppcheck issues

* Remov else

* Move attribute to typedef

* Formatting

* Disable maybe-uninitialized warning since its broken on gcc

* Add constexpr default constructor

* Formatting

* Fix compiler warnings

* Fix adjust_allocation test
Co-authored-by: default avatarShucai Xiao <shucai@gmail.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 165d1a17
#include <migraphx/config.hpp>
#include <migraphx/cpu/dnnl.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
struct dnnl_reduction : dnnl_op<dnnl_reduction, dnnl::reduction>
{
std::string algo;
std::vector<std::int64_t> axes{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.algo, "algo"), f(self.axes, "axes"));
}
std::string name() const { return "dnnl::reduction"; }
shape compute_shape(std::vector<shape> inputs) const
{
// Compensate for allocation
inputs.pop_back();
check_shapes{inputs, *this}.has(1).standard();
auto s = inputs.at(0);
auto lens = s.lens();
for(auto axis : axes)
{
lens[axis] = 1;
}
auto r = shape{s.type(), lens};
// Call to get_primitive to make sure an algo is available
this->get_primitive(this->to_memory_desc(r, inputs));
return r;
}
dnnl::reduction::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
return {to_dnnl_algo(algo), m.at(DNNL_ARG_SRC), m.at(DNNL_ARG_DST), 0, 0};
}
};
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/config.hpp>
#include <migraphx/cpu/dnnl.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
struct dnnl_reorder : dnnl_op<dnnl_reorder, dnnl::reorder>
{
template <class Self, class F>
static auto reflect(Self&, F)
{
return pack();
}
std::string name() const { return "dnnl::reorder"; }
shape adjust_shape(const shape& x, int) const { return x; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2);
return inputs.back();
}
// Custom desc class since its missing in dnnl
struct desc
{
dnnl::memory::desc src;
dnnl::memory::desc dst;
};
desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
return {m.at(DNNL_ARG_SRC), m.at(DNNL_ARG_DST)};
}
auto get_primitive_desc(const desc& d) const
{
auto& engine = get_dnnl_context().engine;
return dnnl::reorder::primitive_desc(engine, d.src, engine, d.dst);
}
};
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/cpu/pointwise.hpp> #include <migraphx/cpu/dnnl.hpp>
#include <migraphx/op/relu.hpp> #include <migraphx/op/softmax.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace cpu { namespace cpu {
struct dnnl_relu : dnnl_extend_op<dnnl_relu, dnnl::eltwise_forward, op::relu> struct dnnl_softmax : dnnl_extend_op<dnnl_softmax, dnnl::softmax_forward, op::softmax>
{ {
dnnl::eltwise_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const dnnl::softmax_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{ {
return {dnnl::prop_kind::forward_inference, int axis = this->op.axis;
dnnl::algorithm::eltwise_relu, return {dnnl::prop_kind::forward_inference, m.at(DNNL_ARG_SRC_0), axis};
m.at(DNNL_ARG_SRC_0)};
} }
}; };
......
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/cpu/pointwise.hpp> #include <migraphx/cpu/pointwise.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/sub.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace cpu { namespace cpu {
struct dnnl_add : dnnl_extend_op<dnnl_add, dnnl::binary, op::add> template struct cpu_binary<op::sub>;
{
dnnl::binary::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
return {dnnl::algorithm::binary_add,
m.at(DNNL_ARG_SRC_0),
m.at(DNNL_ARG_SRC_1),
m.at(DNNL_ARG_DST)};
}
};
} // namespace cpu } // namespace cpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <migraphx/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
#include <migraphx/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/cpu/write_literals.hpp>
#include <migraphx/cpu/allocation_model.hpp> #include <migraphx/cpu/allocation_model.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/cpu/lowering.hpp> #include <migraphx/cpu/lowering.hpp>
...@@ -38,7 +39,6 @@ std::string target::name() const { return "cpu"; } ...@@ -38,7 +39,6 @@ std::string target::name() const { return "cpu"; }
std::vector<pass> target::get_passes(migraphx::context&, const compile_options&) const std::vector<pass> target::get_passes(migraphx::context&, const compile_options&) const
{ {
std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end()); std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end());
unsupported_types.erase(shape::type_t::double_type);
unsupported_types.erase(shape::type_t::float_type); unsupported_types.erase(shape::type_t::float_type);
return {normalize_ops{}, return {normalize_ops{},
eliminate_data_type{unsupported_types, shape::type_t::float_type}, eliminate_data_type{unsupported_types, shape::type_t::float_type},
...@@ -63,10 +63,12 @@ std::vector<pass> target::get_passes(migraphx::context&, const compile_options&) ...@@ -63,10 +63,12 @@ std::vector<pass> target::get_passes(migraphx::context&, const compile_options&)
propagate_constant{}, propagate_constant{},
dead_code_elimination{}, dead_code_elimination{},
lowering{}, lowering{},
eliminate_contiguous{}, eliminate_contiguous{"dnnl::reorder"},
dead_code_elimination{}, dead_code_elimination{},
adjust_allocation{cpu_allocation_model{}}, adjust_allocation{cpu_allocation_model{}},
dead_code_elimination{}, dead_code_elimination{},
write_literals{},
dead_code_elimination{},
memory_coloring{"cpu::allocate"}, memory_coloring{"cpu::allocate"},
dead_code_elimination{}}; dead_code_elimination{}};
} }
......
#include <migraphx/cpu/write_literals.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
struct cpu_literal
{
argument data;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.data, "data"));
}
std::string name() const { return "cpu::literal"; }
shape compute_shape(const std::vector<shape>&) const { return data.get_shape(); }
argument compute(const shape&, const std::vector<argument>&) const { return data; }
friend std::ostream& operator<<(std::ostream& os, const cpu_literal& x)
{
os << x.name();
return os;
}
};
void write_literals::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "@literal")
continue;
m.replace_instruction(ins, cpu_literal{ins->get_literal().get_argument()});
}
}
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -68,7 +68,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -68,7 +68,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
propagate_constant{}, propagate_constant{},
dead_code_elimination{}, dead_code_elimination{},
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
eliminate_contiguous{}, eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{}, dead_code_elimination{},
eliminate_concat{concat_gpu_optimization{}}, eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{}, dead_code_elimination{},
......
...@@ -8,7 +8,8 @@ ...@@ -8,7 +8,8 @@
void run_pass(migraphx::module& m) void run_pass(migraphx::module& m)
{ {
migraphx::run_passes(m, {migraphx::eliminate_contiguous{}, migraphx::dead_code_elimination{}}); migraphx::run_passes(
m, {migraphx::eliminate_contiguous{"contiguous"}, migraphx::dead_code_elimination{}});
} }
TEST_CASE(standard_op) TEST_CASE(standard_op)
......
...@@ -23,7 +23,7 @@ void run_lowering(migraphx::program& p) ...@@ -23,7 +23,7 @@ void run_lowering(migraphx::program& p)
{migraphx::auto_contiguous{}, {migraphx::auto_contiguous{},
migraphx::gpu::lowering{&ctx, false}, migraphx::gpu::lowering{&ctx, false},
migraphx::dead_code_elimination{}, migraphx::dead_code_elimination{},
migraphx::eliminate_contiguous{}, migraphx::eliminate_contiguous{"gpu::contiguous"},
migraphx::dead_code_elimination{}}); migraphx::dead_code_elimination{}});
} }
......
...@@ -10,7 +10,7 @@ struct test_acos : verify_program<test_acos> ...@@ -10,7 +10,7 @@ struct test_acos : verify_program<test_acos>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {16}}; migraphx::shape s{migraphx::shape::float_type, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("acos"), x); mm->add_instruction(migraphx::make_op("acos"), x);
return p; return p;
......
...@@ -10,7 +10,7 @@ struct test_asin : verify_program<test_asin> ...@@ -10,7 +10,7 @@ struct test_asin : verify_program<test_asin>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {16}}; migraphx::shape s{migraphx::shape::float_type, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("asin"), x); mm->add_instruction(migraphx::make_op("asin"), x);
return p; return p;
......
...@@ -10,7 +10,7 @@ struct test_asinh : verify_program<test_asinh> ...@@ -10,7 +10,7 @@ struct test_asinh : verify_program<test_asinh>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {16}}; migraphx::shape s{migraphx::shape::float_type, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("asinh"), x); mm->add_instruction(migraphx::make_op("asinh"), x);
return p; return p;
......
...@@ -10,7 +10,7 @@ struct test_atan : verify_program<test_atan> ...@@ -10,7 +10,7 @@ struct test_atan : verify_program<test_atan>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {16}}; migraphx::shape s{migraphx::shape::float_type, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("atan"), x); mm->add_instruction(migraphx::make_op("atan"), x);
return p; return p;
......
...@@ -10,10 +10,10 @@ struct test_atanh : verify_program<test_atanh> ...@@ -10,10 +10,10 @@ struct test_atanh : verify_program<test_atanh>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {16}}; migraphx::shape s{migraphx::shape::float_type, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto min_val = mm->add_literal(-0.95); auto min_val = mm->add_literal(-0.95f);
auto max_val = mm->add_literal(0.95); auto max_val = mm->add_literal(0.95f);
min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {16}}}), min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {16}}}),
min_val); min_val);
max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {16}}}), max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {16}}}),
......
...@@ -10,7 +10,7 @@ struct test_cos : verify_program<test_cos> ...@@ -10,7 +10,7 @@ struct test_cos : verify_program<test_cos>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {8}}; migraphx::shape s{migraphx::shape::float_type, {8}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("cos"), x); mm->add_instruction(migraphx::make_op("cos"), x);
return p; return p;
......
...@@ -10,7 +10,7 @@ struct test_cosh : verify_program<test_cosh> ...@@ -10,7 +10,7 @@ struct test_cosh : verify_program<test_cosh>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {16}}; migraphx::shape s{migraphx::shape::float_type, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("cosh"), x); mm->add_instruction(migraphx::make_op("cosh"), x);
return p; return p;
......
...@@ -23,8 +23,6 @@ template struct test_logsoftmax<0, migraphx::shape::float_type>; ...@@ -23,8 +23,6 @@ template struct test_logsoftmax<0, migraphx::shape::float_type>;
template struct test_logsoftmax<1, migraphx::shape::float_type>; template struct test_logsoftmax<1, migraphx::shape::float_type>;
template struct test_logsoftmax<2, migraphx::shape::float_type>; template struct test_logsoftmax<2, migraphx::shape::float_type>;
template struct test_logsoftmax<3, migraphx::shape::float_type>; template struct test_logsoftmax<3, migraphx::shape::float_type>;
template struct test_logsoftmax<1, migraphx::shape::double_type>;
template struct test_logsoftmax<3, migraphx::shape::double_type>;
template struct test_logsoftmax<1, migraphx::shape::half_type>; template struct test_logsoftmax<1, migraphx::shape::half_type>;
template struct test_logsoftmax<0, migraphx::shape::half_type>; template struct test_logsoftmax<0, migraphx::shape::half_type>;
template struct test_logsoftmax<2, migraphx::shape::half_type>; template struct test_logsoftmax<2, migraphx::shape::half_type>;
......
...@@ -10,7 +10,7 @@ struct test_recip : verify_program<test_recip> ...@@ -10,7 +10,7 @@ struct test_recip : verify_program<test_recip>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {3}}; migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("recip"), x); mm->add_instruction(migraphx::make_op("recip"), x);
return p; return p;
......
...@@ -10,7 +10,7 @@ struct test_sinh : verify_program<test_sinh> ...@@ -10,7 +10,7 @@ struct test_sinh : verify_program<test_sinh>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {16}}; migraphx::shape s{migraphx::shape::float_type, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("sinh"), x); mm->add_instruction(migraphx::make_op("sinh"), x);
return p; return p;
......
...@@ -21,8 +21,6 @@ struct test_softmax : verify_program<test_softmax<Axis, T>> ...@@ -21,8 +21,6 @@ struct test_softmax : verify_program<test_softmax<Axis, T>>
template struct test_softmax<0, migraphx::shape::float_type>; template struct test_softmax<0, migraphx::shape::float_type>;
template struct test_softmax<2, migraphx::shape::float_type>; template struct test_softmax<2, migraphx::shape::float_type>;
template struct test_softmax<1, migraphx::shape::double_type>;
template struct test_softmax<3, migraphx::shape::double_type>;
template struct test_softmax<0, migraphx::shape::half_type>; template struct test_softmax<0, migraphx::shape::half_type>;
template struct test_softmax<1, migraphx::shape::half_type>; template struct test_softmax<1, migraphx::shape::half_type>;
template struct test_softmax<2, migraphx::shape::half_type>; template struct test_softmax<2, migraphx::shape::half_type>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment