Commit c81b53c5 authored by Paul's avatar Paul
Browse files

Allow inner to store registers as a result

parent 180dc7a0
......@@ -100,7 +100,7 @@ struct layernorm_compiler : compiler<layernorm_compiler>
{"layernorm", v.get("layernorm", std::string{"layernorm"})},
{"axis", to_string(axis)},
{"eps", to_string(eps)}});
return compile_hip_code_object(src, options);
}
......
......@@ -136,16 +136,28 @@ struct index
return (n - _c<1>) / stride + _c<1>;
}
template <class N>
constexpr auto max_global_stride_iterations(N n) const
{
return max_stride_iterations(n, nglobal());
}
template <class N>
constexpr auto max_local_stride_iterations(N n) const
{
return max_stride_iterations(n, nlocal());
}
template <class F, class I, class D>
static constexpr auto invoke_loop(F f, I i, D d) -> decltype(f(i, d), void())
static constexpr auto invoke_loop(F f, I i, D d) -> decltype(f(i, d))
{
f(i, d);
return f(i, d);
}
template <class F, class I, class D>
static constexpr auto invoke_loop(F f, I i, D) -> decltype(f(i), void())
static constexpr auto invoke_loop(F f, I i, D) -> decltype(f(i))
{
f(i);
return f(i);
}
template <class F, class N, class Stride>
......@@ -168,6 +180,7 @@ struct index
}
else
{
static_assert(max_stride_iterations(n, stride) < 64);
sequence(max_stride_iterations(n, stride), [&](auto... ks) {
fold([&](auto d, auto k) {
auto i = start + stride * k;
......
......@@ -103,10 +103,10 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
#else
constexpr index_int lanes_per_thread = 64;
#endif
using type = decltype(f(0));
using type = decltype(index::invoke_loop(f, 0, _c<0>));
__shared__ type buffer[idx.max_nlocal() / lanes_per_thread];
type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
dpp_reduce(x, op);
const auto ldsidx = idx.local / lanes_per_thread;
......@@ -131,7 +131,7 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
using type = decltype(f(0));
__shared__ type buffer[idx.max_nlocal()];
type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
buffer[idx.local] = x;
__syncthreads();
......@@ -167,6 +167,22 @@ constexpr auto reduce_slice(Input input, T i)
namespace reduce {
struct inner_storage_tag
{
};
template<class R, class F>
struct storage_access : F
{
using type = R;
};
template<class R, class F>
constexpr storage_access<R, F> make_storage_access(F f)
{
return {{f}};
}
template <class Slicer, class F>
constexpr auto sliced(Slicer slicer, F f)
{
......@@ -191,23 +207,156 @@ constexpr auto compute_reduce_axis()
template <class Input, index_int Axis>
using with_axis = decltype(compute_reduce_axis<Input, Axis>());
template<class Derived>
struct reducer_base
{
template<class T>
__device__ auto make_inner_slice(T x) const
{
if constexpr(is_base_of<inner_storage_tag, T>{})
{
return x;
}
else
{
auto&& derived = static_cast<const Derived&>(*this);
auto t = derived.slice(x);
return make_storage_access<typename decltype(t)::type>([=](auto i, auto...) -> auto& {
return t[i];
});
}
}
template<class T, class... Ts>
constexpr auto get_size(T&& x, [[maybe_unused]] Ts&&... xs) const
{
MIGRAPHX_ASSERT(get_size(x) == get_size(xs...));
return get_size(x);
}
template<class T, class... Ts>
constexpr auto get_size(T&& x) const
{
if constexpr(is_base_of<inner_storage_tag, T>{})
{
return x.rsize();
}
else
{
auto&& derived = static_cast<const Derived&>(*this);
auto t = derived.slice(x);
return t.size();
}
}
template<class F>
__device__ auto inner_sliced(F f) const
{
return [=](auto&&... xs) {
return f(get_size(xs...), make_inner_slice(xs)...);
};
}
template<class T>
static __device__ typename T::type& decl_inner_storage(const T&);
template <class F>
__device__ auto inner(F f) const
{
return this->inner_sliced([=](auto n, auto&&... xs) {
using result_type = decltype(f(decl_inner_storage(xs)...));
auto&& derived = static_cast<const Derived&>(*this);
if constexpr(is_void<result_type>{})
{
derived.inner_void_impl(f, n, xs...);
}
else
{
return derived.template inner_impl<result_type>(f, n, xs...);
}
});
}
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const
{
return this->inner_sliced([=](auto n, auto&&... xs) {
auto&& derived = static_cast<const Derived&>(*this);
return derived.reduce_impl(op, init, read, n, xs...);
});
}
template <class Op, class T>
__device__ auto reduce(Op op, T init) const
{
return this->reduce(op, init, op::id{});
}
template <class F>
__device__ void outer(F f) const
{
f();
}
template <class Input>
constexpr auto elements() const
{
auto&& derived = static_cast<const Derived&>(*this);
using reduce_type = decltype(derived.slice(Input{}));
using value_type = typename Input::type;
constexpr auto relements = get_shape_c<reduce_type>{}.elements();
if constexpr(vec_size<value_type>() > 1)
return relements * vec_size<value_type>();
else
return relements;
}
};
struct block
{
template <class Slicer>
struct reducer
struct reducer : reducer_base<reducer<Slicer>>
{
index idx;
Slicer slice;
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const
template <class T, index_int N, class Size>
struct inner_storage : inner_storage_tag
{
return sliced(slice, [=](auto x, auto... xs) {
return block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) {
return vec_reduce(read(x[j], xs[j]...), op);
});
using type = T;
array<T, N> arr;
constexpr Size rsize() const {return {};}
template <class U, class V>
constexpr auto& operator()(U, V d) const
{
return arr[d];
}
template <class U, class V>
constexpr auto& operator()(U, V d)
{
return arr[d];
}
};
template <class Op, class T, class Read, class N, class... Ts>
__device__ auto reduce_impl(Op op, T init, Read read, N n, Ts&&... xs) const
{
return block_reduce(idx, op, init, n, [&](auto j, auto d) {
return vec_reduce(read(xs(j, d)...), op);
});
}
// template <class Op, class T, class Read>
// __device__ auto reduce(Op op, T init, Read read) const
// {
// return sliced(slice, [=](auto x, auto... xs) {
// return block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) {
// return vec_reduce(read(x[j], xs[j]...), op);
// });
// });
// }
template <class F>
__device__ void outer(F f) const
{
......@@ -215,31 +364,26 @@ struct block
f();
}
template <class F>
__device__ auto inner(F f) const
template <class F, class N, class... Ts>
__device__ void inner_void_impl(F f, N n, Ts&&... xs) const
{
return sliced(slice, [=](auto x, auto... xs) {
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); });
});
idx.local_stride(n, [&](auto j, auto d) { f(xs(j, d)...); });
}
template <class Input>
constexpr auto elements() const
template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const
{
using reduce_type = decltype(slice(Input{}));
using value_type = typename Input::type;
constexpr auto relements = get_shape_c<reduce_type>{}.elements();
if constexpr(vec_size<value_type>() > 1)
return relements * vec_size<value_type>();
else
return relements;
using max_iterations = decltype(idx.max_local_stride_iterations(n));
inner_storage<R, max_iterations{}, N> storage;
idx.local_stride(n, [&](auto j, auto d) { storage(j, d) = f(xs(j, d)...); });
return storage;
}
};
template <class Slicer>
static __device__ auto make(index idx, Slicer slicer)
{
return reducer<Slicer>{idx, slicer};
return reducer<Slicer>{{}, idx, slicer};
}
template <class Output, class F>
......
......@@ -141,6 +141,25 @@ MIGRAPHX_BUILTIN_TYPE_TRAITN(is_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_nothrow_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_trivially_constructible);
template<class T>
struct remove_cv
{
using type = T;
};
template<class T>
struct remove_cv<const T> : remove_cv<T>
{
};
template<class T>
struct remove_cv<volatile T> : remove_cv<T>
{
};
template<class T>
using remove_cv_t = typename remove_cv<T>::type;
template <class T>
struct remove_reference
{
......@@ -168,6 +187,9 @@ struct add_pointer : type_identity<typename remove_reference<T>::type*>
template <class T>
using add_pointer_t = typename add_pointer<T>::type;
template<class T>
struct is_void : is_same<void, remove_cv_t<T>> {};
template <class... Ts>
struct common_type;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment