"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "92a3ae1bcbad99042e7d32a6461095a5d6d1b646"
Commit 6ea3abe0 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

simplify code according to comments.

parent 711356dc
...@@ -51,7 +51,6 @@ add_library(migraphx_gpu ...@@ -51,7 +51,6 @@ add_library(migraphx_gpu
convolution.cpp convolution.cpp
softmax.cpp softmax.cpp
logsoftmax.cpp logsoftmax.cpp
convert.cpp
contiguous.cpp contiguous.cpp
concat.cpp concat.cpp
relu.cpp relu.cpp
......
#include <migraphx/gpu/convert.hpp>
#include <migraphx/gpu/device/convert.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_convert::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
check_shapes{inputs}.packed();
return op.compute_shape(inputs);
}
argument hip_convert::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
return device::convert(ctx.get_stream().get(), args[1], args[0]);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -6,7 +6,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -6,7 +6,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument convert(hipStream_t stream, const argument& result, const argument& arg) void convert(hipStream_t stream, const argument& result, const argument& arg)
{ {
result.visit([&](auto output) { result.visit([&](auto output) {
arg.visit([&](auto input) { arg.visit([&](auto input) {
...@@ -16,8 +16,6 @@ argument convert(hipStream_t stream, const argument& result, const argument& arg ...@@ -16,8 +16,6 @@ argument convert(hipStream_t stream, const argument& result, const argument& arg
result.get_shape().elements())([=](auto i) { output_ptr[i] = input_ptr[i]; }); result.get_shape().elements())([=](auto i) { output_ptr[i] = input_ptr[i]; });
}); });
}); });
return result;
} }
} // namespace device } // namespace device
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/op/convert.hpp> #include <migraphx/op/convert.hpp>
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/convert.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -10,13 +12,19 @@ namespace gpu { ...@@ -10,13 +12,19 @@ namespace gpu {
struct context; struct context;
struct hip_convert struct hip_convert : unary_device<hip_convert, device::convert>
{ {
op::convert op; op::convert op;
std::string name() const { return "gpu::convert"; }
shape compute_shape(std::vector<shape> inputs) const; hip_convert(const op::convert& oper) : op(oper) { }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const; hip_convert(const op::convert&& oper) : op(std::move(oper)) { }
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
shape compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
check_shapes{inputs}.packed();
return op.compute_shape(inputs);
}
}; };
} // namespace gpu } // namespace gpu
......
...@@ -11,7 +11,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,7 +11,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument convert(hipStream_t stream, const argument& result, const argument& arg); void convert(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment