Commit 6dc749f3 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add teh gpu impementation for the logsoftmax operator

parent b4517d7d
...@@ -659,8 +659,7 @@ struct cpu_logsoftmax ...@@ -659,8 +659,7 @@ struct cpu_logsoftmax
shape_for_each(output_shape, [&](auto idx) { shape_for_each(output_shape, [&](auto idx) {
auto index = compute_batch_index(idx, batch_shape, op.axis); auto index = compute_batch_index(idx, batch_shape, op.axis);
output(idx.begin(), idx.end()) = output(idx.begin(), idx.end()) -= batch_sum[index];
input(idx.begin(), idx.end()) - batch_max[index] - batch_sum[index];
}); });
}); });
......
...@@ -26,6 +26,7 @@ add_library(migraphx_device ...@@ -26,6 +26,7 @@ add_library(migraphx_device
device/atan.cpp device/atan.cpp
device/add_relu.cpp device/add_relu.cpp
device/contiguous.cpp device/contiguous.cpp
device/logsoftmax.cpp
device/mul.cpp device/mul.cpp
device/concat.cpp device/concat.cpp
device/pad.cpp device/pad.cpp
......
...@@ -19,14 +19,14 @@ argument gather(hipStream_t stream, ...@@ -19,14 +19,14 @@ argument gather(hipStream_t stream,
int axis_index = (axis < 0) ? (axis + output_shape.lens().size()) : axis; int axis_index = (axis < 0) ? (axis + output_shape.lens().size()) : axis;
visit_all(args.back(), args[0])([&](auto output, auto input) { visit_all(args.back(), args[0])([&](auto output, auto input) {
std::size_t nelements = output_shape.elements(); std::size_t nelements = output_shape.elements();
args[1].visit([&](auto indices) { args[1].visit([=](auto indices) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) { visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
const auto* indices_ptr = device_cast(indices.data()); const auto* indices_ptr = device_cast(indices.data());
auto* outptr = device_cast(output.data()); auto* outptr = device_cast(output.data());
const auto* inptr = device_cast(input.data()); const auto* inptr = device_cast(input.data());
hip_tensor_descriptor<ndim> desc_input(input.get_shape()); hip_tensor_descriptor<ndim> desc_input(input.get_shape());
hip_tensor_descriptor<ndim> desc_output(output.get_shape()); hip_tensor_descriptor<ndim> desc_output(output.get_shape());
gs_launch(stream, nelements)([=](auto i) { gs_launch(stream, nelements)([&](auto i) {
auto lens = desc_output.multi(i); auto lens = desc_output.multi(i);
lens[axis_index] = indices_ptr[lens[axis_index]]; lens[axis_index] = indices_ptr[lens[axis_index]];
outptr[i] = inptr[desc_input.linear(lens)]; outptr[i] = inptr[desc_input.linear(lens)];
......
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/logsoftmax.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument logsoftmax(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
int axis)
{
auto lens = output_shape.lens();
std::size_t batch_size = std::accumulate(
lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<std::size_t>());
std::size_t n_dims = std::accumulate(
lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
migraphx::shape comp_shape{output_shape.type(), {batch_size, n_dims}};
visit_all(args.back(), args.front())([&](auto output, auto input) {
const auto *input_ptr = device_cast(input.data());
auto *output_ptr = device_cast(output.data());
// each thread is for one item in the batch
gs_launch(stream, batch_size)([=](auto i) {
std::size_t row_start = i * n_dims;
// get max
auto batch_max = input_ptr[row_start];
for (std::size_t j = 1; j < n_dims; ++j)
{
auto ind = row_start + j;
batch_max = std::max(to_hip_type(batch_max), to_hip_type(input_ptr[ind]));
}
for (std::size_t j = 0; j < n_dims; ++j)
{
auto ind = row_start + j;
output_ptr[ind] = input_ptr[ind] - batch_max;
}
auto batch_sum = output_ptr[row_start];
for (std::size_t j = 1; j < n_dims; ++j)
{
auto ind = row_start + j;
batch_sum += ::exp(to_hip_type(output_ptr[ind]));
}
batch_sum = ::log(to_hip_type(batch_sum));
for (std::size_t j = 0; j < n_dims; ++j)
{
auto ind = row_start + j;
output_ptr[ind] -= batch_sum;
}
});
});
return args.back();
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_LOGSOFTMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_LOGSOFTMAX_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument logsoftmax(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
int axis);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -22,7 +22,7 @@ namespace migraphx { ...@@ -22,7 +22,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
struct miopen_logsoftmax struct hip_logsoftmax
{ {
op::logsoftmax op; op::logsoftmax op;
std::string name() const { return "gpu::logsoftmax"; } std::string name() const { return "gpu::logsoftmax"; }
......
#include <migraphx/gpu/logsoftmax.hpp> #include <migraphx/gpu/logsoftmax.hpp>
#include <migraphx/gpu/device/log.hpp> #include <migraphx/gpu/device/logsoftmax.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
...@@ -9,41 +9,17 @@ namespace migraphx { ...@@ -9,41 +9,17 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
shape miopen_logsoftmax::compute_shape(const std::vector<shape>& inputs) const shape hip_logsoftmax::compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2).standard(); check_shapes{inputs, *this}.has(2).standard();
return op.compute_shape({inputs.at(0)}); return op.compute_shape({inputs.at(0)});
} }
argument miopen_logsoftmax::compute(context& ctx, argument hip_logsoftmax::compute(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
float alpha = 1; return device::logsoftmax(ctx.get_stream().get(), output_shape, args, op.axis);
float beta = 0;
// temporarily reshape the input to a(0)...a(axis-1)
// and a(axis)....a(n)
auto lens = output_shape.lens();
std::size_t batch_size = std::accumulate(
lens.begin(), lens.begin() + op.axis, std::size_t{1}, std::multiplies<std::size_t>());
std::size_t n_dims = std::accumulate(
lens.begin() + op.axis, lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
migraphx::shape comp_shape{output_shape.type(), {batch_size, n_dims, 1, 1}};
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
miopenSoftmaxForward(ctx.get_stream().get_miopen(),
&alpha,
x_desc.get(),
args[0].implicit(),
&beta,
y_desc.get(),
args[1].implicit());
// call the device::log function to perform the log operation
device::log(ctx.get_stream().get(), args[1], args[0]);
return args[1];
} }
} // namespace gpu } // namespace gpu
......
...@@ -98,7 +98,7 @@ struct miopen_apply ...@@ -98,7 +98,7 @@ struct miopen_apply
add_extend_op<miopen_contiguous, op::contiguous>("contiguous"); add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
add_extend_op<hip_concat, op::concat>("concat"); add_extend_op<hip_concat, op::concat>("concat");
add_extend_op<miopen_softmax, op::softmax>("softmax"); add_extend_op<miopen_softmax, op::softmax>("softmax");
add_extend_op<miopen_logsoftmax, op::logsoftmax>("logsoftmax"); add_extend_op<hip_logsoftmax, op::logsoftmax>("logsoftmax");
add_extend_op<hip_gather, op::gather>("gather"); add_extend_op<hip_gather, op::gather>("gather");
add_extend_op<hip_pad, op::pad>("pad"); add_extend_op<hip_pad, op::pad>("pad");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment