Commit c81b53c5 authored by Paul's avatar Paul
Browse files

Allow inner to store registers as a result

parent 180dc7a0
...@@ -100,7 +100,7 @@ struct layernorm_compiler : compiler<layernorm_compiler> ...@@ -100,7 +100,7 @@ struct layernorm_compiler : compiler<layernorm_compiler>
{"layernorm", v.get("layernorm", std::string{"layernorm"})}, {"layernorm", v.get("layernorm", std::string{"layernorm"})},
{"axis", to_string(axis)}, {"axis", to_string(axis)},
{"eps", to_string(eps)}}); {"eps", to_string(eps)}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
...@@ -136,16 +136,28 @@ struct index ...@@ -136,16 +136,28 @@ struct index
return (n - _c<1>) / stride + _c<1>; return (n - _c<1>) / stride + _c<1>;
} }
template <class N>
constexpr auto max_global_stride_iterations(N n) const
{
return max_stride_iterations(n, nglobal());
}
template <class N>
constexpr auto max_local_stride_iterations(N n) const
{
return max_stride_iterations(n, nlocal());
}
template <class F, class I, class D> template <class F, class I, class D>
static constexpr auto invoke_loop(F f, I i, D d) -> decltype(f(i, d), void()) static constexpr auto invoke_loop(F f, I i, D d) -> decltype(f(i, d))
{ {
f(i, d); return f(i, d);
} }
template <class F, class I, class D> template <class F, class I, class D>
static constexpr auto invoke_loop(F f, I i, D) -> decltype(f(i), void()) static constexpr auto invoke_loop(F f, I i, D) -> decltype(f(i))
{ {
f(i); return f(i);
} }
template <class F, class N, class Stride> template <class F, class N, class Stride>
...@@ -168,6 +180,7 @@ struct index ...@@ -168,6 +180,7 @@ struct index
} }
else else
{ {
static_assert(max_stride_iterations(n, stride) < 64);
sequence(max_stride_iterations(n, stride), [&](auto... ks) { sequence(max_stride_iterations(n, stride), [&](auto... ks) {
fold([&](auto d, auto k) { fold([&](auto d, auto k) {
auto i = start + stride * k; auto i = start + stride * k;
......
...@@ -103,10 +103,10 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f) ...@@ -103,10 +103,10 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
#else #else
constexpr index_int lanes_per_thread = 64; constexpr index_int lanes_per_thread = 64;
#endif #endif
using type = decltype(f(0)); using type = decltype(index::invoke_loop(f, 0, _c<0>));
__shared__ type buffer[idx.max_nlocal() / lanes_per_thread]; __shared__ type buffer[idx.max_nlocal() / lanes_per_thread];
type x = init; type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); }); idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
dpp_reduce(x, op); dpp_reduce(x, op);
const auto ldsidx = idx.local / lanes_per_thread; const auto ldsidx = idx.local / lanes_per_thread;
...@@ -131,7 +131,7 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f) ...@@ -131,7 +131,7 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
using type = decltype(f(0)); using type = decltype(f(0));
__shared__ type buffer[idx.max_nlocal()]; __shared__ type buffer[idx.max_nlocal()];
type x = init; type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); }); idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
buffer[idx.local] = x; buffer[idx.local] = x;
__syncthreads(); __syncthreads();
...@@ -167,6 +167,22 @@ constexpr auto reduce_slice(Input input, T i) ...@@ -167,6 +167,22 @@ constexpr auto reduce_slice(Input input, T i)
namespace reduce { namespace reduce {
struct inner_storage_tag
{
};
template<class R, class F>
struct storage_access : F
{
using type = R;
};
template<class R, class F>
constexpr storage_access<R, F> make_storage_access(F f)
{
return {{f}};
}
template <class Slicer, class F> template <class Slicer, class F>
constexpr auto sliced(Slicer slicer, F f) constexpr auto sliced(Slicer slicer, F f)
{ {
...@@ -191,23 +207,156 @@ constexpr auto compute_reduce_axis() ...@@ -191,23 +207,156 @@ constexpr auto compute_reduce_axis()
template <class Input, index_int Axis> template <class Input, index_int Axis>
using with_axis = decltype(compute_reduce_axis<Input, Axis>()); using with_axis = decltype(compute_reduce_axis<Input, Axis>());
template<class Derived>
struct reducer_base
{
template<class T>
__device__ auto make_inner_slice(T x) const
{
if constexpr(is_base_of<inner_storage_tag, T>{})
{
return x;
}
else
{
auto&& derived = static_cast<const Derived&>(*this);
auto t = derived.slice(x);
return make_storage_access<typename decltype(t)::type>([=](auto i, auto...) -> auto& {
return t[i];
});
}
}
template<class T, class... Ts>
constexpr auto get_size(T&& x, [[maybe_unused]] Ts&&... xs) const
{
MIGRAPHX_ASSERT(get_size(x) == get_size(xs...));
return get_size(x);
}
template<class T, class... Ts>
constexpr auto get_size(T&& x) const
{
if constexpr(is_base_of<inner_storage_tag, T>{})
{
return x.rsize();
}
else
{
auto&& derived = static_cast<const Derived&>(*this);
auto t = derived.slice(x);
return t.size();
}
}
template<class F>
__device__ auto inner_sliced(F f) const
{
return [=](auto&&... xs) {
return f(get_size(xs...), make_inner_slice(xs)...);
};
}
template<class T>
static __device__ typename T::type& decl_inner_storage(const T&);
template <class F>
__device__ auto inner(F f) const
{
return this->inner_sliced([=](auto n, auto&&... xs) {
using result_type = decltype(f(decl_inner_storage(xs)...));
auto&& derived = static_cast<const Derived&>(*this);
if constexpr(is_void<result_type>{})
{
derived.inner_void_impl(f, n, xs...);
}
else
{
return derived.template inner_impl<result_type>(f, n, xs...);
}
});
}
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const
{
return this->inner_sliced([=](auto n, auto&&... xs) {
auto&& derived = static_cast<const Derived&>(*this);
return derived.reduce_impl(op, init, read, n, xs...);
});
}
template <class Op, class T>
__device__ auto reduce(Op op, T init) const
{
return this->reduce(op, init, op::id{});
}
template <class F>
__device__ void outer(F f) const
{
f();
}
template <class Input>
constexpr auto elements() const
{
auto&& derived = static_cast<const Derived&>(*this);
using reduce_type = decltype(derived.slice(Input{}));
using value_type = typename Input::type;
constexpr auto relements = get_shape_c<reduce_type>{}.elements();
if constexpr(vec_size<value_type>() > 1)
return relements * vec_size<value_type>();
else
return relements;
}
};
struct block struct block
{ {
template <class Slicer> template <class Slicer>
struct reducer struct reducer : reducer_base<reducer<Slicer>>
{ {
index idx; index idx;
Slicer slice; Slicer slice;
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const template <class T, index_int N, class Size>
struct inner_storage : inner_storage_tag
{ {
return sliced(slice, [=](auto x, auto... xs) { using type = T;
return block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) { array<T, N> arr;
return vec_reduce(read(x[j], xs[j]...), op); constexpr Size rsize() const {return {};}
}); template <class U, class V>
constexpr auto& operator()(U, V d) const
{
return arr[d];
}
template <class U, class V>
constexpr auto& operator()(U, V d)
{
return arr[d];
}
};
template <class Op, class T, class Read, class N, class... Ts>
__device__ auto reduce_impl(Op op, T init, Read read, N n, Ts&&... xs) const
{
return block_reduce(idx, op, init, n, [&](auto j, auto d) {
return vec_reduce(read(xs(j, d)...), op);
}); });
} }
// template <class Op, class T, class Read>
// __device__ auto reduce(Op op, T init, Read read) const
// {
// return sliced(slice, [=](auto x, auto... xs) {
// return block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) {
// return vec_reduce(read(x[j], xs[j]...), op);
// });
// });
// }
template <class F> template <class F>
__device__ void outer(F f) const __device__ void outer(F f) const
{ {
...@@ -215,31 +364,26 @@ struct block ...@@ -215,31 +364,26 @@ struct block
f(); f();
} }
template <class F> template <class F, class N, class... Ts>
__device__ auto inner(F f) const __device__ void inner_void_impl(F f, N n, Ts&&... xs) const
{ {
return sliced(slice, [=](auto x, auto... xs) { idx.local_stride(n, [&](auto j, auto d) { f(xs(j, d)...); });
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); });
});
} }
template <class Input> template <class R, class F, class N, class... Ts>
constexpr auto elements() const __device__ auto inner_impl(F f, N n, Ts&&... xs) const
{ {
using reduce_type = decltype(slice(Input{})); using max_iterations = decltype(idx.max_local_stride_iterations(n));
using value_type = typename Input::type; inner_storage<R, max_iterations{}, N> storage;
constexpr auto relements = get_shape_c<reduce_type>{}.elements(); idx.local_stride(n, [&](auto j, auto d) { storage(j, d) = f(xs(j, d)...); });
if constexpr(vec_size<value_type>() > 1) return storage;
return relements * vec_size<value_type>();
else
return relements;
} }
}; };
template <class Slicer> template <class Slicer>
static __device__ auto make(index idx, Slicer slicer) static __device__ auto make(index idx, Slicer slicer)
{ {
return reducer<Slicer>{idx, slicer}; return reducer<Slicer>{{}, idx, slicer};
} }
template <class Output, class F> template <class Output, class F>
......
...@@ -141,6 +141,25 @@ MIGRAPHX_BUILTIN_TYPE_TRAITN(is_constructible); ...@@ -141,6 +141,25 @@ MIGRAPHX_BUILTIN_TYPE_TRAITN(is_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_nothrow_constructible); MIGRAPHX_BUILTIN_TYPE_TRAITN(is_nothrow_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_trivially_constructible); MIGRAPHX_BUILTIN_TYPE_TRAITN(is_trivially_constructible);
template<class T>
struct remove_cv
{
using type = T;
};
template<class T>
struct remove_cv<const T> : remove_cv<T>
{
};
template<class T>
struct remove_cv<volatile T> : remove_cv<T>
{
};
template<class T>
using remove_cv_t = typename remove_cv<T>::type;
template <class T> template <class T>
struct remove_reference struct remove_reference
{ {
...@@ -168,6 +187,9 @@ struct add_pointer : type_identity<typename remove_reference<T>::type*> ...@@ -168,6 +187,9 @@ struct add_pointer : type_identity<typename remove_reference<T>::type*>
template <class T> template <class T>
using add_pointer_t = typename add_pointer<T>::type; using add_pointer_t = typename add_pointer<T>::type;
template<class T>
struct is_void : is_same<void, remove_cv_t<T>> {};
template <class... Ts> template <class... Ts>
struct common_type; struct common_type;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment