Commit c0154dca authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from the develop branch

parents ca170b5c b93f5320
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct asin : unary
struct asin : unary<asin>
{
std::string name() const { return "asin"; }
auto apply() const
{
return [](auto x) { return std::asin(x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct atan : unary
struct atan : unary<atan>
{
std::string name() const { return "atan"; }
auto apply() const
{
return [](auto x) { return std::atan(x); };
}
};
} // namespace op
......
#ifndef MIGRAPHX_GUARD_OPERATORS_BINARY_HPP
#define MIGRAPHX_GUARD_OPERATORS_BINARY_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <migraphx/op/name.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct binary
template <class Derived>
struct binary : op_name<Derived>
{
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type().same_dims();
auto t = inputs.at(0).type();
auto lens = inputs.at(0).lens();
return {t, lens};
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed())
{
return s0;
}
else
{
return {s0.type(), s0.lens()};
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
if(input1.get_shape().standard() and input2.get_shape().standard())
{
std::transform(input1.begin(),
input1.end(),
input2.begin(),
output.begin(),
static_cast<const Derived&>(*this).apply());
}
else
{
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end()));
});
}
});
return result;
}
};
......
......@@ -27,45 +27,43 @@ namespace op {
struct broadcast
{
uint64_t axis = 0;
std::vector<std::size_t> broadcast_lens;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
return pack(f(self.axis, "axis"), f(self.broadcast_lens, "dims"));
}
shape broadcast_shape;
std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto t = inputs.at(0).type();
auto input = inputs.at(0);
std::vector<size_t> bcast_strides(broadcast_shape.lens().size(), 0);
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
if(std::all_of(broadcast_shape.lens().cbegin(), broadcast_shape.lens().cend(), [&](auto x) {
return x == 1;
}))
if(std::all_of(
broadcast_lens.cbegin(), broadcast_lens.cend(), [&](auto x) { return x == 1; }))
{
if(axis != 0)
MIGRAPHX_THROW("when broadcasting tensor of size 1, axis should be 0");
return {t, broadcast_shape.lens(), std::move(bcast_strides)};
MIGRAPHX_THROW("BROADCAST: when broadcasting tensor of size 1, axis should be 0");
return {t, broadcast_lens, std::move(bcast_strides)};
}
else
{
assert(broadcast_shape.lens().size() - axis >= input.lens().size());
if(!std::equal(
input.lens().begin(), input.lens().end(), broadcast_shape.lens().begin() + axis))
MIGRAPHX_THROW("when broadcasting success sizes must match");
assert(broadcast_lens.size() - axis >= input.lens().size());
if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
MIGRAPHX_THROW("BROADCAST: when broadcasting success sizes must match");
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
return {t, broadcast_shape.lens(), std::move(bcast_strides)};
return {t, broadcast_lens, std::move(bcast_strides)};
}
}
argument compute(shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
......
#ifndef MIGRAPHX_GUARD_OPERATORS_CLIP_HPP
#define MIGRAPHX_GUARD_OPERATORS_CLIP_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <limits>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct clip : unary<clip>
{
float max_val = std::numeric_limits<float>::max();
float min_val = std::numeric_limits<float>::min();
clip() {}
clip(float max, float min) : max_val(max), min_val(min) {}
auto apply() const
{
auto max = max_val;
auto min = min_val;
return [max, min](auto x) {
using type = decltype(x);
return std::min(std::max(type(min), x), type(max));
};
}
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.max_val, "max"), f(self.min_val, "min"));
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -31,6 +31,8 @@ enum class rnn_direction
bidirectional,
};
std::ostream& operator<<(std::ostream& os, rnn_direction v);
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -19,6 +19,13 @@ namespace op {
struct concat
{
std::size_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "concat"; }
std::vector<std::size_t> compute_offsets(const shape& output_shape,
const std::vector<argument>& args) const
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct cos : unary
struct cos : unary<cos>
{
std::string name() const { return "cos"; }
auto apply() const
{
return [](auto x) { return std::cos(x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct cosh : unary
struct cosh : unary<cosh>
{
std::string name() const { return "cosh"; }
auto apply() const
{
return [](auto x) { return std::cosh(x); };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct div : binary
struct div : binary<div>
{
std::string name() const { return "div"; }
auto apply() const
{
return [](auto x, auto y) { return x / y; };
}
};
} // namespace op
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct exp : unary
struct exp : unary<exp>
{
std::string name() const { return "exp"; }
auto apply() const
{
return [](auto x) { return std::exp(x); };
}
};
} // namespace op
......
......@@ -46,7 +46,7 @@ struct flatten
{
return {std::move(output_shape), std::move(args.front().data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
......
......@@ -19,11 +19,18 @@ namespace op {
struct gather
{
int axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "gather"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
check_shapes{inputs, *this}.has(2).standard();
auto lens = inputs[0].lens();
int n_dim = static_cast<int>(lens.size());
if(axis >= n_dim || axis < -n_dim)
......
......@@ -27,6 +27,16 @@ struct gru
float clip = 0.0f;
int linear_before_reset = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.clip, "clip"),
f(self.linear_before_reset, "linear_before_reset"));
}
std::string name() const { return "gru"; }
shape compute_shape(std::vector<shape> inputs) const
{
......
......@@ -24,7 +24,7 @@ struct identity
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
int output_alias(const std::vector<shape>&) const { return 0; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
......
......@@ -18,19 +18,20 @@ namespace op {
struct leaky_relu
{
std::string name() const { return "leaky_relu"; }
float alpha;
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return inputs.front();
}
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.alpha, "alpha"));
}
std::string name() const { return "leaky_relu"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return inputs.front();
}
};
} // namespace op
......
......@@ -39,7 +39,7 @@ struct load
MIGRAPHX_THROW("Load access is out of bounds");
return {s, args[0].data() + offset};
}
int output_alias(const std::vector<shape>&) const { return 0; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
friend std::ostream& operator<<(std::ostream& os, const load& op)
{
......
......@@ -17,9 +17,12 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct log : unary
struct log : unary<log>
{
std::string name() const { return "log"; }
auto apply() const
{
return [](auto x) { return std::log(x); };
}
};
} // namespace op
......
......@@ -19,10 +19,17 @@ namespace op {
struct logsoftmax
{
int axis = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "logsoftmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
check_shapes{inputs}.has(1).standard();
if(axis < 0 || axis > inputs[0].lens().size())
{
MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) +
......
......@@ -25,6 +25,15 @@ struct lstm
float clip = 0.0f;
int input_forget = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.input_forget, "input_forget"));
}
std::string name() const { return "lstm"; }
shape compute_shape(std::vector<shape> inputs) const
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment