Unverified Commit 818e6e53 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into mlir-c

parents d716c516 cb18b0b5
...@@ -236,6 +236,11 @@ void print_program(const program& p) { std::cout << p << std::endl; } ...@@ -236,6 +236,11 @@ void print_program(const program& p) { std::cout << p << std::endl; }
void print_module(const module& m) { std::cout << m << std::endl; } void print_module(const module& m) { std::cout << m << std::endl; }
migraphx::instruction_ref add_allocation(module& m, const migraphx::shape& s)
{
return m.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}}), {});
}
struct experimental_custom_op struct experimental_custom_op
{ {
std::string name; std::string name;
...@@ -260,7 +265,12 @@ struct custom_operation ...@@ -260,7 +265,12 @@ struct custom_operation
return op.compute_shape(std::move(inputs)); return op.compute_shape(std::move(inputs));
} }
argument compute(const std::vector<argument>&) const { MIGRAPHX_THROW("Not computable"); } // TODO: Compute method with module_args
argument
compute(migraphx::context ctx, migraphx::shape output_shape, std::vector<argument> inputs) const
{
return op.compute(std::move(ctx), std::move(output_shape), std::move(inputs));
}
}; };
template <class CustomOp> template <class CustomOp>
...@@ -577,6 +587,24 @@ struct migraphx_experimental_custom_op ...@@ -577,6 +587,24 @@ struct migraphx_experimental_custom_op
manage_generic_ptr<migraphx_experimental_custom_op_copy, migraphx_experimental_custom_op_delete> manage_generic_ptr<migraphx_experimental_custom_op_copy, migraphx_experimental_custom_op_delete>
object_ptr = nullptr; object_ptr = nullptr;
migraphx::experimental_custom_op xobject; migraphx::experimental_custom_op xobject;
migraphx_experimental_custom_op_compute compute_f = nullptr;
migraphx::argument compute(migraphx::context ctx,
migraphx::shape output,
std::vector<migraphx::argument> inputs) const
{
std::remove_pointer_t<migraphx_argument_t> out;
if(compute_f == nullptr)
throw std::runtime_error("compute function is missing.");
auto api_error_result = compute_f(&out,
object_ptr.data,
object_cast<migraphx_context_t>(&(ctx)),
object_cast<migraphx_shape_t>(&(output)),
object_cast<migraphx_arguments_t>(&(inputs)));
if(api_error_result != migraphx_status_success)
throw std::runtime_error("Error in compute.");
return (&out)->object;
}
migraphx_experimental_custom_op_compute_shape compute_shape_f = nullptr; migraphx_experimental_custom_op_compute_shape compute_shape_f = nullptr;
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{ {
...@@ -1141,6 +1169,21 @@ extern "C" migraphx_status migraphx_module_add_return(migraphx_instruction_t* ou ...@@ -1141,6 +1169,21 @@ extern "C" migraphx_status migraphx_module_add_return(migraphx_instruction_t* ou
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_module_add_allocation(migraphx_instruction_t* out,
migraphx_module_t module,
const_migraphx_shape_t s)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(s == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter s: Null pointer");
*out = allocate<migraphx_instruction_t>(
migraphx::add_allocation((module->object), (s->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program) extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program)
{ {
auto api_error_result = migraphx::try_([&] { destroy((program)); }); auto api_error_result = migraphx::try_([&] { destroy((program)); });
...@@ -1772,6 +1815,14 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi ...@@ -1772,6 +1815,14 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_experimental_custom_op_set_compute(migraphx_experimental_custom_op_t obj,
migraphx_experimental_custom_op_compute input)
{
auto api_error_result = migraphx::try_([&] { (obj)->compute_f = (input); });
return api_error_result;
}
extern "C" migraphx_status migraphx_experimental_custom_op_set_compute_shape( extern "C" migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input) migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input)
{ {
......
...@@ -129,6 +129,12 @@ typedef const struct migraphx_context* const_migraphx_context_t; ...@@ -129,6 +129,12 @@ typedef const struct migraphx_context* const_migraphx_context_t;
typedef struct migraphx_experimental_custom_op* migraphx_experimental_custom_op_t; typedef struct migraphx_experimental_custom_op* migraphx_experimental_custom_op_t;
typedef const struct migraphx_experimental_custom_op* const_migraphx_experimental_custom_op_t; typedef const struct migraphx_experimental_custom_op* const_migraphx_experimental_custom_op_t;
typedef migraphx_status (*migraphx_experimental_custom_op_compute)(migraphx_argument_t out,
void* obj,
migraphx_context_t ctx,
migraphx_shape_t output,
migraphx_arguments_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraphx_shape_t out, typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraphx_shape_t out,
void* obj, void* obj,
migraphx_shapes_t inputs); migraphx_shapes_t inputs);
...@@ -295,6 +301,10 @@ migraphx_status migraphx_module_add_return(migraphx_instruction_t* out, ...@@ -295,6 +301,10 @@ migraphx_status migraphx_module_add_return(migraphx_instruction_t* out,
migraphx_module_t module, migraphx_module_t module,
migraphx_instructions_t args); migraphx_instructions_t args);
migraphx_status migraphx_module_add_allocation(migraphx_instruction_t* out,
migraphx_module_t module,
const_migraphx_shape_t s);
migraphx_status migraphx_program_destroy(migraphx_program_t program); migraphx_status migraphx_program_destroy(migraphx_program_t program);
migraphx_status migraphx_program_assign_to(migraphx_program_t output, migraphx_status migraphx_program_assign_to(migraphx_program_t output,
...@@ -477,6 +487,10 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi ...@@ -477,6 +487,10 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
migraphx_experimental_custom_op_delete d, migraphx_experimental_custom_op_delete d,
const char* name); const char* name);
migraphx_status
migraphx_experimental_custom_op_set_compute(migraphx_experimental_custom_op_t obj,
migraphx_experimental_custom_op_compute input);
migraphx_status migraphx_experimental_custom_op_set_compute_shape( migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input); migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input);
......
...@@ -401,11 +401,14 @@ struct interface_base : Base ...@@ -401,11 +401,14 @@ struct interface_base : Base
return x; return x;
} }
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
template <class T> template <class T>
auto auto_convert_param(rank<1>, T x) -> decltype(as_handle<T>{x}) auto auto_convert_param(rank<1>, T x) -> decltype(as_handle<T>{x})
{ {
return as_handle<T>{x}; return as_handle<T>{x};
} }
#pragma GCC diagnostic pop
template <class T> template <class T>
auto auto_convert_param(rank<2>, T x) -> decltype(as_handle<T>{x, borrow{}}) auto auto_convert_param(rank<2>, T x) -> decltype(as_handle<T>{x, borrow{}})
...@@ -565,6 +568,14 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -565,6 +568,14 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
return pout; return pout;
} }
template <typename T>
std::vector<T> as_vector() const
{
size_t vector_len = this->get_shape().bytes() / sizeof(T);
T* buffer_ptr = reinterpret_cast<T*>(this->data());
return {buffer_ptr, buffer_ptr + vector_len};
}
/// Generate an argument using random data /// Generate an argument using random data
static argument generate(shape ps, size_t pseed = 0) static argument generate(shape ps, size_t pseed = 0)
{ {
...@@ -802,13 +813,20 @@ struct module ...@@ -802,13 +813,20 @@ struct module
return instruction(ret_ins, own{}); return instruction(ret_ins, own{});
} }
instruction add_allocation(const migraphx::shape& s)
{
migraphx_instruction_t ret_ins;
call(&migraphx_module_add_allocation, &ret_ins, mm.get(), s.get_handle_ptr());
return instruction(ret_ins, own{});
}
migraphx_module_t get_handle_ptr() const { return mm.get(); } migraphx_module_t get_handle_ptr() const { return mm.get(); }
private: private:
std::shared_ptr<migraphx_module> mm; std::shared_ptr<migraphx_module> mm;
}; };
struct context struct context : handle_lookup<context, migraphx_context>
{ {
context(migraphx_context* p, borrow) : ctx(std::shared_ptr<migraphx_context*>(), p) {} context(migraphx_context* p, borrow) : ctx(std::shared_ptr<migraphx_context*>(), p) {}
...@@ -1177,9 +1195,10 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op ...@@ -1177,9 +1195,10 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op
struct experimental_custom_op_base struct experimental_custom_op_base
{ {
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual shape compute_shape(shapes inputs) const = 0; virtual argument compute(context ctx, shape output, arguments inputs) const = 0;
virtual ~experimental_custom_op_base() = default; virtual shape compute_shape(shapes inputs) const = 0;
virtual ~experimental_custom_op_base() = default;
}; };
struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental_custom_op)> struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental_custom_op)>
...@@ -1189,6 +1208,7 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental ...@@ -1189,6 +1208,7 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental
{ {
this->make_interface(&migraphx_experimental_custom_op_create, obj, obj.name().c_str()); this->make_interface(&migraphx_experimental_custom_op_create, obj, obj.name().c_str());
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute_shape); MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute_shape);
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute);
} }
void register_op() { call(&migraphx_experimental_custom_op_register, this->get_handle_ptr()); } void register_op() { call(&migraphx_experimental_custom_op_register, this->get_handle_ptr()); }
......
...@@ -244,6 +244,10 @@ def module(h): ...@@ -244,6 +244,10 @@ def module(h):
h.method('add_return', h.method('add_return',
api.params(args='std::vector<migraphx::instruction_ref>'), api.params(args='std::vector<migraphx::instruction_ref>'),
returns='migraphx::instruction_ref') returns='migraphx::instruction_ref')
h.method('add_allocation',
api.params(s='const migraphx::shape&'),
invoke='migraphx::add_allocation($@)',
returns='migraphx::instruction_ref')
@auto_handle() @auto_handle()
...@@ -436,6 +440,11 @@ def context(h): ...@@ -436,6 +440,11 @@ def context(h):
'migraphx::experimental_custom_op') 'migraphx::experimental_custom_op')
def experimental_custom_op(h): def experimental_custom_op(h):
h.constructor('create', api.params(name='const char*')) h.constructor('create', api.params(name='const char*'))
h.virtual('compute',
api.params(ctx='migraphx::context',
output='migraphx::shape',
inputs='std::vector<migraphx::argument>'),
returns='migraphx::argument')
h.virtual('compute_shape', h.virtual('compute_shape',
api.params(inputs='std::vector<migraphx::shape>'), api.params(inputs='std::vector<migraphx::shape>'),
returns='migraphx::shape') returns='migraphx::shape')
......
...@@ -183,7 +183,7 @@ struct module ...@@ -183,7 +183,7 @@ struct module
void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const; void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const;
std::vector<module_ref> get_sub_modules() const; std::vector<module_ref> get_sub_modules(bool shallow = false) const;
module& sort(); module& sort();
ins_dep_map calc_implicit_deps() const; ins_dep_map calc_implicit_deps() const;
......
...@@ -38,6 +38,7 @@ struct module_pass_manager ...@@ -38,6 +38,7 @@ struct module_pass_manager
module_pass_manager(const module_pass_manager&) = delete; module_pass_manager(const module_pass_manager&) = delete;
virtual module& get_module() = 0; virtual module& get_module() = 0;
virtual module* create_module(const std::string& name) = 0; virtual module* create_module(const std::string& name) = 0;
virtual module* get_common_parent() = 0;
virtual void run_pass(const pass& p) = 0; virtual void run_pass(const pass& p) = 0;
protected: protected:
......
...@@ -132,6 +132,8 @@ struct program ...@@ -132,6 +132,8 @@ struct program
std::vector<const module*> get_modules() const; std::vector<const module*> get_modules() const;
std::vector<module*> get_modules(); std::vector<module*> get_modules();
std::unordered_multimap<module_ref, module_ref> get_module_tree();
void remove_module(const std::string& name); void remove_module(const std::string& name);
void remove_unused_modules(); void remove_unused_modules();
......
...@@ -216,10 +216,16 @@ void replace(Range&& r, const T& old, const T& new_x) ...@@ -216,10 +216,16 @@ void replace(Range&& r, const T& old, const T& new_x)
std::replace(r.begin(), r.end(), old, new_x); std::replace(r.begin(), r.end(), old, new_x);
} }
template <class R1, class R2> template <class R1, class R2, class... Predicate>
bool equal(R1&& r1, R2&& r2) bool equal(R1&& r1, R2&& r2, Predicate... pred)
{ {
return std::equal(r1.begin(), r1.end(), r2.begin(), r2.end()); return std::equal(r1.begin(), r1.end(), r2.begin(), r2.end(), pred...);
}
template <class Range>
auto distance(Range&& r)
{
return std::distance(r.begin(), r.end());
} }
template <class R> template <class R>
......
...@@ -821,17 +821,20 @@ void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) ...@@ -821,17 +821,20 @@ void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a)
}); });
} }
std::vector<module_ref> module::get_sub_modules() const std::vector<module_ref> module::get_sub_modules(bool shallow) const
{ {
std::vector<module_ref> vec_modules; std::vector<module_ref> vec_modules;
for(auto ins : iterator_for(*this)) for(auto ins : iterator_for(*this))
{ {
const auto& mod_args = ins->module_inputs(); const auto& mod_args = ins->module_inputs();
vec_modules.insert(vec_modules.end(), mod_args.begin(), mod_args.end()); vec_modules.insert(vec_modules.end(), mod_args.begin(), mod_args.end());
for(const auto& smod : mod_args) if(not shallow)
{ {
auto sub_mods = smod->get_sub_modules(); for(const auto& smod : mod_args)
vec_modules.insert(vec_modules.end(), sub_mods.begin(), sub_mods.end()); {
auto sub_mods = smod->get_sub_modules();
vec_modules.insert(vec_modules.end(), sub_mods.begin(), sub_mods.end());
}
} }
} }
......
...@@ -66,14 +66,12 @@ void run_pass(program& prog, const pass& p, tracer trace) ...@@ -66,14 +66,12 @@ void run_pass(program& prog, const pass& p, tracer trace)
struct module_pm : module_pass_manager struct module_pm : module_pass_manager
{ {
module* mod; module* mod = nullptr;
program* prog; tracer* t = nullptr;
tracer* t; module* common_parent = nullptr;
program* prog = nullptr;
module_pm(module* pmod = nullptr, program* pprog = nullptr, tracer* pt = nullptr) module_pm(module* pmod = nullptr, tracer* pt = nullptr) : mod(pmod), t(pt) {}
: mod(pmod), prog(pprog), t(pt)
{
}
template <class... Ts> template <class... Ts>
void trace(Ts&&... xs) const void trace(Ts&&... xs) const
...@@ -92,6 +90,7 @@ struct module_pm : module_pass_manager ...@@ -92,6 +90,7 @@ struct module_pm : module_pass_manager
assert(prog); assert(prog);
return prog->create_module(name); return prog->create_module(name);
} }
virtual module* get_common_parent() override { return common_parent; }
virtual void run_pass(const pass& p) override virtual void run_pass(const pass& p) override
{ {
assert(mod); assert(mod);
...@@ -111,7 +110,7 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace) ...@@ -111,7 +110,7 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
trace = tracer{std::cout}; trace = tracer{std::cout};
for(const auto& p : passes) for(const auto& p : passes)
{ {
module_pm{&mod, nullptr, &trace}.run_pass(p); module_pm{&mod, &trace}.run_pass(p);
} }
} }
...@@ -119,14 +118,31 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace) ...@@ -119,14 +118,31 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
{ {
if(enabled(MIGRAPHX_TRACE_PASSES{})) if(enabled(MIGRAPHX_TRACE_PASSES{}))
trace = tracer{std::cout}; trace = tracer{std::cout};
std::unordered_set<module_ref> visited;
for(const auto& p : passes) for(const auto& p : passes)
{ {
auto mods = prog.get_modules(); auto mods = prog.get_modules();
auto tree = prog.get_module_tree();
visited.clear();
for(const auto& mod : reverse(mods)) for(const auto& mod : reverse(mods))
{ {
if(mod->bypass()) if(mod->bypass())
continue; continue;
module_pm{mod, &prog, &trace}.run_pass(p); if(not visited.insert(mod).second)
continue;
module_pm mpm{mod, &trace};
mpm.prog = &prog;
auto parents = range(tree.equal_range(mod));
auto nparents = distance(parents);
if(nparents == 0)
mpm.common_parent = nullptr;
else if(nparents == 1)
mpm.common_parent = parents.begin()->second;
else
// Just set common parent to main module when there is muliple parents for now
// TODO: Compute the common parent
mpm.common_parent = prog.get_main_module();
mpm.run_pass(p);
} }
run_pass(prog, p, trace); run_pass(prog, p, trace);
} }
......
...@@ -869,6 +869,23 @@ std::vector<module*> program::get_modules() ...@@ -869,6 +869,23 @@ std::vector<module*> program::get_modules()
return result; return result;
} }
template <class Module, class Map>
void generic_insert_module_tree(Module* pm, Map& m)
{
for(auto* sm : pm->get_sub_modules(true))
{
m.insert(std::make_pair(sm, pm));
generic_insert_module_tree(sm, m);
}
}
std::unordered_multimap<module_ref, module_ref> program::get_module_tree()
{
std::unordered_multimap<module_ref, module_ref> result;
generic_insert_module_tree(this->get_main_module(), result);
return result;
}
template <class Map, class T> template <class Map, class T>
bool is_unused_module(Map& m, const std::vector<T*>& mods, const std::string& name) bool is_unused_module(Map& m, const std::vector<T*>& mods, const std::string& name)
{ {
......
...@@ -61,9 +61,7 @@ struct shape_impl ...@@ -61,9 +61,7 @@ struct shape_impl
{ {
assert(t != shape::tuple_type); assert(t != shape::tuple_type);
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
// assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and m_standard = this->elements() == this->element_space() and not skips() and
// "At least one stride must be non-zero");
m_standard = this->elements() == this->element_space() and
std::is_sorted(m_strides.rbegin(), m_strides.rend()); std::is_sorted(m_strides.rbegin(), m_strides.rend());
} }
...@@ -110,6 +108,15 @@ struct shape_impl ...@@ -110,6 +108,15 @@ struct shape_impl
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>()); m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
} }
// Does the shape skip over elements?
bool skips() const
{
assert(m_lens.size() == m_strides.size());
if(elements() == 1)
return false;
return std::none_of(m_strides.begin(), m_strides.end(), [](auto x) { return x == 1; });
}
std::shared_ptr<shape_impl> copy() const { return std::make_shared<shape_impl>(*this); } std::shared_ptr<shape_impl> copy() const { return std::make_shared<shape_impl>(*this); }
}; };
...@@ -260,7 +267,8 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end ...@@ -260,7 +267,8 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end
bool shape::packed() const bool shape::packed() const
{ {
return this->sub_shapes().empty() and this->elements() == this->element_space(); return this->sub_shapes().empty() and not impl->skips() and
this->elements() == this->element_space();
} }
bool shape::transposed() const bool shape::transposed() const
...@@ -285,10 +293,8 @@ bool shape::transposed() const ...@@ -285,10 +293,8 @@ bool shape::transposed() const
bool shape::broadcasted() const bool shape::broadcasted() const
{ {
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
return std::accumulate(this->strides().begin(), return std::any_of(
this->strides().end(), this->strides().begin(), this->strides().end(), [](auto x) { return x == 0; });
std::size_t{1},
std::multiplies<std::size_t>()) == 0;
} }
bool shape::scalar() const bool shape::scalar() const
......
...@@ -163,7 +163,6 @@ add_library(migraphx_gpu ...@@ -163,7 +163,6 @@ add_library(migraphx_gpu
convolution.cpp convolution.cpp
deconvolution.cpp deconvolution.cpp
device_name.cpp device_name.cpp
eliminate_workspace.cpp
elu.cpp elu.cpp
fuse_mlir.cpp fuse_mlir.cpp
fuse_ops.cpp fuse_ops.cpp
......
...@@ -48,6 +48,7 @@ ...@@ -48,6 +48,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/array.hpp> #include <migraphx/array.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp> #include <migraphx/op/clip.hpp>
#include <cmath> #include <cmath>
#include <set> #include <set>
...@@ -701,6 +702,7 @@ struct miopen_fusion ...@@ -701,6 +702,7 @@ struct miopen_fusion
return args.back(); return args.back();
} }
}; };
MIGRAPHX_REGISTER_OP(miopen_fusion)
struct miopen_conv_bias struct miopen_conv_bias
{ {
...@@ -1005,9 +1007,43 @@ struct find_commutative_broadcast ...@@ -1005,9 +1007,43 @@ struct find_commutative_broadcast
}; };
} // namespace } // namespace
struct find_contiguous
{
auto matcher() const { return match::name("gpu::contiguous"); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
m.replace_instruction(
ins,
make_op("gpu::precompile_op", {{"op", to_value(make_op("contiguous"))}}),
ins->inputs());
}
};
struct find_contiguous_pointwise
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(precompile_name("pointwise")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto pw = ins->inputs().front();
auto alloc = ins->inputs().back();
auto args = pw->inputs();
args.back() = alloc;
m.replace_instruction(ins, pw->get_operator(), args, pw->module_inputs());
}
};
void fuse_ops::apply(module& m) const void fuse_ops::apply(module& m) const
{ {
match::find_matches(m, find_gelu{}, find_gelu_new{fast_math}); match::find_matches(m, find_contiguous_pointwise{}, find_gelu{}, find_gelu_new{fast_math});
run_passes(m, {dead_code_elimination{}}); run_passes(m, {dead_code_elimination{}});
match::find_matches(m, find_triadd{}); match::find_matches(m, find_triadd{});
match::find_matches(m, match::find_matches(m,
...@@ -1029,6 +1065,7 @@ void fuse_ops::apply(module& m) const ...@@ -1029,6 +1065,7 @@ void fuse_ops::apply(module& m) const
find_gemm_add{}, find_gemm_add{},
find_gemm_pointwise{}, find_gemm_pointwise{},
find_commutative_broadcast{}); find_commutative_broadcast{});
match::find_matches(m, find_contiguous{});
} }
} // namespace gpu } // namespace gpu
......
...@@ -37,7 +37,6 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -37,7 +37,6 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_REGISTER_OP(hip_allocate) MIGRAPHX_REGISTER_OP(hip_allocate)
MIGRAPHX_REGISTER_OP(hip_sync_device)
MIGRAPHX_REGISTER_OP(hip_sync_stream) MIGRAPHX_REGISTER_OP(hip_sync_stream)
MIGRAPHX_REGISTER_OP(hip_copy_to_gpu) MIGRAPHX_REGISTER_OP(hip_copy_to_gpu)
MIGRAPHX_REGISTER_OP(hip_copy_from_gpu) MIGRAPHX_REGISTER_OP(hip_copy_from_gpu)
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_ELIMINATE_WORKSPACE_HPP
#define MIGRAPHX_GUARD_RTGLIB_ELIMINATE_WORKSPACE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu {
struct eliminate_workspace
{
std::string name() const { return "eliminate_workspace"; }
void apply(module& m) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -59,12 +59,11 @@ argument get_preallocation(context& ctx, const std::string& id); ...@@ -59,12 +59,11 @@ argument get_preallocation(context& ctx, const std::string& id);
struct hip_allocate struct hip_allocate
{ {
shape s; shape s;
std::string tag{};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.s, "shape"), f(self.tag, "tag")); return pack(f(self.s, "shape"));
} }
std::string name() const { return "hip::allocate"; } std::string name() const { return "hip::allocate"; }
...@@ -79,42 +78,8 @@ struct hip_allocate ...@@ -79,42 +78,8 @@ struct hip_allocate
} }
}; };
struct hip_sync_device
{
std::string tag{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.tag, "tag"));
}
std::string name() const { return "hip::sync_device"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
if(inputs.empty())
return {};
return inputs.front();
}
argument compute(context&, const shape&, const std::vector<argument>& args) const
{
gpu_sync();
if(args.empty())
return {};
return args.front();
}
};
struct hip_sync_stream struct hip_sync_stream
{ {
std::string tag{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.tag, "tag"));
}
std::string name() const { return "hip::sync_stream"; } std::string name() const { return "hip::sync_stream"; }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
......
...@@ -79,7 +79,7 @@ static std::vector<std::string> get_op_names(const module& m) ...@@ -79,7 +79,7 @@ static std::vector<std::string> get_op_names(const module& m)
struct pointwise_compiler : compiler<pointwise_compiler> struct pointwise_compiler : compiler<pointwise_compiler>
{ {
std::vector<std::string> names() const { return {"pointwise"}; } std::vector<std::string> names() const { return {"pointwise", "contiguous"}; }
static std::size_t oversubscribe_if(bool b) static std::size_t oversubscribe_if(bool b)
{ {
...@@ -114,34 +114,45 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -114,34 +114,45 @@ struct pointwise_compiler : compiler<pointwise_compiler>
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
compiler_replace compile(context& ctx, instruction_ref ins, const operation&) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
assert(not ins->module_inputs().empty()); if(op.name() == "contiguous")
auto* pm = ins->module_inputs().front(); {
run_passes(*pm, {eliminate_common_subexpression{}, dead_code_elimination{}}); return replace(compile_op(
cpp_generator g; ctx,
g.fmap([](const std::string& fname) { return "migraphx::" + fname; }); to_shapes(ins->inputs()),
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})"); {{"lambda", "[](auto x) { return x; }"}, {"kernel", "contiguous_kernel"}}));
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})"); }
g.add_point_op("sign", else
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"); {
g.add_point_op("equal", "migraphx::abs(${0} == ${1})"); assert(not ins->module_inputs().empty());
g.add_point_op("less", "migraphx::abs(${0} < ${1})"); auto* pm = ins->module_inputs().front();
g.add_point_op("greater", "migraphx::abs(${0} > ${1})"); run_passes(*pm, {eliminate_common_subexpression{}, dead_code_elimination{}});
g.add_point_op("not", "migraphx::abs(not ${0})"); cpp_generator g;
// Add explict conversions g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.fresult( g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; }); g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
auto name = g.create_function( g.add_point_op("sign",
g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm)); "${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
std::string lambda = "MIGRAPHX_LIFT(" + name + ")"; g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
auto op_names = get_op_names(*pm); g.add_point_op("less", "migraphx::abs(${0} < ${1})");
op_names.push_back("kernel"); g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
auto op_name_string = join_strings(op_names, "_"); g.add_point_op("not", "migraphx::abs(not ${0})");
return replace( // Add explict conversions
compile_op(ctx, g.fresult([](const shape& s) {
to_shapes(ins->inputs()), return "migraphx::convert<" + shape::cpp_type(s.type()) + ">";
{{"lambda", lambda}, {"preamble", g.str()}, {"kernel", op_name_string}})); });
auto name = g.create_function(
g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm));
std::string lambda = "MIGRAPHX_LIFT(" + name + ")";
auto op_names = get_op_names(*pm);
op_names.push_back("kernel");
auto op_name_string = join_strings(op_names, "_");
return replace(compile_op(
ctx,
to_shapes(ins->inputs()),
{{"lambda", lambda}, {"preamble", g.str()}, {"kernel", op_name_string}}));
}
} }
}; };
} // namespace gpu } // namespace gpu
......
...@@ -49,7 +49,7 @@ constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op) ...@@ -49,7 +49,7 @@ constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op)
{ {
for(; first != last; ++first) for(; first != last; ++first)
{ {
init = op(std::move(init), *first); init = op(static_cast<T&&>(init), *first);
} }
return init; return init;
} }
...@@ -64,6 +64,20 @@ constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first) ...@@ -64,6 +64,20 @@ constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first)
return d_first; return d_first;
} }
template <class InputIt, class OutputIt, class UnaryPredicate>
constexpr OutputIt copy_if(InputIt first, InputIt last, OutputIt d_first, UnaryPredicate pred)
{
for(; first != last; ++first)
{
if(pred(*first))
{
*d_first = *first;
++d_first;
}
}
return d_first;
}
template <class Iterator, class Compare> template <class Iterator, class Compare>
constexpr Iterator is_sorted_until(Iterator first, Iterator last, Compare comp) constexpr Iterator is_sorted_until(Iterator first, Iterator last, Compare comp)
{ {
...@@ -115,6 +129,24 @@ constexpr Iterator find(Iterator first, Iterator last, const T& value) ...@@ -115,6 +129,24 @@ constexpr Iterator find(Iterator first, Iterator last, const T& value)
return find_if(first, last, [&](const auto& x) { return x == value; }); return find_if(first, last, [&](const auto& x) { return x == value; });
} }
template <class InputIt, class UnaryPredicate>
constexpr bool any_of(InputIt first, InputIt last, UnaryPredicate p)
{
return find_if(first, last, p) != last;
}
template <class InputIt, class UnaryPredicate>
constexpr bool none_of(InputIt first, InputIt last, UnaryPredicate p)
{
return find_if(first, last, p) == last;
}
template <class InputIt, class UnaryPredicate>
constexpr bool all_of(InputIt first, InputIt last, UnaryPredicate p)
{
return none_of(first, last, [=](auto&& x) { return not p(x); });
}
template <class Iterator1, class Iterator2> template <class Iterator1, class Iterator2>
constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, Iterator2 s_last) constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, Iterator2 s_last)
{ {
......
...@@ -41,8 +41,15 @@ struct implicit_conversion_op ...@@ -41,8 +41,15 @@ struct implicit_conversion_op
template <index_int N, class U> template <index_int N, class U>
constexpr operator vec<U, N>() const constexpr operator vec<U, N>() const
{ {
static_assert(vec_size<T>() == N, "Vector mismatch size"); if constexpr(vec_size<T>() == 0)
return __builtin_convertvector(x, vec<U, N>); {
return x;
}
else
{
static_assert(vec_size<T>() == N, "Vector mismatch size");
return __builtin_convertvector(x, vec<U, N>);
}
} }
template <class U> template <class U>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment