Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9ab38607
Commit
9ab38607
authored
Nov 11, 2023
by
Paul
Browse files
Fix swizzle for navi
parent
328fce97
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
18 deletions
+16
-18
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+14
-18
test/verify/test_reduce_op_small.cpp
test/verify/test_reduce_op_small.cpp
+2
-0
No files found.
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
9ab38607
...
...
@@ -46,7 +46,7 @@ __device__ void dpp_reduce(T& in, Op op)
out
=
dpp_mov
<
dpp_row_shr
(
8
),
0xf
,
0xc
>
(
in
);
in
=
op
(
in
,
out
);
#if __AMDGCN_WAVEFRONT_SIZE == 32
out
=
dpp_swizzle
<
dpp_row_bcast
(
15
)
>
(
in
);
out
=
dpp_swizzle
<
0x1e0
>
(
in
);
in
=
op
(
in
,
out
);
#else
out
=
dpp_mov
<
dpp_row_bcast
(
15
),
0xa
>
(
in
);
...
...
@@ -59,7 +59,7 @@ __device__ void dpp_reduce(T& in, Op op)
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) x = 1
#elif __AMDGCN_WAVEFRONT_SIZE == 64
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) \
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins
, f
) \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
...
...
@@ -68,29 +68,29 @@ __device__ void dpp_reduce(T& in, Op op)
"s_nop 1\n" #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n" \
"s_nop 1\n" \
: "=v"(x) \
: "0"(x))
: "0"(x))
; (void)f
#else
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) \
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins
, f
) \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \
"s_nop 1\n" \
"s_nop 1\n" \
: "=v"(x) \
: "0"(x))
: "0"(x)); \
auto y = dpp_swizzle<0x1e0>(x); \
x = f(x, y)
#endif
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \
__device__ inline void dpp_reduce(int32_t& x, op) \
__device__ inline void dpp_reduce(double& x, op
f
) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64
, f
); } \
__device__ inline void dpp_reduce(float& x, op
f
) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32
, f
); } \
__device__ inline void dpp_reduce(half& x, op
f
) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16
, f
); } \
__device__ inline void dpp_reduce(int32_t& x, op
f
) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32); \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32
, f
); \
} \
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); }
__device__ inline void dpp_reduce(uint32_t& x, op
f
) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32
, f
); }
// Note: when max and min are in int32_t, signed version of instruction needs to be used.
MIGRAPHX_DPP_REDUCE
(
op
::
sum
,
v_add
,
_u
)
...
...
@@ -102,11 +102,7 @@ template <class Op, class T, class Index, class F>
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
{
MIGRAPHX_ASSERT
(
idx
.
max_nlocal
()
==
idx
.
nlocal
());
#if __AMDGCN_WAVEFRONT_SIZE == 32
constexpr
index_int
lanes_per_thread
=
16
;
#else
constexpr
index_int
lanes_per_thread
=
64
;
#endif
constexpr
index_int
lanes_per_thread
=
__AMDGCN_WAVEFRONT_SIZE
;
using
type
=
decltype
(
index
::
invoke_loop
(
f
,
0
,
_c
<
0
>
));
__shared__
type
buffer
[
idx
.
max_nlocal
()
/
lanes_per_thread
];
type
x
=
init
;
...
...
test/verify/test_reduce_op_small.cpp
View file @
9ab38607
...
...
@@ -46,11 +46,13 @@ struct test_reduce_op_small : verify_program<test_reduce_op_small<Op, Axis, T>>
};
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_sum
,
1
,
migraphx
::
shape
::
float_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_sum
,
3
,
migraphx
::
shape
::
float_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_sum
,
2
,
migraphx
::
shape
::
int32_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_mean
,
2
,
migraphx
::
shape
::
int32_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_max
,
2
,
migraphx
::
shape
::
int32_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_min
,
2
,
migraphx
::
shape
::
int32_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_sum
,
3
,
migraphx
::
shape
::
half_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_sum
,
2
,
migraphx
::
shape
::
half_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_mean
,
2
,
migraphx
::
shape
::
half_type
>;
template
struct
test_reduce_op_small
<
migraphx
::
op
::
reduce_max
,
2
,
migraphx
::
shape
::
half_type
>;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment