Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
34b9258a
Commit
34b9258a
authored
Oct 13, 2023
by
Umang Yadav
Browse files
Fixes
parent
40e6b38a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
43 additions
and
34 deletions
+43
-34
src/generate_root_modules.cpp
src/generate_root_modules.cpp
+43
-34
No files found.
src/generate_root_modules.cpp
View file @
34b9258a
...
@@ -21,7 +21,6 @@
...
@@ -21,7 +21,6 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include "migraphx/instruction_ref.hpp"
#include <cstddef>
#include <cstddef>
#include <limits>
#include <limits>
#include <iterator>
#include <iterator>
...
@@ -142,16 +141,16 @@ struct auto_gen_root_modules
...
@@ -142,16 +141,16 @@ struct auto_gen_root_modules
}
}
}
}
bool
is_different_subgraph
(
migraphx
::
instruction_ref
ins
,
size_t
tid
)
bool
is_different_subgraph
(
migraphx
::
instruction_ref
ins
,
std
::
optional
<
std
::
size_t
>
tid
)
{
{
if
(
tass
.
find
(
ins
)
==
tass
.
end
()
or
tass
.
at
(
ins
)
!=
tid
)
if
(
tass
.
find
(
ins
)
==
tass
.
end
())
{
{
return
t
rue
;
return
t
id
.
has_value
()
;
}
}
return
false
;
return
tass
.
at
(
ins
)
!=
tid
.
value_or
(
std
::
numeric_limits
<
std
::
size_t
>::
max
())
;
}
}
bool
is_merge_node
(
migraphx
::
instruction_ref
ins
,
size_t
tid
)
bool
is_merge_node
(
migraphx
::
instruction_ref
ins
,
std
::
optional
<
std
::
size_t
>
tid
)
{
{
const
auto
inputs
=
ins
->
inputs
();
const
auto
inputs
=
ins
->
inputs
();
if
(
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
input_ins
)
{
if
(
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
input_ins
)
{
...
@@ -165,7 +164,7 @@ struct auto_gen_root_modules
...
@@ -165,7 +164,7 @@ struct auto_gen_root_modules
return
false
;
return
false
;
}
}
bool
is_fork_node
(
migraphx
::
instruction_ref
ins
,
size_t
tid
)
bool
is_fork_node
(
migraphx
::
instruction_ref
ins
,
std
::
optional
<
std
::
size_t
>
tid
)
{
{
const
auto
outputs
=
ins
->
outputs
();
const
auto
outputs
=
ins
->
outputs
();
if
(
std
::
any_of
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
auto
output_ins
)
{
if
(
std
::
any_of
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
auto
output_ins
)
{
...
@@ -189,7 +188,7 @@ struct auto_gen_root_modules
...
@@ -189,7 +188,7 @@ struct auto_gen_root_modules
mm
->
debug_print
();
mm
->
debug_print
();
}
}
size_t
current_tid
=
std
::
numeric_limits
<
std
::
size_t
>::
max
()
;
std
::
optional
<
std
::
size_t
>
current_tid
=
nullopt
;
for
(
auto
ins
:
iterator_for
(
*
mm
))
for
(
auto
ins
:
iterator_for
(
*
mm
))
{
{
if
(
enabled
(
MIGRAPHX_DEBUG_PARTITIONER
{}))
if
(
enabled
(
MIGRAPHX_DEBUG_PARTITIONER
{}))
...
@@ -205,38 +204,48 @@ struct auto_gen_root_modules
...
@@ -205,38 +204,48 @@ struct auto_gen_root_modules
{
{
continue
;
continue
;
}
}
else
if
(
ins
->
name
()
==
"@return"
or
is_different_subgraph
(
ins
,
current_tid
)
or
if
(
not
current_tid
.
has_value
())
is_merge_node
(
ins
,
current_tid
))
{
{
generate_run_on_target_modules
(
mm
,
p
,
ins
,
current_tid
);
if
(
tass
.
find
(
ins
)
==
tass
.
end
())
}
else
if
(
is_fork_node
(
ins
,
current_tid
))
{
same_tid_ins_vec
.
push_back
(
ins
);
same_tid_ins_set
.
insert
(
ins
);
generate_run_on_target_modules
(
mm
,
p
,
std
::
next
(
ins
),
current_tid
);
if
(
not
same_tid_ins_vec
.
empty
())
{
{
current_tid
=
std
::
numeric_limits
<
std
::
size_t
>::
max
();
continue
;
same_tid_ins_set
.
erase
(
ins
);
}
same_tid_ins_vec
.
pop_back
();
else
{
current_tid
=
std
::
make_optional
<
std
::
size_t
>
(
tass
.
at
(
ins
));
update_tid_counter
(
current_tid
.
value
());
same_tid_ins_vec
.
push_back
(
ins
);
same_tid_ins_set
.
insert
(
ins
);
}
}
}
else
if
(
current_tid
==
std
::
numeric_limits
<
std
::
size_t
>::
max
())
{
current_tid
=
tass
.
at
(
ins
);
update_tid_counter
(
current_tid
);
same_tid_ins_vec
.
push_back
(
ins
);
same_tid_ins_set
.
insert
(
ins
);
}
else
if
(
tass
.
at
(
ins
)
==
current_tid
)
{
same_tid_ins_vec
.
push_back
(
ins
);
same_tid_ins_set
.
insert
(
ins
);
}
}
else
else
{
{
MIGRAPHX_THROW
(
"Partition: this case shouldn't occur"
);
if
(
ins
->
name
()
==
"@return"
or
is_different_subgraph
(
ins
,
current_tid
)
or
is_merge_node
(
ins
,
current_tid
))
{
generate_run_on_target_modules
(
mm
,
p
,
ins
,
current_tid
.
value
());
}
else
if
(
is_fork_node
(
ins
,
current_tid
))
{
same_tid_ins_vec
.
push_back
(
ins
);
same_tid_ins_set
.
insert
(
ins
);
generate_run_on_target_modules
(
mm
,
p
,
std
::
next
(
ins
),
current_tid
.
value
());
if
(
not
same_tid_ins_vec
.
empty
())
{
current_tid
=
nullopt
;
same_tid_ins_set
.
erase
(
ins
);
same_tid_ins_vec
.
pop_back
();
}
}
else
if
(
tass
.
at
(
ins
)
==
current_tid
.
value
())
{
same_tid_ins_vec
.
push_back
(
ins
);
same_tid_ins_set
.
insert
(
ins
);
}
else
{
MIGRAPHX_THROW
(
"Partition: this case shouldn't occur"
);
}
}
}
if
(
skip_ins
.
find
(
ins
)
==
skip_ins
.
end
()
and
not
ins
->
module_inputs
().
empty
())
if
(
skip_ins
.
find
(
ins
)
==
skip_ins
.
end
()
and
not
ins
->
module_inputs
().
empty
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment