Commit 25bad0f3 authored by Paul's avatar Paul
Browse files

Improve visit_all to handle shapes as well

parent ddb6356b
...@@ -76,7 +76,7 @@ struct hip_shape ...@@ -76,7 +76,7 @@ struct hip_shape
}; };
template <std::size_t N> template <std::size_t N>
hip_shape<N> make_hip(const shape& x) hip_shape<N> make_hip_shape(const shape& x)
{ {
return x; return x;
} }
......
...@@ -39,9 +39,9 @@ struct hip_tensor_view ...@@ -39,9 +39,9 @@ struct hip_tensor_view
}; };
template <std::size_t N, class T> template <std::size_t N, class T>
hip_tensor_view<T, N> make_hip(tensor_view<T> x) hip_tensor_view<T, N> make_hip_view(const shape& s, T* x)
{ {
return x; return {x, s};
} }
} // namespace device } // namespace device
......
...@@ -43,21 +43,50 @@ void visit_tensor_size(std::size_t n, F f) ...@@ -43,21 +43,50 @@ void visit_tensor_size(std::size_t n, F f)
} }
} }
inline std::size_t tensor_size(const shape& x) { return x.lens().size(); } inline shape get_shape(const shape& x) { return x; }
template <class T> template <class T>
auto tensor_size(const T& x) -> decltype(x.get_shape().lens().size()) auto get_shape(const T& x) -> decltype(x.get_shape())
{ {
return x.get_shape().lens().size(); return x.get_shape();
}
template <class V, class F, class... Ts>
void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
{
visit_tensor_size(s.lens().size(), [&](auto ndim) {
s.visit_type([&](auto as) { v(f(xs, ndim, as)...); });
});
}
template<class F>
struct hip_convert
{
F f;
template<class RawData, class N, class As>
auto operator()(RawData x, N ndim, As as) const -> decltype(make_hip_view<ndim>(x.get_shape(), f(as.from(x.data()))))
{
return make_hip_view<ndim>(x.get_shape(), f(as.from(x.data())));
}
template<class N, class As>
auto operator()(const shape& s, N ndim, As) const
{
return make_hip_shape<ndim>(s);
}
};
template<class F>
hip_convert<F> make_hip_convert(F f)
{
return {f};
} }
template <class T, class... Ts> template <class T, class... Ts>
auto hip_visit_all(T&& x, Ts&&... xs) auto hip_visit_all(T&& x, Ts&&... xs)
{ {
return [&](auto f) { return [&](auto f) {
visit_tensor_size(tensor_size(x), [&](auto dim) { hip_visit_all_impl(get_shape(x), make_hip_convert([](auto* p) {return device_cast(p);}), f, x, xs...);
visit_all(x, xs...)([&](auto... vs) { f(make_hip<dim>(device_cast(vs))...); });
});
}; };
} }
...@@ -65,10 +94,7 @@ template <std::size_t N, class T, class... Ts> ...@@ -65,10 +94,7 @@ template <std::size_t N, class T, class... Ts>
auto hip_vec_visit_all(T&& x, Ts&&... xs) auto hip_vec_visit_all(T&& x, Ts&&... xs)
{ {
return [&](auto f) { return [&](auto f) {
visit_tensor_size(tensor_size(x), [&](auto dim) { hip_visit_all_impl(get_shape(x), make_hip_convert([](auto* p) {return as_vec<N>(device_cast(p));}), f, x, xs...);
visit_all(x,
xs...)([&](auto... vs) { f(make_hip<dim>(as_vec<N>(device_cast(vs)))...); });
});
}; };
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment