Commit b797627a authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into copy_program

parents 34c2889a 6f115a0f
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct sigmoid : unary struct sigmoid : unary<sigmoid>
{ {
std::string name() const { return "sigmoid"; } auto apply() const
{
return [](auto x) { return 1.f / (1.f + std::exp(-x)); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct sin : unary struct sin : unary<sin>
{ {
std::string name() const { return "sin"; } auto apply() const
{
return [](auto x) { return std::sin(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct sinh : unary struct sinh : unary<sinh>
{ {
std::string name() const { return "sinh"; } auto apply() const
{
return [](auto x) { return std::sinh(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct sub : binary struct sub : binary<sub>
{ {
std::string name() const { return "sub"; } auto apply() const
{
return [](auto x, auto y) { return x - y; };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct tan : unary struct tan : unary<tan>
{ {
std::string name() const { return "tan"; } auto apply() const
{
return [](auto x) { return std::tan(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct tanh : unary struct tanh : unary<tanh>
{ {
std::string name() const { return "tanh"; } auto apply() const
{
return [](auto x) { return std::tanh(x); };
}
}; };
} // namespace op } // namespace op
......
#ifndef MIGRAPHX_GUARD_OPERATORS_UNARY_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_UNARY_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNARY_HPP #define MIGRAPHX_GUARD_OPERATORS_UNARY_HPP
#include <array> #include <migraphx/op/name.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct unary template <class Derived>
struct unary : op_name<Derived>
{ {
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
return inputs.at(0); return inputs.at(0);
} }
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
if(input.get_shape().standard())
{
std::transform(input.begin(),
input.end(),
output.begin(),
static_cast<const Derived&>(*this).apply());
}
else
{
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) =
static_cast<const Derived&>(*this).apply()(input(idx.begin(), idx.end()));
});
}
});
return result;
}
}; };
} // namespace op } // namespace op
......
#ifndef MIGRAPHX_GUARD_RTGLIB_CONSTANT_PROPAGATE_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_PROPAGATE_CONSTANT_HPP
#define MIGRAPHX_GUARD_RTGLIB_CONSTANT_PROPAGATE_HPP #define MIGRAPHX_GUARD_RTGLIB_PROPAGATE_CONSTANT_HPP
#include <string> #include <string>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -12,9 +12,9 @@ struct program; ...@@ -12,9 +12,9 @@ struct program;
/** /**
* Replace instructions which take all literals with a literal of the computation. * Replace instructions which take all literals with a literal of the computation.
*/ */
struct constant_propagate struct propagate_constant
{ {
std::string name() const { return "constant_propagate"; } std::string name() const { return "propagate_constant"; }
void apply(program& p) const; void apply(program& p) const;
}; };
......
#include <migraphx/propagate_constant.hpp>
#include <migraphx/program.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/functional.hpp>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool skip_propogate(instruction_ref ins)
{
if(ins->name() == "@literal")
return true;
auto&& s = ins->get_shape();
if(s.broadcasted() and not s.scalar())
return true;
if(s.scalar() and s.elements() != 1)
return true;
return false;
}
void propagate_constant::apply(program& p) const
{
fix([&](auto self, auto ins) {
if(not skip_propogate(ins))
{
auto r = ins->eval();
if(not r.empty())
{
assert(r.get_shape() == ins->get_shape());
auto l = p.add_literal(r.get_shape(), r.data());
p.replace_instruction(ins, l);
return;
}
}
std::unordered_set<instruction_ref> children(ins->inputs().begin(), ins->inputs().end());
for(auto child : children)
self(child);
})(std::prev(p.end()));
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
#include <migraphx/constant_propagate.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/eliminate_contiguous.hpp> #include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/common_subexpression_elimination.hpp> #include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
...@@ -48,7 +48,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -48,7 +48,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
//dead_code_elimination{}, //dead_code_elimination{},
simplify_algebra{}, simplify_algebra{},
dead_code_elimination{}, dead_code_elimination{},
constant_propagate{}, propagate_constant{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
//simplify_reshapes{}, //simplify_reshapes{},
......
...@@ -236,8 +236,7 @@ struct test_exp : verify_program<test_exp> ...@@ -236,8 +236,7 @@ struct test_exp : verify_program<test_exp>
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {6}}; migraphx::shape s{migraphx::shape::float_type, {6}};
std::vector<float> data{0.1f, 0.2f, 1.f, 2.f, 0.6f, 10.f}; auto x = p.add_instruction(migraphx::op::abs{}, p.add_parameter("x", s));
auto x = p.add_literal(s, data);
p.add_instruction(migraphx::op::exp{}, x); p.add_instruction(migraphx::op::exp{}, x);
return p; return p;
} }
...@@ -249,8 +248,7 @@ struct test_log : verify_program<test_log> ...@@ -249,8 +248,7 @@ struct test_log : verify_program<test_log>
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {6}}; migraphx::shape s{migraphx::shape::float_type, {6}};
std::vector<float> data{0.1f, 0.2f, 1.f, 2.f, 0.6f, 100.f}; auto x = p.add_instruction(migraphx::op::abs{}, p.add_parameter("x", s));
auto x = p.add_literal(s, data);
p.add_instruction(migraphx::op::log{}, x); p.add_instruction(migraphx::op::log{}, x);
return p; return p;
} }
......
#include <migraphx/constant_propagate.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/add.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/mul.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
...@@ -9,12 +11,12 @@ struct const_prop_target ...@@ -9,12 +11,12 @@ struct const_prop_target
std::string name() const { return "const_prop"; } std::string name() const { return "const_prop"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const std::vector<migraphx::pass> get_passes(migraphx::context&) const
{ {
return {migraphx::constant_propagate{}, migraphx::dead_code_elimination{}}; return {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}};
} }
migraphx::context get_context() const { return {}; } migraphx::context get_context() const { return {}; }
}; };
TEST_CASE(const_add1) TEST_CASE(const_add)
{ {
migraphx::program p1; migraphx::program p1;
auto one = p1.add_literal(1); auto one = p1.add_literal(1);
...@@ -29,7 +31,7 @@ TEST_CASE(const_add1) ...@@ -29,7 +31,7 @@ TEST_CASE(const_add1)
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
TEST_CASE(const_add2) TEST_CASE(const_add_parameter)
{ {
migraphx::program p1; migraphx::program p1;
auto one = p1.add_parameter("one", {migraphx::shape::int32_type, {1}}); auto one = p1.add_parameter("one", {migraphx::shape::int32_type, {1}});
...@@ -44,7 +46,7 @@ TEST_CASE(const_add2) ...@@ -44,7 +46,7 @@ TEST_CASE(const_add2)
EXPECT(p1 != p2); EXPECT(p1 != p2);
} }
TEST_CASE(const_add3) TEST_CASE(const_multiadd)
{ {
migraphx::program p1; migraphx::program p1;
auto one = p1.add_literal(1); auto one = p1.add_literal(1);
...@@ -60,4 +62,54 @@ TEST_CASE(const_add3) ...@@ -60,4 +62,54 @@ TEST_CASE(const_add3)
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
TEST_CASE(const_add_mul)
{
migraphx::program p1;
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto mul = p1.add_instruction(migraphx::op::mul{}, two, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, one, mul);
auto sum2 = p1.add_instruction(migraphx::op::add{}, sum1, two);
p1.add_instruction(pass_op{}, sum2);
p1.compile(const_prop_target{});
migraphx::program p2;
auto total = p2.add_literal(7);
p2.add_instruction(pass_op{}, total);
EXPECT(p1 == p2);
}
TEST_CASE(const_add_scalar)
{
migraphx::program p1;
auto one = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(1));
auto two = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(2));
auto sum = p1.add_instruction(migraphx::op::add{}, one, two);
p1.add_instruction(pass_op{}, sum);
p1.compile(const_prop_target{});
migraphx::program p2;
auto total =
p2.add_literal(migraphx::literal{{migraphx::shape::int32_type, {2, 2}}, {3, 3, 3, 3}});
p2.add_instruction(pass_op{}, total);
EXPECT(p1 == p2);
}
TEST_CASE(const_scalar)
{
migraphx::program p1;
{
auto one = p1.add_instruction(migraphx::op::scalar{{2, 2}}, p1.add_literal(1));
p1.add_instruction(pass_op{}, one);
}
p1.compile(const_prop_target{});
migraphx::program p2;
{
auto one = p2.add_instruction(migraphx::op::scalar{{2, 2}}, p2.add_literal(1));
p2.add_instruction(pass_op{}, one);
}
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment