Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
18e96032
Commit
18e96032
authored
Feb 01, 2023
by
Shiv
Browse files
add layernom fuse update
parents
2c93aa87
0d94f068
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
37 additions
and
19 deletions
+37
-19
src/include/migraphx/match/layernorm.hpp
src/include/migraphx/match/layernorm.hpp
+12
-5
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+10
-3
src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp
src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp
+2
-2
src/targets/gpu/prefuse_ops.cpp
src/targets/gpu/prefuse_ops.cpp
+13
-9
No files found.
src/include/migraphx/match/layernorm.hpp
View file @
18e96032
...
@@ -43,16 +43,23 @@ struct layernorm_matcher
...
@@ -43,16 +43,23 @@ struct layernorm_matcher
auto
variance
()
const
auto
variance
()
const
{
{
return
f
(
"reduce_mean"
)(
arg
(
0
)(
f
(
"pow"
)(
arg
(
0
)(
x_minus_mean
()),
arg
(
1
)(
has_value
(
2.0
f
)))));
return
f
(
"reduce_mean"
)(
arg
(
0
)(
any_of
(
f
(
"pow"
)(
arg
(
0
)(
x_minus_mean
()),
arg
(
1
)(
has_value
(
2.0
f
))),
f
(
"mul"
)(
arg
(
0
)(
x_minus_mean
()),
arg
(
1
)(
x_minus_mean
())),
f
(
"sqdiff"
)(
either_arg
(
0
,
1
)(
any
().
bind
(
"x"
),
skip_broadcasts
(
f
(
"reduce_mean"
)))))));
}
}
auto
layernorm_onnx
(
)
const
auto
sqrt_add_eps
(
const
std
::
string
&
name
)
const
{
{
auto
add_eps
=
f
(
"add"
)(
either_arg
(
0
,
1
)(
variance
(),
is_constant
().
bind
(
"eps"
)));
auto
add_eps
=
f
(
"add"
)(
either_arg
(
0
,
1
)(
variance
(),
is_constant
().
bind
(
"eps"
)));
return
f
(
"div"
)(
return
skip_broadcasts
(
f
(
name
)(
arg
(
0
)(
any_of
(
add_eps
,
variance
()))));
arg
(
0
)(
x_minus_mean
()),
}
arg
(
1
)(
skip_broadcasts
(
f
(
"sqrt"
)(
arg
(
0
)(
match
::
any_of
(
add_eps
,
variance
()))))));
auto
layernorm_onnx
()
const
{
auto
div_sqrt
=
f
(
"div"
)(
arg
(
0
)(
x_minus_mean
()),
arg
(
1
)(
sqrt_add_eps
(
"sqrt"
)));
auto
mul_rsqrt
=
f
(
"mul"
)(
either_arg
(
0
,
1
)(
x_minus_mean
(),
sqrt_add_eps
(
"rsqrt"
)));
return
any
(
any_of
(
div_sqrt
,
mul_rsqrt
));
}
}
auto
matcher
()
const
{
return
layernorm_onnx
();
}
auto
matcher
()
const
{
return
layernorm_onnx
();
}
...
...
src/simplify_algebra.cpp
View file @
18e96032
...
@@ -31,6 +31,7 @@
...
@@ -31,6 +31,7 @@
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/common.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/serialize.hpp>
...
@@ -340,12 +341,18 @@ struct find_inner_broadcast
...
@@ -340,12 +341,18 @@ struct find_inner_broadcast
std
::
back_inserter
(
inputs
),
std
::
back_inserter
(
inputs
),
[](
auto
i
)
{
return
i
->
inputs
().
front
();
});
[](
auto
i
)
{
return
i
->
inputs
().
front
();
});
if
(
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
i
)
{
if
(
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
i
)
{
return
i
->
get_shape
()
!=
inputs
.
front
()
->
get_shape
();
return
i
->
get_shape
()
!=
inputs
.
front
()
->
get_shape
()
and
i
->
get_shape
().
elements
()
!=
1
;
}))
}))
return
;
return
;
auto
op
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
inputs
);
auto
b_it
=
std
::
find_if
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
[
&
](
auto
i
)
{
m
.
replace_instruction
(
ins
,
broadcasts
.
front
()
->
get_operator
(),
op
);
return
not
i
->
get_shape
().
scalar
();
});
if
(
b_it
==
broadcasts
.
end
())
b_it
=
broadcasts
.
begin
();
auto
op
=
insert_common_op
(
m
,
ins
,
ins
->
get_operator
(),
inputs
);
m
.
replace_instruction
(
ins
,
(
*
b_it
)
->
get_operator
(),
op
);
}
}
};
};
...
...
src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp
View file @
18e96032
...
@@ -30,14 +30,14 @@
...
@@ -30,14 +30,14 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
struct
module
_pass_manager
;
namespace
gpu
{
namespace
gpu
{
struct
prefuse_ops
struct
prefuse_ops
{
{
std
::
string
name
()
const
{
return
"gpu::prefuse_ops"
;
}
std
::
string
name
()
const
{
return
"gpu::prefuse_ops"
;
}
void
apply
(
module
&
m
)
const
;
void
apply
(
module
_pass_manager
&
m
)
const
;
};
};
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/prefuse_ops.cpp
View file @
18e96032
...
@@ -26,6 +26,8 @@
...
@@ -26,6 +26,8 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -90,7 +92,9 @@ struct find_layernorm
...
@@ -90,7 +92,9 @@ struct find_layernorm
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
eps
=
r
.
instructions
[
"eps"
]
->
eval
().
at
<
float
>
();
float
eps
=
0
;
if
(
contains
(
r
.
instructions
,
"eps"
))
eps
=
r
.
instructions
[
"eps"
]
->
eval
().
at
<
float
>
();
m
.
replace_instruction
(
ins
,
layernorm
{
eps
},
x_ins
);
m
.
replace_instruction
(
ins
,
layernorm
{
eps
},
x_ins
);
}
}
...
@@ -100,26 +104,26 @@ struct find_add_layernorm
...
@@ -100,26 +104,26 @@ struct find_add_layernorm
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
layernorm
(
)(
return
match
::
name
(
"gpu::pre
layernorm
"
)(
match
::
v
ar
(
"x"
)
(
match
::
name
(
"add"
)(
match
::
used_once
()).
bind
(
"add"
)));
match
::
ar
gs
(
match
::
name
(
"add"
)(
match
::
used_once
()).
bind
(
"add"
)));
}
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
add_ins
=
r
.
instructions
[
"add"
];
float
eps
=
0
;
auto
op
=
any_cast
<
layernorm
>
(
ins
->
get_operator
());
if
(
contains
(
r
.
instructions
,
"eps"
))
eps
=
r
.
instructions
[
"eps"
]
->
eval
().
at
<
float
>
();
m
.
replace_instruction
(
ins
,
add_layernorm
{
eps
},
add_ins
->
inputs
());
m
.
replace_instruction
(
ins
,
add_layernorm
{
op
.
epsilon
},
add_ins
->
inputs
());
}
}
};
};
}
// namespace
}
// namespace
void
prefuse_ops
::
apply
(
module
&
m
)
const
void
prefuse_ops
::
apply
(
module
_pass_manager
&
mp
m
)
const
{
{
match
::
find_matches
(
m
,
find_add_layernorm
{},
find_layernorm
{});
match
::
find_matches
(
mpm
.
get_module
(),
find_layernorm
{});
mpm
.
run_pass
(
dead_code_elimination
{});
match
::
find_matches
(
mpm
.
get_module
(),
find_add_layernorm
{});
}
}
}
// namespace gpu
}
// namespace gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment