Commit 92803edf authored by Paul's avatar Paul
Browse files

Formatting

parent dfa79e73
...@@ -1118,13 +1118,13 @@ struct scalar ...@@ -1118,13 +1118,13 @@ struct scalar
int output_alias(const std::vector<shape>&) const { return 0; } int output_alias(const std::vector<shape>&) const { return 0; }
}; };
template<class Derived> template <class Derived>
struct binary struct binary
{ {
std::string name() const std::string name() const
{ {
static const std::string& name = get_type_name<Derived>(); static const std::string& name = get_type_name<Derived>();
return name.substr(name.rfind("::")+2); return name.substr(name.rfind("::") + 2);
} }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
...@@ -1139,14 +1139,17 @@ struct binary ...@@ -1139,14 +1139,17 @@ struct binary
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
if(input1.get_shape().standard() and input2.get_shape().standard()) if(input1.get_shape().standard() and input2.get_shape().standard())
{ {
std::transform( std::transform(input1.begin(),
input1.begin(), input1.end(), input2.begin(), output.begin(), static_cast<const Derived&>(*this).apply()); input1.end(),
input2.begin(),
output.begin(),
static_cast<const Derived&>(*this).apply());
} }
else else
{ {
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
static_cast<const Derived&>(*this).apply()(input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end())); input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end()));
}); });
} }
}); });
...@@ -1156,32 +1159,50 @@ struct binary ...@@ -1156,32 +1159,50 @@ struct binary
struct add : binary<add> struct add : binary<add>
{ {
auto apply() const { return [](auto x, auto y) { return x + y; }; } auto apply() const
{
return [](auto x, auto y) { return x + y; };
}
}; };
struct sub : binary<sub> struct sub : binary<sub>
{ {
auto apply() const { return [](auto x, auto y) { return x - y; }; } auto apply() const
{
return [](auto x, auto y) { return x - y; };
}
}; };
struct mul : binary<mul> struct mul : binary<mul>
{ {
auto apply() const { return [](auto x, auto y) { return x * y; }; } auto apply() const
{
return [](auto x, auto y) { return x * y; };
}
}; };
struct div : binary<div> struct div : binary<div>
{ {
auto apply() const { return [](auto x, auto y) { return x / y; }; } auto apply() const
{
return [](auto x, auto y) { return x / y; };
}
}; };
struct max : binary<max> struct max : binary<max>
{ {
auto apply() const { return [](auto x, auto y) { return std::max(x, y); }; } auto apply() const
{
return [](auto x, auto y) { return std::max(x, y); };
}
}; };
struct min : binary<min> struct min : binary<min>
{ {
auto apply() const { return [](auto x, auto y) { return std::min(x, y); }; } auto apply() const
{
return [](auto x, auto y) { return std::min(x, y); };
}
}; };
struct load struct load
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment