Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
06b02add
Commit
06b02add
authored
Apr 08, 2019
by
Paul
Browse files
Fix broken tests
parent
4b7a267a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
20 deletions
+15
-20
src/constant_propagate.cpp
src/constant_propagate.cpp
+11
-17
src/include/migraphx/op/binary.hpp
src/include/migraphx/op/binary.hpp
+4
-3
No files found.
src/constant_propagate.cpp
View file @
06b02add
...
@@ -7,32 +7,26 @@
...
@@ -7,32 +7,26 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
match_const_add
bool
skip_propogate
(
instruction_ref
ins
)
{
{
auto
matcher
()
const
if
(
ins
->
name
()
==
"@literal"
)
{
return
true
;
return
match
::
name
(
"add"
)(
match
::
args
(
match
::
name
(
"@literal"
),
match
::
name
(
"@literal"
)));
if
(
ins
->
get_shape
().
broadcasted
()
and
not
ins
->
get_shape
().
scalar
())
}
return
true
;
if
(
ins
->
get_shape
().
scalar
()
and
ins
->
get_shape
().
elements
()
!=
1
)
void
apply
(
program
&
p
,
const
match
::
matcher_result
&
r
)
const
return
true
;
{
return
false
;
auto
ins
=
r
.
result
;
}
auto
arg1
=
ins
->
inputs
().
at
(
0
)
->
get_literal
();
auto
arg2
=
ins
->
inputs
().
at
(
1
)
->
get_literal
();
auto
sum
=
p
.
add_literal
(
transform
(
arg1
,
arg2
,
[](
auto
x
,
auto
y
)
{
return
x
+
y
;
}));
p
.
replace_instruction
(
ins
,
sum
);
}
};
void
constant_propagate
::
apply
(
program
&
p
)
const
void
constant_propagate
::
apply
(
program
&
p
)
const
{
{
fix
([
&
](
auto
self
,
auto
ins
)
{
fix
([
&
](
auto
self
,
auto
ins
)
{
if
(
not
ins
->
get_shape
().
broadcasted
()
and
ins
->
name
()
!=
"@literal"
)
if
(
not
skip_propogate
(
ins
)
)
{
{
auto
r
=
ins
->
eval
();
auto
r
=
ins
->
eval
();
if
(
not
r
.
empty
())
if
(
not
r
.
empty
())
{
{
assert
(
r
.
get_shape
()
==
ins
->
get_shape
());
auto
l
=
p
.
add_literal
(
r
.
get_shape
(),
r
.
data
());
auto
l
=
p
.
add_literal
(
r
.
get_shape
(),
r
.
data
());
p
.
replace_instruction
(
ins
,
l
);
p
.
replace_instruction
(
ins
,
l
);
return
;
return
;
...
...
src/include/migraphx/op/binary.hpp
View file @
06b02add
...
@@ -28,9 +28,10 @@ struct binary
...
@@ -28,9 +28,10 @@ struct binary
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
}.
has
(
2
).
same_type
().
same_dims
();
check_shapes
{
inputs
}.
has
(
2
).
same_type
().
same_dims
();
auto
t
=
inputs
.
at
(
0
).
type
();
const
auto
&
s
=
inputs
.
front
();
auto
lens
=
inputs
.
at
(
0
).
lens
();
if
(
s
.
scalar
()
and
s
.
elements
()
==
1
)
return
{
t
,
lens
};
return
{
s
.
type
()};
return
{
s
.
type
(),
s
.
lens
()};
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment