Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b8202d61
Unverified
Commit
b8202d61
authored
Nov 08, 2023
by
Attila Dusnoki
Committed by
GitHub
Nov 07, 2023
Browse files
Add scales attribute parse in upsample for older opset versions (#2336)
parent
d7c8b66f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
126 additions
and
65 deletions
+126
-65
src/onnx/parse_resize.cpp
src/onnx/parse_resize.cpp
+90
-63
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+14
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+22
-2
test/onnx/upsample_ver7_test.onnx
test/onnx/upsample_ver7_test.onnx
+0
-0
No files found.
src/onnx/parse_resize.cpp
View file @
b8202d61
...
...
@@ -181,6 +181,76 @@ static std::string get_nearest_mode(const onnx_parser::attribute_map& attr)
return
nearest_mode
;
}
static
std
::
vector
<
double
>
get_scales
(
const
onnx_parser
::
attribute_map
&
attr
)
{
std
::
vector
<
double
>
scales
;
if
(
contains
(
attr
,
"scales"
))
{
copy
(
attr
.
at
(
"scales"
).
floats
(),
std
::
back_inserter
(
scales
));
}
return
scales
;
}
static
void
parse_args
(
const
std
::
vector
<
instruction_ref
>&
args
,
const
std
::
vector
<
size_t
>&
in_lens
,
const
std
::
string
&
op_name
,
std
::
vector
<
double
>&
vec_scale
,
std
::
vector
<
std
::
size_t
>&
out_lens
)
{
for
(
const
auto
&
arg
:
args
)
{
if
(
arg
->
name
()
==
"undefined"
or
arg
==
args
.
front
())
{
continue
;
}
// skipped empty input
auto
lens
=
arg
->
get_shape
().
lens
();
if
(
lens
.
empty
())
{
continue
;
}
auto
type
=
arg
->
get_shape
().
type
();
// output size
if
(
type
==
shape
::
int64_type
)
{
auto
arg_out_s
=
arg
->
eval
();
check_arg_empty
(
arg_out_s
,
"PARSE_"
+
op_name
+
": dynamic output size is not supported!"
);
arg_out_s
.
visit
([
&
](
const
auto
&
ol
)
{
out_lens
.
assign
(
ol
.
begin
(),
ol
.
end
());
});
if
(
out_lens
.
size
()
!=
in_lens
.
size
())
{
MIGRAPHX_THROW
(
"PARSE_"
+
op_name
+
": specified output size does not match input size"
);
}
// compute the scale
vec_scale
.
resize
(
in_lens
.
size
());
std
::
transform
(
in_lens
.
begin
(),
in_lens
.
end
(),
out_lens
.
begin
(),
vec_scale
.
begin
(),
[](
auto
iss
,
auto
oss
)
{
return
1.0
*
oss
/
iss
;
});
}
else
{
// scale input
if
(
lens
[
0
]
==
in_lens
.
size
())
{
auto
arg_scale
=
arg
->
eval
();
check_arg_empty
(
arg_scale
,
"PARSE_"
+
op_name
+
": dynamic input scale is not supported!"
);
arg_scale
.
visit
([
&
](
const
auto
&
v
)
{
vec_scale
.
assign
(
v
.
begin
(),
v
.
end
());
});
}
}
}
}
struct
parse_resize
:
op_parser
<
parse_resize
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Resize"
},
{
"Upsample"
}};
}
...
...
@@ -214,72 +284,30 @@ struct parse_resize : op_parser<parse_resize>
std
::
vector
<
std
::
size_t
>
out_lens
(
in_lens
.
size
());
// scale
std
::
vector
<
double
>
vec_scale
;
std
::
vector
<
double
>
vec_scale
=
get_scales
(
info
.
attributes
)
;
for
(
const
auto
&
arg
:
args
)
// If `scales` was not an attribute, it must be an input
if
(
vec_scale
.
empty
())
{
if
(
arg
->
name
()
==
"undefined"
or
arg
==
args
.
front
())
{
continue
;
}
// skipped empty input
auto
lens
=
arg
->
get_shape
().
lens
();
if
(
lens
.
empty
())
{
continue
;
}
auto
type
=
arg
->
get_shape
().
type
();
// output size
if
(
type
==
shape
::
int64_type
)
{
auto
arg_out_s
=
arg
->
eval
();
check_arg_empty
(
arg_out_s
,
"PARSE_"
+
opd
.
op_name
+
": dynamic output size is not supported!"
);
arg_out_s
.
visit
([
&
](
const
auto
&
ol
)
{
out_lens
.
assign
(
ol
.
begin
(),
ol
.
end
());
});
if
(
out_lens
.
size
()
!=
in_lens
.
size
())
{
MIGRAPHX_THROW
(
"PARSE_"
+
opd
.
op_name
+
": specified output size does not match input size"
);
}
// Depending on the args, it *must* populate the `vec_scale`, and might populate
// `out_lens`
parse_args
(
args
,
in_lens
,
opd
.
op_name
,
vec_scale
,
out_lens
);
}
// compute the scale
vec_scale
.
resize
(
in_lens
.
size
());
std
::
transform
(
in_lens
.
begin
(),
in_lens
.
end
(),
out_lens
.
begin
(),
vec_scale
.
begin
(),
[](
auto
iss
,
auto
oss
)
{
return
1.0
*
oss
/
iss
;
});
}
else
{
if
(
in_lens
.
size
()
!=
vec_scale
.
size
())
{
MIGRAPHX_THROW
(
"PARSE_"
+
opd
.
op_name
+
": ranks of input and scale are different!"
);
}
// scale input
if
(
lens
[
0
]
==
in_lens
.
size
())
{
auto
arg_scale
=
arg
->
eval
();
check_arg_empty
(
arg_scale
,
"PARSE_"
+
opd
.
op_name
+
": dynamic input scale is not supported!"
);
arg_scale
.
visit
([
&
](
const
auto
&
v
)
{
vec_scale
.
assign
(
v
.
begin
(),
v
.
end
());
});
if
(
in_lens
.
size
()
!=
vec_scale
.
size
())
{
MIGRAPHX_THROW
(
"PARSE_"
+
opd
.
op_name
+
": ranks of input and scale are different!"
);
}
std
::
transform
(
in_lens
.
begin
(),
in_lens
.
end
(),
vec_scale
.
begin
(),
out_lens
.
begin
(),
[
&
](
auto
idx
,
auto
scale
)
{
return
static_cast
<
std
::
size_t
>
(
idx
*
scale
);
});
}
}
// if the output was not calculated yet, we update it based on the scales
if
(
all_of
(
out_lens
.
cbegin
(),
out_lens
.
cend
(),
[](
auto
o
)
{
return
o
==
0
;
}))
{
std
::
transform
(
in_lens
.
begin
(),
in_lens
.
end
(),
vec_scale
.
begin
(),
out_lens
.
begin
(),
[
&
](
auto
idx
,
auto
scale
)
{
return
static_cast
<
std
::
size_t
>
(
idx
*
scale
);
});
}
shape
out_s
{
in_s
.
type
(),
out_lens
};
...
...
@@ -288,7 +316,6 @@ struct parse_resize : op_parser<parse_resize>
// reshape input to one-dimension
std
::
vector
<
int64_t
>
rsp_lens
=
{
static_cast
<
int64_t
>
(
in_s
.
elements
())};
args
[
0
]
=
info
.
make_contiguous
(
args
[
0
]);
auto
rsp
=
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
rsp_lens
}}),
args
[
0
]);
if
(
mode
==
"nearest"
)
...
...
test/onnx/gen_onnx.py
View file @
b8202d61
...
...
@@ -9031,6 +9031,20 @@ def upsample_test():
return
([
node
],
[
X
],
[
Y
],
[
scale_tensor
])
@
onnx_test
()
def
upsample_ver7_test
():
X
=
helper
.
make_tensor_value_info
(
'X'
,
TensorProto
.
FLOAT
,
[
1
,
1
,
2
,
2
])
Y
=
helper
.
make_tensor_value_info
(
'Y'
,
TensorProto
.
FLOAT
,
[
1
,
1
,
4
,
6
])
node
=
onnx
.
helper
.
make_node
(
'Upsample'
,
inputs
=
[
'X'
],
outputs
=
[
'Y'
],
mode
=
'nearest'
,
scales
=
[
1.0
,
1.0
,
2.0
,
3.0
])
return
([
node
],
[
X
],
[
Y
])
@
onnx_test
()
def
variable_batch_test
():
x
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
...
...
test/onnx/onnx_test.cpp
View file @
b8202d61
...
...
@@ -6557,9 +6557,8 @@ TEST_CASE(resize_nonstd_input_test)
auto tx =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), inx);
mm->add_instruction(migraphx::make_op("undefined"));
auto
tx_cont
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
tx
);
auto
lrsp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
8
}}}),
tx
_cont
);
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), tx);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
mm->add_return({r});
...
...
@@ -8418,6 +8417,27 @@ TEST_CASE(upsample_test)
EXPECT(p == prog);
}
TEST_CASE(upsample_ver7_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
auto ix = mm->add_parameter("X", sx);
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}};
std::vector<int> ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3};
auto li = mm->add_literal(migraphx::literal(si, ind));
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), ix);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, li);
mm->add_return({r});
auto prog = migraphx::parse_onnx("upsample_ver7_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(unknown_test_throw_print_error)
{
migraphx::onnx_options options;
...
...
test/onnx/upsample_ver7_test.onnx
0 → 100644
View file @
b8202d61
File added
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment