Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3cc4be9e
Commit
3cc4be9e
authored
Oct 19, 2023
by
Umang Yadav
Browse files
Fix conditions for fork-merge on return case
parent
44369c8e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
27 deletions
+37
-27
src/generate_root_modules.cpp
src/generate_root_modules.cpp
+21
-8
test/generate_root_modules.cpp
test/generate_root_modules.cpp
+16
-19
No files found.
src/generate_root_modules.cpp
View file @
3cc4be9e
...
@@ -149,8 +149,8 @@ struct auto_gen_root_modules
...
@@ -149,8 +149,8 @@ struct auto_gen_root_modules
}
}
}
}
bool
i
s_different_
subgraph
(
migraphx
::
instruction_ref
current_ins
,
bool
ha
s_different_
tass
(
migraphx
::
instruction_ref
current_ins
,
std
::
optional
<
std
::
size_t
>
previous_tid
)
std
::
optional
<
std
::
size_t
>
previous_tid
)
{
{
if
(
tass
.
find
(
current_ins
)
==
tass
.
end
())
if
(
tass
.
find
(
current_ins
)
==
tass
.
end
())
{
{
...
@@ -162,11 +162,13 @@ struct auto_gen_root_modules
...
@@ -162,11 +162,13 @@ struct auto_gen_root_modules
/*
/*
Merge node is defined as node where two or more branches converge.
Merge node is defined as node where two or more branches converge.
NodeX NodeY
NodeX NodeY
| |
| |
---------
---------
|
|
NodeZ
NodeZ
For the partitioner, if any of the merge node's input doesn't have same tid as the merge node
For the partitioner, if any of the merge node's input doesn't have same tid as the merge node
itself then, it is classified as boundary for subgraph.
itself then, it is classified as boundary for subgraph.
*/
*/
...
@@ -179,17 +181,22 @@ struct auto_gen_root_modules
...
@@ -179,17 +181,22 @@ struct auto_gen_root_modules
return
false
;
return
false
;
}
}
return
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
input_ins
)
{
return
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
input_ins
)
{
return
is_different_subgraph
(
input_ins
,
ins_tid
);
return
(
(
this
->
skip_ins
.
find
(
input_ins
)
!=
skip_ins
.
end
())
or
(
tass
.
find
(
input_ins
)
!=
tass
.
end
()
and
tass
.
at
(
input_ins
)
!=
ins_tid
.
value_or
(
std
::
numeric_limits
<
std
::
size_t
>::
max
())));
});
});
}
}
/*
/*
Fork node is defined as node where graph forks into two or more branches
Fork node is defined as node where graph forks into two or more branches
NodeX
NodeX
|
|
------------
------------
| |
| |
NodeY NodeZ
NodeY NodeZ
For the partitioner, if any of the fork node's output doesn't have same tid as the fork node
For the partitioner, if any of the fork node's output doesn't have same tid as the fork node
itself then, it is classified as boundary for subgraph.
itself then, it is classified as boundary for subgraph.
*/
*/
...
@@ -201,7 +208,13 @@ struct auto_gen_root_modules
...
@@ -201,7 +208,13 @@ struct auto_gen_root_modules
return
false
;
return
false
;
}
}
return
std
::
any_of
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
auto
output_ins
)
{
return
std
::
any_of
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
auto
output_ins
)
{
return
is_different_subgraph
(
output_ins
,
ins_tid
);
if
(
output_ins
->
name
()
==
"return"
)
{
return
false
;
}
return
(
tass
.
find
(
output_ins
)
!=
tass
.
end
()
and
tass
.
at
(
output_ins
)
!=
ins_tid
.
value_or
(
std
::
numeric_limits
<
std
::
size_t
>::
max
()));
});
});
}
}
...
@@ -254,7 +267,7 @@ struct auto_gen_root_modules
...
@@ -254,7 +267,7 @@ struct auto_gen_root_modules
}
}
else
else
{
{
if
(
ins
->
name
()
==
"@return"
or
i
s_different_
subgraph
(
ins
,
current_tid
)
or
if
(
ins
->
name
()
==
"@return"
or
ha
s_different_
tass
(
ins
,
current_tid
)
or
is_merge_node
(
ins
,
current_tid
))
is_merge_node
(
ins
,
current_tid
))
{
{
generate_run_on_target_modules
(
mm
,
p
,
ins
,
current_tid
);
generate_run_on_target_modules
(
mm
,
p
,
ins
,
current_tid
);
...
@@ -304,7 +317,7 @@ struct auto_gen_root_modules
...
@@ -304,7 +317,7 @@ struct auto_gen_root_modules
// gather all parameters
// gather all parameters
std
::
unordered_set
<
instruction_ref
>
params
;
std
::
unordered_set
<
instruction_ref
>
params
;
// gather all return values
// gather all return values
std
::
unordered_set
<
instruction_ref
>
return_ins
;
std
::
vector
<
instruction_ref
>
return_ins
;
for
(
auto
tins
:
iterator_for
(
same_tid_ins_vec
))
for
(
auto
tins
:
iterator_for
(
same_tid_ins_vec
))
{
{
auto
inputs
=
(
*
tins
)
->
inputs
();
auto
inputs
=
(
*
tins
)
->
inputs
();
...
@@ -321,7 +334,7 @@ struct auto_gen_root_modules
...
@@ -321,7 +334,7 @@ struct auto_gen_root_modules
return
same_tid_ins_set
.
count
(
out_ins
)
==
0
;
return
same_tid_ins_set
.
count
(
out_ins
)
==
0
;
}))
}))
{
{
return_ins
.
insert
(
*
tins
);
return_ins
.
push_back
(
*
tins
);
}
}
}
}
if
(
enabled
(
MIGRAPHX_DEBUG_ROOT_GENERATOR
{}))
if
(
enabled
(
MIGRAPHX_DEBUG_ROOT_GENERATOR
{}))
...
@@ -380,7 +393,7 @@ struct auto_gen_root_modules
...
@@ -380,7 +393,7 @@ struct auto_gen_root_modules
for
(
auto
ritr
:
iterator_for
(
return_ins
))
for
(
auto
ritr
:
iterator_for
(
return_ins
))
{
{
rins
.
push_back
(
params_map
.
at
(
*
ritr
));
rins
.
push_back
(
params_map
.
at
(
*
ritr
));
return_ins_idx_map
[
*
ritr
]
=
std
::
distance
(
ritr
,
return_ins
.
begin
());
return_ins_idx_map
[
*
ritr
]
=
std
::
distance
(
return_ins
.
begin
()
,
ritr
);
}
}
tmod
->
add_return
(
rins
);
tmod
->
add_return
(
rins
);
...
...
test/generate_root_modules.cpp
View file @
3cc4be9e
...
@@ -485,9 +485,11 @@ TEST_CASE(fork_case_4)
...
@@ -485,9 +485,11 @@ TEST_CASE(fork_case_4)
|
|
---------------------------
---------------------------
| |
| |
Identity (tid = 0) Return
Identity (tid = 0) |
|
| |
Return
--------------------------
|
Return
*/
*/
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
8
}};
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
8
}};
migraphx
::
target_assignments
tass
;
migraphx
::
target_assignments
tass
;
...
@@ -505,29 +507,24 @@ TEST_CASE(fork_case_4)
...
@@ -505,29 +507,24 @@ TEST_CASE(fork_case_4)
migraphx
::
generate_root_modules
(
p1
,
tass
);
migraphx
::
generate_root_modules
(
p1
,
tass
);
migraphx
::
program
p2
;
migraphx
::
program
p2
;
{
{
migraphx
::
module_ref
mm
=
p2
.
get_main_module
();
migraphx
::
module_ref
mm
=
p2
.
get_main_module
();
auto
y
=
mm
->
add_parameter
(
"y"
,
s
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s
);
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
migraphx
::
module_ref
target_mod_0_0
=
p2
.
create_module
(
"target_mod_0_0"
);
migraphx
::
module_ref
target_mod_0_0
=
p2
.
create_module
(
"target_mod_0_0"
);
auto
target_mod_0_0_param_1
=
target_mod_0_0
->
add_parameter
(
"param:1"
,
s
);
auto
target_mod_0_0_param_0
=
target_mod_0_0
->
add_parameter
(
"param:0"
,
s
);
auto
target_mod_0_0_param_0
=
target_mod_0_0
->
add_parameter
(
"param:0"
,
s
);
auto
x_target_mod_0_0_2
=
target_mod_0_0
->
add_instruction
(
auto
target_mod_0_0_param_1
=
target_mod_0_0
->
add_parameter
(
"param:1"
,
s
);
auto
x_target_mod_0_0_0
=
target_mod_0_0
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
target_mod_0_0_param_1
,
target_mod_0_0_param_0
);
migraphx
::
make_op
(
"add"
),
target_mod_0_0_param_1
,
target_mod_0_0_param_0
);
target_mod_0_0
->
add_return
({
x_target_mod_0_0_2
});
auto
x_target_mod_0_0_1
=
target_mod_0_0
->
add_instruction
(
migraphx
::
make_op
(
"identity"
),
x_target_mod_0_0_0
);
migraphx
::
module_ref
target_mod_0_1
=
p2
.
create_module
(
"target_mod_0_1"
);
target_mod_0_0
->
add_return
({
x_target_mod_0_0_0
,
x_target_mod_0_0_1
});
auto
target_mod_0_1_param_0
=
target_mod_0_1
->
add_parameter
(
"param:0"
,
s
);
auto
x_target_mod_0_1_1
=
target_mod_0_1
->
add_instruction
(
migraphx
::
make_op
(
"identity"
),
target_mod_0_1_param_0
);
target_mod_0_1
->
add_return
({
x_target_mod_0_1_1
});
auto
x_2
=
mm
->
add_instruction
(
auto
x_2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"run_on_target"
,
{{
"target_id"
,
0
}}),
{
y
,
x
},
{
target_mod_0_0
});
migraphx
::
make_op
(
"run_on_target"
,
{{
"target_id"
,
0
}}),
{
y
,
x
},
{
target_mod_0_0
});
auto
x_3
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
0
}}),
x_2
);
auto
x_3
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
0
}}),
x_2
);
auto
x_4
=
mm
->
add_instruction
(
auto
x_4
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
1
}}),
x_2
);
migraphx
::
make_op
(
"run_on_target"
,
{{
"target_id"
,
0
}}),
{
x_3
},
{
target_mod_0_1
});
mm
->
add_return
({
x_3
,
x_4
});
auto
x_5
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
0
}}),
x_4
);
mm
->
add_return
({
x_3
,
x_5
});
}
}
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment