"vscode:/vscode.git/clone" did not exist on "5aa0ba498414731c46f5fcd55ccf0540407aff7c"
Commit 25bad0f3 authored by Paul's avatar Paul
Browse files

Improve visit_all to handle shapes as well

parent ddb6356b
......@@ -76,7 +76,7 @@ struct hip_shape
};
template <std::size_t N>
hip_shape<N> make_hip(const shape& x)
hip_shape<N> make_hip_shape(const shape& x)
{
return x;
}
......
......@@ -39,9 +39,9 @@ struct hip_tensor_view
};
template <std::size_t N, class T>
hip_tensor_view<T, N> make_hip(tensor_view<T> x)
hip_tensor_view<T, N> make_hip_view(const shape& s, T* x)
{
return x;
return {x, s};
}
} // namespace device
......
......@@ -43,21 +43,50 @@ void visit_tensor_size(std::size_t n, F f)
}
}
inline std::size_t tensor_size(const shape& x) { return x.lens().size(); }
inline shape get_shape(const shape& x) { return x; }
template <class T>
auto tensor_size(const T& x) -> decltype(x.get_shape().lens().size())
auto get_shape(const T& x) -> decltype(x.get_shape())
{
return x.get_shape().lens().size();
return x.get_shape();
}
template <class V, class F, class... Ts>
void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
{
visit_tensor_size(s.lens().size(), [&](auto ndim) {
s.visit_type([&](auto as) { v(f(xs, ndim, as)...); });
});
}
template<class F>
struct hip_convert
{
F f;
template<class RawData, class N, class As>
auto operator()(RawData x, N ndim, As as) const -> decltype(make_hip_view<ndim>(x.get_shape(), f(as.from(x.data()))))
{
return make_hip_view<ndim>(x.get_shape(), f(as.from(x.data())));
}
template<class N, class As>
auto operator()(const shape& s, N ndim, As) const
{
return make_hip_shape<ndim>(s);
}
};
template<class F>
hip_convert<F> make_hip_convert(F f)
{
return {f};
}
template <class T, class... Ts>
auto hip_visit_all(T&& x, Ts&&... xs)
{
return [&](auto f) {
visit_tensor_size(tensor_size(x), [&](auto dim) {
visit_all(x, xs...)([&](auto... vs) { f(make_hip<dim>(device_cast(vs))...); });
});
hip_visit_all_impl(get_shape(x), make_hip_convert([](auto* p) {return device_cast(p);}), f, x, xs...);
};
}
......@@ -65,10 +94,7 @@ template <std::size_t N, class T, class... Ts>
auto hip_vec_visit_all(T&& x, Ts&&... xs)
{
return [&](auto f) {
visit_tensor_size(tensor_size(x), [&](auto dim) {
visit_all(x,
xs...)([&](auto... vs) { f(make_hip<dim>(as_vec<N>(device_cast(vs)))...); });
});
hip_visit_all_impl(get_shape(x), make_hip_convert([](auto* p) {return as_vec<N>(device_cast(p));}), f, x, xs...);
};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment