Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0b840acc
Commit
0b840acc
authored
Nov 08, 2023
by
Paul
Browse files
Format
parent
0a7ee4de
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
34 deletions
+35
-34
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+35
-34
No files found.
src/simplify_algebra.cpp
View file @
0b840acc
...
@@ -745,19 +745,20 @@ void move_instructions_back(module& m, instruction_ref pos, std::vector<instruct
...
@@ -745,19 +745,20 @@ void move_instructions_back(module& m, instruction_ref pos, std::vector<instruct
optional
<
std
::
size_t
>
find_split_axis
(
const
std
::
vector
<
instruction_ref
>&
slices
)
optional
<
std
::
size_t
>
find_split_axis
(
const
std
::
vector
<
instruction_ref
>&
slices
)
{
{
auto
first
=
slices
.
front
();
auto
first
=
slices
.
front
();
auto
get_slice
=
[](
auto
&
i
)
->
auto
&
{
return
any_cast
<
op
::
slice
>
(
i
->
get_operator
());
};
auto
get_slice
=
[](
auto
&
i
)
->
auto
&
{
return
any_cast
<
op
::
slice
>
(
i
->
get_operator
());
};
auto
get_start
=
[
&
](
auto
&
i
)
->
auto
&
{
return
get_slice
(
i
).
starts
;
};
auto
get_start
=
[
&
](
auto
&
i
)
->
auto
&
{
return
get_slice
(
i
).
starts
;
};
auto
get_end
=
[
&
](
auto
&
i
)
->
auto
&
{
return
get_slice
(
i
).
ends
;
};
auto
get_end
=
[
&
](
auto
&
i
)
->
auto
&
{
return
get_slice
(
i
).
ends
;
};
auto
find_different_axis
=
[
&
](
auto
select
)
{
auto
find_different_axis
=
[
&
](
auto
select
)
{
std
::
vector
<
int64_t
>
different
;
std
::
vector
<
int64_t
>
different
;
std
::
for_each
(
slices
.
begin
()
+
1
,
slices
.
end
(),
[
&
](
const
auto
&
slice
)
{
std
::
for_each
(
slices
.
begin
()
+
1
,
slices
.
end
(),
[
&
](
const
auto
&
slice
)
{
auto
it
=
select
(
slice
).
begin
();
auto
it
=
select
(
slice
).
begin
();
while
(
it
!=
select
(
slice
).
end
())
while
(
it
!=
select
(
slice
).
end
())
{
{
auto
p
=
std
::
mismatch
(
it
,
select
(
slice
).
end
(),
select
(
first
).
begin
(),
select
(
first
).
end
());
auto
p
=
std
::
mismatch
(
it
,
select
(
slice
).
end
(),
select
(
first
).
begin
(),
select
(
first
).
end
());
auto
i
=
p
.
first
-
select
(
slice
).
begin
();
auto
i
=
p
.
first
-
select
(
slice
).
begin
();
if
(
not
contains
(
different
,
i
))
if
(
not
contains
(
different
,
i
))
different
.
push_back
(
i
);
different
.
push_back
(
i
);
it
=
p
.
first
;
it
=
p
.
first
;
}
}
...
@@ -765,10 +766,10 @@ optional<std::size_t> find_split_axis(const std::vector<instruction_ref>& slices
...
@@ -765,10 +766,10 @@ optional<std::size_t> find_split_axis(const std::vector<instruction_ref>& slices
return
different
;
return
different
;
};
};
auto
different_starts
=
find_different_axis
(
get_start
);
auto
different_starts
=
find_different_axis
(
get_start
);
auto
different_ends
=
find_different_axis
(
get_end
);
auto
different_ends
=
find_different_axis
(
get_end
);
if
(
different_ends
!=
different_starts
)
if
(
different_ends
!=
different_starts
)
return
nullopt
;
return
nullopt
;
if
(
different_starts
.
empty
())
if
(
different_starts
.
empty
())
return
nullopt
;
return
nullopt
;
return
different_starts
.
front
();
return
different_starts
.
front
();
}
}
...
@@ -810,9 +811,9 @@ std::vector<instruction_ref> get_splits(instruction_ref ins)
...
@@ -810,9 +811,9 @@ std::vector<instruction_ref> get_splits(instruction_ref ins)
struct
split_analyzer
struct
split_analyzer
{
{
std
::
vector
<
instruction_ref
>
slices
=
{};
std
::
vector
<
instruction_ref
>
slices
=
{};
std
::
size_t
axis
=
0
;
std
::
size_t
axis
=
0
;
template
<
class
T
>
template
<
class
T
>
static
auto
&
get_slice
(
T
&
i
)
static
auto
&
get_slice
(
T
&
i
)
{
{
return
any_cast
<
op
::
slice
>
(
i
->
get_operator
());
return
any_cast
<
op
::
slice
>
(
i
->
get_operator
());
...
@@ -823,26 +824,29 @@ struct split_analyzer
...
@@ -823,26 +824,29 @@ struct split_analyzer
{
{
std
::
vector
<
instruction_ref
>
result
;
std
::
vector
<
instruction_ref
>
result
;
std
::
copy_if
(
ins
->
outputs
().
begin
(),
std
::
copy_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
ins
->
outputs
().
end
(),
std
::
back_inserter
(
result
),
std
::
back_inserter
(
result
),
[
&
](
auto
i
)
{
return
i
->
name
()
==
"slice"
;
});
[
&
](
auto
i
)
{
return
i
->
name
()
==
"slice"
;
});
if
(
result
.
size
()
<
2
)
if
(
result
.
size
()
<
2
)
return
;
return
;
auto
&&
axes
=
get_slice
(
result
.
front
()).
axes
;
auto
&&
axes
=
get_slice
(
result
.
front
()).
axes
;
if
(
std
::
any_of
(
result
.
begin
(),
result
.
end
(),
[
&
](
auto
i
)
{
return
get_slice
(
i
).
axes
!=
axes
;
}))
if
(
std
::
any_of
(
result
.
begin
(),
result
.
end
(),
[
&
](
auto
i
)
{
return
get_slice
(
i
).
axes
!=
axes
;
}))
return
;
return
;
auto
split_axis
=
find_split_axis
(
result
);
auto
split_axis
=
find_split_axis
(
result
);
if
(
not
split_axis
.
has_value
())
if
(
not
split_axis
.
has_value
())
return
;
return
;
axis
=
*
split_axis
;
axis
=
*
split_axis
;
auto
get_start
=
[
&
](
auto
&
i
)
->
auto
&
{
return
get_slice
(
i
).
starts
[
axis
];
};
auto
get_start
=
[
&
](
auto
&
i
)
->
auto
&
{
return
get_slice
(
i
).
starts
[
axis
];
};
auto
get_end
=
[
&
](
auto
&
i
)
->
auto
&
{
return
get_slice
(
i
).
ends
[
axis
];
};
auto
get_end
=
[
&
](
auto
&
i
)
->
auto
&
{
return
get_slice
(
i
).
ends
[
axis
];
};
std
::
sort
(
std
::
sort
(
result
.
begin
(),
result
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
result
.
begin
(),
result
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
return
get_start
(
x
)
<
get_start
(
y
);
});
return
get_start
(
x
)
<
get_start
(
y
);
if
(
get_start
(
result
.
front
())
!=
0
)
});
if
(
get_start
(
result
.
front
())
!=
0
)
return
;
return
;
auto
it
=
std
::
adjacent_find
(
auto
it
=
std
::
adjacent_find
(
result
.
begin
(),
result
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
result
.
begin
(),
result
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
return
get_end
(
x
)
!=
get_start
(
y
);
});
return
get_end
(
x
)
!=
get_start
(
y
);
});
if
(
it
!=
result
.
end
())
if
(
it
!=
result
.
end
())
return
;
return
;
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
axes
.
size
();
i
++
)
...
@@ -853,18 +857,12 @@ struct split_analyzer
...
@@ -853,18 +857,12 @@ struct split_analyzer
slices
=
result
;
slices
=
result
;
}
}
bool
has_multi_axes
()
const
bool
has_multi_axes
()
const
{
return
get_slice
(
slices
.
front
()).
axes
.
size
()
>
1
;
}
{
return
get_slice
(
slices
.
front
()).
axes
.
size
()
>
1
;
}
operation
pre_split
()
const
operation
pre_split
()
const
{
{
auto
slice
=
get_slice
(
slices
.
front
());
auto
slice
=
get_slice
(
slices
.
front
());
auto
remove_axis
=
[
&
](
auto
&
v
)
auto
remove_axis
=
[
&
](
auto
&
v
)
{
v
.
erase
(
v
.
begin
()
+
axis
);
};
{
v
.
erase
(
v
.
begin
()
+
axis
);
};
remove_axis
(
slice
.
axes
);
remove_axis
(
slice
.
axes
);
remove_axis
(
slice
.
starts
);
remove_axis
(
slice
.
starts
);
remove_axis
(
slice
.
ends
);
remove_axis
(
slice
.
ends
);
...
@@ -873,17 +871,20 @@ struct split_analyzer
...
@@ -873,17 +871,20 @@ struct split_analyzer
instruction_ref
insert_pre_split
(
module
&
m
,
instruction_ref
ins
)
const
instruction_ref
insert_pre_split
(
module
&
m
,
instruction_ref
ins
)
const
{
{
if
(
not
has_multi_axes
())
if
(
not
has_multi_axes
())
return
ins
;
return
ins
;
return
m
.
insert_instruction
(
std
::
next
(
ins
),
pre_split
(),
ins
);
return
m
.
insert_instruction
(
std
::
next
(
ins
),
pre_split
(),
ins
);
}
}
operation
post_split
(
instruction_ref
ins
)
const
operation
post_split
(
instruction_ref
ins
)
const
{
{
if
(
not
has_multi_axes
())
if
(
not
has_multi_axes
())
return
ins
->
get_operator
();
return
ins
->
get_operator
();
auto
slice
=
get_slice
(
ins
);
auto
slice
=
get_slice
(
ins
);
return
make_op
(
"slice"
,
{{
"axes"
,
{
slice
.
axes
[
axis
]}},
{
"starts"
,
{
slice
.
starts
[
axis
]}},
{
"ends"
,
{
slice
.
ends
[
axis
]}}});
return
make_op
(
"slice"
,
{{
"axes"
,
{
slice
.
axes
[
axis
]}},
{
"starts"
,
{
slice
.
starts
[
axis
]}},
{
"ends"
,
{
slice
.
ends
[
axis
]}}});
}
}
};
};
...
@@ -981,7 +982,7 @@ struct find_splits
...
@@ -981,7 +982,7 @@ struct find_splits
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
split_analyzer
analyzer
{
ins
};
split_analyzer
analyzer
{
ins
};
if
(
analyzer
.
slices
.
empty
())
if
(
analyzer
.
slices
.
empty
())
return
;
return
;
ins
=
analyzer
.
insert_pre_split
(
m
,
ins
);
ins
=
analyzer
.
insert_pre_split
(
m
,
ins
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment