"vscode:/vscode.git/clone" did not exist on "96ff131ba3d158b391bad5590f72ef1744211e40"
Commit ccb23388 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Combine operator implementation to simplify code.

parent 5e8db473
#ifndef MIGRAPHX_GUARD_RTGLIB_UNARY_HPP
#define MIGRAPHX_GUARD_RTGLIB_UNARY_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
template<class Derived>
struct oper
{
std::string name() const { return get_type_name<Derived>(); }
};
template<class Derived, void(*F)(hipStream_t, const argument&, const argument&)>
struct unary_device : oper<Derived>
{
shape compute_shape(const std::vector<shape>& inputs) const {
check_shapes{inputs, *this}.has(2);
return inputs.at(0);
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const {
F(ctx.get_stream().get(), args[1], args[0]);
return args[1];
}
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
};
template<class Derived, void(*F)(hipStream_t, const argument&, const argument&, const argument&)>
struct binary_device : oper<Derived>
{
shape compute_shape(const std::vector<shape>& inputs) const {
check_shapes{inputs, *this}.has(3);
return inputs.at(0);
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const {
F(ctx.get_stream().get(), args[2], args[1], args[0]);
return args[2];
}
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#include <migraphx/gpu/max.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_max::compute_shape(const std::vector<shape>& inputs) const
{
// check_shapes{inputs, *this}.has(3).standard();
check_shapes{inputs, *this}.has(3);
return inputs.at(0);
}
argument hip_max::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::max(ctx.get_stream().get(), args[2], args[0], args[1]);
return args[2];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/min.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_min::compute_shape(const std::vector<shape>& inputs) const
{
// check_shapes{inputs, *this}.has(3).standard();
check_shapes{inputs, *this}.has(3);
return inputs.at(0);
}
argument hip_min::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::min(ctx.get_stream().get(), args[2], args[0], args[1]);
return args[2];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/mul.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_mul::compute_shape(const std::vector<shape>& inputs) const
{
// check_shapes{inputs, *this}.has(3).standard();
check_shapes{inputs, *this}.has(3);
return inputs.at(0);
}
argument hip_mul::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::mul(ctx.get_stream().get(), args[2], args[0], args[1]);
return args[2];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/sin.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_sin::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2);
return inputs.at(0);
}
argument hip_sin::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::sin(ctx.get_stream().get(), args[1], args[0]);
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/sinh.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_sinh::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2);
return inputs.at(0);
}
argument hip_sinh::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::sinh(ctx.get_stream().get(), args[1], args[0]);
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/tan.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_tan::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2);
return inputs.at(0);
}
argument hip_tan::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::tan(ctx.get_stream().get(), args[1], args[0]);
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment