Commit 67c92b83 authored by Paul's avatar Paul
Browse files

Fx subwave implementation

parent 5be87179
......@@ -126,7 +126,7 @@ static std::size_t compute_subwave_size(context& ctx, std::size_t n)
{
std::size_t max_wavefront_size = ctx.get_current_device().get_wavefront_size();
std::size_t wavefront_size = 1;
while(wavefront_size < n and wavefront_size < max_wavefront_size)
while(wavefront_size <= n and wavefront_size < max_wavefront_size)
wavefront_size *= 2;
return wavefront_size;
}
......
......@@ -30,6 +30,8 @@
namespace migraphx {
constexpr bool is_power_of_2(unsigned int x) { return x > 0 && !(x & (x - 1)); }
#ifndef MIGRAPHX_HAS_DPP
#define MIGRAPHX_HAS_DPP 1
#endif
......@@ -86,6 +88,13 @@ __device__ T dpp_swizzle(T& x)
return dpp_op(x, [](auto i) { return __hip_ds_swizzle(i, Mask); });
}
template<unsigned int SrcLane, unsigned int Width, class T>
__device__ T dpp_readlane(T& x)
{
static_assert(is_power_of_2(Width), "Width must be a power of 2");
return dpp_op(x, [](auto i) { return __shfl(i, SrcLane, Width); });
}
#endif // MIGRAPHX_HAS_DPP
} // namespace migraphx
......
......@@ -143,6 +143,10 @@ struct index
template <unsigned int SubWaveSize>
constexpr auto local_subwave() const
{
#ifdef MIGRAPHX_HAS_CONST_LOCAL
if constexpr(decltype(nlocal()){} == SubWaveSize)
return local;
#endif
return local % nlocal_subwave<SubWaveSize>();
}
template <unsigned int SubWaveSize>
......
......@@ -32,8 +32,6 @@
namespace migraphx {
constexpr bool is_power_of_2(unsigned int x) { return x > 0 && !(x & (x - 1)); }
#if MIGRAPHX_HAS_DPP
template <unsigned int SubWaveSize, class T, class Op>
......@@ -41,42 +39,41 @@ __device__ void dpp_reduce(T& in, Op op)
{
static_assert(SubWaveSize <= __AMDGCN_WAVEFRONT_SIZE, "Too large subwave size");
static_assert(is_power_of_2(SubWaveSize), "SubWaveSize is not a power of 2");
T out{};
if constexpr(SubWaveSize > 1)
{
out = dpp_mov<dpp_row_shr(1)>(in);
auto out = dpp_mov<dpp_row_shr(1)>(in);
in = op(in, out);
}
if constexpr(SubWaveSize > 2)
{
out = dpp_mov<dpp_row_shr(2)>(in);
auto out = dpp_mov<dpp_row_shr(2)>(in);
in = op(in, out);
}
if constexpr(SubWaveSize > 4)
{
out = dpp_mov<dpp_row_shr(4), 0xf, 0xe>(in);
auto out = dpp_mov<dpp_row_shr(4), 0xf, 0xe>(in);
in = op(in, out);
}
if constexpr(SubWaveSize > 8)
{
out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in);
auto out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in);
in = op(in, out);
}
#if __AMDGCN_WAVEFRONT_SIZE == 32
if constexpr(SubWaveSize > 16)
{
out = dpp_swizzle<0x1e0>(in);
auto out = dpp_swizzle<0x1e0>(in);
in = op(in, out);
}
#else
if constexpr(SubWaveSize > 16)
{
out = dpp_mov<dpp_row_bcast(15), 0xa>(in);
auto out = dpp_mov<dpp_row_bcast(15), 0xa>(in);
in = op(in, out);
}
if constexpr(SubWaveSize > 32)
{
out = dpp_mov<dpp_row_bcast(31), 0xc>(in);
auto out = dpp_mov<dpp_row_bcast(31), 0xc>(in);
in = op(in, out);
}
#endif
......@@ -173,9 +170,11 @@ __device__ auto subwave_reduce(index idx, Op op, T init, Index n, F f)
using type = decltype(index::invoke_loop(f, 0, _c<0>));
type x = init;
idx.local_subwave_stride<SubWaveSize>(
n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
n, [&](auto i, auto d) {
x = op(x, index::invoke_loop(f, i, d));
});
dpp_reduce<SubWaveSize>(x, op);
return x;
return dpp_readlane<SubWaveSize-1, SubWaveSize>(x);
}
template <class Op, class T, class Index, class F>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment