Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0dc81c3b
"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "aa7b76b50895fe52675877619fea508d07d06993"
Commit
0dc81c3b
authored
Oct 02, 2023
by
Khalique Ahmed
Browse files
formatting
parent
00dae07f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
18 deletions
+26
-18
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+13
-9
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+3
-2
test/optimize_module_test.cpp
test/optimize_module_test.cpp
+10
-7
No files found.
src/simplify_algebra.cpp
View file @
0dc81c3b
...
@@ -524,7 +524,7 @@ struct find_inner_broadcast
...
@@ -524,7 +524,7 @@ struct find_inner_broadcast
auto
bcast_strides
=
broadcasts
.
front
()
->
get_shape
().
strides
().
size
();
auto
bcast_strides
=
broadcasts
.
front
()
->
get_shape
().
strides
().
size
();
std
::
vector
<
size_t
>
common_axis
(
bcast_strides
,
0
);
std
::
vector
<
size_t
>
common_axis
(
bcast_strides
,
0
);
// go through the strides of each broadcast,
// go through the strides of each broadcast,
// keep track of values that are equal to 0 in a dimension
// keep track of values that are equal to 0 in a dimension
for
(
auto
i
=
0
;
i
<
bcast_strides
;
i
++
)
for
(
auto
i
=
0
;
i
<
bcast_strides
;
i
++
)
{
{
for
(
auto
j
=
0
;
j
<
broadcasts
.
size
();
j
++
)
for
(
auto
j
=
0
;
j
<
broadcasts
.
size
();
j
++
)
...
@@ -562,15 +562,19 @@ struct find_inner_broadcast
...
@@ -562,15 +562,19 @@ struct find_inner_broadcast
}));
}));
auto
op
=
insert_common_op
(
m
,
ins
,
ins
->
get_operator
(),
inputs
);
auto
op
=
insert_common_op
(
m
,
ins
,
ins
->
get_operator
(),
inputs
);
std
::
vector
<
shape
>
broadcast_shapes
;
std
::
vector
<
shape
>
broadcast_shapes
;
std
::
transform
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
std
::
back_inserter
(
broadcast_shapes
),
[](
auto
broadcast
){
std
::
transform
(
broadcasts
.
begin
(),
return
broadcast
->
get_shape
();
broadcasts
.
end
(),
});
std
::
back_inserter
(
broadcast_shapes
),
[](
auto
broadcast
)
{
return
broadcast
->
get_shape
();
});
std
::
vector
<
shape
>
common_shapes
;
std
::
vector
<
shape
>
common_shapes
;
std
::
transform
(
op
->
inputs
().
begin
(),
op
->
inputs
().
end
(),
std
::
back_inserter
(
common_shapes
),
[](
auto
common
){
std
::
transform
(
op
->
inputs
().
begin
(),
return
common
->
get_shape
();
op
->
inputs
().
end
(),
});
std
::
back_inserter
(
common_shapes
),
if
(
broadcast_shapes
==
common_shapes
and
std
::
all_of
(
op
->
inputs
().
begin
(),
op
->
inputs
().
end
(),
[](
auto
i
){
[](
auto
common
)
{
return
common
->
get_shape
();
});
return
i
->
name
()
==
"broadcast"
or
i
->
name
()
==
"multibroadcast"
;}))
if
(
broadcast_shapes
==
common_shapes
and
std
::
all_of
(
op
->
inputs
().
begin
(),
op
->
inputs
().
end
(),
[](
auto
i
)
{
return
i
->
name
()
==
"broadcast"
or
i
->
name
()
==
"multibroadcast"
;
}))
return
;
return
;
m
.
replace_instruction
(
ins
,
broadcasts
.
front
()
->
get_operator
(),
op
);
m
.
replace_instruction
(
ins
,
broadcasts
.
front
()
->
get_operator
(),
op
);
}
}
...
...
src/simplify_reshapes.cpp
View file @
0dc81c3b
...
@@ -645,13 +645,14 @@ struct find_broadcast_transpose
...
@@ -645,13 +645,14 @@ struct find_broadcast_transpose
if
(
not
input
->
get_shape
().
scalar
())
if
(
not
input
->
get_shape
().
scalar
())
{
{
// find common shape
// find common shape
auto
in_lens
=
input
->
get_shape
().
lens
();
auto
in_lens
=
input
->
get_shape
().
lens
();
int
lens_diff
=
ins_lens
.
size
()
-
in_lens
.
size
();
int
lens_diff
=
ins_lens
.
size
()
-
in_lens
.
size
();
if
(
lens_diff
>
0
)
if
(
lens_diff
>
0
)
{
{
std
::
vector
<
size_t
>
unsqueeze_axes
(
lens_diff
);
std
::
vector
<
size_t
>
unsqueeze_axes
(
lens_diff
);
std
::
iota
(
unsqueeze_axes
.
begin
(),
unsqueeze_axes
.
end
(),
0
);
std
::
iota
(
unsqueeze_axes
.
begin
(),
unsqueeze_axes
.
end
(),
0
);
input
=
m
.
insert_instruction
(
bcast_ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
unsqueeze_axes
}}),
input
);
input
=
m
.
insert_instruction
(
bcast_ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
unsqueeze_axes
}}),
input
);
}
}
input
=
m
.
insert_instruction
(
bcast_ins
,
ins
->
get_operator
(),
input
);
input
=
m
.
insert_instruction
(
bcast_ins
,
ins
->
get_operator
(),
input
);
}
}
...
...
test/optimize_module_test.cpp
View file @
0dc81c3b
...
@@ -82,15 +82,18 @@ TEST_CASE(broadcast_transpose_inner_broadcast_generic)
...
@@ -82,15 +82,18 @@ TEST_CASE(broadcast_transpose_inner_broadcast_generic)
run_pass
(
m1
);
run_pass
(
m1
);
migraphx
::
module
m2
;
migraphx
::
module
m2
;
{
{
auto
l1
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
,
10
}});
auto
l1
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
,
10
}});
auto
l2
=
m2
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
l2
=
m2
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
unsqueeze
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
l2
);
auto
unsqueeze
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
l2
);
auto
transpose
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
}}}),
unsqueeze
);
auto
transpose
=
m2
.
add_instruction
(
auto
mb1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
5
,
10
}}}),
l1
);
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
}}}),
unsqueeze
);
auto
mb2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
5
,
10
}}}),
transpose
);
auto
mb1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
5
,
10
}}}),
l1
);
auto
mb2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
5
,
10
}}}),
transpose
);
auto
mul
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
mb1
,
mb2
);
auto
mul
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
mb1
,
mb2
);
auto
mb3
=
auto
mb3
=
m2
.
add_instruction
(
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
5
,
10
}}}),
mul
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
5
,
10
}}}),
mul
);
m2
.
add_return
({
mb3
});
m2
.
add_return
({
mb3
});
}
}
EXPECT
(
m1
==
m2
);
EXPECT
(
m1
==
m2
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment