"torchvision/tv_tensors/_bounding_boxes.py" did not exist on "30b879fc68a6970d5c82afeb5d7e0b00a3771967"
Unverified Commit 8c73c72e authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Improve dpp reductions on navi (#2439)

parent 7d4b1719
...@@ -49,12 +49,8 @@ constexpr unsigned int dpp_row_bcast(unsigned int x) ...@@ -49,12 +49,8 @@ constexpr unsigned int dpp_row_bcast(unsigned int x)
return y; return y;
} }
template <unsigned int DppCtrl, template <class T, class F>
unsigned int RowMask = 0xf, __device__ T dpp_op(T& x, F f)
unsigned int BankMask = 0xf,
bool BoundCtrl = false,
class T>
__device__ T dpp_mov(T& x)
{ {
static const index_int n = sizeof(T) < 4 ? 1 : sizeof(T) / 4; static const index_int n = sizeof(T) < 4 ? 1 : sizeof(T) / 4;
union type union type
...@@ -68,10 +64,28 @@ __device__ T dpp_mov(T& x) ...@@ -68,10 +64,28 @@ __device__ T dpp_mov(T& x)
input.data = x; input.data = x;
for(index_int i = 0; i < n; i++) for(index_int i = 0; i < n; i++)
{ {
output.reg[i] = __hip_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl); output.reg[i] = f(input.reg[i]);
} }
return output.data; return output.data;
} }
template <unsigned int DppCtrl,
unsigned int RowMask = 0xf,
unsigned int BankMask = 0xf,
bool BoundCtrl = false,
class T>
__device__ T dpp_mov(T& x)
{
return dpp_op(x,
[](auto i) { return __hip_move_dpp(i, DppCtrl, RowMask, BankMask, BoundCtrl); });
}
template <unsigned int Mask, class T>
__device__ T dpp_swizzle(T& x)
{
return dpp_op(x, [](auto i) { return __hip_ds_swizzle(i, Mask); });
}
#endif // MIGRAPHX_HAS_DPP #endif // MIGRAPHX_HAS_DPP
} // namespace migraphx } // namespace migraphx
......
...@@ -45,7 +45,10 @@ __device__ void dpp_reduce(T& in, Op op) ...@@ -45,7 +45,10 @@ __device__ void dpp_reduce(T& in, Op op)
in = op(in, out); in = op(in, out);
out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in); out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in);
in = op(in, out); in = op(in, out);
#if __AMDGCN_WAVEFRONT_SIZE == 64 #if __AMDGCN_WAVEFRONT_SIZE == 32
out = dpp_swizzle<0x1e0>(in);
in = op(in, out);
#else
out = dpp_mov<dpp_row_bcast(15), 0xa>(in); out = dpp_mov<dpp_row_bcast(15), 0xa>(in);
in = op(in, out); in = op(in, out);
out = dpp_mov<dpp_row_bcast(31), 0xc>(in); out = dpp_mov<dpp_row_bcast(31), 0xc>(in);
...@@ -54,9 +57,11 @@ __device__ void dpp_reduce(T& in, Op op) ...@@ -54,9 +57,11 @@ __device__ void dpp_reduce(T& in, Op op)
} }
#if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK) #if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK)
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) x = 1 #define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
(void)f; \
x = 1
#elif __AMDGCN_WAVEFRONT_SIZE == 64 #elif __AMDGCN_WAVEFRONT_SIZE == 64
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) \ #define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \ __asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
...@@ -65,29 +70,42 @@ __device__ void dpp_reduce(T& in, Op op) ...@@ -65,29 +70,42 @@ __device__ void dpp_reduce(T& in, Op op)
"s_nop 1\n" #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n" \
"s_nop 1\n" \ "s_nop 1\n" \
: "=v"(x) \ : "=v"(x) \
: "0"(x)) : "0"(x)); \
(void)f
#else #else
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) \ #define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \ __asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \ "s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \
"s_nop 1\n" \
"s_nop 1\n" \
: "=v"(x) \ : "=v"(x) \
: "0"(x)) : "0"(x)); \
auto y = dpp_swizzle<0x1e0>(x); \
x = f(x, y)
#endif #endif
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \ #define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \ __device__ inline void dpp_reduce(double& x, op f) \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \ { \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \ MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64, f); \
__device__ inline void dpp_reduce(int32_t& x, op) \ } \
__device__ inline void dpp_reduce(float& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32, f); \
} \
__device__ inline void dpp_reduce(half& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16, f); \
} \
__device__ inline void dpp_reduce(int32_t& x, op f) \
{ \ { \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32); \ MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32, f); \
} \ } \
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); } __device__ inline void dpp_reduce(uint32_t& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32, f); \
}
// Note: when max and min are in int32_t, signed version of instruction needs to be used. // Note: when max and min are in int32_t, signed version of instruction needs to be used.
MIGRAPHX_DPP_REDUCE(op::sum, v_add, _u) MIGRAPHX_DPP_REDUCE(op::sum, v_add, _u)
...@@ -99,11 +117,7 @@ template <class Op, class T, class Index, class F> ...@@ -99,11 +117,7 @@ template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal()); MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
#if __AMDGCN_WAVEFRONT_SIZE == 32 constexpr index_int lanes_per_thread = __AMDGCN_WAVEFRONT_SIZE;
constexpr index_int lanes_per_thread = 16;
#else
constexpr index_int lanes_per_thread = 64;
#endif
using type = decltype(index::invoke_loop(f, 0, _c<0>)); using type = decltype(index::invoke_loop(f, 0, _c<0>));
__shared__ type buffer[idx.max_nlocal() / lanes_per_thread]; __shared__ type buffer[idx.max_nlocal() / lanes_per_thread];
type x = type(init); type x = type(init);
......
...@@ -46,11 +46,13 @@ struct test_reduce_op_small : verify_program<test_reduce_op_small<Op, Axis, T>> ...@@ -46,11 +46,13 @@ struct test_reduce_op_small : verify_program<test_reduce_op_small<Op, Axis, T>>
}; };
template struct test_reduce_op_small<migraphx::op::reduce_sum, 1, migraphx::shape::float_type>; template struct test_reduce_op_small<migraphx::op::reduce_sum, 1, migraphx::shape::float_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 3, migraphx::shape::float_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 2, migraphx::shape::int32_type>; template struct test_reduce_op_small<migraphx::op::reduce_sum, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::shape::int32_type>; template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::int32_type>; template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_min, 2, migraphx::shape::int32_type>; template struct test_reduce_op_small<migraphx::op::reduce_min, 2, migraphx::shape::int32_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 3, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum, 2, migraphx::shape::half_type>; template struct test_reduce_op_small<migraphx::op::reduce_sum, 2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::shape::half_type>; template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::half_type>; template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::half_type>;
...@@ -60,6 +62,9 @@ template struct test_reduce_op_small<migraphx::op::reduce_prod, -2, migraphx::sh ...@@ -60,6 +62,9 @@ template struct test_reduce_op_small<migraphx::op::reduce_prod, -2, migraphx::sh
template struct test_reduce_op_small<migraphx::op::reduce_sum, template struct test_reduce_op_small<migraphx::op::reduce_sum,
2, 2,
migraphx::shape::fp8e4m3fnuz_type>; migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum,
3,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_small<migraphx::op::reduce_mean, template struct test_reduce_op_small<migraphx::op::reduce_mean,
2, 2,
migraphx::shape::fp8e4m3fnuz_type>; migraphx::shape::fp8e4m3fnuz_type>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment