Commit 60557bc6 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 795a7083
...@@ -34,13 +34,16 @@ struct reduce_mean ...@@ -34,13 +34,16 @@ struct reduce_mean
return {s.type(), lens}; return {s.type(), lens};
} }
template<class T> template <class T>
void calc_mean(tensor_view<T>& input, shape& batch_shape, std::vector<std::size_t>& out_idx, tensor_view<T>& output) const void calc_mean(tensor_view<T>& input,
shape& batch_shape,
std::vector<std::size_t>& out_idx,
tensor_view<T>& output) const
{ {
auto data_idx = out_idx; auto data_idx = out_idx;
T val = T{0}; T val = T{0};
shape_for_each(batch_shape, [&](auto b_idx) { shape_for_each(batch_shape, [&](auto b_idx) {
for (auto axis : axes) for(auto axis : axes)
{ {
data_idx[axis] = b_idx[axis]; data_idx[axis] = b_idx[axis];
} }
...@@ -55,7 +58,8 @@ struct reduce_mean ...@@ -55,7 +58,8 @@ struct reduce_mean
argument result{output_shape}; argument result{output_shape};
auto arg_lens = args.front().get_shape().lens(); auto arg_lens = args.front().get_shape().lens();
std::vector<std::size_t> batch_lens(output_shape.lens().size(), 1); std::vector<std::size_t> batch_lens(output_shape.lens().size(), 1);
for (auto axis : axes) { for(auto axis : axes)
{
batch_lens[axis] = arg_lens[axis]; batch_lens[axis] = arg_lens[axis];
} }
shape batch_shape{output_shape.type(), batch_lens}; shape batch_shape{output_shape.type(), batch_lens};
......
...@@ -34,13 +34,16 @@ struct reduce_sum ...@@ -34,13 +34,16 @@ struct reduce_sum
return {s.type(), lens}; return {s.type(), lens};
} }
template<class T> template <class T>
void calc_sum(tensor_view<T>& input, shape& batch_shape, std::vector<std::size_t>& out_idx, tensor_view<T>& output) const void calc_sum(tensor_view<T>& input,
shape& batch_shape,
std::vector<std::size_t>& out_idx,
tensor_view<T>& output) const
{ {
auto data_idx = out_idx; auto data_idx = out_idx;
T val = T{0}; T val = T{0};
shape_for_each(batch_shape, [&](auto b_idx) { shape_for_each(batch_shape, [&](auto b_idx) {
for (auto axis : axes) for(auto axis : axes)
{ {
data_idx[axis] = b_idx[axis]; data_idx[axis] = b_idx[axis];
} }
...@@ -55,7 +58,8 @@ struct reduce_sum ...@@ -55,7 +58,8 @@ struct reduce_sum
argument result{output_shape}; argument result{output_shape};
auto arg_lens = args.front().get_shape().lens(); auto arg_lens = args.front().get_shape().lens();
std::vector<std::size_t> batch_lens(output_shape.lens().size(), 1); std::vector<std::size_t> batch_lens(output_shape.lens().size(), 1);
for (auto axis : axes) { for(auto axis : axes)
{
batch_lens[axis] = arg_lens[axis]; batch_lens[axis] = arg_lens[axis];
} }
shape batch_shape{output_shape.type(), batch_lens}; shape batch_shape{output_shape.type(), batch_lens};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment