Commit c0154dca authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from the develop branch

parents ca170b5c b93f5320
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct asin : unary struct asin : unary<asin>
{ {
std::string name() const { return "asin"; } auto apply() const
{
return [](auto x) { return std::asin(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct atan : unary struct atan : unary<atan>
{ {
std::string name() const { return "atan"; } auto apply() const
{
return [](auto x) { return std::atan(x); };
}
}; };
} // namespace op } // namespace op
......
#ifndef MIGRAPHX_GUARD_OPERATORS_BINARY_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_BINARY_HPP
#define MIGRAPHX_GUARD_OPERATORS_BINARY_HPP #define MIGRAPHX_GUARD_OPERATORS_BINARY_HPP
#include <array> #include <migraphx/op/name.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct binary template <class Derived>
struct binary : op_name<Derived>
{ {
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type().same_dims(); check_shapes{inputs}.has(2).same_type().same_dims();
auto t = inputs.at(0).type(); auto s0 = inputs.at(0);
auto lens = inputs.at(0).lens(); auto s1 = inputs.at(1);
return {t, lens}; if(s0 == s1 and s0.packed())
{
return s0;
}
else
{
return {s0.type(), s0.lens()};
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
if(input1.get_shape().standard() and input2.get_shape().standard())
{
std::transform(input1.begin(),
input1.end(),
input2.begin(),
output.begin(),
static_cast<const Derived&>(*this).apply());
}
else
{
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end()));
});
}
});
return result;
} }
}; };
......
...@@ -27,45 +27,43 @@ namespace op { ...@@ -27,45 +27,43 @@ namespace op {
struct broadcast struct broadcast
{ {
uint64_t axis = 0; uint64_t axis = 0;
std::vector<std::size_t> broadcast_lens;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.axis, "axis")); return pack(f(self.axis, "axis"), f(self.broadcast_lens, "dims"));
} }
shape broadcast_shape;
std::string name() const { return "broadcast"; } std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
auto input = inputs.at(0); auto input = inputs.at(0);
std::vector<size_t> bcast_strides(broadcast_shape.lens().size(), 0); std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
if(std::all_of(broadcast_shape.lens().cbegin(), broadcast_shape.lens().cend(), [&](auto x) { if(std::all_of(
return x == 1; broadcast_lens.cbegin(), broadcast_lens.cend(), [&](auto x) { return x == 1; }))
}))
{ {
if(axis != 0) if(axis != 0)
MIGRAPHX_THROW("when broadcasting tensor of size 1, axis should be 0"); MIGRAPHX_THROW("BROADCAST: when broadcasting tensor of size 1, axis should be 0");
return {t, broadcast_shape.lens(), std::move(bcast_strides)}; return {t, broadcast_lens, std::move(bcast_strides)};
} }
else else
{ {
assert(broadcast_shape.lens().size() - axis >= input.lens().size()); assert(broadcast_lens.size() - axis >= input.lens().size());
if(!std::equal( if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
input.lens().begin(), input.lens().end(), broadcast_shape.lens().begin() + axis)) MIGRAPHX_THROW("BROADCAST: when broadcasting success sizes must match");
MIGRAPHX_THROW("when broadcasting success sizes must match");
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis); std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
return {t, broadcast_shape.lens(), std::move(bcast_strides)}; return {t, broadcast_lens, std::move(bcast_strides)};
} }
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return {std::move(output_shape), std::move(args.at(0).data)};
} }
int output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
} // namespace op } // namespace op
......
#ifndef MIGRAPHX_GUARD_OPERATORS_CLIP_HPP
#define MIGRAPHX_GUARD_OPERATORS_CLIP_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <limits>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct clip : unary<clip>
{
float max_val = std::numeric_limits<float>::max();
float min_val = std::numeric_limits<float>::min();
clip() {}
clip(float max, float min) : max_val(max), min_val(min) {}
auto apply() const
{
auto max = max_val;
auto min = min_val;
return [max, min](auto x) {
using type = decltype(x);
return std::min(std::max(type(min), x), type(max));
};
}
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.max_val, "max"), f(self.min_val, "min"));
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -31,6 +31,8 @@ enum class rnn_direction ...@@ -31,6 +31,8 @@ enum class rnn_direction
bidirectional, bidirectional,
}; };
std::ostream& operator<<(std::ostream& os, rnn_direction v);
} // namespace op } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -19,6 +19,13 @@ namespace op { ...@@ -19,6 +19,13 @@ namespace op {
struct concat struct concat
{ {
std::size_t axis = 0; std::size_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "concat"; } std::string name() const { return "concat"; }
std::vector<std::size_t> compute_offsets(const shape& output_shape, std::vector<std::size_t> compute_offsets(const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct cos : unary struct cos : unary<cos>
{ {
std::string name() const { return "cos"; } auto apply() const
{
return [](auto x) { return std::cos(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct cosh : unary struct cosh : unary<cosh>
{ {
std::string name() const { return "cosh"; } auto apply() const
{
return [](auto x) { return std::cosh(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct div : binary struct div : binary<div>
{ {
std::string name() const { return "div"; } auto apply() const
{
return [](auto x, auto y) { return x / y; };
}
}; };
} // namespace op } // namespace op
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct exp : unary struct exp : unary<exp>
{ {
std::string name() const { return "exp"; } auto apply() const
{
return [](auto x) { return std::exp(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -46,7 +46,7 @@ struct flatten ...@@ -46,7 +46,7 @@ struct flatten
{ {
return {std::move(output_shape), std::move(args.front().data)}; return {std::move(output_shape), std::move(args.front().data)};
} }
int output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
} // namespace op } // namespace op
......
...@@ -19,11 +19,18 @@ namespace op { ...@@ -19,11 +19,18 @@ namespace op {
struct gather struct gather
{ {
int axis = 0; int axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "gather"; } std::string name() const { return "gather"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2).standard();
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
int n_dim = static_cast<int>(lens.size()); int n_dim = static_cast<int>(lens.size());
if(axis >= n_dim || axis < -n_dim) if(axis >= n_dim || axis < -n_dim)
......
...@@ -27,6 +27,16 @@ struct gru ...@@ -27,6 +27,16 @@ struct gru
float clip = 0.0f; float clip = 0.0f;
int linear_before_reset = 0; int linear_before_reset = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.clip, "clip"),
f(self.linear_before_reset, "linear_before_reset"));
}
std::string name() const { return "gru"; } std::string name() const { return "gru"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -24,7 +24,7 @@ struct identity ...@@ -24,7 +24,7 @@ struct identity
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return {std::move(output_shape), std::move(args.at(0).data)};
} }
int output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
} // namespace op } // namespace op
......
...@@ -18,19 +18,20 @@ namespace op { ...@@ -18,19 +18,20 @@ namespace op {
struct leaky_relu struct leaky_relu
{ {
std::string name() const { return "leaky_relu"; }
float alpha; float alpha;
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return inputs.front();
}
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.alpha, "alpha")); return pack(f(self.alpha, "alpha"));
} }
std::string name() const { return "leaky_relu"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return inputs.front();
}
}; };
} // namespace op } // namespace op
......
...@@ -39,7 +39,7 @@ struct load ...@@ -39,7 +39,7 @@ struct load
MIGRAPHX_THROW("Load access is out of bounds"); MIGRAPHX_THROW("Load access is out of bounds");
return {s, args[0].data() + offset}; return {s, args[0].data() + offset};
} }
int output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
friend std::ostream& operator<<(std::ostream& os, const load& op) friend std::ostream& operator<<(std::ostream& os, const load& op)
{ {
......
...@@ -17,9 +17,12 @@ namespace migraphx { ...@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct log : unary struct log : unary<log>
{ {
std::string name() const { return "log"; } auto apply() const
{
return [](auto x) { return std::log(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -19,10 +19,17 @@ namespace op { ...@@ -19,10 +19,17 @@ namespace op {
struct logsoftmax struct logsoftmax
{ {
int axis = 1; int axis = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "logsoftmax"; } std::string name() const { return "logsoftmax"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1).standard();
if(axis < 0 || axis > inputs[0].lens().size()) if(axis < 0 || axis > inputs[0].lens().size())
{ {
MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) + MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) +
......
...@@ -25,6 +25,15 @@ struct lstm ...@@ -25,6 +25,15 @@ struct lstm
float clip = 0.0f; float clip = 0.0f;
int input_forget = 0; int input_forget = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.input_forget, "input_forget"));
}
std::string name() const { return "lstm"; } std::string name() const { return "lstm"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment