Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0a7ee4de
Commit
0a7ee4de
authored
Nov 08, 2023
by
Paul
Browse files
handle multi axis split
parent
3c160a3f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
118 additions
and
6 deletions
+118
-6
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+118
-6
No files found.
src/simplify_algebra.cpp
View file @
0a7ee4de
...
...
@@ -743,6 +743,36 @@ void move_instructions_back(module& m, instruction_ref pos, std::vector<instruct
}
}
optional
<
std
::
size_t
>
find_split_axis
(
const
std
::
vector
<
instruction_ref
>&
slices
)
{
auto
first
=
slices
.
front
();
auto
get_slice
=
[](
auto
&
i
)
->
auto
&
{
return
any_cast
<
op
::
slice
>
(
i
->
get_operator
());
};
auto
get_start
=
[
&
](
auto
&
i
)
->
auto
&
{
return
get_slice
(
i
).
starts
;
};
auto
get_end
=
[
&
](
auto
&
i
)
->
auto
&
{
return
get_slice
(
i
).
ends
;
};
auto
find_different_axis
=
[
&
](
auto
select
)
{
std
::
vector
<
int64_t
>
different
;
std
::
for_each
(
slices
.
begin
()
+
1
,
slices
.
end
(),
[
&
](
const
auto
&
slice
)
{
auto
it
=
select
(
slice
).
begin
();
while
(
it
!=
select
(
slice
).
end
())
{
auto
p
=
std
::
mismatch
(
it
,
select
(
slice
).
end
(),
select
(
first
).
begin
(),
select
(
first
).
end
());
auto
i
=
p
.
first
-
select
(
slice
).
begin
();
if
(
not
contains
(
different
,
i
))
different
.
push_back
(
i
);
it
=
p
.
first
;
}
});
return
different
;
};
auto
different_starts
=
find_different_axis
(
get_start
);
auto
different_ends
=
find_different_axis
(
get_end
);
if
(
different_ends
!=
different_starts
)
return
nullopt
;
if
(
different_starts
.
empty
())
return
nullopt
;
return
different_starts
.
front
();
}
std
::
vector
<
instruction_ref
>
get_splits
(
instruction_ref
ins
)
{
std
::
vector
<
instruction_ref
>
result
;
...
...
@@ -777,6 +807,86 @@ std::vector<instruction_ref> get_splits(instruction_ref ins)
return
result
;
}
struct
split_analyzer
{
std
::
vector
<
instruction_ref
>
slices
=
{};
std
::
size_t
axis
=
0
;
template
<
class
T
>
static
auto
&
get_slice
(
T
&
i
)
{
return
any_cast
<
op
::
slice
>
(
i
->
get_operator
());
}
split_analyzer
()
=
default
;
explicit
split_analyzer
(
instruction_ref
ins
)
{
std
::
vector
<
instruction_ref
>
result
;
std
::
copy_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
std
::
back_inserter
(
result
),
[
&
](
auto
i
)
{
return
i
->
name
()
==
"slice"
;
});
if
(
result
.
size
()
<
2
)
return
;
auto
&&
axes
=
get_slice
(
result
.
front
()).
axes
;
if
(
std
::
any_of
(
result
.
begin
(),
result
.
end
(),
[
&
](
auto
i
)
{
return
get_slice
(
i
).
axes
!=
axes
;
}))
return
;
auto
split_axis
=
find_split_axis
(
result
);
if
(
not
split_axis
.
has_value
())
return
;
axis
=
*
split_axis
;
auto
get_start
=
[
&
](
auto
&
i
)
->
auto
&
{
return
get_slice
(
i
).
starts
[
axis
];
};
auto
get_end
=
[
&
](
auto
&
i
)
->
auto
&
{
return
get_slice
(
i
).
ends
[
axis
];
};
std
::
sort
(
result
.
begin
(),
result
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
return
get_start
(
x
)
<
get_start
(
y
);
});
if
(
get_start
(
result
.
front
())
!=
0
)
return
;
auto
it
=
std
::
adjacent_find
(
result
.
begin
(),
result
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
return
get_end
(
x
)
!=
get_start
(
y
);
});
if
(
it
!=
result
.
end
())
return
;
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
{
if
(
ins
->
get_shape
().
lens
()[
axes
[
i
]]
!=
get_slice
(
result
.
back
()).
ends
[
i
])
return
;
}
slices
=
result
;
}
bool
has_multi_axes
()
const
{
return
get_slice
(
slices
.
front
()).
axes
.
size
()
>
1
;
}
operation
pre_split
()
const
{
auto
slice
=
get_slice
(
slices
.
front
());
auto
remove_axis
=
[
&
](
auto
&
v
)
{
v
.
erase
(
v
.
begin
()
+
axis
);
};
remove_axis
(
slice
.
axes
);
remove_axis
(
slice
.
starts
);
remove_axis
(
slice
.
ends
);
return
slice
;
}
instruction_ref
insert_pre_split
(
module
&
m
,
instruction_ref
ins
)
const
{
if
(
not
has_multi_axes
())
return
ins
;
return
m
.
insert_instruction
(
std
::
next
(
ins
),
pre_split
(),
ins
);
}
operation
post_split
(
instruction_ref
ins
)
const
{
if
(
not
has_multi_axes
())
return
ins
->
get_operator
();
auto
slice
=
get_slice
(
ins
);
return
make_op
(
"slice"
,
{{
"axes"
,
{
slice
.
axes
[
axis
]}},
{
"starts"
,
{
slice
.
starts
[
axis
]}},
{
"ends"
,
{
slice
.
ends
[
axis
]}}});
}
};
struct
find_splits
{
auto
matcher
()
const
...
...
@@ -870,14 +980,16 @@ struct find_splits
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
splits
=
get_splits
(
ins
)
;
if
(
split
s
.
empty
())
split_analyzer
analyzer
{
ins
}
;
if
(
analyzer
.
slice
s
.
empty
())
return
;
for
(
const
auto
&
group
:
get_split_groups
(
m
,
splits
))
ins
=
analyzer
.
insert_pre_split
(
m
,
ins
);
for
(
const
auto
&
group
:
get_split_groups
(
m
,
analyzer
.
slices
))
{
auto
start
=
group
.
front
();
auto
split_front
=
split
s
.
front
();
auto
split_front
=
analyzer
.
slice
s
.
front
();
auto
op
=
start
->
get_operator
();
if
(
not
is_fusable
(
start
,
split_front
))
{
...
...
@@ -920,7 +1032,7 @@ struct find_splits
move_instructions_back
(
m
,
ins
,
data_args
);
auto
slice_op
=
any_cast
<
op
::
slice
>
(
split
s
.
front
()
->
get_operator
());
auto
slice_op
=
any_cast
<
op
::
slice
>
(
analyzer
.
slice
s
.
front
()
->
get_operator
());
assert
(
not
slice_op
.
axes
.
empty
());
if
(
slice_op
.
axes
.
size
()
>
1
)
return
;
...
...
@@ -951,7 +1063,7 @@ struct find_splits
m
.
replace_instruction
(
output
,
output
->
get_operator
(),
x
);
}
m
.
replace_instruction
(
i
,
split
->
get_operator
(
),
c
);
m
.
replace_instruction
(
i
,
analyzer
.
post_split
(
split
),
c
);
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment