Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e612b60c
"vscode:/vscode.git/clone" did not exist on "2e414b7c922b516152ca7d9586a76dc3734aa212"
Commit
e612b60c
authored
Oct 28, 2022
by
Ted Themistokleous
Browse files
Avoid unneeded nesting
parent
881a4bd4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
60 additions
and
58 deletions
+60
-58
src/onnx/parse_if.cpp
src/onnx/parse_if.cpp
+60
-58
No files found.
src/onnx/parse_if.cpp
View file @
e612b60c
...
@@ -115,76 +115,78 @@ struct parse_if : op_parser<parse_if>
...
@@ -115,76 +115,78 @@ struct parse_if : op_parser<parse_if>
else_out_shape
.
type_string
());
else_out_shape
.
type_string
());
}
}
if
(
not
then_out_shape
.
dynamic
()
and
not
else_out_shape
.
dynamic
())
if
(
then_out_shape
.
dynamic
()
or
else_out_shape
.
dynamic
())
{
{
auto
then_lens
=
then_out_shape
.
lens
()
;
continue
;
auto
else_lens
=
else_out_shape
.
lens
();
}
// Throw error if both branches have zero output shapes. Not possible for static
auto
then_lens
=
then_out_shape
.
lens
();
// inputs
auto
else_lens
=
else_out_shape
.
lens
();
if
(
then_lens
.
empty
()
and
else_lens
.
empty
())
{
throw_shapes
();
}
auto
handle_empty_branch
=
[](
module_ref
&
mdl
,
int
index
,
const
shape
&
out_shape
)
{
// Throw error if both branches have zero output shapes. Not possible for static
shape
gen_shape
(
shape
(
out_shape
.
type
(),
{
1
},
{
0
}));
// inputs
auto
literal_ins
=
if
(
then_lens
.
empty
()
and
else_lens
.
empty
())
mdl
->
insert_literal
(
std
::
prev
(
mdl
->
end
()),
literal
(
gen_shape
,
{
0
}));
{
auto
unsqueeze_ins
=
mdl
->
insert_instruction
(
throw_shapes
();
std
::
prev
(
mdl
->
end
()),
}
make_op
(
"scalar"
,
{{
"scalar_bcst_dims"
,
out_shape
.
lens
()}}),
literal_ins
);
auto
handle_empty_branch
=
[](
module_ref
&
mdl
,
int
index
,
const
shape
&
out_shape
)
{
auto
broad_ins
=
mdl
->
insert_instruction
(
shape
gen_shape
(
shape
(
out_shape
.
type
(),
{
1
},
{
0
}));
std
::
prev
(
mdl
->
end
()),
auto
literal_ins
=
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_shape
.
lens
()}}),
mdl
->
insert_literal
(
std
::
prev
(
mdl
->
end
()),
literal
(
gen_shape
,
{
0
}));
unsqueeze_ins
);
auto
unsqueeze_ins
=
mdl
->
insert_instruction
(
auto
contig_out
=
mdl
->
insert_instruction
(
std
::
prev
(
mdl
->
end
()),
std
::
prev
(
mdl
->
end
()),
make_op
(
"contiguous"
),
broad_ins
);
make_op
(
"scalar"
,
{{
"scalar_bcst_dims"
,
out_shape
.
lens
()}}),
mdl
->
replace_instruction
(
std
::
prev
(
mdl
->
end
())
->
inputs
().
at
(
index
),
contig_out
);
literal_ins
);
return
out_shape
.
lens
();
auto
broad_ins
=
mdl
->
insert_instruction
(
};
std
::
prev
(
mdl
->
end
()),
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_shape
.
lens
()}}),
// Handle one empty branch by setting output identical to the other
unsqueeze_ins
);
// need to update the then_shape before we do further checks
auto
contig_out
=
mdl
->
insert_instruction
(
if
(
then_lens
.
empty
())
std
::
prev
(
mdl
->
end
()),
make_op
(
"contiguous"
),
broad_ins
);
{
mdl
->
replace_instruction
(
std
::
prev
(
mdl
->
end
())
->
inputs
().
at
(
index
),
contig_out
);
then_lens
=
handle_empty_branch
(
then_mdl
,
i
,
else_out_shape
);
return
out_shape
.
lens
();
}
};
else
if
(
else_lens
.
empty
())
// Handle one empty branch by setting output identical to the other
// need to update the then_shape before we do further checks
if
(
then_lens
.
empty
())
{
then_lens
=
handle_empty_branch
(
then_mdl
,
i
,
else_out_shape
);
}
else
if
(
else_lens
.
empty
())
{
else_lens
=
handle_empty_branch
(
else_mdl
,
i
,
then_out_shape
);
}
// check equivalent length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
int
rank_delta
=
abs
((
static_cast
<
int
>
(
then_lens
.
size
()
-
else_lens
.
size
())));
if
(
rank_delta
==
1
)
{
// make sure dims are equivalent in static shapes
if
(
not
all_but_last_dims_equal
(
then_lens
,
else_lens
))
{
{
else_lens
=
handle_empty_branch
(
else_mdl
,
i
,
then_out
_shape
);
throw
_shape
s
(
);
}
}
// check equivalent length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
auto
last_then
=
then_lens
.
back
();
int
rank_delta
=
abs
((
static_cast
<
int
>
(
then_lens
.
size
()
-
else_lens
.
size
()))
);
auto
last_else
=
else_lens
.
back
(
);
if
(
rank_delta
==
1
)
// Find which dim to unsqueeze
if
((
then_lens
.
size
()
<
else_lens
.
size
())
&&
(
last_else
==
1
))
{
{
// make sure dims are equivalent in static shapes
unsqueeze_last_op
(
then_mdl
,
i
,
else_lens
);
if
(
not
all_but_last_dims_equal
(
then_lens
,
else_lens
))
{
throw_shapes
();
}
auto
last_then
=
then_lens
.
back
();
auto
last_else
=
else_lens
.
back
();
// Find which dim to unsqueeze
if
((
then_lens
.
size
()
<
else_lens
.
size
())
&&
(
last_else
==
1
))
{
unsqueeze_last_op
(
then_mdl
,
i
,
else_lens
);
}
else
if
((
then_lens
.
size
()
>
else_lens
.
size
())
&&
(
last_then
==
1
))
{
unsqueeze_last_op
(
else_mdl
,
i
,
then_lens
);
}
}
}
else
if
(
rank_delta
>
1
)
else
if
(
(
then_lens
.
size
()
>
else_lens
.
size
())
&&
(
last_then
==
1
)
)
{
{
throw_shapes
(
);
unsqueeze_last_op
(
else_mdl
,
i
,
then_lens
);
}
}
}
}
else
if
(
rank_delta
>
1
)
{
throw_shapes
();
}
}
}
auto
if_ret
=
info
.
add_instruction
(
make_op
(
"if"
),
args
,
{
then_mdl
,
else_mdl
});
auto
if_ret
=
info
.
add_instruction
(
make_op
(
"if"
),
args
,
{
then_mdl
,
else_mdl
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment