Commit 11e155c2 authored by Paul's avatar Paul
Browse files

Merge

parents 8a9c5bce aa7ff911
#ifndef MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#define MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#include <ostream>
#include <vector>
#include <migraphx/config.hpp>
#include <utility>
......@@ -15,6 +17,15 @@ enum padding_mode_t
valid
};
// The pooling modes must correspond 1-1 to the operators defined for struct parse_pooling.
// Used in pooling and roialign operators.
enum class pooling_mode
{
average,
max,
lpnorm
};
// indicate rnn computation direction
enum class rnn_direction
{
......@@ -23,6 +34,7 @@ enum class rnn_direction
bidirectional,
};
std::ostream& operator<<(std::ostream& os, pooling_mode v);
std::ostream& operator<<(std::ostream& os, rnn_direction v);
} // namespace op
......
......@@ -97,7 +97,6 @@ struct deconvolution
shape win_shape{output_shape.type(), win_size};
par_dfor(in_n, wei_c)([&](int o, int k) {
shape_for_each(win_shape, [&](auto idx_win) {
const int w = idx_win[0];
......@@ -140,9 +139,7 @@ struct deconvolution
weights(idx_wei.begin(), idx_wei.end());
}
});
});
});
return result;
}
......
......@@ -51,7 +51,6 @@ struct flatten
{
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
#ifndef MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct gathernd
{
int batch_dims = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.batch_dims, "batch_dims"));
}
std::string name() const { return "gathernd"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
auto r = inputs.front().lens().size();
auto q = inputs.back().lens().size();
auto k = inputs.back().lens().back();
if(k > r - batch_dims)
{
MIGRAPHX_THROW("GATHERND: Indices of length " + std::to_string(k) +
" cannot be used to access data of rank " +
std::to_string(r - batch_dims));
}
auto indices_lens_iter = inputs.back().lens().begin();
auto output_lens_size = q + r - k - batch_dims - 1;
std::vector<std::size_t> output_lens(output_lens_size);
std::copy(indices_lens_iter, indices_lens_iter + (q - 1), output_lens.begin());
if(k < r - batch_dims)
{
auto data_lens = inputs.front().lens();
std::copy(
data_lens.begin() + batch_dims + k, data_lens.end(), output_lens.begin() + q - 1);
}
shape output_shape{inputs.front().type(), output_lens};
return output_shape;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) {
auto indices_shape = indices.get_shape();
auto indices_shape_lens = indices_shape.lens();
auto data_shape = data.get_shape();
auto data_shape_lens = data_shape.lens();
auto k = indices_shape.lens().back();
const auto num_slice_dims = k;
std::size_t num_slices = std::accumulate(indices_shape_lens.begin(),
indices_shape_lens.end() - 1,
1,
std::multiplies<std::size_t>());
std::size_t slice_size = std::accumulate(data_shape_lens.begin() + k + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
std::size_t num_batches = std::accumulate(data_shape_lens.begin(),
data_shape_lens.begin() + batch_dims,
1,
std::multiplies<std::size_t>());
std::size_t data_batch_stride =
std::accumulate(data_shape_lens.begin() + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
auto num_slices_per_batch = num_slices / num_batches;
std::vector<std::size_t> sizes_from_slice_dims(num_slice_dims);
{
auto running_product = slice_size;
for(std::size_t i = 0; i < num_slice_dims; ++i)
{
sizes_from_slice_dims[num_slice_dims - 1 - i] = running_product;
running_product *= data_shape_lens[batch_dims + num_slice_dims - 1 - i];
}
}
std::vector<std::size_t> input_slice_offsets(num_slices);
par_for(num_slices, [&](const auto i) {
std::size_t batch_idx = i / num_slices_per_batch;
auto slice_indices = indices.begin() + (i * num_slice_dims);
std::size_t relative_slice_offset = 0;
for(size_t dim_idx = 0; dim_idx < num_slice_dims; ++dim_idx)
{
int64_t index = *(slice_indices + dim_idx);
const std::size_t input_dim_idx = batch_dims + dim_idx;
const auto input_dim = data_shape_lens[input_dim_idx];
if(index < -static_cast<int64_t>(input_dim) or
index >= static_cast<int64_t>(input_dim))
MIGRAPHX_THROW("GatherND: index " + std::to_string(index) +
" is out of bounds for dim of len " +
std::to_string(input_dim));
if(index < 0)
index += input_dim;
relative_slice_offset += index * sizes_from_slice_dims[dim_idx];
}
input_slice_offsets[i] =
(batch_idx * data_batch_stride) + relative_slice_offset;
});
par_for(num_slices * slice_size, [&](const auto i) {
auto slice_offset = input_slice_offsets[i / slice_size];
output[i] = data[slice_offset + i % slice_size];
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_ISNAN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ISNAN_HPP
#include <migraphx/op/unary.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct isnan : unary<isnan>
{
auto apply() const
{
return [](auto x) { return std::isnan(x); };
}
std::string name() const { return "isnan"; }
shape compute_shape(std::vector<shape> inputs) const
{
return unary<isnan>::compute_shape(std::move(inputs)).with_type(shape::bool_type);
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -69,7 +69,6 @@ struct multibroadcast
{
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -181,14 +181,15 @@ struct nonmaxsuppression
make_function_output_iterator([&](const auto& x) { sorted_boxes.push(x); });
int64_t box_idx = 0;
transform_if(scores.begin() + score_offset,
scores.begin() + score_offset + box_num,
insert_to_sorted_boxes,
[&](auto sc) {
box_idx++;
return sc >= score_threshold;
},
[&](auto sc) { return std::make_pair(sc, box_idx - 1); });
transform_if(
scores.begin() + score_offset,
scores.begin() + score_offset + box_num,
insert_to_sorted_boxes,
[&](auto sc) {
box_idx++;
return sc >= score_threshold;
},
[&](auto sc) { return std::make_pair(sc, box_idx - 1); });
selected_boxes_inside_class.clear();
// Get the next box with top score, filter by iou_threshold
......
......@@ -8,6 +8,7 @@
#include <migraphx/streamutils.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/value.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/int_divide.hpp>
......@@ -16,16 +17,18 @@
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct pooling
{
std::string mode = "average";
pooling_mode mode = {pooling_mode::average};
std::vector<std::size_t> padding = {0, 0};
std::vector<std::size_t> stride = {1, 1};
std::vector<std::size_t> lengths = {1, 1};
bool ceil_mode = false;
int lp_order = 2;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -34,7 +37,8 @@ struct pooling
f(self.padding, "padding"),
f(self.stride, "stride"),
f(self.lengths, "lengths"),
f(self.ceil_mode, "ceil_mode"));
f(self.ceil_mode, "ceil_mode"),
f(self.lp_order, "lp_order"));
}
std::string name() const { return "pooling"; }
......@@ -88,6 +92,114 @@ struct pooling
check_attribute_size();
return stride.size();
}
struct lpnorm_pool
{
int p = 0;
lpnorm_pool() = delete;
explicit lpnorm_pool(int x) : p{x} {};
template <class T>
double init() const
{
return 0.0;
}
double operator()(double x, double y) const { return x + std::pow(std::abs(y), p); }
double final(double x, std::size_t) const { return std::pow(x, 1. / p); }
};
struct avg_pool
{
template <class T>
double init() const
{
return 0.0;
}
double operator()(double x, double y) const { return x + y; }
double final(double x, std::size_t y) const { return (y == 0) ? 0.0 : (x / y); }
};
struct max_pool
{
template <class T>
T init() const
{
return std::numeric_limits<T>::lowest();
}
double operator()(double x, double y) const { return std::max(x, y); }
double final(double x, std::size_t) const { return (x); }
};
template <class Type, class Out, class In, class Op>
void calc_pooling(const shape& output_shape, Out& output, const In& input, Op op) const
{
auto in_s = input.get_shape();
auto in_lens = in_s.lens();
par_for(output_shape.elements(), [&](auto i) {
auto idx_o = output_shape.multi(i);
auto n_dim = idx_o.size();
std::vector<std::size_t> win_start;
std::vector<std::size_t> win_size;
for(std::size_t dim = 2; dim < n_dim; ++dim)
{
auto d_2 = dim - 2;
int start =
static_cast<int>(idx_o[dim] * stride[d_2]) - static_cast<int>(padding[d_2]);
int end = std::min(start + lengths[d_2], in_lens[dim]);
start = std::max(start, 0);
win_start.push_back(start);
win_size.push_back(end - start);
}
shape win_shape{output_shape.type(), win_size};
auto pool_size = win_shape.elements();
double output_val = op.template init<Type>();
shape_for_each(win_shape, [&](auto idx_w) {
auto idx = idx_o;
std::transform(idx_w.begin(),
idx_w.end(),
win_start.begin(),
idx.begin() + 2,
[](auto ii, auto jj) { return ii + jj; });
if(std::all_of(idx.begin() + 2, idx.end(), [&](auto ii) { return ii >= 0; }) and
idx < in_lens)
{
output_val = op(output_val, input[in_s.index(idx)]);
}
});
output[i] = Type(op.final(output_val, pool_size));
});
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
using type = typename decltype(output)::value_type;
switch(mode)
{
case migraphx::op::pooling_mode::average:
calc_pooling<type>(output_shape, output, input, avg_pool{});
break;
case migraphx::op::pooling_mode::max:
calc_pooling<type>(output_shape, output, input, max_pool{});
break;
case migraphx::op::pooling_mode::lpnorm:
calc_pooling<type>(output_shape, output, input, lpnorm_pool{lp_order});
break;
}
});
return result;
}
};
} // namespace op
......
......@@ -38,18 +38,38 @@ struct prefix_scan_op : op_name<Derived>
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return inputs.at(0);
auto s = inputs.front();
if(s.broadcasted())
{
return {s.type(), s.lens()};
}
else
{
return s.with_lens(s.lens());
}
}
argument compute(const shape&, std::vector<argument> args) const
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result = args[0].copy();
auto s = result.get_shape();
auto slice = shape{s.type(), {s.lens()[axis]}, {s.strides()[axis]}};
auto lens = s.lens();
lens[axis] = 1;
auto batch = shape{s.type(), lens, s.strides()};
auto& self = static_cast<const Derived&>(*this);
argument result{output_shape};
auto s = args[0].get_shape();
if(s == output_shape)
{
result = args[0].copy();
}
else
{
visit_all(result, args[0])([&](auto output, auto input) {
par_for(output_shape.elements(),
[&](auto i) { output[output_shape.index(i)] = input[s.index(i)]; });
});
s = output_shape;
}
auto slice = shape{s.type(), {s.lens()[axis]}, {s.strides()[axis]}};
auto lens = s.lens();
lens[axis] = 1;
auto batch = shape{s.type(), lens, s.strides()};
auto& self = static_cast<const Derived&>(*this);
result.visit([&](auto output) {
using type = decltype(output);
par_for(batch.elements(), [&](auto i) {
......
......@@ -9,6 +9,7 @@
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <migraphx/value.hpp>
#include <cmath>
#include <utility>
......@@ -26,6 +27,8 @@ struct reshape
return pack(f(self.dims, "dims"));
}
value attributes() const { return {{"require_std_shape", true}}; }
std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const
{
......@@ -72,7 +75,6 @@ struct reshape
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -3,6 +3,7 @@
#include <limits>
#include <migraphx/check_shapes.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
......@@ -21,7 +22,7 @@ namespace op {
struct roialign
{
std::string coord_trans_mode = "half_pixel";
std::string mode = "avg";
pooling_mode mode = {pooling_mode::average};
int64_t output_height = 1;
int64_t output_width = 1;
int64_t sampling_ratio = 0;
......@@ -42,7 +43,7 @@ struct roialign
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3).standard();
check_shapes{inputs, *this}.has(3);
auto x_lens = inputs.at(0).lens();
auto roi_lens = inputs.at(1).lens();
auto bi_lens = inputs.at(2).lens();
......@@ -241,19 +242,19 @@ struct roialign
in_dims[0] * in_dims[1]);
double output_val;
std::tie(output_val, vec_index[c]) =
(mode == "avg") ? this->calc_pooling(offset_bottom_data,
bin_grid_size,
pre_calc,
vec_index[c],
avg_pool{})
: this->calc_pooling(offset_bottom_data,
bin_grid_size,
pre_calc,
vec_index[c],
max_pool{});
(mode == migraphx::op::pooling_mode::average)
? this->calc_pooling(offset_bottom_data,
bin_grid_size,
pre_calc,
vec_index[c],
avg_pool{})
: this->calc_pooling(offset_bottom_data,
bin_grid_size,
pre_calc,
vec_index[c],
max_pool{});
output(n, c, ph, pw) = output_val;
});
});
});
......
......@@ -40,7 +40,6 @@ struct scalar
{
return args[0].reshape(output_shape);
}
lifetime get_lifetime() const { return lifetime::borrow; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -8,6 +8,7 @@
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/name.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
......@@ -16,7 +17,17 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatter
// The scatter operator fetches a subset of data given by an index array and then performs a
// reduction operation (add, multiply, or just set the data) on each element returned. We implement
// it as a separate derived struct for each of the three reduction methods. The related operator
// scatterND is a generalization that works on a set of 3 tensors of different ranks. The
// complementary operations are gather/gatherND.
//
// This is a template for deriving child structs from. Each child needs to define
// only a reduction() method. Names are automatically handled by the op_name template.
template <class Derived>
struct scatter : op_name<Derived>
{
int64_t axis = 0;
......@@ -33,29 +44,44 @@ struct scatter
return {{"normalize_axes", normalize}};
}
std::string name() const { return "scatter"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3).standard();
return inputs.front();
// If non-packed, this converts to a packed output while preserving permutation of tensor
return inputs.front().with_lens(inputs.front().lens());
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
// max dimension in axis
auto& self = static_cast<const Derived&>(*this);
// max dimension in each axis
auto axis_dim_size = output_shape.lens()[axis];
// cast all arguments as correct type
visit_all(result, args[0], args[2])([&](auto output, auto data, auto update) {
// copy all of data to output
std::copy(data.begin(), data.end(), output.begin());
args[1].visit([&](auto indices) {
auto ind_s = indices.get_shape();
// iterate through items in shape
shape_for_each(ind_s, [&](const auto& idx) {
auto out_idx = idx;
auto index = indices[ind_s.index(idx)];
auto out_idx = idx;
// Overloaded tensor_view::() invokes indexing logic of
// std::size_t shape::index(std::size_t i) const
// which handles nonstandard shapes correctly
auto index = indices(idx.begin(), idx.end());
// normalize negative indexes (may be redundant after using
// normalize_compute_shape())
index = (index < 0) ? index + axis_dim_size : index;
out_idx[axis] = index;
output[output_shape.index(out_idx)] = update[ind_s.index(idx)];
// look up the appropriate locations in output, using idx and out_idx.
// call reduction() method of derived struct to copy and reduce that element
self.reduction()(output(out_idx.begin(), out_idx.end()),
update(idx.begin(), idx.end()));
});
});
});
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_ADD_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_ADD_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
#include <migraphx/op/scatter.hpp>
// Scatter op. with "add" function as reduction.
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatter_add : scatter<scatter_add>
{
// reduction (pointwise operation) is called by the parent struct's compute() method.
// It works much like a virtual function overload.
// For the scatter methods, there are three different reduction functions.
auto reduction() const
{
return [](auto& x, const auto& y) { x += y; };
}
// name of this struct is automatically assigned by the op_name<>
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_MUL_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_MUL_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
#include <migraphx/op/scatter.hpp>
// Scatter op. with "multiply" as the reduction function.
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatter_mul : scatter<scatter_mul>
{
// reduction (pointwise operation) is called by the parent struct's compute() method.
// It works much like a virtual function overload.
// For the scatter operators, there are three different reduction functions.
auto reduction() const
{
return [](auto& x, const auto& y) { x *= y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_NONE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_NONE_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/scatter.hpp>
#include <cmath>
#include <utility>
// Scatter op. with "none" as the reduction function (just copies the value). This is identical to
// the previously existing Scatter op.
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatter_none : scatter<scatter_none>
{
// reduction (pointwise operation) is called by the parent struct's compute() method.
// It works much like a virtual function overload.
// For the scatter operators, there are three different reduction functions.
auto reduction() const
{
return [](auto& x, const auto& y) { x = y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_ADD_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_ADD_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_add : scatternd_op<scatternd_add>
{
scatternd_add() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x += y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_MUL_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_MUL_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_mul : scatternd_op<scatternd_mul>
{
scatternd_mul() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x *= y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_NONE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_NONE_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_none : scatternd_op<scatternd_none>
{
scatternd_none() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x = y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_OP_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_OP_HPP
#include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
template <class Derived>
struct scatternd_op : op_name<Derived>
{
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3);
auto r = inputs.front().lens().size();
auto q = inputs.at(1).lens().size();
auto k = inputs.at(1).lens().back();
auto ind_lens = inputs.at(1).lens();
auto upd_lens = inputs.back().lens();
auto data_lens = inputs.front().lens();
if(k > r)
MIGRAPHX_THROW("ScatterND: index of size " + std::to_string(k) +
" is too large for tensor of rank " + std::to_string(r));
if(not(std::equal(ind_lens.begin(), ind_lens.begin() + q - 1, upd_lens.begin()) and
std::equal(data_lens.begin() + k, data_lens.end(), upd_lens.begin() + q - 1)))
MIGRAPHX_THROW("ScatterND: incorrect update shape. update.lens != indices.lens[0:q-1] "
"++ data.lens[k:r-1]");
auto s = inputs.front();
if(s.broadcasted())
{
return {s.type(), s.lens()};
}
else
{
return s.with_lens(s.lens());
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto& self = static_cast<const Derived&>(*this);
visit_all(result, args[0], args[2])([&](auto output, auto data, auto updates) {
std::copy(data.begin(), data.end(), output.begin());
args[1].visit([&](auto indices) {
auto updates_shape = updates.get_shape();
auto updates_std = shape{updates_shape.type(), updates_shape.lens()};
auto indices_shape = indices.get_shape();
auto k = indices_shape.lens().back();
auto q = indices_shape.lens().size();
auto r = output_shape.lens().size();
par_for(updates_shape.elements(), [&](const auto i) {
auto updates_idx = updates_std.multi(i);
std::vector<std::size_t> indices_idx(q, 0);
std::copy(
updates_idx.begin(), updates_idx.begin() + q - 1, indices_idx.begin());
auto index_start = indices.begin() +
indices_shape.index(indices_idx.begin(), indices_idx.end());
auto index_end = index_start + k;
std::vector<std::size_t> out_idx(r, 0);
std::copy(index_start, index_end, out_idx.begin());
std::copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k);
self.reduction()(output[output_shape.index(out_idx)], updates[i]);
});
});
});
return result;
}
auto init() const {}
scatternd_op() {}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment