Unverified Commit 8443ecd1 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Normalize ops (#667)



* add a pass to normalize ops

* clang format

* add unit tests

* clang format

* code backup

* clang format

* code backup

* clang format

* add support for slice in the normalize_op function

* clang format

* add operation method api for whether we need to call normalize_op

* clang format

* fix review comments

* clang format

* rename a function namejJ

* clang format

* change compute_shape to normalize_compute_shape for corresponding operators

* clang format

* remove unnecessary code

* fix various issues

* clang format

* add attributes to operators having axis attributes

* clang format

* fixed jenkins build error

* clang format

* fix a bug related to slice

* clang format

* code backup

* clang format

* code backup

* clang format

* rename a file

* fix cppcheck error

* some code refinement

* clang format

* change attributes to enum

* clang format

* refine the enum

* clang format

* remove unnecessary code

* add unit tests for more code coverage and fixed a bug

* clang format

* remove unnecessary changes

* change normalize_axes to normalize

* clang format

* revert back the changes in broadcast.hpp

* rename normalize_axes to normalize

* fix review comments

* clang format

* Add flag to enable cpu backend

* Make buffers shared

* Enable optimizations

* Formatting

* Try to avoid ambiguous assign in value class

* fixed a build error

* clang format

* add the normalize_ops pass to the ref target

* refactor program to module to normalize_ops pass
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent f8b56a66
...@@ -45,6 +45,8 @@ add_library(migraphx ...@@ -45,6 +45,8 @@ add_library(migraphx
convert_to_json.cpp convert_to_json.cpp
opt/memory_coloring.cpp opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp opt/memory_coloring_impl.cpp
normalize_attributes.cpp
normalize_ops.cpp
) )
rocm_set_soversion(migraphx ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx ${MIGRAPHX_SO_VERSION})
function(register_migraphx_ops) function(register_migraphx_ops)
......
#ifndef MIGRAPHX_GUARD_RTGLIB_NORMALIZE_ATTRIBUTES_HPP
#define MIGRAPHX_GUARD_RTGLIB_NORMALIZE_ATTRIBUTES_HPP
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
#include <cstring>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct operation;
template <class T, class...>
struct select_dependent_type
{
using type = T;
};
template <class T, class... Ts>
using dependent_type = typename select_dependent_type<T, Ts...>::type;
bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_NORMALIZE_OPS_HPP
#define MIGRAPHX_GUARD_RTGLIB_NORMALIZE_OPS_HPP
#include <string>
#include <vector>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/program.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
/**
* Process negative axis attributes of ops
*/
struct normalize_ops
{
std::string name() const { return "normalize_ops"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -21,14 +23,21 @@ struct argmax ...@@ -21,14 +23,21 @@ struct argmax
return pack(f(self.axis, "axis")); return pack(f(self.axis, "axis"));
} }
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "argmax"; } std::string name() const { return "argmax"; }
shape compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size()); int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < -n_dim) if(axis >= n_dim || axis < 0)
{ {
MIGRAPHX_THROW("ARGMAX: axis is out of range."); MIGRAPHX_THROW("ARGMAX: axis is out of range.");
} }
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -21,14 +23,21 @@ struct argmin ...@@ -21,14 +23,21 @@ struct argmin
return pack(f(self.axis, "axis")); return pack(f(self.axis, "axis"));
} }
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "argmin"; } std::string name() const { return "argmin"; }
shape compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size()); int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < -n_dim) if(axis >= n_dim || axis < 0)
{ {
MIGRAPHX_THROW("ARGMIN: axis is out of range."); MIGRAPHX_THROW("ARGMIN: axis is out of range.");
} }
......
...@@ -35,10 +35,16 @@ struct broadcast ...@@ -35,10 +35,16 @@ struct broadcast
std::string name() const { return "broadcast"; } std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
auto t = inputs.at(0).type();
auto input = inputs.at(0); auto input = inputs.at(0);
auto t = input.type();
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0); std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
// the broacast op is deprecated now, so not handling the negative
// value of axis anymore
if(axis >= broadcast_lens.size())
{
MIGRAPHX_THROW("BROADCAST : axis is out of range");
}
if(std::all_of( if(std::all_of(
broadcast_lens.cbegin(), broadcast_lens.cend(), [&](auto x) { return x == 1; })) broadcast_lens.cbegin(), broadcast_lens.cend(), [&](auto x) { return x == 1; }))
...@@ -49,9 +55,15 @@ struct broadcast ...@@ -49,9 +55,15 @@ struct broadcast
} }
else else
{ {
assert(broadcast_lens.size() - axis >= input.lens().size()); if(broadcast_lens.size() - axis < input.lens().size())
{
MIGRAPHX_THROW("BROADCAST: when broadcasting success sizes must match");
}
if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis)) if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
{
MIGRAPHX_THROW("BROADCAST: when broadcasting success sizes must match"); MIGRAPHX_THROW("BROADCAST: when broadcasting success sizes must match");
}
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis); std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
return {t, broadcast_lens, std::move(bcast_strides)}; return {t, broadcast_lens, std::move(bcast_strides)};
} }
......
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -25,6 +27,13 @@ struct concat ...@@ -25,6 +27,13 @@ struct concat
return pack(f(self.axis, "axis")); return pack(f(self.axis, "axis"));
} }
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "concat"; } std::string name() const { return "concat"; }
std::vector<std::size_t> compute_offsets(const shape& output_shape, std::vector<std::size_t> compute_offsets(const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
...@@ -41,7 +50,7 @@ struct concat ...@@ -41,7 +50,7 @@ struct concat
} }
return offsets; return offsets;
} }
shape compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
if(inputs.empty()) if(inputs.empty())
{ {
...@@ -50,10 +59,9 @@ struct concat ...@@ -50,10 +59,9 @@ struct concat
const auto& first_shape_lens = inputs.front().lens(); const auto& first_shape_lens = inputs.front().lens();
const auto& type = inputs.front().type(); const auto& type = inputs.front().type();
std::size_t axis_index = (axis < 0) ? (first_shape_lens.size() + axis) : axis;
for(std::size_t l = 0; l < first_shape_lens.size(); l++) for(std::size_t l = 0; l < first_shape_lens.size(); l++)
{ {
if(l != axis_index) if(l != axis)
{ {
if(!std::all_of(inputs.begin(), inputs.end(), [&](auto s) { if(!std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
return s.lens()[l] == first_shape_lens[l]; return s.lens()[l] == first_shape_lens[l];
...@@ -67,11 +75,11 @@ struct concat ...@@ -67,11 +75,11 @@ struct concat
for(const auto& input : inputs) for(const auto& input : inputs)
{ {
const auto& lens = input.lens(); const auto& lens = input.lens();
new_dim_axis += lens[axis_index]; new_dim_axis += lens[axis];
} }
std::vector<std::size_t> new_lens; std::vector<std::size_t> new_lens;
std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens)); std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens));
new_lens[axis_index] = new_dim_axis; new_lens[axis] = new_dim_axis;
return {type, new_lens}; return {type, new_lens};
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
......
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -25,23 +27,23 @@ struct flatten ...@@ -25,23 +27,23 @@ struct flatten
return pack(f(self.axis, "axis")); return pack(f(self.axis, "axis"));
} }
value attributes() const
{
value normalize;
normalize["axis"] =
value::array{normalize_attribute::include_min, normalize_attribute::include_max};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "flatten"; } std::string name() const { return "flatten"; }
shape compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
auto&& lens = inputs.front().lens(); auto&& lens = inputs.front().lens();
int64_t n_dim = static_cast<int64_t>(lens.size()); auto x =
if(axis > n_dim or axis < -n_dim) std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
{ auto y =
MIGRAPHX_THROW("FLATTEN: axis for flatten is out of range"); std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{});
}
auto tuned_axis = (axis < 0) ? axis + n_dim : axis;
auto x = std::accumulate(
lens.begin(), lens.begin() + tuned_axis, std::size_t{1}, std::multiplies<>{});
auto y = std::accumulate(
lens.begin() + tuned_axis, lens.end(), std::size_t{1}, std::multiplies<>{});
return {inputs.at(0).type(), {x, y}}; return {inputs.at(0).type(), {x, y}};
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
......
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -17,7 +19,7 @@ namespace op { ...@@ -17,7 +19,7 @@ namespace op {
struct gather struct gather
{ {
int axis = 0; int64_t axis = 0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -25,27 +27,25 @@ struct gather ...@@ -25,27 +27,25 @@ struct gather
return pack(f(self.axis, "axis")); return pack(f(self.axis, "axis"));
} }
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "gather"; } std::string name() const { return "gather"; }
shape compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2).standard(); check_shapes{inputs, *this}.has(2).standard();
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
int n_dim = static_cast<int>(lens.size());
if(axis >= n_dim || axis < -n_dim)
{
MIGRAPHX_THROW("Gather: axis is out of range.");
}
// negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (n_dim + axis) : axis;
auto type = inputs[0].type(); auto type = inputs[0].type();
lens.erase(lens.begin() + axis_index); lens.erase(lens.begin() + axis);
if(!inputs[1].scalar()) if(!inputs[1].scalar())
{ {
auto ind_lens = inputs[1].lens(); auto ind_lens = inputs[1].lens();
lens.insert(lens.begin() + axis_index, ind_lens.begin(), ind_lens.end()); lens.insert(lens.begin() + axis, ind_lens.begin(), ind_lens.end());
} }
// for scalar output // for scalar output
...@@ -61,10 +61,8 @@ struct gather ...@@ -61,10 +61,8 @@ struct gather
{ {
argument result{output_shape}; argument result{output_shape};
// negative axis means counting dimensions from back // negative axis means counting dimensions from back
auto lens = args[0].get_shape().lens(); auto lens = args[0].get_shape().lens();
int axis_index = (axis < 0) ? static_cast<int>(lens.size() + axis) : axis; std::size_t axis_dim_size = lens[axis];
std::size_t axis_dim_size = lens[axis_index];
// max dimension in axis // max dimension in axis
visit_all(result, args[0])([&](auto output, auto data) { visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) { args[1].visit([&](auto indices) {
...@@ -76,14 +74,14 @@ struct gather ...@@ -76,14 +74,14 @@ struct gather
} }
else else
{ {
auto out_lens = data.get_shape().lens(); auto out_lens = data.get_shape().lens();
out_lens[axis_index] = indices.get_shape().elements(); out_lens[axis] = indices.get_shape().elements();
migraphx::shape out_comp_shape{data.get_shape().type(), out_lens}; migraphx::shape out_comp_shape{data.get_shape().type(), out_lens};
shape_for_each(out_comp_shape, [&](const auto& out_idx) { shape_for_each(out_comp_shape, [&](const auto& out_idx) {
auto data_idx = out_idx; auto data_idx = out_idx;
auto in_index = indices[data_idx[axis_index]]; auto in_index = indices[data_idx[axis]];
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index; in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
data_idx[axis_index] = in_index; data_idx[axis] = in_index;
output[out_comp_shape.index(out_idx.begin(), out_idx.end())] = output[out_comp_shape.index(out_idx.begin(), out_idx.end())] =
data(data_idx.begin(), data_idx.end()); data(data_idx.begin(), data_idx.end());
}); });
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#define MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP #define MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -18,16 +20,17 @@ struct logsoftmax ...@@ -18,16 +20,17 @@ struct logsoftmax
return pack(f(self.axis, "axis")); return pack(f(self.axis, "axis"));
} }
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "logsoftmax"; } std::string name() const { return "logsoftmax"; }
shape compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1).standard();
int64_t n_dim = static_cast<int64_t>(inputs[0].lens().size());
if(axis < -n_dim || axis >= n_dim)
{
MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) +
" is out of range");
}
return inputs.at(0); return inputs.at(0);
} }
......
#ifndef MIGRAPHX_GUARD_OPERATORS_OP_NORMALIZE_ATTRIBUTE_HPP
#define MIGRAPHX_GUARD_OPERATORS_OP_NORMALIZE_ATTRIBUTE_HPP
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
// different attributes
// 1) use_input(default)/use_output
// 2) use_rank(default)/use_len
// 3) clip_min(default)/not_clip_min
// 3.1) include_min(default)/exclude_min
// 4) clip_max(default)/not_clip_max
// 4.1) exclude_max(default)/include_max
enum class normalize_attribute
{
use_len,
use_output,
clip_max,
clip_min,
include_max,
include_min
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <vector> #include <vector>
namespace migraphx { namespace migraphx {
...@@ -60,6 +62,13 @@ struct reduce_op : op_name<Derived> ...@@ -60,6 +62,13 @@ struct reduce_op : op_name<Derived>
return pack(f(self.axes, "axes")); return pack(f(self.axes, "axes"));
} }
value attributes() const
{
value normalize;
normalize["axes"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::vector<int64_t> tune_axes(std::size_t n_dim) const std::vector<int64_t> tune_axes(std::size_t n_dim) const
{ {
auto tuned_axes = axes; auto tuned_axes = axes;
...@@ -68,26 +77,11 @@ struct reduce_op : op_name<Derived> ...@@ -68,26 +77,11 @@ struct reduce_op : op_name<Derived>
tuned_axes.resize(n_dim); tuned_axes.resize(n_dim);
std::iota(tuned_axes.begin(), tuned_axes.end(), 0); std::iota(tuned_axes.begin(), tuned_axes.end(), 0);
} }
else
{
for(auto& axis : tuned_axes)
{
int64_t s_dim = static_cast<int64_t>(n_dim);
if(axis >= s_dim or axis < -s_dim)
{
MIGRAPHX_THROW("REDUCE_OP: axis out of range");
}
if(axis < 0)
{
axis += n_dim;
}
}
}
return tuned_axes; return tuned_axes;
} }
shape compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0); auto s = inputs.at(0);
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -25,6 +27,23 @@ struct slice ...@@ -25,6 +27,23 @@ struct slice
return pack(f(self.axes, "axes"), f(self.starts, "starts"), f(self.ends, "ends")); return pack(f(self.axes, "axes"), f(self.starts, "starts"), f(self.ends, "ends"));
} }
value attributes() const
{
value normalize = value::object{};
normalize["axes"] = value::array{normalize_attribute::include_min};
normalize["starts"] = value::array{normalize_attribute::clip_max,
normalize_attribute::clip_min,
normalize_attribute::include_max,
normalize_attribute::use_len,
normalize_attribute::include_min};
normalize["ends"] = value::array{normalize_attribute::clip_max,
normalize_attribute::clip_min,
normalize_attribute::include_max,
normalize_attribute::use_len,
normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "slice"; } std::string name() const { return "slice"; }
void tune_attributes(std::vector<int64_t>& tuned_axes, void tune_attributes(std::vector<int64_t>& tuned_axes,
...@@ -111,30 +130,34 @@ struct slice ...@@ -111,30 +130,34 @@ struct slice
return offset; return offset;
} }
shape compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto t = input_shape.type(); auto t = input_shape.type();
const auto& old_lens = input_shape.lens(); const auto& old_lens = input_shape.lens();
const auto& old_strides = input_shape.strides(); const auto& old_strides = input_shape.strides();
if(std::any_of(
axes.begin(), axes.end(), [&](auto i) { return (i >= old_lens.size() and i < 0); }))
{
MIGRAPHX_THROW("SLICE: input axis " + to_string_range(axes) + " out of range");
}
if(starts.size() != axes.size() || axes.size() != ends.size()) if(starts.size() != axes.size() || axes.size() != ends.size())
{ {
MIGRAPHX_THROW("SLICE: inconsistent sizes"); MIGRAPHX_THROW("SLICE: inconsistent sizes");
} }
std::vector<int64_t> tuned_axes = axes;
std::vector<int64_t> tuned_starts = starts;
std::vector<int64_t> tuned_ends = ends;
tune_attributes(tuned_axes, tuned_starts, tuned_ends, old_lens);
std::vector<std::size_t> new_lens = old_lens; std::vector<std::size_t> new_lens = old_lens;
for(std::size_t i = 0; i < tuned_axes.size(); i++) for(std::size_t i = 0; i < axes.size(); i++)
{ {
auto axis = tuned_axes[i]; auto axis = axes[i];
new_lens[axis] = fix_index(old_lens, axis, tuned_ends[i]) - new_lens[axis] =
fix_index(old_lens, axis, tuned_starts[i]); fix_index(old_lens, axis, ends[i]) - fix_index(old_lens, axis, starts[i]);
} }
return shape{t, new_lens, old_strides}; return shape{t, new_lens, old_strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
auto input = args[0]; auto input = args[0];
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#define MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP #define MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -18,16 +20,17 @@ struct softmax ...@@ -18,16 +20,17 @@ struct softmax
return pack(f(self.axis, "axis")); return pack(f(self.axis, "axis"));
} }
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "softmax"; } std::string name() const { return "softmax"; }
shape compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1).standard();
int64_t n_dim = inputs[0].lens().size();
if(axis < -n_dim || axis >= n_dim)
{
MIGRAPHX_THROW("SoftMax: input axis value " + std::to_string(axis) +
" is out of range");
}
return inputs.at(0); return inputs.at(0);
} }
......
...@@ -5,9 +5,10 @@ ...@@ -5,9 +5,10 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -25,28 +26,27 @@ struct squeeze ...@@ -25,28 +26,27 @@ struct squeeze
return pack(f(self.axes, "axes")); return pack(f(self.axes, "axes"));
} }
value attributes() const
{
value normalize;
normalize["axes"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "squeeze"; } std::string name() const { return "squeeze"; }
shape compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1).standard();
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
// change to support negative axis value if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
std::vector<int64_t> tuned_axes(axes.size());
std::transform(axes.begin(), axes.end(), tuned_axes.begin(), [&](auto i) {
return i >= 0 ? i : i + old_lens.size();
});
if(std::any_of(tuned_axes.begin(), tuned_axes.end(), [&](auto axis) {
return old_lens[axis] != 1;
}))
{ {
MIGRAPHX_THROW("squeeze axis dimension should be equal to 1"); MIGRAPHX_THROW("squeeze axis dimension should be equal to 1");
} }
std::vector<std::size_t> new_lens; std::vector<std::size_t> new_lens;
if(tuned_axes.empty()) if(axes.empty())
{ {
std::copy_if(old_lens.begin(), std::copy_if(old_lens.begin(),
old_lens.end(), old_lens.end(),
...@@ -57,7 +57,7 @@ struct squeeze ...@@ -57,7 +57,7 @@ struct squeeze
{ {
for(std::size_t i = 0; i < old_lens.size(); i++) for(std::size_t i = 0; i < old_lens.size(); i++)
{ {
if(std::find(tuned_axes.begin(), tuned_axes.end(), i) == tuned_axes.end()) if(std::find(axes.begin(), axes.end(), i) == axes.end())
{ {
new_lens.push_back(old_lens[i]); new_lens.push_back(old_lens[i]);
} }
......
...@@ -5,9 +5,9 @@ ...@@ -5,9 +5,9 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -25,8 +25,16 @@ struct unsqueeze ...@@ -25,8 +25,16 @@ struct unsqueeze
return pack(f(self.axes, "axes")); return pack(f(self.axes, "axes"));
} }
value attributes() const
{
value normalize;
normalize["axes"] =
value::array{normalize_attribute::include_min, normalize_attribute::use_output};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "unsqueeze"; } std::string name() const { return "unsqueeze"; }
shape compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard_or_scalar(); check_shapes{inputs, *this}.has(1).standard_or_scalar();
auto input_shape = inputs[0]; auto input_shape = inputs[0];
...@@ -43,17 +51,11 @@ struct unsqueeze ...@@ -43,17 +51,11 @@ struct unsqueeze
std::size_t new_size = old_lens.size() + axes.size(); std::size_t new_size = old_lens.size() + axes.size();
// in case of axes to be negative, tune to positive
std::vector<int64_t> tuned_axes(axes.size());
std::transform(axes.begin(), axes.end(), tuned_axes.begin(), [new_size](auto i) {
return i >= 0 ? i : i + new_size;
});
std::vector<std::size_t> new_lens(new_size); std::vector<std::size_t> new_lens(new_size);
std::size_t p = 0; std::size_t p = 0;
for(std::size_t i = 0; i < new_size; i++) for(std::size_t i = 0; i < new_size; i++)
{ {
if(std::find(tuned_axes.begin(), tuned_axes.end(), i) != tuned_axes.end()) if(std::find(axes.begin(), axes.end(), i) != axes.end())
{ {
new_lens[i] = 1; new_lens[i] = 1;
} }
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <utility> #include <utility>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/auto_any_cast.hpp> #include <migraphx/auto_any_cast.hpp>
...@@ -58,6 +59,8 @@ struct operation ...@@ -58,6 +59,8 @@ struct operation
/// Returns true if operation does not require a context to run compute /// Returns true if operation does not require a context to run compute
bool is_context_free(const operation& x); bool is_context_free(const operation& x);
/// Returns true if operation needs normalization before running compute
bool need_normalization(const operation& x);
/// Returns true if the operation has a finalize method /// Returns true if the operation has a finalize method
bool has_finalize(const operation& x); bool has_finalize(const operation& x);
...@@ -96,6 +99,14 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name()) ...@@ -96,6 +99,14 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
} // namespace operation_operators } // namespace operation_operators
template <class T>
shape normalize_compute_shape_op(T&& x, std::vector<shape> inputs)
{
dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].lens());
return any_cast<T>(y).normalize_compute_shape(inputs);
}
template <class T> template <class T>
auto compute_op(rank<2>, auto compute_op(rank<2>,
const T& x, const T& x,
...@@ -175,6 +186,20 @@ auto is_context_free_op(const T& x) -> decltype(is_context_free_op( ...@@ -175,6 +186,20 @@ auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
return {}; return {};
} }
template <class T>
auto need_normalization_op(rank<1>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs), std::true_type{});
template <class T>
auto need_normalization_op(rank<0>, const T&, const std::vector<shape>&) -> std::false_type;
template <class T>
auto need_normalization_op(const T& x)
-> decltype(need_normalization_op(rank<1>{}, x, std::declval<std::vector<shape>>()))
{
return {};
}
template <class T> template <class T>
std::ptrdiff_t output_alias_op(const T&, const std::vector<shape>&) std::ptrdiff_t output_alias_op(const T&, const std::vector<shape>&)
{ {
...@@ -246,6 +271,7 @@ void from_value_op(T& x, const value& v) ...@@ -246,6 +271,7 @@ void from_value_op(T& x, const value& v)
* { * {
* std::string name() const; * std::string name() const;
* bool is_context_free() const; * bool is_context_free() const;
* bool need_normalization() const;
* bool has_finalize() const; * bool has_finalize() const;
* std::ptrdiff_t output_alias(const std::vector<shape>& input) const; * std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ; * void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
...@@ -336,6 +362,12 @@ struct operation ...@@ -336,6 +362,12 @@ struct operation
return (*this).private_detail_te_get_handle().is_context_free(); return (*this).private_detail_te_get_handle().is_context_free();
} }
bool need_normalization() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().need_normalization();
}
bool has_finalize() const bool has_finalize() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -417,6 +449,7 @@ struct operation ...@@ -417,6 +449,7 @@ struct operation
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual bool is_context_free() const = 0; virtual bool is_context_free() const = 0;
virtual bool need_normalization() const = 0;
virtual bool has_finalize() const = 0; virtual bool has_finalize() const = 0;
virtual std::ptrdiff_t output_alias(const std::vector<shape>& input) const = 0; virtual std::ptrdiff_t output_alias(const std::vector<shape>& input) const = 0;
virtual void virtual void
...@@ -445,6 +478,19 @@ struct operation ...@@ -445,6 +478,19 @@ struct operation
return detail::is_context_free_op(private_detail_te_self); return detail::is_context_free_op(private_detail_te_self);
} }
template <class T>
static auto private_detail_te_default_need_normalization(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.need_normalization())
{
return private_detail_te_self.need_normalization();
}
template <class T>
static bool private_detail_te_default_need_normalization(float, T&& private_detail_te_self)
{
return detail::need_normalization_op(private_detail_te_self);
}
template <class T> template <class T>
static auto private_detail_te_default_has_finalize(char, T&& private_detail_te_self) static auto private_detail_te_default_has_finalize(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.has_finalize()) -> decltype(private_detail_te_self.has_finalize())
...@@ -496,6 +542,23 @@ struct operation ...@@ -496,6 +542,23 @@ struct operation
detail::finalize_op(private_detail_te_self, ctx, output, input); detail::finalize_op(private_detail_te_self, ctx, output, input);
} }
template <class T>
static auto private_detail_te_default_compute_shape(char,
T&& private_detail_te_self,
const std::vector<shape>& input)
-> decltype(private_detail_te_self.compute_shape(input))
{
return private_detail_te_self.compute_shape(input);
}
template <class T>
static shape private_detail_te_default_compute_shape(float,
T&& private_detail_te_self,
const std::vector<shape>& input)
{
return detail::normalize_compute_shape_op(private_detail_te_self, input);
}
template <class T> template <class T>
static auto private_detail_te_default_compute(char, static auto private_detail_te_default_compute(char,
T&& private_detail_te_self, T&& private_detail_te_self,
...@@ -613,6 +676,12 @@ struct operation ...@@ -613,6 +676,12 @@ struct operation
return private_detail_te_default_is_context_free(char(0), private_detail_te_value); return private_detail_te_default_is_context_free(char(0), private_detail_te_value);
} }
bool need_normalization() const override
{
return private_detail_te_default_need_normalization(char(0), private_detail_te_value);
}
bool has_finalize() const override bool has_finalize() const override
{ {
...@@ -635,7 +704,7 @@ struct operation ...@@ -635,7 +704,7 @@ struct operation
shape compute_shape(const std::vector<shape>& input) const override shape compute_shape(const std::vector<shape>& input) const override
{ {
return private_detail_te_value.compute_shape(input); return private_detail_te_default_compute_shape(char(0), private_detail_te_value, input);
} }
argument compute(context& ctx, argument compute(context& ctx,
...@@ -759,6 +828,14 @@ bool is_context_free(const T& x) ...@@ -759,6 +828,14 @@ bool is_context_free(const T& x)
return detail::is_context_free_op(x); return detail::is_context_free_op(x);
} }
inline bool need_normalization(const operation& op) { return op.need_normalization(); }
template <class T>
bool need_normalization(const T& x)
{
return detail::need_normalization_op(x);
}
inline bool has_finalize(const operation& op) { return op.has_finalize(); } inline bool has_finalize(const operation& op) { return op.has_finalize(); }
template <class T> template <class T>
......
...@@ -224,6 +224,7 @@ struct value ...@@ -224,6 +224,7 @@ struct value
} }
value& operator=(std::nullptr_t); value& operator=(std::nullptr_t);
value& operator=(const std::initializer_list<value>& i);
bool is_array() const; bool is_array() const;
const std::vector<value>& get_array() const; const std::vector<value>& get_array() const;
......
#include <migraphx/operation.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/normalize_attribute.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// different attributes
// 1) use_input(default)/use_output
// 2) use_rank(default)/use_len
// 3) clip_min(default)/not_clip_min
// 3.1) include_min(default)/exclude_min
// 4) clip_max(default)/not_clip_max
// 4.1) exclude_max(default)/include_max
auto tune_attribute(const std::vector<int64_t>& vec,
const std::vector<int64_t>& axes,
const value& val,
const std::vector<std::size_t>& lens)
{
std::vector<int64_t> result(vec);
int64_t n_rank = static_cast<int64_t>(lens.size());
std::vector<op::normalize_attribute> vec_attrs = val.to_vector<op::normalize_attribute>();
if(contains(vec_attrs, op::normalize_attribute::use_output))
{
n_rank = n_rank + vec.size();
}
std::vector<int64_t> max_vals(vec.size(), n_rank);
if(contains(vec_attrs, op::normalize_attribute::use_len))
{
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) { return lens[i]; });
}
if(contains(vec_attrs, op::normalize_attribute::clip_max))
{
if(contains(vec_attrs, op::normalize_attribute::include_max))
{
std::transform(result.begin(),
result.end(),
max_vals.begin(),
result.begin(),
[](auto v, auto mv) { return v > mv ? mv : v; });
}
else
{
std::transform(result.begin(),
result.end(),
max_vals.begin(),
result.begin(),
[](auto v, auto mv) { return v >= mv ? mv - 1 : v; });
}
}
else
{
if(contains(vec_attrs, op::normalize_attribute::include_max))
{
if(!std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!");
}
}
else
{
if(!std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!");
}
}
}
std::vector<int64_t> min_vals = max_vals;
std::transform(min_vals.begin(), min_vals.end(), min_vals.begin(), [](auto v) { return -v; });
if(contains(vec_attrs, op::normalize_attribute::clip_min))
{
if(contains(vec_attrs, op::normalize_attribute::include_min))
{
std::transform(result.begin(),
result.end(),
min_vals.begin(),
result.begin(),
[](auto v, auto mv) { return v < mv ? mv : v; });
}
else
{
std::transform(result.begin(),
result.end(),
min_vals.begin(),
result.begin(),
[](auto v, auto mv) { return v < mv + 1 ? mv + 1 : v; });
}
}
else
{
if(contains(vec_attrs, op::normalize_attribute::include_min))
{
if(!std::equal(min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!");
}
}
else
{
if(!std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!");
}
}
}
std::transform(
result.begin(), result.end(), max_vals.begin(), result.begin(), [](auto v, auto mv) {
return v < 0 ? v + mv : v;
});
return result;
}
bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
{
bool tuned = false;
auto attrs = op.attributes();
auto val = op.to_value();
if(!attrs.contains("normalize_axes"))
{
return false;
}
auto attr_v = attrs.at("normalize_axes").without_key();
for(const auto& rv : attr_v)
{
const auto& key = rv.get_key();
if(val.contains(key))
{
auto vv = val.at(key).without_key();
if(vv.is_array())
{
std::vector<int64_t> axes;
if(val.contains("axes"))
{
axes = val.at("axes").without_key().to_vector<int64_t>();
}
auto vec = vv.to_vector<int64_t>();
auto result = tune_attribute(vec, axes, rv.without_key(), lens);
val[key] = result;
op.from_value(val);
val = op.to_value();
tuned = true;
}
else
{
auto num = vv.to<int64_t>();
auto result = tune_attribute({num}, {num}, rv.without_key(), lens);
val[key] = result.front();
op.from_value(val);
val = op.to_value();
tuned = true;
}
}
else
{
MIGRAPHX_THROW("NORMALIZE_ATTR : op " + op.name() + " attribute \"" + key +
"\" not exist!");
}
}
return tuned;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <unordered_set>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/auto_any_cast.hpp>
#include <migraphx/value.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void normalize_ops::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
auto inputs = ins->inputs();
if(inputs.empty())
continue;
auto lens = inputs[0]->get_shape().lens();
migraphx::operation tuned_op = ins->get_operator();
if(normalize_attributes(tuned_op, lens))
{
m.replace_instruction(ins, tuned_op, inputs);
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment