Commit 328fce97 authored by Paul's avatar Paul
Browse files

Improve dpp reductions on navi

parent 3c160a3f
...@@ -49,12 +49,8 @@ constexpr unsigned int dpp_row_bcast(unsigned int x) ...@@ -49,12 +49,8 @@ constexpr unsigned int dpp_row_bcast(unsigned int x)
return y; return y;
} }
template <unsigned int DppCtrl, template <class T, class F>
unsigned int RowMask = 0xf, __device__ T dpp_op(T& x, F f)
unsigned int BankMask = 0xf,
bool BoundCtrl = false,
class T>
__device__ T dpp_mov(T& x)
{ {
static const index_int n = sizeof(T) < 4 ? 1 : sizeof(T) / 4; static const index_int n = sizeof(T) < 4 ? 1 : sizeof(T) / 4;
union type union type
...@@ -68,10 +64,28 @@ __device__ T dpp_mov(T& x) ...@@ -68,10 +64,28 @@ __device__ T dpp_mov(T& x)
input.data = x; input.data = x;
for(index_int i = 0; i < n; i++) for(index_int i = 0; i < n; i++)
{ {
output.reg[i] = __hip_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl); output.reg[i] = f(input.reg[i]);
} }
return output.data; return output.data;
} }
template <unsigned int DppCtrl,
unsigned int RowMask = 0xf,
unsigned int BankMask = 0xf,
bool BoundCtrl = false,
class T>
__device__ T dpp_mov(T& x)
{
return dpp_op(x,
[](auto i) { return __hip_move_dpp(i, DppCtrl, RowMask, BankMask, BoundCtrl); });
}
template <unsigned int Mask, class T>
__device__ T dpp_swizzle(T& x)
{
return dpp_op(x, [](auto i) { return __hip_ds_swizzle(i, Mask); });
}
#endif // MIGRAPHX_HAS_DPP #endif // MIGRAPHX_HAS_DPP
} // namespace migraphx } // namespace migraphx
......
...@@ -45,7 +45,10 @@ __device__ void dpp_reduce(T& in, Op op) ...@@ -45,7 +45,10 @@ __device__ void dpp_reduce(T& in, Op op)
in = op(in, out); in = op(in, out);
out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in); out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in);
in = op(in, out); in = op(in, out);
#if __AMDGCN_WAVEFRONT_SIZE == 64 #if __AMDGCN_WAVEFRONT_SIZE == 32
out = dpp_swizzle<dpp_row_bcast(15)>(in);
in = op(in, out);
#else
out = dpp_mov<dpp_row_bcast(15), 0xa>(in); out = dpp_mov<dpp_row_bcast(15), 0xa>(in);
in = op(in, out); in = op(in, out);
out = dpp_mov<dpp_row_bcast(31), 0xc>(in); out = dpp_mov<dpp_row_bcast(31), 0xc>(in);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment