Commit 9ab38607 authored by Paul's avatar Paul
Browse files

Fix swizzle for navi

parent 328fce97
......@@ -46,7 +46,7 @@ __device__ void dpp_reduce(T& in, Op op)
out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in);
in = op(in, out);
#if __AMDGCN_WAVEFRONT_SIZE == 32
out = dpp_swizzle<dpp_row_bcast(15)>(in);
out = dpp_swizzle<0x1e0>(in);
in = op(in, out);
#else
out = dpp_mov<dpp_row_bcast(15), 0xa>(in);
......@@ -59,7 +59,7 @@ __device__ void dpp_reduce(T& in, Op op)
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) x = 1
#elif __AMDGCN_WAVEFRONT_SIZE == 64
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) \
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
......@@ -68,29 +68,29 @@ __device__ void dpp_reduce(T& in, Op op)
"s_nop 1\n" #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n" \
"s_nop 1\n" \
: "=v"(x) \
: "0"(x))
: "0"(x)); (void)f
#else
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) \
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \
"s_nop 1\n" \
"s_nop 1\n" \
: "=v"(x) \
: "0"(x))
: "0"(x)); \
auto y = dpp_swizzle<0x1e0>(x); \
x = f(x, y)
#endif
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \
__device__ inline void dpp_reduce(int32_t& x, op) \
__device__ inline void dpp_reduce(double& x, op f) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64, f); } \
__device__ inline void dpp_reduce(float& x, op f) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32, f); } \
__device__ inline void dpp_reduce(half& x, op f) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16, f); } \
__device__ inline void dpp_reduce(int32_t& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32); \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32, f); \
} \
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); }
__device__ inline void dpp_reduce(uint32_t& x, op f) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32, f); }
// Note: when max and min are in int32_t, signed version of instruction needs to be used.
MIGRAPHX_DPP_REDUCE(op::sum, v_add, _u)
......@@ -102,11 +102,7 @@ template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
#if __AMDGCN_WAVEFRONT_SIZE == 32
constexpr index_int lanes_per_thread = 16;
#else
constexpr index_int lanes_per_thread = 64;
#endif
constexpr index_int lanes_per_thread = __AMDGCN_WAVEFRONT_SIZE;
using type = decltype(index::invoke_loop(f, 0, _c<0>));
__shared__ type buffer[idx.max_nlocal() / lanes_per_thread];
type x = init;
......
......@@ -46,11 +46,13 @@ struct test_reduce_op_small : verify_program<test_reduce_op_small<Op, Axis, T>>
};
template struct test_reduce_op_small<migraphx::op::reduce_sum, 1, migraphx::shape::float_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 3, migraphx::shape::float_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_min, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 3, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::half_type>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment