Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
34aeda23
Commit
34aeda23
authored
Oct 15, 2022
by
Ted Themistokleous
Browse files
Fix tidy issue and clean up logic
parent
372206b3
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
42 deletions
+39
-42
src/onnx/parse_if.cpp
src/onnx/parse_if.cpp
+39
-42
No files found.
src/onnx/parse_if.cpp
View file @
34aeda23
...
@@ -115,22 +115,21 @@ struct parse_if : op_parser<parse_if>
...
@@ -115,22 +115,21 @@ struct parse_if : op_parser<parse_if>
{
{
else_lens
=
handle_empty_branch
(
else_mdl
,
then_out_shapes
.
at
(
0
));
else_lens
=
handle_empty_branch
(
else_mdl
,
then_out_shapes
.
at
(
0
));
}
}
else
{
// check equivilant length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
// check equivilant length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
int
dim_delta
=
abs
((
static_cast
<
int
>
(
then_lens
.
size
()
-
else_lens
.
size
())));
int
dim_delta
=
abs
((
static_cast
<
int
>
(
then_lens
.
size
()
-
else_lens
.
size
())));
if
(
dim_delta
<
=
1
)
if
(
dim_delta
=
=
1
)
{
{
auto
all_but_last_dims_equal
=
[](
std
::
vector
<
size_t
>&
lens_
A
,
auto
all_but_last_dims_equal
=
[](
std
::
vector
<
size_t
>&
lens_
a
,
std
::
vector
<
size_t
>&
lens_
B
)
{
std
::
vector
<
size_t
>&
lens_
b
)
{
if
(
lens_
A
.
size
()
<=
lens_
B
.
size
())
if
(
lens_
a
.
size
()
<=
lens_
b
.
size
())
{
{
return
equal
(
lens_
A
.
begin
(),
lens_
A
.
end
(),
lens_
B
.
begin
());
return
equal
(
lens_
a
.
begin
(),
lens_
a
.
end
(),
lens_
b
.
begin
());
}
}
else
else
{
{
return
equal
(
lens_
B
.
begin
(),
lens_
B
.
end
(),
lens_
A
.
begin
());
return
equal
(
lens_
b
.
begin
(),
lens_
b
.
end
(),
lens_
a
.
begin
());
}
}
};
};
...
@@ -140,8 +139,7 @@ struct parse_if : op_parser<parse_if>
...
@@ -140,8 +139,7 @@ struct parse_if : op_parser<parse_if>
throw_shapes
();
throw_shapes
();
}
}
auto
unsqueeze_last_op
=
[](
module_ref
&
mdl
,
auto
unsqueeze_last_op
=
[](
module_ref
&
mdl
,
const
std
::
vector
<
size_t
>&
out_shape
)
{
const
std
::
vector
<
size_t
>&
out_shape
)
{
auto
convert_ins
=
mdl
->
add_instruction
(
auto
convert_ins
=
mdl
->
add_instruction
(
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
out_shape
.
size
()
-
1
}}}),
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
out_shape
.
size
()
-
1
}}}),
std
::
prev
(
std
::
prev
(
mdl
->
end
())));
std
::
prev
(
std
::
prev
(
mdl
->
end
())));
...
@@ -162,12 +160,11 @@ struct parse_if : op_parser<parse_if>
...
@@ -162,12 +160,11 @@ struct parse_if : op_parser<parse_if>
unsqueeze_last_op
(
else_mdl
,
then_lens
);
unsqueeze_last_op
(
else_mdl
,
then_lens
);
}
}
}
}
else
else
if
(
dim_delta
>
1
)
{
{
throw_shapes
();
throw_shapes
();
}
}
}
}
}
auto
if_ret
=
info
.
add_instruction
(
make_op
(
"if"
),
args
,
{
then_mdl
,
else_mdl
});
auto
if_ret
=
info
.
add_instruction
(
make_op
(
"if"
),
args
,
{
then_mdl
,
else_mdl
});
auto
out_s
=
if_ret
->
get_shape
();
auto
out_s
=
if_ret
->
get_shape
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment