Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8c73c72e
Unverified
Commit
8c73c72e
authored
Dec 05, 2023
by
Paul Fultz II
Committed by
GitHub
Dec 05, 2023
Browse files
Improve dpp reductions on navi (#2439)
parent
7d4b1719
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
62 additions
and
29 deletions
+62
-29
src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp
src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp
+21
-7
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+36
-22
test/verify/test_reduce_op_small.cpp
test/verify/test_reduce_op_small.cpp
+5
-0
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/dpp.hpp
View file @
8c73c72e
...
@@ -49,12 +49,8 @@ constexpr unsigned int dpp_row_bcast(unsigned int x)
...
@@ -49,12 +49,8 @@ constexpr unsigned int dpp_row_bcast(unsigned int x)
return
y
;
return
y
;
}
}
template
<
unsigned
int
DppCtrl
,
template
<
class
T
,
class
F
>
unsigned
int
RowMask
=
0xf
,
__device__
T
dpp_op
(
T
&
x
,
F
f
)
unsigned
int
BankMask
=
0xf
,
bool
BoundCtrl
=
false
,
class
T
>
__device__
T
dpp_mov
(
T
&
x
)
{
{
static
const
index_int
n
=
sizeof
(
T
)
<
4
?
1
:
sizeof
(
T
)
/
4
;
static
const
index_int
n
=
sizeof
(
T
)
<
4
?
1
:
sizeof
(
T
)
/
4
;
union
type
union
type
...
@@ -68,10 +64,28 @@ __device__ T dpp_mov(T& x)
...
@@ -68,10 +64,28 @@ __device__ T dpp_mov(T& x)
input
.
data
=
x
;
input
.
data
=
x
;
for
(
index_int
i
=
0
;
i
<
n
;
i
++
)
for
(
index_int
i
=
0
;
i
<
n
;
i
++
)
{
{
output
.
reg
[
i
]
=
__hip_move_dpp
(
input
.
reg
[
i
],
DppCtrl
,
RowMask
,
BankMask
,
BoundCtrl
);
output
.
reg
[
i
]
=
f
(
input
.
reg
[
i
]
);
}
}
return
output
.
data
;
return
output
.
data
;
}
}
template
<
unsigned
int
DppCtrl
,
unsigned
int
RowMask
=
0xf
,
unsigned
int
BankMask
=
0xf
,
bool
BoundCtrl
=
false
,
class
T
>
__device__
T
dpp_mov
(
T
&
x
)
{
return
dpp_op
(
x
,
[](
auto
i
)
{
return
__hip_move_dpp
(
i
,
DppCtrl
,
RowMask
,
BankMask
,
BoundCtrl
);
});
}
template
<
unsigned
int
Mask
,
class
T
>
__device__
T
dpp_swizzle
(
T
&
x
)
{
return
dpp_op
(
x
,
[](
auto
i
)
{
return
__hip_ds_swizzle
(
i
,
Mask
);
});
}
#endif // MIGRAPHX_HAS_DPP
#endif // MIGRAPHX_HAS_DPP
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
8c73c72e
...
@@ -45,7 +45,10 @@ __device__ void dpp_reduce(T& in, Op op)
...
@@ -45,7 +45,10 @@ __device__ void dpp_reduce(T& in, Op op)
in
=
op
(
in
,
out
);
in
=
op
(
in
,
out
);
out
=
dpp_mov
<
dpp_row_shr
(
8
),
0xf
,
0xc
>
(
in
);
out
=
dpp_mov
<
dpp_row_shr
(
8
),
0xf
,
0xc
>
(
in
);
in
=
op
(
in
,
out
);
in
=
op
(
in
,
out
);
#if __AMDGCN_WAVEFRONT_SIZE == 64
#if __AMDGCN_WAVEFRONT_SIZE == 32
out
=
dpp_swizzle
<
0x1e0
>
(
in
);
in
=
op
(
in
,
out
);
#else
out
=
dpp_mov
<
dpp_row_bcast
(
15
),
0xa
>
(
in
);
out
=
dpp_mov
<
dpp_row_bcast
(
15
),
0xa
>
(
in
);
in
=
op
(
in
,
out
);
in
=
op
(
in
,
out
);
out
=
dpp_mov
<
dpp_row_bcast
(
31
),
0xc
>
(
in
);
out
=
dpp_mov
<
dpp_row_bcast
(
31
),
0xc
>
(
in
);
...
@@ -54,9 +57,11 @@ __device__ void dpp_reduce(T& in, Op op)
...
@@ -54,9 +57,11 @@ __device__ void dpp_reduce(T& in, Op op)
}
}
#if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK)
#if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK)
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) x = 1
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
(void)f; \
x = 1
#elif __AMDGCN_WAVEFRONT_SIZE == 64
#elif __AMDGCN_WAVEFRONT_SIZE == 64
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins
)
\
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins
, f)
\
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
...
@@ -65,29 +70,42 @@ __device__ void dpp_reduce(T& in, Op op)
...
@@ -65,29 +70,42 @@ __device__ void dpp_reduce(T& in, Op op)
"s_nop 1\n" #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n" \
"s_nop 1\n" \
"s_nop 1\n" \
: "=v"(x) \
: "=v"(x) \
: "0"(x))
: "0"(x)); \
(void)f
#else
#else
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins
)
\
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins
, f)
\
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \
"s_nop 1\n" \
"s_nop 1\n" \
: "=v"(x) \
: "=v"(x) \
: "0"(x))
: "0"(x)); \
auto y = dpp_swizzle<0x1e0>(x); \
x = f(x, y)
#endif
#endif
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \
#define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \
__device__ inline void dpp_reduce(double& x, op f) \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \
{ \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64, f); \
__device__ inline void dpp_reduce(int32_t& x, op) \
} \
{ \
__device__ inline void dpp_reduce(float& x, op f) \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32); \
{ \
} \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32, f); \
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); }
} \
__device__ inline void dpp_reduce(half& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16, f); \
} \
__device__ inline void dpp_reduce(int32_t& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32, f); \
} \
__device__ inline void dpp_reduce(uint32_t& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32, f); \
}
// Note: when max and min are in int32_t, signed version of instruction needs to be used.
// Note: when max and min are in int32_t, signed version of instruction needs to be used.
MIGRAPHX_DPP_REDUCE
(
op
::
sum
,
v_add
,
_u
)
MIGRAPHX_DPP_REDUCE
(
op
::
sum
,
v_add
,
_u
)
...
@@ -99,11 +117,7 @@ template <class Op, class T, class Index, class F>
...
@@ -99,11 +117,7 @@ template <class Op, class T, class Index, class F>
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
{
{
MIGRAPHX_ASSERT
(
idx
.
max_nlocal
()
==
idx
.
nlocal
());
MIGRAPHX_ASSERT
(
idx
.
max_nlocal
()
==
idx
.
nlocal
());
#if __AMDGCN_WAVEFRONT_SIZE == 32
constexpr
index_int
lanes_per_thread
=
__AMDGCN_WAVEFRONT_SIZE
;
constexpr
index_int
lanes_per_thread
=
16
;
#else
constexpr
index_int
lanes_per_thread
=
64
;
#endif
using
type
=
decltype
(
index
::
invoke_loop
(
f
,
0
,
_c
<
0
>
));
using
type
=
decltype
(
index
::
invoke_loop
(
f
,
0
,
_c
<
0
>
));
__shared__
type
buffer
[
idx
.
max_nlocal
()
/
lanes_per_thread
];
__shared__
type
buffer
[
idx
.
max_nlocal
()
/
lanes_per_thread
];
type
x
=
type
(
init
);
type
x
=
type
(
init
);
...
...
test/verify/test_reduce_op_small.cpp
View file @
8c73c72e
...
@@ -46,11 +46,13 @@ struct test_reduce_op_small : verify_program<test_reduce_op_small<Op, Axis, T>>
...
@@ -46,11 +46,13 @@ struct test_reduce_op_small : verify_program<test_reduce_op_small<Op, Axis, T>>
};
};
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_sum
,
1
,
migraphx
::
shape
::
float_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_sum
,
1
,
migraphx
::
shape
::
float_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_sum
,
3
,
migraphx
::
shape
::
float_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_sum
,
2
,
migraphx
::
shape
::
int32_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_sum
,
2
,
migraphx
::
shape
::
int32_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_mean
,
2
,
migraphx
::
shape
::
int32_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_mean
,
2
,
migraphx
::
shape
::
int32_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_max
,
2
,
migraphx
::
shape
::
int32_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_max
,
2
,
migraphx
::
shape
::
int32_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_min
,
2
,
migraphx
::
shape
::
int32_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_min
,
2
,
migraphx
::
shape
::
int32_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_sum
,
3
,
migraphx
::
shape
::
half_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_sum
,
2
,
migraphx
::
shape
::
half_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_sum
,
2
,
migraphx
::
shape
::
half_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_mean
,
2
,
migraphx
::
shape
::
half_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_mean
,
2
,
migraphx
::
shape
::
half_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_max
,
2
,
migraphx
::
shape
::
half_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_max
,
2
,
migraphx
::
shape
::
half_type
>;
...
@@ -60,6 +62,9 @@ template struct test_reduce_op_small<migraphx::op::reduce_prod, -2, migraphx::sh
...
@@ -60,6 +62,9 @@ template struct test_reduce_op_small<migraphx::op::reduce_prod, -2, migraphx::sh
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_sum
,
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_sum
,
2
,
2
,
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_sum
,
3
,
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_mean
,
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_mean
,
2
,
2
,
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
...
...
gaoqiong
@gaoqiong
mentioned in commit
6f37331e
·
Dec 14, 2023
mentioned in commit
6f37331e
mentioned in commit 6f37331eb867a83481a21b0dd7a6ed6bf81306a2
Toggle commit list
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment