Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3e6eba01
Commit
3e6eba01
authored
Aug 03, 2022
by
Ted Themistokleous
Browse files
First attempt at adding proper reshape for then/else modules in parse_if
parent
aee3164f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
24 deletions
+26
-24
src/onnx/parse_if.cpp
src/onnx/parse_if.cpp
+26
-24
No files found.
src/onnx/parse_if.cpp
View file @
3e6eba01
...
...
@@ -73,36 +73,38 @@ struct parse_if : op_parser<parse_if>
{
MIGRAPHX_THROW
(
"PARSE_IF: "
+
info
.
name
+
" then and else sub_grahps must have same output type! "
+
std
::
to_string
(
then_out_shapes
.
at
(
0
).
type
(
)
)
+
" vs "
+
std
::
to_string
(
else_out_shapes
.
at
(
0
).
type
(
)
));
then_out_shapes
.
at
(
0
).
type
_string
()
+
" vs "
+
else_out_shapes
.
at
(
0
).
type
_string
());
}
// If either argument returns non scalar, promote the scalar to a 1D tensor to meet the
// shape requirements
// if and only if the first dimension matches
if
(
not
then_out_shapes
.
at
(
0
).
scalar
()
||
not
else_out_shapes
.
at
(
0
).
scalar
)
if
(
not
then_out_shapes
.
at
(
0
).
scalar
()
&&
not
else_out_shapes
.
at
(
0
).
scalar
())
{
if
(
then_out_shapes
.
at
(
0
).
scalar
())
// First dimension must agree
if
(
then_out_shapes
.
at
(
0
).
lens
().
at
(
0
)
!=
else_out_shapes
.
at
(
0
).
lens
().
at
(
0
))
{
if
(
then_out_shapes
.
at
(
0
).
lens
().
at
(
0
)
!=
else_out_shapes
.
at
(
0
).
lens
().
at
(
0
))
{
MIGRAPHX_THROW
(
"PARSE_IF: "
+
info
.
name
+
"then out incompatible output shape with else"
);
}
migraphx
::
shape
s
(
then_out_shapes
.
at
(
0
).
type
(),
{
then_out_shapes
.
at
(
0
).
lens
().
at
(
0
),
1
},
{
1
,
1
});
then_mdl
->
add_outline
(
s
);
MIGRAPHX_THROW
(
"PARSE_IF: "
+
then_out_shapes
.
at
(
0
).
type_string
()
+
" & "
+
else_out_shapes
.
at
(
0
).
type_string
()
+
" are incompatible output shapes for then/cases"
);
}
else
if
(
else_out_shapes
.
at
(
0
).
scalar
())
auto
then_out_strides
=
then_out_shapes
.
at
(
0
).
strides
();
auto
else_out_strides
=
else_out_shapes
.
at
(
0
).
strides
();
if
(
then_out_strides
.
size
()
>
else_out_strides
.
size
())
{
else_mdl
->
insert_instruction
(
std
::
prev
(
else_mdl
->
end
()),
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{{
else_out_shapes
.
at
(
0
).
lens
().
at
(
0
),
1
},
{
1
,
1
}}}}),
std
::
prev
(
else_mdl
->
end
())
->
inputs
().
front
());
}
else
if
(
then_out_strides
.
size
()
<
else_out_strides
.
size
())
{
if
(
else_out_shapes
.
at
(
0
).
lens
().
at
(
0
)
!=
else_out_shapes
.
at
(
0
).
lens
().
at
(
0
))
{
MIGRAPHX_THROW
(
"PARSE_IF: "
+
info
.
name
+
"else out incompatible output shape with then"
);
}
migraphx
::
shape
s
(
else_out_shapes
.
at
(
0
).
type
(),
{
else_out_shapes
.
at
(
0
).
lens
().
at
(
0
),
1
},
{
1
,
1
});
else_mdl
->
add_outline
(
s
);
then_mdl
->
insert_instruction
(
std
::
prev
(
then_mdl
->
end
()),
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{{
then_out_shapes
.
at
(
0
).
lens
().
at
(
0
),
1
},
{
1
,
1
}}}}),
std
::
prev
(
then_mdl
->
end
())
->
inputs
().
front
());
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment