Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
fa485ae6
Commit
fa485ae6
authored
Jul 01, 2019
by
Paul
Browse files
Use lazy match operators so it will still short-circuit
parent
6d56671b
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
68 additions
and
10 deletions
+68
-10
src/driver/main.cpp
src/driver/main.cpp
+23
-0
src/include/migraphx/functional.hpp
src/include/migraphx/functional.hpp
+6
-0
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+34
-7
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+3
-3
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+2
-0
No files found.
src/driver/main.cpp
View file @
fa485ae6
...
...
@@ -7,6 +7,14 @@
#include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp>
namespace
migraphx
{
namespace
driver
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -17,6 +25,7 @@ struct loader
std
::
string
file_type
;
bool
is_nhwc
=
true
;
unsigned
trim
=
0
;
bool
optimize
=
false
;
void
parse
(
argument_parser
&
ap
)
{
...
...
@@ -26,6 +35,7 @@ struct loader
ap
(
is_nhwc
,
{
"--nhwc"
},
ap
.
help
(
"Treat tensorflow format as nhwc"
),
ap
.
set_value
(
true
));
ap
(
is_nhwc
,
{
"--nchw"
},
ap
.
help
(
"Treat tensorflow format as nchw"
),
ap
.
set_value
(
false
));
ap
(
trim
,
{
"--trim"
,
"-t"
},
ap
.
help
(
"Trim instructions from the end"
));
ap
(
optimize
,
{
"--optimize"
},
ap
.
help
(
"Optimize when reading"
),
ap
.
set_value
(
true
));
}
program
load
()
...
...
@@ -48,6 +58,19 @@ struct loader
auto
last
=
std
::
prev
(
p
.
end
(),
trim
);
p
.
remove_instructions
(
last
,
p
.
end
());
}
if
(
optimize
)
migraphx
::
run_passes
(
p
,
{
migraphx
::
eliminate_identity
{},
migraphx
::
dead_code_elimination
{},
migraphx
::
simplify_algebra
{},
migraphx
::
dead_code_elimination
{},
migraphx
::
simplify_reshapes
{},
migraphx
::
dead_code_elimination
{},
migraphx
::
propagate_constant
{},
migraphx
::
dead_code_elimination
{},
migraphx
::
eliminate_pad
{},
migraphx
::
dead_code_elimination
{},
});
return
p
;
}
};
...
...
src/include/migraphx/functional.hpp
View file @
fa485ae6
...
...
@@ -190,6 +190,12 @@ auto pop_back_args(Ts&&... xs)
};
}
template
<
class
T
>
auto
always
(
T
x
)
{
return
[
=
](
auto
&&
...)
{
return
x
;
};
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/matcher.hpp
View file @
fa485ae6
...
...
@@ -240,6 +240,24 @@ void find_matches(program& p, Ms&&... ms)
}
}
struct
lazy_and
{
template
<
class
F
,
class
G
>
bool
operator
()(
F
f
,
G
g
)
const
{
return
f
()
and
g
();
}
};
struct
lazy_or
{
template
<
class
F
,
class
G
>
bool
operator
()(
F
f
,
G
g
)
const
{
return
f
()
or
g
();
}
};
template
<
class
Op
,
bool
Start
,
bool
Matches
>
struct
folder
{
...
...
@@ -248,8 +266,11 @@ struct folder
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
Op
op
;
auto
matched
=
[
&
](
auto
m
)
{
return
[
&
]{
return
ctx
.
matched
(
m
,
ins
);
};
};
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
return
op
(
x
,
y
.
match
(
ctx
,
ins
)
!=
ctx
.
not_foun
d
());
return
op
(
always
(
x
),
matche
d
(
y
));
})(
Start
,
ms
...);
if
(
matches
==
Matches
)
return
ins
;
...
...
@@ -265,9 +286,15 @@ struct folder
Op
op
;
bool
matches
=
Start
;
select
(
start
,
[
&
](
auto
ins
)
{
matches
=
op
(
matches
,
fold
([
&
](
auto
x
,
auto
y
)
{
return
op
(
x
,
y
.
match
(
ctx
,
ins
)
!=
ctx
.
not_found
());
})(
Start
,
ms
...));
auto
matched
=
[
&
](
auto
m
)
{
return
[
&
]{
return
ctx
.
matched
(
m
,
ins
);
};
};
auto
fold_match
=
[
&
]
{
return
fold
([
&
](
auto
x
,
auto
y
)
{
return
op
(
always
(
x
),
matched
(
y
));
})(
Start
,
ms
...);
};
matches
=
op
(
always
(
matches
),
fold_match
);
});
if
(
matches
==
Matches
)
return
start
;
...
...
@@ -277,9 +304,9 @@ struct folder
}
};
const
constexpr
auto
all_of
=
folder
<
std
::
logical_and
<
bool
>
,
true
,
true
>
{};
const
constexpr
auto
any_of
=
folder
<
std
::
logical_or
<
bool
>
,
false
,
true
>
{};
const
constexpr
auto
none_of
=
folder
<
std
::
logical_or
<
bool
>
,
false
,
false
>
{};
const
constexpr
auto
all_of
=
folder
<
lazy_and
,
true
,
true
>
{};
const
constexpr
auto
any_of
=
folder
<
lazy_or
,
false
,
true
>
{};
const
constexpr
auto
none_of
=
folder
<
lazy_or
,
false
,
false
>
{};
inline
auto
inputs
()
{
...
...
src/simplify_reshapes.cpp
View file @
fa485ae6
...
...
@@ -162,7 +162,6 @@ struct find_transpose
}
while
(
x
!=
t
and
t
->
name
()
==
"transpose"
);
if
(
t
==
ins
or
t
->
name
()
!=
"transpose"
)
return
;
p
.
debug_print
();
if
(
is_no_transpose
(
dims
))
{
p
.
replace_instruction
(
ins
,
t
->
inputs
().
front
());
...
...
@@ -219,8 +218,9 @@ void simplify_reshapes::apply(program& p) const
ins
,
find_nop_reshapes
{},
find_reshaper
{},
find_transpose
{},
find_concat_transpose
{});
find_transpose
{}
// find_concat_transpose{}
);
}
}
...
...
src/targets/gpu/fuse_ops.cpp
View file @
fa485ae6
...
...
@@ -236,6 +236,8 @@ struct find_triadd
auto
input_ins
=
r
.
instructions
[
"input"
];
auto
ins
=
r
.
result
;
auto
args
=
add_ins
->
inputs
();
assert
(
add_ins
!=
input_ins
);
auto
is_broadcasted
=
[](
auto
arg
)
{
return
arg
->
get_shape
().
broadcasted
();
};
if
(
std
::
count_if
(
args
.
begin
(),
args
.
end
(),
is_broadcasted
)
>
1
)
return
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment