Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b8202d61
Unverified
Commit
b8202d61
authored
Nov 08, 2023
by
Attila Dusnoki
Committed by
GitHub
Nov 07, 2023
Browse files
Add scales attribute parse in upsample for older opset versions (#2336)
parent
d7c8b66f
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
126 additions
and
65 deletions
+126
-65
src/onnx/parse_resize.cpp
src/onnx/parse_resize.cpp
+90
-63
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+14
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+22
-2
test/onnx/upsample_ver7_test.onnx
test/onnx/upsample_ver7_test.onnx
+0
-0
No files found.
src/onnx/parse_resize.cpp
View file @
b8202d61
...
...
@@ -181,41 +181,23 @@ static std::string get_nearest_mode(const onnx_parser::attribute_map& attr)
return
nearest_mode
;
}
st
ruct
parse_resize
:
op_parser
<
parse_resize
>
st
atic
std
::
vector
<
double
>
get_scales
(
const
onnx_parser
::
attribute_map
&
attr
)
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Resize"
},
{
"Upsample"
}};
}
instruction_ref
parse
(
const
op_desc
&
opd
,
const
onnx_parser
&
/*parser*/
,
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
// coord transform mode
std
::
string
coord_trans_mode
=
get_coord_trans_mode
(
info
.
attributes
);
// mode: only nearest and linear modes are supported for now
std
::
string
mode
=
get_mode
(
info
.
attributes
);
// nearest mode
std
::
string
nearest_mode
=
get_nearest_mode
(
info
.
attributes
);
// check exclude_outside, only support 0
if
(
contains
(
info
.
attributes
,
"exclude_outside"
)
and
info
.
attributes
.
at
(
"exclude_outside"
).
i
()
==
1
)
std
::
vector
<
double
>
scales
;
if
(
contains
(
attr
,
"scales"
))
{
MIGRAPHX_THROW
(
"PARSE_"
+
opd
.
op_name
+
": exclude_outside 1 is not supported!"
);
copy
(
attr
.
at
(
"scales"
).
floats
(),
std
::
back_inserter
(
scales
)
);
}
// input data shape info
auto
in_s
=
args
[
0
]
->
get_shape
();
auto
in_lens
=
in_s
.
lens
();
// output shape is explicitly specified
std
::
vector
<
std
::
size_t
>
out_lens
(
in_lens
.
size
());
// scale
std
::
vector
<
double
>
vec_scale
;
return
scales
;
}
static
void
parse_args
(
const
std
::
vector
<
instruction_ref
>&
args
,
const
std
::
vector
<
size_t
>&
in_lens
,
const
std
::
string
&
op_name
,
std
::
vector
<
double
>&
vec_scale
,
std
::
vector
<
std
::
size_t
>&
out_lens
)
{
for
(
const
auto
&
arg
:
args
)
{
if
(
arg
->
name
()
==
"undefined"
or
arg
==
args
.
front
())
...
...
@@ -236,12 +218,12 @@ struct parse_resize : op_parser<parse_resize>
{
auto
arg_out_s
=
arg
->
eval
();
check_arg_empty
(
arg_out_s
,
"PARSE_"
+
opd
.
op_name
+
": dynamic output size is not supported!"
);
"PARSE_"
+
op_name
+
": dynamic output size is not supported!"
);
arg_out_s
.
visit
([
&
](
const
auto
&
ol
)
{
out_lens
.
assign
(
ol
.
begin
(),
ol
.
end
());
});
if
(
out_lens
.
size
()
!=
in_lens
.
size
())
{
MIGRAPHX_THROW
(
"PARSE_"
+
opd
.
op_name
+
MIGRAPHX_THROW
(
"PARSE_"
+
op_name
+
": specified output size does not match input size"
);
}
...
...
@@ -261,25 +243,71 @@ struct parse_resize : op_parser<parse_resize>
{
auto
arg_scale
=
arg
->
eval
();
check_arg_empty
(
arg_scale
,
"PARSE_"
+
opd
.
op_name
+
": dynamic input scale is not supported!"
);
"PARSE_"
+
op_name
+
": dynamic input scale is not supported!"
);
arg_scale
.
visit
([
&
](
const
auto
&
v
)
{
vec_scale
.
assign
(
v
.
begin
(),
v
.
end
());
});
}
}
}
}
struct
parse_resize
:
op_parser
<
parse_resize
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Resize"
},
{
"Upsample"
}};
}
instruction_ref
parse
(
const
op_desc
&
opd
,
const
onnx_parser
&
/*parser*/
,
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
// coord transform mode
std
::
string
coord_trans_mode
=
get_coord_trans_mode
(
info
.
attributes
);
// mode: only nearest and linear modes are supported for now
std
::
string
mode
=
get_mode
(
info
.
attributes
);
// nearest mode
std
::
string
nearest_mode
=
get_nearest_mode
(
info
.
attributes
);
// check exclude_outside, only support 0
if
(
contains
(
info
.
attributes
,
"exclude_outside"
)
and
info
.
attributes
.
at
(
"exclude_outside"
).
i
()
==
1
)
{
MIGRAPHX_THROW
(
"PARSE_"
+
opd
.
op_name
+
": exclude_outside 1 is not supported!"
);
}
// input data shape info
auto
in_s
=
args
[
0
]
->
get_shape
();
auto
in_lens
=
in_s
.
lens
();
// output shape is explicitly specified
std
::
vector
<
std
::
size_t
>
out_lens
(
in_lens
.
size
());
// scale
std
::
vector
<
double
>
vec_scale
=
get_scales
(
info
.
attributes
);
// If `scales` was not an attribute, it must be an input
if
(
vec_scale
.
empty
())
{
// Depending on the args, it *must* populate the `vec_scale`, and might populate
// `out_lens`
parse_args
(
args
,
in_lens
,
opd
.
op_name
,
vec_scale
,
out_lens
);
}
if
(
in_lens
.
size
()
!=
vec_scale
.
size
())
{
MIGRAPHX_THROW
(
"PARSE_"
+
opd
.
op_name
+
": ranks of input and scale are different!"
);
MIGRAPHX_THROW
(
"PARSE_"
+
opd
.
op_name
+
": ranks of input and scale are different!"
);
}
std
::
transform
(
in_lens
.
begin
(),
// if the output was not calculated yet, we update it based on the scales
if
(
all_of
(
out_lens
.
cbegin
(),
out_lens
.
cend
(),
[](
auto
o
)
{
return
o
==
0
;
}))
{
std
::
transform
(
in_lens
.
begin
(),
in_lens
.
end
(),
vec_scale
.
begin
(),
out_lens
.
begin
(),
[
&
](
auto
idx
,
auto
scale
)
{
return
static_cast
<
std
::
size_t
>
(
idx
*
scale
);
});
}
}
[
&
](
auto
idx
,
auto
scale
)
{
return
static_cast
<
std
::
size_t
>
(
idx
*
scale
);
});
}
shape
out_s
{
in_s
.
type
(),
out_lens
};
...
...
@@ -288,7 +316,6 @@ struct parse_resize : op_parser<parse_resize>
// reshape input to one-dimension
std
::
vector
<
int64_t
>
rsp_lens
=
{
static_cast
<
int64_t
>
(
in_s
.
elements
())};
args
[
0
]
=
info
.
make_contiguous
(
args
[
0
]);
auto
rsp
=
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
rsp_lens
}}),
args
[
0
]);
if
(
mode
==
"nearest"
)
...
...
test/onnx/gen_onnx.py
View file @
b8202d61
...
...
@@ -9031,6 +9031,20 @@ def upsample_test():
return
([
node
],
[
X
],
[
Y
],
[
scale_tensor
])
@
onnx_test
()
def
upsample_ver7_test
():
X
=
helper
.
make_tensor_value_info
(
'X'
,
TensorProto
.
FLOAT
,
[
1
,
1
,
2
,
2
])
Y
=
helper
.
make_tensor_value_info
(
'Y'
,
TensorProto
.
FLOAT
,
[
1
,
1
,
4
,
6
])
node
=
onnx
.
helper
.
make_node
(
'Upsample'
,
inputs
=
[
'X'
],
outputs
=
[
'Y'
],
mode
=
'nearest'
,
scales
=
[
1.0
,
1.0
,
2.0
,
3.0
])
return
([
node
],
[
X
],
[
Y
])
@
onnx_test
()
def
variable_batch_test
():
x
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
...
...
test/onnx/onnx_test.cpp
View file @
b8202d61
...
...
@@ -6557,9 +6557,8 @@ TEST_CASE(resize_nonstd_input_test)
auto tx =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), inx);
mm->add_instruction(migraphx::make_op("undefined"));
auto tx_cont = mm->add_instruction(migraphx::make_op("contiguous"), tx);
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), tx
_cont
);
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), tx);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
mm->add_return({r});
...
...
@@ -8418,6 +8417,27 @@ TEST_CASE(upsample_test)
EXPECT(p == prog);
}
TEST_CASE(upsample_ver7_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
auto ix = mm->add_parameter("X", sx);
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}};
std::vector<int> ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3};
auto li = mm->add_literal(migraphx::literal(si, ind));
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), ix);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, li);
mm->add_return({r});
auto prog = migraphx::parse_onnx("upsample_ver7_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(unknown_test_throw_print_error)
{
migraphx::onnx_options options;
...
...
test/onnx/upsample_ver7_test.onnx
0 → 100644
View file @
b8202d61
File added
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment