Commit 2dc6894c authored by Paul's avatar Paul
Browse files

Add subwave reductions

parent df869fd8
...@@ -182,6 +182,8 @@ struct hip_device ...@@ -182,6 +182,8 @@ struct hip_device
std::size_t get_max_workitems_per_block() const { return device_props.maxThreadsPerBlock; } std::size_t get_max_workitems_per_block() const { return device_props.maxThreadsPerBlock; }
std::size_t get_wavefront_size() const { return device_props.warpSize; }
private: private:
std::size_t device_id = 0; std::size_t device_id = 0;
std::size_t current_stream = 0; std::size_t current_stream = 0;
......
...@@ -97,9 +97,10 @@ static shape get_output_shape(const shape& s, const std::vector<T>& axes) ...@@ -97,9 +97,10 @@ static shape get_output_shape(const shape& s, const std::vector<T>& axes)
} }
template <class ReduceLens> template <class ReduceLens>
static std::string get_reduce_algo(const std::vector<shape>& inputs, ReduceLens rlens) static std::string get_reduce_algo(context& ctx, const std::vector<shape>& inputs, ReduceLens rlens)
{ {
const auto init = std::numeric_limits<std::size_t>::max(); const auto init = std::numeric_limits<std::size_t>::max();
auto relements = std::accumulate(rlens.begin(), rlens.end(), 1, std::multiplies<>{});
// The minimum stride // The minimum stride
auto min_stride = std::inner_product( auto min_stride = std::inner_product(
rlens.begin(), rlens.begin(),
...@@ -110,13 +111,24 @@ static std::string get_reduce_algo(const std::vector<shape>& inputs, ReduceLens ...@@ -110,13 +111,24 @@ static std::string get_reduce_algo(const std::vector<shape>& inputs, ReduceLens
[](auto len, auto stride) { return len == 1 ? init : stride; }); [](auto len, auto stride) { return len == 1 ? init : stride; });
if(min_stride > 2) if(min_stride > 2)
return "lane"; return "lane";
if (relements <= ctx.get_current_device().get_wavefront_size())
return "wave";
return "block"; return "block";
} }
static std::string get_reduce_algo(const std::vector<shape>& inputs) static std::string get_reduce_algo(context& ctx, const std::vector<shape>& inputs)
{ {
auto rlens = get_reduce_lens(inputs.front().lens(), inputs.back().lens()); auto rlens = get_reduce_lens(inputs.front().lens(), inputs.back().lens());
return get_reduce_algo(inputs, rlens); return get_reduce_algo(ctx, inputs, rlens);
}
static std::size_t compute_subwave_size(context& ctx, std::size_t n)
{
std::size_t max_wavefront_size = ctx.get_current_device().get_wavefront_size();
std::size_t wavefront_size = 1;
while(wavefront_size < n and wavefront_size < max_wavefront_size)
wavefront_size *= 2;
return wavefront_size;
} }
struct simple_reduce_compiler : compiler<simple_reduce_compiler> struct simple_reduce_compiler : compiler<simple_reduce_compiler>
...@@ -145,18 +157,28 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler> ...@@ -145,18 +157,28 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler>
auto faxis = find_fast_axis({options.virtual_inputs.front()}); auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{}; vectorize vec{};
auto nelements = options.virtual_inputs.back().elements(); auto nelements = options.virtual_inputs.back().elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs)); auto algo = v.get("algo", get_reduce_algo(ctx, options.virtual_inputs));
if(algo == "block") if(algo == "block" or algo == "wave")
{ {
// Vectorize if the axis is a reduction axis // Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1) if(options.virtual_inputs.back().lens()[faxis] == 1)
vec = vectorize::elements(ctx, faxis, options.virtual_inputs); vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size; auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto block_size = compute_block_size(relements, 256); if(algo == "block")
if(relements >= block_size * 256) {
algo = "block_large"; auto block_size = compute_block_size(relements, 256);
options.set_launch_params( if(relements >= block_size * 256)
v, compute_global_for(ctx, nelements * block_size, 256), block_size); algo = "block_large";
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
}
else
{
auto subwave_size = compute_subwave_size(ctx, relements);
algo = "subwave<" + std::to_string(subwave_size) + ">";
options.set_launch_params(
v, compute_global_for(ctx, nelements * subwave_size, 256), ctx.get_current_device().get_wavefront_size());
}
} }
else if(algo == "lane") else if(algo == "lane")
{ {
...@@ -241,18 +263,28 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler> ...@@ -241,18 +263,28 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
auto faxis = find_fast_axis({options.virtual_inputs.front()}); auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{}; vectorize vec{};
auto nelements = reduce_output_shape.elements(); auto nelements = reduce_output_shape.elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs, reduction_shape.lens())); auto algo = v.get("algo", get_reduce_algo(ctx, options.virtual_inputs, reduction_shape.lens()));
if(algo == "block") if(algo == "block" or algo == "wave")
{ {
// Vectorize if the axis is a reduction axis // Vectorize if the axis is a reduction axis
if(reduce_output_shape.lens()[faxis] == 1) if(reduce_output_shape.lens()[faxis] == 1)
vec = vectorize::elements(ctx, faxis, options.virtual_inputs); vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
auto relements = reduction_shape.elements() / vec.size; auto relements = reduction_shape.elements() / vec.size;
auto block_size = compute_block_size(relements, 256); if (algo == "block")
if(relements >= block_size * 256) {
algo = "block_large"; auto block_size = compute_block_size(relements, 256);
options.set_launch_params( if(relements >= block_size * 256)
v, compute_global_for(ctx, nelements * block_size, 256), block_size); algo = "block_large";
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
}
else
{
auto subwave_size = compute_subwave_size(ctx, relements);
algo = "subwave<" + std::to_string(subwave_size) + ">";
options.set_launch_params(
v, compute_global_for(ctx, nelements * subwave_size, 256), ctx.get_current_device().get_wavefront_size());
}
} }
else if(algo == "lane") else if(algo == "lane")
{ {
......
...@@ -135,6 +135,13 @@ struct index ...@@ -135,6 +135,13 @@ struct index
constexpr auto ngroup() const { return nglobal() / max_nlocal(); } constexpr auto ngroup() const { return nglobal() / max_nlocal(); }
template<unsigned int SubWaveSize>
constexpr index_constant<SubWaveSize> nlocal_subwave() const { return {}; }
template<unsigned int SubWaveSize>
constexpr auto local_subwave() const { return local % nlocal_subwave<SubWaveSize>(); }
template<unsigned int SubWaveSize>
constexpr auto nwave() const { return max_nlocal() / nlocal_subwave<SubWaveSize>(); }
constexpr index_constant<__AMDGCN_WAVEFRONT_SIZE> nlocal_wave() const { return {}; } constexpr index_constant<__AMDGCN_WAVEFRONT_SIZE> nlocal_wave() const { return {}; }
constexpr auto local_wave() const { return local % nlocal_wave(); } constexpr auto local_wave() const { return local % nlocal_wave(); }
constexpr auto nwave() const { return max_nlocal() / nlocal_wave(); } constexpr auto nwave() const { return max_nlocal() / nlocal_wave(); }
...@@ -164,6 +171,12 @@ struct index ...@@ -164,6 +171,12 @@ struct index
return max_stride_iterations(n, nlocal_wave()); return max_stride_iterations(n, nlocal_wave());
} }
template <unsigned int SubWaveSize, class N>
constexpr auto max_local_subwave_stride_iterations(N n) const
{
return max_stride_iterations(n, nlocal_subwave<SubWaveSize>());
}
template <class F, class I, class D> template <class F, class I, class D>
static constexpr auto invoke_loop(F f, I i, D d) -> decltype(f(i, d)) static constexpr auto invoke_loop(F f, I i, D d) -> decltype(f(i, d))
{ {
...@@ -254,10 +267,16 @@ struct index ...@@ -254,10 +267,16 @@ struct index
for_stride<false>(group, n, ngroup(), f); for_stride<false>(group, n, ngroup(), f);
} }
template <unsigned int SubWaveSize, class F, class N>
__device__ void local_subwave_stride(N n, F f) const
{
for_stride<true>(local_subwave<SubWaveSize>(), n, nlocal_subwave<SubWaveSize>(), f);
}
template <class F, class N> template <class F, class N>
__device__ void local_wave_stride(N n, F f) const __device__ void local_wave_stride(N n, F f) const
{ {
for_stride<false>(local_wave(), n, nlocal_wave(), f); for_stride<true>(local_wave(), n, nlocal_wave(), f);
} }
}; };
......
...@@ -31,30 +31,66 @@ ...@@ -31,30 +31,66 @@
namespace migraphx { namespace migraphx {
constexpr bool is_power_of_2(unsigned int x)
{
return x > 0 && !(x & (x - 1));
}
#if MIGRAPHX_HAS_DPP #if MIGRAPHX_HAS_DPP
template <class T, class Op> template <unsigned int SubWaveSize, class T, class Op>
__device__ void dpp_reduce(T& in, Op op) __device__ void dpp_reduce(T& in, Op op)
{ {
static_assert(SubWaveSize <= __AMDGCN_WAVEFRONT_SIZE, "Too large subwave size");
static_assert(is_power_of_2(SubWaveSize), "SubWaveSize is not a power of 2");
T out{}; T out{};
out = dpp_mov<dpp_row_shr(1)>(in); if constexpr(SubWaveSize > 1)
in = op(in, out); {
out = dpp_mov<dpp_row_shr(2)>(in); out = dpp_mov<dpp_row_shr(1)>(in);
in = op(in, out); in = op(in, out);
out = dpp_mov<dpp_row_shr(4), 0xf, 0xe>(in); }
in = op(in, out); if constexpr(SubWaveSize > 2)
out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in); {
in = op(in, out); out = dpp_mov<dpp_row_shr(2)>(in);
in = op(in, out);
}
if constexpr(SubWaveSize > 4)
{
out = dpp_mov<dpp_row_shr(4), 0xf, 0xe>(in);
in = op(in, out);
}
if constexpr(SubWaveSize > 8)
{
out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in);
in = op(in, out);
}
#if __AMDGCN_WAVEFRONT_SIZE == 32 #if __AMDGCN_WAVEFRONT_SIZE == 32
out = dpp_swizzle<dpp_row_bcast(15)>(in); if constexpr(SubWaveSize > 16)
in = op(in, out); {
out = dpp_swizzle<dpp_row_bcast(15)>(in);
in = op(in, out);
}
#else #else
out = dpp_mov<dpp_row_bcast(15), 0xa>(in); if constexpr(SubWaveSize > 16)
in = op(in, out); {
out = dpp_mov<dpp_row_bcast(31), 0xc>(in); out = dpp_mov<dpp_row_bcast(15), 0xa>(in);
in = op(in, out); in = op(in, out);
}
if constexpr(SubWaveSize > 32)
{
out = dpp_mov<dpp_row_bcast(31), 0xc>(in);
in = op(in, out);
}
#endif #endif
} }
template <class T, class Op>
__device__ void dpp_reduce(T& in, Op op)
{
dpp_reduce<__AMDGCN_WAVEFRONT_SIZE>(in, op);
}
#if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK) #if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK)
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) x = 1 #define MIGRAPHX_DPP_REDUCE_ASM(x, ins) x = 1
...@@ -98,17 +134,24 @@ MIGRAPHX_DPP_REDUCE(op::product, v_mul, _u) ...@@ -98,17 +134,24 @@ MIGRAPHX_DPP_REDUCE(op::product, v_mul, _u)
MIGRAPHX_DPP_REDUCE(op::max, v_max, _i) MIGRAPHX_DPP_REDUCE(op::max, v_max, _i)
MIGRAPHX_DPP_REDUCE(op::min, v_min, _i) MIGRAPHX_DPP_REDUCE(op::min, v_min, _i)
template <class Op, class T, class Index, class F>
__device__ auto wave_reduce(index idx, Op op, T init, Index n, F f) template <unsigned int SubWaveSize, class Op, class T, class Index, class F>
__device__ auto subwave_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal()); MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
using type = decltype(index::invoke_loop(f, 0, _c<0>)); using type = decltype(index::invoke_loop(f, 0, _c<0>));
type x = init; type x = init;
idx.local_wave_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); }); idx.local_subwave_stride<SubWaveSize>(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
dpp_reduce(x, op); dpp_reduce<SubWaveSize>(x, op);
return x; return x;
} }
template <class Op, class T, class Index, class F>
__device__ auto wave_reduce(index idx, Op op, T init, Index n, F f)
{
return subwave_reduce<__AMDGCN_WAVEFRONT_SIZE>(idx, op, init, n, f);
}
template <class Op, class T, class Index, class F> template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{ {
...@@ -486,7 +529,8 @@ struct block_large ...@@ -486,7 +529,8 @@ struct block_large
} }
}; };
struct wave template<unsigned int SubWaveSize>
struct subwave
{ {
template <class Slicer> template <class Slicer>
struct reducer : reducer_base<reducer<Slicer>> struct reducer : reducer_base<reducer<Slicer>>
...@@ -515,7 +559,7 @@ struct wave ...@@ -515,7 +559,7 @@ struct wave
template <class Op, class T, class Read, class N, class... Ts> template <class Op, class T, class Read, class N, class... Ts>
__device__ auto reduce_impl(Op op, T init, Read read, N n, Ts&&... xs) const __device__ auto reduce_impl(Op op, T init, Read read, N n, Ts&&... xs) const
{ {
return wave_reduce(idx, op, init, n, [&](auto j, auto d) { return subwave_reduce<SubWaveSize>(idx, op, init, n, [&](auto j, auto d) {
return vec_reduce(read(xs(j, d)...), op); return vec_reduce(read(xs(j, d)...), op);
}); });
} }
...@@ -523,7 +567,7 @@ struct wave ...@@ -523,7 +567,7 @@ struct wave
template <class F> template <class F>
__device__ void outer(F f) const __device__ void outer(F f) const
{ {
if(idx.local_wave() == 0) if(idx.local_subwave<SubWaveSize>() == 0)
f(); f();
} }
...@@ -536,9 +580,9 @@ struct wave ...@@ -536,9 +580,9 @@ struct wave
template <class R, class F, class N, class... Ts> template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const __device__ auto inner_impl(F f, N n, Ts&&... xs) const
{ {
using max_iterations = decltype(idx.max_local_wave_stride_iterations(n)); using max_iterations = decltype(idx.max_local_subwave_stride_iterations<SubWaveSize>(n));
inner_storage<R, max_iterations{}, N> storage; inner_storage<R, max_iterations{}, N> storage;
idx.local_wave_stride(n, [&](auto j, auto d) { storage(j, d) = f(xs(j, d)...); }); idx.local_subwave_stride<SubWaveSize>(n, [&](auto j, auto d) { storage(j, d) = f(xs(j, d)...); });
return storage; return storage;
} }
}; };
...@@ -554,13 +598,15 @@ struct wave ...@@ -554,13 +598,15 @@ struct wave
{ {
auto idx = make_index(); auto idx = make_index();
constexpr auto nelements = get_shape_c<Output>{}.elements(); constexpr auto nelements = get_shape_c<Output>{}.elements();
idx.global_stride(nelements * idx.nlocal_wave(), [&](auto i) { idx.global_stride(nelements * idx.nlocal_subwave<SubWaveSize>(), [&](auto i) {
const auto out_idx = get_shape_c<Output>{}.multi(i / idx.nlocal_wave()); const auto out_idx = get_shape_c<Output>{}.multi(i / idx.nlocal_subwave<SubWaveSize>());
f(out_idx, make(idx, [&](auto input) { return reduce_slice<Output>(input, out_idx); })); f(out_idx, make(idx, [&](auto input) { return reduce_slice<Output>(input, out_idx); }));
}); });
} }
}; };
using wave = subwave<__AMDGCN_WAVEFRONT_SIZE>;
struct lane struct lane
{ {
template <class Slicer> template <class Slicer>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment