Commit 12a79223 authored by Khalique's avatar Khalique
Browse files

formatting

parent 2ebb3515
...@@ -13,9 +13,9 @@ namespace gpu { ...@@ -13,9 +13,9 @@ namespace gpu {
namespace device { namespace device {
argument softmax(hipStream_t stream, argument softmax(hipStream_t stream,
const migraphx::shape& output_shape, const migraphx::shape& output_shape,
std::vector<migraphx::argument> args, std::vector<migraphx::argument> args,
int axis) int axis)
{ {
auto lens = output_shape.lens(); auto lens = output_shape.lens();
std::size_t batch_size = std::accumulate( std::size_t batch_size = std::accumulate(
...@@ -35,17 +35,17 @@ argument softmax(hipStream_t stream, ...@@ -35,17 +35,17 @@ argument softmax(hipStream_t stream,
auto batch_max = input_ptr[row_start]; auto batch_max = input_ptr[row_start];
for(std::size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
auto ind = row_start + j; auto ind = row_start + j;
auto hip_type_input = to_hip_type(input_ptr[ind]); auto hip_type_input = to_hip_type(input_ptr[ind]);
batch_max = std::max(to_hip_type(batch_max), hip_type_input); batch_max = std::max(to_hip_type(batch_max), hip_type_input);
output_ptr[ind] = ::exp(hip_type_input); output_ptr[ind] = ::exp(hip_type_input);
} }
auto batch_sum = output_ptr[row_start]; auto batch_sum = output_ptr[row_start];
for(std::size_t j = 1; j < n_dims; ++j) for(std::size_t j = 1; j < n_dims; ++j)
{ {
auto ind = row_start + j; auto ind = row_start + j;
batch_sum += output_ptr[ind]; batch_sum += output_ptr[ind];
} }
for(std::size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
...@@ -59,7 +59,6 @@ argument softmax(hipStream_t stream, ...@@ -59,7 +59,6 @@ argument softmax(hipStream_t stream,
return args.back(); return args.back();
} }
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -11,9 +11,9 @@ namespace gpu { ...@@ -11,9 +11,9 @@ namespace gpu {
namespace device { namespace device {
argument softmax(hipStream_t stream, argument softmax(hipStream_t stream,
const migraphx::shape& output_shape, const migraphx::shape& output_shape,
std::vector<migraphx::argument> args, std::vector<migraphx::argument> args,
int axis); int axis);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -12,8 +12,8 @@ shape hip_softmax::compute_shape(const std::vector<shape>& inputs) const ...@@ -12,8 +12,8 @@ shape hip_softmax::compute_shape(const std::vector<shape>& inputs) const
} }
argument hip_softmax::compute(context& ctx, argument hip_softmax::compute(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
return device::softmax(ctx.get_stream().get(), output_shape, args, 1); return device::softmax(ctx.get_stream().get(), output_shape, args, 1);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment