Commit 2851a6e9 authored by Paul's avatar Paul
Browse files

Format

parent efa5dcce
...@@ -51,10 +51,11 @@ struct softmax_compiler : compiler<softmax_compiler> ...@@ -51,10 +51,11 @@ struct softmax_compiler : compiler<softmax_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
auto axis = v.at("axis").to<int64_t>(); auto axis = v.at("axis").to<int64_t>();
auto block_size = compute_block_size(inputs[0].lens()[axis], 256); auto block_size = compute_block_size(inputs[0].lens()[axis], 256);
hip_compile_options options; hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements(), block_size), 256); options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements(), block_size), 256);
options.output = inputs.back(); options.output = inputs.back();
options.inputs = inputs; options.inputs = inputs;
options.kernel_name = "softmax_kernel"; options.kernel_name = "softmax_kernel";
...@@ -70,7 +71,6 @@ struct softmax_compiler : compiler<softmax_compiler> ...@@ -70,7 +71,6 @@ struct softmax_compiler : compiler<softmax_compiler>
} }
}; };
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -194,9 +194,8 @@ constexpr auto transform(integral_const_array<T, Xs...>, F f) ...@@ -194,9 +194,8 @@ constexpr auto transform(integral_const_array<T, Xs...>, F f)
template <class T, T... Xs, class F> template <class T, T... Xs, class F>
constexpr auto transform_i(integral_const_array<T, Xs...>, F f) constexpr auto transform_i(integral_const_array<T, Xs...>, F f)
{ {
return sequence_c<sizeof...(Xs)>([=](auto... is) { return sequence_c<sizeof...(Xs)>(
return integral_const_array<T, f(Xs, is)...>{}; [=](auto... is) { return integral_const_array<T, f(Xs, is)...>{}; });
});
} }
template <class T, T... Xs, class U, U... Ys, class F> template <class T, T... Xs, class U, U... Ys, class F>
......
...@@ -152,19 +152,19 @@ constexpr auto sliced(Slicer slicer, F f) ...@@ -152,19 +152,19 @@ constexpr auto sliced(Slicer slicer, F f)
}; };
} }
template<class Input, index_int Axis> template <class Input, index_int Axis>
constexpr auto compute_reduce_axis() constexpr auto compute_reduce_axis()
{ {
constexpr auto lens = transform_i(get_shape_c<Input>{}.lens, constexpr auto lens =
[](index_int x, index_int i) -> index_int { transform_i(get_shape_c<Input>{}.lens, [](index_int x, index_int i) -> index_int {
if(i == Axis) if(i == Axis)
return 1; return 1;
return x; return x;
}); });
return make_shape(lens, get_shape_c<Input>{}.strides); return make_shape(lens, get_shape_c<Input>{}.strides);
} }
template<class Input, index_int Axis> template <class Input, index_int Axis>
using with_axis = decltype(compute_reduce_axis<Input, Axis>()); using with_axis = decltype(compute_reduce_axis<Input, Axis>());
struct block struct block
...@@ -196,9 +196,7 @@ struct block ...@@ -196,9 +196,7 @@ struct block
{ {
return [=](auto f) { return [=](auto f) {
// TODO: Assert same elements // TODO: Assert same elements
idx.local_stride(x.elements(), [&](auto j) { idx.local_stride(x.elements(), [&](auto j) { f(x[j], xs[j]...); });
f(x[j], xs[j]...);
});
}; };
} }
}; };
......
...@@ -6,17 +6,15 @@ ...@@ -6,17 +6,15 @@
namespace migraphx { namespace migraphx {
template<index_int Axis, class Input, class Output> template <index_int Axis, class Input, class Output>
void softmax(Input input, Output output) void softmax(Input input, Output output)
{ {
reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) { reduce::block::run<reduce::with_axis<Input, Axis>>([&](auto, auto r) {
auto batch_max = r.reduce(op::max{}, lowest{}, op::id{})(input); auto batch_max = r.reduce(op::max{}, lowest{}, op::id{})(input);
auto batch_sum = r.reduce(op::sum{}, 0, [&](auto x) { auto batch_sum =
return migraphx::exp(x - batch_max); r.reduce(op::sum{}, 0, [&](auto x) { return migraphx::exp(x - batch_max); })(input);
})(input); r.outer(output,
r.outer(output, input)([&](auto& y, auto x) { input)([&](auto& y, auto x) { y = migraphx::exp(x - batch_max) / batch_sum; });
y = migraphx::exp(x - batch_max) / batch_sum;
});
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment