Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
abf62c5d
"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "954f0725760d5b40e70dd0598306f9ad4c843954"
Commit
abf62c5d
authored
Jun 16, 2022
by
charlie
Browse files
Handle dynamic graph input shapes in ONNX parser
parent
faefeef9
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
108 additions
and
20 deletions
+108
-20
src/include/migraphx/onnx.hpp
src/include/migraphx/onnx.hpp
+7
-3
src/onnx/include/migraphx/onnx/onnx_parser.hpp
src/onnx/include/migraphx/onnx/onnx_parser.hpp
+2
-1
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+10
-1
src/onnx/onnx_parser.cpp
src/onnx/onnx_parser.cpp
+40
-14
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+49
-1
No files found.
src/include/migraphx/onnx.hpp
View file @
abf62c5d
...
@@ -10,10 +10,14 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -10,10 +10,14 @@ inline namespace MIGRAPHX_INLINE_NS {
/// struct to pass in onnx options to parser
/// struct to pass in onnx options to parser
struct
onnx_options
struct
onnx_options
{
{
/// default batch size to use (if not specified in onnx file)
/// Old way to set default fixed dimension size (priority over default_dyn_dim_value)
std
::
size_t
default_dim_value
=
1
;
std
::
size_t
default_dim_value
=
0
;
/// Explicitly specify the dims of an input
/// Default dynamic dimension size (if not specified in onnx file)
shape
::
dynamic_dimension
default_dyn_dim_value
=
{
1
,
1
,
0
};
/// Explicitly specify the dims of an input (priority over map_dyn_input_dims)
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
=
{};
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
=
{};
/// Explicitly specify dynamic dims of an input
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
shape
::
dynamic_dimension
>>
map_dyn_input_dims
=
{};
/// Continue parsing onnx file if an unknown operator is found
/// Continue parsing onnx file if an unknown operator is found
bool
skip_unknown_operators
=
false
;
bool
skip_unknown_operators
=
false
;
/// Print program if an error occurs
/// Print program if an error occurs
...
...
src/onnx/include/migraphx/onnx/onnx_parser.hpp
View file @
abf62c5d
...
@@ -71,8 +71,9 @@ struct onnx_parser
...
@@ -71,8 +71,9 @@ struct onnx_parser
node_map
nodes
;
node_map
nodes
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
program
prog
=
program
();
program
prog
=
program
();
s
td
::
size_t
default_dim_value
=
1
;
s
hape
::
dynamic_dimension
default_
dyn_
dim_value
=
{
1
,
1
,
0
}
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
shape
::
dynamic_dimension
>>
map_dyn_input_dims
;
bool
skip_unknown_operators
=
false
;
bool
skip_unknown_operators
=
false
;
int64_t
max_loop_iterations
=
10
;
int64_t
max_loop_iterations
=
10
;
int64_t
opset_version
=
13
;
int64_t
opset_version
=
13
;
...
...
src/onnx/onnx.cpp
View file @
abf62c5d
...
@@ -19,7 +19,16 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
...
@@ -19,7 +19,16 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
{
{
onnx
::
onnx_parser
parser
;
onnx
::
onnx_parser
parser
;
parser
.
map_input_dims
=
options
.
map_input_dims
;
parser
.
map_input_dims
=
options
.
map_input_dims
;
parser
.
default_dim_value
=
options
.
default_dim_value
;
parser
.
map_dyn_input_dims
=
options
.
map_dyn_input_dims
;
auto
dim_val
=
options
.
default_dim_value
;
if
(
dim_val
!=
0
)
{
parser
.
default_dyn_dim_value
=
{
dim_val
,
dim_val
,
0
};
}
else
{
parser
.
default_dyn_dim_value
=
options
.
default_dyn_dim_value
;
}
parser
.
skip_unknown_operators
=
options
.
skip_unknown_operators
;
parser
.
skip_unknown_operators
=
options
.
skip_unknown_operators
;
parser
.
max_loop_iterations
=
options
.
max_loop_iterations
;
parser
.
max_loop_iterations
=
options
.
max_loop_iterations
;
...
...
src/onnx/onnx_parser.cpp
View file @
abf62c5d
...
@@ -12,9 +12,11 @@
...
@@ -12,9 +12,11 @@
#include <migraphx/file_buffer.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/env.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
namespace
onnx
{
static
onnx_parser
::
attribute_map
get_attributes
(
const
onnx
::
NodeProto
&
node
)
static
onnx_parser
::
attribute_map
get_attributes
(
const
onnx
::
NodeProto
&
node
)
...
@@ -245,7 +247,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
...
@@ -245,7 +247,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
// input not in initializer_data, so it is a real input
// input not in initializer_data, so it is a real input
if
(
!
contains
(
mod_insts
,
name
))
if
(
!
contains
(
mod_insts
,
name
))
{
{
// ONNX specification does not specify h
w
o to deal with the
// ONNX specification does not specify ho
w
to deal with the
// scenario that a nested subgraph contains a parameter with the
// scenario that a nested subgraph contains a parameter with the
// name existed in its parent graph.
// name existed in its parent graph.
// In the current implementation, MIGraphX throws an exception for that.
// In the current implementation, MIGraphX throws an exception for that.
...
@@ -254,14 +256,23 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
...
@@ -254,14 +256,23 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
MIGRAPHX_THROW
(
"module
\"
"
+
mod
->
name
()
+
"
\"
has parameter name
\"
"
+
name
+
MIGRAPHX_THROW
(
"module
\"
"
+
mod
->
name
()
+
"
\"
has parameter name
\"
"
+
name
+
"
\"
existing in parent graph!"
);
"
\"
existing in parent graph!"
);
}
}
shape
s
;
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
size_t
>
dims
;
if
(
map_input_dims
.
count
(
name
)
>
0
)
if
(
map_input_dims
.
count
(
name
)
>
0
)
{
{
dims
=
map_input_dims
.
at
(
name
);
dims
=
map_input_dims
.
at
(
name
);
s
=
parse_type
(
input
.
type
(),
dims
);
}
else
if
(
map_dyn_input_dims
.
count
(
name
)
>
0
)
{
shape
::
type_t
shape_type
=
get_type
(
input
.
type
().
tensor_type
().
elem_type
());
s
=
{
shape_type
,
map_dyn_input_dims
.
at
(
name
)};
}
else
{
s
=
parse_type
(
input
.
type
(),
dims
);
}
}
shape
s
=
parse_type
(
input
.
type
(),
dims
);
mod_insts
[
name
]
=
mod
->
add_parameter
(
name
,
s
);
mod_insts
[
name
]
=
mod
->
add_parameter
(
name
,
s
);
}
}
}
}
...
@@ -416,30 +427,45 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
...
@@ -416,30 +427,45 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
return
{
shape_type
,
input_dims
};
return
{
shape_type
,
input_dims
};
}
}
std
::
vector
<
s
td
::
size_t
>
dims
;
std
::
vector
<
s
hape
::
dynamic_dimension
>
dynamic_
dims
;
auto
&&
tensor_dims
=
t
.
tensor_type
().
shape
().
dim
();
auto
&&
tensor_dims
=
t
.
tensor_type
().
shape
().
dim
();
std
::
transform
(
tensor_dims
.
begin
(),
std
::
transform
(
tensor_dims
.
begin
(),
tensor_dims
.
end
(),
tensor_dims
.
end
(),
std
::
back_inserter
(
dims
),
std
::
back_inserter
(
dynamic_
dims
),
[
&
](
auto
&&
d
)
->
s
td
::
size_t
{
[
&
](
auto
&&
d
)
->
s
hape
::
dynamic_dimension
{
if
(
d
.
has_dim_value
())
if
(
d
.
has_dim_value
())
{
{
if
(
static_cast
<
int
>
(
d
.
dim_value
())
<=
0
)
if
(
static_cast
<
int
>
(
d
.
dim_value
())
<=
0
)
{
{
return
default_dim_value
;
return
default_
dyn_
dim_value
;
}
}
return
d
.
dim_value
();
auto
tmp
=
static_cast
<
std
::
size_t
>
(
d
.
dim_value
());
return
{
tmp
,
tmp
,
0
};
}
}
else
else
{
{
return
default_dim_value
;
return
default_
dyn_
dim_value
;
}
}
});
});
if
(
dims
.
empty
())
if
(
dynamic_dims
.
empty
())
{
return
{
shape_type
};
return
{
shape_type
};
}
return
{
shape_type
,
dims
};
if
(
std
::
all_of
(
dynamic_dims
.
begin
(),
dynamic_dims
.
end
(),
[](
auto
dd
)
{
return
dd
.
is_fixed
();
}))
{
std
::
vector
<
std
::
size_t
>
dims
;
std
::
transform
(
dynamic_dims
.
begin
(),
dynamic_dims
.
end
(),
std
::
back_inserter
(
dims
),
[](
auto
d
)
{
return
d
.
max
;
}
);
return
{
shape_type
,
dims
};
}
return
{
shape_type
,
dynamic_dims
};
}
}
shape
::
type_t
get_type
(
int
dtype
)
shape
::
type_t
get_type
(
int
dtype
)
...
...
test/onnx/onnx_test.cpp
View file @
abf62c5d
...
@@ -5411,7 +5411,55 @@ TEST_CASE(variable_batch_test)
...
@@ -5411,7 +5411,55 @@ TEST_CASE(variable_batch_test)
EXPECT(p == prog);
EXPECT(p == prog);
}
}
TEST_CASE
(
variable_batch_user_input_test
)
TEST_CASE(variable_batch_user_input_test1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 16, 16}});
auto r = mm->add_instruction(migraphx::make_op("identity"), l0);
mm->add_return({r});
migraphx::onnx_options options;
options.default_dyn_dim_value = {2, 2, 0};
auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(variable_batch_user_input_test2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {{2, 5, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}}});
auto r = mm->add_instruction(migraphx::make_op("identity"), l0);
mm->add_return({r});
migraphx::onnx_options options;
options.default_dyn_dim_value = {2, 5, 0};
auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(variable_batch_user_input_test3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {{2, 5, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}}});
auto r = mm->add_instruction(migraphx::make_op("identity"), l0);
mm->add_return({r});
migraphx::onnx_options options;
options.map_dyn_input_dims["0"] = {{2, 5, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}};
auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(variable_batch_user_input_test4)
{
{
migraphx::program p;
migraphx::program p;
auto* mm = p.get_main_module();
auto* mm = p.get_main_module();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment