Commit ce4f0940 authored by Paul's avatar Paul
Browse files

Add dpp assembly for dpp_reduce

parent 4cc5393d
......@@ -81,7 +81,7 @@ struct layernorm_compiler : compiler<layernorm_compiler>
}
auto relements = inputs[0].lens()[axis] / vec.size;
auto nelements = (inputs.back().elements() / inputs[0].lens()[axis]);
auto block_size = compute_block_size(relements, 256);
auto block_size = compute_block_size(ctx, relements, 256);
hip_compile_options options;
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
......
#ifndef MIGRAPHX_GUARD_KERNELS_PP_HPP
#define MIGRAPHX_GUARD_KERNELS_PP_HPP
#define MIGRAPHX_PP_PRIMITIVE_CAT(x, y) x##y
#define MIGRAPHX_PP_CAT(x, y) MIGRAPHX_PP_PRIMITIVE_CAT(x, y)
#define MIGRAPHX_PP_EAT(...)
#define MIGRAPHX_PP_EXPAND(...) __VA_ARGS__
#define MIGRAPHX_PP_REPEAT0(m, ...) m(0, __VA_ARGS__)
#define MIGRAPHX_PP_REPEAT1(m, ...) MIGRAPHX_PP_REPEAT0(m, __VA_ARGS__) m(1, __VA_ARGS__)
#define MIGRAPHX_PP_REPEAT2(m, ...) MIGRAPHX_PP_REPEAT1(m, __VA_ARGS__) m(2, __VA_ARGS__)
#define MIGRAPHX_PP_REPEAT3(m, ...) MIGRAPHX_PP_REPEAT2(m, __VA_ARGS__) m(3, __VA_ARGS__)
#define MIGRAPHX_PP_REPEAT4(m, ...) MIGRAPHX_PP_REPEAT3(m, __VA_ARGS__) m(4, __VA_ARGS__)
#define MIGRAPHX_PP_REPEAT5(m, ...) MIGRAPHX_PP_REPEAT4(m, __VA_ARGS__) m(5, __VA_ARGS__)
#define MIGRAPHX_PP_REPEAT6(m, ...) MIGRAPHX_PP_REPEAT5(m, __VA_ARGS__) m(6, __VA_ARGS__)
#define MIGRAPHX_PP_REPEAT7(m, ...) MIGRAPHX_PP_REPEAT6(m, __VA_ARGS__) m(7, __VA_ARGS__)
#define MIGRAPHX_PP_REPEAT8(m, ...) MIGRAPHX_PP_REPEAT7(m, __VA_ARGS__) m(8, __VA_ARGS__)
#define MIGRAPHX_PP_REPEAT9(m, ...) MIGRAPHX_PP_REPEAT8(m, __VA_ARGS__) m(9, __VA_ARGS__)
#define MIGRAPHX_PP_REPEAT10(m, ...) MIGRAPHX_PP_REPEAT9(m, __VA_ARGS__) m(10, __VA_ARGS__)
#define MIGRAPHX_PP_REPEAT(n, m, ...) MIGRAPHX_PP_PRIMITIVE_CAT(MIGRAPHX_PP_REPEAT, n)(m, __VA_ARGS__)
namespace migraphx {
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_PP_HPP
......@@ -28,6 +28,7 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/pp.hpp>
namespace migraphx {
......@@ -81,69 +82,79 @@ __device__ void dpp_reduce(T& in, Op op)
#endif
}
template <class T, class Op>
__device__ void dpp_reduce(T& in, Op op)
{
dpp_reduce<__AMDGCN_WAVEFRONT_SIZE>(in, op);
}
#if 1
#if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK)
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
(void)f; \
x = 1
#elif __AMDGCN_WAVEFRONT_SIZE == 64
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_bcast:15 row_mask:0xa\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n" \
"s_nop 1\n" \
: "=v"(x) \
: "0"(x)); \
(void)f
#define MIGRAPHX_DPP_REDUCE_ASM_FUN(type, op, ins) \
template<unsigned int SubWaveSize> \
__device__ inline void dpp_reduce(type& x, op f) \
{ \
(void)f; \
x = 1; \
}
#else
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins, f) \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \
: "=v"(x) \
: "0"(x)); \
auto y = dpp_swizzle<0x1e0>(x); \
x = f(x, y)
#define MIGRAPHX_DPP_IIF64(then, ...) then
#define MIGRAPHX_DPP_IIF32(then, ...) __VA_ARGS__
#define MIGRAPHX_DPP_IF_64(x) MIGRAPHX_PP_CAT(MIGRAPHX_DPP_IIF, x)
#define MIGRAPHX_DPP_WHEN_64(x) MIGRAPHX_DPP_IF_64(x)(MIGRAPHX_PP_EXPAND, MIGRAPHX_PP_EAT)
#define MIGRAPHX_DPP_REDUCE_ASM0(ins) #ins " %0 %0 %0 row_shr:1\n"
#define MIGRAPHX_DPP_REDUCE_ASM1(ins) #ins " %0 %0 %0 row_shr:2\n"
#define MIGRAPHX_DPP_REDUCE_ASM2(ins) #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n"
#define MIGRAPHX_DPP_REDUCE_ASM3(ins) #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n"
#define MIGRAPHX_DPP_REDUCE_ASM4(ins) #ins " %0 %0 %0 row_bcast:15 row_mask:0xa\n"
#define MIGRAPHX_DPP_REDUCE_ASM5(ins) #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n"
#define MIGRAPHX_DPP_REDUCE_ASM_REPEAT(i, ins) MIGRAPHX_PP_CAT(MIGRAPHX_DPP_REDUCE_ASM, i)(ins) "s_nop 1\n"
#define MIGRAPHX_DPP_REDUCE_ASM(n, x, ins, ...) { \
__asm__ volatile("s_nop 4\n" \
MIGRAPHX_PP_REPEAT(n, MIGRAPHX_DPP_REDUCE_ASM_REPEAT, ins) \
: "=v"(x) \
: "0"(x)); __VA_ARGS__ \
}
#if __AMDGCN_WAVEFRONT_SIZE == 64
#define MIGRAPHX_DPP_REDUCE_SWIZZLE(x, f) (void)f;
#else
#define MIGRAPHX_DPP_REDUCE_SWIZZLE(x, f) \
auto y = dpp_swizzle<0x1e0>(x); \
x = f(x, y);
#endif
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \
__device__ inline void dpp_reduce(double& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64, f); \
} \
__device__ inline void dpp_reduce(float& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32, f); \
} \
__device__ inline void dpp_reduce(half& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16, f); \
} \
__device__ inline void dpp_reduce(int32_t& x, op f) \
#define MIGRAPHX_DPP_REDUCE_ASM_FUN(type, op, ins) \
template<unsigned int SubWaveSize> \
__device__ inline void dpp_reduce(type& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##sign##32, f); \
} \
__device__ inline void dpp_reduce(uint32_t& x, op f) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32, f); \
}
if constexpr(SubWaveSize == 2) MIGRAPHX_DPP_REDUCE_ASM(0, x, ins,); \
if constexpr(SubWaveSize == 4) MIGRAPHX_DPP_REDUCE_ASM(1, x, ins,); \
if constexpr(SubWaveSize == 8) MIGRAPHX_DPP_REDUCE_ASM(2, x, ins,); \
if constexpr(SubWaveSize == 16) MIGRAPHX_DPP_REDUCE_ASM(3, x, ins,); \
if constexpr(SubWaveSize == 32) MIGRAPHX_DPP_REDUCE_ASM(MIGRAPHX_DPP_IF_64(__AMDGCN_WAVEFRONT_SIZE)(4, 3), x, ins,MIGRAPHX_DPP_REDUCE_SWIZZLE(x, f)); \
MIGRAPHX_DPP_WHEN_64(__AMDGCN_WAVEFRONT_SIZE)(if constexpr(SubWaveSize == 64) MIGRAPHX_DPP_REDUCE_ASM(5, x, ins,)); \
}
#endif
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE(op, prefix, sign) \
MIGRAPHX_DPP_REDUCE_ASM_FUN(double, op, prefix##_f64); \
MIGRAPHX_DPP_REDUCE_ASM_FUN(float, op, prefix##_f32); \
MIGRAPHX_DPP_REDUCE_ASM_FUN(half, op, prefix##_f16); \
MIGRAPHX_DPP_REDUCE_ASM_FUN(int32_t, op, prefix##sign##32); \
MIGRAPHX_DPP_REDUCE_ASM_FUN(uint32_t, op, prefix##_u32);
// Note: when max and min are in int32_t, signed version of instruction needs to be used.
MIGRAPHX_DPP_REDUCE(op::sum, v_add, _u)
MIGRAPHX_DPP_REDUCE(op::product, v_mul, _u)
MIGRAPHX_DPP_REDUCE(op::max, v_max, _i)
MIGRAPHX_DPP_REDUCE(op::min, v_min, _i)
#endif
template <class T, class Op>
__device__ void dpp_reduce(T& in, Op op)
{
dpp_reduce<__AMDGCN_WAVEFRONT_SIZE>(in, op);
}
template <unsigned int SubWaveSize, class Op, class T, class Index, class F>
__device__ auto subwave_reduce(index idx, Op op, T init, Index n, F f)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment