Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
14d5666b
Commit
14d5666b
authored
Apr 23, 2018
by
Paul
Browse files
Add clang formatting
parent
2305ac81
Changes
21
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
510 additions
and
450 deletions
+510
-450
.clang-format
.clang-format
+90
-0
.githooks/install
.githooks/install
+7
-0
.githooks/pre-commit
.githooks/pre-commit
+43
-0
include/rtg/argument.hpp
include/rtg/argument.hpp
+7
-15
include/rtg/builtin.hpp
include/rtg/builtin.hpp
+7
-25
include/rtg/instruction.hpp
include/rtg/instruction.hpp
+7
-5
include/rtg/literal.hpp
include/rtg/literal.hpp
+19
-42
include/rtg/operand.hpp
include/rtg/operand.hpp
+16
-14
include/rtg/operators.hpp
include/rtg/operators.hpp
+66
-70
include/rtg/program.hpp
include/rtg/program.hpp
+18
-13
include/rtg/raw_data.hpp
include/rtg/raw_data.hpp
+20
-31
include/rtg/shape.hpp
include/rtg/shape.hpp
+25
-30
include/rtg/stringutils.hpp
include/rtg/stringutils.hpp
+2
-5
include/rtg/tensor_view.hpp
include/rtg/tensor_view.hpp
+28
-45
onnx/read_onnx.cpp
onnx/read_onnx.cpp
+102
-81
src/program.cpp
src/program.cpp
+9
-9
src/shape.cpp
src/shape.cpp
+20
-36
test/eval_test.cpp
test/eval_test.cpp
+20
-21
test/literal_test.cpp
test/literal_test.cpp
+3
-5
test/main.cpp
test/main.cpp
+1
-3
No files found.
.clang-format
0 → 100644
View file @
14d5666b
---
Language: Cpp
AccessModifierOffset: 0
AlignAfterOpenBracket: Align
AlignConsecutiveAssignments: true
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: true
AlignOperands: true
AlignTrailingComments: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortBlocksOnASingleLine: true
AllowShortCaseLabelsOnASingleLine: true
AllowShortFunctionsOnASingleLine: All
AllowShortIfStatementsOnASingleLine: false
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: false
AlwaysBreakTemplateDeclarations: true
BinPackArguments: false
BinPackParameters: false
BraceWrapping:
AfterClass: true
AfterControlStatement: true
AfterEnum: true
AfterFunction: true
AfterNamespace: false
AfterObjCDeclaration: true
AfterStruct: true
AfterUnion: true
BeforeCatch: true
BeforeElse: true
IndentBraces: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Custom
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
ColumnLimit: 100
CommentPragmas: '^ IWYU pragma:'
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
ExperimentalAutoDetectBinPacking: false
ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
IncludeCategories:
- Regex: '^"(llvm|llvm-c|clang|clang-c)/'
Priority: 2
- Regex: '^(<|"(gtest|isl|json)/)'
Priority: 3
- Regex: '.*'
Priority: 1
IndentCaseLabels: false
IndentWidth: 4
IndentWrappedFunctionNames: false
KeepEmptyLinesAtTheStartOfBlocks: true
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: true
PenaltyBreakBeforeFirstCallParameter: 19
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 60
PointerAlignment: Left
ReflowComments: true
SortIncludes: false
SpaceAfterCStyleCast: false
# SpaceAfterTemplateKeyword: true
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: Never
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Cpp11
TabWidth: 8
UseTab: Never
...
.githooks/install
0 → 100755
View file @
14d5666b
#!/usr/bin/env bash
cd
$(
git rev-parse
--git-dir
)
echo
"Installing hooks..."
ln
-s
../.githooks hooks
echo
"Done!"
.githooks/pre-commit
0 → 100755
View file @
14d5666b
#!/bin/sh
#
# This pre-commit hook checks if any versions of clang-format
# are installed, and if so, uses the installed version to format
# the staged changes.
base
=
clang-format-5.0
format
=
""
# Redirect output to stderr.
exec
1>&2
# check if clang-format is installed
type
"
$base
"
>
/dev/null 2>&1
&&
format
=
"
$base
"
# no versions of clang-format are installed
if
[
-z
"
$format
"
]
then
echo
"
$base
is not installed. Pre-commit hook will not be executed."
exit
0
fi
# Do everything from top - level
cd
$(
git rev-parse
--show-toplevel
)
if
git rev-parse
--verify
HEAD
>
/dev/null 2>&1
then
against
=
HEAD
else
# Initial commit: diff against an empty tree object
against
=
16bbb57
fi
# do the formatting
for
file
in
$(
git diff-index
--cached
--name-only
$against
|
grep
-E
'\.h$|\.hpp$|\.cpp$|\.cl$|\.h\.in$|\.hpp\.in$|\.cpp\.in$'
)
do
if
[
-e
"
$file
"
]
then
echo
"
$format
$file
"
"
$format
"
-i
-style
=
file
"
$file
"
fi
done
include/rtg/argument.hpp
View file @
14d5666b
...
...
@@ -9,28 +9,20 @@ namespace rtg {
struct
argument
:
raw_data
<
argument
>
{
argument
()
{}
argument
()
{}
argument
(
shape
s
,
std
::
function
<
char
*
()
>
d
)
:
data
(
d
),
shape_
(
s
)
{}
argument
(
shape
s
,
std
::
function
<
char
*
()
>
d
)
:
data
(
d
),
shape_
(
s
)
{}
std
::
function
<
char
*
()
>
data
;
bool
empty
()
const
{
return
not
data
;
}
bool
empty
()
const
{
return
not
data
;
}
const
shape
&
get_shape
()
const
{
return
this
->
shape_
;
}
private:
const
shape
&
get_shape
()
const
{
return
this
->
shape_
;
}
private:
shape
shape_
;
};
}
}
// namespace rtg
#endif
include/rtg/builtin.hpp
View file @
14d5666b
...
...
@@ -9,38 +9,20 @@ namespace builtin {
struct
literal
{
std
::
string
name
()
const
{
return
"@literal"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
throw
"builtin"
;
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
"builtin"
;
}
std
::
string
name
()
const
{
return
"@literal"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
throw
"builtin"
;
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
"builtin"
;
}
};
struct
param
{
std
::
string
parameter
;
std
::
string
name
()
const
{
return
"@param:"
+
parameter
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
throw
"builtin"
;
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
"builtin"
;
}
std
::
string
name
()
const
{
return
"@param:"
+
parameter
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
throw
"builtin"
;
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
"builtin"
;
}
};
}
}
// namespace builtin
}
// namespace rtg
...
...
include/rtg/instruction.hpp
View file @
14d5666b
...
...
@@ -14,11 +14,13 @@ struct instruction
instruction
(
operand
o
,
shape
r
,
std
::
vector
<
instruction
*>
args
)
:
op
(
std
::
move
(
o
)),
result
(
std
::
move
(
r
)),
arguments
(
std
::
move
(
args
)),
lit
()
{}
{
}
instruction
(
literal
l
)
:
op
(
builtin
::
literal
{}),
result
(
l
.
get_shape
()),
arguments
(),
lit
(
std
::
move
(
l
))
{}
{
}
operand
op
;
shape
result
;
...
...
@@ -26,6 +28,6 @@ struct instruction
literal
lit
;
};
}
}
// namespace rtg
#endif
include/rtg/literal.hpp
View file @
14d5666b
...
...
@@ -10,68 +10,45 @@ namespace rtg {
struct
literal
:
raw_data
<
literal
>
{
literal
()
:
buffer
(),
shape_
()
{}
literal
()
:
buffer
(),
shape_
()
{}
template
<
class
T
>
literal
(
T
x
)
:
buffer
(
sizeof
(
T
),
0
),
shape_
(
shape
::
get_type
<
T
>
{})
template
<
class
T
>
literal
(
T
x
)
:
buffer
(
sizeof
(
T
),
0
),
shape_
(
shape
::
get_type
<
T
>
{})
{
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
*
(
reinterpret_cast
<
T
*>
(
buffer
.
data
()))
=
x
;
}
template
<
class
T
>
literal
(
shape
s
,
const
std
::
vector
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
shape_
(
s
)
template
<
class
T
>
literal
(
shape
s
,
const
std
::
vector
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
shape_
(
s
)
{
assert
(
s
.
packed
());
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
s
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
x
.
begin
(),
x
.
end
(),
as
.
from
(
buffer
.
data
()));
});
s
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
x
.
begin
(),
x
.
end
(),
as
.
from
(
buffer
.
data
()));
});
}
template
<
class
T
>
literal
(
shape
s
,
const
std
::
initializer_list
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
shape_
(
s
)
template
<
class
T
>
literal
(
shape
s
,
const
std
::
initializer_list
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
shape_
(
s
)
{
assert
(
s
.
packed
());
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
s
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
x
.
begin
(),
x
.
end
(),
as
.
from
(
buffer
.
data
()));
});
s
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
x
.
begin
(),
x
.
end
(),
as
.
from
(
buffer
.
data
()));
});
}
template
<
class
Iterator
>
literal
(
shape
s
,
Iterator
start
,
Iterator
end
)
:
buffer
(
s
.
bytes
(),
0
),
shape_
(
s
)
template
<
class
Iterator
>
literal
(
shape
s
,
Iterator
start
,
Iterator
end
)
:
buffer
(
s
.
bytes
(),
0
),
shape_
(
s
)
{
assert
(
s
.
packed
());
s
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
data
()));
});
s
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
data
()));
});
}
literal
(
shape
s
,
const
char
*
x
)
:
buffer
(
x
,
x
+
s
.
bytes
()),
shape_
(
s
)
{}
literal
(
shape
s
,
const
char
*
x
)
:
buffer
(
x
,
x
+
s
.
bytes
()),
shape_
(
s
)
{}
bool
empty
()
const
{
return
this
->
buffer
.
empty
();
}
bool
empty
()
const
{
return
this
->
buffer
.
empty
();
}
const
char
*
data
()
const
{
return
this
->
buffer
.
data
();
}
const
char
*
data
()
const
{
return
this
->
buffer
.
data
();
}
const
shape
&
get_shape
()
const
{
return
this
->
shape_
;
}
const
shape
&
get_shape
()
const
{
return
this
->
shape_
;
}
argument
get_argument
()
const
{
...
...
@@ -79,11 +56,11 @@ struct literal : raw_data<literal>
return
{
shape_
,
[
b
]()
mutable
{
return
b
.
data
();
}};
}
private:
private:
std
::
vector
<
char
>
buffer
;
shape
shape_
;
};
}
}
// namespace rtg
#endif
include/rtg/operand.hpp
View file @
14d5666b
...
...
@@ -12,16 +12,16 @@
namespace
rtg
{
/*
* Type-erased interface for:
*
* struct operand
* {
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* argument compute(std::vector<argument> input) const;
* };
*
*/
* Type-erased interface for:
*
* struct operand
* {
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* argument compute(std::vector<argument> input) const;
* };
*
*/
struct
operand
{
...
...
@@ -80,7 +80,8 @@ struct operand
struct
handle_type_
:
handle_base_type_
{
template
<
typename
TypeErased_U_
=
TypeErased_T_
>
handle_type_
(
TypeErased_T_
value
,
handle_type_
(
TypeErased_T_
value
,
typename
std
::
enable_if
<
std
::
is_reference
<
TypeErased_U_
>::
value
>::
type
*
=
nullptr
)
:
value_
(
value
)
{
...
...
@@ -89,7 +90,8 @@ struct operand
template
<
typename
TypeErased_U_
=
TypeErased_T_
>
handle_type_
(
TypeErased_T_
value
,
typename
std
::
enable_if
<!
std
::
is_reference
<
TypeErased_U_
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
value_
(
std
::
move
(
value
))
nullptr
)
noexcept
:
value_
(
std
::
move
(
value
))
{
}
...
...
@@ -134,6 +136,6 @@ struct operand
std
::
shared_ptr
<
handle_base_type_
>
handle_mem_var_
;
};
}
}
// namespace rtg
#endif
include/rtg/operators.hpp
View file @
14d5666b
...
...
@@ -9,10 +9,7 @@ namespace rtg {
struct
not_computable
{
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
std
::
runtime_error
(
"not computable"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
std
::
runtime_error
(
"not computable"
);
}
};
struct
convolution
...
...
@@ -22,35 +19,43 @@ struct convolution
std
::
array
<
std
::
size_t
,
2
>
dilation
=
{
1
,
1
};
std
::
string
name
()
const
{
return
"convolution[padding={"
+
to_string
(
padding
)
+
"}, stride={"
+
to_string
(
stride
)
+
"}, dilation={"
+
to_string
(
dilation
)
+
"}]"
;
return
"convolution[padding={"
+
to_string
(
padding
)
+
"}, stride={"
+
to_string
(
stride
)
+
"}, dilation={"
+
to_string
(
dilation
)
+
"}]"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
if
(
inputs
.
size
()
!=
2
)
throw
std
::
runtime_error
(
"Wrong number of arguments"
);
if
(
inputs
.
size
()
!=
2
)
throw
std
::
runtime_error
(
"Wrong number of arguments"
);
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
weights
=
inputs
.
at
(
1
);
if
(
input
.
type
()
!=
weights
.
type
())
throw
std
::
runtime_error
(
"Type doesn't match"
);
if
(
input
.
lens
().
size
()
!=
weights
.
lens
().
size
())
throw
std
::
runtime_error
(
"Dimensions don't match"
);
if
(
input
.
lens
().
size
()
!=
4
)
throw
std
::
runtime_error
(
"Only 4d convolution supported"
);
if
(
input
.
type
()
!=
weights
.
type
())
throw
std
::
runtime_error
(
"Type doesn't match"
);
if
(
input
.
lens
().
size
()
!=
weights
.
lens
().
size
())
throw
std
::
runtime_error
(
"Dimensions don't match"
);
if
(
input
.
lens
().
size
()
!=
4
)
throw
std
::
runtime_error
(
"Only 4d convolution supported"
);
auto
t
=
input
.
type
();
return
{
t
,
{
return
{
t
,
{
input
.
lens
()[
0
],
weights
.
lens
()[
0
],
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
input
.
lens
()[
2
]
-
(
1
+
dilation
[
0
]
*
(
weights
.
lens
()[
2
]
-
1
))
+
2
*
padding
[
0
])
/
stride
[
0
]
+
1
)),
1
,
(
input
.
lens
()[
2
]
-
(
1
+
dilation
[
0
]
*
(
weights
.
lens
()[
2
]
-
1
))
+
2
*
padding
[
0
])
/
stride
[
0
]
+
1
)),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
input
.
lens
()[
3
]
-
(
1
+
dilation
[
1
]
*
(
weights
.
lens
()[
3
]
-
1
))
+
2
*
padding
[
1
])
/
stride
[
1
]
+
1
)),
1
,
(
input
.
lens
()[
3
]
-
(
1
+
dilation
[
1
]
*
(
weights
.
lens
()[
3
]
-
1
))
+
2
*
padding
[
1
])
/
stride
[
1
]
+
1
)),
}};
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
std
::
runtime_error
(
"not computable"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
std
::
runtime_error
(
"not computable"
);
}
};
struct
pooling
...
...
@@ -61,68 +66,63 @@ struct pooling
std
::
array
<
std
::
size_t
,
2
>
lengths
=
{
1
,
1
};
std
::
string
name
()
const
{
return
"pooling:"
+
mode
+
"[padding={"
+
to_string
(
padding
)
+
"}, stride={"
+
to_string
(
stride
)
+
"}, lengths={"
+
to_string
(
lengths
)
+
"}]"
;
return
"pooling:"
+
mode
+
"[padding={"
+
to_string
(
padding
)
+
"}, stride={"
+
to_string
(
stride
)
+
"}, lengths={"
+
to_string
(
lengths
)
+
"}]"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
if
(
inputs
.
empty
())
throw
std
::
runtime_error
(
"Wrong number of arguments"
);
if
(
inputs
.
empty
())
throw
std
::
runtime_error
(
"Wrong number of arguments"
);
const
shape
&
input
=
inputs
.
at
(
0
);
if
(
input
.
lens
().
size
()
!=
4
)
throw
std
::
runtime_error
(
"Only 4d pooling supported"
);
if
(
input
.
lens
().
size
()
!=
4
)
throw
std
::
runtime_error
(
"Only 4d pooling supported"
);
auto
t
=
input
.
type
();
return
{
t
,
{
return
{
t
,
{
input
.
lens
()[
0
],
input
.
lens
()[
1
],
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
std
::
ceil
((
input
.
lens
()[
3
]
+
2
*
padding
[
0
]
-
lengths
[
0
])
/
static_cast
<
float
>
(
stride
[
0
]))
+
1
)),
1
,
std
::
ceil
((
input
.
lens
()[
3
]
+
2
*
padding
[
0
]
-
lengths
[
0
])
/
static_cast
<
float
>
(
stride
[
0
]))
+
1
)),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
std
::
ceil
((
input
.
lens
()[
4
]
+
2
*
padding
[
1
]
-
lengths
[
1
])
/
static_cast
<
float
>
(
stride
[
1
]))
+
1
)),
1
,
std
::
ceil
((
input
.
lens
()[
4
]
+
2
*
padding
[
1
]
-
lengths
[
1
])
/
static_cast
<
float
>
(
stride
[
1
]))
+
1
)),
}};
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
std
::
runtime_error
(
"not computable"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
std
::
runtime_error
(
"not computable"
);
}
};
struct
activation
{
std
::
string
mode
;
std
::
string
name
()
const
{
return
"activation:"
+
mode
;
}
std
::
string
name
()
const
{
return
"activation:"
+
mode
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
if
(
inputs
.
empty
())
throw
std
::
runtime_error
(
"Wrong number of arguments"
);
if
(
inputs
.
empty
())
throw
std
::
runtime_error
(
"Wrong number of arguments"
);
return
inputs
.
front
();
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
std
::
runtime_error
(
"not computable"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
std
::
runtime_error
(
"not computable"
);
}
};
struct
reshape
{
std
::
vector
<
int64_t
>
dims
;
std
::
string
name
()
const
{
return
"reshape[dims={"
+
to_string
(
dims
)
+
"}]"
;
}
std
::
string
name
()
const
{
return
"reshape[dims={"
+
to_string
(
dims
)
+
"}]"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
if
(
inputs
.
empty
())
throw
std
::
runtime_error
(
"Wrong number of arguments"
);
if
(
inputs
.
empty
())
throw
std
::
runtime_error
(
"Wrong number of arguments"
);
auto
&&
idims
=
inputs
.
front
().
lens
();
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
for
(
std
::
size_t
i
=
0
;
i
<
dims
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
dims
.
size
();
i
++
)
{
if
(
dims
[
i
]
==
0
)
rdims
[
i
]
=
idims
[
i
];
...
...
@@ -130,18 +130,14 @@ struct reshape
if
(
dims
.
back
()
==
-
1
)
{
rdims
.
pop_back
();
std
::
copy
(
idims
.
begin
()
+
rdims
.
size
(),
idims
.
end
(),
std
::
back_inserter
(
rdims
));
std
::
copy
(
idims
.
begin
()
+
rdims
.
size
(),
idims
.
end
(),
std
::
back_inserter
(
rdims
));
}
return
{
inputs
.
front
().
type
(),
rdims
};
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
std
::
runtime_error
(
"not computable"
);
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
std
::
runtime_error
(
"not computable"
);
}
};
}
// namespace rtg
#endif
include/rtg/program.hpp
View file @
14d5666b
...
...
@@ -17,31 +17,34 @@ struct program
program
(
const
program
&
)
=
delete
;
program
&
operator
=
(
const
program
&
)
=
delete
;
template
<
class
...
Ts
>
instruction
*
add_instruction
(
operand
op
,
Ts
*
...
args
)
template
<
class
...
Ts
>
instruction
*
add_instruction
(
operand
op
,
Ts
*
...
args
)
{
shape
r
=
op
.
compute_shape
({
args
->
result
...});
instructions
.
push_back
({
op
,
r
,
{
args
...}});
return
std
::
addressof
(
instructions
.
back
());
}
instruction
*
add_instruction
(
operand
op
,
std
::
vector
<
instruction
*>
args
)
instruction
*
add_instruction
(
operand
op
,
std
::
vector
<
instruction
*>
args
)
{
assert
(
std
::
all_of
(
args
.
begin
(),
args
.
end
(),
[
&
](
instruction
*
x
)
{
return
has_instruction
(
x
);
})
&&
"Argument is not an exisiting instruction"
);
assert
(
std
::
all_of
(
args
.
begin
(),
args
.
end
(),
[
&
](
instruction
*
x
)
{
return
has_instruction
(
x
);
})
&&
"Argument is not an exisiting instruction"
);
std
::
vector
<
shape
>
shapes
(
args
.
size
());
std
::
transform
(
args
.
begin
(),
args
.
end
(),
shapes
.
begin
(),
[](
instruction
*
ins
)
{
return
ins
->
result
;
});
std
::
transform
(
args
.
begin
(),
args
.
end
(),
shapes
.
begin
(),
[](
instruction
*
ins
)
{
return
ins
->
result
;
});
shape
r
=
op
.
compute_shape
(
shapes
);
instructions
.
push_back
({
op
,
r
,
args
});
assert
(
instructions
.
back
().
arguments
==
args
);
return
std
::
addressof
(
instructions
.
back
());
}
template
<
class
...
Ts
>
instruction
*
add_literal
(
Ts
&&
...
xs
)
template
<
class
...
Ts
>
instruction
*
add_literal
(
Ts
&&
...
xs
)
{
instructions
.
emplace_back
(
literal
{
std
::
forward
<
Ts
>
(
xs
)...});
return
std
::
addressof
(
instructions
.
back
());
}
instruction
*
add_parameter
(
std
::
string
name
,
shape
s
)
instruction
*
add_parameter
(
std
::
string
name
,
shape
s
)
{
instructions
.
push_back
({
builtin
::
param
{
std
::
move
(
name
)},
s
,
{}});
return
std
::
addressof
(
instructions
.
back
());
...
...
@@ -52,16 +55,18 @@ struct program
// TODO: Change to stream operator
void
print
()
const
;
bool
has_instruction
(
const
instruction
*
ins
)
const
bool
has_instruction
(
const
instruction
*
ins
)
const
{
return
std
::
find_if
(
instructions
.
begin
(),
instructions
.
end
(),
[
&
](
const
instruction
&
x
)
{
return
ins
==
std
::
addressof
(
x
);
})
!=
instructions
.
end
();
return
std
::
find_if
(
instructions
.
begin
(),
instructions
.
end
(),
[
&
](
const
instruction
&
x
)
{
return
ins
==
std
::
addressof
(
x
);
})
!=
instructions
.
end
();
}
private:
private:
// A list is used to keep references to an instruction stable
std
::
list
<
instruction
>
instructions
;
};
}
}
// namespace rtg
#endif
include/rtg/raw_data.hpp
View file @
14d5666b
...
...
@@ -6,7 +6,7 @@
namespace
rtg
{
template
<
class
Derived
>
template
<
class
Derived
>
struct
raw_data
{
friend
bool
operator
==
(
const
Derived
&
x
,
const
Derived
&
y
)
...
...
@@ -28,53 +28,42 @@ struct raw_data
return
result
;
}
friend
bool
operator
!=
(
const
Derived
&
x
,
const
Derived
&
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
const
Derived
&
x
,
const
Derived
&
y
)
{
return
!
(
x
==
y
);
}
template
<
class
Stream
>
template
<
class
Stream
>
friend
Stream
&
operator
<<
(
Stream
&
os
,
const
Derived
&
d
)
{
d
.
visit
([
&
](
auto
x
)
{
os
<<
x
;
});
d
.
visit
([
&
](
auto
x
)
{
os
<<
x
;
});
return
os
;
}
template
<
class
Visitor
>
void
visit_at
(
Visitor
v
,
std
::
size_t
n
=
0
)
const
template
<
class
Visitor
>
void
visit_at
(
Visitor
v
,
std
::
size_t
n
=
0
)
const
{
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
auto
&&
buffer
=
static_cast
<
const
Derived
&>
(
*
this
).
data
();
s
.
visit_type
([
&
](
auto
as
)
{
v
(
*
(
as
.
from
(
buffer
)
+
s
.
index
(
n
)));
});
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
auto
&&
buffer
=
static_cast
<
const
Derived
&>
(
*
this
).
data
();
s
.
visit_type
([
&
](
auto
as
)
{
v
(
*
(
as
.
from
(
buffer
)
+
s
.
index
(
n
)));
});
}
template
<
class
Visitor
>
template
<
class
Visitor
>
void
visit
(
Visitor
v
)
const
{
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
auto
&&
buffer
=
static_cast
<
const
Derived
&>
(
*
this
).
data
();
s
.
visit_type
([
&
](
auto
as
)
{
v
(
make_view
(
s
,
as
.
from
(
buffer
)));
});
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
auto
&&
buffer
=
static_cast
<
const
Derived
&>
(
*
this
).
data
();
s
.
visit_type
([
&
](
auto
as
)
{
v
(
make_view
(
s
,
as
.
from
(
buffer
)));
});
}
bool
single
()
const
{
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
return
s
.
elements
()
==
1
;
}
template
<
class
T
>
T
at
(
std
::
size_t
n
=
0
)
const
template
<
class
T
>
T
at
(
std
::
size_t
n
=
0
)
const
{
T
result
;
this
->
visit_at
([
&
](
auto
x
)
{
result
=
x
;
},
n
);
this
->
visit_at
([
&
](
auto
x
)
{
result
=
x
;
},
n
);
return
result
;
}
};
...
...
include/rtg/shape.hpp
View file @
14d5666b
...
...
@@ -11,6 +11,7 @@ struct shape
{
// Add new types here
// clang-format off
#define RTG_SHAPE_VISIT_TYPES(m) \
m(float_type, float) \
m(double_type, double) \
...
...
@@ -23,6 +24,7 @@ struct shape
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \
// clang-format on
#define RTG_SHAPE_ENUM_TYPES(x, t) x,
enum
type_t
{
...
...
@@ -30,12 +32,13 @@ struct shape
};
#undef RTG_SHAPE_ENUM_TYPES
template
<
class
T
,
class
=
void
>
template
<
class
T
,
class
=
void
>
struct
get_type
;
#define RTG_SHAPE_GET_TYPE(x, t) \
template<class T> \
template
<class T>
\
struct get_type<t, T> : std::integral_constant<type_t, x> \
{};
{ \
};
RTG_SHAPE_VISIT_TYPES
(
RTG_SHAPE_GET_TYPE
)
#undef RTG_SHAPE_GET_TYPE
...
...
@@ -44,7 +47,6 @@ struct shape
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
);
type_t
type
()
const
;
const
std
::
vector
<
std
::
size_t
>&
lens
()
const
;
const
std
::
vector
<
std
::
size_t
>&
strides
()
const
;
...
...
@@ -63,67 +65,60 @@ struct shape
friend
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
&
x
);
template
<
class
T
>
template
<
class
T
>
struct
as
{
using
type
=
T
;
template
<
class
U
>
template
<
class
U
>
T
operator
()(
U
u
)
const
{
return
T
(
u
);
}
template
<
class
U
>
template
<
class
U
>
T
*
operator
()(
U
*
u
)
const
{
return
static_cast
<
T
*>
(
u
);
}
template
<
class
U
>
template
<
class
U
>
const
T
*
operator
()(
const
U
*
u
)
const
{
return
static_cast
<
T
*>
(
u
);
}
T
operator
()()
const
{
return
{};
}
T
operator
()()
const
{
return
{};
}
std
::
size_t
size
(
std
::
size_t
n
=
1
)
const
{
return
sizeof
(
T
)
*
n
;
}
std
::
size_t
size
(
std
::
size_t
n
=
1
)
const
{
return
sizeof
(
T
)
*
n
;
}
template
<
class
U
>
T
*
from
(
U
*
buffer
,
std
::
size_t
n
=
0
)
const
template
<
class
U
>
T
*
from
(
U
*
buffer
,
std
::
size_t
n
=
0
)
const
{
return
reinterpret_cast
<
T
*>
(
buffer
)
+
n
;
return
reinterpret_cast
<
T
*>
(
buffer
)
+
n
;
}
template
<
class
U
>
const
T
*
from
(
const
U
*
buffer
,
std
::
size_t
n
=
0
)
const
template
<
class
U
>
const
T
*
from
(
const
U
*
buffer
,
std
::
size_t
n
=
0
)
const
{
return
reinterpret_cast
<
const
T
*>
(
buffer
)
+
n
;
return
reinterpret_cast
<
const
T
*>
(
buffer
)
+
n
;
}
};
template
<
class
Visitor
>
template
<
class
Visitor
>
void
visit_type
(
Visitor
v
)
const
{
switch
(
this
->
type_
)
{
#define RTG_SHAPE_VISITOR_CASE(x, t) \
case x: \
v(as<t>()); \
return;
case x: v(as<t>()); return;
RTG_SHAPE_VISIT_TYPES
(
RTG_SHAPE_VISITOR_CASE
)
#undef RTG_SHAPE_VISITOR_CASE
}
assert
(
true
);
}
private:
private:
type_t
type_
;
std
::
vector
<
std
::
size_t
>
lens_
;
std
::
vector
<
std
::
size_t
>
strides_
;
...
...
@@ -134,6 +129,6 @@ private:
std
::
string
type_string
()
const
;
};
}
}
// namespace rtg
#endif
include/rtg/stringutils.hpp
View file @
14d5666b
...
...
@@ -65,17 +65,14 @@ inline std::string remove_prefix(std::string s, std::string prefix)
return
s
;
}
template
<
class
Range
>
template
<
class
Range
>
inline
std
::
string
to_string
(
const
Range
&
r
)
{
std
::
stringstream
ss
;
if
(
!
r
.
empty
())
{
ss
<<
r
.
front
();
std
::
for_each
(
std
::
next
(
r
.
begin
()),
r
.
end
(),
[
&
](
auto
&&
x
)
{
ss
<<
", "
<<
x
;
});
std
::
for_each
(
std
::
next
(
r
.
begin
()),
r
.
end
(),
[
&
](
auto
&&
x
)
{
ss
<<
", "
<<
x
;
});
}
return
ss
.
str
();
}
...
...
include/rtg/tensor_view.hpp
View file @
14d5666b
...
...
@@ -8,48 +8,29 @@
namespace
rtg
{
template
<
class
T
>
template
<
class
T
>
struct
tensor_view
{
tensor_view
()
:
data_
(
nullptr
),
shape_
()
{}
tensor_view
(
shape
s
,
T
*
d
)
:
data_
(
d
),
shape_
(
s
)
{}
tensor_view
()
:
data_
(
nullptr
),
shape_
()
{}
tensor_view
(
shape
s
,
T
*
d
)
:
data_
(
d
),
shape_
(
s
)
{}
const
shape
&
get_shape
()
const
{
return
this
->
shape_
;
}
const
shape
&
get_shape
()
const
{
return
this
->
shape_
;
}
bool
empty
()
const
{
return
data_
==
nullptr
||
shape_
.
lens
().
size
()
==
0
;
}
bool
empty
()
const
{
return
data_
==
nullptr
||
shape_
.
lens
().
size
()
==
0
;
}
std
::
size_t
size
()
const
{
return
shape_
.
elements
();
}
std
::
size_t
size
()
const
{
return
shape_
.
elements
();
}
T
*
data
()
{
return
this
->
data_
;
}
T
*
data
()
{
return
this
->
data_
;
}
const
T
*
data
()
const
{
return
this
->
data_
;
}
const
T
*
data
()
const
{
return
this
->
data_
;
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
const
T
&
operator
()(
Ts
...
xs
)
const
{
return
data_
[
shape_
.
index
({
xs
...})];
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
T
&
operator
()(
Ts
...
xs
)
{
return
data_
[
shape_
.
index
({
xs
...})];
...
...
@@ -82,13 +63,13 @@ struct tensor_view
T
&
back
()
{
assert
(
!
this
->
empty
());
return
data_
[
shape_
.
index
(
this
->
size
()
-
1
)];
return
data_
[
shape_
.
index
(
this
->
size
()
-
1
)];
}
const
T
&
back
()
const
{
assert
(
!
this
->
empty
());
return
data_
[
shape_
.
index
(
this
->
size
()
-
1
)];
return
data_
[
shape_
.
index
(
this
->
size
()
-
1
)];
}
// TODO: Add iterators so it can handle nonpacked tensors
...
...
@@ -101,8 +82,10 @@ struct tensor_view
T
*
end
()
{
assert
(
this
->
shape_
.
packed
());
if
(
this
->
empty
())
return
data_
;
else
return
data_
+
this
->
size
();
if
(
this
->
empty
())
return
data_
;
else
return
data_
+
this
->
size
();
}
const
T
*
begin
()
const
...
...
@@ -114,34 +97,34 @@ struct tensor_view
const
T
*
end
()
const
{
assert
(
this
->
shape_
.
packed
());
if
(
this
->
empty
())
return
data_
;
else
return
data_
+
this
->
size
();
if
(
this
->
empty
())
return
data_
;
else
return
data_
+
this
->
size
();
}
friend
bool
operator
==
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
T
>&
y
)
{
if
(
x
.
shape_
==
y
.
shape_
)
{
for
(
std
::
size_t
i
=
0
;
i
<
x
.
shape_
.
elements
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
x
.
shape_
.
elements
();
i
++
)
{
if
(
!
float_equal
(
x
[
i
],
y
[
i
]))
return
false
;
if
(
!
float_equal
(
x
[
i
],
y
[
i
]))
return
false
;
}
return
true
;
}
return
false
;
}
friend
bool
operator
!=
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
T
>&
y
)
{
return
!
(
x
==
y
);
}
friend
bool
operator
!=
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
T
>&
y
)
{
return
!
(
x
==
y
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
tensor_view
<
T
>&
x
)
{
if
(
!
x
.
empty
())
{
os
<<
x
.
front
();
for
(
std
::
size_t
i
=
1
;
i
<
x
.
shape_
.
elements
();
i
++
)
for
(
std
::
size_t
i
=
1
;
i
<
x
.
shape_
.
elements
();
i
++
)
{
os
<<
", "
<<
x
.
data_
[
x
.
shape_
.
index
(
i
)];
}
...
...
@@ -149,12 +132,12 @@ struct tensor_view
return
os
;
}
private:
private:
T
*
data_
;
shape
shape_
;
};
template
<
class
T
>
template
<
class
T
>
tensor_view
<
T
>
make_view
(
shape
s
,
T
*
data
)
{
return
{
s
,
data
};
...
...
onnx/read_onnx.cpp
View file @
14d5666b
...
...
@@ -13,34 +13,29 @@
struct
unknown
{
std
::
string
op
;
std
::
string
name
()
const
{
return
"unknown:"
+
op
;
}
std
::
string
name
()
const
{
return
"unknown:"
+
op
;
}
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
input
)
const
{
if
(
input
.
empty
())
return
{};
else
return
input
.
front
();
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
input
)
const
{
throw
"not computable"
;
if
(
input
.
empty
())
return
{};
else
return
input
.
front
();
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
input
)
const
{
throw
"not computable"
;
}
};
template
<
class
C
,
class
T
>
template
<
class
C
,
class
T
>
bool
contains
(
C
&&
c
,
T
&&
x
)
{
return
c
.
find
(
x
)
!=
c
.
end
();
}
template
<
class
Range
,
class
Iterator
>
template
<
class
Range
,
class
Iterator
>
void
copy
(
Range
&&
r
,
Iterator
it
)
{
std
::
copy
(
r
.
begin
(),
r
.
end
(),
it
);
}
struct
onnx_parser
{
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
;
...
...
@@ -49,7 +44,10 @@ struct onnx_parser
std
::
unordered_map
<
std
::
string
,
rtg
::
instruction
*>
instructions
;
std
::
shared_ptr
<
rtg
::
program
>
prog
=
std
::
make_shared
<
rtg
::
program
>
();
std
::
unordered_map
<
std
::
string
,
std
::
function
<
rtg
::
instruction
*
(
attribute_map
,
std
::
vector
<
rtg
::
instruction
*>
)
>>
ops
;
std
::
unordered_map
<
std
::
string
,
std
::
function
<
rtg
::
instruction
*
(
attribute_map
,
std
::
vector
<
rtg
::
instruction
*>
)
>>
ops
;
onnx_parser
()
{
...
...
@@ -92,10 +90,7 @@ struct onnx_parser
add_op
(
"Reshape"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction
*>
args
)
{
rtg
::
reshape
op
;
rtg
::
literal
s
=
parse_value
(
attributes
.
at
(
"shape"
));
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
return
prog
->
add_instruction
(
op
,
args
);
});
add_op
(
"Constant"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction
*>
)
{
...
...
@@ -104,7 +99,7 @@ struct onnx_parser
});
}
template
<
class
F
>
template
<
class
F
>
void
add_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
f
);
...
...
@@ -129,14 +124,14 @@ struct onnx_parser
void
parse_graph
(
const
onnx
::
GraphProto
&
graph
)
{
nodes
=
get_nodes
(
graph
);
for
(
auto
&&
input
:
graph
.
input
())
for
(
auto
&&
input
:
graph
.
input
())
{
std
::
string
name
=
input
.
name
();
// TODO: Get shape of input parameter
rtg
::
shape
s
=
parse_type
(
input
.
type
());
instructions
[
name
]
=
prog
->
add_parameter
(
name
,
s
);
}
for
(
auto
&&
p
:
nodes
)
for
(
auto
&&
p
:
nodes
)
{
this
->
parse_node
(
p
.
second
.
name
());
}
...
...
@@ -144,11 +139,11 @@ struct onnx_parser
void
parse_node
(
std
::
string
name
)
{
if
(
instructions
.
count
(
name
)
==
0
)
if
(
instructions
.
count
(
name
)
==
0
)
{
auto
&&
node
=
nodes
.
at
(
name
);
std
::
vector
<
rtg
::
instruction
*>
args
;
for
(
auto
&&
input
:
node
.
input
())
for
(
auto
&&
input
:
node
.
input
())
{
if
(
nodes
.
count
(
input
)
>
0
)
{
...
...
@@ -161,7 +156,7 @@ struct onnx_parser
args
.
push_back
(
instructions
.
at
(
input
));
}
}
if
(
ops
.
count
(
node
.
op_type
())
==
0
)
if
(
ops
.
count
(
node
.
op_type
())
==
0
)
{
instructions
[
name
]
=
prog
->
add_instruction
(
unknown
{
node
.
op_type
()},
args
);
}
...
...
@@ -175,7 +170,7 @@ struct onnx_parser
static
attribute_map
get_attributes
(
const
onnx
::
NodeProto
&
node
)
{
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
result
;
for
(
auto
&&
attr
:
node
.
attribute
())
for
(
auto
&&
attr
:
node
.
attribute
())
{
result
[
attr
.
name
()]
=
attr
;
}
...
...
@@ -185,14 +180,13 @@ struct onnx_parser
static
node_map
get_nodes
(
const
onnx
::
GraphProto
&
graph
)
{
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
result
;
for
(
auto
&&
node
:
graph
.
node
())
for
(
auto
&&
node
:
graph
.
node
())
{
result
[
node
.
name
()]
=
node
;
for
(
auto
&&
output
:
node
.
output
())
for
(
auto
&&
output
:
node
.
output
())
{
result
[
output
]
=
node
;
}
}
return
result
;
}
...
...
@@ -207,8 +201,11 @@ struct onnx_parser
case
onnx
::
AttributeProto
::
STRING
:
return
{};
case
onnx
::
AttributeProto
::
TENSOR
:
return
parse_tensor
(
attr
.
t
());
case
onnx
::
AttributeProto
::
GRAPH
:
return
{};
case
onnx
::
AttributeProto
::
FLOATS
:
return
rtg
::
literal
{
rtg
::
shape
::
float_type
,
attr
.
floats
().
begin
(),
attr
.
floats
().
end
()};
case
onnx
::
AttributeProto
::
INTS
:
return
rtg
::
literal
{
rtg
::
shape
::
int32_type
,
attr
.
ints
().
begin
(),
attr
.
ints
().
end
()};;
case
onnx
::
AttributeProto
::
FLOATS
:
return
rtg
::
literal
{
rtg
::
shape
::
float_type
,
attr
.
floats
().
begin
(),
attr
.
floats
().
end
()};
case
onnx
::
AttributeProto
::
INTS
:
return
rtg
::
literal
{
rtg
::
shape
::
int32_type
,
attr
.
ints
().
begin
(),
attr
.
ints
().
end
()};
;
case
onnx
::
AttributeProto
::
STRINGS
:
return
{};
case
onnx
::
AttributeProto
::
TENSORS
:
return
{};
case
onnx
::
AttributeProto
::
GRAPHS
:
return
{};
...
...
@@ -221,17 +218,33 @@ struct onnx_parser
switch
(
t
.
data_type
())
{
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT
:
return
rtg
::
literal
{{
rtg
::
shape
::
float_type
,
dims
},
t
.
float_data
().
begin
(),
t
.
float_data
().
end
()};
case
onnx
::
TensorProto
::
FLOAT
:
return
rtg
::
literal
{
{
rtg
::
shape
::
float_type
,
dims
},
t
.
float_data
().
begin
(),
t
.
float_data
().
end
()};
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
INT8
:
return
rtg
::
literal
{{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
UINT16
:
return
rtg
::
literal
{{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT16
:
return
rtg
::
literal
{{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT32
:
return
rtg
::
literal
{{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT64
:
return
rtg
::
literal
{{
rtg
::
shape
::
int64_type
,
dims
},
t
.
int64_data
().
begin
(),
t
.
int64_data
().
end
()};
case
onnx
::
TensorProto
::
INT8
:
return
rtg
::
literal
{
{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
UINT16
:
return
rtg
::
literal
{
{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT16
:
return
rtg
::
literal
{
{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT32
:
return
rtg
::
literal
{
{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT64
:
return
rtg
::
literal
{
{
rtg
::
shape
::
int64_type
,
dims
},
t
.
int64_data
().
begin
(),
t
.
int64_data
().
end
()};
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
BOOL
:
return
rtg
::
literal
{{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
BOOL
:
return
rtg
::
literal
{
{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
FLOAT16
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
DOUBLE
:
return
rtg
::
literal
{{
rtg
::
shape
::
double_type
,
dims
},
t
.
double_data
().
begin
(),
t
.
double_data
().
end
()};
case
onnx
::
TensorProto
::
DOUBLE
:
return
rtg
::
literal
{
{
rtg
::
shape
::
double_type
,
dims
},
t
.
double_data
().
begin
(),
t
.
double_data
().
end
()};
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
...
...
@@ -244,26 +257,33 @@ struct onnx_parser
rtg
::
shape
::
type_t
shape_type
;
switch
(
t
.
tensor_type
().
elem_type
())
{
case
onnx
::
TensorProto
::
UNDEFINED
:
break
;
//throw std::runtime_error("Unsupported type UNDEFINED");
case
onnx
::
TensorProto
::
UNDEFINED
:
break
;
// throw std::runtime_error("Unsupported type UNDEFINED");
case
onnx
::
TensorProto
::
FLOAT
:
shape_type
=
rtg
::
shape
::
float_type
;
case
onnx
::
TensorProto
::
UINT8
:
break
;
//throw std::runtime_error("Unsupported type UINT8");
case
onnx
::
TensorProto
::
UINT8
:
break
;
// throw std::runtime_error("Unsupported type UINT8");
case
onnx
::
TensorProto
::
INT8
:
shape_type
=
rtg
::
shape
::
int8_type
;
case
onnx
::
TensorProto
::
UINT16
:
shape_type
=
rtg
::
shape
::
uint16_type
;
case
onnx
::
TensorProto
::
INT16
:
shape_type
=
rtg
::
shape
::
int16_type
;
case
onnx
::
TensorProto
::
INT32
:
shape_type
=
rtg
::
shape
::
int32_type
;
case
onnx
::
TensorProto
::
INT64
:
shape_type
=
rtg
::
shape
::
int64_type
;
case
onnx
::
TensorProto
::
STRING
:
break
;
//throw std::runtime_error("Unsupported type STRING");
case
onnx
::
TensorProto
::
BOOL
:
break
;
//throw std::runtime_error("Unsupported type BOOL");
case
onnx
::
TensorProto
::
FLOAT16
:
break
;
//throw std::runtime_error("Unsupported type FLOAT16");
case
onnx
::
TensorProto
::
STRING
:
break
;
// throw std::runtime_error("Unsupported type STRING");
case
onnx
::
TensorProto
::
BOOL
:
break
;
// throw std::runtime_error("Unsupported type BOOL");
case
onnx
::
TensorProto
::
FLOAT16
:
break
;
// throw std::runtime_error("Unsupported type FLOAT16");
case
onnx
::
TensorProto
::
DOUBLE
:
shape_type
=
rtg
::
shape
::
double_type
;
case
onnx
::
TensorProto
::
UINT32
:
shape_type
=
rtg
::
shape
::
uint32_type
;
case
onnx
::
TensorProto
::
UINT64
:
shape_type
=
rtg
::
shape
::
uint64_type
;
case
onnx
::
TensorProto
::
COMPLEX64
:
break
;
//throw std::runtime_error("Unsupported type COMPLEX64");
case
onnx
::
TensorProto
::
COMPLEX128
:
break
;
//throw std::runtime_error("Unsupported type COMPLEX128");
case
onnx
::
TensorProto
::
COMPLEX64
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX64");
case
onnx
::
TensorProto
::
COMPLEX128
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX128");
}
std
::
vector
<
std
::
size_t
>
dims
;
// TODO: USe std::transform
for
(
auto
&&
d
:
t
.
tensor_type
().
shape
().
dim
())
for
(
auto
&&
d
:
t
.
tensor_type
().
shape
().
dim
())
{
dims
.
push_back
(
d
.
dim_value
());
}
...
...
@@ -271,7 +291,7 @@ struct onnx_parser
}
};
int
main
(
int
argc
,
char
const
*
argv
[])
int
main
(
int
argc
,
char
const
*
argv
[])
{
if
(
argc
>
1
)
{
...
...
@@ -284,7 +304,8 @@ int main(int argc, char const *argv[])
}
catch
(...)
{
if
(
parser
.
prog
)
parser
.
prog
->
print
();
if
(
parser
.
prog
)
parser
.
prog
->
print
();
throw
;
}
parser
.
prog
->
print
();
...
...
src/program.cpp
View file @
14d5666b
...
...
@@ -9,7 +9,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
{
std
::
unordered_map
<
const
instruction
*
,
argument
>
results
;
argument
result
;
for
(
auto
&
ins
:
instructions
)
for
(
auto
&
ins
:
instructions
)
{
if
(
ins
.
op
.
name
()
==
"@literal"
)
{
...
...
@@ -22,9 +22,10 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
else
{
std
::
vector
<
argument
>
values
(
ins
.
arguments
.
size
());
std
::
transform
(
ins
.
arguments
.
begin
(),
ins
.
arguments
.
end
(),
values
.
begin
(),
[
&
](
instruction
*
i
)
{
return
results
.
at
(
i
);
});
std
::
transform
(
ins
.
arguments
.
begin
(),
ins
.
arguments
.
end
(),
values
.
begin
(),
[
&
](
instruction
*
i
)
{
return
results
.
at
(
i
);
});
result
=
ins
.
op
.
compute
(
values
);
}
results
.
emplace
(
std
::
addressof
(
ins
),
result
);
...
...
@@ -37,7 +38,7 @@ void program::print() const
std
::
unordered_map
<
const
instruction
*
,
std
::
string
>
names
;
int
count
=
0
;
for
(
auto
&
ins
:
instructions
)
for
(
auto
&
ins
:
instructions
)
{
std
::
string
var_name
=
"@"
+
std
::
to_string
(
count
);
if
(
starts_with
(
ins
.
op
.
name
(),
"@param"
))
...
...
@@ -51,7 +52,7 @@ void program::print() const
if
(
ins
.
op
.
name
()
==
"@literal"
)
{
if
(
ins
.
lit
.
get_shape
().
elements
()
>
10
)
if
(
ins
.
lit
.
get_shape
().
elements
()
>
10
)
std
::
cout
<<
"{ ... }"
;
else
std
::
cout
<<
"{"
<<
ins
.
lit
<<
"}"
;
...
...
@@ -60,7 +61,7 @@ void program::print() const
if
(
!
ins
.
arguments
.
empty
())
{
char
delim
=
'('
;
for
(
auto
&&
arg
:
ins
.
arguments
)
for
(
auto
&&
arg
:
ins
.
arguments
)
{
assert
(
this
->
has_instruction
(
arg
)
&&
"Instruction not found"
);
std
::
cout
<<
delim
<<
names
.
at
(
arg
);
...
...
@@ -78,5 +79,4 @@ void program::print() const
}
}
}
}
// namespace rtg
src/shape.cpp
View file @
14d5666b
...
...
@@ -7,21 +7,16 @@
namespace
rtg
{
shape
::
shape
()
:
type_
(
float_type
),
lens_
(),
strides_
(),
packed_
(
false
)
{}
shape
::
shape
()
:
type_
(
float_type
),
lens_
(),
strides_
(),
packed_
(
false
)
{}
shape
::
shape
(
type_t
t
)
:
type_
(
t
),
lens_
({
1
}),
strides_
({
1
}),
packed_
(
true
)
{}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
type_
(
t
),
lens_
(
std
::
move
(
l
)),
packed_
(
true
)
shape
::
shape
(
type_t
t
)
:
type_
(
t
),
lens_
({
1
}),
strides_
({
1
}),
packed_
(
true
)
{}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
type_
(
t
),
lens_
(
std
::
move
(
l
)),
packed_
(
true
)
{
this
->
calculate_strides
();
assert
(
lens_
.
size
()
==
strides_
.
size
());
}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
)
:
type_
(
t
),
lens_
(
std
::
move
(
l
)),
strides_
(
std
::
move
(
s
))
:
type_
(
t
),
lens_
(
std
::
move
(
l
)),
strides_
(
std
::
move
(
s
))
{
assert
(
lens_
.
size
()
==
strides_
.
size
());
packed_
=
this
->
elements
()
==
this
->
element_space
();
...
...
@@ -38,18 +33,9 @@ void shape::calculate_strides()
lens_
.
rbegin
(),
lens_
.
rend
()
-
1
,
strides_
.
rbegin
()
+
1
,
std
::
multiplies
<
std
::
size_t
>
());
}
shape
::
type_t
shape
::
type
()
const
{
return
this
->
type_
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
lens
()
const
{
return
this
->
lens_
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
this
->
strides_
;
}
shape
::
type_t
shape
::
type
()
const
{
return
this
->
type_
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
lens
()
const
{
return
this
->
lens_
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
this
->
strides_
;
}
std
::
size_t
shape
::
elements
()
const
{
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
...
...
@@ -77,13 +63,15 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const
std
::
size_t
shape
::
index
(
std
::
size_t
i
)
const
{
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
std
::
inner_product
(
this
->
lens
().
begin
(),
this
->
lens
().
end
(),
this
->
strides
().
begin
(),
std
::
size_t
{
0
},
std
::
plus
<
std
::
size_t
>
{},
[
&
](
std
::
size_t
len
,
std
::
size_t
stride
)
{
return
((
i
/
stride
)
%
len
)
*
stride
;
});
}
bool
shape
::
packed
()
const
{
return
this
->
packed_
;
return
std
::
inner_product
(
this
->
lens
().
begin
(),
this
->
lens
().
end
(),
this
->
strides
().
begin
(),
std
::
size_t
{
0
},
std
::
plus
<
std
::
size_t
>
{},
[
&
](
std
::
size_t
len
,
std
::
size_t
stride
)
{
return
((
i
/
stride
)
%
len
)
*
stride
;
});
}
bool
shape
::
packed
()
const
{
return
this
->
packed_
;
}
std
::
size_t
shape
::
element_space
()
const
{
// TODO: Get rid of intermediate vector
...
...
@@ -104,8 +92,7 @@ std::string shape::type_string() const
switch
(
this
->
type_
)
{
#define RTG_SHAPE_TYPE_STRING_CASE(x, t) \
case x: \
return #x;
case x: return #x;
RTG_SHAPE_VISIT_TYPES
(
RTG_SHAPE_TYPE_STRING_CASE
)
#undef RTG_SHAPE_TYPE_STRING_CASE
}
...
...
@@ -116,10 +103,7 @@ bool operator==(const shape& x, const shape& y)
{
return
x
.
type
()
==
y
.
type
()
&&
x
.
lens
()
==
y
.
lens
()
&&
x
.
strides
()
==
y
.
strides
();
}
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
)
{
return
!
(
x
==
y
);
}
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
)
{
return
!
(
x
==
y
);
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
&
x
)
{
...
...
@@ -129,4 +113,4 @@ std::ostream& operator<<(std::ostream& os, const shape& x)
return
os
;
}
}
}
// namespace rtg
test/eval_test.cpp
View file @
14d5666b
...
...
@@ -4,37 +4,37 @@
#include <rtg/shape.hpp>
#include "test.hpp"
struct
sum_op
{
std
::
string
name
()
const
{
return
"sum"
;
}
std
::
string
name
()
const
{
return
"sum"
;
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
args
)
const
{
rtg
::
argument
result
;
if
(
args
.
size
()
!=
2
)
throw
"Wrong args"
;
if
(
args
[
0
].
get_shape
()
!=
args
[
1
].
get_shape
())
throw
"Wrong args"
;
if
(
args
[
0
].
get_shape
().
lens
().
size
()
!=
1
)
throw
"Wrong args"
;
if
(
args
[
0
].
get_shape
().
lens
().
front
()
!=
1
)
throw
"Wrong args"
;
if
(
args
.
size
()
!=
2
)
throw
"Wrong args"
;
if
(
args
[
0
].
get_shape
()
!=
args
[
1
].
get_shape
())
throw
"Wrong args"
;
if
(
args
[
0
].
get_shape
().
lens
().
size
()
!=
1
)
throw
"Wrong args"
;
if
(
args
[
0
].
get_shape
().
lens
().
front
()
!=
1
)
throw
"Wrong args"
;
args
[
0
].
visit_at
([
&
](
auto
x
)
{
args
[
1
].
visit_at
([
&
](
auto
y
)
{
result
=
rtg
::
literal
{
x
+
y
}.
get_argument
();
});
args
[
1
].
visit_at
([
&
](
auto
y
)
{
result
=
rtg
::
literal
{
x
+
y
}.
get_argument
();
});
});
return
result
;
}
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
inputs
)
const
{
if
(
inputs
.
size
()
!=
2
)
throw
"Wrong inputs"
;
if
(
inputs
.
size
()
!=
2
)
throw
"Wrong inputs"
;
return
inputs
.
front
();
}
};
void
literal_test
()
{
void
literal_test
()
{
rtg
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -45,23 +45,22 @@ void literal_test() {
EXPECT
(
result
!=
rtg
::
literal
{
4
});
}
void
param_test
()
{
void
param_test
()
{
rtg
::
program
p
;
auto
x
=
p
.
add_parameter
(
"x"
,
{
rtg
::
shape
::
int64_type
});
auto
y
=
p
.
add_parameter
(
"y"
,
{
rtg
::
shape
::
int64_type
});
p
.
add_instruction
(
sum_op
{},
x
,
y
);
auto
result
=
p
.
eval
({
{
"x"
,
rtg
::
literal
{
1
}.
get_argument
()},
{
"y"
,
rtg
::
literal
{
2
}.
get_argument
()}
});
auto
result
=
p
.
eval
({{
"x"
,
rtg
::
literal
{
1
}.
get_argument
()},
{
"y"
,
rtg
::
literal
{
2
}.
get_argument
()}});
EXPECT
(
result
==
rtg
::
literal
{
3
});
EXPECT
(
result
!=
rtg
::
literal
{
4
});
}
int
main
()
{
int
main
()
{
literal_test
();
param_test
();
}
test/literal_test.cpp
View file @
14d5666b
...
...
@@ -4,7 +4,6 @@
#include <string>
#include "test.hpp"
void
literal_test
()
{
EXPECT
(
rtg
::
literal
{
1
}
==
rtg
::
literal
{
1
});
...
...
@@ -51,10 +50,9 @@ void literal_os3()
EXPECT
(
ss
.
str
()
==
"1, 2, 3"
);
}
int
main
()
{
int
main
()
{
literal_test
();
literal_os1
();
literal_os2
();
}
test/main.cpp
View file @
14d5666b
int
main
()
{
}
int
main
()
{}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment