Commit 6ac41ed8 authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/jit-reduce-reg' into ck-gsg

parents 64c4c13b 543e166a
...@@ -127,6 +127,8 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -127,6 +127,8 @@ struct reduce_compiler : compiler<reduce_compiler>
vec = vectorize::elements(ctx, faxis, options.virtual_inputs); vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size; auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto block_size = compute_block_size(relements, 256); auto block_size = compute_block_size(relements, 256);
if(relements > block_size * 256)
algo = "block_large";
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size); v, compute_global_for(ctx, nelements * block_size, 256), block_size);
} }
......
...@@ -46,8 +46,9 @@ template <index_int Axis, ...@@ -46,8 +46,9 @@ template <index_int Axis,
__device__ void generic_binary_layernorm( __device__ void generic_binary_layernorm(
F compute, BinOp op, float eps, Output output, Input1 input1, Input2 input2, Inputs... inputs) F compute, BinOp op, float eps, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{ {
using block = reduce::auto_block<reduce::reduce_elements_with_axis<Input1, Axis>()>;
using reduce_output = reduce::with_axis<Input1, Axis>; using reduce_output = reduce::with_axis<Input1, Axis>;
reduce::block::run<reduce_output>([&](auto, auto r) { block::template run<reduce_output>([&](auto, auto r) {
auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2); auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2);
using value_type = typename Input1::type; using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>(); constexpr auto relements = r.template elements<Input1>();
......
...@@ -128,7 +128,7 @@ template <class Op, class T, class Index, class F> ...@@ -128,7 +128,7 @@ template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal()); MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
using type = decltype(f(0)); using type = decltype(index::invoke_loop(f, 0, _c<0>));
__shared__ type buffer[idx.max_nlocal()]; __shared__ type buffer[idx.max_nlocal()];
type x = init; type x = init;
idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); }); idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
...@@ -388,6 +388,79 @@ struct block ...@@ -388,6 +388,79 @@ struct block
} }
}; };
struct block_large
{
template <class Slicer>
struct reducer : reducer_base<reducer<Slicer>>
{
index idx;
Slicer slice;
template <class Size, class F>
struct inner_storage : inner_storage_tag
{
using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
F f;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto operator()(U j, V d) const
{
return f(j, d);
}
};
template <class Size, class F>
constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{
return {f};
}
template <class Op, class T, class Read, class N, class... Ts>
__device__ auto reduce_impl(Op op, T init, Read read, N n, Ts&&... xs) const
{
return block_reduce(idx, op, init, index_int{n}, [&](auto j, auto d) {
return vec_reduce(read(xs(j, d)...), op);
});
}
template <class F>
__device__ void outer(F f) const
{
if(idx.local == 0)
f();
}
template <class F, class N, class... Ts>
__device__ void inner_void_impl(F f, N n, Ts&&... xs) const
{
idx.local_stride(index_int{n}, [&](auto j, auto d) { f(xs(j, d)...); });
}
template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const
{
return make_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
}
};
template <class Slicer>
static __device__ auto make(index idx, Slicer slicer)
{
return reducer<Slicer>{{}, idx, slicer};
}
template <class Output, class F>
static __device__ void run(F f)
{
auto idx = make_index();
constexpr auto nelements = get_shape_c<Output>{}.elements();
idx.global_stride(nelements * idx.nlocal(), [&](auto i) {
const auto out_idx = get_shape_c<Output>{}.multi(i / idx.nlocal());
f(out_idx, make(idx, [&](auto input) { return reduce_slice<Output>(input, out_idx); }));
});
}
};
struct lane struct lane
{ {
template <class Slicer> template <class Slicer>
...@@ -466,6 +539,26 @@ struct lane ...@@ -466,6 +539,26 @@ struct lane
} }
}; };
// TODO: Remove these in the future when they can be selected in the compiler class
template <index_int RElements>
constexpr auto pick_block()
{
using nlocal = decltype(index{}.max_nlocal());
if constexpr(RElements < nlocal{} * 256)
return block{};
else
return block_large{};
}
template <index_int RElements>
using auto_block = decltype(pick_block<RElements>());
template <class Input, index_int Axis>
constexpr auto reduce_elements_with_axis()
{
constexpr auto s = get_shape_c<Input>{};
return s.lens[Axis];
}
} // namespace reduce } // namespace reduce
template <class Algo, template <class Algo,
......
...@@ -32,7 +32,8 @@ namespace migraphx { ...@@ -32,7 +32,8 @@ namespace migraphx {
template <index_int Axis, class Input, class Output> template <index_int Axis, class Input, class Output>
__device__ void softmax(Input input1, Output output) __device__ void softmax(Input input1, Output output)
{ {
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) { using block = reduce::auto_block<reduce::reduce_elements_with_axis<Input, Axis>()>;
block::template run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto input = r.inner(op::id{})(input1); auto input = r.inner(op::id{})(input1);
#ifdef MIGRAPHX_USE_FAST_SOFTMAX #ifdef MIGRAPHX_USE_FAST_SOFTMAX
const auto c = vec_at(r.slice(input1)[0], 0); const auto c = vec_at(r.slice(input1)[0], 0);
......
...@@ -76,3 +76,16 @@ struct test_reduce_mean_2 : verify_program<test_reduce_mean_2> ...@@ -76,3 +76,16 @@ struct test_reduce_mean_2 : verify_program<test_reduce_mean_2>
return p; return p;
}; };
}; };
struct test_large_reduce_mean : verify_program<test_large_reduce_mean>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 256 * 256 * 16}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::op::reduce_mean{{1}}, x);
return p;
};
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment