Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9f046d67
"vscode:/vscode.git/clone" did not exist on "9661bd57466c445545a4f432133d1581330fd8a1"
Commit
9f046d67
authored
Apr 19, 2018
by
Paul
Browse files
Parse onnx and convert to internal ir
parent
2f8e4e83
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
252 additions
and
50 deletions
+252
-50
include/rtg/literal.hpp
include/rtg/literal.hpp
+16
-2
include/rtg/operators.hpp
include/rtg/operators.hpp
+39
-23
include/rtg/shape.hpp
include/rtg/shape.hpp
+9
-1
include/rtg/stringutils.hpp
include/rtg/stringutils.hpp
+1
-1
src/program.cpp
src/program.cpp
+4
-1
src/read_onnx.cpp
src/read_onnx.cpp
+180
-19
test/eval_test.cpp
test/eval_test.cpp
+2
-2
test/literal_test.cpp
test/literal_test.cpp
+1
-1
No files found.
include/rtg/literal.hpp
View file @
9f046d67
...
@@ -28,7 +28,9 @@ struct literal : raw_data<literal>
...
@@ -28,7 +28,9 @@ struct literal : raw_data<literal>
{
{
assert
(
s
.
packed
());
assert
(
s
.
packed
());
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
std
::
copy
(
x
.
begin
(),
x
.
end
(),
reinterpret_cast
<
T
*>
(
buffer
.
data
()));
s
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
x
.
begin
(),
x
.
end
(),
as
.
from
(
buffer
.
data
()));
});
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -37,7 +39,19 @@ struct literal : raw_data<literal>
...
@@ -37,7 +39,19 @@ struct literal : raw_data<literal>
{
{
assert
(
s
.
packed
());
assert
(
s
.
packed
());
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
std
::
copy
(
x
.
begin
(),
x
.
end
(),
reinterpret_cast
<
T
*>
(
buffer
.
data
()));
s
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
x
.
begin
(),
x
.
end
(),
as
.
from
(
buffer
.
data
()));
});
}
template
<
class
Iterator
>
literal
(
shape
s
,
Iterator
start
,
Iterator
end
)
:
buffer
(
s
.
bytes
(),
0
),
shape_
(
s
)
{
assert
(
s
.
packed
());
s
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
data
()));
});
}
}
literal
(
shape
s
,
const
char
*
x
)
literal
(
shape
s
,
const
char
*
x
)
...
...
include/rtg/operators.hpp
View file @
9f046d67
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include <rtg/operand.hpp>
#include <rtg/operand.hpp>
#include <rtg/stringutils.hpp>
#include <rtg/stringutils.hpp>
#include <cmath>
namespace
rtg
{
namespace
rtg
{
...
@@ -10,11 +11,11 @@ struct not_computable
...
@@ -10,11 +11,11 @@ struct not_computable
{
{
argument
compute
(
std
::
vector
<
argument
>
)
const
argument
compute
(
std
::
vector
<
argument
>
)
const
{
{
throw
"not computable"
;
throw
std
::
runtime_error
(
"not computable"
)
;
}
}
};
};
struct
convolution
:
not_computable
struct
convolution
{
{
std
::
array
<
std
::
size_t
,
2
>
padding
=
{
0
,
0
};
std
::
array
<
std
::
size_t
,
2
>
padding
=
{
0
,
0
};
std
::
array
<
std
::
size_t
,
2
>
stride
=
{
1
,
1
};
std
::
array
<
std
::
size_t
,
2
>
stride
=
{
1
,
1
};
...
@@ -28,26 +29,31 @@ struct convolution : not_computable
...
@@ -28,26 +29,31 @@ struct convolution : not_computable
}
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
if
(
inputs
.
size
()
!=
2
)
throw
"Wrong number of arguments"
;
if
(
inputs
.
size
()
!=
2
)
throw
std
::
runtime_error
(
"Wrong number of arguments"
)
;
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
weights
=
inputs
.
at
(
1
);
const
shape
&
weights
=
inputs
.
at
(
1
);
if
(
input
.
type
()
!=
weights
.
type
())
throw
"Type doesn't match"
;
if
(
input
.
type
()
!=
weights
.
type
())
throw
std
::
runtime_error
(
"Type doesn't match"
)
;
if
(
input
.
size
()
!=
weights
.
size
())
throw
"Dimensions don't match"
;
if
(
input
.
lens
().
size
()
!=
weights
.
lens
().
size
())
throw
std
::
runtime_error
(
"Dimensions don't match"
)
;
if
(
input
.
size
()
!=
4
)
throw
"Only 4d convolution supported"
;
if
(
input
.
lens
().
size
()
!=
4
)
throw
std
::
runtime_error
(
"Only 4d convolution supported"
)
;
auto
t
=
input
.
type
();
auto
t
=
input
.
type
();
return
{
t
,
{
return
{
t
,
{
input
[
0
],
input
.
lens
()
[
0
],
weights
[
0
],
weights
.
lens
()
[
0
],
std
::
max
<
std
::
ptrdiff_t
>
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
input
[
2
]
-
(
1
+
dilation
[
0
]
*
(
weights
[
2
]
-
1
))
+
2
*
padding
[
0
])
/
stride
[
0
]
+
1
),
1
,
(
input
.
lens
()
[
2
]
-
(
1
+
dilation
[
0
]
*
(
weights
.
lens
()
[
2
]
-
1
))
+
2
*
padding
[
0
])
/
stride
[
0
]
+
1
)
)
,
std
::
max
<
std
::
ptrdiff_t
>
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
input
[
3
]
-
(
1
+
dilation
[
1
]
*
(
weights
[
3
]
-
1
))
+
2
*
padding
[
1
])
/
stride
[
1
]
+
1
),
1
,
(
input
.
lens
()
[
3
]
-
(
1
+
dilation
[
1
]
*
(
weights
.
lens
()
[
3
]
-
1
))
+
2
*
padding
[
1
])
/
stride
[
1
]
+
1
)
)
,
}};
}};
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
std
::
runtime_error
(
"not computable"
);
}
};
};
struct
pooling
:
not_computable
struct
pooling
{
{
std
::
string
mode
;
std
::
string
mode
;
std
::
array
<
std
::
size_t
,
2
>
padding
=
{
0
,
0
};
std
::
array
<
std
::
size_t
,
2
>
padding
=
{
0
,
0
};
...
@@ -62,24 +68,29 @@ struct pooling : not_computable
...
@@ -62,24 +68,29 @@ struct pooling : not_computable
}
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
if
(
!
inputs
.
empty
())
throw
"Wrong number of arguments"
;
if
(
inputs
.
empty
())
throw
std
::
runtime_error
(
"Wrong number of arguments"
)
;
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
input
=
inputs
.
at
(
0
);
if
(
input
.
size
()
!=
4
)
throw
"Only 4d pooling supported"
;
if
(
input
.
lens
().
size
()
!=
4
)
throw
std
::
runtime_error
(
"Only 4d pooling supported"
)
;
auto
t
=
input
.
type
();
auto
t
=
input
.
type
();
return
{
t
,
{
return
{
t
,
{
input
[
0
],
input
.
lens
()
[
0
],
input
[
1
],
input
.
lens
()
[
1
],
std
::
max
<
std
::
ptrdiff_t
>
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
std
::
ceil
((
input
[
3
]
+
2
*
padding
[
0
]
-
lengths
[
0
])
/
static_cast
<
float
>
(
stride
[
0
]))
+
1
),
1
,
std
::
ceil
((
input
.
lens
()
[
3
]
+
2
*
padding
[
0
]
-
lengths
[
0
])
/
static_cast
<
float
>
(
stride
[
0
]))
+
1
)
)
,
std
::
max
<
std
::
ptrdiff_t
>
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
std
::
ceil
((
input
[
4
]
+
2
*
padding
[
1
]
-
lengths
[
1
])
/
static_cast
<
float
>
(
stride
[
1
]))
+
1
),
1
,
std
::
ceil
((
input
.
lens
()
[
4
]
+
2
*
padding
[
1
]
-
lengths
[
1
])
/
static_cast
<
float
>
(
stride
[
1
]))
+
1
)
)
,
}};
}};
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
std
::
runtime_error
(
"not computable"
);
}
};
};
struct
activation
:
not_computable
struct
activation
{
{
std
::
string
mode
;
std
::
string
mode
;
std
::
string
name
()
const
std
::
string
name
()
const
...
@@ -88,9 +99,14 @@ struct activation : not_computable
...
@@ -88,9 +99,14 @@ struct activation : not_computable
}
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
if
(
!
inputs
.
empty
())
throw
"Wrong number of arguments"
;
if
(
inputs
.
empty
())
throw
std
::
runtime_error
(
"Wrong number of arguments"
)
;
return
inputs
.
front
();
return
inputs
.
front
();
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
std
::
runtime_error
(
"not computable"
);
}
};
};
...
...
include/rtg/shape.hpp
View file @
9f046d67
...
@@ -13,7 +13,15 @@ struct shape
...
@@ -13,7 +13,15 @@ struct shape
// Add new types here
// Add new types here
#define RTG_SHAPE_VISIT_TYPES(m) \
#define RTG_SHAPE_VISIT_TYPES(m) \
m(float_type, float) \
m(float_type, float) \
m(int_type, int) \
m(double_type, double) \
m(uint8_type, uint8_t) \
m(int8_type, int8_t) \
m(uint16_type, uint16_t) \
m(int16_type, int16_t) \
m(int32_type, int32_t) \
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \
#define RTG_SHAPE_ENUM_TYPES(x, t) x,
#define RTG_SHAPE_ENUM_TYPES(x, t) x,
enum
type_t
enum
type_t
...
...
include/rtg/stringutils.hpp
View file @
9f046d67
...
@@ -72,7 +72,7 @@ inline std::string to_string(const Range& r)
...
@@ -72,7 +72,7 @@ inline std::string to_string(const Range& r)
if
(
!
r
.
empty
())
if
(
!
r
.
empty
())
{
{
ss
<<
r
.
front
();
ss
<<
r
.
front
();
std
::
for_each
(
++
r
.
begin
(),
r
.
end
(),
[
&
](
auto
&&
x
)
std
::
for_each
(
std
::
next
(
r
.
begin
()
)
,
r
.
end
(),
[
&
](
auto
&&
x
)
{
{
ss
<<
", "
<<
x
;
ss
<<
", "
<<
x
;
});
});
...
...
src/program.cpp
View file @
9f046d67
...
@@ -51,7 +51,10 @@ void program::print() const
...
@@ -51,7 +51,10 @@ void program::print() const
if
(
ins
.
op
.
name
()
==
"@literal"
)
if
(
ins
.
op
.
name
()
==
"@literal"
)
{
{
std
::
cout
<<
"{"
<<
ins
.
lit
<<
"}"
;
if
(
ins
.
lit
.
get_shape
().
elements
()
>
10
)
std
::
cout
<<
"{ ... }"
;
else
std
::
cout
<<
"{"
<<
ins
.
lit
<<
"}"
;
}
}
if
(
!
ins
.
arguments
.
empty
())
if
(
!
ins
.
arguments
.
empty
())
...
...
src/read_onnx.cpp
View file @
9f046d67
...
@@ -5,8 +5,10 @@
...
@@ -5,8 +5,10 @@
#include <iostream>
#include <iostream>
#include <fstream>
#include <fstream>
#include <unordered_map>
#include <unordered_map>
#include <functional>
#include <rtg/program.hpp>
#include <rtg/program.hpp>
#include <rtg/operators.hpp>
struct
unknown
struct
unknown
{
{
...
@@ -26,12 +28,95 @@ struct unknown
...
@@ -26,12 +28,95 @@ struct unknown
}
}
};
};
template
<
class
C
,
class
T
>
bool
contains
(
C
&&
c
,
T
&&
x
)
{
return
c
.
find
(
x
)
!=
c
.
end
();
}
template
<
class
Range
,
class
Iterator
>
void
copy
(
Range
&&
r
,
Iterator
it
)
{
std
::
copy
(
r
.
begin
(),
r
.
end
(),
it
);
}
struct
onnx_parser
struct
onnx_parser
{
{
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
nodes
;
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
;
using
node_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
;
node_map
nodes
;
std
::
unordered_map
<
std
::
string
,
rtg
::
instruction
*>
instructions
;
std
::
unordered_map
<
std
::
string
,
rtg
::
instruction
*>
instructions
;
std
::
shared_ptr
<
rtg
::
program
>
prog
=
std
::
make_shared
<
rtg
::
program
>
();
std
::
shared_ptr
<
rtg
::
program
>
prog
=
std
::
make_shared
<
rtg
::
program
>
();
std
::
unordered_map
<
std
::
string
,
std
::
function
<
rtg
::
instruction
*
(
attribute_map
,
std
::
vector
<
rtg
::
instruction
*>
)
>>
ops
;
onnx_parser
()
{
add_op
(
"Conv"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction
*>
args
)
{
rtg
::
convolution
op
;
if
(
contains
(
attributes
,
"pads"
))
{
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
}
if
(
contains
(
attributes
,
"strides"
))
{
copy
(
attributes
[
"strides"
].
ints
(),
op
.
stride
.
begin
());
}
if
(
contains
(
attributes
,
"dilations"
))
{
copy
(
attributes
[
"dilations"
].
ints
(),
op
.
dilation
.
begin
());
}
return
prog
->
add_instruction
(
op
,
args
);
});
add_op
(
"MaxPool"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction
*>
args
)
{
rtg
::
pooling
op
{
"max"
};
// for(auto&& p:attributes) std::cout << p.first << std::endl;
if
(
contains
(
attributes
,
"pads"
))
{
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
}
if
(
contains
(
attributes
,
"strides"
))
{
copy
(
attributes
[
"strides"
].
ints
(),
op
.
stride
.
begin
());
}
if
(
contains
(
attributes
,
"kernel_shape"
))
{
copy
(
attributes
[
"kernel_shape"
].
ints
(),
op
.
lengths
.
begin
());
}
return
prog
->
add_instruction
(
op
,
args
);
});
add_op
(
"Relu"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction
*>
args
)
{
return
prog
->
add_instruction
(
rtg
::
activation
{
"relu"
},
args
);
});
add_op
(
"Constant"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction
*>
)
{
rtg
::
literal
v
=
parse_value
(
attributes
.
at
(
"value"
));
return
prog
->
add_literal
(
v
);
});
}
template
<
class
F
>
void
add_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
f
);
}
void
parse_from
(
std
::
istream
&
is
)
{
onnx
::
ModelProto
model
;
if
(
model
.
ParseFromIstream
(
&
is
))
{
if
(
model
.
has_graph
())
{
this
->
parse_graph
(
model
.
graph
());
}
}
else
{
throw
std
::
runtime_error
(
"Failed reading"
);
}
}
void
parse_graph
(
const
onnx
::
GraphProto
&
graph
)
void
parse_graph
(
const
onnx
::
GraphProto
&
graph
)
{
{
nodes
=
get_nodes
(
graph
);
nodes
=
get_nodes
(
graph
);
...
@@ -39,7 +124,8 @@ struct onnx_parser
...
@@ -39,7 +124,8 @@ struct onnx_parser
{
{
std
::
string
name
=
input
.
name
();
std
::
string
name
=
input
.
name
();
// TODO: Get shape of input parameter
// TODO: Get shape of input parameter
instructions
[
name
]
=
prog
->
add_parameter
(
name
,
rtg
::
shape
{});
rtg
::
shape
s
=
parse_type
(
input
.
type
());
instructions
[
name
]
=
prog
->
add_parameter
(
name
,
s
);
}
}
for
(
auto
&&
p
:
nodes
)
for
(
auto
&&
p
:
nodes
)
{
{
...
@@ -66,11 +152,18 @@ struct onnx_parser
...
@@ -66,11 +152,18 @@ struct onnx_parser
args
.
push_back
(
instructions
.
at
(
input
));
args
.
push_back
(
instructions
.
at
(
input
));
}
}
}
}
instructions
[
name
]
=
prog
->
add_instruction
(
unknown
{
node
.
op_type
()},
args
);
if
(
ops
.
count
(
node
.
op_type
())
==
0
)
{
instructions
[
name
]
=
prog
->
add_instruction
(
unknown
{
node
.
op_type
()},
args
);
}
else
{
instructions
[
name
]
=
ops
[
node
.
op_type
()](
get_attributes
(
node
),
args
);
}
}
}
}
}
static
std
::
unordered_map
<
std
::
string
,
onnx
::
A
ttribute
Proto
>
get_attributes
(
const
onnx
::
NodeProto
&
node
)
static
a
ttribute
_map
get_attributes
(
const
onnx
::
NodeProto
&
node
)
{
{
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
result
;
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
result
;
for
(
auto
&&
attr
:
node
.
attribute
())
for
(
auto
&&
attr
:
node
.
attribute
())
...
@@ -80,7 +173,7 @@ struct onnx_parser
...
@@ -80,7 +173,7 @@ struct onnx_parser
return
result
;
return
result
;
}
}
static
std
::
u
no
r
de
red
_map
<
std
::
string
,
onnx
::
NodeProto
>
get_nodes
(
const
onnx
::
GraphProto
&
graph
)
static
node_map
get_nodes
(
const
onnx
::
GraphProto
&
graph
)
{
{
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
result
;
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
result
;
for
(
auto
&&
node
:
graph
.
node
())
for
(
auto
&&
node
:
graph
.
node
())
...
@@ -94,21 +187,80 @@ struct onnx_parser
...
@@ -94,21 +187,80 @@ struct onnx_parser
}
}
return
result
;
return
result
;
}
}
};
std
::
shared_ptr
<
rtg
::
program
>
parse_onnx
(
std
::
istream
&
is
)
static
rtg
::
literal
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
{
{
onnx_parser
parser
;
switch
(
attr
.
type
())
onnx
::
ModelProto
model
;
{
if
(
model
.
ParseFromIstream
(
&
is
))
{
case
onnx
::
AttributeProto
::
UNDEFINED
:
return
{};
if
(
model
.
has_graph
())
{
case
onnx
::
AttributeProto
::
FLOAT
:
return
rtg
::
literal
{
attr
.
f
()};
parser
.
parse_graph
(
model
.
graph
());
case
onnx
::
AttributeProto
::
INT
:
return
rtg
::
literal
{
attr
.
i
()};
case
onnx
::
AttributeProto
::
STRING
:
return
{};
case
onnx
::
AttributeProto
::
TENSOR
:
return
parse_tensor
(
attr
.
t
());
case
onnx
::
AttributeProto
::
GRAPH
:
return
{};
case
onnx
::
AttributeProto
::
FLOATS
:
return
rtg
::
literal
{
rtg
::
shape
::
float_type
,
attr
.
floats
().
begin
(),
attr
.
floats
().
end
()};
case
onnx
::
AttributeProto
::
INTS
:
return
rtg
::
literal
{
rtg
::
shape
::
int32_type
,
attr
.
ints
().
begin
(),
attr
.
ints
().
end
()};;
case
onnx
::
AttributeProto
::
STRINGS
:
return
{};
case
onnx
::
AttributeProto
::
TENSORS
:
return
{};
case
onnx
::
AttributeProto
::
GRAPHS
:
return
{};
}
}
static
rtg
::
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
{
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
switch
(
t
.
data_type
())
{
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT
:
return
rtg
::
literal
{{
rtg
::
shape
::
float_type
,
dims
},
t
.
float_data
().
begin
(),
t
.
float_data
().
end
()};
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
INT8
:
return
rtg
::
literal
{{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
UINT16
:
return
rtg
::
literal
{{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT16
:
return
rtg
::
literal
{{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT32
:
return
rtg
::
literal
{{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT64
:
return
rtg
::
literal
{{
rtg
::
shape
::
int64_type
,
dims
},
t
.
int64_data
().
begin
(),
t
.
int64_data
().
end
()};
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
BOOL
:
return
rtg
::
literal
{{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
FLOAT16
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
DOUBLE
:
return
rtg
::
literal
{{
rtg
::
shape
::
double_type
,
dims
},
t
.
double_data
().
begin
(),
t
.
double_data
().
end
()};
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX128
:
throw
std
::
runtime_error
(
""
);
}
}
}
else
{
throw
"Failed reading"
;
}
}
return
parser
.
prog
;
}
static
rtg
::
shape
parse_type
(
const
onnx
::
TypeProto
&
t
)
{
rtg
::
shape
::
type_t
shape_type
;
switch
(
t
.
tensor_type
().
elem_type
())
{
case
onnx
::
TensorProto
::
UNDEFINED
:
break
;
//throw std::runtime_error("Unsupported type UNDEFINED");
case
onnx
::
TensorProto
::
FLOAT
:
shape_type
=
rtg
::
shape
::
float_type
;
case
onnx
::
TensorProto
::
UINT8
:
break
;
//throw std::runtime_error("Unsupported type UINT8");
case
onnx
::
TensorProto
::
INT8
:
shape_type
=
rtg
::
shape
::
int8_type
;
case
onnx
::
TensorProto
::
UINT16
:
shape_type
=
rtg
::
shape
::
uint16_type
;
case
onnx
::
TensorProto
::
INT16
:
shape_type
=
rtg
::
shape
::
int16_type
;
case
onnx
::
TensorProto
::
INT32
:
shape_type
=
rtg
::
shape
::
int32_type
;
case
onnx
::
TensorProto
::
INT64
:
shape_type
=
rtg
::
shape
::
int64_type
;
case
onnx
::
TensorProto
::
STRING
:
break
;
//throw std::runtime_error("Unsupported type STRING");
case
onnx
::
TensorProto
::
BOOL
:
break
;
//throw std::runtime_error("Unsupported type BOOL");
case
onnx
::
TensorProto
::
FLOAT16
:
break
;
//throw std::runtime_error("Unsupported type FLOAT16");
case
onnx
::
TensorProto
::
DOUBLE
:
shape_type
=
rtg
::
shape
::
double_type
;
case
onnx
::
TensorProto
::
UINT32
:
shape_type
=
rtg
::
shape
::
uint32_type
;
case
onnx
::
TensorProto
::
UINT64
:
shape_type
=
rtg
::
shape
::
uint64_type
;
case
onnx
::
TensorProto
::
COMPLEX64
:
break
;
//throw std::runtime_error("Unsupported type COMPLEX64");
case
onnx
::
TensorProto
::
COMPLEX128
:
break
;
//throw std::runtime_error("Unsupported type COMPLEX128");
}
std
::
vector
<
std
::
size_t
>
dims
;
// TODO: USe std::transform
for
(
auto
&&
d
:
t
.
tensor_type
().
shape
().
dim
())
{
dims
.
push_back
(
d
.
dim_value
());
}
return
{
shape_type
,
dims
};
}
};
int
main
(
int
argc
,
char
const
*
argv
[])
int
main
(
int
argc
,
char
const
*
argv
[])
{
{
...
@@ -116,7 +268,16 @@ int main(int argc, char const *argv[])
...
@@ -116,7 +268,16 @@ int main(int argc, char const *argv[])
{
{
std
::
string
file
=
argv
[
1
];
std
::
string
file
=
argv
[
1
];
std
::
fstream
input
(
file
.
c_str
(),
std
::
ios
::
in
|
std
::
ios
::
binary
);
std
::
fstream
input
(
file
.
c_str
(),
std
::
ios
::
in
|
std
::
ios
::
binary
);
auto
prog
=
parse_onnx
(
input
);
onnx_parser
parser
;
prog
->
print
();
try
{
parser
.
parse_from
(
input
);
}
catch
(...)
{
if
(
parser
.
prog
)
parser
.
prog
->
print
();
throw
;
}
parser
.
prog
->
print
();
}
}
}
}
test/eval_test.cpp
View file @
9f046d67
...
@@ -48,8 +48,8 @@ void literal_test() {
...
@@ -48,8 +48,8 @@ void literal_test() {
void
param_test
()
{
void
param_test
()
{
rtg
::
program
p
;
rtg
::
program
p
;
auto
x
=
p
.
add_parameter
(
"x"
,
{
rtg
::
shape
::
int_type
});
auto
x
=
p
.
add_parameter
(
"x"
,
{
rtg
::
shape
::
int
64
_type
});
auto
y
=
p
.
add_parameter
(
"y"
,
{
rtg
::
shape
::
int_type
});
auto
y
=
p
.
add_parameter
(
"y"
,
{
rtg
::
shape
::
int
64
_type
});
p
.
add_instruction
(
sum_op
{},
x
,
y
);
p
.
add_instruction
(
sum_op
{},
x
,
y
);
auto
result
=
p
.
eval
({
auto
result
=
p
.
eval
({
...
...
test/literal_test.cpp
View file @
9f046d67
...
@@ -44,7 +44,7 @@ void literal_os2()
...
@@ -44,7 +44,7 @@ void literal_os2()
void
literal_os3
()
void
literal_os3
()
{
{
rtg
::
shape
s
{
rtg
::
shape
::
int_type
,
{
3
}};
rtg
::
shape
s
{
rtg
::
shape
::
int
64
_type
,
{
3
}};
rtg
::
literal
l
{
s
,
{
1
,
2
,
3
}};
rtg
::
literal
l
{
s
,
{
1
,
2
,
3
}};
std
::
stringstream
ss
;
std
::
stringstream
ss
;
ss
<<
l
;
ss
<<
l
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment