Commit ff3bd8e6 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents 32b69ceb c310bc5c
...@@ -168,6 +168,12 @@ void copy(Range&& r, Iterator it) ...@@ -168,6 +168,12 @@ void copy(Range&& r, Iterator it)
std::copy(r.begin(), r.end(), it); std::copy(r.begin(), r.end(), it);
} }
template <class Range>
auto reverse(Range& r)
{
return range(std::make_reverse_iterator(r.end()), std::make_reverse_iterator(r.begin()));
}
template <class Range, class T> template <class Range, class T>
void replace(Range&& r, const T& old, const T& new_x) void replace(Range&& r, const T& old, const T& new_x)
{ {
......
...@@ -29,7 +29,15 @@ struct raw_data : raw_data_base ...@@ -29,7 +29,15 @@ struct raw_data : raw_data_base
friend Stream& operator<<(Stream& os, const Derived& d) friend Stream& operator<<(Stream& os, const Derived& d)
{ {
if(not d.empty()) if(not d.empty())
d.visit([&](auto x) { os << x; }); d.visit([&](auto x) { os << x; },
[&](auto&& xs) {
for(auto&& x : xs)
{
os << "{ ";
os << x;
os << " }, ";
}
});
return os; return os;
} }
...@@ -46,8 +54,18 @@ struct raw_data : raw_data_base ...@@ -46,8 +54,18 @@ struct raw_data : raw_data_base
if(derived.empty()) if(derived.empty())
MIGRAPHX_THROW("Visiting empty data!"); MIGRAPHX_THROW("Visiting empty data!");
auto&& s = derived.get_shape(); auto&& s = derived.get_shape();
auto&& buffer = derived.data(); s.visit_type([&](auto as) { v(*(as.from(derived.data()) + s.index(n))); });
s.visit_type([&](auto as) { v(*(as.from(buffer) + s.index(n))); }); }
template <class Visitor, class TupleVisitor>
void visit(Visitor v, TupleVisitor tv) const
{
auto&& derived = static_cast<const Derived&>(*this);
if(derived.empty())
MIGRAPHX_THROW("Visiting empty data!");
auto&& s = derived.get_shape();
s.visit_type([&](auto as) { v(make_view(s, as.from(derived.data()))); },
[&] { tv(derived.get_sub_objects()); });
} }
/** /**
...@@ -60,12 +78,7 @@ struct raw_data : raw_data_base ...@@ -60,12 +78,7 @@ struct raw_data : raw_data_base
template <class Visitor> template <class Visitor>
void visit(Visitor v) const void visit(Visitor v) const
{ {
auto&& derived = static_cast<const Derived&>(*this); visit(v, [&](const auto&) { MIGRAPHX_THROW("Invalid tuple type"); });
if(derived.empty())
MIGRAPHX_THROW("Visiting empty data!");
auto&& s = derived.get_shape();
auto&& buffer = derived.data();
s.visit_type([&](auto as) { v(make_view(s, as.from(buffer))); });
} }
/// Returns true if the raw data is only one element /// Returns true if the raw data is only one element
...@@ -156,43 +169,27 @@ struct raw_data : raw_data_base ...@@ -156,43 +169,27 @@ struct raw_data : raw_data_base
} }
}; };
template <class T, namespace detail {
class U, template <class V1, class V2, class... Ts>
MIGRAPHX_REQUIRES(std::is_base_of<raw_data_base, T>{} && void visit_all_flatten(const shape& s, V1&& v1, V2&& v2, Ts&&... xs)
std::is_base_of<raw_data_base, U>{})>
bool operator==(const T& x, const U& y)
{ {
auto&& xshape = x.get_shape(); s.visit_type([&](auto as) { v1(make_view(xs.get_shape(), as.from(xs.data()))...); },
auto&& yshape = y.get_shape(); [&] { v2(xs.get_sub_objects()...); });
bool result = x.empty() && y.empty();
if(not result && xshape == yshape)
{
auto&& xbuffer = x.data();
auto&& ybuffer = y.data();
// TODO: Dont use tensor view for single values
xshape.visit_type([&](auto as) {
auto xview = make_view(xshape, as.from(xbuffer));
auto yview = make_view(yshape, as.from(ybuffer));
result = xview == yview;
});
}
return result;
} }
template <class T, template <class V1, class V2, class... Ts>
class U, auto visit_all_pack(const shape& s, V1&& v1, V2&& v2)
MIGRAPHX_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator!=(const T& x, const U& y)
{ {
return !(x == y); return [&](auto&&... xs) {
// Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100
visit_all_flatten(s, v1, v2, xs...);
};
} }
namespace detail { template <class V1, class... Ts>
template <class V, class... Ts> auto visit_all_pack(const shape& s, V1&& v1)
void visit_all_impl(const shape& s, V&& v, Ts&&... xs)
{ {
s.visit_type([&](auto as) { v(make_view(xs.get_shape(), as.from(xs.data()))...); }); return visit_all_pack(s, v1, [](auto&&...) { MIGRAPHX_THROW("Invalid tuple type"); });
} }
} // namespace detail } // namespace detail
...@@ -215,10 +212,7 @@ auto visit_all(T&& x, Ts&&... xs) ...@@ -215,10 +212,7 @@ auto visit_all(T&& x, Ts&&... xs)
std::initializer_list<shape::type_t> types = {xs.get_shape().type()...}; std::initializer_list<shape::type_t> types = {xs.get_shape().type()...};
if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); })) if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same"); MIGRAPHX_THROW("Types must be the same");
return [&](auto v) { return [&](auto... vs) { detail::visit_all_pack(s, vs...)(x, xs...); };
// Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100
detail::visit_all_impl(s, v, x, xs...);
};
} }
template <class T> template <class T>
...@@ -240,6 +234,34 @@ auto visit_all(const std::vector<T>& x) ...@@ -240,6 +234,34 @@ auto visit_all(const std::vector<T>& x)
}; };
} }
template <class T,
class U,
MIGRAPHX_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator==(const T& x, const U& y)
{
auto&& xshape = x.get_shape();
auto&& yshape = y.get_shape();
bool result = x.empty() and y.empty();
if(not result and xshape == yshape)
{
visit_all(x, y)([&](auto xview, auto yview) { result = xview == yview; },
[&](auto&& xs, auto&& ys) {
result = std::equal(xs.begin(), xs.end(), ys.begin(), ys.end());
});
}
return result;
}
template <class T,
class U,
MIGRAPHX_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator!=(const T& x, const U& y)
{
return !(x == y);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -40,7 +40,7 @@ struct shape ...@@ -40,7 +40,7 @@ struct shape
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x, #define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
enum type_t enum type_t
{ {
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES) tuple_type
}; };
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES #undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
...@@ -82,6 +82,8 @@ struct shape ...@@ -82,6 +82,8 @@ struct shape
{ {
} }
shape(const std::vector<shape>& subs);
static shape static shape
from_permutation(type_t t, const std::vector<std::size_t>& l, const std::vector<int64_t>& perm); from_permutation(type_t t, const std::vector<std::size_t>& l, const std::vector<int64_t>& perm);
type_t type() const; type_t type() const;
...@@ -179,11 +181,16 @@ struct shape ...@@ -179,11 +181,16 @@ struct shape
type_t type_enum() const { return get_type<type>{}; } type_t type_enum() const { return get_type<type>{}; }
}; };
template <class Visitor> template <class Visitor, class TupleVisitor>
static void visit(type_t t, Visitor v) static void visit(type_t t, Visitor v, TupleVisitor tv)
{ {
switch(t) switch(t)
{ {
case tuple_type:
{
tv();
return;
}
#define MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE(x, t) \ #define MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return; case x: v(as<t>()); return;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE)
...@@ -193,9 +200,15 @@ struct shape ...@@ -193,9 +200,15 @@ struct shape
} }
template <class Visitor> template <class Visitor>
void visit_type(Visitor v) const static void visit(type_t t, Visitor v)
{ {
visit(this->type(), v); return visit(t, v, [] { MIGRAPHX_THROW("Tuple cannot be visited."); });
}
template <class... Visitors>
void visit_type(Visitors... vs) const
{
visit(this->type(), vs...);
} }
template <class Visitor> template <class Visitor>
...@@ -209,6 +222,8 @@ struct shape ...@@ -209,6 +222,8 @@ struct shape
std::string type_string() const; std::string type_string() const;
static type_t parse_type(const std::string& s); static type_t parse_type(const std::string& s);
const std::vector<shape>& sub_shapes() const;
private: private:
std::shared_ptr<const shape_impl> impl; std::shared_ptr<const shape_impl> impl;
......
...@@ -160,6 +160,7 @@ struct value ...@@ -160,6 +160,7 @@ struct value
binary(T* data, std::size_t s) : base(data, data + s) binary(T* data, std::size_t s) : base(data, data + s)
{ {
} }
explicit binary(std::size_t s) : base(s) {}
}; };
value() = default; value() = default;
......
...@@ -24,8 +24,57 @@ struct module_impl ...@@ -24,8 +24,57 @@ struct module_impl
{ {
// A list is used to keep references to an instruction stable // A list is used to keep references to an instruction stable
std::list<instruction> instructions; std::list<instruction> instructions;
std::unordered_set<instruction*> instruction_set;
std::vector<std::string> input_names; std::vector<std::string> input_names;
std::string name; std::string name;
bool contains(instruction_ref ins) const
{
if(ins == instructions.end())
return false;
return instruction_set.count(std::addressof(*ins)) > 0;
}
template <class... Ts>
instruction_ref emplace(instruction_ref pos, Ts&&... xs)
{
// cppcheck-suppress redundantInitialization
auto r = instructions.emplace(pos, std::forward<Ts>(xs)...);
instruction_set.insert(std::addressof(*r));
return r;
}
instruction_ref insert(instruction_ref pos, const instruction& ins)
{
return emplace(pos, ins);
}
void push_front(const instruction& ins) { insert(instructions.begin(), ins); }
void push_back(const instruction& ins) { insert(instructions.end(), ins); }
template <class... Ts>
void emplace_front(Ts&&... xs)
{
emplace(instructions.begin(), std::forward<Ts>(xs)...);
}
template <class... Ts>
void emplace_back(Ts&&... xs)
{
emplace(instructions.end(), std::forward<Ts>(xs)...);
}
instruction_ref erase(instruction_ref pos)
{
instruction_set.erase(std::addressof(*pos));
return instructions.erase(pos);
}
instruction_ref erase(instruction_ref start, instruction_ref last)
{
std::for_each(start, last, [&](auto& ins) { instruction_set.erase(std::addressof(ins)); });
return instructions.erase(start, last);
}
}; };
const operation& get_operation(instruction_ref ins) { return ins->get_operator(); } const operation& get_operation(instruction_ref ins) { return ins->get_operator(); }
...@@ -71,20 +120,19 @@ void module::assign(const module& m) ...@@ -71,20 +120,19 @@ void module::assign(const module& m)
if(ins->name() == "@literal") if(ins->name() == "@literal")
{ {
auto l = ins->get_literal(); auto l = ins->get_literal();
copy_ins = impl->instructions.insert(impl->instructions.end(), instruction{l}); copy_ins = impl->insert(impl->instructions.end(), instruction{l});
} }
else if(ins->name() == "@param") else if(ins->name() == "@param")
{ {
auto&& name = any_cast<builtin::param>(ins->get_operator()).parameter; auto&& name = any_cast<builtin::param>(ins->get_operator()).parameter;
auto s = ins->get_shape(); auto s = ins->get_shape();
copy_ins = impl->instructions.insert(impl->instructions.end(), copy_ins =
{builtin::param{name}, std::move(s), {}}); impl->insert(impl->instructions.end(), {builtin::param{name}, std::move(s), {}});
} }
else if(ins->name() == "@outline") else if(ins->name() == "@outline")
{ {
auto s = ins->get_shape(); auto s = ins->get_shape();
copy_ins = copy_ins = impl->insert(impl->instructions.end(), {builtin::outline{s}, s, {}});
impl->instructions.insert(impl->instructions.end(), {builtin::outline{s}, s, {}});
} }
else else
{ {
...@@ -127,7 +175,7 @@ instruction_ref module::insert_instruction(instruction_ref ins, ...@@ -127,7 +175,7 @@ instruction_ref module::insert_instruction(instruction_ref ins,
{ {
assert(not starts_with(op.name(), "@")); assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args); shape r = compute_shape(op, args);
auto result = impl->instructions.insert(ins, {op, r, std::move(args)}); auto result = impl->insert(ins, {op, r, std::move(args)});
instruction::backreference(result); instruction::backreference(result);
assert(result->valid(begin())); assert(result->valid(begin()));
return result; return result;
...@@ -148,8 +196,7 @@ instruction_ref module::insert_instruction(instruction_ref ins, ...@@ -148,8 +196,7 @@ instruction_ref module::insert_instruction(instruction_ref ins,
{ {
assert(not starts_with(op.name(), "@")); assert(not starts_with(op.name(), "@"));
auto out_shape = compute_shape(op, args, module_args); auto out_shape = compute_shape(op, args, module_args);
auto result = auto result = impl->insert(ins, {op, out_shape, std::move(args), std::move(module_args)});
impl->instructions.insert(ins, {op, out_shape, std::move(args), std::move(module_args)});
instruction::backreference(result); instruction::backreference(result);
assert(result->valid(begin())); assert(result->valid(begin()));
return result; return result;
...@@ -222,7 +269,7 @@ instruction_ref module::remove_instruction(instruction_ref ins) ...@@ -222,7 +269,7 @@ instruction_ref module::remove_instruction(instruction_ref ins)
assert(has_instruction(ins)); assert(has_instruction(ins));
assert(ins->outputs().empty()); assert(ins->outputs().empty());
ins->clear_arguments(); ins->clear_arguments();
return impl->instructions.erase(ins); return impl->erase(ins);
} }
instruction_ref module::remove_instructions(instruction_ref first, instruction_ref last) instruction_ref module::remove_instructions(instruction_ref first, instruction_ref last)
...@@ -233,7 +280,7 @@ instruction_ref module::remove_instructions(instruction_ref first, instruction_r ...@@ -233,7 +280,7 @@ instruction_ref module::remove_instructions(instruction_ref first, instruction_r
assert(has_instruction(first)); assert(has_instruction(first));
std::for_each(first, last, [&](instruction& ins) { ins.clear_arguments(); }); std::for_each(first, last, [&](instruction& ins) { ins.clear_arguments(); });
assert(std::all_of(first, last, [&](const instruction& ins) { return ins.outputs().empty(); })); assert(std::all_of(first, last, [&](const instruction& ins) { return ins.outputs().empty(); }));
return impl->instructions.erase(first, last); return impl->erase(first, last);
} }
instruction_ref module::move_instruction(instruction_ref src, instruction_ref dst) instruction_ref module::move_instruction(instruction_ref src, instruction_ref dst)
...@@ -252,13 +299,13 @@ instruction_ref module::move_instructions(instruction_ref src, instruction_ref d ...@@ -252,13 +299,13 @@ instruction_ref module::move_instructions(instruction_ref src, instruction_ref d
instruction_ref module::add_literal(literal l) instruction_ref module::add_literal(literal l)
{ {
impl->instructions.emplace_front(std::move(l)); impl->emplace_front(std::move(l));
return impl->instructions.begin(); return impl->instructions.begin();
} }
instruction_ref module::add_outline(const shape& s) instruction_ref module::add_outline(const shape& s)
{ {
impl->instructions.push_front({builtin::outline{s}, s, {}}); impl->push_front({builtin::outline{s}, s, {}});
return impl->instructions.begin(); return impl->instructions.begin();
} }
...@@ -267,13 +314,13 @@ instruction_ref module::add_parameter(std::string name, shape s) ...@@ -267,13 +314,13 @@ instruction_ref module::add_parameter(std::string name, shape s)
assert(get_parameter_shape(name) == shape{}); assert(get_parameter_shape(name) == shape{});
impl->input_names.push_back(name); impl->input_names.push_back(name);
impl->instructions.push_front({builtin::param{std::move(name)}, std::move(s), {}}); impl->push_front({builtin::param{std::move(name)}, std::move(s), {}});
return impl->instructions.begin(); return impl->instructions.begin();
} }
instruction_ref module::add_return(std::vector<instruction_ref> args) instruction_ref module::add_return(std::vector<instruction_ref> args)
{ {
impl->instructions.push_back({builtin::returns{}, {}, std::move(args)}); impl->push_back({builtin::returns{}, {}, std::move(args)});
auto result = std::prev(impl->instructions.end()); auto result = std::prev(impl->instructions.end());
instruction::backreference(result); instruction::backreference(result);
assert(result->valid(begin())); assert(result->valid(begin()));
...@@ -350,13 +397,7 @@ std::unordered_map<std::string, shape> module::get_parameter_shapes() const ...@@ -350,13 +397,7 @@ std::unordered_map<std::string, shape> module::get_parameter_shapes() const
return result; return result;
} }
bool module::has_instruction(instruction_ref ins) const bool module::has_instruction(instruction_ref ins) const { return impl->contains(ins); }
{
return std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
return std::addressof(*ins) == std::addressof(x);
}) != impl->instructions.end();
}
std::size_t module::size() const { return impl->instructions.size(); } std::size_t module::size() const { return impl->instructions.size(); }
instruction_ref module::begin() const { return impl->instructions.begin(); } instruction_ref module::begin() const { return impl->instructions.begin(); }
...@@ -364,6 +405,8 @@ instruction_ref module::end() const { return impl->instructions.end(); } ...@@ -364,6 +405,8 @@ instruction_ref module::end() const { return impl->instructions.end(); }
std::vector<shape> module::get_output_shapes() const std::vector<shape> module::get_output_shapes() const
{ {
if(impl->instructions.empty())
return {};
auto last_ins = impl->instructions.back(); auto last_ins = impl->instructions.back();
if(last_ins.name() == "@return") if(last_ins.name() == "@return")
{ {
......
...@@ -13,7 +13,7 @@ target_include_directories(migraphx_onnx PRIVATE include) ...@@ -13,7 +13,7 @@ target_include_directories(migraphx_onnx PRIVATE include)
set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx) set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx)
rocm_set_soversion(migraphx_onnx ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_onnx ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_onnx) rocm_clang_tidy_check(migraphx_onnx)
target_link_libraries(migraphx_onnx PRIVATE onnx-proto) target_link_libraries(migraphx_onnx PRIVATE onnx-proto "-Wl,--exclude-libs,ALL")
target_link_libraries(migraphx_onnx PUBLIC migraphx) target_link_libraries(migraphx_onnx PUBLIC migraphx)
rocm_install_targets( rocm_install_targets(
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
instruction_ref parse_prefix_scan_oper(const std::string& op_name,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args)
{
migraphx::argument in = args[1]->eval();
check_arg_empty(in, "PARSE_PREFIX_SCAN: axis - dynamic shape not supported");
std::vector<std::size_t> axis_in;
in.visit([&](auto input) { axis_in.assign(input.begin(), input.end()); });
int64_t axis = axis_in[0];
bool exclusive = false;
bool reverse = false;
if(contains(info.attributes, "exclusive"))
{
exclusive = parser.parse_value(info.attributes.at("exclusive")).at<bool>();
}
if(contains(info.attributes, "reverse"))
{
reverse = parser.parse_value(info.attributes.at("reverse")).at<bool>();
}
return info.add_instruction(
make_op(op_name, {{"axis", axis}, {"exclusive", exclusive}, {"reverse", reverse}}),
args[0]);
}
struct parse_prefix_scan_op : op_parser<parse_prefix_scan_op>
{
std::vector<op_desc> operators() const { return {{"CumSum", "prefix_scan_sum"}}; }
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
return parse_prefix_scan_oper(opd.op_name, parser, std::move(info), std::move(args));
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -15,25 +15,56 @@ ...@@ -15,25 +15,56 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void run_passes(module& modl, const std::vector<pass>& passes, tracer trace) void validate_pass(module& mod, const pass& p, tracer trace)
{ {
for(const auto& p : passes) (void)mod;
{ (void)p;
trace("Module: ", modl.name(), ", Pass: ", p.name()); (void)trace;
p.apply(modl);
trace(modl);
#ifndef NDEBUG #ifndef NDEBUG
trace("Validate ..."); trace("Validate ...");
auto invalid = modl.validate(); auto invalid = mod.validate();
if(invalid != modl.end()) if(invalid != mod.end())
{ {
auto index = std::distance(modl.begin(), invalid); auto index = std::distance(mod.begin(), invalid);
MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " + MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " +
std::to_string(index) + ": " + invalid->name()); std::to_string(index) + ": " + invalid->name());
} }
trace(); trace();
#endif #endif
}
void run_pass(module& mod, const pass& p, tracer trace)
{
trace("Module: ", mod.name(), ", Pass: ", p.name());
assert(mod.validate() == mod.end());
p.apply(mod);
trace(mod);
validate_pass(mod, p, trace);
}
void run_pass(program& prog, const pass& p, tracer trace)
{
trace("Pass: ", p.name());
p.apply(prog);
trace(prog);
}
void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
{
for(const auto& p : passes)
{
run_pass(mod, p, trace);
}
}
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
{
for(const auto& p : passes)
{
auto mods = prog.get_modules();
for(const auto& mod : reverse(mods))
{
run_pass(*mod, p, trace);
}
run_pass(prog, p, trace);
} }
} }
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/output_iterator.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
...@@ -26,13 +28,12 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -26,13 +28,12 @@ inline namespace MIGRAPHX_INLINE_NS {
struct program_impl struct program_impl
{ {
// A map is used to keep references to modules of the program // A map is used to keep references to modules of the program
// all the modules are store in the depth-first order std::unordered_map<std::string, module> modules;
std::list<module> modules;
context ctx; context ctx;
std::string target_name; std::string target_name;
}; };
program::program() : impl(std::make_unique<program_impl>()) { impl->modules.push_back({"main"}); } program::program() : impl(std::make_unique<program_impl>()) { this->create_module("main"); }
program::program(program&&) noexcept = default; program::program(program&&) noexcept = default;
program::~program() noexcept = default; program::~program() noexcept = default;
...@@ -65,11 +66,11 @@ void program::assign(const program& p) ...@@ -65,11 +66,11 @@ void program::assign(const program& p)
// build a map from old ins to new ins // build a map from old ins to new ins
// Build a map from old module to new module // Build a map from old module to new module
std::unordered_map<module_ref, module_ref> mod_map; std::unordered_map<module_ref, module_ref> mod_map;
std::transform(impl->modules.begin(), std::transform(
impl->modules.begin(),
impl->modules.end(), impl->modules.end(),
p.impl->modules.begin(),
std::inserter(mod_map, mod_map.begin()), std::inserter(mod_map, mod_map.begin()),
[](auto&& x, auto&& y) { return std::make_pair(&y, &x); }); [&](auto&& xp) { return std::make_pair(&p.impl->modules.at(xp.first), &xp.second); });
std::unordered_map<instruction_ref, instruction_ref> ins_map; std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto&& pp : mod_map) for(auto&& pp : mod_map)
...@@ -86,7 +87,7 @@ void program::assign(const program& p) ...@@ -86,7 +87,7 @@ void program::assign(const program& p)
// Update all references from all modules // Update all references from all modules
for(auto&& mp : impl->modules) for(auto&& mp : impl->modules)
{ {
for(auto ins : iterator_for(mp)) for(auto ins : iterator_for(mp.second))
instruction::replace_refs(ins, ins_map, mod_map); instruction::replace_refs(ins, ins_map, mod_map);
} }
} }
...@@ -144,14 +145,14 @@ void program::compile(const target& t, compile_options options) ...@@ -144,14 +145,14 @@ void program::compile(const target& t, compile_options options)
options.trace(*this); options.trace(*this);
options.trace(); options.trace();
auto mods = this->get_modules();
std::reverse(mods.begin(), mods.end());
auto&& passes = t.get_passes(this->impl->ctx, options); auto&& passes = t.get_passes(this->impl->ctx, options);
run_passes(*this, passes, options.trace);
for(const auto& mod : mods) auto mods = this->get_modules();
// Validate and finalize
for(const auto& mod : reverse(mods))
{ {
assert(mod->validate() == mod->end());
run_passes(*mod, passes, options.trace);
auto invalid = mod->validate(); auto invalid = mod->validate();
if(invalid != mod->end()) if(invalid != mod->end())
{ {
...@@ -306,7 +307,7 @@ std::vector<argument> program::eval(parameter_map params) const ...@@ -306,7 +307,7 @@ std::vector<argument> program::eval(parameter_map params) const
} }
} }
const int program_file_version = 4; const int program_file_version = 5;
value program::to_value() const value program::to_value() const
{ {
...@@ -316,14 +317,14 @@ value program::to_value() const ...@@ -316,14 +317,14 @@ value program::to_value() const
if(not this->impl->target_name.empty()) if(not this->impl->target_name.empty())
result["context"] = this->impl->ctx.to_value(); result["context"] = this->impl->ctx.to_value();
value module_vals = value::array{}; value module_vals = value::object{};
std::unordered_map<instruction_ref, std::string> names; std::unordered_map<instruction_ref, std::string> names;
for(auto& mod : this->impl->modules) for(auto& mod : this->get_modules())
{ {
value mod_val; value mod_val;
value nodes; value nodes;
mod_val["name"] = mod.name(); mod_val["name"] = mod->name();
names = mod.print( names = mod->print(
[&](auto ins, auto ins_names) { [&](auto ins, auto ins_names) {
value node; value node;
node["output"] = ins_names.at(ins); node["output"] = ins_names.at(ins);
...@@ -358,7 +359,7 @@ value program::to_value() const ...@@ -358,7 +359,7 @@ value program::to_value() const
names); names);
mod_val["nodes"] = nodes; mod_val["nodes"] = nodes;
module_vals.push_back(mod_val); module_vals[mod->name()] = mod_val;
} }
result["modules"] = module_vals; result["modules"] = module_vals;
...@@ -371,12 +372,7 @@ static void mod_from_val(module_ref mod, ...@@ -371,12 +372,7 @@ static void mod_from_val(module_ref mod,
std::unordered_map<std::string, instruction_ref>& instructions, std::unordered_map<std::string, instruction_ref>& instructions,
const std::unordered_map<std::string, module_ref>& map_mods) const std::unordered_map<std::string, module_ref>& map_mods)
{ {
const auto* it = std::find_if(v.begin(), v.end(), [&](auto& mv) { const auto& module_val = v.at(mod->name());
return mv.at("name").template to<std::string>() == mod->name();
});
assert(it != v.end());
const auto& module_val = *it;
for(const value& node : module_val.at("nodes")) for(const value& node : module_val.at("nodes"))
{ {
instruction_ref output; instruction_ref output;
...@@ -455,15 +451,18 @@ void program::from_value(const value& v) ...@@ -455,15 +451,18 @@ void program::from_value(const value& v)
} }
auto module_vals = v.at("modules"); auto module_vals = v.at("modules");
std::unordered_map<std::string, module_ref> map_mods;
for(const auto& vv : module_vals) for(const auto& vv : module_vals)
{ {
const auto& name = vv.at("name").to<std::string>(); const auto& name = vv.get_key();
if(name == "main") if(name == "main")
continue; continue;
impl->modules.push_back({name}); impl->modules.emplace(name, name);
map_mods[name] = &impl->modules.back();
} }
std::unordered_map<std::string, module_ref> map_mods;
std::transform(impl->modules.begin(),
impl->modules.end(),
std::inserter(map_mods, map_mods.end()),
[&](auto&& pp) { return std::make_pair(pp.first, &pp.second); });
std::unordered_map<std::string, instruction_ref> map_insts; std::unordered_map<std::string, instruction_ref> map_insts;
auto* mm = get_main_module(); auto* mm = get_main_module();
...@@ -585,8 +584,8 @@ void program::debug_print() const { std::cout << *this << std::endl; } ...@@ -585,8 +584,8 @@ void program::debug_print() const { std::cout << *this << std::endl; }
void program::debug_print(instruction_ref ins) const void program::debug_print(instruction_ref ins) const
{ {
std::unordered_map<instruction_ref, std::string> names; std::unordered_map<instruction_ref, std::string> names;
if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& it) { if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& pp) {
return (it.end() == ins); return (pp.second.end() == ins);
})) }))
{ {
std::cout << "End instruction" << std::endl; std::cout << "End instruction" << std::endl;
...@@ -594,7 +593,7 @@ void program::debug_print(instruction_ref ins) const ...@@ -594,7 +593,7 @@ void program::debug_print(instruction_ref ins) const
} }
else if(std::none_of(this->impl->modules.begin(), else if(std::none_of(this->impl->modules.begin(),
this->impl->modules.end(), this->impl->modules.end(),
[&](const auto& it) { return it.has_instruction(ins); })) [&](const auto& pp) { return pp.second.has_instruction(ins); }))
{ {
std::cout << "Instruction not part of program" << std::endl; std::cout << "Instruction not part of program" << std::endl;
return; return;
...@@ -615,9 +614,9 @@ void program::print( ...@@ -615,9 +614,9 @@ void program::print(
const std::function<void(instruction_ref, std::unordered_map<instruction_ref, std::string>)>& const std::function<void(instruction_ref, std::unordered_map<instruction_ref, std::string>)>&
print_func) const print_func) const
{ {
for(const auto& mod : this->impl->modules) for(const auto& pp : this->impl->modules)
{ {
names = mod.print(print_func, names); names = pp.second.print(print_func, names);
} }
} }
...@@ -647,74 +646,118 @@ void program::dry_run(std::unordered_map<std::string, argument> params) const ...@@ -647,74 +646,118 @@ void program::dry_run(std::unordered_map<std::string, argument> params) const
void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const
{ {
for(auto& mod : this->impl->modules) for(auto& pp : this->impl->modules)
{ {
std::cout << mod.name() << ":" << std::endl; std::cout << pp.first << ":" << std::endl;
mod.annotate(os, a); pp.second.annotate(os, a);
} }
} }
const module* program::get_module(const std::string& name) const const module* program::get_module(const std::string& name) const { return &impl->modules.at(name); }
{
auto it = std::find_if(
impl->modules.begin(), impl->modules.end(), [&](auto& m) { return (m.name() == name); });
if(it == impl->modules.end())
{
return nullptr;
}
return &(*it);
}
module* program::create_module(const std::string& name) module* program::create_module(const std::string& name)
{ {
auto it = impl->modules.insert(impl->modules.end(), {name}); assert(not contains(impl->modules, name));
return &(*it); auto r = impl->modules.emplace(name, name);
return &(r.first->second);
} }
module* program::get_module(const std::string& name) module* program::get_module(const std::string& name) { return &impl->modules.at(name); }
{
auto it = std::find_if(
impl->modules.begin(), impl->modules.end(), [&](auto& m) { return (m.name() == name); });
if(it == impl->modules.end())
{
return nullptr;
}
return &(*it);
}
module* program::get_main_module() { return get_module("main"); } module* program::get_main_module() { return get_module("main"); }
const module* program::get_main_module() const { return get_module("main"); } const module* program::get_main_module() const { return get_module("main"); }
std::vector<const module*> program::get_modules() const template <class T>
std::vector<T*> generic_get_modules(T* mm)
{ {
const module* mm = this->get_main_module(); std::vector<T*> vec_modules;
std::vector<const module*> vec_modules;
vec_modules.push_back(mm); vec_modules.push_back(mm);
auto sub_modules = mm->get_sub_modules(); auto sub_modules = mm->get_sub_modules();
vec_modules.insert(vec_modules.end(), sub_modules.begin(), sub_modules.end()); vec_modules.insert(vec_modules.end(), sub_modules.begin(), sub_modules.end());
return vec_modules; return vec_modules;
} }
template <class Map, class T, class OutputIterator>
void generic_get_unused_modules(Map& m, const std::vector<T*>& mods, OutputIterator out)
{
std::unordered_set<std::string> used;
std::transform(mods.begin(), mods.end(), std::inserter(used, used.end()), [](auto&& mod) {
return mod->name();
});
transform_if(m.begin(),
m.end(),
out,
[&](auto&& pp) { return not contains(used, pp.first); },
[](auto&& pp) { return &pp.second; });
}
std::vector<const module*> program::get_modules() const
{
auto result = generic_get_modules(this->get_main_module());
generic_get_unused_modules(impl->modules, result, std::back_inserter(result));
return result;
}
std::vector<module*> program::get_modules() std::vector<module*> program::get_modules()
{ {
module* mm = this->get_main_module(); auto result = generic_get_modules(this->get_main_module());
std::vector<module*> vec_modules; generic_get_unused_modules(impl->modules, result, std::back_inserter(result));
vec_modules.push_back(mm); return result;
auto sub_modules = mm->get_sub_modules(); }
vec_modules.insert(vec_modules.end(), sub_modules.begin(), sub_modules.end());
return vec_modules; template <class Map, class T>
bool is_unused_module(Map& m, const std::vector<T*>& mods, const std::string& name)
{
bool is_unused = false;
generic_get_unused_modules(m, mods, make_function_output_iterator([&](auto* mod) {
if(mod->name() == name)
is_unused = true;
}));
return is_unused;
}
template <class Map>
bool references_instruction(Map& m, const instruction& ins, const std::string& name)
{
return std::any_of(m.begin(), m.end(), [&](auto&& p) {
if(p.first == name)
return false;
return std::any_of(p.second.begin(), p.second.end(), [&](auto&& i) {
return std::any_of(i.inputs().begin(), i.inputs().end(), [&](auto&& j) {
return std::addressof(*j) == std::addressof(ins);
});
});
});
}
void program::remove_module(const std::string& name)
{
// cppcheck-suppress assertWithSideEffect
assert(is_unused_module(impl->modules, generic_get_modules(this->get_main_module()), name) &&
"Module used in program");
assert(std::none_of(
impl->modules.at(name).begin(),
impl->modules.at(name).end(),
[&](auto&& ins) { return references_instruction(impl->modules, ins, name); }) &&
"Instruction referenced in another module");
impl->modules.erase(name);
}
void program::remove_unused_modules()
{
std::vector<module*> unused;
generic_get_unused_modules(
impl->modules, generic_get_modules(this->get_main_module()), std::back_inserter(unused));
for(auto* m : unused)
this->remove_module(m->name());
} }
program& program::sort() program& program::sort()
{ {
for(auto& mod : this->impl->modules) for(auto& pp : this->impl->modules)
{ {
mod.sort(); pp.second.sort();
} }
return *this; return *this;
......
...@@ -398,7 +398,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -398,7 +398,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false); m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
m.def("to_gpu", &migraphx::gpu::to_gpu, py::arg("arg"), py::arg("host") = false); m.def("to_gpu", &migraphx::gpu::to_gpu, py::arg("arg"), py::arg("host") = false);
m.def("from_gpu", &migraphx::gpu::from_gpu); m.def("from_gpu", &migraphx::gpu::from_gpu);
m.def("gpu_sync", &migraphx::gpu::gpu_sync); m.def("gpu_sync", [] { migraphx::gpu::gpu_sync(); });
#endif #endif
#ifdef VERSION_INFO #ifdef VERSION_INFO
......
...@@ -11,6 +11,9 @@ void raw_data_to_value(value& v, const RawData& rd) ...@@ -11,6 +11,9 @@ void raw_data_to_value(value& v, const RawData& rd)
{ {
value result; value result;
result["shape"] = migraphx::to_value(rd.get_shape()); result["shape"] = migraphx::to_value(rd.get_shape());
if(rd.get_shape().type() == shape::tuple_type)
result["sub"] = migraphx::to_value(rd.get_sub_objects());
else
result["data"] = migraphx::value::binary(rd.data(), rd.get_shape().bytes()); result["data"] = migraphx::value::binary(rd.data(), rd.get_shape().bytes());
v = result; v = result;
} }
...@@ -25,8 +28,15 @@ void migraphx_from_value(const value& v, literal& l) ...@@ -25,8 +28,15 @@ void migraphx_from_value(const value& v, literal& l)
void migraphx_to_value(value& v, const argument& a) { raw_data_to_value(v, a); } void migraphx_to_value(value& v, const argument& a) { raw_data_to_value(v, a); }
void migraphx_from_value(const value& v, argument& a) void migraphx_from_value(const value& v, argument& a)
{ {
if(v.contains("data"))
{
literal l = migraphx::from_value<literal>(v); literal l = migraphx::from_value<literal>(v);
a = l.get_argument(); a = l.get_argument();
}
else
{
a = migraphx::from_value<std::vector<argument>>(v.at("sub"));
}
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -20,28 +20,36 @@ struct shape_impl ...@@ -20,28 +20,36 @@ struct shape_impl
return result; return result;
} }
shape_impl() : m_type(shape::float_type), m_standard(false) {} shape_impl() : m_type(shape::float_type) {}
shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({0}), m_standard(true) {} shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({0}), m_standard(true)
{
assert(t != shape::tuple_type);
}
shape_impl(shape::type_t t, std::vector<std::size_t> l) shape_impl(shape::type_t t, std::vector<std::size_t> l)
: m_type(t), m_lens(std::move(l)), m_standard(true) : m_type(t), m_lens(std::move(l)), m_standard(true)
{ {
assert(t != shape::tuple_type);
this->calculate_strides(); this->calculate_strides();
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
} }
shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: m_type(t), m_lens(std::move(l)), m_strides(std::move(s)) : m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
{ {
assert(t != shape::tuple_type);
assert(m_lens.size() == m_strides.size()); assert(m_lens.size() == m_strides.size());
// assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and // assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and
// "At least one stride must be non-zero"); // "At least one stride must be non-zero");
m_standard = this->elements() == this->element_space() and m_standard = this->elements() == this->element_space() and
std::is_sorted(m_strides.rbegin(), m_strides.rend()); std::is_sorted(m_strides.rbegin(), m_strides.rend());
} }
shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}
shape::type_t m_type; shape::type_t m_type;
std::vector<std::size_t> m_lens; std::vector<std::size_t> m_lens = {};
std::vector<std::size_t> m_strides; std::vector<std::size_t> m_strides = {};
bool m_standard; std::vector<shape> m_shapes = {};
bool m_standard = false;
void calculate_strides() void calculate_strides()
{ {
...@@ -84,7 +92,7 @@ const std::vector<shape::type_t>& shape::types() ...@@ -84,7 +92,7 @@ const std::vector<shape::type_t>& shape::types()
{ {
static const std::vector<shape::type_t> result = { static const std::vector<shape::type_t> result = {
#define MIGRAPHX_GENERATE_TYPE_VECTOR(x, t) x, #define MIGRAPHX_GENERATE_TYPE_VECTOR(x, t) x,
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_GENERATE_TYPE_VECTOR)}; MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_GENERATE_TYPE_VECTOR) tuple_type};
return result; return result;
} }
...@@ -92,6 +100,7 @@ std::string shape::name(shape::type_t t) ...@@ -92,6 +100,7 @@ std::string shape::name(shape::type_t t)
{ {
switch(t) switch(t)
{ {
case tuple_type: return "tuple_type";
#define MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE(x, t) \ #define MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE(x, t) \
case x: return #x; case x: return #x;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE)
...@@ -103,6 +112,7 @@ std::string shape::cpp_type(shape::type_t t) ...@@ -103,6 +112,7 @@ std::string shape::cpp_type(shape::type_t t)
{ {
switch(t) switch(t)
{ {
case tuple_type: MIGRAPHX_THROW("No C++ type for tuple");
#define MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE(x, t) \ #define MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE(x, t) \
case x: return #t; case x: return #t;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE)
...@@ -123,6 +133,8 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) ...@@ -123,6 +133,8 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
{ {
} }
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape shape::from_permutation(type_t t, shape shape::from_permutation(type_t t,
const std::vector<std::size_t>& l, const std::vector<std::size_t>& l,
const std::vector<int64_t>& perm) const std::vector<int64_t>& perm)
...@@ -139,13 +151,24 @@ const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; ...@@ -139,13 +151,24 @@ const std::vector<std::size_t>& shape::strides() const { return impl->m_strides;
std::size_t shape::elements() const { return impl->elements(); } std::size_t shape::elements() const { return impl->elements(); }
std::size_t shape::bytes() const std::size_t shape::bytes() const
{ {
if(this->sub_shapes().empty())
{
std::size_t n = 0; std::size_t n = 0;
this->visit_type([&](auto as) { n = as.size(); }); this->visit_type([&](auto as) { n = as.size(); });
return n * this->element_space(); return n * this->element_space();
}
else
{
return std::accumulate(this->sub_shapes().begin(),
this->sub_shapes().end(),
std::size_t{0},
[&](auto x, auto y) { return x + y.bytes(); });
}
} }
std::size_t shape::type_size() const std::size_t shape::type_size() const
{ {
std::size_t n = 0; std::size_t n = 0;
if(this->sub_shapes().empty())
this->visit_type([&](auto as) { n = as.size(); }); this->visit_type([&](auto as) { n = as.size(); });
return n; return n;
} }
...@@ -208,7 +231,10 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end ...@@ -208,7 +231,10 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end
}); });
} }
bool shape::packed() const { return this->elements() == this->element_space(); } bool shape::packed() const
{
return this->sub_shapes().empty() and this->elements() == this->element_space();
}
bool shape::transposed() const bool shape::transposed() const
{ {
...@@ -242,7 +268,8 @@ bool shape::scalar() const ...@@ -242,7 +268,8 @@ bool shape::scalar() const
{ {
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
// if any stride > 0, then accumulate will return false // if any stride > 0, then accumulate will return false
return std::accumulate(this->strides().begin(), this->strides().end(), std::size_t(0)) == 0; return this->sub_shapes().empty() and
std::accumulate(this->strides().begin(), this->strides().end(), std::size_t(0)) == 0;
} }
bool shape::standard() const { return impl->m_standard; } bool shape::standard() const { return impl->m_standard; }
...@@ -273,15 +300,23 @@ std::string shape::type_string() const { return name(this->type()); } ...@@ -273,15 +300,23 @@ std::string shape::type_string() const { return name(this->type()); }
bool operator==(const shape& x, const shape& y) bool operator==(const shape& x, const shape& y)
{ {
return x.type() == y.type() && x.lens() == y.lens() && x.strides() == y.strides(); return x.impl == y.impl or (x.type() == y.type() and x.lens() == y.lens() and
x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes());
} }
bool operator!=(const shape& x, const shape& y) { return !(x == y); } bool operator!=(const shape& x, const shape& y) { return !(x == y); }
std::ostream& operator<<(std::ostream& os, const shape& x) std::ostream& operator<<(std::ostream& os, const shape& x)
{ {
if(x.sub_shapes().empty())
{
os << x.type_string() << ", "; os << x.type_string() << ", ";
os << "{" << to_string_range(x.lens()) << "}, "; os << "{" << to_string_range(x.lens()) << "}, ";
os << "{" << to_string_range(x.strides()) << "}"; os << "{" << to_string_range(x.strides()) << "}";
}
else
{
os << "[" << to_string_range(x.sub_shapes()) << "]";
}
return os; return os;
} }
...@@ -289,23 +324,36 @@ shape::type_t shape::parse_type(const std::string& s) ...@@ -289,23 +324,36 @@ shape::type_t shape::parse_type(const std::string& s)
{ {
static const std::unordered_map<std::string, shape::type_t> m = { static const std::unordered_map<std::string, shape::type_t> m = {
#define MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP(x, t) {#x, x}, {#t, x}, #define MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP(x, t) {#x, x}, {#t, x},
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP)}; MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP){"tuple_type",
tuple_type},
{"tuple", tuple_type}};
return m.at(s); return m.at(s);
} }
const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; }
void migraphx_to_value(value& v, const shape& s) void migraphx_to_value(value& v, const shape& s)
{ {
value result; value result;
result["type"] = migraphx::to_value(s.type_string()); result["type"] = migraphx::to_value(s.type_string());
result["lens"] = migraphx::to_value(s.lens()); result["lens"] = migraphx::to_value(s.lens());
result["strides"] = migraphx::to_value(s.strides()); result["strides"] = migraphx::to_value(s.strides());
result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
v = result; v = result;
} }
void migraphx_from_value(const value& v, shape& s) void migraphx_from_value(const value& v, shape& s)
{ {
s = shape{shape::parse_type(v.at("type").get_string()), auto t = v.at("type").get_string();
if(t == "tuple_type")
{
s = shape{migraphx::from_value<std::vector<migraphx::shape>>(v.at("sub_shapes"))};
}
else
{
s = shape{shape::parse_type(t),
v.at("lens").to_vector<std::size_t>(), v.at("lens").to_vector<std::size_t>(),
v.at("strides").to_vector<std::size_t>()}; v.at("strides").to_vector<std::size_t>()};
}
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -133,6 +133,7 @@ add_library(migraphx_gpu ...@@ -133,6 +133,7 @@ add_library(migraphx_gpu
logsoftmax.cpp logsoftmax.cpp
lrn.cpp lrn.cpp
leaky_relu.cpp leaky_relu.cpp
mlir_conv.cpp
pack_args.cpp pack_args.cpp
pack_int8_args.cpp pack_int8_args.cpp
pad.cpp pad.cpp
...@@ -148,6 +149,7 @@ add_library(migraphx_gpu ...@@ -148,6 +149,7 @@ add_library(migraphx_gpu
write_literals.cpp write_literals.cpp
) )
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu) set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
function(register_migraphx_gpu_ops PREFIX) function(register_migraphx_gpu_ops PREFIX)
foreach(OP ${ARGN}) foreach(OP ${ARGN})
register_op(migraphx_gpu HEADER migraphx/gpu/${OP}.hpp OPERATORS gpu::${PREFIX}${OP} INCLUDES migraphx/gpu/context.hpp) register_op(migraphx_gpu HEADER migraphx/gpu/${OP}.hpp OPERATORS gpu::${PREFIX}${OP} INCLUDES migraphx/gpu/context.hpp)
...@@ -259,6 +261,20 @@ endif() ...@@ -259,6 +261,20 @@ endif()
message(STATUS "clang-offload-bundler: ${MIGRAPHX_OFFLOADBUNDLER_BIN}") message(STATUS "clang-offload-bundler: ${MIGRAPHX_OFFLOADBUNDLER_BIN}")
message(STATUS "extractkernel: ${MIGRAPHX_EXTRACT_KERNEL}") message(STATUS "extractkernel: ${MIGRAPHX_EXTRACT_KERNEL}")
set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "")
if(MIGRAPHX_ENABLE_MLIR)
find_library(LIBMLIRMIOPEN MLIRMIOpenThin REQUIRED)
# REQUIRED is not supported before cmake 3.18
if(NOT LIBMLIRMIOPEN)
message(FATAL_ERROR "libMLIRMIOpenThin not found")
else()
message(STATUS "Build with libMLIRMIOpenThin: " ${LIBMLIRMIOPEN})
endif()
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_MLIR_MIOPEN_SUPPORT")
target_link_libraries(migraphx_gpu PUBLIC ${LIBMLIRMIOPEN})
endif()
# Get flags needed to compile hip # Get flags needed to compile hip
include(TargetFlags) include(TargetFlags)
target_flags(HIP_COMPILER_FLAGS hip::device) target_flags(HIP_COMPILER_FLAGS hip::device)
...@@ -286,7 +302,6 @@ endif() ...@@ -286,7 +302,6 @@ endif()
# Workaround broken rocblas headers # Workaround broken rocblas headers
target_compile_definitions(migraphx_gpu PUBLIC -D__HIP_PLATFORM_HCC__=1) target_compile_definitions(migraphx_gpu PUBLIC -D__HIP_PLATFORM_HCC__=1)
target_compile_options(migraphx_gpu PRIVATE -std=c++17)
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels) target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
......
...@@ -48,17 +48,6 @@ std::string generate_index_ints(const std::vector<T>& v) ...@@ -48,17 +48,6 @@ std::string generate_index_ints(const std::vector<T>& v)
return "index_ints<" + to_string_range(v) + ">{}"; return "index_ints<" + to_string_range(v) + ">{}";
} }
std::string generate_cpp_type(shape::type_t t)
{
switch(t)
{
#define MIGRAPHX_GPU_GENERATE_TYPE_STRING(x, t) \
case shape::x: return #t;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_GPU_GENERATE_TYPE_STRING)
}
MIGRAPHX_THROW("Invalid type");
}
std::string generate_make_shape(const shape& s) std::string generate_make_shape(const shape& s)
{ {
return "make_shape(" + generate_index_ints(s.lens()) + ", " + generate_index_ints(s.strides()) + return "make_shape(" + generate_index_ints(s.lens()) + ", " + generate_index_ints(s.strides()) +
...@@ -80,7 +69,7 @@ std::string generate_make_tensor(std::size_t n, const shape& s) ...@@ -80,7 +69,7 @@ std::string generate_make_tensor(std::size_t n, const shape& s)
{ {
return interpolate_string(make_tensor_template, return interpolate_string(make_tensor_template,
{{"n", std::to_string(n)}, {{"n", std::to_string(n)},
{"type", generate_cpp_type(s.type())}, {"type", shape::cpp_type(s.type())},
{"lens", generate_index_ints(s.lens())}, {"lens", generate_index_ints(s.lens())},
{"strides", generate_index_ints(s.strides())}}); {"strides", generate_index_ints(s.strides())}});
} }
......
...@@ -7,9 +7,33 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -7,9 +7,33 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void contiguous_nonstandard(hipStream_t stream, const argument& result, const argument& arg)
{
shape s{result.get_shape().type(), result.get_shape().lens()};
visit_all(result, arg)([&](auto output_v, auto input_v) {
hip_visit_views(output_v, input_v, s)([&](auto output, auto input, auto standard_shape) {
mi_gs_launch(stream,
standard_shape)([=](auto idx) __device__ { output[idx] = input[idx]; });
});
});
}
void contiguous_packed(hipStream_t stream, const argument& result, const argument& arg)
{
index_int nelements = result.get_shape().elements();
visit_all(result, arg)([&](auto output_v, auto input_v) {
const auto* input = device_cast(input_v.data());
auto* output = device_cast(output_v.data());
gs_launch(stream, nelements)([=](auto i) __device__ { output[i] = input[i]; });
});
}
void contiguous(hipStream_t stream, const argument& result, const argument& arg) void contiguous(hipStream_t stream, const argument& result, const argument& arg)
{ {
nary(stream, result, arg)([](auto x) __device__ { return x; }); if(result.get_shape() == arg.get_shape() and result.get_shape().packed())
contiguous_packed(stream, result, arg);
else
contiguous_nonstandard(stream, result, arg);
} }
} // namespace device } // namespace device
......
...@@ -51,6 +51,50 @@ auto get_shape(const T& x) -> decltype(x.get_shape()) ...@@ -51,6 +51,50 @@ auto get_shape(const T& x) -> decltype(x.get_shape())
return x.get_shape(); return x.get_shape();
} }
template <class T>
struct is_hip_type : std::false_type
{
};
template <>
struct is_hip_type<float> : std::true_type
{
};
template <>
struct is_hip_type<half> : std::true_type
{
};
template <>
struct is_hip_type<bool> : std::true_type
{
};
template <>
struct is_hip_type<std::int8_t> : std::true_type
{
};
template <>
struct is_hip_type<std::uint8_t> : std::true_type
{
};
template <class T, class V, MIGRAPHX_REQUIRES(is_hip_type<typename T::type>{})>
void hip_visitor_invoke(T as, V&& v)
{
v(as);
}
template <class T, class V, MIGRAPHX_REQUIRES(not is_hip_type<typename T::type>{})>
void hip_visitor_invoke(T, V&&)
{
MIGRAPHX_THROW(std::string("Unsupported data type on GPU: ") + __PRETTY_FUNCTION__);
}
template <class V>
auto hip_visitor(V v)
{
return [=](auto as) { hip_visitor_invoke(as, v); };
}
template <class V, class F, class... Ts> template <class V, class F, class... Ts>
void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
{ {
...@@ -62,8 +106,9 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) ...@@ -62,8 +106,9 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
static_cast<index_int>(get_shape(xs).lens().size())...}; static_cast<index_int>(get_shape(xs).lens().size())...};
if(!std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); })) if(!std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), visit_tensor_size(s.lens().size(), [&](auto ndim) {
[&](auto ndim) { s.visit_type([&](auto as) { v(f(xs, ndim, as)...); }); }); s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); }));
});
} }
template <class V, class F, class... Ts> template <class V, class F, class... Ts>
......
...@@ -16,6 +16,7 @@ rocblas_datatype get_type(shape::type_t type) ...@@ -16,6 +16,7 @@ rocblas_datatype get_type(shape::type_t type)
case shape::uint8_type: return rocblas_datatype_u8_r; case shape::uint8_type: return rocblas_datatype_u8_r;
case shape::int32_type: return rocblas_datatype_i32_r; case shape::int32_type: return rocblas_datatype_i32_r;
case shape::uint32_type: return rocblas_datatype_u32_r; case shape::uint32_type: return rocblas_datatype_u32_r;
case shape::tuple_type:
case shape::bool_type: case shape::bool_type:
case shape::uint16_type: case shape::uint16_type:
case shape::int16_type: case shape::int16_type:
......
...@@ -15,6 +15,7 @@ namespace gpu { ...@@ -15,6 +15,7 @@ namespace gpu {
MIGRAPHX_REGISTER_OP(hip_allocate) MIGRAPHX_REGISTER_OP(hip_allocate)
MIGRAPHX_REGISTER_OP(hip_sync_device) MIGRAPHX_REGISTER_OP(hip_sync_device)
MIGRAPHX_REGISTER_OP(hip_sync_stream)
MIGRAPHX_REGISTER_OP(hip_copy_to_gpu) MIGRAPHX_REGISTER_OP(hip_copy_to_gpu)
MIGRAPHX_REGISTER_OP(hip_copy_from_gpu) MIGRAPHX_REGISTER_OP(hip_copy_from_gpu)
MIGRAPHX_REGISTER_OP(hip_copy) MIGRAPHX_REGISTER_OP(hip_copy)
...@@ -146,6 +147,8 @@ void gpu_sync() ...@@ -146,6 +147,8 @@ void gpu_sync()
MIGRAPHX_THROW("hip device synchronization failed: " + hip_error(status)); MIGRAPHX_THROW("hip device synchronization failed: " + hip_error(status));
} }
void gpu_sync(const context& ctx) { ctx.finish(); }
void hip_async_copy(context& ctx, const argument& src, const argument& dst, hipMemcpyKind kind) void hip_async_copy(context& ctx, const argument& src, const argument& dst, hipMemcpyKind kind)
{ {
std::size_t src_size = src.get_shape().bytes(); std::size_t src_size = src.get_shape().bytes();
......
...@@ -21,10 +21,20 @@ using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy); ...@@ -21,10 +21,20 @@ using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy);
struct hip_device struct hip_device
{ {
hip_device() { add_stream(); } hip_device()
{
device_props.gcnArchName[0] = '\0';
device_props.gcnArch = 0;
device_props.multiProcessorCount = 0;
add_stream();
}
hip_device(std::size_t id, std::size_t n) : device_id(id) hip_device(std::size_t id, std::size_t n) : device_id(id)
{ {
auto status = hipGetDeviceProperties(&device_props, device_id);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to allocate stream");
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
add_stream(); add_stream();
} }
...@@ -98,6 +108,16 @@ struct hip_device ...@@ -98,6 +108,16 @@ struct hip_device
return "gfx" + std::to_string(props.gcnArch); return "gfx" + std::to_string(props.gcnArch);
} }
void wait() const
{
if(s == nullptr)
return;
setup();
auto status = hipStreamSynchronize(s.get());
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to wait.");
}
void wait(hipEvent_t event) void wait(hipEvent_t event)
{ {
setup(); setup();
...@@ -127,16 +147,29 @@ struct hip_device ...@@ -127,16 +147,29 @@ struct hip_device
stream& get_stream(std::size_t n) { return streams.at(n); } stream& get_stream(std::size_t n) { return streams.at(n); }
const stream& get_stream() const { return streams.at(current_stream); }
const stream& get_stream(std::size_t n) const { return streams.at(n); }
void set_stream(std::size_t n) { current_stream = n; } void set_stream(std::size_t n) { current_stream = n; }
std::size_t nstreams() const { return streams.size(); } std::size_t nstreams() const { return streams.size(); }
std::size_t stream_id() const { return current_stream; } std::size_t stream_id() const { return current_stream; }
std::string get_device_name() const { return device_props.gcnArchName; }
std::size_t get_device_major() const { return device_props.major; }
std::size_t get_device_minor() const { return device_props.minor; }
std::size_t get_cu_count() const { return device_props.multiProcessorCount; }
private: private:
std::size_t device_id = 0; std::size_t device_id = 0;
std::size_t current_stream = 0; std::size_t current_stream = 0;
std::vector<stream> streams; std::vector<stream> streams;
hipDeviceProp_t device_props;
public: public:
std::unordered_map<std::string, argument> preallocations{}; std::unordered_map<std::string, argument> preallocations{};
...@@ -155,9 +188,21 @@ struct context ...@@ -155,9 +188,21 @@ struct context
return *current_device; return *current_device;
} }
const hip_device& get_current_device() const
{
assert(current_device != nullptr);
return *current_device;
}
hip_device::stream& get_stream() { return get_current_device().get_stream(); } hip_device::stream& get_stream() { return get_current_device().get_stream(); }
hip_device::stream& get_stream(std::size_t n) { return get_current_device().get_stream(n); } hip_device::stream& get_stream(std::size_t n) { return get_current_device().get_stream(n); }
const hip_device::stream& get_stream() const { return get_current_device().get_stream(); }
const hip_device::stream& get_stream(std::size_t n) const
{
return get_current_device().get_stream(n);
}
void set_stream(std::size_t n) { get_current_device().set_stream(n); } void set_stream(std::size_t n) { get_current_device().set_stream(n); }
void create_events(std::size_t num_of_events) void create_events(std::size_t num_of_events)
...@@ -169,7 +214,7 @@ struct context ...@@ -169,7 +214,7 @@ struct context
hipEvent_t get_event(std::size_t i) const { return events.at(i).get(); } hipEvent_t get_event(std::size_t i) const { return events.at(i).get(); }
std::vector<argument> literals{}; std::vector<argument> literals{};
void finish() const { gpu_sync(); } void finish() const { get_stream().wait(); }
static hip_event_ptr create_event() static hip_event_ptr create_event()
{ {
......
...@@ -25,6 +25,7 @@ argument from_gpu(const argument& arg); ...@@ -25,6 +25,7 @@ argument from_gpu(const argument& arg);
void set_device(std::size_t id); void set_device(std::size_t id);
void gpu_sync(); void gpu_sync();
void gpu_sync(const context& ctx);
void gpu_copy(context& ctx, const argument& src, const argument& dst); void gpu_copy(context& ctx, const argument& src, const argument& dst);
void copy_to_gpu(context& ctx, const argument& src, const argument& dst); void copy_to_gpu(context& ctx, const argument& src, const argument& dst);
...@@ -82,6 +83,33 @@ struct hip_sync_device ...@@ -82,6 +83,33 @@ struct hip_sync_device
} }
}; };
struct hip_sync_stream
{
std::string tag{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.tag, "tag"));
}
std::string name() const { return "hip::sync_stream"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
if(inputs.empty())
return {};
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
gpu_sync(ctx);
if(args.empty())
return {};
return args.front();
}
};
struct hip_copy_to_gpu struct hip_copy_to_gpu
{ {
std::string name() const { return "hip::copy_to_gpu"; } std::string name() const { return "hip::copy_to_gpu"; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment