Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2bca512e
"src/targets/cpu/sub.cpp" did not exist on "96358e41cc883791c8d3ad50280bea4871a18000"
Commit
2bca512e
authored
Dec 12, 2023
by
charlie
Browse files
progress
parent
6c41008a
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
101 additions
and
22 deletions
+101
-22
src/include/migraphx/match/gelu_erf.hpp
src/include/migraphx/match/gelu_erf.hpp
+3
-3
src/include/migraphx/match/gelu_tanh.hpp
src/include/migraphx/match/gelu_tanh.hpp
+22
-11
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+8
-3
src/rewrite_gelu.cpp
src/rewrite_gelu.cpp
+13
-0
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+1
-1
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+2
-1
test/rewrite_gelu_test.cpp
test/rewrite_gelu_test.cpp
+45
-0
tools/accuracy/accuracy_checker.py
tools/accuracy/accuracy_checker.py
+7
-3
No files found.
src/include/migraphx/match/gelu_erf.hpp
View file @
2bca512e
...
@@ -38,10 +38,10 @@ struct gelu_erf_matcher
...
@@ -38,10 +38,10 @@ struct gelu_erf_matcher
F
f
;
F
f
;
auto
erf_fn
()
const
auto
erf_fn
()
const
{
{
auto
mul_1_sqrt_2
=
f
(
"mul"
)(
either_arg
(
0
,
1
)(
none_of
(
has_value
(
M_SQRT1_2
,
1e-3
)).
bind
(
"x"
),
auto
mul_1_sqrt_2
=
f
(
"mul"
)(
has_value
(
M_SQRT1_2
,
1e-3
)));
either_arg
(
0
,
1
)(
none_of
(
has_value
(
M_SQRT1_2
)).
bind
(
"x"
),
has_value
(
M_SQRT1_2
,
1e-3
)));
auto
div_sqrt_2
=
auto
div_sqrt_2
=
f
(
"div"
)(
args
(
none_of
(
has_value
(
M_SQRT2
,
1e-3
)).
bind
(
"x"
),
has_value
(
M_SQRT2
,
1e-3
)));
f
(
"div"
)(
args
(
none_of
(
has_value
(
M_SQRT2
,
1e-3
)).
bind
(
"x"
),
has_value
(
M_SQRT2
)));
return
f
(
"erf"
)(
used_once
(),
arg
(
0
)(
used_once
(),
any_of
(
mul_1_sqrt_2
,
div_sqrt_2
)));
return
f
(
"erf"
)(
used_once
(),
arg
(
0
)(
used_once
(),
any_of
(
mul_1_sqrt_2
,
div_sqrt_2
)));
}
}
...
...
src/include/migraphx/match/gelu_tanh.hpp
View file @
2bca512e
...
@@ -36,23 +36,34 @@ template <class F>
...
@@ -36,23 +36,34 @@ template <class F>
struct
gelu_tanh_matcher
struct
gelu_tanh_matcher
{
{
F
f
;
F
f
;
/// x ^ 3
auto
pow_fn
()
const
{
return
f
(
"pow"
)(
used_once
(),
arg
(
1
)(
has_value
(
3.0
f
)));
}
auto
pow_fn
()
const
{
return
f
(
"pow"
)(
used_once
(),
arg
(
1
)(
has_value
(
3.0
f
)));
}
/// tanh( sqrt(2/M_PI) * (x + 0.044715 * x ^ 3 )
auto
tanh_fn
()
const
auto
tanh_fn
()
const
{
{
return
f
(
"tanh"
)(
auto
mul_const_pow
=
f
(
"mul"
)(
either_arg
(
0
,
1
)(
has_value
(
0.044715
f
),
pow_fn
()));
used_once
(),
auto
add_any_mul
=
f
(
"add"
)(
any_arg
(
0
,
1
)(
mul_const_pow
));
arg
(
0
)(
f
(
"mul"
)(
either_arg
(
0
,
1
)(
has_value
(
sqrt
(
M_2_PI
),
1e-3
),
auto
either_SQRT2RPI_add
=
either_arg
(
0
,
1
)(
has_value
(
sqrt
(
M_2_PI
)),
add_any_mul
);
f
(
"add"
)(
any_arg
(
0
,
1
)(
f
(
"mul"
)(
either_arg
(
0
,
1
)(
return
f
(
"tanh"
)(
used_once
(),
arg
(
0
)(
f
(
"mul"
)(
either_SQRT2RPI_add
)));
has_value
(
0.044715
f
),
pow_fn
()))))))));
}
/// x * (0.5? + 0.5 * tanh( sqrt(2/M_PI) * (x? + 0.044715 * x? ^ 3) ) )
/// <item>? question mark means it doesn't explicitly match that item (anything will work)
auto
matcher_v0
()
const
{
auto
mul_half_tanh
=
f
(
"mul"
)(
either_arg
(
0
,
1
)(
has_value
(
0.5
f
),
tanh_fn
()));
auto
add_any_mul
=
f
(
"add"
)(
any_arg
(
0
,
1
)(
mul_half_tanh
));
return
f
(
"mul"
)(
either_arg
(
0
,
1
)(
any
().
bind
(
"x"
),
add_any_mul
));
}
}
auto
matcher
()
const
/// x * 0.5 * (1.0 + tanh( sqrt(2/M_PI) * (x + 0.044715 * x ^ 3) ) )
auto
matcher_v1
()
const
{
{
return
f
(
"mul"
)(
used_once
(),
auto
add_one_tanh
=
f
(
"add"
)(
used_once
(),
either_arg
(
0
,
1
)(
has_value
(
1.0
),
tanh_fn
()));
either_arg
(
0
,
1
)(
any
().
bind
(
"x"
),
auto
mul_half_x
=
f
(
"mul"
)(
used_once
(),
either_arg
(
0
,
1
)(
has_value
(
0.5
),
any
().
bind
(
"x"
)));
f
(
"add"
)(
any_arg
(
0
,
1
)(
f
(
"mul"
)(
return
f
(
"mul"
)(
either_arg
(
0
,
1
)(
mul_half_x
,
add_one_tanh
));
either_arg
(
0
,
1
)(
has_value
(
0.5
f
),
tanh_fn
()))))));
}
}
};
};
}
// namespace detail
}
// namespace detail
...
@@ -60,7 +71,7 @@ struct gelu_tanh_matcher
...
@@ -60,7 +71,7 @@ struct gelu_tanh_matcher
template
<
class
F
>
template
<
class
F
>
auto
gelu_tanh
(
F
f
)
auto
gelu_tanh
(
F
f
)
{
{
return
detail
::
gelu_tanh_matcher
<
F
>
{
f
}.
matcher
();
return
detail
::
gelu_tanh_matcher
<
F
>
{
f
}.
matcher
_v1
();
}
}
inline
auto
gelu_tanh
()
inline
auto
gelu_tanh
()
...
...
src/include/migraphx/matcher.hpp
View file @
2bca512e
...
@@ -864,7 +864,7 @@ auto skip_broadcasts_transposes_contiguous(Ms... ms)
...
@@ -864,7 +864,7 @@ auto skip_broadcasts_transposes_contiguous(Ms... ms)
}
}
template
<
class
T
>
template
<
class
T
>
inline
auto
has_value
(
T
x
,
floa
t
tol
erance
=
1
e-6
)
inline
auto
has_value
(
T
x
,
std
::
size_
t
a
tol
_mult
=
10
,
std
::
size_t
rtol_mult
=
1
0
)
{
{
return
skip_broadcasts_converts
(
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
return
skip_broadcasts_converts
(
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"@literal"
)
if
(
ins
->
name
()
!=
"@literal"
)
...
@@ -874,8 +874,13 @@ inline auto has_value(T x, float tolerance = 1e-6)
...
@@ -874,8 +874,13 @@ inline auto has_value(T x, float tolerance = 1e-6)
return
false
;
return
false
;
bool
b
=
false
;
bool
b
=
false
;
l
.
visit
([
&
](
auto
v
)
{
l
.
visit
([
&
](
auto
v
)
{
if
(
std
::
all_of
(
// cast to the literal's data type before comparing
v
.
begin
(),
v
.
end
(),
[
&
](
auto
val
)
{
return
std
::
fabs
(
val
-
x
)
<
tolerance
;
}))
using
type
=
typename
decltype
(
v
)
::
value_type
;
auto
eps
=
std
::
numeric_limits
<
type
>::
epsilon
();
if
(
std
::
all_of
(
v
.
begin
(),
v
.
end
(),
[
&
](
auto
val
)
{
return
std
::
fabs
(
val
-
static_cast
<
type
>
(
x
))
<
(
atol_mult
*
eps
+
rtol_mult
*
eps
*
std
::
fabs
(
val
));
}))
b
=
true
;
b
=
true
;
});
});
return
b
;
return
b
;
...
...
src/rewrite_gelu.cpp
View file @
2bca512e
...
@@ -70,6 +70,7 @@ struct find_tanh_fast_gelu
...
@@ -70,6 +70,7 @@ struct find_tanh_fast_gelu
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
/*
auto ins = r.result;
auto ins = r.result;
auto x = r.instructions["x"];
auto x = r.instructions["x"];
auto sqrt_2_rpi = m.add_literal(
auto sqrt_2_rpi = m.add_literal(
...
@@ -89,6 +90,18 @@ struct find_tanh_fast_gelu
...
@@ -89,6 +90,18 @@ struct find_tanh_fast_gelu
auto cdf = insert_common_op(m, ins, make_op("div"), {one, e});
auto cdf = insert_common_op(m, ins, make_op("div"), {one, e});
auto y = m.insert_instruction(ins, make_op("mul"), x, cdf);
auto y = m.insert_instruction(ins, make_op("mul"), x, cdf);
m.replace_instruction(ins, y);
m.replace_instruction(ins, y);
*/
auto
ins
=
r
.
result
;
auto
x
=
r
.
instructions
[
"x"
];
auto
sqrt1_2
=
m
.
add_literal
(
literal
{
shape
{
x
->
get_shape
().
type
()},
{
M_SQRT1_2
}});
auto
one
=
m
.
add_literal
(
literal
{
shape
{
x
->
get_shape
().
type
()},
{
1.0
f
}});
auto
one_half
=
m
.
add_literal
(
literal
{
shape
{
x
->
get_shape
().
type
()},
{
0.5
f
}});
auto
a
=
insert_common_op
(
m
,
ins
,
make_op
(
"mul"
),
{
x
,
sqrt1_2
});
auto
erf
=
m
.
insert_instruction
(
ins
,
make_op
(
"erf"
),
a
);
auto
add_erf
=
insert_common_op
(
m
,
ins
,
make_op
(
"add"
),
{
one
,
erf
});
auto
b
=
insert_common_op
(
m
,
ins
,
make_op
(
"mul"
),
{
one_half
,
add_erf
});
auto
y
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
x
,
b
);
m
.
replace_instruction
(
ins
,
y
);
}
}
};
};
...
...
src/simplify_algebra.cpp
View file @
2bca512e
...
@@ -1217,7 +1217,7 @@ struct find_unit_ops
...
@@ -1217,7 +1217,7 @@ struct find_unit_ops
auto
div_1
=
auto
div_1
=
match
::
name
(
"div"
)(
match
::
args
(
match
::
any
().
bind
(
"x"
),
match
::
has_value
(
1.0
f
)));
match
::
name
(
"div"
)(
match
::
args
(
match
::
any
().
bind
(
"x"
),
match
::
has_value
(
1.0
f
)));
auto
add_0
=
match
::
name
(
"add"
)(
auto
add_0
=
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
has_value
(
0.0
f
,
1e-12
),
match
::
any
().
bind
(
"x"
)));
match
::
either_arg
(
0
,
1
)(
match
::
has_value
(
0.0
f
,
0
,
0
),
match
::
any
().
bind
(
"x"
)));
auto
sub_0
=
auto
sub_0
=
match
::
name
(
"sub"
)(
match
::
args
(
match
::
any
().
bind
(
"x"
),
match
::
has_value
(
0.0
f
)));
match
::
name
(
"sub"
)(
match
::
args
(
match
::
any
().
bind
(
"x"
),
match
::
has_value
(
0.0
f
)));
return
match
::
any_of
(
mul_1
,
div_1
,
add_0
,
sub_0
);
return
match
::
any_of
(
mul_1
,
div_1
,
add_0
,
sub_0
);
...
...
src/targets/gpu/lowering.cpp
View file @
2bca512e
...
@@ -83,7 +83,8 @@ struct miopen_apply
...
@@ -83,7 +83,8 @@ struct miopen_apply
assert
(
mod
!=
nullptr
);
assert
(
mod
!=
nullptr
);
assert
(
pass
!=
nullptr
);
assert
(
pass
!=
nullptr
);
compute_fp32
=
get_compute_fp32_flag
();
// compute_fp32 = get_compute_fp32_flag();
compute_fp32
=
true
;
offload_copy
=
(
mod
==
mpm
->
get_root_module
())
?
pass
->
offload_copy
:
false
;
offload_copy
=
(
mod
==
mpm
->
get_root_module
())
?
pass
->
offload_copy
:
false
;
add_generic_op
(
"contiguous"
);
add_generic_op
(
"contiguous"
);
...
...
test/rewrite_gelu_test.cpp
View file @
2bca512e
...
@@ -122,4 +122,49 @@ TEST_CASE(non_bias_gelu)
...
@@ -122,4 +122,49 @@ TEST_CASE(non_bias_gelu)
EXPECT
(
m1
==
m2
);
EXPECT
(
m1
==
m2
);
}
}
TEST_CASE
(
tanh_gelu_distilgpt2_fp16
)
{
// Uses constant values seen in the distilgpt2_fp16 model, note how they're not exactly right
migraphx
::
shape
s1
{
migraphx
::
shape
::
half_type
,
{
2
,
4
,
8
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
half_type
};
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
s1
);
auto
fit_const
=
m1
.
add_literal
(
migraphx
::
literal
{
s2
,
{
0.044708251953125
}});
auto
sqrt_2_rpi
=
m1
.
add_literal
(
migraphx
::
literal
{
s2
,
{
0.7978515625
}});
auto
one
=
m1
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.0
f
}});
auto
one_half
=
m1
.
add_literal
(
migraphx
::
literal
{
s2
,
{
0.5
f
}});
auto
three
=
m1
.
add_literal
(
migraphx
::
literal
{
s2
,
{
3.0
f
}});
auto
pow0
=
add_common_op
(
m1
,
migraphx
::
make_op
(
"pow"
),
{
x
,
three
});
auto
mul0
=
add_common_op
(
m1
,
migraphx
::
make_op
(
"mul"
),
{
pow0
,
fit_const
});
auto
add0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
{
mul0
,
x
});
auto
mul1
=
add_common_op
(
m1
,
migraphx
::
make_op
(
"mul"
),
{
add0
,
sqrt_2_rpi
});
auto
tanh0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"tanh"
),
mul1
);
auto
add1
=
add_common_op
(
m1
,
migraphx
::
make_op
(
"add"
),
{
tanh0
,
one
});
auto
mul2
=
add_common_op
(
m1
,
migraphx
::
make_op
(
"mul"
),
{
x
,
one_half
});
auto
y
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
{
add1
,
mul2
});
m1
.
add_return
({
y
});
}
migraphx
::
rewrite_gelu
pass
;
pass
.
apply
(
m1
);
migraphx
::
dead_code_elimination
dce
;
dce
.
apply
(
m1
);
migraphx
::
module
m2
;
{
auto
x
=
m2
.
add_parameter
(
"x"
,
s1
);
auto
sqrt1_2
=
m2
.
add_literal
(
migraphx
::
literal
{
s2
,
{
M_SQRT1_2
}});
auto
one
=
m2
.
add_literal
(
migraphx
::
literal
{
s2
,
{
1.0
f
}});
auto
one_half
=
m2
.
add_literal
(
migraphx
::
literal
{
s2
,
{
0.5
f
}});
auto
a
=
add_common_op
(
m2
,
migraphx
::
make_op
(
"mul"
),
{
x
,
sqrt1_2
});
auto
erf
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"erf"
),
a
);
auto
add_erf
=
add_common_op
(
m2
,
migraphx
::
make_op
(
"add"
),
{
one
,
erf
});
auto
b
=
add_common_op
(
m2
,
migraphx
::
make_op
(
"mul"
),
{
one_half
,
add_erf
});
auto
y
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
x
,
b
);
m2
.
add_return
({
y
});
}
EXPECT
(
m1
==
m2
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
tools/accuracy/accuracy_checker.py
View file @
2bca512e
...
@@ -136,9 +136,13 @@ def check_correctness(gold_outputs,
...
@@ -136,9 +136,13 @@ def check_correctness(gold_outputs,
if
verbose
:
if
verbose
:
with
np
.
printoptions
(
threshold
=
np
.
inf
):
with
np
.
printoptions
(
threshold
=
np
.
inf
):
print
(
'
\n
Output {} is incorrect ...'
.
format
(
i
))
print
(
'
\n
Output {} is incorrect ...'
.
format
(
i
))
print
(
'Expected value:
\n
{}
\n
'
.
format
(
gold_outputs
[
i
]))
#print('Expected value: \n{}'.format(gold_outputs[i]))
print
(
'
\n
......
\n
'
)
#print('\n......\n')
print
(
'Actual value:
\n
{}
\n
'
.
format
(
outputs
[
i
]))
#print('Actual value: \n{}\n'.format(outputs[i]))
diff
=
gold_outputs
[
i
]
-
outputs
[
i
]
#print(f'Difference: {diff}')
max_diff
=
np
.
max
(
np
.
abs
(
diff
))
print
(
f
'Max Difference:
{
max_diff
}
'
)
else
:
else
:
print
(
'Outputs do not match'
)
print
(
'Outputs do not match'
)
break
break
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment