#include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace gpu { namespace device { argument softmax(hipStream_t stream, const migraphx::shape& output_shape, std::vector args, int axis) { auto lens = output_shape.lens(); std::size_t batch_size = std::accumulate( lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies()); std::size_t n_dims = std::accumulate( lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies()); migraphx::shape comp_shape{output_shape.type(), {batch_size, n_dims}}; visit_all(args.back(), args.front())([&](auto output, auto input) { const auto* input_ptr = device_cast(input.data()); auto* output_ptr = device_cast(output.data()); // each thread is for one item in the batch gs_launch(stream, batch_size)([=](auto i) { std::size_t row_start = i * n_dims; // get max auto batch_max = input_ptr[row_start]; for(std::size_t j = 0; j < n_dims; ++j) { auto ind = row_start + j; auto hip_type_input = to_hip_type(input_ptr[ind]); batch_max = std::max(to_hip_type(batch_max), hip_type_input); output_ptr[ind] = ::exp(hip_type_input); } auto batch_sum = output_ptr[row_start]; for(std::size_t j = 1; j < n_dims; ++j) { auto ind = row_start + j; batch_sum += output_ptr[ind]; } for(std::size_t j = 0; j < n_dims; ++j) { auto ind = row_start + j; output_ptr[ind] /= batch_sum; } }); }); return args.back(); } } // namespace device } // namespace gpu } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx