Commit b797627a authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into copy_program

parents 34c2889a 6f115a0f
......@@ -5,7 +5,7 @@ include(ROCMPackageConfigHelpers)
add_library(migraphx
auto_contiguous.cpp
common_subexpression_elimination.cpp
constant_propagate.cpp
propagate_constant.cpp
dead_code_elimination.cpp
eliminate_allocation.cpp
eliminate_contiguous.cpp
......
#include <migraphx/constant_propagate.hpp>
#include <migraphx/program.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct match_const_add
{
auto matcher() const
{
return match::name("add")(match::args(match::name("@literal"), match::name("@literal")));
}
void apply(program& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto arg1 = ins->inputs().at(0)->get_literal();
auto arg2 = ins->inputs().at(1)->get_literal();
auto sum = p.add_literal(transform(arg1, arg2, [](auto x, auto y) { return x + y; }));
p.replace_instruction(ins, sum);
}
};
void constant_propagate::apply(program& p) const { match::find_matches(p, match_const_add{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_MAKE_SIGNED_HPP
#define MIGRAPHX_GUARD_RTGLIB_MAKE_SIGNED_HPP
#include <migraphx/config.hpp>
#include <type_traits>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
typename std::conditional_t<std::is_integral<T>{}, std::make_signed<T>, std::enable_if<true, T>>::
type
make_signed(T x)
{
return x;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -10,6 +10,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/make_signed.hpp>
#include <cmath>
#include <utility>
......@@ -17,9 +18,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct abs : unary
struct abs : unary<abs>
{
std::string name() const { return "abs"; }
auto apply() const
{
return [](auto x) { return std::abs(make_signed(x)); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct acos : unary
struct acos : unary<acos>
{
std::string name() const { return "acos"; }
auto apply() const
{
return [](auto x) { return std::acos(x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct add : binary
struct add : binary<add>
{
std::string name() const { return "add"; }
auto apply() const
{
return [](auto x, auto y) { return x + y; };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct asin : unary
struct asin : unary<asin>
{
std::string name() const { return "asin"; }
auto apply() const
{
return [](auto x) { return std::asin(x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct atan : unary
struct atan : unary<atan>
{
std::string name() const { return "atan"; }
auto apply() const
{
return [](auto x) { return std::atan(x); };
}
};
} // namespace op
......
#ifndef MIGRAPHX_GUARD_OPERATORS_BINARY_HPP
#define MIGRAPHX_GUARD_OPERATORS_BINARY_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <migraphx/op/name.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct binary
template <class Derived>
struct binary : op_name<Derived>
{
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type().same_dims();
auto t = inputs.at(0).type();
auto lens = inputs.at(0).lens();
return {t, lens};
const auto& s = inputs.front();
if(s.scalar() and s.elements() == 1)
return {s.type()};
return {s.type(), s.lens()};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
if(input1.get_shape().standard() and input2.get_shape().standard())
{
std::transform(input1.begin(),
input1.end(),
input2.begin(),
output.begin(),
static_cast<const Derived&>(*this).apply());
}
else
{
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end()));
});
}
});
return result;
}
};
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct cos : unary
struct cos : unary<cos>
{
std::string name() const { return "cos"; }
auto apply() const
{
return [](auto x) { return std::cos(x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct cosh : unary
struct cosh : unary<cosh>
{
std::string name() const { return "cosh"; }
auto apply() const
{
return [](auto x) { return std::cosh(x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct div : binary
struct div : binary<div>
{
std::string name() const { return "div"; }
auto apply() const
{
return [](auto x, auto y) { return x / y; };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct exp : unary
struct exp : unary<exp>
{
std::string name() const { return "exp"; }
auto apply() const
{
return [](auto x) { return std::exp(x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct log : unary
struct log : unary<log>
{
std::string name() const { return "log"; }
auto apply() const
{
return [](auto x) { return std::log(x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct max : binary
struct max : binary<max>
{
std::string name() const { return "max"; }
auto apply() const
{
return [](auto x, auto y) { return std::max(x, y); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct min : binary
struct min : binary<min>
{
std::string name() const { return "min"; }
auto apply() const
{
return [](auto x, auto y) { return std::min(x, y); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct mul : binary
struct mul : binary<mul>
{
std::string name() const { return "mul"; }
auto apply() const
{
return [](auto x, auto y) { return x * y; };
}
};
} // namespace op
......
#ifndef MIGRAPHX_GUARD_RTGLIB_NAME_HPP
#define MIGRAPHX_GUARD_RTGLIB_NAME_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
/// Create name from class
template <class Derived>
struct op_name
{
std::string name() const
{
static const std::string& name = get_type_name<Derived>();
return name.substr(name.rfind("::") + 2);
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct neg : unary
struct neg : unary<neg>
{
std::string name() const { return "neg"; }
auto apply() const
{
return [](auto x) { return -x; };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct relu : unary
struct relu : unary<relu>
{
std::string name() const { return "relu"; }
auto apply() const
{
return [](auto x) { return std::max(decltype(x){0}, x); };
}
};
} // namespace op
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment