Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
271eeddd
Unverified
Commit
271eeddd
authored
Oct 13, 2023
by
kahmed10
Committed by
GitHub
Oct 13, 2023
Browse files
Optimize broadcast + transpose for nonscalars (#2271)
parent
68161431
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
151 additions
and
9 deletions
+151
-9
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+22
-1
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+23
-7
test/optimize_module_test.cpp
test/optimize_module_test.cpp
+37
-0
test/simplify_algebra_test.cpp
test/simplify_algebra_test.cpp
+18
-1
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+51
-0
No files found.
src/simplify_algebra.cpp
View file @
271eeddd
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -521,6 +521,27 @@ struct find_inner_broadcast
...
@@ -521,6 +521,27 @@ struct find_inner_broadcast
})
<
(
lens
.
size
()
-
1
);
})
<
(
lens
.
size
()
-
1
);
}))
}))
return
;
return
;
if
(
broadcasts
.
size
()
>
1
)
{
auto
bcast_strides
=
broadcasts
.
front
()
->
get_shape
().
strides
().
size
();
std
::
vector
<
size_t
>
common_axis
(
bcast_strides
,
0
);
// go through the strides of each broadcast,
// keep track of values that are equal to 0 in a dimension
for
(
auto
i
=
0
;
i
<
bcast_strides
;
i
++
)
{
for
(
const
auto
&
broadcast
:
broadcasts
)
{
if
(
broadcast
->
get_shape
().
strides
()[
i
]
==
0
)
common_axis
[
i
]
++
;
}
}
// if no common broadcast axis, transformation is not useful
if
(
std
::
find_if
(
common_axis
.
begin
(),
common_axis
.
end
(),
[](
auto
num_common
)
{
return
num_common
>
1
;
})
==
common_axis
.
end
())
return
;
}
std
::
vector
<
instruction_ref
>
inputs
;
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
broadcasts
.
begin
(),
std
::
transform
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
broadcasts
.
end
(),
...
...
src/simplify_reshapes.cpp
View file @
271eeddd
...
@@ -632,6 +632,9 @@ struct find_transpose_contiguous_reshaper_unary
...
@@ -632,6 +632,9 @@ struct find_transpose_contiguous_reshaper_unary
}
}
};
};
// simplifies broadcast->transpose to transpose->broadcast
// in the case of a scalar, simply rewrite to broadcast
// this can allow for further optimizations with find_inner_broadcast() in simplify_algebra.cpp
struct
find_broadcast_transpose
struct
find_broadcast_transpose
{
{
auto
matcher
()
const
auto
matcher
()
const
...
@@ -642,17 +645,30 @@ struct find_broadcast_transpose
...
@@ -642,17 +645,30 @@ struct find_broadcast_transpose
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
transpose
=
r
.
result
;
auto
ins
_lens
=
ins
->
get_shape
().
lens
();
auto
transpose
_lens
=
transpose
->
get_shape
().
lens
();
auto
bcast_ins
=
r
.
instructions
[
"bcast_ins"
];
auto
bcast_ins
=
r
.
instructions
[
"bcast_ins"
];
auto
input
=
bcast_ins
->
inputs
().
front
();
auto
input
=
bcast_ins
->
inputs
().
front
();
//
for now, focusing on
scalar transformation
// scalar transformation
does not need extra transpose
if
(
not
input
->
get_shape
().
scalar
())
if
(
not
input
->
get_shape
().
scalar
())
return
;
{
// find common shape
auto
in_lens
=
input
->
get_shape
().
lens
();
int
lens_diff
=
transpose_lens
.
size
()
-
in_lens
.
size
();
// insert unsqueeze if input lens < transpose lens
if
(
lens_diff
>
0
)
{
std
::
vector
<
size_t
>
unsqueeze_axes
(
lens_diff
);
std
::
iota
(
unsqueeze_axes
.
begin
(),
unsqueeze_axes
.
end
(),
0
);
input
=
m
.
insert_instruction
(
bcast_ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
unsqueeze_axes
}}),
input
);
}
// apply transpose before the multibroadcast
input
=
m
.
insert_instruction
(
bcast_ins
,
transpose
->
get_operator
(),
input
);
}
auto
new_mbcast
=
m
.
insert_instruction
(
auto
new_mbcast
=
m
.
insert_instruction
(
bcast_ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
ins
_lens
}}),
input
);
bcast_ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
transpose
_lens
}}),
input
);
m
.
replace_instruction
(
ins
,
new_mbcast
);
m
.
replace_instruction
(
transpose
,
new_mbcast
);
}
}
};
};
...
...
test/optimize_module_test.cpp
View file @
271eeddd
...
@@ -62,4 +62,41 @@ TEST_CASE(broadcast_transpose_inner_broadcast)
...
@@ -62,4 +62,41 @@ TEST_CASE(broadcast_transpose_inner_broadcast)
EXPECT
(
m1
==
m2
);
EXPECT
(
m1
==
m2
);
}
}
TEST_CASE
(
broadcast_transpose_inner_broadcast_generic
)
{
// first optimizes broadcast+transpose to unsqueeze+transpose+broadcast,
// then finds inner broadcast to become mul+broadcast
migraphx
::
module
m1
;
{
auto
l1
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
,
10
}});
auto
l2
=
m1
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
mb1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
5
,
10
}}}),
l1
);
auto
mb2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
10
,
5
}}}),
l2
);
auto
t1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
}}}),
mb2
);
auto
mul
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
mb1
,
t1
);
m1
.
add_return
({
mul
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
l1
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
,
10
}});
auto
l2
=
m2
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
unsqueeze
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
l2
);
auto
transpose
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
}}}),
unsqueeze
);
auto
mb1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
5
,
10
}}}),
l1
);
auto
mb2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
5
,
10
}}}),
transpose
);
auto
mul
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
mb1
,
mb2
);
auto
mb3
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
5
,
10
}}}),
mul
);
m2
.
add_return
({
mb3
});
}
EXPECT
(
m1
==
m2
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/simplify_algebra_test.cpp
View file @
271eeddd
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -669,6 +669,23 @@ TEST_CASE(simplify_inner_broadcast_different_broadcasts)
...
@@ -669,6 +669,23 @@ TEST_CASE(simplify_inner_broadcast_different_broadcasts)
EXPECT
(
m1
==
m2
);
EXPECT
(
m1
==
m2
);
}
}
TEST_CASE
(
simplify_inner_broadcast_no_common_axis
)
{
auto
b
=
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
5
,
10
}}});
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
5
,
10
}});
auto
y
=
m1
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
5
,
1
}});
auto
xb
=
m1
.
add_instruction
(
b
,
x
);
auto
yb
=
m1
.
add_instruction
(
b
,
y
);
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
xb
,
yb
);
m1
.
add_instruction
(
pass_op
{},
sum
);
}
migraphx
::
module
m2
=
m1
;
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
simplify_add_conv1
)
TEST_CASE
(
simplify_add_conv1
)
{
{
migraphx
::
module
m
;
migraphx
::
module
m
;
...
...
test/simplify_reshapes_test.cpp
View file @
271eeddd
...
@@ -67,6 +67,57 @@ migraphx::module make_concat_multibroadcast(const std::vector<size_t>& in_lens,
...
@@ -67,6 +67,57 @@ migraphx::module make_concat_multibroadcast(const std::vector<size_t>& in_lens,
return
m
;
return
m
;
}
}
TEST_CASE
(
broadcast_transpose
)
{
migraphx
::
module
m1
;
{
auto
l
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
mb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
5
}}}),
l
);
auto
t1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
2
,
0
,
1
}}}),
mb
);
m1
.
add_return
({
t1
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
l
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
u1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
l
);
auto
t1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
2
,
0
,
1
}}}),
u1
);
auto
mb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
5
,
2
,
3
}}}),
t1
);
m2
.
add_return
({
mb
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
broadcast_transpose_opt
)
{
// extra transpose from transformation will be optimized out
migraphx
::
module
m1
;
{
auto
l
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
mb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
5
}}}),
l
);
auto
t1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
,
2
}}}),
mb
);
m1
.
add_return
({
t1
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
l
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
5
}});
auto
u1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
l
);
auto
mb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
2
,
5
}}}),
u1
);
m2
.
add_return
({
mb
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
broadcast_transpose_scalar
)
TEST_CASE
(
broadcast_transpose_scalar
)
{
{
migraphx
::
module
m1
;
migraphx
::
module
m1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment