"ts/webui/src/static/style/logDrawer.scss" did not exist on "fc7ddcd0c83febfbbae76bc5065e1e9d6cd8f8c3"
Commit 2dc6894c authored by Paul's avatar Paul
Browse files

Add subwave reductions

parent df869fd8
......@@ -182,6 +182,8 @@ struct hip_device
std::size_t get_max_workitems_per_block() const { return device_props.maxThreadsPerBlock; }
std::size_t get_wavefront_size() const { return device_props.warpSize; }
private:
std::size_t device_id = 0;
std::size_t current_stream = 0;
......
......@@ -97,9 +97,10 @@ static shape get_output_shape(const shape& s, const std::vector<T>& axes)
}
template <class ReduceLens>
static std::string get_reduce_algo(const std::vector<shape>& inputs, ReduceLens rlens)
static std::string get_reduce_algo(context& ctx, const std::vector<shape>& inputs, ReduceLens rlens)
{
const auto init = std::numeric_limits<std::size_t>::max();
auto relements = std::accumulate(rlens.begin(), rlens.end(), 1, std::multiplies<>{});
// The minimum stride
auto min_stride = std::inner_product(
rlens.begin(),
......@@ -110,13 +111,24 @@ static std::string get_reduce_algo(const std::vector<shape>& inputs, ReduceLens
[](auto len, auto stride) { return len == 1 ? init : stride; });
if(min_stride > 2)
return "lane";
if (relements <= ctx.get_current_device().get_wavefront_size())
return "wave";
return "block";
}
static std::string get_reduce_algo(const std::vector<shape>& inputs)
static std::string get_reduce_algo(context& ctx, const std::vector<shape>& inputs)
{
auto rlens = get_reduce_lens(inputs.front().lens(), inputs.back().lens());
return get_reduce_algo(inputs, rlens);
return get_reduce_algo(ctx, inputs, rlens);
}
static std::size_t compute_subwave_size(context& ctx, std::size_t n)
{
std::size_t max_wavefront_size = ctx.get_current_device().get_wavefront_size();
std::size_t wavefront_size = 1;
while(wavefront_size < n and wavefront_size < max_wavefront_size)
wavefront_size *= 2;
return wavefront_size;
}
struct simple_reduce_compiler : compiler<simple_reduce_compiler>
......@@ -145,18 +157,28 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler>
auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{};
auto nelements = options.virtual_inputs.back().elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs));
if(algo == "block")
auto algo = v.get("algo", get_reduce_algo(ctx, options.virtual_inputs));
if(algo == "block" or algo == "wave")
{
// Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1)
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto block_size = compute_block_size(relements, 256);
if(relements >= block_size * 256)
algo = "block_large";
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
if(algo == "block")
{
auto block_size = compute_block_size(relements, 256);
if(relements >= block_size * 256)
algo = "block_large";
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
}
else
{
auto subwave_size = compute_subwave_size(ctx, relements);
algo = "subwave<" + std::to_string(subwave_size) + ">";
options.set_launch_params(
v, compute_global_for(ctx, nelements * subwave_size, 256), ctx.get_current_device().get_wavefront_size());
}
}
else if(algo == "lane")
{
......@@ -241,18 +263,28 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{};
auto nelements = reduce_output_shape.elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs, reduction_shape.lens()));
if(algo == "block")
auto algo = v.get("algo", get_reduce_algo(ctx, options.virtual_inputs, reduction_shape.lens()));
if(algo == "block" or algo == "wave")
{
// Vectorize if the axis is a reduction axis
if(reduce_output_shape.lens()[faxis] == 1)
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
auto relements = reduction_shape.elements() / vec.size;
auto block_size = compute_block_size(relements, 256);
if(relements >= block_size * 256)
algo = "block_large";
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
if (algo == "block")
{
auto block_size = compute_block_size(relements, 256);
if(relements >= block_size * 256)
algo = "block_large";
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
}
else
{
auto subwave_size = compute_subwave_size(ctx, relements);
algo = "subwave<" + std::to_string(subwave_size) + ">";
options.set_launch_params(
v, compute_global_for(ctx, nelements * subwave_size, 256), ctx.get_current_device().get_wavefront_size());
}
}
else if(algo == "lane")
{
......
......@@ -135,6 +135,13 @@ struct index
constexpr auto ngroup() const { return nglobal() / max_nlocal(); }
template<unsigned int SubWaveSize>
constexpr index_constant<SubWaveSize> nlocal_subwave() const { return {}; }
template<unsigned int SubWaveSize>
constexpr auto local_subwave() const { return local % nlocal_subwave<SubWaveSize>(); }
template<unsigned int SubWaveSize>
constexpr auto nwave() const { return max_nlocal() / nlocal_subwave<SubWaveSize>(); }
constexpr index_constant<__AMDGCN_WAVEFRONT_SIZE> nlocal_wave() const { return {}; }
constexpr auto local_wave() const { return local % nlocal_wave(); }
constexpr auto nwave() const { return max_nlocal() / nlocal_wave(); }
......@@ -164,6 +171,12 @@ struct index
return max_stride_iterations(n, nlocal_wave());
}
template <unsigned int SubWaveSize, class N>
constexpr auto max_local_subwave_stride_iterations(N n) const
{
return max_stride_iterations(n, nlocal_subwave<SubWaveSize>());
}
template <class F, class I, class D>
static constexpr auto invoke_loop(F f, I i, D d) -> decltype(f(i, d))
{
......@@ -254,10 +267,16 @@ struct index
for_stride<false>(group, n, ngroup(), f);
}
template <unsigned int SubWaveSize, class F, class N>
__device__ void local_subwave_stride(N n, F f) const
{
for_stride<true>(local_subwave<SubWaveSize>(), n, nlocal_subwave<SubWaveSize>(), f);
}
template <class F, class N>
__device__ void local_wave_stride(N n, F f) const
{
for_stride<false>(local_wave(), n, nlocal_wave(), f);
for_stride<true>(local_wave(), n, nlocal_wave(), f);
}
};
......
......@@ -31,30 +31,66 @@
namespace migraphx {
constexpr bool is_power_of_2(unsigned int x)
{
return x > 0 && !(x & (x - 1));
}
#if MIGRAPHX_HAS_DPP
template <class T, class Op>
template <unsigned int SubWaveSize, class T, class Op>
__device__ void dpp_reduce(T& in, Op op)
{
static_assert(SubWaveSize <= __AMDGCN_WAVEFRONT_SIZE, "Too large subwave size");
static_assert(is_power_of_2(SubWaveSize), "SubWaveSize is not a power of 2");
T out{};
out = dpp_mov<dpp_row_shr(1)>(in);
in = op(in, out);
out = dpp_mov<dpp_row_shr(2)>(in);
in = op(in, out);
out = dpp_mov<dpp_row_shr(4), 0xf, 0xe>(in);
in = op(in, out);
out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in);
in = op(in, out);
if constexpr(SubWaveSize > 1)
{
out = dpp_mov<dpp_row_shr(1)>(in);
in = op(in, out);
}
if constexpr(SubWaveSize > 2)
{
out = dpp_mov<dpp_row_shr(2)>(in);
in = op(in, out);
}
if constexpr(SubWaveSize > 4)
{
out = dpp_mov<dpp_row_shr(4), 0xf, 0xe>(in);
in = op(in, out);
}
if constexpr(SubWaveSize > 8)
{
out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in);
in = op(in, out);
}
#if __AMDGCN_WAVEFRONT_SIZE == 32
out = dpp_swizzle<dpp_row_bcast(15)>(in);
in = op(in, out);
if constexpr(SubWaveSize > 16)
{
out = dpp_swizzle<dpp_row_bcast(15)>(in);
in = op(in, out);
}
#else
out = dpp_mov<dpp_row_bcast(15), 0xa>(in);
in = op(in, out);
out = dpp_mov<dpp_row_bcast(31), 0xc>(in);
in = op(in, out);
if constexpr(SubWaveSize > 16)
{
out = dpp_mov<dpp_row_bcast(15), 0xa>(in);
in = op(in, out);
}
if constexpr(SubWaveSize > 32)
{
out = dpp_mov<dpp_row_bcast(31), 0xc>(in);
in = op(in, out);
}
#endif
}
template <class T, class Op>
__device__ void dpp_reduce(T& in, Op op)
{
dpp_reduce<__AMDGCN_WAVEFRONT_SIZE>(in, op);
}
#if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK)
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) x = 1
......@@ -98,17 +134,24 @@ MIGRAPHX_DPP_REDUCE(op::product, v_mul, _u)
MIGRAPHX_DPP_REDUCE(op::max, v_max, _i)
MIGRAPHX_DPP_REDUCE(op::min, v_min, _i)
template <class Op, class T, class Index, class F>
__device__ auto wave_reduce(index idx, Op op, T init, Index n, F f)
template <unsigned int SubWaveSize, class Op, class T, class Index, class F>
__device__ auto subwave_reduce(index idx, Op op, T init, Index n, F f)
{
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
using type = decltype(index::invoke_loop(f, 0, _c<0>));
type x = init;
idx.local_wave_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
dpp_reduce(x, op);
idx.local_subwave_stride<SubWaveSize>(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
dpp_reduce<SubWaveSize>(x, op);
return x;
}
template <class Op, class T, class Index, class F>
__device__ auto wave_reduce(index idx, Op op, T init, Index n, F f)
{
return subwave_reduce<__AMDGCN_WAVEFRONT_SIZE>(idx, op, init, n, f);
}
template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{
......@@ -486,7 +529,8 @@ struct block_large
}
};
struct wave
template<unsigned int SubWaveSize>
struct subwave
{
template <class Slicer>
struct reducer : reducer_base<reducer<Slicer>>
......@@ -515,7 +559,7 @@ struct wave
template <class Op, class T, class Read, class N, class... Ts>
__device__ auto reduce_impl(Op op, T init, Read read, N n, Ts&&... xs) const
{
return wave_reduce(idx, op, init, n, [&](auto j, auto d) {
return subwave_reduce<SubWaveSize>(idx, op, init, n, [&](auto j, auto d) {
return vec_reduce(read(xs(j, d)...), op);
});
}
......@@ -523,7 +567,7 @@ struct wave
template <class F>
__device__ void outer(F f) const
{
if(idx.local_wave() == 0)
if(idx.local_subwave<SubWaveSize>() == 0)
f();
}
......@@ -536,9 +580,9 @@ struct wave
template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const
{
using max_iterations = decltype(idx.max_local_wave_stride_iterations(n));
using max_iterations = decltype(idx.max_local_subwave_stride_iterations<SubWaveSize>(n));
inner_storage<R, max_iterations{}, N> storage;
idx.local_wave_stride(n, [&](auto j, auto d) { storage(j, d) = f(xs(j, d)...); });
idx.local_subwave_stride<SubWaveSize>(n, [&](auto j, auto d) { storage(j, d) = f(xs(j, d)...); });
return storage;
}
};
......@@ -554,13 +598,15 @@ struct wave
{
auto idx = make_index();
constexpr auto nelements = get_shape_c<Output>{}.elements();
idx.global_stride(nelements * idx.nlocal_wave(), [&](auto i) {
const auto out_idx = get_shape_c<Output>{}.multi(i / idx.nlocal_wave());
idx.global_stride(nelements * idx.nlocal_subwave<SubWaveSize>(), [&](auto i) {
const auto out_idx = get_shape_c<Output>{}.multi(i / idx.nlocal_subwave<SubWaveSize>());
f(out_idx, make(idx, [&](auto input) { return reduce_slice<Output>(input, out_idx); }));
});
}
};
using wave = subwave<__AMDGCN_WAVEFRONT_SIZE>;
struct lane
{
template <class Slicer>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment