Commit 308db690 authored by Paul's avatar Paul
Browse files

Format

parent 2dc6894c
...@@ -111,7 +111,7 @@ static std::string get_reduce_algo(context& ctx, const std::vector<shape>& input ...@@ -111,7 +111,7 @@ static std::string get_reduce_algo(context& ctx, const std::vector<shape>& input
[](auto len, auto stride) { return len == 1 ? init : stride; }); [](auto len, auto stride) { return len == 1 ? init : stride; });
if(min_stride > 2) if(min_stride > 2)
return "lane"; return "lane";
if (relements <= ctx.get_current_device().get_wavefront_size()) if(relements <= ctx.get_current_device().get_wavefront_size())
return "wave"; return "wave";
return "block"; return "block";
} }
...@@ -176,8 +176,9 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler> ...@@ -176,8 +176,9 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler>
{ {
auto subwave_size = compute_subwave_size(ctx, relements); auto subwave_size = compute_subwave_size(ctx, relements);
algo = "subwave<" + std::to_string(subwave_size) + ">"; algo = "subwave<" + std::to_string(subwave_size) + ">";
options.set_launch_params( options.set_launch_params(v,
v, compute_global_for(ctx, nelements * subwave_size, 256), ctx.get_current_device().get_wavefront_size()); compute_global_for(ctx, nelements * subwave_size, 256),
ctx.get_current_device().get_wavefront_size());
} }
} }
else if(algo == "lane") else if(algo == "lane")
...@@ -263,14 +264,15 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler> ...@@ -263,14 +264,15 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
auto faxis = find_fast_axis({options.virtual_inputs.front()}); auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{}; vectorize vec{};
auto nelements = reduce_output_shape.elements(); auto nelements = reduce_output_shape.elements();
auto algo = v.get("algo", get_reduce_algo(ctx, options.virtual_inputs, reduction_shape.lens())); auto algo =
v.get("algo", get_reduce_algo(ctx, options.virtual_inputs, reduction_shape.lens()));
if(algo == "block" or algo == "wave") if(algo == "block" or algo == "wave")
{ {
// Vectorize if the axis is a reduction axis // Vectorize if the axis is a reduction axis
if(reduce_output_shape.lens()[faxis] == 1) if(reduce_output_shape.lens()[faxis] == 1)
vec = vectorize::elements(ctx, faxis, options.virtual_inputs); vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
auto relements = reduction_shape.elements() / vec.size; auto relements = reduction_shape.elements() / vec.size;
if (algo == "block") if(algo == "block")
{ {
auto block_size = compute_block_size(relements, 256); auto block_size = compute_block_size(relements, 256);
if(relements >= block_size * 256) if(relements >= block_size * 256)
...@@ -282,8 +284,9 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler> ...@@ -282,8 +284,9 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
{ {
auto subwave_size = compute_subwave_size(ctx, relements); auto subwave_size = compute_subwave_size(ctx, relements);
algo = "subwave<" + std::to_string(subwave_size) + ">"; algo = "subwave<" + std::to_string(subwave_size) + ">";
options.set_launch_params( options.set_launch_params(v,
v, compute_global_for(ctx, nelements * subwave_size, 256), ctx.get_current_device().get_wavefront_size()); compute_global_for(ctx, nelements * subwave_size, 256),
ctx.get_current_device().get_wavefront_size());
} }
} }
else if(algo == "lane") else if(algo == "lane")
......
...@@ -135,12 +135,21 @@ struct index ...@@ -135,12 +135,21 @@ struct index
constexpr auto ngroup() const { return nglobal() / max_nlocal(); } constexpr auto ngroup() const { return nglobal() / max_nlocal(); }
template<unsigned int SubWaveSize> template <unsigned int SubWaveSize>
constexpr index_constant<SubWaveSize> nlocal_subwave() const { return {}; } constexpr index_constant<SubWaveSize> nlocal_subwave() const
template<unsigned int SubWaveSize> {
constexpr auto local_subwave() const { return local % nlocal_subwave<SubWaveSize>(); } return {};
template<unsigned int SubWaveSize> }
constexpr auto nwave() const { return max_nlocal() / nlocal_subwave<SubWaveSize>(); } template <unsigned int SubWaveSize>
constexpr auto local_subwave() const
{
return local % nlocal_subwave<SubWaveSize>();
}
template <unsigned int SubWaveSize>
constexpr auto nwave() const
{
return max_nlocal() / nlocal_subwave<SubWaveSize>();
}
constexpr index_constant<__AMDGCN_WAVEFRONT_SIZE> nlocal_wave() const { return {}; } constexpr index_constant<__AMDGCN_WAVEFRONT_SIZE> nlocal_wave() const { return {}; }
constexpr auto local_wave() const { return local % nlocal_wave(); } constexpr auto local_wave() const { return local % nlocal_wave(); }
......
...@@ -31,11 +31,7 @@ ...@@ -31,11 +31,7 @@
namespace migraphx { namespace migraphx {
constexpr bool is_power_of_2(unsigned int x) { return x > 0 && !(x & (x - 1)); }
constexpr bool is_power_of_2(unsigned int x)
{
return x > 0 && !(x & (x - 1));
}
#if MIGRAPHX_HAS_DPP #if MIGRAPHX_HAS_DPP
...@@ -134,14 +130,14 @@ MIGRAPHX_DPP_REDUCE(op::product, v_mul, _u) ...@@ -134,14 +130,14 @@ MIGRAPHX_DPP_REDUCE(op::product, v_mul, _u)
MIGRAPHX_DPP_REDUCE(op::max, v_max, _i) MIGRAPHX_DPP_REDUCE(op::max, v_max, _i)
MIGRAPHX_DPP_REDUCE(op::min, v_min, _i) MIGRAPHX_DPP_REDUCE(op::min, v_min, _i)
template <unsigned int SubWaveSize, class Op, class T, class Index, class F> template <unsigned int SubWaveSize, class Op, class T, class Index, class F>
__device__ auto subwave_reduce(index idx, Op op, T init, Index n, F f) __device__ auto subwave_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal()); MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
using type = decltype(index::invoke_loop(f, 0, _c<0>)); using type = decltype(index::invoke_loop(f, 0, _c<0>));
type x = init; type x = init;
idx.local_subwave_stride<SubWaveSize>(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); }); idx.local_subwave_stride<SubWaveSize>(
n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
dpp_reduce<SubWaveSize>(x, op); dpp_reduce<SubWaveSize>(x, op);
return x; return x;
} }
...@@ -529,7 +525,7 @@ struct block_large ...@@ -529,7 +525,7 @@ struct block_large
} }
}; };
template<unsigned int SubWaveSize> template <unsigned int SubWaveSize>
struct subwave struct subwave
{ {
template <class Slicer> template <class Slicer>
...@@ -580,9 +576,11 @@ struct subwave ...@@ -580,9 +576,11 @@ struct subwave
template <class R, class F, class N, class... Ts> template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const __device__ auto inner_impl(F f, N n, Ts&&... xs) const
{ {
using max_iterations = decltype(idx.max_local_subwave_stride_iterations<SubWaveSize>(n)); using max_iterations =
decltype(idx.max_local_subwave_stride_iterations<SubWaveSize>(n));
inner_storage<R, max_iterations{}, N> storage; inner_storage<R, max_iterations{}, N> storage;
idx.local_subwave_stride<SubWaveSize>(n, [&](auto j, auto d) { storage(j, d) = f(xs(j, d)...); }); idx.local_subwave_stride<SubWaveSize>(
n, [&](auto j, auto d) { storage(j, d) = f(xs(j, d)...); });
return storage; return storage;
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment