Commit 3a848f0d authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into doc2

parents 64e8e30a d1e945da
...@@ -93,7 +93,7 @@ struct instruction ...@@ -93,7 +93,7 @@ struct instruction
void replace(const shape& r); void replace(const shape& r);
operation op; operation op;
shape result; shape result{};
std::vector<instruction_ref> output; std::vector<instruction_ref> output;
std::vector<instruction_ref> arguments; std::vector<instruction_ref> arguments;
literal lit; literal lit;
......
...@@ -7,8 +7,20 @@ ...@@ -7,8 +7,20 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
/// struct to pass in onnx options to parser
struct onnx_options
{
unsigned int batch_size = 1;
};
/// Create a program from an onnx file /// Create a program from an onnx file
program parse_onnx(const std::string& name); program parse_onnx(const std::string& name, onnx_options = onnx_options{});
/// Create a program from an onnx buffer
program parse_onnx_buffer(const std::string& buffer, onnx_options options);
/// Create a program from an onnx buffer
program parse_onnx_buffer(const void* data, std::size_t size, onnx_options options);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ACOSH_HPP
#define MIGRAPHX_GUARD_OPERATORS_ACOSH_HPP
#include <migraphx/op/unary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct acosh : unary<acosh>
{
auto apply() const
{
return [](auto x) { return std::acosh(x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -27,25 +27,30 @@ struct argmax ...@@ -27,25 +27,30 @@ struct argmax
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size()); int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0) if(axis >= n_dim || axis < -n_dim)
{ {
MIGRAPHX_THROW("ARGMAX: axis is out of range."); MIGRAPHX_THROW("ARGMAX: axis is out of range.");
} }
lens[axis] = 1; int64_t tuned_axis = (axis < 0) ? axis + n_dim : axis;
lens[tuned_axis] = 1;
return {shape::int64_type, lens}; return {shape::int64_type, lens};
} }
template <class T> template <class T>
int64_t calc_argmax(T& input, std::vector<std::size_t>& indices, size_t item_num) const int64_t calc_argmax(T& input,
int64_t tuned_axis,
std::vector<std::size_t>& indices,
size_t item_num) const
{ {
auto max_val = input(indices.begin(), indices.end()); auto max_val = input(indices.begin(), indices.end());
int64_t max_index = 0; int64_t max_index = 0;
for(std::size_t i = 1; i < item_num; ++i) for(std::size_t i = 1; i < item_num; ++i)
{ {
indices[axis] = i; indices[tuned_axis] = i;
auto cur_val = input(indices.begin(), indices.end()); auto cur_val = input(indices.begin(), indices.end());
if(max_val < cur_val) if(max_val < cur_val)
{ {
max_val = cur_val; max_val = cur_val;
...@@ -59,13 +64,15 @@ struct argmax ...@@ -59,13 +64,15 @@ struct argmax
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto batch_item_num = args.front().get_shape().lens()[axis]; auto n_dim = args.front().get_shape().lens().size();
auto tuned_axis = axis < 0 ? axis + n_dim : axis;
auto batch_item_num = args.front().get_shape().lens()[tuned_axis];
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
par_for(output_shape.elements(), [&](auto i) { par_for(output_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i); auto data_idx = output_shape.multi(i);
output[i] = this->calc_argmax(input, data_idx, batch_item_num); output[i] = this->calc_argmax(input, tuned_axis, data_idx, batch_item_num);
}); });
}); });
}); });
......
...@@ -27,25 +27,29 @@ struct argmin ...@@ -27,25 +27,29 @@ struct argmin
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size()); int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0) if(axis >= n_dim || axis < -n_dim)
{ {
MIGRAPHX_THROW("ARGMIN: axis is out of range."); MIGRAPHX_THROW("ARGMIN: axis is out of range.");
} }
lens[axis] = 1; int64_t tuned_axis = (axis < 0) ? axis + n_dim : axis;
lens[tuned_axis] = 1;
return {shape::int64_type, lens}; return {shape::int64_type, lens};
} }
template <class T> template <class T>
int64_t calc_argmin(T& input, std::vector<std::size_t>& indices, size_t item_num) const int64_t calc_argmin(T& input,
int64_t tuned_axis,
std::vector<std::size_t>& indices,
size_t item_num) const
{ {
auto min_val = input(indices.begin(), indices.end()); auto min_val = input(indices.begin(), indices.end());
int64_t min_index = 0; int64_t min_index = 0;
for(std::size_t i = 1; i < item_num; ++i) for(std::size_t i = 1; i < item_num; ++i)
{ {
indices[axis] = i; indices[tuned_axis] = i;
auto cur_val = input(indices.begin(), indices.end()); auto cur_val = input(indices.begin(), indices.end());
if(min_val > cur_val) if(min_val > cur_val)
{ {
min_val = cur_val; min_val = cur_val;
...@@ -59,13 +63,15 @@ struct argmin ...@@ -59,13 +63,15 @@ struct argmin
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
std::size_t batch_item_num = args.front().get_shape().lens()[axis]; auto n_dim = args.front().get_shape().lens().size();
auto tuned_axis = axis < 0 ? axis + n_dim : axis;
std::size_t batch_item_num = args.front().get_shape().lens()[tuned_axis];
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
par_for(output_shape.elements(), [&](auto i) { par_for(output_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i); auto data_idx = output_shape.multi(i);
output[i] = this->calc_argmin(input, data_idx, batch_item_num); output[i] = this->calc_argmin(input, tuned_axis, data_idx, batch_item_num);
}); });
}); });
}); });
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ASINH_HPP
#define MIGRAPHX_GUARD_OPERATORS_ASINH_HPP
#include <migraphx/op/unary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct asinh : unary<asinh>
{
auto apply() const
{
return [](auto x) { return std::asinh(x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_ATANH_HPP
#define MIGRAPHX_GUARD_OPERATORS_ATANH_HPP
#include <migraphx/op/unary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct atanh : unary<atanh>
{
auto apply() const
{
return [](auto x) { return std::atanh(x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_DECONVOLUTION_HPP
#define MIGRAPHX_GUARD_OPERATORS_DECONVOLUTION_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct deconvolution
{
std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> dilation = {{1, 1}};
padding_mode_t padding_mode = default_;
int group = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.padding, "padding"),
f(self.stride, "stride"),
f(self.dilation, "dilation"),
f(self.padding_mode, "padding_mode"),
f(self.group, "group"));
}
std::string name() const { return "deconvolution"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4);
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
auto t = input.type();
return {t,
{
input.lens()[0],
weights.lens()[1],
std::size_t(std::max<std::ptrdiff_t>(
1,
stride[0] * (input.lens()[2] - 1) +
((weights.lens()[2] - 1) * dilation[0] + 1) - 2 * padding[0])),
std::size_t(std::max<std::ptrdiff_t>(
1,
stride[1] * (input.lens()[3] - 1) +
((weights.lens()[3] - 1) * dilation[1] + 1) - 2 * padding[1])),
}};
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -18,7 +18,7 @@ namespace op { ...@@ -18,7 +18,7 @@ namespace op {
struct flatten struct flatten
{ {
uint64_t axis = 0; int64_t axis = 1;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -30,16 +30,19 @@ struct flatten ...@@ -30,16 +30,19 @@ struct flatten
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
auto&& lens = inputs.front().lens(); auto&& lens = inputs.front().lens();
int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis > lens.size()) if(axis > n_dim or axis < -n_dim)
{ {
MIGRAPHX_THROW("axis for flatten must be less than tensor rank"); MIGRAPHX_THROW("FLATTEN: axis for flatten is out of range");
} }
auto x =
std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{}); auto tuned_axis = (axis < 0) ? axis + n_dim : axis;
auto y =
std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{}); auto x = std::accumulate(
lens.begin(), lens.begin() + tuned_axis, std::size_t{1}, std::multiplies<>{});
auto y = std::accumulate(
lens.begin() + tuned_axis, lens.end(), std::size_t{1}, std::multiplies<>{});
return {inputs.at(0).type(), {x, y}}; return {inputs.at(0).type(), {x, y}};
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
......
...@@ -11,7 +11,7 @@ namespace op { ...@@ -11,7 +11,7 @@ namespace op {
struct logsoftmax struct logsoftmax
{ {
int axis = 1; int64_t axis = 1;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -23,7 +23,8 @@ struct logsoftmax ...@@ -23,7 +23,8 @@ struct logsoftmax
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1).standard(); check_shapes{inputs}.has(1).standard();
if(axis < 0 || axis >= inputs[0].lens().size()) int64_t n_dim = static_cast<int64_t>(inputs[0].lens().size());
if(axis < -n_dim || axis >= n_dim)
{ {
MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) + MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) +
" is out of range"); " is out of range");
......
#ifndef MIGRAPHX_GUARD_OPERATORS_PRELU_HPP
#define MIGRAPHX_GUARD_OPERATORS_PRELU_HPP
#include <migraphx/op/binary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct prelu : binary<prelu>
{
auto apply() const
{
return [](auto x, auto slope) { return ((x < 0) ? (x * slope) : x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -40,6 +40,15 @@ struct zero ...@@ -40,6 +40,15 @@ struct zero
} }
}; };
struct one
{
template <class T>
operator T() const
{
return T{1};
}
};
template <class Derived> template <class Derived>
struct reduce_op : op_name<Derived> struct reduce_op : op_name<Derived>
{ {
......
#ifndef MIGRAPHX_GUARD_OPERATORS_REDUCE_PROD_HPP
#define MIGRAPHX_GUARD_OPERATORS_REDUCE_PROD_HPP
#include <migraphx/op/reduce_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct reduce_prod : reduce_op<reduce_prod>
{
reduce_prod() {}
reduce_prod(std::vector<int64_t> ax) : reduce_op(std::move(ax)) {}
auto op() const
{
return [=](auto x, auto y) { return x * y; };
}
auto init() const { return one(); }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -11,7 +11,7 @@ namespace op { ...@@ -11,7 +11,7 @@ namespace op {
struct softmax struct softmax
{ {
int axis = 1; int64_t axis = 1;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -23,7 +23,8 @@ struct softmax ...@@ -23,7 +23,8 @@ struct softmax
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1).standard(); check_shapes{inputs}.has(1).standard();
if(axis < 0 || axis >= inputs[0].lens().size()) int64_t n_dim = inputs[0].lens().size();
if(axis < -n_dim || axis >= n_dim)
{ {
MIGRAPHX_THROW("SoftMax: input axis value " + std::to_string(axis) + MIGRAPHX_THROW("SoftMax: input axis value " + std::to_string(axis) +
" is out of range"); " is out of range");
......
...@@ -33,13 +33,21 @@ struct squeeze ...@@ -33,13 +33,21 @@ struct squeeze
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
if(std::any_of(
axes.begin(), axes.end(), [&](auto axis) { return input_shape.lens()[axis] != 1; })) // change to support negative axis value
std::vector<int64_t> tuned_axes(axes.size());
std::transform(axes.begin(), axes.end(), tuned_axes.begin(), [&](auto i) {
return i >= 0 ? i : i + old_lens.size();
});
if(std::any_of(tuned_axes.begin(), tuned_axes.end(), [&](auto axis) {
return old_lens[axis] != 1;
}))
{ {
MIGRAPHX_THROW("squeeze axis dimension should be equal to 1"); MIGRAPHX_THROW("squeeze axis dimension should be equal to 1");
} }
std::vector<std::size_t> new_lens; std::vector<std::size_t> new_lens;
if(axes.empty()) if(tuned_axes.empty())
{ {
std::copy_if(old_lens.begin(), std::copy_if(old_lens.begin(),
old_lens.end(), old_lens.end(),
...@@ -50,7 +58,7 @@ struct squeeze ...@@ -50,7 +58,7 @@ struct squeeze
{ {
for(std::size_t i = 0; i < old_lens.size(); i++) for(std::size_t i = 0; i < old_lens.size(); i++)
{ {
if(std::find(axes.begin(), axes.end(), i) == axes.end()) if(std::find(tuned_axes.begin(), tuned_axes.end(), i) == tuned_axes.end())
{ {
new_lens.push_back(old_lens[i]); new_lens.push_back(old_lens[i]);
} }
......
...@@ -34,13 +34,22 @@ struct transpose ...@@ -34,13 +34,22 @@ struct transpose
auto input_lens = input.lens(); auto input_lens = input.lens();
auto input_strides = input.strides(); auto input_strides = input.strides();
auto t = input.type(); auto t = input.type();
if(dims.size() != input_lens.size()) auto tuned_dims = dims;
// if not perm provided, reverse the dims
if(tuned_dims.empty())
{
tuned_dims.resize(input_lens.size());
std::iota(tuned_dims.begin(), tuned_dims.end(), 0);
std::reverse(tuned_dims.begin(), tuned_dims.end());
}
if(tuned_dims.size() != input_lens.size())
{ {
MIGRAPHX_THROW("Permutation has wrong number of axes"); MIGRAPHX_THROW("Permutation has wrong number of axes");
} }
std::vector<int64_t> axes(dims.size()); std::vector<int64_t> axes(tuned_dims.size());
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
if(!std::is_permutation(axes.begin(), axes.end(), dims.begin())) if(!std::is_permutation(axes.begin(), axes.end(), tuned_dims.begin()))
{ {
MIGRAPHX_THROW("Invalid permutation"); MIGRAPHX_THROW("Invalid permutation");
} }
...@@ -48,8 +57,8 @@ struct transpose ...@@ -48,8 +57,8 @@ struct transpose
std::vector<size_t> output_strides(input_lens.size()); std::vector<size_t> output_strides(input_lens.size());
for(std::size_t i = 0; i < output_lens.size(); i++) for(std::size_t i = 0; i < output_lens.size(); i++)
{ {
output_lens[i] = input_lens[dims[i]]; output_lens[i] = input_lens[tuned_dims[i]];
output_strides[i] = input_strides[dims[i]]; output_strides[i] = input_strides[tuned_dims[i]];
} }
return {t, output_lens, output_strides}; return {t, output_lens, output_strides};
} }
......
...@@ -38,11 +38,18 @@ struct unsqueeze ...@@ -38,11 +38,18 @@ struct unsqueeze
return shape{type, old_lens}; return shape{type, old_lens};
std::size_t new_size = old_lens.size() + axes.size(); std::size_t new_size = old_lens.size() + axes.size();
// in case of axes to be negative, tune to positive
std::vector<int64_t> tuned_axes(axes.size());
std::transform(axes.begin(), axes.end(), tuned_axes.begin(), [new_size](auto i) {
return i >= 0 ? i : i + new_size;
});
std::vector<std::size_t> new_lens(new_size); std::vector<std::size_t> new_lens(new_size);
std::size_t p = 0; std::size_t p = 0;
for(std::size_t i = 0; i < new_size; i++) for(std::size_t i = 0; i < new_size; i++)
{ {
if(std::find(axes.begin(), axes.end(), i) != axes.end()) if(std::find(tuned_axes.begin(), tuned_axes.end(), i) != tuned_axes.end())
{ {
new_lens[i] = 1; new_lens[i] = 1;
} }
......
...@@ -257,11 +257,17 @@ struct operation ...@@ -257,11 +257,17 @@ struct operation
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
operation& operator=(PrivateDetailTypeErasedT value) operation& operator=(PrivateDetailTypeErasedT value)
{ {
if(private_detail_te_handle_mem_var.unique()) using std::swap;
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value); auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
else if(!private_detail_te_handle_mem_var) if(derived and private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>( {
std::forward<PrivateDetailTypeErasedT>(value)); *derived = std::forward<PrivateDetailTypeErasedT>(value);
}
else
{
operation rhs(value);
swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var);
}
return *this; return *this;
} }
...@@ -269,7 +275,7 @@ struct operation ...@@ -269,7 +275,7 @@ struct operation
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast() PrivateDetailTypeErasedT* any_cast()
{ {
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type< ? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>( typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle()) private_detail_te_get_handle())
...@@ -280,7 +286,7 @@ struct operation ...@@ -280,7 +286,7 @@ struct operation
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{ {
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type< ? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>( typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle()) private_detail_te_get_handle())
......
...@@ -4,12 +4,15 @@ ...@@ -4,12 +4,15 @@
#include <migraphx/op/abnormal_ops.hpp> #include <migraphx/op/abnormal_ops.hpp>
#include <migraphx/op/abs.hpp> #include <migraphx/op/abs.hpp>
#include <migraphx/op/acos.hpp> #include <migraphx/op/acos.hpp>
#include <migraphx/op/acosh.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/add.hpp>
#include <migraphx/op/argmax.hpp> #include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp> #include <migraphx/op/argmin.hpp>
#include <migraphx/op/asin.hpp> #include <migraphx/op/asin.hpp>
#include <migraphx/op/asinh.hpp>
#include <migraphx/op/as_shape.hpp> #include <migraphx/op/as_shape.hpp>
#include <migraphx/op/atan.hpp> #include <migraphx/op/atan.hpp>
#include <migraphx/op/atanh.hpp>
#include <migraphx/op/batch_norm.hpp> #include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/binary.hpp> #include <migraphx/op/binary.hpp>
#include <migraphx/op/broadcast.hpp> #include <migraphx/op/broadcast.hpp>
...@@ -23,6 +26,7 @@ ...@@ -23,6 +26,7 @@
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/cosh.hpp> #include <migraphx/op/cosh.hpp>
#include <migraphx/op/cos.hpp> #include <migraphx/op/cos.hpp>
#include <migraphx/op/deconvolution.hpp>
#include <migraphx/op/div.hpp> #include <migraphx/op/div.hpp>
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/elu.hpp> #include <migraphx/op/elu.hpp>
...@@ -48,13 +52,15 @@ ...@@ -48,13 +52,15 @@
#include <migraphx/op/outline.hpp> #include <migraphx/op/outline.hpp>
#include <migraphx/op/pad.hpp> #include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/op/prelu.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/pow.hpp> #include <migraphx/op/pow.hpp>
#include <migraphx/op/reduce_sum.hpp> #include <migraphx/op/reduce_max.hpp>
#include <migraphx/op/reduce_mean.hpp> #include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/reduce_min.hpp> #include <migraphx/op/reduce_min.hpp>
#include <migraphx/op/reduce_max.hpp> #include <migraphx/op/reduce_prod.hpp>
#include <migraphx/op/reduce_sum.hpp>
#include <migraphx/op/relu.hpp> #include <migraphx/op/relu.hpp>
#include <migraphx/op/reshape.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/rnn.hpp> #include <migraphx/op/rnn.hpp>
......
...@@ -57,11 +57,17 @@ struct pass ...@@ -57,11 +57,17 @@ struct pass
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
pass& operator=(PrivateDetailTypeErasedT value) pass& operator=(PrivateDetailTypeErasedT value)
{ {
if(private_detail_te_handle_mem_var.unique()) using std::swap;
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value); auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
else if(!private_detail_te_handle_mem_var) if(derived and private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>( {
std::forward<PrivateDetailTypeErasedT>(value)); *derived = std::forward<PrivateDetailTypeErasedT>(value);
}
else
{
pass rhs(value);
swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var);
}
return *this; return *this;
} }
...@@ -69,7 +75,7 @@ struct pass ...@@ -69,7 +75,7 @@ struct pass
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast() PrivateDetailTypeErasedT* any_cast()
{ {
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type< ? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>( typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle()) private_detail_te_get_handle())
...@@ -80,7 +86,7 @@ struct pass ...@@ -80,7 +86,7 @@ struct pass
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{ {
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT) return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type< ? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>( typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle()) private_detail_te_get_handle())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment