Commit c9e391fe authored by Paul's avatar Paul
Browse files

Refactor softmax

parent 8399b302
......@@ -12,69 +12,51 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument softmax(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
int axis)
argument softmax(hipStream_t stream, argument result, argument arg, int axis)
{
auto lens = output_shape.lens();
auto lens = result.get_shape().lens();
auto batch_lens = lens;
size_t n_dims = lens[axis];
batch_lens[axis] = 1;
migraphx::shape batch_shape{shape::int32_type, batch_lens};
shape batch_shape{result.get_shape().type(), batch_lens};
visit_all(args.back(), args.front())([&](auto output, auto input) {
const auto* input_ptr = device_cast(input.data());
auto* output_ptr = device_cast(output.data());
visit_tensor_size(batch_shape.lens().size(), [&](auto n_dim) {
hip_tensor_descriptor<n_dim> desc_batch(batch_shape);
hip_tensor_descriptor<n_dim> desc_data(output_shape);
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
// each thread is for one item in the batch
gs_launch(stream, batch_shape.elements())([=](auto i) {
auto batch_idx = desc_batch.multi(i);
auto data_idx = batch_idx;
// get max
auto batch_max = input_ptr[desc_data.linear(batch_idx)];
for(std::size_t j = 1; j < n_dims; ++j)
{
data_idx[axis] = j;
batch_max = std::max(to_hip_type(batch_max),
to_hip_type(input_ptr[desc_data.linear(data_idx)]));
}
// each thread is for one item in the batch
gs_launch(stream, batch_shape.elements())([=](auto i) {
auto batch_idx = batch.multi(i);
auto data_idx = batch_idx;
for(std::size_t j = 0; j < n_dims; ++j)
{
data_idx[axis] = j;
auto idx = desc_data.linear(data_idx);
output_ptr[idx] = input_ptr[idx] - batch_max;
}
// get max
auto batch_max = input[batch_idx];
for(std::size_t j = 1; j < n_dims; ++j)
{
data_idx[axis] = j;
batch_max = std::max(to_hip_type(batch_max), to_hip_type(input[data_idx]));
}
for(std::size_t j = 0; j < n_dims; ++j)
{
data_idx[axis] = j;
auto idx = desc_data.linear(data_idx);
output_ptr[idx] = exp(to_hip_type(output_ptr[idx]));
}
for(std::size_t j = 0; j < n_dims; ++j)
{
data_idx[axis] = j;
output[data_idx] = exp(to_hip_type(input[data_idx] - batch_max));
}
auto batch_sum = output_ptr[desc_data.linear(batch_idx)];
for(std::size_t j = 1; j < n_dims; ++j)
{
data_idx[axis] = j;
batch_sum += output_ptr[desc_data.linear(data_idx)];
}
auto batch_sum = output[batch_idx];
for(std::size_t j = 1; j < n_dims; ++j)
{
data_idx[axis] = j;
batch_sum += output[data_idx];
}
for(std::size_t j = 0; j < n_dims; ++j)
{
data_idx[axis] = j;
auto idx = desc_data.linear(data_idx);
output_ptr[idx] = output_ptr[idx] / batch_sum;
}
});
for(std::size_t j = 0; j < n_dims; ++j)
{
data_idx[axis] = j;
output[data_idx] = output[data_idx] / batch_sum;
}
});
});
return args.back();
return result;
}
} // namespace device
......
......@@ -10,10 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument softmax(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
int axis);
argument softmax(hipStream_t stream, argument result, argument arg, int axis);
} // namespace device
} // namespace gpu
......
......@@ -38,10 +38,10 @@ shape hip_softmax::compute_shape(const std::vector<shape>& inputs) const
}
argument hip_softmax::compute(context& ctx,
const shape& output_shape,
const shape&,
const std::vector<argument>& args) const
{
return device::softmax(ctx.get_stream().get(), output_shape, args, op.axis);
return device::softmax(ctx.get_stream().get(), args[1], args[0], op.axis);
}
} // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment