Commit 1ad95e66 authored by Paul's avatar Paul
Browse files

Formatting

parent bb666690
...@@ -51,15 +51,19 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args) ...@@ -51,15 +51,19 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
}); });
} }
template<class F> template <class F>
void binary_broadcast_vec_impl(F f, const argument& result, const argument& arg1, const argument& arg2) void binary_broadcast_vec_impl(F f,
const argument& result,
const argument& arg1,
const argument& arg2)
{ {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape(); const auto& b_shape = arg2.get_shape();
auto bdim = std::distance(b_shape.strides().begin(), auto bdim =
std::find_if(b_shape.strides().begin(), std::distance(b_shape.strides().begin(),
b_shape.strides().end(), std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
[](auto x) { return x != 0; })); return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim]; auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len; auto bdim_next_stride = bdim_stride * bdim_len;
...@@ -102,15 +106,16 @@ void binary_broadcast_vec_impl(F f, const argument& result, const argument& arg1 ...@@ -102,15 +106,16 @@ void binary_broadcast_vec_impl(F f, const argument& result, const argument& arg1
}); });
} }
template<class F> template <class F>
void binary_broadcast_impl(F f, const argument& result, const argument& arg1, const argument& arg2) void binary_broadcast_impl(F f, const argument& result, const argument& arg1, const argument& arg2)
{ {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape(); const auto& b_shape = arg2.get_shape();
auto bdim = std::distance(b_shape.strides().begin(), auto bdim =
std::find_if(b_shape.strides().begin(), std::distance(b_shape.strides().begin(),
b_shape.strides().end(), std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
[](auto x) { return x != 0; })); return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim]; auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim]; auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len; auto bdim_next_stride = bdim_stride * bdim_len;
...@@ -194,31 +199,24 @@ void nary_impl(F f, argument result, Arguments... args) ...@@ -194,31 +199,24 @@ void nary_impl(F f, argument result, Arguments... args)
nary_standard_impl(f, result, args...); nary_standard_impl(f, result, args...);
else else
nary_nonstandard_impl(f, result, args...); nary_nonstandard_impl(f, result, args...);
} }
template <class... Arguments> template <class... Arguments>
auto nary_nonstandard(argument result, Arguments... args) auto nary_nonstandard(argument result, Arguments... args)
{ {
return [=](auto f) { return [=](auto f) { nary_nonstandard_impl(f, result, args...); };
nary_nonstandard_impl(f, result, args...);
};
} }
template <class... Arguments> template <class... Arguments>
auto nary_standard(argument result, Arguments... args) auto nary_standard(argument result, Arguments... args)
{ {
return [=](auto f) { return [=](auto f) { nary_standard_impl(f, result, args...); };
nary_standard_impl(f, result, args...);
};
} }
template <class... Arguments> template <class... Arguments>
auto nary(argument result, Arguments... args) auto nary(argument result, Arguments... args)
{ {
return [=](auto f) { return [=](auto f) { nary_impl(f, result, args...); };
nary_impl(f, result, args...);
};
} }
inline auto nary(const argument& result, const argument& arg1, const argument& arg2) inline auto nary(const argument& result, const argument& arg1, const argument& arg2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment