Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c297ce5f
Commit
c297ce5f
authored
Nov 04, 2022
by
Ted Themistokleous
Browse files
Fixes to handle constants
parent
b6ca9b26
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
34 additions
and
29 deletions
+34
-29
src/onnx/parse_constant.cpp
src/onnx/parse_constant.cpp
+2
-1
src/onnx/parse_constant_of_shape.cpp
src/onnx/parse_constant_of_shape.cpp
+11
-3
src/onnx/parse_if.cpp
src/onnx/parse_if.cpp
+15
-10
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+6
-15
No files found.
src/onnx/parse_constant.cpp
View file @
c297ce5f
...
...
@@ -43,7 +43,8 @@ struct parse_constant : op_parser<parse_constant>
// return empty literal
if
(
v
.
get_shape
().
elements
()
==
0
)
{
return
info
.
add_literal
(
literal
{
v
.
get_shape
().
type
()});
migraphx
::
shape
empty_constant
(
v
.
get_shape
().
type
(),
{
1
},
{
0
});
return
info
.
add_literal
(
literal
{
empty_constant
,
{
0
}});
}
auto
dim_size
=
info
.
attributes
.
at
(
"value"
).
t
().
dims_size
();
...
...
src/onnx/parse_constant_of_shape.cpp
View file @
c297ce5f
...
...
@@ -68,7 +68,7 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
// empty input tensor, output is a scalar
if
(
args
[
0
]
->
get_shape
().
elements
()
==
0
)
{
s
=
migraphx
::
shape
{
type
,
{
1
},
{
0
}};
s
=
migraphx
::
shape
{
type
,
{
1
},
{}};
}
else
{
...
...
@@ -84,8 +84,16 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
l_val
.
visit
([
&
](
auto
val
)
{
using
val_type
=
std
::
remove_cv_t
<
typename
decltype
(
val
)
::
value_type
>
;
// l_val contains only one element
std
::
vector
<
val_type
>
out_vec
(
s
.
elements
(),
val
.
front
());
l_out
=
literal
(
s
,
out_vec
);
if
(
s
.
elements
()
>
0
)
{
std
::
vector
<
val_type
>
out_vec
(
s
.
elements
(),
val
.
front
());
l_out
=
literal
(
s
,
out_vec
);
}
else
{
std
::
vector
<
val_type
>
out_vec
{
val
.
front
()};
l_out
=
literal
(
s
,
out_vec
);
}
});
return
info
.
add_literal
(
l_out
);
...
...
src/onnx/parse_if.cpp
View file @
c297ce5f
...
...
@@ -31,6 +31,7 @@
#include <migraphx/reduce_dims.hpp>
#include <algorithm>
#include <migraphx/stringutils.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
...
...
@@ -92,7 +93,9 @@ struct parse_if : op_parser<parse_if>
auto
throw_shapes
=
[
&
]()
{
MIGRAPHX_THROW
(
"PARSE_IF: "
+
info
.
name
+
" then and else sub_graphs must have compatible shapes "
);
" then and else sub_graphs must have compatible shapes "
+
to_string_range
(
then_out_shapes
)
+
" vs "
+
to_string_range
(
else_out_shapes
));
};
if
(
then_out_shapes
.
size
()
!=
else_out_shapes
.
size
())
...
...
@@ -126,16 +129,14 @@ struct parse_if : op_parser<parse_if>
assert
(
not
(
then_lens
.
empty
()
and
else_lens
.
empty
()));
auto
handle_empty_branch
=
[](
module_ref
&
mdl
,
int
index
,
const
shape
&
out_shape
)
{
shape
gen_shape
(
shape
(
out_shape
.
type
(),
{
1
},
{
0
}));
auto
literal_ins
=
mdl
->
add_literal
(
literal
(
gen_shape
,
{
0
}));
auto
unsqueeze_ins
=
mdl
->
insert_instruction
(
std
::
prev
(
mdl
->
end
()),
make_op
(
"scalar"
,
{{
"scalar_bcst_dims"
,
out_shape
.
lens
()}}),
literal_ins
);
auto
scalar_ins
=
mdl
->
insert_instruction
(
std
::
prev
(
mdl
->
end
()),
make_op
(
"scalar"
,
{{
"out_lens"
,
out_shape
.
lens
()}}),
std
::
prev
(
mdl
->
end
()));
auto
broad_ins
=
mdl
->
insert_instruction
(
std
::
prev
(
mdl
->
end
()),
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_shape
.
lens
()}}),
unsqueeze
_ins
);
scalar
_ins
);
auto
contig_out
=
mdl
->
insert_instruction
(
std
::
prev
(
mdl
->
end
()),
make_op
(
"contiguous"
),
broad_ins
);
mdl
->
replace_instruction
(
std
::
prev
(
mdl
->
end
())
->
inputs
().
at
(
index
),
contig_out
);
...
...
@@ -144,11 +145,12 @@ struct parse_if : op_parser<parse_if>
// Handle one empty branch by setting output identical to the other
// need to update the then_shape before we do further checks
if
(
then_lens
.
empty
())
if
(
then_out_shape
.
strides
().
empty
())
{
then_lens
=
handle_empty_branch
(
then_mdl
,
i
,
else_out_shape
);
}
else
if
(
else_
lens
.
empty
())
else
if
(
else_
out_shape
.
strides
()
.
empty
())
{
else_lens
=
handle_empty_branch
(
else_mdl
,
i
,
then_out_shape
);
}
...
...
@@ -183,6 +185,9 @@ struct parse_if : op_parser<parse_if>
}
}
then_mdl
->
debug_print
();
else_mdl
->
debug_print
();
auto
if_ret
=
info
.
add_instruction
(
make_op
(
"if"
),
args
,
{
then_mdl
,
else_mdl
});
auto
out_s
=
if_ret
->
get_shape
();
assert
(
out_s
.
type
()
==
shape
::
tuple_type
);
...
...
test/onnx/onnx_test.cpp
View file @
c297ce5f
...
...
@@ -761,7 +761,7 @@ TEST_CASE(constant_empty_scalar_int64_test)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
int64_type
});
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
int64_type
,
{
0
}
});
auto
prog
=
optimize_onnx
(
"constant_empty_scalar_int64_test.onnx"
);
EXPECT
(
p
==
prog
);
...
...
@@ -781,8 +781,8 @@ TEST_CASE(const_of_shape_empty_input_test)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
mm
->
add_literal
(
migraphx
::
literal
(
migraphx
::
shape
::
int32_type
));
migraphx
::
shape
s
(
migraphx
::
shape
::
int64_type
,
{
1
}
,
{
0
}
);
mm
->
add_literal
(
migraphx
::
literal
(
migraphx
::
shape
::
int32_type
,
{
0
}
));
migraphx
::
shape
s
(
migraphx
::
shape
::
int64_type
,
{
1
});
std
::
vector
<
int64_t
>
vec
(
s
.
elements
(),
10
);
mm
->
add_literal
(
migraphx
::
literal
(
s
,
vec
));
...
...
@@ -2425,17 +2425,18 @@ TEST_CASE(if_literal_test)
auto
cond
=
mm
->
add_parameter
(
"cond"
,
cond_s
);
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
5
}};
migraphx
::
shape
empty_const
(
migraphx
::
shape
::
float_type
,
{
1
},
{
0
});
auto
*
then_mod
=
p
.
create_module
(
"If_1_if"
);
std
::
vector
<
float
>
data1
=
{
1
,
2
,
3
,
4
,
5
};
auto
l1
=
then_mod
->
add_literal
(
migraphx
::
literal
(
s
,
data1
));
then_mod
->
add_literal
({});
then_mod
->
add_literal
({
empty_const
,
{
0
}
});
then_mod
->
add_return
({
l1
});
auto
*
else_mod
=
p
.
create_module
(
"If_1_else"
);
std
::
vector
<
float
>
data2
=
{
5
,
4
,
3
,
2
,
1
};
auto
l2
=
else_mod
->
add_literal
(
migraphx
::
literal
(
s
,
data2
));
else_mod
->
add_literal
({});
else_mod
->
add_literal
({
empty_const
,
{
0
}
});
else_mod
->
add_return
({
l2
});
auto
ret
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"if"
),
{
cond
},
{
then_mod
,
else_mod
});
...
...
@@ -2599,8 +2600,6 @@ TEST_CASE(if_then_empty_constant_test)
auto
*
then_mod
=
p
.
create_module
(
"If_4_if"
);
then_mod
->
add_literal
(
migraphx
::
shape
::
int64_type
);
migraphx
::
shape
gen_shape
(
migraphx
::
shape
(
s
.
type
(),
{
1
},
{
0
}));
auto
literal_ins
=
then_mod
->
add_literal
(
migraphx
::
literal
(
gen_shape
,
{
0
}));
auto
unsqueeze_ins
=
then_mod
->
add_instruction
(
...
...
@@ -2636,9 +2635,6 @@ TEST_CASE(if_then_empty_constant_multi_output_test)
auto
*
then_mod
=
p
.
create_module
(
"If_4_if"
);
then_mod
->
add_literal
(
migraphx
::
shape
::
int64_type
);
then_mod
->
add_literal
(
migraphx
::
shape
::
int64_type
);
migraphx
::
shape
gen_shape
(
migraphx
::
shape
(
s
.
type
(),
{
1
},
{
0
}));
auto
literal_ins
=
then_mod
->
add_literal
(
migraphx
::
literal
(
gen_shape
,
{
0
}));
...
...
@@ -2691,8 +2687,6 @@ TEST_CASE(if_else_empty_constant_test)
auto
*
else_mod
=
p
.
create_module
(
"If_4_else"
);
else_mod
->
add_literal
(
s
.
type
());
migraphx
::
shape
gen_shape
(
migraphx
::
shape
(
s
.
type
(),
{
1
},
{
0
}));
auto
literal_ins
=
else_mod
->
add_literal
(
migraphx
::
literal
(
gen_shape
,
{
0
}));
...
...
@@ -2731,9 +2725,6 @@ TEST_CASE(if_else_empty_constant_multi_output_test)
auto
*
else_mod
=
p
.
create_module
(
"If_4_else"
);
else_mod
->
add_literal
(
migraphx
::
shape
::
int64_type
);
else_mod
->
add_literal
(
migraphx
::
shape
::
int64_type
);
migraphx
::
shape
gen_shape
(
migraphx
::
shape
(
s
.
type
(),
{
1
},
{
0
}));
auto
literal_ins
=
else_mod
->
add_literal
(
migraphx
::
literal
(
gen_shape
,
{
0
}));
auto
unsqueeze_ins
=
else_mod
->
add_instruction
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment