"docs/static/vscode:/vscode.git/clone" did not exist on "25c4c3b5fa472a07cab61a964a1ae632b4ee5925"
Commit 7dd2cf04 authored by Paul's avatar Paul
Browse files

Refactor logsoftmax

parent d15edcb6
...@@ -12,65 +12,55 @@ namespace gpu { ...@@ -12,65 +12,55 @@ namespace gpu {
namespace device { namespace device {
argument logsoftmax(hipStream_t stream, argument logsoftmax(hipStream_t stream,
const migraphx::shape& output_shape, argument result,
std::vector<migraphx::argument> args, argument arg,
int axis) int axis)
{ {
auto lens = output_shape.lens(); auto lens = result.get_shape().lens();
auto num_in_batch = lens[axis]; auto num_in_batch = lens[axis];
auto batch_lens = lens; auto batch_lens = lens;
batch_lens[axis] = 1; batch_lens[axis] = 1;
migraphx::shape batch_shape{output_shape.type(), batch_lens}; shape batch_shape{result.get_shape().type(), batch_lens};
visit_all(args.back(), args.front())([&](auto output, auto input) { hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
const auto* input_ptr = device_cast(input.data());
auto* output_ptr = device_cast(output.data());
visit_tensor_size(batch_shape.lens().size(), [&](auto n_dim) {
hip_tensor_descriptor<n_dim> desc_batch(batch_shape);
hip_tensor_descriptor<n_dim> desc_data(output_shape);
// each thread is for one item in the batch // each thread is for one item in the batch
gs_launch(stream, batch_shape.elements())([=](auto i) { gs_launch(stream, batch_shape.elements())([=](auto i) {
auto batch_idx = desc_batch.multi(i); auto batch_idx = batch.multi(i);
auto data_idx = batch_idx; auto data_idx = batch_idx;
// get max // get max
auto batch_max = input_ptr[desc_data.linear(batch_idx)]; auto batch_max = input[batch_idx];
for(std::size_t j = 1; j < num_in_batch; ++j) for(std::size_t j = 1; j < num_in_batch; ++j)
{ {
data_idx[axis] = j; data_idx[axis] = j;
size_t idx = desc_data.linear(data_idx); batch_max = std::max(to_hip_type(batch_max), to_hip_type(input[data_idx]));
batch_max = std::max(to_hip_type(batch_max), to_hip_type(input_ptr[idx]));
} }
for(std::size_t j = 0; j < num_in_batch; ++j) for(std::size_t j = 0; j < num_in_batch; ++j)
{ {
data_idx[axis] = j; data_idx[axis] = j;
size_t idx = desc_data.linear(data_idx); output[data_idx] = input[data_idx] - batch_max;
output_ptr[idx] = input_ptr[idx] - batch_max;
} }
auto batch_sum = ::exp(to_hip_type(output_ptr[desc_data.linear(batch_idx)])); auto batch_sum = ::exp(to_hip_type(output[batch_idx]));
for(std::size_t j = 1; j < num_in_batch; ++j) for(std::size_t j = 1; j < num_in_batch; ++j)
{ {
data_idx[axis] = j; data_idx[axis] = j;
size_t idx = desc_data.linear(data_idx); batch_sum += ::exp(to_hip_type(output[data_idx]));
batch_sum += ::exp(to_hip_type(output_ptr[idx]));
} }
batch_sum = ::log(to_hip_type(batch_sum)); batch_sum = ::log(to_hip_type(batch_sum));
for(std::size_t j = 0; j < num_in_batch; ++j) for(std::size_t j = 0; j < num_in_batch; ++j)
{ {
data_idx[axis] = j; data_idx[axis] = j;
size_t idx = desc_data.linear(data_idx); output[data_idx] -= batch_sum;
output_ptr[idx] -= batch_sum;
} }
}); });
}); });
});
return args.back(); return result;
} }
} // namespace device } // namespace device
......
...@@ -11,8 +11,8 @@ namespace gpu { ...@@ -11,8 +11,8 @@ namespace gpu {
namespace device { namespace device {
argument logsoftmax(hipStream_t stream, argument logsoftmax(hipStream_t stream,
const migraphx::shape& output_shape, argument result,
std::vector<migraphx::argument> args, argument arg,
int axis); int axis);
} // namespace device } // namespace device
......
...@@ -16,10 +16,10 @@ shape hip_logsoftmax::compute_shape(const std::vector<shape>& inputs) const ...@@ -16,10 +16,10 @@ shape hip_logsoftmax::compute_shape(const std::vector<shape>& inputs) const
} }
argument hip_logsoftmax::compute(context& ctx, argument hip_logsoftmax::compute(context& ctx,
const shape& output_shape, const shape&,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
return device::logsoftmax(ctx.get_stream().get(), output_shape, args, op.axis); return device::logsoftmax(ctx.get_stream().get(), args[1], args[0], op.axis);
} }
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment