Commit c2db1916 authored by Paul's avatar Paul
Browse files

Simplify vecotrized types

parent 6c057881
......@@ -217,9 +217,9 @@ struct hip_shape
template <class T, std::size_t N>
struct hip_tensor_view
{
using value_type = device_type<T>;
using value_type = T;
__device__ __host__ hip_tensor_view() = default;
__host__ hip_tensor_view(tensor_view<T> x) : d(device_cast(x.data())), s(x.get_shape()) {}
__host__ hip_tensor_view(tensor_view<T> x) : d(x.data()), s(x.get_shape()) {}
__host__ hip_tensor_view(T* x, const shape& ss) : d(x), s(ss) {}
MIGRAPHX_DEVICE_CONSTEXPR const hip_shape<N>& get_shape() const { return s; }
......@@ -249,12 +249,6 @@ hip_tensor_view<T, N> make_hip_tensor_view(tensor_view<T> x)
return x;
}
template <std::size_t N, std::size_t M, class T>
hip_tensor_view<vec<device_type<T>, M>, N> make_hip_vec_tensor_view(tensor_view<T> x)
{
return {as_vec<M>(device_cast(x.data())), x.get_shape()};
}
template <std::size_t N, std::size_t M, class T>
hip_vector<hip_tensor_view<T, N>, M> make_hip_tensor_views(const std::vector<tensor_view<T>>& x)
{
......@@ -269,7 +263,7 @@ auto hip_visit_all(T&& x, Ts&&... xs)
{
return [&](auto f) {
visit_tensor_size(x.get_shape().lens().size(), [&](auto dim) {
visit_all(x, xs...)([&](auto... vs) { f(make_hip_tensor_view<dim>(vs)...); });
visit_all(x, xs...)([&](auto... vs) { f(make_hip_tensor_view<dim>(device_cast(vs))...); });
});
};
}
......@@ -279,7 +273,7 @@ auto hip_vec_visit_all(T&& x, Ts&&... xs)
{
return [&](auto f) {
visit_tensor_size(x.get_shape().lens().size(), [&](auto dim) {
visit_all(x, xs...)([&](auto... vs) { f(make_hip_vec_tensor_view<dim, N>(vs)...); });
visit_all(x, xs...)([&](auto... vs) { f(make_hip_tensor_view<dim>(as_vec<N>(device_cast(vs)))...); });
});
};
}
......
......@@ -10,15 +10,23 @@
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/tensor_view.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <class T, std::size_t N>
using vec = T __attribute__((ext_vector_type(N)));
template <std::size_t N, class T>
__device__ __host__ T* as_pointer(vec<T, N>* x)
{
return reinterpret_cast<T*>(x);
}
template <std::size_t N, class T>
__device__ __host__ vec<T, N>* as_vec(T* x)
{
......@@ -26,9 +34,9 @@ __device__ __host__ vec<T, N>* as_vec(T* x)
}
template <std::size_t N, class T>
__device__ __host__ T* as_pointer(vec<T, N>* x)
tensor_view<vec<T, N>> as_vec(tensor_view<T> x)
{
return reinterpret_cast<T*>(x);
return {x.get_shape(), as_vec<N>(x.data())};
}
template <std::size_t N, class... Ts>
......@@ -47,9 +55,9 @@ struct device_type
};
template <class T, std::size_t N>
struct device_type<T __attribute__((ext_vector_type(N)))>
struct device_type<vec<T, N>>
{
using type = typename device_type<T>::type __attribute__((ext_vector_type(N)));
using type = vec<typename device_type<T>::type, N>;
};
template <>
......@@ -102,6 +110,12 @@ device_type<T>* device_cast(T* x)
return reinterpret_cast<device_type<T>*>(x);
}
template <class T>
tensor_view<device_type<T>> device_cast(tensor_view<T> x)
{
return {x.get_shape(), reinterpret_cast<device_type<T>*>(x.data())};
}
template <class T>
T to_hip_type(T x)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment