Commit 8724242e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

further refactoring of softmax and logsoftmax.

parent 6d1c23e9
...@@ -11,24 +11,24 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,24 +11,24 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument logsoftmax(hipStream_t stream, void logsoftmax(hipStream_t stream,
const migraphx::shape& output_shape, const argument& result,
std::vector<migraphx::argument> args, const argument& arg,
int axis) int axis)
{ {
auto lens = output_shape.lens(); auto lens = result.get_shape().lens();
auto n_dims = lens[axis]; auto n_dims = lens[axis];
auto batch_lens = lens; auto batch_lens = lens;
batch_lens[axis] = 1; batch_lens[axis] = 1;
migraphx::shape batch_shape{output_shape.type(), batch_lens}; migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
visit_all(args.back(), args.front())([&](auto output, auto input) { visit_all(result, arg)([&](auto output, auto input) {
const auto* input_ptr = device_cast(input.data()); const auto* input_ptr = device_cast(input.data());
auto* output_ptr = device_cast(output.data()); auto* output_ptr = device_cast(output.data());
visit_tensor_size(batch_shape.lens().size(), [&](auto n_dim) { visit_tensor_size(batch_shape.lens().size(), [&](auto n_dim) {
hip_tensor_descriptor<n_dim> desc_batch(batch_shape); hip_tensor_descriptor<n_dim> desc_batch(batch_shape);
hip_tensor_descriptor<n_dim> desc_data(output_shape); hip_tensor_descriptor<n_dim> desc_data(result.get_shape());
// use one block for items in one batch. // use one block for items in one batch.
// opt 1, load all data to lds then use the same approach as // opt 1, load all data to lds then use the same approach as
...@@ -142,8 +142,6 @@ argument logsoftmax(hipStream_t stream, ...@@ -142,8 +142,6 @@ argument logsoftmax(hipStream_t stream,
}); });
}); });
}); });
return args.back();
} }
} // namespace device } // namespace device
......
...@@ -12,23 +12,23 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -12,23 +12,23 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument softmax(hipStream_t stream, void softmax(hipStream_t stream,
const migraphx::shape& output_shape, const argument& result,
std::vector<migraphx::argument> args, const argument& arg,
int axis) int axis)
{ {
auto lens = output_shape.lens(); auto lens = result.get_shape().lens();
auto batch_lens = lens; auto batch_lens = lens;
size_t n_dims = lens[axis]; size_t n_dims = lens[axis];
batch_lens[axis] = 1; batch_lens[axis] = 1;
migraphx::shape batch_shape{output_shape.type(), batch_lens}; migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
visit_all(args.back(), args.front())([&](auto output, auto input) { visit_all(result, arg)([&](auto output, auto input) {
const auto* input_ptr = device_cast(input.data()); const auto* input_ptr = device_cast(input.data());
auto* output_ptr = device_cast(output.data()); auto* output_ptr = device_cast(output.data());
visit_tensor_size(batch_shape.lens().size(), [&](auto n_dim) { visit_tensor_size(batch_shape.lens().size(), [&](auto n_dim) {
hip_tensor_descriptor<n_dim> desc_batch(batch_shape); hip_tensor_descriptor<n_dim> desc_batch(batch_shape);
hip_tensor_descriptor<n_dim> desc_data(output_shape); hip_tensor_descriptor<n_dim> desc_data(result.get_shape());
// use one block for items in one batch. // use one block for items in one batch.
const size_t max_block_size = 1024; const size_t max_block_size = 1024;
...@@ -139,8 +139,6 @@ argument softmax(hipStream_t stream, ...@@ -139,8 +139,6 @@ argument softmax(hipStream_t stream,
}); });
}); });
}); });
return args.back();
} }
} // namespace device } // namespace device
......
...@@ -10,9 +10,9 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,9 +10,9 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument logsoftmax(hipStream_t stream, void logsoftmax(hipStream_t stream,
const migraphx::shape& output_shape, const argument& result,
std::vector<migraphx::argument> args, const argument& arg,
int axis); int axis);
} // namespace device } // namespace device
......
...@@ -10,9 +10,9 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,9 +10,9 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument softmax(hipStream_t stream, void softmax(hipStream_t stream,
const migraphx::shape& output_shape, const argument& result,
std::vector<migraphx::argument> args, const argument& arg,
int axis); int axis);
} // namespace device } // namespace device
......
...@@ -16,10 +16,11 @@ shape hip_logsoftmax::compute_shape(const std::vector<shape>& inputs) const ...@@ -16,10 +16,11 @@ shape hip_logsoftmax::compute_shape(const std::vector<shape>& inputs) const
} }
argument hip_logsoftmax::compute(context& ctx, argument hip_logsoftmax::compute(context& ctx,
const shape& output_shape, const shape&,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
return device::logsoftmax(ctx.get_stream().get(), output_shape, args, op.axis); device::logsoftmax(ctx.get_stream().get(), args.back(), args.front(), op.axis);
return args.back();
} }
} // namespace gpu } // namespace gpu
......
...@@ -38,10 +38,11 @@ shape hip_softmax::compute_shape(const std::vector<shape>& inputs) const ...@@ -38,10 +38,11 @@ shape hip_softmax::compute_shape(const std::vector<shape>& inputs) const
} }
argument hip_softmax::compute(context& ctx, argument hip_softmax::compute(context& ctx,
const shape& output_shape, const shape&,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
return device::softmax(ctx.get_stream().get(), output_shape, args, op.axis); device::softmax(ctx.get_stream().get(), args.back(), args.front(), op.axis);
return args.back();
} }
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment