Commit aa44099d authored by Shiv's avatar Shiv
Browse files

decoder shape fix

parent 00ef197c
...@@ -62,8 +62,11 @@ struct shape_impl ...@@ -62,8 +62,11 @@ struct shape_impl
{ {
assert(t != shape::tuple_type); assert(t != shape::tuple_type);
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
std::vector<std::size_t> std_strides = {};
this->calculate_standard_strides(std_strides);
bool is_scalar = std::accumulate(m_strides.begin(), m_strides.end(), std::size_t(0)) == 0;
m_standard = this->elements() == this->element_space() and not skips() and m_standard = this->elements() == this->element_space() and not skips() and
std::is_sorted(m_strides.rbegin(), m_strides.rend()); (m_strides == std_strides or is_scalar);
} }
shape_impl(shape::type_t t, std::vector<shape::dynamic_dimension> dims) shape_impl(shape::type_t t, std::vector<shape::dynamic_dimension> dims)
...@@ -94,16 +97,18 @@ struct shape_impl ...@@ -94,16 +97,18 @@ struct shape_impl
std::vector<shape::dynamic_dimension> m_dyn_dims = {}; std::vector<shape::dynamic_dimension> m_dyn_dims = {};
void calculate_strides() void calculate_strides() { this->calculate_standard_strides(m_strides); }
void calculate_standard_strides(std::vector<std::size_t>& strides)
{ {
m_strides.clear(); strides.clear();
m_strides.resize(m_lens.size(), 0); strides.resize(m_lens.size(), 0);
if(m_strides.empty()) if(strides.empty())
return; return;
m_strides.back() = 1; strides.back() = 1;
std::partial_sum(m_lens.rbegin(), std::partial_sum(m_lens.rbegin(),
m_lens.rend() - 1, m_lens.rend() - 1,
m_strides.rbegin() + 1, strides.rbegin() + 1,
std::multiplies<std::size_t>()); std::multiplies<std::size_t>());
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment