Unverified Commit e0fe2f55 authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge branch 'develop' into copy_program

parents 34c2889a 6f115a0f
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct sigmoid : unary
struct sigmoid : unary<sigmoid>
{
std::string name() const { return "sigmoid"; }
auto apply() const
{
return [](auto x) { return 1.f / (1.f + std::exp(-x)); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct sin : unary
struct sin : unary<sin>
{
std::string name() const { return "sin"; }
auto apply() const
{
return [](auto x) { return std::sin(x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct sinh : unary
struct sinh : unary<sinh>
{
std::string name() const { return "sinh"; }
auto apply() const
{
return [](auto x) { return std::sinh(x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct sub : binary
struct sub : binary<sub>
{
std::string name() const { return "sub"; }
auto apply() const
{
return [](auto x, auto y) { return x - y; };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct tan : unary
struct tan : unary<tan>
{
std::string name() const { return "tan"; }
auto apply() const
{
return [](auto x) { return std::tan(x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct tanh : unary
struct tanh : unary<tanh>
{
std::string name() const { return "tanh"; }
auto apply() const
{
return [](auto x) { return std::tanh(x); };
}
};
} // namespace op
......
#ifndef MIGRAPHX_GUARD_OPERATORS_UNARY_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNARY_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <migraphx/op/name.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct unary
template <class Derived>
struct unary : op_name<Derived>
{
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
return inputs.at(0);
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
if(input.get_shape().standard())
{
std::transform(input.begin(),
input.end(),
output.begin(),
static_cast<const Derived&>(*this).apply());
}
else
{
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) =
static_cast<const Derived&>(*this).apply()(input(idx.begin(), idx.end()));
});
}
});
return result;
}
};
} // namespace op
......
#ifndef MIGRAPHX_GUARD_RTGLIB_CONSTANT_PROPAGATE_HPP
#define MIGRAPHX_GUARD_RTGLIB_CONSTANT_PROPAGATE_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_PROPAGATE_CONSTANT_HPP
#define MIGRAPHX_GUARD_RTGLIB_PROPAGATE_CONSTANT_HPP
#include <string>
#include <migraphx/config.hpp>
......@@ -12,9 +12,9 @@ struct program;
/**
* Replace instructions which take all literals with a literal of the computation.
*/
struct constant_propagate
struct propagate_constant
{
std::string name() const { return "constant_propagate"; }
std::string name() const { return "propagate_constant"; }
void apply(program& p) const;
};
......
#include <migraphx/propagate_constant.hpp>
#include <migraphx/program.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/functional.hpp>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool skip_propogate(instruction_ref ins)
{
if(ins->name() == "@literal")
return true;
auto&& s = ins->get_shape();
if(s.broadcasted() and not s.scalar())
return true;
if(s.scalar() and s.elements() != 1)
return true;
return false;
}
void propagate_constant::apply(program& p) const
{
fix([&](auto self, auto ins) {
if(not skip_propogate(ins))
{
auto r = ins->eval();
if(not r.empty())
{
assert(r.get_shape() == ins->get_shape());
auto l = p.add_literal(r.get_shape(), r.data());
p.replace_instruction(ins, l);
return;
}
}
std::unordered_set<instruction_ref> children(ins->inputs().begin(), ins->inputs().end());
for(auto child : children)
self(child);
})(std::prev(p.end()));
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -11,7 +11,7 @@
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/constant_propagate.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
......@@ -48,7 +48,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
//dead_code_elimination{},
simplify_algebra{},
dead_code_elimination{},
constant_propagate{},
propagate_constant{},
dead_code_elimination{},
auto_contiguous{},
//simplify_reshapes{},
......
......@@ -236,8 +236,7 @@ struct test_exp : verify_program<test_exp>
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {6}};
std::vector<float> data{0.1f, 0.2f, 1.f, 2.f, 0.6f, 10.f};
auto x = p.add_literal(s, data);
auto x = p.add_instruction(migraphx::op::abs{}, p.add_parameter("x", s));
p.add_instruction(migraphx::op::exp{}, x);
return p;
}
......@@ -249,8 +248,7 @@ struct test_log : verify_program<test_log>
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {6}};
std::vector<float> data{0.1f, 0.2f, 1.f, 2.f, 0.6f, 100.f};
auto x = p.add_literal(s, data);
auto x = p.add_instruction(migraphx::op::abs{}, p.add_parameter("x", s));
p.add_instruction(migraphx::op::log{}, x);
return p;
}
......
#include <migraphx/constant_propagate.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/mul.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
......@@ -9,12 +11,12 @@ struct const_prop_target
std::string name() const { return "const_prop"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
{
return {migraphx::constant_propagate{}, migraphx::dead_code_elimination{}};
return {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}};
}
migraphx::context get_context() const { return {}; }
};
TEST_CASE(const_add1)
TEST_CASE(const_add)
{
migraphx::program p1;
auto one = p1.add_literal(1);
......@@ -29,7 +31,7 @@ TEST_CASE(const_add1)
EXPECT(p1 == p2);
}
TEST_CASE(const_add2)
TEST_CASE(const_add_parameter)
{
migraphx::program p1;
auto one = p1.add_parameter("one", {migraphx::shape::int32_type, {1}});
......@@ -44,7 +46,7 @@ TEST_CASE(const_add2)
EXPECT(p1 != p2);
}
TEST_CASE(const_add3)
TEST_CASE(const_multiadd)
{
migraphx::program p1;
auto one = p1.add_literal(1);
......@@ -60,4 +62,54 @@ TEST_CASE(const_add3)
EXPECT(p1 == p2);
}
TEST_CASE(const_add_mul)
{
migraphx::program p1;
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto mul = p1.add_instruction(migraphx::op::mul{}, two, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, one, mul);
auto sum2 = p1.add_instruction(migraphx::op::add{}, sum1, two);
p1.add_instruction(pass_op{}, sum2);
p1.compile(const_prop_target{});
migraphx::program p2;
auto total = p2.add_literal(7);
p2.add_instruction(pass_op{}, total);
EXPECT(p1 == p2);
}
TEST_CASE(const_add_scalar)
{
migraphx::program p1;
auto one = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(1));
auto two = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(2));
auto sum = p1.add_instruction(migraphx::op::add{}, one, two);
p1.add_instruction(pass_op{}, sum);
p1.compile(const_prop_target{});
migraphx::program p2;
auto total =
p2.add_literal(migraphx::literal{{migraphx::shape::int32_type, {2, 2}}, {3, 3, 3, 3}});
p2.add_instruction(pass_op{}, total);
EXPECT(p1 == p2);
}
TEST_CASE(const_scalar)
{
migraphx::program p1;
{
auto one = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(1));
p1.add_instruction(pass_op{}, one);
}
p1.compile(const_prop_target{});
migraphx::program p2;
{
auto one = p2.add_instruction(migraphx::op::scalar{{2, 2}}, p2.add_literal(1));
p2.add_instruction(pass_op{}, one);
}
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment