Commit 2b9f612f authored by Paul's avatar Paul
Browse files

Dont always unroll global stride

parent 6a42a385
...@@ -118,16 +118,14 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -118,16 +118,14 @@ struct reduce_compiler : compiler<reduce_compiler>
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
auto faxis = find_fast_axis({options.virtual_inputs.front()}); auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{}; vectorize vec{};
// Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1)
{
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
}
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto nelements = options.virtual_inputs.back().elements(); auto nelements = options.virtual_inputs.back().elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs)); auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs));
if(algo == "block") if(algo == "block")
{ {
// Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1)
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto block_size = compute_block_size(relements, 256); auto block_size = compute_block_size(relements, 256);
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size); v, compute_global_for(ctx, nelements * block_size, 256), block_size);
......
...@@ -160,7 +160,7 @@ struct index ...@@ -160,7 +160,7 @@ struct index
return f(i); return f(i);
} }
template <class F, class N, class Stride> template <bool Unroll, class F, class N, class Stride>
static constexpr void for_stride(index_int start, N n, Stride stride, F f) static constexpr void for_stride(index_int start, N n, Stride stride, F f)
{ {
MIGRAPHX_ASSERT(start < stride); MIGRAPHX_ASSERT(start < stride);
...@@ -178,7 +178,7 @@ struct index ...@@ -178,7 +178,7 @@ struct index
invoke_loop(f, start, _c<0>); invoke_loop(f, start, _c<0>);
} }
} }
else else if constexpr(Unroll)
{ {
MIGRAPHX_STATIC_ASSERT_FOR(max_stride_iterations(n, stride) < 256) MIGRAPHX_STATIC_ASSERT_FOR(max_stride_iterations(n, stride) < 256)
{ {
...@@ -192,6 +192,15 @@ struct index ...@@ -192,6 +192,15 @@ struct index
}); });
} }
} }
else
{
index_int k = 0;
for(index_int i = start; i < n; i += stride)
{
invoke_loop(f, i, k);
k++;
}
}
} }
else else
{ {
...@@ -207,13 +216,13 @@ struct index ...@@ -207,13 +216,13 @@ struct index
template <class F, class N> template <class F, class N>
__device__ void global_stride(N n, F f) const __device__ void global_stride(N n, F f) const
{ {
for_stride(global, n, nglobal(), f); for_stride<false>(global, n, nglobal(), f);
} }
template <class F, class N> template <class F, class N>
__device__ void local_stride(N n, F f) const __device__ void local_stride(N n, F f) const
{ {
for_stride(local, n, nlocal(), f); for_stride<true>(local, n, nlocal(), f);
} }
}; };
......
...@@ -390,6 +390,7 @@ struct block ...@@ -390,6 +390,7 @@ struct block
struct lane struct lane
{ {
#if 0
template <class Slicer> template <class Slicer>
struct reducer struct reducer
{ {
...@@ -439,6 +440,73 @@ struct lane ...@@ -439,6 +440,73 @@ struct lane
{ {
return reducer<Slicer>{idx, slicer}; return reducer<Slicer>{idx, slicer};
} }
#else
template <class Slicer>
struct reducer : reducer_base<reducer<Slicer>>
{
index idx;
Slicer slice;
template <class Size, class F>
struct inner_storage : inner_storage_tag
{
using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
F f;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto operator()(U j, V d) const
{
return f(j, d);
}
};
template<class Size, class F>
constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{
return {f};
}
template <class Op, class T, class Read, class N, class U, class... Us>
__device__ auto reduce_impl(Op op, T init, Read read, N n, U&& x, Us&&... xs) const
{
using type = remove_reference_t<decltype(x(0, _c<0>))>;
type r = init;
for(index_int j = 0; j < n; j++)
{
r = op(r, read(x(j, _c<0>), xs(j, _c<0>)...));
}
return r;
}
template <class F>
__device__ void outer(F f) const
{
f();
}
template <class F, class N, class... Ts>
__device__ void inner_void_impl(F f, N n, Ts&&... xs) const
{
for(index_int j = 0; j < n; j++)
{
f(xs(j, _c<0>)...);
}
}
template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const
{
return make_inner_storage(n, [=](auto j, auto d) {
return f(xs(j, d)...);
});
}
};
template <class Slicer>
static __device__ auto make(index idx, Slicer slicer)
{
return reducer<Slicer>{{}, idx, slicer};
}
#endif
template <class Output, class F> template <class Output, class F>
static __device__ void run(F f) static __device__ void run(F f)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment