Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
af3e36c0
Commit
af3e36c0
authored
Nov 11, 2023
by
Ted Themistokleous
Browse files
Remove additional contiguous around reshape from fuse_pointwise and simplify_reshapes.
Update tests to reflect change accordingly
parent
492e202d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
26 deletions
+12
-26
src/fuse_pointwise.cpp
src/fuse_pointwise.cpp
+1
-2
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+2
-6
test/fuse_pointwise.cpp
test/fuse_pointwise.cpp
+5
-10
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+4
-8
No files found.
src/fuse_pointwise.cpp
View file @
af3e36c0
...
...
@@ -219,9 +219,8 @@ struct find_pointwise_reshape_pointwise
auto
reshape_input
=
[
&
](
const
auto
&
ins_to_insert
)
{
return
[
&
](
auto
input
)
{
auto
c
=
m
.
insert_instruction
(
ins_to_insert
,
make_op
(
"contiguous"
),
input
);
return
m
.
insert_instruction
(
ins_to_insert
,
make_op
(
"reshape"
,
{{
"dims"
,
cd
.
dims
}}),
c
);
ins_to_insert
,
make_op
(
"reshape"
,
{{
"dims"
,
cd
.
dims
}}),
input
);
};
};
auto
x_inputs
=
x_ins
->
inputs
();
...
...
src/simplify_reshapes.cpp
View file @
af3e36c0
...
...
@@ -103,8 +103,6 @@ struct find_reshaper
auto
input
=
mr
.
instructions
[
"x"
];
auto
dims
=
ins
->
get_shape
().
lens
();
if
(
not
input
->
get_shape
().
standard
())
input
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
input
);
m
.
replace_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
input
);
}
};
...
...
@@ -475,9 +473,8 @@ struct find_resize
ins_rsp
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
in_dims
}}),
in_rsp
);
auto
mb_rsp
=
m
.
insert_instruction
(
ins_rsp
,
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_dims
}}),
rsp_data
);
auto
std_mb
=
m
.
insert_instruction
(
ins
,
migraphx
::
make_op
(
"contiguous"
),
mb_rsp
);
std
::
vector
<
int64_t
>
rsp_dims
(
out_lens
.
begin
(),
out_lens
.
end
());
m
.
replace_instruction
(
ins
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
rsp_dims
}}),
std_mb
);
m
.
replace_instruction
(
ins
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
rsp_dims
}}),
mb_rsp
);
}
};
...
...
@@ -626,9 +623,8 @@ struct find_transpose_contiguous_reshaper_unary
auto
cont_ins
=
r
.
instructions
[
"cont_ins"
];
auto
unary_op_name
=
ins
->
get_operator
().
name
();
auto
unary_ins
=
m
.
insert_instruction
(
cont_ins
,
make_op
(
unary_op_name
),
trans_ins
);
auto
new_cont_ins
=
m
.
insert_instruction
(
cont_ins
,
make_op
(
"contiguous"
),
unary_ins
);
// older cont and reshape are removed by deadcode elimination
m
.
replace_instruction
(
ins
,
reshaper_ins
->
get_operator
(),
new_cont
_ins
);
m
.
replace_instruction
(
ins
,
reshaper_ins
->
get_operator
(),
unary
_ins
);
}
};
...
...
test/fuse_pointwise.cpp
View file @
af3e36c0
...
...
@@ -414,8 +414,7 @@ TEST_CASE(add_reshape_add_nonstandard)
auto
y
=
mm
->
add_parameter
(
"y"
,
s1
);
auto
z
=
mm
->
add_parameter
(
"z"
,
s2
);
auto
add1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
y
);
auto
c
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
add1
);
auto
reshape
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s2
.
lens
()}}),
c
);
auto
reshape
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s2
.
lens
()}}),
add1
);
auto
add2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
reshape
,
z
);
mm
->
add_return
({
add2
});
}
...
...
@@ -426,10 +425,8 @@ TEST_CASE(add_reshape_add_nonstandard)
auto
x
=
mm
->
add_parameter
(
"x"
,
s1
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s1
);
auto
z
=
mm
->
add_parameter
(
"z"
,
s2
);
auto
cx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
x
);
auto
cy
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
y
);
auto
x2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s3
.
lens
()}}),
cx
);
auto
y2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s3
.
lens
()}}),
cy
);
auto
x2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s3
.
lens
()}}),
x
);
auto
y2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s3
.
lens
()}}),
y
);
auto
z2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s3
.
lens
()}}),
z
);
auto
fadd
=
add_pointwise
(
p2
,
"main:pointwise0"
,
{
x2
,
y2
,
z2
},
[
=
](
auto
*
pm
,
const
auto
&
inputs
)
{
...
...
@@ -466,10 +463,8 @@ TEST_CASE(add_unsqueeze_add_nonstandard)
auto
x
=
mm
->
add_parameter
(
"x"
,
s1
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s1
);
auto
z
=
mm
->
add_parameter
(
"z"
,
s2
);
auto
cx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
x
);
auto
cy
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
y
);
auto
x2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s2
.
lens
()}}),
cx
);
auto
y2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s2
.
lens
()}}),
cy
);
auto
x2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s2
.
lens
()}}),
x
);
auto
y2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
s2
.
lens
()}}),
y
);
auto
fadd
=
add_pointwise
(
p2
,
"main:pointwise0"
,
{
x2
,
y2
,
z
},
[
=
](
auto
*
pm
,
const
auto
&
inputs
)
{
auto
add1
=
pm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
inputs
[
0
],
inputs
[
1
]);
...
...
test/simplify_reshapes_test.cpp
View file @
af3e36c0
...
...
@@ -888,9 +888,8 @@ TEST_CASE(optimize_resize)
std
::
vector
<
int64_t
>
mb_dims
=
{
1
,
2
,
2
,
2
,
2
,
3
};
auto
mbx
=
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
mb_dims
}}),
rspx
);
auto
std_mb
=
m
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
mbx
);
std
::
vector
<
int64_t
>
orig_dims
=
{
1
,
2
,
4
,
6
};
auto
rmb
=
m
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
orig_dims
}}),
std_
mb
);
auto
rmb
=
m
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
orig_dims
}}),
mb
x
);
auto
r
=
m
.
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
1
}}),
rmb
);
m
.
add_return
({
r
});
...
...
@@ -1301,9 +1300,8 @@ TEST_CASE(transpose_contiguous_reshape_unary)
auto
transpose_ins
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
3
,
4
,
1
,
5
,
2
}}}),
reshape_ins1
);
auto
relu
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"relu"
),
transpose_ins
);
auto
cont_ins
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
relu
);
auto
reshape_ins2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
2
,
2
,
10
,
10
}}}),
cont_ins
);
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
2
,
2
,
10
,
10
}}}),
relu
);
m2
.
add_instruction
(
pass_op
{},
reshape_ins2
);
}
EXPECT
(
m1
==
m2
);
...
...
@@ -1328,8 +1326,7 @@ TEST_CASE(transpose_contiguous_squeeze_unary)
auto
transpose_ins
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
x
);
auto
rsqrt
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"rsqrt"
),
transpose_ins
);
auto
cont_ins
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
rsqrt
);
auto
sq_ins
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
1
}}}),
cont_ins
);
auto
sq_ins
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
1
}}}),
rsqrt
);
m2
.
add_instruction
(
pass_op
{},
sq_ins
);
}
EXPECT
(
m1
==
m2
);
...
...
@@ -1355,9 +1352,8 @@ TEST_CASE(transpose_contiguous_unsqueeze_unary)
auto
transpose_ins
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
x
);
auto
round
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"nearbyint"
),
transpose_ins
);
auto
cont_ins
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
round
);
auto
unsq_ins
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}}}),
cont_ins
);
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}}}),
round
);
m2
.
add_instruction
(
pass_op
{},
unsq_ins
);
}
EXPECT
(
m1
==
m2
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment