Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
fa12da23
Commit
fa12da23
authored
Oct 20, 2023
by
Umang Yadav
Browse files
Changes for the order fix
parent
1e80ceef
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
261 additions
and
139 deletions
+261
-139
src/generate_root_modules.cpp
src/generate_root_modules.cpp
+85
-20
test/generate_root_modules.cpp
test/generate_root_modules.cpp
+176
-119
No files found.
src/generate_root_modules.cpp
View file @
fa12da23
...
@@ -174,18 +174,51 @@ struct auto_gen_root_modules
...
@@ -174,18 +174,51 @@ struct auto_gen_root_modules
*/
*/
bool
is_merge_node
(
migraphx
::
instruction_ref
ins
,
std
::
optional
<
std
::
size_t
>
ins_tid
)
bool
is_merge_node
(
migraphx
::
instruction_ref
ins
,
std
::
optional
<
std
::
size_t
>
ins_tid
)
{
{
const
auto
inputs
=
ins
->
inputs
();
const
auto
inputs
=
ins
->
inputs
();
if
(
inputs
.
size
()
==
1
)
size_t
in_degree
=
inputs
.
size
();
if
(
in_degree
==
1
)
{
{
return
false
;
return
false
;
}
}
return
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
input_ins
)
{
size_t
input_from_other_tid_module
=
0
;
return
(
size_t
num_default_tids
=
0
;
(
this
->
skip_ins
.
find
(
input_ins
)
!=
skip_ins
.
end
())
or
size_t
num_different_tids
=
0
;
(
tass
.
find
(
input_ins
)
!=
tass
.
end
()
and
size_t
num_same_tid
=
0
;
tass
.
at
(
input_ins
)
!=
ins_tid
.
value_or
(
std
::
numeric_limits
<
std
::
size_t
>::
max
())));
// std::unordered_map<size_t, size_t> in_tid_freq_map;
});
for
(
const
auto
&
input_ins
:
inputs
)
{
if
(
skip_ins
.
find
(
input_ins
)
!=
skip_ins
.
end
())
{
input_from_other_tid_module
++
;
}
else
if
(
tass
.
find
(
input_ins
)
==
tass
.
end
())
{
num_default_tids
++
;
}
else
if
(
tass
.
at
(
input_ins
)
!=
ins_tid
)
{
num_different_tids
++
;
}
else
{
num_same_tid
++
;
}
}
assert
(
input_from_other_tid_module
+
num_default_tids
+
num_different_tids
+
num_same_tid
==
in_degree
);
if
(
input_from_other_tid_module
>
1
)
{
return
true
;
}
else
if
(
input_from_other_tid_module
+
num_default_tids
==
in_degree
)
{
return
false
;
}
else
if
(
num_same_tid
+
num_default_tids
==
in_degree
)
{
return
false
;
}
return
true
;
}
}
/*
/*
...
@@ -200,21 +233,47 @@ struct auto_gen_root_modules
...
@@ -200,21 +233,47 @@ struct auto_gen_root_modules
For the partitioner, if any of the fork node's output doesn't have same tid as the fork node
For the partitioner, if any of the fork node's output doesn't have same tid as the fork node
itself then, it is classified as boundary for subgraph.
itself then, it is classified as boundary for subgraph.
*/
*/
bool
is_fork_node
(
migraphx
::
instruction_ref
ins
,
std
::
optional
<
std
::
size_t
>
ins_tid
)
bool
is_fork_node
(
migraphx
::
instruction_ref
ins
,
std
::
size_t
ins_tid
)
{
{
const
auto
outputs
=
ins
->
outputs
();
const
auto
outputs
=
ins
->
outputs
();
if
(
outputs
.
size
()
==
1
)
if
(
outputs
.
size
()
==
1
)
{
{
return
false
;
return
false
;
}
}
// if all the outputs are for the "default" or with same tid then it is not a fork but
// rather simply a boundary
std
::
unordered_map
<
std
::
size_t
,
std
::
size_t
>
output_tids
;
for
(
const
auto
&
output_ins
:
outputs
)
{
if
(
tass
.
find
(
output_ins
)
!=
tass
.
end
())
{
auto
out_tid
=
tass
.
at
(
output_ins
);
if
(
output_tids
.
find
(
out_tid
)
==
output_tids
.
end
())
{
output_tids
[
out_tid
]
=
1
;
}
else
{
output_tids
[
out_tid
]
++
;
}
}
}
if
(
output_tids
.
empty
())
{
return
false
;
}
else
if
(
output_tids
.
size
()
==
1
and
output_tids
.
cbegin
()
->
second
==
outputs
.
size
())
{
return
false
;
}
return
std
::
any_of
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
auto
output_ins
)
{
return
std
::
any_of
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
auto
output_ins
)
{
if
(
output_ins
->
name
()
==
"return"
)
if
(
output_ins
->
name
()
==
"return"
)
{
{
return
false
;
return
false
;
}
}
return
(
tass
.
find
(
output_ins
)
!=
tass
.
end
()
and
return
(
tass
.
find
(
output_ins
)
!=
tass
.
end
()
and
tass
.
at
(
output_ins
)
!=
ins_tid
);
tass
.
at
(
output_ins
)
!=
ins_tid
.
value_or
(
std
::
numeric_limits
<
std
::
size_t
>::
max
()));
});
});
}
}
...
@@ -262,7 +321,7 @@ struct auto_gen_root_modules
...
@@ -262,7 +321,7 @@ struct auto_gen_root_modules
current_tid
=
std
::
make_optional
<
std
::
size_t
>
(
tass
.
at
(
ins
));
current_tid
=
std
::
make_optional
<
std
::
size_t
>
(
tass
.
at
(
ins
));
same_tid_ins_vec
.
push_back
(
ins
);
same_tid_ins_vec
.
push_back
(
ins
);
same_tid_ins_set
.
insert
(
ins
);
same_tid_ins_set
.
insert
(
ins
);
fork_node
=
is_fork_node
(
ins
,
current_tid
);
fork_node
=
is_fork_node
(
ins
,
current_tid
.
value
()
);
}
}
}
}
else
else
...
@@ -281,7 +340,8 @@ struct auto_gen_root_modules
...
@@ -281,7 +340,8 @@ struct auto_gen_root_modules
{
{
MIGRAPHX_THROW
(
"GenerateRootModules: this case shouldn't occur"
);
MIGRAPHX_THROW
(
"GenerateRootModules: this case shouldn't occur"
);
}
}
fork_node
=
is_fork_node
(
ins
,
current_tid
);
fork_node
=
is_fork_node
(
ins
,
current_tid
.
value_or
(
std
::
numeric_limits
<
std
::
size_t
>::
max
()));
}
}
if
(
not
ins
->
module_inputs
().
empty
())
if
(
not
ins
->
module_inputs
().
empty
())
...
@@ -315,7 +375,8 @@ struct auto_gen_root_modules
...
@@ -315,7 +375,8 @@ struct auto_gen_root_modules
return
;
return
;
}
}
// gather all parameters
// gather all parameters
std
::
unordered_set
<
instruction_ref
>
params
;
std
::
unordered_set
<
instruction_ref
>
params_set
;
std
::
vector
<
instruction_ref
>
params_vec
;
// gather all return values
// gather all return values
std
::
vector
<
instruction_ref
>
return_ins
;
std
::
vector
<
instruction_ref
>
return_ins
;
for
(
auto
tins
:
iterator_for
(
same_tid_ins_vec
))
for
(
auto
tins
:
iterator_for
(
same_tid_ins_vec
))
...
@@ -325,11 +386,15 @@ struct auto_gen_root_modules
...
@@ -325,11 +386,15 @@ struct auto_gen_root_modules
transform_if
(
transform_if
(
inputs
.
cbegin
(),
inputs
.
cbegin
(),
inputs
.
cend
(),
inputs
.
cend
(),
std
::
inserter
(
params
,
params
.
end
()
),
std
::
back_
inserter
(
params
_vec
),
[
&
](
auto
in_param
)
{
[
&
](
auto
in_param
)
{
return
(
params
.
count
(
in_param
)
==
0
and
same_tid_ins_set
.
count
(
in_param
)
==
0
);
return
(
params_set
.
count
(
in_param
)
==
0
and
same_tid_ins_set
.
count
(
in_param
)
==
0
);
},
},
[
&
](
auto
in_param
)
{
return
in_param
;
});
[
&
](
auto
in_param
)
{
params_set
.
insert
(
in_param
);
return
in_param
;
});
if
(
std
::
any_of
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
const
auto
out_ins
)
{
if
(
std
::
any_of
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
const
auto
out_ins
)
{
return
same_tid_ins_set
.
count
(
out_ins
)
==
0
;
return
same_tid_ins_set
.
count
(
out_ins
)
==
0
;
}))
}))
...
@@ -340,7 +405,7 @@ struct auto_gen_root_modules
...
@@ -340,7 +405,7 @@ struct auto_gen_root_modules
if
(
enabled
(
MIGRAPHX_DEBUG_ROOT_GENERATOR
{}))
if
(
enabled
(
MIGRAPHX_DEBUG_ROOT_GENERATOR
{}))
{
{
std
::
cout
<<
"params ins:
\n
"
;
std
::
cout
<<
"params ins:
\n
"
;
for
(
auto
tmp
:
iterator_for
(
params
))
for
(
auto
tmp
:
iterator_for
(
params
_vec
))
{
{
(
*
tmp
)
->
debug_print
();
(
*
tmp
)
->
debug_print
();
}
}
...
@@ -357,7 +422,7 @@ struct auto_gen_root_modules
...
@@ -357,7 +422,7 @@ struct auto_gen_root_modules
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
params_map
;
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
params_map
;
std
::
size_t
param_counter
=
0
;
std
::
size_t
param_counter
=
0
;
std
::
vector
<
instruction_ref
>
rot_ins_params
;
std
::
vector
<
instruction_ref
>
rot_ins_params
;
for
(
auto
pins
:
iterator_for
(
params
))
for
(
auto
pins
:
iterator_for
(
params
_vec
))
{
{
auto
scalar
=
get_scalar
(
*
pins
);
auto
scalar
=
get_scalar
(
*
pins
);
if
(
scalar
.
empty
())
if
(
scalar
.
empty
())
...
...
test/generate_root_modules.cpp
View file @
fa12da23
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment