"awq/vscode:/vscode.git/clone" did not exist on "0e77dbc1f8c434da6f814a115a9efb147620f5b9"
Commit 88351f31 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change the gpu implementation of the softmax.

parent 65faffa0
......@@ -30,7 +30,7 @@ struct softmax
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1).standard();
if(axis < 0 || axis > inputs[0].lens().size())
if(axis < 0 || axis >= inputs[0].lens().size())
{
MIGRAPHX_THROW("SoftMax: input axis value " + std::to_string(axis) +
" is out of range");
......
......@@ -517,44 +517,6 @@ struct cpu_unary
}
};
// struct softmax2d
// {
// std::string name() const { return "cpu::softmax2d"; }
// shape compute_shape(const std::vector<shape>& inputs) const { return inputs.front(); }
// argument compute(context&, const shape& output_shape, std::vector<argument> args) const
// {
// argument result{output_shape};
// visit_all(result, args[0])([&](auto output, auto input) {
// using value_type = typename decltype(input)::value_type;
// auto nb = input.get_shape().lens()[0];
// auto nc = input.get_shape().lens()[1];
// auto nh = input.get_shape().lens()[2];
// auto nw = input.get_shape().lens()[3];
// dfor(nb, nh, nw)([&](std::size_t b, std::size_t i, std::size_t j) {
// value_type cmax = std::numeric_limits<value_type>::lowest();
// for(std::size_t c = 0; c < nc; c++)
// {
// cmax = std::max(cmax, input(b, c, i, j));
// }
// for(std::size_t c = 0; c < nc; c++)
// {
// output(b, c, i, j) = std::exp(input(b, c, i, j) - cmax);
// }
// value_type sum = value_type(0);
// for(std::size_t c = 0; c < nc; c++)
// {
// sum += output(b, c, i, j);
// }
// for(std::size_t c = 0; c < nc; c++)
// {
// output(b, c, i, j) = output(b, c, i, j) / sum;
// }
// });
// });
// return result;
// }
// };
struct cpu_softmax
{
op::softmax op;
......
......@@ -18,51 +18,58 @@ argument softmax(hipStream_t stream,
int axis)
{
auto lens = output_shape.lens();
std::size_t batch_size = std::accumulate(
lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<std::size_t>());
std::size_t n_dims = std::accumulate(
lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
migraphx::shape comp_shape{output_shape.type(), {batch_size, n_dims}};
auto batch_lens = lens;
size_t n_dims = lens[axis];
batch_lens[axis] = 1;
migraphx::shape batch_shape{shape::int32_type, batch_lens};
visit_all(args.back(), args.front())([&](auto output, auto input) {
const auto* input_ptr = device_cast(input.data());
auto* output_ptr = device_cast(output.data());
visit_tensor_size(batch_shape.lens().size(), [&](auto n_dim) {
hip_tensor_descriptor<n_dim> desc_batch(batch_shape);
hip_tensor_descriptor<n_dim> desc_data(output_shape);
// each thread is for one item in the batch
gs_launch(stream, batch_size)([=](auto i) {
std::size_t row_start = i * n_dims;
// get max
auto batch_max = input_ptr[row_start];
for(std::size_t j = 1; j < n_dims; ++j)
{
auto ind = row_start + j;
batch_max = std::max(to_hip_type(batch_max), to_hip_type(input_ptr[ind]));
}
// each thread is for one item in the batch
gs_launch(stream, batch_shape.elements())([=](auto i) {
auto batch_idx = desc_batch.multi(i);
auto data_idx = batch_idx;
// get max
auto batch_max = input_ptr[desc_data.linear(batch_idx)];
for(std::size_t j = 1; j < n_dims; ++j)
{
data_idx[axis] = j;
batch_max = std::max(to_hip_type(batch_max), to_hip_type(input_ptr[desc_data.linear(data_idx)]));
}
for(std::size_t j = 0; j < n_dims; ++j)
{
auto ind = row_start + j;
output_ptr[ind] = input_ptr[ind] - batch_max;
}
for(std::size_t j = 0; j < n_dims; ++j)
{
data_idx[axis] = j;
auto idx = desc_data.linear(data_idx);
output_ptr[idx] = input_ptr[idx] - batch_max;
}
for(std::size_t j = 0; j < n_dims; ++j)
{
auto ind = row_start + j;
output_ptr[ind] = exp(to_hip_type(input_ptr[ind]));
}
for(std::size_t j = 0; j < n_dims; ++j)
{
data_idx[axis] = j;
auto idx = desc_data.linear(data_idx);
output_ptr[idx] = exp(to_hip_type(output_ptr[idx]));
}
auto batch_sum = output_ptr[row_start];
for(std::size_t j = 1; j < n_dims; ++j)
{
auto ind = row_start + j;
batch_sum += output_ptr[ind];
}
auto batch_sum = output_ptr[desc_data.linear(batch_idx)];
for(std::size_t j = 1; j < n_dims; ++j)
{
data_idx[axis] = j;
batch_sum += output_ptr[desc_data.linear(data_idx)];
}
for(std::size_t j = 0; j < n_dims; ++j)
{
auto ind = row_start + j;
output_ptr[ind] /= batch_sum;
}
for(std::size_t j = 0; j < n_dims; ++j)
{
data_idx[axis] = j;
auto idx = desc_data.linear(data_idx);
output_ptr[idx] = output_ptr[idx] / batch_sum;
}
});
});
});
......
......@@ -15,7 +15,7 @@ argument hip_softmax::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
return device::softmax(ctx.get_stream().get(), output_shape, args, 1);
return device::softmax(ctx.get_stream().get(), output_shape, args, op.axis);
}
} // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment