Unverified Commit e7ec374f authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Refactor dynamic_dimension to have multiple optimals (#1625)

Makes the optimals into a std::set<std::size_t>
Changes shape object functions to handle the opts change
Changes to convolution, flatten, pooling, and convolution in that they no longer calculate the output optimal dimensions. Instead returns empty opts. Will need to change this in the future if we want to support dynamic shapes fully.
Many changes to tests and shape calls with respect to the new optimals
parent 1329b9be
...@@ -89,8 +89,8 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha ...@@ -89,8 +89,8 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
} }
else if(a == 1 or b == 1) else if(a == 1 or b == 1)
{ {
// setting opt to 0, may need to be changed // setting optimals to empty, may need to be changed
return shape::dynamic_dimension{std::max(a.min, b.min), std::max(a.max, b.max), 0}; return shape::dynamic_dimension{std::max(a.min, b.min), std::max(a.max, b.max)};
} }
else else
{ {
......
...@@ -37,7 +37,7 @@ struct onnx_options ...@@ -37,7 +37,7 @@ struct onnx_options
std::size_t default_dim_value = 0; std::size_t default_dim_value = 0;
/// Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value set /// Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value set
/// parser throws) /// parser throws)
shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0}; shape::dynamic_dimension default_dyn_dim_value = {1, 1};
/// Explicitly specify the dims of an input /// Explicitly specify the dims of an input
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {}; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {};
/// Explicitly specify dynamic dims of an input (if both map_input_dims and map_dyn_input_dims /// Explicitly specify dynamic dims of an input (if both map_input_dims and map_dyn_input_dims
......
...@@ -62,7 +62,7 @@ struct argmax ...@@ -62,7 +62,7 @@ struct argmax
if(s0.dynamic()) if(s0.dynamic())
{ {
auto dyn_dims = s0.dyn_dims(); auto dyn_dims = s0.dyn_dims();
dyn_dims[axis] = {1, 1, 0}; dyn_dims[axis] = {1, 1};
return {shape::int64_type, dyn_dims}; return {shape::int64_type, dyn_dims};
} }
else else
......
...@@ -134,7 +134,7 @@ struct concat ...@@ -134,7 +134,7 @@ struct concat
} }
auto new_dims = inputs[0].dyn_dims(); auto new_dims = inputs[0].dyn_dims();
new_dims[axis] = migraphx::shape::dynamic_dimension{new_min, new_max, 0}; new_dims[axis] = migraphx::shape::dynamic_dimension{new_min, new_max};
return {inputs[0].type(), new_dims}; return {inputs[0].type(), new_dims};
} }
else else
......
...@@ -35,6 +35,10 @@ namespace migraphx { ...@@ -35,6 +35,10 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/**
* Convolution operator. Does not support optimal dimensions for spatial dimensions. Returns empty
* optimals.
*/
struct convolution struct convolution
{ {
std::vector<std::size_t> padding = {0, 0}; std::vector<std::size_t> padding = {0, 0};
...@@ -145,7 +149,7 @@ struct convolution ...@@ -145,7 +149,7 @@ struct convolution
else else
{ {
auto l = input_shape.lens().at(0); auto l = input_shape.lens().at(0);
output_dyn_dims.push_back({l, l, 0}); output_dyn_dims.push_back({l, l});
} }
}; };
...@@ -162,25 +166,30 @@ struct convolution ...@@ -162,25 +166,30 @@ struct convolution
if(x_shape.dynamic()) if(x_shape.dynamic())
{ {
auto x = x_shape.dyn_dims()[i + 2]; auto x = x_shape.dyn_dims()[i + 2];
output_dyn_dims.push_back(shape::dynamic_dimension{ std::set<std::size_t> optimals{};
ceil_div(x.min, s), ceil_div(x.max, s), ceil_div(x.opt, s)}); std::transform(x.optimals.begin(),
x.optimals.end(),
std::inserter(optimals, optimals.begin()),
[&](auto o) { return ceil_div(o, s); });
output_dyn_dims.push_back(
shape::dynamic_dimension{ceil_div(x.min, s), ceil_div(x.max, s), optimals});
} }
else else
{ {
auto od = ceil_div(x_shape.lens()[i + 2], s); auto od = ceil_div(x_shape.lens()[i + 2], s);
output_dyn_dims.push_back(shape::dynamic_dimension{od, od, 0}); output_dyn_dims.push_back(shape::dynamic_dimension{od, od});
} }
} }
} }
else else
{ {
// Does not compute for optimals
auto min_spatial_dims = calc_conv_lens(x_shape.min_lens(), w_shape.max_lens()); auto min_spatial_dims = calc_conv_lens(x_shape.min_lens(), w_shape.max_lens());
auto max_spatial_dims = calc_conv_lens(x_shape.max_lens(), w_shape.min_lens()); auto max_spatial_dims = calc_conv_lens(x_shape.max_lens(), w_shape.min_lens());
auto opt_spatial_dims = calc_conv_lens(x_shape.opt_lens(), w_shape.opt_lens());
for(size_t i = 0; i < num_spatial_dims; ++i) for(size_t i = 0; i < num_spatial_dims; ++i)
{ {
output_dyn_dims.push_back(shape::dynamic_dimension{ output_dyn_dims.push_back(
min_spatial_dims[i], max_spatial_dims[i], opt_spatial_dims[i]}); shape::dynamic_dimension{min_spatial_dims[i], max_spatial_dims[i], {}});
} }
} }
return shape{x_shape.type(), output_dyn_dims}; return shape{x_shape.type(), output_dyn_dims};
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -59,27 +60,22 @@ struct flatten ...@@ -59,27 +60,22 @@ struct flatten
auto s = inputs[0]; auto s = inputs[0];
if(s.dynamic()) if(s.dynamic())
{ {
// Doesn't handle optimals
auto min_lens = s.min_lens(); auto min_lens = s.min_lens();
auto max_lens = s.max_lens(); auto max_lens = s.max_lens();
auto opt_lens = s.opt_lens();
// If any of the opt values is 0, output opt will be 0 // If any of the opt values is 0, output opt will be 0
shape::dynamic_dimension x = { shape::dynamic_dimension x = {
std::accumulate( std::accumulate(
min_lens.begin(), min_lens.begin() + axis, std::size_t{1}, std::multiplies<>{}), min_lens.begin(), min_lens.begin() + axis, std::size_t{1}, std::multiplies<>{}),
std::accumulate( std::accumulate(
max_lens.begin(), max_lens.begin() + axis, std::size_t{1}, std::multiplies<>{}), max_lens.begin(), max_lens.begin() + axis, std::size_t{1}, std::multiplies<>{}),
std::accumulate(opt_lens.begin(), {}};
opt_lens.begin() + axis,
std::size_t{1},
std::multiplies<>{})};
shape::dynamic_dimension y = { shape::dynamic_dimension y = {
std::accumulate( std::accumulate(
min_lens.begin() + axis, min_lens.end(), std::size_t{1}, std::multiplies<>{}), min_lens.begin() + axis, min_lens.end(), std::size_t{1}, std::multiplies<>{}),
std::accumulate( std::accumulate(
max_lens.begin() + axis, max_lens.end(), std::size_t{1}, std::multiplies<>{}), max_lens.begin() + axis, max_lens.end(), std::size_t{1}, std::multiplies<>{}),
std::accumulate( {}};
opt_lens.begin() + axis, opt_lens.end(), std::size_t{1}, std::multiplies<>{}),
};
return {s.type(), {x, y}}; return {s.type(), {x, y}};
} }
else else
......
...@@ -121,7 +121,7 @@ struct gathernd ...@@ -121,7 +121,7 @@ struct gathernd
// A rank 0 output is a scalar // A rank 0 output is a scalar
if(output_ndim == 0) if(output_ndim == 0)
return shape(data_shape.type(), {shape::dynamic_dimension({1, 1, 0})}); return shape(data_shape.type(), {shape::dynamic_dimension({1, 1})});
// Part of the output shape comes from indices tensor, part from data tensor // Part of the output shape comes from indices tensor, part from data tensor
std::vector<shape::dynamic_dimension> output_dims(output_ndim); std::vector<shape::dynamic_dimension> output_dims(output_ndim);
......
...@@ -119,8 +119,8 @@ struct nonmaxsuppression ...@@ -119,8 +119,8 @@ struct nonmaxsuppression
fixed_shape_error_check(); fixed_shape_error_check();
} }
std::vector<shape::dynamic_dimension> out_lens = {}; std::vector<shape::dynamic_dimension> out_lens = {};
out_lens.push_back({0, max_num_boxes, 0}); out_lens.push_back({0, max_num_boxes});
out_lens.push_back({3, 3, 0}); out_lens.push_back({3, 3});
return {shape::int64_type, out_lens}; return {shape::int64_type, out_lens};
} }
else else
......
...@@ -89,25 +89,17 @@ struct pooling ...@@ -89,25 +89,17 @@ struct pooling
std::vector<std::size_t> output_lens{}; std::vector<std::size_t> output_lens{};
for(size_t i = 0; i < kdims; ++i) for(size_t i = 0; i < kdims; ++i)
{ {
if(input_lens[i + 2] == 0) std::size_t padding_factor = 2 * padding[i];
{ if(padding.size() == 2 * kdims)
// handle opt = 0 padding_factor = padding[i] + padding[i + kdims];
output_lens.push_back(0); assert(input_lens[i + 2] + padding_factor >= lengths[i]);
} std::size_t dim_size = input_lens[i + 2] + padding_factor - lengths[i];
else std::size_t len =
{ (ceil_mode)
std::size_t padding_factor = 2 * padding[i]; ? dim_size / stride[i] +
if(padding.size() == 2 * kdims) static_cast<std::size_t>((dim_size % stride[i] != 0)) // ceil uint divide
padding_factor = padding[i] + padding[i + kdims]; : dim_size / stride[i]; // floor divide
assert(input_lens[i + 2] + padding_factor >= lengths[i]); output_lens.push_back(len + 1);
std::size_t dim_size = input_lens[i + 2] + padding_factor - lengths[i];
std::size_t len =
(ceil_mode)
? dim_size / stride[i] + static_cast<std::size_t>((dim_size % stride[i] !=
0)) // ceil uint divide
: dim_size / stride[i]; // floor divide
output_lens.push_back(len + 1);
}
} }
return output_lens; return output_lens;
} }
...@@ -134,19 +126,19 @@ struct pooling ...@@ -134,19 +126,19 @@ struct pooling
{ {
for(size_t i = 0; i < kdims; ++i) for(size_t i = 0; i < kdims; ++i)
{ {
output_dyn_dims.push_back(shape::dynamic_dimension{1, 1, 1}); output_dyn_dims.push_back(shape::dynamic_dimension{1, 1});
} }
return {input.type(), output_dyn_dims}; return {input.type(), output_dyn_dims};
} }
else else
{ {
// does not compute for optimals
auto min_spatial_dims = calc_spatial_dim_out(input.min_lens(), kdims); auto min_spatial_dims = calc_spatial_dim_out(input.min_lens(), kdims);
auto max_spatial_dims = calc_spatial_dim_out(input.max_lens(), kdims); auto max_spatial_dims = calc_spatial_dim_out(input.max_lens(), kdims);
auto opt_spatial_dims = calc_spatial_dim_out(input.opt_lens(), kdims);
for(size_t i = 0; i < kdims; ++i) for(size_t i = 0; i < kdims; ++i)
{ {
output_dyn_dims.push_back(shape::dynamic_dimension{ output_dyn_dims.push_back(
min_spatial_dims[i], max_spatial_dims[i], opt_spatial_dims[i]}); shape::dynamic_dimension{min_spatial_dims[i], max_spatial_dims[i], {}});
} }
return {input.type(), output_dyn_dims}; return {input.type(), output_dyn_dims};
} }
......
...@@ -123,9 +123,7 @@ struct reduce_op : op_name<Derived> ...@@ -123,9 +123,7 @@ struct reduce_op : op_name<Derived>
auto tuned_axes = tune_axes(output_dyn_dims.size()); auto tuned_axes = tune_axes(output_dyn_dims.size());
for(const auto& axis : tuned_axes) for(const auto& axis : tuned_axes)
{ {
// At the time of writing, there's no functional difference between output_dyn_dims[axis] = {1, 1};
// optimum of 0 (no opt) or 1.
output_dyn_dims[axis] = {1, 1, 0};
} }
return shape{s.type(), output_dyn_dims}; return shape{s.type(), output_dyn_dims};
......
...@@ -111,16 +111,15 @@ struct slice ...@@ -111,16 +111,15 @@ struct slice
// For a static shape, old_lens will be adjusted to a new size // For a static shape, old_lens will be adjusted to a new size
// for those axes that are sliced. // for those axes that are sliced.
// For dynamic shape, the adjusted old_lens become the new max values, // For dynamic shape, the adjusted old_lens become the new max values,
// while updating the old mins and opts if possible. // while updating the old mins and optimals if possible.
std::vector<std::size_t> new_mins; std::vector<std::size_t> new_mins;
std::vector<std::size_t> new_opts;
std::vector<std::size_t> old_lens; std::vector<std::size_t> old_lens;
std::vector<std::size_t> old_strides; std::vector<std::size_t> old_strides;
// Doesn't handle optimals
if(input_shape.dynamic()) if(input_shape.dynamic())
{ {
old_lens = input_shape.max_lens(); old_lens = input_shape.max_lens();
new_mins = input_shape.min_lens(); new_mins = input_shape.min_lens();
new_opts = input_shape.opt_lens();
} }
else else
{ {
...@@ -146,17 +145,11 @@ struct slice ...@@ -146,17 +145,11 @@ struct slice
std::size_t sliced_min_length = ends[i] - starts[i]; std::size_t sliced_min_length = ends[i] - starts[i];
// if the slice size is smaller than maxes but larger than mins // if the slice size is smaller than maxes but larger than mins
new_mins[axis] = std::min(sliced_min_length, new_mins[axis]); new_mins[axis] = std::min(sliced_min_length, new_mins[axis]);
auto sliced_opt_length = ends[i] - starts[i];
if(new_opts[axis] != 0)
new_opts[axis] = sliced_opt_length;
if(new_opts[axis] < new_mins[axis] or new_opts[axis] > new_lens[axis])
new_opts[axis] = 0;
} }
} }
if(input_shape.dynamic()) if(input_shape.dynamic())
{ {
return shape{t, new_mins, new_lens, new_opts}; return shape{t, new_mins, new_lens, {}};
} }
else else
{ {
......
...@@ -81,7 +81,7 @@ struct unsqueeze ...@@ -81,7 +81,7 @@ struct unsqueeze
{ {
if(std::find(axes.begin(), axes.end(), i) != axes.end()) if(std::find(axes.begin(), axes.end(), i) != axes.end())
{ {
dyn_dims.push_back({1, 1, 0}); dyn_dims.push_back({1, 1});
} }
else else
{ {
......
...@@ -188,7 +188,8 @@ auto from_value_impl(rank<3>, const value& v, T& x) ...@@ -188,7 +188,8 @@ auto from_value_impl(rank<3>, const value& v, T& x)
} }
template <class T> template <class T>
auto from_value_impl(rank<4>, const value& v, T& x) -> decltype(x.insert(*x.begin()), void()) auto from_value_impl(rank<4>, const value& v, T& x)
-> decltype(x.insert(*x.begin()), std::declval<typename T::mapped_type>(), void())
{ {
x.clear(); x.clear();
for(auto&& e : v) for(auto&& e : v)
......
...@@ -29,10 +29,12 @@ ...@@ -29,10 +29,12 @@
#include <ostream> #include <ostream>
#include <numeric> #include <numeric>
#include <memory> #include <memory>
#include <set>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -87,12 +89,12 @@ struct shape ...@@ -87,12 +89,12 @@ struct shape
{ {
std::size_t min = 0; std::size_t min = 0;
std::size_t max = 0; std::size_t max = 0;
std::size_t opt = 0; std::set<std::size_t> optimals{};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.min, "min"), f(self.max, "max"), f(self.opt, "opt")); return pack(f(self.min, "min"), f(self.max, "max"), f(self.optimals, "optimals"));
} }
bool is_fixed() const; bool is_fixed() const;
...@@ -132,11 +134,12 @@ struct shape ...@@ -132,11 +134,12 @@ struct shape
shape(type_t t, std::vector<dynamic_dimension> dims); shape(type_t t, std::vector<dynamic_dimension> dims);
// Construct a dynamic shape from three sets of lengths (of the same rank) // Construct a dynamic shape from vectors of mins, maxes, and optimals.
// optimals_list is a vector of optimals that corresponds to each min and max.
shape(type_t t, shape(type_t t,
std::vector<std::size_t> mins, std::vector<std::size_t> mins,
std::vector<std::size_t> maxes, std::vector<std::size_t> maxes,
std::vector<std::size_t> opts); std::vector<std::set<std::size_t>> optimals_list);
template <class Range> template <class Range>
shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end())) shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end()))
...@@ -198,9 +201,9 @@ struct shape ...@@ -198,9 +201,9 @@ struct shape
/*! /*!
* Optimum lengths for dynamic shape. * Optimum lengths for dynamic shape.
* lens() for fixed shape. * Empty for fixed shape.
*/ */
std::vector<std::size_t> opt_lens() const; std::vector<std::set<std::size_t>> opt_lens() const;
/// Map multiple indices to space index /// Map multiple indices to space index
std::size_t index(std::initializer_list<std::size_t> l) const; std::size_t index(std::initializer_list<std::size_t> l) const;
...@@ -253,7 +256,7 @@ struct shape ...@@ -253,7 +256,7 @@ struct shape
shape with_type(type_t t) const; shape with_type(type_t t) const;
// convert the shape to an equivalent dynamic shape // convert the shape to an equivalent dynamic shape with empty optimals
shape to_dynamic() const; shape to_dynamic() const;
friend bool operator==(const shape& x, const shape& y); friend bool operator==(const shape& x, const shape& y);
......
...@@ -94,7 +94,7 @@ struct onnx_parser ...@@ -94,7 +94,7 @@ struct onnx_parser
node_map nodes; node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0}; shape::dynamic_dimension default_dyn_dim_value = {1, 1};
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims; std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
bool use_dyn_output = false; bool use_dyn_output = false;
......
...@@ -46,14 +46,14 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs) ...@@ -46,14 +46,14 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
auto dim_val = options.default_dim_value; auto dim_val = options.default_dim_value;
if(dim_val != 0) if(dim_val != 0)
{ {
if(options.default_dyn_dim_value != shape::dynamic_dimension{1, 1, 0}) if(options.default_dyn_dim_value != shape::dynamic_dimension{1, 1})
{ {
MIGRAPHX_THROW("PARSE_ONNX_FROM: both default_dim_value and default_dyn_dim_value" MIGRAPHX_THROW("PARSE_ONNX_FROM: both default_dim_value and default_dyn_dim_value"
"set to non-default value"); "set to non-default value");
} }
else else
{ {
parser.default_dyn_dim_value = {dim_val, dim_val, 0}; parser.default_dyn_dim_value = {dim_val, dim_val};
} }
} }
else else
......
...@@ -491,7 +491,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t, ...@@ -491,7 +491,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
return default_dyn_dim_value; return default_dyn_dim_value;
} }
std::size_t tmp = d.dim_value(); std::size_t tmp = d.dim_value();
return {tmp, tmp, 0}; return {tmp, tmp};
} }
else else
{ {
......
...@@ -74,13 +74,23 @@ struct shape_impl ...@@ -74,13 +74,23 @@ struct shape_impl
shape_impl(shape::type_t t, shape_impl(shape::type_t t,
std::vector<std::size_t> mins, std::vector<std::size_t> mins,
std::vector<std::size_t> maxes, std::vector<std::size_t> maxes,
std::vector<std::size_t> opts) std::vector<std::set<std::size_t>> optimals_list)
: m_type(t) : m_type(t)
{ {
assert(mins.size() == maxes.size() and maxes.size() == opts.size()); if(optimals_list.empty())
for(size_t i = 0; i < mins.size(); ++i)
{ {
m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i], opts[i]}); for(size_t i = 0; i < mins.size(); ++i)
{
m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i]});
}
}
else
{
assert(mins.size() == maxes.size() and maxes.size() == optimals_list.size());
for(size_t i = 0; i < mins.size(); ++i)
{
m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i], optimals_list[i]});
}
} }
} }
...@@ -147,7 +157,7 @@ struct shape_impl ...@@ -147,7 +157,7 @@ struct shape_impl
std::transform(m_dyn_dims.cbegin(), std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(), m_dyn_dims.cend(),
ret.begin(), ret.begin(),
[](shape::dynamic_dimension x) { return x.min; }); [](const shape::dynamic_dimension& x) { return x.min; });
return ret; return ret;
} }
...@@ -157,19 +167,20 @@ struct shape_impl ...@@ -157,19 +167,20 @@ struct shape_impl
std::transform(m_dyn_dims.cbegin(), std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(), m_dyn_dims.cend(),
ret.begin(), ret.begin(),
[](shape::dynamic_dimension x) { return x.max; }); [](const shape::dynamic_dimension& x) { return x.max; });
return ret; return ret;
} }
std::vector<std::size_t> opt_lens() const std::vector<std::set<std::size_t>> opt_lens() const
{ {
std::vector<std::size_t> ret(m_dyn_dims.size()); std::vector<std::set<std::size_t>> ret(m_dyn_dims.size());
std::transform(m_dyn_dims.cbegin(), std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(), m_dyn_dims.cend(),
ret.begin(), ret.begin(),
[](shape::dynamic_dimension x) { return x.opt; }); [](const shape::dynamic_dimension& x) { return x.optimals; });
return ret; return ret;
} }
// Does the shape skip over elements? // Does the shape skip over elements?
bool skips() const bool skips() const
{ {
...@@ -240,8 +251,9 @@ shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims) ...@@ -240,8 +251,9 @@ shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
shape::shape(type_t t, shape::shape(type_t t,
std::vector<std::size_t> mins, std::vector<std::size_t> mins,
std::vector<std::size_t> maxes, std::vector<std::size_t> maxes,
std::vector<std::size_t> opts) std::vector<std::set<std::size_t>> optimals_list)
: impl(std::make_shared<shape_impl>(t, std::move(mins), std::move(maxes), std::move(opts))) : impl(std::make_shared<shape_impl>(
t, std::move(mins), std::move(maxes), std::move(optimals_list)))
{ {
} }
...@@ -473,8 +485,7 @@ shape shape::to_dynamic() const ...@@ -473,8 +485,7 @@ shape shape::to_dynamic() const
{ {
return *this; return *this;
} }
std::vector<std::size_t> zeroes(this->ndim(), 0); return {type(), lens(), lens(), {}};
return {type(), lens(), lens(), zeroes};
} }
std::size_t shape::element_space() const { return impl->element_space(); } std::size_t shape::element_space() const { return impl->element_space(); }
...@@ -506,23 +517,22 @@ std::vector<std::size_t> shape::max_lens() const ...@@ -506,23 +517,22 @@ std::vector<std::size_t> shape::max_lens() const
return this->dynamic() ? impl->max_lens() : this->lens(); return this->dynamic() ? impl->max_lens() : this->lens();
} }
std::vector<std::size_t> shape::opt_lens() const std::vector<std::set<std::size_t>> shape::opt_lens() const { return impl->opt_lens(); }
{
return this->dynamic() ? impl->opt_lens() : this->lens();
}
bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max; } bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max; }
bool shape::dynamic_dimension::has_optimal() const { return opt != 0; } bool shape::dynamic_dimension::has_optimal() const { return not optimals.empty(); }
shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const std::size_t& x) shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const std::size_t& x)
{ {
this->min += x; this->min += x;
this->max += x; this->max += x;
if(this->opt != 0) std::set<std::size_t> new_optimals;
{ std::transform(this->optimals.begin(),
this->opt += x; this->optimals.end(),
}; std::inserter(new_optimals, new_optimals.begin()),
[&x](const auto& opt) { return (opt + x); });
this->optimals = new_optimals;
return *this; return *this;
} }
...@@ -532,19 +542,23 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t ...@@ -532,19 +542,23 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t
assert(this->max >= x); assert(this->max >= x);
this->min -= x; this->min -= x;
this->max -= x; this->max -= x;
if(this->opt != 0) std::set<std::size_t> new_optimals;
{ std::transform(this->optimals.begin(),
assert(this->opt >= x); this->optimals.end(),
this->opt -= x; std::inserter(new_optimals, new_optimals.begin()),
} [&x](const auto& opt) {
assert(opt >= x);
return (opt - x);
});
this->optimals = new_optimals;
return *this; return *this;
} }
bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{ {
// don't check opt if both are fixed // don't check optimals if both are fixed
return (x.min == y.min and x.max == y.max and return (x.min == y.min and x.max == y.max and
((x.is_fixed() and y.is_fixed()) or (x.opt == y.opt))); ((x.is_fixed() and y.is_fixed()) or (x.optimals == y.optimals)));
} }
bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
...@@ -553,7 +567,7 @@ bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimensio ...@@ -553,7 +567,7 @@ bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimensio
} }
std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x) std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
{ {
os << "[" << x.min << ", " << x.max << ", " << x.opt << "]"; os << "[ " << x.min << ", " << x.max << ", {" << migraphx::to_string_range(x.optimals) << "} ]";
return os; return os;
} }
...@@ -663,10 +677,12 @@ void migraphx_from_value(const value& v, shape& s) ...@@ -663,10 +677,12 @@ void migraphx_from_value(const value& v, shape& s)
auto v_dd = v.at("dynamic_dimensions"); auto v_dd = v.at("dynamic_dimensions");
std::vector<shape::dynamic_dimension> dyn_dims(v.at("dynamic_dimensions").size()); std::vector<shape::dynamic_dimension> dyn_dims(v.at("dynamic_dimensions").size());
std::transform(v_dd.begin(), v_dd.end(), dyn_dims.begin(), [](migraphx::value x) { std::transform(v_dd.begin(), v_dd.end(), dyn_dims.begin(), [](migraphx::value x) {
auto x_min = x.at("min").template to<size_t>(); auto x_min = x.at("min").template to<size_t>();
auto x_max = x.at("max").template to<size_t>(); auto x_max = x.at("max").template to<size_t>();
auto x_opt = x.at("opt").template to<size_t>(); auto v_optimals = x.at("optimals");
return shape::dynamic_dimension{x_min, x_max, x_opt}; std::set<size_t> set_x_optimals =
from_value<std::set<std::size_t>>(x.at("optimals"));
return shape::dynamic_dimension{x_min, x_max, set_x_optimals};
}); });
s = shape{shape::parse_type(t), dyn_dims}; s = shape{shape::parse_type(t), dyn_dims};
......
...@@ -36,7 +36,7 @@ bool create_shapes(bool dynamic_allowed) ...@@ -36,7 +36,7 @@ bool create_shapes(bool dynamic_allowed)
try try
{ {
shape a{shape::int64_type, {3}}; shape a{shape::int64_type, {3}};
shape b{shape::float_type, {{3, 6, 0}, {4, 4, 0}}}; shape b{shape::float_type, {{3, 6}, {4, 4}}};
auto op = migraphx::make_op("add"); auto op = migraphx::make_op("add");
migraphx::check_shapes{{a, b}, op, dynamic_allowed}.has(2); migraphx::check_shapes{{a, b}, op, dynamic_allowed}.has(2);
return true; return true;
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment