Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4c185e65
Commit
4c185e65
authored
Sep 30, 2022
by
Paul
Browse files
Fuse across binary slices
parent
40118191
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
77 additions
and
30 deletions
+77
-30
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+77
-30
No files found.
src/simplify_algebra.cpp
View file @
4c185e65
...
...
@@ -559,6 +559,42 @@ struct find_splits
return
true
;
}
static
std
::
vector
<
instruction_ref
>
split_nary
(
const
std
::
vector
<
instruction_ref
>&
group
)
{
// All inputs have the same slices
if
(
not
std
::
all_of
(
group
.
begin
(),
group
.
end
(),
[](
auto
ins
)
{
if
(
ins
->
inputs
().
empty
())
return
false
;
auto
first
=
ins
->
inputs
().
front
();
if
(
first
->
name
()
!=
"slice"
)
return
false
;
return
std
::
all_of
(
ins
->
inputs
().
begin
()
+
1
,
ins
->
inputs
().
end
(),
[
&
](
auto
input
)
{
return
input
->
get_operator
()
==
first
->
get_operator
();
});
}))
return
{};
auto
start
=
group
.
front
();
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
start
->
inputs
().
begin
(),
start
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[](
auto
ins
)
{
return
ins
->
inputs
().
front
();
});
if
(
not
std
::
all_of
(
group
.
begin
(),
group
.
end
(),
[
&
](
auto
ins
)
{
return
std
::
equal
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
slice
,
auto
input
)
{
return
slice
->
inputs
().
front
()
==
input
;
});
}))
return
{};
return
inputs
;
}
template
<
class
Range
>
static
instruction_ref
find_last_instruction
(
const
module
&
m
,
const
Range
&
r
)
{
auto
rm
=
reverse
(
m
);
auto
it
=
std
::
find_first_of
(
rm
.
begin
(),
rm
.
end
(),
r
.
begin
(),
r
.
end
());
return
std
::
prev
(
it
.
base
());
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
...
...
@@ -591,42 +627,53 @@ struct find_splits
assert
(
not
std
::
none_of
(
start
->
inputs
().
begin
(),
start
->
inputs
().
end
(),
[](
auto
i
)
{
return
i
->
name
()
==
"slice"
;
})
&&
"one argument must be a split"
);
auto
data_idx
=
1
;
if
(
start
->
inputs
().
back
()
->
name
()
==
"slice"
)
auto
split_inputs
=
split_nary
(
group
);
if
(
not
split_inputs
.
empty
())
{
split_i
dx
=
1
;
data_idx
=
0
;
auto
last
=
find_last_instruction
(
m
,
split_i
nputs
)
;
c
=
m
.
insert_instruction
(
std
::
next
(
last
),
op
,
split_inputs
)
;
}
else
{
auto
data_idx
=
1
;
if
(
start
->
inputs
().
back
()
->
name
()
==
"slice"
)
{
split_idx
=
1
;
data_idx
=
0
;
}
std
::
vector
<
instruction_ref
>
data_args
;
std
::
transform
(
group
.
begin
(),
group
.
end
(),
std
::
back_inserter
(
data_args
),
[
&
](
auto
i
)
{
return
i
->
inputs
()[
data_idx
];
});
std
::
vector
<
instruction_ref
>
data_args
;
std
::
transform
(
group
.
begin
(),
group
.
end
(),
std
::
back_inserter
(
data_args
),
[
&
](
auto
i
)
{
return
i
->
inputs
()[
data_idx
];
});
// Data arguments must be a constant
if
(
std
::
any_of
(
data_args
.
begin
(),
data_args
.
end
(),
[](
auto
i
)
{
return
not
i
->
can_eval
();
}))
return
;
// Data arguments must be a constant
if
(
std
::
any_of
(
data_args
.
begin
(),
data_args
.
end
(),
[](
auto
i
)
{
return
not
i
->
can_eval
();
}))
return
;
for
(
auto
data
:
data_args
)
m
.
move_instructions
(
data
,
ins
);
for
(
auto
data
:
data_args
)
m
.
move_instructions
(
data
,
ins
);
auto
slice_op
=
any_cast
<
op
::
slice
>
(
splits
.
front
()
->
get_operator
());
assert
(
not
slice_op
.
axes
.
empty
());
if
(
slice_op
.
axes
.
size
()
>
1
)
return
;
auto
concat_axis
=
slice_op
.
axes
.
front
();
// TODO: Check if axises match
auto
concat
=
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
concat_axis
}}),
data_args
);
std
::
vector
<
instruction_ref
>
args
;
args
.
resize
(
2
);
args
[
split_idx
]
=
ins
;
args
[
data_idx
]
=
concat
;
c
=
m
.
insert_instruction
(
std
::
next
(
ins
),
op
,
args
);
}
auto
slice_op
=
any_cast
<
op
::
slice
>
(
splits
.
front
()
->
get_operator
());
assert
(
not
slice_op
.
axes
.
empty
());
if
(
slice_op
.
axes
.
size
()
>
1
)
return
;
auto
concat_axis
=
slice_op
.
axes
.
front
();
// TODO: Check if axises match
auto
concat
=
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
concat_axis
}}),
data_args
);
std
::
vector
<
instruction_ref
>
args
;
args
.
resize
(
2
);
args
[
split_idx
]
=
ins
;
args
[
data_idx
]
=
concat
;
c
=
m
.
insert_instruction
(
std
::
next
(
ins
),
op
,
args
);
}
if
(
c
!=
m
.
end
())
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment