Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8b2ee166
Commit
8b2ee166
authored
Oct 13, 2023
by
Umang Yadav
Browse files
Fork and merge cases working
parent
8db527c7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
45 deletions
+44
-45
src/generate_root_modules.cpp
src/generate_root_modules.cpp
+44
-45
No files found.
src/generate_root_modules.cpp
View file @
8b2ee166
...
@@ -152,29 +152,41 @@ struct auto_gen_root_modules
...
@@ -152,29 +152,41 @@ struct auto_gen_root_modules
bool
is_merge_node
(
migraphx
::
instruction_ref
ins
,
std
::
optional
<
std
::
size_t
>
tid
)
bool
is_merge_node
(
migraphx
::
instruction_ref
ins
,
std
::
optional
<
std
::
size_t
>
tid
)
{
{
const
auto
inputs
=
ins
->
inputs
();
const
auto
inputs
=
ins
->
inputs
();
return
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
input_ins
)
{
if
(
inputs
.
size
()
==
1
)
if
((
skip_ins
.
find
(
input_ins
)
!=
skip_ins
.
end
())
or
{
(
tass
.
find
(
input_ins
)
!=
tass
.
end
()
and
tass
.
at
(
ins
)
!=
tid
.
value_or
(
std
::
numeric_limits
<
std
::
size_t
>::
max
())))
{
return
true
;
}
return
false
;
return
false
;
});
}
if
(
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
input_ins
)
{
if
((
skip_ins
.
find
(
input_ins
)
!=
skip_ins
.
end
())
or
(
tass
.
find
(
input_ins
)
!=
tass
.
end
()
and
tass
.
at
(
ins
)
!=
tid
.
value_or
(
std
::
numeric_limits
<
std
::
size_t
>::
max
())))
{
return
true
;
}
return
false
;
}))
return
true
;
return
false
;
}
}
bool
is_fork_node
(
migraphx
::
instruction_ref
ins
,
std
::
optional
<
std
::
size_t
>
tid
)
bool
is_fork_node
(
migraphx
::
instruction_ref
ins
,
std
::
optional
<
std
::
size_t
>
tid
)
{
{
const
auto
outputs
=
ins
->
outputs
();
const
auto
outputs
=
ins
->
outputs
();
return
std
::
any_of
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
auto
output_ins
)
{
if
(
outputs
.
size
()
==
1
)
if
(
tass
.
find
(
output_ins
)
!=
tass
.
end
()
and
{
tass
.
at
(
output_ins
)
!=
tid
.
value_or
(
std
::
numeric_limits
<
std
::
size_t
>::
max
())
and
output_ins
->
name
()
!=
"@return"
)
{
return
true
;
}
return
false
;
return
false
;
});
}
if
(
std
::
any_of
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
auto
output_ins
)
{
if
(
tass
.
find
(
output_ins
)
!=
tass
.
end
()
and
tass
.
at
(
output_ins
)
!=
tid
.
value_or
(
std
::
numeric_limits
<
std
::
size_t
>::
max
())
and
output_ins
->
name
()
!=
"@return"
)
{
return
true
;
}
return
false
;
}))
return
true
;
return
false
;
}
}
void
find_subgraphs
(
migraphx
::
module_ref
mm
,
migraphx
::
program
&
p
)
void
find_subgraphs
(
migraphx
::
module_ref
mm
,
migraphx
::
program
&
p
)
...
@@ -186,17 +198,28 @@ struct auto_gen_root_modules
...
@@ -186,17 +198,28 @@ struct auto_gen_root_modules
std
::
cout
<<
"sorted module:
\n
"
;
std
::
cout
<<
"sorted module:
\n
"
;
mm
->
debug_print
();
mm
->
debug_print
();
}
}
bool
fork_node
=
false
;
std
::
optional
<
std
::
size_t
>
current_tid
=
nullopt
;
std
::
optional
<
std
::
size_t
>
current_tid
=
nullopt
;
for
(
auto
ins
:
iterator_for
(
*
mm
))
for
(
auto
ins
:
iterator_for
(
*
mm
))
{
{
if
(
enabled
(
MIGRAPHX_DEBUG_ROOT_GENERATOR
{}))
if
(
enabled
(
MIGRAPHX_DEBUG_ROOT_GENERATOR
{}))
{
{
std
::
cout
<<
"looking at instruction:
\n
"
;
std
::
cout
<<
"looking at instruction:
\n
"
;
std
::
cout
<<
"ins->name() == "
<<
ins
->
name
()
<<
std
::
endl
;
ins
->
debug_print
();
ins
->
debug_print
();
}
}
if
(
fork_node
)
{
std
::
cout
<<
"found fork node
\n
"
;
assert
(
current_tid
.
has_value
());
generate_run_on_target_modules
(
mm
,
p
,
ins
,
current_tid
.
value
());
if
(
not
same_tid_ins_vec
.
empty
())
{
current_tid
=
nullopt
;
same_tid_ins_set
.
erase
(
ins
);
same_tid_ins_vec
.
pop_back
();
}
fork_node
=
false
;
}
// skip all params, literal and builtins other than return, skip "run_on_target_mod"
// skip all params, literal and builtins other than return, skip "run_on_target_mod"
// ins
// ins
if
((
starts_with
(
ins
->
name
(),
"@"
)
and
ins
->
name
()
!=
"@return"
)
or
if
((
starts_with
(
ins
->
name
(),
"@"
)
and
ins
->
name
()
!=
"@return"
)
or
...
@@ -212,18 +235,7 @@ struct auto_gen_root_modules
...
@@ -212,18 +235,7 @@ struct auto_gen_root_modules
update_tid_counter
(
current_tid
.
value
());
update_tid_counter
(
current_tid
.
value
());
same_tid_ins_vec
.
push_back
(
ins
);
same_tid_ins_vec
.
push_back
(
ins
);
same_tid_ins_set
.
insert
(
ins
);
same_tid_ins_set
.
insert
(
ins
);
if
(
is_fork_node
(
ins
,
current_tid
))
fork_node
=
is_fork_node
(
ins
,
current_tid
);
{
generate_run_on_target_modules
(
mm
,
p
,
std
::
next
(
ins
),
current_tid
.
value
());
if
(
not
same_tid_ins_vec
.
empty
())
{
// generate() method would populate these container for next(ins),
// remove them to maintain invariant
current_tid
=
nullopt
;
same_tid_ins_set
.
erase
(
std
::
next
(
ins
));
same_tid_ins_vec
.
pop_back
();
}
}
}
}
}
}
else
else
...
@@ -233,20 +245,6 @@ struct auto_gen_root_modules
...
@@ -233,20 +245,6 @@ struct auto_gen_root_modules
{
{
generate_run_on_target_modules
(
mm
,
p
,
ins
,
current_tid
.
value
());
generate_run_on_target_modules
(
mm
,
p
,
ins
,
current_tid
.
value
());
}
}
else
if
(
is_fork_node
(
ins
,
current_tid
))
{
same_tid_ins_vec
.
push_back
(
ins
);
same_tid_ins_set
.
insert
(
ins
);
generate_run_on_target_modules
(
mm
,
p
,
std
::
next
(
ins
),
current_tid
.
value
());
if
(
not
same_tid_ins_vec
.
empty
())
{
// generate() method would populate these container for next(ins), remove
// them to maintain invariant
current_tid
=
nullopt
;
same_tid_ins_set
.
erase
(
std
::
next
(
ins
));
same_tid_ins_vec
.
pop_back
();
}
}
else
if
(
tass
.
at
(
ins
)
==
current_tid
.
value
())
else
if
(
tass
.
at
(
ins
)
==
current_tid
.
value
())
{
{
same_tid_ins_vec
.
push_back
(
ins
);
same_tid_ins_vec
.
push_back
(
ins
);
...
@@ -256,6 +254,7 @@ struct auto_gen_root_modules
...
@@ -256,6 +254,7 @@ struct auto_gen_root_modules
{
{
MIGRAPHX_THROW
(
"GenerateRootModules: this case shouldn't occur"
);
MIGRAPHX_THROW
(
"GenerateRootModules: this case shouldn't occur"
);
}
}
fork_node
=
is_fork_node
(
ins
,
current_tid
);
}
}
if
(
not
ins
->
module_inputs
().
empty
())
if
(
not
ins
->
module_inputs
().
empty
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment