Unverified Commit b00489b3 authored by Lakhinder Walia's avatar Lakhinder Walia Committed by GitHub
Browse files

Move operation for memory performance + misc changes for cpu performance (#2130)

Reduce memory footprint by std::move of temporary (potentially very large) containers.
Minor cleanup for performance optimization: e.g. of Index() calculation -- which can get repeated millions of times in large tensors/vectors in a single Visit.
parent 20f6ed92
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -62,7 +62,7 @@ void convolution(Output output, T input, T weights, Padding padding, Stride stri ...@@ -62,7 +62,7 @@ void convolution(Output output, T input, T weights, Padding padding, Stride stri
shape win_shape{output_shape.type(), win_size}; shape win_shape{output_shape.type(), win_size};
double acc = 0.0; double acc = 0.0;
shape_for_each(win_shape, [&](auto idx_win) { shape_for_each(win_shape, [&](const auto& idx_win) {
auto k = idx_win[0]; auto k = idx_win[0];
const auto in_ch = group_id * wei_c + k; const auto in_ch = group_id * wei_c + k;
std::vector<std::ptrdiff_t> idx(idx_o.begin(), idx_o.end()); std::vector<std::ptrdiff_t> idx(idx_o.begin(), idx_o.end());
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -164,7 +164,7 @@ struct convolution_backwards ...@@ -164,7 +164,7 @@ struct convolution_backwards
shape win_shape{dyn_out.computed_shape.type(), win_size}; shape win_shape{dyn_out.computed_shape.type(), win_size};
par_dfor(in_n, wei_c)([&](int o, int k) { par_dfor(in_n, wei_c)([&](int o, int k) {
shape_for_each(win_shape, [&](auto idx_win) { shape_for_each(win_shape, [&](const auto& idx_win) {
const int w = idx_win[0]; const int w = idx_win[0];
auto input_dims_start = idx_win.begin() + 1; auto input_dims_start = idx_win.begin() + 1;
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -125,13 +125,12 @@ struct gather ...@@ -125,13 +125,12 @@ struct gather
auto out_lens = data.get_shape().lens(); auto out_lens = data.get_shape().lens();
out_lens[axis] = indices.get_shape().elements(); out_lens[axis] = indices.get_shape().elements();
migraphx::shape out_comp_shape{data.get_shape().type(), out_lens}; migraphx::shape out_comp_shape{data.get_shape().type(), out_lens};
shape_for_each(out_comp_shape, [&](const auto& out_idx) { shape_for_each(out_comp_shape, [&](const auto& out_idx_v, size_t out_idx) {
auto data_idx = out_idx; auto data_idx = out_idx_v;
auto in_index = indices[data_idx[axis]]; auto in_index = indices[data_idx[axis]];
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index; in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
data_idx[axis] = in_index; data_idx[axis] = in_index;
output[out_comp_shape.index(out_idx.begin(), out_idx.end())] = output[out_idx] = data(data_idx.begin(), data_idx.end());
data(data_idx.begin(), data_idx.end());
}); });
} }
}); });
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -258,7 +258,7 @@ struct nonmaxsuppression ...@@ -258,7 +258,7 @@ struct nonmaxsuppression
selected_boxes_inside_class.reserve(max_output_shape.elements()); selected_boxes_inside_class.reserve(max_output_shape.elements());
// iterate over batches and classes // iterate over batches and classes
shape comp_s{shape::double_type, {num_batches, num_classes}}; shape comp_s{shape::double_type, {num_batches, num_classes}};
shape_for_each(comp_s, [&](auto idx) { shape_for_each(comp_s, [&](const auto& idx) {
auto batch_idx = idx[0]; auto batch_idx = idx[0];
auto class_idx = idx[1]; auto class_idx = idx[1];
// index offset for this class // index offset for this class
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -56,10 +56,10 @@ struct nonzero ...@@ -56,10 +56,10 @@ struct nonzero
std::vector<std::vector<std::size_t>> vec_idx; std::vector<std::vector<std::size_t>> vec_idx;
auto s = args.front().get_shape(); auto s = args.front().get_shape();
args.front().visit([&](auto v) { args.front().visit([&](auto v) {
shape_for_each(s, [&](auto idx) { shape_for_each(s, [&](const auto& idx_v, size_t idx) {
if(not float_equal(v[s.index(idx)], 0)) if(not float_equal(v[idx], 0))
{ {
vec_idx.push_back(idx); vec_idx.push_back(idx_v);
} }
}); });
}); });
......
...@@ -365,7 +365,7 @@ struct pooling ...@@ -365,7 +365,7 @@ struct pooling
double output_val = op.template init<Type>(); double output_val = op.template init<Type>();
// for each element in the window... // for each element in the window...
shape_for_each(win_shape, [&](auto idx_w) { shape_for_each(win_shape, [&](const auto& idx_w) {
// the coordinates of this element // the coordinates of this element
auto idx = idx_o; auto idx = idx_o;
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -163,7 +163,7 @@ struct reduce_op : op_name<Derived> ...@@ -163,7 +163,7 @@ struct reduce_op : op_name<Derived>
auto& self = static_cast<const Derived&>(*this); auto& self = static_cast<const Derived&>(*this);
auto data_idx = out_idx; auto data_idx = out_idx;
accumulator val = self.init(); accumulator val = self.init();
shape_for_each(batch_shape, [&](auto b_idx) { shape_for_each(batch_shape, [&](const auto& b_idx) {
this->tune_dims(tuned_axes, b_idx, data_idx); this->tune_dims(tuned_axes, b_idx, data_idx);
accumulator x = input(data_idx.begin(), data_idx.end()); accumulator x = input(data_idx.begin(), data_idx.end());
val = self.op()(accumulator{self.input()(x)}, val); val = self.op()(accumulator{self.input()(x)}, val);
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -70,13 +70,13 @@ struct reverse ...@@ -70,13 +70,13 @@ struct reverse
argument result{s}; argument result{s};
auto lens = s.lens(); auto lens = s.lens();
visit_all(result, args.front())([&](auto output, auto input) { visit_all(result, args.front())([&](auto output, auto input) {
shape_for_each(s, [&](const auto& out_idx) { shape_for_each(s, [&](const auto& out_idx_v, size_t out_idx) {
auto in_idx = out_idx; auto in_idx = out_idx_v;
for(const auto& axis : axes) for(const auto& axis : axes)
{ {
in_idx[axis] = lens[axis] - 1 - out_idx[axis]; in_idx[axis] = lens[axis] - 1 - out_idx_v[axis];
} }
output[s.index(out_idx)] = input[s.index(in_idx)]; output[out_idx] = input[s.index(in_idx)];
}); });
}); });
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -113,10 +113,9 @@ struct roialign ...@@ -113,10 +113,9 @@ struct roialign
{ {
std::vector<pos_weight> results(bin_grid_size[0] * bin_grid_size[1] * output_height * std::vector<pos_weight> results(bin_grid_size[0] * bin_grid_size[1] * output_height *
output_width); output_width);
shape_for_each(comp_s, [&](auto idx) { shape_for_each(comp_s, [&](const auto& idx_v, size_t index) {
std::array<std::size_t, 2> p = {idx[0], idx[1]}; std::array<std::size_t, 2> p = {idx_v[0], idx_v[1]};
std::array<std::size_t, 2> i = {idx[2], idx[3]}; std::array<std::size_t, 2> i = {idx_v[2], idx_v[3]};
auto index = comp_s.index(idx);
std::array<float, 2> xy{}; std::array<float, 2> xy{};
std::array<int64_t, 2> low{}; std::array<int64_t, 2> low{};
...@@ -255,7 +254,7 @@ struct roialign ...@@ -255,7 +254,7 @@ struct roialign
std::vector<std::size_t> comp_lens1 = {channels, out_dims[0], out_dims[1]}; std::vector<std::size_t> comp_lens1 = {channels, out_dims[0], out_dims[1]};
shape comp_s1{migraphx::shape::float_type, comp_lens1}; shape comp_s1{migraphx::shape::float_type, comp_lens1};
std::vector<int64_t> vec_index(channels, 0); std::vector<int64_t> vec_index(channels, 0);
shape_for_each(comp_s1, [&](auto idx) { shape_for_each(comp_s1, [&](const auto& idx) {
auto c = idx[0]; auto c = idx[0];
auto ph = idx[1]; auto ph = idx[1];
auto pw = idx[2]; auto pw = idx[2];
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
......
...@@ -205,7 +205,7 @@ void transform(Range1&& r1, Range2&& r2, Iterator it, F f) ...@@ -205,7 +205,7 @@ void transform(Range1&& r1, Range2&& r2, Iterator it, F f)
} }
template <class Range> template <class Range>
auto reverse(Range& r) auto reverse(Range&& r)
{ {
return range(std::make_reverse_iterator(r.end()), std::make_reverse_iterator(r.begin())); return range(std::make_reverse_iterator(r.end()), std::make_reverse_iterator(r.begin()));
} }
......
...@@ -263,7 +263,7 @@ struct MIGRAPHX_EXPORT shape ...@@ -263,7 +263,7 @@ struct MIGRAPHX_EXPORT shape
/// no padding /// no padding
bool packed() const; bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending /// Returns true if the shape has been transposed. That is the strides are not in descending
/// order /// order
bool transposed() const; bool transposed() const;
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -37,11 +37,11 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -37,11 +37,11 @@ inline namespace MIGRAPHX_INLINE_NS {
template <class F> template <class F>
void shape_for_each(const migraphx::shape& s, F f) void shape_for_each(const migraphx::shape& s, F f)
{ {
// Ensure calls to f use const ref to vector
auto call = [&f](const std::vector<std::size_t>& i) { f(i); };
std::vector<std::size_t> indices(s.lens().size()); std::vector<std::size_t> indices(s.lens().size());
const auto& index_const_ref = indices;
shape ss{s.type(), s.lens()}; shape ss{s.type(), s.lens()};
for(std::size_t i = 0; i < ss.elements(); i++) size_t max = ss.elements();
for(std::size_t i = 0; i < max; i++)
{ {
std::transform(ss.strides().begin(), std::transform(ss.strides().begin(),
ss.strides().end(), ss.strides().end(),
...@@ -51,9 +51,13 @@ void shape_for_each(const migraphx::shape& s, F f) ...@@ -51,9 +51,13 @@ void shape_for_each(const migraphx::shape& s, F f)
assert(len > 0 and stride > 0); assert(len > 0 and stride > 0);
return (i / stride) % len; return (i / stride) % len;
}); });
call(indices); if constexpr(std::is_invocable<F, decltype(index_const_ref), decltype(i)>{})
f(index_const_ref, i);
else
f(index_const_ref);
} }
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -97,22 +97,19 @@ const auto& get_original_idx_op(const std::string& mode) ...@@ -97,22 +97,19 @@ const auto& get_original_idx_op(const std::string& mode)
static std::vector<int> static std::vector<int>
calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& vvv_ind, calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& vvv_ind,
int i_dim, int i_dim,
const std::vector<std::vector<std::size_t>>& vec_dims, std::vector<std::vector<std::size_t>> vec_dims,
const shape& in_s) const shape& in_s)
{ {
if(i_dim == vvv_ind.size()) if(i_dim == vvv_ind.size())
{ {
std::vector<int> vec_ind; std::vector<int> vec_ind(vec_dims.size());
vec_ind.resize(vec_dims.size());
std::transform(vec_dims.begin(), vec_dims.end(), vec_ind.begin(), [&](auto idx) { std::transform(vec_dims.begin(), vec_dims.end(), vec_ind.begin(), [&](auto idx) {
return static_cast<int>(in_s.index(idx)); return static_cast<int>(in_s.index(idx));
}); });
return vec_ind; return vec_ind;
} }
const auto& vv_ind = vvv_ind[i_dim]; const auto& vv_lo = vvv_ind[i_dim][0];
const auto& vv_lo = vv_ind.at(0);
std::vector<std::vector<std::size_t>> vec_dims1; std::vector<std::vector<std::size_t>> vec_dims1;
for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size()) for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size())
{ {
...@@ -126,8 +123,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v ...@@ -126,8 +123,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v
}); });
} }
const auto& vv_hi = vv_ind.at(1); const auto& vv_hi = vvv_ind[i_dim][1];
for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size()) for(std::size_t start = 0; start < vec_dims.size(); start += vv_hi.size())
{ {
std::transform(vv_hi.begin(), std::transform(vv_hi.begin(),
vv_hi.end(), vv_hi.end(),
...@@ -138,8 +135,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v ...@@ -138,8 +135,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v
return dim; return dim;
}); });
} }
vec_dims.clear();
return calc_neighbor_points(vvv_ind, i_dim + 1, vec_dims1, in_s); return calc_neighbor_points(vvv_ind, i_dim + 1, std::move(vec_dims1), in_s);
} }
static std::string get_coord_trans_mode(const onnx_parser::attribute_map& attr) static std::string get_coord_trans_mode(const onnx_parser::attribute_map& attr)
...@@ -240,7 +237,7 @@ struct parse_resize : op_parser<parse_resize> ...@@ -240,7 +237,7 @@ struct parse_resize : op_parser<parse_resize>
auto arg_out_s = arg->eval(); auto arg_out_s = arg->eval();
check_arg_empty(arg_out_s, check_arg_empty(arg_out_s,
"PARSE_" + opd.op_name + ": dynamic output size is not supported!"); "PARSE_" + opd.op_name + ": dynamic output size is not supported!");
arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); }); arg_out_s.visit([&](const auto& ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_lens.size()) if(out_lens.size() != in_lens.size())
{ {
...@@ -267,7 +264,7 @@ struct parse_resize : op_parser<parse_resize> ...@@ -267,7 +264,7 @@ struct parse_resize : op_parser<parse_resize>
"PARSE_" + opd.op_name + "PARSE_" + opd.op_name +
": dynamic input scale is not supported!"); ": dynamic input scale is not supported!");
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); }); arg_scale.visit([&](const auto& v) { vec_scale.assign(v.begin(), v.end()); });
if(in_lens.size() != vec_scale.size()) if(in_lens.size() != vec_scale.size())
{ {
MIGRAPHX_THROW("PARSE_" + opd.op_name + MIGRAPHX_THROW("PARSE_" + opd.op_name +
...@@ -300,15 +297,15 @@ struct parse_resize : op_parser<parse_resize> ...@@ -300,15 +297,15 @@ struct parse_resize : op_parser<parse_resize>
// map out_idx to in_idx // map out_idx to in_idx
auto nearest_op = get_nearest_op(nearest_mode); auto nearest_op = get_nearest_op(nearest_mode);
shape_for_each(out_s, [&](auto idx) { shape_for_each(out_s, [&](const auto& out_idx_v, size_t out_idx) {
auto in_idx = idx; std::vector<size_t> in_idx(out_idx_v.size());
for(auto ii = 0; ii < in_lens.size(); ++ii) for(auto ii = 0; ii < in_lens.size(); ++ii)
{ {
auto idx_val = idx_op(in_lens[ii], out_lens[ii], idx[ii], vec_scale[ii]); auto idx_val = idx_op(in_lens[ii], out_lens[ii], out_idx_v[ii], vec_scale[ii]);
in_idx[ii] = nearest_op(in_lens[ii], idx_val); in_idx[ii] = nearest_op(in_lens[ii], idx_val);
} }
ind[out_s.index(idx)] = static_cast<int64_t>(in_s.index(in_idx)); ind[out_idx] = static_cast<int64_t>(in_s.index(in_idx));
}); });
shape ind_s{shape::int32_type, out_lens}; shape ind_s{shape::int32_type, out_lens};
...@@ -323,24 +320,21 @@ struct parse_resize : op_parser<parse_resize> ...@@ -323,24 +320,21 @@ struct parse_resize : op_parser<parse_resize>
// get the number of dimensions // get the number of dimensions
std::size_t n_dim = out_lens.size(); std::size_t n_dim = out_lens.size();
std::vector<std::vector<std::size_t>> vv_ind(2, std::vector<std::size_t>(out_elements)); auto vvv_ind = std::vector(n_dim, std::vector(2, std::vector<size_t>(out_elements)));
std::vector<std::vector<std::vector<std::size_t>>> vvv_ind(n_dim, vv_ind);
std::vector<std::vector<float>> delta(n_dim, std::vector<float>(out_elements)); std::vector<std::vector<float>> delta(n_dim, std::vector<float>(out_elements));
shape_for_each(out_s, [&](auto idx) { shape_for_each(out_s, [&](const auto& out_idx_v, size_t out_idx) {
auto in_idx = idx;
auto out_idx = out_s.index(idx);
for(auto ii = 0; ii < in_lens.size(); ++ii) for(auto ii = 0; ii < in_lens.size(); ++ii)
{ {
auto idx_val = idx_op(in_lens[ii], out_lens[ii], idx[ii], vec_scale[ii]); auto idx_val = idx_op(in_lens[ii], out_lens[ii], out_idx_v[ii], vec_scale[ii]);
vvv_ind[ii][0][out_idx] = nearest_floor(in_lens[ii], idx_val); vvv_ind[ii][0][out_idx] = nearest_floor(in_lens[ii], idx_val);
vvv_ind[ii][1][out_idx] = nearest_ceil(in_lens[ii], idx_val); vvv_ind[ii][1][out_idx] = nearest_ceil(in_lens[ii], idx_val);
delta[ii][out_idx] = idx_val - vvv_ind[ii][0][out_idx]; delta[ii][out_idx] = idx_val - vvv_ind[ii][0][out_idx];
} }
}); });
std::vector<std::vector<std::size_t>> vec_dims(out_elements); auto ind = calc_neighbor_points(
auto ind = calc_neighbor_points(vvv_ind, 0, vec_dims, in_s); vvv_ind, 0, std::vector<std::vector<std::size_t>>(out_elements), in_s);
auto ind_lens = out_lens; auto ind_lens = out_lens;
ind_lens[0] *= (std::size_t{1} << n_dim); ind_lens[0] *= (std::size_t{1} << n_dim);
shape ind_s{shape::int32_type, ind_lens}; shape ind_s{shape::int32_type, ind_lens};
......
...@@ -50,13 +50,14 @@ struct shape_impl ...@@ -50,13 +50,14 @@ struct shape_impl
{ {
assert(t != shape::tuple_type); assert(t != shape::tuple_type);
} }
shape_impl(shape::type_t t, std::vector<std::size_t> l) shape_impl(shape::type_t t, std::vector<std::size_t> l)
: m_type(t), m_lens(std::move(l)), m_standard(true) : m_type(t), m_lens(std::move(l)), m_standard(true)
{ {
assert(t != shape::tuple_type); assert(t != shape::tuple_type);
this->calculate_strides(); this->calculate_strides();
assert(m_lens.size() == m_strides.size());
} }
shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: m_type(t), m_lens(std::move(l)), m_strides(std::move(s)) : m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
{ {
...@@ -151,6 +152,22 @@ struct shape_impl ...@@ -151,6 +152,22 @@ struct shape_impl
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>()); m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
} }
std::size_t get_index(size_t i) const
{
std::size_t result = 0;
std::size_t s = 1;
for(auto k : migraphx::reverse(migraphx::range(m_lens.size())))
{
std::size_t stride = m_strides[k];
std::size_t len = m_lens[k];
std::size_t idx = (i % (s * len)) / s;
result += stride * idx;
s *= len;
}
return result;
}
std::vector<std::size_t> min_lens() const std::vector<std::size_t> min_lens() const
{ {
std::vector<std::size_t> ret(m_dyn_dims.size()); std::vector<std::size_t> ret(m_dyn_dims.size());
...@@ -213,6 +230,7 @@ std::string shape::name(shape::type_t t) ...@@ -213,6 +230,7 @@ std::string shape::name(shape::type_t t)
} }
MIGRAPHX_THROW("Invalid type"); MIGRAPHX_THROW("Invalid type");
} }
std::string shape::cpp_type(shape::type_t t) std::string shape::cpp_type(shape::type_t t)
{ {
switch(t) switch(t)
...@@ -229,10 +247,12 @@ std::string shape::cpp_type(shape::type_t t) ...@@ -229,10 +247,12 @@ std::string shape::cpp_type(shape::type_t t)
shape::shape() : impl(shape_impl::default_shape()) {} shape::shape() : impl(shape_impl::default_shape()) {}
shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {} shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {}
shape::shape(type_t t, std::vector<std::size_t> l) shape::shape(type_t t, std::vector<std::size_t> l)
: impl(std::make_shared<shape_impl>(t, std::move(l))) : impl(std::make_shared<shape_impl>(t, std::move(l)))
{ {
} }
shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: impl(std::make_shared<shape_impl>(t, std::move(l), std::move(s))) : impl(std::make_shared<shape_impl>(t, std::move(l), std::move(s)))
{ {
...@@ -358,21 +378,8 @@ std::size_t shape::index(std::size_t i) const ...@@ -358,21 +378,8 @@ std::size_t shape::index(std::size_t i) const
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
if(this->standard()) if(this->standard())
return i; return i;
else
{ return impl->get_index(i);
std::size_t s = 1;
std::size_t result = 0;
for(std::size_t j = 0; j < this->lens().size(); j++)
{
const std::size_t k = this->lens().size() - j - 1;
const std::size_t stride = this->strides()[k];
const std::size_t len = this->lens()[k];
const std::size_t idx = (i % (s * len)) / s;
result += stride * idx;
s *= len;
}
return result;
}
} }
std::vector<std::size_t> shape::multi(std::size_t idx) const std::vector<std::size_t> shape::multi(std::size_t idx) const
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment