Commit 72c188be authored by Paul's avatar Paul
Browse files

Formatting

parent 25bad0f3
...@@ -54,29 +54,29 @@ auto get_shape(const T& x) -> decltype(x.get_shape()) ...@@ -54,29 +54,29 @@ auto get_shape(const T& x) -> decltype(x.get_shape())
template <class V, class F, class... Ts> template <class V, class F, class... Ts>
void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
{ {
visit_tensor_size(s.lens().size(), [&](auto ndim) { visit_tensor_size(s.lens().size(),
s.visit_type([&](auto as) { v(f(xs, ndim, as)...); }); [&](auto ndim) { s.visit_type([&](auto as) { v(f(xs, ndim, as)...); }); });
});
} }
template<class F> template <class F>
struct hip_convert struct hip_convert
{ {
F f; F f;
template<class RawData, class N, class As> template <class RawData, class N, class As>
auto operator()(RawData x, N ndim, As as) const -> decltype(make_hip_view<ndim>(x.get_shape(), f(as.from(x.data())))) auto operator()(RawData x, N ndim, As as) const
-> decltype(make_hip_view<ndim>(x.get_shape(), f(as.from(x.data()))))
{ {
return make_hip_view<ndim>(x.get_shape(), f(as.from(x.data()))); return make_hip_view<ndim>(x.get_shape(), f(as.from(x.data())));
} }
template<class N, class As> template <class N, class As>
auto operator()(const shape& s, N ndim, As) const auto operator()(const shape& s, N ndim, As) const
{ {
return make_hip_shape<ndim>(s); return make_hip_shape<ndim>(s);
} }
}; };
template<class F> template <class F>
hip_convert<F> make_hip_convert(F f) hip_convert<F> make_hip_convert(F f)
{ {
return {f}; return {f};
...@@ -86,7 +86,8 @@ template <class T, class... Ts> ...@@ -86,7 +86,8 @@ template <class T, class... Ts>
auto hip_visit_all(T&& x, Ts&&... xs) auto hip_visit_all(T&& x, Ts&&... xs)
{ {
return [&](auto f) { return [&](auto f) {
hip_visit_all_impl(get_shape(x), make_hip_convert([](auto* p) {return device_cast(p);}), f, x, xs...); hip_visit_all_impl(
get_shape(x), make_hip_convert([](auto* p) { return device_cast(p); }), f, x, xs...);
}; };
} }
...@@ -94,7 +95,11 @@ template <std::size_t N, class T, class... Ts> ...@@ -94,7 +95,11 @@ template <std::size_t N, class T, class... Ts>
auto hip_vec_visit_all(T&& x, Ts&&... xs) auto hip_vec_visit_all(T&& x, Ts&&... xs)
{ {
return [&](auto f) { return [&](auto f) {
hip_visit_all_impl(get_shape(x), make_hip_convert([](auto* p) {return as_vec<N>(device_cast(p));}), f, x, xs...); hip_visit_all_impl(get_shape(x),
make_hip_convert([](auto* p) { return as_vec<N>(device_cast(p)); }),
f,
x,
xs...);
}; };
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment