"include/ck/utility/sequence.hpp" did not exist on "cd29b09a824311bb33fd3f66b4d97a291b5e90e0"
Unverified Commit 8443ecd1 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Normalize ops (#667)



* add a pass to normalize ops

* clang format

* add unit tests

* clang format

* code backup

* clang format

* code backup

* clang format

* add support for slice in the normalize_op function

* clang format

* add operation method api for whether we need to call normalize_op

* clang format

* fix review comments

* clang format

* rename a function namejJ

* clang format

* change compute_shape to normalize_compute_shape for corresponding operators

* clang format

* remove unnecessary code

* fix various issues

* clang format

* add attributes to operators having axis attributes

* clang format

* fixed jenkins build error

* clang format

* fix a bug related to slice

* clang format

* code backup

* clang format

* code backup

* clang format

* rename a file

* fix cppcheck error

* some code refinement

* clang format

* change attributes to enum

* clang format

* refine the enum

* clang format

* remove unnecessary code

* add unit tests for more code coverage and fixed a bug

* clang format

* remove unnecessary changes

* change normalize_axes to normalize

* clang format

* revert back the changes in broadcast.hpp

* rename normalize_axes to normalize

* fix review comments

* clang format

* Add flag to enable cpu backend

* Make buffers shared

* Enable optimizations

* Formatting

* Try to avoid ambiguous assign in value class

* fixed a build error

* clang format

* add the normalize_ops pass to the ref target

* refactor program to module to normalize_ops pass
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent f8b56a66
......@@ -45,6 +45,8 @@ add_library(migraphx
convert_to_json.cpp
opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp
normalize_attributes.cpp
normalize_ops.cpp
)
rocm_set_soversion(migraphx ${MIGRAPHX_SO_VERSION})
function(register_migraphx_ops)
......
#ifndef MIGRAPHX_GUARD_RTGLIB_NORMALIZE_ATTRIBUTES_HPP
#define MIGRAPHX_GUARD_RTGLIB_NORMALIZE_ATTRIBUTES_HPP
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
#include <cstring>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct operation;
template <class T, class...>
struct select_dependent_type
{
using type = T;
};
template <class T, class... Ts>
using dependent_type = typename select_dependent_type<T, Ts...>::type;
bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_NORMALIZE_OPS_HPP
#define MIGRAPHX_GUARD_RTGLIB_NORMALIZE_OPS_HPP
#include <string>
#include <vector>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/program.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
/**
* Process negative axis attributes of ops
*/
struct normalize_ops
{
std::string name() const { return "normalize_ops"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -6,6 +6,8 @@
#include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -21,14 +23,21 @@ struct argmax
return pack(f(self.axis, "axis"));
}
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "argmax"; }
shape compute_shape(std::vector<shape> inputs) const
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < -n_dim)
if(axis >= n_dim || axis < 0)
{
MIGRAPHX_THROW("ARGMAX: axis is out of range.");
}
......
......@@ -6,6 +6,8 @@
#include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -21,14 +23,21 @@ struct argmin
return pack(f(self.axis, "axis"));
}
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "argmin"; }
shape compute_shape(std::vector<shape> inputs) const
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < -n_dim)
if(axis >= n_dim || axis < 0)
{
MIGRAPHX_THROW("ARGMIN: axis is out of range.");
}
......
......@@ -35,10 +35,16 @@ struct broadcast
std::string name() const { return "broadcast"; }
shape compute_shape(std::vector<shape> inputs) const
{
auto t = inputs.at(0).type();
auto input = inputs.at(0);
auto t = input.type();
std::vector<size_t> bcast_strides(broadcast_lens.size(), 0);
// the broacast op is deprecated now, so not handling the negative
// value of axis anymore
if(axis >= broadcast_lens.size())
{
MIGRAPHX_THROW("BROADCAST : axis is out of range");
}
if(std::all_of(
broadcast_lens.cbegin(), broadcast_lens.cend(), [&](auto x) { return x == 1; }))
......@@ -49,9 +55,15 @@ struct broadcast
}
else
{
assert(broadcast_lens.size() - axis >= input.lens().size());
if(broadcast_lens.size() - axis < input.lens().size())
{
MIGRAPHX_THROW("BROADCAST: when broadcasting success sizes must match");
}
if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
{
MIGRAPHX_THROW("BROADCAST: when broadcasting success sizes must match");
}
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
return {t, broadcast_lens, std::move(bcast_strides)};
}
......
......@@ -8,6 +8,8 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
......@@ -25,6 +27,13 @@ struct concat
return pack(f(self.axis, "axis"));
}
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "concat"; }
std::vector<std::size_t> compute_offsets(const shape& output_shape,
const std::vector<argument>& args) const
......@@ -41,7 +50,7 @@ struct concat
}
return offsets;
}
shape compute_shape(std::vector<shape> inputs) const
shape normalize_compute_shape(std::vector<shape> inputs) const
{
if(inputs.empty())
{
......@@ -50,10 +59,9 @@ struct concat
const auto& first_shape_lens = inputs.front().lens();
const auto& type = inputs.front().type();
std::size_t axis_index = (axis < 0) ? (first_shape_lens.size() + axis) : axis;
for(std::size_t l = 0; l < first_shape_lens.size(); l++)
{
if(l != axis_index)
if(l != axis)
{
if(!std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
return s.lens()[l] == first_shape_lens[l];
......@@ -67,11 +75,11 @@ struct concat
for(const auto& input : inputs)
{
const auto& lens = input.lens();
new_dim_axis += lens[axis_index];
new_dim_axis += lens[axis];
}
std::vector<std::size_t> new_lens;
std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens));
new_lens[axis_index] = new_dim_axis;
new_lens[axis] = new_dim_axis;
return {type, new_lens};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
......
......@@ -8,6 +8,8 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
......@@ -25,23 +27,23 @@ struct flatten
return pack(f(self.axis, "axis"));
}
value attributes() const
{
value normalize;
normalize["axis"] =
value::array{normalize_attribute::include_min, normalize_attribute::include_max};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "flatten"; }
shape compute_shape(std::vector<shape> inputs) const
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto&& lens = inputs.front().lens();
int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis > n_dim or axis < -n_dim)
{
MIGRAPHX_THROW("FLATTEN: axis for flatten is out of range");
}
auto tuned_axis = (axis < 0) ? axis + n_dim : axis;
auto x = std::accumulate(
lens.begin(), lens.begin() + tuned_axis, std::size_t{1}, std::multiplies<>{});
auto y = std::accumulate(
lens.begin() + tuned_axis, lens.end(), std::size_t{1}, std::multiplies<>{});
auto&& lens = inputs.front().lens();
auto x =
std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
auto y =
std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{});
return {inputs.at(0).type(), {x, y}};
}
argument compute(shape output_shape, std::vector<argument> args) const
......
......@@ -8,6 +8,8 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
......@@ -17,7 +19,7 @@ namespace op {
struct gather
{
int axis = 0;
int64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -25,27 +27,25 @@ struct gather
return pack(f(self.axis, "axis"));
}
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "gather"; }
shape compute_shape(std::vector<shape> inputs) const
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
auto lens = inputs[0].lens();
int n_dim = static_cast<int>(lens.size());
if(axis >= n_dim || axis < -n_dim)
{
MIGRAPHX_THROW("Gather: axis is out of range.");
}
// negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (n_dim + axis) : axis;
auto type = inputs[0].type();
lens.erase(lens.begin() + axis_index);
lens.erase(lens.begin() + axis);
if(!inputs[1].scalar())
{
auto ind_lens = inputs[1].lens();
lens.insert(lens.begin() + axis_index, ind_lens.begin(), ind_lens.end());
lens.insert(lens.begin() + axis, ind_lens.begin(), ind_lens.end());
}
// for scalar output
......@@ -61,10 +61,8 @@ struct gather
{
argument result{output_shape};
// negative axis means counting dimensions from back
auto lens = args[0].get_shape().lens();
int axis_index = (axis < 0) ? static_cast<int>(lens.size() + axis) : axis;
std::size_t axis_dim_size = lens[axis_index];
auto lens = args[0].get_shape().lens();
std::size_t axis_dim_size = lens[axis];
// max dimension in axis
visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) {
......@@ -76,14 +74,14 @@ struct gather
}
else
{
auto out_lens = data.get_shape().lens();
out_lens[axis_index] = indices.get_shape().elements();
auto out_lens = data.get_shape().lens();
out_lens[axis] = indices.get_shape().elements();
migraphx::shape out_comp_shape{data.get_shape().type(), out_lens};
shape_for_each(out_comp_shape, [&](const auto& out_idx) {
auto data_idx = out_idx;
auto in_index = indices[data_idx[axis_index]];
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
data_idx[axis_index] = in_index;
auto data_idx = out_idx;
auto in_index = indices[data_idx[axis]];
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
data_idx[axis] = in_index;
output[out_comp_shape.index(out_idx.begin(), out_idx.end())] =
data(data_idx.begin(), data_idx.end());
});
......
......@@ -2,6 +2,8 @@
#define MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
......@@ -18,16 +20,17 @@ struct logsoftmax
return pack(f(self.axis, "axis"));
}
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "logsoftmax"; }
shape compute_shape(std::vector<shape> inputs) const
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
int64_t n_dim = static_cast<int64_t>(inputs[0].lens().size());
if(axis < -n_dim || axis >= n_dim)
{
MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) +
" is out of range");
}
return inputs.at(0);
}
......
#ifndef MIGRAPHX_GUARD_OPERATORS_OP_NORMALIZE_ATTRIBUTE_HPP
#define MIGRAPHX_GUARD_OPERATORS_OP_NORMALIZE_ATTRIBUTE_HPP
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
// different attributes
// 1) use_input(default)/use_output
// 2) use_rank(default)/use_len
// 3) clip_min(default)/not_clip_min
// 3.1) include_min(default)/exclude_min
// 4) clip_max(default)/not_clip_max
// 4.1) exclude_max(default)/include_max
enum class normalize_attribute
{
use_len,
use_output,
clip_max,
clip_min,
include_max,
include_min
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -7,6 +7,8 @@
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <vector>
namespace migraphx {
......@@ -60,6 +62,13 @@ struct reduce_op : op_name<Derived>
return pack(f(self.axes, "axes"));
}
value attributes() const
{
value normalize;
normalize["axes"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::vector<int64_t> tune_axes(std::size_t n_dim) const
{
auto tuned_axes = axes;
......@@ -68,26 +77,11 @@ struct reduce_op : op_name<Derived>
tuned_axes.resize(n_dim);
std::iota(tuned_axes.begin(), tuned_axes.end(), 0);
}
else
{
for(auto& axis : tuned_axes)
{
int64_t s_dim = static_cast<int64_t>(n_dim);
if(axis >= s_dim or axis < -s_dim)
{
MIGRAPHX_THROW("REDUCE_OP: axis out of range");
}
if(axis < 0)
{
axis += n_dim;
}
}
}
return tuned_axes;
}
shape compute_shape(std::vector<shape> inputs) const
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0);
......
......@@ -5,6 +5,8 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
#include <vector>
......@@ -25,6 +27,23 @@ struct slice
return pack(f(self.axes, "axes"), f(self.starts, "starts"), f(self.ends, "ends"));
}
value attributes() const
{
value normalize = value::object{};
normalize["axes"] = value::array{normalize_attribute::include_min};
normalize["starts"] = value::array{normalize_attribute::clip_max,
normalize_attribute::clip_min,
normalize_attribute::include_max,
normalize_attribute::use_len,
normalize_attribute::include_min};
normalize["ends"] = value::array{normalize_attribute::clip_max,
normalize_attribute::clip_min,
normalize_attribute::include_max,
normalize_attribute::use_len,
normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "slice"; }
void tune_attributes(std::vector<int64_t>& tuned_axes,
......@@ -111,30 +130,34 @@ struct slice
return offset;
}
shape compute_shape(std::vector<shape> inputs) const
shape normalize_compute_shape(std::vector<shape> inputs) const
{
auto input_shape = inputs[0];
auto t = input_shape.type();
const auto& old_lens = input_shape.lens();
const auto& old_strides = input_shape.strides();
if(std::any_of(
axes.begin(), axes.end(), [&](auto i) { return (i >= old_lens.size() and i < 0); }))
{
MIGRAPHX_THROW("SLICE: input axis " + to_string_range(axes) + " out of range");
}
if(starts.size() != axes.size() || axes.size() != ends.size())
{
MIGRAPHX_THROW("SLICE: inconsistent sizes");
}
std::vector<int64_t> tuned_axes = axes;
std::vector<int64_t> tuned_starts = starts;
std::vector<int64_t> tuned_ends = ends;
tune_attributes(tuned_axes, tuned_starts, tuned_ends, old_lens);
std::vector<std::size_t> new_lens = old_lens;
for(std::size_t i = 0; i < tuned_axes.size(); i++)
for(std::size_t i = 0; i < axes.size(); i++)
{
auto axis = tuned_axes[i];
new_lens[axis] = fix_index(old_lens, axis, tuned_ends[i]) -
fix_index(old_lens, axis, tuned_starts[i]);
auto axis = axes[i];
new_lens[axis] =
fix_index(old_lens, axis, ends[i]) - fix_index(old_lens, axis, starts[i]);
}
return shape{t, new_lens, old_strides};
}
argument compute(shape output_shape, std::vector<argument> args) const
{
auto input = args[0];
......
......@@ -2,6 +2,8 @@
#define MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
......@@ -18,16 +20,17 @@ struct softmax
return pack(f(self.axis, "axis"));
}
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "softmax"; }
shape compute_shape(std::vector<shape> inputs) const
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
int64_t n_dim = inputs[0].lens().size();
if(axis < -n_dim || axis >= n_dim)
{
MIGRAPHX_THROW("SoftMax: input axis value " + std::to_string(axis) +
" is out of range");
}
return inputs.at(0);
}
......
......@@ -5,9 +5,10 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
......@@ -25,28 +26,27 @@ struct squeeze
return pack(f(self.axes, "axes"));
}
value attributes() const
{
value normalize;
normalize["axes"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "squeeze"; }
shape compute_shape(std::vector<shape> inputs) const
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto input_shape = inputs[0];
auto type = input_shape.type();
auto old_lens = input_shape.lens();
// change to support negative axis value
std::vector<int64_t> tuned_axes(axes.size());
std::transform(axes.begin(), axes.end(), tuned_axes.begin(), [&](auto i) {
return i >= 0 ? i : i + old_lens.size();
});
if(std::any_of(tuned_axes.begin(), tuned_axes.end(), [&](auto axis) {
return old_lens[axis] != 1;
}))
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
{
MIGRAPHX_THROW("squeeze axis dimension should be equal to 1");
}
std::vector<std::size_t> new_lens;
if(tuned_axes.empty())
if(axes.empty())
{
std::copy_if(old_lens.begin(),
old_lens.end(),
......@@ -57,7 +57,7 @@ struct squeeze
{
for(std::size_t i = 0; i < old_lens.size(); i++)
{
if(std::find(tuned_axes.begin(), tuned_axes.end(), i) == tuned_axes.end())
if(std::find(axes.begin(), axes.end(), i) == axes.end())
{
new_lens.push_back(old_lens[i]);
}
......
......@@ -5,9 +5,9 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
......@@ -25,8 +25,16 @@ struct unsqueeze
return pack(f(self.axes, "axes"));
}
value attributes() const
{
value normalize;
normalize["axes"] =
value::array{normalize_attribute::include_min, normalize_attribute::use_output};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "unsqueeze"; }
shape compute_shape(std::vector<shape> inputs) const
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard_or_scalar();
auto input_shape = inputs[0];
......@@ -43,17 +51,11 @@ struct unsqueeze
std::size_t new_size = old_lens.size() + axes.size();
// in case of axes to be negative, tune to positive
std::vector<int64_t> tuned_axes(axes.size());
std::transform(axes.begin(), axes.end(), tuned_axes.begin(), [new_size](auto i) {
return i >= 0 ? i : i + new_size;
});
std::vector<std::size_t> new_lens(new_size);
std::size_t p = 0;
for(std::size_t i = 0; i < new_size; i++)
{
if(std::find(tuned_axes.begin(), tuned_axes.end(), i) != tuned_axes.end())
if(std::find(axes.begin(), axes.end(), i) != axes.end())
{
new_lens[i] = 1;
}
......
......@@ -9,6 +9,7 @@
#include <utility>
#include <migraphx/reflect.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/auto_any_cast.hpp>
......@@ -58,6 +59,8 @@ struct operation
/// Returns true if operation does not require a context to run compute
bool is_context_free(const operation& x);
/// Returns true if operation needs normalization before running compute
bool need_normalization(const operation& x);
/// Returns true if the operation has a finalize method
bool has_finalize(const operation& x);
......@@ -96,6 +99,14 @@ auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
} // namespace operation_operators
template <class T>
shape normalize_compute_shape_op(T&& x, std::vector<shape> inputs)
{
dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].lens());
return any_cast<T>(y).normalize_compute_shape(inputs);
}
template <class T>
auto compute_op(rank<2>,
const T& x,
......@@ -175,6 +186,20 @@ auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
return {};
}
template <class T>
auto need_normalization_op(rank<1>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs), std::true_type{});
template <class T>
auto need_normalization_op(rank<0>, const T&, const std::vector<shape>&) -> std::false_type;
template <class T>
auto need_normalization_op(const T& x)
-> decltype(need_normalization_op(rank<1>{}, x, std::declval<std::vector<shape>>()))
{
return {};
}
template <class T>
std::ptrdiff_t output_alias_op(const T&, const std::vector<shape>&)
{
......@@ -246,6 +271,7 @@ void from_value_op(T& x, const value& v)
* {
* std::string name() const;
* bool is_context_free() const;
* bool need_normalization() const;
* bool has_finalize() const;
* std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
* void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
......@@ -336,6 +362,12 @@ struct operation
return (*this).private_detail_te_get_handle().is_context_free();
}
bool need_normalization() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().need_normalization();
}
bool has_finalize() const
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -417,6 +449,7 @@ struct operation
virtual std::string name() const = 0;
virtual bool is_context_free() const = 0;
virtual bool need_normalization() const = 0;
virtual bool has_finalize() const = 0;
virtual std::ptrdiff_t output_alias(const std::vector<shape>& input) const = 0;
virtual void
......@@ -445,6 +478,19 @@ struct operation
return detail::is_context_free_op(private_detail_te_self);
}
template <class T>
static auto private_detail_te_default_need_normalization(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.need_normalization())
{
return private_detail_te_self.need_normalization();
}
template <class T>
static bool private_detail_te_default_need_normalization(float, T&& private_detail_te_self)
{
return detail::need_normalization_op(private_detail_te_self);
}
template <class T>
static auto private_detail_te_default_has_finalize(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.has_finalize())
......@@ -496,6 +542,23 @@ struct operation
detail::finalize_op(private_detail_te_self, ctx, output, input);
}
template <class T>
static auto private_detail_te_default_compute_shape(char,
T&& private_detail_te_self,
const std::vector<shape>& input)
-> decltype(private_detail_te_self.compute_shape(input))
{
return private_detail_te_self.compute_shape(input);
}
template <class T>
static shape private_detail_te_default_compute_shape(float,
T&& private_detail_te_self,
const std::vector<shape>& input)
{
return detail::normalize_compute_shape_op(private_detail_te_self, input);
}
template <class T>
static auto private_detail_te_default_compute(char,
T&& private_detail_te_self,
......@@ -613,6 +676,12 @@ struct operation
return private_detail_te_default_is_context_free(char(0), private_detail_te_value);
}
bool need_normalization() const override
{
return private_detail_te_default_need_normalization(char(0), private_detail_te_value);
}
bool has_finalize() const override
{
......@@ -635,7 +704,7 @@ struct operation
shape compute_shape(const std::vector<shape>& input) const override
{
return private_detail_te_value.compute_shape(input);
return private_detail_te_default_compute_shape(char(0), private_detail_te_value, input);
}
argument compute(context& ctx,
......@@ -759,6 +828,14 @@ bool is_context_free(const T& x)
return detail::is_context_free_op(x);
}
inline bool need_normalization(const operation& op) { return op.need_normalization(); }
template <class T>
bool need_normalization(const T& x)
{
return detail::need_normalization_op(x);
}
inline bool has_finalize(const operation& op) { return op.has_finalize(); }
template <class T>
......
......@@ -224,6 +224,7 @@ struct value
}
value& operator=(std::nullptr_t);
value& operator=(const std::initializer_list<value>& i);
bool is_array() const;
const std::vector<value>& get_array() const;
......
#include <migraphx/operation.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/normalize_attribute.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// different attributes
// 1) use_input(default)/use_output
// 2) use_rank(default)/use_len
// 3) clip_min(default)/not_clip_min
// 3.1) include_min(default)/exclude_min
// 4) clip_max(default)/not_clip_max
// 4.1) exclude_max(default)/include_max
auto tune_attribute(const std::vector<int64_t>& vec,
const std::vector<int64_t>& axes,
const value& val,
const std::vector<std::size_t>& lens)
{
std::vector<int64_t> result(vec);
int64_t n_rank = static_cast<int64_t>(lens.size());
std::vector<op::normalize_attribute> vec_attrs = val.to_vector<op::normalize_attribute>();
if(contains(vec_attrs, op::normalize_attribute::use_output))
{
n_rank = n_rank + vec.size();
}
std::vector<int64_t> max_vals(vec.size(), n_rank);
if(contains(vec_attrs, op::normalize_attribute::use_len))
{
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) { return lens[i]; });
}
if(contains(vec_attrs, op::normalize_attribute::clip_max))
{
if(contains(vec_attrs, op::normalize_attribute::include_max))
{
std::transform(result.begin(),
result.end(),
max_vals.begin(),
result.begin(),
[](auto v, auto mv) { return v > mv ? mv : v; });
}
else
{
std::transform(result.begin(),
result.end(),
max_vals.begin(),
result.begin(),
[](auto v, auto mv) { return v >= mv ? mv - 1 : v; });
}
}
else
{
if(contains(vec_attrs, op::normalize_attribute::include_max))
{
if(!std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!");
}
}
else
{
if(!std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!");
}
}
}
std::vector<int64_t> min_vals = max_vals;
std::transform(min_vals.begin(), min_vals.end(), min_vals.begin(), [](auto v) { return -v; });
if(contains(vec_attrs, op::normalize_attribute::clip_min))
{
if(contains(vec_attrs, op::normalize_attribute::include_min))
{
std::transform(result.begin(),
result.end(),
min_vals.begin(),
result.begin(),
[](auto v, auto mv) { return v < mv ? mv : v; });
}
else
{
std::transform(result.begin(),
result.end(),
min_vals.begin(),
result.begin(),
[](auto v, auto mv) { return v < mv + 1 ? mv + 1 : v; });
}
}
else
{
if(contains(vec_attrs, op::normalize_attribute::include_min))
{
if(!std::equal(min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!");
}
}
else
{
if(!std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!");
}
}
}
std::transform(
result.begin(), result.end(), max_vals.begin(), result.begin(), [](auto v, auto mv) {
return v < 0 ? v + mv : v;
});
return result;
}
bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
{
bool tuned = false;
auto attrs = op.attributes();
auto val = op.to_value();
if(!attrs.contains("normalize_axes"))
{
return false;
}
auto attr_v = attrs.at("normalize_axes").without_key();
for(const auto& rv : attr_v)
{
const auto& key = rv.get_key();
if(val.contains(key))
{
auto vv = val.at(key).without_key();
if(vv.is_array())
{
std::vector<int64_t> axes;
if(val.contains("axes"))
{
axes = val.at("axes").without_key().to_vector<int64_t>();
}
auto vec = vv.to_vector<int64_t>();
auto result = tune_attribute(vec, axes, rv.without_key(), lens);
val[key] = result;
op.from_value(val);
val = op.to_value();
tuned = true;
}
else
{
auto num = vv.to<int64_t>();
auto result = tune_attribute({num}, {num}, rv.without_key(), lens);
val[key] = result.front();
op.from_value(val);
val = op.to_value();
tuned = true;
}
}
else
{
MIGRAPHX_THROW("NORMALIZE_ATTR : op " + op.name() + " attribute \"" + key +
"\" not exist!");
}
}
return tuned;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <unordered_set>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/auto_any_cast.hpp>
#include <migraphx/value.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void normalize_ops::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
auto inputs = ins->inputs();
if(inputs.empty())
continue;
auto lens = inputs[0]->get_shape().lens();
migraphx::operation tuned_op = ins->get_operator();
if(normalize_attributes(tuned_op, lens))
{
m.replace_instruction(ins, tuned_op, inputs);
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment