"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "da728bcdd6fd947246066584947d37a0fa5251b9"
Commit 14d5666b authored by Paul's avatar Paul
Browse files

Add clang formatting

parent 2305ac81
---
Language: Cpp
AccessModifierOffset: 0
AlignAfterOpenBracket: Align
AlignConsecutiveAssignments: true
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: true
AlignOperands: true
AlignTrailingComments: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortBlocksOnASingleLine: true
AllowShortCaseLabelsOnASingleLine: true
AllowShortFunctionsOnASingleLine: All
AllowShortIfStatementsOnASingleLine: false
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: false
AlwaysBreakTemplateDeclarations: true
BinPackArguments: false
BinPackParameters: false
BraceWrapping:
AfterClass: true
AfterControlStatement: true
AfterEnum: true
AfterFunction: true
AfterNamespace: false
AfterObjCDeclaration: true
AfterStruct: true
AfterUnion: true
BeforeCatch: true
BeforeElse: true
IndentBraces: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Custom
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
ColumnLimit: 100
CommentPragmas: '^ IWYU pragma:'
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
ExperimentalAutoDetectBinPacking: false
ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
IncludeCategories:
- Regex: '^"(llvm|llvm-c|clang|clang-c)/'
Priority: 2
- Regex: '^(<|"(gtest|isl|json)/)'
Priority: 3
- Regex: '.*'
Priority: 1
IndentCaseLabels: false
IndentWidth: 4
IndentWrappedFunctionNames: false
KeepEmptyLinesAtTheStartOfBlocks: true
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: true
PenaltyBreakBeforeFirstCallParameter: 19
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 60
PointerAlignment: Left
ReflowComments: true
SortIncludes: false
SpaceAfterCStyleCast: false
# SpaceAfterTemplateKeyword: true
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: Never
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Cpp11
TabWidth: 8
UseTab: Never
...
#!/usr/bin/env bash
cd $(git rev-parse --git-dir)
echo "Installing hooks..."
ln -s ../.githooks hooks
echo "Done!"
#!/bin/sh
#
# This pre-commit hook checks if any versions of clang-format
# are installed, and if so, uses the installed version to format
# the staged changes.
base=clang-format-5.0
format=""
# Redirect output to stderr.
exec 1>&2
# check if clang-format is installed
type "$base" >/dev/null 2>&1 && format="$base"
# no versions of clang-format are installed
if [ -z "$format" ]
then
echo "$base is not installed. Pre-commit hook will not be executed."
exit 0
fi
# Do everything from top - level
cd $(git rev-parse --show-toplevel)
if git rev-parse --verify HEAD >/dev/null 2>&1
then
against=HEAD
else
# Initial commit: diff against an empty tree object
against=16bbb57
fi
# do the formatting
for file in $(git diff-index --cached --name-only $against | grep -E '\.h$|\.hpp$|\.cpp$|\.cl$|\.h\.in$|\.hpp\.in$|\.cpp\.in$')
do
if [ -e "$file" ]
then
echo "$format $file"
"$format" -i -style=file "$file"
fi
done
......@@ -9,28 +9,20 @@ namespace rtg {
struct argument : raw_data<argument>
{
argument()
{}
argument() {}
argument(shape s, std::function<char*()> d)
: data(d), shape_(s)
{}
argument(shape s, std::function<char*()> d) : data(d), shape_(s) {}
std::function<char*()> data;
bool empty() const
{
return not data;
}
bool empty() const { return not data; }
const shape& get_shape() const
{
return this->shape_;
}
private:
const shape& get_shape() const { return this->shape_; }
private:
shape shape_;
};
}
} // namespace rtg
#endif
......@@ -9,38 +9,20 @@ namespace builtin {
struct literal
{
std::string name() const
{
return "@literal";
}
shape compute_shape(std::vector<shape>) const
{
throw "builtin";
}
argument compute(std::vector<argument>) const
{
throw "builtin";
}
std::string name() const { return "@literal"; }
shape compute_shape(std::vector<shape>) const { throw "builtin"; }
argument compute(std::vector<argument>) const { throw "builtin"; }
};
struct param
{
std::string parameter;
std::string name() const
{
return "@param:" + parameter;
}
shape compute_shape(std::vector<shape>) const
{
throw "builtin";
}
argument compute(std::vector<argument>) const
{
throw "builtin";
}
std::string name() const { return "@param:" + parameter; }
shape compute_shape(std::vector<shape>) const { throw "builtin"; }
argument compute(std::vector<argument>) const { throw "builtin"; }
};
}
} // namespace builtin
} // namespace rtg
......
......@@ -13,12 +13,14 @@ struct instruction
instruction() {}
instruction(operand o, shape r, std::vector<instruction*> args)
: op(std::move(o)), result(std::move(r)), arguments(std::move(args)), lit()
{}
: op(std::move(o)), result(std::move(r)), arguments(std::move(args)), lit()
{
}
instruction(literal l)
: op(builtin::literal{}), result(l.get_shape()), arguments(), lit(std::move(l))
{}
: op(builtin::literal{}), result(l.get_shape()), arguments(), lit(std::move(l))
{
}
operand op;
shape result;
......@@ -26,6 +28,6 @@ struct instruction
literal lit;
};
}
} // namespace rtg
#endif
......@@ -10,68 +10,45 @@ namespace rtg {
struct literal : raw_data<literal>
{
literal()
: buffer(), shape_()
{}
literal() : buffer(), shape_() {}
template<class T>
literal(T x)
: buffer(sizeof(T), 0), shape_(shape::get_type<T>{})
template <class T>
literal(T x) : buffer(sizeof(T), 0), shape_(shape::get_type<T>{})
{
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
*(reinterpret_cast<T*>(buffer.data())) = x;
}
template<class T>
literal(shape s, const std::vector<T>& x)
: buffer(s.bytes(), 0), shape_(s)
template <class T>
literal(shape s, const std::vector<T>& x) : buffer(s.bytes(), 0), shape_(s)
{
assert(s.packed());
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
s.visit_type([&](auto as) {
std::copy(x.begin(), x.end(), as.from(buffer.data()));
});
s.visit_type([&](auto as) { std::copy(x.begin(), x.end(), as.from(buffer.data())); });
}
template<class T>
literal(shape s, const std::initializer_list<T>& x)
: buffer(s.bytes(), 0), shape_(s)
template <class T>
literal(shape s, const std::initializer_list<T>& x) : buffer(s.bytes(), 0), shape_(s)
{
assert(s.packed());
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
s.visit_type([&](auto as) {
std::copy(x.begin(), x.end(), as.from(buffer.data()));
});
s.visit_type([&](auto as) { std::copy(x.begin(), x.end(), as.from(buffer.data())); });
}
template<class Iterator>
literal(shape s, Iterator start, Iterator end)
: buffer(s.bytes(), 0), shape_(s)
template <class Iterator>
literal(shape s, Iterator start, Iterator end) : buffer(s.bytes(), 0), shape_(s)
{
assert(s.packed());
s.visit_type([&](auto as) {
std::copy(start, end, as.from(buffer.data()));
});
s.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.data())); });
}
literal(shape s, const char* x)
: buffer(x, x+s.bytes()), shape_(s)
{}
bool empty() const
{
return this->buffer.empty();
}
literal(shape s, const char* x) : buffer(x, x + s.bytes()), shape_(s) {}
const char* data() const
{
return this->buffer.data();
}
bool empty() const { return this->buffer.empty(); }
const shape& get_shape() const
{
return this->shape_;
}
const char* data() const { return this->buffer.data(); }
const shape& get_shape() const { return this->shape_; }
argument get_argument() const
{
......@@ -79,11 +56,11 @@ struct literal : raw_data<literal>
return {shape_, [b]() mutable { return b.data(); }};
}
private:
private:
std::vector<char> buffer;
shape shape_;
};
}
} // namespace rtg
#endif
......@@ -12,16 +12,16 @@
namespace rtg {
/*
* Type-erased interface for:
*
* struct operand
* {
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* argument compute(std::vector<argument> input) const;
* };
*
*/
* Type-erased interface for:
*
* struct operand
* {
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* argument compute(std::vector<argument> input) const;
* };
*
*/
struct operand
{
......@@ -80,8 +80,9 @@ struct operand
struct handle_type_ : handle_base_type_
{
template <typename TypeErased_U_ = TypeErased_T_>
handle_type_(TypeErased_T_ value,
typename std::enable_if<std::is_reference<TypeErased_U_>::value>::type* = nullptr)
handle_type_(
TypeErased_T_ value,
typename std::enable_if<std::is_reference<TypeErased_U_>::value>::type* = nullptr)
: value_(value)
{
}
......@@ -89,7 +90,8 @@ struct operand
template <typename TypeErased_U_ = TypeErased_T_>
handle_type_(TypeErased_T_ value,
typename std::enable_if<!std::is_reference<TypeErased_U_>::value, int>::type* =
nullptr) noexcept : value_(std::move(value))
nullptr) noexcept
: value_(std::move(value))
{
}
......@@ -134,6 +136,6 @@ struct operand
std::shared_ptr<handle_base_type_> handle_mem_var_;
};
}
} // namespace rtg
#endif
......@@ -9,120 +9,120 @@ namespace rtg {
struct not_computable
{
argument compute(std::vector<argument>) const
{
throw std::runtime_error("not computable");
}
argument compute(std::vector<argument>) const { throw std::runtime_error("not computable"); }
};
struct convolution
{
std::array<std::size_t, 2> padding = {0, 0};
std::array<std::size_t, 2> stride = {1, 1};
std::array<std::size_t, 2> padding = {0, 0};
std::array<std::size_t, 2> stride = {1, 1};
std::array<std::size_t, 2> dilation = {1, 1};
std::string name() const
{
return "convolution[padding={" + to_string(padding) +
"}, stride={" + to_string(stride) +
"}, dilation={" + to_string(dilation) +
"}]";
return "convolution[padding={" + to_string(padding) + "}, stride={" + to_string(stride) +
"}, dilation={" + to_string(dilation) + "}]";
}
shape compute_shape(std::vector<shape> inputs) const
{
if(inputs.size() != 2) throw std::runtime_error("Wrong number of arguments");
const shape& input = inputs.at(0);
if(inputs.size() != 2)
throw std::runtime_error("Wrong number of arguments");
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
if(input.type() != weights.type()) throw std::runtime_error("Type doesn't match");
if(input.lens().size() != weights.lens().size()) throw std::runtime_error("Dimensions don't match");
if(input.lens().size() != 4) throw std::runtime_error("Only 4d convolution supported");
if(input.type() != weights.type())
throw std::runtime_error("Type doesn't match");
if(input.lens().size() != weights.lens().size())
throw std::runtime_error("Dimensions don't match");
if(input.lens().size() != 4)
throw std::runtime_error("Only 4d convolution supported");
auto t = input.type();
return {t, {
input.lens()[0],
weights.lens()[0],
std::size_t(std::max<std::ptrdiff_t>(
1, (input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) + 2 * padding[0]) / stride[0] + 1)),
std::size_t(std::max<std::ptrdiff_t>(
1, (input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) + 2 * padding[1]) / stride[1] + 1)),
}};
return {t,
{
input.lens()[0],
weights.lens()[0],
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
2 * padding[0]) /
stride[0] +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
2 * padding[1]) /
stride[1] +
1)),
}};
}
argument compute(std::vector<argument>) const
{
throw std::runtime_error("not computable");
}
argument compute(std::vector<argument>) const { throw std::runtime_error("not computable"); }
};
struct pooling
{
std::string mode;
std::array<std::size_t, 2> padding = {0, 0};
std::array<std::size_t, 2> stride = {1, 1};
std::array<std::size_t, 2> stride = {1, 1};
std::array<std::size_t, 2> lengths = {1, 1};
std::string name() const
{
return "pooling:" + mode + "[padding={" + to_string(padding) +
"}, stride={" + to_string(stride) +
"}, lengths={" + to_string(lengths) +
"}]";
return "pooling:" + mode + "[padding={" + to_string(padding) + "}, stride={" +
to_string(stride) + "}, lengths={" + to_string(lengths) + "}]";
}
shape compute_shape(std::vector<shape> inputs) const
{
if(inputs.empty()) throw std::runtime_error("Wrong number of arguments");
const shape& input = inputs.at(0);
if(input.lens().size() != 4) throw std::runtime_error("Only 4d pooling supported");
if(inputs.empty())
throw std::runtime_error("Wrong number of arguments");
const shape& input = inputs.at(0);
if(input.lens().size() != 4)
throw std::runtime_error("Only 4d pooling supported");
auto t = input.type();
return {t, {
input.lens()[0],
input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>(
1, std::ceil((input.lens()[3] + 2 * padding[0] - lengths[0]) / static_cast<float>(stride[0])) + 1)),
std::size_t(std::max<std::ptrdiff_t>(
1, std::ceil((input.lens()[4] + 2 * padding[1] - lengths[1]) / static_cast<float>(stride[1])) + 1)),
}};
return {t,
{
input.lens()[0],
input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>(
1,
std::ceil((input.lens()[3] + 2 * padding[0] - lengths[0]) /
static_cast<float>(stride[0])) +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
std::ceil((input.lens()[4] + 2 * padding[1] - lengths[1]) /
static_cast<float>(stride[1])) +
1)),
}};
}
argument compute(std::vector<argument>) const
{
throw std::runtime_error("not computable");
}
argument compute(std::vector<argument>) const { throw std::runtime_error("not computable"); }
};
struct activation
{
std::string mode;
std::string name() const
{
return "activation:" + mode;
}
std::string name() const { return "activation:" + mode; }
shape compute_shape(std::vector<shape> inputs) const
{
if(inputs.empty()) throw std::runtime_error("Wrong number of arguments");
if(inputs.empty())
throw std::runtime_error("Wrong number of arguments");
return inputs.front();
}
argument compute(std::vector<argument>) const
{
throw std::runtime_error("not computable");
}
argument compute(std::vector<argument>) const { throw std::runtime_error("not computable"); }
};
struct reshape
{
std::vector<int64_t> dims;
std::string name() const
{
return "reshape[dims={" + to_string(dims) +
"}]";
}
std::string name() const { return "reshape[dims={" + to_string(dims) + "}]"; }
shape compute_shape(std::vector<shape> inputs) const
{
if(inputs.empty()) throw std::runtime_error("Wrong number of arguments");
if(inputs.empty())
throw std::runtime_error("Wrong number of arguments");
auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end());
for(std::size_t i = 0;i < dims.size();i++)
for(std::size_t i = 0; i < dims.size(); i++)
{
if(dims[i] == 0)
rdims[i] = idims[i];
......@@ -130,18 +130,14 @@ struct reshape
if(dims.back() == -1)
{
rdims.pop_back();
std::copy(idims.begin()+rdims.size(), idims.end(), std::back_inserter(rdims));
std::copy(idims.begin() + rdims.size(), idims.end(), std::back_inserter(rdims));
}
return {inputs.front().type(), rdims};
}
argument compute(std::vector<argument>) const
{
throw std::runtime_error("not computable");
}
argument compute(std::vector<argument>) const { throw std::runtime_error("not computable"); }
};
} // namespace rtg
#endif
......@@ -13,35 +13,38 @@ namespace rtg {
struct program
{
// TODO: A program should be copyable
program() = default;
program() = default;
program(const program&) = delete;
program& operator=(const program&) = delete;
template<class... Ts>
instruction * add_instruction(operand op, Ts*... args)
template <class... Ts>
instruction* add_instruction(operand op, Ts*... args)
{
shape r = op.compute_shape({args->result...});
instructions.push_back({op, r, {args...}});
return std::addressof(instructions.back());
}
instruction * add_instruction(operand op, std::vector<instruction*> args)
instruction* add_instruction(operand op, std::vector<instruction*> args)
{
assert(std::all_of(args.begin(), args.end(), [&](instruction* x) { return has_instruction(x); }) && "Argument is not an exisiting instruction");
assert(std::all_of(
args.begin(), args.end(), [&](instruction* x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
std::vector<shape> shapes(args.size());
std::transform(args.begin(), args.end(), shapes.begin(), [](instruction* ins) { return ins->result; });
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction* ins) { return ins->result; });
shape r = op.compute_shape(shapes);
instructions.push_back({op, r, args});
assert(instructions.back().arguments == args);
return std::addressof(instructions.back());
}
template<class... Ts>
instruction * add_literal(Ts&&... xs)
template <class... Ts>
instruction* add_literal(Ts&&... xs)
{
instructions.emplace_back(literal{std::forward<Ts>(xs)...});
return std::addressof(instructions.back());
}
instruction * add_parameter(std::string name, shape s)
instruction* add_parameter(std::string name, shape s)
{
instructions.push_back({builtin::param{std::move(name)}, s, {}});
return std::addressof(instructions.back());
......@@ -52,16 +55,18 @@ struct program
// TODO: Change to stream operator
void print() const;
bool has_instruction(const instruction * ins) const
bool has_instruction(const instruction* ins) const
{
return std::find_if(instructions.begin(), instructions.end(), [&](const instruction& x) {return ins == std::addressof(x); }) != instructions.end();
return std::find_if(instructions.begin(), instructions.end(), [&](const instruction& x) {
return ins == std::addressof(x);
}) != instructions.end();
}
private:
private:
// A list is used to keep references to an instruction stable
std::list<instruction> instructions;
};
}
} // namespace rtg
#endif
......@@ -6,14 +6,14 @@
namespace rtg {
template<class Derived>
template <class Derived>
struct raw_data
{
friend bool operator==(const Derived& x, const Derived& y)
{
auto&& xshape = x.get_shape();
auto&& yshape = y.get_shape();
bool result = x.empty() && y.empty();
bool result = x.empty() && y.empty();
if(not result && xshape == yshape)
{
auto&& xbuffer = x.data();
......@@ -22,59 +22,48 @@ struct raw_data
xshape.visit_type([&](auto as) {
auto xview = make_view(xshape, as.from(xbuffer));
auto yview = make_view(yshape, as.from(ybuffer));
result = xview == yview;
result = xview == yview;
});
}
return result;
}
friend bool operator!=(const Derived& x, const Derived& y)
{
return !(x == y);
}
template<class Stream>
friend bool operator!=(const Derived& x, const Derived& y) { return !(x == y); }
template <class Stream>
friend Stream& operator<<(Stream& os, const Derived& d)
{
d.visit([&](auto x) {
os << x;
});
d.visit([&](auto x) { os << x; });
return os;
}
template<class Visitor>
void visit_at(Visitor v, std::size_t n=0) const
template <class Visitor>
void visit_at(Visitor v, std::size_t n = 0) const
{
auto && s = static_cast<const Derived&>(*this).get_shape();
auto && buffer = static_cast<const Derived&>(*this).data();
s.visit_type([&](auto as) {
v(*(as.from(buffer)+s.index(n)));
});
auto&& s = static_cast<const Derived&>(*this).get_shape();
auto&& buffer = static_cast<const Derived&>(*this).data();
s.visit_type([&](auto as) { v(*(as.from(buffer) + s.index(n))); });
}
template<class Visitor>
template <class Visitor>
void visit(Visitor v) const
{
auto && s = static_cast<const Derived&>(*this).get_shape();
auto && buffer = static_cast<const Derived&>(*this).data();
s.visit_type([&](auto as) {
v(make_view(s, as.from(buffer)));
});
auto&& s = static_cast<const Derived&>(*this).get_shape();
auto&& buffer = static_cast<const Derived&>(*this).data();
s.visit_type([&](auto as) { v(make_view(s, as.from(buffer))); });
}
bool single() const
{
auto && s = static_cast<const Derived&>(*this).get_shape();
auto&& s = static_cast<const Derived&>(*this).get_shape();
return s.elements() == 1;
}
template<class T>
T at(std::size_t n=0) const
template <class T>
T at(std::size_t n = 0) const
{
T result;
this->visit_at([&](auto x) {
result = x;
}, n);
this->visit_at([&](auto x) { result = x; }, n);
return result;
}
};
......
......@@ -11,6 +11,7 @@ struct shape
{
// Add new types here
// clang-format off
#define RTG_SHAPE_VISIT_TYPES(m) \
m(float_type, float) \
m(double_type, double) \
......@@ -23,6 +24,7 @@ struct shape
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \
// clang-format on
#define RTG_SHAPE_ENUM_TYPES(x, t) x,
enum type_t
{
......@@ -30,12 +32,13 @@ struct shape
};
#undef RTG_SHAPE_ENUM_TYPES
template<class T, class=void>
template <class T, class = void>
struct get_type;
#define RTG_SHAPE_GET_TYPE(x, t) \
template<class T> \
#define RTG_SHAPE_GET_TYPE(x, t) \
template <class T> \
struct get_type<t, T> : std::integral_constant<type_t, x> \
{};
{ \
};
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_GET_TYPE)
#undef RTG_SHAPE_GET_TYPE
......@@ -44,7 +47,6 @@ struct shape
shape(type_t t, std::vector<std::size_t> l);
shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s);
type_t type() const;
const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t>& strides() const;
......@@ -63,67 +65,60 @@ struct shape
friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x);
template<class T>
template <class T>
struct as
{
using type = T;
template<class U>
template <class U>
T operator()(U u) const
{
return T(u);
}
template<class U>
template <class U>
T* operator()(U* u) const
{
return static_cast<T*>(u);
}
template<class U>
template <class U>
const T* operator()(const U* u) const
{
return static_cast<T*>(u);
}
T operator()() const
{
return {};
}
T operator()() const { return {}; }
std::size_t size(std::size_t n=1) const
{
return sizeof(T)*n;
}
std::size_t size(std::size_t n = 1) const { return sizeof(T) * n; }
template<class U>
T* from(U* buffer, std::size_t n=0) const
template <class U>
T* from(U* buffer, std::size_t n = 0) const
{
return reinterpret_cast<T*>(buffer)+n;
return reinterpret_cast<T*>(buffer) + n;
}
template<class U>
const T* from(const U* buffer, std::size_t n=0) const
template <class U>
const T* from(const U* buffer, std::size_t n = 0) const
{
return reinterpret_cast<const T*>(buffer)+n;
return reinterpret_cast<const T*>(buffer) + n;
}
};
template<class Visitor>
template <class Visitor>
void visit_type(Visitor v) const
{
switch(this->type_)
switch(this->type_)
{
#define RTG_SHAPE_VISITOR_CASE(x, t) \
case x: \
v(as<t>()); \
return;
case x: v(as<t>()); return;
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_VISITOR_CASE)
#undef RTG_SHAPE_VISITOR_CASE
}
assert(true);
}
private:
private:
type_t type_;
std::vector<std::size_t> lens_;
std::vector<std::size_t> strides_;
......@@ -134,6 +129,6 @@ private:
std::string type_string() const;
};
}
} // namespace rtg
#endif
......@@ -65,17 +65,14 @@ inline std::string remove_prefix(std::string s, std::string prefix)
return s;
}
template<class Range>
template <class Range>
inline std::string to_string(const Range& r)
{
std::stringstream ss;
if(!r.empty())
{
ss << r.front();
std::for_each(std::next(r.begin()), r.end(), [&](auto&& x)
{
ss << ", " << x;
});
std::for_each(std::next(r.begin()), r.end(), [&](auto&& x) { ss << ", " << x; });
}
return ss.str();
}
......
......@@ -8,48 +8,29 @@
namespace rtg {
template<class T>
template <class T>
struct tensor_view
{
tensor_view()
: data_(nullptr), shape_()
{}
tensor_view(shape s, T* d)
: data_(d), shape_(s)
{}
const shape& get_shape() const
{
return this->shape_;
}
tensor_view() : data_(nullptr), shape_() {}
tensor_view(shape s, T* d) : data_(d), shape_(s) {}
bool empty() const
{
return data_ == nullptr || shape_.lens().size() == 0;
}
const shape& get_shape() const { return this->shape_; }
std::size_t size() const
{
return shape_.elements();
}
bool empty() const { return data_ == nullptr || shape_.lens().size() == 0; }
T* data()
{
return this->data_;
}
std::size_t size() const { return shape_.elements(); }
const T* data() const
{
return this->data_;
}
T* data() { return this->data_; }
template<class... Ts>
const T* data() const { return this->data_; }
template <class... Ts>
const T& operator()(Ts... xs) const
{
return data_[shape_.index({xs...})];
}
template<class... Ts>
template <class... Ts>
T& operator()(Ts... xs)
{
return data_[shape_.index({xs...})];
......@@ -82,13 +63,13 @@ struct tensor_view
T& back()
{
assert(!this->empty());
return data_[shape_.index(this->size()-1)];
return data_[shape_.index(this->size() - 1)];
}
const T& back() const
{
assert(!this->empty());
return data_[shape_.index(this->size()-1)];
return data_[shape_.index(this->size() - 1)];
}
// TODO: Add iterators so it can handle nonpacked tensors
......@@ -101,8 +82,10 @@ struct tensor_view
T* end()
{
assert(this->shape_.packed());
if(this->empty()) return data_;
else return data_+this->size();
if(this->empty())
return data_;
else
return data_ + this->size();
}
const T* begin() const
......@@ -114,34 +97,34 @@ struct tensor_view
const T* end() const
{
assert(this->shape_.packed());
if(this->empty()) return data_;
else return data_+this->size();
if(this->empty())
return data_;
else
return data_ + this->size();
}
friend bool operator==(const tensor_view<T>& x, const tensor_view<T>& y)
{
if(x.shape_ == y.shape_)
{
for(std::size_t i = 0;i < x.shape_.elements();i++)
for(std::size_t i = 0; i < x.shape_.elements(); i++)
{
if(!float_equal(x[i], y[i])) return false;
if(!float_equal(x[i], y[i]))
return false;
}
return true;
}
return false;
}
friend bool operator!=(const tensor_view<T>& x, const tensor_view<T>& y)
{
return !(x == y);
}
friend bool operator!=(const tensor_view<T>& x, const tensor_view<T>& y) { return !(x == y); }
friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x)
{
if(!x.empty())
{
os << x.front();
for(std::size_t i = 1;i < x.shape_.elements();i++)
for(std::size_t i = 1; i < x.shape_.elements(); i++)
{
os << ", " << x.data_[x.shape_.index(i)];
}
......@@ -149,12 +132,12 @@ struct tensor_view
return os;
}
private:
private:
T* data_;
shape shape_;
};
template<class T>
template <class T>
tensor_view<T> make_view(shape s, T* data)
{
return {s, data};
......
......@@ -13,43 +13,41 @@
struct unknown
{
std::string op;
std::string name() const
{
return "unknown:"+op;
}
std::string name() const { return "unknown:" + op; }
rtg::shape compute_shape(std::vector<rtg::shape> input) const
{
if(input.empty()) return {};
else return input.front();
}
rtg::argument compute(std::vector<rtg::argument> input) const
{
throw "not computable";
if(input.empty())
return {};
else
return input.front();
}
rtg::argument compute(std::vector<rtg::argument> input) const { throw "not computable"; }
};
template<class C, class T>
template <class C, class T>
bool contains(C&& c, T&& x)
{
return c.find(x) != c.end();
}
template<class Range, class Iterator>
template <class Range, class Iterator>
void copy(Range&& r, Iterator it)
{
std::copy(r.begin(), r.end(), it);
}
struct onnx_parser
struct onnx_parser
{
using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
using node_map = std::unordered_map<std::string, onnx::NodeProto>;
using node_map = std::unordered_map<std::string, onnx::NodeProto>;
node_map nodes;
std::unordered_map<std::string, rtg::instruction*> instructions;
std::shared_ptr<rtg::program> prog = std::make_shared<rtg::program>();
std::unordered_map<std::string, std::function<rtg::instruction*(attribute_map, std::vector<rtg::instruction*>)>> ops;
std::unordered_map<
std::string,
std::function<rtg::instruction*(attribute_map, std::vector<rtg::instruction*>)>>
ops;
onnx_parser()
{
......@@ -92,10 +90,7 @@ struct onnx_parser
add_op("Reshape", [this](attribute_map attributes, std::vector<rtg::instruction*> args) {
rtg::reshape op;
rtg::literal s = parse_value(attributes.at("shape"));
s.visit([&](auto v)
{
copy(v, std::back_inserter(op.dims));
});
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
return prog->add_instruction(op, args);
});
add_op("Constant", [this](attribute_map attributes, std::vector<rtg::instruction*>) {
......@@ -104,7 +99,7 @@ struct onnx_parser
});
}
template<class F>
template <class F>
void add_op(std::string name, F f)
{
ops.emplace(name, f);
......@@ -113,14 +108,14 @@ struct onnx_parser
void parse_from(std::istream& is)
{
onnx::ModelProto model;
if(model.ParseFromIstream(&is))
if(model.ParseFromIstream(&is))
{
if(model.has_graph())
if(model.has_graph())
{
this->parse_graph(model.graph());
}
}
else
}
else
{
throw std::runtime_error("Failed reading");
}
......@@ -129,14 +124,14 @@ struct onnx_parser
void parse_graph(const onnx::GraphProto& graph)
{
nodes = get_nodes(graph);
for(auto&& input:graph.input())
for(auto&& input : graph.input())
{
std::string name = input.name();
// TODO: Get shape of input parameter
rtg::shape s = parse_type(input.type());
rtg::shape s = parse_type(input.type());
instructions[name] = prog->add_parameter(name, s);
}
for(auto&& p:nodes)
for(auto&& p : nodes)
{
this->parse_node(p.second.name());
}
......@@ -144,11 +139,11 @@ struct onnx_parser
void parse_node(std::string name)
{
if (instructions.count(name) == 0)
if(instructions.count(name) == 0)
{
auto&& node = nodes.at(name);
std::vector<rtg::instruction*> args;
for(auto&& input:node.input())
for(auto&& input : node.input())
{
if(nodes.count(input) > 0)
{
......@@ -161,7 +156,7 @@ struct onnx_parser
args.push_back(instructions.at(input));
}
}
if (ops.count(node.op_type()) == 0)
if(ops.count(node.op_type()) == 0)
{
instructions[name] = prog->add_instruction(unknown{node.op_type()}, args);
}
......@@ -175,7 +170,7 @@ struct onnx_parser
static attribute_map get_attributes(const onnx::NodeProto& node)
{
std::unordered_map<std::string, onnx::AttributeProto> result;
for(auto&& attr:node.attribute())
for(auto&& attr : node.attribute())
{
result[attr.name()] = attr;
}
......@@ -185,14 +180,13 @@ struct onnx_parser
static node_map get_nodes(const onnx::GraphProto& graph)
{
std::unordered_map<std::string, onnx::NodeProto> result;
for(auto&& node:graph.node())
for(auto&& node : graph.node())
{
result[node.name()] = node;
for(auto&& output:node.output())
for(auto&& output : node.output())
{
result[output] = node;
}
}
return result;
}
......@@ -201,17 +195,20 @@ struct onnx_parser
{
switch(attr.type())
{
case onnx::AttributeProto::UNDEFINED: return {};
case onnx::AttributeProto::FLOAT: return rtg::literal{attr.f()};
case onnx::AttributeProto::INT: return rtg::literal{attr.i()};
case onnx::AttributeProto::STRING: return {};
case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
case onnx::AttributeProto::GRAPH: return {};
case onnx::AttributeProto::FLOATS: return rtg::literal{rtg::shape::float_type, attr.floats().begin(), attr.floats().end()};
case onnx::AttributeProto::INTS: return rtg::literal{rtg::shape::int32_type, attr.ints().begin(), attr.ints().end()};;
case onnx::AttributeProto::STRINGS: return {};
case onnx::AttributeProto::TENSORS: return {};
case onnx::AttributeProto::GRAPHS: return {};
case onnx::AttributeProto::UNDEFINED: return {};
case onnx::AttributeProto::FLOAT: return rtg::literal{attr.f()};
case onnx::AttributeProto::INT: return rtg::literal{attr.i()};
case onnx::AttributeProto::STRING: return {};
case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
case onnx::AttributeProto::GRAPH: return {};
case onnx::AttributeProto::FLOATS:
return rtg::literal{rtg::shape::float_type, attr.floats().begin(), attr.floats().end()};
case onnx::AttributeProto::INTS:
return rtg::literal{rtg::shape::int32_type, attr.ints().begin(), attr.ints().end()};
;
case onnx::AttributeProto::STRINGS: return {};
case onnx::AttributeProto::TENSORS: return {};
case onnx::AttributeProto::GRAPHS: return {};
}
}
......@@ -220,22 +217,38 @@ struct onnx_parser
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
switch(t.data_type())
{
case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT: return rtg::literal{{rtg::shape::float_type, dims}, t.float_data().begin(), t.float_data().end()};
case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8: return rtg::literal{{rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::UINT16: return rtg::literal{{rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::INT16: return rtg::literal{{rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::INT32: return rtg::literal{{rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::INT64: return rtg::literal{{rtg::shape::int64_type, dims}, t.int64_data().begin(), t.int64_data().end()};
case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: return rtg::literal{{rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::FLOAT16: throw std::runtime_error("");
case onnx::TensorProto::DOUBLE: return rtg::literal{{rtg::shape::double_type, dims}, t.double_data().begin(), t.double_data().end()};
case onnx::TensorProto::UINT32: throw std::runtime_error("");
case onnx::TensorProto::UINT64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT:
return rtg::literal{
{rtg::shape::float_type, dims}, t.float_data().begin(), t.float_data().end()};
case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8:
return rtg::literal{
{rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::UINT16:
return rtg::literal{
{rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::INT16:
return rtg::literal{
{rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::INT32:
return rtg::literal{
{rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::INT64:
return rtg::literal{
{rtg::shape::int64_type, dims}, t.int64_data().begin(), t.int64_data().end()};
case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL:
return rtg::literal{
{rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::FLOAT16: throw std::runtime_error("");
case onnx::TensorProto::DOUBLE:
return rtg::literal{
{rtg::shape::double_type, dims}, t.double_data().begin(), t.double_data().end()};
case onnx::TensorProto::UINT32: throw std::runtime_error("");
case onnx::TensorProto::UINT64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
}
}
......@@ -244,26 +257,33 @@ struct onnx_parser
rtg::shape::type_t shape_type;
switch(t.tensor_type().elem_type())
{
case onnx::TensorProto::UNDEFINED: break; //throw std::runtime_error("Unsupported type UNDEFINED");
case onnx::TensorProto::FLOAT: shape_type = rtg::shape::float_type;
case onnx::TensorProto::UINT8: break; //throw std::runtime_error("Unsupported type UINT8");
case onnx::TensorProto::INT8: shape_type = rtg::shape::int8_type;
case onnx::TensorProto::UINT16: shape_type = rtg::shape::uint16_type;
case onnx::TensorProto::INT16: shape_type = rtg::shape::int16_type;
case onnx::TensorProto::INT32: shape_type = rtg::shape::int32_type;
case onnx::TensorProto::INT64: shape_type = rtg::shape::int64_type;
case onnx::TensorProto::STRING: break; //throw std::runtime_error("Unsupported type STRING");
case onnx::TensorProto::BOOL: break; //throw std::runtime_error("Unsupported type BOOL");
case onnx::TensorProto::FLOAT16: break; //throw std::runtime_error("Unsupported type FLOAT16");
case onnx::TensorProto::DOUBLE: shape_type = rtg::shape::double_type;
case onnx::TensorProto::UINT32: shape_type = rtg::shape::uint32_type;
case onnx::TensorProto::UINT64: shape_type = rtg::shape::uint64_type;
case onnx::TensorProto::COMPLEX64: break; //throw std::runtime_error("Unsupported type COMPLEX64");
case onnx::TensorProto::COMPLEX128: break; //throw std::runtime_error("Unsupported type COMPLEX128");
case onnx::TensorProto::UNDEFINED:
break; // throw std::runtime_error("Unsupported type UNDEFINED");
case onnx::TensorProto::FLOAT: shape_type = rtg::shape::float_type;
case onnx::TensorProto::UINT8:
break; // throw std::runtime_error("Unsupported type UINT8");
case onnx::TensorProto::INT8: shape_type = rtg::shape::int8_type;
case onnx::TensorProto::UINT16: shape_type = rtg::shape::uint16_type;
case onnx::TensorProto::INT16: shape_type = rtg::shape::int16_type;
case onnx::TensorProto::INT32: shape_type = rtg::shape::int32_type;
case onnx::TensorProto::INT64: shape_type = rtg::shape::int64_type;
case onnx::TensorProto::STRING:
break; // throw std::runtime_error("Unsupported type STRING");
case onnx::TensorProto::BOOL:
break; // throw std::runtime_error("Unsupported type BOOL");
case onnx::TensorProto::FLOAT16:
break; // throw std::runtime_error("Unsupported type FLOAT16");
case onnx::TensorProto::DOUBLE: shape_type = rtg::shape::double_type;
case onnx::TensorProto::UINT32: shape_type = rtg::shape::uint32_type;
case onnx::TensorProto::UINT64: shape_type = rtg::shape::uint64_type;
case onnx::TensorProto::COMPLEX64:
break; // throw std::runtime_error("Unsupported type COMPLEX64");
case onnx::TensorProto::COMPLEX128:
break; // throw std::runtime_error("Unsupported type COMPLEX128");
}
std::vector<std::size_t> dims;
// TODO: USe std::transform
for(auto&& d:t.tensor_type().shape().dim())
for(auto&& d : t.tensor_type().shape().dim())
{
dims.push_back(d.dim_value());
}
......@@ -271,7 +291,7 @@ struct onnx_parser
}
};
int main(int argc, char const *argv[])
int main(int argc, char const* argv[])
{
if(argc > 1)
{
......@@ -284,7 +304,8 @@ int main(int argc, char const *argv[])
}
catch(...)
{
if(parser.prog) parser.prog->print();
if(parser.prog)
parser.prog->print();
throw;
}
parser.prog->print();
......
......@@ -9,7 +9,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
{
std::unordered_map<const instruction*, argument> results;
argument result;
for(auto& ins:instructions)
for(auto& ins : instructions)
{
if(ins.op.name() == "@literal")
{
......@@ -22,9 +22,10 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
else
{
std::vector<argument> values(ins.arguments.size());
std::transform(ins.arguments.begin(), ins.arguments.end(), values.begin(), [&](instruction * i) {
return results.at(i);
});
std::transform(ins.arguments.begin(),
ins.arguments.end(),
values.begin(),
[&](instruction* i) { return results.at(i); });
result = ins.op.compute(values);
}
results.emplace(std::addressof(ins), result);
......@@ -37,7 +38,7 @@ void program::print() const
std::unordered_map<const instruction*, std::string> names;
int count = 0;
for(auto& ins:instructions)
for(auto& ins : instructions)
{
std::string var_name = "@" + std::to_string(count);
if(starts_with(ins.op.name(), "@param"))
......@@ -51,7 +52,7 @@ void program::print() const
if(ins.op.name() == "@literal")
{
if (ins.lit.get_shape().elements() > 10)
if(ins.lit.get_shape().elements() > 10)
std::cout << "{ ... }";
else
std::cout << "{" << ins.lit << "}";
......@@ -60,7 +61,7 @@ void program::print() const
if(!ins.arguments.empty())
{
char delim = '(';
for(auto&& arg:ins.arguments)
for(auto&& arg : ins.arguments)
{
assert(this->has_instruction(arg) && "Instruction not found");
std::cout << delim << names.at(arg);
......@@ -78,5 +79,4 @@ void program::print() const
}
}
}
} // namespace rtg
......@@ -7,21 +7,16 @@
namespace rtg {
shape::shape()
: type_(float_type), lens_(), strides_(), packed_(false)
{}
shape::shape() : type_(float_type), lens_(), strides_(), packed_(false) {}
shape::shape(type_t t)
: type_(t), lens_({1}), strides_({1}), packed_(true)
{}
shape::shape(type_t t, std::vector<std::size_t> l)
: type_(t), lens_(std::move(l)), packed_(true)
shape::shape(type_t t) : type_(t), lens_({1}), strides_({1}), packed_(true) {}
shape::shape(type_t t, std::vector<std::size_t> l) : type_(t), lens_(std::move(l)), packed_(true)
{
this->calculate_strides();
assert(lens_.size() == strides_.size());
}
shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: type_(t), lens_(std::move(l)), strides_(std::move(s))
: type_(t), lens_(std::move(l)), strides_(std::move(s))
{
assert(lens_.size() == strides_.size());
packed_ = this->elements() == this->element_space();
......@@ -38,18 +33,9 @@ void shape::calculate_strides()
lens_.rbegin(), lens_.rend() - 1, strides_.rbegin() + 1, std::multiplies<std::size_t>());
}
shape::type_t shape::type() const
{
return this->type_;
}
const std::vector<std::size_t>& shape::lens() const
{
return this->lens_;
}
const std::vector<std::size_t>& shape::strides() const
{
return this->strides_;
}
shape::type_t shape::type() const { return this->type_; }
const std::vector<std::size_t>& shape::lens() const { return this->lens_; }
const std::vector<std::size_t>& shape::strides() const { return this->strides_; }
std::size_t shape::elements() const
{
assert(this->lens().size() == this->strides().size());
......@@ -77,13 +63,15 @@ std::size_t shape::index(const std::vector<std::size_t>& l) const
std::size_t shape::index(std::size_t i) const
{
assert(this->lens().size() == this->strides().size());
return std::inner_product(this->lens().begin(), this->lens().end(), this->strides().begin(), std::size_t{0}, std::plus<std::size_t>{},
[&](std::size_t len, std::size_t stride) { return ((i / stride) % len)*stride; });
}
bool shape::packed() const
{
return this->packed_;
}
return std::inner_product(
this->lens().begin(),
this->lens().end(),
this->strides().begin(),
std::size_t{0},
std::plus<std::size_t>{},
[&](std::size_t len, std::size_t stride) { return ((i / stride) % len) * stride; });
}
bool shape::packed() const { return this->packed_; }
std::size_t shape::element_space() const
{
// TODO: Get rid of intermediate vector
......@@ -101,11 +89,10 @@ std::size_t shape::element_space() const
std::string shape::type_string() const
{
switch(this->type_)
switch(this->type_)
{
#define RTG_SHAPE_TYPE_STRING_CASE(x, t) \
case x: \
return #x;
case x: return #x;
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_TYPE_STRING_CASE)
#undef RTG_SHAPE_TYPE_STRING_CASE
}
......@@ -116,10 +103,7 @@ bool operator==(const shape& x, const shape& y)
{
return x.type() == y.type() && x.lens() == y.lens() && x.strides() == y.strides();
}
bool operator!=(const shape& x, const shape& y)
{
return !(x == y);
}
bool operator!=(const shape& x, const shape& y) { return !(x == y); }
std::ostream& operator<<(std::ostream& os, const shape& x)
{
......@@ -129,4 +113,4 @@ std::ostream& operator<<(std::ostream& os, const shape& x)
return os;
}
}
} // namespace rtg
......@@ -4,37 +4,37 @@
#include <rtg/shape.hpp>
#include "test.hpp"
struct sum_op
{
std::string name() const
{
return "sum";
}
std::string name() const { return "sum"; }
rtg::argument compute(std::vector<rtg::argument> args) const
{
rtg::argument result;
if(args.size() != 2) throw "Wrong args";
if(args[0].get_shape() != args[1].get_shape()) throw "Wrong args";
if(args[0].get_shape().lens().size() != 1) throw "Wrong args";
if(args[0].get_shape().lens().front() != 1) throw "Wrong args";
if(args.size() != 2)
throw "Wrong args";
if(args[0].get_shape() != args[1].get_shape())
throw "Wrong args";
if(args[0].get_shape().lens().size() != 1)
throw "Wrong args";
if(args[0].get_shape().lens().front() != 1)
throw "Wrong args";
args[0].visit_at([&](auto x) {
args[1].visit_at([&](auto y) {
result = rtg::literal{x + y}.get_argument();
});
args[1].visit_at([&](auto y) { result = rtg::literal{x + y}.get_argument(); });
});
return result;
}
rtg::shape compute_shape(std::vector<rtg::shape> inputs) const
{
if(inputs.size() != 2) throw "Wrong inputs";
if(inputs.size() != 2)
throw "Wrong inputs";
return inputs.front();
}
};
void literal_test() {
void literal_test()
{
rtg::program p;
auto one = p.add_literal(1);
......@@ -45,23 +45,22 @@ void literal_test() {
EXPECT(result != rtg::literal{4});
}
void param_test() {
void param_test()
{
rtg::program p;
auto x = p.add_parameter("x", {rtg::shape::int64_type});
auto y = p.add_parameter("y", {rtg::shape::int64_type});
p.add_instruction(sum_op{}, x, y);
auto result = p.eval({
{"x", rtg::literal{1}.get_argument()},
{"y", rtg::literal{2}.get_argument()}
});
auto result =
p.eval({{"x", rtg::literal{1}.get_argument()}, {"y", rtg::literal{2}.get_argument()}});
EXPECT(result == rtg::literal{3});
EXPECT(result != rtg::literal{4});
}
int main() {
int main()
{
literal_test();
param_test();
}
......@@ -4,7 +4,6 @@
#include <string>
#include "test.hpp"
void literal_test()
{
EXPECT(rtg::literal{1} == rtg::literal{1});
......@@ -23,7 +22,7 @@ void literal_test()
rtg::literal l4{};
EXPECT(l3 == l4);
EXPECT(l3.empty());
EXPECT(l4.empty());
EXPECT(l4.empty());
}
void literal_os1()
......@@ -51,10 +50,9 @@ void literal_os3()
EXPECT(ss.str() == "1, 2, 3");
}
int main() {
int main()
{
literal_test();
literal_os1();
literal_os2();
}
int main() {
}
int main() {}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment