Commit 14d5666b authored by Paul's avatar Paul
Browse files

Add clang formatting

parent 2305ac81
---
Language: Cpp
AccessModifierOffset: 0
AlignAfterOpenBracket: Align
AlignConsecutiveAssignments: true
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: true
AlignOperands: true
AlignTrailingComments: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortBlocksOnASingleLine: true
AllowShortCaseLabelsOnASingleLine: true
AllowShortFunctionsOnASingleLine: All
AllowShortIfStatementsOnASingleLine: false
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: false
AlwaysBreakTemplateDeclarations: true
BinPackArguments: false
BinPackParameters: false
BraceWrapping:
AfterClass: true
AfterControlStatement: true
AfterEnum: true
AfterFunction: true
AfterNamespace: false
AfterObjCDeclaration: true
AfterStruct: true
AfterUnion: true
BeforeCatch: true
BeforeElse: true
IndentBraces: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Custom
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
ColumnLimit: 100
CommentPragmas: '^ IWYU pragma:'
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
ExperimentalAutoDetectBinPacking: false
ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
IncludeCategories:
- Regex: '^"(llvm|llvm-c|clang|clang-c)/'
Priority: 2
- Regex: '^(<|"(gtest|isl|json)/)'
Priority: 3
- Regex: '.*'
Priority: 1
IndentCaseLabels: false
IndentWidth: 4
IndentWrappedFunctionNames: false
KeepEmptyLinesAtTheStartOfBlocks: true
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: true
PenaltyBreakBeforeFirstCallParameter: 19
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 60
PointerAlignment: Left
ReflowComments: true
SortIncludes: false
SpaceAfterCStyleCast: false
# SpaceAfterTemplateKeyword: true
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: Never
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Cpp11
TabWidth: 8
UseTab: Never
...
#!/usr/bin/env bash
cd $(git rev-parse --git-dir)
echo "Installing hooks..."
ln -s ../.githooks hooks
echo "Done!"
#!/bin/sh
#
# This pre-commit hook checks if any versions of clang-format
# are installed, and if so, uses the installed version to format
# the staged changes.
base=clang-format-5.0
format=""
# Redirect output to stderr.
exec 1>&2
# check if clang-format is installed
type "$base" >/dev/null 2>&1 && format="$base"
# no versions of clang-format are installed
if [ -z "$format" ]
then
echo "$base is not installed. Pre-commit hook will not be executed."
exit 0
fi
# Do everything from top - level
cd $(git rev-parse --show-toplevel)
if git rev-parse --verify HEAD >/dev/null 2>&1
then
against=HEAD
else
# Initial commit: diff against an empty tree object
against=16bbb57
fi
# do the formatting
for file in $(git diff-index --cached --name-only $against | grep -E '\.h$|\.hpp$|\.cpp$|\.cl$|\.h\.in$|\.hpp\.in$|\.cpp\.in$')
do
if [ -e "$file" ]
then
echo "$format $file"
"$format" -i -style=file "$file"
fi
done
...@@ -9,28 +9,20 @@ namespace rtg { ...@@ -9,28 +9,20 @@ namespace rtg {
struct argument : raw_data<argument> struct argument : raw_data<argument>
{ {
argument() argument() {}
{}
argument(shape s, std::function<char*()> d) argument(shape s, std::function<char*()> d) : data(d), shape_(s) {}
: data(d), shape_(s)
{}
std::function<char*()> data; std::function<char*()> data;
bool empty() const bool empty() const { return not data; }
{
return not data;
}
const shape& get_shape() const const shape& get_shape() const { return this->shape_; }
{
return this->shape_; private:
}
private:
shape shape_; shape shape_;
}; };
} } // namespace rtg
#endif #endif
...@@ -9,38 +9,20 @@ namespace builtin { ...@@ -9,38 +9,20 @@ namespace builtin {
struct literal struct literal
{ {
std::string name() const std::string name() const { return "@literal"; }
{ shape compute_shape(std::vector<shape>) const { throw "builtin"; }
return "@literal"; argument compute(std::vector<argument>) const { throw "builtin"; }
}
shape compute_shape(std::vector<shape>) const
{
throw "builtin";
}
argument compute(std::vector<argument>) const
{
throw "builtin";
}
}; };
struct param struct param
{ {
std::string parameter; std::string parameter;
std::string name() const std::string name() const { return "@param:" + parameter; }
{ shape compute_shape(std::vector<shape>) const { throw "builtin"; }
return "@param:" + parameter; argument compute(std::vector<argument>) const { throw "builtin"; }
}
shape compute_shape(std::vector<shape>) const
{
throw "builtin";
}
argument compute(std::vector<argument>) const
{
throw "builtin";
}
}; };
} } // namespace builtin
} // namespace rtg } // namespace rtg
......
...@@ -13,12 +13,14 @@ struct instruction ...@@ -13,12 +13,14 @@ struct instruction
instruction() {} instruction() {}
instruction(operand o, shape r, std::vector<instruction*> args) instruction(operand o, shape r, std::vector<instruction*> args)
: op(std::move(o)), result(std::move(r)), arguments(std::move(args)), lit() : op(std::move(o)), result(std::move(r)), arguments(std::move(args)), lit()
{} {
}
instruction(literal l) instruction(literal l)
: op(builtin::literal{}), result(l.get_shape()), arguments(), lit(std::move(l)) : op(builtin::literal{}), result(l.get_shape()), arguments(), lit(std::move(l))
{} {
}
operand op; operand op;
shape result; shape result;
...@@ -26,6 +28,6 @@ struct instruction ...@@ -26,6 +28,6 @@ struct instruction
literal lit; literal lit;
}; };
} } // namespace rtg
#endif #endif
...@@ -10,68 +10,45 @@ namespace rtg { ...@@ -10,68 +10,45 @@ namespace rtg {
struct literal : raw_data<literal> struct literal : raw_data<literal>
{ {
literal() literal() : buffer(), shape_() {}
: buffer(), shape_()
{}
template<class T> template <class T>
literal(T x) literal(T x) : buffer(sizeof(T), 0), shape_(shape::get_type<T>{})
: buffer(sizeof(T), 0), shape_(shape::get_type<T>{})
{ {
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types"); static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
*(reinterpret_cast<T*>(buffer.data())) = x; *(reinterpret_cast<T*>(buffer.data())) = x;
} }
template<class T> template <class T>
literal(shape s, const std::vector<T>& x) literal(shape s, const std::vector<T>& x) : buffer(s.bytes(), 0), shape_(s)
: buffer(s.bytes(), 0), shape_(s)
{ {
assert(s.packed()); assert(s.packed());
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types"); static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
s.visit_type([&](auto as) { s.visit_type([&](auto as) { std::copy(x.begin(), x.end(), as.from(buffer.data())); });
std::copy(x.begin(), x.end(), as.from(buffer.data()));
});
} }
template<class T> template <class T>
literal(shape s, const std::initializer_list<T>& x) literal(shape s, const std::initializer_list<T>& x) : buffer(s.bytes(), 0), shape_(s)
: buffer(s.bytes(), 0), shape_(s)
{ {
assert(s.packed()); assert(s.packed());
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types"); static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
s.visit_type([&](auto as) { s.visit_type([&](auto as) { std::copy(x.begin(), x.end(), as.from(buffer.data())); });
std::copy(x.begin(), x.end(), as.from(buffer.data()));
});
} }
template<class Iterator> template <class Iterator>
literal(shape s, Iterator start, Iterator end) literal(shape s, Iterator start, Iterator end) : buffer(s.bytes(), 0), shape_(s)
: buffer(s.bytes(), 0), shape_(s)
{ {
assert(s.packed()); assert(s.packed());
s.visit_type([&](auto as) { s.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.data())); });
std::copy(start, end, as.from(buffer.data()));
});
} }
literal(shape s, const char* x)
: buffer(x, x+s.bytes()), shape_(s)
{}
bool empty() const literal(shape s, const char* x) : buffer(x, x + s.bytes()), shape_(s) {}
{
return this->buffer.empty();
}
const char* data() const bool empty() const { return this->buffer.empty(); }
{
return this->buffer.data();
}
const shape& get_shape() const const char* data() const { return this->buffer.data(); }
{
return this->shape_; const shape& get_shape() const { return this->shape_; }
}
argument get_argument() const argument get_argument() const
{ {
...@@ -79,11 +56,11 @@ struct literal : raw_data<literal> ...@@ -79,11 +56,11 @@ struct literal : raw_data<literal>
return {shape_, [b]() mutable { return b.data(); }}; return {shape_, [b]() mutable { return b.data(); }};
} }
private: private:
std::vector<char> buffer; std::vector<char> buffer;
shape shape_; shape shape_;
}; };
} } // namespace rtg
#endif #endif
...@@ -12,16 +12,16 @@ ...@@ -12,16 +12,16 @@
namespace rtg { namespace rtg {
/* /*
* Type-erased interface for: * Type-erased interface for:
* *
* struct operand * struct operand
* { * {
* std::string name() const; * std::string name() const;
* shape compute_shape(std::vector<shape> input) const; * shape compute_shape(std::vector<shape> input) const;
* argument compute(std::vector<argument> input) const; * argument compute(std::vector<argument> input) const;
* }; * };
* *
*/ */
struct operand struct operand
{ {
...@@ -80,8 +80,9 @@ struct operand ...@@ -80,8 +80,9 @@ struct operand
struct handle_type_ : handle_base_type_ struct handle_type_ : handle_base_type_
{ {
template <typename TypeErased_U_ = TypeErased_T_> template <typename TypeErased_U_ = TypeErased_T_>
handle_type_(TypeErased_T_ value, handle_type_(
typename std::enable_if<std::is_reference<TypeErased_U_>::value>::type* = nullptr) TypeErased_T_ value,
typename std::enable_if<std::is_reference<TypeErased_U_>::value>::type* = nullptr)
: value_(value) : value_(value)
{ {
} }
...@@ -89,7 +90,8 @@ struct operand ...@@ -89,7 +90,8 @@ struct operand
template <typename TypeErased_U_ = TypeErased_T_> template <typename TypeErased_U_ = TypeErased_T_>
handle_type_(TypeErased_T_ value, handle_type_(TypeErased_T_ value,
typename std::enable_if<!std::is_reference<TypeErased_U_>::value, int>::type* = typename std::enable_if<!std::is_reference<TypeErased_U_>::value, int>::type* =
nullptr) noexcept : value_(std::move(value)) nullptr) noexcept
: value_(std::move(value))
{ {
} }
...@@ -134,6 +136,6 @@ struct operand ...@@ -134,6 +136,6 @@ struct operand
std::shared_ptr<handle_base_type_> handle_mem_var_; std::shared_ptr<handle_base_type_> handle_mem_var_;
}; };
} } // namespace rtg
#endif #endif
...@@ -9,120 +9,120 @@ namespace rtg { ...@@ -9,120 +9,120 @@ namespace rtg {
struct not_computable struct not_computable
{ {
argument compute(std::vector<argument>) const argument compute(std::vector<argument>) const { throw std::runtime_error("not computable"); }
{
throw std::runtime_error("not computable");
}
}; };
struct convolution struct convolution
{ {
std::array<std::size_t, 2> padding = {0, 0}; std::array<std::size_t, 2> padding = {0, 0};
std::array<std::size_t, 2> stride = {1, 1}; std::array<std::size_t, 2> stride = {1, 1};
std::array<std::size_t, 2> dilation = {1, 1}; std::array<std::size_t, 2> dilation = {1, 1};
std::string name() const std::string name() const
{ {
return "convolution[padding={" + to_string(padding) + return "convolution[padding={" + to_string(padding) + "}, stride={" + to_string(stride) +
"}, stride={" + to_string(stride) + "}, dilation={" + to_string(dilation) + "}]";
"}, dilation={" + to_string(dilation) +
"}]";
} }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
if(inputs.size() != 2) throw std::runtime_error("Wrong number of arguments"); if(inputs.size() != 2)
const shape& input = inputs.at(0); throw std::runtime_error("Wrong number of arguments");
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1); const shape& weights = inputs.at(1);
if(input.type() != weights.type()) throw std::runtime_error("Type doesn't match"); if(input.type() != weights.type())
if(input.lens().size() != weights.lens().size()) throw std::runtime_error("Dimensions don't match"); throw std::runtime_error("Type doesn't match");
if(input.lens().size() != 4) throw std::runtime_error("Only 4d convolution supported"); if(input.lens().size() != weights.lens().size())
throw std::runtime_error("Dimensions don't match");
if(input.lens().size() != 4)
throw std::runtime_error("Only 4d convolution supported");
auto t = input.type(); auto t = input.type();
return {t, { return {t,
input.lens()[0], {
weights.lens()[0], input.lens()[0],
std::size_t(std::max<std::ptrdiff_t>( weights.lens()[0],
1, (input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) + 2 * padding[0]) / stride[0] + 1)), std::size_t(std::max<std::ptrdiff_t>(
std::size_t(std::max<std::ptrdiff_t>( 1,
1, (input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) + 2 * padding[1]) / stride[1] + 1)), (input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
}}; 2 * padding[0]) /
stride[0] +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
2 * padding[1]) /
stride[1] +
1)),
}};
} }
argument compute(std::vector<argument>) const argument compute(std::vector<argument>) const { throw std::runtime_error("not computable"); }
{
throw std::runtime_error("not computable");
}
}; };
struct pooling struct pooling
{ {
std::string mode; std::string mode;
std::array<std::size_t, 2> padding = {0, 0}; std::array<std::size_t, 2> padding = {0, 0};
std::array<std::size_t, 2> stride = {1, 1}; std::array<std::size_t, 2> stride = {1, 1};
std::array<std::size_t, 2> lengths = {1, 1}; std::array<std::size_t, 2> lengths = {1, 1};
std::string name() const std::string name() const
{ {
return "pooling:" + mode + "[padding={" + to_string(padding) + return "pooling:" + mode + "[padding={" + to_string(padding) + "}, stride={" +
"}, stride={" + to_string(stride) + to_string(stride) + "}, lengths={" + to_string(lengths) + "}]";
"}, lengths={" + to_string(lengths) +
"}]";
} }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
if(inputs.empty()) throw std::runtime_error("Wrong number of arguments"); if(inputs.empty())
const shape& input = inputs.at(0); throw std::runtime_error("Wrong number of arguments");
if(input.lens().size() != 4) throw std::runtime_error("Only 4d pooling supported"); const shape& input = inputs.at(0);
if(input.lens().size() != 4)
throw std::runtime_error("Only 4d pooling supported");
auto t = input.type(); auto t = input.type();
return {t, { return {t,
input.lens()[0], {
input.lens()[1], input.lens()[0],
std::size_t(std::max<std::ptrdiff_t>( input.lens()[1],
1, std::ceil((input.lens()[3] + 2 * padding[0] - lengths[0]) / static_cast<float>(stride[0])) + 1)), std::size_t(std::max<std::ptrdiff_t>(
std::size_t(std::max<std::ptrdiff_t>( 1,
1, std::ceil((input.lens()[4] + 2 * padding[1] - lengths[1]) / static_cast<float>(stride[1])) + 1)), std::ceil((input.lens()[3] + 2 * padding[0] - lengths[0]) /
}}; static_cast<float>(stride[0])) +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
std::ceil((input.lens()[4] + 2 * padding[1] - lengths[1]) /
static_cast<float>(stride[1])) +
1)),
}};
} }
argument compute(std::vector<argument>) const argument compute(std::vector<argument>) const { throw std::runtime_error("not computable"); }
{
throw std::runtime_error("not computable");
}
}; };
struct activation struct activation
{ {
std::string mode; std::string mode;
std::string name() const std::string name() const { return "activation:" + mode; }
{
return "activation:" + mode;
}
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
if(inputs.empty()) throw std::runtime_error("Wrong number of arguments"); if(inputs.empty())
throw std::runtime_error("Wrong number of arguments");
return inputs.front(); return inputs.front();
} }
argument compute(std::vector<argument>) const argument compute(std::vector<argument>) const { throw std::runtime_error("not computable"); }
{
throw std::runtime_error("not computable");
}
}; };
struct reshape struct reshape
{ {
std::vector<int64_t> dims; std::vector<int64_t> dims;
std::string name() const std::string name() const { return "reshape[dims={" + to_string(dims) + "}]"; }
{
return "reshape[dims={" + to_string(dims) +
"}]";
}
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
if(inputs.empty()) throw std::runtime_error("Wrong number of arguments"); if(inputs.empty())
throw std::runtime_error("Wrong number of arguments");
auto&& idims = inputs.front().lens(); auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end()); std::vector<std::size_t> rdims(dims.begin(), dims.end());
for(std::size_t i = 0;i < dims.size();i++) for(std::size_t i = 0; i < dims.size(); i++)
{ {
if(dims[i] == 0) if(dims[i] == 0)
rdims[i] = idims[i]; rdims[i] = idims[i];
...@@ -130,18 +130,14 @@ struct reshape ...@@ -130,18 +130,14 @@ struct reshape
if(dims.back() == -1) if(dims.back() == -1)
{ {
rdims.pop_back(); rdims.pop_back();
std::copy(idims.begin()+rdims.size(), idims.end(), std::back_inserter(rdims)); std::copy(idims.begin() + rdims.size(), idims.end(), std::back_inserter(rdims));
} }
return {inputs.front().type(), rdims}; return {inputs.front().type(), rdims};
} }
argument compute(std::vector<argument>) const argument compute(std::vector<argument>) const { throw std::runtime_error("not computable"); }
{
throw std::runtime_error("not computable");
}
}; };
} // namespace rtg } // namespace rtg
#endif #endif
...@@ -13,35 +13,38 @@ namespace rtg { ...@@ -13,35 +13,38 @@ namespace rtg {
struct program struct program
{ {
// TODO: A program should be copyable // TODO: A program should be copyable
program() = default; program() = default;
program(const program&) = delete; program(const program&) = delete;
program& operator=(const program&) = delete; program& operator=(const program&) = delete;
template<class... Ts> template <class... Ts>
instruction * add_instruction(operand op, Ts*... args) instruction* add_instruction(operand op, Ts*... args)
{ {
shape r = op.compute_shape({args->result...}); shape r = op.compute_shape({args->result...});
instructions.push_back({op, r, {args...}}); instructions.push_back({op, r, {args...}});
return std::addressof(instructions.back()); return std::addressof(instructions.back());
} }
instruction * add_instruction(operand op, std::vector<instruction*> args) instruction* add_instruction(operand op, std::vector<instruction*> args)
{ {
assert(std::all_of(args.begin(), args.end(), [&](instruction* x) { return has_instruction(x); }) && "Argument is not an exisiting instruction"); assert(std::all_of(
args.begin(), args.end(), [&](instruction* x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
std::vector<shape> shapes(args.size()); std::vector<shape> shapes(args.size());
std::transform(args.begin(), args.end(), shapes.begin(), [](instruction* ins) { return ins->result; }); std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction* ins) { return ins->result; });
shape r = op.compute_shape(shapes); shape r = op.compute_shape(shapes);
instructions.push_back({op, r, args}); instructions.push_back({op, r, args});
assert(instructions.back().arguments == args); assert(instructions.back().arguments == args);
return std::addressof(instructions.back()); return std::addressof(instructions.back());
} }
template<class... Ts> template <class... Ts>
instruction * add_literal(Ts&&... xs) instruction* add_literal(Ts&&... xs)
{ {
instructions.emplace_back(literal{std::forward<Ts>(xs)...}); instructions.emplace_back(literal{std::forward<Ts>(xs)...});
return std::addressof(instructions.back()); return std::addressof(instructions.back());
} }
instruction * add_parameter(std::string name, shape s) instruction* add_parameter(std::string name, shape s)
{ {
instructions.push_back({builtin::param{std::move(name)}, s, {}}); instructions.push_back({builtin::param{std::move(name)}, s, {}});
return std::addressof(instructions.back()); return std::addressof(instructions.back());
...@@ -52,16 +55,18 @@ struct program ...@@ -52,16 +55,18 @@ struct program
// TODO: Change to stream operator // TODO: Change to stream operator
void print() const; void print() const;
bool has_instruction(const instruction * ins) const bool has_instruction(const instruction* ins) const
{ {
return std::find_if(instructions.begin(), instructions.end(), [&](const instruction& x) {return ins == std::addressof(x); }) != instructions.end(); return std::find_if(instructions.begin(), instructions.end(), [&](const instruction& x) {
return ins == std::addressof(x);
}) != instructions.end();
} }
private: private:
// A list is used to keep references to an instruction stable // A list is used to keep references to an instruction stable
std::list<instruction> instructions; std::list<instruction> instructions;
}; };
} } // namespace rtg
#endif #endif
...@@ -6,14 +6,14 @@ ...@@ -6,14 +6,14 @@
namespace rtg { namespace rtg {
template<class Derived> template <class Derived>
struct raw_data struct raw_data
{ {
friend bool operator==(const Derived& x, const Derived& y) friend bool operator==(const Derived& x, const Derived& y)
{ {
auto&& xshape = x.get_shape(); auto&& xshape = x.get_shape();
auto&& yshape = y.get_shape(); auto&& yshape = y.get_shape();
bool result = x.empty() && y.empty(); bool result = x.empty() && y.empty();
if(not result && xshape == yshape) if(not result && xshape == yshape)
{ {
auto&& xbuffer = x.data(); auto&& xbuffer = x.data();
...@@ -22,59 +22,48 @@ struct raw_data ...@@ -22,59 +22,48 @@ struct raw_data
xshape.visit_type([&](auto as) { xshape.visit_type([&](auto as) {
auto xview = make_view(xshape, as.from(xbuffer)); auto xview = make_view(xshape, as.from(xbuffer));
auto yview = make_view(yshape, as.from(ybuffer)); auto yview = make_view(yshape, as.from(ybuffer));
result = xview == yview; result = xview == yview;
}); });
} }
return result; return result;
} }
friend bool operator!=(const Derived& x, const Derived& y) friend bool operator!=(const Derived& x, const Derived& y) { return !(x == y); }
{
return !(x == y); template <class Stream>
}
template<class Stream>
friend Stream& operator<<(Stream& os, const Derived& d) friend Stream& operator<<(Stream& os, const Derived& d)
{ {
d.visit([&](auto x) { d.visit([&](auto x) { os << x; });
os << x;
});
return os; return os;
} }
template<class Visitor> template <class Visitor>
void visit_at(Visitor v, std::size_t n=0) const void visit_at(Visitor v, std::size_t n = 0) const
{ {
auto && s = static_cast<const Derived&>(*this).get_shape(); auto&& s = static_cast<const Derived&>(*this).get_shape();
auto && buffer = static_cast<const Derived&>(*this).data(); auto&& buffer = static_cast<const Derived&>(*this).data();
s.visit_type([&](auto as) { s.visit_type([&](auto as) { v(*(as.from(buffer) + s.index(n))); });
v(*(as.from(buffer)+s.index(n)));
});
} }
template<class Visitor> template <class Visitor>
void visit(Visitor v) const void visit(Visitor v) const
{ {
auto && s = static_cast<const Derived&>(*this).get_shape(); auto&& s = static_cast<const Derived&>(*this).get_shape();
auto && buffer = static_cast<const Derived&>(*this).data(); auto&& buffer = static_cast<const Derived&>(*this).data();
s.visit_type([&](auto as) { s.visit_type([&](auto as) { v(make_view(s, as.from(buffer))); });
v(make_view(s, as.from(buffer)));
});
} }
bool single() const bool single() const
{ {
auto && s = static_cast<const Derived&>(*this).get_shape(); auto&& s = static_cast<const Derived&>(*this).get_shape();
return s.elements() == 1; return s.elements() == 1;
} }
template<class T> template <class T>
T at(std::size_t n=0) const T at(std::size_t n = 0) const
{ {
T result; T result;
this->visit_at([&](auto x) { this->visit_at([&](auto x) { result = x; }, n);
result = x;
}, n);
return result; return result;
} }
}; };
......
...@@ -11,6 +11,7 @@ struct shape ...@@ -11,6 +11,7 @@ struct shape
{ {
// Add new types here // Add new types here
// clang-format off
#define RTG_SHAPE_VISIT_TYPES(m) \ #define RTG_SHAPE_VISIT_TYPES(m) \
m(float_type, float) \ m(float_type, float) \
m(double_type, double) \ m(double_type, double) \
...@@ -23,6 +24,7 @@ struct shape ...@@ -23,6 +24,7 @@ struct shape
m(uint32_type, uint32_t) \ m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \ m(uint64_type, uint64_t) \
// clang-format on
#define RTG_SHAPE_ENUM_TYPES(x, t) x, #define RTG_SHAPE_ENUM_TYPES(x, t) x,
enum type_t enum type_t
{ {
...@@ -30,12 +32,13 @@ struct shape ...@@ -30,12 +32,13 @@ struct shape
}; };
#undef RTG_SHAPE_ENUM_TYPES #undef RTG_SHAPE_ENUM_TYPES
template<class T, class=void> template <class T, class = void>
struct get_type; struct get_type;
#define RTG_SHAPE_GET_TYPE(x, t) \ #define RTG_SHAPE_GET_TYPE(x, t) \
template<class T> \ template <class T> \
struct get_type<t, T> : std::integral_constant<type_t, x> \ struct get_type<t, T> : std::integral_constant<type_t, x> \
{}; { \
};
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_GET_TYPE) RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_GET_TYPE)
#undef RTG_SHAPE_GET_TYPE #undef RTG_SHAPE_GET_TYPE
...@@ -44,7 +47,6 @@ struct shape ...@@ -44,7 +47,6 @@ struct shape
shape(type_t t, std::vector<std::size_t> l); shape(type_t t, std::vector<std::size_t> l);
shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s); shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s);
type_t type() const; type_t type() const;
const std::vector<std::size_t>& lens() const; const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t>& strides() const; const std::vector<std::size_t>& strides() const;
...@@ -63,67 +65,60 @@ struct shape ...@@ -63,67 +65,60 @@ struct shape
friend bool operator!=(const shape& x, const shape& y); friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x); friend std::ostream& operator<<(std::ostream& os, const shape& x);
template<class T> template <class T>
struct as struct as
{ {
using type = T; using type = T;
template<class U> template <class U>
T operator()(U u) const T operator()(U u) const
{ {
return T(u); return T(u);
} }
template<class U> template <class U>
T* operator()(U* u) const T* operator()(U* u) const
{ {
return static_cast<T*>(u); return static_cast<T*>(u);
} }
template<class U> template <class U>
const T* operator()(const U* u) const const T* operator()(const U* u) const
{ {
return static_cast<T*>(u); return static_cast<T*>(u);
} }
T operator()() const T operator()() const { return {}; }
{
return {};
}
std::size_t size(std::size_t n=1) const std::size_t size(std::size_t n = 1) const { return sizeof(T) * n; }
{
return sizeof(T)*n;
}
template<class U> template <class U>
T* from(U* buffer, std::size_t n=0) const T* from(U* buffer, std::size_t n = 0) const
{ {
return reinterpret_cast<T*>(buffer)+n; return reinterpret_cast<T*>(buffer) + n;
} }
template<class U> template <class U>
const T* from(const U* buffer, std::size_t n=0) const const T* from(const U* buffer, std::size_t n = 0) const
{ {
return reinterpret_cast<const T*>(buffer)+n; return reinterpret_cast<const T*>(buffer) + n;
} }
}; };
template<class Visitor> template <class Visitor>
void visit_type(Visitor v) const void visit_type(Visitor v) const
{ {
switch(this->type_) switch(this->type_)
{ {
#define RTG_SHAPE_VISITOR_CASE(x, t) \ #define RTG_SHAPE_VISITOR_CASE(x, t) \
case x: \ case x: v(as<t>()); return;
v(as<t>()); \
return;
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_VISITOR_CASE) RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_VISITOR_CASE)
#undef RTG_SHAPE_VISITOR_CASE #undef RTG_SHAPE_VISITOR_CASE
} }
assert(true); assert(true);
} }
private:
private:
type_t type_; type_t type_;
std::vector<std::size_t> lens_; std::vector<std::size_t> lens_;
std::vector<std::size_t> strides_; std::vector<std::size_t> strides_;
...@@ -134,6 +129,6 @@ private: ...@@ -134,6 +129,6 @@ private:
std::string type_string() const; std::string type_string() const;
}; };
} } // namespace rtg
#endif #endif
...@@ -65,17 +65,14 @@ inline std::string remove_prefix(std::string s, std::string prefix) ...@@ -65,17 +65,14 @@ inline std::string remove_prefix(std::string s, std::string prefix)
return s; return s;
} }
template<class Range> template <class Range>
inline std::string to_string(const Range& r) inline std::string to_string(const Range& r)
{ {
std::stringstream ss; std::stringstream ss;
if(!r.empty()) if(!r.empty())
{ {
ss << r.front(); ss << r.front();
std::for_each(std::next(r.begin()), r.end(), [&](auto&& x) std::for_each(std::next(r.begin()), r.end(), [&](auto&& x) { ss << ", " << x; });
{
ss << ", " << x;
});
} }
return ss.str(); return ss.str();
} }
......
...@@ -8,48 +8,29 @@ ...@@ -8,48 +8,29 @@
namespace rtg { namespace rtg {
template<class T> template <class T>
struct tensor_view struct tensor_view
{ {
tensor_view() tensor_view() : data_(nullptr), shape_() {}
: data_(nullptr), shape_() tensor_view(shape s, T* d) : data_(d), shape_(s) {}
{}
tensor_view(shape s, T* d)
: data_(d), shape_(s)
{}
const shape& get_shape() const
{
return this->shape_;
}
bool empty() const const shape& get_shape() const { return this->shape_; }
{
return data_ == nullptr || shape_.lens().size() == 0;
}
std::size_t size() const bool empty() const { return data_ == nullptr || shape_.lens().size() == 0; }
{
return shape_.elements();
}
T* data() std::size_t size() const { return shape_.elements(); }
{
return this->data_;
}
const T* data() const T* data() { return this->data_; }
{
return this->data_;
}
template<class... Ts> const T* data() const { return this->data_; }
template <class... Ts>
const T& operator()(Ts... xs) const const T& operator()(Ts... xs) const
{ {
return data_[shape_.index({xs...})]; return data_[shape_.index({xs...})];
} }
template<class... Ts> template <class... Ts>
T& operator()(Ts... xs) T& operator()(Ts... xs)
{ {
return data_[shape_.index({xs...})]; return data_[shape_.index({xs...})];
...@@ -82,13 +63,13 @@ struct tensor_view ...@@ -82,13 +63,13 @@ struct tensor_view
T& back() T& back()
{ {
assert(!this->empty()); assert(!this->empty());
return data_[shape_.index(this->size()-1)]; return data_[shape_.index(this->size() - 1)];
} }
const T& back() const const T& back() const
{ {
assert(!this->empty()); assert(!this->empty());
return data_[shape_.index(this->size()-1)]; return data_[shape_.index(this->size() - 1)];
} }
// TODO: Add iterators so it can handle nonpacked tensors // TODO: Add iterators so it can handle nonpacked tensors
...@@ -101,8 +82,10 @@ struct tensor_view ...@@ -101,8 +82,10 @@ struct tensor_view
T* end() T* end()
{ {
assert(this->shape_.packed()); assert(this->shape_.packed());
if(this->empty()) return data_; if(this->empty())
else return data_+this->size(); return data_;
else
return data_ + this->size();
} }
const T* begin() const const T* begin() const
...@@ -114,34 +97,34 @@ struct tensor_view ...@@ -114,34 +97,34 @@ struct tensor_view
const T* end() const const T* end() const
{ {
assert(this->shape_.packed()); assert(this->shape_.packed());
if(this->empty()) return data_; if(this->empty())
else return data_+this->size(); return data_;
else
return data_ + this->size();
} }
friend bool operator==(const tensor_view<T>& x, const tensor_view<T>& y) friend bool operator==(const tensor_view<T>& x, const tensor_view<T>& y)
{ {
if(x.shape_ == y.shape_) if(x.shape_ == y.shape_)
{ {
for(std::size_t i = 0;i < x.shape_.elements();i++) for(std::size_t i = 0; i < x.shape_.elements(); i++)
{ {
if(!float_equal(x[i], y[i])) return false; if(!float_equal(x[i], y[i]))
return false;
} }
return true; return true;
} }
return false; return false;
} }
friend bool operator!=(const tensor_view<T>& x, const tensor_view<T>& y) friend bool operator!=(const tensor_view<T>& x, const tensor_view<T>& y) { return !(x == y); }
{
return !(x == y);
}
friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x) friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x)
{ {
if(!x.empty()) if(!x.empty())
{ {
os << x.front(); os << x.front();
for(std::size_t i = 1;i < x.shape_.elements();i++) for(std::size_t i = 1; i < x.shape_.elements(); i++)
{ {
os << ", " << x.data_[x.shape_.index(i)]; os << ", " << x.data_[x.shape_.index(i)];
} }
...@@ -149,12 +132,12 @@ struct tensor_view ...@@ -149,12 +132,12 @@ struct tensor_view
return os; return os;
} }
private: private:
T* data_; T* data_;
shape shape_; shape shape_;
}; };
template<class T> template <class T>
tensor_view<T> make_view(shape s, T* data) tensor_view<T> make_view(shape s, T* data)
{ {
return {s, data}; return {s, data};
......
...@@ -13,43 +13,41 @@ ...@@ -13,43 +13,41 @@
struct unknown struct unknown
{ {
std::string op; std::string op;
std::string name() const std::string name() const { return "unknown:" + op; }
{
return "unknown:"+op;
}
rtg::shape compute_shape(std::vector<rtg::shape> input) const rtg::shape compute_shape(std::vector<rtg::shape> input) const
{ {
if(input.empty()) return {}; if(input.empty())
else return input.front(); return {};
} else
rtg::argument compute(std::vector<rtg::argument> input) const return input.front();
{
throw "not computable";
} }
rtg::argument compute(std::vector<rtg::argument> input) const { throw "not computable"; }
}; };
template<class C, class T> template <class C, class T>
bool contains(C&& c, T&& x) bool contains(C&& c, T&& x)
{ {
return c.find(x) != c.end(); return c.find(x) != c.end();
} }
template<class Range, class Iterator> template <class Range, class Iterator>
void copy(Range&& r, Iterator it) void copy(Range&& r, Iterator it)
{ {
std::copy(r.begin(), r.end(), it); std::copy(r.begin(), r.end(), it);
} }
struct onnx_parser
struct onnx_parser
{ {
using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>; using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
using node_map = std::unordered_map<std::string, onnx::NodeProto>; using node_map = std::unordered_map<std::string, onnx::NodeProto>;
node_map nodes; node_map nodes;
std::unordered_map<std::string, rtg::instruction*> instructions; std::unordered_map<std::string, rtg::instruction*> instructions;
std::shared_ptr<rtg::program> prog = std::make_shared<rtg::program>(); std::shared_ptr<rtg::program> prog = std::make_shared<rtg::program>();
std::unordered_map<std::string, std::function<rtg::instruction*(attribute_map, std::vector<rtg::instruction*>)>> ops; std::unordered_map<
std::string,
std::function<rtg::instruction*(attribute_map, std::vector<rtg::instruction*>)>>
ops;
onnx_parser() onnx_parser()
{ {
...@@ -92,10 +90,7 @@ struct onnx_parser ...@@ -92,10 +90,7 @@ struct onnx_parser
add_op("Reshape", [this](attribute_map attributes, std::vector<rtg::instruction*> args) { add_op("Reshape", [this](attribute_map attributes, std::vector<rtg::instruction*> args) {
rtg::reshape op; rtg::reshape op;
rtg::literal s = parse_value(attributes.at("shape")); rtg::literal s = parse_value(attributes.at("shape"));
s.visit([&](auto v) s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
{
copy(v, std::back_inserter(op.dims));
});
return prog->add_instruction(op, args); return prog->add_instruction(op, args);
}); });
add_op("Constant", [this](attribute_map attributes, std::vector<rtg::instruction*>) { add_op("Constant", [this](attribute_map attributes, std::vector<rtg::instruction*>) {
...@@ -104,7 +99,7 @@ struct onnx_parser ...@@ -104,7 +99,7 @@ struct onnx_parser
}); });
} }
template<class F> template <class F>
void add_op(std::string name, F f) void add_op(std::string name, F f)
{ {
ops.emplace(name, f); ops.emplace(name, f);
...@@ -113,14 +108,14 @@ struct onnx_parser ...@@ -113,14 +108,14 @@ struct onnx_parser
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
onnx::ModelProto model; onnx::ModelProto model;
if(model.ParseFromIstream(&is)) if(model.ParseFromIstream(&is))
{ {
if(model.has_graph()) if(model.has_graph())
{ {
this->parse_graph(model.graph()); this->parse_graph(model.graph());
} }
} }
else else
{ {
throw std::runtime_error("Failed reading"); throw std::runtime_error("Failed reading");
} }
...@@ -129,14 +124,14 @@ struct onnx_parser ...@@ -129,14 +124,14 @@ struct onnx_parser
void parse_graph(const onnx::GraphProto& graph) void parse_graph(const onnx::GraphProto& graph)
{ {
nodes = get_nodes(graph); nodes = get_nodes(graph);
for(auto&& input:graph.input()) for(auto&& input : graph.input())
{ {
std::string name = input.name(); std::string name = input.name();
// TODO: Get shape of input parameter // TODO: Get shape of input parameter
rtg::shape s = parse_type(input.type()); rtg::shape s = parse_type(input.type());
instructions[name] = prog->add_parameter(name, s); instructions[name] = prog->add_parameter(name, s);
} }
for(auto&& p:nodes) for(auto&& p : nodes)
{ {
this->parse_node(p.second.name()); this->parse_node(p.second.name());
} }
...@@ -144,11 +139,11 @@ struct onnx_parser ...@@ -144,11 +139,11 @@ struct onnx_parser
void parse_node(std::string name) void parse_node(std::string name)
{ {
if (instructions.count(name) == 0) if(instructions.count(name) == 0)
{ {
auto&& node = nodes.at(name); auto&& node = nodes.at(name);
std::vector<rtg::instruction*> args; std::vector<rtg::instruction*> args;
for(auto&& input:node.input()) for(auto&& input : node.input())
{ {
if(nodes.count(input) > 0) if(nodes.count(input) > 0)
{ {
...@@ -161,7 +156,7 @@ struct onnx_parser ...@@ -161,7 +156,7 @@ struct onnx_parser
args.push_back(instructions.at(input)); args.push_back(instructions.at(input));
} }
} }
if (ops.count(node.op_type()) == 0) if(ops.count(node.op_type()) == 0)
{ {
instructions[name] = prog->add_instruction(unknown{node.op_type()}, args); instructions[name] = prog->add_instruction(unknown{node.op_type()}, args);
} }
...@@ -175,7 +170,7 @@ struct onnx_parser ...@@ -175,7 +170,7 @@ struct onnx_parser
static attribute_map get_attributes(const onnx::NodeProto& node) static attribute_map get_attributes(const onnx::NodeProto& node)
{ {
std::unordered_map<std::string, onnx::AttributeProto> result; std::unordered_map<std::string, onnx::AttributeProto> result;
for(auto&& attr:node.attribute()) for(auto&& attr : node.attribute())
{ {
result[attr.name()] = attr; result[attr.name()] = attr;
} }
...@@ -185,14 +180,13 @@ struct onnx_parser ...@@ -185,14 +180,13 @@ struct onnx_parser
static node_map get_nodes(const onnx::GraphProto& graph) static node_map get_nodes(const onnx::GraphProto& graph)
{ {
std::unordered_map<std::string, onnx::NodeProto> result; std::unordered_map<std::string, onnx::NodeProto> result;
for(auto&& node:graph.node()) for(auto&& node : graph.node())
{ {
result[node.name()] = node; result[node.name()] = node;
for(auto&& output:node.output()) for(auto&& output : node.output())
{ {
result[output] = node; result[output] = node;
} }
} }
return result; return result;
} }
...@@ -201,17 +195,20 @@ struct onnx_parser ...@@ -201,17 +195,20 @@ struct onnx_parser
{ {
switch(attr.type()) switch(attr.type())
{ {
case onnx::AttributeProto::UNDEFINED: return {}; case onnx::AttributeProto::UNDEFINED: return {};
case onnx::AttributeProto::FLOAT: return rtg::literal{attr.f()}; case onnx::AttributeProto::FLOAT: return rtg::literal{attr.f()};
case onnx::AttributeProto::INT: return rtg::literal{attr.i()}; case onnx::AttributeProto::INT: return rtg::literal{attr.i()};
case onnx::AttributeProto::STRING: return {}; case onnx::AttributeProto::STRING: return {};
case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t()); case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
case onnx::AttributeProto::GRAPH: return {}; case onnx::AttributeProto::GRAPH: return {};
case onnx::AttributeProto::FLOATS: return rtg::literal{rtg::shape::float_type, attr.floats().begin(), attr.floats().end()}; case onnx::AttributeProto::FLOATS:
case onnx::AttributeProto::INTS: return rtg::literal{rtg::shape::int32_type, attr.ints().begin(), attr.ints().end()};; return rtg::literal{rtg::shape::float_type, attr.floats().begin(), attr.floats().end()};
case onnx::AttributeProto::STRINGS: return {}; case onnx::AttributeProto::INTS:
case onnx::AttributeProto::TENSORS: return {}; return rtg::literal{rtg::shape::int32_type, attr.ints().begin(), attr.ints().end()};
case onnx::AttributeProto::GRAPHS: return {}; ;
case onnx::AttributeProto::STRINGS: return {};
case onnx::AttributeProto::TENSORS: return {};
case onnx::AttributeProto::GRAPHS: return {};
} }
} }
...@@ -220,22 +217,38 @@ struct onnx_parser ...@@ -220,22 +217,38 @@ struct onnx_parser
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end()); std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
switch(t.data_type()) switch(t.data_type())
{ {
case onnx::TensorProto::UNDEFINED: throw std::runtime_error(""); case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT: return rtg::literal{{rtg::shape::float_type, dims}, t.float_data().begin(), t.float_data().end()}; case onnx::TensorProto::FLOAT:
case onnx::TensorProto::UINT8: throw std::runtime_error(""); return rtg::literal{
case onnx::TensorProto::INT8: return rtg::literal{{rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; {rtg::shape::float_type, dims}, t.float_data().begin(), t.float_data().end()};
case onnx::TensorProto::UINT16: return rtg::literal{{rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT16: return rtg::literal{{rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; case onnx::TensorProto::INT8:
case onnx::TensorProto::INT32: return rtg::literal{{rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return rtg::literal{
case onnx::TensorProto::INT64: return rtg::literal{{rtg::shape::int64_type, dims}, t.int64_data().begin(), t.int64_data().end()}; {rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::STRING: throw std::runtime_error(""); case onnx::TensorProto::UINT16:
case onnx::TensorProto::BOOL: return rtg::literal{{rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return rtg::literal{
case onnx::TensorProto::FLOAT16: throw std::runtime_error(""); {rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::DOUBLE: return rtg::literal{{rtg::shape::double_type, dims}, t.double_data().begin(), t.double_data().end()}; case onnx::TensorProto::INT16:
case onnx::TensorProto::UINT32: throw std::runtime_error(""); return rtg::literal{
case onnx::TensorProto::UINT64: throw std::runtime_error(""); {rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::COMPLEX64: throw std::runtime_error(""); case onnx::TensorProto::INT32:
case onnx::TensorProto::COMPLEX128: throw std::runtime_error(""); return rtg::literal{
{rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::INT64:
return rtg::literal{
{rtg::shape::int64_type, dims}, t.int64_data().begin(), t.int64_data().end()};
case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL:
return rtg::literal{
{rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::FLOAT16: throw std::runtime_error("");
case onnx::TensorProto::DOUBLE:
return rtg::literal{
{rtg::shape::double_type, dims}, t.double_data().begin(), t.double_data().end()};
case onnx::TensorProto::UINT32: throw std::runtime_error("");
case onnx::TensorProto::UINT64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
} }
} }
...@@ -244,26 +257,33 @@ struct onnx_parser ...@@ -244,26 +257,33 @@ struct onnx_parser
rtg::shape::type_t shape_type; rtg::shape::type_t shape_type;
switch(t.tensor_type().elem_type()) switch(t.tensor_type().elem_type())
{ {
case onnx::TensorProto::UNDEFINED: break; //throw std::runtime_error("Unsupported type UNDEFINED"); case onnx::TensorProto::UNDEFINED:
case onnx::TensorProto::FLOAT: shape_type = rtg::shape::float_type; break; // throw std::runtime_error("Unsupported type UNDEFINED");
case onnx::TensorProto::UINT8: break; //throw std::runtime_error("Unsupported type UINT8"); case onnx::TensorProto::FLOAT: shape_type = rtg::shape::float_type;
case onnx::TensorProto::INT8: shape_type = rtg::shape::int8_type; case onnx::TensorProto::UINT8:
case onnx::TensorProto::UINT16: shape_type = rtg::shape::uint16_type; break; // throw std::runtime_error("Unsupported type UINT8");
case onnx::TensorProto::INT16: shape_type = rtg::shape::int16_type; case onnx::TensorProto::INT8: shape_type = rtg::shape::int8_type;
case onnx::TensorProto::INT32: shape_type = rtg::shape::int32_type; case onnx::TensorProto::UINT16: shape_type = rtg::shape::uint16_type;
case onnx::TensorProto::INT64: shape_type = rtg::shape::int64_type; case onnx::TensorProto::INT16: shape_type = rtg::shape::int16_type;
case onnx::TensorProto::STRING: break; //throw std::runtime_error("Unsupported type STRING"); case onnx::TensorProto::INT32: shape_type = rtg::shape::int32_type;
case onnx::TensorProto::BOOL: break; //throw std::runtime_error("Unsupported type BOOL"); case onnx::TensorProto::INT64: shape_type = rtg::shape::int64_type;
case onnx::TensorProto::FLOAT16: break; //throw std::runtime_error("Unsupported type FLOAT16"); case onnx::TensorProto::STRING:
case onnx::TensorProto::DOUBLE: shape_type = rtg::shape::double_type; break; // throw std::runtime_error("Unsupported type STRING");
case onnx::TensorProto::UINT32: shape_type = rtg::shape::uint32_type; case onnx::TensorProto::BOOL:
case onnx::TensorProto::UINT64: shape_type = rtg::shape::uint64_type; break; // throw std::runtime_error("Unsupported type BOOL");
case onnx::TensorProto::COMPLEX64: break; //throw std::runtime_error("Unsupported type COMPLEX64"); case onnx::TensorProto::FLOAT16:
case onnx::TensorProto::COMPLEX128: break; //throw std::runtime_error("Unsupported type COMPLEX128"); break; // throw std::runtime_error("Unsupported type FLOAT16");
case onnx::TensorProto::DOUBLE: shape_type = rtg::shape::double_type;
case onnx::TensorProto::UINT32: shape_type = rtg::shape::uint32_type;
case onnx::TensorProto::UINT64: shape_type = rtg::shape::uint64_type;
case onnx::TensorProto::COMPLEX64:
break; // throw std::runtime_error("Unsupported type COMPLEX64");
case onnx::TensorProto::COMPLEX128:
break; // throw std::runtime_error("Unsupported type COMPLEX128");
} }
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
// TODO: USe std::transform // TODO: USe std::transform
for(auto&& d:t.tensor_type().shape().dim()) for(auto&& d : t.tensor_type().shape().dim())
{ {
dims.push_back(d.dim_value()); dims.push_back(d.dim_value());
} }
...@@ -271,7 +291,7 @@ struct onnx_parser ...@@ -271,7 +291,7 @@ struct onnx_parser
} }
}; };
int main(int argc, char const *argv[]) int main(int argc, char const* argv[])
{ {
if(argc > 1) if(argc > 1)
{ {
...@@ -284,7 +304,8 @@ int main(int argc, char const *argv[]) ...@@ -284,7 +304,8 @@ int main(int argc, char const *argv[])
} }
catch(...) catch(...)
{ {
if(parser.prog) parser.prog->print(); if(parser.prog)
parser.prog->print();
throw; throw;
} }
parser.prog->print(); parser.prog->print();
......
...@@ -9,7 +9,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const ...@@ -9,7 +9,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
{ {
std::unordered_map<const instruction*, argument> results; std::unordered_map<const instruction*, argument> results;
argument result; argument result;
for(auto& ins:instructions) for(auto& ins : instructions)
{ {
if(ins.op.name() == "@literal") if(ins.op.name() == "@literal")
{ {
...@@ -22,9 +22,10 @@ literal program::eval(std::unordered_map<std::string, argument> params) const ...@@ -22,9 +22,10 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
else else
{ {
std::vector<argument> values(ins.arguments.size()); std::vector<argument> values(ins.arguments.size());
std::transform(ins.arguments.begin(), ins.arguments.end(), values.begin(), [&](instruction * i) { std::transform(ins.arguments.begin(),
return results.at(i); ins.arguments.end(),
}); values.begin(),
[&](instruction* i) { return results.at(i); });
result = ins.op.compute(values); result = ins.op.compute(values);
} }
results.emplace(std::addressof(ins), result); results.emplace(std::addressof(ins), result);
...@@ -37,7 +38,7 @@ void program::print() const ...@@ -37,7 +38,7 @@ void program::print() const
std::unordered_map<const instruction*, std::string> names; std::unordered_map<const instruction*, std::string> names;
int count = 0; int count = 0;
for(auto& ins:instructions) for(auto& ins : instructions)
{ {
std::string var_name = "@" + std::to_string(count); std::string var_name = "@" + std::to_string(count);
if(starts_with(ins.op.name(), "@param")) if(starts_with(ins.op.name(), "@param"))
...@@ -51,7 +52,7 @@ void program::print() const ...@@ -51,7 +52,7 @@ void program::print() const
if(ins.op.name() == "@literal") if(ins.op.name() == "@literal")
{ {
if (ins.lit.get_shape().elements() > 10) if(ins.lit.get_shape().elements() > 10)
std::cout << "{ ... }"; std::cout << "{ ... }";
else else
std::cout << "{" << ins.lit << "}"; std::cout << "{" << ins.lit << "}";
...@@ -60,7 +61,7 @@ void program::print() const ...@@ -60,7 +61,7 @@ void program::print() const
if(!ins.arguments.empty()) if(!ins.arguments.empty())
{ {
char delim = '('; char delim = '(';
for(auto&& arg:ins.arguments) for(auto&& arg : ins.arguments)
{ {
assert(this->has_instruction(arg) && "Instruction not found"); assert(this->has_instruction(arg) && "Instruction not found");
std::cout << delim << names.at(arg); std::cout << delim << names.at(arg);
...@@ -78,5 +79,4 @@ void program::print() const ...@@ -78,5 +79,4 @@ void program::print() const
} }
} }
} } // namespace rtg
...@@ -7,21 +7,16 @@ ...@@ -7,21 +7,16 @@
namespace rtg { namespace rtg {
shape::shape() shape::shape() : type_(float_type), lens_(), strides_(), packed_(false) {}
: type_(float_type), lens_(), strides_(), packed_(false)
{}
shape::shape(type_t t) shape::shape(type_t t) : type_(t), lens_({1}), strides_({1}), packed_(true) {}
: type_(t), lens_({1}), strides_({1}), packed_(true) shape::shape(type_t t, std::vector<std::size_t> l) : type_(t), lens_(std::move(l)), packed_(true)
{}
shape::shape(type_t t, std::vector<std::size_t> l)
: type_(t), lens_(std::move(l)), packed_(true)
{ {
this->calculate_strides(); this->calculate_strides();
assert(lens_.size() == strides_.size()); assert(lens_.size() == strides_.size());
} }
shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: type_(t), lens_(std::move(l)), strides_(std::move(s)) : type_(t), lens_(std::move(l)), strides_(std::move(s))
{ {
assert(lens_.size() == strides_.size()); assert(lens_.size() == strides_.size());
packed_ = this->elements() == this->element_space(); packed_ = this->elements() == this->element_space();
...@@ -38,18 +33,9 @@ void shape::calculate_strides() ...@@ -38,18 +33,9 @@ void shape::calculate_strides()
lens_.rbegin(), lens_.rend() - 1, strides_.rbegin() + 1, std::multiplies<std::size_t>()); lens_.rbegin(), lens_.rend() - 1, strides_.rbegin() + 1, std::multiplies<std::size_t>());
} }
shape::type_t shape::type() const shape::type_t shape::type() const { return this->type_; }
{ const std::vector<std::size_t>& shape::lens() const { return this->lens_; }
return this->type_; const std::vector<std::size_t>& shape::strides() const { return this->strides_; }
}
const std::vector<std::size_t>& shape::lens() const
{
return this->lens_;
}
const std::vector<std::size_t>& shape::strides() const
{
return this->strides_;
}
std::size_t shape::elements() const std::size_t shape::elements() const
{ {
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
...@@ -77,13 +63,15 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const ...@@ -77,13 +63,15 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const
std::size_t shape::index(std::size_t i) const std::size_t shape::index(std::size_t i) const
{ {
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
return std::inner_product(this->lens().begin(), this->lens().end(), this->strides().begin(), std::size_t{0}, std::plus<std::size_t>{}, return std::inner_product(
[&](std::size_t len, std::size_t stride) { return ((i / stride) % len)*stride; }); this->lens().begin(),
} this->lens().end(),
bool shape::packed() const this->strides().begin(),
{ std::size_t{0},
return this->packed_; std::plus<std::size_t>{},
} [&](std::size_t len, std::size_t stride) { return ((i / stride) % len) * stride; });
}
bool shape::packed() const { return this->packed_; }
std::size_t shape::element_space() const std::size_t shape::element_space() const
{ {
// TODO: Get rid of intermediate vector // TODO: Get rid of intermediate vector
...@@ -101,11 +89,10 @@ std::size_t shape::element_space() const ...@@ -101,11 +89,10 @@ std::size_t shape::element_space() const
std::string shape::type_string() const std::string shape::type_string() const
{ {
switch(this->type_) switch(this->type_)
{ {
#define RTG_SHAPE_TYPE_STRING_CASE(x, t) \ #define RTG_SHAPE_TYPE_STRING_CASE(x, t) \
case x: \ case x: return #x;
return #x;
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_TYPE_STRING_CASE) RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_TYPE_STRING_CASE)
#undef RTG_SHAPE_TYPE_STRING_CASE #undef RTG_SHAPE_TYPE_STRING_CASE
} }
...@@ -116,10 +103,7 @@ bool operator==(const shape& x, const shape& y) ...@@ -116,10 +103,7 @@ bool operator==(const shape& x, const shape& y)
{ {
return x.type() == y.type() && x.lens() == y.lens() && x.strides() == y.strides(); return x.type() == y.type() && x.lens() == y.lens() && x.strides() == y.strides();
} }
bool operator!=(const shape& x, const shape& y) bool operator!=(const shape& x, const shape& y) { return !(x == y); }
{
return !(x == y);
}
std::ostream& operator<<(std::ostream& os, const shape& x) std::ostream& operator<<(std::ostream& os, const shape& x)
{ {
...@@ -129,4 +113,4 @@ std::ostream& operator<<(std::ostream& os, const shape& x) ...@@ -129,4 +113,4 @@ std::ostream& operator<<(std::ostream& os, const shape& x)
return os; return os;
} }
} } // namespace rtg
...@@ -4,37 +4,37 @@ ...@@ -4,37 +4,37 @@
#include <rtg/shape.hpp> #include <rtg/shape.hpp>
#include "test.hpp" #include "test.hpp"
struct sum_op struct sum_op
{ {
std::string name() const std::string name() const { return "sum"; }
{
return "sum";
}
rtg::argument compute(std::vector<rtg::argument> args) const rtg::argument compute(std::vector<rtg::argument> args) const
{ {
rtg::argument result; rtg::argument result;
if(args.size() != 2) throw "Wrong args"; if(args.size() != 2)
if(args[0].get_shape() != args[1].get_shape()) throw "Wrong args"; throw "Wrong args";
if(args[0].get_shape().lens().size() != 1) throw "Wrong args"; if(args[0].get_shape() != args[1].get_shape())
if(args[0].get_shape().lens().front() != 1) throw "Wrong args"; throw "Wrong args";
if(args[0].get_shape().lens().size() != 1)
throw "Wrong args";
if(args[0].get_shape().lens().front() != 1)
throw "Wrong args";
args[0].visit_at([&](auto x) { args[0].visit_at([&](auto x) {
args[1].visit_at([&](auto y) { args[1].visit_at([&](auto y) { result = rtg::literal{x + y}.get_argument(); });
result = rtg::literal{x + y}.get_argument();
});
}); });
return result; return result;
} }
rtg::shape compute_shape(std::vector<rtg::shape> inputs) const rtg::shape compute_shape(std::vector<rtg::shape> inputs) const
{ {
if(inputs.size() != 2) throw "Wrong inputs"; if(inputs.size() != 2)
throw "Wrong inputs";
return inputs.front(); return inputs.front();
} }
}; };
void literal_test() { void literal_test()
{
rtg::program p; rtg::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -45,23 +45,22 @@ void literal_test() { ...@@ -45,23 +45,22 @@ void literal_test() {
EXPECT(result != rtg::literal{4}); EXPECT(result != rtg::literal{4});
} }
void param_test() { void param_test()
{
rtg::program p; rtg::program p;
auto x = p.add_parameter("x", {rtg::shape::int64_type}); auto x = p.add_parameter("x", {rtg::shape::int64_type});
auto y = p.add_parameter("y", {rtg::shape::int64_type}); auto y = p.add_parameter("y", {rtg::shape::int64_type});
p.add_instruction(sum_op{}, x, y); p.add_instruction(sum_op{}, x, y);
auto result = p.eval({ auto result =
{"x", rtg::literal{1}.get_argument()}, p.eval({{"x", rtg::literal{1}.get_argument()}, {"y", rtg::literal{2}.get_argument()}});
{"y", rtg::literal{2}.get_argument()}
});
EXPECT(result == rtg::literal{3}); EXPECT(result == rtg::literal{3});
EXPECT(result != rtg::literal{4}); EXPECT(result != rtg::literal{4});
} }
int main() { int main()
{
literal_test(); literal_test();
param_test(); param_test();
} }
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
#include <string> #include <string>
#include "test.hpp" #include "test.hpp"
void literal_test() void literal_test()
{ {
EXPECT(rtg::literal{1} == rtg::literal{1}); EXPECT(rtg::literal{1} == rtg::literal{1});
...@@ -23,7 +22,7 @@ void literal_test() ...@@ -23,7 +22,7 @@ void literal_test()
rtg::literal l4{}; rtg::literal l4{};
EXPECT(l3 == l4); EXPECT(l3 == l4);
EXPECT(l3.empty()); EXPECT(l3.empty());
EXPECT(l4.empty()); EXPECT(l4.empty());
} }
void literal_os1() void literal_os1()
...@@ -51,10 +50,9 @@ void literal_os3() ...@@ -51,10 +50,9 @@ void literal_os3()
EXPECT(ss.str() == "1, 2, 3"); EXPECT(ss.str() == "1, 2, 3");
} }
int main() { int main()
{
literal_test(); literal_test();
literal_os1(); literal_os1();
literal_os2(); literal_os2();
} }
int main() { int main() {}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment