Commit 5ec978eb authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent edc23800
...@@ -46,7 +46,8 @@ struct parse_nonzero : op_parser<parse_nonzero> ...@@ -46,7 +46,8 @@ struct parse_nonzero : op_parser<parse_nonzero>
}); });
shape in_s = args[0]->get_shape(); shape in_s = args[0]->get_shape();
shape out_s{shape::int64_type, {static_cast<int>(in_s.lens().size()), static_cast<int>(indices.size())}}; shape out_s{shape::int64_type,
{static_cast<int>(in_s.lens().size()), static_cast<int>(indices.size())}};
std::vector<int64_t> out_data(out_s.elements()); std::vector<int64_t> out_data(out_s.elements());
for(int i = 0; i < indices.size(); ++i) for(int i = 0; i < indices.size(); ++i)
......
...@@ -84,12 +84,8 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -84,12 +84,8 @@ struct parse_pooling : op_parser<parse_pooling>
{ {
values["padding"].clear(); values["padding"].clear();
// return paddings could be empty, then setting to 0 for no padding // return paddings could be empty, then setting to 0 for no padding
cal_auto_padding_size(info, cal_auto_padding_size(
values, info, values, values["lengths"].to_vector<int>(), {1, 1}, in_lens, paddings);
values["lengths"].to_vector<int>(),
{1, 1},
in_lens,
paddings);
} }
if(paddings.size() != 2 * kdims) if(paddings.size() != 2 * kdims)
......
...@@ -1295,8 +1295,9 @@ bool rewrite_rnn::is_variable_seq_lens(const module& prog, instruction_ref seq_l ...@@ -1295,8 +1295,9 @@ bool rewrite_rnn::is_variable_seq_lens(const module& prog, instruction_ref seq_l
return is_var_lens; return is_var_lens;
} }
int int rewrite_rnn::get_seq_len(const module& prog,
rewrite_rnn::get_seq_len(const module& prog, instruction_ref input, instruction_ref seq_lens) const instruction_ref input,
instruction_ref seq_lens) const
{ {
bool is_var_lens = is_variable_seq_lens(prog, seq_lens); bool is_var_lens = is_variable_seq_lens(prog, seq_lens);
auto input_shape = input->get_shape(); auto input_shape = input->get_shape();
......
...@@ -58,10 +58,8 @@ struct shape_impl ...@@ -58,10 +58,8 @@ struct shape_impl
if(m_strides.empty()) if(m_strides.empty())
return; return;
m_strides.back() = 1; m_strides.back() = 1;
std::partial_sum(m_lens.rbegin(), std::partial_sum(
m_lens.rend() - 1, m_lens.rbegin(), m_lens.rend() - 1, m_strides.rbegin() + 1, std::multiplies<int>());
m_strides.rbegin() + 1,
std::multiplies<int>());
} }
int element_space() const int element_space() const
...@@ -83,8 +81,7 @@ struct shape_impl ...@@ -83,8 +81,7 @@ struct shape_impl
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
if(m_lens.empty()) if(m_lens.empty())
return 0; return 0;
return std::accumulate( return std::accumulate(m_lens.begin(), m_lens.end(), int{1}, std::multiplies<int>());
m_lens.begin(), m_lens.end(), int{1}, std::multiplies<int>());
} }
}; };
...@@ -124,10 +121,7 @@ std::string shape::cpp_type(shape::type_t t) ...@@ -124,10 +121,7 @@ std::string shape::cpp_type(shape::type_t t)
shape::shape() : impl(shape_impl::default_shape()) {} shape::shape() : impl(shape_impl::default_shape()) {}
shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {} shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {}
shape::shape(type_t t, std::vector<int> l) shape::shape(type_t t, std::vector<int> l) : impl(std::make_shared<shape_impl>(t, std::move(l))) {}
: impl(std::make_shared<shape_impl>(t, std::move(l)))
{
}
shape::shape(type_t t, std::vector<int> l, std::vector<int> s) shape::shape(type_t t, std::vector<int> l, std::vector<int> s)
: impl(std::make_shared<shape_impl>(t, std::move(l), std::move(s))) : impl(std::make_shared<shape_impl>(t, std::move(l), std::move(s)))
{ {
...@@ -135,9 +129,7 @@ shape::shape(type_t t, std::vector<int> l, std::vector<int> s) ...@@ -135,9 +129,7 @@ shape::shape(type_t t, std::vector<int> l, std::vector<int> s)
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {} shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape shape::from_permutation(type_t t, shape shape::from_permutation(type_t t, const std::vector<int>& l, const std::vector<int64_t>& perm)
const std::vector<int>& l,
const std::vector<int64_t>& perm)
{ {
auto new_lens = reorder_dims(l, perm); auto new_lens = reorder_dims(l, perm);
shape result = reorder_shape({t, new_lens}, invert_permutation(perm)); shape result = reorder_shape({t, new_lens}, invert_permutation(perm));
...@@ -221,11 +213,8 @@ void shape::multi_copy(int i, int* start, const int* end) const ...@@ -221,11 +213,8 @@ void shape::multi_copy(int i, int* start, const int* end) const
assert(this->standard()); assert(this->standard());
(void)end; (void)end;
assert(lens().size() <= (end - start)); assert(lens().size() <= (end - start));
std::transform(strides().begin(), std::transform(
strides().end(), strides().begin(), strides().end(), lens().begin(), start, [&](int stride, int len) {
lens().begin(),
start,
[&](int stride, int len) {
assert(len > 0 and stride > 0); assert(len > 0 and stride > 0);
return (i / stride) % len; return (i / stride) % len;
}); });
...@@ -258,10 +247,8 @@ bool shape::transposed() const ...@@ -258,10 +247,8 @@ bool shape::transposed() const
bool shape::broadcasted() const bool shape::broadcasted() const
{ {
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
return std::accumulate(this->strides().begin(), return std::accumulate(
this->strides().end(), this->strides().begin(), this->strides().end(), int{1}, std::multiplies<int>()) == 0;
int{1},
std::multiplies<int>()) == 0;
} }
bool shape::scalar() const bool shape::scalar() const
...@@ -289,10 +276,7 @@ shape shape::with_lens(type_t t, const std::vector<int>& l) const ...@@ -289,10 +276,7 @@ shape shape::with_lens(type_t t, const std::vector<int>& l) const
return shape::from_permutation(t, l, perm); return shape::from_permutation(t, l, perm);
} }
shape shape::with_lens(const std::vector<int>& l) const shape shape::with_lens(const std::vector<int>& l) const { return this->with_lens(this->type(), l); }
{
return this->with_lens(this->type(), l);
}
int shape::element_space() const { return impl->element_space(); } int shape::element_space() const { return impl->element_space(); }
...@@ -350,9 +334,8 @@ void migraphx_from_value(const value& v, shape& s) ...@@ -350,9 +334,8 @@ void migraphx_from_value(const value& v, shape& s)
} }
else else
{ {
s = shape{shape::parse_type(t), s = shape{
v.at("lens").to_vector<int>(), shape::parse_type(t), v.at("lens").to_vector<int>(), v.at("strides").to_vector<int>()};
v.at("strides").to_vector<int>()};
} }
} }
......
...@@ -626,9 +626,7 @@ struct find_split_concat ...@@ -626,9 +626,7 @@ struct find_split_concat
} }
}; };
bool axis_equal(const std::vector<int>& x, bool axis_equal(const std::vector<int>& x, const std::vector<int>& y, int axis)
const std::vector<int>& y,
int axis)
{ {
return x.size() == y.size() and x.size() > axis and return x.size() == y.size() and x.size() > axis and
std::equal(x.begin(), x.begin() + axis, y.begin()) and std::equal(x.begin(), x.begin() + axis, y.begin()) and
...@@ -912,8 +910,8 @@ struct find_split_reshape ...@@ -912,8 +910,8 @@ struct find_split_reshape
// ensure reshape happens after the axis dimension // ensure reshape happens after the axis dimension
auto axis = any_cast<op::slice>(slc->get_operator()).axes[0]; auto axis = any_cast<op::slice>(slc->get_operator()).axes[0];
auto slc_lens = slc->get_shape().lens(); auto slc_lens = slc->get_shape().lens();
auto slc_dim_size = std::accumulate( auto slc_dim_size =
slc_lens.begin() + axis, slc_lens.end(), 1, std::multiplies<int>()); std::accumulate(slc_lens.begin() + axis, slc_lens.end(), 1, std::multiplies<int>());
// search the reshape output (standard shape) to decide which axis are // search the reshape output (standard shape) to decide which axis are
// in its output corresponding to the slc_dim_size // in its output corresponding to the slc_dim_size
......
...@@ -92,8 +92,7 @@ void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, ...@@ -92,8 +92,7 @@ void migemm_impl(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat,
{ {
auto lens = amat.get_shape().lens(); auto lens = amat.get_shape().lens();
bool batch_mul = bool batch_mul =
std::accumulate( std::accumulate(lens.rbegin() + 2, lens.rend(), int{1}, std::multiplies<int>()) == 1;
lens.rbegin() + 2, lens.rend(), int{1}, std::multiplies<int>()) == 1;
if(batch_mul) if(batch_mul)
{ {
migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{}); migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment