Commit b50d336a authored by Paul's avatar Paul
Browse files

Use block_large for layernorm and softmax

parent 70009bcc
...@@ -46,8 +46,9 @@ template <index_int Axis, ...@@ -46,8 +46,9 @@ template <index_int Axis,
__device__ void generic_binary_layernorm( __device__ void generic_binary_layernorm(
F compute, BinOp op, float eps, Output output, Input1 input1, Input2 input2, Inputs... inputs) F compute, BinOp op, float eps, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{ {
using block = reduce::auto_block<reduce::reduce_elements_with_axis<Input1, Axis>()>;
using reduce_output = reduce::with_axis<Input1, Axis>; using reduce_output = reduce::with_axis<Input1, Axis>;
reduce::block::run<reduce_output>([&](auto, auto r) { block::template run<reduce_output>([&](auto, auto r) {
auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2); auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2);
using value_type = typename Input1::type; using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>(); constexpr auto relements = r.template elements<Input1>();
......
...@@ -539,6 +539,26 @@ struct lane ...@@ -539,6 +539,26 @@ struct lane
} }
}; };
// TODO: Remove these in the future when they can be selected in the compiler class
template<index_int RElements>
constexpr auto pick_block()
{
using nlocal = decltype(index{}.max_nlocal());
if constexpr(RElements < nlocal{}*256)
return block{};
else
return block_large{};
}
template<index_int RElements>
using auto_block = decltype(pick_block<RElements>());
template <class Input, index_int Axis>
constexpr auto reduce_elements_with_axis()
{
constexpr auto s = get_shape_c<Input>{};
return s.lens[Axis];
}
} // namespace reduce } // namespace reduce
template <class Algo, template <class Algo,
......
...@@ -32,7 +32,8 @@ namespace migraphx { ...@@ -32,7 +32,8 @@ namespace migraphx {
template <index_int Axis, class Input, class Output> template <index_int Axis, class Input, class Output>
__device__ void softmax(Input input1, Output output) __device__ void softmax(Input input1, Output output)
{ {
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) { using block = reduce::auto_block<reduce::reduce_elements_with_axis<Input, Axis>()>;
block::template run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto input = r.inner(op::id{})(input1); auto input = r.inner(op::id{})(input1);
#ifdef MIGRAPHX_USE_FAST_SOFTMAX #ifdef MIGRAPHX_USE_FAST_SOFTMAX
const auto c = vec_at(r.slice(input1)[0], 0); const auto c = vec_at(r.slice(input1)[0], 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment