Commit 1bc16951 authored by turneram's avatar turneram
Browse files

Formatting

parent c7096299
...@@ -19,11 +19,11 @@ struct parse_layernorm : op_parser<parse_layernorm> ...@@ -19,11 +19,11 @@ struct parse_layernorm : op_parser<parse_layernorm>
const std::vector<instruction_ref>& args) const const std::vector<instruction_ref>& args) const
{ {
// un-fuse layernorm op so migraphx can handle fusion instead // un-fuse layernorm op so migraphx can handle fusion instead
auto x = args.front(); auto x = args.front();
auto x_type = x->get_shape().type(); auto x_type = x->get_shape().type();
auto weights = args.at(1); auto weights = args.at(1);
auto bias = args.at(2); auto bias = args.at(2);
float epsilon = 1e-12f; float epsilon = 1e-12f;
int64_t axis = -1; int64_t axis = -1;
...@@ -35,23 +35,23 @@ struct parse_layernorm : op_parser<parse_layernorm> ...@@ -35,23 +35,23 @@ struct parse_layernorm : op_parser<parse_layernorm>
{ {
epsilon = parser.parse_value(info.attributes.at("axis")).at<int64_t>(); epsilon = parser.parse_value(info.attributes.at("axis")).at<int64_t>();
} }
auto epsilon_lit = info.add_literal(literal{shape{x_type, {1}}, {epsilon}}); auto epsilon_lit = info.add_literal(literal{shape{x_type, {1}}, {epsilon}});
auto exponent = info.add_literal(literal{shape{x_type, {1}}, {2.0}}); auto exponent = info.add_literal(literal{shape{x_type, {1}}, {2.0}});
auto dims = x->get_shape().lens(); auto dims = x->get_shape().lens();
auto mean = info.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {axis}}}), x); auto mean = info.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {axis}}}), x);
auto mean_mbcast = auto mean_mbcast =
info.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean); info.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto sub = info.add_instruction(migraphx::make_op("sub"), x, mean_mbcast); auto sub = info.add_instruction(migraphx::make_op("sub"), x, mean_mbcast);
auto exponent_mbcast = auto exponent_mbcast = info.add_instruction(
info.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), exponent); migraphx::make_op("multibroadcast", {{"out_lens", dims}}), exponent);
auto pow = info.add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast); auto pow = info.add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast);
auto var = info.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {axis}}}), pow); auto var = info.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {axis}}}), pow);
auto add_epsilon = info.add_broadcastable_binary_op("add", var, epsilon_lit); auto add_epsilon = info.add_broadcastable_binary_op("add", var, epsilon_lit);
auto sqrt = info.add_instruction(migraphx::make_op("sqrt"), add_epsilon); auto sqrt = info.add_instruction(migraphx::make_op("sqrt"), add_epsilon);
auto div = info.add_broadcastable_binary_op("div", sub, sqrt); auto div = info.add_broadcastable_binary_op("div", sub, sqrt);
auto mul = info.add_broadcastable_binary_op("mul", div, weights); auto mul = info.add_broadcastable_binary_op("mul", div, weights);
return info.add_broadcastable_binary_op("add", mul, bias); return info.add_broadcastable_binary_op("add", mul, bias);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment