Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
50e6d5eb
Commit
50e6d5eb
authored
Oct 16, 2019
by
Paul Fultz II
Committed by
mvermeulen
Oct 16, 2019
Browse files
Flatten nested concats (#391)
* Flatten nested concats * Formatting * Rename tests
parent
756c5908
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
1 deletion
+73
-1
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+34
-1
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+39
-0
No files found.
src/simplify_reshapes.cpp
View file @
50e6d5eb
...
@@ -179,6 +179,38 @@ struct find_concat_transpose
...
@@ -179,6 +179,38 @@ struct find_concat_transpose
}
}
};
};
struct
find_nested_concat
{
auto
matcher
()
const
{
return
match
::
name
(
"concat"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"concat"
)));
}
static
std
::
size_t
get_axis
(
instruction_ref
ins
)
{
auto
op
=
any_cast
<
op
::
concat
>
(
ins
->
get_operator
());
return
op
.
axis
;
}
void
apply
(
program
&
p
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
axis
=
get_axis
(
ins
);
std
::
vector
<
instruction_ref
>
args
;
fix
([
&
](
auto
self
,
auto
&&
inputs
)
{
for
(
auto
&&
i
:
inputs
)
{
if
(
i
->
name
()
==
"concat"
and
get_axis
(
i
)
==
axis
and
i
->
outputs
().
size
()
==
1
)
self
(
i
->
inputs
());
else
args
.
push_back
(
i
);
}
})(
ins
->
inputs
());
p
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
args
);
}
};
void
simplify_reshapes
::
apply
(
program
&
p
)
const
void
simplify_reshapes
::
apply
(
program
&
p
)
const
{
{
for
(
int
i
=
0
;
i
<
2
;
i
++
)
for
(
int
i
=
0
;
i
<
2
;
i
++
)
...
@@ -196,7 +228,8 @@ void simplify_reshapes::apply(program& p) const
...
@@ -196,7 +228,8 @@ void simplify_reshapes::apply(program& p) const
find_nop_reshapes
{},
find_nop_reshapes
{},
find_reshaper
{},
find_reshaper
{},
find_transpose
{},
find_transpose
{},
find_concat_transpose
{});
find_concat_transpose
{},
find_nested_concat
{});
}
}
}
}
}
}
...
...
test/simplify_reshapes_test.cpp
View file @
50e6d5eb
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <basic_ops.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
#include <test.hpp>
...
@@ -328,4 +329,42 @@ TEST_CASE(concat_transpose3)
...
@@ -328,4 +329,42 @@ TEST_CASE(concat_transpose3)
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
1
);
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
1
);
}
}
TEST_CASE
(
nested_concat
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
,
4
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
y
=
p
.
add_parameter
(
"y"
,
s
);
auto
concat1
=
p
.
add_instruction
(
migraphx
::
op
::
concat
{
1
},
x
,
y
);
auto
concat2
=
p
.
add_instruction
(
migraphx
::
op
::
concat
{
1
},
y
,
x
);
auto
concat3
=
p
.
add_instruction
(
migraphx
::
op
::
concat
{
1
},
concat1
,
concat2
);
p
.
add_instruction
(
pass_op
{},
concat3
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
().
lens
()
==
out_shape
.
lens
());
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
2
);
EXPECT
(
std
::
count_if
(
p
.
begin
(),
p
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
})
==
1
);
}
TEST_CASE
(
nested_concat_partial
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
,
4
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
y
=
p
.
add_parameter
(
"y"
,
s
);
auto
l
=
p
.
add_literal
(
migraphx
::
generate_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
3
,
4
}}));
auto
concat1
=
p
.
add_instruction
(
migraphx
::
op
::
concat
{
1
},
x
,
y
);
auto
concat2
=
p
.
add_instruction
(
migraphx
::
op
::
concat
{
1
},
y
,
x
);
auto
concat3
=
p
.
add_instruction
(
migraphx
::
op
::
concat
{
1
},
concat1
,
concat2
,
l
);
p
.
add_instruction
(
pass_op
{},
concat3
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
().
lens
()
==
out_shape
.
lens
());
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
2
);
EXPECT
(
std
::
count_if
(
p
.
begin
(),
p
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
})
==
1
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment