Commit 4a39a0f7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into add-conv_bn_add-test

parents 5564172e bb827865
...@@ -15,6 +15,7 @@ namespace op { ...@@ -15,6 +15,7 @@ namespace op {
// 3.1) include_min(default)/exclude_min // 3.1) include_min(default)/exclude_min
// 4) clip_max(default)/not_clip_max // 4) clip_max(default)/not_clip_max
// 4.1) exclude_max(default)/include_max // 4.1) exclude_max(default)/include_max
// 5) normalize padding
enum class normalize_attribute enum class normalize_attribute
{ {
use_len, use_len,
...@@ -22,7 +23,8 @@ enum class normalize_attribute ...@@ -22,7 +23,8 @@ enum class normalize_attribute
clip_max, clip_max,
clip_min, clip_min,
include_max, include_max,
include_min include_min,
normalize_padding
}; };
} // namespace op } // namespace op
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/value.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/int_divide.hpp> #include <migraphx/int_divide.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -40,29 +41,39 @@ struct pooling ...@@ -40,29 +41,39 @@ struct pooling
void check_attribute_size() const void check_attribute_size() const
{ {
if(not(padding.size() == stride.size() and padding.size() == lengths.size())) if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == lengths.size()))
{ {
MIGRAPHX_THROW("POOLING: inconsistent attribute sizes"); MIGRAPHX_THROW("POOLING: inconsistent attribute sizes");
} }
} }
shape compute_shape(std::vector<shape> inputs) const value attributes() const { return {{"normalize_padding", "padding"}}; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
auto input_lens = input.lens();
size_t kdims = input_lens.size() - 2; auto input_lens = input.lens();
if(kdims != this->kdims()) size_t kdims = input_lens.size() - 2;
auto input_size = inputs[0].lens().size();
auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2))
{ {
MIGRAPHX_THROW("pooling: input k-dims does not match attribute size"); MIGRAPHX_THROW("POOLING: input and attribute size mismatch!");
} }
std::vector<std::size_t> output_lens(input_lens.begin(), input_lens.begin() + 2); std::vector<std::size_t> output_lens(input_lens.begin(), input_lens.begin() + 2);
for(size_t i = 0; i < kdims; i++) for(size_t i = 0; i < kdims; i++)
{ {
std::ptrdiff_t dim_size = input_lens[i + 2] + 2 * padding[i] - lengths[i]; std::ptrdiff_t dim_size;
auto padding_factor = 2 * padding[i];
if(padding_size == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
dim_size = input_lens[i + 2] + padding_factor - lengths[i];
assert(dim_size >= 0); assert(dim_size >= 0);
std::size_t len = (ceil_mode) ? ceil_divide<std::ptrdiff_t>(dim_size, stride[i]) std::size_t len = (ceil_mode) ? ceil_divide<std::ptrdiff_t>(dim_size, stride[i])
: floor_divide<std::ptrdiff_t>(dim_size, stride[i]); : floor_divide<std::ptrdiff_t>(dim_size, stride[i]);
...@@ -75,7 +86,7 @@ struct pooling ...@@ -75,7 +86,7 @@ struct pooling
size_t kdims() const size_t kdims() const
{ {
check_attribute_size(); check_attribute_size();
return padding.size(); return stride.size();
} }
}; };
......
...@@ -43,7 +43,7 @@ struct prefix_scan_op : op_name<Derived> ...@@ -43,7 +43,7 @@ struct prefix_scan_op : op_name<Derived>
argument compute(const shape&, std::vector<argument> args) const argument compute(const shape&, std::vector<argument> args) const
{ {
argument result = args[0]; argument result = args[0].copy();
auto s = result.get_shape(); auto s = result.get_shape();
auto slice = shape{s.type(), {s.lens()[axis]}, {s.strides()[axis]}}; auto slice = shape{s.type(), {s.lens()[axis]}, {s.strides()[axis]}};
auto lens = s.lens(); auto lens = s.lens();
......
...@@ -36,19 +36,23 @@ struct quant_convolution ...@@ -36,19 +36,23 @@ struct quant_convolution
f(self.group, "group")); f(self.group, "group"));
} }
value attributes() const { return {{"general_data_type", "convolution"}}; } value attributes() const
{
return {{"general_data_type", "convolution"}, {"normalize_padding", "padding"}};
}
std::string name() const { return "quant_convolution"; } std::string name() const { return "quant_convolution"; }
void check_attribute_size() const void check_attribute_size() const
{ {
if(not(padding.size() == stride.size() and padding.size() == dilation.size())) if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == dilation.size()))
{ {
MIGRAPHX_THROW("quant_convolution: inconsistent attribute sizes"); MIGRAPHX_THROW("QUANT_CONVOLUTION: inconsistent attribute sizes");
} }
} }
shape compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3); check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3);
check_attribute_size(); check_attribute_size();
...@@ -70,13 +74,16 @@ struct quant_convolution ...@@ -70,13 +74,16 @@ struct quant_convolution
t = shape::int32_type; t = shape::int32_type;
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]}; std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
auto padding_size = padding.size();
for(size_t i = 0; i < kdims; i++) for(size_t i = 0; i < kdims; i++)
{ {
auto padding_factor = 2 * padding[i];
if(padding_size == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>( output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1, 1,
(input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) + (input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) +
2 * padding[i]) / padding_factor) /
stride[i] + stride[i] +
1))); 1)));
} }
...@@ -87,7 +94,7 @@ struct quant_convolution ...@@ -87,7 +94,7 @@ struct quant_convolution
size_t kdims() const size_t kdims() const
{ {
check_attribute_size(); check_attribute_size();
return padding.size(); return stride.size();
} }
}; };
......
...@@ -18,21 +18,12 @@ namespace op { ...@@ -18,21 +18,12 @@ namespace op {
struct quant_dot struct quant_dot
{ {
int32_t alpha = 1;
int32_t beta = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
}
value attributes() const { return {{"general_data_type", "dot"}}; } value attributes() const { return {{"general_data_type", "dot"}}; }
std::string name() const { return "quant_dot"; } std::string name() const { return "quant_dot"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.same_type(); check_shapes{{inputs.at(0), inputs.at(1)}, *this}.same_type().has(2);
const shape& a = inputs.at(0); const shape& a = inputs.at(0);
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
...@@ -64,18 +55,6 @@ struct quant_dot ...@@ -64,18 +55,6 @@ struct quant_dot
auto out_lens = a.lens(); auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1]; out_lens[dim_1] = b.lens()[dim_1];
if(inputs.size() == 3 && out_lens != inputs.at(2).lens())
{
MIGRAPHX_THROW("QUANT_DOT: dimension mismatch, operand C: {" +
to_string_range(inputs.at(2).lens()) +
"}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}");
}
if(inputs.size() == 3 && inputs.at(2).type() != shape::int32_type)
{
MIGRAPHX_THROW("QUANT_DOT: operand C type must be int32");
}
return {shape::int32_type, out_lens}; return {shape::int32_type, out_lens};
} }
}; };
......
#ifndef MIGRAPHX_GUARD_OPERATORS_QUANTIZE_LINEAR_HPP
#define MIGRAPHX_GUARD_OPERATORS_QUANTIZE_LINEAR_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/config.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/tune_axis.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct quantizelinear
{
std::string name() const { return "quantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.same_dims();
if(inputs.size() == 3)
{
return {inputs[2].type(), inputs[0].lens(), inputs[0].strides()};
}
return {shape::uint8_type, inputs[0].lens(), inputs[0].strides()};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto x = args.at(0);
auto y_scale = args.at(1);
std::vector<int8_t> zeros(output_shape.bytes(), 0);
argument y_zero_point{output_shape, zeros.data()};
if(args.size() == 3)
{
y_zero_point = args.at(2);
}
argument result{output_shape};
visit_all(result, y_zero_point)([&](auto output, auto zero_pts) {
x.visit([&](auto input) {
y_scale.visit([&](auto scales) {
using quant_type = typename decltype(output)::value_type;
auto min_value = std::numeric_limits<quant_type>::min();
auto max_value = std::numeric_limits<quant_type>::max();
par_for(output_shape.elements(), [&](auto i) {
int64_t quantized = static_cast<int64_t>(std::round(input[i] / scales[i])) +
static_cast<int64_t>(zero_pts[i]);
output[i] = std::max(static_cast<int64_t>(min_value),
std::min(static_cast<int64_t>(max_value), quantized));
});
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -71,7 +72,7 @@ struct reshape ...@@ -71,7 +72,7 @@ struct reshape
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
#ifndef MIGRAPHX_GUARD_OPERATORS_REVERSE_HPP
#define MIGRAPHX_GUARD_OPERATORS_REVERSE_HPP
#include <algorithm>
#include <vector>
#include <cmath>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct reverse
{
std::vector<int64_t> axes;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"));
}
std::string name() const { return "reverse"; }
value attributes() const
{
value normalize;
normalize["axes"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
shape normalize_compute_shape(std::vector<shape> inputs) const
{
return inputs[0].with_lens(inputs[0].lens());
}
argument compute(const shape& s, std::vector<argument> args) const
{
argument result{s};
auto lens = s.lens();
visit_all(result, args.front())([&](auto output, auto input) {
shape_for_each(s, [&](const auto& out_idx) {
auto in_idx = out_idx;
for(const auto& axis : axes)
{
in_idx[axis] = lens[axis] - 1 - out_idx[axis];
}
output[s.index(out_idx)] = input[s.index(in_idx)];
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -39,7 +40,7 @@ struct scalar ...@@ -39,7 +40,7 @@ struct scalar
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatter
{
int64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "scatter"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3).standard();
return inputs.front();
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
// max dimension in axis
auto axis_dim_size = output_shape.lens()[axis];
visit_all(result, args[0], args[2])([&](auto output, auto data, auto update) {
std::copy(data.begin(), data.end(), output.begin());
args[1].visit([&](auto indices) {
auto ind_s = indices.get_shape();
shape_for_each(ind_s, [&](const auto& idx) {
auto out_idx = idx;
auto index = indices[ind_s.index(idx)];
index = (index < 0) ? index + axis_dim_size : index;
out_idx[axis] = index;
output[output_shape.index(out_idx)] = update[ind_s.index(idx)];
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -77,7 +78,7 @@ struct squeeze ...@@ -77,7 +78,7 @@ struct squeeze
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
#ifndef MIGRAPHX_GUARD_OPERATORS_STEP_HPP
#define MIGRAPHX_GUARD_OPERATORS_STEP_HPP
#include "migraphx/stringutils.hpp"
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct step
{
std::vector<int64_t> axes;
std::vector<int64_t> steps;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"), f(self.steps, "steps"));
}
value attributes() const
{
value normalize;
normalize["axes"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "step"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto input = inputs.at(0);
auto in_lens = input.lens();
auto t = input.type();
if(axes.size() != steps.size())
{
MIGRAPHX_THROW("STEP: attribute axes {" + to_string_range(axes) +
"} has different dimensions from step {" + to_string_range(steps) +
"}.");
}
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return axis >= in_lens.size(); }))
{
MIGRAPHX_THROW("STEP: axis value is out of range!");
}
auto lens = in_lens;
auto strides = input.strides();
for(auto i : range(axes.size()))
{
auto axis = axes[i];
auto step = steps[i];
lens[axis] = (in_lens[axis] + step - 1) / step;
strides[axis] *= step;
}
return {t, lens, strides};
}
argument compute(shape output_shape, std::vector<argument> args) const
{
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_GATHER_HPP
#define MIGRAPHX_GUARD_OPERATORS_GATHER_HPP
#include <algorithm>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct topk
{
int64_t k = 1;
int64_t axis = 0;
bool largest = true;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.k, "k"), f(self.axis, "axis"), f(self.largest, "largest"));
}
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "topk"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs.at(0).lens();
auto type = inputs.at(0).type();
lens[axis] = k;
shape s_val{type, lens};
shape s_ind{shape::int64_type, lens};
return shape({s_val, s_ind});
}
template <class T, class Compare>
struct heap_vector
{
std::vector<T> data;
Compare compare;
heap_vector(const std::vector<T>& val, Compare comp) : data(val), compare(std::move(comp))
{
std::make_heap(data.begin(), data.end(), compare);
}
void try_push(T val)
{
if(not compare(val, data.front()))
return;
std::pop_heap(data.begin(), data.end(), compare);
data.back() = val;
std::push_heap(data.begin(), data.end(), compare);
}
std::vector<T> sort()
{
auto sorted_data = data;
std::sort_heap(sorted_data.begin(), sorted_data.end(), compare);
return sorted_data;
}
};
template <class T, class Compare>
heap_vector<T, Compare> make_heap(std::vector<T> val, Compare compare) const
{
return {std::move(val), std::move(compare)};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto vec_ss = output_shape.sub_shapes();
argument res_val{vec_ss.front()};
argument res_ind{vec_ss.back()};
auto in_s = args.front().get_shape();
auto out_s = vec_ss.front();
auto comp_lens = in_s.lens();
auto axis_dim = comp_lens[axis];
// compute shape
comp_lens[axis] = 1;
shape comp_s{in_s.type(), comp_lens};
visit_all(res_val, args.front())([&](auto out_val, auto input) {
auto* out_ind = res_ind.cast<int64_t>();
par_for(comp_s.elements(), [&](auto i) {
auto idx = comp_s.multi(i);
std::vector<std::size_t> indices(k);
std::iota(indices.begin(), indices.end(), 0);
auto comp = [&](auto i1, auto i2) {
auto idx1 = idx;
auto idx2 = idx;
idx1[axis] = i1;
idx2[axis] = i2;
return this->largest
? std::greater<>{}(input[in_s.index(idx1)], input[in_s.index(idx2)])
: std::less<>{}(input[in_s.index(idx1)], input[in_s.index(idx2)]);
};
auto hp = this->make_heap(indices, comp);
for(std::size_t ii = indices.size(); ii < axis_dim; ++ii)
{
hp.try_push(ii);
}
auto sorted_indices = hp.sort();
auto out_idx = idx;
auto in_idx = idx;
for(auto j : range(sorted_indices.size()))
{
out_idx[axis] = j;
in_idx[axis] = sorted_indices[j];
out_val[out_s.index(out_idx)] = input[in_s.index(in_idx)];
out_ind[out_s.index(out_idx)] = sorted_indices[j];
}
});
});
return argument({res_val, res_ind});
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -20,7 +21,7 @@ struct transpose ...@@ -20,7 +21,7 @@ struct transpose
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.dims, "dims")); return pack(f(self.dims, "permutation"));
} }
std::string name() const { return "transpose"; } std::string name() const { return "transpose"; }
...@@ -31,31 +32,23 @@ struct transpose ...@@ -31,31 +32,23 @@ struct transpose
auto input_lens = input.lens(); auto input_lens = input.lens();
auto input_strides = input.strides(); auto input_strides = input.strides();
auto t = input.type(); auto t = input.type();
auto tuned_dims = dims;
// if not perm provided, reverse the dims
if(tuned_dims.empty())
{
tuned_dims.resize(input_lens.size());
std::iota(tuned_dims.begin(), tuned_dims.end(), 0);
std::reverse(tuned_dims.begin(), tuned_dims.end());
}
if(tuned_dims.size() != input_lens.size()) if(dims.size() != input_lens.size())
{ {
MIGRAPHX_THROW("Permutation has wrong number of axes"); MIGRAPHX_THROW("Permutation has wrong number of axes");
} }
std::vector<int64_t> axes(tuned_dims.size()); std::vector<int64_t> axes(dims.size());
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
if(!std::is_permutation(axes.begin(), axes.end(), tuned_dims.begin())) if(!std::is_permutation(axes.begin(), axes.end(), dims.begin()))
{ {
MIGRAPHX_THROW("Invalid permutation"); MIGRAPHX_THROW("TRANSPOSE: Invalid permutation");
} }
std::vector<size_t> output_lens(input_lens.size()); std::vector<size_t> output_lens(input_lens.size());
std::vector<size_t> output_strides(input_lens.size()); std::vector<size_t> output_strides(input_lens.size());
for(std::size_t i = 0; i < output_lens.size(); i++) for(std::size_t i = 0; i < output_lens.size(); i++)
{ {
output_lens[i] = input_lens[tuned_dims[i]]; output_lens[i] = input_lens[dims[i]];
output_strides[i] = input_strides[tuned_dims[i]]; output_strides[i] = input_strides[dims[i]];
} }
return {t, output_lens, output_strides}; return {t, output_lens, output_strides};
} }
...@@ -63,7 +56,7 @@ struct transpose ...@@ -63,7 +56,7 @@ struct transpose
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -41,7 +41,11 @@ struct unary : op_name<Derived> ...@@ -41,7 +41,11 @@ struct unary : op_name<Derived>
{ {
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(1); check_shapes{inputs, static_cast<const Derived&>(*this)}.has(1);
auto s = inputs.at(0); auto s = inputs.at(0);
if(s.broadcasted()) if(s.scalar())
{
return s;
}
else if(s.broadcasted())
{ {
return {s.type(), s.lens()}; return {s.type(), s.lens()};
} }
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -70,7 +71,7 @@ struct unsqueeze ...@@ -70,7 +71,7 @@ struct unsqueeze
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
#ifndef MIGRAPHX_GUARD_OPERATORS_WHERE_HPP
#define MIGRAPHX_GUARD_OPERATORS_WHERE_HPP
#include <array>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct where
{
std::string name() const { return "where"; }
value attributes() const { return {{"pointwise", true}, {"point_op", "${0} ? ${1} : ${2}"}}; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3).same_dims();
auto s1 = inputs.at(1);
auto s2 = inputs.at(2);
if(s1 == s2 and s1.packed())
{
return s1;
}
else if(s1.packed() != s2.packed())
{
return s1.packed() ? s1 : s2;
}
else if(s1.broadcasted() != s2.broadcasted())
{
return s1.broadcasted() ? s2.with_lens(s1.lens()) : s1.with_lens(s1.lens());
}
else
{
return {s1.type(), s1.lens()};
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[1], args[2])([&](auto output, const auto x, const auto y) {
args[0].visit([&](const auto condition) {
par_for(output_shape.elements(),
[&](auto i) { output[i] = condition[i] ? x[i] : y[i]; });
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <migraphx/module_ref.hpp> #include <migraphx/module_ref.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/auto_any_cast.hpp> #include <migraphx/auto_any_cast.hpp>
#include <migraphx/lifetime.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -178,7 +179,7 @@ shape normalize_compute_shape_op(const T& x, ...@@ -178,7 +179,7 @@ shape normalize_compute_shape_op(const T& x,
} }
template <class T> template <class T>
auto compute_op(rank<2>, auto compute_op(rank<1>,
const T& x, const T& x,
context& ctx, context& ctx,
const shape& output_shape, const shape& output_shape,
...@@ -188,14 +189,6 @@ auto compute_op(rank<2>, ...@@ -188,14 +189,6 @@ auto compute_op(rank<2>,
return x.compute(auto_any_cast(ctx), output_shape, input); return x.compute(auto_any_cast(ctx), output_shape, input);
} }
template <class T>
auto compute_op(
rank<1>, const T& x, context&, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input))
{
return x.compute(output_shape, input);
}
template <class T> template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&) argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{ {
...@@ -207,50 +200,118 @@ template <class T> ...@@ -207,50 +200,118 @@ template <class T>
argument argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input) compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
{ {
return compute_op(rank<2>{}, x, ctx, output_shape, input); return compute_op(rank<1>{}, x, ctx, output_shape, input);
} }
template <class T> template <class T>
auto compute_op(rank<2>, const T& x, const shape& output_shape, const std::vector<argument>& input) auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input)) -> decltype(x.compute(output_shape, input))
{ {
return x.compute(output_shape, input); return x.compute(output_shape, input);
} }
template <class T> template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input) argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
-> decltype(x.compute(auto_any_cast(std::declval<context&>()), output_shape, input))
{ {
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Not computable without a context: " + name); MIGRAPHX_THROW("Not computable: " + name);
} }
template <class T> template <class T>
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&) argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
{
return compute_op(rank<1>{}, x, output_shape, input);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f))
{
return x.compute(output, inputs, module_args, f);
}
template <class T, class F>
argument compute_op(rank<0>,
const T& x,
const shape&,
const std::vector<argument>&,
const std::vector<module_ref>&,
F)
{ {
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name); MIGRAPHX_THROW("Not computable: " + name);
} }
template <class T> template <class T, class F>
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input) argument compute_op(const T& x,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f)
{ {
return compute_op(rank<2>{}, x, output_shape, input); return compute_op(rank<1>{}, x, output, inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
auto compute_op(rank<1>, auto compute_op(rank<4>,
const T& x, const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(inputs, module_args, f)) F f) -> decltype(x.compute(auto_any_cast(ctx), output, inputs, module_args, f))
{ {
return x.compute(inputs, module_args, f); return x.compute(auto_any_cast(ctx), output, inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
argument auto compute_op(rank<3>,
compute_op(rank<0>, const T& x, const std::vector<argument>&, const std::vector<module_ref>&, F) const T& x,
context&,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f))
{
return x.compute(output, inputs, module_args, f);
}
template <class T, class F>
auto compute_op(rank<2>,
const T& x,
context&,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(output, inputs))
{
return x.compute(output, inputs);
}
template <class T, class F>
auto compute_op(rank<1>,
const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs,
const std::vector<module_ref>&,
F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs))
{
return x.compute(auto_any_cast(ctx), output, inputs);
}
template <class T, class F>
argument compute_op(rank<0>,
const T& x,
context&,
const shape&,
const std::vector<argument>&,
const std::vector<module_ref>&,
F)
{ {
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name); MIGRAPHX_THROW("Not computable: " + name);
...@@ -258,11 +319,13 @@ argument ...@@ -258,11 +319,13 @@ argument
template <class T, class F> template <class T, class F>
argument compute_op(const T& x, argument compute_op(const T& x,
context& ctx,
const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) F f)
{ {
return compute_op(rank<1>{}, x, inputs, module_args, f); return compute_op(rank<4>{}, x, ctx, output, inputs, module_args, f);
} }
template <class T> template <class T>
...@@ -385,9 +448,9 @@ void from_value_op(T& x, const value& v) ...@@ -385,9 +448,9 @@ void from_value_op(T& x, const value& v)
} }
template <class T> template <class T>
bool is_borrowed_op(const T&) lifetime get_lifetime_op(const T&)
{ {
return false; return lifetime::local;
} }
} // namespace detail } // namespace detail
...@@ -401,7 +464,7 @@ bool is_borrowed_op(const T&) ...@@ -401,7 +464,7 @@ bool is_borrowed_op(const T&)
* bool is_context_free() const; * bool is_context_free() const;
* bool need_normalization() const; * bool need_normalization() const;
* bool has_finalize() const; * bool has_finalize() const;
* bool is_borrowed() const; * lifetime get_lifetime() const;
* std::ptrdiff_t output_alias(const std::vector<shape>& input) const; * std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
* value compile(context& ctx,const shape& output,const std::vector<shape>& input) ; * value compile(context& ctx,const shape& output,const std::vector<shape>& input) ;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ; * void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
...@@ -409,9 +472,12 @@ bool is_borrowed_op(const T&) ...@@ -409,9 +472,12 @@ bool is_borrowed_op(const T&)
* shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>& * shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>&
* mod_args) const; argument compute(context& ctx,const shape& output,const std::vector<argument>& * mod_args) const; argument compute(context& ctx,const shape& output,const std::vector<argument>&
* input) const; argument compute(const shape& output,const std::vector<argument>& input) * input) const; argument compute(const shape& output,const std::vector<argument>& input)
* const; argument compute(const std::vector<argument>& input,const std::vector<module_ref>& * const; argument compute(const shape& output,const std::vector<argument>& input,const
* module_args,std::function<std::vector<argument>(module_ref& mdl, const * std::vector<module_ref>& module_args,std::function<std::vector<argument>(module_ref&, const
* std::unordered_map<std::string, argument>& inputs)> run) const; value to_value() const; void * std::unordered_map<std::string, argument>&)> run) const; argument compute(context& ctx,const
* shape& output,const std::vector<argument>& input,const std::vector<module_ref>&
* module_args,std::function<std::vector<argument>(module_ref&, const
* std::unordered_map<std::string, argument>&)> run) const; value to_value() const; void
* from_value(const value& v) ; value attributes() const; friend std::ostream & * from_value(const value& v) ; value attributes() const; friend std::ostream &
* operator<<(std::ostream & os,const operation & op) ; friend bool operator==(const operation & * operator<<(std::ostream & os,const operation & op) ; friend bool operator==(const operation &
* x,const operation & y) ; * x,const operation & y) ;
...@@ -506,10 +572,10 @@ struct operation ...@@ -506,10 +572,10 @@ struct operation
return (*this).private_detail_te_get_handle().has_finalize(); return (*this).private_detail_te_get_handle().has_finalize();
} }
bool is_borrowed() const lifetime get_lifetime() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().is_borrowed(); return (*this).private_detail_te_get_handle().get_lifetime();
} }
std::ptrdiff_t output_alias(const std::vector<shape>& input) const std::ptrdiff_t output_alias(const std::vector<shape>& input) const
...@@ -555,14 +621,27 @@ struct operation ...@@ -555,14 +621,27 @@ struct operation
return (*this).private_detail_te_get_handle().compute(output, input); return (*this).private_detail_te_get_handle().compute(output, input);
} }
argument compute( argument compute(const shape& output,
const std::vector<argument>& input, const std::vector<argument>& input,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
std::function<std::vector<argument>( std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run) const module_ref&, const std::unordered_map<std::string, argument>&)> run) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(
output, input, module_args, std::move(run));
}
argument compute(context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().compute(input, module_args, std::move(run)); return (*this).private_detail_te_get_handle().compute(
ctx, output, input, module_args, std::move(run));
} }
value to_value() const value to_value() const
...@@ -612,7 +691,7 @@ struct operation ...@@ -612,7 +691,7 @@ struct operation
virtual bool is_context_free() const = 0; virtual bool is_context_free() const = 0;
virtual bool need_normalization() const = 0; virtual bool need_normalization() const = 0;
virtual bool has_finalize() const = 0; virtual bool has_finalize() const = 0;
virtual bool is_borrowed() const = 0; virtual lifetime get_lifetime() const = 0;
virtual std::ptrdiff_t output_alias(const std::vector<shape>& input) const = 0; virtual std::ptrdiff_t output_alias(const std::vector<shape>& input) const = 0;
virtual value virtual value
compile(context& ctx, const shape& output, const std::vector<shape>& input) = 0; compile(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
...@@ -625,16 +704,23 @@ struct operation ...@@ -625,16 +704,23 @@ struct operation
compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0; compute(context& ctx, const shape& output, const std::vector<argument>& input) const = 0;
virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0; virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0;
virtual argument virtual argument
compute(const std::vector<argument>& input, compute(const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
std::function<std::vector<argument>( std::function<std::vector<argument>(
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run) module_ref&, const std::unordered_map<std::string, argument>&)> run) const = 0;
const = 0; virtual argument
virtual value to_value() const = 0; compute(context& ctx,
virtual void from_value(const value& v) = 0; const shape& output,
virtual value attributes() const = 0; const std::vector<argument>& input,
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0; const std::vector<module_ref>& module_args,
virtual bool operator==(const operation& y) const = 0; std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const = 0;
virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0;
virtual value attributes() const = 0;
virtual std::ostream& operator_shift_left(std::ostream& os) const = 0;
virtual bool operator==(const operation& y) const = 0;
}; };
template <class T> template <class T>
...@@ -677,16 +763,16 @@ struct operation ...@@ -677,16 +763,16 @@ struct operation
} }
template <class T> template <class T>
static auto private_detail_te_default_is_borrowed(char, T&& private_detail_te_self) static auto private_detail_te_default_get_lifetime(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.is_borrowed()) -> decltype(private_detail_te_self.get_lifetime())
{ {
return private_detail_te_self.is_borrowed(); return private_detail_te_self.get_lifetime();
} }
template <class T> template <class T>
static bool private_detail_te_default_is_borrowed(float, T&& private_detail_te_self) static lifetime private_detail_te_default_get_lifetime(float, T&& private_detail_te_self)
{ {
return detail::is_borrowed_op(private_detail_te_self); return detail::get_lifetime_op(private_detail_te_self);
} }
template <class T> template <class T>
...@@ -828,25 +914,58 @@ struct operation ...@@ -828,25 +914,58 @@ struct operation
static auto private_detail_te_default_compute( static auto private_detail_te_default_compute(
char, char,
T&& private_detail_te_self, T&& private_detail_te_self,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(module_ref&,
const std::unordered_map<std::string, argument>&)> run)
-> decltype(private_detail_te_self.compute(output, input, module_args, std::move(run)))
{
return private_detail_te_self.compute(output, input, module_args, std::move(run));
}
template <class T>
static argument private_detail_te_default_compute(
float,
T&& private_detail_te_self,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(module_ref&,
const std::unordered_map<std::string, argument>&)> run)
{
return detail::compute_op(
private_detail_te_self, output, input, module_args, std::move(run));
}
template <class T>
static auto private_detail_te_default_compute(
char,
T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<argument>& input, const std::vector<argument>& input,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
std::function<std::vector<argument>( std::function<std::vector<argument>(module_ref&,
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run) const std::unordered_map<std::string, argument>&)> run)
-> decltype(private_detail_te_self.compute(input, module_args, std::move(run))) -> decltype(private_detail_te_self.compute(ctx, output, input, module_args, std::move(run)))
{ {
return private_detail_te_self.compute(input, module_args, std::move(run)); return private_detail_te_self.compute(ctx, output, input, module_args, std::move(run));
} }
template <class T> template <class T>
static argument private_detail_te_default_compute( static argument private_detail_te_default_compute(
float, float,
T&& private_detail_te_self, T&& private_detail_te_self,
context& ctx,
const shape& output,
const std::vector<argument>& input, const std::vector<argument>& input,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
std::function<std::vector<argument>( std::function<std::vector<argument>(module_ref&,
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run) const std::unordered_map<std::string, argument>&)> run)
{ {
return detail::compute_op(private_detail_te_self, input, module_args, std::move(run)); return detail::compute_op(
private_detail_te_self, ctx, output, input, module_args, std::move(run));
} }
template <class T> template <class T>
...@@ -938,10 +1057,10 @@ struct operation ...@@ -938,10 +1057,10 @@ struct operation
return private_detail_te_default_has_finalize(char(0), private_detail_te_value); return private_detail_te_default_has_finalize(char(0), private_detail_te_value);
} }
bool is_borrowed() const override lifetime get_lifetime() const override
{ {
return private_detail_te_default_is_borrowed(char(0), private_detail_te_value); return private_detail_te_default_get_lifetime(char(0), private_detail_te_value);
} }
std::ptrdiff_t output_alias(const std::vector<shape>& input) const override std::ptrdiff_t output_alias(const std::vector<shape>& input) const override
...@@ -994,16 +1113,29 @@ struct operation ...@@ -994,16 +1113,29 @@ struct operation
char(0), private_detail_te_value, output, input); char(0), private_detail_te_value, output, input);
} }
argument argument compute(
compute(const std::vector<argument>& input, const shape& output,
const std::vector<module_ref>& module_args, const std::vector<argument>& input,
std::function<std::vector<argument>( const std::vector<module_ref>& module_args,
module_ref& mdl, const std::unordered_map<std::string, argument>& inputs)> run) std::function<std::vector<argument>(
const override module_ref&, const std::unordered_map<std::string, argument>&)> run) const override
{
return private_detail_te_default_compute(
char(0), private_detail_te_value, output, input, module_args, std::move(run));
}
argument compute(
context& ctx,
const shape& output,
const std::vector<argument>& input,
const std::vector<module_ref>& module_args,
std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)> run) const override
{ {
return private_detail_te_default_compute( return private_detail_te_default_compute(
char(0), private_detail_te_value, input, module_args, std::move(run)); char(0), private_detail_te_value, ctx, output, input, module_args, std::move(run));
} }
value to_value() const override value to_value() const override
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include <migraphx/op/flatten.hpp> #include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp> #include <migraphx/op/floor.hpp>
#include <migraphx/op/gather.hpp> #include <migraphx/op/gather.hpp>
#include <migraphx/op/get_tuple_elem.hpp>
#include <migraphx/op/greater.hpp> #include <migraphx/op/greater.hpp>
#include <migraphx/op/gru.hpp> #include <migraphx/op/gru.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
...@@ -48,6 +49,7 @@ ...@@ -48,6 +49,7 @@
#include <migraphx/op/logical_or.hpp> #include <migraphx/op/logical_or.hpp>
#include <migraphx/op/logical_xor.hpp> #include <migraphx/op/logical_xor.hpp>
#include <migraphx/op/logsoftmax.hpp> #include <migraphx/op/logsoftmax.hpp>
#include <migraphx/op/loop.hpp>
#include <migraphx/op/lrn.hpp> #include <migraphx/op/lrn.hpp>
#include <migraphx/op/lstm.hpp> #include <migraphx/op/lstm.hpp>
#include <migraphx/op/max.hpp> #include <migraphx/op/max.hpp>
...@@ -55,6 +57,7 @@ ...@@ -55,6 +57,7 @@
#include <migraphx/op/mul.hpp> #include <migraphx/op/mul.hpp>
#include <migraphx/op/multibroadcast.hpp> #include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/neg.hpp> #include <migraphx/op/neg.hpp>
#include <migraphx/op/nonzero.hpp>
#include <migraphx/op/outline.hpp> #include <migraphx/op/outline.hpp>
#include <migraphx/op/pad.hpp> #include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp> #include <migraphx/op/pooling.hpp>
...@@ -71,6 +74,7 @@ ...@@ -71,6 +74,7 @@
#include <migraphx/op/reduce_sum.hpp> #include <migraphx/op/reduce_sum.hpp>
#include <migraphx/op/relu.hpp> #include <migraphx/op/relu.hpp>
#include <migraphx/op/reshape.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/reverse.hpp>
#include <migraphx/op/rnn.hpp> #include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_cell_output.hpp> #include <migraphx/op/rnn_last_cell_output.hpp>
#include <migraphx/op/rnn_last_hs_output.hpp> #include <migraphx/op/rnn_last_hs_output.hpp>
...@@ -79,6 +83,7 @@ ...@@ -79,6 +83,7 @@
#include <migraphx/op/round.hpp> #include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp> #include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp> #include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter.hpp>
#include <migraphx/op/sigmoid.hpp> #include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sign.hpp> #include <migraphx/op/sign.hpp>
#include <migraphx/op/sinh.hpp> #include <migraphx/op/sinh.hpp>
...@@ -88,14 +93,17 @@ ...@@ -88,14 +93,17 @@
#include <migraphx/op/sqrt.hpp> #include <migraphx/op/sqrt.hpp>
#include <migraphx/op/sqdiff.hpp> #include <migraphx/op/sqdiff.hpp>
#include <migraphx/op/squeeze.hpp> #include <migraphx/op/squeeze.hpp>
#include <migraphx/op/step.hpp>
#include <migraphx/op/sub.hpp> #include <migraphx/op/sub.hpp>
#include <migraphx/op/tanh.hpp> #include <migraphx/op/tanh.hpp>
#include <migraphx/op/tan.hpp> #include <migraphx/op/tan.hpp>
#include <migraphx/op/topk.hpp>
#include <migraphx/op/transpose.hpp> #include <migraphx/op/transpose.hpp>
#include <migraphx/op/unary.hpp> #include <migraphx/op/unary.hpp>
#include <migraphx/op/unary_not.hpp> #include <migraphx/op/unary_not.hpp>
#include <migraphx/op/undefined.hpp> #include <migraphx/op/undefined.hpp>
#include <migraphx/op/unknown.hpp> #include <migraphx/op/unknown.hpp>
#include <migraphx/op/unsqueeze.hpp> #include <migraphx/op/unsqueeze.hpp>
#include <migraphx/op/where.hpp>
#endif #endif
#ifndef MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
#include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK)
#if __has_include(<optional>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_OPTIONAL 1
#else
#define MIGRAPHX_HAS_OPTIONAL 0
#endif
#if __has_include(<experimental/optional>) && __cplusplus >= 201103L
#define MIGRAPHX_HAS_OPTIONAL_TS 1
#else
#define MIGRAPHX_HAS_OPTIONAL_TS 0
#endif
#else
#define MIGRAPHX_HAS_OPTIONAL 0
#define MIGRAPHX_HAS_OPTIONAL_TS 0
#endif
#if MIGRAPHX_HAS_OPTIONAL
#include <optional>
#elif MIGRAPHX_HAS_OPTIONAL_TS
#include <experimental/optional>
#else
#error "No optional include available"
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#if MIGRAPHX_HAS_OPTIONAL
template <class T>
using optional = std::optional<T>;
using nullopt_t = std::nullopt_t;
constexpr auto nullopt = std::nullopt;
#elif MIGRAPHX_HAS_OPTIONAL_TS
template <class T>
using optional = std::experimental::optional<T>;
using nullopt_t = std::experimental::nullopt_t;
constexpr auto nullopt = std::experimental::nullopt;
#endif
template <class T>
bool has_value(const optional<T>& x)
{
return x != nullopt;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment