Commit 5c7bb1f8 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into ins_fp32_fp16

parents c5941d87 6f115a0f
...@@ -5,7 +5,7 @@ include(ROCMPackageConfigHelpers) ...@@ -5,7 +5,7 @@ include(ROCMPackageConfigHelpers)
add_library(migraphx add_library(migraphx
auto_contiguous.cpp auto_contiguous.cpp
common_subexpression_elimination.cpp common_subexpression_elimination.cpp
constant_propagate.cpp propagate_constant.cpp
dead_code_elimination.cpp dead_code_elimination.cpp
eliminate_allocation.cpp eliminate_allocation.cpp
eliminate_contiguous.cpp eliminate_contiguous.cpp
......
#include <migraphx/constant_propagate.hpp>
#include <migraphx/program.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct match_const_add
{
auto matcher() const
{
return match::name("add")(match::args(match::name("@literal"), match::name("@literal")));
}
void apply(program& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto arg1 = ins->inputs().at(0)->get_literal();
auto arg2 = ins->inputs().at(1)->get_literal();
auto sum = p.add_literal(transform(arg1, arg2, [](auto x, auto y) { return x + y; }));
p.replace_instruction(ins, sum);
}
};
void constant_propagate::apply(program& p) const { match::find_matches(p, match_const_add{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_MAKE_SIGNED_HPP
#define MIGRAPHX_GUARD_RTGLIB_MAKE_SIGNED_HPP
#include <migraphx/config.hpp>
#include <type_traits>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
typename std::conditional_t<std::is_integral<T>{}, std::make_signed<T>, std::enable_if<true, T>>::
type
make_signed(T x)
{
return x;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/make_signed.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -17,9 +18,12 @@ namespace migraphx { ...@@ -17,9 +18,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct abs : unary struct abs : unary<abs>
{ {
std::string name() const { return "abs"; } auto apply() const
{
return [](auto x) { return std::abs(make_signed(x)); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct acos : unary struct acos : unary<acos>
{ {
std::string name() const { return "acos"; } auto apply() const
{
return [](auto x) { return std::acos(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct add : binary struct add : binary<add>
{ {
std::string name() const { return "add"; } auto apply() const
{
return [](auto x, auto y) { return x + y; };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct asin : unary struct asin : unary<asin>
{ {
std::string name() const { return "asin"; } auto apply() const
{
return [](auto x) { return std::asin(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct atan : unary struct atan : unary<atan>
{ {
std::string name() const { return "atan"; } auto apply() const
{
return [](auto x) { return std::atan(x); };
}
}; };
} // namespace op } // namespace op
......
#ifndef MIGRAPHX_GUARD_OPERATORS_BINARY_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_BINARY_HPP
#define MIGRAPHX_GUARD_OPERATORS_BINARY_HPP #define MIGRAPHX_GUARD_OPERATORS_BINARY_HPP
#include <array> #include <migraphx/op/name.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct binary template <class Derived>
struct binary : op_name<Derived>
{ {
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type().same_dims(); check_shapes{inputs}.has(2).same_type().same_dims();
auto t = inputs.at(0).type(); const auto& s = inputs.front();
auto lens = inputs.at(0).lens(); if(s.scalar() and s.elements() == 1)
return {t, lens}; return {s.type()};
return {s.type(), s.lens()};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
if(input1.get_shape().standard() and input2.get_shape().standard())
{
std::transform(input1.begin(),
input1.end(),
input2.begin(),
output.begin(),
static_cast<const Derived&>(*this).apply());
}
else
{
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end()));
});
}
});
return result;
} }
}; };
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct cos : unary struct cos : unary<cos>
{ {
std::string name() const { return "cos"; } auto apply() const
{
return [](auto x) { return std::cos(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct cosh : unary struct cosh : unary<cosh>
{ {
std::string name() const { return "cosh"; } auto apply() const
{
return [](auto x) { return std::cosh(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct div : binary struct div : binary<div>
{ {
std::string name() const { return "div"; } auto apply() const
{
return [](auto x, auto y) { return x / y; };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct exp : unary struct exp : unary<exp>
{ {
std::string name() const { return "exp"; } auto apply() const
{
return [](auto x) { return std::exp(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct log : unary struct log : unary<log>
{ {
std::string name() const { return "log"; } auto apply() const
{
return [](auto x) { return std::log(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct max : binary struct max : binary<max>
{ {
std::string name() const { return "max"; } auto apply() const
{
return [](auto x, auto y) { return std::max(x, y); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct min : binary struct min : binary<min>
{ {
std::string name() const { return "min"; } auto apply() const
{
return [](auto x, auto y) { return std::min(x, y); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct mul : binary struct mul : binary<mul>
{ {
std::string name() const { return "mul"; } auto apply() const
{
return [](auto x, auto y) { return x * y; };
}
}; };
} // namespace op } // namespace op
......
#ifndef MIGRAPHX_GUARD_RTGLIB_NAME_HPP
#define MIGRAPHX_GUARD_RTGLIB_NAME_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
/// Create name from class
template <class Derived>
struct op_name
{
std::string name() const
{
static const std::string& name = get_type_name<Derived>();
return name.substr(name.rfind("::") + 2);
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct neg : unary struct neg : unary<neg>
{ {
std::string name() const { return "neg"; } auto apply() const
{
return [](auto x) { return -x; };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct relu : unary struct relu : unary<relu>
{ {
std::string name() const { return "relu"; } auto apply() const
{
return [](auto x) { return std::max(decltype(x){0}, x); };
}
}; };
} // namespace op } // namespace op
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment