"vscode:/vscode.git/clone" did not exist on "1eda0a1709b5a9cc77753e1bfd45e2067d523cd6"
Commit 60557bc6 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 795a7083
......@@ -34,20 +34,23 @@ struct reduce_mean
return {s.type(), lens};
}
template<class T>
void calc_mean(tensor_view<T>& input, shape& batch_shape, std::vector<std::size_t>& out_idx, tensor_view<T>& output) const
template <class T>
void calc_mean(tensor_view<T>& input,
shape& batch_shape,
std::vector<std::size_t>& out_idx,
tensor_view<T>& output) const
{
auto data_idx = out_idx;
T val = T{0};
T val = T{0};
shape_for_each(batch_shape, [&](auto b_idx) {
for (auto axis : axes)
for(auto axis : axes)
{
data_idx[axis] = b_idx[axis];
}
val += input(data_idx.begin(), data_idx.end());
});
output(out_idx.begin(), out_idx.end()) = val / batch_shape.elements();
output(out_idx.begin(), out_idx.end()) = val / batch_shape.elements();
}
argument compute(const shape& output_shape, std::vector<argument> args) const
......@@ -55,7 +58,8 @@ struct reduce_mean
argument result{output_shape};
auto arg_lens = args.front().get_shape().lens();
std::vector<std::size_t> batch_lens(output_shape.lens().size(), 1);
for (auto axis : axes) {
for(auto axis : axes)
{
batch_lens[axis] = arg_lens[axis];
}
shape batch_shape{output_shape.type(), batch_lens};
......
......@@ -34,20 +34,23 @@ struct reduce_sum
return {s.type(), lens};
}
template<class T>
void calc_sum(tensor_view<T>& input, shape& batch_shape, std::vector<std::size_t>& out_idx, tensor_view<T>& output) const
template <class T>
void calc_sum(tensor_view<T>& input,
shape& batch_shape,
std::vector<std::size_t>& out_idx,
tensor_view<T>& output) const
{
auto data_idx = out_idx;
T val = T{0};
T val = T{0};
shape_for_each(batch_shape, [&](auto b_idx) {
for (auto axis : axes)
for(auto axis : axes)
{
data_idx[axis] = b_idx[axis];
}
val += input(data_idx.begin(), data_idx.end());
});
output(out_idx.begin(), out_idx.end()) = val;
output(out_idx.begin(), out_idx.end()) = val;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
......@@ -55,7 +58,8 @@ struct reduce_sum
argument result{output_shape};
auto arg_lens = args.front().get_shape().lens();
std::vector<std::size_t> batch_lens(output_shape.lens().size(), 1);
for (auto axis : axes) {
for(auto axis : axes)
{
batch_lens[axis] = arg_lens[axis];
}
shape batch_shape{output_shape.type(), batch_lens};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment