Commit 92803edf authored by Paul's avatar Paul
Browse files

Formatting

parent dfa79e73
......@@ -1118,13 +1118,13 @@ struct scalar
int output_alias(const std::vector<shape>&) const { return 0; }
};
template<class Derived>
template <class Derived>
struct binary
{
std::string name() const
{
static const std::string& name = get_type_name<Derived>();
return name.substr(name.rfind("::")+2);
return name.substr(name.rfind("::") + 2);
}
shape compute_shape(std::vector<shape> inputs) const
{
......@@ -1139,14 +1139,17 @@ struct binary
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
if(input1.get_shape().standard() and input2.get_shape().standard())
{
std::transform(
input1.begin(), input1.end(), input2.begin(), output.begin(), static_cast<const Derived&>(*this).apply());
std::transform(input1.begin(),
input1.end(),
input2.begin(),
output.begin(),
static_cast<const Derived&>(*this).apply());
}
else
{
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) =
static_cast<const Derived&>(*this).apply()(input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end()));
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end()));
});
}
});
......@@ -1156,32 +1159,50 @@ struct binary
struct add : binary<add>
{
auto apply() const { return [](auto x, auto y) { return x + y; }; }
auto apply() const
{
return [](auto x, auto y) { return x + y; };
}
};
struct sub : binary<sub>
{
auto apply() const { return [](auto x, auto y) { return x - y; }; }
auto apply() const
{
return [](auto x, auto y) { return x - y; };
}
};
struct mul : binary<mul>
{
auto apply() const { return [](auto x, auto y) { return x * y; }; }
auto apply() const
{
return [](auto x, auto y) { return x * y; };
}
};
struct div : binary<div>
{
auto apply() const { return [](auto x, auto y) { return x / y; }; }
auto apply() const
{
return [](auto x, auto y) { return x / y; };
}
};
struct max : binary<max>
{
auto apply() const { return [](auto x, auto y) { return std::max(x, y); }; }
auto apply() const
{
return [](auto x, auto y) { return std::max(x, y); };
}
};
struct min : binary<min>
{
auto apply() const { return [](auto x, auto y) { return std::min(x, y); }; }
auto apply() const
{
return [](auto x, auto y) { return std::min(x, y); };
}
};
struct load
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment