Commit 67c92b83 authored by Paul's avatar Paul
Browse files

Fx subwave implementation

parent 5be87179
...@@ -126,7 +126,7 @@ static std::size_t compute_subwave_size(context& ctx, std::size_t n) ...@@ -126,7 +126,7 @@ static std::size_t compute_subwave_size(context& ctx, std::size_t n)
{ {
std::size_t max_wavefront_size = ctx.get_current_device().get_wavefront_size(); std::size_t max_wavefront_size = ctx.get_current_device().get_wavefront_size();
std::size_t wavefront_size = 1; std::size_t wavefront_size = 1;
while(wavefront_size < n and wavefront_size < max_wavefront_size) while(wavefront_size <= n and wavefront_size < max_wavefront_size)
wavefront_size *= 2; wavefront_size *= 2;
return wavefront_size; return wavefront_size;
} }
......
...@@ -30,6 +30,8 @@ ...@@ -30,6 +30,8 @@
namespace migraphx { namespace migraphx {
constexpr bool is_power_of_2(unsigned int x) { return x > 0 && !(x & (x - 1)); }
#ifndef MIGRAPHX_HAS_DPP #ifndef MIGRAPHX_HAS_DPP
#define MIGRAPHX_HAS_DPP 1 #define MIGRAPHX_HAS_DPP 1
#endif #endif
...@@ -86,6 +88,13 @@ __device__ T dpp_swizzle(T& x) ...@@ -86,6 +88,13 @@ __device__ T dpp_swizzle(T& x)
return dpp_op(x, [](auto i) { return __hip_ds_swizzle(i, Mask); }); return dpp_op(x, [](auto i) { return __hip_ds_swizzle(i, Mask); });
} }
template<unsigned int SrcLane, unsigned int Width, class T>
__device__ T dpp_readlane(T& x)
{
static_assert(is_power_of_2(Width), "Width must be a power of 2");
return dpp_op(x, [](auto i) { return __shfl(i, SrcLane, Width); });
}
#endif // MIGRAPHX_HAS_DPP #endif // MIGRAPHX_HAS_DPP
} // namespace migraphx } // namespace migraphx
......
...@@ -143,6 +143,10 @@ struct index ...@@ -143,6 +143,10 @@ struct index
template <unsigned int SubWaveSize> template <unsigned int SubWaveSize>
constexpr auto local_subwave() const constexpr auto local_subwave() const
{ {
#ifdef MIGRAPHX_HAS_CONST_LOCAL
if constexpr(decltype(nlocal()){} == SubWaveSize)
return local;
#endif
return local % nlocal_subwave<SubWaveSize>(); return local % nlocal_subwave<SubWaveSize>();
} }
template <unsigned int SubWaveSize> template <unsigned int SubWaveSize>
......
...@@ -32,8 +32,6 @@ ...@@ -32,8 +32,6 @@
namespace migraphx { namespace migraphx {
constexpr bool is_power_of_2(unsigned int x) { return x > 0 && !(x & (x - 1)); }
#if MIGRAPHX_HAS_DPP #if MIGRAPHX_HAS_DPP
template <unsigned int SubWaveSize, class T, class Op> template <unsigned int SubWaveSize, class T, class Op>
...@@ -41,42 +39,41 @@ __device__ void dpp_reduce(T& in, Op op) ...@@ -41,42 +39,41 @@ __device__ void dpp_reduce(T& in, Op op)
{ {
static_assert(SubWaveSize <= __AMDGCN_WAVEFRONT_SIZE, "Too large subwave size"); static_assert(SubWaveSize <= __AMDGCN_WAVEFRONT_SIZE, "Too large subwave size");
static_assert(is_power_of_2(SubWaveSize), "SubWaveSize is not a power of 2"); static_assert(is_power_of_2(SubWaveSize), "SubWaveSize is not a power of 2");
T out{};
if constexpr(SubWaveSize > 1) if constexpr(SubWaveSize > 1)
{ {
out = dpp_mov<dpp_row_shr(1)>(in); auto out = dpp_mov<dpp_row_shr(1)>(in);
in = op(in, out); in = op(in, out);
} }
if constexpr(SubWaveSize > 2) if constexpr(SubWaveSize > 2)
{ {
out = dpp_mov<dpp_row_shr(2)>(in); auto out = dpp_mov<dpp_row_shr(2)>(in);
in = op(in, out); in = op(in, out);
} }
if constexpr(SubWaveSize > 4) if constexpr(SubWaveSize > 4)
{ {
out = dpp_mov<dpp_row_shr(4), 0xf, 0xe>(in); auto out = dpp_mov<dpp_row_shr(4), 0xf, 0xe>(in);
in = op(in, out); in = op(in, out);
} }
if constexpr(SubWaveSize > 8) if constexpr(SubWaveSize > 8)
{ {
out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in); auto out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in);
in = op(in, out); in = op(in, out);
} }
#if __AMDGCN_WAVEFRONT_SIZE == 32 #if __AMDGCN_WAVEFRONT_SIZE == 32
if constexpr(SubWaveSize > 16) if constexpr(SubWaveSize > 16)
{ {
out = dpp_swizzle<0x1e0>(in); auto out = dpp_swizzle<0x1e0>(in);
in = op(in, out); in = op(in, out);
} }
#else #else
if constexpr(SubWaveSize > 16) if constexpr(SubWaveSize > 16)
{ {
out = dpp_mov<dpp_row_bcast(15), 0xa>(in); auto out = dpp_mov<dpp_row_bcast(15), 0xa>(in);
in = op(in, out); in = op(in, out);
} }
if constexpr(SubWaveSize > 32) if constexpr(SubWaveSize > 32)
{ {
out = dpp_mov<dpp_row_bcast(31), 0xc>(in); auto out = dpp_mov<dpp_row_bcast(31), 0xc>(in);
in = op(in, out); in = op(in, out);
} }
#endif #endif
...@@ -173,9 +170,11 @@ __device__ auto subwave_reduce(index idx, Op op, T init, Index n, F f) ...@@ -173,9 +170,11 @@ __device__ auto subwave_reduce(index idx, Op op, T init, Index n, F f)
using type = decltype(index::invoke_loop(f, 0, _c<0>)); using type = decltype(index::invoke_loop(f, 0, _c<0>));
type x = init; type x = init;
idx.local_subwave_stride<SubWaveSize>( idx.local_subwave_stride<SubWaveSize>(
n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); }); n, [&](auto i, auto d) {
x = op(x, index::invoke_loop(f, i, d));
});
dpp_reduce<SubWaveSize>(x, op); dpp_reduce<SubWaveSize>(x, op);
return x; return dpp_readlane<SubWaveSize-1, SubWaveSize>(x);
} }
template <class Op, class T, class Index, class F> template <class Op, class T, class Index, class F>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment