Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ce4f0940
Commit
ce4f0940
authored
Dec 07, 2023
by
Paul
Browse files
Add dpp assembly for dpp_reduce
parent
4cc5393d
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
92 additions
and
52 deletions
+92
-52
src/targets/gpu/jit/layernorm.cpp
src/targets/gpu/jit/layernorm.cpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/pp.hpp
src/targets/gpu/kernels/include/migraphx/kernels/pp.hpp
+29
-0
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+62
-51
No files found.
src/targets/gpu/jit/layernorm.cpp
View file @
ce4f0940
...
@@ -81,7 +81,7 @@ struct layernorm_compiler : compiler<layernorm_compiler>
...
@@ -81,7 +81,7 @@ struct layernorm_compiler : compiler<layernorm_compiler>
}
}
auto
relements
=
inputs
[
0
].
lens
()[
axis
]
/
vec
.
size
;
auto
relements
=
inputs
[
0
].
lens
()[
axis
]
/
vec
.
size
;
auto
nelements
=
(
inputs
.
back
().
elements
()
/
inputs
[
0
].
lens
()[
axis
]);
auto
nelements
=
(
inputs
.
back
().
elements
()
/
inputs
[
0
].
lens
()[
axis
]);
auto
block_size
=
compute_block_size
(
relements
,
256
);
auto
block_size
=
compute_block_size
(
ctx
,
relements
,
256
);
hip_compile_options
options
;
hip_compile_options
options
;
options
.
set_launch_params
(
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
...
...
src/targets/gpu/kernels/include/migraphx/kernels/pp.hpp
0 → 100644
View file @
ce4f0940
#ifndef MIGRAPHX_GUARD_KERNELS_PP_HPP
#define MIGRAPHX_GUARD_KERNELS_PP_HPP
#define MIGRAPHX_PP_PRIMITIVE_CAT(x, y) x##y
#define MIGRAPHX_PP_CAT(x, y) MIGRAPHX_PP_PRIMITIVE_CAT(x, y)
#define MIGRAPHX_PP_EAT(...)
#define MIGRAPHX_PP_EXPAND(...) __VA_ARGS__
#define MIGRAPHX_PP_REPEAT0(m, ...) m(0, __VA_ARGS__)
#define MIGRAPHX_PP_REPEAT1(m, ...) MIGRAPHX_PP_REPEAT0(m, __VA_ARGS__) m(1, __VA_ARGS__)
#define MIGRAPHX_PP_REPEAT2(m, ...) MIGRAPHX_PP_REPEAT1(m, __VA_ARGS__) m(2, __VA_ARGS__)
#define MIGRAPHX_PP_REPEAT3(m, ...) MIGRAPHX_PP_REPEAT2(m, __VA_ARGS__) m(3, __VA_ARGS__)
#define MIGRAPHX_PP_REPEAT4(m, ...) MIGRAPHX_PP_REPEAT3(m, __VA_ARGS__) m(4, __VA_ARGS__)
#define MIGRAPHX_PP_REPEAT5(m, ...) MIGRAPHX_PP_REPEAT4(m, __VA_ARGS__) m(5, __VA_ARGS__)
#define MIGRAPHX_PP_REPEAT6(m, ...) MIGRAPHX_PP_REPEAT5(m, __VA_ARGS__) m(6, __VA_ARGS__)
#define MIGRAPHX_PP_REPEAT7(m, ...) MIGRAPHX_PP_REPEAT6(m, __VA_ARGS__) m(7, __VA_ARGS__)
#define MIGRAPHX_PP_REPEAT8(m, ...) MIGRAPHX_PP_REPEAT7(m, __VA_ARGS__) m(8, __VA_ARGS__)
#define MIGRAPHX_PP_REPEAT9(m, ...) MIGRAPHX_PP_REPEAT8(m, __VA_ARGS__) m(9, __VA_ARGS__)
#define MIGRAPHX_PP_REPEAT10(m, ...) MIGRAPHX_PP_REPEAT9(m, __VA_ARGS__) m(10, __VA_ARGS__)
#define MIGRAPHX_PP_REPEAT(n, m, ...) MIGRAPHX_PP_PRIMITIVE_CAT(MIGRAPHX_PP_REPEAT, n)(m, __VA_ARGS__)
namespace
migraphx
{
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_PP_HPP
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
ce4f0940
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/pp.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -81,69 +82,79 @@ __device__ void dpp_reduce(T& in, Op op)
...
@@ -81,69 +82,79 @@ __device__ void dpp_reduce(T& in, Op op)
#endif
#endif
}
}
template
<
class
T
,
class
Op
>
#if 1
__device__
void
dpp_reduce
(
T
&
in
,
Op
op
)
{
dpp_reduce
<
__AMDGCN_WAVEFRONT_SIZE
>
(
in
,
op
);
}
#if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK)
#if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK)
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
#define MIGRAPHX_DPP_REDUCE_ASM_FUN(type, op, ins) \
template<unsigned int SubWaveSize> \
__device__ inline void dpp_reduce(type& x, op f) \
{ \
(void)f; \
(void)f; \
x = 1
x = 1; \
#elif __AMDGCN_WAVEFRONT_SIZE == 64
}
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_bcast:15 row_mask:0xa\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n" \
"s_nop 1\n" \
: "=v"(x) \
: "0"(x)); \
(void)f
#else
#else
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
#define MIGRAPHX_DPP_IIF64(then, ...) then
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
#define MIGRAPHX_DPP_IIF32(then, ...) __VA_ARGS__
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
#define MIGRAPHX_DPP_IF_64(x) MIGRAPHX_PP_CAT(MIGRAPHX_DPP_IIF, x)
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
#define MIGRAPHX_DPP_WHEN_64(x) MIGRAPHX_DPP_IF_64(x)(MIGRAPHX_PP_EXPAND, MIGRAPHX_PP_EAT)
"s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \
#define MIGRAPHX_DPP_REDUCE_ASM0(ins) #ins " %0 %0 %0 row_shr:1\n"
#define MIGRAPHX_DPP_REDUCE_ASM1(ins) #ins " %0 %0 %0 row_shr:2\n"
#define MIGRAPHX_DPP_REDUCE_ASM2(ins) #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n"
#define MIGRAPHX_DPP_REDUCE_ASM3(ins) #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n"
#define MIGRAPHX_DPP_REDUCE_ASM4(ins) #ins " %0 %0 %0 row_bcast:15 row_mask:0xa\n"
#define MIGRAPHX_DPP_REDUCE_ASM5(ins) #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n"
#define MIGRAPHX_DPP_REDUCE_ASM_REPEAT(i, ins) MIGRAPHX_PP_CAT(MIGRAPHX_DPP_REDUCE_ASM, i)(ins) "s_nop 1\n"
#define MIGRAPHX_DPP_REDUCE_ASM(n, x, ins, ...) { \
__asm__ volatile("s_nop 4\n" \
MIGRAPHX_PP_REPEAT(n, MIGRAPHX_DPP_REDUCE_ASM_REPEAT, ins) \
: "=v"(x) \
: "=v"(x) \
: "0"(x)); \
: "0"(x)); __VA_ARGS__ \
}
#if __AMDGCN_WAVEFRONT_SIZE == 64
#define MIGRAPHX_DPP_REDUCE_SWIZZLE(x, f) (void)f;
#else
#define MIGRAPHX_DPP_REDUCE_SWIZZLE(x, f) \
auto y = dpp_swizzle<0x1e0>(x); \
auto y = dpp_swizzle<0x1e0>(x); \
x = f(x, y)
x = f(x, y)
;
#endif
#endif
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE_ASM_FUN(type, op, ins) \
#define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \
template<unsigned int SubWaveSize> \
__device__ inline void dpp_reduce(double& x, op f) \
__device__ inline void dpp_reduce(type& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64, f); \
} \
__device__ inline void dpp_reduce(float& x, op f) \
{ \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32, f); \
if constexpr(SubWaveSize == 2) MIGRAPHX_DPP_REDUCE_ASM(0, x, ins,); \
} \
if constexpr(SubWaveSize == 4) MIGRAPHX_DPP_REDUCE_ASM(1, x, ins,); \
__device__ inline void dpp_reduce(half& x, op f) \
if constexpr(SubWaveSize == 8) MIGRAPHX_DPP_REDUCE_ASM(2, x, ins,); \
{ \
if constexpr(SubWaveSize == 16) MIGRAPHX_DPP_REDUCE_ASM(3, x, ins,); \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16, f); \
if constexpr(SubWaveSize == 32) MIGRAPHX_DPP_REDUCE_ASM(MIGRAPHX_DPP_IF_64(__AMDGCN_WAVEFRONT_SIZE)(4, 3), x, ins,MIGRAPHX_DPP_REDUCE_SWIZZLE(x, f)); \
} \
MIGRAPHX_DPP_WHEN_64(__AMDGCN_WAVEFRONT_SIZE)(if constexpr(SubWaveSize == 64) MIGRAPHX_DPP_REDUCE_ASM(5, x, ins,)); \
__device__ inline void dpp_reduce(int32_t& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32, f); \
} \
__device__ inline void dpp_reduce(uint32_t& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32, f); \
}
}
#endif
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \
MIGRAPHX_DPP_REDUCE_ASM_FUN(double, op, prefix##_f64); \
MIGRAPHX_DPP_REDUCE_ASM_FUN(float, op, prefix##_f32); \
MIGRAPHX_DPP_REDUCE_ASM_FUN(half, op, prefix##_f16); \
MIGRAPHX_DPP_REDUCE_ASM_FUN(int32_t, op, prefix##sign##32); \
MIGRAPHX_DPP_REDUCE_ASM_FUN(uint32_t, op, prefix##_u32);
// Note: when max and min are in int32_t, signed version of instruction needs to be used.
// Note: when max and min are in int32_t, signed version of instruction needs to be used.
MIGRAPHX_DPP_REDUCE
(
op
::
sum
,
v_add
,
_u
)
MIGRAPHX_DPP_REDUCE
(
op
::
sum
,
v_add
,
_u
)
MIGRAPHX_DPP_REDUCE
(
op
::
product
,
v_mul
,
_u
)
MIGRAPHX_DPP_REDUCE
(
op
::
product
,
v_mul
,
_u
)
MIGRAPHX_DPP_REDUCE
(
op
::
max
,
v_max
,
_i
)
MIGRAPHX_DPP_REDUCE
(
op
::
max
,
v_max
,
_i
)
MIGRAPHX_DPP_REDUCE
(
op
::
min
,
v_min
,
_i
)
MIGRAPHX_DPP_REDUCE
(
op
::
min
,
v_min
,
_i
)
#endif
template
<
class
T
,
class
Op
>
__device__
void
dpp_reduce
(
T
&
in
,
Op
op
)
{
dpp_reduce
<
__AMDGCN_WAVEFRONT_SIZE
>
(
in
,
op
);
}
template
<
unsigned
int
SubWaveSize
,
class
Op
,
class
T
,
class
Index
,
class
F
>
template
<
unsigned
int
SubWaveSize
,
class
Op
,
class
T
,
class
Index
,
class
F
>
__device__
auto
subwave_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
__device__
auto
subwave_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment