Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2265e0d8
Commit
2265e0d8
authored
Jun 21, 2018
by
Paul
Browse files
Formatting
parent
7ac2d6a5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
24 deletions
+12
-24
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+11
-23
src/onnx/read_onnx.cpp
src/onnx/read_onnx.cpp
+1
-1
No files found.
src/onnx/onnx.cpp
View file @
2265e0d8
...
@@ -24,10 +24,7 @@ struct unknown
...
@@ -24,10 +24,7 @@ struct unknown
else
else
return
input
.
front
();
return
input
.
front
();
}
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
{
{
os
<<
x
.
name
();
os
<<
x
.
name
();
...
@@ -51,8 +48,7 @@ struct onnx_parser
...
@@ -51,8 +48,7 @@ struct onnx_parser
{
{
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
;
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
;
using
node_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
;
using
node_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
;
using
op_func
=
using
op_func
=
std
::
function
<
instruction_ref
(
attribute_map
,
std
::
vector
<
instruction_ref
>
)
>
;
std
::
function
<
instruction_ref
(
attribute_map
,
std
::
vector
<
instruction_ref
>
)
>
;
node_map
nodes
;
node_map
nodes
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
program
prog
=
program
();
program
prog
=
program
();
...
@@ -165,7 +161,7 @@ struct onnx_parser
...
@@ -165,7 +161,7 @@ struct onnx_parser
{
{
const
std
::
string
&
name
=
input
.
name
();
const
std
::
string
&
name
=
input
.
name
();
// TODO: Get shape of input parameter
// TODO: Get shape of input parameter
shape
s
=
parse_type
(
input
.
type
());
shape
s
=
parse_type
(
input
.
type
());
instructions
[
name
]
=
prog
.
add_parameter
(
name
,
s
);
instructions
[
name
]
=
prog
.
add_parameter
(
name
,
s
);
}
}
for
(
auto
&&
p
:
nodes
)
for
(
auto
&&
p
:
nodes
)
...
@@ -245,8 +241,7 @@ struct onnx_parser
...
@@ -245,8 +241,7 @@ struct onnx_parser
case
onnx
::
AttributeProto
::
STRING
:
return
{};
case
onnx
::
AttributeProto
::
STRING
:
return
{};
case
onnx
::
AttributeProto
::
TENSOR
:
return
parse_tensor
(
attr
.
t
());
case
onnx
::
AttributeProto
::
TENSOR
:
return
parse_tensor
(
attr
.
t
());
case
onnx
::
AttributeProto
::
GRAPH
:
return
{};
case
onnx
::
AttributeProto
::
GRAPH
:
return
{};
case
onnx
::
AttributeProto
::
FLOATS
:
case
onnx
::
AttributeProto
::
FLOATS
:
return
from_repeated
(
shape
::
float_type
,
attr
.
floats
());
return
from_repeated
(
shape
::
float_type
,
attr
.
floats
());
case
onnx
::
AttributeProto
::
INTS
:
return
from_repeated
(
shape
::
int64_type
,
attr
.
ints
());
case
onnx
::
AttributeProto
::
INTS
:
return
from_repeated
(
shape
::
int64_type
,
attr
.
ints
());
case
onnx
::
AttributeProto
::
STRINGS
:
return
{};
case
onnx
::
AttributeProto
::
STRINGS
:
return
{};
case
onnx
::
AttributeProto
::
TENSORS
:
return
{};
case
onnx
::
AttributeProto
::
TENSORS
:
return
{};
...
@@ -262,28 +257,21 @@ struct onnx_parser
...
@@ -262,28 +257,21 @@ struct onnx_parser
{
{
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT
:
case
onnx
::
TensorProto
::
FLOAT
:
return
literal
{
return
literal
{{
shape
::
float_type
,
dims
},
t
.
float_data
().
begin
(),
t
.
float_data
().
end
()};
{
shape
::
float_type
,
dims
},
t
.
float_data
().
begin
(),
t
.
float_data
().
end
()};
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
INT8
:
case
onnx
::
TensorProto
::
INT8
:
return
literal
{
return
literal
{{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
UINT16
:
case
onnx
::
TensorProto
::
UINT16
:
return
literal
{
return
literal
{{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT16
:
case
onnx
::
TensorProto
::
INT16
:
return
literal
{
return
literal
{{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT32
:
case
onnx
::
TensorProto
::
INT32
:
return
literal
{
return
literal
{{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT64
:
case
onnx
::
TensorProto
::
INT64
:
return
literal
{
return
literal
{{
shape
::
int64_type
,
dims
},
t
.
int64_data
().
begin
(),
t
.
int64_data
().
end
()};
{
shape
::
int64_type
,
dims
},
t
.
int64_data
().
begin
(),
t
.
int64_data
().
end
()};
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
BOOL
:
case
onnx
::
TensorProto
::
BOOL
:
return
literal
{
return
literal
{{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
FLOAT16
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT16
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
DOUBLE
:
case
onnx
::
TensorProto
::
DOUBLE
:
return
literal
{
return
literal
{
...
...
src/onnx/read_onnx.cpp
View file @
2265e0d8
...
@@ -25,7 +25,7 @@ int main(int argc, char const* argv[])
...
@@ -25,7 +25,7 @@ int main(int argc, char const* argv[])
if
(
argc
>
1
)
if
(
argc
>
1
)
{
{
std
::
string
file
=
argv
[
1
];
std
::
string
file
=
argv
[
1
];
auto
prog
=
rtg
::
parse_onnx
(
file
);
auto
prog
=
rtg
::
parse_onnx
(
file
);
prog
.
compile
(
rtg
::
cpu
::
cpu_target
{});
prog
.
compile
(
rtg
::
cpu
::
cpu_target
{});
auto
s
=
prog
.
get_parameter_shape
(
"Input3"
);
auto
s
=
prog
.
get_parameter_shape
(
"Input3"
);
auto
input3
=
get_tensor_argument
(
s
);
auto
input3
=
get_tensor_argument
(
s
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment