"...resnet50_tensorflow.git" did not exist on "321b8633441aa7a1905a65e628a3459437bc36a9"
Commit 39dc2b4a authored by Paul's avatar Paul
Browse files

Formatting

parent 1ce83fbe
...@@ -151,7 +151,7 @@ struct shape ...@@ -151,7 +151,7 @@ struct shape
private: private:
std::shared_ptr<const shape_impl> impl; std::shared_ptr<const shape_impl> impl;
std::size_t element_space() const; std::size_t element_space() const;
std::string type_string() const; std::string type_string() const;
}; };
......
...@@ -11,67 +11,70 @@ namespace migraph { ...@@ -11,67 +11,70 @@ namespace migraph {
struct shape_impl struct shape_impl
{ {
static std::shared_ptr<shape_impl> default_shape() static std::shared_ptr<shape_impl> default_shape()
{ {
static std::shared_ptr<shape_impl> result = std::make_shared<shape_impl>(); static std::shared_ptr<shape_impl> result = std::make_shared<shape_impl>();
return result; return result;
} }
shape_impl() : m_type(shape::float_type), m_standard(false) {} shape_impl() : m_type(shape::float_type), m_standard(false) {}
shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({1}), m_standard(true) {} shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({1}), m_standard(true) {}
shape_impl(shape::type_t t, std::vector<std::size_t> l) shape_impl(shape::type_t t, std::vector<std::size_t> l)
: m_type(t), m_lens(std::move(l)), m_standard(true) : m_type(t), m_lens(std::move(l)), m_standard(true)
{ {
this->calculate_strides(); this->calculate_strides();
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
} }
shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: m_type(t), m_lens(std::move(l)), m_strides(std::move(s)) : m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
{ {
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and
"At least one stride must be non-zero"); "At least one stride must be non-zero");
m_standard = this->elements() == this->element_space() and std::is_sorted(m_strides.rbegin(), m_strides.rend()); m_standard = this->elements() == this->element_space() and
} std::is_sorted(m_strides.rbegin(), m_strides.rend());
shape::type_t m_type; }
std::vector<std::size_t> m_lens; shape::type_t m_type;
std::vector<std::size_t> m_strides; std::vector<std::size_t> m_lens;
bool m_standard; std::vector<std::size_t> m_strides;
bool m_standard;
void calculate_strides()
{ void calculate_strides()
m_strides.clear(); {
m_strides.resize(m_lens.size(), 0); m_strides.clear();
if(m_strides.empty()) m_strides.resize(m_lens.size(), 0);
return; if(m_strides.empty())
m_strides.back() = 1; return;
std::partial_sum( m_strides.back() = 1;
m_lens.rbegin(), m_lens.rend() - 1, m_strides.rbegin() + 1, std::multiplies<std::size_t>()); std::partial_sum(m_lens.rbegin(),
} m_lens.rend() - 1,
m_strides.rbegin() + 1,
std::size_t element_space() const std::multiplies<std::size_t>());
{ }
assert(m_lens.size() == m_strides.size());
if(m_lens.empty()) std::size_t element_space() const
return 0; {
return std::inner_product(m_lens.begin(), assert(m_lens.size() == m_strides.size());
m_lens.end(), if(m_lens.empty())
m_strides.begin(), return 0;
std::size_t{0}, return std::inner_product(m_lens.begin(),
std::plus<std::size_t>{}, m_lens.end(),
[](std::size_t l, std::size_t s) { return (l - 1) * s; }) + m_strides.begin(),
1; std::size_t{0},
} std::plus<std::size_t>{},
[](std::size_t l, std::size_t s) { return (l - 1) * s; }) +
std::size_t elements() const 1;
{ }
assert(m_lens.size() == m_strides.size());
if(m_lens.empty()) std::size_t elements() const
return 0; {
return std::accumulate( assert(m_lens.size() == m_strides.size());
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>()); if(m_lens.empty())
} return 0;
return std::accumulate(
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
}
}; };
shape::shape() : impl(shape_impl::default_shape()) {} shape::shape() : impl(shape_impl::default_shape()) {}
...@@ -89,10 +92,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) ...@@ -89,10 +92,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
shape::type_t shape::type() const { return impl->m_type; } shape::type_t shape::type() const { return impl->m_type; }
const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; } const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; } const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; }
std::size_t shape::elements() const std::size_t shape::elements() const { return impl->elements(); }
{
return impl->elements();
}
std::size_t shape::bytes() const std::size_t shape::bytes() const
{ {
std::size_t n = 0; std::size_t n = 0;
...@@ -145,10 +145,7 @@ bool shape::broadcasted() const ...@@ -145,10 +145,7 @@ bool shape::broadcasted() const
bool shape::standard() const { return impl->m_standard; } bool shape::standard() const { return impl->m_standard; }
std::size_t shape::element_space() const std::size_t shape::element_space() const { return impl->element_space(); }
{
return impl->element_space();
}
std::string shape::type_string() const std::string shape::type_string() const
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment