Commit 31065c7d authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_squeeze' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_model_test

parents 6bec381f 6acbd4e4
...@@ -31,9 +31,9 @@ namespace migraphx { ...@@ -31,9 +31,9 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class Iterator, class EndIterator> template <class Iterator, class EndIterator>
auto is_end(rank<2>, Iterator it, EndIterator) -> decltype(!it._M_dereferenceable()) auto is_end(rank<2>, Iterator it, EndIterator) -> decltype(not it._M_dereferenceable())
{ {
return !it._M_dereferenceable(); return not it._M_dereferenceable();
} }
template <class Iterator, class EndIterator> template <class Iterator, class EndIterator>
......
...@@ -27,6 +27,8 @@ ...@@ -27,6 +27,8 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -46,6 +48,8 @@ operation make_op(const std::string& name, const Value& v) ...@@ -46,6 +48,8 @@ operation make_op(const std::string& name, const Value& v)
return make_op_from_value(name, v); return make_op_from_value(name, v);
} }
operation make_json_op(const std::string& name, const std::string& s);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -181,7 +181,7 @@ struct marker ...@@ -181,7 +181,7 @@ struct marker
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type( private_detail_te_handle_type(
PrivateDetailTypeErasedT value, PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value, typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value)) : private_detail_te_value(std::move(value))
{ {
...@@ -233,7 +233,7 @@ struct marker ...@@ -233,7 +233,7 @@ struct marker
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique()) if(not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
......
...@@ -38,11 +38,11 @@ struct gelu_erf_matcher ...@@ -38,11 +38,11 @@ struct gelu_erf_matcher
F f; F f;
auto erf_fn() const auto erf_fn() const
{ {
return f("erf")( auto mul_1_sqrt_2 = f("mul")(either_arg(0, 1)(none_of(has_value(M_SQRT1_2, 1e-3)).bind("x"),
used_once(), has_value(M_SQRT1_2, 1e-3)));
arg(0)(used_once(), auto div_sqrt_2 =
f("mul")(either_arg(0, 1)(none_of(has_value(M_SQRT1_2, 1e-3)).bind("x"), f("div")(args(none_of(has_value(M_SQRT2, 1e-3)).bind("x"), has_value(M_SQRT2, 1e-3)));
has_value(M_SQRT1_2, 1e-3))))); return f("erf")(used_once(), arg(0)(used_once(), any_of(mul_1_sqrt_2, div_sqrt_2)));
} }
auto add_erf() const auto add_erf() const
......
...@@ -50,8 +50,8 @@ struct layernorm_matcher ...@@ -50,8 +50,8 @@ struct layernorm_matcher
{ {
return f("div")(arg(0)(x_minus_mean()), return f("div")(arg(0)(x_minus_mean()),
arg(1)(skip_broadcasts(f("sqrt")( arg(1)(skip_broadcasts(f("sqrt")(arg(0)(
arg(0)(f("add")(either_arg(0, 1)(variance(), has_value(1e-12f)))))))); f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps"))))))));
} }
auto matcher() const { return layernorm_onnx(); } auto matcher() const { return layernorm_onnx(); }
......
...@@ -564,6 +564,11 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in ...@@ -564,6 +564,11 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in
return nullopt; return nullopt;
} }
MIGRAPHX_PRED_MATCHER(broadcast, instruction_ref ins)
{
return contains({"broadcast", "multibroadcast"}, ins->name());
}
template <class... Ms> template <class... Ms>
auto skip(Ms... ms) auto skip(Ms... ms)
{ {
...@@ -813,8 +818,7 @@ inline auto has_attribute(const std::string& name) ...@@ -813,8 +818,7 @@ inline auto has_attribute(const std::string& name)
template <class... Ms> template <class... Ms>
auto pointwise(Ms... ms) auto pointwise(Ms... ms)
{ {
return match::has_attribute("pointwise")(match::any_of(match::nargs(1), match::nargs(2)), return match::has_attribute("pointwise")(ms...);
ms...);
} }
} // namespace match } // namespace match
......
...@@ -219,7 +219,7 @@ struct module ...@@ -219,7 +219,7 @@ struct module
friend std::ostream& operator<<(std::ostream& os, const module& m); friend std::ostream& operator<<(std::ostream& os, const module& m);
friend bool operator==(const module& x, const module& y); friend bool operator==(const module& x, const module& y);
friend bool operator!=(const module& x, const module& y) { return !(x == y); } friend bool operator!=(const module& x, const module& y) { return not(x == y); }
private: private:
void assign(const module& m); void assign(const module& m);
......
...@@ -35,17 +35,13 @@ struct onnx_options ...@@ -35,17 +35,13 @@ struct onnx_options
{ {
/// Old way to set default fixed dimension size /// Old way to set default fixed dimension size
std::size_t default_dim_value = 0; std::size_t default_dim_value = 0;
/*! /// Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value set
* Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value /// parser throws)
* set parser throws)
*/
shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0}; shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0};
/// Explicitly specify the dims of an input /// Explicitly specify the dims of an input
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {}; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {};
/*! /// Explicitly specify dynamic dims of an input (if both map_input_dims and map_dyn_input_dims
* Explicitly specify dynamic dims of an input (if both map_input_dims and /// set parser throws)
* map_dyn_input_dims set parser throws)
*/
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims = {}; std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims = {};
/// Continue parsing onnx file if an unknown operator is found /// Continue parsing onnx file if an unknown operator is found
bool skip_unknown_operators = false; bool skip_unknown_operators = false;
...@@ -53,6 +49,8 @@ struct onnx_options ...@@ -53,6 +49,8 @@ struct onnx_options
bool print_program_on_error = false; bool print_program_on_error = false;
/// Max iter num for the loop operator /// Max iter num for the loop operator
int64_t max_loop_iterations = 10; int64_t max_loop_iterations = 10;
/// Use dynamic output for operators when available
bool use_dyn_output = false;
}; };
/// Create a program from an onnx file /// Create a program from an onnx file
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -60,10 +61,19 @@ struct binary : op_name<Derived> ...@@ -60,10 +61,19 @@ struct binary : op_name<Derived>
value attributes() const { return base_attributes(); } value attributes() const { return base_attributes(); }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(2).same_type().same_dims(); check_shapes{inputs, static_cast<const Derived&>(*this), true}
.has(2)
.same_type()
.same_dims();
auto s0 = inputs.at(0); auto s0 = inputs.at(0);
auto s1 = inputs.at(1); auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed()) if(s0.dynamic() or s1.dynamic())
{
if(s0 == s1)
return s0;
MIGRAPHX_THROW("BINARY: " + point_function() + ": fixed-dyn shape for inputs");
}
else if(s0 == s1 and s0.packed())
{ {
return s0; return s0;
} }
...@@ -81,9 +91,9 @@ struct binary : op_name<Derived> ...@@ -81,9 +91,9 @@ struct binary : op_name<Derived>
} }
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
std::transform(input1.begin(), std::transform(input1.begin(),
input1.end(), input1.end(),
......
...@@ -27,23 +27,30 @@ ...@@ -27,23 +27,30 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/// The broadcast operator performs the numpy-style broadcasting of an axis of a given tensor. This /**
/// is achieved primarily by setting the stride of the broadcasted axis to zero. Linear indicies are * 1 input version:
/// computed from multi-indicies by computing the inner product on the multi-index with the strides. * Broadcasts a tensor from the original shape to the broadcast_lens by setting the stride of
/// For example, if we have a tensor A(2,3) it has lengths of (2,3) and strides of (3,1). If we want * broadcasted dimensions to zero. `axis` attribute for a 1D input shape is the output dimension
/// to compute the linear offset that corresponds to the element on the 2nd row (i = 1) and 3rd * that stays the same. ex: broadcasting shape [1024] -> [4, 1024, 3] has axis = 1 For higher rank
/// column (j = 2), we compute the following inner product (1,2) dot (3, 1) = 1*3 + 2*1 = 5. It is * input shapes, axis is an offset parameter for the broadcasting. Such that this operator would
/// obvious from there that we can negate the effects of a given axis by setting the stride of that * work in the opposite direction of NumPy broadcasting. ex: broadcasting shape [2, 2] -> [2, 2, 3]
/// axis to zero. * with axis = 0
*
* 2 input version:
* Broadcast the first input 1D shape into the second input shape based on the axis parameter.
* Handles broadcasting a 1D static shape into a higher rank dynamic shape.
* broadcast_lens is not used
*/
struct broadcast struct broadcast
{ {
uint64_t axis = 0; uint64_t axis = 0;
std::vector<std::size_t> broadcast_lens; std::vector<std::size_t> broadcast_lens = {};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -54,37 +61,88 @@ struct broadcast ...@@ -54,37 +61,88 @@ struct broadcast
std::string name() const { return "broadcast"; } std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
auto input = inputs.at(0); check_shapes{inputs, *this, true}.has(1, 2);
auto t = input.type(); auto s0 = inputs.at(0);
auto t = s0.type();
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0); if(inputs.size() == 1)
// the broacast op is deprecated now, so not handling the negative {
// the ONNX broadcast op is deprecated now, so not handling the negative
// value of axis anymore // value of axis anymore
if(axis >= broadcast_lens.size()) if(axis >= broadcast_lens.size())
{ {
MIGRAPHX_THROW("BROADCAST : axis is out of range"); MIGRAPHX_THROW("BROADCAST : axis " + migraphx::to_string(axis) +
" is out of range");
} }
if(broadcast_lens.size() - axis < s0.lens().size())
if(broadcast_lens.size() - axis < input.lens().size())
{ {
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than input ndims"); MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than s0 ndims");
} }
if(not std::equal(s0.lens().begin(), s0.lens().end(), broadcast_lens.begin() + axis))
if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
{ {
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match"); MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
} }
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
shape output{t, broadcast_lens, std::move(bcast_strides)}; shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < input.elements()) if(output.elements() < s0.elements())
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to input size"); {
// don't think this can occur?
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size");
}
return output; return output;
} }
argument compute(shape output_shape, std::vector<argument> args) const else
{
// two inputs
auto s1 = inputs.at(1);
if(s0.dynamic())
{
MIGRAPHX_THROW("BROADCAST_2in: s0 is a dynamic shape, does not handle broadcasting "
"a dynamic shape");
}
if(s0.ndim() != 1)
{
MIGRAPHX_THROW("BROADCAST_2in: s0 has ndim " + migraphx::to_string(s0.ndim()) +
", only handle ndim = 1");
}
if(axis >= s1.ndim())
{ {
return args[0].reshape(output_shape); MIGRAPHX_THROW("BROADCAST_2in: axis " + migraphx::to_string(axis) +
" is out of range");
} }
if(s1.dynamic())
{
s0 = s0.to_dynamic();
if(s0.dyn_dims()[0] != s1.dyn_dims()[axis])
{
MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with dynamic s1 axis "
"dimension length (" +
migraphx::to_string(s0.dyn_dims()[0]) +
" != " + migraphx::to_string(s1.dyn_dims()[axis]) + ")");
}
return s1;
}
if(s0.lens()[0] != s1.lens()[axis])
{
MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with static s1 axis "
"dimension length (" +
migraphx::to_string(s0.dyn_dims()[0]) +
" != " + migraphx::to_string(s1.dyn_dims()[axis]) + ")");
}
std::vector<size_t> bcast_strides(s1.ndim(), 0);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
shape output{t, s1.lens(), std::move(bcast_strides)};
return output;
}
}
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
return args[0].reshape(dyn_out.computed_shape);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -33,11 +33,11 @@ namespace migraphx { ...@@ -33,11 +33,11 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
// Padding mode is default_ for fixed shape padding.
// same_lower and same_upper used for dynamic padding.
enum padding_mode_t enum padding_mode_t
{ {
default_, // NOLINT default_, // NOLINT
same,
valid,
same_lower, same_lower,
same_upper same_upper
}; };
......
...@@ -86,7 +86,7 @@ struct concat ...@@ -86,7 +86,7 @@ struct concat
{ {
if(l != axis) if(l != axis)
{ {
if(!std::all_of(inputs.begin(), inputs.end(), [&](auto s) { if(not std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
return s.lens()[l] == first_shape_lens[l]; return s.lens()[l] == first_shape_lens[l];
})) }))
{ {
......
...@@ -44,8 +44,16 @@ struct convert : unary<convert> ...@@ -44,8 +44,16 @@ struct convert : unary<convert>
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
return {target_type, inputs.at(0).lens(), inputs.at(0).strides()}; auto input = inputs.at(0);
if(input.dynamic())
{
return {target_type, input.dyn_dims()};
}
else
{
return {target_type, input.lens(), input.strides()};
}
} }
std::string point_op() const std::string point_op() const
......
...@@ -43,7 +43,6 @@ struct convolution ...@@ -43,7 +43,6 @@ struct convolution
int group = 1; int group = 1;
padding_mode_t padding_mode = default_; padding_mode_t padding_mode = default_;
bool use_dynamic_same_auto_pad = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -52,16 +51,15 @@ struct convolution ...@@ -52,16 +51,15 @@ struct convolution
f(self.stride, "stride"), f(self.stride, "stride"),
f(self.dilation, "dilation"), f(self.dilation, "dilation"),
f(self.group, "group"), f(self.group, "group"),
f(self.padding_mode, "padding_mode"), f(self.padding_mode, "padding_mode"));
f(self.use_dynamic_same_auto_pad, "use_dynamic_same_auto_pad"));
} }
std::string name() const { return "convolution"; } std::string name() const { return "convolution"; }
void check_attribute_size() const void check_attribute_size() const
{ {
if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and if((padding.size() != stride.size() and (padding.size() / 2) != stride.size()) or
stride.size() == dilation.size())) stride.size() != dilation.size())
{ {
MIGRAPHX_THROW("CONVOLUTION: inconsistent attribute sizes"); MIGRAPHX_THROW("CONVOLUTION: inconsistent attribute sizes");
} }
...@@ -76,7 +74,8 @@ struct convolution ...@@ -76,7 +74,8 @@ struct convolution
// num of dims of input and attribute should match // num of dims of input and attribute should match
const auto input_size = inputs[0].max_lens().size(); const auto input_size = inputs[0].max_lens().size();
const auto padding_size = padding.size(); const auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2))
if(input_size != padding_size / 2 + 2 && input_size != padding_size + 2)
{ {
MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!"); MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!");
} }
...@@ -93,13 +92,6 @@ struct convolution ...@@ -93,13 +92,6 @@ struct convolution
x_shape.lens().at(1) != (w_shape.lens().at(1) * group)) x_shape.lens().at(1) != (w_shape.lens().at(1) * group))
MIGRAPHX_THROW("CONVOLUTION: mismatched channel numbers"); MIGRAPHX_THROW("CONVOLUTION: mismatched channel numbers");
std::vector<op::padding_mode_t> dyn_pad_modes = {op::padding_mode_t::same_upper,
op::padding_mode_t::same_lower};
if(use_dynamic_same_auto_pad and not contains(dyn_pad_modes, padding_mode))
{
MIGRAPHX_THROW("CONVOLUTION: use_dynamic_same_auto_pad set with invalid padding mode");
}
if(x_shape.dynamic() or w_shape.dynamic()) if(x_shape.dynamic() or w_shape.dynamic())
{ {
return dynamic_compute_shape(x_shape, w_shape); return dynamic_compute_shape(x_shape, w_shape);
...@@ -161,7 +153,7 @@ struct convolution ...@@ -161,7 +153,7 @@ struct convolution
dynamic_shape_push_back(w_shape); dynamic_shape_push_back(w_shape);
const size_t num_spatial_dims = x_shape.max_lens().size() - 2; const size_t num_spatial_dims = x_shape.max_lens().size() - 2;
if(use_dynamic_same_auto_pad) if(padding_mode != default_)
{ {
for(std::size_t i = 0; i < num_spatial_dims; ++i) for(std::size_t i = 0; i < num_spatial_dims; ++i)
{ {
......
...@@ -61,8 +61,8 @@ struct deconvolution ...@@ -61,8 +61,8 @@ struct deconvolution
void check_attribute_size() const void check_attribute_size() const
{ {
if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and if((padding.size() != stride.size() and (padding.size() / 2) != stride.size()) or
stride.size() == dilation.size())) stride.size() != dilation.size())
{ {
MIGRAPHX_THROW("deconvolution: inconsistent attribute sizes"); MIGRAPHX_THROW("deconvolution: inconsistent attribute sizes");
} }
......
...@@ -43,13 +43,14 @@ struct dot ...@@ -43,13 +43,14 @@ struct dot
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
if(!std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; })) if(not std::all_of(
inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
{ {
MIGRAPHX_THROW("DOT: dot only accept 2 or more dims operands"); MIGRAPHX_THROW("DOT: dot only accept 2 or more dims operands");
} }
// only handle the case that the batch size of a and b are the same // only handle the case that the batch size of a and b are the same
if(!std::equal( if(not std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend())) a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{ {
MIGRAPHX_THROW("DOT: batch size of A and B mismatch: {" + to_string_range(a.lens()) + MIGRAPHX_THROW("DOT: batch size of A and B mismatch: {" + to_string_range(a.lens()) +
......
...@@ -32,14 +32,13 @@ namespace migraphx { ...@@ -32,14 +32,13 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct elu struct elu : unary<elu>
{ {
std::string name() const { return "elu"; }
float alpha = 1; float alpha = 1;
shape compute_shape(std::vector<shape> inputs) const
std::string point_op() const
{ {
check_shapes{inputs, *this}.has(1); return "${function:where}(${0} > 0, ${0}, ${alpha} * (${function:exp}(${0}) - 1))";
return inputs.front();
} }
template <class Self, class F> template <class Self, class F>
...@@ -47,6 +46,11 @@ struct elu ...@@ -47,6 +46,11 @@ struct elu
{ {
return pack(f(self.alpha, "alpha")); return pack(f(self.alpha, "alpha"));
} }
auto apply() const
{
return [&](auto x) { return x > 0 ? x : alpha * std::expm1(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -24,17 +24,8 @@ ...@@ -24,17 +24,8 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_FMOD_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_FMOD_HPP #define MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#include <array>
#include <migraphx/op/binary.hpp> #include <migraphx/op/binary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath> #include <cmath>
#include <utility>
#include <type_traits>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -49,7 +40,6 @@ struct fmod : binary<fmod> ...@@ -49,7 +40,6 @@ struct fmod : binary<fmod>
a["commutative"] = false; a["commutative"] = false;
return a; return a;
} }
std::string point_function() const { return "fmod"; }
auto apply() const auto apply() const
{ {
return [](auto x, auto y) { return std::fmod(x, y); }; return [](auto x, auto y) { return std::fmod(x, y); };
......
...@@ -65,7 +65,7 @@ struct gather ...@@ -65,7 +65,7 @@ struct gather
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
auto type = inputs[0].type(); auto type = inputs[0].type();
lens.erase(lens.begin() + axis); lens.erase(lens.begin() + axis);
if(!inputs[1].scalar()) if(not inputs[1].scalar())
{ {
auto ind_lens = inputs[1].lens(); auto ind_lens = inputs[1].lens();
lens.insert(lens.begin() + axis, ind_lens.begin(), ind_lens.end()); lens.insert(lens.begin() + axis, ind_lens.begin(), ind_lens.end());
......
...@@ -26,12 +26,13 @@ ...@@ -26,12 +26,13 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/op/unary.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct leaky_relu struct leaky_relu : unary<leaky_relu>
{ {
float alpha = 0.01; float alpha = 0.01;
...@@ -41,11 +42,13 @@ struct leaky_relu ...@@ -41,11 +42,13 @@ struct leaky_relu
return pack(f(self.alpha, "alpha")); return pack(f(self.alpha, "alpha"));
} }
std::string point_op() const { return "${function:where}(${0} > 0, ${0}, ${alpha} * ${0})"; }
std::string name() const { return "leaky_relu"; } std::string name() const { return "leaky_relu"; }
shape compute_shape(std::vector<shape> inputs) const
auto apply() const
{ {
check_shapes{inputs, *this}.has(1); return [&](auto x) { return x > 0 ? x : x * alpha; };
return inputs.front();
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment