Commit b973563b authored by Paul's avatar Paul
Browse files

Format

parent 5a140c90
...@@ -136,16 +136,14 @@ std::string generate_pointwise(const module& pm, const std::string& name) ...@@ -136,16 +136,14 @@ std::string generate_pointwise(const module& pm, const std::string& name)
g.fmap([](const std::string& fname) { return "migraphx::" + fname; }); g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})"); g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})"); g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
g.add_point_op("sign", g.add_point_op("sign", "${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
g.add_point_op("equal", "migraphx::abs(${0} == ${1})"); g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
g.add_point_op("less", "migraphx::abs(${0} < ${1})"); g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})"); g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})"); g.add_point_op("not", "migraphx::abs(not ${0})");
// Add explict conversions // Add explict conversions
g.fresult([](const shape& s) { g.fresult(
return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; [](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
});
g.create_function( g.create_function(
g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m).set_name(name)); g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m).set_name(name));
return g.str(); return g.str();
......
...@@ -118,8 +118,8 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -118,8 +118,8 @@ struct pointwise_compiler : compiler<pointwise_compiler>
auto pf = generate_pointwise(*pm, "inner_pointwise"); auto pf = generate_pointwise(*pm, "inner_pointwise");
std::string lambda = "MIGRAPHX_LIFT(inner_pointwise)"; std::string lambda = "MIGRAPHX_LIFT(inner_pointwise)";
auto kernel_name = generate_name_from_ops(*pm) + "_kernel"; auto kernel_name = generate_name_from_ops(*pm) + "_kernel";
return replace(compile_op( return replace(
ctx, compile_op(ctx,
to_shapes(ins->inputs()), to_shapes(ins->inputs()),
{{"lambda", lambda}, {"preamble", pf}, {"kernel", kernel_name}})); {{"lambda", lambda}, {"preamble", pf}, {"kernel", kernel_name}}));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment