Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
27dc554d
Commit
27dc554d
authored
Aug 19, 2022
by
turneram
Browse files
Merge remote-tracking branch 'origin/rewrite-fast-gelu' into bert-perf2
parents
3c133f81
b878f78f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
284 additions
and
0 deletions
+284
-0
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/include/migraphx/match/gelu_erf.hpp
src/include/migraphx/match/gelu_erf.hpp
+35
-0
src/include/migraphx/rewrite_gelu.hpp
src/include/migraphx/rewrite_gelu.hpp
+48
-0
src/rewrite_gelu.cpp
src/rewrite_gelu.cpp
+61
-0
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+3
-0
test/rewrite_gelu_test.cpp
test/rewrite_gelu_test.cpp
+136
-0
No files found.
src/CMakeLists.txt
View file @
27dc554d
...
@@ -82,6 +82,7 @@ add_library(migraphx
...
@@ -82,6 +82,7 @@ add_library(migraphx
simplify_qdq.cpp
simplify_qdq.cpp
sqlite.cpp
sqlite.cpp
rewrite_batchnorm.cpp
rewrite_batchnorm.cpp
rewrite_gelu.cpp
rewrite_pooling.cpp
rewrite_pooling.cpp
rewrite_quantization.cpp
rewrite_quantization.cpp
rewrite_rnn.cpp
rewrite_rnn.cpp
...
...
src/include/migraphx/match/gelu_erf.hpp
View file @
27dc554d
...
@@ -67,6 +67,41 @@ inline auto gelu_erf()
...
@@ -67,6 +67,41 @@ inline auto gelu_erf()
return
gelu_erf
([](
auto
x
)
{
return
name
(
x
);
});
return
gelu_erf
([](
auto
x
)
{
return
name
(
x
);
});
}
}
namespace
detail
{
template
<
class
F
>
struct
bert_gelu_erf_matcher
{
F
f
;
auto
erf_fn
()
const
{
return
f
(
"erf"
)(
used_once
(),
arg
(
0
)(
used_once
(),
f
(
"div"
)(
either_arg
(
0
,
1
)(
none_of
(
has_value
(
1.414
f
,
1e-3
)).
bind
(
"x"
),
has_value
(
1.414
f
,
1e-3
)))));
}
auto
add_erf
()
const
{
return
f
(
"add"
)(
used_once
(),
either_arg
(
0
,
1
)(
erf_fn
(),
has_value
(
1.0
f
)));
}
auto
one_half
()
const
{
return
has_value
(
0.5
f
);
}
auto
matcher
()
const
{
return
unordered_tree
(
f
(
"mul"
),
one_half
(),
add_erf
(),
any
());
}
};
}
// namespace detail
template
<
class
F
>
auto
bert_gelu_erf
(
F
f
)
{
return
detail
::
bert_gelu_erf_matcher
<
F
>
{
f
}.
matcher
();
}
inline
auto
bert_gelu_erf
()
{
return
bert_gelu_erf
([](
auto
x
)
{
return
name
(
x
);
});
}
}
// namespace match
}
// namespace match
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/rewrite_gelu.hpp
0 → 100644
View file @
27dc554d
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_GELU_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_GELU_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
/**
* Rewrite gelu standard formula as the sigmoid approximation formula
*/
struct
rewrite_gelu
{
std
::
string
name
()
const
{
return
"rewrite_gelu"
;
}
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/rewrite_gelu.cpp
0 → 100644
View file @
27dc554d
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/match/gelu_erf.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
find_gelu_erf
{
auto
matcher
()
const
{
return
match
::
bert_gelu_erf
();
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
x
=
r
.
instructions
[
"x"
];
auto
lit
=
m
.
add_literal
(
literal
{
shape
{
x
->
get_shape
().
type
()},
{
1.702
f
}});
auto
mul
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
x
->
get_shape
().
lens
()}}),
lit
);
mul
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
x
,
mul
);
auto
sig
=
m
.
insert_instruction
(
ins
,
make_op
(
"neg"
),
mul
);
sig
=
m
.
insert_instruction
(
ins
,
make_op
(
"exp"
),
sig
);
auto
one
=
m
.
add_literal
(
literal
{
shape
{
x
->
get_shape
().
type
()},
{
1.0
f
}});
one
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
x
->
get_shape
().
lens
()}}),
one
);
sig
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
sig
,
one
);
sig
=
m
.
insert_instruction
(
ins
,
make_op
(
"div"
),
one
,
sig
);
sig
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
x
,
sig
);
m
.
replace_instruction
(
ins
,
sig
);
}
};
void
rewrite_gelu
::
apply
(
module
&
m
)
const
{
match
::
find_matches
(
m
,
find_gelu_erf
{});
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/target.cpp
View file @
27dc554d
...
@@ -42,6 +42,7 @@
...
@@ -42,6 +42,7 @@
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/replace_allocate.hpp>
#include <migraphx/replace_allocate.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_rnn.hpp>
...
@@ -116,6 +117,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -116,6 +117,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
inline_module
{},
inline_module
{},
rewrite_pooling
{},
rewrite_pooling
{},
dead_code_elimination
{},
dead_code_elimination
{},
rewrite_gelu
{},
dead_code_elimination
{},
eliminate_common_subexpression
{},
eliminate_common_subexpression
{},
dead_code_elimination
{},
dead_code_elimination
{},
simplify_algebra
{},
simplify_algebra
{},
...
...
test/rewrite_gelu_test.cpp
0 → 100644
View file @
27dc554d
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <test.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/verify.hpp>
TEST_CASE
(
bias_gelu
)
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
half_type
,
{
2
,
4
,
8
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
half_type
};
migraphx
::
module
m1
;
{
auto
a
=
m1
.
add_parameter
(
"a"
,
s1
);
auto
b
=
m1
.
add_parameter
(
"b"
,
s1
);
auto
add1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
a
,
b
);
auto
l1
=
m1
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.4140625
f
}});
l1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s1
.
lens
()}}),
l1
);
auto
div
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"div"
),
add1
,
l1
);
auto
erf
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"erf"
),
div
);
auto
l2
=
m1
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.0
f
}});
l2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s1
.
lens
()}}),
l2
);
auto
add2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
erf
,
l2
);
auto
mul
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
add1
,
add2
);
auto
l3
=
m1
.
add_literal
(
migraphx
::
literal
{
s2
,
{
0.5
f
}});
l3
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s1
.
lens
()}}),
l3
);
mul
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
mul
,
l3
);
m1
.
add_return
({
mul
});
}
migraphx
::
rewrite_gelu
pass
;
pass
.
apply
(
m1
);
migraphx
::
dead_code_elimination
dce
;
dce
.
apply
(
m1
);
migraphx
::
module
m2
;
{
auto
a
=
m2
.
add_parameter
(
"a"
,
s1
);
auto
b
=
m2
.
add_parameter
(
"b"
,
s1
);
auto
add
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
a
,
b
);
auto
l1
=
m2
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.702
f
}});
l1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s1
.
lens
()}}),
l1
);
auto
mul
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
add
,
l1
);
auto
sig
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"neg"
),
mul
);
sig
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"exp"
),
sig
);
auto
l2
=
m2
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.0
f
}});
l2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s1
.
lens
()}}),
l2
);
sig
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
sig
,
l2
);
sig
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"div"
),
l2
,
sig
);
sig
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
add
,
sig
);
m2
.
add_return
({
sig
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
non_bias_gelu
)
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
half_type
,
{
2
,
4
,
8
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
half_type
};
migraphx
::
module
m1
;
{
auto
a
=
m1
.
add_parameter
(
"a"
,
s1
);
auto
b
=
m1
.
add_parameter
(
"b"
,
s1
);
auto
sub
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"sub"
),
a
,
b
);
auto
l1
=
m1
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.4140625
f
}});
l1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s1
.
lens
()}}),
l1
);
auto
div
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"div"
),
sub
,
l1
);
auto
erf
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"erf"
),
div
);
auto
l2
=
m1
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.0
f
}});
l2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s1
.
lens
()}}),
l2
);
auto
add2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
erf
,
l2
);
auto
mul
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sub
,
add2
);
auto
l3
=
m1
.
add_literal
(
migraphx
::
literal
{
s2
,
{
0.5
f
}});
l3
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s1
.
lens
()}}),
l3
);
mul
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
mul
,
l3
);
m1
.
add_return
({
mul
});
}
migraphx
::
rewrite_gelu
pass
;
pass
.
apply
(
m1
);
migraphx
::
dead_code_elimination
dce
;
dce
.
apply
(
m1
);
migraphx
::
module
m2
;
{
auto
a
=
m2
.
add_parameter
(
"a"
,
s1
);
auto
b
=
m2
.
add_parameter
(
"b"
,
s1
);
auto
sub
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"sub"
),
a
,
b
);
auto
l1
=
m2
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.702
f
}});
l1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s1
.
lens
()}}),
l1
);
auto
mul
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sub
,
l1
);
auto
sig
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"neg"
),
mul
);
sig
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"exp"
),
sig
);
auto
l2
=
m2
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.0
f
}});
l2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s1
.
lens
()}}),
l2
);
sig
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
sig
,
l2
);
sig
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"div"
),
l2
,
sig
);
sig
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sub
,
sig
);
m2
.
add_return
({
sig
});
}
EXPECT
(
m1
==
m2
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment