Commit b50d336a authored by Paul's avatar Paul
Browse files

Use block_large for layernorm and softmax

parent 70009bcc
......@@ -46,8 +46,9 @@ template <index_int Axis,
__device__ void generic_binary_layernorm(
F compute, BinOp op, float eps, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{
using block = reduce::auto_block<reduce::reduce_elements_with_axis<Input1, Axis>()>;
using reduce_output = reduce::with_axis<Input1, Axis>;
reduce::block::run<reduce_output>([&](auto, auto r) {
block::template run<reduce_output>([&](auto, auto r) {
auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2);
using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>();
......
......@@ -539,6 +539,26 @@ struct lane
}
};
// TODO: Remove these in the future when they can be selected in the compiler class
template<index_int RElements>
constexpr auto pick_block()
{
using nlocal = decltype(index{}.max_nlocal());
if constexpr(RElements < nlocal{}*256)
return block{};
else
return block_large{};
}
template<index_int RElements>
using auto_block = decltype(pick_block<RElements>());
template <class Input, index_int Axis>
constexpr auto reduce_elements_with_axis()
{
constexpr auto s = get_shape_c<Input>{};
return s.lens[Axis];
}
} // namespace reduce
template <class Algo,
......
......@@ -32,7 +32,8 @@ namespace migraphx {
template <index_int Axis, class Input, class Output>
__device__ void softmax(Input input1, Output output)
{
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
using block = reduce::auto_block<reduce::reduce_elements_with_axis<Input, Axis>()>;
block::template run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto input = r.inner(op::id{})(input1);
#ifdef MIGRAPHX_USE_FAST_SOFTMAX
const auto c = vec_at(r.slice(input1)[0], 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment