Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b4a9a9c5
Commit
b4a9a9c5
authored
Oct 04, 2023
by
Khalique Ahmed
Browse files
add tests, rename variables
parent
8d7948aa
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
73 additions
and
6 deletions
+73
-6
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+8
-6
test/simplify_algebra_test.cpp
test/simplify_algebra_test.cpp
+17
-0
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+48
-0
No files found.
src/simplify_reshapes.cpp
View file @
b4a9a9c5
...
...
@@ -642,8 +642,8 @@ struct find_broadcast_transpose
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
ins_lens
=
ins
->
get_shape
().
lens
();
auto
transpose
=
r
.
result
;
auto
transpose_lens
=
transpose
->
get_shape
().
lens
();
auto
bcast_ins
=
r
.
instructions
[
"bcast_ins"
];
auto
input
=
bcast_ins
->
inputs
().
front
();
// scalar transformation does not need extra transpose
...
...
@@ -651,7 +651,8 @@ struct find_broadcast_transpose
{
// find common shape
auto
in_lens
=
input
->
get_shape
().
lens
();
int
lens_diff
=
ins_lens
.
size
()
-
in_lens
.
size
();
int
lens_diff
=
transpose_lens
.
size
()
-
in_lens
.
size
();
// insert unsqueeze if input lens < transpose lens
if
(
lens_diff
>
0
)
{
std
::
vector
<
size_t
>
unsqueeze_axes
(
lens_diff
);
...
...
@@ -659,11 +660,12 @@ struct find_broadcast_transpose
input
=
m
.
insert_instruction
(
bcast_ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
unsqueeze_axes
}}),
input
);
}
input
=
m
.
insert_instruction
(
bcast_ins
,
ins
->
get_operator
(),
input
);
// apply transpose before the multibroadcast
input
=
m
.
insert_instruction
(
bcast_ins
,
transpose
->
get_operator
(),
input
);
}
auto
new_mbcast
=
m
.
insert_instruction
(
bcast_ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
ins
_lens
}}),
input
);
m
.
replace_instruction
(
ins
,
new_mbcast
);
bcast_ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
transpose
_lens
}}),
input
);
m
.
replace_instruction
(
transpose
,
new_mbcast
);
}
};
...
...
test/simplify_algebra_test.cpp
View file @
b4a9a9c5
...
...
@@ -669,6 +669,23 @@ TEST_CASE(simplify_inner_broadcast_different_broadcasts)
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
simplify_inner_broadcast_no_common_axis
)
{
auto
b
=
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
5
,
10
}}});
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
5
,
10
}});
auto
y
=
m1
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
5
,
1
}});
auto
xb
=
m1
.
add_instruction
(
b
,
x
);
auto
yb
=
m1
.
add_instruction
(
b
,
y
);
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
xb
,
yb
);
m1
.
add_instruction
(
pass_op
{},
sum
);
}
migraphx
::
module
m2
=
m1
;
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
simplify_add_conv1
)
{
migraphx
::
module
m
;
...
...
test/simplify_reshapes_test.cpp
View file @
b4a9a9c5
...
...
@@ -67,6 +67,54 @@ migraphx::module make_concat_multibroadcast(const std::vector<size_t>& in_lens,
return
m
;
}
TEST_CASE
(
broadcast_transpose
)
{
migraphx
::
module
m1
;
{
auto
l
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
mb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
5
}}}),
l
);
auto
t1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
2
,
0
,
1
}}}),
mb
);
m1
.
add_return
({
t1
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
l
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
u1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
l
);
auto
t1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
2
,
0
,
1
}}}),
u1
);
auto
mb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
5
,
2
,
3
}}}),
t1
);
m2
.
add_return
({
mb
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
broadcast_transpose_opt
)
{
// extra transpose from transformation will be optimized out
migraphx
::
module
m1
;
{
auto
l
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
mb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
5
}}}),
l
);
auto
t1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
,
2
}}}),
mb
);
m1
.
add_return
({
t1
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
l
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
u1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
l
);
auto
mb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
2
,
5
}}}),
u1
);
m2
.
add_return
({
mb
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
broadcast_transpose_scalar
)
{
migraphx
::
module
m1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment