Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
76aa5e81
Commit
76aa5e81
authored
Aug 23, 2022
by
turneram
Browse files
Merge remote-tracking branch 'origin/rewrite-fast-gelu' into bert-perf
parents
e58d20ae
9882f6db
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
327 additions
and
0 deletions
+327
-0
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/include/migraphx/match/gelu_erf.hpp
src/include/migraphx/match/gelu_erf.hpp
+35
-0
src/include/migraphx/rewrite_gelu.hpp
src/include/migraphx/rewrite_gelu.hpp
+48
-0
src/rewrite_gelu.cpp
src/rewrite_gelu.cpp
+59
-0
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+3
-0
test/rewrite_gelu_test.cpp
test/rewrite_gelu_test.cpp
+125
-0
test/verify/test_add_gelu_half.cpp
test/verify/test_add_gelu_half.cpp
+56
-0
No files found.
src/CMakeLists.txt
View file @
76aa5e81
...
@@ -83,6 +83,7 @@ add_library(migraphx
...
@@ -83,6 +83,7 @@ add_library(migraphx
sqlite.cpp
sqlite.cpp
rewrite_batched_gemms.cpp
rewrite_batched_gemms.cpp
rewrite_batchnorm.cpp
rewrite_batchnorm.cpp
rewrite_gelu.cpp
rewrite_pooling.cpp
rewrite_pooling.cpp
rewrite_quantization.cpp
rewrite_quantization.cpp
rewrite_rnn.cpp
rewrite_rnn.cpp
...
...
src/include/migraphx/match/gelu_erf.hpp
View file @
76aa5e81
...
@@ -67,6 +67,41 @@ inline auto gelu_erf()
...
@@ -67,6 +67,41 @@ inline auto gelu_erf()
return
gelu_erf
([](
auto
x
)
{
return
name
(
x
);
});
return
gelu_erf
([](
auto
x
)
{
return
name
(
x
);
});
}
}
namespace
detail
{
template
<
class
F
>
struct
bert_gelu_erf_matcher
{
F
f
;
auto
erf_fn
()
const
{
return
f
(
"erf"
)(
used_once
(),
arg
(
0
)(
used_once
(),
f
(
"div"
)(
either_arg
(
0
,
1
)(
none_of
(
has_value
(
1.414
f
,
1e-3
)).
bind
(
"x"
),
has_value
(
1.414
f
,
1e-3
)))));
}
auto
add_erf
()
const
{
return
f
(
"add"
)(
used_once
(),
either_arg
(
0
,
1
)(
erf_fn
(),
has_value
(
1.0
f
)));
}
auto
one_half
()
const
{
return
has_value
(
0.5
f
);
}
auto
matcher
()
const
{
return
unordered_tree
(
f
(
"mul"
),
one_half
(),
add_erf
(),
any
());
}
};
}
// namespace detail
template
<
class
F
>
auto
bert_gelu_erf
(
F
f
)
{
return
detail
::
bert_gelu_erf_matcher
<
F
>
{
f
}.
matcher
();
}
inline
auto
bert_gelu_erf
()
{
return
bert_gelu_erf
([](
auto
x
)
{
return
name
(
x
);
});
}
}
// namespace match
}
// namespace match
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/rewrite_gelu.hpp
0 → 100644
View file @
76aa5e81
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_GELU_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_GELU_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
/**
* Rewrite gelu standard formula as the sigmoid approximation formula
*/
struct
rewrite_gelu
{
std
::
string
name
()
const
{
return
"rewrite_gelu"
;
}
void
apply
(
module
&
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/rewrite_gelu.cpp
0 → 100644
View file @
76aa5e81
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/match/gelu_erf.hpp>
#include <migraphx/common.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
find_gelu_erf
{
auto
matcher
()
const
{
return
match
::
bert_gelu_erf
();
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
x
=
r
.
instructions
[
"x"
];
if
(
x
->
get_shape
().
type
()
!=
migraphx
::
shape
::
half_type
)
return
;
auto
lit
=
m
.
add_literal
(
literal
{
shape
{
x
->
get_shape
().
type
()},
{
1.702
f
}});
auto
mul
=
insert_common_op
(
m
,
ins
,
make_op
(
"mul"
),
{
x
,
lit
});
auto
sig
=
m
.
insert_instruction
(
ins
,
make_op
(
"neg"
),
mul
);
sig
=
m
.
insert_instruction
(
ins
,
make_op
(
"exp"
),
sig
);
auto
one
=
m
.
add_literal
(
literal
{
shape
{
x
->
get_shape
().
type
()},
{
1.0
f
}});
sig
=
insert_common_op
(
m
,
ins
,
make_op
(
"add"
),
{
sig
,
one
});
sig
=
m
.
insert_instruction
(
ins
,
make_op
(
"div"
),
x
,
sig
);
m
.
replace_instruction
(
ins
,
sig
);
}
};
void
rewrite_gelu
::
apply
(
module
&
m
)
const
{
match
::
find_matches
(
m
,
find_gelu_erf
{});
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/target.cpp
View file @
76aa5e81
...
@@ -42,6 +42,7 @@
...
@@ -42,6 +42,7 @@
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/replace_allocate.hpp>
#include <migraphx/replace_allocate.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_rnn.hpp>
...
@@ -117,6 +118,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -117,6 +118,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
inline_module
{},
inline_module
{},
rewrite_pooling
{},
rewrite_pooling
{},
dead_code_elimination
{},
dead_code_elimination
{},
rewrite_gelu
{},
dead_code_elimination
{},
eliminate_common_subexpression
{},
eliminate_common_subexpression
{},
dead_code_elimination
{},
dead_code_elimination
{},
simplify_algebra
{},
simplify_algebra
{},
...
...
test/rewrite_gelu_test.cpp
0 → 100644
View file @
76aa5e81
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <test.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/verify.hpp>
TEST_CASE
(
bias_gelu
)
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
half_type
,
{
2
,
4
,
8
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
half_type
};
migraphx
::
module
m1
;
{
auto
a
=
m1
.
add_parameter
(
"a"
,
s1
);
auto
b
=
m1
.
add_parameter
(
"b"
,
s1
);
auto
add1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
a
,
b
);
auto
l1
=
m1
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.4140625
f
}});
auto
div
=
add_common_op
(
m1
,
migraphx
::
make_op
(
"div"
),
{
add1
,
l1
});
auto
erf
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"erf"
),
div
);
auto
l2
=
m1
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.0
f
}});
auto
add2
=
add_common_op
(
m1
,
migraphx
::
make_op
(
"add"
),
{
erf
,
l2
});
auto
mul
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
add1
,
add2
);
auto
l3
=
m1
.
add_literal
(
migraphx
::
literal
{
s2
,
{
0.5
f
}});
mul
=
add_common_op
(
m1
,
migraphx
::
make_op
(
"mul"
),
{
mul
,
l3
});
m1
.
add_return
({
mul
});
}
migraphx
::
rewrite_gelu
pass
;
pass
.
apply
(
m1
);
migraphx
::
dead_code_elimination
dce
;
dce
.
apply
(
m1
);
migraphx
::
module
m2
;
{
auto
a
=
m2
.
add_parameter
(
"a"
,
s1
);
auto
b
=
m2
.
add_parameter
(
"b"
,
s1
);
auto
add
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
a
,
b
);
auto
l1
=
m2
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.702
f
}});
auto
mul
=
add_common_op
(
m2
,
migraphx
::
make_op
(
"mul"
),
{
add
,
l1
});
auto
sig
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"neg"
),
mul
);
sig
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"exp"
),
sig
);
auto
l2
=
m2
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.0
f
}});
sig
=
add_common_op
(
m2
,
migraphx
::
make_op
(
"add"
),
{
sig
,
l2
});
sig
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"div"
),
add
,
sig
);
m2
.
add_return
({
sig
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
non_bias_gelu
)
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
half_type
,
{
2
,
4
,
8
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
half_type
};
migraphx
::
module
m1
;
{
auto
a
=
m1
.
add_parameter
(
"a"
,
s1
);
auto
b
=
m1
.
add_parameter
(
"b"
,
s1
);
auto
sub
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"sub"
),
a
,
b
);
auto
l1
=
m1
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.4140625
f
}});
auto
div
=
add_common_op
(
m1
,
migraphx
::
make_op
(
"div"
),
{
sub
,
l1
});
auto
erf
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"erf"
),
div
);
auto
l2
=
m1
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.0
f
}});
auto
add2
=
add_common_op
(
m1
,
migraphx
::
make_op
(
"add"
),
{
erf
,
l2
});
auto
mul
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sub
,
add2
);
auto
l3
=
m1
.
add_literal
(
migraphx
::
literal
{
s2
,
{
0.5
f
}});
mul
=
add_common_op
(
m1
,
migraphx
::
make_op
(
"mul"
),
{
mul
,
l3
});
m1
.
add_return
({
mul
});
}
migraphx
::
rewrite_gelu
pass
;
pass
.
apply
(
m1
);
migraphx
::
dead_code_elimination
dce
;
dce
.
apply
(
m1
);
migraphx
::
module
m2
;
{
auto
a
=
m2
.
add_parameter
(
"a"
,
s1
);
auto
b
=
m2
.
add_parameter
(
"b"
,
s1
);
auto
sub
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"sub"
),
a
,
b
);
auto
l1
=
m2
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.702
f
}});
auto
mul
=
add_common_op
(
m2
,
migraphx
::
make_op
(
"mul"
),
{
sub
,
l1
});
auto
sig
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"neg"
),
mul
);
sig
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"exp"
),
sig
);
auto
l2
=
m2
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.0
f
}});
sig
=
add_common_op
(
m2
,
migraphx
::
make_op
(
"add"
),
{
sig
,
l2
});
sig
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"div"
),
sub
,
sig
);
m2
.
add_return
({
sig
});
}
EXPECT
(
m1
==
m2
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/verify/test_add_gelu_half.cpp
0 → 100644
View file @
76aa5e81
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_add_gelu_half
:
verify_program
<
test_add_gelu_half
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
size_t
>
input_lens
{
1
,
1
,
5
};
auto
x
=
mm
->
add_parameter
(
"x"
,
{
migraphx
::
shape
::
half_type
,
input_lens
});
auto
y
=
mm
->
add_parameter
(
"y"
,
{
migraphx
::
shape
::
half_type
,
input_lens
});
auto
half
=
mm
->
add_literal
(
migraphx
::
literal
{{
migraphx
::
shape
::
half_type
},
{
0.5
f
}});
auto
one
=
mm
->
add_literal
(
migraphx
::
literal
{{
migraphx
::
shape
::
half_type
},
{
1.0
f
}});
auto
sqrt2
=
mm
->
add_literal
(
migraphx
::
literal
{{
migraphx
::
shape
::
half_type
},
{
M_SQRT2
}});
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
y
);
auto
half_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
half
);
auto
mul_half
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
add
,
half_mbcast
);
auto
sqrt2_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
sqrt2
);
auto
div
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"div"
),
add
,
sqrt2_mbcast
);
auto
erf
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"erf"
),
div
);
auto
one_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
one
);
auto
add_one
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
erf
,
one_mbcast
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
mul_half
,
add_one
);
return
p
;
}
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment