Commit dd768142 authored by Artur Wojcik's avatar Artur Wojcik
Browse files

dnnl

parent cf02b661
...@@ -95,7 +95,7 @@ template <class Derived, class Primitive> ...@@ -95,7 +95,7 @@ template <class Derived, class Primitive>
struct dnnl_op : auto_register_op<Derived> struct dnnl_op : auto_register_op<Derived>
{ {
std::vector<post_op> post_ops; std::vector<post_op> post_ops;
std::function<argument(context& ctx, const std::vector<argument>& args)> execute; std::function<argument(context&, const std::vector<argument>&)> execute;
template <class Self, class F> template <class Self, class F>
static auto reflect_base(Self& self, F f) static auto reflect_base(Self& self, F f)
...@@ -284,7 +284,7 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -284,7 +284,7 @@ struct dnnl_op : auto_register_op<Derived>
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return static_cast<std::ptrdiff_t>(shapes.size() - 1);
} }
value compile(context&, const shape& output_shape, std::vector<shape> inputs) value compile(context&, const shape& output_shape, std::vector<shape> inputs)
{ {
...@@ -300,91 +300,128 @@ struct dnnl_op : auto_register_op<Derived> ...@@ -300,91 +300,128 @@ struct dnnl_op : auto_register_op<Derived>
{ {
// Compensate for allocation // Compensate for allocation
inputs.pop_back(); inputs.pop_back();
const auto& self = static_cast<const Derived&>(*this); auto md = to_memory_desc(output_shape, inputs);
auto name = self.name(); auto prim = get_primitive(md);
auto md = to_memory_desc(output_shape, inputs); auto arg_lookup = create_arg_map(inputs.size());
auto prim = get_primitive(md);
auto arg_lookup = create_arg_map(inputs.size());
#ifndef NDEBUG #ifndef NDEBUG
auto prim_attr = get_primitive_attr(md); // NOLINTNEXTLINE
execute = std::bind(&dnnl_op::internal,
this,
output_shape,
inputs,
md,
prim,
arg_lookup,
std::placeholders::_1,
std::placeholders::_2);
#else
// NOLINTNEXTLINE
execute = std::bind(&dnnl_op::internal,
this,
md,
prim,
arg_lookup,
std::placeholders::_1,
std::placeholders::_2);
#endif #endif
execute = [=](context&, const std::vector<argument>& args) { }
std::vector<shape> trim_post_op_inputs(const std::vector<shape>& inputs) const
{
auto prim_input_size = inputs.size() - this->get_extra_post_op_args();
return {inputs.begin(), inputs.begin() + prim_input_size};
}
private:
#ifndef NDEBUG #ifndef NDEBUG
// Check that the memory descriptors have not changed argument internal(const shape& output_shape,
auto debug_args = args; const std::vector<shape>& inputs,
debug_args.pop_back(); std::unordered_map<int, dnnl::memory::desc> md,
auto debug_md = to_memory_desc(output_shape, to_shapes(debug_args)); Primitive prim,
for(auto&& p : debug_md) std::vector<int> arg_lookup,
{ context&,
if(md.count(p.first) == 0) const std::vector<argument>& args)
MIGRAPHX_THROW(name + #else
": Missing memory descriptor for: " + std::to_string(p.first)); argument internal(std::unordered_map<int, dnnl::memory::desc> md,
if(p.second == md.at(p.first)) Primitive prim,
continue; std::vector<int> arg_lookup,
context&,
const std::vector<argument>& args)
#endif
{
#ifndef NDEBUG
const auto& self = static_cast<const Derived&>(*this);
auto name = self.name();
auto prim_attr = get_primitive_attr(md);
// Check that the memory descriptors have not changed
auto debug_args = args;
debug_args.pop_back();
auto debug_md = to_memory_desc(output_shape, to_shapes(debug_args));
for(auto&& p : debug_md)
{
if(md.count(p.first) == 0)
MIGRAPHX_THROW(name + MIGRAPHX_THROW(name +
": Memory descriptor has changed for: " + std::to_string(p.first)); ": Missing memory descriptor for: " + std::to_string(p.first));
} if(p.second == md.at(p.first))
// Check post_ops args are correct continue;
auto pos = prim_attr.get_post_ops(); MIGRAPHX_THROW(name +
auto prim_input_size = inputs.size() - this->get_extra_post_op_args(); ": Memory descriptor has changed for: " + std::to_string(p.first));
int j = 0; }
for(int i = 0; i < pos.len(); i++) // Check post_ops args are correct
auto pos = prim_attr.get_post_ops();
auto prim_input_size = inputs.size() - this->get_extra_post_op_args();
int j = 0;
for(int i = 0; i < pos.len(); i++)
{
auto arg = j + prim_input_size;
auto kind = pos.kind(i);
std::string mesg = "Post op " + std::to_string(i) + "@" + std::to_string(arg) + ": ";
try
{ {
auto arg = j + prim_input_size; dnnl::algorithm algo;
auto kind = pos.kind(i); dnnl::memory::desc mdesc;
std::string mesg = float scale = 0;
"Post op " + std::to_string(i) + "@" + std::to_string(arg) + ": "; float alpha = 0;
try float beta = 0;
if(kind == dnnl::primitive::kind::binary)
{
pos.get_params_binary(i, algo, mdesc);
if(mdesc != md.at(arg_lookup.at(arg)))
MIGRAPHX_THROW(mesg + "Memory descriptor doesn't match for binary "
"post op");
j++;
}
else if(kind == dnnl::primitive::kind::eltwise)
{
pos.get_params_eltwise(i, scale, algo, alpha, beta);
}
else if(kind == dnnl::primitive::kind::sum)
{ {
dnnl::algorithm algo; pos.get_params_sum(i, scale);
dnnl::memory::desc mdesc; algo = dnnl::algorithm::binary_add;
float scale = 0;
float alpha = 0;
float beta = 0;
if(kind == dnnl::primitive::kind::binary)
{
pos.get_params_binary(i, algo, mdesc);
if(mdesc != md.at(arg_lookup.at(arg)))
MIGRAPHX_THROW(mesg +
"Memory descriptor doesn't match for binary post op");
j++;
}
else if(kind == dnnl::primitive::kind::eltwise)
{
pos.get_params_eltwise(i, scale, algo, alpha, beta);
}
else if(kind == dnnl::primitive::kind::sum)
{
pos.get_params_sum(i, scale);
algo = dnnl::algorithm::binary_add;
}
else
{
MIGRAPHX_THROW("Unknown kind");
}
if(to_dnnl_algo(post_ops[i].algo) != algo)
MIGRAPHX_THROW(mesg + "Algorithm doesn't match for post op " +
post_ops[i].algo + " != " + to_string(algo));
} }
catch(const dnnl::error& e) else
{ {
MIGRAPHX_THROW(mesg + "Failed to get post ops argument " + ": " + e.what()); MIGRAPHX_THROW("Unknown kind");
} }
if(to_dnnl_algo(post_ops[i].algo) != algo)
MIGRAPHX_THROW(mesg + "Algorithm doesn't match for post op " +
post_ops[i].algo + " != " + to_string(algo));
}
catch(const dnnl::error& e)
{
MIGRAPHX_THROW(mesg + "Failed to get post ops argument " + ": " + e.what());
} }
}
#endif #endif
std::unordered_map<int, dnnl::memory> m; std::unordered_map<int, dnnl::memory> m;
m[MIGRAPHX_DNNL_PREFIX(ARG_DST)] = m[MIGRAPHX_DNNL_PREFIX(ARG_DST)] =
to_dnnl_memory(md.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), args.back()); to_dnnl_memory(md.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), args.back());
for(int i = 0; i < args.size() - 1; i++) for(int i = 0; i < args.size() - 1; i++)
m[arg_lookup[i]] = to_dnnl_memory(md.at(arg_lookup[i]), args[i]); m[arg_lookup[i]] = to_dnnl_memory(md.at(arg_lookup[i]), args[i]);
prim.execute(get_dnnl_context().stream, m); prim.execute(get_dnnl_context().stream, m);
return args.back(); return args.back();
};
}
std::vector<shape> trim_post_op_inputs(const std::vector<shape>& inputs) const
{
auto prim_input_size = inputs.size() - this->get_extra_post_op_args();
return {inputs.begin(), inputs.begin() + prim_input_size};
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment