Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
fa485ae6
Commit
fa485ae6
authored
Jul 01, 2019
by
Paul
Browse files
Use lazy match operators so it will still short-circuit
parent
6d56671b
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
68 additions
and
10 deletions
+68
-10
src/driver/main.cpp
src/driver/main.cpp
+23
-0
src/include/migraphx/functional.hpp
src/include/migraphx/functional.hpp
+6
-0
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+34
-7
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+3
-3
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+2
-0
No files found.
src/driver/main.cpp
View file @
fa485ae6
...
...
@@ -7,6 +7,14 @@
#include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp>
namespace
migraphx
{
namespace
driver
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -17,6 +25,7 @@ struct loader
std
::
string
file_type
;
bool
is_nhwc
=
true
;
unsigned
trim
=
0
;
bool
optimize
=
false
;
void
parse
(
argument_parser
&
ap
)
{
...
...
@@ -26,6 +35,7 @@ struct loader
ap
(
is_nhwc
,
{
"--nhwc"
},
ap
.
help
(
"Treat tensorflow format as nhwc"
),
ap
.
set_value
(
true
));
ap
(
is_nhwc
,
{
"--nchw"
},
ap
.
help
(
"Treat tensorflow format as nchw"
),
ap
.
set_value
(
false
));
ap
(
trim
,
{
"--trim"
,
"-t"
},
ap
.
help
(
"Trim instructions from the end"
));
ap
(
optimize
,
{
"--optimize"
},
ap
.
help
(
"Optimize when reading"
),
ap
.
set_value
(
true
));
}
program
load
()
...
...
@@ -48,6 +58,19 @@ struct loader
auto
last
=
std
::
prev
(
p
.
end
(),
trim
);
p
.
remove_instructions
(
last
,
p
.
end
());
}
if
(
optimize
)
migraphx
::
run_passes
(
p
,
{
migraphx
::
eliminate_identity
{},
migraphx
::
dead_code_elimination
{},
migraphx
::
simplify_algebra
{},
migraphx
::
dead_code_elimination
{},
migraphx
::
simplify_reshapes
{},
migraphx
::
dead_code_elimination
{},
migraphx
::
propagate_constant
{},
migraphx
::
dead_code_elimination
{},
migraphx
::
eliminate_pad
{},
migraphx
::
dead_code_elimination
{},
});
return
p
;
}
};
...
...
src/include/migraphx/functional.hpp
View file @
fa485ae6
...
...
@@ -190,6 +190,12 @@ auto pop_back_args(Ts&&... xs)
};
}
template
<
class
T
>
auto
always
(
T
x
)
{
return
[
=
](
auto
&&
...)
{
return
x
;
};
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/matcher.hpp
View file @
fa485ae6
...
...
@@ -240,6 +240,24 @@ void find_matches(program& p, Ms&&... ms)
}
}
struct
lazy_and
{
template
<
class
F
,
class
G
>
bool
operator
()(
F
f
,
G
g
)
const
{
return
f
()
and
g
();
}
};
struct
lazy_or
{
template
<
class
F
,
class
G
>
bool
operator
()(
F
f
,
G
g
)
const
{
return
f
()
or
g
();
}
};
template
<
class
Op
,
bool
Start
,
bool
Matches
>
struct
folder
{
...
...
@@ -248,8 +266,11 @@ struct folder
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
Op
op
;
auto
matched
=
[
&
](
auto
m
)
{
return
[
&
]{
return
ctx
.
matched
(
m
,
ins
);
};
};
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
return
op
(
x
,
y
.
match
(
ctx
,
ins
)
!=
ctx
.
not_foun
d
());
return
op
(
always
(
x
),
matche
d
(
y
));
})(
Start
,
ms
...);
if
(
matches
==
Matches
)
return
ins
;
...
...
@@ -265,9 +286,15 @@ struct folder
Op
op
;
bool
matches
=
Start
;
select
(
start
,
[
&
](
auto
ins
)
{
matches
=
op
(
matches
,
fold
([
&
](
auto
x
,
auto
y
)
{
return
op
(
x
,
y
.
match
(
ctx
,
ins
)
!=
ctx
.
not_found
());
})(
Start
,
ms
...));
auto
matched
=
[
&
](
auto
m
)
{
return
[
&
]{
return
ctx
.
matched
(
m
,
ins
);
};
};
auto
fold_match
=
[
&
]
{
return
fold
([
&
](
auto
x
,
auto
y
)
{
return
op
(
always
(
x
),
matched
(
y
));
})(
Start
,
ms
...);
};
matches
=
op
(
always
(
matches
),
fold_match
);
});
if
(
matches
==
Matches
)
return
start
;
...
...
@@ -277,9 +304,9 @@ struct folder
}
};
const
constexpr
auto
all_of
=
folder
<
std
::
logical_and
<
bool
>
,
true
,
true
>
{};
const
constexpr
auto
any_of
=
folder
<
std
::
logical_or
<
bool
>
,
false
,
true
>
{};
const
constexpr
auto
none_of
=
folder
<
std
::
logical_or
<
bool
>
,
false
,
false
>
{};
const
constexpr
auto
all_of
=
folder
<
lazy_and
,
true
,
true
>
{};
const
constexpr
auto
any_of
=
folder
<
lazy_or
,
false
,
true
>
{};
const
constexpr
auto
none_of
=
folder
<
lazy_or
,
false
,
false
>
{};
inline
auto
inputs
()
{
...
...
src/simplify_reshapes.cpp
View file @
fa485ae6
...
...
@@ -162,7 +162,6 @@ struct find_transpose
}
while
(
x
!=
t
and
t
->
name
()
==
"transpose"
);
if
(
t
==
ins
or
t
->
name
()
!=
"transpose"
)
return
;
p
.
debug_print
();
if
(
is_no_transpose
(
dims
))
{
p
.
replace_instruction
(
ins
,
t
->
inputs
().
front
());
...
...
@@ -219,8 +218,9 @@ void simplify_reshapes::apply(program& p) const
ins
,
find_nop_reshapes
{},
find_reshaper
{},
find_transpose
{},
find_concat_transpose
{});
find_transpose
{}
// find_concat_transpose{}
);
}
}
...
...
src/targets/gpu/fuse_ops.cpp
View file @
fa485ae6
...
...
@@ -236,6 +236,8 @@ struct find_triadd
auto
input_ins
=
r
.
instructions
[
"input"
];
auto
ins
=
r
.
result
;
auto
args
=
add_ins
->
inputs
();
assert
(
add_ins
!=
input_ins
);
auto
is_broadcasted
=
[](
auto
arg
)
{
return
arg
->
get_shape
().
broadcasted
();
};
if
(
std
::
count_if
(
args
.
begin
(),
args
.
end
(),
is_broadcasted
)
>
1
)
return
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment