Unverified Commit 251cdd74 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add initial experimental custom op (#1109)

This creates a custom op which has name() and compute_shape() methods. 
parent cd165ebd
...@@ -148,13 +148,15 @@ jobs: ...@@ -148,13 +148,15 @@ jobs:
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v2 uses: actions/setup-python@v2
with: with:
python-version: 3.6 python-version: 3.8
- name: Install pyflakes - name: Install pyflakes
run: pip install pyflakes==2.3.1 mypy==0.931 run: pip install pyflakes==2.4.0 mypy==0.931
- name: Run pyflakes - name: Run pyflakes
run: | run: |
pyflakes --version
pyflakes examples/ tools/ src/ test/ doc/ pyflakes examples/ tools/ src/ test/ doc/
mypy --version
mypy tools/api.py mypy tools/api.py
linux: linux:
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <migraphx/load_save.hpp> #include <migraphx/load_save.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp> #include <migraphx/convert_to_json.hpp>
#include <algorithm> #include <algorithm>
...@@ -212,6 +213,39 @@ void print_program(const program& p) { std::cout << p << std::endl; } ...@@ -212,6 +213,39 @@ void print_program(const program& p) { std::cout << p << std::endl; }
void print_module(const module& m) { std::cout << m << std::endl; } void print_module(const module& m) { std::cout << m << std::endl; }
struct experimental_custom_op
{
std::string name;
experimental_custom_op() = default;
experimental_custom_op(std::string pname) : name(std::move(pname)) {}
};
template <class CustomOp>
struct custom_operation
{
template <class Self, class F>
static auto reflect(Self&, F)
{
return pack();
}
CustomOp op;
std::string name() const { return op.xobject.name; }
shape compute_shape(std::vector<shape> inputs) const
{
return op.compute_shape(std::move(inputs));
}
argument compute(const std::vector<argument>&) const { MIGRAPHX_THROW("Not computable"); }
};
template <class CustomOp>
void register_custom_op(const CustomOp& op)
{
register_op(custom_operation<CustomOp>{op});
}
migraphx::context get_context(const program& p) { return p.get_context(); } migraphx::context get_context(const program& p) { return p.get_context(); }
} // namespace migraphx } // namespace migraphx
...@@ -238,12 +272,60 @@ void destroy(T* x) ...@@ -238,12 +272,60 @@ void destroy(T* x)
{ {
delete x; // NOLINT delete x; // NOLINT
} }
// TODO: Move to interface preamble
template <class C, class D>
struct manage_generic_ptr
{
manage_generic_ptr() = default;
manage_generic_ptr(std::nullptr_t) {}
manage_generic_ptr(void* pdata, C pcopier, D pdeleter)
: data(nullptr), copier(pcopier), deleter(pdeleter)
{
copier(&data, pdata);
}
manage_generic_ptr(const manage_generic_ptr& rhs)
: data(nullptr), copier(rhs.copier), deleter(rhs.deleter)
{
if(copier)
copier(&data, rhs.data);
}
manage_generic_ptr(manage_generic_ptr&& other) noexcept
: data(other.data), copier(other.copier), deleter(other.deleter)
{
other.data = nullptr;
other.copier = nullptr;
other.deleter = nullptr;
}
manage_generic_ptr& operator=(manage_generic_ptr rhs)
{
std::swap(data, rhs.data);
std::swap(copier, rhs.copier);
std::swap(deleter, rhs.deleter);
return *this;
}
~manage_generic_ptr()
{
if(data != nullptr)
deleter(data);
}
void* data = nullptr;
C copier = nullptr;
D deleter = nullptr;
};
extern "C" struct migraphx_shape; extern "C" struct migraphx_shape;
struct migraphx_shape struct migraphx_shape
{ {
template <class... Ts> template <class... Ts>
migraphx_shape(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_shape(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::shape object; migraphx::shape object;
...@@ -253,7 +335,8 @@ extern "C" struct migraphx_argument; ...@@ -253,7 +335,8 @@ extern "C" struct migraphx_argument;
struct migraphx_argument struct migraphx_argument
{ {
template <class... Ts> template <class... Ts>
migraphx_argument(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_argument(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::argument object; migraphx::argument object;
...@@ -263,7 +346,8 @@ extern "C" struct migraphx_target; ...@@ -263,7 +346,8 @@ extern "C" struct migraphx_target;
struct migraphx_target struct migraphx_target
{ {
template <class... Ts> template <class... Ts>
migraphx_target(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_target(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::target object; migraphx::target object;
...@@ -273,7 +357,8 @@ extern "C" struct migraphx_program_parameter_shapes; ...@@ -273,7 +357,8 @@ extern "C" struct migraphx_program_parameter_shapes;
struct migraphx_program_parameter_shapes struct migraphx_program_parameter_shapes
{ {
template <class... Ts> template <class... Ts>
migraphx_program_parameter_shapes(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_program_parameter_shapes(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
std::unordered_map<std::string, migraphx::shape> object; std::unordered_map<std::string, migraphx::shape> object;
...@@ -283,7 +368,8 @@ extern "C" struct migraphx_program_parameters; ...@@ -283,7 +368,8 @@ extern "C" struct migraphx_program_parameters;
struct migraphx_program_parameters struct migraphx_program_parameters
{ {
template <class... Ts> template <class... Ts>
migraphx_program_parameters(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_program_parameters(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
std::unordered_map<std::string, migraphx::argument> object; std::unordered_map<std::string, migraphx::argument> object;
...@@ -293,7 +379,8 @@ extern "C" struct migraphx_arguments; ...@@ -293,7 +379,8 @@ extern "C" struct migraphx_arguments;
struct migraphx_arguments struct migraphx_arguments
{ {
template <class... Ts> template <class... Ts>
migraphx_arguments(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_arguments(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
std::vector<migraphx::argument> object; std::vector<migraphx::argument> object;
...@@ -303,7 +390,8 @@ extern "C" struct migraphx_shapes; ...@@ -303,7 +390,8 @@ extern "C" struct migraphx_shapes;
struct migraphx_shapes struct migraphx_shapes
{ {
template <class... Ts> template <class... Ts>
migraphx_shapes(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_shapes(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
std::vector<migraphx::shape> object; std::vector<migraphx::shape> object;
...@@ -343,7 +431,8 @@ extern "C" struct migraphx_module; ...@@ -343,7 +431,8 @@ extern "C" struct migraphx_module;
struct migraphx_module struct migraphx_module
{ {
template <class... Ts> template <class... Ts>
migraphx_module(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_module(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::module object; migraphx::module object;
...@@ -353,7 +442,8 @@ extern "C" struct migraphx_program; ...@@ -353,7 +442,8 @@ extern "C" struct migraphx_program;
struct migraphx_program struct migraphx_program
{ {
template <class... Ts> template <class... Ts>
migraphx_program(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_program(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::program object; migraphx::program object;
...@@ -363,7 +453,8 @@ extern "C" struct migraphx_operation; ...@@ -363,7 +453,8 @@ extern "C" struct migraphx_operation;
struct migraphx_operation struct migraphx_operation
{ {
template <class... Ts> template <class... Ts>
migraphx_operation(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_operation(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::operation object; migraphx::operation object;
...@@ -373,7 +464,8 @@ extern "C" struct migraphx_onnx_options; ...@@ -373,7 +464,8 @@ extern "C" struct migraphx_onnx_options;
struct migraphx_onnx_options struct migraphx_onnx_options
{ {
template <class... Ts> template <class... Ts>
migraphx_onnx_options(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_onnx_options(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::onnx_options object; migraphx::onnx_options object;
...@@ -383,7 +475,8 @@ extern "C" struct migraphx_file_options; ...@@ -383,7 +475,8 @@ extern "C" struct migraphx_file_options;
struct migraphx_file_options struct migraphx_file_options
{ {
template <class... Ts> template <class... Ts>
migraphx_file_options(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_file_options(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::file_options object; migraphx::file_options object;
...@@ -393,7 +486,8 @@ extern "C" struct migraphx_compile_options; ...@@ -393,7 +486,8 @@ extern "C" struct migraphx_compile_options;
struct migraphx_compile_options struct migraphx_compile_options
{ {
template <class... Ts> template <class... Ts>
migraphx_compile_options(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_compile_options(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::compile_options object; migraphx::compile_options object;
...@@ -403,7 +497,8 @@ extern "C" struct migraphx_tf_options; ...@@ -403,7 +497,8 @@ extern "C" struct migraphx_tf_options;
struct migraphx_tf_options struct migraphx_tf_options
{ {
template <class... Ts> template <class... Ts>
migraphx_tf_options(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_tf_options(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::tf_options object; migraphx::tf_options object;
...@@ -413,7 +508,8 @@ extern "C" struct migraphx_quantize_op_names; ...@@ -413,7 +508,8 @@ extern "C" struct migraphx_quantize_op_names;
struct migraphx_quantize_op_names struct migraphx_quantize_op_names
{ {
template <class... Ts> template <class... Ts>
migraphx_quantize_op_names(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_quantize_op_names(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
std::vector<std::string> object; std::vector<std::string> object;
...@@ -423,7 +519,8 @@ extern "C" struct migraphx_quantize_int8_options; ...@@ -423,7 +519,8 @@ extern "C" struct migraphx_quantize_int8_options;
struct migraphx_quantize_int8_options struct migraphx_quantize_int8_options
{ {
template <class... Ts> template <class... Ts>
migraphx_quantize_int8_options(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_quantize_int8_options(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::quantize_int8_options object; migraphx::quantize_int8_options object;
...@@ -433,12 +530,41 @@ extern "C" struct migraphx_context; ...@@ -433,12 +530,41 @@ extern "C" struct migraphx_context;
struct migraphx_context struct migraphx_context
{ {
template <class... Ts> template <class... Ts>
migraphx_context(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_context(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::context object; migraphx::context object;
}; };
extern "C" struct migraphx_experimental_custom_op;
struct migraphx_experimental_custom_op
{
template <class... Ts>
migraphx_experimental_custom_op(void* p,
migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d,
Ts&&... xs)
: object_ptr(p, c, d), xobject(std::forward<Ts>(xs)...)
{
}
manage_generic_ptr<migraphx_experimental_custom_op_copy, migraphx_experimental_custom_op_delete>
object_ptr = nullptr;
migraphx::experimental_custom_op xobject;
migraphx_experimental_custom_op_compute_shape compute_shape_f = nullptr;
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
std::remove_pointer_t<migraphx_shape_t> out;
if(compute_shape_f == nullptr)
throw std::runtime_error("compute_shape function is missing.");
auto api_error_result =
compute_shape_f(&out, object_ptr.data, object_cast<migraphx_shapes_t>(&(inputs)));
if(api_error_result != migraphx_status_success)
throw std::runtime_error("Error in compute_shape.");
return (&out)->object;
}
};
extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape) extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape)
{ {
auto api_error_result = migraphx::try_([&] { destroy((shape)); }); auto api_error_result = migraphx::try_([&] { destroy((shape)); });
...@@ -1564,3 +1690,51 @@ extern "C" migraphx_status migraphx_context_finish(const_migraphx_context_t cont ...@@ -1564,3 +1690,51 @@ extern "C" migraphx_status migraphx_context_finish(const_migraphx_context_t cont
}); });
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op)
{
auto api_error_result = migraphx::try_([&] { destroy((experimental_custom_op)); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_experimental_custom_op_assign_to(migraphx_experimental_custom_op_t output,
const_migraphx_experimental_custom_op_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experimental_custom_op,
void* obj,
migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d,
const char* name)
{
auto api_error_result = migraphx::try_([&] {
*experimental_custom_op =
allocate<migraphx_experimental_custom_op_t>((obj), (c), (d), (name));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input)
{
auto api_error_result = migraphx::try_([&] { (obj)->compute_shape_f = (input); });
return api_error_result;
}
extern "C" migraphx_status
migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op)
{
auto api_error_result = migraphx::try_([&] {
if(experimental_custom_op == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter experimental_custom_op: Null pointer");
migraphx::register_custom_op((*experimental_custom_op));
});
return api_error_result;
}
...@@ -103,6 +103,17 @@ typedef const struct migraphx_quantize_int8_options* const_migraphx_quantize_int ...@@ -103,6 +103,17 @@ typedef const struct migraphx_quantize_int8_options* const_migraphx_quantize_int
typedef struct migraphx_context* migraphx_context_t; typedef struct migraphx_context* migraphx_context_t;
typedef const struct migraphx_context* const_migraphx_context_t; typedef const struct migraphx_context* const_migraphx_context_t;
typedef struct migraphx_experimental_custom_op* migraphx_experimental_custom_op_t;
typedef const struct migraphx_experimental_custom_op* const_migraphx_experimental_custom_op_t;
typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraphx_shape_t out,
void* obj,
migraphx_shapes_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void* input);
typedef migraphx_status (*migraphx_experimental_custom_op_delete)(void* input);
migraphx_status migraphx_shape_destroy(migraphx_shape_t shape); migraphx_status migraphx_shape_destroy(migraphx_shape_t shape);
migraphx_status migraphx_shape_assign_to(migraphx_shape_t output, const_migraphx_shape_t input); migraphx_status migraphx_shape_assign_to(migraphx_shape_t output, const_migraphx_shape_t input);
...@@ -422,6 +433,26 @@ migraphx_status migraphx_quantize_int8(migraphx_program_t prog, ...@@ -422,6 +433,26 @@ migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_status migraphx_context_finish(const_migraphx_context_t context); migraphx_status migraphx_context_finish(const_migraphx_context_t context);
migraphx_status
migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op);
migraphx_status
migraphx_experimental_custom_op_assign_to(migraphx_experimental_custom_op_t output,
const_migraphx_experimental_custom_op_t input);
migraphx_status
migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experimental_custom_op,
void* obj,
migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d,
const char* name);
migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input);
migraphx_status
migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
......
...@@ -15,6 +15,16 @@ namespace migraphx { ...@@ -15,6 +15,16 @@ namespace migraphx {
inline namespace api { // NOLINT inline namespace api { // NOLINT
#endif #endif
template <int N>
struct rank : rank<N - 1>
{
};
template <>
struct rank<0>
{
};
template <class T, class F, class... Ts> template <class T, class F, class... Ts>
T* make(F f, Ts&&... xs) T* make(F f, Ts&&... xs)
{ {
...@@ -231,6 +241,138 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>> ...@@ -231,6 +241,138 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
std::shared_ptr<T> m_handle; std::shared_ptr<T> m_handle;
}; };
template <class Base>
struct interface_base : Base
{
interface_base() : Base() {}
protected:
template <class F>
static migraphx_status try_(F f) // NOLINT
{
try
{
f();
return migraphx_status_success;
}
catch(...)
{
return migraphx_status_unknown_error;
}
}
template <class F, class T, class... Ts>
void make_interface(F f, T& obj, Ts&&... xs)
{
auto copy = [](void** out, void* input) {
return try_([&] {
T** y = reinterpret_cast<T**>(out);
T* x = reinterpret_cast<T*>(input);
assert(x != nullptr and y != nullptr and *y == nullptr);
*y = new T(*x); // NOLINT
});
};
auto del = [](void* input) {
return try_([&] {
T* x = reinterpret_cast<T*>(input);
delete x; // NOLINT
});
};
this->make_handle(f, &obj, copy, del, std::forward<Ts>(xs)...);
}
template <class T, class Setter, class F>
void set_fp(Setter setter, F pf)
{
static F f = pf;
(void)f; // avoid warning on gcc
call(setter, this->get_handle_ptr(), [](auto... xs) -> migraphx_status {
return try_([&] { call_cast_arg<T>(rank<1>{}, f, xs...); });
});
}
template <class T, class Setter, class F>
void set_auto_fp(Setter setter, F f)
{
return set_fp<T>(setter, [=](T& obj, auto out, auto... xs) {
auto_invoke(f, out, obj, auto_convert_param(rank<2>{}, xs)...);
});
}
struct no_out_arg
{
};
template <class T, class F, class X, class... Xs, class = std::enable_if_t<std::is_void<X>{}>>
static void call_cast_arg(rank<0>, F f, X* obj, Xs... xs)
{
f(reinterpret_cast<T*>(obj), no_out_arg{}, xs...);
}
template <class T,
class F,
class R,
class X,
class... Xs,
class = std::enable_if_t<std::is_void<X>{}>>
static void call_cast_arg(rank<1>, F f, R result, X* obj, Xs... xs)
{
f(*reinterpret_cast<T*>(obj), result, xs...);
}
template <class F, class T, class... Ts>
void auto_invoke(F f, T* out, Ts&&... xs)
{
auto_assign(rank<2>{}, out, f(std::forward<Ts>(xs)...));
}
template <class F, class T, class... Ts>
void auto_invoke(F f, no_out_arg, Ts&&... xs)
{
f(std::forward<Ts>(xs)...);
}
template <class T, class = std::enable_if_t<std::is_fundamental<T>{} or std::is_enum<T>{}>>
T auto_convert_param(rank<0>, T x)
{
return x;
}
template <class T>
auto auto_convert_param(rank<1>, T x) -> decltype(as_handle<T>{x})
{
return as_handle<T>{x};
}
template <class T>
auto auto_convert_param(rank<2>, T x) -> decltype(as_handle<T>{x, borrow{}})
{
return as_handle<T>{x, borrow{}};
}
template <class T, class U>
void auto_assign(rank<0>, T* out, U x)
{
return *out = x;
}
template <class T, class U>
auto auto_assign(rank<1>, T* out, U x) -> decltype(x.assign_to_handle(out))
{
x.assign_to_handle(out);
}
};
// NOLINTNEXTLINE
#define MIGRAPHX_INTERFACE_LIFT(T, prefix, name) \
this->set_auto_fp<T>(&migraphx_##prefix##_set_##name, \
[](T& x, auto... xs) { return x.name(xs...); })
template <class Base, class T>
using require_interface =
std::enable_if_t<std::is_base_of<Base, T>{} and not std::is_same<T, Base>{} and
std::is_copy_constructible<T>{} and std::is_final<T>{}>;
#ifdef DOXYGEN #ifdef DOXYGEN
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) handle_base<> #define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) handle_base<>
#else #else
...@@ -988,6 +1130,32 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op ...@@ -988,6 +1130,32 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op
options.get_handle_ptr()); options.get_handle_ptr());
} }
struct experimental_custom_op_base
{
virtual std::string name() const = 0;
virtual shape compute_shape(shapes inputs) const = 0;
virtual ~experimental_custom_op_base() = default;
};
struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental_custom_op)>
{
template <class T>
experimental_custom_op(T& obj)
{
this->make_interface(&migraphx_experimental_custom_op_create, obj, obj.name().c_str());
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute_shape);
}
void register_op() { call(&migraphx_experimental_custom_op_register, this->get_handle_ptr()); }
};
template <class T, class = require_interface<experimental_custom_op_base, T>>
void register_experimental_custom_op(T& obj)
{
experimental_custom_op op{obj};
op.register_op();
}
#ifndef DOXYGEN #ifndef DOXYGEN
} // namespace api } // namespace api
#endif #endif
......
...@@ -403,3 +403,13 @@ api.add_function('migraphx_quantize_int8', ...@@ -403,3 +403,13 @@ api.add_function('migraphx_quantize_int8',
@auto_handle(ref=True) @auto_handle(ref=True)
def context(h): def context(h):
h.method('finish', const=True) h.method('finish', const=True)
@api.interface('migraphx_experimental_custom_op',
'migraphx::experimental_custom_op')
def experimental_custom_op(h):
h.constructor('create', api.params(name='const char*'))
h.virtual('compute_shape',
api.params(inputs='std::vector<migraphx::shape>'),
returns='migraphx::shape')
h.method('register', invoke='migraphx::register_custom_op($@)')
...@@ -11,6 +11,7 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR) ...@@ -11,6 +11,7 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
endfunction() endfunction()
add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR}) add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR})
add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR})
add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR}) add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR})
add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR}) add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR})
add_api_test(module_construct test_module_construct.cpp ${TEST_ONNX_DIR}) add_api_test(module_construct test_module_construct.cpp ${TEST_ONNX_DIR})
......
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"
struct simple_custom_op final : migraphx::experimental_custom_op_base
{
virtual std::string name() const override { return "simple_custom_op"; }
virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override
{
return inputs.front();
}
};
TEST_CASE(register_custom_op)
{
simple_custom_op simple_op;
migraphx::register_experimental_custom_op(simple_op);
auto op = migraphx::operation("simple_custom_op");
EXPECT(op.name() == "simple_custom_op");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/migraphx.hpp> #include <migraphx/migraphx.hpp>
#include <migraphx/rank.hpp>
#include "test.hpp" #include "test.hpp"
template <class T> template <class T>
......
...@@ -102,6 +102,10 @@ header_function = Template(''' ...@@ -102,6 +102,10 @@ header_function = Template('''
${error_type} ${name}(${params}); ${error_type} ${name}(${params});
''') ''')
function_pointer_typedef = Template('''
typedef ${error_type} (*${fname})(${params});
''')
c_api_impl = Template(''' c_api_impl = Template('''
extern "C" ${error_type} ${name}(${params}) extern "C" ${error_type} ${name}(${params})
{ {
...@@ -136,18 +140,23 @@ class CFunction: ...@@ -136,18 +140,23 @@ class CFunction:
self.va_end = ['va_end({});'.format(name)] self.va_end = ['va_end({});'.format(name)]
self.add_param('...', '') self.add_param('...', '')
def substitute(self, form: Template) -> str: def substitute(self, form: Template, **kwargs) -> str:
return form.substitute(error_type=error_type, return form.substitute(error_type=error_type,
try_wrap=try_wrap, try_wrap=try_wrap,
name=self.name, name=self.name,
params=', '.join(self.params), params=', '.join(self.params),
body=";\n ".join(self.body), body=";\n ".join(self.body),
va_start="\n ".join(self.va_start), va_start="\n ".join(self.va_start),
va_end="\n ".join(self.va_end)) va_end="\n ".join(self.va_end),
**kwargs)
def generate_header(self) -> str: def generate_header(self) -> str:
return self.substitute(header_function) return self.substitute(header_function)
def generate_function_pointer(self, name: Optional[str] = None) -> str:
return self.substitute(function_pointer_typedef,
fname=name or self.name)
def generate_body(self) -> str: def generate_body(self) -> str:
return self.substitute(c_api_impl) return self.substitute(c_api_impl)
...@@ -163,7 +172,9 @@ class Parameter: ...@@ -163,7 +172,9 @@ class Parameter:
name: str, name: str,
type: str, type: str,
optional: bool = False, optional: bool = False,
returns: bool = False) -> None: returns: bool = False,
virtual: bool = False,
this: bool = False) -> None:
self.name = name self.name = name
self.type = Type(type) self.type = Type(type)
self.optional = optional self.optional = optional
...@@ -175,7 +186,11 @@ class Parameter: ...@@ -175,7 +186,11 @@ class Parameter:
self.cpp_read = '${name}' self.cpp_read = '${name}'
self.cpp_write = '${name}' self.cpp_write = '${name}'
self.returns = returns self.returns = returns
self.virtual = virtual
self.this = this
self.bad_param_check: Optional[BadParam] = None self.bad_param_check: Optional[BadParam] = None
self.virtual_read: Optional[List[str]] = None
self.virtual_write: Optional[str] = None
def get_name(self, prefix: Optional[str] = None) -> str: def get_name(self, prefix: Optional[str] = None) -> str:
if prefix: if prefix:
...@@ -248,6 +263,48 @@ class Parameter: ...@@ -248,6 +263,48 @@ class Parameter:
raise ValueError("Error for {}: write cannot be a string".format( raise ValueError("Error for {}: write cannot be a string".format(
self.type.str())) self.type.str()))
def virtual_arg(self, prefix: Optional[str] = None) -> List[str]:
read = self.virtual_read
if not read and len(self.write) >= len(self.cparams):
read = [
Template(w.partition('=')[2]).safe_substitute(result='${name}')
for w in self.write
]
if not read:
raise ValueError("No virtual_read parameter provided for: " +
self.type.str())
if isinstance(read, str):
raise ValueError(
"Error for {}: virtual_read cannot be a string".format(
self.type.str()))
return [self.substitute(r, prefix=prefix) for r in read]
def virtual_param(self, prefix: Optional[str] = None) -> str:
return self.substitute('${type} ${name}', prefix=prefix)
def virtual_output_args(self, prefix: Optional[str] = None) -> List[str]:
return [
'&{prefix}{n}'.format(prefix=prefix or '', n=n)
for t, n in self.cparams
]
def virtual_output_declarations(self,
prefix: Optional[str] = None) -> List[str]:
return [
'std::remove_pointer_t<{type}> {prefix}{n};'.format(
type=Type(t).str(), prefix=prefix or '', n=n)
for t, n in self.cparams
]
def virtual_output(self, prefix: Optional[str] = None) -> str:
write = self.virtual_write
if not write:
if '*' in self.read or '->' in self.read:
write = Template(self.read).safe_substitute(name='(&${name})')
else:
write = self.read
return self.substitute(write, prefix=prefix)
def cpp_param(self, prefix: Optional[str] = None) -> str: def cpp_param(self, prefix: Optional[str] = None) -> str:
return self.substitute('${cpptype} ${name}', prefix=prefix) return self.substitute('${cpptype} ${name}', prefix=prefix)
...@@ -311,6 +368,7 @@ class Function: ...@@ -311,6 +368,7 @@ class Function:
invoke: Optional[str] = None, invoke: Optional[str] = None,
fname: Optional[str] = None, fname: Optional[str] = None,
return_name: Optional[str] = None, return_name: Optional[str] = None,
virtual: bool = False,
**kwargs) -> None: **kwargs) -> None:
self.name = name self.name = name
self.params = params or [] self.params = params or []
...@@ -321,6 +379,10 @@ class Function: ...@@ -321,6 +379,10 @@ class Function:
self.return_name = return_name or 'out' self.return_name = return_name or 'out'
self.returns = Parameter(self.return_name, returns, self.returns = Parameter(self.return_name, returns,
returns=True) if returns else None returns=True) if returns else None
for p in self.params:
p.virtual = virtual
if self.returns:
self.returns.virtual = virtual
def share_params(self) -> None: def share_params(self) -> None:
if self.shared_size == True: if self.shared_size == True:
...@@ -556,6 +618,9 @@ def params(virtual: Optional[Dict[str, str]] = None, ...@@ -556,6 +618,9 @@ def params(virtual: Optional[Dict[str, str]] = None,
return result return result
gparams = params
def add_function(name: str, *args, **kwargs) -> Function: def add_function(name: str, *args, **kwargs) -> Function:
f = Function(name, *args, **kwargs) f = Function(name, *args, **kwargs)
functions.append(f) functions.append(f)
...@@ -627,7 +692,7 @@ extern "C" struct ${ctype}; ...@@ -627,7 +692,7 @@ extern "C" struct ${ctype};
struct ${ctype} { struct ${ctype} {
template<class... Ts> template<class... Ts>
${ctype}(Ts&&... xs) ${ctype}(Ts&&... xs)
: object(std::forward<Ts>(xs)...) : object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{} {}
${cpptype} object; ${cpptype} object;
}; };
...@@ -656,6 +721,55 @@ void destroy(T* x) ...@@ -656,6 +721,55 @@ void destroy(T* x)
{ {
delete x; // NOLINT delete x; // NOLINT
} }
// TODO: Move to interface preamble
template <class C, class D>
struct manage_generic_ptr
{
manage_generic_ptr() = default;
manage_generic_ptr(std::nullptr_t)
{
}
manage_generic_ptr(void* pdata, C pcopier, D pdeleter)
: data(nullptr), copier(pcopier), deleter(pdeleter)
{
copier(&data, pdata);
}
manage_generic_ptr(const manage_generic_ptr& rhs)
: data(nullptr), copier(rhs.copier), deleter(rhs.deleter)
{
if(copier)
copier(&data, rhs.data);
}
manage_generic_ptr(manage_generic_ptr&& other) noexcept
: data(other.data), copier(other.copier), deleter(other.deleter)
{
other.data = nullptr;
other.copier = nullptr;
other.deleter = nullptr;
}
manage_generic_ptr& operator=(manage_generic_ptr rhs)
{
std::swap(data, rhs.data);
std::swap(copier, rhs.copier);
std::swap(deleter, rhs.deleter);
return *this;
}
~manage_generic_ptr()
{
if(data != nullptr)
deleter(data);
}
void* data = nullptr;
C copier = nullptr;
D deleter = nullptr;
};
''' '''
cpp_handle_preamble = ''' cpp_handle_preamble = '''
...@@ -718,30 +832,40 @@ def add_handle(name: str, ...@@ -718,30 +832,40 @@ def add_handle(name: str,
ctype: str, ctype: str,
cpptype: str, cpptype: str,
destroy: Optional[str] = None, destroy: Optional[str] = None,
ref: Optional[bool] = None) -> None: ref=False,
skip_def=False) -> None:
opaque_type = ctype + '_t' opaque_type = ctype + '_t'
const_opaque_type = 'const_' + opaque_type const_opaque_type = 'const_' + opaque_type
def handle_wrap(p): def handle_wrap(p: Parameter):
t = Type(opaque_type) t = Type(opaque_type)
if p.type.is_const(): if p.type.is_const():
t = Type('const_' + opaque_type) t = Type('const_' + opaque_type)
if p.returns: # p.read = 'object_cast<${ctype}>(&(${name}))'
if p.virtual:
p.add_param(t)
elif p.returns:
p.add_param(t.add_pointer()) p.add_param(t.add_pointer())
if p.type.is_reference():
p.cpp_write = '${cpptype}(${name}, false)'
p.write = ['*${name} = object_cast<${ctype}>(&(${result}))']
elif p.type.is_pointer():
p.cpp_write = '${cpptype}(${name}, false)'
p.write = ['*${name} = object_cast<${ctype}>(${result})']
else:
p.cpp_write = '${cpptype}(${name})'
p.write = ['*${name} = allocate<${ctype}>(${result})']
else: else:
p.add_param(t) p.add_param(t)
p.bad_param('${name} == nullptr', 'Null pointer') p.bad_param('${name} == nullptr', 'Null pointer')
if p.type.is_reference():
p.virtual_read = ['object_cast<${ctype}>(&(${name}))']
p.cpp_write = '${cpptype}(${name}, false)'
p.write = ['*${name} = object_cast<${ctype}>(&(${result}))']
elif p.type.is_pointer():
p.virtual_read = ['object_cast<${ctype}>(${result})']
p.cpp_write = '${cpptype}(${name}, false)'
p.write = ['*${name} = object_cast<${ctype}>(${result})']
else:
p.virtual_read = ['object_cast<${ctype}>(&(${name}))']
p.cpp_write = '${cpptype}(${name})'
p.write = ['*${name} = allocate<${ctype}>(${result})']
if skip_def:
p.read = '*${name}'
else:
p.read = '${name}->object' p.read = '${name}->object'
p.cpp_read = '${name}.get_handle_ptr()' p.cpp_read = '${name}.get_handle_ptr()'
type_map[cpptype] = handle_wrap type_map[cpptype] = handle_wrap
if not ref: if not ref:
...@@ -753,7 +877,8 @@ def add_handle(name: str, ...@@ -753,7 +877,8 @@ def add_handle(name: str,
invoke='*output = *input') invoke='*output = *input')
add_handle_preamble() add_handle_preamble()
c_header_preamble.append(handle_typedef.substitute(locals())) c_header_preamble.append(handle_typedef.substitute(locals()))
c_api_body_preamble.append(handle_definition.substitute(locals())) if not skip_def:
c_api_body_preamble.append(handle_definition.substitute(locals()))
@cwrap('std::vector') @cwrap('std::vector')
...@@ -763,30 +888,32 @@ def vector_c_wrap(p: Parameter) -> None: ...@@ -763,30 +888,32 @@ def vector_c_wrap(p: Parameter) -> None:
if not inner: if not inner:
return return
t = inner.add_pointer() t = inner.add_pointer()
if p.type.is_reference():
if p.type.is_const():
t = t.add_const()
if p.returns: if p.returns:
if p.type.is_reference(): if p.type.is_reference():
if p.type.is_const():
t = t.add_const()
p.add_param(t.add_pointer()) p.add_param(t.add_pointer())
p.add_size_param() p.add_size_param()
p.bad_param('${name} == nullptr or ${size} == nullptr', p.bad_param('${name} == nullptr or ${size} == nullptr',
'Null pointer') 'Null pointer')
p.cpp_write = '${type}(${name}, ${name}+${size})'
p.write = [
'*${name} = ${result}.data()', '*${size} = ${result}.size()'
]
else: else:
p.add_param(t) p.add_param(t)
p.bad_param('${name} == nullptr', 'Null pointer') p.bad_param('${name} == nullptr', 'Null pointer')
p.cpp_write = '${type}(${name}, ${name}+${size})'
p.write = [
'std::copy(${result}.begin(), ${result}.end(), ${name})'
]
else: else:
p.add_param(t) p.add_param(t)
p.add_size_param() p.add_size_param()
p.bad_param('${name} == nullptr and ${size} != 0', 'Null pointer') p.bad_param('${name} == nullptr and ${size} != 0', 'Null pointer')
p.read = '${type}(${name}, ${name}+${size})'
p.read = '${type}(${name}, ${name}+${size})'
p.cpp_write = '${type}(${name}, ${name}+${size})'
p.virtual_read = ['${name}.data()', '${name}.size()']
if p.type.is_reference():
p.write = [
'*${name} = ${result}.data()', '*${size} = ${result}.size()'
]
else:
p.write = ['std::copy(${result}.begin(), ${result}.end(), ${name})']
@cwrap('std::string') @cwrap('std::string')
...@@ -796,34 +923,34 @@ def string_c_wrap(p: Parameter) -> None: ...@@ -796,34 +923,34 @@ def string_c_wrap(p: Parameter) -> None:
if p.type.is_reference(): if p.type.is_reference():
p.add_param(t.add_pointer()) p.add_param(t.add_pointer())
p.bad_param('${name} == nullptr', 'Null pointer') p.bad_param('${name} == nullptr', 'Null pointer')
p.cpp_write = '${type}(${name})'
p.write = ['*${name} = ${result}.c_str()']
else: else:
p.add_param(t) p.add_param(t)
p.add_param('size_t', p.name + '_size') p.add_param('size_t', p.name + '_size')
p.bad_param('${name} == nullptr', 'Null pointer') p.bad_param('${name} == nullptr', 'Null pointer')
p.cpp_write = '${type}(${name})'
p.write = [
'auto* it = std::copy_n(${result}.begin(), std::min(${result}.size(), ${name}_size - 1), ${name});'
'*it = \'\\0\''
]
else: else:
p.add_param(t) p.add_param(t)
p.bad_param('${name} == nullptr', 'Null pointer') p.bad_param('${name} == nullptr', 'Null pointer')
p.read = '${type}(${name})'
p.read = '${type}(${name})'
p.cpp_write = '${type}(${name})'
p.virtual_read = ['${name}.c_str()']
if p.type.is_reference():
p.write = ['*${name} = ${result}.c_str()']
else:
p.write = [
'auto* it = std::copy_n(${result}.begin(), std::min(${result}.size(), ${name}_size - 1), ${name});'
'*it = \'\\0\''
]
class Handle: class Handle:
def __init__(self, def __init__(self, name: str, ctype: str, cpptype: str, **kwargs) -> None:
name: str,
ctype: str,
cpptype: str,
ref: Optional[bool] = None) -> None:
self.name = name self.name = name
self.ctype = ctype self.ctype = ctype
self.cpptype = cpptype self.cpptype = cpptype
self.opaque_type = self.ctype + '_t'
self.cpp_class = CPPClass(name, ctype) self.cpp_class = CPPClass(name, ctype)
add_handle(name, ctype, cpptype, ref=ref) add_handle(name, ctype, cpptype, **kwargs)
cpp_type_map[cpptype] = name cpp_type_map[cpptype] = name
def cname(self, name: str) -> str: def cname(self, name: str) -> str:
...@@ -833,6 +960,7 @@ class Handle: ...@@ -833,6 +960,7 @@ class Handle:
return Template(s).safe_substitute(name=self.name, return Template(s).safe_substitute(name=self.name,
ctype=self.ctype, ctype=self.ctype,
cpptype=self.cpptype, cpptype=self.cpptype,
opaque_type=self.opaque_type,
**kwargs) **kwargs)
def constructor(self, def constructor(self,
...@@ -887,6 +1015,137 @@ class Handle: ...@@ -887,6 +1015,137 @@ class Handle:
cpp_classes.append(self.cpp_class) cpp_classes.append(self.cpp_class)
interface_handle_definition = Template('''
extern "C" struct ${ctype};
struct ${ctype} {
template<class... Ts>
${ctype}(void* p, ${copier} c, ${deleter} d, Ts&&... xs)
: object_ptr(p, c, d), xobject(std::forward<Ts>(xs)...)
{}
manage_generic_ptr<${copier}, ${deleter}> object_ptr = nullptr;
${cpptype} xobject;
${functions}
};
''')
c_api_virtual_impl = Template('''
${return_type} ${name}(${params}) const
{
${output_decls}
if (${fname} == nullptr)
throw std::runtime_error("${name} function is missing.");
auto api_error_result = ${fname}(${args});
if (api_error_result != ${success})
throw std::runtime_error("Error in ${name}.");
return ${output};
}
''')
def generate_virtual_impl(f: Function, fname: str) -> str:
success = success_type
name = f.name
return_type = 'void'
output_decls = ''
output = ''
largs = []
lparams = []
if f.returns:
return_type = f.returns.type.str()
output_decls = '\n'.join(f.returns.virtual_output_declarations())
largs += f.returns.virtual_output_args()
output = f.returns.virtual_output()
largs += [arg for p in f.params for arg in p.virtual_arg()]
lparams += [p.virtual_param() for p in f.params if not p.this]
args = ', '.join(largs)
params = ', '.join(lparams)
return c_api_virtual_impl.substitute(locals())
class Interface(Handle):
def __init__(self, name: str, ctype: str, cpptype: str) -> None:
super().__init__(name, ctype, cpptype, skip_def=True)
self.ifunctions: List[Function] = []
self.members: List[str] = []
def mname(self, name: str) -> str:
return name + "_f"
def constructor( # type: ignore
self,
name: str,
params: Optional[List[Parameter]] = None,
**kwargs) -> 'Interface':
create = self.substitute('allocate<${opaque_type}>($@)')
initial_params = gparams(obj='void*',
c=self.cname('copy'),
d=self.cname('delete'))
add_function(self.cname(name),
params=initial_params + (params or []),
invoke=create,
returns=self.opaque_type,
return_name=self.name,
**kwargs)
return self
def method(self, *args, **kwargs) -> 'Interface':
super().method(*args, **kwargs)
return self
def virtual(self,
name: str,
params: Optional[List[Parameter]] = None,
const: Optional[bool] = None,
**kwargs) -> 'Interface':
# Add this parameter to the function
this = Parameter('obj', 'void*', this=True)
this.virtual_read = ['object_ptr.data']
f = Function(name,
params=[this] + (params or []),
virtual=True,
**kwargs)
self.ifunctions.append(f)
add_function(self.cname('set_' + name),
params=gparams(obj=self.opaque_type,
input=self.cname(name)),
invoke='${{obj}}->{name} = ${{input}}'.format(
name=self.mname(name)))
return self
def generate_function(self, f: Function):
cname = self.cname(f.name)
mname = self.mname(f.name)
function = generate_virtual_impl(f, fname=mname)
return f"{cname} {mname} = nullptr;{function}"
def generate(self):
required_functions = [
Function('copy',
params=gparams(out='void**', input='void*'),
virtual=True),
Function('delete', params=gparams(input='void*'), virtual=True)
]
for f in self.ifunctions + required_functions:
f.update()
c_header_preamble.extend([
f.get_cfunction().generate_function_pointer(self.cname(f.name))
for f in self.ifunctions + required_functions
])
function_list = [self.generate_function(f) for f in self.ifunctions]
ctype = self.ctype
cpptype = self.cpptype
copier = self.cname('copy')
deleter = self.cname('delete')
functions = '\n'.join(function_list)
c_api_body_preamble.append(
interface_handle_definition.substitute(locals()))
def handle(ctype: str, def handle(ctype: str,
cpptype: str, cpptype: str,
name: Optional[str] = None, name: Optional[str] = None,
...@@ -906,6 +1165,23 @@ def handle(ctype: str, ...@@ -906,6 +1165,23 @@ def handle(ctype: str,
return with_handle return with_handle
def interface(ctype: str, cpptype: str,
name: Optional[str] = None) -> Callable:
def with_interface(f):
n = name or f.__name__
h = Interface(n, ctype, cpptype)
f(h)
h.generate()
@wraps(f)
def decorated(*args, **kwargs):
return f(*args, **kwargs)
return decorated
return with_interface
def template_eval(template, **kwargs): def template_eval(template, **kwargs):
start = '<%' start = '<%'
end = '%>' end = '%>'
...@@ -928,7 +1204,7 @@ def run(args: List[str]) -> None: ...@@ -928,7 +1204,7 @@ def run(args: List[str]) -> None:
else: else:
sys.stdout.write(generate_c_header()) sys.stdout.write(generate_c_header())
sys.stdout.write(generate_c_api_body()) sys.stdout.write(generate_c_api_body())
sys.stdout.write(generate_cpp_header()) # sys.stdout.write(generate_cpp_header())
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <migraphx/load_save.hpp> #include <migraphx/load_save.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp> #include <migraphx/convert_to_json.hpp>
#include <algorithm> #include <algorithm>
...@@ -212,6 +213,39 @@ void print_program(const program& p) { std::cout << p << std::endl; } ...@@ -212,6 +213,39 @@ void print_program(const program& p) { std::cout << p << std::endl; }
void print_module(const module& m) { std::cout << m << std::endl; } void print_module(const module& m) { std::cout << m << std::endl; }
struct experimental_custom_op
{
std::string name;
experimental_custom_op() = default;
experimental_custom_op(std::string pname) : name(std::move(pname)) {}
};
template <class CustomOp>
struct custom_operation
{
template <class Self, class F>
static auto reflect(Self&, F)
{
return pack();
}
CustomOp op;
std::string name() const { return op.xobject.name; }
shape compute_shape(std::vector<shape> inputs) const
{
return op.compute_shape(std::move(inputs));
}
argument compute(const std::vector<argument>&) const { MIGRAPHX_THROW("Not computable"); }
};
template <class CustomOp>
void register_custom_op(const CustomOp& op)
{
register_op(custom_operation<CustomOp>{op});
}
migraphx::context get_context(const program& p) { return p.get_context(); } migraphx::context get_context(const program& p) { return p.get_context(); }
} // namespace migraphx } // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment