Commit 6ea3abe0 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

simplify code according to comments.

parent 711356dc
......@@ -51,7 +51,6 @@ add_library(migraphx_gpu
convolution.cpp
softmax.cpp
logsoftmax.cpp
convert.cpp
contiguous.cpp
concat.cpp
relu.cpp
......
#include <migraphx/gpu/convert.hpp>
#include <migraphx/gpu/device/convert.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_convert::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
check_shapes{inputs}.packed();
return op.compute_shape(inputs);
}
argument hip_convert::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
return device::convert(ctx.get_stream().get(), args[1], args[0]);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -6,7 +6,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument convert(hipStream_t stream, const argument& result, const argument& arg)
void convert(hipStream_t stream, const argument& result, const argument& arg)
{
result.visit([&](auto output) {
arg.visit([&](auto input) {
......@@ -16,8 +16,6 @@ argument convert(hipStream_t stream, const argument& result, const argument& arg
result.get_shape().elements())([=](auto i) { output_ptr[i] = input_ptr[i]; });
});
});
return result;
}
} // namespace device
......
......@@ -3,6 +3,8 @@
#include <migraphx/shape.hpp>
#include <migraphx/op/convert.hpp>
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/convert.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -10,13 +12,19 @@ namespace gpu {
struct context;
struct hip_convert
struct hip_convert : unary_device<hip_convert, device::convert>
{
op::convert op;
std::string name() const { return "gpu::convert"; }
shape compute_shape(std::vector<shape> inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
hip_convert(const op::convert& oper) : op(oper) { }
hip_convert(const op::convert&& oper) : op(std::move(oper)) { }
shape compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
check_shapes{inputs}.packed();
return op.compute_shape(inputs);
}
};
} // namespace gpu
......
......@@ -11,7 +11,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument convert(hipStream_t stream, const argument& result, const argument& arg);
void convert(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment