Commit bb0fff52 authored by Paul's avatar Paul
Browse files

Format

parent c9bc461c
...@@ -52,8 +52,8 @@ struct softmax_compiler : compiler<softmax_compiler> ...@@ -52,8 +52,8 @@ struct softmax_compiler : compiler<softmax_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
auto axis = v.at("axis").to<int64_t>(); auto axis = v.at("axis").to<int64_t>();
auto relements = inputs[0].lens()[axis]; auto relements = inputs[0].lens()[axis];
auto nelements = inputs.back().elements() / relements; auto nelements = inputs.back().elements() / relements;
auto block_size = compute_block_size(relements, 256); auto block_size = compute_block_size(relements, 256);
hip_compile_options options; hip_compile_options options;
options.set_launch_params( options.set_launch_params(
......
...@@ -14,7 +14,7 @@ __device__ void softmax(Input input, Output output) ...@@ -14,7 +14,7 @@ __device__ void softmax(Input input, Output output)
auto batch_sum = auto batch_sum =
r.reduce(op::sum{}, 0, [&](auto x) { return migraphx::exp(x - batch_max); })(input); r.reduce(op::sum{}, 0, [&](auto x) { return migraphx::exp(x - batch_max); })(input);
r.inner([&](auto& y, auto x) { y = migraphx::exp(x - batch_max) / batch_sum; })(output, r.inner([&](auto& y, auto x) { y = migraphx::exp(x - batch_max) / batch_sum; })(output,
input); input);
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment