Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
bb90f0eb
Commit
bb90f0eb
authored
Apr 04, 2023
by
Shiv
Browse files
update conv dot fusion
parent
e7ec374f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
75 additions
and
12 deletions
+75
-12
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+15
-10
test/simplify_algebra_test.cpp
test/simplify_algebra_test.cpp
+60
-2
No files found.
src/simplify_algebra.cpp
View file @
bb90f0eb
...
...
@@ -896,21 +896,26 @@ struct find_conv_dot_horiz_fusion
int64_t
offset
=
0
;
for
(
auto
arg
:
range
(
start
,
last
))
{
auto
outputs
=
arg
->
outputs
();
for
(
auto
output
:
outputs
)
{
if
(
output
->
name
()
!=
"reshape"
)
continue
;
auto
x
=
m
.
insert_instruction
(
output
,
make_op
(
"contiguous"
),
arg
);
m
.
replace_instruction
(
output
,
output
->
get_operator
(),
x
);
}
auto
outputs
=
arg
->
outputs
();
auto
requires_contiguous
=
std
::
any_of
(
outputs
.
begin
(),
outputs
.
end
(),
[](
auto
o
)
{
return
o
->
get_shape
().
standard
();
});
int64_t
len
=
arg
->
get_shape
().
lens
()[
axis
];
m
.
replace_instruction
(
arg
,
auto
slice
=
m
.
insert_instruction
(
std
::
prev
(
arg
),
make_op
(
"slice"
,
{{
"axes"
,
{
axis
}},
{
"starts"
,
{
offset
}},
{
"ends"
,
{
offset
+
len
}}}),
fused
);
if
(
requires_contiguous
)
{
m
.
replace_instruction
(
arg
,
make_op
(
"contiguous"
),
slice
);
}
else
{
m
.
replace_instruction
(
arg
,
slice
);
}
offset
+=
len
;
}
};
...
...
test/simplify_algebra_test.cpp
View file @
bb90f0eb
...
...
@@ -2133,12 +2133,53 @@ TEST_CASE(simplify_dot_horiz)
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
2
}}}),
dot
);
auto
y
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
2
}},
{
"ends"
,
{
4
}}}),
dot
);
x
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
x
);
y
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
y
);
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
y
);
m2
.
add_instruction
(
pass_op
{},
sum
);
}
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
simplify_dot_horiz_nonstandard
)
{
auto
s1
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
4
,
24
,
24
}};
auto
s2
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
4
,
24
,
24
},
{
0
,
1
,
24
}};
migraphx
::
module
m1
;
{
auto
input
=
m1
.
add_parameter
(
"input"
,
s1
);
auto
a
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
s2
,
0
));
auto
b
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
s2
,
1
));
auto
c
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
s2
,
2
));
auto
zeros
=
m1
.
add_literal
(
0
);
zeros
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
4
,
24
,
24
}}}),
zeros
);
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
input
,
a
);
auto
y
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
input
,
b
);
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
input
,
c
);
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
y
,
zeros
);
auto
rsp
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
4
,
24
,
3
,
8
}}}),
sum
);
m1
.
add_instruction
(
pass_op
{},
rsp
);
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
input
=
m2
.
add_parameter
(
"input"
,
s1
);
auto
a
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
s2
,
0
));
auto
b
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
s2
,
1
));
auto
c
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
s2
,
2
));
auto
concat
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
2
}}),
a
,
b
,
c
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
input
,
concat
);
auto
x
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
24
}},
{
"ends"
,
{
48
}}}),
dot
);
x
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
x
);
auto
rsp
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
4
,
24
,
3
,
8
}}}),
x
);
m2
.
add_instruction
(
pass_op
{},
rsp
);
}
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
simplify_dot_horiz_same_constant
)
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
3
,
2
,
2
}};
...
...
@@ -2163,6 +2204,8 @@ TEST_CASE(simplify_dot_horiz_same_constant)
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
2
}}}),
dot
);
auto
y
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
2
}},
{
"ends"
,
{
4
}}}),
dot
);
x
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
x
);
y
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
y
);
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
y
);
m2
.
add_instruction
(
pass_op
{},
sum
);
}
...
...
@@ -2219,10 +2262,11 @@ TEST_CASE(simplify_dot_horiz_reshape)
auto
y
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
4
}},
{
"ends"
,
{
8
}}}),
dot
);
auto
x_cont
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
x
);
auto
y_cont
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
y
);
auto
x_rsp
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
3
,
4
,
2
,
2
}}}),
x_cont
);
auto
y_rsp
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}},
{
"steps"
,
{
2
}}}),
y
);
auto
y_rsp
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}},
{
"steps"
,
{
2
}}}),
y
_cont
);
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
{
x_rsp
,
y_rsp
});
m2
.
add_instruction
(
pass_op
{},
sum
);
}
...
...
@@ -2257,6 +2301,8 @@ TEST_CASE(simplify_conv_horiz)
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
12
}}}),
conv
);
auto
y
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
12
}},
{
"ends"
,
{
24
}}}),
conv
);
x
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
x
);
y
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
y
);
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
y
);
m2
.
add_instruction
(
pass_op
{},
sum
);
}
...
...
@@ -2333,12 +2379,16 @@ TEST_CASE(simplify_conv_horiz_grouped)
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
6
}}}),
conv
);
auto
convy
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
6
}},
{
"ends"
,
{
12
}}}),
conv
);
convx
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
convx
);
convy
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
convy
);
auto
sum1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
convx
,
convy
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
input
,
concat2
);
auto
dotx
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
3
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
64
}}}),
dot
);
auto
doty
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
3
}},
{
"starts"
,
{
64
}},
{
"ends"
,
{
128
}}}),
dot
);
dotx
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
dotx
);
doty
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
doty
);
auto
sum2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
dotx
,
doty
);
auto
sum3
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
sum1
,
sum2
);
m2
.
add_instruction
(
pass_op
{},
sum3
);
...
...
@@ -2391,12 +2441,16 @@ TEST_CASE(simplify_conv_horiz_grouped_extra1)
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
6
}}}),
conv
);
auto
convy
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
6
}},
{
"ends"
,
{
12
}}}),
conv
);
convx
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
convx
);
convy
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
convy
);
auto
sum1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
convx
,
convy
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
input
,
concat2
);
auto
dotx
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
3
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
64
}}}),
dot
);
auto
doty
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
3
}},
{
"starts"
,
{
64
}},
{
"ends"
,
{
128
}}}),
dot
);
dotx
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
dotx
);
doty
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
doty
);
auto
sum2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
dotx
,
doty
);
auto
sqdiffx
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"sqdiff"
),
input
,
e
);
auto
sum3
=
sqdiffx
;
...
...
@@ -2455,12 +2509,16 @@ TEST_CASE(simplify_conv_horiz_grouped_extra2)
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
6
}}}),
conv
);
auto
convy
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
6
}},
{
"ends"
,
{
12
}}}),
conv
);
convx
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
convx
);
convy
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
convy
);
auto
sum1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
convx
,
convy
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
input
,
concat2
);
auto
dotx
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
3
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
64
}}}),
dot
);
auto
doty
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
3
}},
{
"starts"
,
{
64
}},
{
"ends"
,
{
128
}}}),
dot
);
dotx
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
dotx
);
doty
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
doty
);
auto
sum2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
dotx
,
doty
);
auto
sqdiffx
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"sqdiff"
),
input
,
e
);
auto
sqdiffy
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"sqdiff"
),
input
,
f
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment