Commit 0c1df49c authored by Paul's avatar Paul
Browse files

Add dpp reduce

parent b3935928
#include <migraphx/gpu/device/reduce_sum.hpp> #include <migraphx/gpu/device/reduce_sum.hpp>
#include <migraphx/gpu/device/launch.hpp> #include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/visit.hpp> #include <migraphx/gpu/device/visit.hpp>
#include <migraphx/requires.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -9,13 +10,14 @@ namespace device { ...@@ -9,13 +10,14 @@ namespace device {
struct sum struct sum
{ {
template <class T> template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR T operator()(T x, T y) const MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{ {
return x + y; return x + y;
} }
}; };
#ifdef MIGRAPHX_NO_DPP
template <std::size_t N, class Op, class T, class F> template <std::size_t N, class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f) __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
{ {
...@@ -37,10 +39,113 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f) ...@@ -37,10 +39,113 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
} }
return buffer[0]; return buffer[0];
} }
#else
constexpr unsigned int dpp_row_shr(unsigned int x)
{
return 0x110 | x;
}
constexpr unsigned int dpp_row_bcast(unsigned int x)
{
unsigned int y = 0;
switch(x)
{
case 15:
y = 0x142;
break;
case 31:
y = 0x143;
break;
default:
throw std::runtime_error("Unknown bcast");
}
return y;
}
template<unsigned int DppCtrl, unsigned int RowMask = 0xf, unsigned int BankMask = 0xf, bool BoundCtrl= false, class T>
__device__ T dpp_mov(T& x)
{
static const std::size_t n = sizeof(T) < 4 ? 1 : sizeof(T) / 4;
union type
{
uint32_t reg[n];
T data;
};
type output;
type input;
input.data = x;
for(std::size_t i = 0; i < n;i++)
{
output.reg[i] = __llvm_amdgcn_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl);
}
return output.data;
}
template<class T, class Op>
__device__ void dpp_reduce(T& in, Op op)
{
T out;
out = dpp_mov<dpp_row_shr(1)>(in);
in = op(in, out);
out = dpp_mov<dpp_row_shr(2)>(in);
in = op(in, out);
out = dpp_mov<dpp_row_shr(4), 0xf, 0xe>(in);
in = op(in, out);
out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in);
in = op(in, out);
out = dpp_mov<dpp_row_bcast(15), 0xa>(in);
in = op(in, out);
out = dpp_mov<dpp_row_bcast(31), 0xc>(in);
in = op(in, out);
}
__device__ void dpp_reduce(float& x, sum)
{
__asm__ volatile("s_nop 4\n"
"v_add_f32 %0 %0 %0 row_shr:1\n"
"s_nop 1\n"
"v_add_f32 %0 %0 %0 row_shr:2\n"
"s_nop 1\n"
"v_add_f32 %0 %0 %0 row_shr:4 bank_mask:0xe\n"
"s_nop 1\n"
"v_add_f32 %0 %0 %0 row_shr:8 bank_mask:0xc\n"
"s_nop 1\n"
"v_add_f32 %0 %0 %0 row_bcast:15 row_mask:0xa\n"
"s_nop 1\n"
"v_add_f32 %0 %0 %0 row_bcast:31 row_mask:0xc\n"
"s_nop 1\n"
: "=v"(x)
: "0"(x));
}
template <std::size_t N, class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
{
using type = decltype(f(idx.local));
const auto std::size_t wave = 64;
MIGRAPHX_DEVICE_SHARED type buffer[N/64];
type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
dpp_reduce(x, op);
const auto ldsidx = idx.local / 64;
if((idx.local % 64) == 63)
{
buffer[ldsidx] = x;
}
__syncthreads();
type y = 0;
for(std::size_t i = 0; i < idx.nlocal()/64;i++)
{
y += buffer[i];
}
return y;
}
#endif
constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_size) constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_size)
{ {
size_t block_size = 1; size_t block_size = 64;
while(block_size < max_block_size and block_size < n) while(block_size < max_block_size and block_size < n)
block_size *= 2; block_size *= 2;
return block_size; return block_size;
......
...@@ -3457,4 +3457,28 @@ struct test_reduce_sum : verify_program<test_reduce_sum> ...@@ -3457,4 +3457,28 @@ struct test_reduce_sum : verify_program<test_reduce_sum>
}; };
}; };
struct test_reduce_sum_int : verify_program<test_reduce_sum_int>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::int32_type, {3, 4, 8, 8}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_sum{{1}}, x);
return p;
};
};
struct test_reduce_sum_half : verify_program<test_reduce_sum_half>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::half_type, {3, 4, 8, 8}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_sum{{1}}, x);
return p;
};
};
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment