Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
14d5666b
Commit
14d5666b
authored
Apr 23, 2018
by
Paul
Browse files
Add clang formatting
parent
2305ac81
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
510 additions
and
450 deletions
+510
-450
.clang-format
.clang-format
+90
-0
.githooks/install
.githooks/install
+7
-0
.githooks/pre-commit
.githooks/pre-commit
+43
-0
include/rtg/argument.hpp
include/rtg/argument.hpp
+7
-15
include/rtg/builtin.hpp
include/rtg/builtin.hpp
+7
-25
include/rtg/instruction.hpp
include/rtg/instruction.hpp
+7
-5
include/rtg/literal.hpp
include/rtg/literal.hpp
+19
-42
include/rtg/operand.hpp
include/rtg/operand.hpp
+16
-14
include/rtg/operators.hpp
include/rtg/operators.hpp
+66
-70
include/rtg/program.hpp
include/rtg/program.hpp
+18
-13
include/rtg/raw_data.hpp
include/rtg/raw_data.hpp
+20
-31
include/rtg/shape.hpp
include/rtg/shape.hpp
+25
-30
include/rtg/stringutils.hpp
include/rtg/stringutils.hpp
+2
-5
include/rtg/tensor_view.hpp
include/rtg/tensor_view.hpp
+28
-45
onnx/read_onnx.cpp
onnx/read_onnx.cpp
+102
-81
src/program.cpp
src/program.cpp
+9
-9
src/shape.cpp
src/shape.cpp
+20
-36
test/eval_test.cpp
test/eval_test.cpp
+20
-21
test/literal_test.cpp
test/literal_test.cpp
+3
-5
test/main.cpp
test/main.cpp
+1
-3
No files found.
.clang-format
0 → 100644
View file @
14d5666b
---
Language: Cpp
AccessModifierOffset: 0
AlignAfterOpenBracket: Align
AlignConsecutiveAssignments: true
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: true
AlignOperands: true
AlignTrailingComments: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortBlocksOnASingleLine: true
AllowShortCaseLabelsOnASingleLine: true
AllowShortFunctionsOnASingleLine: All
AllowShortIfStatementsOnASingleLine: false
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: false
AlwaysBreakTemplateDeclarations: true
BinPackArguments: false
BinPackParameters: false
BraceWrapping:
AfterClass: true
AfterControlStatement: true
AfterEnum: true
AfterFunction: true
AfterNamespace: false
AfterObjCDeclaration: true
AfterStruct: true
AfterUnion: true
BeforeCatch: true
BeforeElse: true
IndentBraces: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Custom
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
ColumnLimit: 100
CommentPragmas: '^ IWYU pragma:'
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
ExperimentalAutoDetectBinPacking: false
ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
IncludeCategories:
- Regex: '^"(llvm|llvm-c|clang|clang-c)/'
Priority: 2
- Regex: '^(<|"(gtest|isl|json)/)'
Priority: 3
- Regex: '.*'
Priority: 1
IndentCaseLabels: false
IndentWidth: 4
IndentWrappedFunctionNames: false
KeepEmptyLinesAtTheStartOfBlocks: true
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: true
PenaltyBreakBeforeFirstCallParameter: 19
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 60
PointerAlignment: Left
ReflowComments: true
SortIncludes: false
SpaceAfterCStyleCast: false
# SpaceAfterTemplateKeyword: true
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: Never
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Cpp11
TabWidth: 8
UseTab: Never
...
.githooks/install
0 → 100755
View file @
14d5666b
#!/usr/bin/env bash
cd
$(
git rev-parse
--git-dir
)
echo
"Installing hooks..."
ln
-s
../.githooks hooks
echo
"Done!"
.githooks/pre-commit
0 → 100755
View file @
14d5666b
#!/bin/sh
#
# This pre-commit hook checks if any versions of clang-format
# are installed, and if so, uses the installed version to format
# the staged changes.
base
=
clang-format-5.0
format
=
""
# Redirect output to stderr.
exec
1>&2
# check if clang-format is installed
type
"
$base
"
>
/dev/null 2>&1
&&
format
=
"
$base
"
# no versions of clang-format are installed
if
[
-z
"
$format
"
]
then
echo
"
$base
is not installed. Pre-commit hook will not be executed."
exit
0
fi
# Do everything from top - level
cd
$(
git rev-parse
--show-toplevel
)
if
git rev-parse
--verify
HEAD
>
/dev/null 2>&1
then
against
=
HEAD
else
# Initial commit: diff against an empty tree object
against
=
16bbb57
fi
# do the formatting
for
file
in
$(
git diff-index
--cached
--name-only
$against
|
grep
-E
'\.h$|\.hpp$|\.cpp$|\.cl$|\.h\.in$|\.hpp\.in$|\.cpp\.in$'
)
do
if
[
-e
"
$file
"
]
then
echo
"
$format
$file
"
"
$format
"
-i
-style
=
file
"
$file
"
fi
done
include/rtg/argument.hpp
View file @
14d5666b
...
@@ -9,28 +9,20 @@ namespace rtg {
...
@@ -9,28 +9,20 @@ namespace rtg {
struct
argument
:
raw_data
<
argument
>
struct
argument
:
raw_data
<
argument
>
{
{
argument
()
argument
()
{}
{}
argument
(
shape
s
,
std
::
function
<
char
*
()
>
d
)
argument
(
shape
s
,
std
::
function
<
char
*
()
>
d
)
:
data
(
d
),
shape_
(
s
)
{}
:
data
(
d
),
shape_
(
s
)
{}
std
::
function
<
char
*
()
>
data
;
std
::
function
<
char
*
()
>
data
;
bool
empty
()
const
bool
empty
()
const
{
return
not
data
;
}
{
return
not
data
;
}
const
shape
&
get_shape
()
const
const
shape
&
get_shape
()
const
{
return
this
->
shape_
;
}
{
return
this
->
shape_
;
private:
}
private:
shape
shape_
;
shape
shape_
;
};
};
}
}
// namespace rtg
#endif
#endif
include/rtg/builtin.hpp
View file @
14d5666b
...
@@ -9,38 +9,20 @@ namespace builtin {
...
@@ -9,38 +9,20 @@ namespace builtin {
struct
literal
struct
literal
{
{
std
::
string
name
()
const
std
::
string
name
()
const
{
return
"@literal"
;
}
{
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
throw
"builtin"
;
}
return
"@literal"
;
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
"builtin"
;
}
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
throw
"builtin"
;
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
"builtin"
;
}
};
};
struct
param
struct
param
{
{
std
::
string
parameter
;
std
::
string
parameter
;
std
::
string
name
()
const
std
::
string
name
()
const
{
return
"@param:"
+
parameter
;
}
{
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
throw
"builtin"
;
}
return
"@param:"
+
parameter
;
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
"builtin"
;
}
}
shape
compute_shape
(
std
::
vector
<
shape
>
)
const
{
throw
"builtin"
;
}
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
"builtin"
;
}
};
};
}
}
// namespace builtin
}
// namespace rtg
}
// namespace rtg
...
...
include/rtg/instruction.hpp
View file @
14d5666b
...
@@ -13,12 +13,14 @@ struct instruction
...
@@ -13,12 +13,14 @@ struct instruction
instruction
()
{}
instruction
()
{}
instruction
(
operand
o
,
shape
r
,
std
::
vector
<
instruction
*>
args
)
instruction
(
operand
o
,
shape
r
,
std
::
vector
<
instruction
*>
args
)
:
op
(
std
::
move
(
o
)),
result
(
std
::
move
(
r
)),
arguments
(
std
::
move
(
args
)),
lit
()
:
op
(
std
::
move
(
o
)),
result
(
std
::
move
(
r
)),
arguments
(
std
::
move
(
args
)),
lit
()
{}
{
}
instruction
(
literal
l
)
instruction
(
literal
l
)
:
op
(
builtin
::
literal
{}),
result
(
l
.
get_shape
()),
arguments
(),
lit
(
std
::
move
(
l
))
:
op
(
builtin
::
literal
{}),
result
(
l
.
get_shape
()),
arguments
(),
lit
(
std
::
move
(
l
))
{}
{
}
operand
op
;
operand
op
;
shape
result
;
shape
result
;
...
@@ -26,6 +28,6 @@ struct instruction
...
@@ -26,6 +28,6 @@ struct instruction
literal
lit
;
literal
lit
;
};
};
}
}
// namespace rtg
#endif
#endif
include/rtg/literal.hpp
View file @
14d5666b
...
@@ -10,68 +10,45 @@ namespace rtg {
...
@@ -10,68 +10,45 @@ namespace rtg {
struct
literal
:
raw_data
<
literal
>
struct
literal
:
raw_data
<
literal
>
{
{
literal
()
literal
()
:
buffer
(),
shape_
()
{}
:
buffer
(),
shape_
()
{}
template
<
class
T
>
template
<
class
T
>
literal
(
T
x
)
literal
(
T
x
)
:
buffer
(
sizeof
(
T
),
0
),
shape_
(
shape
::
get_type
<
T
>
{})
:
buffer
(
sizeof
(
T
),
0
),
shape_
(
shape
::
get_type
<
T
>
{})
{
{
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
*
(
reinterpret_cast
<
T
*>
(
buffer
.
data
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
buffer
.
data
()))
=
x
;
}
}
template
<
class
T
>
template
<
class
T
>
literal
(
shape
s
,
const
std
::
vector
<
T
>&
x
)
literal
(
shape
s
,
const
std
::
vector
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
shape_
(
s
)
:
buffer
(
s
.
bytes
(),
0
),
shape_
(
s
)
{
{
assert
(
s
.
packed
());
assert
(
s
.
packed
());
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
s
.
visit_type
([
&
](
auto
as
)
{
s
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
x
.
begin
(),
x
.
end
(),
as
.
from
(
buffer
.
data
()));
});
std
::
copy
(
x
.
begin
(),
x
.
end
(),
as
.
from
(
buffer
.
data
()));
});
}
}
template
<
class
T
>
template
<
class
T
>
literal
(
shape
s
,
const
std
::
initializer_list
<
T
>&
x
)
literal
(
shape
s
,
const
std
::
initializer_list
<
T
>&
x
)
:
buffer
(
s
.
bytes
(),
0
),
shape_
(
s
)
:
buffer
(
s
.
bytes
(),
0
),
shape_
(
s
)
{
{
assert
(
s
.
packed
());
assert
(
s
.
packed
());
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
static_assert
(
std
::
is_trivial
<
T
>
{},
"Literals can only be trivial types"
);
s
.
visit_type
([
&
](
auto
as
)
{
s
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
x
.
begin
(),
x
.
end
(),
as
.
from
(
buffer
.
data
()));
});
std
::
copy
(
x
.
begin
(),
x
.
end
(),
as
.
from
(
buffer
.
data
()));
});
}
}
template
<
class
Iterator
>
template
<
class
Iterator
>
literal
(
shape
s
,
Iterator
start
,
Iterator
end
)
literal
(
shape
s
,
Iterator
start
,
Iterator
end
)
:
buffer
(
s
.
bytes
(),
0
),
shape_
(
s
)
:
buffer
(
s
.
bytes
(),
0
),
shape_
(
s
)
{
{
assert
(
s
.
packed
());
assert
(
s
.
packed
());
s
.
visit_type
([
&
](
auto
as
)
{
s
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
data
()));
});
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
data
()));
});
}
}
literal
(
shape
s
,
const
char
*
x
)
:
buffer
(
x
,
x
+
s
.
bytes
()),
shape_
(
s
)
{}
bool
empty
()
const
literal
(
shape
s
,
const
char
*
x
)
:
buffer
(
x
,
x
+
s
.
bytes
()),
shape_
(
s
)
{}
{
return
this
->
buffer
.
empty
();
}
const
char
*
data
()
const
bool
empty
()
const
{
return
this
->
buffer
.
empty
();
}
{
return
this
->
buffer
.
data
();
}
const
shape
&
get_shape
()
const
const
char
*
data
()
const
{
return
this
->
buffer
.
data
();
}
{
return
this
->
shape_
;
const
shape
&
get_shape
()
const
{
return
this
->
shape_
;
}
}
argument
get_argument
()
const
argument
get_argument
()
const
{
{
...
@@ -79,11 +56,11 @@ struct literal : raw_data<literal>
...
@@ -79,11 +56,11 @@ struct literal : raw_data<literal>
return
{
shape_
,
[
b
]()
mutable
{
return
b
.
data
();
}};
return
{
shape_
,
[
b
]()
mutable
{
return
b
.
data
();
}};
}
}
private:
private:
std
::
vector
<
char
>
buffer
;
std
::
vector
<
char
>
buffer
;
shape
shape_
;
shape
shape_
;
};
};
}
}
// namespace rtg
#endif
#endif
include/rtg/operand.hpp
View file @
14d5666b
...
@@ -12,16 +12,16 @@
...
@@ -12,16 +12,16 @@
namespace
rtg
{
namespace
rtg
{
/*
/*
* Type-erased interface for:
* Type-erased interface for:
*
*
* struct operand
* struct operand
* {
* {
* std::string name() const;
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* shape compute_shape(std::vector<shape> input) const;
* argument compute(std::vector<argument> input) const;
* argument compute(std::vector<argument> input) const;
* };
* };
*
*
*/
*/
struct
operand
struct
operand
{
{
...
@@ -80,8 +80,9 @@ struct operand
...
@@ -80,8 +80,9 @@ struct operand
struct
handle_type_
:
handle_base_type_
struct
handle_type_
:
handle_base_type_
{
{
template
<
typename
TypeErased_U_
=
TypeErased_T_
>
template
<
typename
TypeErased_U_
=
TypeErased_T_
>
handle_type_
(
TypeErased_T_
value
,
handle_type_
(
typename
std
::
enable_if
<
std
::
is_reference
<
TypeErased_U_
>::
value
>::
type
*
=
nullptr
)
TypeErased_T_
value
,
typename
std
::
enable_if
<
std
::
is_reference
<
TypeErased_U_
>::
value
>::
type
*
=
nullptr
)
:
value_
(
value
)
:
value_
(
value
)
{
{
}
}
...
@@ -89,7 +90,8 @@ struct operand
...
@@ -89,7 +90,8 @@ struct operand
template
<
typename
TypeErased_U_
=
TypeErased_T_
>
template
<
typename
TypeErased_U_
=
TypeErased_T_
>
handle_type_
(
TypeErased_T_
value
,
handle_type_
(
TypeErased_T_
value
,
typename
std
::
enable_if
<!
std
::
is_reference
<
TypeErased_U_
>::
value
,
int
>::
type
*
=
typename
std
::
enable_if
<!
std
::
is_reference
<
TypeErased_U_
>::
value
,
int
>::
type
*
=
nullptr
)
noexcept
:
value_
(
std
::
move
(
value
))
nullptr
)
noexcept
:
value_
(
std
::
move
(
value
))
{
{
}
}
...
@@ -134,6 +136,6 @@ struct operand
...
@@ -134,6 +136,6 @@ struct operand
std
::
shared_ptr
<
handle_base_type_
>
handle_mem_var_
;
std
::
shared_ptr
<
handle_base_type_
>
handle_mem_var_
;
};
};
}
}
// namespace rtg
#endif
#endif
include/rtg/operators.hpp
View file @
14d5666b
...
@@ -9,120 +9,120 @@ namespace rtg {
...
@@ -9,120 +9,120 @@ namespace rtg {
struct
not_computable
struct
not_computable
{
{
argument
compute
(
std
::
vector
<
argument
>
)
const
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
std
::
runtime_error
(
"not computable"
);
}
{
throw
std
::
runtime_error
(
"not computable"
);
}
};
};
struct
convolution
struct
convolution
{
{
std
::
array
<
std
::
size_t
,
2
>
padding
=
{
0
,
0
};
std
::
array
<
std
::
size_t
,
2
>
padding
=
{
0
,
0
};
std
::
array
<
std
::
size_t
,
2
>
stride
=
{
1
,
1
};
std
::
array
<
std
::
size_t
,
2
>
stride
=
{
1
,
1
};
std
::
array
<
std
::
size_t
,
2
>
dilation
=
{
1
,
1
};
std
::
array
<
std
::
size_t
,
2
>
dilation
=
{
1
,
1
};
std
::
string
name
()
const
std
::
string
name
()
const
{
{
return
"convolution[padding={"
+
to_string
(
padding
)
+
return
"convolution[padding={"
+
to_string
(
padding
)
+
"}, stride={"
+
to_string
(
stride
)
+
"}, stride={"
+
to_string
(
stride
)
+
"}, dilation={"
+
to_string
(
dilation
)
+
"}]"
;
"}, dilation={"
+
to_string
(
dilation
)
+
"}]"
;
}
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
if
(
inputs
.
size
()
!=
2
)
throw
std
::
runtime_error
(
"Wrong number of arguments"
);
if
(
inputs
.
size
()
!=
2
)
const
shape
&
input
=
inputs
.
at
(
0
);
throw
std
::
runtime_error
(
"Wrong number of arguments"
);
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
weights
=
inputs
.
at
(
1
);
const
shape
&
weights
=
inputs
.
at
(
1
);
if
(
input
.
type
()
!=
weights
.
type
())
throw
std
::
runtime_error
(
"Type doesn't match"
);
if
(
input
.
type
()
!=
weights
.
type
())
if
(
input
.
lens
().
size
()
!=
weights
.
lens
().
size
())
throw
std
::
runtime_error
(
"Dimensions don't match"
);
throw
std
::
runtime_error
(
"Type doesn't match"
);
if
(
input
.
lens
().
size
()
!=
4
)
throw
std
::
runtime_error
(
"Only 4d convolution supported"
);
if
(
input
.
lens
().
size
()
!=
weights
.
lens
().
size
())
throw
std
::
runtime_error
(
"Dimensions don't match"
);
if
(
input
.
lens
().
size
()
!=
4
)
throw
std
::
runtime_error
(
"Only 4d convolution supported"
);
auto
t
=
input
.
type
();
auto
t
=
input
.
type
();
return
{
t
,
{
return
{
t
,
input
.
lens
()[
0
],
{
weights
.
lens
()[
0
],
input
.
lens
()[
0
],
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
weights
.
lens
()[
0
],
1
,
(
input
.
lens
()[
2
]
-
(
1
+
dilation
[
0
]
*
(
weights
.
lens
()[
2
]
-
1
))
+
2
*
padding
[
0
])
/
stride
[
0
]
+
1
)),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
1
,
(
input
.
lens
()[
3
]
-
(
1
+
dilation
[
1
]
*
(
weights
.
lens
()[
3
]
-
1
))
+
2
*
padding
[
1
])
/
stride
[
1
]
+
1
)),
(
input
.
lens
()[
2
]
-
(
1
+
dilation
[
0
]
*
(
weights
.
lens
()[
2
]
-
1
))
+
}};
2
*
padding
[
0
])
/
stride
[
0
]
+
1
)),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
(
input
.
lens
()[
3
]
-
(
1
+
dilation
[
1
]
*
(
weights
.
lens
()[
3
]
-
1
))
+
2
*
padding
[
1
])
/
stride
[
1
]
+
1
)),
}};
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
std
::
runtime_error
(
"not computable"
);
}
{
throw
std
::
runtime_error
(
"not computable"
);
}
};
};
struct
pooling
struct
pooling
{
{
std
::
string
mode
;
std
::
string
mode
;
std
::
array
<
std
::
size_t
,
2
>
padding
=
{
0
,
0
};
std
::
array
<
std
::
size_t
,
2
>
padding
=
{
0
,
0
};
std
::
array
<
std
::
size_t
,
2
>
stride
=
{
1
,
1
};
std
::
array
<
std
::
size_t
,
2
>
stride
=
{
1
,
1
};
std
::
array
<
std
::
size_t
,
2
>
lengths
=
{
1
,
1
};
std
::
array
<
std
::
size_t
,
2
>
lengths
=
{
1
,
1
};
std
::
string
name
()
const
std
::
string
name
()
const
{
{
return
"pooling:"
+
mode
+
"[padding={"
+
to_string
(
padding
)
+
return
"pooling:"
+
mode
+
"[padding={"
+
to_string
(
padding
)
+
"}, stride={"
+
"}, stride={"
+
to_string
(
stride
)
+
to_string
(
stride
)
+
"}, lengths={"
+
to_string
(
lengths
)
+
"}]"
;
"}, lengths={"
+
to_string
(
lengths
)
+
"}]"
;
}
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
if
(
inputs
.
empty
())
throw
std
::
runtime_error
(
"Wrong number of arguments"
);
if
(
inputs
.
empty
())
const
shape
&
input
=
inputs
.
at
(
0
);
throw
std
::
runtime_error
(
"Wrong number of arguments"
);
if
(
input
.
lens
().
size
()
!=
4
)
throw
std
::
runtime_error
(
"Only 4d pooling supported"
);
const
shape
&
input
=
inputs
.
at
(
0
);
if
(
input
.
lens
().
size
()
!=
4
)
throw
std
::
runtime_error
(
"Only 4d pooling supported"
);
auto
t
=
input
.
type
();
auto
t
=
input
.
type
();
return
{
t
,
{
return
{
t
,
input
.
lens
()[
0
],
{
input
.
lens
()[
1
],
input
.
lens
()[
0
],
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
input
.
lens
()[
1
],
1
,
std
::
ceil
((
input
.
lens
()[
3
]
+
2
*
padding
[
0
]
-
lengths
[
0
])
/
static_cast
<
float
>
(
stride
[
0
]))
+
1
)),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
1
,
std
::
ceil
((
input
.
lens
()[
4
]
+
2
*
padding
[
1
]
-
lengths
[
1
])
/
static_cast
<
float
>
(
stride
[
1
]))
+
1
)),
std
::
ceil
((
input
.
lens
()[
3
]
+
2
*
padding
[
0
]
-
lengths
[
0
])
/
}};
static_cast
<
float
>
(
stride
[
0
]))
+
1
)),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
std
::
ceil
((
input
.
lens
()[
4
]
+
2
*
padding
[
1
]
-
lengths
[
1
])
/
static_cast
<
float
>
(
stride
[
1
]))
+
1
)),
}};
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
std
::
runtime_error
(
"not computable"
);
}
{
throw
std
::
runtime_error
(
"not computable"
);
}
};
};
struct
activation
struct
activation
{
{
std
::
string
mode
;
std
::
string
mode
;
std
::
string
name
()
const
std
::
string
name
()
const
{
return
"activation:"
+
mode
;
}
{
return
"activation:"
+
mode
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
if
(
inputs
.
empty
())
throw
std
::
runtime_error
(
"Wrong number of arguments"
);
if
(
inputs
.
empty
())
throw
std
::
runtime_error
(
"Wrong number of arguments"
);
return
inputs
.
front
();
return
inputs
.
front
();
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
std
::
runtime_error
(
"not computable"
);
}
{
throw
std
::
runtime_error
(
"not computable"
);
}
};
};
struct
reshape
struct
reshape
{
{
std
::
vector
<
int64_t
>
dims
;
std
::
vector
<
int64_t
>
dims
;
std
::
string
name
()
const
std
::
string
name
()
const
{
return
"reshape[dims={"
+
to_string
(
dims
)
+
"}]"
;
}
{
return
"reshape[dims={"
+
to_string
(
dims
)
+
"}]"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
if
(
inputs
.
empty
())
throw
std
::
runtime_error
(
"Wrong number of arguments"
);
if
(
inputs
.
empty
())
throw
std
::
runtime_error
(
"Wrong number of arguments"
);
auto
&&
idims
=
inputs
.
front
().
lens
();
auto
&&
idims
=
inputs
.
front
().
lens
();
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
for
(
std
::
size_t
i
=
0
;
i
<
dims
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
dims
.
size
();
i
++
)
{
{
if
(
dims
[
i
]
==
0
)
if
(
dims
[
i
]
==
0
)
rdims
[
i
]
=
idims
[
i
];
rdims
[
i
]
=
idims
[
i
];
...
@@ -130,18 +130,14 @@ struct reshape
...
@@ -130,18 +130,14 @@ struct reshape
if
(
dims
.
back
()
==
-
1
)
if
(
dims
.
back
()
==
-
1
)
{
{
rdims
.
pop_back
();
rdims
.
pop_back
();
std
::
copy
(
idims
.
begin
()
+
rdims
.
size
(),
idims
.
end
(),
std
::
back_inserter
(
rdims
));
std
::
copy
(
idims
.
begin
()
+
rdims
.
size
(),
idims
.
end
(),
std
::
back_inserter
(
rdims
));
}
}
return
{
inputs
.
front
().
type
(),
rdims
};
return
{
inputs
.
front
().
type
(),
rdims
};
}
}
argument
compute
(
std
::
vector
<
argument
>
)
const
argument
compute
(
std
::
vector
<
argument
>
)
const
{
throw
std
::
runtime_error
(
"not computable"
);
}
{
throw
std
::
runtime_error
(
"not computable"
);
}
};
};
}
// namespace rtg
}
// namespace rtg
#endif
#endif
include/rtg/program.hpp
View file @
14d5666b
...
@@ -13,35 +13,38 @@ namespace rtg {
...
@@ -13,35 +13,38 @@ namespace rtg {
struct
program
struct
program
{
{
// TODO: A program should be copyable
// TODO: A program should be copyable
program
()
=
default
;
program
()
=
default
;
program
(
const
program
&
)
=
delete
;
program
(
const
program
&
)
=
delete
;
program
&
operator
=
(
const
program
&
)
=
delete
;
program
&
operator
=
(
const
program
&
)
=
delete
;
template
<
class
...
Ts
>
template
<
class
...
Ts
>
instruction
*
add_instruction
(
operand
op
,
Ts
*
...
args
)
instruction
*
add_instruction
(
operand
op
,
Ts
*
...
args
)
{
{
shape
r
=
op
.
compute_shape
({
args
->
result
...});
shape
r
=
op
.
compute_shape
({
args
->
result
...});
instructions
.
push_back
({
op
,
r
,
{
args
...}});
instructions
.
push_back
({
op
,
r
,
{
args
...}});
return
std
::
addressof
(
instructions
.
back
());
return
std
::
addressof
(
instructions
.
back
());
}
}
instruction
*
add_instruction
(
operand
op
,
std
::
vector
<
instruction
*>
args
)
instruction
*
add_instruction
(
operand
op
,
std
::
vector
<
instruction
*>
args
)
{
{
assert
(
std
::
all_of
(
args
.
begin
(),
args
.
end
(),
[
&
](
instruction
*
x
)
{
return
has_instruction
(
x
);
})
&&
"Argument is not an exisiting instruction"
);
assert
(
std
::
all_of
(
args
.
begin
(),
args
.
end
(),
[
&
](
instruction
*
x
)
{
return
has_instruction
(
x
);
})
&&
"Argument is not an exisiting instruction"
);
std
::
vector
<
shape
>
shapes
(
args
.
size
());
std
::
vector
<
shape
>
shapes
(
args
.
size
());
std
::
transform
(
args
.
begin
(),
args
.
end
(),
shapes
.
begin
(),
[](
instruction
*
ins
)
{
return
ins
->
result
;
});
std
::
transform
(
args
.
begin
(),
args
.
end
(),
shapes
.
begin
(),
[](
instruction
*
ins
)
{
return
ins
->
result
;
});
shape
r
=
op
.
compute_shape
(
shapes
);
shape
r
=
op
.
compute_shape
(
shapes
);
instructions
.
push_back
({
op
,
r
,
args
});
instructions
.
push_back
({
op
,
r
,
args
});
assert
(
instructions
.
back
().
arguments
==
args
);
assert
(
instructions
.
back
().
arguments
==
args
);
return
std
::
addressof
(
instructions
.
back
());
return
std
::
addressof
(
instructions
.
back
());
}
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
instruction
*
add_literal
(
Ts
&&
...
xs
)
instruction
*
add_literal
(
Ts
&&
...
xs
)
{
{
instructions
.
emplace_back
(
literal
{
std
::
forward
<
Ts
>
(
xs
)...});
instructions
.
emplace_back
(
literal
{
std
::
forward
<
Ts
>
(
xs
)...});
return
std
::
addressof
(
instructions
.
back
());
return
std
::
addressof
(
instructions
.
back
());
}
}
instruction
*
add_parameter
(
std
::
string
name
,
shape
s
)
instruction
*
add_parameter
(
std
::
string
name
,
shape
s
)
{
{
instructions
.
push_back
({
builtin
::
param
{
std
::
move
(
name
)},
s
,
{}});
instructions
.
push_back
({
builtin
::
param
{
std
::
move
(
name
)},
s
,
{}});
return
std
::
addressof
(
instructions
.
back
());
return
std
::
addressof
(
instructions
.
back
());
...
@@ -52,16 +55,18 @@ struct program
...
@@ -52,16 +55,18 @@ struct program
// TODO: Change to stream operator
// TODO: Change to stream operator
void
print
()
const
;
void
print
()
const
;
bool
has_instruction
(
const
instruction
*
ins
)
const
bool
has_instruction
(
const
instruction
*
ins
)
const
{
{
return
std
::
find_if
(
instructions
.
begin
(),
instructions
.
end
(),
[
&
](
const
instruction
&
x
)
{
return
ins
==
std
::
addressof
(
x
);
})
!=
instructions
.
end
();
return
std
::
find_if
(
instructions
.
begin
(),
instructions
.
end
(),
[
&
](
const
instruction
&
x
)
{
return
ins
==
std
::
addressof
(
x
);
})
!=
instructions
.
end
();
}
}
private:
private:
// A list is used to keep references to an instruction stable
// A list is used to keep references to an instruction stable
std
::
list
<
instruction
>
instructions
;
std
::
list
<
instruction
>
instructions
;
};
};
}
}
// namespace rtg
#endif
#endif
include/rtg/raw_data.hpp
View file @
14d5666b
...
@@ -6,14 +6,14 @@
...
@@ -6,14 +6,14 @@
namespace
rtg
{
namespace
rtg
{
template
<
class
Derived
>
template
<
class
Derived
>
struct
raw_data
struct
raw_data
{
{
friend
bool
operator
==
(
const
Derived
&
x
,
const
Derived
&
y
)
friend
bool
operator
==
(
const
Derived
&
x
,
const
Derived
&
y
)
{
{
auto
&&
xshape
=
x
.
get_shape
();
auto
&&
xshape
=
x
.
get_shape
();
auto
&&
yshape
=
y
.
get_shape
();
auto
&&
yshape
=
y
.
get_shape
();
bool
result
=
x
.
empty
()
&&
y
.
empty
();
bool
result
=
x
.
empty
()
&&
y
.
empty
();
if
(
not
result
&&
xshape
==
yshape
)
if
(
not
result
&&
xshape
==
yshape
)
{
{
auto
&&
xbuffer
=
x
.
data
();
auto
&&
xbuffer
=
x
.
data
();
...
@@ -22,59 +22,48 @@ struct raw_data
...
@@ -22,59 +22,48 @@ struct raw_data
xshape
.
visit_type
([
&
](
auto
as
)
{
xshape
.
visit_type
([
&
](
auto
as
)
{
auto
xview
=
make_view
(
xshape
,
as
.
from
(
xbuffer
));
auto
xview
=
make_view
(
xshape
,
as
.
from
(
xbuffer
));
auto
yview
=
make_view
(
yshape
,
as
.
from
(
ybuffer
));
auto
yview
=
make_view
(
yshape
,
as
.
from
(
ybuffer
));
result
=
xview
==
yview
;
result
=
xview
==
yview
;
});
});
}
}
return
result
;
return
result
;
}
}
friend
bool
operator
!=
(
const
Derived
&
x
,
const
Derived
&
y
)
friend
bool
operator
!=
(
const
Derived
&
x
,
const
Derived
&
y
)
{
return
!
(
x
==
y
);
}
{
return
!
(
x
==
y
);
template
<
class
Stream
>
}
template
<
class
Stream
>
friend
Stream
&
operator
<<
(
Stream
&
os
,
const
Derived
&
d
)
friend
Stream
&
operator
<<
(
Stream
&
os
,
const
Derived
&
d
)
{
{
d
.
visit
([
&
](
auto
x
)
{
d
.
visit
([
&
](
auto
x
)
{
os
<<
x
;
});
os
<<
x
;
});
return
os
;
return
os
;
}
}
template
<
class
Visitor
>
template
<
class
Visitor
>
void
visit_at
(
Visitor
v
,
std
::
size_t
n
=
0
)
const
void
visit_at
(
Visitor
v
,
std
::
size_t
n
=
0
)
const
{
{
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
auto
&&
buffer
=
static_cast
<
const
Derived
&>
(
*
this
).
data
();
auto
&&
buffer
=
static_cast
<
const
Derived
&>
(
*
this
).
data
();
s
.
visit_type
([
&
](
auto
as
)
{
s
.
visit_type
([
&
](
auto
as
)
{
v
(
*
(
as
.
from
(
buffer
)
+
s
.
index
(
n
)));
});
v
(
*
(
as
.
from
(
buffer
)
+
s
.
index
(
n
)));
});
}
}
template
<
class
Visitor
>
template
<
class
Visitor
>
void
visit
(
Visitor
v
)
const
void
visit
(
Visitor
v
)
const
{
{
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
auto
&&
buffer
=
static_cast
<
const
Derived
&>
(
*
this
).
data
();
auto
&&
buffer
=
static_cast
<
const
Derived
&>
(
*
this
).
data
();
s
.
visit_type
([
&
](
auto
as
)
{
s
.
visit_type
([
&
](
auto
as
)
{
v
(
make_view
(
s
,
as
.
from
(
buffer
)));
});
v
(
make_view
(
s
,
as
.
from
(
buffer
)));
});
}
}
bool
single
()
const
bool
single
()
const
{
{
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
auto
&&
s
=
static_cast
<
const
Derived
&>
(
*
this
).
get_shape
();
return
s
.
elements
()
==
1
;
return
s
.
elements
()
==
1
;
}
}
template
<
class
T
>
template
<
class
T
>
T
at
(
std
::
size_t
n
=
0
)
const
T
at
(
std
::
size_t
n
=
0
)
const
{
{
T
result
;
T
result
;
this
->
visit_at
([
&
](
auto
x
)
{
this
->
visit_at
([
&
](
auto
x
)
{
result
=
x
;
},
n
);
result
=
x
;
},
n
);
return
result
;
return
result
;
}
}
};
};
...
...
include/rtg/shape.hpp
View file @
14d5666b
...
@@ -11,6 +11,7 @@ struct shape
...
@@ -11,6 +11,7 @@ struct shape
{
{
// Add new types here
// Add new types here
// clang-format off
#define RTG_SHAPE_VISIT_TYPES(m) \
#define RTG_SHAPE_VISIT_TYPES(m) \
m(float_type, float) \
m(float_type, float) \
m(double_type, double) \
m(double_type, double) \
...
@@ -23,6 +24,7 @@ struct shape
...
@@ -23,6 +24,7 @@ struct shape
m(uint32_type, uint32_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \
m(uint64_type, uint64_t) \
// clang-format on
#define RTG_SHAPE_ENUM_TYPES(x, t) x,
#define RTG_SHAPE_ENUM_TYPES(x, t) x,
enum
type_t
enum
type_t
{
{
...
@@ -30,12 +32,13 @@ struct shape
...
@@ -30,12 +32,13 @@ struct shape
};
};
#undef RTG_SHAPE_ENUM_TYPES
#undef RTG_SHAPE_ENUM_TYPES
template
<
class
T
,
class
=
void
>
template
<
class
T
,
class
=
void
>
struct
get_type
;
struct
get_type
;
#define RTG_SHAPE_GET_TYPE(x, t) \
#define RTG_SHAPE_GET_TYPE(x, t)
\
template<class T> \
template
<class T>
\
struct get_type<t, T> : std::integral_constant<type_t, x> \
struct get_type<t, T> : std::integral_constant<type_t, x> \
{};
{ \
};
RTG_SHAPE_VISIT_TYPES
(
RTG_SHAPE_GET_TYPE
)
RTG_SHAPE_VISIT_TYPES
(
RTG_SHAPE_GET_TYPE
)
#undef RTG_SHAPE_GET_TYPE
#undef RTG_SHAPE_GET_TYPE
...
@@ -44,7 +47,6 @@ struct shape
...
@@ -44,7 +47,6 @@ struct shape
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
);
type_t
type
()
const
;
type_t
type
()
const
;
const
std
::
vector
<
std
::
size_t
>&
lens
()
const
;
const
std
::
vector
<
std
::
size_t
>&
lens
()
const
;
const
std
::
vector
<
std
::
size_t
>&
strides
()
const
;
const
std
::
vector
<
std
::
size_t
>&
strides
()
const
;
...
@@ -63,67 +65,60 @@ struct shape
...
@@ -63,67 +65,60 @@ struct shape
friend
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
);
friend
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
&
x
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
&
x
);
template
<
class
T
>
template
<
class
T
>
struct
as
struct
as
{
{
using
type
=
T
;
using
type
=
T
;
template
<
class
U
>
template
<
class
U
>
T
operator
()(
U
u
)
const
T
operator
()(
U
u
)
const
{
{
return
T
(
u
);
return
T
(
u
);
}
}
template
<
class
U
>
template
<
class
U
>
T
*
operator
()(
U
*
u
)
const
T
*
operator
()(
U
*
u
)
const
{
{
return
static_cast
<
T
*>
(
u
);
return
static_cast
<
T
*>
(
u
);
}
}
template
<
class
U
>
template
<
class
U
>
const
T
*
operator
()(
const
U
*
u
)
const
const
T
*
operator
()(
const
U
*
u
)
const
{
{
return
static_cast
<
T
*>
(
u
);
return
static_cast
<
T
*>
(
u
);
}
}
T
operator
()()
const
T
operator
()()
const
{
return
{};
}
{
return
{};
}
std
::
size_t
size
(
std
::
size_t
n
=
1
)
const
std
::
size_t
size
(
std
::
size_t
n
=
1
)
const
{
return
sizeof
(
T
)
*
n
;
}
{
return
sizeof
(
T
)
*
n
;
}
template
<
class
U
>
template
<
class
U
>
T
*
from
(
U
*
buffer
,
std
::
size_t
n
=
0
)
const
T
*
from
(
U
*
buffer
,
std
::
size_t
n
=
0
)
const
{
{
return
reinterpret_cast
<
T
*>
(
buffer
)
+
n
;
return
reinterpret_cast
<
T
*>
(
buffer
)
+
n
;
}
}
template
<
class
U
>
template
<
class
U
>
const
T
*
from
(
const
U
*
buffer
,
std
::
size_t
n
=
0
)
const
const
T
*
from
(
const
U
*
buffer
,
std
::
size_t
n
=
0
)
const
{
{
return
reinterpret_cast
<
const
T
*>
(
buffer
)
+
n
;
return
reinterpret_cast
<
const
T
*>
(
buffer
)
+
n
;
}
}
};
};
template
<
class
Visitor
>
template
<
class
Visitor
>
void
visit_type
(
Visitor
v
)
const
void
visit_type
(
Visitor
v
)
const
{
{
switch
(
this
->
type_
)
switch
(
this
->
type_
)
{
{
#define RTG_SHAPE_VISITOR_CASE(x, t) \
#define RTG_SHAPE_VISITOR_CASE(x, t) \
case x: \
case x: v(as<t>()); return;
v(as<t>()); \
return;
RTG_SHAPE_VISIT_TYPES
(
RTG_SHAPE_VISITOR_CASE
)
RTG_SHAPE_VISIT_TYPES
(
RTG_SHAPE_VISITOR_CASE
)
#undef RTG_SHAPE_VISITOR_CASE
#undef RTG_SHAPE_VISITOR_CASE
}
}
assert
(
true
);
assert
(
true
);
}
}
private:
private:
type_t
type_
;
type_t
type_
;
std
::
vector
<
std
::
size_t
>
lens_
;
std
::
vector
<
std
::
size_t
>
lens_
;
std
::
vector
<
std
::
size_t
>
strides_
;
std
::
vector
<
std
::
size_t
>
strides_
;
...
@@ -134,6 +129,6 @@ private:
...
@@ -134,6 +129,6 @@ private:
std
::
string
type_string
()
const
;
std
::
string
type_string
()
const
;
};
};
}
}
// namespace rtg
#endif
#endif
include/rtg/stringutils.hpp
View file @
14d5666b
...
@@ -65,17 +65,14 @@ inline std::string remove_prefix(std::string s, std::string prefix)
...
@@ -65,17 +65,14 @@ inline std::string remove_prefix(std::string s, std::string prefix)
return
s
;
return
s
;
}
}
template
<
class
Range
>
template
<
class
Range
>
inline
std
::
string
to_string
(
const
Range
&
r
)
inline
std
::
string
to_string
(
const
Range
&
r
)
{
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
if
(
!
r
.
empty
())
if
(
!
r
.
empty
())
{
{
ss
<<
r
.
front
();
ss
<<
r
.
front
();
std
::
for_each
(
std
::
next
(
r
.
begin
()),
r
.
end
(),
[
&
](
auto
&&
x
)
std
::
for_each
(
std
::
next
(
r
.
begin
()),
r
.
end
(),
[
&
](
auto
&&
x
)
{
ss
<<
", "
<<
x
;
});
{
ss
<<
", "
<<
x
;
});
}
}
return
ss
.
str
();
return
ss
.
str
();
}
}
...
...
include/rtg/tensor_view.hpp
View file @
14d5666b
...
@@ -8,48 +8,29 @@
...
@@ -8,48 +8,29 @@
namespace
rtg
{
namespace
rtg
{
template
<
class
T
>
template
<
class
T
>
struct
tensor_view
struct
tensor_view
{
{
tensor_view
()
tensor_view
()
:
data_
(
nullptr
),
shape_
()
{}
:
data_
(
nullptr
),
shape_
()
tensor_view
(
shape
s
,
T
*
d
)
:
data_
(
d
),
shape_
(
s
)
{}
{}
tensor_view
(
shape
s
,
T
*
d
)
:
data_
(
d
),
shape_
(
s
)
{}
const
shape
&
get_shape
()
const
{
return
this
->
shape_
;
}
bool
empty
()
const
const
shape
&
get_shape
()
const
{
return
this
->
shape_
;
}
{
return
data_
==
nullptr
||
shape_
.
lens
().
size
()
==
0
;
}
std
::
size_t
size
()
const
bool
empty
()
const
{
return
data_
==
nullptr
||
shape_
.
lens
().
size
()
==
0
;
}
{
return
shape_
.
elements
();
}
T
*
data
()
std
::
size_t
size
()
const
{
return
shape_
.
elements
();
}
{
return
this
->
data_
;
}
const
T
*
data
()
const
T
*
data
()
{
return
this
->
data_
;
}
{
return
this
->
data_
;
}
template
<
class
...
Ts
>
const
T
*
data
()
const
{
return
this
->
data_
;
}
template
<
class
...
Ts
>
const
T
&
operator
()(
Ts
...
xs
)
const
const
T
&
operator
()(
Ts
...
xs
)
const
{
{
return
data_
[
shape_
.
index
({
xs
...})];
return
data_
[
shape_
.
index
({
xs
...})];
}
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
T
&
operator
()(
Ts
...
xs
)
T
&
operator
()(
Ts
...
xs
)
{
{
return
data_
[
shape_
.
index
({
xs
...})];
return
data_
[
shape_
.
index
({
xs
...})];
...
@@ -82,13 +63,13 @@ struct tensor_view
...
@@ -82,13 +63,13 @@ struct tensor_view
T
&
back
()
T
&
back
()
{
{
assert
(
!
this
->
empty
());
assert
(
!
this
->
empty
());
return
data_
[
shape_
.
index
(
this
->
size
()
-
1
)];
return
data_
[
shape_
.
index
(
this
->
size
()
-
1
)];
}
}
const
T
&
back
()
const
const
T
&
back
()
const
{
{
assert
(
!
this
->
empty
());
assert
(
!
this
->
empty
());
return
data_
[
shape_
.
index
(
this
->
size
()
-
1
)];
return
data_
[
shape_
.
index
(
this
->
size
()
-
1
)];
}
}
// TODO: Add iterators so it can handle nonpacked tensors
// TODO: Add iterators so it can handle nonpacked tensors
...
@@ -101,8 +82,10 @@ struct tensor_view
...
@@ -101,8 +82,10 @@ struct tensor_view
T
*
end
()
T
*
end
()
{
{
assert
(
this
->
shape_
.
packed
());
assert
(
this
->
shape_
.
packed
());
if
(
this
->
empty
())
return
data_
;
if
(
this
->
empty
())
else
return
data_
+
this
->
size
();
return
data_
;
else
return
data_
+
this
->
size
();
}
}
const
T
*
begin
()
const
const
T
*
begin
()
const
...
@@ -114,34 +97,34 @@ struct tensor_view
...
@@ -114,34 +97,34 @@ struct tensor_view
const
T
*
end
()
const
const
T
*
end
()
const
{
{
assert
(
this
->
shape_
.
packed
());
assert
(
this
->
shape_
.
packed
());
if
(
this
->
empty
())
return
data_
;
if
(
this
->
empty
())
else
return
data_
+
this
->
size
();
return
data_
;
else
return
data_
+
this
->
size
();
}
}
friend
bool
operator
==
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
T
>&
y
)
friend
bool
operator
==
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
T
>&
y
)
{
{
if
(
x
.
shape_
==
y
.
shape_
)
if
(
x
.
shape_
==
y
.
shape_
)
{
{
for
(
std
::
size_t
i
=
0
;
i
<
x
.
shape_
.
elements
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
x
.
shape_
.
elements
();
i
++
)
{
{
if
(
!
float_equal
(
x
[
i
],
y
[
i
]))
return
false
;
if
(
!
float_equal
(
x
[
i
],
y
[
i
]))
return
false
;
}
}
return
true
;
return
true
;
}
}
return
false
;
return
false
;
}
}
friend
bool
operator
!=
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
T
>&
y
)
friend
bool
operator
!=
(
const
tensor_view
<
T
>&
x
,
const
tensor_view
<
T
>&
y
)
{
return
!
(
x
==
y
);
}
{
return
!
(
x
==
y
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
tensor_view
<
T
>&
x
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
tensor_view
<
T
>&
x
)
{
{
if
(
!
x
.
empty
())
if
(
!
x
.
empty
())
{
{
os
<<
x
.
front
();
os
<<
x
.
front
();
for
(
std
::
size_t
i
=
1
;
i
<
x
.
shape_
.
elements
();
i
++
)
for
(
std
::
size_t
i
=
1
;
i
<
x
.
shape_
.
elements
();
i
++
)
{
{
os
<<
", "
<<
x
.
data_
[
x
.
shape_
.
index
(
i
)];
os
<<
", "
<<
x
.
data_
[
x
.
shape_
.
index
(
i
)];
}
}
...
@@ -149,12 +132,12 @@ struct tensor_view
...
@@ -149,12 +132,12 @@ struct tensor_view
return
os
;
return
os
;
}
}
private:
private:
T
*
data_
;
T
*
data_
;
shape
shape_
;
shape
shape_
;
};
};
template
<
class
T
>
template
<
class
T
>
tensor_view
<
T
>
make_view
(
shape
s
,
T
*
data
)
tensor_view
<
T
>
make_view
(
shape
s
,
T
*
data
)
{
{
return
{
s
,
data
};
return
{
s
,
data
};
...
...
onnx/read_onnx.cpp
View file @
14d5666b
...
@@ -13,43 +13,41 @@
...
@@ -13,43 +13,41 @@
struct
unknown
struct
unknown
{
{
std
::
string
op
;
std
::
string
op
;
std
::
string
name
()
const
std
::
string
name
()
const
{
return
"unknown:"
+
op
;
}
{
return
"unknown:"
+
op
;
}
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
input
)
const
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
input
)
const
{
{
if
(
input
.
empty
())
return
{};
if
(
input
.
empty
())
else
return
input
.
front
();
return
{};
}
else
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
input
)
const
return
input
.
front
();
{
throw
"not computable"
;
}
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
input
)
const
{
throw
"not computable"
;
}
};
};
template
<
class
C
,
class
T
>
template
<
class
C
,
class
T
>
bool
contains
(
C
&&
c
,
T
&&
x
)
bool
contains
(
C
&&
c
,
T
&&
x
)
{
{
return
c
.
find
(
x
)
!=
c
.
end
();
return
c
.
find
(
x
)
!=
c
.
end
();
}
}
template
<
class
Range
,
class
Iterator
>
template
<
class
Range
,
class
Iterator
>
void
copy
(
Range
&&
r
,
Iterator
it
)
void
copy
(
Range
&&
r
,
Iterator
it
)
{
{
std
::
copy
(
r
.
begin
(),
r
.
end
(),
it
);
std
::
copy
(
r
.
begin
(),
r
.
end
(),
it
);
}
}
struct
onnx_parser
struct
onnx_parser
{
{
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
;
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
;
using
node_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
;
using
node_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
;
node_map
nodes
;
node_map
nodes
;
std
::
unordered_map
<
std
::
string
,
rtg
::
instruction
*>
instructions
;
std
::
unordered_map
<
std
::
string
,
rtg
::
instruction
*>
instructions
;
std
::
shared_ptr
<
rtg
::
program
>
prog
=
std
::
make_shared
<
rtg
::
program
>
();
std
::
shared_ptr
<
rtg
::
program
>
prog
=
std
::
make_shared
<
rtg
::
program
>
();
std
::
unordered_map
<
std
::
string
,
std
::
function
<
rtg
::
instruction
*
(
attribute_map
,
std
::
vector
<
rtg
::
instruction
*>
)
>>
ops
;
std
::
unordered_map
<
std
::
string
,
std
::
function
<
rtg
::
instruction
*
(
attribute_map
,
std
::
vector
<
rtg
::
instruction
*>
)
>>
ops
;
onnx_parser
()
onnx_parser
()
{
{
...
@@ -92,10 +90,7 @@ struct onnx_parser
...
@@ -92,10 +90,7 @@ struct onnx_parser
add_op
(
"Reshape"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction
*>
args
)
{
add_op
(
"Reshape"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction
*>
args
)
{
rtg
::
reshape
op
;
rtg
::
reshape
op
;
rtg
::
literal
s
=
parse_value
(
attributes
.
at
(
"shape"
));
rtg
::
literal
s
=
parse_value
(
attributes
.
at
(
"shape"
));
s
.
visit
([
&
](
auto
v
)
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
return
prog
->
add_instruction
(
op
,
args
);
return
prog
->
add_instruction
(
op
,
args
);
});
});
add_op
(
"Constant"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction
*>
)
{
add_op
(
"Constant"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction
*>
)
{
...
@@ -104,7 +99,7 @@ struct onnx_parser
...
@@ -104,7 +99,7 @@ struct onnx_parser
});
});
}
}
template
<
class
F
>
template
<
class
F
>
void
add_op
(
std
::
string
name
,
F
f
)
void
add_op
(
std
::
string
name
,
F
f
)
{
{
ops
.
emplace
(
name
,
f
);
ops
.
emplace
(
name
,
f
);
...
@@ -113,14 +108,14 @@ struct onnx_parser
...
@@ -113,14 +108,14 @@ struct onnx_parser
void
parse_from
(
std
::
istream
&
is
)
void
parse_from
(
std
::
istream
&
is
)
{
{
onnx
::
ModelProto
model
;
onnx
::
ModelProto
model
;
if
(
model
.
ParseFromIstream
(
&
is
))
if
(
model
.
ParseFromIstream
(
&
is
))
{
{
if
(
model
.
has_graph
())
if
(
model
.
has_graph
())
{
{
this
->
parse_graph
(
model
.
graph
());
this
->
parse_graph
(
model
.
graph
());
}
}
}
}
else
else
{
{
throw
std
::
runtime_error
(
"Failed reading"
);
throw
std
::
runtime_error
(
"Failed reading"
);
}
}
...
@@ -129,14 +124,14 @@ struct onnx_parser
...
@@ -129,14 +124,14 @@ struct onnx_parser
void
parse_graph
(
const
onnx
::
GraphProto
&
graph
)
void
parse_graph
(
const
onnx
::
GraphProto
&
graph
)
{
{
nodes
=
get_nodes
(
graph
);
nodes
=
get_nodes
(
graph
);
for
(
auto
&&
input
:
graph
.
input
())
for
(
auto
&&
input
:
graph
.
input
())
{
{
std
::
string
name
=
input
.
name
();
std
::
string
name
=
input
.
name
();
// TODO: Get shape of input parameter
// TODO: Get shape of input parameter
rtg
::
shape
s
=
parse_type
(
input
.
type
());
rtg
::
shape
s
=
parse_type
(
input
.
type
());
instructions
[
name
]
=
prog
->
add_parameter
(
name
,
s
);
instructions
[
name
]
=
prog
->
add_parameter
(
name
,
s
);
}
}
for
(
auto
&&
p
:
nodes
)
for
(
auto
&&
p
:
nodes
)
{
{
this
->
parse_node
(
p
.
second
.
name
());
this
->
parse_node
(
p
.
second
.
name
());
}
}
...
@@ -144,11 +139,11 @@ struct onnx_parser
...
@@ -144,11 +139,11 @@ struct onnx_parser
void
parse_node
(
std
::
string
name
)
void
parse_node
(
std
::
string
name
)
{
{
if
(
instructions
.
count
(
name
)
==
0
)
if
(
instructions
.
count
(
name
)
==
0
)
{
{
auto
&&
node
=
nodes
.
at
(
name
);
auto
&&
node
=
nodes
.
at
(
name
);
std
::
vector
<
rtg
::
instruction
*>
args
;
std
::
vector
<
rtg
::
instruction
*>
args
;
for
(
auto
&&
input
:
node
.
input
())
for
(
auto
&&
input
:
node
.
input
())
{
{
if
(
nodes
.
count
(
input
)
>
0
)
if
(
nodes
.
count
(
input
)
>
0
)
{
{
...
@@ -161,7 +156,7 @@ struct onnx_parser
...
@@ -161,7 +156,7 @@ struct onnx_parser
args
.
push_back
(
instructions
.
at
(
input
));
args
.
push_back
(
instructions
.
at
(
input
));
}
}
}
}
if
(
ops
.
count
(
node
.
op_type
())
==
0
)
if
(
ops
.
count
(
node
.
op_type
())
==
0
)
{
{
instructions
[
name
]
=
prog
->
add_instruction
(
unknown
{
node
.
op_type
()},
args
);
instructions
[
name
]
=
prog
->
add_instruction
(
unknown
{
node
.
op_type
()},
args
);
}
}
...
@@ -175,7 +170,7 @@ struct onnx_parser
...
@@ -175,7 +170,7 @@ struct onnx_parser
static
attribute_map
get_attributes
(
const
onnx
::
NodeProto
&
node
)
static
attribute_map
get_attributes
(
const
onnx
::
NodeProto
&
node
)
{
{
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
result
;
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
result
;
for
(
auto
&&
attr
:
node
.
attribute
())
for
(
auto
&&
attr
:
node
.
attribute
())
{
{
result
[
attr
.
name
()]
=
attr
;
result
[
attr
.
name
()]
=
attr
;
}
}
...
@@ -185,14 +180,13 @@ struct onnx_parser
...
@@ -185,14 +180,13 @@ struct onnx_parser
static
node_map
get_nodes
(
const
onnx
::
GraphProto
&
graph
)
static
node_map
get_nodes
(
const
onnx
::
GraphProto
&
graph
)
{
{
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
result
;
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
result
;
for
(
auto
&&
node
:
graph
.
node
())
for
(
auto
&&
node
:
graph
.
node
())
{
{
result
[
node
.
name
()]
=
node
;
result
[
node
.
name
()]
=
node
;
for
(
auto
&&
output
:
node
.
output
())
for
(
auto
&&
output
:
node
.
output
())
{
{
result
[
output
]
=
node
;
result
[
output
]
=
node
;
}
}
}
}
return
result
;
return
result
;
}
}
...
@@ -201,17 +195,20 @@ struct onnx_parser
...
@@ -201,17 +195,20 @@ struct onnx_parser
{
{
switch
(
attr
.
type
())
switch
(
attr
.
type
())
{
{
case
onnx
::
AttributeProto
::
UNDEFINED
:
return
{};
case
onnx
::
AttributeProto
::
UNDEFINED
:
return
{};
case
onnx
::
AttributeProto
::
FLOAT
:
return
rtg
::
literal
{
attr
.
f
()};
case
onnx
::
AttributeProto
::
FLOAT
:
return
rtg
::
literal
{
attr
.
f
()};
case
onnx
::
AttributeProto
::
INT
:
return
rtg
::
literal
{
attr
.
i
()};
case
onnx
::
AttributeProto
::
INT
:
return
rtg
::
literal
{
attr
.
i
()};
case
onnx
::
AttributeProto
::
STRING
:
return
{};
case
onnx
::
AttributeProto
::
STRING
:
return
{};
case
onnx
::
AttributeProto
::
TENSOR
:
return
parse_tensor
(
attr
.
t
());
case
onnx
::
AttributeProto
::
TENSOR
:
return
parse_tensor
(
attr
.
t
());
case
onnx
::
AttributeProto
::
GRAPH
:
return
{};
case
onnx
::
AttributeProto
::
GRAPH
:
return
{};
case
onnx
::
AttributeProto
::
FLOATS
:
return
rtg
::
literal
{
rtg
::
shape
::
float_type
,
attr
.
floats
().
begin
(),
attr
.
floats
().
end
()};
case
onnx
::
AttributeProto
::
FLOATS
:
case
onnx
::
AttributeProto
::
INTS
:
return
rtg
::
literal
{
rtg
::
shape
::
int32_type
,
attr
.
ints
().
begin
(),
attr
.
ints
().
end
()};;
return
rtg
::
literal
{
rtg
::
shape
::
float_type
,
attr
.
floats
().
begin
(),
attr
.
floats
().
end
()};
case
onnx
::
AttributeProto
::
STRINGS
:
return
{};
case
onnx
::
AttributeProto
::
INTS
:
case
onnx
::
AttributeProto
::
TENSORS
:
return
{};
return
rtg
::
literal
{
rtg
::
shape
::
int32_type
,
attr
.
ints
().
begin
(),
attr
.
ints
().
end
()};
case
onnx
::
AttributeProto
::
GRAPHS
:
return
{};
;
case
onnx
::
AttributeProto
::
STRINGS
:
return
{};
case
onnx
::
AttributeProto
::
TENSORS
:
return
{};
case
onnx
::
AttributeProto
::
GRAPHS
:
return
{};
}
}
}
}
...
@@ -220,22 +217,38 @@ struct onnx_parser
...
@@ -220,22 +217,38 @@ struct onnx_parser
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
switch
(
t
.
data_type
())
switch
(
t
.
data_type
())
{
{
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT
:
return
rtg
::
literal
{{
rtg
::
shape
::
float_type
,
dims
},
t
.
float_data
().
begin
(),
t
.
float_data
().
end
()};
case
onnx
::
TensorProto
::
FLOAT
:
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
return
rtg
::
literal
{
case
onnx
::
TensorProto
::
INT8
:
return
rtg
::
literal
{{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
{
rtg
::
shape
::
float_type
,
dims
},
t
.
float_data
().
begin
(),
t
.
float_data
().
end
()};
case
onnx
::
TensorProto
::
UINT16
:
return
rtg
::
literal
{{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
INT16
:
return
rtg
::
literal
{{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT8
:
case
onnx
::
TensorProto
::
INT32
:
return
rtg
::
literal
{{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
return
rtg
::
literal
{
case
onnx
::
TensorProto
::
INT64
:
return
rtg
::
literal
{{
rtg
::
shape
::
int64_type
,
dims
},
t
.
int64_data
().
begin
(),
t
.
int64_data
().
end
()};
{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT16
:
case
onnx
::
TensorProto
::
BOOL
:
return
rtg
::
literal
{{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
return
rtg
::
literal
{
case
onnx
::
TensorProto
::
FLOAT16
:
throw
std
::
runtime_error
(
""
);
{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
DOUBLE
:
return
rtg
::
literal
{{
rtg
::
shape
::
double_type
,
dims
},
t
.
double_data
().
begin
(),
t
.
double_data
().
end
()};
case
onnx
::
TensorProto
::
INT16
:
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
return
rtg
::
literal
{
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
INT32
:
case
onnx
::
TensorProto
::
COMPLEX128
:
throw
std
::
runtime_error
(
""
);
return
rtg
::
literal
{
{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT64
:
return
rtg
::
literal
{
{
rtg
::
shape
::
int64_type
,
dims
},
t
.
int64_data
().
begin
(),
t
.
int64_data
().
end
()};
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
BOOL
:
return
rtg
::
literal
{
{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
FLOAT16
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
DOUBLE
:
return
rtg
::
literal
{
{
rtg
::
shape
::
double_type
,
dims
},
t
.
double_data
().
begin
(),
t
.
double_data
().
end
()};
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX128
:
throw
std
::
runtime_error
(
""
);
}
}
}
}
...
@@ -244,26 +257,33 @@ struct onnx_parser
...
@@ -244,26 +257,33 @@ struct onnx_parser
rtg
::
shape
::
type_t
shape_type
;
rtg
::
shape
::
type_t
shape_type
;
switch
(
t
.
tensor_type
().
elem_type
())
switch
(
t
.
tensor_type
().
elem_type
())
{
{
case
onnx
::
TensorProto
::
UNDEFINED
:
break
;
//throw std::runtime_error("Unsupported type UNDEFINED");
case
onnx
::
TensorProto
::
UNDEFINED
:
case
onnx
::
TensorProto
::
FLOAT
:
shape_type
=
rtg
::
shape
::
float_type
;
break
;
// throw std::runtime_error("Unsupported type UNDEFINED");
case
onnx
::
TensorProto
::
UINT8
:
break
;
//throw std::runtime_error("Unsupported type UINT8");
case
onnx
::
TensorProto
::
FLOAT
:
shape_type
=
rtg
::
shape
::
float_type
;
case
onnx
::
TensorProto
::
INT8
:
shape_type
=
rtg
::
shape
::
int8_type
;
case
onnx
::
TensorProto
::
UINT8
:
case
onnx
::
TensorProto
::
UINT16
:
shape_type
=
rtg
::
shape
::
uint16_type
;
break
;
// throw std::runtime_error("Unsupported type UINT8");
case
onnx
::
TensorProto
::
INT16
:
shape_type
=
rtg
::
shape
::
int16_type
;
case
onnx
::
TensorProto
::
INT8
:
shape_type
=
rtg
::
shape
::
int8_type
;
case
onnx
::
TensorProto
::
INT32
:
shape_type
=
rtg
::
shape
::
int32_type
;
case
onnx
::
TensorProto
::
UINT16
:
shape_type
=
rtg
::
shape
::
uint16_type
;
case
onnx
::
TensorProto
::
INT64
:
shape_type
=
rtg
::
shape
::
int64_type
;
case
onnx
::
TensorProto
::
INT16
:
shape_type
=
rtg
::
shape
::
int16_type
;
case
onnx
::
TensorProto
::
STRING
:
break
;
//throw std::runtime_error("Unsupported type STRING");
case
onnx
::
TensorProto
::
INT32
:
shape_type
=
rtg
::
shape
::
int32_type
;
case
onnx
::
TensorProto
::
BOOL
:
break
;
//throw std::runtime_error("Unsupported type BOOL");
case
onnx
::
TensorProto
::
INT64
:
shape_type
=
rtg
::
shape
::
int64_type
;
case
onnx
::
TensorProto
::
FLOAT16
:
break
;
//throw std::runtime_error("Unsupported type FLOAT16");
case
onnx
::
TensorProto
::
STRING
:
case
onnx
::
TensorProto
::
DOUBLE
:
shape_type
=
rtg
::
shape
::
double_type
;
break
;
// throw std::runtime_error("Unsupported type STRING");
case
onnx
::
TensorProto
::
UINT32
:
shape_type
=
rtg
::
shape
::
uint32_type
;
case
onnx
::
TensorProto
::
BOOL
:
case
onnx
::
TensorProto
::
UINT64
:
shape_type
=
rtg
::
shape
::
uint64_type
;
break
;
// throw std::runtime_error("Unsupported type BOOL");
case
onnx
::
TensorProto
::
COMPLEX64
:
break
;
//throw std::runtime_error("Unsupported type COMPLEX64");
case
onnx
::
TensorProto
::
FLOAT16
:
case
onnx
::
TensorProto
::
COMPLEX128
:
break
;
//throw std::runtime_error("Unsupported type COMPLEX128");
break
;
// throw std::runtime_error("Unsupported type FLOAT16");
case
onnx
::
TensorProto
::
DOUBLE
:
shape_type
=
rtg
::
shape
::
double_type
;
case
onnx
::
TensorProto
::
UINT32
:
shape_type
=
rtg
::
shape
::
uint32_type
;
case
onnx
::
TensorProto
::
UINT64
:
shape_type
=
rtg
::
shape
::
uint64_type
;
case
onnx
::
TensorProto
::
COMPLEX64
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX64");
case
onnx
::
TensorProto
::
COMPLEX128
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX128");
}
}
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
size_t
>
dims
;
// TODO: USe std::transform
// TODO: USe std::transform
for
(
auto
&&
d
:
t
.
tensor_type
().
shape
().
dim
())
for
(
auto
&&
d
:
t
.
tensor_type
().
shape
().
dim
())
{
{
dims
.
push_back
(
d
.
dim_value
());
dims
.
push_back
(
d
.
dim_value
());
}
}
...
@@ -271,7 +291,7 @@ struct onnx_parser
...
@@ -271,7 +291,7 @@ struct onnx_parser
}
}
};
};
int
main
(
int
argc
,
char
const
*
argv
[])
int
main
(
int
argc
,
char
const
*
argv
[])
{
{
if
(
argc
>
1
)
if
(
argc
>
1
)
{
{
...
@@ -284,7 +304,8 @@ int main(int argc, char const *argv[])
...
@@ -284,7 +304,8 @@ int main(int argc, char const *argv[])
}
}
catch
(...)
catch
(...)
{
{
if
(
parser
.
prog
)
parser
.
prog
->
print
();
if
(
parser
.
prog
)
parser
.
prog
->
print
();
throw
;
throw
;
}
}
parser
.
prog
->
print
();
parser
.
prog
->
print
();
...
...
src/program.cpp
View file @
14d5666b
...
@@ -9,7 +9,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
...
@@ -9,7 +9,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
{
{
std
::
unordered_map
<
const
instruction
*
,
argument
>
results
;
std
::
unordered_map
<
const
instruction
*
,
argument
>
results
;
argument
result
;
argument
result
;
for
(
auto
&
ins
:
instructions
)
for
(
auto
&
ins
:
instructions
)
{
{
if
(
ins
.
op
.
name
()
==
"@literal"
)
if
(
ins
.
op
.
name
()
==
"@literal"
)
{
{
...
@@ -22,9 +22,10 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
...
@@ -22,9 +22,10 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
else
else
{
{
std
::
vector
<
argument
>
values
(
ins
.
arguments
.
size
());
std
::
vector
<
argument
>
values
(
ins
.
arguments
.
size
());
std
::
transform
(
ins
.
arguments
.
begin
(),
ins
.
arguments
.
end
(),
values
.
begin
(),
[
&
](
instruction
*
i
)
{
std
::
transform
(
ins
.
arguments
.
begin
(),
return
results
.
at
(
i
);
ins
.
arguments
.
end
(),
});
values
.
begin
(),
[
&
](
instruction
*
i
)
{
return
results
.
at
(
i
);
});
result
=
ins
.
op
.
compute
(
values
);
result
=
ins
.
op
.
compute
(
values
);
}
}
results
.
emplace
(
std
::
addressof
(
ins
),
result
);
results
.
emplace
(
std
::
addressof
(
ins
),
result
);
...
@@ -37,7 +38,7 @@ void program::print() const
...
@@ -37,7 +38,7 @@ void program::print() const
std
::
unordered_map
<
const
instruction
*
,
std
::
string
>
names
;
std
::
unordered_map
<
const
instruction
*
,
std
::
string
>
names
;
int
count
=
0
;
int
count
=
0
;
for
(
auto
&
ins
:
instructions
)
for
(
auto
&
ins
:
instructions
)
{
{
std
::
string
var_name
=
"@"
+
std
::
to_string
(
count
);
std
::
string
var_name
=
"@"
+
std
::
to_string
(
count
);
if
(
starts_with
(
ins
.
op
.
name
(),
"@param"
))
if
(
starts_with
(
ins
.
op
.
name
(),
"@param"
))
...
@@ -51,7 +52,7 @@ void program::print() const
...
@@ -51,7 +52,7 @@ void program::print() const
if
(
ins
.
op
.
name
()
==
"@literal"
)
if
(
ins
.
op
.
name
()
==
"@literal"
)
{
{
if
(
ins
.
lit
.
get_shape
().
elements
()
>
10
)
if
(
ins
.
lit
.
get_shape
().
elements
()
>
10
)
std
::
cout
<<
"{ ... }"
;
std
::
cout
<<
"{ ... }"
;
else
else
std
::
cout
<<
"{"
<<
ins
.
lit
<<
"}"
;
std
::
cout
<<
"{"
<<
ins
.
lit
<<
"}"
;
...
@@ -60,7 +61,7 @@ void program::print() const
...
@@ -60,7 +61,7 @@ void program::print() const
if
(
!
ins
.
arguments
.
empty
())
if
(
!
ins
.
arguments
.
empty
())
{
{
char
delim
=
'('
;
char
delim
=
'('
;
for
(
auto
&&
arg
:
ins
.
arguments
)
for
(
auto
&&
arg
:
ins
.
arguments
)
{
{
assert
(
this
->
has_instruction
(
arg
)
&&
"Instruction not found"
);
assert
(
this
->
has_instruction
(
arg
)
&&
"Instruction not found"
);
std
::
cout
<<
delim
<<
names
.
at
(
arg
);
std
::
cout
<<
delim
<<
names
.
at
(
arg
);
...
@@ -78,5 +79,4 @@ void program::print() const
...
@@ -78,5 +79,4 @@ void program::print() const
}
}
}
}
}
}
// namespace rtg
src/shape.cpp
View file @
14d5666b
...
@@ -7,21 +7,16 @@
...
@@ -7,21 +7,16 @@
namespace
rtg
{
namespace
rtg
{
shape
::
shape
()
shape
::
shape
()
:
type_
(
float_type
),
lens_
(),
strides_
(),
packed_
(
false
)
{}
:
type_
(
float_type
),
lens_
(),
strides_
(),
packed_
(
false
)
{}
shape
::
shape
(
type_t
t
)
shape
::
shape
(
type_t
t
)
:
type_
(
t
),
lens_
({
1
}),
strides_
({
1
}),
packed_
(
true
)
{}
:
type_
(
t
),
lens_
({
1
}),
strides_
({
1
}),
packed_
(
true
)
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
type_
(
t
),
lens_
(
std
::
move
(
l
)),
packed_
(
true
)
{}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
)
:
type_
(
t
),
lens_
(
std
::
move
(
l
)),
packed_
(
true
)
{
{
this
->
calculate_strides
();
this
->
calculate_strides
();
assert
(
lens_
.
size
()
==
strides_
.
size
());
assert
(
lens_
.
size
()
==
strides_
.
size
());
}
}
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
)
shape
::
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
)
:
type_
(
t
),
lens_
(
std
::
move
(
l
)),
strides_
(
std
::
move
(
s
))
:
type_
(
t
),
lens_
(
std
::
move
(
l
)),
strides_
(
std
::
move
(
s
))
{
{
assert
(
lens_
.
size
()
==
strides_
.
size
());
assert
(
lens_
.
size
()
==
strides_
.
size
());
packed_
=
this
->
elements
()
==
this
->
element_space
();
packed_
=
this
->
elements
()
==
this
->
element_space
();
...
@@ -38,18 +33,9 @@ void shape::calculate_strides()
...
@@ -38,18 +33,9 @@ void shape::calculate_strides()
lens_
.
rbegin
(),
lens_
.
rend
()
-
1
,
strides_
.
rbegin
()
+
1
,
std
::
multiplies
<
std
::
size_t
>
());
lens_
.
rbegin
(),
lens_
.
rend
()
-
1
,
strides_
.
rbegin
()
+
1
,
std
::
multiplies
<
std
::
size_t
>
());
}
}
shape
::
type_t
shape
::
type
()
const
shape
::
type_t
shape
::
type
()
const
{
return
this
->
type_
;
}
{
const
std
::
vector
<
std
::
size_t
>&
shape
::
lens
()
const
{
return
this
->
lens_
;
}
return
this
->
type_
;
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
this
->
strides_
;
}
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
lens
()
const
{
return
this
->
lens_
;
}
const
std
::
vector
<
std
::
size_t
>&
shape
::
strides
()
const
{
return
this
->
strides_
;
}
std
::
size_t
shape
::
elements
()
const
std
::
size_t
shape
::
elements
()
const
{
{
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
...
@@ -77,13 +63,15 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const
...
@@ -77,13 +63,15 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const
std
::
size_t
shape
::
index
(
std
::
size_t
i
)
const
std
::
size_t
shape
::
index
(
std
::
size_t
i
)
const
{
{
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
assert
(
this
->
lens
().
size
()
==
this
->
strides
().
size
());
return
std
::
inner_product
(
this
->
lens
().
begin
(),
this
->
lens
().
end
(),
this
->
strides
().
begin
(),
std
::
size_t
{
0
},
std
::
plus
<
std
::
size_t
>
{},
return
std
::
inner_product
(
[
&
](
std
::
size_t
len
,
std
::
size_t
stride
)
{
return
((
i
/
stride
)
%
len
)
*
stride
;
});
this
->
lens
().
begin
(),
}
this
->
lens
().
end
(),
bool
shape
::
packed
()
const
this
->
strides
().
begin
(),
{
std
::
size_t
{
0
},
return
this
->
packed_
;
std
::
plus
<
std
::
size_t
>
{},
}
[
&
](
std
::
size_t
len
,
std
::
size_t
stride
)
{
return
((
i
/
stride
)
%
len
)
*
stride
;
});
}
bool
shape
::
packed
()
const
{
return
this
->
packed_
;
}
std
::
size_t
shape
::
element_space
()
const
std
::
size_t
shape
::
element_space
()
const
{
{
// TODO: Get rid of intermediate vector
// TODO: Get rid of intermediate vector
...
@@ -101,11 +89,10 @@ std::size_t shape::element_space() const
...
@@ -101,11 +89,10 @@ std::size_t shape::element_space() const
std
::
string
shape
::
type_string
()
const
std
::
string
shape
::
type_string
()
const
{
{
switch
(
this
->
type_
)
switch
(
this
->
type_
)
{
{
#define RTG_SHAPE_TYPE_STRING_CASE(x, t) \
#define RTG_SHAPE_TYPE_STRING_CASE(x, t) \
case x: \
case x: return #x;
return #x;
RTG_SHAPE_VISIT_TYPES
(
RTG_SHAPE_TYPE_STRING_CASE
)
RTG_SHAPE_VISIT_TYPES
(
RTG_SHAPE_TYPE_STRING_CASE
)
#undef RTG_SHAPE_TYPE_STRING_CASE
#undef RTG_SHAPE_TYPE_STRING_CASE
}
}
...
@@ -116,10 +103,7 @@ bool operator==(const shape& x, const shape& y)
...
@@ -116,10 +103,7 @@ bool operator==(const shape& x, const shape& y)
{
{
return
x
.
type
()
==
y
.
type
()
&&
x
.
lens
()
==
y
.
lens
()
&&
x
.
strides
()
==
y
.
strides
();
return
x
.
type
()
==
y
.
type
()
&&
x
.
lens
()
==
y
.
lens
()
&&
x
.
strides
()
==
y
.
strides
();
}
}
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
)
bool
operator
!=
(
const
shape
&
x
,
const
shape
&
y
)
{
return
!
(
x
==
y
);
}
{
return
!
(
x
==
y
);
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
&
x
)
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
shape
&
x
)
{
{
...
@@ -129,4 +113,4 @@ std::ostream& operator<<(std::ostream& os, const shape& x)
...
@@ -129,4 +113,4 @@ std::ostream& operator<<(std::ostream& os, const shape& x)
return
os
;
return
os
;
}
}
}
}
// namespace rtg
test/eval_test.cpp
View file @
14d5666b
...
@@ -4,37 +4,37 @@
...
@@ -4,37 +4,37 @@
#include <rtg/shape.hpp>
#include <rtg/shape.hpp>
#include "test.hpp"
#include "test.hpp"
struct
sum_op
struct
sum_op
{
{
std
::
string
name
()
const
std
::
string
name
()
const
{
return
"sum"
;
}
{
return
"sum"
;
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
args
)
const
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
args
)
const
{
{
rtg
::
argument
result
;
rtg
::
argument
result
;
if
(
args
.
size
()
!=
2
)
throw
"Wrong args"
;
if
(
args
.
size
()
!=
2
)
if
(
args
[
0
].
get_shape
()
!=
args
[
1
].
get_shape
())
throw
"Wrong args"
;
throw
"Wrong args"
;
if
(
args
[
0
].
get_shape
().
lens
().
size
()
!=
1
)
throw
"Wrong args"
;
if
(
args
[
0
].
get_shape
()
!=
args
[
1
].
get_shape
())
if
(
args
[
0
].
get_shape
().
lens
().
front
()
!=
1
)
throw
"Wrong args"
;
throw
"Wrong args"
;
if
(
args
[
0
].
get_shape
().
lens
().
size
()
!=
1
)
throw
"Wrong args"
;
if
(
args
[
0
].
get_shape
().
lens
().
front
()
!=
1
)
throw
"Wrong args"
;
args
[
0
].
visit_at
([
&
](
auto
x
)
{
args
[
0
].
visit_at
([
&
](
auto
x
)
{
args
[
1
].
visit_at
([
&
](
auto
y
)
{
args
[
1
].
visit_at
([
&
](
auto
y
)
{
result
=
rtg
::
literal
{
x
+
y
}.
get_argument
();
});
result
=
rtg
::
literal
{
x
+
y
}.
get_argument
();
});
});
});
return
result
;
return
result
;
}
}
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
inputs
)
const
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
inputs
)
const
{
{
if
(
inputs
.
size
()
!=
2
)
throw
"Wrong inputs"
;
if
(
inputs
.
size
()
!=
2
)
throw
"Wrong inputs"
;
return
inputs
.
front
();
return
inputs
.
front
();
}
}
};
};
void
literal_test
()
{
void
literal_test
()
{
rtg
::
program
p
;
rtg
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
one
=
p
.
add_literal
(
1
);
...
@@ -45,23 +45,22 @@ void literal_test() {
...
@@ -45,23 +45,22 @@ void literal_test() {
EXPECT
(
result
!=
rtg
::
literal
{
4
});
EXPECT
(
result
!=
rtg
::
literal
{
4
});
}
}
void
param_test
()
{
void
param_test
()
{
rtg
::
program
p
;
rtg
::
program
p
;
auto
x
=
p
.
add_parameter
(
"x"
,
{
rtg
::
shape
::
int64_type
});
auto
x
=
p
.
add_parameter
(
"x"
,
{
rtg
::
shape
::
int64_type
});
auto
y
=
p
.
add_parameter
(
"y"
,
{
rtg
::
shape
::
int64_type
});
auto
y
=
p
.
add_parameter
(
"y"
,
{
rtg
::
shape
::
int64_type
});
p
.
add_instruction
(
sum_op
{},
x
,
y
);
p
.
add_instruction
(
sum_op
{},
x
,
y
);
auto
result
=
p
.
eval
({
auto
result
=
{
"x"
,
rtg
::
literal
{
1
}.
get_argument
()},
p
.
eval
({{
"x"
,
rtg
::
literal
{
1
}.
get_argument
()},
{
"y"
,
rtg
::
literal
{
2
}.
get_argument
()}});
{
"y"
,
rtg
::
literal
{
2
}.
get_argument
()}
});
EXPECT
(
result
==
rtg
::
literal
{
3
});
EXPECT
(
result
==
rtg
::
literal
{
3
});
EXPECT
(
result
!=
rtg
::
literal
{
4
});
EXPECT
(
result
!=
rtg
::
literal
{
4
});
}
}
int
main
()
{
int
main
()
{
literal_test
();
literal_test
();
param_test
();
param_test
();
}
}
test/literal_test.cpp
View file @
14d5666b
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
#include <string>
#include <string>
#include "test.hpp"
#include "test.hpp"
void
literal_test
()
void
literal_test
()
{
{
EXPECT
(
rtg
::
literal
{
1
}
==
rtg
::
literal
{
1
});
EXPECT
(
rtg
::
literal
{
1
}
==
rtg
::
literal
{
1
});
...
@@ -23,7 +22,7 @@ void literal_test()
...
@@ -23,7 +22,7 @@ void literal_test()
rtg
::
literal
l4
{};
rtg
::
literal
l4
{};
EXPECT
(
l3
==
l4
);
EXPECT
(
l3
==
l4
);
EXPECT
(
l3
.
empty
());
EXPECT
(
l3
.
empty
());
EXPECT
(
l4
.
empty
());
EXPECT
(
l4
.
empty
());
}
}
void
literal_os1
()
void
literal_os1
()
...
@@ -51,10 +50,9 @@ void literal_os3()
...
@@ -51,10 +50,9 @@ void literal_os3()
EXPECT
(
ss
.
str
()
==
"1, 2, 3"
);
EXPECT
(
ss
.
str
()
==
"1, 2, 3"
);
}
}
int
main
()
{
int
main
()
{
literal_test
();
literal_test
();
literal_os1
();
literal_os1
();
literal_os2
();
literal_os2
();
}
}
test/main.cpp
View file @
14d5666b
int
main
()
{
int
main
()
{}
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment