Commit 55422f0e authored by Paul's avatar Paul
Browse files

Add vectorized nary broadcast

parent 8ec57ece
...@@ -14,28 +14,10 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -14,28 +14,10 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
// template <class T> template <class... Ts>
// using vec4 = T __attribute__((ext_vector_type(4))); auto pack(Ts... xs) __device__
template <class T, std::size_t N>
using vec = T __attribute__((ext_vector_type(N)));
template <std::size_t N, class T>
__device__ __host__ vec<T, N>* as_vec(T* x)
{
return reinterpret_cast<vec<T, N>*>(x);
}
template <std::size_t N, class T>
__device__ __host__ T* as_pointer(vec<T, N>* x)
{
return reinterpret_cast<T*>(x);
}
template <std::size_t N, class... Ts>
auto pack_vec(Ts... xs)
{ {
return [=](auto f, std::size_t n) { return f(as_vec<4>(xs)[n]...); }; return [=](auto f) { return f(xs...); };
} }
template <class F, class... Arguments> template <class F, class... Arguments>
...@@ -258,6 +240,55 @@ void binary_broadcast_impl( ...@@ -258,6 +240,55 @@ void binary_broadcast_impl(
}); });
} }
template <class F, class... Arguments>
void nary_broadcast_vec_impl(hipStream_t stream, F f, argument result, argument barg, Arguments... args)
{
const auto& output_shape = result.get_shape();
const auto& b_shape = barg.get_shape();
auto bdim =
std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
const std::size_t vec_size = 4;
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
const std::size_t bdim_vec_len = bdim_len / vec_size;
hip_vec_visit_all<vec_size>(result, barg, args...)([&](auto output, auto binput, auto... inputs) {
using type = typename decltype(output)::value_type;
const std::size_t nelements = output.size() / vec_size;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048 / vec_size];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{
buffer[i] = binput.data()[i];
}
__syncthreads();
auto* bp = as_pointer(buffer);
// Process the data
for(size_t i = idx.global; i < nelements; i += nglobal)
{
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b = bp[bidx];
auto out = output.data()[i];
pack(inputs.data()[i]...)([&](auto... xs) __device__ {
for(std::size_t j = 0; j < vec_size; j++)
{
output.data()[i][j] = f(xs[j]..., b);
}
});
output.data()[i] = out;
}
});
});
}
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg, Arguments... args) void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg, Arguments... args)
{ {
...@@ -385,15 +416,14 @@ auto nary(hipStream_t stream, argument result, Arguments... args) ...@@ -385,15 +416,14 @@ auto nary(hipStream_t stream, argument result, Arguments... args)
assert(bshape.lens()[b_idx] == b_len); assert(bshape.lens()[b_idx] == b_len);
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero)) if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
{ {
nary_broadcast_impl(stream, f, result, barg, args2...);
const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
// const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and (front_args(args...).get_shape().elements() % 4 == 0);
// (arg1.get_shape().elements() % 4 == 0); if(divisible_by_4)
// if(divisible_by_4) nary_broadcast_vec_impl(stream, f, result, barg, args2...);
// binary_broadcast_vec_impl(stream, f, result, arg1, arg); else
// else nary_broadcast_impl(stream, f, result, barg, args2...);
// binary_broadcast_impl(stream, f, result, arg1, arg); return;
// return;
} }
} }
}); });
......
...@@ -214,13 +214,13 @@ struct hip_shape ...@@ -214,13 +214,13 @@ struct hip_shape
return result; return result;
} }
}; };
template <class T, std::size_t N> template <class T, std::size_t N>
struct hip_tensor_view struct hip_tensor_view
{ {
using value_type = device_type<T>; using value_type = device_type<T>;
__device__ __host__ hip_tensor_view() = default; __device__ __host__ hip_tensor_view() = default;
__host__ hip_tensor_view(tensor_view<T> x) : d(device_cast(x.data())), s(x.get_shape()) {} __host__ hip_tensor_view(tensor_view<T> x) : d(device_cast(x.data())), s(x.get_shape()) {}
__host__ hip_tensor_view(T* x, const shape& ss) : d(x), s(ss) {}
MIGRAPHX_DEVICE_CONSTEXPR const hip_shape<N>& get_shape() const { return s; } MIGRAPHX_DEVICE_CONSTEXPR const hip_shape<N>& get_shape() const { return s; }
...@@ -249,6 +249,12 @@ hip_tensor_view<T, N> make_hip_tensor_view(tensor_view<T> x) ...@@ -249,6 +249,12 @@ hip_tensor_view<T, N> make_hip_tensor_view(tensor_view<T> x)
return x; return x;
} }
template <std::size_t N, std::size_t M, class T>
hip_tensor_view<vec<device_type<T>, M>, N> make_hip_vec_tensor_view(tensor_view<T> x)
{
return {as_vec<M>(device_cast(x.data())), x.get_shape()};
}
template <std::size_t N, std::size_t M, class T> template <std::size_t N, std::size_t M, class T>
hip_vector<hip_tensor_view<T, N>, M> make_hip_tensor_views(const std::vector<tensor_view<T>>& x) hip_vector<hip_tensor_view<T, N>, M> make_hip_tensor_views(const std::vector<tensor_view<T>>& x)
{ {
...@@ -268,6 +274,16 @@ auto hip_visit_all(T&& x, Ts&&... xs) ...@@ -268,6 +274,16 @@ auto hip_visit_all(T&& x, Ts&&... xs)
}; };
} }
template <std::size_t N, class T, class... Ts>
auto hip_vec_visit_all(T&& x, Ts&&... xs)
{
return [&](auto f) {
visit_tensor_size(x.get_shape().lens().size(), [&](auto dim) {
visit_all(x, xs...)([&](auto... vs) { f(make_hip_vec_tensor_view<dim, N>(vs)...); });
});
};
}
template <std::size_t N, class T> template <std::size_t N, class T>
auto hip_visit_all(const std::vector<T>& x) auto hip_visit_all(const std::vector<T>& x)
{ {
......
...@@ -16,6 +16,27 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -16,6 +16,27 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
template <class T, std::size_t N>
using vec = T __attribute__((ext_vector_type(N)));
template <std::size_t N, class T>
__device__ __host__ vec<T, N>* as_vec(T* x)
{
return reinterpret_cast<vec<T, N>*>(x);
}
template <std::size_t N, class T>
__device__ __host__ T* as_pointer(vec<T, N>* x)
{
return reinterpret_cast<T*>(x);
}
template <std::size_t N, class... Ts>
auto pack_vec(Ts... xs)
{
return [=](auto f, std::size_t n) { return f(as_vec<N>(xs)[n]...); };
}
using gpu_half = __fp16; using gpu_half = __fp16;
namespace detail { namespace detail {
...@@ -25,12 +46,19 @@ struct device_type ...@@ -25,12 +46,19 @@ struct device_type
using type = T; using type = T;
}; };
template <class T, std::size_t N>
struct device_type<T __attribute__((ext_vector_type(N)))>
{
using type = typename device_type<T>::type __attribute__((ext_vector_type(N)));
};
template <> template <>
struct device_type<half> struct device_type<half>
{ {
using type = gpu_half; using type = gpu_half;
}; };
template <class T> template <class T>
struct host_type struct host_type
{ {
...@@ -38,7 +66,7 @@ struct host_type ...@@ -38,7 +66,7 @@ struct host_type
}; };
template <> template <>
struct device_type<gpu_half> struct host_type<gpu_half>
{ {
using type = half; using type = half;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment