Unverified Commit 4a5a23a4 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Extend lifetimes in C++ API (#1139)

Helps avoid dangling references. This also deprecates the constructors that didnt take a lifetime annotation since its ambiguous the lifetime.
parent 8b4c417c
...@@ -4,7 +4,7 @@ CheckOptions: ...@@ -4,7 +4,7 @@ CheckOptions:
- key: bugprone-unused-return-value.CheckedFunctions - key: bugprone-unused-return-value.CheckedFunctions
value: '::std::async;::std::launder;::std::remove;::std::remove_if;::std::unique;::std::unique_ptr::release;::std::basic_string::empty;::std::vector::empty;::std::find;::std::find_if;::std::find_if_not;::std::all_of;::std::any_of;::std::none_of;::std::count;::std::count_if;::std::mismatch;::std::find_end;::std::find_first_of;::std::adjacent_find;::std::search;::std::search_n;::std::nth_element;::std::lower_bound;::std::upper_bound;::std::binary_search;::std::equal_range;::std::max;::std::max_element;::std::min;::std::min_element;::std::minmax;::std::minmax_element;::std::equal;::std::lexicographical_compare;::std::accumulate;::std::inner_product' value: '::std::async;::std::launder;::std::remove;::std::remove_if;::std::unique;::std::unique_ptr::release;::std::basic_string::empty;::std::vector::empty;::std::find;::std::find_if;::std::find_if_not;::std::all_of;::std::any_of;::std::none_of;::std::count;::std::count_if;::std::mismatch;::std::find_end;::std::find_first_of;::std::adjacent_find;::std::search;::std::search_n;::std::nth_element;::std::lower_bound;::std::upper_bound;::std::binary_search;::std::equal_range;::std::max;::std::max_element;::std::min;::std::min_element;::std::minmax;::std::minmax_element;::std::equal;::std::lexicographical_compare;::std::accumulate;::std::inner_product'
- key: cppcoreguidelines-macro-usage.AllowedRegexp - key: cppcoreguidelines-macro-usage.AllowedRegexp
value: 'DEBUG|ASSERT|ASSUME|UNREACHABLE|FALLTHROUGH|STRINGIZE|_HAS_|_THROW|_REQUIRES|_DECLARE_|_VISIT_|_REGISTER_|_GENERATE_|_DETAIL_|_TIDY_|_MANAGE_PTR|_MATCHER|DEVICE_SHARED|_WORKAROUND_' value: 'DEBUG|ASSERT|ASSUME|UNREACHABLE|FALLTHROUGH|DEPRECATED|STRINGIZE|_HAS_|_THROW|_REQUIRES|_DECLARE_|_VISIT_|_REGISTER_|_GENERATE_|_DETAIL_|_TIDY_|_MANAGE_PTR|_MATCHER|DEVICE_SHARED|_WORKAROUND_'
- key: modernize-loop-convert.MinConfidence - key: modernize-loop-convert.MinConfidence
value: risky value: risky
- key: modernize-loop-convert.NamingStyle - key: modernize-loop-convert.NamingStyle
......
...@@ -15,6 +15,16 @@ namespace migraphx { ...@@ -15,6 +15,16 @@ namespace migraphx {
inline namespace api { // NOLINT inline namespace api { // NOLINT
#endif #endif
#ifdef __has_cpp_attribute
#if __has_cpp_attribute(deprecated)
#define MIGRAPHX_DEPRECATED(...) [[deprecated(__VA_ARGS__)]]
#endif
#endif
#ifndef MIGRAPHX_DEPRECATED
#define MIGRAPHX_DEPRECATED(...)
#endif
template <int N> template <int N>
struct rank : rank<N - 1> struct rank : rank<N - 1>
{ {
...@@ -99,34 +109,22 @@ struct iota_iterator ...@@ -99,34 +109,22 @@ struct iota_iterator
return it; return it;
} }
// TODO: operator-> // TODO: operator->
reference operator*() const { return (*f)(index); } reference operator*() const { return f(index); }
};
template <class F, class Iterator> friend iota_iterator operator+(iota_iterator x, iota_iterator y)
inline iota_iterator<F, Iterator> operator+(iota_iterator<F, Iterator> x, {
iota_iterator<F, Iterator> y) return iota_iterator(x.index + y.index, x.f);
{ }
return iota_iterator<F, Iterator>(x.index + y.index, x.f);
}
template <class F, class Iterator> friend iota_iterator operator-(iota_iterator x, iota_iterator y)
inline iota_iterator<F, Iterator> operator-(iota_iterator<F, Iterator> x, {
iota_iterator<F, Iterator> y) return iota_iterator(x.index - y.index, x.f);
{ }
return iota_iterator<F, Iterator>(x.index - y.index, x.f);
}
template <class F, class Iterator> friend bool operator==(iota_iterator x, iota_iterator y) { return x.index == y.index; }
inline bool operator==(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y)
{
return x.index == y.index;
}
template <class F, class Iterator> friend bool operator!=(iota_iterator x, iota_iterator y) { return x.index != y.index; }
inline bool operator!=(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y) };
{
return x.index != y.index;
}
template <class Derived> template <class Derived>
struct array_base struct array_base
...@@ -136,8 +134,20 @@ struct array_base ...@@ -136,8 +134,20 @@ struct array_base
template <class T> template <class T>
using value_type_t = decltype(std::declval<T>()[0]); using value_type_t = decltype(std::declval<T>()[0]);
struct iterator_read
{
const Derived* self;
template <class D = Derived>
value_type_t<D> operator()(size_t pidx) const
{
return (*self)[pidx];
}
};
template <class T> template <class T>
using iterator_t = iota_iterator<typename T::iterator_read>; using iterator_t = iota_iterator<iterator_read>;
bool empty() const { return derived().size() == 0; }
template <class D = Derived> template <class D = Derived>
value_type_t<D> front() const value_type_t<D> front() const
...@@ -154,13 +164,13 @@ struct array_base ...@@ -154,13 +164,13 @@ struct array_base
template <class D = Derived> template <class D = Derived>
iterator_t<D> begin() const iterator_t<D> begin() const
{ {
return {0, {derived().get_handle_ptr()}}; return {0, {&derived()}};
} }
template <class D = Derived> template <class D = Derived>
iterator_t<D> end() const iterator_t<D> end() const
{ {
return {derived().size(), {derived().get_handle_ptr()}}; return {derived().size(), {&derived()}};
} }
}; };
...@@ -200,9 +210,25 @@ struct borrow ...@@ -200,9 +210,25 @@ struct borrow
{ {
}; };
template <class T>
struct share
{
share(std::shared_ptr<T> p) : ptr(std::move(p)) {}
template <class U>
std::shared_ptr<U> alias(U* p) const
{
return std::shared_ptr<U>{ptr, p};
}
private:
std::shared_ptr<T> ptr;
};
template <class Derived, class T, class D, D Deleter, class A, A Assigner> template <class Derived, class T, class D, D Deleter, class A, A Assigner>
struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>> struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
{ {
using handle_type = T;
handle_base() : m_handle(nullptr) {} handle_base() : m_handle(nullptr) {}
template <class F, class... Ts> template <class F, class... Ts>
void make_handle(F f, Ts&&... xs) void make_handle(F f, Ts&&... xs)
...@@ -231,6 +257,14 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>> ...@@ -231,6 +257,14 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
m_handle = std::shared_ptr<U>{ptr, [](U*) {}}; m_handle = std::shared_ptr<U>{ptr, [](U*) {}};
} }
template <class U, class V>
void set_handle(U* ptr, share<V> b)
{
m_handle = std::shared_ptr<T>{ptr, [b](U*) {}};
}
share<T> share_handle() const { return {m_handle}; }
template <class U> template <class U>
void assign_to_handle(U* x) void assign_to_handle(U* x)
{ {
...@@ -241,6 +275,17 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>> ...@@ -241,6 +275,17 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
std::shared_ptr<T> m_handle; std::shared_ptr<T> m_handle;
}; };
// NOLINTNEXTLINE
#define MIGRAPHX_HANDLE_CONSTRUCTOR(name) \
template <class HandleType, \
class Lifetime, \
class = \
typename std::enable_if<std::is_convertible<HandleType*, handle_type*>{}>::type> \
name(HandleType* p, Lifetime lifetime) \
{ \
this->set_handle(p, std::move(lifetime)); \
}
template <class Base> template <class Base>
struct interface_base : Base struct interface_base : Base
{ {
...@@ -398,11 +443,10 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -398,11 +443,10 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
{ {
shape() {} shape() {}
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
shape(const migraphx_shape* p) { this->set_handle(p, borrow{}); } shape(const migraphx_shape* p) { this->set_handle(p, borrow{}); }
shape(migraphx_shape* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(shape);
shape(migraphx_shape* p, borrow) { this->set_handle(p, borrow{}); }
/// Construct a scalar shape /// Construct a scalar shape
shape(migraphx_shape_datatype_t type) shape(migraphx_shape_datatype_t type)
...@@ -479,10 +523,9 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -479,10 +523,9 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{ {
argument() {} argument() {}
argument(migraphx_argument* p, borrow) { this->set_handle(p, borrow{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(argument);
argument(migraphx_argument* p, own) { this->set_handle(p, own{}); }
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); } argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); }
argument(shape pshape, void* pbuffer) argument(shape pshape, void* pbuffer)
...@@ -494,7 +537,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -494,7 +537,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_argument_shape, &pout, this->get_handle_ptr()); call(&migraphx_argument_shape, &pout, this->get_handle_ptr());
return {pout}; return {pout, this->share_handle()};
} }
char* data() const char* data() const
...@@ -526,9 +569,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target) ...@@ -526,9 +569,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
{ {
target() {} target() {}
target(migraphx_target* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(target);
target(migraphx_target* p, borrow) { this->set_handle(p, borrow{}); }
/// Construct a target from its name /// Construct a target from its name
target(const char* name) { this->make_handle(&migraphx_target_create, name); } target(const char* name) { this->make_handle(&migraphx_target_create, name); }
...@@ -538,15 +579,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -538,15 +579,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{ {
program_parameter_shapes() {} program_parameter_shapes() {}
program_parameter_shapes(migraphx_program_parameter_shapes* p, own) MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameter_shapes);
{
this->set_handle(p, own{});
}
program_parameter_shapes(migraphx_program_parameter_shapes* p, borrow)
{
this->set_handle(p, borrow{});
}
size_t size() const size_t size() const
{ {
...@@ -559,7 +592,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -559,7 +592,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_program_parameter_shapes_get, &pout, this->get_handle_ptr(), pname); call(&migraphx_program_parameter_shapes_get, &pout, this->get_handle_ptr(), pname);
return {pout}; return {pout, this->share_handle()};
} }
std::vector<const char*> names() const std::vector<const char*> names() const
...@@ -576,10 +609,9 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -576,10 +609,9 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
/// A class to construct the inputs parameters for a program /// A class to construct the inputs parameters for a program
struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters) struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
{ {
program_parameters(migraphx_program_parameters* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameters);
program_parameters(migraphx_program_parameters* p, borrow) { this->set_handle(p, borrow{}); }
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
program_parameters(migraphx_program_parameters* p) { this->set_handle(p, borrow{}); } program_parameters(migraphx_program_parameters* p) { this->set_handle(p, borrow{}); }
program_parameters() { this->make_handle(&migraphx_program_parameters_create); } program_parameters() { this->make_handle(&migraphx_program_parameters_create); }
...@@ -604,9 +636,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters) ...@@ -604,9 +636,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments> struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{ {
arguments(migraphx_arguments* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(arguments);
arguments(migraphx_arguments* p, borrow) { this->set_handle(p, borrow{}); }
size_t size() const size_t size() const
{ {
...@@ -619,27 +649,13 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments> ...@@ -619,27 +649,13 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{ {
const_migraphx_argument_t pout; const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, this->get_handle_ptr(), pidx); call(&migraphx_arguments_get, &pout, this->get_handle_ptr(), pidx);
return {pout}; return {pout, this->share_handle()};
} }
struct iterator_read
{
migraphx_arguments* self;
argument operator()(size_t pidx) const
{
const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, self, pidx);
return {pout};
}
};
}; };
struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{ {
shapes(migraphx_shapes* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(shapes);
shapes(migraphx_shapes* p, borrow) { this->set_handle(p, borrow{}); }
size_t size() const size_t size() const
{ {
...@@ -652,26 +668,13 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> ...@@ -652,26 +668,13 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_shapes_get, &pout, this->get_handle_ptr(), pidx); call(&migraphx_shapes_get, &pout, this->get_handle_ptr(), pidx);
return {pout}; return {pout, this->share_handle()};
} }
struct iterator_read
{
migraphx_shapes* self;
shape operator()(size_t pidx) const
{
const_migraphx_shape_t pout;
call(&migraphx_shapes_get, &pout, self, pidx);
return {pout};
}
};
}; };
struct operation : MIGRAPHX_HANDLE_BASE(operation) struct operation : MIGRAPHX_HANDLE_BASE(operation)
{ {
operation(migraphx_operation* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(operation);
operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts> template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs) operation(const char* name, const char* attributes = nullptr, Ts... xs)
...@@ -689,15 +692,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation) ...@@ -689,15 +692,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction) struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction)
{ {
instruction(migraphx_instruction* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(instruction);
}; };
struct instructions : MIGRAPHX_HANDLE_BASE(instructions) struct instructions : MIGRAPHX_HANDLE_BASE(instructions)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(instructions);
instructions(migraphx_instructions* p, own) { this->set_handle(p, own{}); }
instructions(migraphx_instructions* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts> template <class... Ts>
instructions(Ts... xs) instructions(Ts... xs)
...@@ -711,33 +711,36 @@ struct module; ...@@ -711,33 +711,36 @@ struct module;
struct modules : MIGRAPHX_HANDLE_BASE(modules) struct modules : MIGRAPHX_HANDLE_BASE(modules)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(modules);
modules(migraphx_modules* p, own) { this->set_handle(p, own{}); }
modules(migraphx_modules* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts> template <class... Ts>
modules(Ts... xs) modules(Ts... xs)
{ {
std::array<migraphx_module_t, sizeof...(Ts)> a = {xs.mm...}; std::array<migraphx_module_t, sizeof...(Ts)> a = {xs.get_handle_ptr()...};
this->make_handle(&migraphx_modules_create, a.data(), a.size()); this->make_handle(&migraphx_modules_create, a.data(), a.size());
} }
}; };
struct module struct module
{ {
migraphx_module_t mm; MIGRAPHX_DEPRECATED("Constructor without lifetime annotation is deprecated.")
module(migraphx_module* m) : mm(std::shared_ptr<migraphx_module*>(), m) {}
module(const migraphx_module_t& m) : mm(m) {} module(migraphx_module* m, borrow) : mm(std::shared_ptr<migraphx_module*>(), m) {}
void print() const { call(&migraphx_module_print, mm); } template <class T>
module(migraphx_module* m, share<T> b) : mm(b.alias(m))
{
}
void print() const { call(&migraphx_module_print, mm.get()); }
instruction add_instruction(const migraphx::operation& op, const migraphx::instructions& args) instruction add_instruction(const migraphx::operation& op, const migraphx::instructions& args)
{ {
migraphx_instruction_t op_ins; migraphx_instruction_t op_ins;
call(&migraphx_module_add_instruction, call(&migraphx_module_add_instruction,
&op_ins, &op_ins,
mm, mm.get(),
op.get_handle_ptr(), op.get_handle_ptr(),
args.get_handle_ptr()); args.get_handle_ptr());
return instruction(op_ins, own{}); return instruction(op_ins, own{});
...@@ -750,7 +753,7 @@ struct module ...@@ -750,7 +753,7 @@ struct module
migraphx_instruction_t op_ins; migraphx_instruction_t op_ins;
call(&migraphx_module_add_instruction_with_mod_args, call(&migraphx_module_add_instruction_with_mod_args,
&op_ins, &op_ins,
mm, mm.get(),
op.get_handle_ptr(), op.get_handle_ptr(),
args.get_handle_ptr(), args.get_handle_ptr(),
module_args.get_handle_ptr()); module_args.get_handle_ptr());
...@@ -760,39 +763,53 @@ struct module ...@@ -760,39 +763,53 @@ struct module
instruction add_parameter(const std::string& name, shape s) instruction add_parameter(const std::string& name, shape s)
{ {
migraphx_instruction_t param_ins; migraphx_instruction_t param_ins;
call(&migraphx_module_add_parameter, &param_ins, mm, name.c_str(), s.get_handle_ptr()); call(
&migraphx_module_add_parameter, &param_ins, mm.get(), name.c_str(), s.get_handle_ptr());
return instruction(param_ins, own{}); return instruction(param_ins, own{});
} }
instruction add_return(const migraphx::instructions& args) instruction add_return(const migraphx::instructions& args)
{ {
migraphx_instruction_t ret_ins; migraphx_instruction_t ret_ins;
call(&migraphx_module_add_return, &ret_ins, mm, args.get_handle_ptr()); call(&migraphx_module_add_return, &ret_ins, mm.get(), args.get_handle_ptr());
return instruction(ret_ins, own{}); return instruction(ret_ins, own{});
} }
migraphx_module_t get_handle_ptr() const { return mm.get(); }
private:
std::shared_ptr<migraphx_module> mm;
}; };
struct context struct context
{ {
migraphx_context_t ctx; context(migraphx_context* p, borrow) : ctx(std::shared_ptr<migraphx_context*>(), p) {}
template <class T>
context(migraphx_context* p, share<T> b) : ctx(b.alias(p))
{
}
void finish() const { call(&migraphx_context_finish, ctx); } void finish() const { call(&migraphx_context_finish, ctx.get()); }
template <class T> template <class T>
T get_queue() T get_queue()
{ {
void* out; void* out;
call(&migraphx_context_get_queue, &out, ctx); call(&migraphx_context_get_queue, &out, ctx.get());
// TODO: check type here // TODO: check type here
return reinterpret_cast<T>(out); return reinterpret_cast<T>(out);
} }
private:
std::shared_ptr<migraphx_context> ctx;
}; };
struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options) struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
{ {
compile_options() { this->make_handle(&migraphx_compile_options_create); } compile_options() { this->make_handle(&migraphx_compile_options_create); }
compile_options(migraphx_compile_options* p, own) { this->set_handle(p, own()); } MIGRAPHX_HANDLE_CONSTRUCTOR(compile_options);
/// For targets with offloaded memory(such as the gpu), this will insert /// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the /// instructions during compilation to copy the input parameters to the
...@@ -816,9 +833,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -816,9 +833,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{ {
program() { this->make_handle(&migraphx_program_create); } program() { this->make_handle(&migraphx_program_create); }
program(migraphx_program* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(program);
program(migraphx_program* p, borrow) { this->set_handle(p, borrow{}); }
/// Compile the program for a specific target to be ran on /// Compile the program for a specific target to be ran on
void compile(const target& ptarget, const compile_options& poptions) const void compile(const target& ptarget, const compile_options& poptions) const
...@@ -881,21 +896,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -881,21 +896,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{ {
migraphx_module_t p_modu; migraphx_module_t p_modu;
call(&migraphx_program_get_main_module, &p_modu, this->get_handle_ptr()); call(&migraphx_program_get_main_module, &p_modu, this->get_handle_ptr());
return module{p_modu}; return module{p_modu, this->share_handle()};
} }
context experimental_get_context() context experimental_get_context()
{ {
migraphx_context_t ctx; migraphx_context_t ctx;
call(&migraphx_program_experimental_get_context, &ctx, this->get_handle_ptr()); call(&migraphx_program_experimental_get_context, &ctx, this->get_handle_ptr());
return context{ctx}; return context{ctx, this->share_handle()};
} }
module create_module(const std::string& name) module create_module(const std::string& name)
{ {
migraphx_module_t p_modu; migraphx_module_t p_modu;
call(&migraphx_program_create_module, &p_modu, this->get_handle_ptr(), name.data()); call(&migraphx_program_create_module, &p_modu, this->get_handle_ptr(), name.data());
return module{p_modu}; return module{p_modu, this->share_handle()};
} }
friend bool operator!=(const program& px, const program& py) { return !(px == py); } friend bool operator!=(const program& px, const program& py) { return !(px == py); }
...@@ -904,10 +919,9 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -904,10 +919,9 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
// options for migraphx file format options // options for migraphx file format options
struct file_options : MIGRAPHX_HANDLE_BASE(file_options) struct file_options : MIGRAPHX_HANDLE_BASE(file_options)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(file_options);
file_options() { this->make_handle(&migraphx_file_options_create); } file_options() { this->make_handle(&migraphx_file_options_create); }
file_options(migraphx_file_options* p, own) { this->set_handle(p, own()); }
// set file format // set file format
void set_file_format(const char* format) void set_file_format(const char* format)
{ {
...@@ -947,7 +961,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options) ...@@ -947,7 +961,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{ {
onnx_options() { this->make_handle(&migraphx_onnx_options_create); } onnx_options() { this->make_handle(&migraphx_onnx_options_create); }
onnx_options(migraphx_onnx_options* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(onnx_options);
/// Make onnx parser treat an inputs with a certain dimensions /// Make onnx parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim) void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
...@@ -1029,7 +1043,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options) ...@@ -1029,7 +1043,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
{ {
tf_options() { this->make_handle(&migraphx_tf_options_create); } tf_options() { this->make_handle(&migraphx_tf_options_create); }
tf_options(migraphx_tf_options* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(tf_options);
/// Make tf parser treat an inputs with a certain dimensions /// Make tf parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim) void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
...@@ -1082,7 +1096,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names) ...@@ -1082,7 +1096,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
{ {
quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); } quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); }
quantize_op_names(migraphx_quantize_op_names* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_op_names);
void add(const std::string& name) void add(const std::string& name)
{ {
...@@ -1107,12 +1121,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options) ...@@ -1107,12 +1121,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
{ {
quantize_int8_options() { this->make_handle(&migraphx_quantize_int8_options_create); } quantize_int8_options() { this->make_handle(&migraphx_quantize_int8_options_create); }
quantize_int8_options(migraphx_quantize_int8_options* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_int8_options);
quantize_int8_options(migraphx_quantize_int8_options* p, borrow)
{
this->set_handle(p, borrow{});
}
/// Add an operator that should be quantized /// Add an operator that should be quantized
void add_op_name(const std::string& name) void add_op_name(const std::string& name)
......
...@@ -9,6 +9,7 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR) ...@@ -9,6 +9,7 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
add_dependencies(check ${NAME}) add_dependencies(check ${NAME})
endfunction() endfunction()
add_api_test(array_base test_array_base.cpp ${TEST_ONNX_DIR})
add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR}) add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR})
add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR}) add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR})
add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR}) add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR})
......
#include <migraphx/migraphx.hpp>
#include "test.hpp"
struct array2 : migraphx::array_base<array2>
{
std::vector<int> v;
array2() = default;
array2(std::initializer_list<int> x) : v(x) {}
std::size_t size() const { return v.size(); }
int operator[](std::size_t i) const { return v[i]; }
};
TEST_CASE(iterators)
{
array2 a = {1, 2, 3};
EXPECT(bool{std::equal(a.begin(), a.end(), a.v.begin())});
}
TEST_CASE(front_back)
{
array2 a = {1, 2, 3};
EXPECT(a.front() == 1);
EXPECT(a.back() == 3);
}
TEST_CASE(empty)
{
array2 a = {1, 2, 3};
EXPECT(not a.empty());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -60,16 +60,16 @@ TEST_CASE(if_then_else_op) ...@@ -60,16 +60,16 @@ TEST_CASE(if_then_else_op)
p.compile(migraphx::target("ref")); p.compile(migraphx::target("ref"));
auto outputs = auto outputs =
p.eval({{"cond", migraphx::argument(cond_s, &cond)}, {"x", x_arg}, {"y", y_arg}}); p.eval({{"cond", migraphx::argument(cond_s, &cond)}, {"x", x_arg}, {"y", y_arg}});
return outputs; return outputs[0];
}; };
// then branch // then branch
auto then_res = run_prog(true); auto then_res = run_prog(true);
CHECK(bool{then_res[0] == x_arg}); CHECK(bool{then_res == x_arg});
// else branch // else branch
auto else_res = run_prog(false); auto else_res = run_prog(false);
CHECK(bool{else_res[0] == y_arg}); CHECK(bool{else_res == y_arg});
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment