Commit 308db690 authored by Paul's avatar Paul
Browse files

Format

parent 2dc6894c
......@@ -111,7 +111,7 @@ static std::string get_reduce_algo(context& ctx, const std::vector<shape>& input
[](auto len, auto stride) { return len == 1 ? init : stride; });
if(min_stride > 2)
return "lane";
if (relements <= ctx.get_current_device().get_wavefront_size())
if(relements <= ctx.get_current_device().get_wavefront_size())
return "wave";
return "block";
}
......@@ -176,8 +176,9 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler>
{
auto subwave_size = compute_subwave_size(ctx, relements);
algo = "subwave<" + std::to_string(subwave_size) + ">";
options.set_launch_params(
v, compute_global_for(ctx, nelements * subwave_size, 256), ctx.get_current_device().get_wavefront_size());
options.set_launch_params(v,
compute_global_for(ctx, nelements * subwave_size, 256),
ctx.get_current_device().get_wavefront_size());
}
}
else if(algo == "lane")
......@@ -263,14 +264,15 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{};
auto nelements = reduce_output_shape.elements();
auto algo = v.get("algo", get_reduce_algo(ctx, options.virtual_inputs, reduction_shape.lens()));
auto algo =
v.get("algo", get_reduce_algo(ctx, options.virtual_inputs, reduction_shape.lens()));
if(algo == "block" or algo == "wave")
{
// Vectorize if the axis is a reduction axis
if(reduce_output_shape.lens()[faxis] == 1)
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
auto relements = reduction_shape.elements() / vec.size;
if (algo == "block")
if(algo == "block")
{
auto block_size = compute_block_size(relements, 256);
if(relements >= block_size * 256)
......@@ -282,8 +284,9 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
{
auto subwave_size = compute_subwave_size(ctx, relements);
algo = "subwave<" + std::to_string(subwave_size) + ">";
options.set_launch_params(
v, compute_global_for(ctx, nelements * subwave_size, 256), ctx.get_current_device().get_wavefront_size());
options.set_launch_params(v,
compute_global_for(ctx, nelements * subwave_size, 256),
ctx.get_current_device().get_wavefront_size());
}
}
else if(algo == "lane")
......
......@@ -135,12 +135,21 @@ struct index
constexpr auto ngroup() const { return nglobal() / max_nlocal(); }
template<unsigned int SubWaveSize>
constexpr index_constant<SubWaveSize> nlocal_subwave() const { return {}; }
template<unsigned int SubWaveSize>
constexpr auto local_subwave() const { return local % nlocal_subwave<SubWaveSize>(); }
template<unsigned int SubWaveSize>
constexpr auto nwave() const { return max_nlocal() / nlocal_subwave<SubWaveSize>(); }
template <unsigned int SubWaveSize>
constexpr index_constant<SubWaveSize> nlocal_subwave() const
{
return {};
}
template <unsigned int SubWaveSize>
constexpr auto local_subwave() const
{
return local % nlocal_subwave<SubWaveSize>();
}
template <unsigned int SubWaveSize>
constexpr auto nwave() const
{
return max_nlocal() / nlocal_subwave<SubWaveSize>();
}
constexpr index_constant<__AMDGCN_WAVEFRONT_SIZE> nlocal_wave() const { return {}; }
constexpr auto local_wave() const { return local % nlocal_wave(); }
......
......@@ -31,11 +31,7 @@
namespace migraphx {
constexpr bool is_power_of_2(unsigned int x)
{
return x > 0 && !(x & (x - 1));
}
constexpr bool is_power_of_2(unsigned int x) { return x > 0 && !(x & (x - 1)); }
#if MIGRAPHX_HAS_DPP
......@@ -134,14 +130,14 @@ MIGRAPHX_DPP_REDUCE(op::product, v_mul, _u)
MIGRAPHX_DPP_REDUCE(op::max, v_max, _i)
MIGRAPHX_DPP_REDUCE(op::min, v_min, _i)
template <unsigned int SubWaveSize, class Op, class T, class Index, class F>
__device__ auto subwave_reduce(index idx, Op op, T init, Index n, F f)
{
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
using type = decltype(index::invoke_loop(f, 0, _c<0>));
type x = init;
idx.local_subwave_stride<SubWaveSize>(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
idx.local_subwave_stride<SubWaveSize>(
n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
dpp_reduce<SubWaveSize>(x, op);
return x;
}
......@@ -529,7 +525,7 @@ struct block_large
}
};
template<unsigned int SubWaveSize>
template <unsigned int SubWaveSize>
struct subwave
{
template <class Slicer>
......@@ -580,9 +576,11 @@ struct subwave
template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const
{
using max_iterations = decltype(idx.max_local_subwave_stride_iterations<SubWaveSize>(n));
using max_iterations =
decltype(idx.max_local_subwave_stride_iterations<SubWaveSize>(n));
inner_storage<R, max_iterations{}, N> storage;
idx.local_subwave_stride<SubWaveSize>(n, [&](auto j, auto d) { storage(j, d) = f(xs(j, d)...); });
idx.local_subwave_stride<SubWaveSize>(
n, [&](auto j, auto d) { storage(j, d) = f(xs(j, d)...); });
return storage;
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment