"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "39ca7206b7dabd4ef142c65df870c8d28b876fe0"
Commit 39dc2b4a authored by Paul's avatar Paul
Browse files

Formatting

parent 1ce83fbe
......@@ -151,7 +151,7 @@ struct shape
private:
std::shared_ptr<const shape_impl> impl;
std::size_t element_space() const;
std::string type_string() const;
};
......
......@@ -11,67 +11,70 @@ namespace migraph {
struct shape_impl
{
static std::shared_ptr<shape_impl> default_shape()
{
static std::shared_ptr<shape_impl> result = std::make_shared<shape_impl>();
return result;
}
shape_impl() : m_type(shape::float_type), m_standard(false) {}
shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({1}), m_standard(true) {}
shape_impl(shape::type_t t, std::vector<std::size_t> l)
: m_type(t), m_lens(std::move(l)), m_standard(true)
{
this->calculate_strides();
assert(m_lens.size() == m_strides.size());
}
shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
{
assert(m_lens.size() == m_strides.size());
assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and
"At least one stride must be non-zero");
m_standard = this->elements() == this->element_space() and std::is_sorted(m_strides.rbegin(), m_strides.rend());
}
shape::type_t m_type;
std::vector<std::size_t> m_lens;
std::vector<std::size_t> m_strides;
bool m_standard;
void calculate_strides()
{
m_strides.clear();
m_strides.resize(m_lens.size(), 0);
if(m_strides.empty())
return;
m_strides.back() = 1;
std::partial_sum(
m_lens.rbegin(), m_lens.rend() - 1, m_strides.rbegin() + 1, std::multiplies<std::size_t>());
}
std::size_t element_space() const
{
assert(m_lens.size() == m_strides.size());
if(m_lens.empty())
return 0;
return std::inner_product(m_lens.begin(),
m_lens.end(),
m_strides.begin(),
std::size_t{0},
std::plus<std::size_t>{},
[](std::size_t l, std::size_t s) { return (l - 1) * s; }) +
1;
}
std::size_t elements() const
{
assert(m_lens.size() == m_strides.size());
if(m_lens.empty())
return 0;
return std::accumulate(
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
}
static std::shared_ptr<shape_impl> default_shape()
{
static std::shared_ptr<shape_impl> result = std::make_shared<shape_impl>();
return result;
}
shape_impl() : m_type(shape::float_type), m_standard(false) {}
shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({1}), m_standard(true) {}
shape_impl(shape::type_t t, std::vector<std::size_t> l)
: m_type(t), m_lens(std::move(l)), m_standard(true)
{
this->calculate_strides();
assert(m_lens.size() == m_strides.size());
}
shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
{
assert(m_lens.size() == m_strides.size());
assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and
"At least one stride must be non-zero");
m_standard = this->elements() == this->element_space() and
std::is_sorted(m_strides.rbegin(), m_strides.rend());
}
shape::type_t m_type;
std::vector<std::size_t> m_lens;
std::vector<std::size_t> m_strides;
bool m_standard;
void calculate_strides()
{
m_strides.clear();
m_strides.resize(m_lens.size(), 0);
if(m_strides.empty())
return;
m_strides.back() = 1;
std::partial_sum(m_lens.rbegin(),
m_lens.rend() - 1,
m_strides.rbegin() + 1,
std::multiplies<std::size_t>());
}
std::size_t element_space() const
{
assert(m_lens.size() == m_strides.size());
if(m_lens.empty())
return 0;
return std::inner_product(m_lens.begin(),
m_lens.end(),
m_strides.begin(),
std::size_t{0},
std::plus<std::size_t>{},
[](std::size_t l, std::size_t s) { return (l - 1) * s; }) +
1;
}
std::size_t elements() const
{
assert(m_lens.size() == m_strides.size());
if(m_lens.empty())
return 0;
return std::accumulate(
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
}
};
shape::shape() : impl(shape_impl::default_shape()) {}
......@@ -89,10 +92,7 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
shape::type_t shape::type() const { return impl->m_type; }
const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; }
std::size_t shape::elements() const
{
return impl->elements();
}
std::size_t shape::elements() const { return impl->elements(); }
std::size_t shape::bytes() const
{
std::size_t n = 0;
......@@ -145,10 +145,7 @@ bool shape::broadcasted() const
bool shape::standard() const { return impl->m_standard; }
std::size_t shape::element_space() const
{
return impl->element_space();
}
std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment