Commit 1ce83fbe authored by Paul's avatar Paul
Browse files

Make shape ref counted

parent 0dd8ee4f
......@@ -5,11 +5,14 @@
#include <cassert>
#include <ostream>
#include <numeric>
#include <memory>
#include <migraph/errors.hpp>
namespace migraph {
struct shape_impl;
struct shape
{
......@@ -136,7 +139,7 @@ struct shape
template <class Visitor>
void visit_type(Visitor v) const
{
switch(this->m_type)
switch(this->type())
{
#define MIGRAPH_SHAPE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return;
......@@ -147,12 +150,8 @@ struct shape
}
private:
type_t m_type;
std::vector<std::size_t> m_lens;
std::vector<std::size_t> m_strides;
bool m_standard;
void calculate_strides();
std::shared_ptr<const shape_impl> impl;
std::size_t element_space() const;
std::string type_string() const;
};
......
......@@ -8,45 +8,90 @@
namespace migraph {
shape::shape() : m_type(float_type), m_standard(false) {}
struct shape_impl
{
static std::shared_ptr<shape_impl> default_shape()
{
static std::shared_ptr<shape_impl> result = std::make_shared<shape_impl>();
return result;
}
shape_impl() : m_type(shape::float_type), m_standard(false) {}
shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({1}), m_standard(true) {}
shape_impl(shape::type_t t, std::vector<std::size_t> l)
: m_type(t), m_lens(std::move(l)), m_standard(true)
{
this->calculate_strides();
assert(m_lens.size() == m_strides.size());
}
shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
{
assert(m_lens.size() == m_strides.size());
assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and
"At least one stride must be non-zero");
m_standard = this->elements() == this->element_space() and std::is_sorted(m_strides.rbegin(), m_strides.rend());
}
shape::type_t m_type;
std::vector<std::size_t> m_lens;
std::vector<std::size_t> m_strides;
bool m_standard;
void calculate_strides()
{
m_strides.clear();
m_strides.resize(m_lens.size(), 0);
if(m_strides.empty())
return;
m_strides.back() = 1;
std::partial_sum(
m_lens.rbegin(), m_lens.rend() - 1, m_strides.rbegin() + 1, std::multiplies<std::size_t>());
}
std::size_t element_space() const
{
assert(m_lens.size() == m_strides.size());
if(m_lens.empty())
return 0;
return std::inner_product(m_lens.begin(),
m_lens.end(),
m_strides.begin(),
std::size_t{0},
std::plus<std::size_t>{},
[](std::size_t l, std::size_t s) { return (l - 1) * s; }) +
1;
}
std::size_t elements() const
{
assert(m_lens.size() == m_strides.size());
if(m_lens.empty())
return 0;
return std::accumulate(
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
}
};
shape::shape(type_t t) : m_type(t), m_lens({1}), m_strides({1}), m_standard(true) {}
shape::shape() : impl(shape_impl::default_shape()) {}
shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {}
shape::shape(type_t t, std::vector<std::size_t> l)
: m_type(t), m_lens(std::move(l)), m_standard(true)
: impl(std::make_shared<shape_impl>(t, std::move(l)))
{
this->calculate_strides();
assert(m_lens.size() == m_strides.size());
}
shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
{
assert(m_lens.size() == m_strides.size());
assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and
"At least one stride must be non-zero");
m_standard = this->packed() and not this->transposed();
}
void shape::calculate_strides()
: impl(std::make_shared<shape_impl>(t, std::move(l), std::move(s)))
{
m_strides.clear();
m_strides.resize(m_lens.size(), 0);
if(m_strides.empty())
return;
m_strides.back() = 1;
std::partial_sum(
m_lens.rbegin(), m_lens.rend() - 1, m_strides.rbegin() + 1, std::multiplies<std::size_t>());
}
shape::type_t shape::type() const { return this->m_type; }
const std::vector<std::size_t>& shape::lens() const { return this->m_lens; }
const std::vector<std::size_t>& shape::strides() const { return this->m_strides; }
shape::type_t shape::type() const { return impl->m_type; }
const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; }
std::size_t shape::elements() const
{
assert(this->lens().size() == this->strides().size());
if(this->lens().empty())
return 0;
return std::accumulate(
this->lens().begin(), this->lens().end(), std::size_t{1}, std::multiplies<std::size_t>());
return impl->elements();
}
std::size_t shape::bytes() const
{
......@@ -98,25 +143,16 @@ bool shape::broadcasted() const
std::multiplies<std::size_t>()) == 0;
}
bool shape::standard() const { return this->m_standard; }
bool shape::standard() const { return impl->m_standard; }
std::size_t shape::element_space() const
{
assert(this->lens().size() == this->strides().size());
if(this->lens().empty())
return 0;
return std::inner_product(this->lens().begin(),
this->lens().end(),
this->strides().begin(),
std::size_t{0},
std::plus<std::size_t>{},
[](std::size_t l, std::size_t s) { return (l - 1) * s; }) +
1;
return impl->element_space();
}
std::string shape::type_string() const
{
switch(this->m_type)
switch(this->type())
{
#define MIGRAPH_SHAPE_TYPE_STRING_CASE(x, t) \
case x: return #x;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment