Commit 88351f31 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change the gpu implementation of the softmax.

parent 65faffa0
...@@ -30,7 +30,7 @@ struct softmax ...@@ -30,7 +30,7 @@ struct softmax
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1).standard(); check_shapes{inputs}.has(1).standard();
if(axis < 0 || axis > inputs[0].lens().size()) if(axis < 0 || axis >= inputs[0].lens().size())
{ {
MIGRAPHX_THROW("SoftMax: input axis value " + std::to_string(axis) + MIGRAPHX_THROW("SoftMax: input axis value " + std::to_string(axis) +
" is out of range"); " is out of range");
......
...@@ -517,44 +517,6 @@ struct cpu_unary ...@@ -517,44 +517,6 @@ struct cpu_unary
} }
}; };
// struct softmax2d
// {
// std::string name() const { return "cpu::softmax2d"; }
// shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); }
// argument compute(context&, const shape& output_shape, std::vector<argument> args) const
// {
// argument result{output_shape};
// visit_all(result, args[0])([&](auto output, auto input) {
// using value_type = typename decltype(input)::value_type;
// auto nb = input.get_shape().lens()[0];
// auto nc = input.get_shape().lens()[1];
// auto nh = input.get_shape().lens()[2];
// auto nw = input.get_shape().lens()[3];
// dfor(nb, nh, nw)([&](std::size_t b, std::size_t i, std::size_t j) {
// value_type cmax = std::numeric_limits<value_type>::lowest();
// for(std::size_t c = 0; c < nc; c++)
// {
// cmax = std::max(cmax, input(b, c, i, j));
// }
// for(std::size_t c = 0; c < nc; c++)
// {
// output(b, c, i, j) = std::exp(input(b, c, i, j) - cmax);
// }
// value_type sum = value_type(0);
// for(std::size_t c = 0; c < nc; c++)
// {
// sum += output(b, c, i, j);
// }
// for(std::size_t c = 0; c < nc; c++)
// {
// output(b, c, i, j) = output(b, c, i, j) / sum;
// }
// });
// });
// return result;
// }
// };
struct cpu_softmax struct cpu_softmax
{ {
op::softmax op; op::softmax op;
......
...@@ -18,51 +18,58 @@ argument softmax(hipStream_t stream, ...@@ -18,51 +18,58 @@ argument softmax(hipStream_t stream,
int axis) int axis)
{ {
auto lens = output_shape.lens(); auto lens = output_shape.lens();
std::size_t batch_size = std::accumulate( auto batch_lens = lens;
lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<std::size_t>()); size_t n_dims = lens[axis];
std::size_t n_dims = std::accumulate( batch_lens[axis] = 1;
lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<std::size_t>()); migraphx::shape batch_shape{shape::int32_type, batch_lens};
migraphx::shape comp_shape{output_shape.type(), {batch_size, n_dims}};
visit_all(args.back(), args.front())([&](auto output, auto input) { visit_all(args.back(), args.front())([&](auto output, auto input) {
const auto* input_ptr = device_cast(input.data()); const auto* input_ptr = device_cast(input.data());
auto* output_ptr = device_cast(output.data()); auto* output_ptr = device_cast(output.data());
visit_tensor_size(batch_shape.lens().size(), [&](auto n_dim) {
hip_tensor_descriptor<n_dim> desc_batch(batch_shape);
hip_tensor_descriptor<n_dim> desc_data(output_shape);
// each thread is for one item in the batch // each thread is for one item in the batch
gs_launch(stream, batch_size)([=](auto i) { gs_launch(stream, batch_shape.elements())([=](auto i) {
std::size_t row_start = i * n_dims; auto batch_idx = desc_batch.multi(i);
// get max auto data_idx = batch_idx;
auto batch_max = input_ptr[row_start]; // get max
for(std::size_t j = 1; j < n_dims; ++j) auto batch_max = input_ptr[desc_data.linear(batch_idx)];
{ for(std::size_t j = 1; j < n_dims; ++j)
auto ind = row_start + j; {
batch_max = std::max(to_hip_type(batch_max), to_hip_type(input_ptr[ind])); data_idx[axis] = j;
} batch_max = std::max(to_hip_type(batch_max), to_hip_type(input_ptr[desc_data.linear(data_idx)]));
}
for(std::size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
auto ind = row_start + j; data_idx[axis] = j;
output_ptr[ind] = input_ptr[ind] - batch_max; auto idx = desc_data.linear(data_idx);
} output_ptr[idx] = input_ptr[idx] - batch_max;
}
for(std::size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
auto ind = row_start + j; data_idx[axis] = j;
output_ptr[ind] = exp(to_hip_type(input_ptr[ind])); auto idx = desc_data.linear(data_idx);
} output_ptr[idx] = exp(to_hip_type(output_ptr[idx]));
}
auto batch_sum = output_ptr[row_start]; auto batch_sum = output_ptr[desc_data.linear(batch_idx)];
for(std::size_t j = 1; j < n_dims; ++j) for(std::size_t j = 1; j < n_dims; ++j)
{ {
auto ind = row_start + j; data_idx[axis] = j;
batch_sum += output_ptr[ind]; batch_sum += output_ptr[desc_data.linear(data_idx)];
} }
for(std::size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
auto ind = row_start + j; data_idx[axis] = j;
output_ptr[ind] /= batch_sum; auto idx = desc_data.linear(data_idx);
} output_ptr[idx] = output_ptr[idx] / batch_sum;
}
});
}); });
}); });
......
...@@ -15,7 +15,7 @@ argument hip_softmax::compute(context& ctx, ...@@ -15,7 +15,7 @@ argument hip_softmax::compute(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
return device::softmax(ctx.get_stream().get(), output_shape, args, 1); return device::softmax(ctx.get_stream().get(), output_shape, args, op.axis);
} }
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment