Commit 5aac7a8c authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/jit-reduce-reg' into ck-gsg

parents 51c34a3b 4f12db9e
...@@ -118,16 +118,14 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -118,16 +118,14 @@ struct reduce_compiler : compiler<reduce_compiler>
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
auto faxis = find_fast_axis({options.virtual_inputs.front()}); auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{}; vectorize vec{};
// Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1)
{
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
}
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto nelements = options.virtual_inputs.back().elements(); auto nelements = options.virtual_inputs.back().elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs)); auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs));
if(algo == "block") if(algo == "block")
{ {
// Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1)
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto block_size = compute_block_size(relements, 256); auto block_size = compute_block_size(relements, 256);
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size); v, compute_global_for(ctx, nelements * block_size, 256), block_size);
...@@ -166,7 +164,7 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -166,7 +164,7 @@ struct reduce_compiler : compiler<reduce_compiler>
auto reduce_elements = get_reduce_elements(ins->inputs()); auto reduce_elements = get_reduce_elements(ins->inputs());
auto reduce_type = ins->inputs().front()->get_shape().type(); auto reduce_type = ins->inputs().front()->get_shape().type();
v["reduction"] = "op::sum{}"; v["reduction"] = "op::sum{}";
std::string mean = "op::mean{" + std::to_string(reduce_elements) + "}"; std::string mean = "op::mean<" + std::to_string(reduce_elements) + ">{}";
// Use float accumulator when reduction size is too large for half // Use float accumulator when reduction size is too large for half
if(reduce_type == shape::half_type and reduce_elements > 16384) if(reduce_type == shape::half_type and reduce_elements > 16384)
v["read"] = "compose(" + mean + ", op::convert_to<float>{})"; v["read"] = "compose(" + mean + ", op::convert_to<float>{})";
......
...@@ -178,5 +178,9 @@ MIGRAPHX_HIP_NORETURN inline __host__ __device__ void assert_fail(const source_l ...@@ -178,5 +178,9 @@ MIGRAPHX_HIP_NORETURN inline __host__ __device__ void assert_fail(const source_l
#define MIGRAPHX_WARN(...) #define MIGRAPHX_WARN(...)
#endif #endif
#define MIGRAPHX_STATIC_ASSERT_FOR(...) \
static_assert(__VA_ARGS__); \
if constexpr(__VA_ARGS__)
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_DEBUG_HPP #endif // MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
...@@ -163,6 +163,30 @@ struct index ...@@ -163,6 +163,30 @@ struct index
} }
template <class F, class N, class Stride> template <class F, class N, class Stride>
static constexpr void for_stride_loop_unroll(index_int start, N n, Stride stride, F f)
{
sequence(max_stride_iterations(n, stride), [&](auto... ks) {
fold([&](auto d, auto k) {
auto i = start + stride * k;
if(i < n)
invoke_loop(f, i, d);
return d + _c<1>;
})(_c<0>, ks...);
});
}
template <class F, class N, class Stride>
static constexpr void for_stride_loop(index_int start, N n, Stride stride, F f)
{
index_int k = 0;
for(index_int i = start; i < n; i += stride)
{
invoke_loop(f, i, k);
k++;
}
}
template <bool Unroll, class F, class N, class Stride>
static constexpr void for_stride(index_int start, N n, Stride stride, F f) static constexpr void for_stride(index_int start, N n, Stride stride, F f)
{ {
MIGRAPHX_ASSERT(start < stride); MIGRAPHX_ASSERT(start < stride);
...@@ -180,40 +204,34 @@ struct index ...@@ -180,40 +204,34 @@ struct index
invoke_loop(f, start, _c<0>); invoke_loop(f, start, _c<0>);
} }
} }
else else if constexpr(Unroll)
{ {
static_assert(max_stride_iterations(n, stride) < 128); MIGRAPHX_STATIC_ASSERT_FOR(max_stride_iterations(n, stride) < 256)
sequence(max_stride_iterations(n, stride), [&](auto... ks) { {
fold([&](auto d, auto k) { for_stride_loop_unroll(start, n, stride, f);
auto i = start + stride * k;
if(i < n)
invoke_loop(f, i, d);
return d + _c<1>;
})(_c<0>, ks...);
});
} }
} }
else else
{ {
index_int k = 0; for_stride_loop(start, n, stride, f);
for(index_int i = start; i < n; i += stride) }
{
invoke_loop(f, i, k);
k++;
} }
else
{
for_stride_loop(start, n, stride, f);
} }
} }
template <class F, class N> template <class F, class N>
__device__ void global_stride(N n, F f) const __device__ void global_stride(N n, F f) const
{ {
for_stride(global, n, nglobal(), f); for_stride<false>(global, n, nglobal(), f);
} }
template <class F, class N> template <class F, class N>
__device__ void local_stride(N n, F f) const __device__ void local_stride(N n, F f) const
{ {
for_stride(local, n, nlocal(), f); for_stride<true>(local, n, nlocal(), f);
} }
template <class F, class N> template <class F, class N>
......
...@@ -66,13 +66,22 @@ struct convert_to ...@@ -66,13 +66,22 @@ struct convert_to
} }
}; };
template <index_int N>
struct mean struct mean
{ {
index_int item_num = 1;
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const MIGRAPHX_DEVICE_CONSTEXPR T operator()(T x) const
{
using type = vec_type<T>;
if constexpr(is_floating_point<type>{})
{ {
return x / static_cast<T>(item_num); constexpr type d = 1.0 / N;
return x * d;
}
else
{
return x / static_cast<type>(N);
}
} }
}; };
......
...@@ -391,22 +391,40 @@ struct block ...@@ -391,22 +391,40 @@ struct block
struct lane struct lane
{ {
template <class Slicer> template <class Slicer>
struct reducer struct reducer : reducer_base<reducer<Slicer>>
{ {
index idx; index idx;
Slicer slice; Slicer slice;
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const template <class Size, class F>
struct inner_storage : inner_storage_tag
{
using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
F f;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto operator()(U j, V d) const
{ {
return sliced(slice, [=](auto x, auto... xs) { return f(j, d);
using type = typename decltype(x)::type; }
};
template <class Size, class F>
constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{
return {f};
}
template <class Op, class T, class Read, class N, class U, class... Us>
__device__ auto reduce_impl(Op op, T init, Read read, N n, U&& x, Us&&... xs) const
{
using type = remove_reference_t<decltype(x(0, _c<0>))>;
type r = init; type r = init;
for(index_int j = 0; j < x.get_shape().elements(); j++) for(index_int j = 0; j < n; j++)
{ {
r = op(r, read(x[j], xs[j]...)); r = op(r, read(x(j, _c<0>), xs(j, _c<0>)...));
} }
return r; return r;
});
} }
template <class F> template <class F>
...@@ -415,29 +433,25 @@ struct lane ...@@ -415,29 +433,25 @@ struct lane
f(); f();
} }
template <class F> template <class F, class N, class... Ts>
__device__ auto inner(F f) const __device__ void inner_void_impl(F f, N n, Ts&&... xs) const
{ {
return sliced(slice, [=](auto x, auto... xs) { for(index_int j = 0; j < n; j++)
for(index_int j = 0; j < x.get_shape().elements(); j++)
{ {
f(x[j], xs[j]...); f(xs(j, _c<0>)...);
} }
});
} }
template <class Input> template <class R, class F, class N, class... Ts>
constexpr auto elements() const __device__ auto inner_impl(F f, N n, Ts&&... xs) const
{ {
using reduce_type = decltype(slice(Input{})); return make_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
return get_shape_c<reduce_type>{}.elements();
} }
}; };
template <class Slicer> template <class Slicer>
static __device__ auto make(index idx, Slicer slicer) static __device__ auto make(index idx, Slicer slicer)
{ {
return reducer<Slicer>{idx, slicer}; return reducer<Slicer>{{}, idx, slicer};
} }
template <class Output, class F> template <class Output, class F>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment