Commit 4d913223 authored by charlie's avatar charlie
Browse files

Add dynamic shape constructor

Constructs a dynamic shape from three vectors of lengths,
the minimums, maximums, and optimals
parent b162c4ec
...@@ -176,11 +176,7 @@ struct convolution ...@@ -176,11 +176,7 @@ struct convolution
auto min_spatial_dims = calc_conv_lens(x_shape.min_lens(), w_shape.max_lens()); auto min_spatial_dims = calc_conv_lens(x_shape.min_lens(), w_shape.max_lens());
auto max_spatial_dims = calc_conv_lens(x_shape.max_lens(), w_shape.min_lens()); auto max_spatial_dims = calc_conv_lens(x_shape.max_lens(), w_shape.min_lens());
auto opt_spatial_dims = calc_conv_lens(x_shape.opt_lens(), w_shape.opt_lens()); auto opt_spatial_dims = calc_conv_lens(x_shape.opt_lens(), w_shape.opt_lens());
for(size_t i = 0; i < num_spatial_dims; ++i) return shape{x_shape.type(), min_spatial_dims, max_spatial_dims, opt_spatial_dims};
{
output_dyn_dims.push_back(shape::dynamic_dimension{
min_spatial_dims[i], max_spatial_dims[i], opt_spatial_dims[i]});
}
} }
return shape{x_shape.type(), output_dyn_dims}; return shape{x_shape.type(), output_dyn_dims};
} }
......
...@@ -115,6 +115,12 @@ struct shape ...@@ -115,6 +115,12 @@ struct shape
shape(type_t t, std::vector<dynamic_dimension> dims); shape(type_t t, std::vector<dynamic_dimension> dims);
// Construct a dynamic shape from three sets of lengths (of the same rank)
shape(type_t t,
std::vector<std::size_t> mins,
std::vector<std::size_t> maxes,
std::vector<std::size_t> opts);
template <class Range> template <class Range>
shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end())) shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end()))
{ {
......
...@@ -71,6 +71,19 @@ struct shape_impl ...@@ -71,6 +71,19 @@ struct shape_impl
{ {
} }
shape_impl(shape::type_t t,
std::vector<std::size_t> mins,
std::vector<std::size_t> maxes,
std::vector<std::size_t> opts)
: m_type(t)
{
assert(mins.size() == maxes.size() and maxes.size() == opts.size());
for(size_t i = 0; i < mins.size(); ++i)
{
m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i], opts[i]});
}
}
shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {} shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}
shape::type_t m_type; shape::type_t m_type;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment