Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
89e71273
Commit
89e71273
authored
Aug 31, 2023
by
umang yadav
Browse files
Partitioner working for nested if then else modules
parent
6c93676a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
356 additions
and
137 deletions
+356
-137
src/partitioner.cpp
src/partitioner.cpp
+213
-136
test/multi_target/multitarget_test.cpp
test/multi_target/multitarget_test.cpp
+143
-1
No files found.
src/partitioner.cpp
View file @
89e71273
...
...
@@ -67,6 +67,175 @@ static literal get_scalar(instruction_ref ins)
}
return
r
;
}
void
update_tid_counter
(
std
::
size_t
tid
,
std
::
unordered_map
<
std
::
size_t
,
std
::
size_t
>&
tid_counter
)
{
assert
(
tid
!=
std
::
numeric_limits
<
std
::
size_t
>::
max
());
if
(
tid_counter
.
find
(
tid
)
!=
tid_counter
.
end
())
{
tid_counter
[
tid
]
++
;
}
else
{
tid_counter
[
tid
]
=
0
;
}
}
void
generate_run_on_target_modules
(
migraphx
::
module_ref
mm
,
migraphx
::
program
&
p
,
migraphx
::
instruction_ref
ins
,
std
::
size_t
&
current_tid
,
const
target_assignments
&
tass
,
std
::
unordered_set
<
instruction_ref
>&
skip_ins
,
std
::
unordered_map
<
std
::
size_t
,
std
::
size_t
>&
tid_counter
,
std
::
vector
<
instruction_ref
>&
same_tid_ins_vec
,
std
::
unordered_set
<
instruction_ref
>&
same_tid_ins_set
)
{
assert
(
same_tid_ins_vec
.
size
()
==
same_tid_ins_set
.
size
());
if
(
same_tid_ins_vec
.
empty
())
{
assert
(
current_tid
==
std
::
numeric_limits
<
std
::
size_t
>::
max
());
return
;
}
// gather all parameters
std
::
unordered_set
<
instruction_ref
>
params
;
// gather all return values
std
::
unordered_set
<
instruction_ref
>
return_ins
;
for
(
auto
tins
:
iterator_for
(
same_tid_ins_vec
))
{
auto
inputs
=
(
*
tins
)
->
inputs
();
auto
outputs
=
(
*
tins
)
->
outputs
();
transform_if
(
inputs
.
cbegin
(),
inputs
.
cend
(),
std
::
inserter
(
params
,
params
.
end
()),
[
&
](
auto
in_param
)
{
return
(
params
.
count
(
in_param
)
==
0
and
same_tid_ins_set
.
count
(
in_param
)
==
0
);
},
[
&
](
auto
in_param
)
{
return
in_param
;
});
if
(
std
::
any_of
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
const
auto
out_ins
)
{
return
same_tid_ins_set
.
count
(
out_ins
)
==
0
;
}))
{
return_ins
.
insert
(
*
tins
);
}
}
if
(
enabled
(
MIGRAPHX_DEBUG_PARTITIONER
{}))
{
std
::
cout
<<
"params ins:
\n
"
;
for
(
auto
tmp
:
iterator_for
(
params
))
{
(
*
tmp
)
->
debug_print
();
}
std
::
cout
<<
"
\n
"
;
std
::
cout
<<
"return ins:
\n
"
;
for
(
auto
tmp
:
iterator_for
(
return_ins
))
{
(
*
tmp
)
->
debug_print
();
}
std
::
cout
<<
"
\n
"
;
}
auto
*
tmod
=
p
.
create_module
(
"target_mod_"
+
std
::
to_string
(
current_tid
)
+
"_"
+
std
::
to_string
(
tid_counter
[
current_tid
]));
update_tid_counter
(
current_tid
,
tid_counter
);
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
params_map
;
std
::
size_t
param_counter
=
0
;
std
::
vector
<
instruction_ref
>
rot_ins_params
;
for
(
auto
pins
:
iterator_for
(
params
))
{
auto
scalar
=
get_scalar
(
*
pins
);
if
(
scalar
.
empty
())
{
params_map
[
*
pins
]
=
tmod
->
add_parameter
(
"param:"
+
std
::
to_string
(
param_counter
++
),
(
*
pins
)
->
get_shape
());
rot_ins_params
.
push_back
(
*
pins
);
}
else
{
params_map
[
*
pins
]
=
tmod
->
add_literal
(
scalar
);
}
}
// TODO: what if params_map is empty ?
for
(
auto
tins
:
iterator_for
(
same_tid_ins_vec
))
{
auto
inputs
=
(
*
tins
)
->
inputs
();
std
::
vector
<
instruction_ref
>
new_inputs
;
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
new_inputs
),
[
&
](
auto
input_ins
)
{
return
params_map
.
at
(
input_ins
);
});
// [TODO]: what if it is has module args ?
auto
tmod_tins
=
tmod
->
add_instruction
((
*
tins
)
->
get_operator
(),
new_inputs
,
(
*
tins
)
->
module_inputs
());
// add parameter to params map so that it can be looked up by other insturctions
params_map
[
*
tins
]
=
tmod_tins
;
}
std
::
vector
<
instruction_ref
>
rins
;
std
::
unordered_map
<
instruction_ref
,
std
::
size_t
>
return_ins_idx_map
;
for
(
auto
ritr
:
iterator_for
(
return_ins
))
{
rins
.
push_back
(
params_map
.
at
(
*
ritr
));
return_ins_idx_map
[
*
ritr
]
=
std
::
distance
(
ritr
,
return_ins
.
begin
());
}
tmod
->
add_return
(
rins
);
if
(
enabled
(
MIGRAPHX_DEBUG_PARTITIONER
{}))
{
std
::
cout
<<
"tmod:
\n
"
;
tmod
->
debug_print
();
}
// add run_on_target ins
auto
rot_ins
=
mm
->
insert_instruction
(
ins
,
make_op
(
"run_on_target"
,
{{
"target_id"
,
current_tid
}}),
rot_ins_params
,
{
tmod
});
skip_ins
.
insert
(
rot_ins
);
// fetch return instructions from tuple
for
(
auto
mm_rins
:
iterator_for
(
return_ins
))
{
auto
tuple_elem_ins
=
mm
->
insert_instruction
(
ins
,
make_op
(
"get_tuple_elem"
,
{{
"index"
,
return_ins_idx_map
.
at
(
*
mm_rins
)}}),
rot_ins
);
skip_ins
.
insert
(
tuple_elem_ins
);
// replace returns from tmod in main module
mm
->
replace_instruction
(
*
mm_rins
,
tuple_elem_ins
);
}
dead_code_elimination
{}.
apply
(
*
mm
);
// update current_tid
same_tid_ins_set
.
clear
();
same_tid_ins_vec
.
clear
();
if
(
tass
.
find
(
ins
)
!=
tass
.
end
())
{
current_tid
=
tass
.
at
(
ins
);
update_tid_counter
(
current_tid
,
tid_counter
);
same_tid_ins_set
.
insert
(
ins
);
same_tid_ins_vec
.
push_back
(
ins
);
}
else
{
current_tid
=
std
::
numeric_limits
<
std
::
size_t
>::
max
();
}
if
(
enabled
(
MIGRAPHX_DEBUG_PARTITIONER
{}))
{
std
::
cout
<<
"module after creation of tmod and rot:
\n
"
;
mm
->
debug_print
();
}
}
/*
Given target assignments (tass) for the instructions, generate run_on_target modules subgraphs
automatically. Input graph should be uncompiled migraphx program. target assignments (tass) map
should have a map of instruction to target_id. Instructions that are not inside tass map are
considered to be targeted for the "Ref" by default. params, literals and other builtins shouldn't be
part of the tass, only compute and reshape instructions should be part of tass. Copy, sync and alloc
instructions would be generated by compiler at later stage, so those shouldn't be considered.
(TODO): CustomOps may require special handling.
Identify subgraph boundaries, Ref is used for instructions that do not have assignments
1. Ref --> Target X --> Ref
2. Ref --> Target X --> Target 2
3. Target X --> Target Y --> Target Z , in this case target X and target Z can be same
4. Target X --> "@return"
5. Target X --> Ref --> "@return"
*/
void
partition
(
migraphx
::
module_ref
mm
,
migraphx
::
program
&
p
,
const
target_assignments
&
tass
,
...
...
@@ -107,154 +276,62 @@ void partition(migraphx::module_ref mm,
ins
,
ins
->
get_operator
(),
ins
->
inputs
(),
ins
->
module_inputs
());
}
}
if
((
starts_with
(
ins
->
name
(),
"@"
)
and
ins
->
name
()
!=
"@return"
)
or
skip_ins
.
count
(
ins
)
!=
0
)
if
(
ins
->
name
()
==
"@return"
)
{
generate_run_on_target_modules
(
mm
,
p
,
ins
,
current_tid
,
tass
,
skip_ins
,
tid_counter
,
same_tid_ins_vec
,
same_tid_ins_set
);
}
// skip all params, literal and builitins other than return, skip "run_on_target_mod" ins
else
if
(
starts_with
(
ins
->
name
(),
"@"
)
or
skip_ins
.
count
(
ins
)
!=
0
)
{
continue
;
}
else
if
(
ins
->
name
()
!=
"@return"
and
current_tid
==
std
::
numeric_limits
<
std
::
size_t
>::
max
())
else
if
(
tass
.
find
(
ins
)
==
tass
.
end
())
{
if
(
tass
.
find
(
ins
)
==
tass
.
end
())
{
continue
;
}
current_tid
=
tass
.
at
(
ins
);
tid_counter
[
current_tid
]
=
0
;
generate_run_on_target_modules
(
mm
,
p
,
ins
,
current_tid
,
tass
,
skip_ins
,
tid_counter
,
same_tid_ins_vec
,
same_tid_ins_set
);
}
else
if
(
current_tid
==
std
::
numeric_limits
<
std
::
size_t
>::
max
())
{
current_tid
=
tass
.
at
(
ins
);
update_tid_counter
(
current_tid
,
tid_counter
);
same_tid_ins_vec
.
push_back
(
ins
);
same_tid_ins_set
.
insert
(
ins
);
}
else
if
(
ins
->
name
()
!=
"@return"
and
tass
.
at
(
ins
)
==
current_tid
)
else
if
(
tass
.
at
(
ins
)
==
current_tid
)
{
same_tid_ins_vec
.
push_back
(
ins
);
same_tid_ins_set
.
insert
(
ins
);
}
else
if
(
ins
->
name
()
==
"@return"
or
tass
.
at
(
ins
)
!=
current_tid
)
else
if
(
tass
.
at
(
ins
)
!=
current_tid
)
{
// gather all parameters
std
::
unordered_set
<
instruction_ref
>
params
;
// gather all return values
std
::
unordered_set
<
instruction_ref
>
return_ins
;
for
(
auto
tins
:
iterator_for
(
same_tid_ins_vec
))
{
auto
inputs
=
(
*
tins
)
->
inputs
();
auto
outputs
=
(
*
tins
)
->
outputs
();
transform_if
(
inputs
.
cbegin
(),
inputs
.
cend
(),
std
::
inserter
(
params
,
params
.
end
()),
[
&
](
auto
in_param
)
{
return
(
params
.
count
(
in_param
)
==
0
and
same_tid_ins_set
.
count
(
in_param
)
==
0
);
},
[
&
](
auto
in_param
)
{
return
in_param
;
});
if
(
std
::
any_of
(
outputs
.
begin
(),
outputs
.
end
(),
[
&
](
const
auto
out_ins
)
{
return
same_tid_ins_set
.
count
(
out_ins
)
==
0
;
}))
{
return_ins
.
insert
(
*
tins
);
}
}
if
(
enabled
(
MIGRAPHX_DEBUG_PARTITIONER
{}))
{
std
::
cout
<<
"params ins:
\n
"
;
for
(
auto
tmp
:
iterator_for
(
params
))
{
(
*
tmp
)
->
debug_print
();
}
std
::
cout
<<
"
\n
"
;
std
::
cout
<<
"return ins:
\n
"
;
for
(
auto
tmp
:
iterator_for
(
return_ins
))
{
(
*
tmp
)
->
debug_print
();
}
std
::
cout
<<
"
\n
"
;
}
auto
*
tmod
=
p
.
create_module
(
"target_mod_"
+
std
::
to_string
(
current_tid
)
+
"_"
+
std
::
to_string
(
tid_counter
[
current_tid
]));
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
params_map
;
std
::
size_t
param_counter
=
0
;
std
::
vector
<
instruction_ref
>
rot_ins_params
;
for
(
auto
pins
:
iterator_for
(
params
))
{
auto
scalar
=
get_scalar
(
*
pins
);
if
(
scalar
.
empty
())
{
params_map
[
*
pins
]
=
tmod
->
add_parameter
(
"param:"
+
std
::
to_string
(
param_counter
++
),
(
*
pins
)
->
get_shape
());
rot_ins_params
.
push_back
(
*
pins
);
}
else
{
params_map
[
*
pins
]
=
tmod
->
add_literal
(
scalar
);
}
}
// TODO: what if params_map is empty ?
for
(
auto
tins
:
iterator_for
(
same_tid_ins_vec
))
{
auto
inputs
=
(
*
tins
)
->
inputs
();
std
::
vector
<
instruction_ref
>
new_inputs
;
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
new_inputs
),
[
&
](
auto
input_ins
)
{
return
params_map
.
at
(
input_ins
);
});
// [TODO]: what if it is has module args ?
auto
tmod_tins
=
tmod
->
add_instruction
(
(
*
tins
)
->
get_operator
(),
new_inputs
,
(
*
tins
)
->
module_inputs
());
// add parameter to params map so that it can be looked up by other insturctions
params_map
[
*
tins
]
=
tmod_tins
;
}
std
::
vector
<
instruction_ref
>
rins
;
std
::
unordered_map
<
instruction_ref
,
std
::
size_t
>
return_ins_idx_map
;
for
(
auto
ritr
:
iterator_for
(
return_ins
))
{
rins
.
push_back
(
params_map
.
at
(
*
ritr
));
return_ins_idx_map
[
*
ritr
]
=
std
::
distance
(
ritr
,
return_ins
.
begin
());
}
tmod
->
add_return
(
rins
);
if
(
enabled
(
MIGRAPHX_DEBUG_PARTITIONER
{}))
{
std
::
cout
<<
"tmod:
\n
"
;
tmod
->
debug_print
();
}
// add run_on_target ins
auto
rot_ins
=
mm
->
insert_instruction
(
ins
,
make_op
(
"run_on_target"
,
{{
"target_id"
,
current_tid
}}),
rot_ins_params
,
{
tmod
});
skip_ins
.
insert
(
rot_ins
);
// fetch return instructions from tuple
for
(
auto
mm_rins
:
iterator_for
(
return_ins
))
{
auto
tuple_elem_ins
=
mm
->
insert_instruction
(
ins
,
make_op
(
"get_tuple_elem"
,
{{
"index"
,
return_ins_idx_map
.
at
(
*
mm_rins
)}}),
rot_ins
);
skip_ins
.
insert
(
tuple_elem_ins
);
// replace returns from tmod in main module
mm
->
replace_instruction
(
*
mm_rins
,
tuple_elem_ins
);
}
dead_code_elimination
{}.
apply
(
*
mm
);
// update current_tid
if
(
ins
->
name
()
!=
"@return"
)
{
current_tid
=
tass
.
at
(
ins
);
if
(
tid_counter
.
count
(
current_tid
)
==
0
)
{
tid_counter
[
current_tid
]
=
0
;
}
tid_counter
[
current_tid
]
++
;
same_tid_ins_set
.
clear
();
same_tid_ins_vec
.
clear
();
same_tid_ins_set
.
insert
(
ins
);
same_tid_ins_vec
.
push_back
(
ins
);
}
if
(
enabled
(
MIGRAPHX_DEBUG_PARTITIONER
{}))
{
std
::
cout
<<
"module after creation of tmod and rot:
\n
"
;
mm
->
debug_print
();
}
generate_run_on_target_modules
(
mm
,
p
,
ins
,
current_tid
,
tass
,
skip_ins
,
tid_counter
,
same_tid_ins_vec
,
same_tid_ins_set
);
}
else
{
MIGRAPHX_THROW
(
"Partition: this shouldn't occur"
);
}
}
}
...
...
test/multi_target/multitarget_test.cpp
View file @
89e71273
...
...
@@ -223,7 +223,7 @@ TEST_CASE(single_target_multi_compile)
// eval
migraphx
::
parameter_map
params
;
std
::
vector
<
float
>
boxes_vec
=
{
0.5
,
0.5
,
1.0
,
1.0
,
0.5
,
0.6
,
1.0
,
1.0
,
0.5
,
0.4
,
1.0
,
1.0
,
0.5
,
10.5
,
1.0
,
1.0
,
0.5
,
10.6
,
1.0
,
1.0
,
0.5
,
100.5
,
1.0
,
1.0
};
0.5
,
10.5
,
1.0
,
1.0
,
0.5
,
10.6
,
1.0
,
1.0
,
0.5
,
100.5
,
1.0
,
1.0
};
params
[
"boxes"
]
=
migraphx
::
argument
(
boxes_s
,
boxes_vec
.
data
());
auto
output
=
p
.
eval
(
params
).
back
();
std
::
vector
<
int64_t
>
gold_vec
=
{
0
,
0
,
3
,
0
,
0
,
0
,
0
,
0
,
5
};
...
...
@@ -510,6 +510,148 @@ TEST_CASE(multitarget_compile_nested_if_then_else)
}
}
// TODO : FPGA compilation is broken right now, below test mentions fpga but doesn't compile for it
TEST_CASE
(
multitarget_compile_nested_if_then_else_partition
)
{
std
::
unordered_map
<
std
::
size_t
,
std
::
size_t
>
counter_map
=
{{
0
,
0
},
{
1
,
0
}};
migraphx
::
shape
ds
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
migraphx
::
target_assignments
tass
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
cond_s
{
migraphx
::
shape
::
bool_type
};
auto
cond_0
=
mm
->
add_parameter
(
"cond_0"
,
cond_s
);
auto
cond_1
=
mm
->
add_parameter
(
"cond_1"
,
cond_s
);
auto
x
=
mm
->
add_parameter
(
"x"
,
ds
);
auto
y
=
mm
->
add_parameter
(
"y"
,
ds
);
auto
z
=
mm
->
add_parameter
(
"z"
,
ds
);
auto
create_test_module
=
[
&
](
migraphx
::
program
&
prog
,
const
std
::
vector
<
migraphx
::
instruction_ref
>&
inputs
,
std
::
size_t
tid
)
{
std
::
string
mod_name
=
"target_"
+
std
::
to_string
(
tid
)
+
"_"
+
std
::
to_string
(
counter_map
[
tid
]
++
);
auto
*
test_mod
=
prog
.
create_module
(
mod_name
);
std
::
vector
<
float
>
data
(
ds
.
elements
(),
-
1
);
auto
l1
=
test_mod
->
add_literal
(
migraphx
::
literal
(
ds
,
data
));
auto
ins1
=
test_mod
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
inputs
[
0
],
l1
);
auto
ins2
=
test_mod
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
ins1
,
inputs
[
1
]);
auto
ins3
=
test_mod
->
add_instruction
(
migraphx
::
make_op
(
"sub"
),
ins2
,
inputs
[
2
]);
test_mod
->
add_return
({
ins3
});
tass
.
insert
(
tass
.
begin
(),
std
::
make_pair
(
ins1
,
tid
));
tass
.
insert
(
tass
.
begin
(),
std
::
make_pair
(
ins2
,
tid
));
tass
.
insert
(
tass
.
begin
(),
std
::
make_pair
(
ins3
,
tid
));
return
test_mod
;
};
auto
*
then_mod
=
p
.
create_module
(
"then_mod"
);
auto
then_mod_cond
=
then_mod
->
add_parameter
(
"then_mod_cond"
,
cond_s
);
auto
then_mod_param_0
=
then_mod
->
add_parameter
(
"then_mod_param_0"
,
ds
);
auto
then_mod_param_1
=
then_mod
->
add_parameter
(
"then_mod_param_1"
,
ds
);
auto
then_mod_param_2
=
then_mod
->
add_parameter
(
"then_mod_param_2"
,
ds
);
auto
then_mod_ref_ins
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
then_mod_param_0
,
then_mod_param_1
);
tass
.
insert
(
tass
.
begin
(),
std
::
make_pair
(
then_mod_ref_ins
,
3
));
auto
then_mod_if
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"if"
),
{
then_mod_cond
,
then_mod_param_0
,
then_mod_param_1
,
then_mod_param_2
,
then_mod_ref_ins
,
then_mod_param_1
,
then_mod_param_2
},
{
create_test_module
(
p
,
{
then_mod_param_0
,
then_mod_param_1
,
then_mod_param_2
},
1
),
create_test_module
(
p
,
{
then_mod_ref_ins
,
then_mod_param_1
,
then_mod_param_2
},
0
)});
auto
then_mod_if_0
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
0
}}),
then_mod_if
);
then_mod
->
add_return
({
then_mod_if_0
});
// create nested else_mod with multiple targets.
// else_mod has one instruction that runs a module on "fpga" and another instruction that
// creates nested modules using "If" that runs on "cpu" and "gpu"
auto
*
else_mod
=
p
.
create_module
(
"else_mod"
);
auto
else_mod_cond
=
else_mod
->
add_parameter
(
"else_mod_cond"
,
cond_s
);
auto
else_mod_param_0
=
else_mod
->
add_parameter
(
"else_mod_param_0"
,
ds
);
auto
else_mod_param_1
=
else_mod
->
add_parameter
(
"else_mod_param_1"
,
ds
);
auto
else_mod_param_2
=
else_mod
->
add_parameter
(
"else_mod_param_2"
,
ds
);
auto
else_mod_fpga_ins
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
else_mod_param_0
,
else_mod_param_2
);
tass
.
insert
(
tass
.
begin
(),
std
::
make_pair
(
else_mod_fpga_ins
,
2
));
auto
else_mod_if
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"if"
),
{
else_mod_cond
,
else_mod_fpga_ins
,
else_mod_param_0
,
else_mod_param_1
,
else_mod_param_2
,
else_mod_param_1
,
else_mod_param_0
},
{
create_test_module
(
p
,
{
else_mod_fpga_ins
,
else_mod_param_0
,
else_mod_param_1
},
0
),
create_test_module
(
p
,
{
else_mod_param_2
,
else_mod_param_1
,
else_mod_param_0
},
1
)});
auto
else_mod_if_0
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
0
}}),
else_mod_if
);
else_mod
->
add_return
({
else_mod_if_0
});
// Create nested and multi-target main module using "If"
auto
main_if_ins
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"if"
),
{
cond_0
,
cond_1
,
x
,
y
,
z
,
cond_1
,
x
,
y
,
z
},
{
then_mod
,
else_mod
});
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"get_tuple_elem"
,
{{
"index"
,
0
}}),
main_if_ins
);
mm
->
add_return
({
r
});
// compile
migraphx
::
compile_options
gpu_opts
;
gpu_opts
.
offload_copy
=
true
;
std
::
cout
<<
"before parition
\n
"
;
p
.
debug_print
();
migraphx
::
partition
(
p
,
tass
);
std
::
cout
<<
"after partition
\n
"
;
p
.
debug_print
();
p
.
compile
({
migraphx
::
make_target
(
"gpu"
),
migraphx
::
make_target
(
"cpu"
),
migraphx
::
make_target
(
"ref"
),
migraphx
::
make_target
(
"ref"
)},
{
gpu_opts
});
std
::
cout
<<
"after compilation
\n
"
;
p
.
debug_print
();
EXPECT
(
check_compiled_program
(
p
,
{
migraphx
::
make_target
(
"gpu"
),
migraphx
::
make_target
(
"cpu"
),
migraphx
::
make_target
(
"ref"
),
migraphx
::
make_target
(
"ref"
)}));
// do evaluation using different conditions
migraphx
::
parameter_map
params
;
float
x_i
=
2.0
;
float
y_i
=
3.0
;
float
z_i
=
4.0
;
params
[
"x"
]
=
migraphx
::
fill_argument
(
ds
,
x_i
);
params
[
"y"
]
=
migraphx
::
fill_argument
(
ds
,
y_i
);
params
[
"z"
]
=
migraphx
::
fill_argument
(
ds
,
z_i
);
// cover all paths with different combination of conditions
std
::
vector
<
std
::
pair
<
bool
,
bool
>>
test_conds
=
{
{
true
,
true
},
{
true
,
false
},
{
false
,
true
},
{
false
,
false
}};
for
(
auto
[
cond_val_0
,
cond_val_1
]
:
test_conds
)
{
params
[
"cond_0"
]
=
migraphx
::
argument
(
cond_s
,
&
cond_val_0
);
params
[
"cond_1"
]
=
migraphx
::
argument
(
cond_s
,
&
cond_val_1
);
auto
result
=
p
.
eval
(
params
).
back
();
// main has one instruction that is : if_then_else
// then mod is doing : {tmp = x+y; (cond) ? (((x-1)*y)-z) : (((tmp-1)*y)-z);}
// else mod is doing : {tmp = x+z; (cond) ? (((tmp-1)*x)-y) : (((z-1)*y)-x);}
float
gold_i
=
-
1.0
;
if
(
cond_val_0
)
{
float
tmp_i
=
x_i
+
y_i
;
gold_i
=
(
cond_val_1
)
?
(((
x_i
-
1
)
*
y_i
)
-
z_i
)
:
(((
tmp_i
-
1
)
*
y_i
)
-
z_i
);
}
else
{
float
tmp_i
=
x_i
+
z_i
;
gold_i
=
(
cond_val_1
)
?
(((
tmp_i
-
1
)
*
x_i
)
-
y_i
)
:
(((
z_i
-
1
)
*
y_i
)
-
x_i
);
}
auto
gold
=
migraphx
::
fill_argument
(
ds
,
gold_i
);
EXPECT
(
gold
==
result
);
}
}
// TODO : FPGA compilation is broken right now, below test mentions fpga but doesn't compile for it
TEST_CASE
(
multitarget_select_module
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment