Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
bf01b980
Commit
bf01b980
authored
May 10, 2022
by
turneram
Browse files
Formatting
parent
2c7fc04b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
32 deletions
+23
-32
src/onnx/parse_fastgelu.cpp
src/onnx/parse_fastgelu.cpp
+9
-11
src/onnx/parse_gelu.cpp
src/onnx/parse_gelu.cpp
+12
-11
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+2
-10
No files found.
src/onnx/parse_fastgelu.cpp
View file @
bf01b980
...
@@ -17,19 +17,18 @@ struct parse_fastgelu : op_parser<parse_fastgelu>
...
@@ -17,19 +17,18 @@ struct parse_fastgelu : op_parser<parse_fastgelu>
onnx_parser
::
node_info
info
,
onnx_parser
::
node_info
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
{
if
(
args
.
size
()
!=
1
)
if
(
args
.
size
()
!=
1
)
MIGRAPHX_THROW
(
"FastGelu: too many arguments. Expected 1; got "
+
std
::
to_string
(
args
.
size
()));
MIGRAPHX_THROW
(
"FastGelu: too many arguments. Expected 1; got "
+
std
::
to_string
(
args
.
size
()));
// silu approximation
// silu approximation
auto
x
=
args
.
front
();
auto
x
=
args
.
front
();
auto
x_type
=
x
->
get_shape
().
type
();
auto
x_type
=
x
->
get_shape
().
type
();
auto
lit
=
info
.
add_literal
(
literal
{
shape
{
x_type
,
{
1
}},
{
1.702
f
}});
auto
lit
=
info
.
add_literal
(
literal
{
shape
{
x_type
,
{
1
}},
{
1.702
f
}});
auto
sigmoid
=
info
.
add_broadcastable_binary_op
(
"mul"
,
lit
,
x
);
auto
sigmoid
=
info
.
add_broadcastable_binary_op
(
"mul"
,
lit
,
x
);
sigmoid
=
info
.
add_instruction
(
make_op
(
"sigmoid"
),
sigmoid
);
sigmoid
=
info
.
add_instruction
(
make_op
(
"sigmoid"
),
sigmoid
);
return
info
.
add_instruction
(
make_op
(
"mul"
),
sigmoid
,
x
);
return
info
.
add_instruction
(
make_op
(
"mul"
),
sigmoid
,
x
);
// tanh approximation
// tanh approximation
/* auto x = args.front();
/* auto x = args.front();
auto x_type = x->get_shape().type();
auto x_type = x->get_shape().type();
...
@@ -48,13 +47,12 @@ struct parse_fastgelu : op_parser<parse_fastgelu>
...
@@ -48,13 +47,12 @@ struct parse_fastgelu : op_parser<parse_fastgelu>
return info.add_instruction(make_op("mul"), x, tanh); */
return info.add_instruction(make_op("mul"), x, tanh); */
// tanh approximation with pow
// tanh approximation with pow
/* auto x = args.front();
/* auto x = args.front();
auto x_type = x->get_shape().type();
auto x_type = x->get_shape().type();
auto three = info.add_literal(literal{shape{x_type, {1}}, {3}});
auto three = info.add_literal(literal{shape{x_type, {1}}, {3}});
three = info.add_instruction(make_op("multibroadcast", {{"out_lens",
x->get_shape().lens()}}), three);
three = info.add_instruction(make_op("multibroadcast", {{"out_lens",
auto x3 = info.add_instruction(make_op("pow"), x, three);
x->get_shape().lens()}}), three);
auto x3 = info.add_instruction(make_op("pow"), x, three);
auto magic_number = info.add_literal(literal{shape{x_type, {1}}, {0.044715f}});
auto magic_number = info.add_literal(literal{shape{x_type, {1}}, {0.044715f}});
x3 = info.add_broadcastable_binary_op("mul", magic_number, x3);
x3 = info.add_broadcastable_binary_op("mul", magic_number, x3);
auto product = info.add_instruction(make_op("add"), x, x3);
auto product = info.add_instruction(make_op("add"), x, x3);
...
...
src/onnx/parse_gelu.cpp
View file @
bf01b980
...
@@ -17,18 +17,19 @@ struct parse_gelu : op_parser<parse_gelu>
...
@@ -17,18 +17,19 @@ struct parse_gelu : op_parser<parse_gelu>
onnx_parser
::
node_info
info
,
onnx_parser
::
node_info
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
{
if
(
args
.
size
()
!=
1
)
if
(
args
.
size
()
!=
1
)
MIGRAPHX_THROW
(
"Gelu: too many arguments. Expected 1; got "
+
std
::
to_string
(
args
.
size
()));
MIGRAPHX_THROW
(
"Gelu: too many arguments. Expected 1; got "
+
std
::
to_string
(
args
.
size
()));
auto
x
=
args
.
front
();
auto
x_type
=
x
->
get_shape
().
type
();
auto
x
=
args
.
front
();
auto
x_type
=
x
->
get_shape
().
type
();
auto
root_inv
=
info
.
add_literal
(
literal
{
shape
{
x_type
,
{
1
}},
{
1.0
f
/
std
::
sqrt
(
2.0
f
)}});
auto
root_inv
=
info
.
add_literal
(
literal
{
shape
{
x_type
,
{
1
}},
{
1.0
f
/
std
::
sqrt
(
2.0
f
)}});
auto
product
=
info
.
add_broadcastable_binary_op
(
"mul"
,
x
,
root_inv
);
auto
product
=
info
.
add_broadcastable_binary_op
(
"mul"
,
x
,
root_inv
);
auto
erf
=
info
.
add_instruction
(
make_op
(
"erf"
),
product
);
auto
erf
=
info
.
add_instruction
(
make_op
(
"erf"
),
product
);
auto
one
=
info
.
add_literal
(
literal
{
shape
{
x_type
,
{
1
}},
{
1.0
f
}});
auto
one
=
info
.
add_literal
(
literal
{
shape
{
x_type
,
{
1
}},
{
1.0
f
}});
erf
=
info
.
add_broadcastable_binary_op
(
"add"
,
one
,
erf
);
erf
=
info
.
add_broadcastable_binary_op
(
"add"
,
one
,
erf
);
auto
half
=
info
.
add_literal
(
literal
{
shape
{
x_type
,
{
1
}},
{
0.5
f
}});
auto
half
=
info
.
add_literal
(
literal
{
shape
{
x_type
,
{
1
}},
{
0.5
f
}});
erf
=
info
.
add_broadcastable_binary_op
(
"mul"
,
half
,
erf
);
erf
=
info
.
add_broadcastable_binary_op
(
"mul"
,
half
,
erf
);
return
info
.
add_instruction
(
make_op
(
"mul"
),
x
,
erf
);
return
info
.
add_instruction
(
make_op
(
"mul"
),
x
,
erf
);
}
}
...
...
test/onnx/gen_onnx.py
View file @
bf01b980
...
@@ -1671,11 +1671,7 @@ def gelu_test():
...
@@ -1671,11 +1671,7 @@ def gelu_test():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
16
,
384
,
3072
])
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
16
,
384
,
3072
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
16
,
384
,
3072
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
16
,
384
,
3072
])
node
=
onnx
.
helper
.
make_node
(
node
=
onnx
.
helper
.
make_node
(
'Gelu'
,
inputs
=
[
'x'
],
outputs
=
[
'y'
])
'Gelu'
,
inputs
=
[
'x'
],
outputs
=
[
'y'
]
)
return
([
node
],
[
x
],
[
y
])
return
([
node
],
[
x
],
[
y
])
...
@@ -1685,11 +1681,7 @@ def fastgelu_test():
...
@@ -1685,11 +1681,7 @@ def fastgelu_test():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
16
,
384
,
3072
])
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT
,
[
16
,
384
,
3072
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
16
,
384
,
3072
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT
,
[
16
,
384
,
3072
])
node
=
onnx
.
helper
.
make_node
(
node
=
onnx
.
helper
.
make_node
(
'FastGelu'
,
inputs
=
[
'x'
],
outputs
=
[
'y'
])
'FastGelu'
,
inputs
=
[
'x'
],
outputs
=
[
'y'
]
)
return
([
node
],
[
x
],
[
y
])
return
([
node
],
[
x
],
[
y
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment