Unverified Commit 1b098fd7 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into type-string-driver

parents 05f2ee1c c0398ded
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
#define MIGRAPHX_GUARD_OPERATORS_AS_SHAPE_HPP #define MIGRAPHX_GUARD_OPERATORS_AS_SHAPE_HPP
#include <array> #include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -34,7 +34,7 @@ struct as_shape ...@@ -34,7 +34,7 @@ struct as_shape
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return args.front().reshape(output_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <array> #include <array>
#include <migraphx/op/unary.hpp> #include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ASINH_HPP
#define MIGRAPHX_GUARD_OPERATORS_ASINH_HPP
#include <migraphx/op/unary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct asinh : unary<asinh>
{
auto apply() const
{
return [](auto x) { return std::asinh(x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <array> #include <array>
#include <migraphx/op/unary.hpp> #include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ATANH_HPP
#define MIGRAPHX_GUARD_OPERATORS_ATANH_HPP
#include <migraphx/op/unary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct atanh : unary<atanh>
{
auto apply() const
{
return [](auto x) { return std::atanh(x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <array> #include <array>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
...@@ -42,9 +41,8 @@ struct batch_norm_inference ...@@ -42,9 +41,8 @@ struct batch_norm_inference
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(5); check_shapes{inputs, *this}.has(5);
check_shapes{inputs.data(), inputs.data() + 1, *this}.only_dims(4); check_shapes{inputs.data(), inputs.data() + 1, *this}.same_ndims();
check_shapes{inputs.data() + 1, inputs.data() + inputs.size(), *this}.same_shape().elements( check_shapes{inputs.data() + 1, inputs.data() + inputs.size(), *this}.same_shape();
inputs.front().lens()[1]);
return inputs.front(); return inputs.front();
} }
}; };
......
...@@ -2,6 +2,11 @@ ...@@ -2,6 +2,11 @@
#define MIGRAPHX_GUARD_OPERATORS_BINARY_HPP #define MIGRAPHX_GUARD_OPERATORS_BINARY_HPP
#include <migraphx/op/name.hpp> #include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -10,15 +15,45 @@ namespace op { ...@@ -10,15 +15,45 @@ namespace op {
template <class Derived> template <class Derived>
struct binary : op_name<Derived> struct binary : op_name<Derived>
{ {
std::string point_function() const { return this->name(); }
std::string point_op() const
{
const auto& self = static_cast<const Derived&>(*this);
auto pf = self.point_function();
if(pf.empty())
return {};
if(with_char(::ispunct)(pf.front()))
{
return "${0} " + pf + " ${1}";
}
else
{
return "${function:" + pf + "}(${0}, ${1})";
}
}
value base_attributes() const
{
const auto& self = static_cast<const Derived&>(*this);
return {{"pointwise", true}, {"point_op", self.point_op()}};
}
value attributes() const { return base_attributes(); }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(2).same_type().same_dims(); check_shapes{inputs, static_cast<const Derived&>(*this)}.has(2).same_type().same_dims();
auto s0 = inputs.at(0); auto s0 = inputs.at(0);
auto s1 = inputs.at(1); auto s1 = inputs.at(1);
if(s0 == s1 and s0.packed()) if(s0 == s1 and s0.packed())
{ {
return s0; return s0;
} }
else if(s0.packed() != s1.packed())
{
return s0.packed() ? s0 : s1;
}
else if(s0.broadcasted() != s1.broadcasted())
{
return s0.broadcasted() ? s1.with_lens(s0.lens()) : s0.with_lens(s0.lens());
}
else else
{ {
return {s0.type(), s0.lens()}; return {s0.type(), s0.lens()};
...@@ -28,32 +63,13 @@ struct binary : op_name<Derived> ...@@ -28,32 +63,13 @@ struct binary : op_name<Derived>
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto s1 = args[0].get_shape(); visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
auto s2 = args[1].get_shape(); std::transform(input1.begin(),
if(s1 == s2 and s1.packed()) input1.end(),
{ input2.begin(),
shape std_shape{s1.type(), s1.lens()}; output.begin(),
argument std_result{std_shape, result.data()}; static_cast<const Derived&>(*this).apply());
argument std_arg0{std_shape, args[0].data()}; });
argument std_arg1{std_shape, args[1].data()};
visit_all(std_result, std_arg0, std_arg1)([&](auto output, auto input1, auto input2) {
std::transform(input1.begin(),
input1.end(),
input2.begin(),
output.begin(),
static_cast<const Derived&>(*this).apply());
});
}
else
{
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end()));
});
});
}
return result; return result;
} }
}; };
......
...@@ -2,13 +2,11 @@ ...@@ -2,13 +2,11 @@
#define MIGRAPHX_GUARD_OPERATORS_BROADCAST_HPP #define MIGRAPHX_GUARD_OPERATORS_BROADCAST_HPP
#include <array> #include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/argument.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/functional.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -32,36 +30,42 @@ struct broadcast ...@@ -32,36 +30,42 @@ struct broadcast
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.axis, "axis"), f(self.broadcast_lens, "dims")); return pack(f(self.axis, "axis"), f(self.broadcast_lens, "out_lens"));
} }
std::string name() const { return "broadcast"; } std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
auto t = inputs.at(0).type();
auto input = inputs.at(0); auto input = inputs.at(0);
auto t = input.type();
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0); std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
// the broacast op is deprecated now, so not handling the negative
// value of axis anymore
if(axis >= broadcast_lens.size())
{
MIGRAPHX_THROW("BROADCAST : axis is out of range");
}
if(std::all_of( if(broadcast_lens.size() - axis < input.lens().size())
broadcast_lens.cbegin(), broadcast_lens.cend(), [&](auto x) { return x == 1; }))
{ {
if(axis != 0) MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than input ndims");
MIGRAPHX_THROW("BROADCAST: when broadcasting tensor of size 1, axis should be 0");
return {t, broadcast_lens, std::move(bcast_strides)};
} }
else
if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
{ {
assert(broadcast_lens.size() - axis >= input.lens().size()); MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
MIGRAPHX_THROW("BROADCAST: when broadcasting success sizes must match");
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
return {t, broadcast_lens, std::move(bcast_strides)};
} }
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < input.elements())
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to input size");
return output;
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return args[0].reshape(output_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
#define MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP #define MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP
#include <array> #include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/context.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -30,7 +30,9 @@ struct capture ...@@ -30,7 +30,9 @@ struct capture
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); } shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(const shape&, std::vector<argument> args) const // the context argument is added to prevent the op from be eliminated by
// constant propagation
argument compute(context&, const shape&, const std::vector<argument>& args) const
{ {
if(f) if(f)
{ {
...@@ -43,6 +45,8 @@ struct capture ...@@ -43,6 +45,8 @@ struct capture
return args.front(); return args.front();
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
} // namespace op } // namespace op
......
...@@ -3,12 +3,11 @@ ...@@ -3,12 +3,11 @@
#include <array> #include <array>
#include <migraphx/op/unary.hpp> #include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -18,29 +17,31 @@ namespace migraphx { ...@@ -18,29 +17,31 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct clip : unary<clip> struct clip
{ {
float max_val = std::numeric_limits<float>::max(); std::string name() const { return "clip"; }
float min_val = std::numeric_limits<float>::min();
clip() {} value attributes() const
{
clip(float max, float min) : max_val(max), min_val(min) {} return {{"pointwise", true},
{"point_op", "${function:min}(${function:max}(${1}, ${0}), ${2})"}};
}
auto apply() const shape compute_shape(std::vector<shape> inputs) const
{ {
auto max = max_val; check_shapes{inputs, *this}.has(3).same_type().same_dims();
auto min = min_val; return inputs.front();
return [max, min](auto x) {
using type = decltype(x);
return std::min(std::max(type(min), x), type(max));
};
} }
template <class Self, class F> argument compute(const shape& output_shape, std::vector<argument> args) const
static auto reflect(Self& self, F f)
{ {
return pack(f(self.max_val, "max"), f(self.min_val, "min")); argument result{output_shape};
visit_all(result, args[0], args[1], args[2])([&](auto output, auto x, auto min, auto max) {
par_for(output_shape.elements(),
[&](auto i) { output[i] = std::min(std::max(min[i], x[i]), max[i]); });
});
return result;
} }
}; };
......
#ifndef MIGRAPHX_GUARD_OPERATORS_COMMON_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#define MIGRAPHX_GUARD_OPERATORS_COMMON_HPP #define MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#include <array> #include <ostream>
#include <migraphx/operation.hpp> #include <vector>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <cmath>
#include <utility> #include <utility>
namespace migraphx { namespace migraphx {
...@@ -23,6 +17,15 @@ enum padding_mode_t ...@@ -23,6 +17,15 @@ enum padding_mode_t
valid valid
}; };
// The pooling modes must correspond 1-1 to the operators defined for struct parse_pooling.
// Used in pooling and roialign operators.
enum class pooling_mode
{
average,
max,
lpnorm
};
// indicate rnn computation direction // indicate rnn computation direction
enum class rnn_direction enum class rnn_direction
{ {
...@@ -31,6 +34,7 @@ enum class rnn_direction ...@@ -31,6 +34,7 @@ enum class rnn_direction
bidirectional, bidirectional,
}; };
std::ostream& operator<<(std::ostream& os, pooling_mode v);
std::ostream& operator<<(std::ostream& os, rnn_direction v); std::ostream& operator<<(std::ostream& os, rnn_direction v);
} // namespace op } // namespace op
......
...@@ -2,15 +2,18 @@ ...@@ -2,15 +2,18 @@
#define MIGRAPHX_GUARD_OPERATORS_CONCAT_HPP #define MIGRAPHX_GUARD_OPERATORS_CONCAT_HPP
#include <array> #include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
#include <migraphx/tune_axis.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -26,23 +29,29 @@ struct concat ...@@ -26,23 +29,29 @@ struct concat
return pack(f(self.axis, "axis")); return pack(f(self.axis, "axis"));
} }
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "concat"; } std::string name() const { return "concat"; }
std::vector<std::size_t> compute_offsets(const shape& output_shape, std::vector<std::size_t> compute_offsets(const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
auto n_dims = args[0].get_shape().lens().size(); auto n_dims = args[0].get_shape().lens().size();
std::size_t axis_index = (axis < 0) ? axis + n_dims : axis;
std::vector<std::size_t> offsets; std::vector<std::size_t> offsets;
std::vector<std::size_t> offset(n_dims, 0); std::vector<std::size_t> offset(n_dims, 0);
offset[axis_index] = 0; offset[axis] = 0;
for(const auto& arg : args) for(const auto& arg : args)
{ {
offsets.push_back(output_shape.index(offset)); offsets.push_back(output_shape.index(offset));
offset[axis_index] += arg.get_shape().lens()[axis_index]; offset[axis] += arg.get_shape().lens()[axis];
} }
return offsets; return offsets;
} }
shape compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
if(inputs.empty()) if(inputs.empty())
{ {
...@@ -51,10 +60,9 @@ struct concat ...@@ -51,10 +60,9 @@ struct concat
const auto& first_shape_lens = inputs.front().lens(); const auto& first_shape_lens = inputs.front().lens();
const auto& type = inputs.front().type(); const auto& type = inputs.front().type();
std::size_t axis_index = (axis < 0) ? (first_shape_lens.size() + axis) : axis;
for(std::size_t l = 0; l < first_shape_lens.size(); l++) for(std::size_t l = 0; l < first_shape_lens.size(); l++)
{ {
if(l != axis_index) if(l != axis)
{ {
if(!std::all_of(inputs.begin(), inputs.end(), [&](auto s) { if(!std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
return s.lens()[l] == first_shape_lens[l]; return s.lens()[l] == first_shape_lens[l];
...@@ -68,12 +76,12 @@ struct concat ...@@ -68,12 +76,12 @@ struct concat
for(const auto& input : inputs) for(const auto& input : inputs)
{ {
const auto& lens = input.lens(); const auto& lens = input.lens();
new_dim_axis += lens[axis_index]; new_dim_axis += lens[axis];
} }
std::vector<std::size_t> new_lens; std::vector<std::size_t> new_lens;
std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens)); std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens));
new_lens[axis_index] = new_dim_axis; new_lens[axis] = new_dim_axis;
return {type, new_lens}; return shape::from_permutation(type, new_lens, find_permutation(inputs));
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
...@@ -81,17 +89,12 @@ struct concat ...@@ -81,17 +89,12 @@ struct concat
std::vector<std::size_t> coffsets = compute_offsets(output_shape, args); std::vector<std::size_t> coffsets = compute_offsets(output_shape, args);
for(std::size_t l = 0; l < args.size(); l++) for(std::size_t l = 0; l < args.size(); l++)
{ {
auto argl = args[l]; auto argl = args[l];
std::size_t nelements = argl.get_shape().elements();
visit_all(result, argl)([&](auto output, auto input) { visit_all(result, argl)([&](auto output, auto input) {
auto slice_shape = auto slice_shape =
shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()}; shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()};
auto slice = make_view(slice_shape, output.data() + coffsets[l]); auto slice = make_view(slice_shape, output.data() + coffsets[l]);
// cppcheck-suppress useStlAlgorithm std::copy(input.begin(), input.end(), slice.begin());
for(std::size_t i = 0; i < nelements; i++)
{
slice[i] = input[i];
}
}); });
} }
return result; return result;
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#define MIGRAPHX_GUARD_OPERATORS_CONTIGUOUS_HPP #define MIGRAPHX_GUARD_OPERATORS_CONTIGUOUS_HPP
#include <array> #include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
...@@ -28,6 +27,8 @@ struct contiguous ...@@ -28,6 +27,8 @@ struct contiguous
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
if(inputs.front().standard())
return inputs.front();
auto lens = inputs.at(0).lens(); auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
return {t, lens}; return {t, lens};
...@@ -43,6 +44,11 @@ struct contiguous ...@@ -43,6 +44,11 @@ struct contiguous
}); });
return result; return result;
} }
auto apply() const
{
return [](auto x) { return x; };
}
}; };
} // namespace op } // namespace op
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <array> #include <array>
#include <migraphx/op/unary.hpp> #include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
...@@ -33,9 +32,19 @@ struct convert : unary<convert> ...@@ -33,9 +32,19 @@ struct convert : unary<convert>
return {target_type, inputs.at(0).lens(), inputs.at(0).strides()}; return {target_type, inputs.at(0).lens(), inputs.at(0).strides()};
} }
std::string point_op() const
{
return "${function:convert}<" + shape::cpp_type(target_type) + ">(${0})";
}
auto apply() const auto apply() const
{ {
return [](auto x) { return x; }; auto type = target_type;
return [type](auto x) {
auto y = x;
shape::visit(type, [&](auto as) { y = std::min(std::max(as(x), as.min()), as.max()); });
return y;
};
} }
convert(shape::type_t t) : target_type{t} {} convert(shape::type_t t) : target_type{t} {}
......
...@@ -3,13 +3,14 @@ ...@@ -3,13 +3,14 @@
#include <array> #include <array>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -19,12 +20,12 @@ namespace op { ...@@ -19,12 +20,12 @@ namespace op {
struct convolution struct convolution
{ {
std::array<std::size_t, 2> padding = {{0, 0}}; std::vector<std::size_t> padding = {0, 0};
std::array<std::size_t, 2> stride = {{1, 1}}; std::vector<std::size_t> stride = {1, 1};
std::array<std::size_t, 2> dilation = {{1, 1}}; std::vector<std::size_t> dilation = {1, 1};
padding_mode_t padding_mode = default_;
int group = 1; int group = 1;
padding_mode_t padding_mode = default_;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -32,36 +33,68 @@ struct convolution ...@@ -32,36 +33,68 @@ struct convolution
return pack(f(self.padding, "padding"), return pack(f(self.padding, "padding"),
f(self.stride, "stride"), f(self.stride, "stride"),
f(self.dilation, "dilation"), f(self.dilation, "dilation"),
f(self.padding_mode, "padding_mode"), f(self.group, "group"),
f(self.group, "group")); f(self.padding_mode, "padding_mode"));
} }
std::string name() const { return "convolution"; } std::string name() const { return "convolution"; }
shape compute_shape(std::vector<shape> inputs) const
void check_attribute_size() const
{
if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == dilation.size()))
{
MIGRAPHX_THROW("CONVOLUTION: inconsistent attribute sizes");
}
}
value attributes() const { return {{"normalize_padding", "padding"}}; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4); check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3);
check_attribute_size();
// dim num of input and attribute should match
auto input_size = inputs[0].lens().size();
auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2))
{
MIGRAPHX_THROW("CONVOLUTION: input and attribute size mismatch!");
}
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
const shape& weights = inputs.at(1); const shape& weights = inputs.at(1);
auto t = input.type(); size_t kdims = input_size - 2;
if(kdims != this->kdims())
return {t, {
{ MIGRAPHX_THROW("convolution: input k-dims does not match attribute size");
input.lens()[0], }
weights.lens()[0],
std::size_t(std::max<std::ptrdiff_t>( if(input.lens().at(1) != (weights.lens().at(1) * group))
1, MIGRAPHX_THROW("CONVOLUTION: Mismatch channel numbers");
(input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
2 * padding[0]) / std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
stride[0] +
1)), for(size_t i = 0; i < kdims; i++)
std::size_t(std::max<std::ptrdiff_t>( {
1, auto padding_factor = 2 * padding[i];
(input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) + if(padding_size == 2 * kdims)
2 * padding[1]) / padding_factor = padding[i] + padding[i + kdims];
stride[1] + output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1)), 1,
}}; (input.lens()[i + 2] - (1 + dilation[i] * (weights.lens()[i + 2] - 1)) +
padding_factor) /
stride[i] +
1)));
}
return inputs[0].with_lens(output_lens);
}
size_t kdims() const
{
check_attribute_size();
return stride.size();
} }
}; };
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <array> #include <array>
#include <migraphx/op/unary.hpp> #include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <array> #include <array>
#include <migraphx/op/unary.hpp> #include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
......
#ifndef MIGRAPHX_GUARD_OPERATORS_DECONVOLUTION_HPP
#define MIGRAPHX_GUARD_OPERATORS_DECONVOLUTION_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/par_dfor.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct deconvolution
{
std::vector<std::size_t> padding = {0, 0};
std::vector<std::size_t> stride = {1, 1};
std::vector<std::size_t> dilation = {1, 1};
padding_mode_t padding_mode = default_;
int group = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.padding, "padding"),
f(self.stride, "stride"),
f(self.dilation, "dilation"),
f(self.padding_mode, "padding_mode"),
f(self.group, "group"));
}
std::string name() const { return "deconvolution"; }
void check_attribute_size() const
{
if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and
stride.size() == dilation.size()))
{
MIGRAPHX_THROW("deconvolution: inconsistent attribute sizes");
}
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2).same_type().same_ndims().min_ndims(3);
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
size_t kdims = input.lens().size() - 2;
if(kdims != this->kdims())
{
MIGRAPHX_THROW("deconvolution: input k-dims does not match attribute size");
}
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[1]};
for(size_t i = 0; i < kdims; i++)
{
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1,
stride[i] * (input.lens()[i + 2] - 1) +
((weights.lens()[i + 2] - 1) * dilation[i] + 1) - 2 * padding[i])));
}
return inputs[0].with_lens(output_lens);
}
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto kdims = this->kdims();
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
using type = typename decltype(output)::value_type;
std::fill(output.begin(), output.end(), type{0});
auto in_lens = input.get_shape().lens();
auto in_n = in_lens[0];
auto in_c = in_lens[1];
auto wei = weights.get_shape().lens();
auto wei_n = wei[0];
auto wei_c = wei[1];
auto out_lens = output_shape.lens();
std::vector<std::size_t> win_size{in_c};
std::copy(in_lens.begin() + 2, in_lens.end(), std::back_inserter(win_size));
std::copy(wei.begin() + 2, wei.end(), std::back_inserter(win_size));
shape win_shape{output_shape.type(), win_size};
par_dfor(in_n, wei_c)([&](int o, int k) {
shape_for_each(win_shape, [&](auto idx_win) {
const int w = idx_win[0];
auto input_dims_start = idx_win.begin() + 1;
auto wei_dims_start = idx_win.begin() + kdims + 1;
std::vector<std::ptrdiff_t> win_start;
for(std::size_t n = 0; n < kdims; ++n)
{
win_start.push_back(std::ptrdiff_t(*(input_dims_start + n) * stride[n]) -
std::ptrdiff_t(padding[n]));
}
const int group_id = w / (wei_n / group);
const int in_ch = group_id * wei_c + k;
std::vector<std::ptrdiff_t> idx_out{o, in_ch};
for(size_t n = 0; n < kdims; n++)
{
idx_out.push_back(win_start[n] + *(wei_dims_start + n) * dilation[n]);
}
std::vector<std::ptrdiff_t> idx_wei{w, k};
std::copy(wei_dims_start, idx_win.end(), std::back_inserter(idx_wei));
std::vector<std::ptrdiff_t> idx_in{o, w};
std::copy(input_dims_start, wei_dims_start, std::back_inserter(idx_in));
if(std::all_of(
idx_out.begin() + 2, idx_out.end(), [&](auto ii) { return ii >= 0; }) and
std::equal(idx_out.begin() + 2,
idx_out.end(),
out_lens.begin() + 2,
out_lens.end(),
std::less<std::ptrdiff_t>{}))
{
output(idx_out.begin(), idx_out.end()) +=
input(idx_in.begin(), idx_in.end()) *
weights(idx_wei.begin(), idx_wei.end());
}
});
});
});
return result;
}
size_t kdims() const
{
check_attribute_size();
return stride.size();
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_DEQUANTIZE_LINEAR_HPP
#define MIGRAPHX_GUARD_OPERATORS_DEQUANTIZE_LINEAR_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/config.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/tune_axis.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct dequantizelinear
{
std::string name() const { return "dequantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.same_dims();
return {inputs[1].type(), inputs[0].lens(), inputs[0].strides()};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto x = args.at(0);
auto x_scale = args.at(1);
std::vector<int8_t> zeros(output_shape.bytes(), 0);
argument x_zero_point{{x.get_shape().type(), output_shape.lens()}, zeros.data()};
if(args.size() == 3)
{
x_zero_point = args.at(2);
}
argument result{output_shape};
visit_all(x, x_zero_point)([&](auto input, auto zero_pts) {
visit_all(result, x_scale)([&](auto output, auto scales) {
par_for(output_shape.elements(), [&](auto i) {
output[i] = static_cast<double>(static_cast<int64_t>(input[i]) -
static_cast<int64_t>(zero_pts[i])) *
scales[i];
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <array> #include <array>
#include <migraphx/op/binary.hpp> #include <migraphx/op/binary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
...@@ -19,6 +18,7 @@ namespace op { ...@@ -19,6 +18,7 @@ namespace op {
struct div : binary<div> struct div : binary<div>
{ {
std::string point_function() const { return "/"; }
auto apply() const auto apply() const
{ {
return [](auto x, auto y) { return x / y; }; return [](auto x, auto y) { return x / y; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment