Commit ca2cb168 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

temp changes.

parents a443007e d4d2335a
......@@ -30,6 +30,29 @@ struct binary
return {inputs.at(0).type(), inputs.at(0).lens()};
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
if(input1.get_shape().standard() and input2.get_shape().standard())
{
std::transform(input1.begin(),
input1.end(),
input2.begin(),
output.begin(),
static_cast<const Derived&>(*this).apply());
}
else
{
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end()));
});
}
});
return result;
}
};
} // namespace op
......
......@@ -2,7 +2,7 @@
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/sin.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/contiguous.hpp>
#include <basic_ops.hpp>
......@@ -61,7 +61,7 @@ TEST_CASE(transpose_standard_op)
auto l = p.add_literal(get_2x2());
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c = p.add_instruction(migraphx::op::contiguous{}, t);
auto sum = p.add_instruction(migraphx::op::add{}, c, c);
auto sum = p.add_instruction(migraphx::op::sin{}, c);
p.add_instruction(pass_standard_op{}, sum);
auto count = std::distance(p.begin(), p.end());
p.compile(eliminate_contiguous_target{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment