Unverified Commit e8be8548 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Register all operators in migraphx (#604)

* Register ops for main migraphx

* Formatting

* Register cpu ops

* Formatting

* Show list of operators in the driver

* Formatting

* Simplify regiter

* Try to register gpu ops

* Fix compiler errors

* Register rest of the gpu operators

* Add some tests

* Formatting

* Fix gcc compiler warnings

* Formatting

* Fix tidy warnings

* Fix compile error

* Use correct op name

* Register layer norm

* Use const ref

* Make run const
parent 2c5d5fee
......@@ -2,7 +2,7 @@ CheckOptions:
- key: bugprone-unused-return-value.CheckedFunctions
value: '::std::async;::std::launder;::std::remove;::std::remove_if;::std::unique;::std::unique_ptr::release;::std::basic_string::empty;::std::vector::empty;::std::find;::std::find_if;::std::find_if_not;::std::all_of;::std::any_of;::std::none_of;::std::count;::std::count_if;::std::mismatch;::std::find_end;::std::find_first_of;::std::adjacent_find;::std::search;::std::search_n;::std::nth_element;::std::lower_bound;::std::upper_bound;::std::binary_search;::std::equal_range;::std::max;::std::max_element;::std::min;::std::min_element;::std::minmax;::std::minmax_element;::std::equal;::std::lexicographical_compare;::std::accumulate;::std::inner_product'
- key: cppcoreguidelines-macro-usage.AllowedRegexp
value: 'DEBUG|FALLTHROUGH|_THROW|_REQUIRES|_DECLARE_|_VISIT_|_GENERATE_|_DETAIL_|_MANAGE_PTR|_MATCHER|DEVICE_SHARED'
value: 'DEBUG|FALLTHROUGH|_THROW|_REQUIRES|_DECLARE_|_VISIT_|_REGISTER_|_GENERATE_|_DETAIL_|_MANAGE_PTR|_MATCHER|DEVICE_SHARED'
- key: cppcoreguidelines-narrowing-conversions.WarnOnFloatingPointNarrowingConversion
value: 0
- key: modernize-loop-convert.MinConfidence
......
function(register_op TARGET_NAME)
set(options)
set(oneValueArgs HEADER)
set(multiValueArgs OPERATORS INCLUDES)
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
string(MAKE_C_IDENTIFIER "${PARSE_HEADER}" BASE_NAME)
file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ops)
set(FILE_NAME ${CMAKE_CURRENT_BINARY_DIR}/ops/${BASE_NAME}.cpp)
file(WRITE "${FILE_NAME}" "")
foreach(INCLUDE ${PARSE_INCLUDES})
file(APPEND "${FILE_NAME}" "
#include <${INCLUDE}>
")
endforeach()
file(APPEND "${FILE_NAME}" "
#include <migraphx/register_op.hpp>
#include <${PARSE_HEADER}>
")
file(APPEND "${FILE_NAME}" "
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
")
foreach(OPERATOR ${PARSE_OPERATORS})
file(APPEND "${FILE_NAME}" "
MIGRAPHX_REGISTER_OP(${OPERATOR})
")
endforeach()
file(APPEND "${FILE_NAME}" "
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
")
target_sources(${TARGET_NAME} PRIVATE ${FILE_NAME})
endfunction()
include(ROCMInstallTargets)
include(ROCMPackageConfigHelpers)
include(RegisterOp)
add_library(migraphx
auto_contiguous.cpp
......@@ -28,6 +29,7 @@ add_library(migraphx
schedule.cpp
serialize.cpp
pass_manager.cpp
register_op.cpp
simplify_algebra.cpp
simplify_reshapes.cpp
value.cpp
......@@ -36,6 +38,98 @@ add_library(migraphx
opt/memory_coloring_impl.cpp
)
rocm_set_soversion(migraphx ${MIGRAPHX_SO_VERSION})
function(register_migraphx_ops)
foreach(OP ${ARGN})
register_op(migraphx HEADER migraphx/op/${OP}.hpp OPERATORS op::${OP})
endforeach()
endfunction()
register_migraphx_ops(
abs
acosh
acos
add
argmax
argmin
asinh
asin
as_shape
atanh
atan
batch_norm_inference
broadcast
capture
ceil
clip
concat
contiguous
convert
convolution
cosh
cos
deconvolution
div
dot
elu
erf
exp
flatten
floor
gather
gru
identity
im2col
leaky_relu
load
log
logsoftmax
lrn
lstm
max
min
mul
multibroadcast
neg
outline
pad
pooling
pow
prelu
quant_convolution
quant_dot
recip
reduce_max
reduce_mean
reduce_min
reduce_prod
reduce_sum
relu
reshape
rnn
rnn_last_cell_output
rnn_last_hs_output
rnn_var_sl_last_output
round
rsqrt
scalar
sigmoid
sign
sinh
sin
slice
softmax
sqdiff
sqrt
squeeze
sub
tanh
tan
transpose
undefined
unknown
unsqueeze
)
register_op(migraphx HEADER migraphx/op/rnn_variable_seq_lens.hpp OPERATORS op::rnn_var_sl_shift_output op::rnn_var_sl_shift_sequence)
register_op(migraphx HEADER migraphx/builtin.hpp OPERATORS builtin::literal builtin::param builtin::returns)
rocm_clang_tidy_check(migraphx)
rocm_install_targets(
TARGETS migraphx
......
......@@ -15,6 +15,7 @@
#include <migraphx/pass_manager.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp>
......@@ -319,6 +320,26 @@ struct perf : command<perf>
}
};
struct op : command<op>
{
bool show_ops = false;
void parse(argument_parser& ap)
{
ap(show_ops,
{"--list", "-l"},
ap.help("List all the operators of MIGraphX"),
ap.set_value(true));
}
void run() const
{
if(show_ops)
{
for(const auto& name : get_operators())
std::cout << name << std::endl;
}
}
};
struct main_command
{
static std::string get_command_help()
......
......@@ -3,7 +3,6 @@
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......
......@@ -3,7 +3,6 @@
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......
......@@ -3,7 +3,6 @@
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
......
......@@ -2,7 +2,6 @@
#define MIGRAPHX_GUARD_OPERATORS_AS_SHAPE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......
......@@ -3,7 +3,6 @@
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......
......@@ -3,7 +3,6 @@
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......
......@@ -3,7 +3,6 @@
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......
......@@ -2,6 +2,9 @@
#define MIGRAPHX_GUARD_OPERATORS_BINARY_HPP
#include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -2,12 +2,9 @@
#define MIGRAPHX_GUARD_OPERATORS_BROADCAST_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
......
......@@ -2,7 +2,6 @@
#define MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......
......@@ -3,7 +3,6 @@
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......
#ifndef MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#define MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
......
......@@ -2,7 +2,6 @@
#define MIGRAPHX_GUARD_OPERATORS_CONCAT_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......
......@@ -2,7 +2,6 @@
#define MIGRAPHX_GUARD_OPERATORS_CONTIGUOUS_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment