Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c297ce5f
"test/ut/compression/vscode:/vscode.git/clone" did not exist on "2566badb06095b9e3ea16eb6f00fd58da65a95fd"
Commit
c297ce5f
authored
Nov 04, 2022
by
Ted Themistokleous
Browse files
Fixes to handle constants
parent
b6ca9b26
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
34 additions
and
29 deletions
+34
-29
src/onnx/parse_constant.cpp
src/onnx/parse_constant.cpp
+2
-1
src/onnx/parse_constant_of_shape.cpp
src/onnx/parse_constant_of_shape.cpp
+11
-3
src/onnx/parse_if.cpp
src/onnx/parse_if.cpp
+15
-10
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+6
-15
No files found.
src/onnx/parse_constant.cpp
View file @
c297ce5f
...
@@ -43,7 +43,8 @@ struct parse_constant : op_parser<parse_constant>
...
@@ -43,7 +43,8 @@ struct parse_constant : op_parser<parse_constant>
// return empty literal
// return empty literal
if
(
v
.
get_shape
().
elements
()
==
0
)
if
(
v
.
get_shape
().
elements
()
==
0
)
{
{
return
info
.
add_literal
(
literal
{
v
.
get_shape
().
type
()});
migraphx
::
shape
empty_constant
(
v
.
get_shape
().
type
(),
{
1
},
{
0
});
return
info
.
add_literal
(
literal
{
empty_constant
,
{
0
}});
}
}
auto
dim_size
=
info
.
attributes
.
at
(
"value"
).
t
().
dims_size
();
auto
dim_size
=
info
.
attributes
.
at
(
"value"
).
t
().
dims_size
();
...
...
src/onnx/parse_constant_of_shape.cpp
View file @
c297ce5f
...
@@ -68,7 +68,7 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
...
@@ -68,7 +68,7 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
// empty input tensor, output is a scalar
// empty input tensor, output is a scalar
if
(
args
[
0
]
->
get_shape
().
elements
()
==
0
)
if
(
args
[
0
]
->
get_shape
().
elements
()
==
0
)
{
{
s
=
migraphx
::
shape
{
type
,
{
1
},
{
0
}};
s
=
migraphx
::
shape
{
type
,
{
1
},
{}};
}
}
else
else
{
{
...
@@ -84,8 +84,16 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
...
@@ -84,8 +84,16 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
l_val
.
visit
([
&
](
auto
val
)
{
l_val
.
visit
([
&
](
auto
val
)
{
using
val_type
=
std
::
remove_cv_t
<
typename
decltype
(
val
)
::
value_type
>
;
using
val_type
=
std
::
remove_cv_t
<
typename
decltype
(
val
)
::
value_type
>
;
// l_val contains only one element
// l_val contains only one element
std
::
vector
<
val_type
>
out_vec
(
s
.
elements
(),
val
.
front
());
if
(
s
.
elements
()
>
0
)
l_out
=
literal
(
s
,
out_vec
);
{
std
::
vector
<
val_type
>
out_vec
(
s
.
elements
(),
val
.
front
());
l_out
=
literal
(
s
,
out_vec
);
}
else
{
std
::
vector
<
val_type
>
out_vec
{
val
.
front
()};
l_out
=
literal
(
s
,
out_vec
);
}
});
});
return
info
.
add_literal
(
l_out
);
return
info
.
add_literal
(
l_out
);
...
...
src/onnx/parse_if.cpp
View file @
c297ce5f
...
@@ -31,6 +31,7 @@
...
@@ -31,6 +31,7 @@
#include <migraphx/reduce_dims.hpp>
#include <migraphx/reduce_dims.hpp>
#include <algorithm>
#include <algorithm>
#include <migraphx/stringutils.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
namespace
onnx
{
...
@@ -92,7 +93,9 @@ struct parse_if : op_parser<parse_if>
...
@@ -92,7 +93,9 @@ struct parse_if : op_parser<parse_if>
auto
throw_shapes
=
[
&
]()
{
auto
throw_shapes
=
[
&
]()
{
MIGRAPHX_THROW
(
"PARSE_IF: "
+
info
.
name
+
MIGRAPHX_THROW
(
"PARSE_IF: "
+
info
.
name
+
" then and else sub_graphs must have compatible shapes "
);
" then and else sub_graphs must have compatible shapes "
+
to_string_range
(
then_out_shapes
)
+
" vs "
+
to_string_range
(
else_out_shapes
));
};
};
if
(
then_out_shapes
.
size
()
!=
else_out_shapes
.
size
())
if
(
then_out_shapes
.
size
()
!=
else_out_shapes
.
size
())
...
@@ -126,16 +129,14 @@ struct parse_if : op_parser<parse_if>
...
@@ -126,16 +129,14 @@ struct parse_if : op_parser<parse_if>
assert
(
not
(
then_lens
.
empty
()
and
else_lens
.
empty
()));
assert
(
not
(
then_lens
.
empty
()
and
else_lens
.
empty
()));
auto
handle_empty_branch
=
[](
module_ref
&
mdl
,
int
index
,
const
shape
&
out_shape
)
{
auto
handle_empty_branch
=
[](
module_ref
&
mdl
,
int
index
,
const
shape
&
out_shape
)
{
shape
gen_shape
(
shape
(
out_shape
.
type
(),
{
1
},
{
0
}));
auto
scalar_ins
=
auto
literal_ins
=
mdl
->
add_literal
(
literal
(
gen_shape
,
{
0
}));
mdl
->
insert_instruction
(
std
::
prev
(
mdl
->
end
()),
auto
unsqueeze_ins
=
mdl
->
insert_instruction
(
make_op
(
"scalar"
,
{{
"out_lens"
,
out_shape
.
lens
()}}),
std
::
prev
(
mdl
->
end
()),
std
::
prev
(
mdl
->
end
()));
make_op
(
"scalar"
,
{{
"scalar_bcst_dims"
,
out_shape
.
lens
()}}),
literal_ins
);
auto
broad_ins
=
mdl
->
insert_instruction
(
auto
broad_ins
=
mdl
->
insert_instruction
(
std
::
prev
(
mdl
->
end
()),
std
::
prev
(
mdl
->
end
()),
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_shape
.
lens
()}}),
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_shape
.
lens
()}}),
unsqueeze
_ins
);
scalar
_ins
);
auto
contig_out
=
mdl
->
insert_instruction
(
auto
contig_out
=
mdl
->
insert_instruction
(
std
::
prev
(
mdl
->
end
()),
make_op
(
"contiguous"
),
broad_ins
);
std
::
prev
(
mdl
->
end
()),
make_op
(
"contiguous"
),
broad_ins
);
mdl
->
replace_instruction
(
std
::
prev
(
mdl
->
end
())
->
inputs
().
at
(
index
),
contig_out
);
mdl
->
replace_instruction
(
std
::
prev
(
mdl
->
end
())
->
inputs
().
at
(
index
),
contig_out
);
...
@@ -144,11 +145,12 @@ struct parse_if : op_parser<parse_if>
...
@@ -144,11 +145,12 @@ struct parse_if : op_parser<parse_if>
// Handle one empty branch by setting output identical to the other
// Handle one empty branch by setting output identical to the other
// need to update the then_shape before we do further checks
// need to update the then_shape before we do further checks
if
(
then_lens
.
empty
())
if
(
then_out_shape
.
strides
().
empty
())
{
{
then_lens
=
handle_empty_branch
(
then_mdl
,
i
,
else_out_shape
);
then_lens
=
handle_empty_branch
(
then_mdl
,
i
,
else_out_shape
);
}
}
else
if
(
else_
lens
.
empty
())
else
if
(
else_
out_shape
.
strides
()
.
empty
())
{
{
else_lens
=
handle_empty_branch
(
else_mdl
,
i
,
then_out_shape
);
else_lens
=
handle_empty_branch
(
else_mdl
,
i
,
then_out_shape
);
}
}
...
@@ -183,6 +185,9 @@ struct parse_if : op_parser<parse_if>
...
@@ -183,6 +185,9 @@ struct parse_if : op_parser<parse_if>
}
}
}
}
then_mdl
->
debug_print
();
else_mdl
->
debug_print
();
auto
if_ret
=
info
.
add_instruction
(
make_op
(
"if"
),
args
,
{
then_mdl
,
else_mdl
});
auto
if_ret
=
info
.
add_instruction
(
make_op
(
"if"
),
args
,
{
then_mdl
,
else_mdl
});
auto
out_s
=
if_ret
->
get_shape
();
auto
out_s
=
if_ret
->
get_shape
();
assert
(
out_s
.
type
()
==
shape
::
tuple_type
);
assert
(
out_s
.
type
()
==
shape
::
tuple_type
);
...
...
test/onnx/onnx_test.cpp
View file @
c297ce5f
...
@@ -761,7 +761,7 @@ TEST_CASE(constant_empty_scalar_int64_test)
...
@@ -761,7 +761,7 @@ TEST_CASE(constant_empty_scalar_int64_test)
{
{
migraphx::program p;
migraphx::program p;
auto* mm = p.get_main_module();
auto* mm = p.get_main_module();
mm
->
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
::
int64_type
});
mm->add_literal(migraphx::literal{migraphx::shape::int64_type
, {0}
});
auto prog = optimize_onnx("constant_empty_scalar_int64_test.onnx");
auto prog = optimize_onnx("constant_empty_scalar_int64_test.onnx");
EXPECT(p == prog);
EXPECT(p == prog);
...
@@ -781,8 +781,8 @@ TEST_CASE(const_of_shape_empty_input_test)
...
@@ -781,8 +781,8 @@ TEST_CASE(const_of_shape_empty_input_test)
{
{
migraphx::program p;
migraphx::program p;
auto* mm = p.get_main_module();
auto* mm = p.get_main_module();
mm
->
add_literal
(
migraphx
::
literal
(
migraphx
::
shape
::
int32_type
));
mm->add_literal(migraphx::literal(migraphx::shape::int32_type
, {0}
));
migraphx
::
shape
s
(
migraphx
::
shape
::
int64_type
,
{
1
}
,
{
0
}
);
migraphx::shape s(migraphx::shape::int64_type, {1});
std::vector<int64_t> vec(s.elements(), 10);
std::vector<int64_t> vec(s.elements(), 10);
mm->add_literal(migraphx::literal(s, vec));
mm->add_literal(migraphx::literal(s, vec));
...
@@ -2425,17 +2425,18 @@ TEST_CASE(if_literal_test)
...
@@ -2425,17 +2425,18 @@ TEST_CASE(if_literal_test)
auto cond = mm->add_parameter("cond", cond_s);
auto cond = mm->add_parameter("cond", cond_s);
migraphx::shape s{migraphx::shape::float_type, {5}};
migraphx::shape s{migraphx::shape::float_type, {5}};
migraphx::shape empty_const(migraphx::shape::float_type, {1}, {0});
auto* then_mod = p.create_module("If_1_if");
auto* then_mod = p.create_module("If_1_if");
std::vector<float> data1 = {1, 2, 3, 4, 5};
std::vector<float> data1 = {1, 2, 3, 4, 5};
auto l1 = then_mod->add_literal(migraphx::literal(s, data1));
auto l1 = then_mod->add_literal(migraphx::literal(s, data1));
then_mod
->
add_literal
({});
then_mod->add_literal({
empty_const, {0}
});
then_mod->add_return({l1});
then_mod->add_return({l1});
auto* else_mod = p.create_module("If_1_else");
auto* else_mod = p.create_module("If_1_else");
std::vector<float> data2 = {5, 4, 3, 2, 1};
std::vector<float> data2 = {5, 4, 3, 2, 1};
auto l2 = else_mod->add_literal(migraphx::literal(s, data2));
auto l2 = else_mod->add_literal(migraphx::literal(s, data2));
else_mod
->
add_literal
({});
else_mod->add_literal({
empty_const, {0}
});
else_mod->add_return({l2});
else_mod->add_return({l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
...
@@ -2599,8 +2600,6 @@ TEST_CASE(if_then_empty_constant_test)
...
@@ -2599,8 +2600,6 @@ TEST_CASE(if_then_empty_constant_test)
auto* then_mod = p.create_module("If_4_if");
auto* then_mod = p.create_module("If_4_if");
then_mod
->
add_literal
(
migraphx
::
shape
::
int64_type
);
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins = then_mod->add_literal(migraphx::literal(gen_shape, {0}));
auto literal_ins = then_mod->add_literal(migraphx::literal(gen_shape, {0}));
auto unsqueeze_ins = then_mod->add_instruction(
auto unsqueeze_ins = then_mod->add_instruction(
...
@@ -2636,9 +2635,6 @@ TEST_CASE(if_then_empty_constant_multi_output_test)
...
@@ -2636,9 +2635,6 @@ TEST_CASE(if_then_empty_constant_multi_output_test)
auto* then_mod = p.create_module("If_4_if");
auto* then_mod = p.create_module("If_4_if");
then_mod
->
add_literal
(
migraphx
::
shape
::
int64_type
);
then_mod
->
add_literal
(
migraphx
::
shape
::
int64_type
);
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins = then_mod->add_literal(migraphx::literal(gen_shape, {0}));
auto literal_ins = then_mod->add_literal(migraphx::literal(gen_shape, {0}));
...
@@ -2691,8 +2687,6 @@ TEST_CASE(if_else_empty_constant_test)
...
@@ -2691,8 +2687,6 @@ TEST_CASE(if_else_empty_constant_test)
auto* else_mod = p.create_module("If_4_else");
auto* else_mod = p.create_module("If_4_else");
else_mod
->
add_literal
(
s
.
type
());
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins = else_mod->add_literal(migraphx::literal(gen_shape, {0}));
auto literal_ins = else_mod->add_literal(migraphx::literal(gen_shape, {0}));
...
@@ -2731,9 +2725,6 @@ TEST_CASE(if_else_empty_constant_multi_output_test)
...
@@ -2731,9 +2725,6 @@ TEST_CASE(if_else_empty_constant_multi_output_test)
auto* else_mod = p.create_module("If_4_else");
auto* else_mod = p.create_module("If_4_else");
else_mod
->
add_literal
(
migraphx
::
shape
::
int64_type
);
else_mod
->
add_literal
(
migraphx
::
shape
::
int64_type
);
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
migraphx::shape gen_shape(migraphx::shape(s.type(), {1}, {0}));
auto literal_ins = else_mod->add_literal(migraphx::literal(gen_shape, {0}));
auto literal_ins = else_mod->add_literal(migraphx::literal(gen_shape, {0}));
auto unsqueeze_ins = else_mod->add_instruction(
auto unsqueeze_ins = else_mod->add_instruction(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment