Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
00dae07f
Commit
00dae07f
authored
Oct 02, 2023
by
Khalique Ahmed
Browse files
add generic case for broadcast transpose
parent
434a06cf
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
76 additions
and
3 deletions
+76
-3
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+29
-0
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+13
-3
test/optimize_module_test.cpp
test/optimize_module_test.cpp
+34
-0
No files found.
src/simplify_algebra.cpp
View file @
00dae07f
...
@@ -521,6 +521,24 @@ struct find_inner_broadcast
...
@@ -521,6 +521,24 @@ struct find_inner_broadcast
})
<
(
lens
.
size
()
-
1
);
})
<
(
lens
.
size
()
-
1
);
}))
}))
return
;
return
;
auto
bcast_strides
=
broadcasts
.
front
()
->
get_shape
().
strides
().
size
();
std
::
vector
<
size_t
>
common_axis
(
bcast_strides
,
0
);
// go through the strides of each broadcast,
// keep track of values that are equal to 0 in a dimension
for
(
auto
i
=
0
;
i
<
bcast_strides
;
i
++
)
{
for
(
auto
j
=
0
;
j
<
broadcasts
.
size
();
j
++
)
{
if
(
broadcasts
[
j
]
->
get_shape
().
strides
()[
i
]
==
0
)
common_axis
[
i
]
++
;
}
}
// if no common broadcast axis, transformation is not useful
if
(
std
::
find_if
(
common_axis
.
begin
(),
common_axis
.
end
(),
[](
auto
num_common
)
{
return
num_common
>
1
;
})
==
common_axis
.
end
())
return
;
std
::
vector
<
instruction_ref
>
inputs
;
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
broadcasts
.
begin
(),
std
::
transform
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
broadcasts
.
end
(),
...
@@ -543,6 +561,17 @@ struct find_inner_broadcast
...
@@ -543,6 +561,17 @@ struct find_inner_broadcast
return
3
;
return
3
;
}));
}));
auto
op
=
insert_common_op
(
m
,
ins
,
ins
->
get_operator
(),
inputs
);
auto
op
=
insert_common_op
(
m
,
ins
,
ins
->
get_operator
(),
inputs
);
std
::
vector
<
shape
>
broadcast_shapes
;
std
::
transform
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
std
::
back_inserter
(
broadcast_shapes
),
[](
auto
broadcast
){
return
broadcast
->
get_shape
();
});
std
::
vector
<
shape
>
common_shapes
;
std
::
transform
(
op
->
inputs
().
begin
(),
op
->
inputs
().
end
(),
std
::
back_inserter
(
common_shapes
),
[](
auto
common
){
return
common
->
get_shape
();
});
if
(
broadcast_shapes
==
common_shapes
and
std
::
all_of
(
op
->
inputs
().
begin
(),
op
->
inputs
().
end
(),
[](
auto
i
){
return
i
->
name
()
==
"broadcast"
or
i
->
name
()
==
"multibroadcast"
;}))
return
;
m
.
replace_instruction
(
ins
,
broadcasts
.
front
()
->
get_operator
(),
op
);
m
.
replace_instruction
(
ins
,
broadcasts
.
front
()
->
get_operator
(),
op
);
}
}
};
};
...
...
src/simplify_reshapes.cpp
View file @
00dae07f
...
@@ -641,10 +641,20 @@ struct find_broadcast_transpose
...
@@ -641,10 +641,20 @@ struct find_broadcast_transpose
auto
ins_lens
=
ins
->
get_shape
().
lens
();
auto
ins_lens
=
ins
->
get_shape
().
lens
();
auto
bcast_ins
=
r
.
instructions
[
"bcast_ins"
];
auto
bcast_ins
=
r
.
instructions
[
"bcast_ins"
];
auto
input
=
bcast_ins
->
inputs
().
front
();
auto
input
=
bcast_ins
->
inputs
().
front
();
//
for now, focusing on
scalar transformation
// scalar transformation
does not need extra transpose
if
(
not
input
->
get_shape
().
scalar
())
if
(
not
input
->
get_shape
().
scalar
())
return
;
{
// find common shape
auto
in_lens
=
input
->
get_shape
().
lens
();
int
lens_diff
=
ins_lens
.
size
()
-
in_lens
.
size
();
if
(
lens_diff
>
0
)
{
std
::
vector
<
size_t
>
unsqueeze_axes
(
lens_diff
);
std
::
iota
(
unsqueeze_axes
.
begin
(),
unsqueeze_axes
.
end
(),
0
);
input
=
m
.
insert_instruction
(
bcast_ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
unsqueeze_axes
}}),
input
);
}
input
=
m
.
insert_instruction
(
bcast_ins
,
ins
->
get_operator
(),
input
);
}
auto
new_mbcast
=
m
.
insert_instruction
(
auto
new_mbcast
=
m
.
insert_instruction
(
bcast_ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
ins_lens
}}),
input
);
bcast_ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
ins_lens
}}),
input
);
m
.
replace_instruction
(
ins
,
new_mbcast
);
m
.
replace_instruction
(
ins
,
new_mbcast
);
...
...
test/optimize_module_test.cpp
View file @
00dae07f
...
@@ -62,4 +62,38 @@ TEST_CASE(broadcast_transpose_inner_broadcast)
...
@@ -62,4 +62,38 @@ TEST_CASE(broadcast_transpose_inner_broadcast)
EXPECT
(
m1
==
m2
);
EXPECT
(
m1
==
m2
);
}
}
TEST_CASE
(
broadcast_transpose_inner_broadcast_generic
)
{
// first optimizes broadcast+transpose to unsqueeze+transpose+broadcast,
// then finds inner broadcast to become mul+broadcast
migraphx
::
module
m1
;
{
auto
l1
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
,
10
}});
auto
l2
=
m1
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
mb1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
5
,
10
}}}),
l1
);
auto
mb2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
10
,
5
}}}),
l2
);
auto
t1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
}}}),
mb2
);
auto
mul
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
mb1
,
t1
);
m1
.
add_return
({
mul
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
l1
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
,
10
}});
auto
l2
=
m2
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
unsqueeze
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
l2
);
auto
transpose
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
}}}),
unsqueeze
);
auto
mb1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
5
,
10
}}}),
l1
);
auto
mb2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
5
,
10
}}}),
transpose
);
auto
mul
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
mb1
,
mb2
);
auto
mb3
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
5
,
10
}}}),
mul
);
m2
.
add_return
({
mb3
});
}
EXPECT
(
m1
==
m2
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment