Unverified Commit 1b098fd7 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into type-string-driver

parents 05f2ee1c c0398ded
......@@ -2,13 +2,13 @@
#define MIGRAPHX_GUARD_OPERATORS_AS_SHAPE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -34,7 +34,7 @@ struct as_shape
}
argument compute(shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.front().data)};
return args.front().reshape(output_shape);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -3,7 +3,6 @@
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ASINH_HPP
#define MIGRAPHX_GUARD_OPERATORS_ASINH_HPP
#include <migraphx/op/unary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct asinh : unary<asinh>
{
auto apply() const
{
return [](auto x) { return std::asinh(x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -3,7 +3,6 @@
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ATANH_HPP
#define MIGRAPHX_GUARD_OPERATORS_ATANH_HPP
#include <migraphx/op/unary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct atanh : unary<atanh>
{
auto apply() const
{
return [](auto x) { return std::atanh(x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -3,7 +3,6 @@
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......@@ -42,9 +41,8 @@ struct batch_norm_inference
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(5);
check_shapes{inputs.data(), inputs.data() + 1, *this}.only_dims(4);
check_shapes{inputs.data() + 1, inputs.data() + inputs.size(), *this}.same_shape().elements(
inputs.front().lens()[1]);
check_shapes{inputs.data(), inputs.data() + 1, *this}.same_ndims();
check_shapes{inputs.data() + 1, inputs.data() + inputs.size(), *this}.same_shape();
return inputs.front();
}
};
......
......@@ -2,6 +2,11 @@
#define MIGRAPHX_GUARD_OPERATORS_BINARY_HPP
#include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -10,15 +15,45 @@ namespace op {
template <class Derived>
struct binary : op_name<Derived>
{
std::string point_function() const { return this->name(); }
std::string point_op() const
{
const auto& self = static_cast<const Derived&>(*this);
auto pf = self.point_function();
if(pf.empty())
return {};
if(with_char(::ispunct)(pf.front()))
{
return "${0} " + pf + " ${1}";
}
else
{
return "${function:" + pf + "}(${0}, ${1})";
}
}
value base_attributes() const
{
const auto& self = static_cast<const Derived&>(*this);
return {{"pointwise", true}, {"point_op", self.point_op()}};
}
value attributes() const { return base_attributes(); }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(2).same_type().same_dims();
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(2).same_type().same_dims();
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed())
{
return s0;
}
else if(s0.packed() != s1.packed())
{
return s0.packed() ? s0 : s1;
}
else if(s0.broadcasted() != s1.broadcasted())
{
return s0.broadcasted() ? s1.with_lens(s0.lens()) : s0.with_lens(s0.lens());
}
else
{
return {s0.type(), s0.lens()};
......@@ -28,32 +63,13 @@ struct binary : op_name<Derived>
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto s1 = args[0].get_shape();
auto s2 = args[1].get_shape();
if(s1 == s2 and s1.packed())
{
shape std_shape{s1.type(), s1.lens()};
argument std_result{std_shape, result.data()};
argument std_arg0{std_shape, args[0].data()};
argument std_arg1{std_shape, args[1].data()};
visit_all(std_result, std_arg0, std_arg1)([&](auto output, auto input1, auto input2) {
std::transform(input1.begin(),
input1.end(),
input2.begin(),
output.begin(),
static_cast<const Derived&>(*this).apply());
});
}
else
{
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end()));
});
});
}
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
std::transform(input1.begin(),
input1.end(),
input2.begin(),
output.begin(),
static_cast<const Derived&>(*this).apply());
});
return result;
}
};
......
......@@ -2,13 +2,11 @@
#define MIGRAPHX_GUARD_OPERATORS_BROADCAST_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath>
#include <utility>
......@@ -32,36 +30,42 @@ struct broadcast
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"), f(self.broadcast_lens, "dims"));
return pack(f(self.axis, "axis"), f(self.broadcast_lens, "out_lens"));
}
std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto t = inputs.at(0).type();
auto input = inputs.at(0);
auto t = input.type();
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
// the broacast op is deprecated now, so not handling the negative
// value of axis anymore
if(axis >= broadcast_lens.size())
{
MIGRAPHX_THROW("BROADCAST : axis is out of range");
}
if(std::all_of(
broadcast_lens.cbegin(), broadcast_lens.cend(), [&](auto x) { return x == 1; }))
if(broadcast_lens.size() - axis < input.lens().size())
{
if(axis != 0)
MIGRAPHX_THROW("BROADCAST: when broadcasting tensor of size 1, axis should be 0");
return {t, broadcast_lens, std::move(bcast_strides)};
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than input ndims");
}
else
if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
{
assert(broadcast_lens.size() - axis >= input.lens().size());
if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
MIGRAPHX_THROW("BROADCAST: when broadcasting success sizes must match");
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
return {t, broadcast_lens, std::move(bcast_strides)};
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
}
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < input.elements())
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to input size");
return output;
}
argument compute(shape output_shape, std::vector<argument> args) const
{
return {std::move(output_shape), std::move(args.at(0).data)};
return args[0].reshape(output_shape);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -2,13 +2,13 @@
#define MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/context.hpp>
#include <cmath>
#include <utility>
......@@ -30,7 +30,9 @@ struct capture
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(const shape&, std::vector<argument> args) const
// the context argument is added to prevent the op from be eliminated by
// constant propagation
argument compute(context&, const shape&, const std::vector<argument>& args) const
{
if(f)
{
......@@ -43,6 +45,8 @@ struct capture
return args.front();
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
......
......@@ -3,12 +3,11 @@
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
......@@ -18,29 +17,31 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct clip : unary<clip>
struct clip
{
float max_val = std::numeric_limits<float>::max();
float min_val = std::numeric_limits<float>::min();
std::string name() const { return "clip"; }
clip() {}
clip(float max, float min) : max_val(max), min_val(min) {}
value attributes() const
{
return {{"pointwise", true},
{"point_op", "${function:min}(${function:max}(${1}, ${0}), ${2})"}};
}
auto apply() const
shape compute_shape(std::vector<shape> inputs) const
{
auto max = max_val;
auto min = min_val;
return [max, min](auto x) {
using type = decltype(x);
return std::min(std::max(type(min), x), type(max));
};
check_shapes{inputs, *this}.has(3).same_type().same_dims();
return inputs.front();
}
template <class Self, class F>
static auto reflect(Self& self, F f)
argument compute(const shape& output_shape, std::vector<argument> args) const
{
return pack(f(self.max_val, "max"), f(self.min_val, "min"));
argument result{output_shape};
visit_all(result, args[0], args[1], args[2])([&](auto output, auto x, auto min, auto max) {
par_for(output_shape.elements(),
[&](auto i) { output[i] = std::min(std::max(min[i], x[i]), max[i]); });
});
return result;
}
};
......
#ifndef MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#define MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <ostream>
#include <vector>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
......@@ -23,6 +17,15 @@ enum padding_mode_t
valid
};
// The pooling modes must correspond 1-1 to the operators defined for struct parse_pooling.
// Used in pooling and roialign operators.
enum class pooling_mode
{
average,
max,
lpnorm
};
// indicate rnn computation direction
enum class rnn_direction
{
......@@ -31,6 +34,7 @@ enum class rnn_direction
bidirectional,
};
std::ostream& operator<<(std::ostream& os, pooling_mode v);
std::ostream& operator<<(std::ostream& os, rnn_direction v);
} // namespace op
......
......@@ -2,15 +2,18 @@
#define MIGRAPHX_GUARD_OPERATORS_CONCAT_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
#include <migraphx/tune_axis.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -26,23 +29,29 @@ struct concat
return pack(f(self.axis, "axis"));
}
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "concat"; }
std::vector<std::size_t> compute_offsets(const shape& output_shape,
const std::vector<argument>& args) const
{
auto n_dims = args[0].get_shape().lens().size();
std::size_t axis_index = (axis < 0) ? axis + n_dims : axis;
auto n_dims = args[0].get_shape().lens().size();
std::vector<std::size_t> offsets;
std::vector<std::size_t> offset(n_dims, 0);
offset[axis_index] = 0;
offset[axis] = 0;
for(const auto& arg : args)
{
offsets.push_back(output_shape.index(offset));
offset[axis_index] += arg.get_shape().lens()[axis_index];
offset[axis] += arg.get_shape().lens()[axis];
}
return offsets;
}
shape compute_shape(std::vector<shape> inputs) const
shape normalize_compute_shape(std::vector<shape> inputs) const
{
if(inputs.empty())
{
......@@ -51,10 +60,9 @@ struct concat
const auto& first_shape_lens = inputs.front().lens();
const auto& type = inputs.front().type();
std::size_t axis_index = (axis < 0) ? (first_shape_lens.size() + axis) : axis;
for(std::size_t l = 0; l < first_shape_lens.size(); l++)
{
if(l != axis_index)
if(l != axis)
{
if(!std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
return s.lens()[l] == first_shape_lens[l];
......@@ -68,12 +76,12 @@ struct concat
for(const auto& input : inputs)
{
const auto& lens = input.lens();
new_dim_axis += lens[axis_index];
new_dim_axis += lens[axis];
}
std::vector<std::size_t> new_lens;
std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens));
new_lens[axis_index] = new_dim_axis;
return {type, new_lens};
new_lens[axis] = new_dim_axis;
return shape::from_permutation(type, new_lens, find_permutation(inputs));
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
......@@ -81,17 +89,12 @@ struct concat
std::vector<std::size_t> coffsets = compute_offsets(output_shape, args);
for(std::size_t l = 0; l < args.size(); l++)
{
auto argl = args[l];
std::size_t nelements = argl.get_shape().elements();
auto argl = args[l];
visit_all(result, argl)([&](auto output, auto input) {
auto slice_shape =
shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()};
auto slice = make_view(slice_shape, output.data() + coffsets[l]);
// cppcheck-suppress useStlAlgorithm
for(std::size_t i = 0; i < nelements; i++)
{
slice[i] = input[i];
}
std::copy(input.begin(), input.end(), slice.begin());
});
}
return result;
......
......@@ -2,7 +2,6 @@
#define MIGRAPHX_GUARD_OPERATORS_CONTIGUOUS_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......@@ -28,6 +27,8 @@ struct contiguous
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
if(inputs.front().standard())
return inputs.front();
auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type();
return {t, lens};
......@@ -43,6 +44,11 @@ struct contiguous
});
return result;
}
auto apply() const
{
return [](auto x) { return x; };
}
};
} // namespace op
......
......@@ -3,7 +3,6 @@
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......@@ -33,9 +32,19 @@ struct convert : unary<convert>
return {target_type, inputs.at(0).lens(), inputs.at(0).strides()};
}
std::string point_op() const
{
return "${function:convert}<" + shape::cpp_type(target_type) + ">(${0})";
}
auto apply() const
{
return [](auto x) { return x; };
auto type = target_type;
return [type](auto x) {
auto y = x;
shape::visit(type, [&](auto as) { y = std::min(std::max(as(x), as.min()), as.max()); });
return y;
};
}
convert(shape::type_t t) : target_type{t} {}
......
......@@ -3,13 +3,14 @@
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
......@@ -19,12 +20,12 @@ namespace op {
struct convolution
{
std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> dilation = {{1, 1}};
std::vector<std::size_t> padding = {0, 0};
std::vector<std::size_t> stride = {1, 1};
std::vector<std::size_t> dilation = {1, 1};
padding_mode_t padding_mode = default_;
int group = 1;
padding_mode_t padding_mode = default_;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -32,36 +33,68 @@ struct convolution
return pack(f(self.padding, "padding"),
f(self.stride, "stride"),
f(self.dilation, "dilation"),
f(self.padding_mode, "padding_mode"),
f(self.group, "group"));
f(self.group, "group"),
f(self.padding_mode, "padding_mode"));
}
std::string name() const { return "convolution"; }
shape compute_shape(std::vector<shape> inputs) const
void check_attribute_size() const
{
if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == dilation.size()))
{
MIGRAPHX_THROW("CONVOLUTION: inconsistent attribute sizes");
}
}
value attributes() const { return {{"normalize_padding", "padding"}}; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4);
check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3);
check_attribute_size();
// dim num of input and attribute should match
auto input_size = inputs[0].lens().size();
auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2))
{
MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!");
}
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
auto t = input.type();
return {t,
{
input.lens()[0],
weights.lens()[0],
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
2 * padding[0]) /
stride[0] +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
2 * padding[1]) /
stride[1] +
1)),
}};
size_t kdims = input_size - 2;
if(kdims != this->kdims())
{
MIGRAPHX_THROW("convolution: input k-dims does not match attribute size");
}
if(input.lens().at(1) != (weights.lens().at(1) * group))
MIGRAPHX_THROW("CONVOLUTION: Mismatch channel numbers");
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
for(size_t i = 0; i < kdims; i++)
{
auto padding_factor = 2 * padding[i];
if(padding_size == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) +
padding_factor) /
stride[i] +
1)));
}
return inputs[0].with_lens(output_lens);
}
size_t kdims() const
{
check_attribute_size();
return stride.size();
}
};
......
......@@ -3,7 +3,6 @@
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......
......@@ -3,7 +3,6 @@
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......
#ifndef MIGRAPHX_GUARD_OPERATORS_DECONVOLUTION_HPP
#define MIGRAPHX_GUARD_OPERATORS_DECONVOLUTION_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/par_dfor.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct deconvolution
{
std::vector<std::size_t> padding = {0, 0};
std::vector<std::size_t> stride = {1, 1};
std::vector<std::size_t> dilation = {1, 1};
padding_mode_t padding_mode = default_;
int group = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.padding, "padding"),
f(self.stride, "stride"),
f(self.dilation, "dilation"),
f(self.padding_mode, "padding_mode"),
f(self.group, "group"));
}
std::string name() const { return "deconvolution"; }
void check_attribute_size() const
{
if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == dilation.size()))
{
MIGRAPHX_THROW("deconvolution: inconsistent attribute sizes");
}
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3);
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
size_t kdims = input.lens().size() - 2;
if(kdims != this->kdims())
{
MIGRAPHX_THROW("deconvolution: input k-dims does not match attribute size");
}
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[1]};
for(size_t i = 0; i < kdims; i++)
{
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1,
stride[i] * (input.lens()[i + 2] - 1) +
((weights.lens()[i + 2] - 1) * dilation[i] + 1) - 2 * padding[i])));
}
return inputs[0].with_lens(output_lens);
}
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto kdims = this->kdims();
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
using type = typename decltype(output)::value_type;
std::fill(output.begin(), output.end(), type{0});
auto in_lens = input.get_shape().lens();
auto in_n = in_lens[0];
auto in_c = in_lens[1];
auto wei = weights.get_shape().lens();
auto wei_n = wei[0];
auto wei_c = wei[1];
auto out_lens = output_shape.lens();
std::vector<std::size_t> win_size{in_c};
std::copy(in_lens.begin() + 2, in_lens.end(), std::back_inserter(win_size));
std::copy(wei.begin() + 2, wei.end(), std::back_inserter(win_size));
shape win_shape{output_shape.type(), win_size};
par_dfor(in_n, wei_c)([&](int o, int k) {
shape_for_each(win_shape, [&](auto idx_win) {
const int w = idx_win[0];
auto input_dims_start = idx_win.begin() + 1;
auto wei_dims_start = idx_win.begin() + kdims + 1;
std::vector<std::ptrdiff_t> win_start;
for(std::size_t n = 0; n < kdims; ++n)
{
win_start.push_back(std::ptrdiff_t(*(input_dims_start + n) * stride[n]) -
std::ptrdiff_t(padding[n]));
}
const int group_id = w / (wei_n / group);
const int in_ch = group_id * wei_c + k;
std::vector<std::ptrdiff_t> idx_out{o, in_ch};
for(size_t n = 0; n < kdims; n++)
{
idx_out.push_back(win_start[n] + *(wei_dims_start + n) * dilation[n]);
}
std::vector<std::ptrdiff_t> idx_wei{w, k};
std::copy(wei_dims_start, idx_win.end(), std::back_inserter(idx_wei));
std::vector<std::ptrdiff_t> idx_in{o, w};
std::copy(input_dims_start, wei_dims_start, std::back_inserter(idx_in));
if(std::all_of(
idx_out.begin() + 2, idx_out.end(), [&](auto ii) { return ii >= 0; }) and
std::equal(idx_out.begin() + 2,
idx_out.end(),
out_lens.begin() + 2,
out_lens.end(),
std::less<std::ptrdiff_t>{}))
{
output(idx_out.begin(), idx_out.end()) +=
input(idx_in.begin(), idx_in.end()) *
weights(idx_wei.begin(), idx_wei.end());
}
});
});
});
return result;
}
size_t kdims() const
{
check_attribute_size();
return stride.size();
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_DEQUANTIZE_LINEAR_HPP
#define MIGRAPHX_GUARD_OPERATORS_DEQUANTIZE_LINEAR_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/config.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/tune_axis.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct dequantizelinear
{
std::string name() const { return "dequantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.same_dims();
return {inputs[1].type(), inputs[0].lens(), inputs[0].strides()};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto x = args.at(0);
auto x_scale = args.at(1);
std::vector<int8_t> zeros(output_shape.bytes(), 0);
argument x_zero_point{{x.get_shape().type(), output_shape.lens()}, zeros.data()};
if(args.size() == 3)
{
x_zero_point = args.at(2);
}
argument result{output_shape};
visit_all(x, x_zero_point)([&](auto input, auto zero_pts) {
visit_all(result, x_scale)([&](auto output, auto scales) {
par_for(output_shape.elements(), [&](auto i) {
output[i] = static_cast<double>(static_cast<int64_t>(input[i]) -
static_cast<int64_t>(zero_pts[i])) *
scales[i];
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -3,7 +3,6 @@
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......@@ -19,6 +18,7 @@ namespace op {
struct div : binary<div>
{
std::string point_function() const { return "/"; }
auto apply() const
{
return [](auto x, auto y) { return x / y; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment