Commit 8724242e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

further refactoring of softmax and logsoftmax.

parent 6d1c23e9
......@@ -11,24 +11,24 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument logsoftmax(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
void logsoftmax(hipStream_t stream,
const argument& result,
const argument& arg,
int axis)
{
auto lens = output_shape.lens();
auto lens = result.get_shape().lens();
auto n_dims = lens[axis];
auto batch_lens = lens;
batch_lens[axis] = 1;
migraphx::shape batch_shape{output_shape.type(), batch_lens};
migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
visit_all(args.back(), args.front())([&](auto output, auto input) {
visit_all(result, arg)([&](auto output, auto input) {
const auto* input_ptr = device_cast(input.data());
auto* output_ptr = device_cast(output.data());
visit_tensor_size(batch_shape.lens().size(), [&](auto n_dim) {
hip_tensor_descriptor<n_dim> desc_batch(batch_shape);
hip_tensor_descriptor<n_dim> desc_data(output_shape);
hip_tensor_descriptor<n_dim> desc_data(result.get_shape());
// use one block for items in one batch.
// opt 1, load all data to lds then use the same approach as
......@@ -142,8 +142,6 @@ argument logsoftmax(hipStream_t stream,
});
});
});
return args.back();
}
} // namespace device
......
......@@ -12,23 +12,23 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument softmax(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
void softmax(hipStream_t stream,
const argument& result,
const argument& arg,
int axis)
{
auto lens = output_shape.lens();
auto lens = result.get_shape().lens();
auto batch_lens = lens;
size_t n_dims = lens[axis];
batch_lens[axis] = 1;
migraphx::shape batch_shape{output_shape.type(), batch_lens};
migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
visit_all(args.back(), args.front())([&](auto output, auto input) {
visit_all(result, arg)([&](auto output, auto input) {
const auto* input_ptr = device_cast(input.data());
auto* output_ptr = device_cast(output.data());
visit_tensor_size(batch_shape.lens().size(), [&](auto n_dim) {
hip_tensor_descriptor<n_dim> desc_batch(batch_shape);
hip_tensor_descriptor<n_dim> desc_data(output_shape);
hip_tensor_descriptor<n_dim> desc_data(result.get_shape());
// use one block for items in one batch.
const size_t max_block_size = 1024;
......@@ -139,8 +139,6 @@ argument softmax(hipStream_t stream,
});
});
});
return args.back();
}
} // namespace device
......
......@@ -10,9 +10,9 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument logsoftmax(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
void logsoftmax(hipStream_t stream,
const argument& result,
const argument& arg,
int axis);
} // namespace device
......
......@@ -10,10 +10,10 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument softmax(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
int axis);
void softmax(hipStream_t stream,
const argument& result,
const argument& arg,
int axis);
} // namespace device
} // namespace gpu
......
......@@ -16,10 +16,11 @@ shape hip_logsoftmax::compute_shape(const std::vector<shape>& inputs) const
}
argument hip_logsoftmax::compute(context& ctx,
const shape& output_shape,
const shape&,
const std::vector<argument>& args) const
{
return device::logsoftmax(ctx.get_stream().get(), output_shape, args, op.axis);
device::logsoftmax(ctx.get_stream().get(), args.back(), args.front(), op.axis);
return args.back();
}
} // namespace gpu
......
......@@ -38,10 +38,11 @@ shape hip_softmax::compute_shape(const std::vector<shape>& inputs) const
}
argument hip_softmax::compute(context& ctx,
const shape& output_shape,
const shape&,
const std::vector<argument>& args) const
{
return device::softmax(ctx.get_stream().get(), output_shape, args, op.axis);
device::softmax(ctx.get_stream().get(), args.back(), args.front(), op.axis);
return args.back();
}
} // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment