Commit c0154dca authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from the develop branch

parents ca170b5c b93f5320
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct max : binary
struct max : binary<max>
{
std::string name() const { return "max"; }
auto apply() const
{
return [](auto x, auto y) { return std::max(x, y); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct min : binary
struct min : binary<min>
{
std::string name() const { return "min"; }
auto apply() const
{
return [](auto x, auto y) { return std::min(x, y); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct mul : binary
struct mul : binary<mul>
{
std::string name() const { return "mul"; }
auto apply() const
{
return [](auto x, auto y) { return x * y; };
}
};
} // namespace op
......
......@@ -42,7 +42,7 @@ struct multibroadcast
std::vector<size_t> bcast_strides(output_lens.size(), 0);
auto offset = output_lens.size() - input.lens().size();
for(int i = input.lens().size() - 1; i >= 0; i--)
for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--)
{
if(output_lens[i + offset] == input.lens()[i])
{
......@@ -55,7 +55,7 @@ struct multibroadcast
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
......
#ifndef MIGRAPHX_GUARD_RTGLIB_NAME_HPP
#define MIGRAPHX_GUARD_RTGLIB_NAME_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
/// Create name from class
template <class Derived>
struct op_name
{
std::string name() const
{
static const std::string& name = get_type_name<Derived>();
return name.substr(name.rfind("::") + 2);
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct neg : unary
struct neg : unary<neg>
{
std::string name() const { return "neg"; }
auto apply() const
{
return [](auto x) { return -x; };
}
};
} // namespace op
......
......@@ -9,6 +9,7 @@
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/int_divide.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
......@@ -48,35 +49,6 @@ struct pooling
assert(lengths[1] <= (input.lens()[3] + 2 * padding[1]));
if(padding_mode == default_)
{
return {
t,
{
input.lens()[0],
input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>(
1,
std::ptrdiff_t(std::floor((input.lens()[2] + 2 * padding[0] - lengths[0]) /
static_cast<float>(stride[0]))) +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
std::ptrdiff_t(std::floor((input.lens()[3] + 2 * padding[1] - lengths[1]) /
static_cast<float>(stride[1]))) +
1)),
}};
}
else if(padding_mode == same)
{
return {t,
{input.lens()[0],
input.lens()[1],
static_cast<std::size_t>(
std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
static_cast<std::size_t>(
std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))}};
}
else if(padding_mode == valid)
{
return {t,
{
......@@ -84,16 +56,39 @@ struct pooling
input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>(
1,
std::ptrdiff_t(std::floor((input.lens()[2] - lengths[0]) /
static_cast<float>(stride[0]))) +
floor_divide<std::ptrdiff_t>(
input.lens()[2] + 2 * padding[0] - lengths[0], stride[0]) +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
std::ptrdiff_t(std::floor((input.lens()[3] - lengths[1]) /
static_cast<float>(stride[1]))) +
floor_divide<std::ptrdiff_t>(
input.lens()[3] + 2 * padding[1] - lengths[1], stride[1]) +
1)),
}};
}
else if(padding_mode == same)
{
return {t,
{input.lens()[0],
input.lens()[1],
ceil_divide<std::size_t>(input.lens()[2], stride[0]),
ceil_divide<std::size_t>(input.lens()[3], stride[1])}};
}
else if(padding_mode == valid)
{
return {
t,
{
input.lens()[0],
input.lens()[1],
std::size_t(std::max<std::ptrdiff_t>(
1,
floor_divide<std::ptrdiff_t>(input.lens()[2] - lengths[0], stride[0]) + 1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
floor_divide<std::ptrdiff_t>(input.lens()[3] - lengths[1], stride[1]) + 1)),
}};
}
else
{
MIGRAPHX_THROW("Invalid padding mode");
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct relu : unary
struct relu : unary<relu>
{
std::string name() const { return "relu"; }
auto apply() const
{
return [](auto x) { return std::max(decltype(x){0}, x); };
}
};
} // namespace op
......
......@@ -66,7 +66,7 @@ struct reshape
{
return {std::move(output_shape), std::move(args.front().data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
......
......@@ -25,6 +25,15 @@ struct rnn
rnn_direction direction = rnn_direction::forward;
float clip = 0.0f;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.clip, "clip"));
}
std::string name() const { return "rnn"; }
shape compute_shape(std::vector<shape> inputs) const
{
......
......@@ -18,7 +18,13 @@ namespace op {
struct scalar
{
shape scalar_bcast;
std::vector<std::size_t> scalar_bcast_lens;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.scalar_bcast_lens, "scalar_bcst_dims"));
}
std::string name() const { return "scalar"; }
......@@ -26,15 +32,15 @@ struct scalar
{
assert(check_shapes{inputs}.has(1).only_dims(1).size() == 1);
auto t = inputs.at(0).type();
std::vector<std::size_t> strides(scalar_bcast.lens().size(), 0);
return {t, scalar_bcast.lens(), strides};
std::vector<std::size_t> strides(scalar_bcast_lens.size(), 0);
return {t, scalar_bcast_lens, strides};
}
argument compute(shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct sigmoid : unary
struct sigmoid : unary<sigmoid>
{
std::string name() const { return "sigmoid"; }
auto apply() const
{
return [](auto x) { return 1.f / (1.f + std::exp(-x)); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct sin : unary
struct sin : unary<sin>
{
std::string name() const { return "sin"; }
auto apply() const
{
return [](auto x) { return std::sin(x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct sinh : unary
struct sinh : unary<sinh>
{
std::string name() const { return "sinh"; }
auto apply() const
{
return [](auto x) { return std::sinh(x); };
}
};
} // namespace op
......
......@@ -86,7 +86,7 @@ struct slice
auto offset = compute_offset(input.get_shape()) * output_shape.type_size();
return {std::move(output_shape), [=] { return input.data() + offset; }};
}
int output_alias(const std::vector<shape>&) const { return 0; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
......
......@@ -69,7 +69,7 @@ struct squeeze
{
return {std::move(output_shape), std::move(args.front().data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct sub : binary
struct sub : binary<sub>
{
std::string name() const { return "sub"; }
auto apply() const
{
return [](auto x, auto y) { return x - y; };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct tan : unary
struct tan : unary<tan>
{
std::string name() const { return "tan"; }
auto apply() const
{
return [](auto x) { return std::tan(x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct tanh : unary
struct tanh : unary<tanh>
{
std::string name() const { return "tanh"; }
auto apply() const
{
return [](auto x) { return std::tanh(x); };
}
};
} // namespace op
......
......@@ -57,7 +57,7 @@ struct transpose
{
return {std::move(output_shape), std::move(args.front().data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment