"docs/vscode:/vscode.git/clone" did not exist on "99c39b40121a6ea6fcecc575d38bc6ff12b0d979"
Commit 7f7cbbc0 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 88351f31
...@@ -17,9 +17,9 @@ argument softmax(hipStream_t stream, ...@@ -17,9 +17,9 @@ argument softmax(hipStream_t stream,
std::vector<migraphx::argument> args, std::vector<migraphx::argument> args,
int axis) int axis)
{ {
auto lens = output_shape.lens(); auto lens = output_shape.lens();
auto batch_lens = lens; auto batch_lens = lens;
size_t n_dims = lens[axis]; size_t n_dims = lens[axis];
batch_lens[axis] = 1; batch_lens[axis] = 1;
migraphx::shape batch_shape{shape::int32_type, batch_lens}; migraphx::shape batch_shape{shape::int32_type, batch_lens};
...@@ -33,26 +33,27 @@ argument softmax(hipStream_t stream, ...@@ -33,26 +33,27 @@ argument softmax(hipStream_t stream,
// each thread is for one item in the batch // each thread is for one item in the batch
gs_launch(stream, batch_shape.elements())([=](auto i) { gs_launch(stream, batch_shape.elements())([=](auto i) {
auto batch_idx = desc_batch.multi(i); auto batch_idx = desc_batch.multi(i);
auto data_idx = batch_idx; auto data_idx = batch_idx;
// get max // get max
auto batch_max = input_ptr[desc_data.linear(batch_idx)]; auto batch_max = input_ptr[desc_data.linear(batch_idx)];
for(std::size_t j = 1; j < n_dims; ++j) for(std::size_t j = 1; j < n_dims; ++j)
{ {
data_idx[axis] = j; data_idx[axis] = j;
batch_max = std::max(to_hip_type(batch_max), to_hip_type(input_ptr[desc_data.linear(data_idx)])); batch_max = std::max(to_hip_type(batch_max),
to_hip_type(input_ptr[desc_data.linear(data_idx)]));
} }
for(std::size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
data_idx[axis] = j; data_idx[axis] = j;
auto idx = desc_data.linear(data_idx); auto idx = desc_data.linear(data_idx);
output_ptr[idx] = input_ptr[idx] - batch_max; output_ptr[idx] = input_ptr[idx] - batch_max;
} }
for(std::size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
data_idx[axis] = j; data_idx[axis] = j;
auto idx = desc_data.linear(data_idx); auto idx = desc_data.linear(data_idx);
output_ptr[idx] = exp(to_hip_type(output_ptr[idx])); output_ptr[idx] = exp(to_hip_type(output_ptr[idx]));
} }
...@@ -65,8 +66,8 @@ argument softmax(hipStream_t stream, ...@@ -65,8 +66,8 @@ argument softmax(hipStream_t stream,
for(std::size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
data_idx[axis] = j; data_idx[axis] = j;
auto idx = desc_data.linear(data_idx); auto idx = desc_data.linear(data_idx);
output_ptr[idx] = output_ptr[idx] / batch_sum; output_ptr[idx] = output_ptr[idx] / batch_sum;
} }
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment