Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7f4c7809
Commit
7f4c7809
authored
Jun 30, 2019
by
Paul
Browse files
Formatting
parent
2e51006e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
37 deletions
+32
-37
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+18
-18
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+14
-19
No files found.
src/include/migraphx/matcher.hpp
View file @
7f4c7809
...
@@ -21,7 +21,7 @@ struct matcher_context
...
@@ -21,7 +21,7 @@ struct matcher_context
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
instruction_ref
not_found
()
const
{
return
last
;
}
instruction_ref
not_found
()
const
{
return
last
;
}
template
<
class
M
>
template
<
class
M
>
bool
matched
(
M
m
,
instruction_ref
ins
)
bool
matched
(
M
m
,
instruction_ref
ins
)
{
{
return
m
.
match
(
*
this
,
ins
)
!=
this
->
not_found
();
return
m
.
match
(
*
this
,
ins
)
!=
this
->
not_found
();
...
@@ -240,7 +240,7 @@ void find_matches(program& p, Ms&&... ms)
...
@@ -240,7 +240,7 @@ void find_matches(program& p, Ms&&... ms)
}
}
}
}
template
<
class
Op
,
bool
Start
,
bool
Matches
>
template
<
class
Op
,
bool
Start
,
bool
Matches
>
struct
folder
struct
folder
{
{
template
<
class
...
Ts
>
template
<
class
...
Ts
>
...
@@ -257,7 +257,7 @@ struct folder
...
@@ -257,7 +257,7 @@ struct folder
});
});
}
}
template
<
class
Selector
>
template
<
class
Selector
>
auto
operator
[](
Selector
select
)
const
auto
operator
[](
Selector
select
)
const
{
{
return
[
=
](
auto
...
ms
)
{
return
[
=
](
auto
...
ms
)
{
...
@@ -284,7 +284,7 @@ const constexpr auto none_of = folder<std::logical_or<bool>, false, false>{};
...
@@ -284,7 +284,7 @@ const constexpr auto none_of = folder<std::logical_or<bool>, false, false>{};
inline
auto
inputs
()
inline
auto
inputs
()
{
{
return
[](
auto
ins
,
auto
f
)
{
return
[](
auto
ins
,
auto
f
)
{
for
(
auto
&&
x
:
ins
->
inputs
())
for
(
auto
&&
x
:
ins
->
inputs
())
f
(
x
);
f
(
x
);
};
};
}
}
...
@@ -292,7 +292,7 @@ inline auto inputs()
...
@@ -292,7 +292,7 @@ inline auto inputs()
inline
auto
outputs
()
inline
auto
outputs
()
{
{
return
[](
auto
ins
,
auto
f
)
{
return
[](
auto
ins
,
auto
f
)
{
for
(
auto
&&
x
:
ins
->
outputs
())
for
(
auto
&&
x
:
ins
->
outputs
())
f
(
x
);
f
(
x
);
};
};
}
}
...
@@ -312,12 +312,11 @@ MIGRAPHX_PRED_MATCHER(transpose_shape, instruction_ref ins)
...
@@ -312,12 +312,11 @@ MIGRAPHX_PRED_MATCHER(transpose_shape, instruction_ref ins)
MIGRAPHX_PRED_MATCHER
(
same_shapes
,
instruction_ref
ins
)
MIGRAPHX_PRED_MATCHER
(
same_shapes
,
instruction_ref
ins
)
{
{
if
(
ins
->
inputs
().
empty
())
if
(
ins
->
inputs
().
empty
())
return
false
;
return
false
;
auto
s
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
s
=
ins
->
inputs
().
front
()
->
get_shape
();
return
std
::
all_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
x
)
{
return
std
::
all_of
(
return
x
->
get_shape
()
==
s
;
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
x
)
{
return
x
->
get_shape
()
==
s
;
});
});
}
}
MIGRAPHX_BASIC_MATCHER
(
output
,
matcher_context
&
ctx
,
instruction_ref
ins
)
MIGRAPHX_BASIC_MATCHER
(
output
,
matcher_context
&
ctx
,
instruction_ref
ins
)
...
@@ -336,7 +335,7 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
...
@@ -336,7 +335,7 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
return
ctx
.
not_found
();
return
ctx
.
not_found
();
}
}
template
<
class
...
Ms
>
template
<
class
...
Ms
>
auto
skip_output
(
Ms
...
ms
)
auto
skip_output
(
Ms
...
ms
)
{
{
auto
m
=
any_of
(
ms
...);
auto
m
=
any_of
(
ms
...);
...
@@ -345,10 +344,10 @@ auto skip_output(Ms... ms)
...
@@ -345,10 +344,10 @@ auto skip_output(Ms... ms)
if
(
ins
->
outputs
().
size
()
==
1
)
if
(
ins
->
outputs
().
size
()
==
1
)
{
{
auto
next
=
ins
->
outputs
().
front
();
auto
next
=
ins
->
outputs
().
front
();
if
(
ctx
.
matched
(
m
,
next
))
if
(
ctx
.
matched
(
m
,
next
))
{
{
auto
skipped_next
=
self
(
next
);
auto
skipped_next
=
self
(
next
);
if
(
skipped_next
!=
ctx
.
not_found
())
if
(
skipped_next
!=
ctx
.
not_found
())
return
skipped_next
;
return
skipped_next
;
}
}
return
next
;
return
next
;
...
@@ -366,8 +365,9 @@ inline auto name(std::string s)
...
@@ -366,8 +365,9 @@ inline auto name(std::string s)
inline
auto
name
(
std
::
unordered_set
<
std
::
string
>
names
)
inline
auto
name
(
std
::
unordered_set
<
std
::
string
>
names
)
{
{
return
make_basic_pred_matcher
(
return
make_basic_pred_matcher
([
=
,
names
=
std
::
move
(
names
)
](
instruction_ref
ins
)
{
[
=
,
names
=
std
::
move
(
names
)
](
instruction_ref
ins
)
{
return
names
.
count
(
ins
->
name
())
>
0
;
});
return
names
.
count
(
ins
->
name
())
>
0
;
});
}
}
inline
auto
nargs
(
std
::
size_t
n
)
inline
auto
nargs
(
std
::
size_t
n
)
...
...
src/simplify_reshapes.cpp
View file @
7f4c7809
...
@@ -25,10 +25,7 @@ const auto& reshaper_names()
...
@@ -25,10 +25,7 @@ const auto& reshaper_names()
return
names
;
return
names
;
}
}
bool
is_reshaper
(
instruction_ref
ins
)
bool
is_reshaper
(
instruction_ref
ins
)
{
return
contains
(
reshaper_names
(),
ins
->
name
());
}
{
return
contains
(
reshaper_names
(),
ins
->
name
());
}
instruction_ref
find_transpose_input
(
instruction_ref
ins
)
instruction_ref
find_transpose_input
(
instruction_ref
ins
)
{
{
...
@@ -90,7 +87,8 @@ struct find_reshaper
...
@@ -90,7 +87,8 @@ struct find_reshaper
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
reshaper_names
())(
match
::
any_of
[
match
::
outputs
()](
match
::
name
(
reshaper_names
())));
return
match
::
name
(
reshaper_names
())(
match
::
any_of
[
match
::
outputs
()](
match
::
name
(
reshaper_names
())));
}
}
void
apply
(
program
&
p
,
match
::
matcher_result
mr
)
const
void
apply
(
program
&
p
,
match
::
matcher_result
mr
)
const
...
@@ -139,7 +137,8 @@ struct find_transpose
...
@@ -139,7 +137,8 @@ struct find_transpose
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"transpose"
)(
match
::
none_of
(
match
::
skip_output
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"transpose"
))));
return
match
::
name
(
"transpose"
)(
match
::
none_of
(
match
::
skip_output
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"transpose"
))));
}
}
void
apply
(
program
&
p
,
match
::
matcher_result
mr
)
const
void
apply
(
program
&
p
,
match
::
matcher_result
mr
)
const
...
@@ -172,7 +171,8 @@ struct find_concat_transpose
...
@@ -172,7 +171,8 @@ struct find_concat_transpose
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"concat"
)(
match
::
same_shapes
(),
match
::
all_of
[
match
::
inputs
()](
match
::
transpose_shape
()));
return
match
::
name
(
"concat"
)(
match
::
same_shapes
(),
match
::
all_of
[
match
::
inputs
()](
match
::
transpose_shape
()));
}
}
void
apply
(
program
&
p
,
match
::
matcher_result
mr
)
const
void
apply
(
program
&
p
,
match
::
matcher_result
mr
)
const
...
@@ -187,10 +187,9 @@ struct find_concat_transpose
...
@@ -187,10 +187,9 @@ struct find_concat_transpose
std
::
vector
<
instruction_ref
>
inputs
;
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
std
::
transform
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
i
)
{
ins
->
inputs
().
end
(),
return
p
.
insert_instruction
(
ins
,
op
::
transpose
{
permutation
},
i
);
std
::
back_inserter
(
inputs
),
});
[
&
](
auto
i
)
{
return
p
.
insert_instruction
(
ins
,
op
::
transpose
{
permutation
},
i
);
});
auto
concat
=
p
.
insert_instruction
(
ins
,
op
,
inputs
);
auto
concat
=
p
.
insert_instruction
(
ins
,
op
,
inputs
);
auto
t
=
p
.
insert_instruction
(
ins
,
op
::
transpose
{
ipermutaion
},
concat
);
auto
t
=
p
.
insert_instruction
(
ins
,
op
::
transpose
{
ipermutaion
},
concat
);
p
.
replace_instruction
(
ins
,
t
);
p
.
replace_instruction
(
ins
,
t
);
...
@@ -207,11 +206,7 @@ void simplify_reshapes::apply(program& p) const
...
@@ -207,11 +206,7 @@ void simplify_reshapes::apply(program& p) const
// Skip possible dead instructions
// Skip possible dead instructions
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
continue
;
continue
;
match
::
find_matches
(
p
,
ins
,
match
::
find_matches
(
p
,
ins
,
find_reshaper
{},
find_transpose
{},
find_concat_transpose
{});
find_reshaper
{},
find_transpose
{},
find_concat_transpose
{}
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment