Commit bf01b980 authored by turneram's avatar turneram
Browse files

Formatting

parent 2c7fc04b
......@@ -17,19 +17,18 @@ struct parse_fastgelu : op_parser<parse_fastgelu>
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
if (args.size() != 1)
MIGRAPHX_THROW("FastGelu: too many arguments. Expected 1; got " + std::to_string(args.size()));
if(args.size() != 1)
MIGRAPHX_THROW("FastGelu: too many arguments. Expected 1; got " +
std::to_string(args.size()));
// silu approximation
auto x = args.front();
auto x_type = x->get_shape().type();
auto lit = info.add_literal(literal{shape{x_type, {1}}, {1.702f}});
auto x = args.front();
auto x_type = x->get_shape().type();
auto lit = info.add_literal(literal{shape{x_type, {1}}, {1.702f}});
auto sigmoid = info.add_broadcastable_binary_op("mul", lit, x);
sigmoid = info.add_instruction(make_op("sigmoid"), sigmoid);
sigmoid = info.add_instruction(make_op("sigmoid"), sigmoid);
return info.add_instruction(make_op("mul"), sigmoid, x);
// tanh approximation
/* auto x = args.front();
auto x_type = x->get_shape().type();
......@@ -48,13 +47,12 @@ struct parse_fastgelu : op_parser<parse_fastgelu>
return info.add_instruction(make_op("mul"), x, tanh); */
// tanh approximation with pow
/* auto x = args.front();
auto x_type = x->get_shape().type();
auto three = info.add_literal(literal{shape{x_type, {1}}, {3}});
three = info.add_instruction(make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), three);
auto x3 = info.add_instruction(make_op("pow"), x, three);
three = info.add_instruction(make_op("multibroadcast", {{"out_lens",
x->get_shape().lens()}}), three); auto x3 = info.add_instruction(make_op("pow"), x, three);
auto magic_number = info.add_literal(literal{shape{x_type, {1}}, {0.044715f}});
x3 = info.add_broadcastable_binary_op("mul", magic_number, x3);
auto product = info.add_instruction(make_op("add"), x, x3);
......
......@@ -17,18 +17,19 @@ struct parse_gelu : op_parser<parse_gelu>
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
if (args.size() != 1)
MIGRAPHX_THROW("Gelu: too many arguments. Expected 1; got " + std::to_string(args.size()));
auto x = args.front();
auto x_type = x->get_shape().type();
if(args.size() != 1)
MIGRAPHX_THROW("Gelu: too many arguments. Expected 1; got " +
std::to_string(args.size()));
auto x = args.front();
auto x_type = x->get_shape().type();
auto root_inv = info.add_literal(literal{shape{x_type, {1}}, {1.0f / std::sqrt(2.0f)}});
auto product = info.add_broadcastable_binary_op("mul", x, root_inv);
auto erf = info.add_instruction(make_op("erf"), product);
auto one = info.add_literal(literal{shape{x_type, {1}}, {1.0f}});
erf = info.add_broadcastable_binary_op("add", one, erf);
auto half = info.add_literal(literal{shape{x_type, {1}}, {0.5f}});
erf = info.add_broadcastable_binary_op("mul", half, erf);
auto product = info.add_broadcastable_binary_op("mul", x, root_inv);
auto erf = info.add_instruction(make_op("erf"), product);
auto one = info.add_literal(literal{shape{x_type, {1}}, {1.0f}});
erf = info.add_broadcastable_binary_op("add", one, erf);
auto half = info.add_literal(literal{shape{x_type, {1}}, {0.5f}});
erf = info.add_broadcastable_binary_op("mul", half, erf);
return info.add_instruction(make_op("mul"), x, erf);
}
......
......@@ -1671,11 +1671,7 @@ def gelu_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [16, 384, 3072])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [16, 384, 3072])
node = onnx.helper.make_node(
'Gelu',
inputs=['x'],
outputs=['y']
)
node = onnx.helper.make_node('Gelu', inputs=['x'], outputs=['y'])
return ([node], [x], [y])
......@@ -1685,11 +1681,7 @@ def fastgelu_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [16, 384, 3072])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [16, 384, 3072])
node = onnx.helper.make_node(
'FastGelu',
inputs=['x'],
outputs=['y']
)
node = onnx.helper.make_node('FastGelu', inputs=['x'], outputs=['y'])
return ([node], [x], [y])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment