Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0dc81c3b
Commit
0dc81c3b
authored
Oct 02, 2023
by
Khalique Ahmed
Browse files
formatting
parent
00dae07f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
18 deletions
+26
-18
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+13
-9
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+3
-2
test/optimize_module_test.cpp
test/optimize_module_test.cpp
+10
-7
No files found.
src/simplify_algebra.cpp
View file @
0dc81c3b
...
@@ -562,15 +562,19 @@ struct find_inner_broadcast
...
@@ -562,15 +562,19 @@ struct find_inner_broadcast
}));
}));
auto
op
=
insert_common_op
(
m
,
ins
,
ins
->
get_operator
(),
inputs
);
auto
op
=
insert_common_op
(
m
,
ins
,
ins
->
get_operator
(),
inputs
);
std
::
vector
<
shape
>
broadcast_shapes
;
std
::
vector
<
shape
>
broadcast_shapes
;
std
::
transform
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
std
::
back_inserter
(
broadcast_shapes
),
[](
auto
broadcast
){
std
::
transform
(
broadcasts
.
begin
(),
return
broadcast
->
get_shape
();
broadcasts
.
end
(),
});
std
::
back_inserter
(
broadcast_shapes
),
[](
auto
broadcast
)
{
return
broadcast
->
get_shape
();
});
std
::
vector
<
shape
>
common_shapes
;
std
::
vector
<
shape
>
common_shapes
;
std
::
transform
(
op
->
inputs
().
begin
(),
op
->
inputs
().
end
(),
std
::
back_inserter
(
common_shapes
),
[](
auto
common
){
std
::
transform
(
op
->
inputs
().
begin
(),
return
common
->
get_shape
();
op
->
inputs
().
end
(),
});
std
::
back_inserter
(
common_shapes
),
if
(
broadcast_shapes
==
common_shapes
and
std
::
all_of
(
op
->
inputs
().
begin
(),
op
->
inputs
().
end
(),
[](
auto
i
){
[](
auto
common
)
{
return
common
->
get_shape
();
});
return
i
->
name
()
==
"broadcast"
or
i
->
name
()
==
"multibroadcast"
;}))
if
(
broadcast_shapes
==
common_shapes
and
std
::
all_of
(
op
->
inputs
().
begin
(),
op
->
inputs
().
end
(),
[](
auto
i
)
{
return
i
->
name
()
==
"broadcast"
or
i
->
name
()
==
"multibroadcast"
;
}))
return
;
return
;
m
.
replace_instruction
(
ins
,
broadcasts
.
front
()
->
get_operator
(),
op
);
m
.
replace_instruction
(
ins
,
broadcasts
.
front
()
->
get_operator
(),
op
);
}
}
...
...
src/simplify_reshapes.cpp
View file @
0dc81c3b
...
@@ -651,7 +651,8 @@ struct find_broadcast_transpose
...
@@ -651,7 +651,8 @@ struct find_broadcast_transpose
{
{
std
::
vector
<
size_t
>
unsqueeze_axes
(
lens_diff
);
std
::
vector
<
size_t
>
unsqueeze_axes
(
lens_diff
);
std
::
iota
(
unsqueeze_axes
.
begin
(),
unsqueeze_axes
.
end
(),
0
);
std
::
iota
(
unsqueeze_axes
.
begin
(),
unsqueeze_axes
.
end
(),
0
);
input
=
m
.
insert_instruction
(
bcast_ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
unsqueeze_axes
}}),
input
);
input
=
m
.
insert_instruction
(
bcast_ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
unsqueeze_axes
}}),
input
);
}
}
input
=
m
.
insert_instruction
(
bcast_ins
,
ins
->
get_operator
(),
input
);
input
=
m
.
insert_instruction
(
bcast_ins
,
ins
->
get_operator
(),
input
);
}
}
...
...
test/optimize_module_test.cpp
View file @
0dc81c3b
...
@@ -85,12 +85,15 @@ TEST_CASE(broadcast_transpose_inner_broadcast_generic)
...
@@ -85,12 +85,15 @@ TEST_CASE(broadcast_transpose_inner_broadcast_generic)
auto
l1
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
,
10
}});
auto
l1
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
,
10
}});
auto
l2
=
m2
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
l2
=
m2
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
unsqueeze
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
l2
);
auto
unsqueeze
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
l2
);
auto
transpose
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
}}}),
unsqueeze
);
auto
transpose
=
m2
.
add_instruction
(
auto
mb1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
5
,
10
}}}),
l1
);
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
}}}),
unsqueeze
);
auto
mb2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
5
,
10
}}}),
transpose
);
auto
mb1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
5
,
10
}}}),
l1
);
auto
mb2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
5
,
10
}}}),
transpose
);
auto
mul
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
mb1
,
mb2
);
auto
mul
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
mb1
,
mb2
);
auto
mb3
=
auto
mb3
=
m2
.
add_instruction
(
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
5
,
10
}}}),
mul
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
5
,
10
}}}),
mul
);
m2
.
add_return
({
mb3
});
m2
.
add_return
({
mb3
});
}
}
EXPECT
(
m1
==
m2
);
EXPECT
(
m1
==
m2
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment