Commit 9ab38607 authored by Paul's avatar Paul
Browse files

Fix swizzle for navi

parent 328fce97
...@@ -46,7 +46,7 @@ __device__ void dpp_reduce(T& in, Op op) ...@@ -46,7 +46,7 @@ __device__ void dpp_reduce(T& in, Op op)
out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in); out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in);
in = op(in, out); in = op(in, out);
#if __AMDGCN_WAVEFRONT_SIZE == 32 #if __AMDGCN_WAVEFRONT_SIZE == 32
out = dpp_swizzle<dpp_row_bcast(15)>(in); out = dpp_swizzle<0x1e0>(in);
in = op(in, out); in = op(in, out);
#else #else
out = dpp_mov<dpp_row_bcast(15), 0xa>(in); out = dpp_mov<dpp_row_bcast(15), 0xa>(in);
...@@ -59,7 +59,7 @@ __device__ void dpp_reduce(T& in, Op op) ...@@ -59,7 +59,7 @@ __device__ void dpp_reduce(T& in, Op op)
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) x = 1 #define MIGRAPHX_DPP_REDUCE_ASM(x, ins) x = 1
#elif __AMDGCN_WAVEFRONT_SIZE == 64 #elif __AMDGCN_WAVEFRONT_SIZE == 64
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) \ #define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \ __asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
...@@ -68,29 +68,29 @@ __device__ void dpp_reduce(T& in, Op op) ...@@ -68,29 +68,29 @@ __device__ void dpp_reduce(T& in, Op op)
"s_nop 1\n" #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n" \
"s_nop 1\n" \ "s_nop 1\n" \
: "=v"(x) \ : "=v"(x) \
: "0"(x)) : "0"(x)); (void)f
#else #else
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) \ #define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \ __asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \
"s_nop 1\n" \
"s_nop 1\n" \
: "=v"(x) \ : "=v"(x) \
: "0"(x)) : "0"(x)); \
auto y = dpp_swizzle<0x1e0>(x); \
x = f(x, y)
#endif #endif
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \ #define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \ __device__ inline void dpp_reduce(double& x, op f) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64, f); } \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \ __device__ inline void dpp_reduce(float& x, op f) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32, f); } \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \ __device__ inline void dpp_reduce(half& x, op f) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16, f); } \
__device__ inline void dpp_reduce(int32_t& x, op) \ __device__ inline void dpp_reduce(int32_t& x, op f) \
{ \ { \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32); \ MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32, f); \
} \ } \
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); } __device__ inline void dpp_reduce(uint32_t& x, op f) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32, f); }
// Note: when max and min are in int32_t, signed version of instruction needs to be used. // Note: when max and min are in int32_t, signed version of instruction needs to be used.
MIGRAPHX_DPP_REDUCE(op::sum, v_add, _u) MIGRAPHX_DPP_REDUCE(op::sum, v_add, _u)
...@@ -102,11 +102,7 @@ template <class Op, class T, class Index, class F> ...@@ -102,11 +102,7 @@ template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal()); MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
#if __AMDGCN_WAVEFRONT_SIZE == 32 constexpr index_int lanes_per_thread = __AMDGCN_WAVEFRONT_SIZE;
constexpr index_int lanes_per_thread = 16;
#else
constexpr index_int lanes_per_thread = 64;
#endif
using type = decltype(index::invoke_loop(f, 0, _c<0>)); using type = decltype(index::invoke_loop(f, 0, _c<0>));
__shared__ type buffer[idx.max_nlocal() / lanes_per_thread]; __shared__ type buffer[idx.max_nlocal() / lanes_per_thread];
type x = init; type x = init;
......
...@@ -46,11 +46,13 @@ struct test_reduce_op_small : verify_program<test_reduce_op_small<Op, Axis, T>> ...@@ -46,11 +46,13 @@ struct test_reduce_op_small : verify_program<test_reduce_op_small<Op, Axis, T>>
}; };
template struct test_reduce_op_small<migraphx::op::reduce_sum, 1, migraphx::shape::float_type>; template struct test_reduce_op_small<migraphx::op::reduce_sum, 1, migraphx::shape::float_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 3, migraphx::shape::float_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 2, migraphx::shape::int32_type>; template struct test_reduce_op_small<migraphx::op::reduce_sum, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::shape::int32_type>; template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::int32_type>; template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_min, 2, migraphx::shape::int32_type>; template struct test_reduce_op_small<migraphx::op::reduce_min, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 3, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 2, migraphx::shape::half_type>; template struct test_reduce_op_small<migraphx::op::reduce_sum, 2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::shape::half_type>; template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::half_type>; template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::half_type>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment