Commit 4a39a0f7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into add-conv_bn_add-test

parents 5564172e bb827865
......@@ -15,6 +15,7 @@ namespace op {
// 3.1) include_min(default)/exclude_min
// 4) clip_max(default)/not_clip_max
// 4.1) exclude_max(default)/include_max
// 5) normalize padding
enum class normalize_attribute
{
use_len,
......@@ -22,7 +23,8 @@ enum class normalize_attribute
clip_max,
clip_min,
include_max,
include_min
include_min,
normalize_padding
};
} // namespace op
......
......@@ -8,6 +8,7 @@
#include <migraphx/streamutils.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/value.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/int_divide.hpp>
#include <migraphx/config.hpp>
......@@ -40,29 +41,39 @@ struct pooling
void check_attribute_size() const
{
if(not(padding.size() == stride.size() and padding.size() == lengths.size()))
if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == lengths.size()))
{
MIGRAPHX_THROW("POOLING: inconsistent attribute sizes");
}
}
shape compute_shape(std::vector<shape> inputs) const
value attributes() const { return {{"normalize_padding", "padding"}}; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
const shape& input = inputs.at(0);
auto input_lens = input.lens();
size_t kdims = input_lens.size() - 2;
if(kdims != this->kdims())
auto input_lens = input.lens();
size_t kdims = input_lens.size() - 2;
auto input_size = inputs[0].lens().size();
auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2))
{
MIGRAPHX_THROW("pooling: input k-dims does not match attribute size");
MIGRAPHX_THROW("POOLING: input and attribute size mismatch!");
}
std::vector<std::size_t> output_lens(input_lens.begin(), input_lens.begin() + 2);
for(size_t i = 0; i < kdims; i++)
{
std::ptrdiff_t dim_size = input_lens[i + 2] + 2 * padding[i] - lengths[i];
std::ptrdiff_t dim_size;
auto padding_factor = 2 * padding[i];
if(padding_size == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
dim_size = input_lens[i + 2] + padding_factor - lengths[i];
assert(dim_size >= 0);
std::size_t len = (ceil_mode) ? ceil_divide<std::ptrdiff_t>(dim_size, stride[i])
: floor_divide<std::ptrdiff_t>(dim_size, stride[i]);
......@@ -75,7 +86,7 @@ struct pooling
size_t kdims() const
{
check_attribute_size();
return padding.size();
return stride.size();
}
};
......
......@@ -43,7 +43,7 @@ struct prefix_scan_op : op_name<Derived>
argument compute(const shape&, std::vector<argument> args) const
{
argument result = args[0];
argument result = args[0].copy();
auto s = result.get_shape();
auto slice = shape{s.type(), {s.lens()[axis]}, {s.strides()[axis]}};
auto lens = s.lens();
......
......@@ -36,19 +36,23 @@ struct quant_convolution
f(self.group, "group"));
}
value attributes() const { return {{"general_data_type", "convolution"}}; }
value attributes() const
{
return {{"general_data_type", "convolution"}, {"normalize_padding", "padding"}};
}
std::string name() const { return "quant_convolution"; }
void check_attribute_size() const
{
if(not(padding.size() == stride.size() and padding.size() == dilation.size()))
if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == dilation.size()))
{
MIGRAPHX_THROW("quant_convolution: inconsistent attribute sizes");
MIGRAPHX_THROW("QUANT_CONVOLUTION: inconsistent attribute sizes");
}
}
shape compute_shape(std::vector<shape> inputs) const
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3);
check_attribute_size();
......@@ -70,13 +74,16 @@ struct quant_convolution
t = shape::int32_type;
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
auto padding_size = padding.size();
for(size_t i = 0; i < kdims; i++)
{
auto padding_factor = 2 * padding[i];
if(padding_size == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) +
2 * padding[i]) /
padding_factor) /
stride[i] +
1)));
}
......@@ -87,7 +94,7 @@ struct quant_convolution
size_t kdims() const
{
check_attribute_size();
return padding.size();
return stride.size();
}
};
......
......@@ -18,21 +18,12 @@ namespace op {
struct quant_dot
{
int32_t alpha = 1;
int32_t beta = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
}
value attributes() const { return {{"general_data_type", "dot"}}; }
std::string name() const { return "quant_dot"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.same_type();
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.same_type().has(2);
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
......@@ -64,18 +55,6 @@ struct quant_dot
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
if(inputs.size() == 3 && out_lens != inputs.at(2).lens())
{
MIGRAPHX_THROW("QUANT_DOT: dimension mismatch, operand C: {" +
to_string_range(inputs.at(2).lens()) +
"}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}");
}
if(inputs.size() == 3 && inputs.at(2).type() != shape::int32_type)
{
MIGRAPHX_THROW("QUANT_DOT: operand C type must be int32");
}
return {shape::int32_type, out_lens};
}
};
......
#ifndef MIGRAPHX_GUARD_OPERATORS_QUANTIZE_LINEAR_HPP
#define MIGRAPHX_GUARD_OPERATORS_QUANTIZE_LINEAR_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/config.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/tune_axis.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct quantizelinear
{
std::string name() const { return "quantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.same_dims();
if(inputs.size() == 3)
{
return {inputs[2].type(), inputs[0].lens(), inputs[0].strides()};
}
return {shape::uint8_type, inputs[0].lens(), inputs[0].strides()};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto x = args.at(0);
auto y_scale = args.at(1);
std::vector<int8_t> zeros(output_shape.bytes(), 0);
argument y_zero_point{output_shape, zeros.data()};
if(args.size() == 3)
{
y_zero_point = args.at(2);
}
argument result{output_shape};
visit_all(result, y_zero_point)([&](auto output, auto zero_pts) {
x.visit([&](auto input) {
y_scale.visit([&](auto scales) {
using quant_type = typename decltype(output)::value_type;
auto min_value = std::numeric_limits<quant_type>::min();
auto max_value = std::numeric_limits<quant_type>::max();
par_for(output_shape.elements(), [&](auto i) {
int64_t quantized = static_cast<int64_t>(std::round(input[i] / scales[i])) +
static_cast<int64_t>(zero_pts[i]);
output[i] = std::max(static_cast<int64_t>(min_value),
std::min(static_cast<int64_t>(max_value), quantized));
});
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -71,7 +72,7 @@ struct reshape
return args[0].reshape(output_shape);
}
bool is_borrowed() const { return true; }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
#ifndef MIGRAPHX_GUARD_OPERATORS_REVERSE_HPP
#define MIGRAPHX_GUARD_OPERATORS_REVERSE_HPP
#include <algorithm>
#include <vector>
#include <cmath>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct reverse
{
std::vector<int64_t> axes;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"));
}
std::string name() const { return "reverse"; }
value attributes() const
{
value normalize;
normalize["axes"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
shape normalize_compute_shape(std::vector<shape> inputs) const
{
return inputs[0].with_lens(inputs[0].lens());
}
argument compute(const shape& s, std::vector<argument> args) const
{
argument result{s};
auto lens = s.lens();
visit_all(result, args.front())([&](auto output, auto input) {
shape_for_each(s, [&](const auto& out_idx) {
auto in_idx = out_idx;
for(const auto& axis : axes)
{
in_idx[axis] = lens[axis] - 1 - out_idx[axis];
}
output[s.index(out_idx)] = input[s.index(in_idx)];
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -39,7 +40,7 @@ struct scalar
{
return args[0].reshape(output_shape);
}
bool is_borrowed() const { return true; }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatter
{
int64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "scatter"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3).standard();
return inputs.front();
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
// max dimension in axis
auto axis_dim_size = output_shape.lens()[axis];
visit_all(result, args[0], args[2])([&](auto output, auto data, auto update) {
std::copy(data.begin(), data.end(), output.begin());
args[1].visit([&](auto indices) {
auto ind_s = indices.get_shape();
shape_for_each(ind_s, [&](const auto& idx) {
auto out_idx = idx;
auto index = indices[ind_s.index(idx)];
index = (index < 0) ? index + axis_dim_size : index;
out_idx[axis] = index;
output[output_shape.index(out_idx)] = update[ind_s.index(idx)];
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -9,6 +9,7 @@
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -77,7 +78,7 @@ struct squeeze
{
return args[0].reshape(output_shape);
}
bool is_borrowed() const { return true; }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
#ifndef MIGRAPHX_GUARD_OPERATORS_STEP_HPP
#define MIGRAPHX_GUARD_OPERATORS_STEP_HPP
#include "migraphx/stringutils.hpp"
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct step
{
std::vector<int64_t> axes;
std::vector<int64_t> steps;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"), f(self.steps, "steps"));
}
value attributes() const
{
value normalize;
normalize["axes"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "step"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto input = inputs.at(0);
auto in_lens = input.lens();
auto t = input.type();
if(axes.size() != steps.size())
{
MIGRAPHX_THROW("STEP: attribute axes {" + to_string_range(axes) +
"} has different dimensions from step {" + to_string_range(steps) +
"}.");
}
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return axis >= in_lens.size(); }))
{
MIGRAPHX_THROW("STEP: axis value is out of range!");
}
auto lens = in_lens;
auto strides = input.strides();
for(auto i : range(axes.size()))
{
auto axis = axes[i];
auto step = steps[i];
lens[axis] = (in_lens[axis] + step - 1) / step;
strides[axis] *= step;
}
return {t, lens, strides};
}
argument compute(shape output_shape, std::vector<argument> args) const
{
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_GATHER_HPP
#define MIGRAPHX_GUARD_OPERATORS_GATHER_HPP
#include <algorithm>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct topk
{
int64_t k = 1;
int64_t axis = 0;
bool largest = true;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.k, "k"), f(self.axis, "axis"), f(self.largest, "largest"));
}
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "topk"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs.at(0).lens();
auto type = inputs.at(0).type();
lens[axis] = k;
shape s_val{type, lens};
shape s_ind{shape::int64_type, lens};
return shape({s_val, s_ind});
}
template <class T, class Compare>
struct heap_vector
{
std::vector<T> data;
Compare compare;
heap_vector(const std::vector<T>& val, Compare comp) : data(val), compare(std::move(comp))
{
std::make_heap(data.begin(), data.end(), compare);
}
void try_push(T val)
{
if(not compare(val, data.front()))
return;
std::pop_heap(data.begin(), data.end(), compare);
data.back() = val;
std::push_heap(data.begin(), data.end(), compare);
}
std::vector<T> sort()
{
auto sorted_data = data;
std::sort_heap(sorted_data.begin(), sorted_data.end(), compare);
return sorted_data;
}
};
template <class T, class Compare>
heap_vector<T, Compare> make_heap(std::vector<T> val, Compare compare) const
{
return {std::move(val), std::move(compare)};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto vec_ss = output_shape.sub_shapes();
argument res_val{vec_ss.front()};
argument res_ind{vec_ss.back()};
auto in_s = args.front().get_shape();
auto out_s = vec_ss.front();
auto comp_lens = in_s.lens();
auto axis_dim = comp_lens[axis];
// compute shape
comp_lens[axis] = 1;
shape comp_s{in_s.type(), comp_lens};
visit_all(res_val, args.front())([&](auto out_val, auto input) {
auto* out_ind = res_ind.cast<int64_t>();
par_for(comp_s.elements(), [&](auto i) {
auto idx = comp_s.multi(i);
std::vector<std::size_t> indices(k);
std::iota(indices.begin(), indices.end(), 0);
auto comp = [&](auto i1, auto i2) {
auto idx1 = idx;
auto idx2 = idx;
idx1[axis] = i1;
idx2[axis] = i2;
return this->largest
? std::greater<>{}(input[in_s.index(idx1)], input[in_s.index(idx2)])
: std::less<>{}(input[in_s.index(idx1)], input[in_s.index(idx2)]);
};
auto hp = this->make_heap(indices, comp);
for(std::size_t ii = indices.size(); ii < axis_dim; ++ii)
{
hp.try_push(ii);
}
auto sorted_indices = hp.sort();
auto out_idx = idx;
auto in_idx = idx;
for(auto j : range(sorted_indices.size()))
{
out_idx[axis] = j;
in_idx[axis] = sorted_indices[j];
out_val[out_s.index(out_idx)] = input[in_s.index(in_idx)];
out_ind[out_s.index(out_idx)] = sorted_indices[j];
}
});
});
return argument({res_val, res_ind});
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -6,6 +6,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -20,7 +21,7 @@ struct transpose
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.dims, "dims"));
return pack(f(self.dims, "permutation"));
}
std::string name() const { return "transpose"; }
......@@ -31,31 +32,23 @@ struct transpose
auto input_lens = input.lens();
auto input_strides = input.strides();
auto t = input.type();
auto tuned_dims = dims;
// if not perm provided, reverse the dims
if(tuned_dims.empty())
{
tuned_dims.resize(input_lens.size());
std::iota(tuned_dims.begin(), tuned_dims.end(), 0);
std::reverse(tuned_dims.begin(), tuned_dims.end());
}
if(tuned_dims.size() != input_lens.size())
if(dims.size() != input_lens.size())
{
MIGRAPHX_THROW("Permutation has wrong number of axes");
}
std::vector<int64_t> axes(tuned_dims.size());
std::vector<int64_t> axes(dims.size());
std::iota(axes.begin(), axes.end(), 0);
if(!std::is_permutation(axes.begin(), axes.end(), tuned_dims.begin()))
if(!std::is_permutation(axes.begin(), axes.end(), dims.begin()))
{
MIGRAPHX_THROW("Invalid permutation");
MIGRAPHX_THROW("TRANSPOSE: Invalid permutation");
}
std::vector<size_t> output_lens(input_lens.size());
std::vector<size_t> output_strides(input_lens.size());
for(std::size_t i = 0; i < output_lens.size(); i++)
{
output_lens[i] = input_lens[tuned_dims[i]];
output_strides[i] = input_strides[tuned_dims[i]];
output_lens[i] = input_lens[dims[i]];
output_strides[i] = input_strides[dims[i]];
}
return {t, output_lens, output_strides};
}
......@@ -63,7 +56,7 @@ struct transpose
{
return args[0].reshape(output_shape);
}
bool is_borrowed() const { return true; }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -41,7 +41,11 @@ struct unary : op_name<Derived>
{
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(1);
auto s = inputs.at(0);
if(s.broadcasted())
if(s.scalar())
{
return s;
}
else if(s.broadcasted())
{
return {s.type(), s.lens()};
}
......
......@@ -8,6 +8,7 @@
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -70,7 +71,7 @@ struct unsqueeze
{
return args[0].reshape(output_shape);
}
bool is_borrowed() const { return true; }
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
#ifndef MIGRAPHX_GUARD_OPERATORS_WHERE_HPP
#define MIGRAPHX_GUARD_OPERATORS_WHERE_HPP
#include <array>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct where
{
std::string name() const { return "where"; }
value attributes() const { return {{"pointwise", true}, {"point_op", "${0} ? ${1} : ${2}"}}; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3).same_dims();
auto s1 = inputs.at(1);
auto s2 = inputs.at(2);
if(s1 == s2 and s1.packed())
{
return s1;
}
else if(s1.packed() != s2.packed())
{
return s1.packed() ? s1 : s2;
}
else if(s1.broadcasted() != s2.broadcasted())
{
return s1.broadcasted() ? s2.with_lens(s1.lens()) : s1.with_lens(s1.lens());
}
else
{
return {s1.type(), s1.lens()};
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[1], args[2])([&](auto output, const auto x, const auto y) {
args[0].visit([&](const auto condition) {
par_for(output_shape.elements(),
[&](auto i) { output[i] = condition[i] ? x[i] : y[i]; });
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -15,6 +15,7 @@
#include <migraphx/module_ref.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/auto_any_cast.hpp>
#include <migraphx/lifetime.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
......@@ -178,7 +179,7 @@ shape normalize_compute_shape_op(const T& x,
}
template <class T>
auto compute_op(rank<2>,
auto compute_op(rank<1>,
const T& x,
context& ctx,
const shape& output_shape,
......@@ -188,14 +189,6 @@ auto compute_op(rank<2>,
return x.compute(auto_any_cast(ctx), output_shape, input);
}
template <class T>
auto compute_op(
rank<1>, const T& x, context&, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{
......@@ -207,50 +200,118 @@ template <class T>
argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<2>{}, x, ctx, output_shape, input);
return compute_op(rank<1>{}, x, ctx, output_shape, input);
}
template <class T>
auto compute_op(rank<2>, const T& x, const shape& output_shape, const std::vector<argument>& input)
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(std::declval<context&>()), output_shape, input))
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable without a context: " + name);
MIGRAPHX_THROW("Not computable: " + name);
}
template <class T>
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<1>{}, x, output_shape, input);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f))
{
return x.compute(output, inputs, module_args, f);
}
template <class T, class F>
argument compute_op(rank<0>,
const T& x,
const shape&,
const std::vector<argument>&,
const std::vector<module_ref>&,
F)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
}
template <class T>
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
template <class T, class F>
argument compute_op(const T& x,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f)
{
return compute_op(rank<2>{}, x, output_shape, input);
return compute_op(rank<1>{}, x, output, inputs, module_args, f);
}
template <class T, class F>
auto compute_op(rank<1>,
auto compute_op(rank<4>,
const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(inputs, module_args, f))
F f) -> decltype(x.compute(auto_any_cast(ctx), output, inputs, module_args, f))
{
return x.compute(inputs, module_args, f);
return x.compute(auto_any_cast(ctx), output, inputs, module_args, f);
}
template <class T, class F>
argument
compute_op(rank<0>, const T& x, const std::vector<argument>&, const std::vector<module_ref>&, F)
auto compute_op(rank<3>,
const T& x,
context&,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f))
{
return x.compute(output, inputs, module_args, f);
}
template <class T, class F>
auto compute_op(rank<2>,
const T& x,
context&,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(output, inputs))
{
return x.compute(output, inputs);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs))
{
return x.compute(auto_any_cast(ctx), output, inputs);
}
template <class T, class F>
argument compute_op(rank<0>,
const T& x,
context&,
const shape&,
const std::vector<argument>&,
const std::vector<module_ref>&,
F)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
......@@ -258,11 +319,13 @@ argument
template <class T, class F>
argument compute_op(const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f)
{
return compute_op(rank<1>{}, x, inputs, module_args, f);
return compute_op(rank<4>{}, x, ctx, output, inputs, module_args, f);
}
template <class T>
......@@ -385,9 +448,9 @@ void from_value_op(T& x, const value& v)
}
template <class T>
bool is_borrowed_op(const T&)
lifetime get_lifetime_op(const T&)
{
return false;
return lifetime::local;
}
} // namespace detail
......@@ -401,7 +464,7 @@ bool is_borrowed_op(const T&)
* bool is_context_free() const;
* bool need_normalization() const;
* bool has_finalize() const;
* bool is_borrowed() const;
* lifetime get_lifetime() const;
* std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
* value compile(context& ctx,const shape& output,const std::vector<shape>& input) ;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
......@@ -409,9 +472,12 @@ bool is_borrowed_op(const T&)
* shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>&
* mod_args) const; argument compute(context& ctx,const shape& output,const std::vector<argument>&
* input) const; argument compute(const shape& output,const std::vector<argument>& input)
* const; argument compute(const std::vector<argument>& input,const std::vector<module_ref>&
* module_args,std::function<std::vector<argument>(module_ref& mdl, const
* std::unordered_map<std::string, argument>& inputs)> run) const; value to_value() const; void
* const; argument compute(const shape& output,const std::vector<argument>& input,const
* std::vector<module_ref>& module_args,std::function<std::vector<argument>(module_ref&, const
* std::unordered_map<std::string, argument>&)> run) const; argument compute(context& ctx,const
* shape& output,const std::vector<argument>& input,const std::vector<module_ref>&
* module_args,std::function<std::vector<argument>(module_ref&, const
* std::unordered_map<std::string, argument>&)> run) const; value to_value() const; void
* from_value(const value& v) ; value attributes() const; friend std::ostream &
* operator<<(std::ostream & os,const operation & op) ; friend bool operator==(const operation &
* x,const operation & y) ;
......@@ -506,10 +572,10 @@ struct operation
return (*this).private_detail_te_get_handle().has_finalize();
}
bool is_borrowed() const
lifetime get_lifetime() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().is_borrowed();
return (*this).private_detail_te_get_handle().get_lifetime();
}
std::ptrdiff_t output_alias(const std::vector<shape>& input) const
......@@ -555,14 +621,27 @@ struct operation
return (*this).private_detail_te_get_handle().compute(output, input);
}
argument compute(
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run) const
argument compute(const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(
output, input, module_args, std::move(run));
}
argument compute(context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(input, module_args, std::move(run));
return (*this).private_detail_te_get_handle().compute(
ctx, output, input, module_args, std::move(run));
}
value to_value() const
......@@ -612,7 +691,7 @@ struct operation
virtual bool is_context_free() const = 0;
virtual bool need_normalization() const = 0;
virtual bool has_finalize() const = 0;
virtual bool is_borrowed() const = 0;
virtual lifetime get_lifetime() const = 0;
virtual std::ptrdiff_t output_alias(const std::vector<shape>& input) const = 0;
virtual value
compile(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
......@@ -625,16 +704,23 @@ struct operation
compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0;
virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0;
virtual argument
compute(const std::vector<argument>& input,
compute(const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run)
const = 0;
virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0;
virtual value attributes() const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
virtual bool operator==(const operation& y) const = 0;
module_ref&, const std::unordered_map<std::string, argument>&)> run) const = 0;
virtual argument
compute(context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const = 0;
virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0;
virtual value attributes() const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
virtual bool operator==(const operation& y) const = 0;
};
template <class T>
......@@ -677,16 +763,16 @@ struct operation
}
template <class T>
static auto private_detail_te_default_is_borrowed(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.is_borrowed())
static auto private_detail_te_default_get_lifetime(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.get_lifetime())
{
return private_detail_te_self.is_borrowed();
return private_detail_te_self.get_lifetime();
}
template <class T>
static bool private_detail_te_default_is_borrowed(float, T&& private_detail_te_self)
static lifetime private_detail_te_default_get_lifetime(float, T&& private_detail_te_self)
{
return detail::is_borrowed_op(private_detail_te_self);
return detail::get_lifetime_op(private_detail_te_self);
}
template <class T>
......@@ -828,25 +914,58 @@ struct operation
static auto private_detail_te_default_compute(
char,
T&& private_detail_te_self,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(module_ref&,
const std::unordered_map<std::string, argument>&)> run)
-> decltype(private_detail_te_self.compute(output, input, module_args, std::move(run)))
{
return private_detail_te_self.compute(output, input, module_args, std::move(run));
}
template <class T>
static argument private_detail_te_default_compute(
float,
T&& private_detail_te_self,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(module_ref&,
const std::unordered_map<std::string, argument>&)> run)
{
return detail::compute_op(
private_detail_te_self, output, input, module_args, std::move(run));
}
template <class T>
static auto private_detail_te_default_compute(
char,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run)
-> decltype(private_detail_te_self.compute(input, module_args, std::move(run)))
std::function<std::vector<argument>(module_ref&,
const std::unordered_map<std::string, argument>&)> run)
-> decltype(private_detail_te_self.compute(ctx, output, input, module_args, std::move(run)))
{
return private_detail_te_self.compute(input, module_args, std::move(run));
return private_detail_te_self.compute(ctx, output, input, module_args, std::move(run));
}
template <class T>
static argument private_detail_te_default_compute(
float,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run)
std::function<std::vector<argument>(module_ref&,
const std::unordered_map<std::string, argument>&)> run)
{
return detail::compute_op(private_detail_te_self, input, module_args, std::move(run));
return detail::compute_op(
private_detail_te_self, ctx, output, input, module_args, std::move(run));
}
template <class T>
......@@ -938,10 +1057,10 @@ struct operation
return private_detail_te_default_has_finalize(char(0), private_detail_te_value);
}
bool is_borrowed() const override
lifetime get_lifetime() const override
{
return private_detail_te_default_is_borrowed(char(0), private_detail_te_value);
return private_detail_te_default_get_lifetime(char(0), private_detail_te_value);
}
std::ptrdiff_t output_alias(const std::vector<shape>& input) const override
......@@ -994,16 +1113,29 @@ struct operation
char(0), private_detail_te_value, output, input);
}
argument
compute(const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run)
const override
argument compute(
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const override
{
return private_detail_te_default_compute(
char(0), private_detail_te_value, output, input, module_args, std::move(run));
}
argument compute(
context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const override
{
return private_detail_te_default_compute(
char(0), private_detail_te_value, input, module_args, std::move(run));
char(0), private_detail_te_value, ctx, output, input, module_args, std::move(run));
}
value to_value() const override
......
......@@ -35,6 +35,7 @@
#include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/get_tuple_elem.hpp>
#include <migraphx/op/greater.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/identity.hpp>
......@@ -48,6 +49,7 @@
#include <migraphx/op/logical_or.hpp>
#include <migraphx/op/logical_xor.hpp>
#include <migraphx/op/logsoftmax.hpp>
#include <migraphx/op/loop.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/max.hpp>
......@@ -55,6 +57,7 @@
#include <migraphx/op/mul.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/neg.hpp>
#include <migraphx/op/nonzero.hpp>
#include <migraphx/op/outline.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
......@@ -71,6 +74,7 @@
#include <migraphx/op/reduce_sum.hpp>
#include <migraphx/op/relu.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/reverse.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_cell_output.hpp>
#include <migraphx/op/rnn_last_hs_output.hpp>
......@@ -79,6 +83,7 @@
#include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter.hpp>
#include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sign.hpp>
#include <migraphx/op/sinh.hpp>
......@@ -88,14 +93,17 @@
#include <migraphx/op/sqrt.hpp>
#include <migraphx/op/sqdiff.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/step.hpp>
#include <migraphx/op/sub.hpp>
#include <migraphx/op/tanh.hpp>
#include <migraphx/op/tan.hpp>
#include <migraphx/op/topk.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unary.hpp>
#include <migraphx/op/unary_not.hpp>
#include <migraphx/op/undefined.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/op/unsqueeze.hpp>
#include <migraphx/op/where.hpp>
#endif
#ifndef MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
#include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK)
#if __has_include(<optional>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_OPTIONAL 1
#else
#define MIGRAPHX_HAS_OPTIONAL 0
#endif
#if __has_include(<experimental/optional>) && __cplusplus >= 201103L
#define MIGRAPHX_HAS_OPTIONAL_TS 1
#else
#define MIGRAPHX_HAS_OPTIONAL_TS 0
#endif
#else
#define MIGRAPHX_HAS_OPTIONAL 0
#define MIGRAPHX_HAS_OPTIONAL_TS 0
#endif
#if MIGRAPHX_HAS_OPTIONAL
#include <optional>
#elif MIGRAPHX_HAS_OPTIONAL_TS
#include <experimental/optional>
#else
#error "No optional include available"
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#if MIGRAPHX_HAS_OPTIONAL
template <class T>
using optional = std::optional<T>;
using nullopt_t = std::nullopt_t;
constexpr auto nullopt = std::nullopt;
#elif MIGRAPHX_HAS_OPTIONAL_TS
template <class T>
using optional = std::experimental::optional<T>;
using nullopt_t = std::experimental::nullopt_t;
constexpr auto nullopt = std::experimental::nullopt;
#endif
template <class T>
bool has_value(const optional<T>& x)
{
return x != nullopt;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment