Unverified Commit e8be8548 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Register all operators in migraphx (#604)

* Register ops for main migraphx

* Formatting

* Register cpu ops

* Formatting

* Show list of operators in the driver

* Formatting

* Simplify regiter

* Try to register gpu ops

* Fix compiler errors

* Register rest of the gpu operators

* Add some tests

* Formatting

* Fix gcc compiler warnings

* Formatting

* Fix tidy warnings

* Fix compile error

* Use correct op name

* Register layer norm

* Use const ref

* Make run const
parent 2c5d5fee
#ifndef MIGRAPHX_GUARD_OPERATORS_SLICE_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_SLICE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SLICE_HPP #define MIGRAPHX_GUARD_OPERATORS_SLICE_HPP
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP #define MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#define MIGRAPHX_GUARD_OPERATORS_SQUEEZE_HPP #define MIGRAPHX_GUARD_OPERATORS_SQUEEZE_HPP
#include <array> #include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <array> #include <array>
#include <migraphx/op/binary.hpp> #include <migraphx/op/binary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <array> #include <array>
#include <migraphx/op/unary.hpp> #include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <array> #include <array>
#include <migraphx/op/unary.hpp> #include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
......
...@@ -2,12 +2,9 @@ ...@@ -2,12 +2,9 @@
#define MIGRAPHX_GUARD_OPERATORS_TRANSPOSE_HPP #define MIGRAPHX_GUARD_OPERATORS_TRANSPOSE_HPP
#include <array> #include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/argument.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/functional.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
......
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
#define MIGRAPHX_GUARD_OPERATORS_UNARY_HPP #define MIGRAPHX_GUARD_OPERATORS_UNARY_HPP
#include <migraphx/op/name.hpp> #include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
#ifndef MIGRAPHX_GUARD_RTGLIB_UNDEFINED_HPP
#define MIGRAPHX_GUARD_RTGLIB_UNDEFINED_HPP
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct undefined
{
std::string name() const { return "undefined"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(0);
return {};
}
argument compute(const shape&, const std::vector<argument>&) const { return {{}, nullptr}; }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_ABNORMAL_OPS_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_UNKNOWN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ABNORMAL_OPS_HPP #define MIGRAPHX_GUARD_RTGLIB_UNKNOWN_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <cmath> #include <migraphx/argument.hpp>
#include <utility> #include <migraphx/check_shapes.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct not_computable
{
argument compute(const shape&, const std::vector<argument>&) const
{
MIGRAPHX_THROW("not computable");
}
};
struct undefined
{
std::string name() const { return "undefined"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(0);
return {};
}
argument compute(const shape&, const std::vector<argument>&) const { return {{}, nullptr}; }
};
struct unknown struct unknown
{ {
std::string op; std::string op;
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#define MIGRAPHX_GUARD_OPERATORS_UNSQUEEZE_HPP #define MIGRAPHX_GUARD_OPERATORS_UNSQUEEZE_HPP
#include <array> #include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
......
#ifndef MIGRAPHX_GUARD_OPERATORS_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_HPP
#define MIGRAPHX_GUARD_OPERATORS_HPP #define MIGRAPHX_GUARD_OPERATORS_HPP
#include <migraphx/op/abnormal_ops.hpp>
#include <migraphx/op/abs.hpp> #include <migraphx/op/abs.hpp>
#include <migraphx/op/acos.hpp> #include <migraphx/op/acos.hpp>
#include <migraphx/op/acosh.hpp> #include <migraphx/op/acosh.hpp>
...@@ -13,7 +12,7 @@ ...@@ -13,7 +12,7 @@
#include <migraphx/op/as_shape.hpp> #include <migraphx/op/as_shape.hpp>
#include <migraphx/op/atan.hpp> #include <migraphx/op/atan.hpp>
#include <migraphx/op/atanh.hpp> #include <migraphx/op/atanh.hpp>
#include <migraphx/op/batch_norm.hpp> #include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/binary.hpp> #include <migraphx/op/binary.hpp>
#include <migraphx/op/broadcast.hpp> #include <migraphx/op/broadcast.hpp>
#include <migraphx/op/capture.hpp> #include <migraphx/op/capture.hpp>
...@@ -86,6 +85,8 @@ ...@@ -86,6 +85,8 @@
#include <migraphx/op/tan.hpp> #include <migraphx/op/tan.hpp>
#include <migraphx/op/transpose.hpp> #include <migraphx/op/transpose.hpp>
#include <migraphx/op/unary.hpp> #include <migraphx/op/unary.hpp>
#include <migraphx/op/undefined.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/op/unsqueeze.hpp> #include <migraphx/op/unsqueeze.hpp>
#endif #endif
#ifndef MIGRAPHX_GUARD_RTGLIB_REGISTER_OP_HPP
#define MIGRAPHX_GUARD_RTGLIB_REGISTER_OP_HPP
#include <migraphx/config.hpp>
#include <migraphx/operation.hpp>
#include <cstring>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void register_op(const operation& op);
operation load_op(const std::string& name);
std::vector<std::string> get_operators();
template <class T>
int register_op()
{
register_op(T{});
return 0;
}
template <class T>
struct auto_register_op
{
static int static_register;
// This typedef ensures that the static member will be instantiated if
// the class itself is instantiated
using static_register_type =
std::integral_constant<decltype(&static_register), &static_register>;
};
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wglobal-constructors"
#endif
template <class T>
int auto_register_op<T>::static_register = register_op<T>(); // NOLINT
#ifdef __clang__
#pragma clang diagnostic pop
#endif
#define MIGRAPHX_REGISTER_OP_NAME_DETAIL(x) migraphx_auto_register_##x
#define MIGRAPHX_REGISTER_OP_NAME(x) MIGRAPHX_REGISTER_OP_NAME_DETAIL(x)
#define MIGRAPHX_REGISTER_OP(...) \
void MIGRAPHX_REGISTER_OP_NAME(__LINE__)(migraphx::auto_register_op<__VA_ARGS__> x = \
migraphx::auto_register_op<__VA_ARGS__>{}) \
__attribute__((unused));
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#include <migraphx/register_op.hpp>
#include <unordered_map>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::unordered_map<std::string, operation>& op_map()
{
static std::unordered_map<std::string, operation> m;
return m;
}
void register_op(const operation& op) { op_map()[op.name()] = op; }
operation load_op(const std::string& name) { return op_map().at(name); }
std::vector<std::string> get_operators()
{
std::vector<std::string> result;
std::transform(op_map().begin(), op_map().end(), std::back_inserter(result), [&](auto&& p) {
return p.first;
});
std::sort(result.begin(), result.end());
return result;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/rewrite_batchnorm.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/op/batch_norm.hpp> #include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/broadcast.hpp> #include <migraphx/op/broadcast.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp> #include <migraphx/op/mul.hpp>
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include <migraphx/cpu/lowering.hpp> #include <migraphx/cpu/lowering.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/op/batch_norm.hpp> #include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp> #include <migraphx/op/deconvolution.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <migraphx/par_dfor.hpp> #include <migraphx/par_dfor.hpp>
#include <migraphx/clamp.hpp> #include <migraphx/clamp.hpp>
#include <migraphx/cpu/gemm.hpp> #include <migraphx/cpu/gemm.hpp>
#include <migraphx/register_op.hpp>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <iostream> #include <iostream>
...@@ -122,6 +123,7 @@ struct cpu_batch_norm_inference ...@@ -122,6 +123,7 @@ struct cpu_batch_norm_inference
return output; return output;
} }
}; };
MIGRAPHX_REGISTER_OP(cpu_batch_norm_inference)
struct cpu_lrn struct cpu_lrn
{ {
...@@ -166,6 +168,7 @@ struct cpu_lrn ...@@ -166,6 +168,7 @@ struct cpu_lrn
return result; return result;
} }
}; };
MIGRAPHX_REGISTER_OP(cpu_lrn)
template <class V, class T, class... Ts> template <class V, class T, class... Ts>
void visit_quantize_impl(V&& v, T&& x, Ts&&... xs) void visit_quantize_impl(V&& v, T&& x, Ts&&... xs)
...@@ -183,8 +186,12 @@ auto visit_quantize(T&& x, Ts&&... xs) ...@@ -183,8 +186,12 @@ auto visit_quantize(T&& x, Ts&&... xs)
} }
template <class Op> template <class Op>
struct cpu_convolution struct cpu_convolution : auto_register_op<cpu_convolution<Op>>
{ {
cpu_convolution() = default;
cpu_convolution(Op pop) : op(std::move(pop)) {}
Op op; Op op;
template <class Self, class F> template <class Self, class F>
...@@ -256,8 +263,12 @@ struct cpu_convolution ...@@ -256,8 +263,12 @@ struct cpu_convolution
}; };
template <class Op> template <class Op>
struct cpu_deconvolution struct cpu_deconvolution : auto_register_op<cpu_deconvolution<Op>>
{ {
cpu_deconvolution() = default;
cpu_deconvolution(Op pop) : op(std::move(pop)) {}
Op op; Op op;
template <class Self, class F> template <class Self, class F>
...@@ -405,6 +416,7 @@ struct cpu_im2col ...@@ -405,6 +416,7 @@ struct cpu_im2col
return result; return result;
} }
}; };
MIGRAPHX_REGISTER_OP(cpu_im2col)
struct max_pool struct max_pool
{ {
...@@ -431,8 +443,12 @@ struct avg_pool ...@@ -431,8 +443,12 @@ struct avg_pool
}; };
template <class Op> template <class Op>
struct cpu_pooling struct cpu_pooling : auto_register_op<cpu_pooling<Op>>
{ {
cpu_pooling() = default;
cpu_pooling(op::pooling pop) : op(std::move(pop)) {}
op::pooling op; op::pooling op;
template <class Self, class F> template <class Self, class F>
...@@ -501,21 +517,19 @@ struct cpu_op ...@@ -501,21 +517,19 @@ struct cpu_op
{ {
return migraphx::reflect(self.op, f); return migraphx::reflect(self.op, f);
} }
std::string name() const { return "cpu::" + op.name(); } std::string name() const { return "cpu::op"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, const std::vector<argument>& args) const argument compute(context&, const shape& output_shape, const std::vector<argument>& args) const
{ {
return op.compute(output_shape, args); return op.compute(output_shape, args);
} }
friend bool operator==(const cpu_op& x, const cpu_op& y) { return x.op == y.op; } friend std::ostream& operator<<(std::ostream& os, const cpu_op& x)
friend bool operator==(const cpu_op& x, const operation& y)
{ {
if(x.name() != y.name()) os << "cpu::" << x.op;
return false; return os;
return x == any_cast<cpu_op>(y);
} }
friend bool operator==(const operation& x, const cpu_op& y) { return y == x; }
}; };
MIGRAPHX_REGISTER_OP(cpu_op)
struct cpu_pad struct cpu_pad
{ {
...@@ -552,6 +566,7 @@ struct cpu_pad ...@@ -552,6 +566,7 @@ struct cpu_pad
return result; return result;
} }
}; };
MIGRAPHX_REGISTER_OP(cpu_pad)
struct cpu_gemm struct cpu_gemm
{ {
...@@ -603,6 +618,7 @@ struct cpu_gemm ...@@ -603,6 +618,7 @@ struct cpu_gemm
return result; return result;
} }
}; };
MIGRAPHX_REGISTER_OP(cpu_gemm)
struct cpu_quant_gemm struct cpu_quant_gemm
{ {
...@@ -669,6 +685,7 @@ struct cpu_quant_gemm ...@@ -669,6 +685,7 @@ struct cpu_quant_gemm
return result; return result;
} }
}; };
MIGRAPHX_REGISTER_OP(cpu_gemm)
struct leaky_relu_op struct leaky_relu_op
{ {
...@@ -693,8 +710,15 @@ struct elu_op ...@@ -693,8 +710,15 @@ struct elu_op
}; };
template <typename Op> template <typename Op>
struct cpu_unary struct cpu_unary : auto_register_op<cpu_unary<Op>>
{ {
cpu_unary() = default;
template <class T>
cpu_unary(T pop) : op(Op{std::move(pop)})
{
}
Op op; Op op;
template <class Self, class F> template <class Self, class F>
...@@ -723,8 +747,12 @@ struct cpu_unary ...@@ -723,8 +747,12 @@ struct cpu_unary
}; };
template <class Op> template <class Op>
struct cpu_softmax struct cpu_softmax : auto_register_op<cpu_softmax<Op>>
{ {
cpu_softmax() = default;
cpu_softmax(Op pop) : op(std::move(pop)) {}
Op op; Op op;
template <class Self, class F> template <class Self, class F>
...@@ -828,6 +856,7 @@ struct cpu_rnn_var_sl_last_output ...@@ -828,6 +856,7 @@ struct cpu_rnn_var_sl_last_output
return result; return result;
} }
}; };
MIGRAPHX_REGISTER_OP(cpu_rnn_var_sl_last_output)
struct cpu_apply struct cpu_apply
{ {
......
...@@ -108,7 +108,7 @@ add_library(migraphx_gpu ...@@ -108,7 +108,7 @@ add_library(migraphx_gpu
logsoftmax.cpp logsoftmax.cpp
concat.cpp concat.cpp
leaky_relu.cpp leaky_relu.cpp
batchnorm.cpp batch_norm_inference.cpp
write_literals.cpp write_literals.cpp
rocblas.cpp rocblas.cpp
abs.cpp abs.cpp
...@@ -129,6 +129,85 @@ add_library(migraphx_gpu ...@@ -129,6 +129,85 @@ add_library(migraphx_gpu
sync_device.cpp sync_device.cpp
) )
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu) set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
function(register_migraphx_gpu_ops PREFIX)
foreach(OP ${ARGN})
register_op(migraphx_gpu HEADER migraphx/gpu/${OP}.hpp OPERATORS gpu::${PREFIX}${OP} INCLUDES migraphx/gpu/context.hpp)
endforeach()
endfunction()
register_migraphx_gpu_ops(hip_
acosh
acos
add
argmax
argmin
asinh
asin
atanh
atan
ceil
clip
concat
convert
cosh
cos
div
erf
exp
floor
gather
log
logsoftmax
max
min
mul
pad
pow
prelu
recip
reduce_max
reduce_mean
reduce_min
reduce_prod
reduce_sum
relu
round
rsqrt
sigmoid
sign
sinh
sin
softmax
sqdiff
sqrt
sub
tanh
tan
)
register_migraphx_gpu_ops(miopen_
abs
batch_norm_inference
contiguous
convolution
deconvolution
elu
int8_conv_pack
leaky_relu
lrn
pooling
quant_convolution
)
register_op(migraphx_gpu
HEADER migraphx/gpu/rnn_variable_seq_lens.hpp
OPERATORS gpu::hip_rnn_var_sl_shift_sequence gpu::hip_rnn_var_sl_shift_output gpu::hip_rnn_var_sl_last_output
INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu
HEADER migraphx/gpu/int8_gemm_pack.hpp
OPERATORS gpu::hip_int8_gemm_pack_a gpu::hip_int8_gemm_pack_b
INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu
HEADER migraphx/gpu/gemm.hpp
OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot>
INCLUDES migraphx/gpu/context.hpp)
rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_gpu) rocm_clang_tidy_check(migraphx_gpu)
# Workaround broken rocblas headers # Workaround broken rocblas headers
......
#include <migraphx/gpu/batchnorm.hpp> #include <migraphx/gpu/batch_norm_inference.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <migraphx/gpu/device/mul_add_relu.hpp> #include <migraphx/gpu/device/mul_add_relu.hpp>
#include <migraphx/gpu/device/add.hpp> #include <migraphx/gpu/device/add.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/array.hpp> #include <migraphx/array.hpp>
#include <migraphx/op/clip.hpp> #include <migraphx/op/clip.hpp>
#include <cmath> #include <cmath>
...@@ -45,6 +46,8 @@ struct fusion ...@@ -45,6 +46,8 @@ struct fusion
return result; return result;
} }
fusion() = default;
fusion(const shape& input) fusion(const shape& input)
// : fp(make_fusion_plan(input)) // : fp(make_fusion_plan(input))
{ {
...@@ -175,66 +178,82 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins) ...@@ -175,66 +178,82 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
struct hip_triadd : ternary_device<hip_triadd, &device::add> struct hip_triadd : ternary_device<hip_triadd, &device::add>
{ {
}; };
MIGRAPHX_REGISTER_OP(hip_triadd)
struct hip_triadd_clip : quinary_device<hip_triadd_clip, &device::add_clip> struct hip_triadd_clip : quinary_device<hip_triadd_clip, &device::add_clip>
{ {
}; };
MIGRAPHX_REGISTER_OP(hip_triadd_clip)
struct hip_add_clip : quaternary_device<hip_add_clip, &device::add_clip> struct hip_add_clip : quaternary_device<hip_add_clip, &device::add_clip>
{ {
}; };
MIGRAPHX_REGISTER_OP(hip_add_clip)
struct hip_triadd_relu : ternary_device<hip_triadd_relu, &device::add_relu> struct hip_triadd_relu : ternary_device<hip_triadd_relu, &device::add_relu>
{ {
}; };
MIGRAPHX_REGISTER_OP(hip_triadd_relu)
struct hip_triadd_sigmoid : ternary_device<hip_triadd_sigmoid, &device::add_sigmoid> struct hip_triadd_sigmoid : ternary_device<hip_triadd_sigmoid, &device::add_sigmoid>
{ {
}; };
MIGRAPHX_REGISTER_OP(hip_triadd_sigmoid)
struct hip_triadd_tanh : ternary_device<hip_triadd_tanh, &device::add_tanh> struct hip_triadd_tanh : ternary_device<hip_triadd_tanh, &device::add_tanh>
{ {
}; };
MIGRAPHX_REGISTER_OP(hip_triadd_tanh)
struct hip_add_relu : binary_device<hip_add_relu, &device::add_relu> struct hip_add_relu : binary_device<hip_add_relu, &device::add_relu>
{ {
}; };
MIGRAPHX_REGISTER_OP(hip_add_relu)
struct hip_add_sigmoid : binary_device<hip_add_relu, &device::add_sigmoid> struct hip_add_sigmoid : binary_device<hip_add_relu, &device::add_sigmoid>
{ {
}; };
MIGRAPHX_REGISTER_OP(hip_add_sigmoid)
struct hip_add_tanh : binary_device<hip_add_tanh, &device::add_tanh> struct hip_add_tanh : binary_device<hip_add_tanh, &device::add_tanh>
{ {
}; };
MIGRAPHX_REGISTER_OP(hip_add_tanh)
struct hip_layernorm : unary_device<hip_layernorm, &device::layernorm> struct hip_layernorm : unary_device<hip_layernorm, &device::layernorm>
{ {
}; };
MIGRAPHX_REGISTER_OP(hip_layernorm)
struct hip_gelu : unary_device<hip_gelu, &device::gelu> struct hip_gelu : unary_device<hip_gelu, &device::gelu>
{ {
}; };
MIGRAPHX_REGISTER_OP(hip_gelu)
struct hip_add_gelu : binary_device<hip_add_gelu, &device::add_gelu> struct hip_add_gelu : binary_device<hip_add_gelu, &device::add_gelu>
{ {
}; };
MIGRAPHX_REGISTER_OP(hip_add_gelu)
struct hip_gelu_new : unary_device<hip_gelu_new, &device::gelu_new> struct hip_gelu_new : unary_device<hip_gelu_new, &device::gelu_new>
{ {
}; };
MIGRAPHX_REGISTER_OP(hip_gelu_new)
struct hip_add_gelu_new : binary_device<hip_add_gelu_new, &device::add_gelu_new> struct hip_add_gelu_new : binary_device<hip_add_gelu_new, &device::add_gelu_new>
{ {
}; };
MIGRAPHX_REGISTER_OP(hip_add_gelu_new)
struct hip_mul_add : ternary_device<hip_mul_add, &device::mul_add> struct hip_mul_add : ternary_device<hip_mul_add, &device::mul_add>
{ {
}; };
MIGRAPHX_REGISTER_OP(hip_mul_add)
struct hip_mul_add_relu : ternary_device<hip_mul_add_relu, &device::mul_add_relu> struct hip_mul_add_relu : ternary_device<hip_mul_add_relu, &device::mul_add_relu>
{ {
}; };
MIGRAPHX_REGISTER_OP(hip_mul_add_relu)
void move_broadcasted_back(std::vector<instruction_ref>& args) void move_broadcasted_back(std::vector<instruction_ref>& args)
{ {
...@@ -567,6 +586,8 @@ struct miopen_conv_bias ...@@ -567,6 +586,8 @@ struct miopen_conv_bias
return op::convolution::reflect(self.op, f); return op::convolution::reflect(self.op, f);
} }
miopen_conv_bias() = default;
miopen_conv_bias(op::convolution c, const shape& input, const shape& weights, const shape& b) miopen_conv_bias(op::convolution c, const shape& input, const shape& weights, const shape& b)
: op(std::move(c)), f(input) : op(std::move(c)), f(input)
{ {
...@@ -598,6 +619,7 @@ struct miopen_conv_bias ...@@ -598,6 +619,7 @@ struct miopen_conv_bias
return shapes.size() - 1; return shapes.size() - 1;
} }
}; };
MIGRAPHX_REGISTER_OP(miopen_conv_bias)
struct miopen_conv_bias_relu struct miopen_conv_bias_relu
{ {
...@@ -613,6 +635,8 @@ struct miopen_conv_bias_relu ...@@ -613,6 +635,8 @@ struct miopen_conv_bias_relu
return op::convolution::reflect(self.op, f); return op::convolution::reflect(self.op, f);
} }
miopen_conv_bias_relu() = default;
miopen_conv_bias_relu(op::convolution c, miopen_conv_bias_relu(op::convolution c,
const shape& input, const shape& input,
const shape& weights, const shape& weights,
...@@ -648,6 +672,7 @@ struct miopen_conv_bias_relu ...@@ -648,6 +672,7 @@ struct miopen_conv_bias_relu
return shapes.size() - 1; return shapes.size() - 1;
} }
}; };
MIGRAPHX_REGISTER_OP(miopen_conv_bias_relu)
template <class... Ms> template <class... Ms>
auto conv_bias(Ms... ms) auto conv_bias(Ms... ms)
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/contiguous.hpp> #include <migraphx/gpu/device/contiguous.hpp>
#include <miopen/miopen.h> #include <miopen/miopen.h>
...@@ -12,6 +13,14 @@ namespace migraphx { ...@@ -12,6 +13,14 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_REGISTER_OP(hip_allocate)
MIGRAPHX_REGISTER_OP(hip_sync_device)
MIGRAPHX_REGISTER_OP(hip_copy_to_gpu)
MIGRAPHX_REGISTER_OP(hip_copy_from_gpu)
MIGRAPHX_REGISTER_OP(hip_copy)
MIGRAPHX_REGISTER_OP(hip_allocate_memory)
MIGRAPHX_REGISTER_OP(hip_copy_literal)
using hip_ptr = MIGRAPHX_MANAGE_PTR(void, hipFree); using hip_ptr = MIGRAPHX_MANAGE_PTR(void, hipFree);
using hip_host_ptr = MIGRAPHX_MANAGE_PTR(void, hipHostUnregister); using hip_host_ptr = MIGRAPHX_MANAGE_PTR(void, hipHostUnregister);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment