Commit 2f268bc2 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into mlir-c

parents f75c5a38 aa7ff911
......@@ -15,6 +15,16 @@ namespace migraphx {
inline namespace api { // NOLINT
#endif
#ifdef __has_cpp_attribute
#if __has_cpp_attribute(deprecated)
#define MIGRAPHX_DEPRECATED(...) [[deprecated(__VA_ARGS__)]]
#endif
#endif
#ifndef MIGRAPHX_DEPRECATED
#define MIGRAPHX_DEPRECATED(...)
#endif
template <int N>
struct rank : rank<N - 1>
{
......@@ -29,10 +39,7 @@ template <class T, class F, class... Ts>
T* make(F f, Ts&&... xs)
{
T* result = nullptr;
// cppcheck-suppress redundantInitialization
// cppcheck-suppress redundantAssignment
// cppcheck-suppress unreadVariable
auto e = f(&result, std::forward<Ts>(xs)...);
auto e = f(&result, std::forward<Ts>(xs)...);
if(e != migraphx_status_success)
throw std::runtime_error("Failed to call function");
return result;
......@@ -41,9 +48,6 @@ T* make(F f, Ts&&... xs)
template <class F, class... Ts>
void call(F f, Ts&&... xs)
{
// cppcheck-suppress redundantInitialization
// cppcheck-suppress redundantAssignment
// cppcheck-suppress unreadVariable
auto e = f(std::forward<Ts>(xs)...);
if(e != migraphx_status_success)
throw std::runtime_error("Failed to call function");
......@@ -99,34 +103,22 @@ struct iota_iterator
return it;
}
// TODO: operator->
reference operator*() const { return (*f)(index); }
};
reference operator*() const { return f(index); }
template <class F, class Iterator>
inline iota_iterator<F, Iterator> operator+(iota_iterator<F, Iterator> x,
iota_iterator<F, Iterator> y)
{
return iota_iterator<F, Iterator>(x.index + y.index, x.f);
}
friend iota_iterator operator+(iota_iterator x, iota_iterator y)
{
return iota_iterator(x.index + y.index, x.f);
}
template <class F, class Iterator>
inline iota_iterator<F, Iterator> operator-(iota_iterator<F, Iterator> x,
iota_iterator<F, Iterator> y)
{
return iota_iterator<F, Iterator>(x.index - y.index, x.f);
}
friend iota_iterator operator-(iota_iterator x, iota_iterator y)
{
return iota_iterator(x.index - y.index, x.f);
}
template <class F, class Iterator>
inline bool operator==(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y)
{
return x.index == y.index;
}
friend bool operator==(iota_iterator x, iota_iterator y) { return x.index == y.index; }
template <class F, class Iterator>
inline bool operator!=(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y)
{
return x.index != y.index;
}
friend bool operator!=(iota_iterator x, iota_iterator y) { return x.index != y.index; }
};
template <class Derived>
struct array_base
......@@ -136,8 +128,20 @@ struct array_base
template <class T>
using value_type_t = decltype(std::declval<T>()[0]);
struct iterator_read
{
const Derived* self;
template <class D = Derived>
value_type_t<D> operator()(size_t pidx) const
{
return (*self)[pidx];
}
};
template <class T>
using iterator_t = iota_iterator<typename T::iterator_read>;
using iterator_t = iota_iterator<iterator_read>;
bool empty() const { return derived().size() == 0; }
template <class D = Derived>
value_type_t<D> front() const
......@@ -154,13 +158,13 @@ struct array_base
template <class D = Derived>
iterator_t<D> begin() const
{
return {0, {derived().get_handle_ptr()}};
return {0, {&derived()}};
}
template <class D = Derived>
iterator_t<D> end() const
{
return {derived().size(), {derived().get_handle_ptr()}};
return {derived().size(), {&derived()}};
}
};
......@@ -200,9 +204,25 @@ struct borrow
{
};
template <class T>
struct share
{
share(std::shared_ptr<T> p) : ptr(std::move(p)) {}
template <class U>
std::shared_ptr<U> alias(U* p) const
{
return std::shared_ptr<U>{ptr, p};
}
private:
std::shared_ptr<T> ptr;
};
template <class Derived, class T, class D, D Deleter, class A, A Assigner>
struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
{
using handle_type = T;
handle_base() : m_handle(nullptr) {}
template <class F, class... Ts>
void make_handle(F f, Ts&&... xs)
......@@ -231,6 +251,14 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
m_handle = std::shared_ptr<U>{ptr, [](U*) {}};
}
template <class U, class V>
void set_handle(U* ptr, share<V> b)
{
m_handle = std::shared_ptr<T>{ptr, [b](U*) {}};
}
share<T> share_handle() const { return {m_handle}; }
template <class U>
void assign_to_handle(U* x)
{
......@@ -241,6 +269,17 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
std::shared_ptr<T> m_handle;
};
// NOLINTNEXTLINE
#define MIGRAPHX_HANDLE_CONSTRUCTOR(name) \
template <class HandleType, \
class Lifetime, \
class = \
typename std::enable_if<std::is_convertible<HandleType*, handle_type*>{}>::type> \
name(HandleType* p, Lifetime lifetime) \
{ \
this->set_handle(p, std::move(lifetime)); \
}
template <class Base>
struct interface_base : Base
{
......@@ -269,6 +308,7 @@ struct interface_base : Base
T** y = reinterpret_cast<T**>(out);
T* x = reinterpret_cast<T*>(input);
assert(x != nullptr and y != nullptr and *y == nullptr);
// cppcheck-suppress useSmartPointer
*y = new T(*x); // NOLINT
});
};
......@@ -398,11 +438,10 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
{
shape() {}
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
shape(const migraphx_shape* p) { this->set_handle(p, borrow{}); }
shape(migraphx_shape* p, own) { this->set_handle(p, own{}); }
shape(migraphx_shape* p, borrow) { this->set_handle(p, borrow{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(shape);
/// Construct a scalar shape
shape(migraphx_shape_datatype_t type)
......@@ -479,10 +518,9 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{
argument() {}
argument(migraphx_argument* p, borrow) { this->set_handle(p, borrow{}); }
argument(migraphx_argument* p, own) { this->set_handle(p, own{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(argument);
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); }
argument(shape pshape, void* pbuffer)
......@@ -494,7 +532,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{
const_migraphx_shape_t pout;
call(&migraphx_argument_shape, &pout, this->get_handle_ptr());
return {pout};
return {pout, this->share_handle()};
}
char* data() const
......@@ -526,9 +564,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
{
target() {}
target(migraphx_target* p, own) { this->set_handle(p, own{}); }
target(migraphx_target* p, borrow) { this->set_handle(p, borrow{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(target);
/// Construct a target from its name
target(const char* name) { this->make_handle(&migraphx_target_create, name); }
......@@ -538,15 +574,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{
program_parameter_shapes() {}
program_parameter_shapes(migraphx_program_parameter_shapes* p, own)
{
this->set_handle(p, own{});
}
program_parameter_shapes(migraphx_program_parameter_shapes* p, borrow)
{
this->set_handle(p, borrow{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameter_shapes);
size_t size() const
{
......@@ -559,7 +587,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{
const_migraphx_shape_t pout;
call(&migraphx_program_parameter_shapes_get, &pout, this->get_handle_ptr(), pname);
return {pout};
return {pout, this->share_handle()};
}
std::vector<const char*> names() const
......@@ -576,10 +604,9 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
/// A class to construct the inputs parameters for a program
struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
{
program_parameters(migraphx_program_parameters* p, own) { this->set_handle(p, own{}); }
program_parameters(migraphx_program_parameters* p, borrow) { this->set_handle(p, borrow{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameters);
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
program_parameters(migraphx_program_parameters* p) { this->set_handle(p, borrow{}); }
program_parameters() { this->make_handle(&migraphx_program_parameters_create); }
......@@ -604,9 +631,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{
arguments(migraphx_arguments* p, own) { this->set_handle(p, own{}); }
arguments(migraphx_arguments* p, borrow) { this->set_handle(p, borrow{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(arguments);
size_t size() const
{
......@@ -619,27 +644,13 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{
const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, this->get_handle_ptr(), pidx);
return {pout};
return {pout, this->share_handle()};
}
struct iterator_read
{
migraphx_arguments* self;
argument operator()(size_t pidx) const
{
const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, self, pidx);
return {pout};
}
};
};
struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{
shapes(migraphx_shapes* p, own) { this->set_handle(p, own{}); }
shapes(migraphx_shapes* p, borrow) { this->set_handle(p, borrow{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(shapes);
size_t size() const
{
......@@ -652,26 +663,13 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{
const_migraphx_shape_t pout;
call(&migraphx_shapes_get, &pout, this->get_handle_ptr(), pidx);
return {pout};
return {pout, this->share_handle()};
}
struct iterator_read
{
migraphx_shapes* self;
shape operator()(size_t pidx) const
{
const_migraphx_shape_t pout;
call(&migraphx_shapes_get, &pout, self, pidx);
return {pout};
}
};
};
struct operation : MIGRAPHX_HANDLE_BASE(operation)
{
operation(migraphx_operation* p, own) { this->set_handle(p, own{}); }
operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(operation);
template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs)
......@@ -689,15 +687,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction)
{
instruction(migraphx_instruction* p, own) { this->set_handle(p, own{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(instruction);
};
struct instructions : MIGRAPHX_HANDLE_BASE(instructions)
{
instructions(migraphx_instructions* p, own) { this->set_handle(p, own{}); }
instructions(migraphx_instructions* p, borrow) { this->set_handle(p, borrow{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(instructions);
template <class... Ts>
instructions(Ts... xs)
......@@ -711,33 +706,36 @@ struct module;
struct modules : MIGRAPHX_HANDLE_BASE(modules)
{
modules(migraphx_modules* p, own) { this->set_handle(p, own{}); }
modules(migraphx_modules* p, borrow) { this->set_handle(p, borrow{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(modules);
template <class... Ts>
modules(Ts... xs)
{
std::array<migraphx_module_t, sizeof...(Ts)> a = {xs.mm...};
std::array<migraphx_module_t, sizeof...(Ts)> a = {xs.get_handle_ptr()...};
this->make_handle(&migraphx_modules_create, a.data(), a.size());
}
};
struct module
{
migraphx_module_t mm;
MIGRAPHX_DEPRECATED("Constructor without lifetime annotation is deprecated.")
module(migraphx_module* m) : mm(std::shared_ptr<migraphx_module*>(), m) {}
module(const migraphx_module_t& m) : mm(m) {}
module(migraphx_module* m, borrow) : mm(std::shared_ptr<migraphx_module*>(), m) {}
void print() const { call(&migraphx_module_print, mm); }
template <class T>
module(migraphx_module* m, share<T> b) : mm(b.alias(m))
{
}
void print() const { call(&migraphx_module_print, mm.get()); }
instruction add_instruction(const migraphx::operation& op, const migraphx::instructions& args)
{
migraphx_instruction_t op_ins;
call(&migraphx_module_add_instruction,
&op_ins,
mm,
mm.get(),
op.get_handle_ptr(),
args.get_handle_ptr());
return instruction(op_ins, own{});
......@@ -750,40 +748,72 @@ struct module
migraphx_instruction_t op_ins;
call(&migraphx_module_add_instruction_with_mod_args,
&op_ins,
mm,
mm.get(),
op.get_handle_ptr(),
args.get_handle_ptr(),
module_args.get_handle_ptr());
return instruction(op_ins, own{});
}
template <typename T>
instruction add_literal(const migraphx::shape& s, T* buffer)
{
migraphx_instruction_t literal_ins;
const auto* buffer_ptr = reinterpret_cast<const char*>(buffer);
call(&migraphx_module_add_literal, &literal_ins, mm.get(), s.get_handle_ptr(), buffer_ptr);
return instruction(literal_ins, own{});
}
instruction add_parameter(const std::string& name, shape s)
{
migraphx_instruction_t param_ins;
call(&migraphx_module_add_parameter, &param_ins, mm, name.c_str(), s.get_handle_ptr());
call(
&migraphx_module_add_parameter, &param_ins, mm.get(), name.c_str(), s.get_handle_ptr());
return instruction(param_ins, own{});
}
instruction add_return(const migraphx::instructions& args)
{
migraphx_instruction_t ret_ins;
call(&migraphx_module_add_return, &ret_ins, mm, args.get_handle_ptr());
call(&migraphx_module_add_return, &ret_ins, mm.get(), args.get_handle_ptr());
return instruction(ret_ins, own{});
}
migraphx_module_t get_handle_ptr() const { return mm.get(); }
private:
std::shared_ptr<migraphx_module> mm;
};
struct context
{
migraphx_context_t ctx;
context(migraphx_context* p, borrow) : ctx(std::shared_ptr<migraphx_context*>(), p) {}
void finish() const { call(&migraphx_context_finish, ctx); }
template <class T>
context(migraphx_context* p, share<T> b) : ctx(b.alias(p))
{
}
void finish() const { call(&migraphx_context_finish, ctx.get()); }
template <class T>
T get_queue()
{
void* out;
call(&migraphx_context_get_queue, &out, ctx.get());
// TODO: check type here
return reinterpret_cast<T>(out);
}
private:
std::shared_ptr<migraphx_context> ctx;
};
struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
{
compile_options() { this->make_handle(&migraphx_compile_options_create); }
compile_options(migraphx_compile_options* p, own) { this->set_handle(p, own()); }
MIGRAPHX_HANDLE_CONSTRUCTOR(compile_options);
/// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the
......@@ -807,9 +837,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{
program() { this->make_handle(&migraphx_program_create); }
program(migraphx_program* p, own) { this->set_handle(p, own{}); }
program(migraphx_program* p, borrow) { this->set_handle(p, borrow{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(program);
/// Compile the program for a specific target to be ran on
void compile(const target& ptarget, const compile_options& poptions) const
......@@ -872,21 +900,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{
migraphx_module_t p_modu;
call(&migraphx_program_get_main_module, &p_modu, this->get_handle_ptr());
return module{p_modu};
return module{p_modu, this->share_handle()};
}
context experimental_get_context()
{
migraphx_context_t ctx;
call(&migraphx_program_experimental_get_context, &ctx, this->get_handle_ptr());
return context{ctx};
return context{ctx, this->share_handle()};
}
module create_module(const std::string& name)
{
migraphx_module_t p_modu;
call(&migraphx_program_create_module, &p_modu, this->get_handle_ptr(), name.data());
return module{p_modu};
return module{p_modu, this->share_handle()};
}
friend bool operator!=(const program& px, const program& py) { return !(px == py); }
......@@ -895,10 +923,9 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
// options for migraphx file format options
struct file_options : MIGRAPHX_HANDLE_BASE(file_options)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(file_options);
file_options() { this->make_handle(&migraphx_file_options_create); }
file_options(migraphx_file_options* p, own) { this->set_handle(p, own()); }
// set file format
void set_file_format(const char* format)
{
......@@ -938,7 +965,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{
onnx_options() { this->make_handle(&migraphx_onnx_options_create); }
onnx_options(migraphx_onnx_options* p, own) { this->set_handle(p, own{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(onnx_options);
/// Make onnx parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
......@@ -1020,7 +1047,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
{
tf_options() { this->make_handle(&migraphx_tf_options_create); }
tf_options(migraphx_tf_options* p, own) { this->set_handle(p, own{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(tf_options);
/// Make tf parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
......@@ -1073,7 +1100,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
{
quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); }
quantize_op_names(migraphx_quantize_op_names* p, own) { this->set_handle(p, own{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_op_names);
void add(const std::string& name)
{
......@@ -1098,12 +1125,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
{
quantize_int8_options() { this->make_handle(&migraphx_quantize_int8_options_create); }
quantize_int8_options(migraphx_quantize_int8_options* p, own) { this->set_handle(p, own{}); }
quantize_int8_options(migraphx_quantize_int8_options* p, borrow)
{
this->set_handle(p, borrow{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_int8_options);
/// Add an operator that should be quantized
void add_op_name(const std::string& name)
......
......@@ -212,6 +212,9 @@ def module(h):
module_refs='std::vector<migraphx::module*>'),
fname='add_instruction',
returns='migraphx::instruction_ref')
h.method('add_literal',
api.params(shape='const migraphx::shape&', buffer='const char*'),
returns='migraphx::instruction_ref')
h.method('add_parameter',
api.params(name='const char*', shape='const migraphx::shape&'),
returns='migraphx::instruction_ref')
......@@ -403,6 +406,7 @@ api.add_function('migraphx_quantize_int8',
@auto_handle(ref=True)
def context(h):
h.method('finish', const=True)
h.method('get_queue', returns='void*', fname='get_queue().unsafe_get')
@api.interface('migraphx_experimental_custom_op',
......
......@@ -29,7 +29,6 @@ void argument::assign_buffer(std::function<char*()> d)
// Collect all shapes
std::unordered_map<std::size_t, shape> shapes;
{
// cppcheck-suppress variableScope
std::size_t i = 0;
fix([&](auto self, auto ss) {
if(ss.sub_shapes().empty())
......@@ -60,7 +59,6 @@ void argument::assign_buffer(std::function<char*()> d)
}
assert(offset == s.bytes());
// cppcheck-suppress variableScope
std::size_t i = 0;
m_data = fix<data_t>([&](auto self, auto ss) {
data_t result;
......
......@@ -8,10 +8,10 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void auto_contiguous::apply(module& p) const
void auto_contiguous::apply(module& m) const
{
std::string key = "require_std_shape";
for(auto ins : reverse_iterator_for(p))
for(auto ins : reverse_iterator_for(m))
{
auto&& attr = ins->get_operator().attributes();
if((attr.get(key, false)))
......@@ -23,18 +23,18 @@ void auto_contiguous::apply(module& p) const
{
return in;
}
return p.insert_instruction(ins, make_op("contiguous"), in);
return m.insert_instruction(ins, make_op("contiguous"), in);
});
if(new_args != args)
{
p.replace_instruction(ins, ins->get_operator(), new_args);
m.replace_instruction(ins, ins->get_operator(), new_args);
}
}
}
auto last = std::prev(p.end());
for(auto ins : iterator_for(p))
auto last = std::prev(m.end());
for(auto ins : iterator_for(m))
{
// for last instruction that is NOT a return
if(ins->outputs().empty() and ins != last)
......@@ -42,8 +42,8 @@ void auto_contiguous::apply(module& p) const
shape s = ins->get_shape();
if(not s.standard() and s.elements() != 0)
{
auto c = p.insert_instruction(std::next(ins), make_op("contiguous"), ins);
p.replace_instruction(ins, c);
auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins);
m.replace_instruction(ins, c);
}
}
}
......
......@@ -28,7 +28,7 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
{
params += " " + src.path.filename().string();
if(out.empty())
out = src.path.stem().string() + ".o";
out = src.path.stem().string() + out_ext;
}
}
......
......@@ -9,26 +9,6 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Range, class Iterator>
std::ptrdiff_t bidistance(const Range& r, Iterator start, Iterator last)
{
auto start_forward = start;
auto start_backwards = start;
std::size_t n = 0;
while(start_forward != last and start_backwards != last)
{
n++;
if(start_forward != r.end())
start_forward++;
if(start_backwards != r.begin())
start_backwards--;
}
if(start_forward == last)
return n;
else
return -n;
}
void dead_code_elimination::apply(program& p) const { p.remove_unused_modules(); }
void dead_code_elimination::apply(module& m) const
......@@ -48,17 +28,21 @@ void dead_code_elimination::apply(module& m) const
if(i->get_shape().elements() == 0 and i->name().front() != '@' and
i->name() != "undefined" and i->name() != "identity")
continue;
assert(bidistance(m, i, last) > 0);
assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last));
std::unordered_set<instruction_ref> visited;
fix([&](auto self, auto leaf) {
if(not m.has_instruction(leaf))
return;
if(leaf->outputs().empty())
{
// Dont visit inputs twice
if(not visited.insert(leaf).second)
return;
std::unordered_set<instruction_ref> args(leaf->inputs().begin(),
leaf->inputs().end());
leaf->clear_arguments();
assert(bidistance(m, last, leaf) < 0);
assert(std::distance(m.begin(), leaf) < std::distance(m.begin(), last));
assert(leaf != ins);
if(leaf->name() != "@param")
m.move_instruction(leaf, m.end());
......
......@@ -17,7 +17,7 @@ class marker_roctx
std::function<int(const char*)> sym_roctx_range_push;
std::function<int()> sym_roctx_range_pop;
uint64_t range_id;
uint64_t range_id = 0;
public:
marker_roctx()
......
......@@ -13,13 +13,13 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void eliminate_allocation::apply(module& p) const
void eliminate_allocation::apply(module& m) const
{
assert(alignment > 0);
std::size_t n = 0;
std::vector<std::pair<instruction_ref, std::size_t>> allocs;
for(auto ins : iterator_for(p))
for(auto ins : iterator_for(m))
{
if(ins->name() != allocation_op)
continue;
......@@ -30,13 +30,13 @@ void eliminate_allocation::apply(module& p) const
}
if(n > 0)
{
auto mem = p.add_parameter("memory", shape{shape::int8_type, {n}});
auto mem = m.add_parameter("memory", shape{shape::int8_type, {n}});
for(auto&& pp : allocs)
{
auto ins = pp.first;
auto s = ins->get_shape();
auto offset = pp.second;
p.replace_instruction(
m.replace_instruction(
ins, make_op("load", {{"shape", to_value(s)}, {"offset", offset}}), mem);
}
}
......
......@@ -11,7 +11,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Range>
void cse_range(module& p, Range&& r)
void cse_range(module& m, Range&& r)
{
std::unordered_multimap<std::string, instruction_ref> instructions;
std::unordered_set<instruction_ref> processed_ins;
......@@ -30,24 +30,24 @@ void cse_range(module& p, Range&& r)
continue;
if(*eq != *ins)
continue;
p.replace_instruction(ins, eq);
m.replace_instruction(ins, eq);
processed_ins.emplace(ins);
std::vector<instruction_ref> outputs;
std::copy_if(eq->outputs().begin(),
eq->outputs().end(),
std::back_inserter(outputs),
[&](auto x) { return p.has_instruction(x); });
[&](auto x) { return m.has_instruction(x); });
std::sort(outputs.begin(), outputs.end(), [&](auto x, auto y) {
return std::distance(eq, x) < std::distance(eq, y);
});
cse_range(p, outputs);
cse_range(m, outputs);
}
instructions.emplace(ins->name(), ins);
}
}
void eliminate_common_subexpression::apply(module& p) const { cse_range(p, iterator_for(p)); }
void eliminate_common_subexpression::apply(module& m) const { cse_range(m, iterator_for(m)); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -13,9 +13,9 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void eliminate_concat::apply(module& p) const
void eliminate_concat::apply(module& m) const
{
for(auto ins : iterator_for(p))
for(auto ins : iterator_for(m))
{
// Look for the concat operator
if(ins->name() != concat_opt.name())
......@@ -64,22 +64,22 @@ void eliminate_concat::apply(module& p) const
std::sort(sorted_allocations.begin(),
sorted_allocations.end(),
[&](instruction_ref x, instruction_ref y) {
return std::distance(p.begin(), x) < std::distance(p.begin(), y);
return std::distance(m.begin(), x) < std::distance(m.begin(), y);
});
// Move "super" allocation to the front
auto first = sorted_allocations.front();
auto super = p.move_instruction(last, first);
auto super = m.move_instruction(last, first);
// Replace each allocation with a load
std::size_t offset = 0;
for(auto alloc : allocations)
{
op::load op{alloc->get_shape(), offset};
p.replace_instruction(alloc, op, {super});
m.replace_instruction(alloc, op, {super});
offset += alloc->get_shape().bytes();
}
std::vector<instruction_ref> args = {super};
std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args));
p.replace_instruction(ins, migraphx::make_op("identity"), args);
m.replace_instruction(ins, migraphx::make_op("identity"), args);
}
}
}
......
......@@ -6,6 +6,7 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/par_for.hpp>
#include <utility>
namespace migraphx {
......@@ -69,9 +70,11 @@ static bool try_compute_shape(instruction_ref ins,
return try_compute_shape(ins, inputs, mods);
}
void eliminate_contiguous::apply(module& p) const
void eliminate_contiguous::apply(module& m) const
{
for(auto ins : iterator_for(p))
std::vector<instruction_ref> const_instruction;
for(auto ins : iterator_for(m))
{
// return instruction should have inputs with standard shape
if(ins->name() == "@return")
......@@ -81,6 +84,7 @@ void eliminate_contiguous::apply(module& p) const
auto args = ins->inputs();
auto new_args = args;
auto mod_args = ins->module_inputs();
for(auto arg : ins->inputs())
{
if(arg->name() == op_name)
......@@ -93,15 +97,25 @@ void eliminate_contiguous::apply(module& p) const
}
else if(prev->can_eval())
{
auto c = op::contiguous{};
auto r = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
auto l = p.add_literal(r.get_shape(), r.data());
p.replace_instruction(arg, l);
const_instruction.push_back(arg);
}
}
}
}
// Perform evaluations in parallel
std::vector<argument> literals(const_instruction.size());
par_for(const_instruction.size(), 1, [&](const auto i) {
auto c = op::contiguous{};
auto prev = const_instruction[i]->inputs().front();
literals[i] = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
});
for(size_t i = 0; i < const_instruction.size(); i++)
{
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
m.replace_instruction(const_instruction[i], l);
}
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -8,21 +8,21 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void eliminate_identity::apply(module& p) const
void eliminate_identity::apply(module& m) const
{
auto last = std::prev(p.end());
for(auto ins : iterator_for(p))
auto last = std::prev(m.end());
for(auto ins : iterator_for(m))
{
// Skip the first instruction, since we always process the previous
// instruction
if(ins == p.begin())
if(ins == m.begin())
continue;
const auto i = std::prev(ins);
if(i->name() == "identity")
{
p.replace_instruction(i, i->inputs().front());
p.move_instruction(i, p.end());
m.replace_instruction(i, i->inputs().front());
m.move_instruction(i, m.end());
}
if(ins == last)
{
......@@ -31,7 +31,7 @@ void eliminate_identity::apply(module& p) const
const instruction_ref& identity_input = ins->inputs().front();
if(identity_input->outputs().size() == 1)
{
p.move_instruction(identity_input, i);
m.move_instruction(identity_input, i);
// since this is the last instruction, removing it only
// requires changing "last" and calling remove below
last = std::prev(last);
......@@ -40,7 +40,7 @@ void eliminate_identity::apply(module& p) const
break;
}
}
p.remove_instructions(std::next(last), p.end());
m.remove_instructions(std::next(last), m.end());
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -13,7 +13,7 @@ struct adjust_allocation
{
allocation_model model;
std::string name() const { return "adjust_allocation"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -16,7 +16,7 @@ struct stream_race
instruction_ref before;
};
std::vector<stream_race> analyze_streams(const module& p, const stream_model& m);
std::vector<stream_race> analyze_streams(const module& m, const stream_model& strmm);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -13,7 +13,7 @@ struct module;
struct auto_contiguous
{
std::string name() const { return "auto_contiguous"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -33,7 +33,7 @@ struct check_context
};
std::string name() const { return "check_context"; }
void apply(module& p) const { p.insert_instruction(p.begin(), op{}); }
void apply(module& m) const { m.insert_instruction(m.begin(), op{}); }
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -24,6 +24,7 @@ struct src_compiler
std::string flags = "";
std::string output = "";
std::string launcher = "";
std::string out_ext = ".o";
std::function<fs::path(fs::path)> process = nullptr;
std::vector<char> compile(const std::vector<src_file>& srcs) const;
};
......
......@@ -19,7 +19,7 @@ struct eliminate_allocation
std::string allocation_op{};
std::size_t alignment = 32;
std::string name() const { return "eliminate_allocation"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -16,7 +16,7 @@ struct module;
struct eliminate_common_subexpression
{
std::string name() const { return "eliminate_common_subexpression"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -18,7 +18,7 @@ struct eliminate_concat
{
concat_optimization concat_opt;
std::string name() const { return "eliminate_concat"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment