Commit f9437603 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents 781ce146 658cdab0
#include <migraphx/process.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/env.hpp>
#include <functional>
#include <iostream>
#include <unistd.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_CMD_EXECUTE)
std::function<void(const char*)> redirect_to(std::ostream& os)
{
return [&](const char* x) { os << x; };
}
int exec(const std::string& cmd, const std::function<void(const char*)>& std_out)
{
int ec = 0;
if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{}))
std::cout << cmd << std::endl;
std::array<char, 128> buffer;
auto closer = [&](FILE* stream) {
auto status = pclose(stream);
ec = WIFEXITED(status) ? 0 : WEXITSTATUS(status); // NOLINT
};
{
// TODO: Use execve instead of popen
std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), "r"), closer); // NOLINT
if(!pipe)
MIGRAPHX_THROW("popen() failed: " + cmd);
while(fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr)
std_out(buffer.data());
}
return ec;
}
struct process_impl
{
std::string command{};
fs::path cwd{};
std::string get_command() const
{
std::string result;
if(not cwd.empty())
result += "cd " + cwd.string() + "; ";
result += command;
return result;
}
};
process::process(const std::string& cmd) : impl(std::make_unique<process_impl>())
{
impl->command = cmd;
}
process::process(process&&) noexcept = default;
process& process::operator=(process rhs)
{
std::swap(impl, rhs.impl);
return *this;
}
process::~process() noexcept = default;
process& process::cwd(const fs::path& p)
{
impl->cwd = p;
return *this;
}
void process::exec()
{
auto ec = migraphx::exec(impl->get_command(), redirect_to(std::cout));
if(ec != 0)
MIGRAPHX_THROW("Command " + impl->get_command() + " exited with status " +
std::to_string(ec));
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -158,6 +158,13 @@ void program::compile(const target& t, compile_options options) ...@@ -158,6 +158,13 @@ void program::compile(const target& t, compile_options options)
MIGRAPHX_THROW("Invalid module " + mod->name() + " from compilation at instruction " + MIGRAPHX_THROW("Invalid module " + mod->name() + " from compilation at instruction " +
std::to_string(std::distance(mod->begin(), invalid))); std::to_string(std::distance(mod->begin(), invalid)));
} }
auto dangling = mod->find_dangling_reference();
if(dangling != mod->end())
{
auto index = std::distance(mod->begin(), dangling);
MIGRAPHX_THROW("Dangling reference in module " + mod->name() + " from instruction " +
std::to_string(index));
}
mod->finalize(this->impl->ctx); mod->finalize(this->impl->ctx);
} }
} }
...@@ -249,7 +256,6 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -249,7 +256,6 @@ std::vector<argument> generic_eval(const module* mod,
} }
assert(results.find(ins) != results.end()); assert(results.find(ins) != results.end());
} }
return {results.at(std::prev(mod->end()))}; return {results.at(std::prev(mod->end()))};
} }
......
...@@ -254,7 +254,6 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -254,7 +254,6 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("__repr__", [](const migraphx::module& mm) { return migraphx::to_string(mm); }); .def("__repr__", [](const migraphx::module& mm) { return migraphx::to_string(mm); });
py::class_<migraphx::program>(m, "program") py::class_<migraphx::program>(m, "program")
.def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); })
.def("get_parameter_names", &migraphx::program::get_parameter_names) .def("get_parameter_names", &migraphx::program::get_parameter_names)
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes) .def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("get_output_shapes", &migraphx::program::get_output_shapes) .def("get_output_shapes", &migraphx::program::get_output_shapes)
......
...@@ -88,6 +88,29 @@ const std::vector<shape::type_t>& shape::types() ...@@ -88,6 +88,29 @@ const std::vector<shape::type_t>& shape::types()
return result; return result;
} }
std::string shape::name(shape::type_t t)
{
switch(t)
{
#define MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE(x, t) \
case x: return #x;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE
}
MIGRAPHX_THROW("Invalid type");
}
std::string shape::cpp_type(shape::type_t t)
{
switch(t)
{
#define MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE(x, t) \
case x: return #t;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE
}
MIGRAPHX_THROW("Invalid type");
}
shape::shape() : impl(shape_impl::default_shape()) {} shape::shape() : impl(shape_impl::default_shape()) {}
shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {} shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {}
...@@ -246,17 +269,7 @@ shape shape::with_lens(const std::vector<std::size_t>& l) const ...@@ -246,17 +269,7 @@ shape shape::with_lens(const std::vector<std::size_t>& l) const
std::size_t shape::element_space() const { return impl->element_space(); } std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const std::string shape::type_string() const { return name(this->type()); }
{
switch(this->type())
{
#define MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_CASE(x, t) \
case x: return #x;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_CASE
}
MIGRAPHX_THROW("Invalid type");
}
bool operator==(const shape& x, const shape& y) bool operator==(const shape& x, const shape& y)
{ {
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/algorithm.hpp> #include <migraphx/algorithm.hpp>
...@@ -39,13 +38,6 @@ auto conv_const_weights() ...@@ -39,13 +38,6 @@ auto conv_const_weights()
match::args(match::any(), match::is_constant().bind("w"))); match::args(match::any(), match::is_constant().bind("w")));
} }
template <class... Ms>
auto pointwise(Ms... ms)
{
return match::has_attribute("pointwise")(match::any_of(match::nargs(1), match::nargs(2)),
ms...);
}
auto reduction() { return match::name_contains("reduce"); } auto reduction() { return match::name_contains("reduce"); }
struct find_mul_conv struct find_mul_conv
...@@ -287,7 +279,7 @@ struct find_concat_op ...@@ -287,7 +279,7 @@ struct find_concat_op
auto matcher() const auto matcher() const
{ {
return match::name("concat")(match::any_of[match::inputs()]( return match::name("concat")(match::any_of[match::inputs()](
match::any_of(pointwise(), match::name("broadcast")), match::used_once())); match::any_of(match::pointwise(), match::name("broadcast")), match::used_once()));
} }
template <class Iterator> template <class Iterator>
...@@ -407,8 +399,8 @@ struct find_splits ...@@ -407,8 +399,8 @@ struct find_splits
{ {
auto matcher() const auto matcher() const
{ {
return match::any(match::any_of[match::outputs()]( return match::any(match::any_of[match::outputs()](match::name("slice")(
match::name("slice")(match::any_of[match::outputs()](pointwise(), reduction())))); match::any_of[match::outputs()](match::pointwise(), reduction()))));
} }
static std::vector<std::vector<instruction_ref>> static std::vector<std::vector<instruction_ref>>
......
#include <iterator>
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -318,11 +319,233 @@ struct find_nested_concat ...@@ -318,11 +319,233 @@ struct find_nested_concat
} }
}; };
struct find_resize
{
auto matcher() const
{
return match::name("gather")(
match::args(match::name("reshape").bind("data"), match::is_constant().bind("ind")));
}
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto ins_rsp = r.instructions["data"];
auto ins_ind = r.instructions["ind"];
// resize input shape
if(ins_rsp->get_shape().lens().size() != 1)
{
return;
}
// resize output shape
const auto& in_shape = ins_rsp->inputs().front()->get_shape();
const auto& out_shape = ins->get_shape();
// check if output shape is multiple of input shape
const auto& in_lens = in_shape.lens();
const auto& out_lens = out_shape.lens();
if(in_lens.size() != out_lens.size())
{
return;
}
// output shape must be multiple of input shape
std::vector<bool> is_multi(in_lens.size());
std::transform(
in_lens.begin(), in_lens.end(), out_lens.begin(), is_multi.begin(), [](auto x, auto y) {
return (y % x == 0);
});
if(not std::all_of(is_multi.begin(), is_multi.end(), [](auto b) { return b; }))
{
return;
}
// output must be multiple of inputs
std::vector<std::size_t> scales(in_lens.size());
std::transform(
in_lens.begin(), in_lens.end(), out_lens.begin(), scales.begin(), [](auto x, auto y) {
return y / x;
});
// if ind is not constant, cannot optimize
std::vector<int> vec_ind;
auto arg_ind = ins_ind->eval();
if(arg_ind.empty())
{
return;
}
arg_ind.visit([&](auto v) { vec_ind.assign(v.begin(), v.end()); });
std::vector<int> index(out_shape.elements());
std::iota(index.begin(), index.end(), 0);
if(not std::all_of(index.begin(), index.end(), [&](auto i) {
auto out_idx = out_shape.multi(i);
auto in_idx = out_idx;
std::transform(out_idx.begin(),
out_idx.end(),
scales.begin(),
in_idx.begin(),
[&](auto io, auto scale) { return io - (io % scale); });
return vec_ind[i] == vec_ind[out_shape.index(in_idx)];
}))
{
return;
}
// wrap up shapes for multibroadcast
std::vector<std::pair<std::size_t, std::size_t>> dim_scales;
std::transform(in_lens.begin(),
in_lens.end(),
out_lens.begin(),
std::back_inserter(dim_scales),
[](auto x, auto y) { return std::make_pair(x, y / x); });
std::vector<int64_t> in_dims;
std::vector<int64_t> out_dims;
for(auto& isp : dim_scales)
{
in_dims.push_back(isp.first);
out_dims.push_back(isp.first * isp.second);
if(isp.first == 1 or isp.second == 1)
{
continue;
}
out_dims.back() = isp.first;
in_dims.push_back(1);
out_dims.push_back(isp.second);
}
auto in_rsp = ins_rsp->inputs().front();
auto rsp_data = p.insert_instruction(
ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp);
auto mb_rsp = p.insert_instruction(
ins_rsp, migraphx::make_op("multibroadcast", {{"output_lens", out_dims}}), rsp_data);
auto std_mb = p.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp);
std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end());
p.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb);
}
};
struct find_where_op
{
auto matcher() const
{
return match::name("gather")(
match::args(match::name("reshape")(match::arg(0)(match::name("concat").bind("data"))),
match::is_constant().bind("ind")));
}
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto concat = r.instructions["data"];
auto ins_ind = r.instructions["ind"];
std::vector<bool> vec_ind;
auto arg_ind = ins_ind->eval();
arg_ind.visit([&](auto v) { vec_ind.assign(v.begin(), v.end()); });
// ind has to be the same value
auto val = vec_ind.front();
if(not std::all_of(vec_ind.begin(), vec_ind.end(), [&](auto v) { return (v == val); }))
{
return;
}
// concat axis must be 0
auto op = any_cast<op::concat>(concat->get_operator());
if(op.axis != 0)
{
return;
}
// check concat inputs, it has to be 2 and have the same shape
const auto& inputs = concat->inputs();
if(inputs.size() != 2)
{
return;
}
if(inputs.at(0)->get_shape() != inputs.at(1)->get_shape())
{
return;
}
if(inputs.at(0)->get_shape().lens() != ins_ind->get_shape().lens())
{
return;
}
if(val)
{
p.replace_instruction(ins, inputs.at(0));
}
else
{
p.replace_instruction(ins, inputs.at(1));
}
}
};
struct find_reshape_cont
{
auto matcher() const
{
return match::pointwise(
match::nargs(2),
match::either_arg(0, 1)(
match::name("reshape")(match::args(match::name("contiguous").bind("cont")))
.bind("rsp"),
match::any()));
}
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto ins_cont = r.instructions["cont"];
auto in_ins = r.instructions["rsp"];
auto cont_input = ins_cont->inputs().front();
auto lens = cont_input->get_shape().lens();
std::vector<int64_t> dims(lens.begin(), lens.end());
if(in_ins->get_shape() != ins->get_shape())
{
return;
}
if(not std::all_of(ins->inputs().begin(), ins->inputs().end(), [](auto i) {
return i->get_shape().standard();
}))
{
return;
}
auto out_lens = ins->get_shape().lens();
std::vector<int64_t> out_dims(out_lens.begin(), out_lens.end());
std::vector<instruction_ref> inputs;
for(const auto& in : ins->inputs())
{
if(in == in_ins)
{
inputs.push_back(cont_input);
}
else
{
inputs.push_back(
p.insert_instruction(ins, make_op("reshape", {{"dims", dims}}), in));
}
}
auto out = p.insert_instruction(ins, ins->get_operator(), inputs);
p.replace_instruction(ins, make_op("reshape", {{"dims", out_dims}}), out);
}
};
void simplify_reshapes::apply(module& p) const void simplify_reshapes::apply(module& p) const
{ {
for(int i = 0; i < 2; i++) for(int i = 0; i < 2; i++)
{ {
match::find_matches(p, match::find_matches(p,
find_where_op{},
find_resize{},
find_reshape_cont{},
find_nop_reshapes{}, find_nop_reshapes{},
find_reshaper{}, find_reshaper{},
find_transpose{}, find_transpose{},
......
...@@ -12,6 +12,7 @@ add_library(migraphx_cpu ...@@ -12,6 +12,7 @@ add_library(migraphx_cpu
dnnl.cpp dnnl.cpp
eltwise.cpp eltwise.cpp
erf.cpp erf.cpp
fuse_ops.cpp
gather.cpp gather.cpp
gemm.cpp gemm.cpp
layernorm.cpp layernorm.cpp
......
...@@ -11,7 +11,7 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary> ...@@ -11,7 +11,7 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary>
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.algo, "algo")); return pack_join(self.reflect_base(self, f), pack(f(self.algo, "algo")));
} }
std::string name() const { return "dnnl::binary"; } std::string name() const { return "dnnl::binary"; }
...@@ -20,7 +20,7 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary> ...@@ -20,7 +20,7 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary>
{ {
// Compensate for allocation // Compensate for allocation
inputs.pop_back(); inputs.pop_back();
check_shapes{inputs, *this}.has(2); check_shapes{this->trim_post_op_inputs(inputs), *this}.has(2);
auto s0 = inputs.at(0); auto s0 = inputs.at(0);
auto s1 = inputs.at(1); auto s1 = inputs.at(1);
auto r = s0; auto r = s0;
......
...@@ -33,9 +33,9 @@ struct dnnl_concat : dnnl_extend_op<dnnl_concat, dnnl::concat, op::concat> ...@@ -33,9 +33,9 @@ struct dnnl_concat : dnnl_extend_op<dnnl_concat, dnnl::concat, op::concat>
return {m.at(DNNL_ARG_DST), std::size_t(op.axis), srcs}; return {m.at(DNNL_ARG_DST), std::size_t(op.axis), srcs};
} }
auto get_primitive_desc(const desc& d) const auto get_primitive_desc(const desc& d, const dnnl::primitive_attr& attr) const
{ {
return dnnl::concat::primitive_desc(d.dst, d.axis, d.srcs, get_dnnl_context().engine); return dnnl::concat::primitive_desc(d.dst, d.axis, d.srcs, get_dnnl_context().engine, attr);
} }
}; };
......
#include <migraphx/cpu/dnnl.hpp> #include <migraphx/cpu/dnnl.hpp>
#if defined(__GNUC__) && __GNUC__ <= 5
namespace std {
template <>
struct hash<dnnl::algorithm>
{
using argument_type = dnnl::algorithm;
using result_type = std::size_t;
result_type operator()(const argument_type& x) const noexcept
{
return std::hash<underlying_type_t<argument_type>>{}(
static_cast<underlying_type_t<argument_type>>(x));
}
};
} // namespace std
#endif
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace cpu { namespace cpu {
...@@ -139,6 +156,23 @@ dnnl::algorithm to_dnnl_algo(const std::string& name) ...@@ -139,6 +156,23 @@ dnnl::algorithm to_dnnl_algo(const std::string& name)
return dnnl_algo_map().at(name); return dnnl_algo_map().at(name);
} }
const std::unordered_map<dnnl::algorithm, std::string>& dnnl_algo_string_map()
{
static const std::unordered_map<dnnl::algorithm, std::string> m = {
#define MIGRAPHX_DNNL_ALGO_GENERATE_VISITOR(x) {dnnl::algorithm::x, #x},
MIGRAPHX_VISIT_DNNL_ALGO(MIGRAPHX_DNNL_ALGO_GENERATE_VISITOR)
#undef MIGRAPHX_DNNL_ALGO_GENERATE_VISITOR
};
return m;
}
std::string to_string(const dnnl::algorithm& algo)
{
if(dnnl_algo_string_map().count(algo) == 0)
return "unknown_" + std::to_string(static_cast<int>(algo));
return dnnl_algo_string_map().at(algo);
}
} // namespace cpu } // namespace cpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -13,7 +13,8 @@ struct dnnl_eltwise : dnnl_op<dnnl_eltwise, dnnl::eltwise_forward> ...@@ -13,7 +13,8 @@ struct dnnl_eltwise : dnnl_op<dnnl_eltwise, dnnl::eltwise_forward>
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.algo, "algo"), f(self.alpha, "alpha"), f(self.beta, "beta")); return pack_join(self.reflect_base(self, f),
pack(f(self.algo, "algo"), f(self.alpha, "alpha"), f(self.beta, "beta")));
} }
std::string name() const { return "dnnl::eltwise"; } std::string name() const { return "dnnl::eltwise"; }
...@@ -22,7 +23,7 @@ struct dnnl_eltwise : dnnl_op<dnnl_eltwise, dnnl::eltwise_forward> ...@@ -22,7 +23,7 @@ struct dnnl_eltwise : dnnl_op<dnnl_eltwise, dnnl::eltwise_forward>
{ {
// Compensate for allocation // Compensate for allocation
inputs.pop_back(); inputs.pop_back();
check_shapes{inputs, *this}.has(1).packed(); check_shapes{this->trim_post_op_inputs(inputs), *this}.has(1).packed();
auto s = inputs.at(0); auto s = inputs.at(0);
auto r = s; auto r = s;
if(not s.packed()) if(not s.packed())
......
#include <migraphx/cpu/fuse_ops.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/value.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/context.hpp>
#include <migraphx/env.hpp>
#include <migraphx/cpu/context.hpp>
#include <migraphx/dead_code_elimination.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_DNNL_POST_OPS_WORKAROUND);
MIGRAPHX_PRED_MATCHER(has_post_ops, instruction_ref ins)
{
auto v = ins->get_operator().to_value();
return v.contains("post_ops");
}
MIGRAPHX_PRED_MATCHER(without_post_ops, instruction_ref ins)
{
auto v = ins->get_operator().to_value();
return v.contains("post_ops") and v["post_ops"].empty();
}
bool workaround_dnnl_broken_post_ops(const operation& op, const operation& post_op)
{
if(contains({"dnnl::dot", "dnnl::convolution"}, op.name()))
return true;
auto pv = post_op.to_value();
if(not pv.at("post_ops").empty())
return true;
auto v = op.to_value();
auto last_op = v.at("post_ops").empty() ? v : v.at("post_ops").back();
auto algo = last_op.contains("algo") ? last_op.at("algo").to<std::string>() : op.name();
auto post_algo = pv["algo"].to<std::string>();
if(starts_with(algo, "eltwise") and starts_with(post_algo, "eltwise"))
return true;
if(algo == post_algo)
return true;
return false;
}
operation merge_post_ops(const operation& op, const operation& post_op)
{
auto pv = post_op.to_value();
auto v = op.to_value();
v["post_ops"].push_back({{"algo", pv["algo"]},
{"alpha", pv["alpha"].value_or(0.0f)},
{"beta", pv["beta"].value_or(0.0f)}});
auto post_ops = pv.at("post_ops");
for(const auto& po : post_ops)
v["post_ops"].push_back(po);
return make_op(op.name(), v);
}
struct find_post_ops
{
context* ctx = nullptr;
match::any_matcher matcher() const
{
if(enabled(MIGRAPHX_DISABLE_DNNL_POST_OPS_WORKAROUND{}))
return match::name("dnnl::eltwise",
"dnnl::binary")(match::arg(0)(has_post_ops(), match::used_once()));
else
return match::name("dnnl::eltwise")(
without_post_ops(),
match::arg(0)(match::name("dnnl::binary")(without_post_ops(), match::used_once())));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = ins->inputs().front();
auto x = x_ins->get_operator();
if(workaround_dnnl_broken_post_ops(x, ins->get_operator()))
return;
auto op = merge_post_ops(x, ins->get_operator());
auto inputs = x_ins->inputs();
inputs.back() = ins->inputs().back();
if(ins->name() == "dnnl::binary")
inputs.insert(std::prev(inputs.end()), ins->inputs().at(1));
auto input_shapes = to_shapes(inputs);
auto new_shape = try_compute_shape(op, input_shapes);
if(new_shape.empty() or new_shape.front() != ins->get_shape())
return;
auto info = compile(op, *ctx, new_shape.front(), input_shapes);
if(info.contains("impl") and starts_with(info.at("impl").to<std::string>(), "ref:"))
return;
m.replace_instruction(ins, op, inputs);
}
};
void fuse_ops::apply(module& m) const
{
for(std::size_t i = 0; i < 4; i++)
{
match::find_matches(m, find_post_ops{ctx});
dead_code_elimination{}.apply(m);
}
}
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <unordered_map> #include <unordered_map>
#include <dnnl.hpp> #include <dnnl.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
#include <migraphx/assert.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -41,11 +42,50 @@ dnnl::memory to_dnnl_memory(const argument& a); ...@@ -41,11 +42,50 @@ dnnl::memory to_dnnl_memory(const argument& a);
dnnl::algorithm to_dnnl_algo(const std::string& name); dnnl::algorithm to_dnnl_algo(const std::string& name);
std::string to_string(const dnnl::algorithm& algo);
struct post_op : reflect_equality<post_op>, reflect_stream<post_op>
{
std::string algo;
float alpha = 0;
float beta = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.algo, "algo"), f(self.alpha, "alpha"), f(self.beta, "beta"));
}
};
template <class Derived, class Primitive> template <class Derived, class Primitive>
struct dnnl_op : auto_register_op<Derived> struct dnnl_op : auto_register_op<Derived>
{ {
std::vector<post_op> post_ops;
std::function<argument(context& ctx, const std::vector<argument>& args)> execute; std::function<argument(context& ctx, const std::vector<argument>& args)> execute;
template <class Self, class F>
static auto reflect_base(Self& self, F f)
{
return pack(f(self.post_ops, "post_ops"));
}
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return reflect_base(self, f);
}
std::size_t get_extra_post_op_args() const
{
return std::count_if(post_ops.begin(), post_ops.end(), [](const auto& po) {
return contains(po.algo, "binary");
});
}
static std::size_t get_binary_post_op_arg(std::size_t pos)
{
return DNNL_ARG_ATTR_MULTIPLE_POST_OP(pos) | DNNL_ARG_SRC_1; // NOLINT
}
static std::vector<shape> to_shapes(const std::vector<argument>& args) static std::vector<shape> to_shapes(const std::vector<argument>& args)
{ {
std::vector<shape> shapes(args.size()); std::vector<shape> shapes(args.size());
...@@ -54,6 +94,13 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -54,6 +94,13 @@ struct dnnl_op : auto_register_op<Derived>
}); });
return shapes; return shapes;
} }
static std::string impl(const Primitive& prim)
{
auto desc = prim.get_primitive_desc();
const char* str = nullptr;
dnnl_primitive_desc_query(desc, dnnl_query_impl_info_str, 0, &str);
return str == nullptr ? "" : str;
}
// Map arg index to arg in dnnl // Map arg index to arg in dnnl
std::vector<int> arg_map(int size) const std::vector<int> arg_map(int size) const
{ {
...@@ -81,14 +128,44 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -81,14 +128,44 @@ struct dnnl_op : auto_register_op<Derived>
} }
return s; return s;
} }
template <class F>
void for_each_post_op(F f) const
{
int i = 0;
for(auto&& op : post_ops)
{
if(contains(op.algo, "binary"))
{
f(op, get_binary_post_op_arg(i));
}
else
{
f(op, -1);
}
i++;
}
}
shape adjust_shape(const shape& s, int) const { return base_adjust_shape(s); } shape adjust_shape(const shape& s, int) const { return base_adjust_shape(s); }
std::vector<int> create_arg_map(std::size_t input_size) const
{
const auto& self = static_cast<const Derived&>(*this);
auto npost_ops = get_extra_post_op_args();
auto prim_input_size = input_size - npost_ops;
auto m = self.arg_map(prim_input_size);
for_each_post_op([&](auto&&, auto arg) {
if(arg < 0)
return;
m.push_back(arg);
});
return m;
}
std::unordered_map<int, dnnl::memory::desc> std::unordered_map<int, dnnl::memory::desc>
to_memory_desc(const shape& output_shape, const std::vector<shape>& inputs) const to_memory_desc(const shape& output_shape, const std::vector<shape>& inputs) const
{ {
const auto& self = static_cast<const Derived&>(*this); const auto& self = static_cast<const Derived&>(*this);
std::unordered_map<int, dnnl::memory::desc> result; std::unordered_map<int, dnnl::memory::desc> result;
result[DNNL_ARG_DST] = to_dnnl_memory_desc(self.adjust_shape(output_shape, inputs.size())); result[DNNL_ARG_DST] = to_dnnl_memory_desc(self.adjust_shape(output_shape, inputs.size()));
auto m = self.arg_map(inputs.size()); auto m = create_arg_map(inputs.size());
assert(m.size() >= inputs.size()); assert(m.size() >= inputs.size());
for(int i = 0; i < inputs.size(); i++) for(int i = 0; i < inputs.size(); i++)
{ {
...@@ -96,17 +173,44 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -96,17 +173,44 @@ struct dnnl_op : auto_register_op<Derived>
} }
return result; return result;
} }
dnnl::primitive_attr
get_primitive_attr(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
dnnl::primitive_attr result;
dnnl::post_ops po;
for_each_post_op([&](auto&& op, auto arg) {
if(contains(op.algo, "binary_add"))
{
auto desc = m.at(arg);
if(desc == m.at(DNNL_ARG_DST))
po.append_sum(1.0f);
else
po.append_binary(to_dnnl_algo(op.algo), m.at(arg));
}
else if(contains(op.algo, "binary"))
{
po.append_binary(to_dnnl_algo(op.algo), m.at(arg));
}
else if(contains(op.algo, "eltwise"))
po.append_eltwise(1.0f, to_dnnl_algo(op.algo), op.alpha, op.beta);
else
MIGRAPHX_THROW("Unknown post op algo: " + op.algo);
});
result.set_post_ops(po);
return result;
}
template <class T> template <class T>
auto get_primitive_desc(const T& desc) const auto get_primitive_desc(const T& desc, const dnnl::primitive_attr& attr) const
-> decltype(typename Primitive::primitive_desc(desc, get_dnnl_context().engine)) -> decltype(typename Primitive::primitive_desc(desc, attr, get_dnnl_context().engine))
{ {
return typename Primitive::primitive_desc(desc, get_dnnl_context().engine); return typename Primitive::primitive_desc(desc, attr, get_dnnl_context().engine);
} }
Primitive get_primitive(const std::unordered_map<int, dnnl::memory::desc>& m) const Primitive get_primitive(const std::unordered_map<int, dnnl::memory::desc>& m) const
{ {
const auto& self = static_cast<const Derived&>(*this); const auto& self = static_cast<const Derived&>(*this);
auto desc = self.get_desc(m); auto desc = self.get_desc(m);
auto pd = self.get_primitive_desc(desc); auto attr = MIGRAPHX_ASSERT_NO_THROW(this->get_primitive_attr(m));
auto pd = self.get_primitive_desc(desc, attr);
return Primitive(pd); return Primitive(pd);
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
...@@ -118,6 +222,15 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -118,6 +222,15 @@ struct dnnl_op : auto_register_op<Derived>
{ {
return shapes.size() - 1; return shapes.size() - 1;
} }
value compile(context&, const shape& output_shape, std::vector<shape> inputs)
{
// Compensate for allocation
inputs.pop_back();
auto md = to_memory_desc(output_shape, inputs);
auto prim = get_primitive(md);
auto impl_name = impl(prim);
return {{"impl", impl_name}};
}
void finalize(context&, const shape& output_shape, std::vector<shape> inputs) void finalize(context&, const shape& output_shape, std::vector<shape> inputs)
{ {
...@@ -127,8 +240,11 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -127,8 +240,11 @@ struct dnnl_op : auto_register_op<Derived>
auto name = self.name(); auto name = self.name();
auto md = to_memory_desc(output_shape, inputs); auto md = to_memory_desc(output_shape, inputs);
auto prim = get_primitive(md); auto prim = get_primitive(md);
auto arg_lookup = self.arg_map(inputs.size()); auto arg_lookup = create_arg_map(inputs.size());
execute = [=](context&, const std::vector<argument>& args) { #ifndef NDEBUG
auto prim_attr = get_primitive_attr(md);
#endif
execute = [=](context&, const std::vector<argument>& args) {
#ifndef NDEBUG #ifndef NDEBUG
// Check that the memory descriptors have not changed // Check that the memory descriptors have not changed
auto debug_args = args; auto debug_args = args;
...@@ -144,6 +260,53 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -144,6 +260,53 @@ struct dnnl_op : auto_register_op<Derived>
MIGRAPHX_THROW(name + MIGRAPHX_THROW(name +
": Memory descriptor has changed for: " + std::to_string(p.first)); ": Memory descriptor has changed for: " + std::to_string(p.first));
} }
// Check post_ops args are correct
auto pos = prim_attr.get_post_ops();
auto prim_input_size = inputs.size() - this->get_extra_post_op_args();
int j = 0;
for(int i = 0; i < pos.len(); i++)
{
auto arg = j + prim_input_size;
auto kind = pos.kind(i);
std::string mesg =
"Post op " + std::to_string(i) + "@" + std::to_string(arg) + ": ";
try
{
dnnl::algorithm algo;
dnnl::memory::desc mdesc;
float scale = 0;
float alpha = 0;
float beta = 0;
if(kind == dnnl::primitive::kind::binary)
{
pos.get_params_binary(i, algo, mdesc);
if(mdesc != md.at(arg_lookup.at(arg)))
MIGRAPHX_THROW(mesg +
"Memory descriptor doesn't match for binary post op");
j++;
}
else if(kind == dnnl::primitive::kind::eltwise)
{
pos.get_params_eltwise(i, scale, algo, alpha, beta);
}
else if(kind == dnnl::primitive::kind::sum)
{
pos.get_params_sum(i, scale);
algo = dnnl::algorithm::binary_add;
}
else
{
MIGRAPHX_THROW("Unknown kind");
}
if(to_dnnl_algo(post_ops[i].algo) != algo)
MIGRAPHX_THROW(mesg + "Algorithm doesn't match for post op " +
post_ops[i].algo + " != " + to_string(algo));
}
catch(const dnnl::error& e)
{
MIGRAPHX_THROW(mesg + "Failed to get post ops argument " + ": " + e.what());
}
}
#endif #endif
std::unordered_map<int, dnnl::memory> m; std::unordered_map<int, dnnl::memory> m;
m[DNNL_ARG_DST] = to_dnnl_memory(md.at(DNNL_ARG_DST), args.back()); m[DNNL_ARG_DST] = to_dnnl_memory(md.at(DNNL_ARG_DST), args.back());
...@@ -153,6 +316,11 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -153,6 +316,11 @@ struct dnnl_op : auto_register_op<Derived>
return args.back(); return args.back();
}; };
} }
std::vector<shape> trim_post_op_inputs(const std::vector<shape>& inputs) const
{
auto prim_input_size = inputs.size() - this->get_extra_post_op_args();
return {inputs.begin(), inputs.begin() + prim_input_size};
}
}; };
template <class Derived, class Primitive, class Op> template <class Derived, class Primitive, class Op>
...@@ -163,7 +331,7 @@ struct dnnl_extend_op : dnnl_op<Derived, Primitive> ...@@ -163,7 +331,7 @@ struct dnnl_extend_op : dnnl_op<Derived, Primitive>
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return migraphx::reflect(self.op, f); return pack_join(self.reflect_base(self, f), migraphx::reflect(self.op, f));
} }
// dnnl has some issues with non-packed inputs // dnnl has some issues with non-packed inputs
...@@ -176,7 +344,7 @@ struct dnnl_extend_op : dnnl_op<Derived, Primitive> ...@@ -176,7 +344,7 @@ struct dnnl_extend_op : dnnl_op<Derived, Primitive>
// Compensate for allocation // Compensate for allocation
inputs.pop_back(); inputs.pop_back();
self.required(check_shapes(inputs, self)); self.required(check_shapes(inputs, self));
auto r = migraphx::compute_shape(op, inputs); auto r = migraphx::compute_shape(op, this->trim_post_op_inputs(inputs));
// Call to get_primitive to make sure an algo is available // Call to get_primitive to make sure an algo is available
this->get_primitive(this->to_memory_desc(r, inputs)); this->get_primitive(this->to_memory_desc(r, inputs));
return r; return r;
......
#ifndef MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP
#define MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP
#include <migraphx/config.hpp>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace cpu {
struct context;
struct fuse_ops
{
context* ctx = nullptr;
std::string name() const { return "cpu::fuse_ops"; }
void apply(module& m) const;
};
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_CPU_FUSE_OPS_HPP
...@@ -15,7 +15,7 @@ namespace cpu { ...@@ -15,7 +15,7 @@ namespace cpu {
struct target struct target
{ {
std::string name() const; std::string name() const;
std::vector<pass> get_passes(migraphx::context& ctx, const compile_options&) const; std::vector<pass> get_passes(migraphx::context& gctx, const compile_options&) const;
migraphx::context get_context() const { return context{}; } migraphx::context get_context() const { return context{}; }
argument copy_to(const argument& arg) const { return arg; } argument copy_to(const argument& arg) const { return arg; }
......
...@@ -20,7 +20,7 @@ struct dnnl_layernorm : dnnl_op<dnnl_layernorm, dnnl::layer_normalization_forwar ...@@ -20,7 +20,7 @@ struct dnnl_layernorm : dnnl_op<dnnl_layernorm, dnnl::layer_normalization_forwar
{ {
// Compensate for allocation // Compensate for allocation
inputs.pop_back(); inputs.pop_back();
check_shapes{inputs, *this}.has(1); check_shapes{this->trim_post_op_inputs(inputs), *this}.has(1);
auto s = inputs.at(0); auto s = inputs.at(0);
// Call to get_primitive to make sure an algo is available // Call to get_primitive to make sure an algo is available
this->get_primitive(this->to_memory_desc(s, inputs)); this->get_primitive(this->to_memory_desc(s, inputs));
......
...@@ -351,7 +351,7 @@ struct cpu_apply ...@@ -351,7 +351,7 @@ struct cpu_apply
std::back_inserter(inputs), std::back_inserter(inputs),
[&](const auto& s) { return r.instructions.at(s); }); [&](const auto& s) { return r.instructions.at(s); });
inputs.push_back(this->insert_allocation(ins, ins->get_shape())); inputs.push_back(this->insert_allocation(ins, ins->get_shape()));
this->modl->replace_instruction(ins, op, inputs); modl->replace_instruction(ins, op, inputs);
}); });
} }
......
...@@ -12,7 +12,8 @@ struct dnnl_reduction : dnnl_op<dnnl_reduction, dnnl::reduction> ...@@ -12,7 +12,8 @@ struct dnnl_reduction : dnnl_op<dnnl_reduction, dnnl::reduction>
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.algo, "algo"), f(self.axes, "axes")); return pack_join(self.reflect_base(self, f),
pack(f(self.algo, "algo"), f(self.axes, "axes")));
} }
std::string name() const { return "dnnl::reduction"; } std::string name() const { return "dnnl::reduction"; }
...@@ -21,7 +22,7 @@ struct dnnl_reduction : dnnl_op<dnnl_reduction, dnnl::reduction> ...@@ -21,7 +22,7 @@ struct dnnl_reduction : dnnl_op<dnnl_reduction, dnnl::reduction>
{ {
// Compensate for allocation // Compensate for allocation
inputs.pop_back(); inputs.pop_back();
check_shapes{inputs, *this}.has(1).standard(); check_shapes{this->trim_post_op_inputs(inputs), *this}.has(1).standard();
auto s = inputs.at(0); auto s = inputs.at(0);
auto lens = s.lens(); auto lens = s.lens();
for(auto axis : axes) for(auto axis : axes)
......
...@@ -7,12 +7,6 @@ namespace cpu { ...@@ -7,12 +7,6 @@ namespace cpu {
struct dnnl_reorder : dnnl_op<dnnl_reorder, dnnl::reorder> struct dnnl_reorder : dnnl_op<dnnl_reorder, dnnl::reorder>
{ {
template <class Self, class F>
static auto reflect(Self&, F)
{
return pack();
}
std::string name() const { return "dnnl::reorder"; } std::string name() const { return "dnnl::reorder"; }
shape adjust_shape(const shape& x, int) const { return x; } shape adjust_shape(const shape& x, int) const { return x; }
...@@ -20,7 +14,10 @@ struct dnnl_reorder : dnnl_op<dnnl_reorder, dnnl::reorder> ...@@ -20,7 +14,10 @@ struct dnnl_reorder : dnnl_op<dnnl_reorder, dnnl::reorder>
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
return inputs.back(); auto r = inputs.back();
// Call to get_primitive to make sure an algo is available
this->get_primitive(this->to_memory_desc(r, inputs));
return r;
} }
// Custom desc class since its missing in dnnl // Custom desc class since its missing in dnnl
struct desc struct desc
...@@ -33,10 +30,10 @@ struct dnnl_reorder : dnnl_op<dnnl_reorder, dnnl::reorder> ...@@ -33,10 +30,10 @@ struct dnnl_reorder : dnnl_op<dnnl_reorder, dnnl::reorder>
return {m.at(DNNL_ARG_SRC), m.at(DNNL_ARG_DST)}; return {m.at(DNNL_ARG_SRC), m.at(DNNL_ARG_DST)};
} }
auto get_primitive_desc(const desc& d) const auto get_primitive_desc(const desc& d, const dnnl::primitive_attr& attr) const
{ {
auto& engine = get_dnnl_context().engine; auto& engine = get_dnnl_context().engine;
return dnnl::reorder::primitive_desc(engine, d.src, engine, d.dst); return dnnl::reorder::primitive_desc(engine, d.src, engine, d.dst, attr);
} }
}; };
......
...@@ -22,9 +22,11 @@ ...@@ -22,9 +22,11 @@
#include <migraphx/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
#include <migraphx/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/cpu/fuse_ops.hpp>
#include <migraphx/cpu/write_literals.hpp> #include <migraphx/cpu/write_literals.hpp>
#include <migraphx/cpu/allocation_model.hpp> #include <migraphx/cpu/allocation_model.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/cpu/context.hpp>
#include <migraphx/cpu/lowering.hpp> #include <migraphx/cpu/lowering.hpp>
#include <migraphx/pass.hpp> #include <migraphx/pass.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
...@@ -36,8 +38,10 @@ namespace cpu { ...@@ -36,8 +38,10 @@ namespace cpu {
std::string target::name() const { return "cpu"; } std::string target::name() const { return "cpu"; }
std::vector<pass> target::get_passes(migraphx::context&, const compile_options&) const // cppcheck-suppress constParameter
std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_options&) const
{ {
auto& ctx = any_cast<context>(gctx);
std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end()); std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end());
unsupported_types.erase(shape::type_t::float_type); unsupported_types.erase(shape::type_t::float_type);
return {normalize_ops{}, return {normalize_ops{},
...@@ -67,6 +71,8 @@ std::vector<pass> target::get_passes(migraphx::context&, const compile_options&) ...@@ -67,6 +71,8 @@ std::vector<pass> target::get_passes(migraphx::context&, const compile_options&)
dead_code_elimination{}, dead_code_elimination{},
adjust_allocation{cpu_allocation_model{}}, adjust_allocation{cpu_allocation_model{}},
dead_code_elimination{}, dead_code_elimination{},
fuse_ops{&ctx},
dead_code_elimination{},
write_literals{}, write_literals{},
dead_code_elimination{}, dead_code_elimination{},
memory_coloring{"cpu::allocate"}, memory_coloring{"cpu::allocate"},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment