Commit f99a3036 authored by turneram's avatar turneram
Browse files

Formatting

parent 2237c5de
......@@ -17,9 +17,9 @@ struct parse_fastgelu : op_parser<parse_fastgelu>
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
if (args.size() != 1)
MIGRAPHX_THROW("FastGelu: too many arguments. Expected 1; got " + std::to_string(args.size()));
if(args.size() != 1)
MIGRAPHX_THROW("FastGelu: too many arguments. Expected 1; got " +
std::to_string(args.size()));
// silu approximation
auto x = args.front();
......@@ -29,7 +29,6 @@ struct parse_fastgelu : op_parser<parse_fastgelu>
sigmoid = info.add_instruction(make_op("sigmoid"), sigmoid);
return info.add_instruction(make_op("mul"), sigmoid, x);
// tanh approximation
/* auto x = args.front();
auto x_type = x->get_shape().type();
......@@ -48,13 +47,12 @@ struct parse_fastgelu : op_parser<parse_fastgelu>
return info.add_instruction(make_op("mul"), x, tanh); */
// tanh approximation with pow
/* auto x = args.front();
auto x_type = x->get_shape().type();
auto three = info.add_literal(literal{shape{x_type, {1}}, {3}});
three = info.add_instruction(make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), three);
auto x3 = info.add_instruction(make_op("pow"), x, three);
three = info.add_instruction(make_op("multibroadcast", {{"out_lens",
x->get_shape().lens()}}), three); auto x3 = info.add_instruction(make_op("pow"), x, three);
auto magic_number = info.add_literal(literal{shape{x_type, {1}}, {0.044715f}});
x3 = info.add_broadcastable_binary_op("mul", magic_number, x3);
auto product = info.add_instruction(make_op("add"), x, x3);
......
......@@ -17,8 +17,9 @@ struct parse_gelu : op_parser<parse_gelu>
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
if (args.size() != 1)
MIGRAPHX_THROW("Gelu: too many arguments. Expected 1; got " + std::to_string(args.size()));
if(args.size() != 1)
MIGRAPHX_THROW("Gelu: too many arguments. Expected 1; got " +
std::to_string(args.size()));
auto x = args.front();
auto x_type = x->get_shape().type();
......
......@@ -1671,11 +1671,7 @@ def gelu_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [16, 384, 3072])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [16, 384, 3072])
node = onnx.helper.make_node(
'Gelu',
inputs=['x'],
outputs=['y']
)
node = onnx.helper.make_node('Gelu', inputs=['x'], outputs=['y'])
return ([node], [x], [y])
......@@ -1685,11 +1681,7 @@ def fastgelu_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [16, 384, 3072])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [16, 384, 3072])
node = onnx.helper.make_node(
'FastGelu',
inputs=['x'],
outputs=['y']
)
node = onnx.helper.make_node('FastGelu', inputs=['x'], outputs=['y'])
return ([node], [x], [y])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment