Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9f046d67
Commit
9f046d67
authored
Apr 19, 2018
by
Paul
Browse files
Parse onnx and convert to internal ir
parent
2f8e4e83
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
252 additions
and
50 deletions
+252
-50
include/rtg/literal.hpp
include/rtg/literal.hpp
+16
-2
include/rtg/operators.hpp
include/rtg/operators.hpp
+39
-23
include/rtg/shape.hpp
include/rtg/shape.hpp
+9
-1
include/rtg/stringutils.hpp
include/rtg/stringutils.hpp
+1
-1
src/program.cpp
src/program.cpp
+4
-1
src/read_onnx.cpp
src/read_onnx.cpp
+180
-19
test/eval_test.cpp
test/eval_test.cpp
+2
-2
test/literal_test.cpp
test/literal_test.cpp
+1
-1
No files found.
include/rtg/literal.hpp
View file @
9f046d67
...
@@ -28,7 +28,9 @@ struct literal : raw_data<literal>
...
@@ -28,7 +28,9 @@ struct literal : raw_data<literal>
{
{
assert
(
s
.
packed
());
assert
(
s
.
packed
());
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
std
::
copy
(
x
.
begin
(),
x
.
end
(),
reinterpret_cast
<
T
*>
(
buffer
.
data
()));
s
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
x
.
begin
(),
x
.
end
(),
as
.
from
(
buffer
.
data
()));
});
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -37,7 +39,19 @@ struct literal : raw_data<literal>
...
@@ -37,7 +39,19 @@ struct literal : raw_data<literal>
{
{
assert
(
s
.
packed
());
assert
(
s
.
packed
());
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
std
::
copy
(
x
.
begin
(),
x
.
end
(),
reinterpret_cast
<
T
*>
(
buffer
.
data
()));
s
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
x
.
begin
(),
x
.
end
(),
as
.
from
(
buffer
.
data
()));
});
}
template
<
class
Iterator
>
literal
(
shape
s
,
Iterator
start
,
Iterator
end
)
:
buffer
(
s
.
bytes
(),
0
),
shape_
(
s
)
{
assert
(
s
.
packed
());
s
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
data
()));
});
}
}
literal
(
shape
s
,
const
char
*
x
)
literal
(
shape
s
,
const
char
*
x
)
...
...
include/rtg/operators.hpp
View file @
9f046d67
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include <rtg/operand.hpp>
#include <rtg/operand.hpp>
#include <rtg/stringutils.hpp>
#include <rtg/stringutils.hpp>
#include <cmath>
namespace
rtg
{
namespace
rtg
{
...
@@ -10,11 +11,11 @@ struct not_computable
...
@@ -10,11 +11,11 @@ struct not_computable
{
{
argument
compute
(
std
::
vector
<
argument
>
)
const
argument
compute
(
std
::
vector
<
argument
>
)
const
{
{
throw
"not computable"
;
throw
std
::
runtime_error
(
"not computable"
)
;
}
}
};
};
struct
convolution
:
not_computable
struct
convolution
{
{
std
::
array
<
std
::
size_t
,
2
>
padding
=
{
0
,
0
};
std
::
array
<
std
::
size_t
,
2
>
padding
=
{
0
,
0
};
std
::
array
<
std
::
size_t
,
2
>
stride
=
{
1
,
1
};
std
::
array
<
std
::
size_t
,
2
>
stride
=
{
1
,
1
};
...
@@ -28,26 +29,31 @@ struct convolution : not_computable
...
@@ -28,26 +29,31 @@ struct convolution : not_computable
}
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
if
(
inputs
.
size
()
!=
2
)
throw
"Wrong number of arguments"
;
if
(
inputs
.
size
()
!=
2
)
throw
std
::
runtime_error
(
"Wrong number of arguments"
)
;
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
weights
=
inputs
.
at
(
1
);
const
shape
&
weights
=
inputs
.
at
(
1
);
if
(
input
.
type
()
!=
weights
.
type
())
throw
"Type doesn't match"
;
if
(
input
.
type
()
!=
weights
.
type
())
throw
std
::
runtime_error
(
"Type doesn't match"
)
;
if
(
input
.
size
()
!=
weights
.
size
())
throw
"Dimensions don't match"
;
if
(
input
.
lens
().
size
()
!=
weights
.
lens
().
size
())
throw
std
::
runtime_error
(
"Dimensions don't match"
)
;
if
(
input
.
size
()
!=
4
)
throw
"Only 4d convolution supported"
;
if
(
input
.
lens
().
size
()
!=
4
)
throw
std
::
runtime_error
(
"Only 4d convolution supported"
)
;
auto
t
=
input
.
type
();
auto
t
=
input
.
type
();
return
{
t
,
{
return
{
t
,
{
input
[
0
],
input
.
lens
()
[
0
],
weights
[
0
],
weights
.
lens
()
[
0
],
std
::
max
<
std
::
ptrdiff_t
>
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
input
[
2
]
-
(
1
+
dilation
[
0
]
*
(
weights
[
2
]
-
1
))
+
2
*
padding
[
0
])
/
stride
[
0
]
+
1
),
1
,
(
input
.
lens
()
[
2
]
-
(
1
+
dilation
[
0
]
*
(
weights
.
lens
()
[
2
]
-
1
))
+
2
*
padding
[
0
])
/
stride
[
0
]
+
1
)
)
,
std
::
max
<
std
::
ptrdiff_t
>
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
input
[
3
]
-
(
1
+
dilation
[
1
]
*
(
weights
[
3
]
-
1
))
+
2
*
padding
[
1
])
/
stride
[
1
]
+
1
),
1
,
(
input
.
lens
()
[
3
]
-
(
1
+
dilation
[
1
]
*
(
weights
.
lens
()
[
3
]
-
1
))
+
2
*
padding
[
1
])
/
stride
[
1
]
+
1
)
)
,
}};
}};
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
std
::
runtime_error
(
"not computable"
);
}
};
};
struct
pooling
:
not_computable
struct
pooling
{
{
std
::
string
mode
;
std
::
string
mode
;
std
::
array
<
std
::
size_t
,
2
>
padding
=
{
0
,
0
};
std
::
array
<
std
::
size_t
,
2
>
padding
=
{
0
,
0
};
...
@@ -62,24 +68,29 @@ struct pooling : not_computable
...
@@ -62,24 +68,29 @@ struct pooling : not_computable
}
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
if
(
!
inputs
.
empty
())
throw
"Wrong number of arguments"
;
if
(
inputs
.
empty
())
throw
std
::
runtime_error
(
"Wrong number of arguments"
)
;
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
input
=
inputs
.
at
(
0
);
if
(
input
.
size
()
!=
4
)
throw
"Only 4d pooling supported"
;
if
(
input
.
lens
().
size
()
!=
4
)
throw
std
::
runtime_error
(
"Only 4d pooling supported"
)
;
auto
t
=
input
.
type
();
auto
t
=
input
.
type
();
return
{
t
,
{
return
{
t
,
{
input
[
0
],
input
.
lens
()
[
0
],
input
[
1
],
input
.
lens
()
[
1
],
std
::
max
<
std
::
ptrdiff_t
>
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
std
::
ceil
((
input
[
3
]
+
2
*
padding
[
0
]
-
lengths
[
0
])
/
static_cast
<
float
>
(
stride
[
0
]))
+
1
),
1
,
std
::
ceil
((
input
.
lens
()
[
3
]
+
2
*
padding
[
0
]
-
lengths
[
0
])
/
static_cast
<
float
>
(
stride
[
0
]))
+
1
)
)
,
std
::
max
<
std
::
ptrdiff_t
>
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
std
::
ceil
((
input
[
4
]
+
2
*
padding
[
1
]
-
lengths
[
1
])
/
static_cast
<
float
>
(
stride
[
1
]))
+
1
),
1
,
std
::
ceil
((
input
.
lens
()
[
4
]
+
2
*
padding
[
1
]
-
lengths
[
1
])
/
static_cast
<
float
>
(
stride
[
1
]))
+
1
)
)
,
}};
}};
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
std
::
runtime_error
(
"not computable"
);
}
};
};
struct
activation
:
not_computable
struct
activation
{
{
std
::
string
mode
;
std
::
string
mode
;
std
::
string
name
()
const
std
::
string
name
()
const
...
@@ -88,9 +99,14 @@ struct activation : not_computable
...
@@ -88,9 +99,14 @@ struct activation : not_computable
}
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
if
(
!
inputs
.
empty
())
throw
"Wrong number of arguments"
;
if
(
inputs
.
empty
())
throw
std
::
runtime_error
(
"Wrong number of arguments"
)
;
return
inputs
.
front
();
return
inputs
.
front
();
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
std
::
runtime_error
(
"not computable"
);
}
};
};
...
...
include/rtg/shape.hpp
View file @
9f046d67
...
@@ -13,7 +13,15 @@ struct shape
...
@@ -13,7 +13,15 @@ struct shape
// Add new types here
// Add new types here
#define RTG_SHAPE_VISIT_TYPES(m) \
#define RTG_SHAPE_VISIT_TYPES(m) \
m(float_type, float) \
m(float_type, float) \
m(int_type, int) \
m(double_type, double) \
m(uint8_type, uint8_t) \
m(int8_type, int8_t) \
m(uint16_type, uint16_t) \
m(int16_type, int16_t) \
m(int32_type, int32_t) \
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \
#define RTG_SHAPE_ENUM_TYPES(x, t) x,
#define RTG_SHAPE_ENUM_TYPES(x, t) x,
enum
type_t
enum
type_t
...
...
include/rtg/stringutils.hpp
View file @
9f046d67
...
@@ -72,7 +72,7 @@ inline std::string to_string(const Range& r)
...
@@ -72,7 +72,7 @@ inline std::string to_string(const Range& r)
if
(
!
r
.
empty
())
if
(
!
r
.
empty
())
{
{
ss
<<
r
.
front
();
ss
<<
r
.
front
();
std
::
for_each
(
++
r
.
begin
(),
r
.
end
(),
[
&
](
auto
&&
x
)
std
::
for_each
(
std
::
next
(
r
.
begin
()
)
,
r
.
end
(),
[
&
](
auto
&&
x
)
{
{
ss
<<
", "
<<
x
;
ss
<<
", "
<<
x
;
});
});
...
...
src/program.cpp
View file @
9f046d67
...
@@ -51,7 +51,10 @@ void program::print() const
...
@@ -51,7 +51,10 @@ void program::print() const
if
(
ins
.
op
.
name
()
==
"@literal"
)
if
(
ins
.
op
.
name
()
==
"@literal"
)
{
{
std
::
cout
<<
"{"
<<
ins
.
lit
<<
"}"
;
if
(
ins
.
lit
.
get_shape
().
elements
()
>
10
)
std
::
cout
<<
"{ ... }"
;
else
std
::
cout
<<
"{"
<<
ins
.
lit
<<
"}"
;
}
}
if
(
!
ins
.
arguments
.
empty
())
if
(
!
ins
.
arguments
.
empty
())
...
...
src/read_onnx.cpp
View file @
9f046d67
...
@@ -5,8 +5,10 @@
...
@@ -5,8 +5,10 @@
#include <iostream>
#include <iostream>
#include <fstream>
#include <fstream>
#include <unordered_map>
#include <unordered_map>
#include <functional>
#include <rtg/program.hpp>
#include <rtg/program.hpp>
#include <rtg/operators.hpp>
struct
unknown
struct
unknown
{
{
...
@@ -26,12 +28,95 @@ struct unknown
...
@@ -26,12 +28,95 @@ struct unknown
}
}
};
};
template
<
class
C
,
class
T
>
bool
contains
(
C
&&
c
,
T
&&
x
)
{
return
c
.
find
(
x
)
!=
c
.
end
();
}
template
<
class
Range
,
class
Iterator
>
void
copy
(
Range
&&
r
,
Iterator
it
)
{
std
::
copy
(
r
.
begin
(),
r
.
end
(),
it
);
}
struct
onnx_parser
struct
onnx_parser
{
{
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
nodes
;
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
;
using
node_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
;
node_map
nodes
;
std
::
unordered_map
<
std
::
string
,
rtg
::
instruction
*>
instructions
;
std
::
unordered_map
<
std
::
string
,
rtg
::
instruction
*>
instructions
;
std
::
shared_ptr
<
rtg
::
program
>
prog
=
std
::
make_shared
<
rtg
::
program
>
();
std
::
shared_ptr
<
rtg
::
program
>
prog
=
std
::
make_shared
<
rtg
::
program
>
();
std
::
unordered_map
<
std
::
string
,
std
::
function
<
rtg
::
instruction
*
(
attribute_map
,
std
::
vector
<
rtg
::
instruction
*>
)
>>
ops
;
onnx_parser
()
{
add_op
(
"Conv"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction
*>
args
)
{
rtg
::
convolution
op
;
if
(
contains
(
attributes
,
"pads"
))
{
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
}
if
(
contains
(
attributes
,
"strides"
))
{
copy
(
attributes
[
"strides"
].
ints
(),
op
.
stride
.
begin
());
}
if
(
contains
(
attributes
,
"dilations"
))
{
copy
(
attributes
[
"dilations"
].
ints
(),
op
.
dilation
.
begin
());
}
return
prog
->
add_instruction
(
op
,
args
);
});
add_op
(
"MaxPool"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction
*>
args
)
{
rtg
::
pooling
op
{
"max"
};
// for(auto&& p:attributes) std::cout << p.first << std::endl;
if
(
contains
(
attributes
,
"pads"
))
{
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
}
if
(
contains
(
attributes
,
"strides"
))
{
copy
(
attributes
[
"strides"
].
ints
(),
op
.
stride
.
begin
());
}
if
(
contains
(
attributes
,
"kernel_shape"
))
{
copy
(
attributes
[
"kernel_shape"
].
ints
(),
op
.
lengths
.
begin
());
}
return
prog
->
add_instruction
(
op
,
args
);
});
add_op
(
"Relu"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction
*>
args
)
{
return
prog
->
add_instruction
(
rtg
::
activation
{
"relu"
},
args
);
});
add_op
(
"Constant"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction
*>
)
{
rtg
::
literal
v
=
parse_value
(
attributes
.
at
(
"value"
));
return
prog
->
add_literal
(
v
);
});
}
template
<
class
F
>
void
add_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
f
);
}
void
parse_from
(
std
::
istream
&
is
)
{
onnx
::
ModelProto
model
;
if
(
model
.
ParseFromIstream
(
&
is
))
{
if
(
model
.
has_graph
())
{
this
->
parse_graph
(
model
.
graph
());
}
}
else
{
throw
std
::
runtime_error
(
"Failed reading"
);
}
}
void
parse_graph
(
const
onnx
::
GraphProto
&
graph
)
void
parse_graph
(
const
onnx
::
GraphProto
&
graph
)
{
{
nodes
=
get_nodes
(
graph
);
nodes
=
get_nodes
(
graph
);
...
@@ -39,7 +124,8 @@ struct onnx_parser
...
@@ -39,7 +124,8 @@ struct onnx_parser
{
{
std
::
string
name
=
input
.
name
();
std
::
string
name
=
input
.
name
();
// TODO: Get shape of input parameter
// TODO: Get shape of input parameter
instructions
[
name
]
=
prog
->
add_parameter
(
name
,
rtg
::
shape
{});
rtg
::
shape
s
=
parse_type
(
input
.
type
());
instructions
[
name
]
=
prog
->
add_parameter
(
name
,
s
);
}
}
for
(
auto
&&
p
:
nodes
)
for
(
auto
&&
p
:
nodes
)
{
{
...
@@ -66,11 +152,18 @@ struct onnx_parser
...
@@ -66,11 +152,18 @@ struct onnx_parser
args
.
push_back
(
instructions
.
at
(
input
));
args
.
push_back
(
instructions
.
at
(
input
));
}
}
}
}
instructions
[
name
]
=
prog
->
add_instruction
(
unknown
{
node
.
op_type
()},
args
);
if
(
ops
.
count
(
node
.
op_type
())
==
0
)
{
instructions
[
name
]
=
prog
->
add_instruction
(
unknown
{
node
.
op_type
()},
args
);
}
else
{
instructions
[
name
]
=
ops
[
node
.
op_type
()](
get_attributes
(
node
),
args
);
}
}
}
}
}
static
std
::
unordered_map
<
std
::
string
,
onnx
::
A
ttribute
Proto
>
get_attributes
(
const
onnx
::
NodeProto
&
node
)
static
a
ttribute
_map
get_attributes
(
const
onnx
::
NodeProto
&
node
)
{
{
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
result
;
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
result
;
for
(
auto
&&
attr
:
node
.
attribute
())
for
(
auto
&&
attr
:
node
.
attribute
())
...
@@ -80,7 +173,7 @@ struct onnx_parser
...
@@ -80,7 +173,7 @@ struct onnx_parser
return
result
;
return
result
;
}
}
static
std
::
u
no
r
de
red
_map
<
std
::
string
,
onnx
::
NodeProto
>
get_nodes
(
const
onnx
::
GraphProto
&
graph
)
static
node_map
get_nodes
(
const
onnx
::
GraphProto
&
graph
)
{
{
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
result
;
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
result
;
for
(
auto
&&
node
:
graph
.
node
())
for
(
auto
&&
node
:
graph
.
node
())
...
@@ -94,21 +187,80 @@ struct onnx_parser
...
@@ -94,21 +187,80 @@ struct onnx_parser
}
}
return
result
;
return
result
;
}
}
};
std
::
shared_ptr
<
rtg
::
program
>
parse_onnx
(
std
::
istream
&
is
)
static
rtg
::
literal
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
{
{
onnx_parser
parser
;
switch
(
attr
.
type
())
onnx
::
ModelProto
model
;
{
if
(
model
.
ParseFromIstream
(
&
is
))
{
case
onnx
::
AttributeProto
::
UNDEFINED
:
return
{};
if
(
model
.
has_graph
())
{
case
onnx
::
AttributeProto
::
FLOAT
:
return
rtg
::
literal
{
attr
.
f
()};
parser
.
parse_graph
(
model
.
graph
());
case
onnx
::
AttributeProto
::
INT
:
return
rtg
::
literal
{
attr
.
i
()};
case
onnx
::
AttributeProto
::
STRING
:
return
{};
case
onnx
::
AttributeProto
::
TENSOR
:
return
parse_tensor
(
attr
.
t
());
case
onnx
::
AttributeProto
::
GRAPH
:
return
{};
case
onnx
::
AttributeProto
::
FLOATS
:
return
rtg
::
literal
{
rtg
::
shape
::
float_type
,
attr
.
floats
().
begin
(),
attr
.
floats
().
end
()};
case
onnx
::
AttributeProto
::
INTS
:
return
rtg
::
literal
{
rtg
::
shape
::
int32_type
,
attr
.
ints
().
begin
(),
attr
.
ints
().
end
()};;
case
onnx
::
AttributeProto
::
STRINGS
:
return
{};
case
onnx
::
AttributeProto
::
TENSORS
:
return
{};
case
onnx
::
AttributeProto
::
GRAPHS
:
return
{};
}
}
static
rtg
::
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
{
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
switch
(
t
.
data_type
())
{
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT
:
return
rtg
::
literal
{{
rtg
::
shape
::
float_type
,
dims
},
t
.
float_data
().
begin
(),
t
.
float_data
().
end
()};
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
INT8
:
return
rtg
::
literal
{{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
UINT16
:
return
rtg
::
literal
{{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT16
:
return
rtg
::
literal
{{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT32
:
return
rtg
::
literal
{{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT64
:
return
rtg
::
literal
{{
rtg
::
shape
::
int64_type
,
dims
},
t
.
int64_data
().
begin
(),
t
.
int64_data
().
end
()};
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
BOOL
:
return
rtg
::
literal
{{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
FLOAT16
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
DOUBLE
:
return
rtg
::
literal
{{
rtg
::
shape
::
double_type
,
dims
},
t
.
double_data
().
begin
(),
t
.
double_data
().
end
()};
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX128
:
throw
std
::
runtime_error
(
""
);
}
}
}
else
{
throw
"Failed reading"
;
}
}
return
parser
.
prog
;
}
static
rtg
::
shape
parse_type
(
const
onnx
::
TypeProto
&
t
)
{
rtg
::
shape
::
type_t
shape_type
;
switch
(
t
.
tensor_type
().
elem_type
())
{
case
onnx
::
TensorProto
::
UNDEFINED
:
break
;
//throw std::runtime_error("Unsupported type UNDEFINED");
case
onnx
::
TensorProto
::
FLOAT
:
shape_type
=
rtg
::
shape
::
float_type
;
case
onnx
::
TensorProto
::
UINT8
:
break
;
//throw std::runtime_error("Unsupported type UINT8");
case
onnx
::
TensorProto
::
INT8
:
shape_type
=
rtg
::
shape
::
int8_type
;
case
onnx
::
TensorProto
::
UINT16
:
shape_type
=
rtg
::
shape
::
uint16_type
;
case
onnx
::
TensorProto
::
INT16
:
shape_type
=
rtg
::
shape
::
int16_type
;
case
onnx
::
TensorProto
::
INT32
:
shape_type
=
rtg
::
shape
::
int32_type
;
case
onnx
::
TensorProto
::
INT64
:
shape_type
=
rtg
::
shape
::
int64_type
;
case
onnx
::
TensorProto
::
STRING
:
break
;
//throw std::runtime_error("Unsupported type STRING");
case
onnx
::
TensorProto
::
BOOL
:
break
;
//throw std::runtime_error("Unsupported type BOOL");
case
onnx
::
TensorProto
::
FLOAT16
:
break
;
//throw std::runtime_error("Unsupported type FLOAT16");
case
onnx
::
TensorProto
::
DOUBLE
:
shape_type
=
rtg
::
shape
::
double_type
;
case
onnx
::
TensorProto
::
UINT32
:
shape_type
=
rtg
::
shape
::
uint32_type
;
case
onnx
::
TensorProto
::
UINT64
:
shape_type
=
rtg
::
shape
::
uint64_type
;
case
onnx
::
TensorProto
::
COMPLEX64
:
break
;
//throw std::runtime_error("Unsupported type COMPLEX64");
case
onnx
::
TensorProto
::
COMPLEX128
:
break
;
//throw std::runtime_error("Unsupported type COMPLEX128");
}
std
::
vector
<
std
::
size_t
>
dims
;
// TODO: USe std::transform
for
(
auto
&&
d
:
t
.
tensor_type
().
shape
().
dim
())
{
dims
.
push_back
(
d
.
dim_value
());
}
return
{
shape_type
,
dims
};
}
};
int
main
(
int
argc
,
char
const
*
argv
[])
int
main
(
int
argc
,
char
const
*
argv
[])
{
{
...
@@ -116,7 +268,16 @@ int main(int argc, char const *argv[])
...
@@ -116,7 +268,16 @@ int main(int argc, char const *argv[])
{
{
std
::
string
file
=
argv
[
1
];
std
::
string
file
=
argv
[
1
];
std
::
fstream
input
(
file
.
c_str
(),
std
::
ios
::
in
|
std
::
ios
::
binary
);
std
::
fstream
input
(
file
.
c_str
(),
std
::
ios
::
in
|
std
::
ios
::
binary
);
auto
prog
=
parse_onnx
(
input
);
onnx_parser
parser
;
prog
->
print
();
try
{
parser
.
parse_from
(
input
);
}
catch
(...)
{
if
(
parser
.
prog
)
parser
.
prog
->
print
();
throw
;
}
parser
.
prog
->
print
();
}
}
}
}
test/eval_test.cpp
View file @
9f046d67
...
@@ -48,8 +48,8 @@ void literal_test() {
...
@@ -48,8 +48,8 @@ void literal_test() {
void
param_test
()
{
void
param_test
()
{
rtg
::
program
p
;
rtg
::
program
p
;
auto
x
=
p
.
add_parameter
(
"x"
,
{
rtg
::
shape
::
int_type
});
auto
x
=
p
.
add_parameter
(
"x"
,
{
rtg
::
shape
::
int
64
_type
});
auto
y
=
p
.
add_parameter
(
"y"
,
{
rtg
::
shape
::
int_type
});
auto
y
=
p
.
add_parameter
(
"y"
,
{
rtg
::
shape
::
int
64
_type
});
p
.
add_instruction
(
sum_op
{},
x
,
y
);
p
.
add_instruction
(
sum_op
{},
x
,
y
);
auto
result
=
p
.
eval
({
auto
result
=
p
.
eval
({
...
...
test/literal_test.cpp
View file @
9f046d67
...
@@ -44,7 +44,7 @@ void literal_os2()
...
@@ -44,7 +44,7 @@ void literal_os2()
void
literal_os3
()
void
literal_os3
()
{
{
rtg
::
shape
s
{
rtg
::
shape
::
int_type
,
{
3
}};
rtg
::
shape
s
{
rtg
::
shape
::
int
64
_type
,
{
3
}};
rtg
::
literal
l
{
s
,
{
1
,
2
,
3
}};
rtg
::
literal
l
{
s
,
{
1
,
2
,
3
}};
std
::
stringstream
ss
;
std
::
stringstream
ss
;
ss
<<
l
;
ss
<<
l
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment