Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
66858a44
Commit
66858a44
authored
May 19, 2022
by
Paul
Browse files
Merge branch 'fuse-horiz-contiguous' into jit-contiguous2
parents
2e5116bf
dfc7bbac
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
135 additions
and
15 deletions
+135
-15
src/eliminate_contiguous.cpp
src/eliminate_contiguous.cpp
+33
-15
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+102
-0
No files found.
src/eliminate_contiguous.cpp
View file @
66858a44
...
@@ -69,22 +69,30 @@ static bool try_compute_shape(instruction_ref ins,
...
@@ -69,22 +69,30 @@ static bool try_compute_shape(instruction_ref ins,
return
try_compute_shape
(
ins
,
inputs
,
mods
);
return
try_compute_shape
(
ins
,
inputs
,
mods
);
}
}
void
eliminate_contiguous
::
apply
(
module
&
m
)
const
template
<
class
F
>
static
void
remove_contiguous
(
const
std
::
string
&
op_name
,
module
&
m
,
F
f
)
{
{
auto
last
=
std
::
prev
(
m
.
end
());
for
(
auto
ins
:
iterator_for
(
m
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
// return instruction should have inputs with standard shape
// return instruction should have inputs with standard shape
if
(
ins
->
name
()
==
"@return"
)
if
(
ins
->
name
()
==
"@return"
)
continue
;
continue
;
if
(
ins
!=
last
and
ins
->
outputs
().
empty
())
continue
;
if
(
not
f
(
ins
))
continue
;
// Make a copy so we can modify it while we iterate
// Make a copy so we can modify it while we iterate
auto
args
=
ins
->
inputs
();
auto
args
=
ins
->
inputs
();
auto
new_args
=
args
;
auto
new_args
=
args
;
auto
mod_args
=
ins
->
module_inputs
();
auto
mod_args
=
ins
->
module_inputs
();
for
(
auto
arg
:
ins
->
inputs
())
for
(
auto
arg
:
ins
->
inputs
())
{
{
if
(
arg
->
name
()
=
=
op_name
)
if
(
arg
->
name
()
!
=
op_name
)
{
continue
;
auto
prev
=
arg
->
inputs
().
front
();
auto
prev
=
arg
->
inputs
().
front
();
replace
(
new_args
,
arg
,
prev
);
replace
(
new_args
,
arg
,
prev
);
if
(
try_compute_shape
(
ins
,
new_args
,
mod_args
))
if
(
try_compute_shape
(
ins
,
new_args
,
mod_args
))
...
@@ -101,7 +109,17 @@ void eliminate_contiguous::apply(module& m) const
...
@@ -101,7 +109,17 @@ void eliminate_contiguous::apply(module& m) const
}
}
}
}
}
}
}
}
void
eliminate_contiguous
::
apply
(
module
&
m
)
const
{
// Skip contiguous from splits first
remove_contiguous
(
op_name
,
m
,
[](
auto
ins
)
{
if
(
ins
->
name
()
!=
"slice"
)
return
true
;
return
(
ins
->
inputs
().
front
()
->
outputs
().
size
()
==
1
);
});
remove_contiguous
(
op_name
,
m
,
[](
auto
)
{
return
true
;
});
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
test/simplify_reshapes_test.cpp
View file @
66858a44
...
@@ -1118,4 +1118,106 @@ TEST_CASE(transpose_contiguous_reshape_binary_broadcast)
...
@@ -1118,4 +1118,106 @@ TEST_CASE(transpose_contiguous_reshape_binary_broadcast)
EXPECT
(
m1
==
m2
);
EXPECT
(
m1
==
m2
);
}
}
TEST_CASE
(
transpose_slice
)
{
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
384
,
36
,
64
}});
auto
slice1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
12
}}}),
x
);
auto
transpose1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
,
3
}}}),
slice1
);
auto
slice2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
12
}},
{
"ends"
,
{
24
}}}),
x
);
auto
transpose2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
,
3
}}}),
slice2
);
auto
slice3
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
24
}},
{
"ends"
,
{
36
}}}),
x
);
auto
transpose3
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
,
3
}}}),
slice3
);
m1
.
add_return
({
transpose1
,
transpose2
,
transpose3
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
384
,
36
,
64
}});
auto
transpose
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
,
3
}}}),
x
);
auto
slice1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
12
}}}),
transpose
);
auto
slice2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
12
}},
{
"ends"
,
{
24
}}}),
transpose
);
auto
slice3
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
24
}},
{
"ends"
,
{
36
}}}),
transpose
);
m2
.
add_return
({
slice1
,
slice2
,
slice3
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
transpose_slice_diff_perm
)
{
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
384
,
36
,
64
}});
auto
slice1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
12
}}}),
x
);
auto
transpose1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
,
3
}}}),
slice1
);
auto
slice2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
12
}},
{
"ends"
,
{
24
}}}),
x
);
auto
transpose2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
slice2
);
auto
slice3
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
24
}},
{
"ends"
,
{
36
}}}),
x
);
auto
transpose3
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
,
3
}}}),
slice3
);
m1
.
add_return
({
transpose1
,
transpose2
,
transpose3
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
384
,
36
,
64
}});
auto
transpose
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
,
3
}}}),
x
);
auto
slice1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
12
}}}),
transpose
);
auto
slice2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
12
}},
{
"ends"
,
{
24
}}}),
transpose
);
auto
transpose2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
slice2
);
auto
slice3
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
24
}},
{
"ends"
,
{
36
}}}),
transpose
);
m2
.
add_return
({
slice1
,
transpose2
,
slice3
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
transpose_slice_single_transpose
)
{
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
384
,
36
,
64
}});
auto
slice1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
12
}}}),
x
);
auto
sqrt1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"sqrt"
),
slice1
);
auto
slice2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
12
}},
{
"ends"
,
{
24
}}}),
x
);
auto
transpose2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
,
3
}}}),
slice2
);
auto
slice3
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
24
}},
{
"ends"
,
{
36
}}}),
x
);
auto
sqrt3
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"sqrt"
),
slice3
);
m1
.
add_return
({
sqrt1
,
transpose2
,
sqrt3
});
}
migraphx
::
module
m2
=
m1
;
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment