Commit 2f268bc2 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into mlir-c

parents f75c5a38 aa7ff911
...@@ -15,6 +15,16 @@ namespace migraphx { ...@@ -15,6 +15,16 @@ namespace migraphx {
inline namespace api { // NOLINT inline namespace api { // NOLINT
#endif #endif
#ifdef __has_cpp_attribute
#if __has_cpp_attribute(deprecated)
#define MIGRAPHX_DEPRECATED(...) [[deprecated(__VA_ARGS__)]]
#endif
#endif
#ifndef MIGRAPHX_DEPRECATED
#define MIGRAPHX_DEPRECATED(...)
#endif
template <int N> template <int N>
struct rank : rank<N - 1> struct rank : rank<N - 1>
{ {
...@@ -29,10 +39,7 @@ template <class T, class F, class... Ts> ...@@ -29,10 +39,7 @@ template <class T, class F, class... Ts>
T* make(F f, Ts&&... xs) T* make(F f, Ts&&... xs)
{ {
T* result = nullptr; T* result = nullptr;
// cppcheck-suppress redundantInitialization auto e = f(&result, std::forward<Ts>(xs)...);
// cppcheck-suppress redundantAssignment
// cppcheck-suppress unreadVariable
auto e = f(&result, std::forward<Ts>(xs)...);
if(e != migraphx_status_success) if(e != migraphx_status_success)
throw std::runtime_error("Failed to call function"); throw std::runtime_error("Failed to call function");
return result; return result;
...@@ -41,9 +48,6 @@ T* make(F f, Ts&&... xs) ...@@ -41,9 +48,6 @@ T* make(F f, Ts&&... xs)
template <class F, class... Ts> template <class F, class... Ts>
void call(F f, Ts&&... xs) void call(F f, Ts&&... xs)
{ {
// cppcheck-suppress redundantInitialization
// cppcheck-suppress redundantAssignment
// cppcheck-suppress unreadVariable
auto e = f(std::forward<Ts>(xs)...); auto e = f(std::forward<Ts>(xs)...);
if(e != migraphx_status_success) if(e != migraphx_status_success)
throw std::runtime_error("Failed to call function"); throw std::runtime_error("Failed to call function");
...@@ -99,34 +103,22 @@ struct iota_iterator ...@@ -99,34 +103,22 @@ struct iota_iterator
return it; return it;
} }
// TODO: operator-> // TODO: operator->
reference operator*() const { return (*f)(index); } reference operator*() const { return f(index); }
};
template <class F, class Iterator> friend iota_iterator operator+(iota_iterator x, iota_iterator y)
inline iota_iterator<F, Iterator> operator+(iota_iterator<F, Iterator> x, {
iota_iterator<F, Iterator> y) return iota_iterator(x.index + y.index, x.f);
{ }
return iota_iterator<F, Iterator>(x.index + y.index, x.f);
}
template <class F, class Iterator> friend iota_iterator operator-(iota_iterator x, iota_iterator y)
inline iota_iterator<F, Iterator> operator-(iota_iterator<F, Iterator> x, {
iota_iterator<F, Iterator> y) return iota_iterator(x.index - y.index, x.f);
{ }
return iota_iterator<F, Iterator>(x.index - y.index, x.f);
}
template <class F, class Iterator> friend bool operator==(iota_iterator x, iota_iterator y) { return x.index == y.index; }
inline bool operator==(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y)
{
return x.index == y.index;
}
template <class F, class Iterator> friend bool operator!=(iota_iterator x, iota_iterator y) { return x.index != y.index; }
inline bool operator!=(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y) };
{
return x.index != y.index;
}
template <class Derived> template <class Derived>
struct array_base struct array_base
...@@ -136,8 +128,20 @@ struct array_base ...@@ -136,8 +128,20 @@ struct array_base
template <class T> template <class T>
using value_type_t = decltype(std::declval<T>()[0]); using value_type_t = decltype(std::declval<T>()[0]);
struct iterator_read
{
const Derived* self;
template <class D = Derived>
value_type_t<D> operator()(size_t pidx) const
{
return (*self)[pidx];
}
};
template <class T> template <class T>
using iterator_t = iota_iterator<typename T::iterator_read>; using iterator_t = iota_iterator<iterator_read>;
bool empty() const { return derived().size() == 0; }
template <class D = Derived> template <class D = Derived>
value_type_t<D> front() const value_type_t<D> front() const
...@@ -154,13 +158,13 @@ struct array_base ...@@ -154,13 +158,13 @@ struct array_base
template <class D = Derived> template <class D = Derived>
iterator_t<D> begin() const iterator_t<D> begin() const
{ {
return {0, {derived().get_handle_ptr()}}; return {0, {&derived()}};
} }
template <class D = Derived> template <class D = Derived>
iterator_t<D> end() const iterator_t<D> end() const
{ {
return {derived().size(), {derived().get_handle_ptr()}}; return {derived().size(), {&derived()}};
} }
}; };
...@@ -200,9 +204,25 @@ struct borrow ...@@ -200,9 +204,25 @@ struct borrow
{ {
}; };
template <class T>
struct share
{
share(std::shared_ptr<T> p) : ptr(std::move(p)) {}
template <class U>
std::shared_ptr<U> alias(U* p) const
{
return std::shared_ptr<U>{ptr, p};
}
private:
std::shared_ptr<T> ptr;
};
template <class Derived, class T, class D, D Deleter, class A, A Assigner> template <class Derived, class T, class D, D Deleter, class A, A Assigner>
struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>> struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
{ {
using handle_type = T;
handle_base() : m_handle(nullptr) {} handle_base() : m_handle(nullptr) {}
template <class F, class... Ts> template <class F, class... Ts>
void make_handle(F f, Ts&&... xs) void make_handle(F f, Ts&&... xs)
...@@ -231,6 +251,14 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>> ...@@ -231,6 +251,14 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
m_handle = std::shared_ptr<U>{ptr, [](U*) {}}; m_handle = std::shared_ptr<U>{ptr, [](U*) {}};
} }
template <class U, class V>
void set_handle(U* ptr, share<V> b)
{
m_handle = std::shared_ptr<T>{ptr, [b](U*) {}};
}
share<T> share_handle() const { return {m_handle}; }
template <class U> template <class U>
void assign_to_handle(U* x) void assign_to_handle(U* x)
{ {
...@@ -241,6 +269,17 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>> ...@@ -241,6 +269,17 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
std::shared_ptr<T> m_handle; std::shared_ptr<T> m_handle;
}; };
// NOLINTNEXTLINE
#define MIGRAPHX_HANDLE_CONSTRUCTOR(name) \
template <class HandleType, \
class Lifetime, \
class = \
typename std::enable_if<std::is_convertible<HandleType*, handle_type*>{}>::type> \
name(HandleType* p, Lifetime lifetime) \
{ \
this->set_handle(p, std::move(lifetime)); \
}
template <class Base> template <class Base>
struct interface_base : Base struct interface_base : Base
{ {
...@@ -269,6 +308,7 @@ struct interface_base : Base ...@@ -269,6 +308,7 @@ struct interface_base : Base
T** y = reinterpret_cast<T**>(out); T** y = reinterpret_cast<T**>(out);
T* x = reinterpret_cast<T*>(input); T* x = reinterpret_cast<T*>(input);
assert(x != nullptr and y != nullptr and *y == nullptr); assert(x != nullptr and y != nullptr and *y == nullptr);
// cppcheck-suppress useSmartPointer
*y = new T(*x); // NOLINT *y = new T(*x); // NOLINT
}); });
}; };
...@@ -398,11 +438,10 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -398,11 +438,10 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
{ {
shape() {} shape() {}
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
shape(const migraphx_shape* p) { this->set_handle(p, borrow{}); } shape(const migraphx_shape* p) { this->set_handle(p, borrow{}); }
shape(migraphx_shape* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(shape);
shape(migraphx_shape* p, borrow) { this->set_handle(p, borrow{}); }
/// Construct a scalar shape /// Construct a scalar shape
shape(migraphx_shape_datatype_t type) shape(migraphx_shape_datatype_t type)
...@@ -479,10 +518,9 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -479,10 +518,9 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{ {
argument() {} argument() {}
argument(migraphx_argument* p, borrow) { this->set_handle(p, borrow{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(argument);
argument(migraphx_argument* p, own) { this->set_handle(p, own{}); }
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); } argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); }
argument(shape pshape, void* pbuffer) argument(shape pshape, void* pbuffer)
...@@ -494,7 +532,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -494,7 +532,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_argument_shape, &pout, this->get_handle_ptr()); call(&migraphx_argument_shape, &pout, this->get_handle_ptr());
return {pout}; return {pout, this->share_handle()};
} }
char* data() const char* data() const
...@@ -526,9 +564,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target) ...@@ -526,9 +564,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
{ {
target() {} target() {}
target(migraphx_target* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(target);
target(migraphx_target* p, borrow) { this->set_handle(p, borrow{}); }
/// Construct a target from its name /// Construct a target from its name
target(const char* name) { this->make_handle(&migraphx_target_create, name); } target(const char* name) { this->make_handle(&migraphx_target_create, name); }
...@@ -538,15 +574,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -538,15 +574,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{ {
program_parameter_shapes() {} program_parameter_shapes() {}
program_parameter_shapes(migraphx_program_parameter_shapes* p, own) MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameter_shapes);
{
this->set_handle(p, own{});
}
program_parameter_shapes(migraphx_program_parameter_shapes* p, borrow)
{
this->set_handle(p, borrow{});
}
size_t size() const size_t size() const
{ {
...@@ -559,7 +587,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -559,7 +587,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_program_parameter_shapes_get, &pout, this->get_handle_ptr(), pname); call(&migraphx_program_parameter_shapes_get, &pout, this->get_handle_ptr(), pname);
return {pout}; return {pout, this->share_handle()};
} }
std::vector<const char*> names() const std::vector<const char*> names() const
...@@ -576,10 +604,9 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -576,10 +604,9 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
/// A class to construct the inputs parameters for a program /// A class to construct the inputs parameters for a program
struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters) struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
{ {
program_parameters(migraphx_program_parameters* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameters);
program_parameters(migraphx_program_parameters* p, borrow) { this->set_handle(p, borrow{}); }
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
program_parameters(migraphx_program_parameters* p) { this->set_handle(p, borrow{}); } program_parameters(migraphx_program_parameters* p) { this->set_handle(p, borrow{}); }
program_parameters() { this->make_handle(&migraphx_program_parameters_create); } program_parameters() { this->make_handle(&migraphx_program_parameters_create); }
...@@ -604,9 +631,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters) ...@@ -604,9 +631,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments> struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{ {
arguments(migraphx_arguments* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(arguments);
arguments(migraphx_arguments* p, borrow) { this->set_handle(p, borrow{}); }
size_t size() const size_t size() const
{ {
...@@ -619,27 +644,13 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments> ...@@ -619,27 +644,13 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{ {
const_migraphx_argument_t pout; const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, this->get_handle_ptr(), pidx); call(&migraphx_arguments_get, &pout, this->get_handle_ptr(), pidx);
return {pout}; return {pout, this->share_handle()};
} }
struct iterator_read
{
migraphx_arguments* self;
argument operator()(size_t pidx) const
{
const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, self, pidx);
return {pout};
}
};
}; };
struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{ {
shapes(migraphx_shapes* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(shapes);
shapes(migraphx_shapes* p, borrow) { this->set_handle(p, borrow{}); }
size_t size() const size_t size() const
{ {
...@@ -652,26 +663,13 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> ...@@ -652,26 +663,13 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_shapes_get, &pout, this->get_handle_ptr(), pidx); call(&migraphx_shapes_get, &pout, this->get_handle_ptr(), pidx);
return {pout}; return {pout, this->share_handle()};
} }
struct iterator_read
{
migraphx_shapes* self;
shape operator()(size_t pidx) const
{
const_migraphx_shape_t pout;
call(&migraphx_shapes_get, &pout, self, pidx);
return {pout};
}
};
}; };
struct operation : MIGRAPHX_HANDLE_BASE(operation) struct operation : MIGRAPHX_HANDLE_BASE(operation)
{ {
operation(migraphx_operation* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(operation);
operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts> template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs) operation(const char* name, const char* attributes = nullptr, Ts... xs)
...@@ -689,15 +687,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation) ...@@ -689,15 +687,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction) struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction)
{ {
instruction(migraphx_instruction* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(instruction);
}; };
struct instructions : MIGRAPHX_HANDLE_BASE(instructions) struct instructions : MIGRAPHX_HANDLE_BASE(instructions)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(instructions);
instructions(migraphx_instructions* p, own) { this->set_handle(p, own{}); }
instructions(migraphx_instructions* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts> template <class... Ts>
instructions(Ts... xs) instructions(Ts... xs)
...@@ -711,33 +706,36 @@ struct module; ...@@ -711,33 +706,36 @@ struct module;
struct modules : MIGRAPHX_HANDLE_BASE(modules) struct modules : MIGRAPHX_HANDLE_BASE(modules)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(modules);
modules(migraphx_modules* p, own) { this->set_handle(p, own{}); }
modules(migraphx_modules* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts> template <class... Ts>
modules(Ts... xs) modules(Ts... xs)
{ {
std::array<migraphx_module_t, sizeof...(Ts)> a = {xs.mm...}; std::array<migraphx_module_t, sizeof...(Ts)> a = {xs.get_handle_ptr()...};
this->make_handle(&migraphx_modules_create, a.data(), a.size()); this->make_handle(&migraphx_modules_create, a.data(), a.size());
} }
}; };
struct module struct module
{ {
migraphx_module_t mm; MIGRAPHX_DEPRECATED("Constructor without lifetime annotation is deprecated.")
module(migraphx_module* m) : mm(std::shared_ptr<migraphx_module*>(), m) {}
module(const migraphx_module_t& m) : mm(m) {} module(migraphx_module* m, borrow) : mm(std::shared_ptr<migraphx_module*>(), m) {}
void print() const { call(&migraphx_module_print, mm); } template <class T>
module(migraphx_module* m, share<T> b) : mm(b.alias(m))
{
}
void print() const { call(&migraphx_module_print, mm.get()); }
instruction add_instruction(const migraphx::operation& op, const migraphx::instructions& args) instruction add_instruction(const migraphx::operation& op, const migraphx::instructions& args)
{ {
migraphx_instruction_t op_ins; migraphx_instruction_t op_ins;
call(&migraphx_module_add_instruction, call(&migraphx_module_add_instruction,
&op_ins, &op_ins,
mm, mm.get(),
op.get_handle_ptr(), op.get_handle_ptr(),
args.get_handle_ptr()); args.get_handle_ptr());
return instruction(op_ins, own{}); return instruction(op_ins, own{});
...@@ -750,40 +748,72 @@ struct module ...@@ -750,40 +748,72 @@ struct module
migraphx_instruction_t op_ins; migraphx_instruction_t op_ins;
call(&migraphx_module_add_instruction_with_mod_args, call(&migraphx_module_add_instruction_with_mod_args,
&op_ins, &op_ins,
mm, mm.get(),
op.get_handle_ptr(), op.get_handle_ptr(),
args.get_handle_ptr(), args.get_handle_ptr(),
module_args.get_handle_ptr()); module_args.get_handle_ptr());
return instruction(op_ins, own{}); return instruction(op_ins, own{});
} }
template <typename T>
instruction add_literal(const migraphx::shape& s, T* buffer)
{
migraphx_instruction_t literal_ins;
const auto* buffer_ptr = reinterpret_cast<const char*>(buffer);
call(&migraphx_module_add_literal, &literal_ins, mm.get(), s.get_handle_ptr(), buffer_ptr);
return instruction(literal_ins, own{});
}
instruction add_parameter(const std::string& name, shape s) instruction add_parameter(const std::string& name, shape s)
{ {
migraphx_instruction_t param_ins; migraphx_instruction_t param_ins;
call(&migraphx_module_add_parameter, &param_ins, mm, name.c_str(), s.get_handle_ptr()); call(
&migraphx_module_add_parameter, &param_ins, mm.get(), name.c_str(), s.get_handle_ptr());
return instruction(param_ins, own{}); return instruction(param_ins, own{});
} }
instruction add_return(const migraphx::instructions& args) instruction add_return(const migraphx::instructions& args)
{ {
migraphx_instruction_t ret_ins; migraphx_instruction_t ret_ins;
call(&migraphx_module_add_return, &ret_ins, mm, args.get_handle_ptr()); call(&migraphx_module_add_return, &ret_ins, mm.get(), args.get_handle_ptr());
return instruction(ret_ins, own{}); return instruction(ret_ins, own{});
} }
migraphx_module_t get_handle_ptr() const { return mm.get(); }
private:
std::shared_ptr<migraphx_module> mm;
}; };
struct context struct context
{ {
migraphx_context_t ctx; context(migraphx_context* p, borrow) : ctx(std::shared_ptr<migraphx_context*>(), p) {}
void finish() const { call(&migraphx_context_finish, ctx); } template <class T>
context(migraphx_context* p, share<T> b) : ctx(b.alias(p))
{
}
void finish() const { call(&migraphx_context_finish, ctx.get()); }
template <class T>
T get_queue()
{
void* out;
call(&migraphx_context_get_queue, &out, ctx.get());
// TODO: check type here
return reinterpret_cast<T>(out);
}
private:
std::shared_ptr<migraphx_context> ctx;
}; };
struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options) struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
{ {
compile_options() { this->make_handle(&migraphx_compile_options_create); } compile_options() { this->make_handle(&migraphx_compile_options_create); }
compile_options(migraphx_compile_options* p, own) { this->set_handle(p, own()); } MIGRAPHX_HANDLE_CONSTRUCTOR(compile_options);
/// For targets with offloaded memory(such as the gpu), this will insert /// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the /// instructions during compilation to copy the input parameters to the
...@@ -807,9 +837,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -807,9 +837,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{ {
program() { this->make_handle(&migraphx_program_create); } program() { this->make_handle(&migraphx_program_create); }
program(migraphx_program* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(program);
program(migraphx_program* p, borrow) { this->set_handle(p, borrow{}); }
/// Compile the program for a specific target to be ran on /// Compile the program for a specific target to be ran on
void compile(const target& ptarget, const compile_options& poptions) const void compile(const target& ptarget, const compile_options& poptions) const
...@@ -872,21 +900,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -872,21 +900,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{ {
migraphx_module_t p_modu; migraphx_module_t p_modu;
call(&migraphx_program_get_main_module, &p_modu, this->get_handle_ptr()); call(&migraphx_program_get_main_module, &p_modu, this->get_handle_ptr());
return module{p_modu}; return module{p_modu, this->share_handle()};
} }
context experimental_get_context() context experimental_get_context()
{ {
migraphx_context_t ctx; migraphx_context_t ctx;
call(&migraphx_program_experimental_get_context, &ctx, this->get_handle_ptr()); call(&migraphx_program_experimental_get_context, &ctx, this->get_handle_ptr());
return context{ctx}; return context{ctx, this->share_handle()};
} }
module create_module(const std::string& name) module create_module(const std::string& name)
{ {
migraphx_module_t p_modu; migraphx_module_t p_modu;
call(&migraphx_program_create_module, &p_modu, this->get_handle_ptr(), name.data()); call(&migraphx_program_create_module, &p_modu, this->get_handle_ptr(), name.data());
return module{p_modu}; return module{p_modu, this->share_handle()};
} }
friend bool operator!=(const program& px, const program& py) { return !(px == py); } friend bool operator!=(const program& px, const program& py) { return !(px == py); }
...@@ -895,10 +923,9 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -895,10 +923,9 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
// options for migraphx file format options // options for migraphx file format options
struct file_options : MIGRAPHX_HANDLE_BASE(file_options) struct file_options : MIGRAPHX_HANDLE_BASE(file_options)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(file_options);
file_options() { this->make_handle(&migraphx_file_options_create); } file_options() { this->make_handle(&migraphx_file_options_create); }
file_options(migraphx_file_options* p, own) { this->set_handle(p, own()); }
// set file format // set file format
void set_file_format(const char* format) void set_file_format(const char* format)
{ {
...@@ -938,7 +965,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options) ...@@ -938,7 +965,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{ {
onnx_options() { this->make_handle(&migraphx_onnx_options_create); } onnx_options() { this->make_handle(&migraphx_onnx_options_create); }
onnx_options(migraphx_onnx_options* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(onnx_options);
/// Make onnx parser treat an inputs with a certain dimensions /// Make onnx parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim) void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
...@@ -1020,7 +1047,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options) ...@@ -1020,7 +1047,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
{ {
tf_options() { this->make_handle(&migraphx_tf_options_create); } tf_options() { this->make_handle(&migraphx_tf_options_create); }
tf_options(migraphx_tf_options* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(tf_options);
/// Make tf parser treat an inputs with a certain dimensions /// Make tf parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim) void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
...@@ -1073,7 +1100,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names) ...@@ -1073,7 +1100,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
{ {
quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); } quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); }
quantize_op_names(migraphx_quantize_op_names* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_op_names);
void add(const std::string& name) void add(const std::string& name)
{ {
...@@ -1098,12 +1125,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options) ...@@ -1098,12 +1125,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
{ {
quantize_int8_options() { this->make_handle(&migraphx_quantize_int8_options_create); } quantize_int8_options() { this->make_handle(&migraphx_quantize_int8_options_create); }
quantize_int8_options(migraphx_quantize_int8_options* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_int8_options);
quantize_int8_options(migraphx_quantize_int8_options* p, borrow)
{
this->set_handle(p, borrow{});
}
/// Add an operator that should be quantized /// Add an operator that should be quantized
void add_op_name(const std::string& name) void add_op_name(const std::string& name)
......
...@@ -212,6 +212,9 @@ def module(h): ...@@ -212,6 +212,9 @@ def module(h):
module_refs='std::vector<migraphx::module*>'), module_refs='std::vector<migraphx::module*>'),
fname='add_instruction', fname='add_instruction',
returns='migraphx::instruction_ref') returns='migraphx::instruction_ref')
h.method('add_literal',
api.params(shape='const migraphx::shape&', buffer='const char*'),
returns='migraphx::instruction_ref')
h.method('add_parameter', h.method('add_parameter',
api.params(name='const char*', shape='const migraphx::shape&'), api.params(name='const char*', shape='const migraphx::shape&'),
returns='migraphx::instruction_ref') returns='migraphx::instruction_ref')
...@@ -403,6 +406,7 @@ api.add_function('migraphx_quantize_int8', ...@@ -403,6 +406,7 @@ api.add_function('migraphx_quantize_int8',
@auto_handle(ref=True) @auto_handle(ref=True)
def context(h): def context(h):
h.method('finish', const=True) h.method('finish', const=True)
h.method('get_queue', returns='void*', fname='get_queue().unsafe_get')
@api.interface('migraphx_experimental_custom_op', @api.interface('migraphx_experimental_custom_op',
......
...@@ -29,7 +29,6 @@ void argument::assign_buffer(std::function<char*()> d) ...@@ -29,7 +29,6 @@ void argument::assign_buffer(std::function<char*()> d)
// Collect all shapes // Collect all shapes
std::unordered_map<std::size_t, shape> shapes; std::unordered_map<std::size_t, shape> shapes;
{ {
// cppcheck-suppress variableScope
std::size_t i = 0; std::size_t i = 0;
fix([&](auto self, auto ss) { fix([&](auto self, auto ss) {
if(ss.sub_shapes().empty()) if(ss.sub_shapes().empty())
...@@ -60,7 +59,6 @@ void argument::assign_buffer(std::function<char*()> d) ...@@ -60,7 +59,6 @@ void argument::assign_buffer(std::function<char*()> d)
} }
assert(offset == s.bytes()); assert(offset == s.bytes());
// cppcheck-suppress variableScope
std::size_t i = 0; std::size_t i = 0;
m_data = fix<data_t>([&](auto self, auto ss) { m_data = fix<data_t>([&](auto self, auto ss) {
data_t result; data_t result;
......
...@@ -8,10 +8,10 @@ ...@@ -8,10 +8,10 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void auto_contiguous::apply(module& p) const void auto_contiguous::apply(module& m) const
{ {
std::string key = "require_std_shape"; std::string key = "require_std_shape";
for(auto ins : reverse_iterator_for(p)) for(auto ins : reverse_iterator_for(m))
{ {
auto&& attr = ins->get_operator().attributes(); auto&& attr = ins->get_operator().attributes();
if((attr.get(key, false))) if((attr.get(key, false)))
...@@ -23,18 +23,18 @@ void auto_contiguous::apply(module& p) const ...@@ -23,18 +23,18 @@ void auto_contiguous::apply(module& p) const
{ {
return in; return in;
} }
return p.insert_instruction(ins, make_op("contiguous"), in); return m.insert_instruction(ins, make_op("contiguous"), in);
}); });
if(new_args != args) if(new_args != args)
{ {
p.replace_instruction(ins, ins->get_operator(), new_args); m.replace_instruction(ins, ins->get_operator(), new_args);
} }
} }
} }
auto last = std::prev(p.end()); auto last = std::prev(m.end());
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
// for last instruction that is NOT a return // for last instruction that is NOT a return
if(ins->outputs().empty() and ins != last) if(ins->outputs().empty() and ins != last)
...@@ -42,8 +42,8 @@ void auto_contiguous::apply(module& p) const ...@@ -42,8 +42,8 @@ void auto_contiguous::apply(module& p) const
shape s = ins->get_shape(); shape s = ins->get_shape();
if(not s.standard() and s.elements() != 0) if(not s.standard() and s.elements() != 0)
{ {
auto c = p.insert_instruction(std::next(ins), make_op("contiguous"), ins); auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins);
p.replace_instruction(ins, c); m.replace_instruction(ins, c);
} }
} }
} }
......
...@@ -28,7 +28,7 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const ...@@ -28,7 +28,7 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
{ {
params += " " + src.path.filename().string(); params += " " + src.path.filename().string();
if(out.empty()) if(out.empty())
out = src.path.stem().string() + ".o"; out = src.path.stem().string() + out_ext;
} }
} }
......
...@@ -9,26 +9,6 @@ ...@@ -9,26 +9,6 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class Range, class Iterator>
std::ptrdiff_t bidistance(const Range& r, Iterator start, Iterator last)
{
auto start_forward = start;
auto start_backwards = start;
std::size_t n = 0;
while(start_forward != last and start_backwards != last)
{
n++;
if(start_forward != r.end())
start_forward++;
if(start_backwards != r.begin())
start_backwards--;
}
if(start_forward == last)
return n;
else
return -n;
}
void dead_code_elimination::apply(program& p) const { p.remove_unused_modules(); } void dead_code_elimination::apply(program& p) const { p.remove_unused_modules(); }
void dead_code_elimination::apply(module& m) const void dead_code_elimination::apply(module& m) const
...@@ -48,17 +28,21 @@ void dead_code_elimination::apply(module& m) const ...@@ -48,17 +28,21 @@ void dead_code_elimination::apply(module& m) const
if(i->get_shape().elements() == 0 and i->name().front() != '@' and if(i->get_shape().elements() == 0 and i->name().front() != '@' and
i->name() != "undefined" and i->name() != "identity") i->name() != "undefined" and i->name() != "identity")
continue; continue;
assert(bidistance(m, i, last) > 0); assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last));
std::unordered_set<instruction_ref> visited;
fix([&](auto self, auto leaf) { fix([&](auto self, auto leaf) {
if(not m.has_instruction(leaf)) if(not m.has_instruction(leaf))
return; return;
if(leaf->outputs().empty()) if(leaf->outputs().empty())
{ {
// Dont visit inputs twice
if(not visited.insert(leaf).second)
return;
std::unordered_set<instruction_ref> args(leaf->inputs().begin(), std::unordered_set<instruction_ref> args(leaf->inputs().begin(),
leaf->inputs().end()); leaf->inputs().end());
leaf->clear_arguments(); leaf->clear_arguments();
assert(bidistance(m, last, leaf) < 0); assert(std::distance(m.begin(), leaf) < std::distance(m.begin(), last));
assert(leaf != ins); assert(leaf != ins);
if(leaf->name() != "@param") if(leaf->name() != "@param")
m.move_instruction(leaf, m.end()); m.move_instruction(leaf, m.end());
......
...@@ -17,7 +17,7 @@ class marker_roctx ...@@ -17,7 +17,7 @@ class marker_roctx
std::function<int(const char*)> sym_roctx_range_push; std::function<int(const char*)> sym_roctx_range_push;
std::function<int()> sym_roctx_range_pop; std::function<int()> sym_roctx_range_pop;
uint64_t range_id; uint64_t range_id = 0;
public: public:
marker_roctx() marker_roctx()
......
...@@ -13,13 +13,13 @@ ...@@ -13,13 +13,13 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void eliminate_allocation::apply(module& p) const void eliminate_allocation::apply(module& m) const
{ {
assert(alignment > 0); assert(alignment > 0);
std::size_t n = 0; std::size_t n = 0;
std::vector<std::pair<instruction_ref, std::size_t>> allocs; std::vector<std::pair<instruction_ref, std::size_t>> allocs;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
if(ins->name() != allocation_op) if(ins->name() != allocation_op)
continue; continue;
...@@ -30,13 +30,13 @@ void eliminate_allocation::apply(module& p) const ...@@ -30,13 +30,13 @@ void eliminate_allocation::apply(module& p) const
} }
if(n > 0) if(n > 0)
{ {
auto mem = p.add_parameter("memory", shape{shape::int8_type, {n}}); auto mem = m.add_parameter("memory", shape{shape::int8_type, {n}});
for(auto&& pp : allocs) for(auto&& pp : allocs)
{ {
auto ins = pp.first; auto ins = pp.first;
auto s = ins->get_shape(); auto s = ins->get_shape();
auto offset = pp.second; auto offset = pp.second;
p.replace_instruction( m.replace_instruction(
ins, make_op("load", {{"shape", to_value(s)}, {"offset", offset}}), mem); ins, make_op("load", {{"shape", to_value(s)}, {"offset", offset}}), mem);
} }
} }
......
...@@ -11,7 +11,7 @@ namespace migraphx { ...@@ -11,7 +11,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class Range> template <class Range>
void cse_range(module& p, Range&& r) void cse_range(module& m, Range&& r)
{ {
std::unordered_multimap<std::string, instruction_ref> instructions; std::unordered_multimap<std::string, instruction_ref> instructions;
std::unordered_set<instruction_ref> processed_ins; std::unordered_set<instruction_ref> processed_ins;
...@@ -30,24 +30,24 @@ void cse_range(module& p, Range&& r) ...@@ -30,24 +30,24 @@ void cse_range(module& p, Range&& r)
continue; continue;
if(*eq != *ins) if(*eq != *ins)
continue; continue;
p.replace_instruction(ins, eq); m.replace_instruction(ins, eq);
processed_ins.emplace(ins); processed_ins.emplace(ins);
std::vector<instruction_ref> outputs; std::vector<instruction_ref> outputs;
std::copy_if(eq->outputs().begin(), std::copy_if(eq->outputs().begin(),
eq->outputs().end(), eq->outputs().end(),
std::back_inserter(outputs), std::back_inserter(outputs),
[&](auto x) { return p.has_instruction(x); }); [&](auto x) { return m.has_instruction(x); });
std::sort(outputs.begin(), outputs.end(), [&](auto x, auto y) { std::sort(outputs.begin(), outputs.end(), [&](auto x, auto y) {
return std::distance(eq, x) < std::distance(eq, y); return std::distance(eq, x) < std::distance(eq, y);
}); });
cse_range(p, outputs); cse_range(m, outputs);
} }
instructions.emplace(ins->name(), ins); instructions.emplace(ins->name(), ins);
} }
} }
void eliminate_common_subexpression::apply(module& p) const { cse_range(p, iterator_for(p)); } void eliminate_common_subexpression::apply(module& m) const { cse_range(m, iterator_for(m)); }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void eliminate_concat::apply(module& p) const void eliminate_concat::apply(module& m) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
// Look for the concat operator // Look for the concat operator
if(ins->name() != concat_opt.name()) if(ins->name() != concat_opt.name())
...@@ -64,22 +64,22 @@ void eliminate_concat::apply(module& p) const ...@@ -64,22 +64,22 @@ void eliminate_concat::apply(module& p) const
std::sort(sorted_allocations.begin(), std::sort(sorted_allocations.begin(),
sorted_allocations.end(), sorted_allocations.end(),
[&](instruction_ref x, instruction_ref y) { [&](instruction_ref x, instruction_ref y) {
return std::distance(p.begin(), x) < std::distance(p.begin(), y); return std::distance(m.begin(), x) < std::distance(m.begin(), y);
}); });
// Move "super" allocation to the front // Move "super" allocation to the front
auto first = sorted_allocations.front(); auto first = sorted_allocations.front();
auto super = p.move_instruction(last, first); auto super = m.move_instruction(last, first);
// Replace each allocation with a load // Replace each allocation with a load
std::size_t offset = 0; std::size_t offset = 0;
for(auto alloc : allocations) for(auto alloc : allocations)
{ {
op::load op{alloc->get_shape(), offset}; op::load op{alloc->get_shape(), offset};
p.replace_instruction(alloc, op, {super}); m.replace_instruction(alloc, op, {super});
offset += alloc->get_shape().bytes(); offset += alloc->get_shape().bytes();
} }
std::vector<instruction_ref> args = {super}; std::vector<instruction_ref> args = {super};
std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args)); std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args));
p.replace_instruction(ins, migraphx::make_op("identity"), args); m.replace_instruction(ins, migraphx::make_op("identity"), args);
} }
} }
} }
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/op/contiguous.hpp> #include <migraphx/op/contiguous.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/par_for.hpp>
#include <utility> #include <utility>
namespace migraphx { namespace migraphx {
...@@ -69,9 +70,11 @@ static bool try_compute_shape(instruction_ref ins, ...@@ -69,9 +70,11 @@ static bool try_compute_shape(instruction_ref ins,
return try_compute_shape(ins, inputs, mods); return try_compute_shape(ins, inputs, mods);
} }
void eliminate_contiguous::apply(module& p) const void eliminate_contiguous::apply(module& m) const
{ {
for(auto ins : iterator_for(p)) std::vector<instruction_ref> const_instruction;
for(auto ins : iterator_for(m))
{ {
// return instruction should have inputs with standard shape // return instruction should have inputs with standard shape
if(ins->name() == "@return") if(ins->name() == "@return")
...@@ -81,6 +84,7 @@ void eliminate_contiguous::apply(module& p) const ...@@ -81,6 +84,7 @@ void eliminate_contiguous::apply(module& p) const
auto args = ins->inputs(); auto args = ins->inputs();
auto new_args = args; auto new_args = args;
auto mod_args = ins->module_inputs(); auto mod_args = ins->module_inputs();
for(auto arg : ins->inputs()) for(auto arg : ins->inputs())
{ {
if(arg->name() == op_name) if(arg->name() == op_name)
...@@ -93,15 +97,25 @@ void eliminate_contiguous::apply(module& p) const ...@@ -93,15 +97,25 @@ void eliminate_contiguous::apply(module& p) const
} }
else if(prev->can_eval()) else if(prev->can_eval())
{ {
auto c = op::contiguous{}; const_instruction.push_back(arg);
auto r = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
auto l = p.add_literal(r.get_shape(), r.data());
p.replace_instruction(arg, l);
} }
} }
} }
} }
// Perform evaluations in parallel
std::vector<argument> literals(const_instruction.size());
par_for(const_instruction.size(), 1, [&](const auto i) {
auto c = op::contiguous{};
auto prev = const_instruction[i]->inputs().front();
literals[i] = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
});
for(size_t i = 0; i < const_instruction.size(); i++)
{
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
m.replace_instruction(const_instruction[i], l);
}
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -8,21 +8,21 @@ ...@@ -8,21 +8,21 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void eliminate_identity::apply(module& p) const void eliminate_identity::apply(module& m) const
{ {
auto last = std::prev(p.end()); auto last = std::prev(m.end());
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
// Skip the first instruction, since we always process the previous // Skip the first instruction, since we always process the previous
// instruction // instruction
if(ins == p.begin()) if(ins == m.begin())
continue; continue;
const auto i = std::prev(ins); const auto i = std::prev(ins);
if(i->name() == "identity") if(i->name() == "identity")
{ {
p.replace_instruction(i, i->inputs().front()); m.replace_instruction(i, i->inputs().front());
p.move_instruction(i, p.end()); m.move_instruction(i, m.end());
} }
if(ins == last) if(ins == last)
{ {
...@@ -31,7 +31,7 @@ void eliminate_identity::apply(module& p) const ...@@ -31,7 +31,7 @@ void eliminate_identity::apply(module& p) const
const instruction_ref& identity_input = ins->inputs().front(); const instruction_ref& identity_input = ins->inputs().front();
if(identity_input->outputs().size() == 1) if(identity_input->outputs().size() == 1)
{ {
p.move_instruction(identity_input, i); m.move_instruction(identity_input, i);
// since this is the last instruction, removing it only // since this is the last instruction, removing it only
// requires changing "last" and calling remove below // requires changing "last" and calling remove below
last = std::prev(last); last = std::prev(last);
...@@ -40,7 +40,7 @@ void eliminate_identity::apply(module& p) const ...@@ -40,7 +40,7 @@ void eliminate_identity::apply(module& p) const
break; break;
} }
} }
p.remove_instructions(std::next(last), p.end()); m.remove_instructions(std::next(last), m.end());
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -13,7 +13,7 @@ struct adjust_allocation ...@@ -13,7 +13,7 @@ struct adjust_allocation
{ {
allocation_model model; allocation_model model;
std::string name() const { return "adjust_allocation"; } std::string name() const { return "adjust_allocation"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -16,7 +16,7 @@ struct stream_race ...@@ -16,7 +16,7 @@ struct stream_race
instruction_ref before; instruction_ref before;
}; };
std::vector<stream_race> analyze_streams(const module& p, const stream_model& m); std::vector<stream_race> analyze_streams(const module& m, const stream_model& strmm);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -13,7 +13,7 @@ struct module; ...@@ -13,7 +13,7 @@ struct module;
struct auto_contiguous struct auto_contiguous
{ {
std::string name() const { return "auto_contiguous"; } std::string name() const { return "auto_contiguous"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -33,7 +33,7 @@ struct check_context ...@@ -33,7 +33,7 @@ struct check_context
}; };
std::string name() const { return "check_context"; } std::string name() const { return "check_context"; }
void apply(module& p) const { p.insert_instruction(p.begin(), op{}); } void apply(module& m) const { m.insert_instruction(m.begin(), op{}); }
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -24,6 +24,7 @@ struct src_compiler ...@@ -24,6 +24,7 @@ struct src_compiler
std::string flags = ""; std::string flags = "";
std::string output = ""; std::string output = "";
std::string launcher = ""; std::string launcher = "";
std::string out_ext = ".o";
std::function<fs::path(fs::path)> process = nullptr; std::function<fs::path(fs::path)> process = nullptr;
std::vector<char> compile(const std::vector<src_file>& srcs) const; std::vector<char> compile(const std::vector<src_file>& srcs) const;
}; };
......
...@@ -19,7 +19,7 @@ struct eliminate_allocation ...@@ -19,7 +19,7 @@ struct eliminate_allocation
std::string allocation_op{}; std::string allocation_op{};
std::size_t alignment = 32; std::size_t alignment = 32;
std::string name() const { return "eliminate_allocation"; } std::string name() const { return "eliminate_allocation"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -16,7 +16,7 @@ struct module; ...@@ -16,7 +16,7 @@ struct module;
struct eliminate_common_subexpression struct eliminate_common_subexpression
{ {
std::string name() const { return "eliminate_common_subexpression"; } std::string name() const { return "eliminate_common_subexpression"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -18,7 +18,7 @@ struct eliminate_concat ...@@ -18,7 +18,7 @@ struct eliminate_concat
{ {
concat_optimization concat_opt; concat_optimization concat_opt;
std::string name() const { return "eliminate_concat"; } std::string name() const { return "eliminate_concat"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment