"vscode:/vscode.git/clone" did not exist on "5b74f76434cb8e9177b82fd67ac6c60450a3aca9"
Unverified Commit 2a2c146c authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge pull request #278 from ROCmSoftwarePlatform/device-refactor

Refactor device
parents 15eb1987 8be483c5
......@@ -10,10 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument logsoftmax(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
int axis);
argument logsoftmax(hipStream_t stream, argument result, argument arg, int axis);
} // namespace device
} // namespace gpu
......
......@@ -10,10 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument softmax(hipStream_t stream,
const migraphx::shape& output_shape,
std::vector<migraphx::argument> args,
int axis);
argument softmax(hipStream_t stream, argument result, argument arg, int axis);
} // namespace device
} // namespace gpu
......
......@@ -15,11 +15,10 @@ shape hip_logsoftmax::compute_shape(const std::vector<shape>& inputs) const
return op.compute_shape({inputs.at(0)});
}
argument hip_logsoftmax::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
argument
hip_logsoftmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
return device::logsoftmax(ctx.get_stream().get(), output_shape, args, op.axis);
return device::logsoftmax(ctx.get_stream().get(), args[1], args[0], op.axis);
}
} // namespace gpu
......
......@@ -37,11 +37,9 @@ shape hip_softmax::compute_shape(const std::vector<shape>& inputs) const
return op.compute_shape({inputs.at(0)});
}
argument hip_softmax::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
argument hip_softmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
return device::softmax(ctx.get_stream().get(), output_shape, args, op.axis);
return device::softmax(ctx.get_stream().get(), args[1], args[0], op.axis);
}
} // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment