Commit 2851a6e9 authored by Paul's avatar Paul
Browse files

Format

parent efa5dcce
......@@ -51,10 +51,11 @@ struct softmax_compiler : compiler<softmax_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
auto axis = v.at("axis").to<int64_t>();
auto axis = v.at("axis").to<int64_t>();
auto block_size = compute_block_size(inputs[0].lens()[axis], 256);
hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements(), block_size), 256);
options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements(), block_size), 256);
options.output = inputs.back();
options.inputs = inputs;
options.kernel_name = "softmax_kernel";
......@@ -70,7 +71,6 @@ struct softmax_compiler : compiler<softmax_compiler>
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -194,9 +194,8 @@ constexpr auto transform(integral_const_array<T, Xs...>, F f)
template <class T, T... Xs, class F>
constexpr auto transform_i(integral_const_array<T, Xs...>, F f)
{
return sequence_c<sizeof...(Xs)>([=](auto... is) {
return integral_const_array<T, f(Xs, is)...>{};
});
return sequence_c<sizeof...(Xs)>(
[=](auto... is) { return integral_const_array<T, f(Xs, is)...>{}; });
}
template <class T, T... Xs, class U, U... Ys, class F>
......
......@@ -152,19 +152,19 @@ constexpr auto sliced(Slicer slicer, F f)
};
}
template<class Input, index_int Axis>
template <class Input, index_int Axis>
constexpr auto compute_reduce_axis()
{
constexpr auto lens = transform_i(get_shape_c<Input>{}.lens,
[](index_int x, index_int i) -> index_int {
if(i == Axis)
return 1;
return x;
});
constexpr auto lens =
transform_i(get_shape_c<Input>{}.lens, [](index_int x, index_int i) -> index_int {
if(i == Axis)
return 1;
return x;
});
return make_shape(lens, get_shape_c<Input>{}.strides);
}
template<class Input, index_int Axis>
template <class Input, index_int Axis>
using with_axis = decltype(compute_reduce_axis<Input, Axis>());
struct block
......@@ -196,9 +196,7 @@ struct block
{
return [=](auto f) {
// TODO: Assert same elements
idx.local_stride(x.elements(), [&](auto j) {
f(x[j], xs[j]...);
});
idx.local_stride(x.elements(), [&](auto j) { f(x[j], xs[j]...); });
};
}
};
......
......@@ -6,17 +6,15 @@
namespace migraphx {
template<index_int Axis, class Input, class Output>
template <index_int Axis, class Input, class Output>
void softmax(Input input, Output output)
{
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto batch_max = r.reduce(op::max{}, lowest{}, op::id{})(input);
auto batch_sum = r.reduce(op::sum{}, 0, [&](auto x) {
return migraphx::exp(x - batch_max);
})(input);
r.outer(output, input)([&](auto& y, auto x) {
y = migraphx::exp(x - batch_max) / batch_sum;
});
auto batch_max = r.reduce(op::max{}, lowest{}, op::id{})(input);
auto batch_sum =
r.reduce(op::sum{}, 0, [&](auto x) { return migraphx::exp(x - batch_max); })(input);
r.outer(output,
input)([&](auto& y, auto x) { y = migraphx::exp(x - batch_max) / batch_sum; });
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment