Unverified Commit f7befe50 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Cpu fusions using post_ops (#781)



* Add eliminate_data_type pass

* Formatting

* Auto convert quant ops

* Formatting

* Flip the order of decompose

* Compute max size differently

* Formatting

* Clamp values in convert

* Formatting

* Fix loss of precision in reduce

* Formatting

* Fix bugs in reduction

* Fix accumulator type in reference softmax implementation

* Formatting

* Update convert test

* Remove unused variables

* Remove unnecessary quant_dot check

* Formatting

* Add tests

* Formatting

* Remove unused code

* Remove duplicate ops

* Remove blaze dependency

* Use set since shape::type_t is no hashable on gcc 5

* Formatting

* Add dnnl binary op

* Formatting

* Add binary and eltwise

* Formatting

* Add softmax

* Formatting

* Remove unused operators

* Add missing files

* Formatting

* Add lrn

* Formatting

* Add deconvolution

* Formatting

* Change allocate default

* Add reorder

* Formatting

* Add reductions

* Formatting

* Sort lines

* Change literals in another loop

* Add pow operator

* Formatting

* Add pow operator

* Formatting

* Make sure shapes are packed

* Allow broadcasted inputs

* Remove unused operators

* Simplify functions

* Remove softmax

* Add sub and erf functions

* Formatting

* Fix bug

* Formatting

* Improve parallism

* Formatting

* Allow multiple batch dimensions

* Formatting

* Move literal transforms out of lowering

* Formatting

* Add gather operator

* Sort lines

* Add early exit for carry

* Formatting

* Add missing concat

* Rename macro

* Fix deep nesting

* Formatting

* Fix cppcheck issues

* Remov else

* Move attribute to typedef

* Formatting

* Disable maybe-uninitialized warning since its broken on gcc

* Add constexpr default constructor

* Formatting

* Fix compiler warnings

* Fix adjust_allocation test

* Add layernorm matcher

* Add gelu_erf matcher

* Formatting

* Add gelu_tanh matcher

* Formatting

* Remove match namespace

* Formatting

* Use matcher instead of string

* Formatting

* Add fusions

* Formatting

* Add post op field

* Formatting

* Make post_ops serializable

* Formatting

* Add eltwise fusions

* Formatting

* Fix null conversions

* Formatting

* Add fuse_ops source files

* Formatting

* Set binary post op index correctly

* Formatting

* Fix serialization bugs

* Check if used once

* Formatting

* Fix error in get_primitive_attr

* Formatting

* Add compile function

* Formatting

* Limit fusions

* Formatting

* Disable with env variable instead of using compile arg

* Formatting

* Fix implicit conversion to bool

* Declar on seperate lines

* Formatting

* Fix cppcheck issues

* Fix ICE in pack_join

* Formatting

* Use const ref

* Make enum hashable

* Formatting

* Add explicit this

* Fix merge issues

* Fix dangling ref

* Formatting

* Add test for compile

* Formatting

* Add more value tests

* Formatting
Co-authored-by: default avatarShucai Xiao <shucai@gmail.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 78eaf2b8
......@@ -147,7 +147,7 @@ std::size_t mismatch_diff(R1&& r1, R2&& r2, T diff)
}
template <class R1, class R2>
double rms_range(R1&& r1, R2&& r2)
double rms_range(const R1& r1, const R2& r2)
{
std::size_t n = range_distance(r1);
if(n == range_distance(r2))
......@@ -164,7 +164,7 @@ double rms_range(R1&& r1, R2&& r2)
}
template <class R1, class R2>
bool verify_range(R1&& r1, R2&& r2, double tolerance = 80, double* out_error = nullptr)
bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out_error = nullptr)
{
double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance;
auto error = rms_range(r1, r2);
......
......@@ -443,5 +443,19 @@ shape compute_shape(const operation& op,
return op.compute_shape(to_shapes(args), mods);
}
}
std::vector<shape> try_compute_shape(const operation& op, const std::vector<shape>& inputs)
{
shape new_shape;
try
{
new_shape = op.compute_shape(inputs);
}
catch(...)
{
return {};
}
return {new_shape};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -396,6 +396,40 @@ instruction_ref module::validate() const
});
}
bool is_borrowed(instruction_ref ins)
{
auto alias = instruction::get_output_alias(ins, true);
if(alias == ins)
return false;
if(alias->get_operator().is_borrowed())
return true;
return is_borrowed(alias);
}
bool is_param_alias(instruction_ref ins)
{
return instruction::get_output_alias(ins)->name() == "@param";
}
bool is_dangling(instruction_ref ins) { return not is_param_alias(ins) and is_borrowed(ins); }
instruction_ref module::find_dangling_reference() const
{
auto last = std::prev(end());
if(last->name() == "@return")
{
auto dangling = std::find_if(
last->inputs().begin(), last->inputs().end(), [](auto x) { return is_dangling(x); });
if(dangling != last->inputs().end())
return *dangling;
}
else if(is_dangling(last))
{
return last;
}
return end();
}
void module::finalize(context& ctx)
{
for(auto ins : iterator_for(*this))
......
......@@ -135,7 +135,7 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
template <class Stream>
packer<Stream>& operator()(msgpack::packer<Stream>& o, const migraphx::value& v) const
{
v.visit([&](auto&& x) { this->write(o, x); });
v.visit_value([&](auto&& x) { this->write(o, x); });
return o;
}
};
......
......@@ -158,6 +158,13 @@ void program::compile(const target& t, compile_options options)
MIGRAPHX_THROW("Invalid module " + mod->name() + " from compilation at instruction " +
std::to_string(std::distance(mod->begin(), invalid)));
}
auto dangling = mod->find_dangling_reference();
if(dangling != mod->end())
{
auto index = std::distance(mod->begin(), dangling);
MIGRAPHX_THROW("Dangling reference in module " + mod->name() + " from instruction " +
std::to_string(index));
}
mod->finalize(this->impl->ctx);
}
}
......@@ -249,7 +256,6 @@ std::vector<argument> generic_eval(const module* mod,
}
assert(results.find(ins) != results.end());
}
return {results.at(std::prev(mod->end()))};
}
......
File mode changed from 100644 to 100755
......@@ -12,6 +12,7 @@ add_library(migraphx_cpu
dnnl.cpp
eltwise.cpp
erf.cpp
fuse_ops.cpp
gather.cpp
gemm.cpp
layernorm.cpp
......
......@@ -11,7 +11,7 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary>
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.algo, "algo"));
return pack_join(self.reflect_base(self, f), pack(f(self.algo, "algo")));
}
std::string name() const { return "dnnl::binary"; }
......@@ -20,7 +20,7 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary>
{
// Compensate for allocation
inputs.pop_back();
check_shapes{inputs, *this}.has(2);
check_shapes{this->trim_post_op_inputs(inputs), *this}.has(2);
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
auto r = s0;
......
......@@ -33,9 +33,9 @@ struct dnnl_concat : dnnl_extend_op<dnnl_concat, dnnl::concat, op::concat>
return {m.at(DNNL_ARG_DST), std::size_t(op.axis), srcs};
}
auto get_primitive_desc(const desc& d) const
auto get_primitive_desc(const desc& d, const dnnl::primitive_attr& attr) const
{
return dnnl::concat::primitive_desc(d.dst, d.axis, d.srcs, get_dnnl_context().engine);
return dnnl::concat::primitive_desc(d.dst, d.axis, d.srcs, get_dnnl_context().engine, attr);
}
};
......
#include <migraphx/cpu/dnnl.hpp>
#if defined(__GNUC__) && __GNUC__ <= 5
namespace std {
template <>
struct hash<dnnl::algorithm>
{
using argument_type = dnnl::algorithm;
using result_type = std::size_t;
result_type operator()(const argument_type& x) const noexcept
{
return std::hash<underlying_type_t<argument_type>>{}(
static_cast<underlying_type_t<argument_type>>(x));
}
};
} // namespace std
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
......@@ -139,6 +156,23 @@ dnnl::algorithm to_dnnl_algo(const std::string& name)
return dnnl_algo_map().at(name);
}
const std::unordered_map<dnnl::algorithm, std::string>& dnnl_algo_string_map()
{
static const std::unordered_map<dnnl::algorithm, std::string> m = {
#define MIGRAPHX_DNNL_ALGO_GENERATE_VISITOR(x) {dnnl::algorithm::x, #x},
MIGRAPHX_VISIT_DNNL_ALGO(MIGRAPHX_DNNL_ALGO_GENERATE_VISITOR)
#undef MIGRAPHX_DNNL_ALGO_GENERATE_VISITOR
};
return m;
}
std::string to_string(const dnnl::algorithm& algo)
{
if(dnnl_algo_string_map().count(algo) == 0)
return "unknown_" + std::to_string(static_cast<int>(algo));
return dnnl_algo_string_map().at(algo);
}
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -13,7 +13,8 @@ struct dnnl_eltwise : dnnl_op<dnnl_eltwise, dnnl::eltwise_forward>
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.algo, "algo"), f(self.alpha, "alpha"), f(self.beta, "beta"));
return pack_join(self.reflect_base(self, f),
pack(f(self.algo, "algo"), f(self.alpha, "alpha"), f(self.beta, "beta")));
}
std::string name() const { return "dnnl::eltwise"; }
......@@ -22,7 +23,7 @@ struct dnnl_eltwise : dnnl_op<dnnl_eltwise, dnnl::eltwise_forward>
{
// Compensate for allocation
inputs.pop_back();
check_shapes{inputs, *this}.has(1).packed();
check_shapes{this->trim_post_op_inputs(inputs), *this}.has(1).packed();
auto s = inputs.at(0);
auto r = s;
if(not s.packed())
......
#include <migraphx/cpu/fuse_ops.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/value.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/context.hpp>
#include <migraphx/env.hpp>
#include <migraphx/cpu/context.hpp>
#include <migraphx/dead_code_elimination.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_DNNL_POST_OPS_WORKAROUND);
MIGRAPHX_PRED_MATCHER(has_post_ops, instruction_ref ins)
{
auto v = ins->get_operator().to_value();
return v.contains("post_ops");
}
MIGRAPHX_PRED_MATCHER(without_post_ops, instruction_ref ins)
{
auto v = ins->get_operator().to_value();
return v.contains("post_ops") and v["post_ops"].empty();
}
bool workaround_dnnl_broken_post_ops(const operation& op, const operation& post_op)
{
if(contains({"dnnl::dot", "dnnl::convolution"}, op.name()))
return true;
auto pv = post_op.to_value();
if(not pv.at("post_ops").empty())
return true;
auto v = op.to_value();
auto last_op = v.at("post_ops").empty() ? v : v.at("post_ops").back();
auto algo = last_op.contains("algo") ? last_op.at("algo").to<std::string>() : op.name();
auto post_algo = pv["algo"].to<std::string>();
if(starts_with(algo, "eltwise") and starts_with(post_algo, "eltwise"))
return true;
if(algo == post_algo)
return true;
return false;
}
operation merge_post_ops(const operation& op, const operation& post_op)
{
auto pv = post_op.to_value();
auto v = op.to_value();
v["post_ops"].push_back({{"algo", pv["algo"]},
{"alpha", pv["alpha"].value_or(0.0f)},
{"beta", pv["beta"].value_or(0.0f)}});
auto post_ops = pv.at("post_ops");
for(const auto& po : post_ops)
v["post_ops"].push_back(po);
return make_op(op.name(), v);
}
struct find_post_ops
{
context* ctx = nullptr;
match::any_matcher matcher() const
{
if(enabled(MIGRAPHX_DISABLE_DNNL_POST_OPS_WORKAROUND{}))
return match::name("dnnl::eltwise",
"dnnl::binary")(match::arg(0)(has_post_ops(), match::used_once()));
else
return match::name("dnnl::eltwise")(
without_post_ops(),
match::arg(0)(match::name("dnnl::binary")(without_post_ops(), match::used_once())));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = ins->inputs().front();
auto x = x_ins->get_operator();
if(workaround_dnnl_broken_post_ops(x, ins->get_operator()))
return;
auto op = merge_post_ops(x, ins->get_operator());
auto inputs = x_ins->inputs();
inputs.back() = ins->inputs().back();
if(ins->name() == "dnnl::binary")
inputs.insert(std::prev(inputs.end()), ins->inputs().at(1));
auto input_shapes = to_shapes(inputs);
auto new_shape = try_compute_shape(op, input_shapes);
if(new_shape.empty() or new_shape.front() != ins->get_shape())
return;
auto info = compile(op, *ctx, new_shape.front(), input_shapes);
if(info.contains("impl") and starts_with(info.at("impl").to<std::string>(), "ref:"))
return;
m.replace_instruction(ins, op, inputs);
}
};
void fuse_ops::apply(module& m) const
{
for(std::size_t i = 0; i < 4; i++)
{
match::find_matches(m, find_post_ops{ctx});
dead_code_elimination{}.apply(m);
}
}
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -9,6 +9,7 @@
#include <unordered_map>
#include <dnnl.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/assert.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -41,11 +42,50 @@ dnnl::memory to_dnnl_memory(const argument& a);
dnnl::algorithm to_dnnl_algo(const std::string& name);
std::string to_string(const dnnl::algorithm& algo);
struct post_op : reflect_equality<post_op>, reflect_stream<post_op>
{
std::string algo;
float alpha = 0;
float beta = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.algo, "algo"), f(self.alpha, "alpha"), f(self.beta, "beta"));
}
};
template <class Derived, class Primitive>
struct dnnl_op : auto_register_op<Derived>
{
std::vector<post_op> post_ops;
std::function<argument(context& ctx, const std::vector<argument>& args)> execute;
template <class Self, class F>
static auto reflect_base(Self& self, F f)
{
return pack(f(self.post_ops, "post_ops"));
}
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return reflect_base(self, f);
}
std::size_t get_extra_post_op_args() const
{
return std::count_if(post_ops.begin(), post_ops.end(), [](const auto& po) {
return contains(po.algo, "binary");
});
}
static std::size_t get_binary_post_op_arg(std::size_t pos)
{
return DNNL_ARG_ATTR_MULTIPLE_POST_OP(pos) | DNNL_ARG_SRC_1; // NOLINT
}
static std::vector<shape> to_shapes(const std::vector<argument>& args)
{
std::vector<shape> shapes(args.size());
......@@ -54,6 +94,13 @@ struct dnnl_op : auto_register_op<Derived>
});
return shapes;
}
static std::string impl(const Primitive& prim)
{
auto desc = prim.get_primitive_desc();
const char* str = nullptr;
dnnl_primitive_desc_query(desc, dnnl_query_impl_info_str, 0, &str);
return str == nullptr ? "" : str;
}
// Map arg index to arg in dnnl
std::vector<int> arg_map(int size) const
{
......@@ -81,14 +128,44 @@ struct dnnl_op : auto_register_op<Derived>
}
return s;
}
template <class F>
void for_each_post_op(F f) const
{
int i = 0;
for(auto&& op : post_ops)
{
if(contains(op.algo, "binary"))
{
f(op, get_binary_post_op_arg(i));
}
else
{
f(op, -1);
}
i++;
}
}
shape adjust_shape(const shape& s, int) const { return base_adjust_shape(s); }
std::vector<int> create_arg_map(std::size_t input_size) const
{
const auto& self = static_cast<const Derived&>(*this);
auto npost_ops = get_extra_post_op_args();
auto prim_input_size = input_size - npost_ops;
auto m = self.arg_map(prim_input_size);
for_each_post_op([&](auto&&, auto arg) {
if(arg < 0)
return;
m.push_back(arg);
});
return m;
}
std::unordered_map<int, dnnl::memory::desc>
to_memory_desc(const shape& output_shape, const std::vector<shape>& inputs) const
{
const auto& self = static_cast<const Derived&>(*this);
std::unordered_map<int, dnnl::memory::desc> result;
result[DNNL_ARG_DST] = to_dnnl_memory_desc(self.adjust_shape(output_shape, inputs.size()));
auto m = self.arg_map(inputs.size());
auto m = create_arg_map(inputs.size());
assert(m.size() >= inputs.size());
for(int i = 0; i < inputs.size(); i++)
{
......@@ -96,17 +173,44 @@ struct dnnl_op : auto_register_op<Derived>
}
return result;
}
dnnl::primitive_attr
get_primitive_attr(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
dnnl::primitive_attr result;
dnnl::post_ops po;
for_each_post_op([&](auto&& op, auto arg) {
if(contains(op.algo, "binary_add"))
{
auto desc = m.at(arg);
if(desc == m.at(DNNL_ARG_DST))
po.append_sum(1.0f);
else
po.append_binary(to_dnnl_algo(op.algo), m.at(arg));
}
else if(contains(op.algo, "binary"))
{
po.append_binary(to_dnnl_algo(op.algo), m.at(arg));
}
else if(contains(op.algo, "eltwise"))
po.append_eltwise(1.0f, to_dnnl_algo(op.algo), op.alpha, op.beta);
else
MIGRAPHX_THROW("Unknown post op algo: " + op.algo);
});
result.set_post_ops(po);
return result;
}
template <class T>
auto get_primitive_desc(const T& desc) const
-> decltype(typename Primitive::primitive_desc(desc, get_dnnl_context().engine))
auto get_primitive_desc(const T& desc, const dnnl::primitive_attr& attr) const
-> decltype(typename Primitive::primitive_desc(desc, attr, get_dnnl_context().engine))
{
return typename Primitive::primitive_desc(desc, get_dnnl_context().engine);
return typename Primitive::primitive_desc(desc, attr, get_dnnl_context().engine);
}
Primitive get_primitive(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
const auto& self = static_cast<const Derived&>(*this);
auto desc = self.get_desc(m);
auto pd = self.get_primitive_desc(desc);
auto attr = MIGRAPHX_ASSERT_NO_THROW(this->get_primitive_attr(m));
auto pd = self.get_primitive_desc(desc, attr);
return Primitive(pd);
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
......@@ -118,6 +222,15 @@ struct dnnl_op : auto_register_op<Derived>
{
return shapes.size() - 1;
}
value compile(context&, const shape& output_shape, std::vector<shape> inputs)
{
// Compensate for allocation
inputs.pop_back();
auto md = to_memory_desc(output_shape, inputs);
auto prim = get_primitive(md);
auto impl_name = impl(prim);
return {{"impl", impl_name}};
}
void finalize(context&, const shape& output_shape, std::vector<shape> inputs)
{
......@@ -127,8 +240,11 @@ struct dnnl_op : auto_register_op<Derived>
auto name = self.name();
auto md = to_memory_desc(output_shape, inputs);
auto prim = get_primitive(md);
auto arg_lookup = self.arg_map(inputs.size());
execute = [=](context&, const std::vector<argument>& args) {
auto arg_lookup = create_arg_map(inputs.size());
#ifndef NDEBUG
auto prim_attr = get_primitive_attr(md);
#endif
execute = [=](context&, const std::vector<argument>& args) {
#ifndef NDEBUG
// Check that the memory descriptors have not changed
auto debug_args = args;
......@@ -144,6 +260,53 @@ struct dnnl_op : auto_register_op<Derived>
MIGRAPHX_THROW(name +
": Memory descriptor has changed for: " + std::to_string(p.first));
}
// Check post_ops args are correct
auto pos = prim_attr.get_post_ops();
auto prim_input_size = inputs.size() - this->get_extra_post_op_args();
int j = 0;
for(int i = 0; i < pos.len(); i++)
{
auto arg = j + prim_input_size;
auto kind = pos.kind(i);
std::string mesg =
"Post op " + std::to_string(i) + "@" + std::to_string(arg) + ": ";
try
{
dnnl::algorithm algo;
dnnl::memory::desc mdesc;
float scale = 0;
float alpha = 0;
float beta = 0;
if(kind == dnnl::primitive::kind::binary)
{
pos.get_params_binary(i, algo, mdesc);
if(mdesc != md.at(arg_lookup.at(arg)))
MIGRAPHX_THROW(mesg +
"Memory descriptor doesn't match for binary post op");
j++;
}
else if(kind == dnnl::primitive::kind::eltwise)
{
pos.get_params_eltwise(i, scale, algo, alpha, beta);
}
else if(kind == dnnl::primitive::kind::sum)
{
pos.get_params_sum(i, scale);
algo = dnnl::algorithm::binary_add;
}
else
{
MIGRAPHX_THROW("Unknown kind");
}
if(to_dnnl_algo(post_ops[i].algo) != algo)
MIGRAPHX_THROW(mesg + "Algorithm doesn't match for post op " +
post_ops[i].algo + " != " + to_string(algo));
}
catch(const dnnl::error& e)
{
MIGRAPHX_THROW(mesg + "Failed to get post ops argument " + ": " + e.what());
}
}
#endif
std::unordered_map<int, dnnl::memory> m;
m[DNNL_ARG_DST] = to_dnnl_memory(md.at(DNNL_ARG_DST), args.back());
......@@ -153,6 +316,11 @@ struct dnnl_op : auto_register_op<Derived>
return args.back();
};
}
std::vector<shape> trim_post_op_inputs(const std::vector<shape>& inputs) const
{
auto prim_input_size = inputs.size() - this->get_extra_post_op_args();
return {inputs.begin(), inputs.begin() + prim_input_size};
}
};
template <class Derived, class Primitive, class Op>
......@@ -163,7 +331,7 @@ struct dnnl_extend_op : dnnl_op<Derived, Primitive>
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
return pack_join(self.reflect_base(self, f), migraphx::reflect(self.op, f));
}
// dnnl has some issues with non-packed inputs
......@@ -176,7 +344,7 @@ struct dnnl_extend_op : dnnl_op<Derived, Primitive>
// Compensate for allocation
inputs.pop_back();
self.required(check_shapes(inputs, self));
auto r = migraphx::compute_shape(op, inputs);
auto r = migraphx::compute_shape(op, this->trim_post_op_inputs(inputs));
// Call to get_primitive to make sure an algo is available
this->get_primitive(this->to_memory_desc(r, inputs));
return r;
......
#ifndef MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP
#define MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP
#include <migraphx/config.hpp>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace cpu {
struct context;
struct fuse_ops
{
context* ctx = nullptr;
std::string name() const { return "cpu::fuse_ops"; }
void apply(module& m) const;
};
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP
......@@ -15,7 +15,7 @@ namespace cpu {
struct target
{
std::string name() const;
std::vector<pass> get_passes(migraphx::context& ctx, const compile_options&) const;
std::vector<pass> get_passes(migraphx::context& gctx, const compile_options&) const;
migraphx::context get_context() const { return context{}; }
argument copy_to(const argument& arg) const { return arg; }
......
......@@ -20,7 +20,7 @@ struct dnnl_layernorm : dnnl_op<dnnl_layernorm, dnnl::layer_normalization_forwar
{
// Compensate for allocation
inputs.pop_back();
check_shapes{inputs, *this}.has(1);
check_shapes{this->trim_post_op_inputs(inputs), *this}.has(1);
auto s = inputs.at(0);
// Call to get_primitive to make sure an algo is available
this->get_primitive(this->to_memory_desc(s, inputs));
......
......@@ -351,7 +351,7 @@ struct cpu_apply
std::back_inserter(inputs),
[&](const auto& s) { return r.instructions.at(s); });
inputs.push_back(this->insert_allocation(ins, ins->get_shape()));
this->modl->replace_instruction(ins, op, inputs);
modl->replace_instruction(ins, op, inputs);
});
}
......
......@@ -12,7 +12,8 @@ struct dnnl_reduction : dnnl_op<dnnl_reduction, dnnl::reduction>
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.algo, "algo"), f(self.axes, "axes"));
return pack_join(self.reflect_base(self, f),
pack(f(self.algo, "algo"), f(self.axes, "axes")));
}
std::string name() const { return "dnnl::reduction"; }
......@@ -21,7 +22,7 @@ struct dnnl_reduction : dnnl_op<dnnl_reduction, dnnl::reduction>
{
// Compensate for allocation
inputs.pop_back();
check_shapes{inputs, *this}.has(1).standard();
check_shapes{this->trim_post_op_inputs(inputs), *this}.has(1).standard();
auto s = inputs.at(0);
auto lens = s.lens();
for(auto axis : axes)
......
......@@ -7,12 +7,6 @@ namespace cpu {
struct dnnl_reorder : dnnl_op<dnnl_reorder, dnnl::reorder>
{
template <class Self, class F>
static auto reflect(Self&, F)
{
return pack();
}
std::string name() const { return "dnnl::reorder"; }
shape adjust_shape(const shape& x, int) const { return x; }
......@@ -20,7 +14,10 @@ struct dnnl_reorder : dnnl_op<dnnl_reorder, dnnl::reorder>
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2);
return inputs.back();
auto r = inputs.back();
// Call to get_primitive to make sure an algo is available
this->get_primitive(this->to_memory_desc(r, inputs));
return r;
}
// Custom desc class since its missing in dnnl
struct desc
......@@ -33,10 +30,10 @@ struct dnnl_reorder : dnnl_op<dnnl_reorder, dnnl::reorder>
return {m.at(DNNL_ARG_SRC), m.at(DNNL_ARG_DST)};
}
auto get_primitive_desc(const desc& d) const
auto get_primitive_desc(const desc& d, const dnnl::primitive_attr& attr) const
{
auto& engine = get_dnnl_context().engine;
return dnnl::reorder::primitive_desc(engine, d.src, engine, d.dst);
return dnnl::reorder::primitive_desc(engine, d.src, engine, d.dst, attr);
}
};
......
......@@ -22,9 +22,11 @@
#include <migraphx/memory_coloring.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/cpu/fuse_ops.hpp>
#include <migraphx/cpu/write_literals.hpp>
#include <migraphx/cpu/allocation_model.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/cpu/context.hpp>
#include <migraphx/cpu/lowering.hpp>
#include <migraphx/pass.hpp>
#include <migraphx/generate.hpp>
......@@ -36,8 +38,10 @@ namespace cpu {
std::string target::name() const { return "cpu"; }
std::vector<pass> target::get_passes(migraphx::context&, const compile_options&) const
// cppcheck-suppress constParameter
std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_options&) const
{
auto& ctx = any_cast<context>(gctx);
std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end());
unsupported_types.erase(shape::type_t::float_type);
return {normalize_ops{},
......@@ -67,6 +71,8 @@ std::vector<pass> target::get_passes(migraphx::context&, const compile_options&)
dead_code_elimination{},
adjust_allocation{cpu_allocation_model{}},
dead_code_elimination{},
fuse_ops{&ctx},
dead_code_elimination{},
write_literals{},
dead_code_elimination{},
memory_coloring{"cpu::allocate"},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment