Commit bf01b980 authored by turneram's avatar turneram
Browse files

Formatting

parent 2c7fc04b
...@@ -17,19 +17,18 @@ struct parse_fastgelu : op_parser<parse_fastgelu> ...@@ -17,19 +17,18 @@ struct parse_fastgelu : op_parser<parse_fastgelu>
onnx_parser::node_info info, onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const const std::vector<instruction_ref>& args) const
{ {
if (args.size() != 1) if(args.size() != 1)
MIGRAPHX_THROW("FastGelu: too many arguments. Expected 1; got " + std::to_string(args.size())); MIGRAPHX_THROW("FastGelu: too many arguments. Expected 1; got " +
std::to_string(args.size()));
// silu approximation // silu approximation
auto x = args.front(); auto x = args.front();
auto x_type = x->get_shape().type(); auto x_type = x->get_shape().type();
auto lit = info.add_literal(literal{shape{x_type, {1}}, {1.702f}}); auto lit = info.add_literal(literal{shape{x_type, {1}}, {1.702f}});
auto sigmoid = info.add_broadcastable_binary_op("mul", lit, x); auto sigmoid = info.add_broadcastable_binary_op("mul", lit, x);
sigmoid = info.add_instruction(make_op("sigmoid"), sigmoid); sigmoid = info.add_instruction(make_op("sigmoid"), sigmoid);
return info.add_instruction(make_op("mul"), sigmoid, x); return info.add_instruction(make_op("mul"), sigmoid, x);
// tanh approximation // tanh approximation
/* auto x = args.front(); /* auto x = args.front();
auto x_type = x->get_shape().type(); auto x_type = x->get_shape().type();
...@@ -48,13 +47,12 @@ struct parse_fastgelu : op_parser<parse_fastgelu> ...@@ -48,13 +47,12 @@ struct parse_fastgelu : op_parser<parse_fastgelu>
return info.add_instruction(make_op("mul"), x, tanh); */ return info.add_instruction(make_op("mul"), x, tanh); */
// tanh approximation with pow // tanh approximation with pow
/* auto x = args.front(); /* auto x = args.front();
auto x_type = x->get_shape().type(); auto x_type = x->get_shape().type();
auto three = info.add_literal(literal{shape{x_type, {1}}, {3}}); auto three = info.add_literal(literal{shape{x_type, {1}}, {3}});
three = info.add_instruction(make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), three); three = info.add_instruction(make_op("multibroadcast", {{"out_lens",
auto x3 = info.add_instruction(make_op("pow"), x, three); x->get_shape().lens()}}), three); auto x3 = info.add_instruction(make_op("pow"), x, three);
auto magic_number = info.add_literal(literal{shape{x_type, {1}}, {0.044715f}}); auto magic_number = info.add_literal(literal{shape{x_type, {1}}, {0.044715f}});
x3 = info.add_broadcastable_binary_op("mul", magic_number, x3); x3 = info.add_broadcastable_binary_op("mul", magic_number, x3);
auto product = info.add_instruction(make_op("add"), x, x3); auto product = info.add_instruction(make_op("add"), x, x3);
......
...@@ -17,18 +17,19 @@ struct parse_gelu : op_parser<parse_gelu> ...@@ -17,18 +17,19 @@ struct parse_gelu : op_parser<parse_gelu>
onnx_parser::node_info info, onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const const std::vector<instruction_ref>& args) const
{ {
if (args.size() != 1) if(args.size() != 1)
MIGRAPHX_THROW("Gelu: too many arguments. Expected 1; got " + std::to_string(args.size())); MIGRAPHX_THROW("Gelu: too many arguments. Expected 1; got " +
std::to_string(args.size()));
auto x = args.front();
auto x_type = x->get_shape().type(); auto x = args.front();
auto x_type = x->get_shape().type();
auto root_inv = info.add_literal(literal{shape{x_type, {1}}, {1.0f / std::sqrt(2.0f)}}); auto root_inv = info.add_literal(literal{shape{x_type, {1}}, {1.0f / std::sqrt(2.0f)}});
auto product = info.add_broadcastable_binary_op("mul", x, root_inv); auto product = info.add_broadcastable_binary_op("mul", x, root_inv);
auto erf = info.add_instruction(make_op("erf"), product); auto erf = info.add_instruction(make_op("erf"), product);
auto one = info.add_literal(literal{shape{x_type, {1}}, {1.0f}}); auto one = info.add_literal(literal{shape{x_type, {1}}, {1.0f}});
erf = info.add_broadcastable_binary_op("add", one, erf); erf = info.add_broadcastable_binary_op("add", one, erf);
auto half = info.add_literal(literal{shape{x_type, {1}}, {0.5f}}); auto half = info.add_literal(literal{shape{x_type, {1}}, {0.5f}});
erf = info.add_broadcastable_binary_op("mul", half, erf); erf = info.add_broadcastable_binary_op("mul", half, erf);
return info.add_instruction(make_op("mul"), x, erf); return info.add_instruction(make_op("mul"), x, erf);
} }
......
...@@ -1671,11 +1671,7 @@ def gelu_test(): ...@@ -1671,11 +1671,7 @@ def gelu_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [16, 384, 3072]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [16, 384, 3072])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [16, 384, 3072]) y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [16, 384, 3072])
node = onnx.helper.make_node( node = onnx.helper.make_node('Gelu', inputs=['x'], outputs=['y'])
'Gelu',
inputs=['x'],
outputs=['y']
)
return ([node], [x], [y]) return ([node], [x], [y])
...@@ -1685,11 +1681,7 @@ def fastgelu_test(): ...@@ -1685,11 +1681,7 @@ def fastgelu_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [16, 384, 3072]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [16, 384, 3072])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [16, 384, 3072]) y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [16, 384, 3072])
node = onnx.helper.make_node( node = onnx.helper.make_node('FastGelu', inputs=['x'], outputs=['y'])
'FastGelu',
inputs=['x'],
outputs=['y']
)
return ([node], [x], [y]) return ([node], [x], [y])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment