Commit dfa79e73 authored by Paul's avatar Paul
Browse files

Add evaluation of binary operators

parent cfbdef6b
......@@ -8,6 +8,7 @@
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
......@@ -1117,8 +1118,14 @@ struct scalar
int output_alias(const std::vector<shape>&) const { return 0; }
};
template<class Derived>
struct binary
{
std::string name() const
{
static const std::string& name = get_type_name<Derived>();
return name.substr(name.rfind("::")+2);
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type().same_dims();
......@@ -1126,36 +1133,55 @@ struct binary
auto lens = inputs.at(0).lens();
return {t, lens};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
if(input1.get_shape().standard() and input2.get_shape().standard())
{
std::transform(
input1.begin(), input1.end(), input2.begin(), output.begin(), static_cast<const Derived&>(*this).apply());
}
else
{
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) =
static_cast<const Derived&>(*this).apply()(input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end()));
});
}
});
return result;
}
};
struct add : binary
struct add : binary<add>
{
std::string name() const { return "add"; }
auto apply() const { return [](auto x, auto y) { return x + y; }; }
};
struct sub : binary
struct sub : binary<sub>
{
std::string name() const { return "sub"; }
auto apply() const { return [](auto x, auto y) { return x - y; }; }
};
struct mul : binary
struct mul : binary<mul>
{
std::string name() const { return "mul"; }
auto apply() const { return [](auto x, auto y) { return x * y; }; }
};
struct div : binary
struct div : binary<div>
{
std::string name() const { return "div"; }
auto apply() const { return [](auto x, auto y) { return x / y; }; }
};
struct max : binary
struct max : binary<max>
{
std::string name() const { return "max"; }
auto apply() const { return [](auto x, auto y) { return std::max(x, y); }; }
};
struct min : binary
struct min : binary<min>
{
std::string name() const { return "min"; }
auto apply() const { return [](auto x, auto y) { return std::min(x, y); }; }
};
struct load
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment