Commit 32b83c9c authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into inner_bcast_fix

parents 92f5a6cd 434a06cf
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -164,7 +164,7 @@ struct convolution_backwards ...@@ -164,7 +164,7 @@ struct convolution_backwards
shape win_shape{dyn_out.computed_shape.type(), win_size}; shape win_shape{dyn_out.computed_shape.type(), win_size};
par_dfor(in_n, wei_c)([&](int o, int k) { par_dfor(in_n, wei_c)([&](int o, int k) {
shape_for_each(win_shape, [&](auto idx_win) { shape_for_each(win_shape, [&](const auto& idx_win) {
const int w = idx_win[0]; const int w = idx_win[0];
auto input_dims_start = idx_win.begin() + 1; auto input_dims_start = idx_win.begin() + 1;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_FILL_HPP
#define MIGRAPHX_GUARD_OPERATORS_FILL_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/par_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
/**
* fill(default_value, output_buffer)
* Fill an output buffer with the given default_value.
* Note that if the default_value is a literal and the output_buffer
* has a static shape this operator can be replaced with a literal.
*/
struct fill
{
std::string name() const { return "fill"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(2).same_type();
if(inputs.at(0).dynamic() or inputs.at(0).elements() != 1)
{
MIGRAPHX_THROW("FILL: default_value is dynamic or more than one element");
}
return inputs.back();
}
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
visit_all(args[0], args[1])([&](auto value, auto output) {
par_for(dyn_out.computed_shape.elements(), [&](auto i) { output[i] = value.front(); });
});
return args[1];
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 1; }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -125,13 +125,12 @@ struct gather ...@@ -125,13 +125,12 @@ struct gather
auto out_lens = data.get_shape().lens(); auto out_lens = data.get_shape().lens();
out_lens[axis] = indices.get_shape().elements(); out_lens[axis] = indices.get_shape().elements();
migraphx::shape out_comp_shape{data.get_shape().type(), out_lens}; migraphx::shape out_comp_shape{data.get_shape().type(), out_lens};
shape_for_each(out_comp_shape, [&](const auto& out_idx) { shape_for_each(out_comp_shape, [&](const auto& out_idx_v, size_t out_idx) {
auto data_idx = out_idx; auto data_idx = out_idx_v;
auto in_index = indices[data_idx[axis]]; auto in_index = indices[data_idx[axis]];
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index; in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
data_idx[axis] = in_index; data_idx[axis] = in_index;
output[out_comp_shape.index(out_idx.begin(), out_idx.end())] = output[out_idx] = data(data_idx.begin(), data_idx.end());
data(data_idx.begin(), data_idx.end());
}); });
} }
}); });
......
...@@ -71,7 +71,7 @@ struct if_op ...@@ -71,7 +71,7 @@ struct if_op
std::unordered_map<std::string, argument> params; std::unordered_map<std::string, argument> params;
std::set<std::string> pnames; std::set<std::string> pnames;
for(const auto& smod : mods) for(const_module_ref smod : mods)
{ {
auto names = smod->get_parameter_names(); auto names = smod->get_parameter_names();
pnames.insert(names.begin(), names.end()); pnames.insert(names.begin(), names.end());
......
...@@ -59,7 +59,7 @@ struct loop ...@@ -59,7 +59,7 @@ struct loop
MIGRAPHX_THROW("LOOP: operator should have one submodule."); MIGRAPHX_THROW("LOOP: operator should have one submodule.");
} }
const auto& mod = mods.front(); const_module_ref mod = mods.front();
auto mod_out_shapes = mod->get_output_shapes(); auto mod_out_shapes = mod->get_output_shapes();
auto dep_param_num = inputs.size() - 2; auto dep_param_num = inputs.size() - 2;
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -258,7 +258,7 @@ struct nonmaxsuppression ...@@ -258,7 +258,7 @@ struct nonmaxsuppression
selected_boxes_inside_class.reserve(max_output_shape.elements()); selected_boxes_inside_class.reserve(max_output_shape.elements());
// iterate over batches and classes // iterate over batches and classes
shape comp_s{shape::double_type, {num_batches, num_classes}}; shape comp_s{shape::double_type, {num_batches, num_classes}};
shape_for_each(comp_s, [&](auto idx) { shape_for_each(comp_s, [&](const auto& idx) {
auto batch_idx = idx[0]; auto batch_idx = idx[0];
auto class_idx = idx[1]; auto class_idx = idx[1];
// index offset for this class // index offset for this class
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -56,10 +56,10 @@ struct nonzero ...@@ -56,10 +56,10 @@ struct nonzero
std::vector<std::vector<std::size_t>> vec_idx; std::vector<std::vector<std::size_t>> vec_idx;
auto s = args.front().get_shape(); auto s = args.front().get_shape();
args.front().visit([&](auto v) { args.front().visit([&](auto v) {
shape_for_each(s, [&](auto idx) { shape_for_each(s, [&](const auto& idx_v, size_t idx) {
if(not float_equal(v[s.index(idx)], 0)) if(not float_equal(v[idx], 0))
{ {
vec_idx.push_back(idx); vec_idx.push_back(idx_v);
} }
}); });
}); });
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/dyn_output.hpp> #include <migraphx/dyn_output.hpp>
...@@ -40,10 +41,20 @@ namespace migraphx { ...@@ -40,10 +41,20 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
// The Pooling operator mostly follows the specifications for the Onnx pooling op.
// It assumes an NCHW layout, extended to support any number of spatial dimensions
// from 1 on up; dimensions are <batch index, channels, spatial dimensions...>
//
struct pooling struct pooling
{ {
// Class members mode, ceil_mode, padding_mode have similar names but refer to separate
// concepts.
pooling_mode mode = {pooling_mode::average}; pooling_mode mode = {pooling_mode::average};
// If the input has rank other than 4 then padding, stride, lengths must all be specified
// since the defaults have 2-dimensions. Exception: padding not required if
// padding_mode != default_
// Padding along each spatial input dimension // Padding along each spatial input dimension
// Can be ndim or 2*ndim values where ndim is size of lengths // Can be ndim or 2*ndim values where ndim is size of lengths
// ndim values means pad the same before and after each dimension // ndim values means pad the same before and after each dimension
...@@ -63,13 +74,14 @@ struct pooling ...@@ -63,13 +74,14 @@ struct pooling
// ceiling mode is a flag affecting output size // ceiling mode is a flag affecting output size
// or equivalently, placements of the pooling kernel. // or equivalently, placements of the pooling kernel.
// When true, round the size upwards, possibly // When true, round the size upwards. When false, round down so that all
// including partial placements where the kernel extends beyond the edge
// of input and even padding. When false, round down so that all
// kernel placements fit but some input values may be dropped. // kernel placements fit but some input values may be dropped.
bool ceil_mode = false; bool ceil_mode = false;
int lp_order = 2; int lp_order = 2;
// Mode for auto padding. default_ indicates no auto padding.
padding_mode_t padding_mode = padding_mode_t::default_;
// Global pooling with dynamic shape input // Global pooling with dynamic shape input
bool dyn_global = false; bool dyn_global = false;
...@@ -84,6 +96,7 @@ struct pooling ...@@ -84,6 +96,7 @@ struct pooling
{ {
return pack(f(self.mode, "mode"), return pack(f(self.mode, "mode"),
f(self.padding, "padding"), f(self.padding, "padding"),
f(self.padding_mode, "padding_mode"),
f(self.stride, "stride"), f(self.stride, "stride"),
f(self.lengths, "lengths"), f(self.lengths, "lengths"),
f(self.ceil_mode, "ceil_mode"), f(self.ceil_mode, "ceil_mode"),
...@@ -97,7 +110,8 @@ struct pooling ...@@ -97,7 +110,8 @@ struct pooling
{ {
if(dyn_global) if(dyn_global)
return; return;
if((padding.size() != stride.size() and (padding.size()) != stride.size() * 2) or if((padding_mode != default_ and padding.size() != stride.size() and
(padding.size()) != stride.size() * 2) or
stride.size() != lengths.size()) stride.size() != lengths.size())
{ {
MIGRAPHX_THROW("POOLING: inconsistent attribute sizes"); MIGRAPHX_THROW("POOLING: inconsistent attribute sizes");
...@@ -137,8 +151,19 @@ struct pooling ...@@ -137,8 +151,19 @@ struct pooling
std::size_t padding_factor = 2 * padding[i]; std::size_t padding_factor = 2 * padding[i];
if(padding.size() == 2 * kdims) if(padding.size() == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims]; padding_factor = padding[i] + padding[i + kdims];
assert(input_lens[i + 2] + padding_factor >= lengths[i]); std::size_t dim_size;
std::size_t dim_size = input_lens[i + 2] + padding_factor - lengths[i]; if(input_lens[i + 2] + padding_factor < lengths[i])
{
if(padding_mode == default_)
MIGRAPHX_THROW("POOLING: not enough padding for the given kernel size");
// lengths can be legitimately larger only if we're doing auto padding
// with a dynamic shape, in which case given padding is ignored. Set a dummy value.
dim_size = 2;
}
else
{
dim_size = input_lens[i + 2] + padding_factor - lengths[i];
}
std::size_t len = std::size_t len =
(ceil_mode) (ceil_mode)
? dim_size / stride[i] + ? dim_size / stride[i] +
...@@ -151,17 +176,13 @@ struct pooling ...@@ -151,17 +176,13 @@ struct pooling
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(1); check_shapes{inputs, *this, true}.has(1).min_ndims(3);
check_attribute_size(); check_attribute_size();
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
auto padding_size = padding.size(); auto stride_size = stride.size();
size_t kdims = input.ndim() - 2; size_t kdims = input.ndim() - 2;
if(input.ndim() < 3) if(input.ndim() != stride_size + 2)
{
MIGRAPHX_THROW("POOLING: input must have 3 or more dimensions and be nonempty");
}
if(input.ndim() * 2 != padding_size + 4 and input.ndim() != padding_size + 2)
{ {
MIGRAPHX_THROW("POOLING: input and attribute size mismatch!"); MIGRAPHX_THROW("POOLING: input and attribute size mismatch!");
} }
...@@ -179,6 +200,28 @@ struct pooling ...@@ -179,6 +200,28 @@ struct pooling
} }
return {input.type(), output_dyn_dims}; return {input.type(), output_dyn_dims};
} }
else if(padding_mode != default_)
{
const size_t num_spatial_dims = inputs[0].ndim() - 2;
const shape& x_shape = inputs[0];
// same as convolution::dynamic_compute_shape()
for(std::size_t i = 0; i < num_spatial_dims; ++i)
{
auto ceil_div = [](std::size_t x, std::size_t y) { return (x + y - 1) / y; };
auto s = stride[i];
auto x = x_shape.dyn_dims()[i + 2];
std::set<std::size_t> optimals{};
std::transform(x.optimals.begin(),
x.optimals.end(),
std::inserter(optimals, optimals.begin()),
[&](auto o) { return ceil_div(o, s); });
output_dyn_dims.push_back(
shape::dynamic_dimension{ceil_div(x.min, s), ceil_div(x.max, s), optimals});
}
return {input.type(), output_dyn_dims};
}
else else
{ {
// does not compute optimals // does not compute optimals
...@@ -267,6 +310,7 @@ struct pooling ...@@ -267,6 +310,7 @@ struct pooling
Out& output, Out& output,
const In& input, const In& input,
const std::vector<std::size_t>& kernel_dims, const std::vector<std::size_t>& kernel_dims,
const std::vector<std::size_t>& padding_vals,
Op op) const Op op) const
{ {
auto in_s = input.get_shape(); auto in_s = input.get_shape();
...@@ -284,8 +328,8 @@ struct pooling ...@@ -284,8 +328,8 @@ struct pooling
for(std::size_t dim = 2; dim < n_dim; ++dim) for(std::size_t dim = 2; dim < n_dim; ++dim)
{ {
auto d_2 = dim - 2; auto d_2 = dim - 2;
int start = int start = static_cast<int>(idx_o[dim] * stride[d_2]) -
static_cast<int>(idx_o[dim] * stride[d_2]) - static_cast<int>(padding[d_2]); static_cast<int>(padding_vals[d_2]);
int end; int end;
// NOLINT // NOLINT
if(count_include_pad and ceil_mode and (mode != pooling_mode::max)) if(count_include_pad and ceil_mode and (mode != pooling_mode::max))
...@@ -297,7 +341,7 @@ struct pooling ...@@ -297,7 +341,7 @@ struct pooling
// Check if this kernel extends beyond the padding at end of dimension // Check if this kernel extends beyond the padding at end of dimension
end = std::min(start + kernel_dims[d_2], end = std::min(start + kernel_dims[d_2],
in_lens[dim] + static_cast<int>(padding[d_2])); in_lens[dim] + static_cast<int>(padding_vals[d_2]));
} }
else else
{ {
...@@ -316,11 +360,12 @@ struct pooling ...@@ -316,11 +360,12 @@ struct pooling
} }
shape win_shape{output_shape.type(), win_size}; shape win_shape{output_shape.type(), win_size};
auto pool_size = win_shape.elements(); auto pool_size = win_shape.elements();
double output_val = op.template init<Type>(); double output_val = op.template init<Type>();
// for each element in the window... // for each element in the window...
shape_for_each(win_shape, [&](auto idx_w) { shape_for_each(win_shape, [&](const auto& idx_w) {
// the coordinates of this element // the coordinates of this element
auto idx = idx_o; auto idx = idx_o;
...@@ -354,30 +399,65 @@ struct pooling ...@@ -354,30 +399,65 @@ struct pooling
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{dyn_out.computed_shape}; argument result;
auto input_lens = args[0].get_shape().lens(); auto input_lens = args[0].get_shape().lens();
std::vector<std::size_t> kernel_dims; std::vector<std::size_t> kernel_dims;
shape output_shape;
// If we have to auto-calculate padding, it will be passed to calc_pooling() as an argument
// instead of the member variable padding.
std::vector<std::size_t> temp_padding(padding);
if(dyn_global) if(dyn_global)
{ {
// for dynamic GlobalPooling, there's no padding
kernel_dims.insert(kernel_dims.end(), input_lens.begin() + 2, input_lens.end()); kernel_dims.insert(kernel_dims.end(), input_lens.begin() + 2, input_lens.end());
output_shape = dyn_out.computed_shape;
result = dyn_out.computed_shape;
} }
else else if((padding_mode != op::padding_mode_t::default_))
{
// if padding_mode is set, input was a dynamic size. Calculate padded size now.
// kernel_lens is the same as kernel_dims, but prepended with the 2 non-
// spatial dimensions. For size computations, it's used like the weights
// tensor for convolutions.
std::vector<std::size_t> kernel_lens;
kernel_lens.insert(kernel_lens.end(), input_lens.begin(), input_lens.begin() + 2);
kernel_lens.insert(kernel_lens.end(), lengths.begin(), lengths.end());
kernel_dims = this->lengths;
auto type = args[0].get_shape().type();
// dilation not currently supported for pooling, so default to all 1's
temp_padding = calc_dyn_auto_pad(
input_lens, kernel_lens, stride, {1, 1}, bool(padding_mode == op::same_upper));
output_shape = compute_padded_pool_shape(
args[0].get_shape(), shape(type, kernel_dims), temp_padding, stride, {1, 1});
result = argument(output_shape);
}
else // fixed/static input
{ {
kernel_dims = this->lengths; kernel_dims = this->lengths;
output_shape = dyn_out.computed_shape;
result = dyn_out.computed_shape;
} }
// Perform the computation and populate result
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
using type = typename decltype(output)::value_type; using type = typename decltype(output)::value_type;
switch(mode) switch(mode)
{ {
case migraphx::op::pooling_mode::average: case migraphx::op::pooling_mode::average:
calc_pooling<type>(dyn_out.computed_shape, output, input, kernel_dims, avg_pool{}); calc_pooling<type>(
output_shape, output, input, kernel_dims, temp_padding, avg_pool{});
break; break;
case migraphx::op::pooling_mode::max: case migraphx::op::pooling_mode::max:
calc_pooling<type>(dyn_out.computed_shape, output, input, kernel_dims, max_pool{}); calc_pooling<type>(
output_shape, output, input, kernel_dims, temp_padding, max_pool{});
break; break;
case migraphx::op::pooling_mode::lpnorm: case migraphx::op::pooling_mode::lpnorm:
calc_pooling<type>( calc_pooling<type>(
dyn_out.computed_shape, output, input, kernel_dims, lpnorm_pool{lp_order}); output_shape, output, input, kernel_dims, temp_padding, lpnorm_pool{lp_order});
break; break;
} }
}); });
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_RANDOM_SEED_HPP
#define MIGRAPHX_GUARD_OPERATORS_RANDOM_SEED_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <random>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
/**
* Generates a random seed for the use of random number generators. Generating the seed
* at runtime guarantees there will be a different random sequence on every execution.
* This operation has no inputs or attributes, and outputs an unsigned integer tensor with
* a single value.
*/
struct random_seed
{
shape::type_t dtype = shape::type_t::uint64_type;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.dtype, "dtype"));
}
std::string name() const { return "random_seed"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(0);
return shape{dtype};
}
argument compute(const shape& output_shape, const std::vector<argument>&) const
{
argument result(output_shape);
result.visit([&](auto output) { output.front() = std::random_device{}(); });
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
/**
* Random Uniform distribution operator. Given a shape, populate it with random
* values. Calls to random_uniform using the same randomization seed as a
* literal input will
* always generate the same pseudo-random sequence.
*
* Inputs: (1) randomization seed (any type is allowed)
* (2) output buffer argument to be populated.
*
* Attributes: none
*
* Output: Returns the buffer from input #2.
*
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_RANDOM_UNIFORM_HPP
#define MIGRAPHX_GUARD_OPERATORS_RANDOM_UNIFORM_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <random>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
/**
* random_uniform populates the passed shape with random numbers, in a uniform
* distribution. Range for floating-point data types is (0, 1);
* for integer types it is [0, <max value for the type>]
*/
struct random_uniform
{
// The random_uniform operation needs the random number generator seed
// to be passed as a runtime input.
std::string name() const { return "random_uniform"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(2);
return inputs.at(1);
}
argument compute(const shape&, std::vector<argument> args) const
{
// Output goes into the passed buffer, not the shape output.
auto result = args[1];
uint64_t local_seed = args[0].at<uint64_t>(0);
std::mt19937 gen(local_seed);
result.visit([&](auto output) {
using type = typename decltype(output)::value_type;
if constexpr(std::is_integral<type>{})
{
// default range for all integer types is
// (0, std::uniform_int_distribution<type>::max()).
// Todo: enable different ranges
std::uniform_int_distribution<type> dis;
std::generate(output.begin(), output.end(), [&] { return dis(gen); });
}
else
{
// default real distribution type is double with range (0, 1);
std::uniform_real_distribution<> dis;
std::generate(output.begin(), output.end(), [&] { return dis(gen); });
}
});
return result;
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 1; }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -163,7 +163,7 @@ struct reduce_op : op_name<Derived> ...@@ -163,7 +163,7 @@ struct reduce_op : op_name<Derived>
auto& self = static_cast<const Derived&>(*this); auto& self = static_cast<const Derived&>(*this);
auto data_idx = out_idx; auto data_idx = out_idx;
accumulator val = self.init(); accumulator val = self.init();
shape_for_each(batch_shape, [&](auto b_idx) { shape_for_each(batch_shape, [&](const auto& b_idx) {
this->tune_dims(tuned_axes, b_idx, data_idx); this->tune_dims(tuned_axes, b_idx, data_idx);
accumulator x = input(data_idx.begin(), data_idx.end()); accumulator x = input(data_idx.begin(), data_idx.end());
val = self.op()(accumulator{self.input()(x)}, val); val = self.op()(accumulator{self.input()(x)}, val);
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -70,13 +70,13 @@ struct reverse ...@@ -70,13 +70,13 @@ struct reverse
argument result{s}; argument result{s};
auto lens = s.lens(); auto lens = s.lens();
visit_all(result, args.front())([&](auto output, auto input) { visit_all(result, args.front())([&](auto output, auto input) {
shape_for_each(s, [&](const auto& out_idx) { shape_for_each(s, [&](const auto& out_idx_v, size_t out_idx) {
auto in_idx = out_idx; auto in_idx = out_idx_v;
for(const auto& axis : axes) for(const auto& axis : axes)
{ {
in_idx[axis] = lens[axis] - 1 - out_idx[axis]; in_idx[axis] = lens[axis] - 1 - out_idx_v[axis];
} }
output[s.index(out_idx)] = input[s.index(in_idx)]; output[out_idx] = input[s.index(in_idx)];
}); });
}); });
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -113,10 +113,9 @@ struct roialign ...@@ -113,10 +113,9 @@ struct roialign
{ {
std::vector<pos_weight> results(bin_grid_size[0] * bin_grid_size[1] * output_height * std::vector<pos_weight> results(bin_grid_size[0] * bin_grid_size[1] * output_height *
output_width); output_width);
shape_for_each(comp_s, [&](auto idx) { shape_for_each(comp_s, [&](const auto& idx_v, size_t index) {
std::array<std::size_t, 2> p = {idx[0], idx[1]}; std::array<std::size_t, 2> p = {idx_v[0], idx_v[1]};
std::array<std::size_t, 2> i = {idx[2], idx[3]}; std::array<std::size_t, 2> i = {idx_v[2], idx_v[3]};
auto index = comp_s.index(idx);
std::array<float, 2> xy{}; std::array<float, 2> xy{};
std::array<int64_t, 2> low{}; std::array<int64_t, 2> low{};
...@@ -125,7 +124,7 @@ struct roialign ...@@ -125,7 +124,7 @@ struct roialign
{ {
xy[ii] = roi_start[ii] + p[ii] * bin_size[ii] + xy[ii] = roi_start[ii] + p[ii] * bin_size[ii] +
(i[ii] + .5f) * bin_size[ii] / bin_grid_size[ii]; (i[ii] + .5f) * bin_size[ii] / bin_grid_size[ii];
xy[ii] = (coord_trans_mode == "output_half_pixel") ? (xy[ii] - 0.5f) : xy[ii]; xy[ii] = (coord_trans_mode == "half_pixel") ? (xy[ii] - 0.5f) : xy[ii];
if(xy[ii] < -1.0 or xy[ii] > dims[ii]) if(xy[ii] < -1.0 or xy[ii] > dims[ii])
{ {
results[index] = pos_weight{}; results[index] = pos_weight{};
...@@ -255,7 +254,7 @@ struct roialign ...@@ -255,7 +254,7 @@ struct roialign
std::vector<std::size_t> comp_lens1 = {channels, out_dims[0], out_dims[1]}; std::vector<std::size_t> comp_lens1 = {channels, out_dims[0], out_dims[1]};
shape comp_s1{migraphx::shape::float_type, comp_lens1}; shape comp_s1{migraphx::shape::float_type, comp_lens1};
std::vector<int64_t> vec_index(channels, 0); std::vector<int64_t> vec_index(channels, 0);
shape_for_each(comp_s1, [&](auto idx) { shape_for_each(comp_s1, [&](const auto& idx) {
auto c = idx[0]; auto c = idx[0];
auto ph = idx[1]; auto ph = idx[1];
auto pw = idx[2]; auto pw = idx[2];
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
......
...@@ -27,19 +27,34 @@ ...@@ -27,19 +27,34 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/normalize_attributes.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/**
* Slice operator that accepts variable axes, starts and ends.
*
* Attributes:
* axes: constant axes to slice over (optional)
* starts: constant slice starting indices (optional)
* ends: constant slice ending indices (optional)
*
* Parameters:
* data: the input tensor to slice (dynamic or static shape)
* input_starts: starting indicies of slice (optional, static shape)
* input_ends: ending indicies of slice (optional, static shape)
* input_axes: axes to slice over (optional, static shape)
*/
struct slice struct slice
{ {
std::vector<int64_t> axes; std::vector<int64_t> axes{};
std::vector<int64_t> starts; std::vector<int64_t> starts{};
std::vector<int64_t> ends; std::vector<int64_t> ends{};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -48,8 +63,8 @@ struct slice ...@@ -48,8 +63,8 @@ struct slice
} }
/** /**
* Ensure that attribute vectors axes, starts, and ends are all the same size and values are in * Ensure that attribute vectors axes, starts, and ends are all the same size and values are
* limits. * within limits.
*/ */
value attributes() const value attributes() const
{ {
...@@ -70,100 +85,235 @@ struct slice ...@@ -70,100 +85,235 @@ struct slice
std::string name() const { return "slice"; } std::string name() const { return "slice"; }
auto compute_offset(const shape& s) const /**
{ * Computes the slice output shape dimensions for given starts, ends,and axes.
const std::vector<std::size_t>& lens = s.lens(); * Templated to also handle tensor views.
const std::vector<std::size_t>& strides = s.strides(); * Possibily different type between [in_starts, in_ends] and [in_axes] if in_axes is this
auto offset = 0; * object's axes attribute. Assumes in_starts and in_ends are normalized; in_axes are valid.
if(not axes.empty()) */
{ template <class A, class B>
for(std::size_t i = 0; i < axes.size(); i++) std::vector<std::size_t>
{ lens_calc(const std::vector<std::size_t>& lengths, A in_starts, A in_ends, B in_axes) const
auto axis = axes[i];
offset += starts[i] * strides[axis];
}
}
else
{ {
for(std::size_t axis = 0; axis < lens.size(); axis++) auto new_lens = lengths;
for(std::size_t i = 0; i < in_axes.size(); ++i)
{ {
offset += starts[axis] * strides[axis]; auto axis = in_axes[i];
new_lens[axis] = in_ends[i] - in_starts[i];
} }
} return new_lens;
return offset;
} }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(1); check_shapes{inputs, *this, true}.has(1, 3, 4);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
if(inputs.size() == 1)
{
auto t = input_shape.type(); auto t = input_shape.type();
// TODO: When support for dynamic shapes is added to normalize_attributes,
// remove this restriction.
if(input_shape.dynamic() and std::any_of(axes.begin(), axes.end(), [&](auto axis) { if(input_shape.dynamic() and std::any_of(axes.begin(), axes.end(), [&](auto axis) {
return not input_shape.dyn_dims()[axis].is_fixed(); return not input_shape.dyn_dims()[axis].is_fixed();
})) }))
{ {
MIGRAPHX_THROW("SLICE: slicing is not allowed on non-fixed dynamic input axis "); MIGRAPHX_THROW("SLICE: slicing is not allowed on non-fixed dynamic input axis ");
} }
// For a static shape, old_lens will be adjusted to a new size
// for those axes that are sliced.
// For dynamic shape, the adjusted old_lens become the new max values,
// while updating the old mins and optimals if possible.
std::vector<std::size_t> new_mins;
std::vector<std::size_t> old_lens;
std::vector<std::size_t> old_strides;
// Doesn't handle optimals
if(input_shape.dynamic()) if(input_shape.dynamic())
{ {
old_lens = input_shape.max_lens(); return shape{t,
new_mins = input_shape.min_lens(); lens_calc(input_shape.min_lens(), starts, ends, axes),
lens_calc(input_shape.max_lens(), starts, ends, axes),
{}};
}
else
{
return shape{
t, lens_calc(input_shape.lens(), starts, ends, axes), input_shape.strides()};
}
}
else
{
// check that starts, ends, and optionally input_axes are all 1D, have the same
// dimension, and are static
check_shapes{inputs.begin() + 1,
inputs.end(),
std::string("SLICE: inputs (starts, ends, and input_axes)"),
false}
.only_dims(1)
.same_dims();
auto dds = input_shape.to_dynamic().dyn_dims();
if(inputs.size() == 3)
{
if(inputs[1].lens().at(0) != axes.size())
{
MIGRAPHX_THROW("SLICE: inputs starts and ends do not have the same dimension "
"as the axes attribute");
}
std::for_each(axes.cbegin(), axes.cend(), [&](const auto& axis) {
dds.at(axis) = {0, dds.at(axis).max};
});
} }
else else
{ {
old_lens = input_shape.lens(); // if axes is an input, then all the output dimensions could be 0 to the max value
// For static shape (including during eval step after a dynamic input) the strides are std::transform(dds.begin(), dds.end(), dds.begin(), [](auto dd) {
// indexed into the pre-slice array, so they are larger than the apparent size of the return shape::dynamic_dimension{0, dd.max};
// resulting shape. });
old_strides = input_shape.strides(); }
return shape{input_shape.type(), dds};
}
} }
std::vector<std::size_t> new_lens = old_lens; /**
* Calculates the starting offset for the sliced tensor.
* Used in compute when only data input and all other information are in the attributes.
*
* \param s static input shape
*/
auto compute_offset(const shape& s) const
{
const std::vector<std::size_t>& lens = s.lens();
const std::vector<std::size_t>& strides = s.strides();
auto offset = 0;
if(not axes.empty())
{
for(std::size_t i = 0; i < axes.size(); i++) for(std::size_t i = 0; i < axes.size(); i++)
{ {
auto axis = axes[i]; auto axis = axes[i];
size_t sliced_length = ends[i] - starts[i]; offset += starts[i] * strides[axis];
// A Numpy indexing convention: a slice size larger than the actual dimension }
// is legal and the "ends" value is clipped to the axis size }
new_lens[axis] = std::min(new_lens[axis], sliced_length); else
if(input_shape.dynamic()) {
for(std::size_t axis = 0; axis < lens.size(); axis++)
{ {
// TODO: when non-fixed shape slicing is allowed, this will be different than offset += starts[axis] * strides[axis];
// sliced_length, making use of TBD start/end values.
std::size_t sliced_min_length = ends[i] - starts[i];
// if the slice size is smaller than maxes but larger than mins
new_mins[axis] = std::min(sliced_min_length, new_mins[axis]);
} }
} }
if(input_shape.dynamic()) return offset * s.type_size();
}
/**
* Calculates the starting offset for the sliced tensor (for aliasing).
* Used when the starts and/or the axes are inputs.
*
* \param s static input shape
* \param input_starts starting indices of slice
* \param ax_vec axes to slice on
*/
template <class IndView, class Axes>
auto compute_offset(const shape& s, const IndView& input_starts, const Axes& ax_vec) const
{ {
return shape{t, new_mins, new_lens, {}}; auto ret = 0;
for(std::size_t i = 0; i < ax_vec.size(); ++i)
{
auto axis = ax_vec[i];
ret += input_starts[i] * s.strides().at(axis);
} }
else return ret * s.type_size();
}
std::unordered_map<std::string, std::vector<int64_t>>
normalize_inputs(const shape& input_shape,
const std::vector<int64_t>& input_starts,
const std::vector<int64_t>& input_ends) const
{ {
return shape{t, new_lens, old_strides}; auto attrs = this->attributes().at("normalize_axes");
return {{"input_starts",
normalize_indices(input_starts,
this->axes,
input_shape,
attrs.at("starts"),
"Slice variable input_starts")},
{"input_ends",
normalize_indices(input_ends,
this->axes,
input_shape,
attrs.at("ends"),
"Slice variable input_ends")}};
} }
/**
* Three input version of the normalize_inputs.
* This one also checks that the input_axes are valid.
*/
std::unordered_map<std::string, std::vector<int64_t>>
normalize_inputs(shape input_shape,
const std::vector<int64_t>& input_starts,
const std::vector<int64_t>& input_ends,
const std::vector<int64_t>& input_axes) const
{
auto attrs = this->attributes().at("normalize_axes");
auto norm_axes =
normalize_axes(input_axes, input_shape, attrs.at("axes"), "Slice variable input_axes");
return {{"input_starts",
normalize_indices(input_starts,
norm_axes,
input_shape,
attrs.at("starts"),
"Slice variable input_starts")},
{"input_ends",
normalize_indices(input_ends,
norm_axes,
input_shape,
attrs.at("ends"),
"Slice variable input ends")},
{"input_axes", norm_axes}};
} }
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
auto input = args[0]; auto input = args[0];
auto input_shape = input.get_shape();
auto offset = compute_offset(input.get_shape()) * dyn_out.computed_shape.type_size(); switch(args.size())
{
case 1: {
std::size_t offset = compute_offset(input_shape);
return {dyn_out.computed_shape, [=] { return input.data() + offset; }}; return {dyn_out.computed_shape, [=] { return input.data() + offset; }};
} }
case 3: {
shape calc_shape;
std::size_t offset = 0;
visit_all(args[1], args[2])([&](auto input_starts, auto input_ends) {
auto norm_inputs = normalize_inputs(input_shape,
input_starts.template to_vector<int64_t>(),
input_ends.template to_vector<int64_t>());
offset = compute_offset(input_shape, norm_inputs.at("input_starts"), this->axes);
calc_shape = {input_shape.type(),
lens_calc(input_shape.lens(),
norm_inputs.at("input_starts"),
norm_inputs.at("input_ends"),
this->axes),
input_shape.strides()};
});
return {calc_shape, [=] { return input.data() + offset; }};
}
case 4: {
shape calc_shape;
std::size_t offset = 0;
visit_all(args[1], args[2], args[3])(
[&](auto input_starts, auto input_ends, auto input_axes) {
auto norm_inputs = normalize_inputs(input_shape,
input_starts.template to_vector<int64_t>(),
input_ends.template to_vector<int64_t>(),
input_axes.template to_vector<int64_t>());
offset = compute_offset(
input_shape, norm_inputs.at("input_starts"), norm_inputs.at("input_axes"));
calc_shape = shape{input_shape.type(),
lens_calc(input_shape.lens(),
norm_inputs.at("input_starts"),
norm_inputs.at("input_ends"),
norm_inputs.at("input_axes")),
input_shape.strides()};
});
return {calc_shape, [=] { return input.data() + offset; }};
}
default: {
// Should never get here; covering in case some code change occurs
MIGRAPHX_THROW("SLICE: invalid number of inputs");
}
}
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -55,6 +55,7 @@ ...@@ -55,6 +55,7 @@
#include <migraphx/op/equal.hpp> #include <migraphx/op/equal.hpp>
#include <migraphx/op/erf.hpp> #include <migraphx/op/erf.hpp>
#include <migraphx/op/exp.hpp> #include <migraphx/op/exp.hpp>
#include <migraphx/op/fill.hpp>
#include <migraphx/op/flatten.hpp> #include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp> #include <migraphx/op/floor.hpp>
#include <migraphx/op/fmod.hpp> #include <migraphx/op/fmod.hpp>
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -62,6 +62,14 @@ shape compute_padded_shape(const shape& input, ...@@ -62,6 +62,14 @@ shape compute_padded_shape(const shape& input,
const std::vector<std::size_t>& stride, const std::vector<std::size_t>& stride,
const std::vector<std::size_t>& dilation); const std::vector<std::size_t>& dilation);
// Used for dynamic auto padding of pooling operators where padding needs to be computed at
// evaulation time.
shape compute_padded_pool_shape(const shape& input,
const shape& kernel,
const std::vector<std::size_t>& padding,
const std::vector<std::size_t>& stride,
const std::vector<std::size_t>& dilation);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -205,7 +205,7 @@ void transform(Range1&& r1, Range2&& r2, Iterator it, F f) ...@@ -205,7 +205,7 @@ void transform(Range1&& r1, Range2&& r2, Iterator it, F f)
} }
template <class Range> template <class Range>
auto reverse(Range& r) auto reverse(Range&& r)
{ {
return range(std::make_reverse_iterator(r.end()), std::make_reverse_iterator(r.begin())); return range(std::make_reverse_iterator(r.end()), std::make_reverse_iterator(r.begin()));
} }
......
...@@ -263,7 +263,7 @@ struct MIGRAPHX_EXPORT shape ...@@ -263,7 +263,7 @@ struct MIGRAPHX_EXPORT shape
/// no padding /// no padding
bool packed() const; bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending /// Returns true if the shape has been transposed. That is the strides are not in descending
/// order /// order
bool transposed() const; bool transposed() const;
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -37,11 +37,11 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -37,11 +37,11 @@ inline namespace MIGRAPHX_INLINE_NS {
template <class F> template <class F>
void shape_for_each(const migraphx::shape& s, F f) void shape_for_each(const migraphx::shape& s, F f)
{ {
// Ensure calls to f use const ref to vector
auto call = [&f](const std::vector<std::size_t>& i) { f(i); };
std::vector<std::size_t> indices(s.lens().size()); std::vector<std::size_t> indices(s.lens().size());
const auto& index_const_ref = indices;
shape ss{s.type(), s.lens()}; shape ss{s.type(), s.lens()};
for(std::size_t i = 0; i < ss.elements(); i++) size_t max = ss.elements();
for(std::size_t i = 0; i < max; i++)
{ {
std::transform(ss.strides().begin(), std::transform(ss.strides().begin(),
ss.strides().end(), ss.strides().end(),
...@@ -51,9 +51,13 @@ void shape_for_each(const migraphx::shape& s, F f) ...@@ -51,9 +51,13 @@ void shape_for_each(const migraphx::shape& s, F f)
assert(len > 0 and stride > 0); assert(len > 0 and stride > 0);
return (i / stride) % len; return (i / stride) % len;
}); });
call(indices); if constexpr(std::is_invocable<F, decltype(index_const_ref), decltype(i)>{})
f(index_const_ref, i);
else
f(index_const_ref);
} }
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment