Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
34aeda23
Commit
34aeda23
authored
Oct 15, 2022
by
Ted Themistokleous
Browse files
Fix tidy issue and clean up logic
parent
372206b3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
42 deletions
+39
-42
src/onnx/parse_if.cpp
src/onnx/parse_if.cpp
+39
-42
No files found.
src/onnx/parse_if.cpp
View file @
34aeda23
...
@@ -115,57 +115,54 @@ struct parse_if : op_parser<parse_if>
...
@@ -115,57 +115,54 @@ struct parse_if : op_parser<parse_if>
{
{
else_lens
=
handle_empty_branch
(
else_mdl
,
then_out_shapes
.
at
(
0
));
else_lens
=
handle_empty_branch
(
else_mdl
,
then_out_shapes
.
at
(
0
));
}
}
else
{
// check equivilant length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
int
dim_delta
=
abs
((
static_cast
<
int
>
(
then_lens
.
size
()
-
else_lens
.
size
())));
if
(
dim_delta
<=
1
)
{
auto
all_but_last_dims_equal
=
[](
std
::
vector
<
size_t
>&
lens_A
,
std
::
vector
<
size_t
>&
lens_B
)
{
if
(
lens_A
.
size
()
<=
lens_B
.
size
())
{
return
equal
(
lens_A
.
begin
(),
lens_A
.
end
(),
lens_B
.
begin
());
}
else
{
return
equal
(
lens_B
.
begin
(),
lens_B
.
end
(),
lens_A
.
begin
());
}
};
// make sure dims are equivalent in static shapes
if
(
not
all_but_last_dims_equal
(
then_lens
,
else_lens
))
{
throw_shapes
();
}
auto
unsqueeze_last_op
=
[](
module_ref
&
mdl
,
// check equivilant length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
const
std
::
vector
<
size_t
>&
out_shape
)
{
int
dim_delta
=
abs
((
static_cast
<
int
>
(
then_lens
.
size
()
-
else_lens
.
size
())));
auto
convert_ins
=
mdl
->
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
out_shape
.
size
()
-
1
}}}),
std
::
prev
(
std
::
prev
(
mdl
->
end
())));
mdl
->
replace_return
({
convert_ins
});
mdl
->
remove_instruction
({
std
::
prev
(
convert_ins
)});
};
auto
last_then
=
*
(
std
::
prev
(
then_lens
.
end
()));
if
(
dim_delta
==
1
)
auto
last_else
=
*
(
std
::
prev
(
else_lens
.
end
()));
{
auto
all_but_last_dims_equal
=
[](
std
::
vector
<
size_t
>&
lens_a
,
// Find which dim to unsqueeze
std
::
vector
<
size_t
>&
lens_b
)
{
if
(
(
then_
lens
.
size
()
<
else_
lens
.
size
())
&&
(
last_else
==
1
))
if
(
lens
_a
.
size
()
<
=
lens
_b
.
size
())
{
{
unsqueeze_last_op
(
then_mdl
,
else_lens
);
return
equal
(
lens_a
.
begin
(),
lens_a
.
end
(),
lens_b
.
begin
()
);
}
}
else
if
((
then_lens
.
size
()
>
else_lens
.
size
())
&&
(
last_then
==
1
))
else
{
{
unsqueeze_last_op
(
else_mdl
,
then_lens
);
return
equal
(
lens_b
.
begin
(),
lens_b
.
end
(),
lens_a
.
begin
()
);
}
}
}
};
else
// make sure dims are equivalent in static shapes
if
(
not
all_but_last_dims_equal
(
then_lens
,
else_lens
))
{
{
throw_shapes
();
throw_shapes
();
}
}
auto
unsqueeze_last_op
=
[](
module_ref
&
mdl
,
const
std
::
vector
<
size_t
>&
out_shape
)
{
auto
convert_ins
=
mdl
->
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
out_shape
.
size
()
-
1
}}}),
std
::
prev
(
std
::
prev
(
mdl
->
end
())));
mdl
->
replace_return
({
convert_ins
});
mdl
->
remove_instruction
({
std
::
prev
(
convert_ins
)});
};
auto
last_then
=
*
(
std
::
prev
(
then_lens
.
end
()));
auto
last_else
=
*
(
std
::
prev
(
else_lens
.
end
()));
// Find which dim to unsqueeze
if
((
then_lens
.
size
()
<
else_lens
.
size
())
&&
(
last_else
==
1
))
{
unsqueeze_last_op
(
then_mdl
,
else_lens
);
}
else
if
((
then_lens
.
size
()
>
else_lens
.
size
())
&&
(
last_then
==
1
))
{
unsqueeze_last_op
(
else_mdl
,
then_lens
);
}
}
else
if
(
dim_delta
>
1
)
{
throw_shapes
();
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment