Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
cab8156e
Commit
cab8156e
authored
Jul 07, 2022
by
turneram
Browse files
Add find_gelu_erf to simplify_alegrbra
parent
f2531606
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
49 additions
and
0 deletions
+49
-0
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+49
-0
No files found.
src/simplify_algebra.cpp
View file @
cab8156e
...
@@ -851,6 +851,54 @@ struct find_div_const
...
@@ -851,6 +851,54 @@ struct find_div_const
}
}
};
};
struct
find_gelu_erf
{
static
auto
match_mul1
()
{
return
match
::
name
(
"mul"
)(
args
(
match
::
any
().
bind
(
"x"
),
match
::
skip_broadcasts
(
match
::
name
(
"recip"
))));
}
static
auto
match_erf
()
{
return
match
::
name
(
"erf"
)(
match
::
arg
(
0
)(
match_mul1
()));
}
static
auto
match_add2
()
{
return
match
::
name
(
"add"
)(
args
(
match_erf
(),
match
::
skip_broadcasts
(
match
::
has_value
(
1.0
f
))));
}
static
auto
match_add1
()
{
return
match
::
name
(
"add"
)(
args
(
match
::
skip_broadcasts
(
match
::
is_constant
()),
match
::
name
(
"dot"
)));
}
static
auto
match_mul2
()
{
return
match
::
name
(
"mul"
)(
match
::
args
(
match_add1
(),
match_add2
()));
}
auto
matcher
()
const
{
return
match
::
name
(
"mul"
)(
match
::
args
(
match_mul2
(),
match
::
skip_broadcasts
(
match
::
has_value
(
0.5
f
))));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
x
=
r
.
instructions
[
"x"
];
auto
lit
=
m
.
add_literal
(
literal
{
shape
{
x
->
get_shape
().
type
()},
{
1.702
f
}});
auto
mul
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
x
->
get_shape
().
lens
()}}),
lit
);
mul
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
x
,
mul
);
auto
sig
=
m
.
insert_instruction
(
ins
,
make_op
(
"neg"
),
mul
);
auto
one
=
m
.
add_literal
(
literal
{
shape
{
x
->
get_shape
().
type
()},
{
1.0
f
}});
one
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
x
->
get_shape
().
lens
()}}),
one
);
sig
=
m
.
insert_instruction
(
ins
,
make_op
(
"exp"
),
sig
);
sig
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
one
,
sig
);
sig
=
m
.
insert_instruction
(
ins
,
make_op
(
"div"
),
one
,
sig
);
sig
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
x
,
sig
);
m
.
replace_instruction
(
ins
,
sig
);
}
};
struct
find_sub_const
struct
find_sub_const
{
{
auto
matcher
()
const
auto
matcher
()
const
...
@@ -1040,6 +1088,7 @@ void simplify_algebra::apply(module& m) const
...
@@ -1040,6 +1088,7 @@ void simplify_algebra::apply(module& m) const
for
(
int
i
=
0
;
i
<
8
;
i
++
)
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
{
match
::
find_matches
(
m
,
match
::
find_matches
(
m
,
find_gelu_erf
{},
find_inner_broadcast
{},
find_inner_broadcast
{},
find_double_add_lit_broadcast
{},
find_double_add_lit_broadcast
{},
find_add_lit_broadcast
{},
find_add_lit_broadcast
{},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment