Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1e80ceef
Commit
1e80ceef
authored
Oct 20, 2023
by
Umang Yadav
Browse files
add single target multiple returns
parent
1796d3e3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
75 additions
and
11 deletions
+75
-11
src/module.cpp
src/module.cpp
+11
-6
test/generate_root_modules.cpp
test/generate_root_modules.cpp
+64
-5
No files found.
src/module.cpp
View file @
1e80ceef
...
@@ -114,7 +114,7 @@ struct module_impl
...
@@ -114,7 +114,7 @@ struct module_impl
const
operation
&
get_operation
(
instruction_ref
ins
)
{
return
ins
->
get_operator
();
}
const
operation
&
get_operation
(
instruction_ref
ins
)
{
return
ins
->
get_operator
();
}
module
::
module
(
const
std
::
string
&
name
)
:
impl
(
std
::
make_unique
<
module_impl
>
())
module
::
module
(
const
std
::
string
&
name
)
:
impl
(
std
::
make_unique
<
module_impl
>
())
{
{
impl
->
name
=
name
;
impl
->
name
=
name
;
}
}
...
@@ -165,7 +165,7 @@ void module::assign(const module& m)
...
@@ -165,7 +165,7 @@ void module::assign(const module& m)
auto
order
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
order
;
auto
order
=
any_cast
<
builtin
::
param
>
(
ins
->
get_operator
()).
order
;
auto
s
=
ins
->
get_shape
();
auto
s
=
ins
->
get_shape
();
copy_ins
=
impl
->
insert
(
impl
->
instructions
.
end
(),
copy_ins
=
impl
->
insert
(
impl
->
instructions
.
end
(),
{
builtin
::
param
{
name
,
order
},
std
::
move
(
s
),
{}});
{
builtin
::
param
{
name
,
order
},
std
::
move
(
s
),
{}});
impl
->
nparams
++
;
impl
->
nparams
++
;
}
}
else
if
(
ins
->
name
()
==
"@outline"
)
else
if
(
ins
->
name
()
==
"@outline"
)
...
@@ -800,8 +800,10 @@ static std::string cpp_var_name(const std::string& name)
...
@@ -800,8 +800,10 @@ static std::string cpp_var_name(const std::string& name)
{
{
std
::
string
prefix
=
"x_"
;
std
::
string
prefix
=
"x_"
;
if
(
not
contains
(
name
,
"@"
))
if
(
not
contains
(
name
,
"@"
))
prefix
=
"p_"
;
{
return
to_c_id
(
prefix
+
replace_string
(
name
,
":"
,
"_module_"
));
return
to_c_id
(
name
);
}
return
to_c_id
(
prefix
+
name
);
}
}
static
void
print_py_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
static
void
print_py_op
(
std
::
ostream
&
os
,
const
operation
&
op
)
...
@@ -825,8 +827,11 @@ static void print_make_op(std::ostream& os, const operation& op)
...
@@ -825,8 +827,11 @@ static void print_make_op(std::ostream& os, const operation& op)
auto
v
=
op
.
to_value
();
auto
v
=
op
.
to_value
();
if
(
not
v
.
empty
())
if
(
not
v
.
empty
())
{
{
os
<<
"migraphx::make_json_op("
<<
enclose_name
(
op
.
name
());
os
<<
"migraphx::make_op("
<<
enclose_name
(
op
.
name
());
os
<<
", "
<<
enclose_name
(
to_json_string
(
v
));
auto
rname
=
"{"
+
replace_string
(
to_json_string
(
v
),
"
\"
"
,
"
\\\"
"
)
+
"}"
;
rname
=
replace_string
(
rname
,
":"
,
", "
);
rname
=
replace_string
(
rname
,
"
\\
"
,
""
);
os
<<
", "
<<
rname
;
}
}
else
else
{
{
...
...
test/generate_root_modules.cpp
View file @
1e80ceef
...
@@ -222,6 +222,65 @@ TEST_CASE(two_targets_ref_inbetween)
...
@@ -222,6 +222,65 @@ TEST_CASE(two_targets_ref_inbetween)
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
}
}
TEST_CASE
(
single_target_multiple_returns
)
{
/*
Add (tid = 0)
|
---------------
| |
Mul Identity
(tid = 0) (tid = 0)
| |
---------------
|
Return
*/
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
8
}};
migraphx
::
target_assignments
tass
;
migraphx
::
program
p1
;
{
auto
*
mm
=
p1
.
get_main_module
();
auto
x_param
=
mm
->
add_parameter
(
"x"
,
s
);
auto
y_param
=
mm
->
add_parameter
(
"y"
,
s
);
auto
z_param
=
mm
->
add_parameter
(
"z"
,
s
);
auto
add_ins
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x_param
,
y_param
);
auto
identity_ins
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"identity"
),
add_ins
);
auto
mul_ins
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
add_ins
,
z_param
);
mm
->
add_return
({
mul_ins
,
identity_ins
});
tass
.
insert
(
tass
.
begin
(),
std
::
make_pair
(
add_ins
,
0
));
tass
.
insert
(
tass
.
begin
(),
std
::
make_pair
(
mul_ins
,
0
));
tass
.
insert
(
tass
.
begin
(),
std
::
make_pair
(
identity_ins
,
0
));
}
migraphx
::
generate_root_modules
(
p1
,
tass
);
migraphx
::
program
p2
;
{
migraphx
::
module_ref
mm
=
p2
.
get_main_module
();
auto
z
=
mm
->
add_parameter
(
"z"
,
s
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s
);
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
migraphx
::
module_ref
target_mod_0_0
=
p2
.
create_module
(
"target_mod_0_0"
);
auto
target_mod_0_0_param_2
=
target_mod_0_0
->
add_parameter
(
"param:2"
,
s
);
auto
target_mod_0_0_param_1
=
target_mod_0_0
->
add_parameter
(
"param:1"
,
s
);
auto
target_mod_0_0_param_0
=
target_mod_0_0
->
add_parameter
(
"param:0"
,
s
);
auto
x_target_mod_0_0_2
=
target_mod_0_0
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
target_mod_0_0_param_2
,
target_mod_0_0_param_1
);
auto
x_target_mod_0_0_3
=
target_mod_0_0
->
add_instruction
(
migraphx
::
make_op
(
"identity"
),
x_target_mod_0_0_2
);
auto
x_target_mod_0_0_4
=
target_mod_0_0
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
x_target_mod_0_0_2
,
target_mod_0_0_param_0
);
target_mod_0_0
->
add_return
({
x_target_mod_0_0_3
,
x_target_mod_0_0_4
});
auto
x_2
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"run_on_target"
,
{{
"target_id"
,
0
}}),
{
z
,
y
,
x
},
{
target_mod_0_0
});
auto
x_3
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
0
}}),
x_2
);
auto
x_4
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
1
}}),
x_2
);
mm
->
add_return
({
x_4
,
x_3
});
}
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
}
TEST_CASE
(
if_then_else_program
)
TEST_CASE
(
if_then_else_program
)
{
{
/*
/*
...
@@ -663,7 +722,7 @@ TEST_CASE(fork_and_merge_case_1)
...
@@ -663,7 +722,7 @@ TEST_CASE(fork_and_merge_case_1)
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
};
};
TEST_CASE
(
fork_and_
merge_case_2
)
TEST_CASE
(
fork_and_
return_as_merge_bypass_branch_and_tass_on_other
)
{
{
/*
/*
**** Fork node returning ****
**** Fork node returning ****
...
@@ -716,7 +775,7 @@ TEST_CASE(fork_and_merge_case_2)
...
@@ -716,7 +775,7 @@ TEST_CASE(fork_and_merge_case_2)
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
}
}
TEST_CASE
(
fork_and_
merge_case_3
)
TEST_CASE
(
fork_and_
return_as_merge_bypass_branch_and_no_tass_on_other
)
{
{
/*
/*
**** Fork node returning ****
**** Fork node returning ****
...
@@ -770,7 +829,7 @@ TEST_CASE(fork_and_merge_case_3)
...
@@ -770,7 +829,7 @@ TEST_CASE(fork_and_merge_case_3)
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
}
}
TEST_CASE
(
fork_and_
merge_case_4
)
TEST_CASE
(
fork_and_
return_as_merge_different_tass_on_both_branches
)
{
{
/*
/*
Add (tid = 0)
Add (tid = 0)
...
@@ -848,7 +907,7 @@ TEST_CASE(fork_and_merge_case_4)
...
@@ -848,7 +907,7 @@ TEST_CASE(fork_and_merge_case_4)
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
};
};
TEST_CASE
(
fork_and_
merge_case_5
)
TEST_CASE
(
fork_and_
return_as_merge_no_tass_on_both_branch
)
{
{
/*
/*
Add (no assignment)
Add (no assignment)
...
@@ -881,7 +940,7 @@ TEST_CASE(fork_and_merge_case_5)
...
@@ -881,7 +940,7 @@ TEST_CASE(fork_and_merge_case_5)
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
EXPECT
(
p1
.
sort
()
==
p2
.
sort
());
}
}
TEST_CASE
(
fork_and_
merge_case_6
)
TEST_CASE
(
fork_and_
return_as_merge_no_tass_on_one_branch
)
{
{
/*
/*
Add (no assignment)
Add (no assignment)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment