Commit ff3bd8e6 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents 32b69ceb c310bc5c
......@@ -168,6 +168,12 @@ void copy(Range&& r, Iterator it)
std::copy(r.begin(), r.end(), it);
}
template <class Range>
auto reverse(Range& r)
{
return range(std::make_reverse_iterator(r.end()), std::make_reverse_iterator(r.begin()));
}
template <class Range, class T>
void replace(Range&& r, const T& old, const T& new_x)
{
......
......@@ -29,7 +29,15 @@ struct raw_data : raw_data_base
friend Stream& operator<<(Stream& os, const Derived& d)
{
if(not d.empty())
d.visit([&](auto x) { os << x; });
d.visit([&](auto x) { os << x; },
[&](auto&& xs) {
for(auto&& x : xs)
{
os << "{ ";
os << x;
os << " }, ";
}
});
return os;
}
......@@ -45,9 +53,19 @@ struct raw_data : raw_data_base
auto&& derived = static_cast<const Derived&>(*this);
if(derived.empty())
MIGRAPHX_THROW("Visiting empty data!");
auto&& s = derived.get_shape();
auto&& buffer = derived.data();
s.visit_type([&](auto as) { v(*(as.from(buffer) + s.index(n))); });
auto&& s = derived.get_shape();
s.visit_type([&](auto as) { v(*(as.from(derived.data()) + s.index(n))); });
}
template <class Visitor, class TupleVisitor>
void visit(Visitor v, TupleVisitor tv) const
{
auto&& derived = static_cast<const Derived&>(*this);
if(derived.empty())
MIGRAPHX_THROW("Visiting empty data!");
auto&& s = derived.get_shape();
s.visit_type([&](auto as) { v(make_view(s, as.from(derived.data()))); },
[&] { tv(derived.get_sub_objects()); });
}
/**
......@@ -60,12 +78,7 @@ struct raw_data : raw_data_base
template <class Visitor>
void visit(Visitor v) const
{
auto&& derived = static_cast<const Derived&>(*this);
if(derived.empty())
MIGRAPHX_THROW("Visiting empty data!");
auto&& s = derived.get_shape();
auto&& buffer = derived.data();
s.visit_type([&](auto as) { v(make_view(s, as.from(buffer))); });
visit(v, [&](const auto&) { MIGRAPHX_THROW("Invalid tuple type"); });
}
/// Returns true if the raw data is only one element
......@@ -156,43 +169,27 @@ struct raw_data : raw_data_base
}
};
template <class T,
class U,
MIGRAPHX_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator==(const T& x, const U& y)
namespace detail {
template <class V1, class V2, class... Ts>
void visit_all_flatten(const shape& s, V1&& v1, V2&& v2, Ts&&... xs)
{
auto&& xshape = x.get_shape();
auto&& yshape = y.get_shape();
bool result = x.empty() && y.empty();
if(not result && xshape == yshape)
{
auto&& xbuffer = x.data();
auto&& ybuffer = y.data();
// TODO: Dont use tensor view for single values
xshape.visit_type([&](auto as) {
auto xview = make_view(xshape, as.from(xbuffer));
auto yview = make_view(yshape, as.from(ybuffer));
result = xview == yview;
});
}
return result;
s.visit_type([&](auto as) { v1(make_view(xs.get_shape(), as.from(xs.data()))...); },
[&] { v2(xs.get_sub_objects()...); });
}
template <class T,
class U,
MIGRAPHX_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator!=(const T& x, const U& y)
template <class V1, class V2, class... Ts>
auto visit_all_pack(const shape& s, V1&& v1, V2&& v2)
{
return !(x == y);
return [&](auto&&... xs) {
// Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100
visit_all_flatten(s, v1, v2, xs...);
};
}
namespace detail {
template <class V, class... Ts>
void visit_all_impl(const shape& s, V&& v, Ts&&... xs)
template <class V1, class... Ts>
auto visit_all_pack(const shape& s, V1&& v1)
{
s.visit_type([&](auto as) { v(make_view(xs.get_shape(), as.from(xs.data()))...); });
return visit_all_pack(s, v1, [](auto&&...) { MIGRAPHX_THROW("Invalid tuple type"); });
}
} // namespace detail
......@@ -215,10 +212,7 @@ auto visit_all(T&& x, Ts&&... xs)
std::initializer_list<shape::type_t> types = {xs.get_shape().type()...};
if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same");
return [&](auto v) {
// Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100
detail::visit_all_impl(s, v, x, xs...);
};
return [&](auto... vs) { detail::visit_all_pack(s, vs...)(x, xs...); };
}
template <class T>
......@@ -240,6 +234,34 @@ auto visit_all(const std::vector<T>& x)
};
}
template <class T,
class U,
MIGRAPHX_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator==(const T& x, const U& y)
{
auto&& xshape = x.get_shape();
auto&& yshape = y.get_shape();
bool result = x.empty() and y.empty();
if(not result and xshape == yshape)
{
visit_all(x, y)([&](auto xview, auto yview) { result = xview == yview; },
[&](auto&& xs, auto&& ys) {
result = std::equal(xs.begin(), xs.end(), ys.begin(), ys.end());
});
}
return result;
}
template <class T,
class U,
MIGRAPHX_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator!=(const T& x, const U& y)
{
return !(x == y);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -40,7 +40,7 @@ struct shape
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
enum type_t
{
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES)
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES) tuple_type
};
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
......@@ -82,6 +82,8 @@ struct shape
{
}
shape(const std::vector<shape>& subs);
static shape
from_permutation(type_t t, const std::vector<std::size_t>& l, const std::vector<int64_t>& perm);
type_t type() const;
......@@ -179,11 +181,16 @@ struct shape
type_t type_enum() const { return get_type<type>{}; }
};
template <class Visitor>
static void visit(type_t t, Visitor v)
template <class Visitor, class TupleVisitor>
static void visit(type_t t, Visitor v, TupleVisitor tv)
{
switch(t)
{
case tuple_type:
{
tv();
return;
}
#define MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE)
......@@ -193,9 +200,15 @@ struct shape
}
template <class Visitor>
void visit_type(Visitor v) const
static void visit(type_t t, Visitor v)
{
visit(this->type(), v);
return visit(t, v, [] { MIGRAPHX_THROW("Tuple cannot be visited."); });
}
template <class... Visitors>
void visit_type(Visitors... vs) const
{
visit(this->type(), vs...);
}
template <class Visitor>
......@@ -209,6 +222,8 @@ struct shape
std::string type_string() const;
static type_t parse_type(const std::string& s);
const std::vector<shape>& sub_shapes() const;
private:
std::shared_ptr<const shape_impl> impl;
......
......@@ -160,6 +160,7 @@ struct value
binary(T* data, std::size_t s) : base(data, data + s)
{
}
explicit binary(std::size_t s) : base(s) {}
};
value() = default;
......
......@@ -24,8 +24,57 @@ struct module_impl
{
// A list is used to keep references to an instruction stable
std::list<instruction> instructions;
std::unordered_set<instruction*> instruction_set;
std::vector<std::string> input_names;
std::string name;
bool contains(instruction_ref ins) const
{
if(ins == instructions.end())
return false;
return instruction_set.count(std::addressof(*ins)) > 0;
}
template <class... Ts>
instruction_ref emplace(instruction_ref pos, Ts&&... xs)
{
// cppcheck-suppress redundantInitialization
auto r = instructions.emplace(pos, std::forward<Ts>(xs)...);
instruction_set.insert(std::addressof(*r));
return r;
}
instruction_ref insert(instruction_ref pos, const instruction& ins)
{
return emplace(pos, ins);
}
void push_front(const instruction& ins) { insert(instructions.begin(), ins); }
void push_back(const instruction& ins) { insert(instructions.end(), ins); }
template <class... Ts>
void emplace_front(Ts&&... xs)
{
emplace(instructions.begin(), std::forward<Ts>(xs)...);
}
template <class... Ts>
void emplace_back(Ts&&... xs)
{
emplace(instructions.end(), std::forward<Ts>(xs)...);
}
instruction_ref erase(instruction_ref pos)
{
instruction_set.erase(std::addressof(*pos));
return instructions.erase(pos);
}
instruction_ref erase(instruction_ref start, instruction_ref last)
{
std::for_each(start, last, [&](auto& ins) { instruction_set.erase(std::addressof(ins)); });
return instructions.erase(start, last);
}
};
const operation& get_operation(instruction_ref ins) { return ins->get_operator(); }
......@@ -71,20 +120,19 @@ void module::assign(const module& m)
if(ins->name() == "@literal")
{
auto l = ins->get_literal();
copy_ins = impl->instructions.insert(impl->instructions.end(), instruction{l});
copy_ins = impl->insert(impl->instructions.end(), instruction{l});
}
else if(ins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins->get_operator()).parameter;
auto s = ins->get_shape();
copy_ins = impl->instructions.insert(impl->instructions.end(),
{builtin::param{name}, std::move(s), {}});
copy_ins =
impl->insert(impl->instructions.end(), {builtin::param{name}, std::move(s), {}});
}
else if(ins->name() == "@outline")
{
auto s = ins->get_shape();
copy_ins =
impl->instructions.insert(impl->instructions.end(), {builtin::outline{s}, s, {}});
auto s = ins->get_shape();
copy_ins = impl->insert(impl->instructions.end(), {builtin::outline{s}, s, {}});
}
else
{
......@@ -127,7 +175,7 @@ instruction_ref module::insert_instruction(instruction_ref ins,
{
assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args);
auto result = impl->instructions.insert(ins, {op, r, std::move(args)});
auto result = impl->insert(ins, {op, r, std::move(args)});
instruction::backreference(result);
assert(result->valid(begin()));
return result;
......@@ -148,8 +196,7 @@ instruction_ref module::insert_instruction(instruction_ref ins,
{
assert(not starts_with(op.name(), "@"));
auto out_shape = compute_shape(op, args, module_args);
auto result =
impl->instructions.insert(ins, {op, out_shape, std::move(args), std::move(module_args)});
auto result = impl->insert(ins, {op, out_shape, std::move(args), std::move(module_args)});
instruction::backreference(result);
assert(result->valid(begin()));
return result;
......@@ -222,7 +269,7 @@ instruction_ref module::remove_instruction(instruction_ref ins)
assert(has_instruction(ins));
assert(ins->outputs().empty());
ins->clear_arguments();
return impl->instructions.erase(ins);
return impl->erase(ins);
}
instruction_ref module::remove_instructions(instruction_ref first, instruction_ref last)
......@@ -233,7 +280,7 @@ instruction_ref module::remove_instructions(instruction_ref first, instruction_r
assert(has_instruction(first));
std::for_each(first, last, [&](instruction& ins) { ins.clear_arguments(); });
assert(std::all_of(first, last, [&](const instruction& ins) { return ins.outputs().empty(); }));
return impl->instructions.erase(first, last);
return impl->erase(first, last);
}
instruction_ref module::move_instruction(instruction_ref src, instruction_ref dst)
......@@ -252,13 +299,13 @@ instruction_ref module::move_instructions(instruction_ref src, instruction_ref d
instruction_ref module::add_literal(literal l)
{
impl->instructions.emplace_front(std::move(l));
impl->emplace_front(std::move(l));
return impl->instructions.begin();
}
instruction_ref module::add_outline(const shape& s)
{
impl->instructions.push_front({builtin::outline{s}, s, {}});
impl->push_front({builtin::outline{s}, s, {}});
return impl->instructions.begin();
}
......@@ -267,13 +314,13 @@ instruction_ref module::add_parameter(std::string name, shape s)
assert(get_parameter_shape(name) == shape{});
impl->input_names.push_back(name);
impl->instructions.push_front({builtin::param{std::move(name)}, std::move(s), {}});
impl->push_front({builtin::param{std::move(name)}, std::move(s), {}});
return impl->instructions.begin();
}
instruction_ref module::add_return(std::vector<instruction_ref> args)
{
impl->instructions.push_back({builtin::returns{}, {}, std::move(args)});
impl->push_back({builtin::returns{}, {}, std::move(args)});
auto result = std::prev(impl->instructions.end());
instruction::backreference(result);
assert(result->valid(begin()));
......@@ -350,13 +397,7 @@ std::unordered_map<std::string, shape> module::get_parameter_shapes() const
return result;
}
bool module::has_instruction(instruction_ref ins) const
{
return std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
return std::addressof(*ins) == std::addressof(x);
}) != impl->instructions.end();
}
bool module::has_instruction(instruction_ref ins) const { return impl->contains(ins); }
std::size_t module::size() const { return impl->instructions.size(); }
instruction_ref module::begin() const { return impl->instructions.begin(); }
......@@ -364,6 +405,8 @@ instruction_ref module::end() const { return impl->instructions.end(); }
std::vector<shape> module::get_output_shapes() const
{
if(impl->instructions.empty())
return {};
auto last_ins = impl->instructions.back();
if(last_ins.name() == "@return")
{
......
......@@ -13,7 +13,7 @@ target_include_directories(migraphx_onnx PRIVATE include)
set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx)
rocm_set_soversion(migraphx_onnx ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_onnx)
target_link_libraries(migraphx_onnx PRIVATE onnx-proto)
target_link_libraries(migraphx_onnx PRIVATE onnx-proto "-Wl,--exclude-libs,ALL")
target_link_libraries(migraphx_onnx PUBLIC migraphx)
rocm_install_targets(
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
instruction_ref parse_prefix_scan_oper(const std::string& op_name,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args)
{
migraphx::argument in = args[1]->eval();
check_arg_empty(in, "PARSE_PREFIX_SCAN: axis - dynamic shape not supported");
std::vector<std::size_t> axis_in;
in.visit([&](auto input) { axis_in.assign(input.begin(), input.end()); });
int64_t axis = axis_in[0];
bool exclusive = false;
bool reverse = false;
if(contains(info.attributes, "exclusive"))
{
exclusive = parser.parse_value(info.attributes.at("exclusive")).at<bool>();
}
if(contains(info.attributes, "reverse"))
{
reverse = parser.parse_value(info.attributes.at("reverse")).at<bool>();
}
return info.add_instruction(
make_op(op_name, {{"axis", axis}, {"exclusive", exclusive}, {"reverse", reverse}}),
args[0]);
}
struct parse_prefix_scan_op : op_parser<parse_prefix_scan_op>
{
std::vector<op_desc> operators() const { return {{"CumSum", "prefix_scan_sum"}}; }
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
return parse_prefix_scan_oper(opd.op_name, parser, std::move(info), std::move(args));
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -15,25 +15,56 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void run_passes(module& modl, const std::vector<pass>& passes, tracer trace)
void validate_pass(module& mod, const pass& p, tracer trace)
{
(void)mod;
(void)p;
(void)trace;
#ifndef NDEBUG
trace("Validate ...");
auto invalid = mod.validate();
if(invalid != mod.end())
{
auto index = std::distance(mod.begin(), invalid);
MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " +
std::to_string(index) + ": " + invalid->name());
}
trace();
#endif
}
void run_pass(module& mod, const pass& p, tracer trace)
{
trace("Module: ", mod.name(), ", Pass: ", p.name());
assert(mod.validate() == mod.end());
p.apply(mod);
trace(mod);
validate_pass(mod, p, trace);
}
void run_pass(program& prog, const pass& p, tracer trace)
{
trace("Pass: ", p.name());
p.apply(prog);
trace(prog);
}
void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
{
for(const auto& p : passes)
{
trace("Module: ", modl.name(), ", Pass: ", p.name());
p.apply(modl);
trace(modl);
run_pass(mod, p, trace);
}
}
#ifndef NDEBUG
trace("Validate ...");
auto invalid = modl.validate();
if(invalid != modl.end())
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
{
for(const auto& p : passes)
{
auto mods = prog.get_modules();
for(const auto& mod : reverse(mods))
{
auto index = std::distance(modl.begin(), invalid);
MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " +
std::to_string(index) + ": " + invalid->name());
run_pass(*mod, p, trace);
}
trace();
#endif
run_pass(prog, p, trace);
}
}
......
......@@ -9,6 +9,8 @@
#include <migraphx/pass_manager.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/output_iterator.hpp>
#include <migraphx/make_op.hpp>
#include <iostream>
#include <sstream>
......@@ -26,13 +28,12 @@ inline namespace MIGRAPHX_INLINE_NS {
struct program_impl
{
// A map is used to keep references to modules of the program
// all the modules are store in the depth-first order
std::list<module> modules;
std::unordered_map<std::string, module> modules;
context ctx;
std::string target_name;
};
program::program() : impl(std::make_unique<program_impl>()) { impl->modules.push_back({"main"}); }
program::program() : impl(std::make_unique<program_impl>()) { this->create_module("main"); }
program::program(program&&) noexcept = default;
program::~program() noexcept = default;
......@@ -65,11 +66,11 @@ void program::assign(const program& p)
// build a map from old ins to new ins
// Build a map from old module to new module
std::unordered_map<module_ref, module_ref> mod_map;
std::transform(impl->modules.begin(),
impl->modules.end(),
p.impl->modules.begin(),
std::inserter(mod_map, mod_map.begin()),
[](auto&& x, auto&& y) { return std::make_pair(&y, &x); });
std::transform(
impl->modules.begin(),
impl->modules.end(),
std::inserter(mod_map, mod_map.begin()),
[&](auto&& xp) { return std::make_pair(&p.impl->modules.at(xp.first), &xp.second); });
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto&& pp : mod_map)
......@@ -86,7 +87,7 @@ void program::assign(const program& p)
// Update all references from all modules
for(auto&& mp : impl->modules)
{
for(auto ins : iterator_for(mp))
for(auto ins : iterator_for(mp.second))
instruction::replace_refs(ins, ins_map, mod_map);
}
}
......@@ -144,14 +145,14 @@ void program::compile(const target& t, compile_options options)
options.trace(*this);
options.trace();
auto mods = this->get_modules();
std::reverse(mods.begin(), mods.end());
auto&& passes = t.get_passes(this->impl->ctx, options);
run_passes(*this, passes, options.trace);
for(const auto& mod : mods)
auto mods = this->get_modules();
// Validate and finalize
for(const auto& mod : reverse(mods))
{
assert(mod->validate() == mod->end());
run_passes(*mod, passes, options.trace);
auto invalid = mod->validate();
if(invalid != mod->end())
{
......@@ -306,7 +307,7 @@ std::vector<argument> program::eval(parameter_map params) const
}
}
const int program_file_version = 4;
const int program_file_version = 5;
value program::to_value() const
{
......@@ -316,14 +317,14 @@ value program::to_value() const
if(not this->impl->target_name.empty())
result["context"] = this->impl->ctx.to_value();
value module_vals = value::array{};
value module_vals = value::object{};
std::unordered_map<instruction_ref, std::string> names;
for(auto& mod : this->impl->modules)
for(auto& mod : this->get_modules())
{
value mod_val;
value nodes;
mod_val["name"] = mod.name();
names = mod.print(
mod_val["name"] = mod->name();
names = mod->print(
[&](auto ins, auto ins_names) {
value node;
node["output"] = ins_names.at(ins);
......@@ -358,7 +359,7 @@ value program::to_value() const
names);
mod_val["nodes"] = nodes;
module_vals.push_back(mod_val);
module_vals[mod->name()] = mod_val;
}
result["modules"] = module_vals;
......@@ -371,12 +372,7 @@ static void mod_from_val(module_ref mod,
std::unordered_map<std::string, instruction_ref>& instructions,
const std::unordered_map<std::string, module_ref>& map_mods)
{
const auto* it = std::find_if(v.begin(), v.end(), [&](auto& mv) {
return mv.at("name").template to<std::string>() == mod->name();
});
assert(it != v.end());
const auto& module_val = *it;
const auto& module_val = v.at(mod->name());
for(const value& node : module_val.at("nodes"))
{
instruction_ref output;
......@@ -455,15 +451,18 @@ void program::from_value(const value& v)
}
auto module_vals = v.at("modules");
std::unordered_map<std::string, module_ref> map_mods;
for(const auto& vv : module_vals)
{
const auto& name = vv.at("name").to<std::string>();
const auto& name = vv.get_key();
if(name == "main")
continue;
impl->modules.push_back({name});
map_mods[name] = &impl->modules.back();
impl->modules.emplace(name, name);
}
std::unordered_map<std::string, module_ref> map_mods;
std::transform(impl->modules.begin(),
impl->modules.end(),
std::inserter(map_mods, map_mods.end()),
[&](auto&& pp) { return std::make_pair(pp.first, &pp.second); });
std::unordered_map<std::string, instruction_ref> map_insts;
auto* mm = get_main_module();
......@@ -585,8 +584,8 @@ void program::debug_print() const { std::cout << *this << std::endl; }
void program::debug_print(instruction_ref ins) const
{
std::unordered_map<instruction_ref, std::string> names;
if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& it) {
return (it.end() == ins);
if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& pp) {
return (pp.second.end() == ins);
}))
{
std::cout << "End instruction" << std::endl;
......@@ -594,7 +593,7 @@ void program::debug_print(instruction_ref ins) const
}
else if(std::none_of(this->impl->modules.begin(),
this->impl->modules.end(),
[&](const auto& it) { return it.has_instruction(ins); }))
[&](const auto& pp) { return pp.second.has_instruction(ins); }))
{
std::cout << "Instruction not part of program" << std::endl;
return;
......@@ -615,9 +614,9 @@ void program::print(
const std::function<void(instruction_ref, std::unordered_map<instruction_ref, std::string>)>&
print_func) const
{
for(const auto& mod : this->impl->modules)
for(const auto& pp : this->impl->modules)
{
names = mod.print(print_func, names);
names = pp.second.print(print_func, names);
}
}
......@@ -647,74 +646,118 @@ void program::dry_run(std::unordered_map<std::string, argument> params) const
void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const
{
for(auto& mod : this->impl->modules)
for(auto& pp : this->impl->modules)
{
std::cout << mod.name() << ":" << std::endl;
mod.annotate(os, a);
std::cout << pp.first << ":" << std::endl;
pp.second.annotate(os, a);
}
}
const module* program::get_module(const std::string& name) const
{
auto it = std::find_if(
impl->modules.begin(), impl->modules.end(), [&](auto& m) { return (m.name() == name); });
if(it == impl->modules.end())
{
return nullptr;
}
return &(*it);
}
const module* program::get_module(const std::string& name) const { return &impl->modules.at(name); }
module* program::create_module(const std::string& name)
{
auto it = impl->modules.insert(impl->modules.end(), {name});
return &(*it);
assert(not contains(impl->modules, name));
auto r = impl->modules.emplace(name, name);
return &(r.first->second);
}
module* program::get_module(const std::string& name)
{
auto it = std::find_if(
impl->modules.begin(), impl->modules.end(), [&](auto& m) { return (m.name() == name); });
if(it == impl->modules.end())
{
return nullptr;
}
return &(*it);
}
module* program::get_module(const std::string& name) { return &impl->modules.at(name); }
module* program::get_main_module() { return get_module("main"); }
const module* program::get_main_module() const { return get_module("main"); }
std::vector<const module*> program::get_modules() const
template <class T>
std::vector<T*> generic_get_modules(T* mm)
{
const module* mm = this->get_main_module();
std::vector<const module*> vec_modules;
std::vector<T*> vec_modules;
vec_modules.push_back(mm);
auto sub_modules = mm->get_sub_modules();
vec_modules.insert(vec_modules.end(), sub_modules.begin(), sub_modules.end());
return vec_modules;
}
template <class Map, class T, class OutputIterator>
void generic_get_unused_modules(Map& m, const std::vector<T*>& mods, OutputIterator out)
{
std::unordered_set<std::string> used;
std::transform(mods.begin(), mods.end(), std::inserter(used, used.end()), [](auto&& mod) {
return mod->name();
});
transform_if(m.begin(),
m.end(),
out,
[&](auto&& pp) { return not contains(used, pp.first); },
[](auto&& pp) { return &pp.second; });
}
std::vector<const module*> program::get_modules() const
{
auto result = generic_get_modules(this->get_main_module());
generic_get_unused_modules(impl->modules, result, std::back_inserter(result));
return result;
}
std::vector<module*> program::get_modules()
{
module* mm = this->get_main_module();
std::vector<module*> vec_modules;
vec_modules.push_back(mm);
auto sub_modules = mm->get_sub_modules();
vec_modules.insert(vec_modules.end(), sub_modules.begin(), sub_modules.end());
auto result = generic_get_modules(this->get_main_module());
generic_get_unused_modules(impl->modules, result, std::back_inserter(result));
return result;
}
return vec_modules;
template <class Map, class T>
bool is_unused_module(Map& m, const std::vector<T*>& mods, const std::string& name)
{
bool is_unused = false;
generic_get_unused_modules(m, mods, make_function_output_iterator([&](auto* mod) {
if(mod->name() == name)
is_unused = true;
}));
return is_unused;
}
template <class Map>
bool references_instruction(Map& m, const instruction& ins, const std::string& name)
{
return std::any_of(m.begin(), m.end(), [&](auto&& p) {
if(p.first == name)
return false;
return std::any_of(p.second.begin(), p.second.end(), [&](auto&& i) {
return std::any_of(i.inputs().begin(), i.inputs().end(), [&](auto&& j) {
return std::addressof(*j) == std::addressof(ins);
});
});
});
}
void program::remove_module(const std::string& name)
{
// cppcheck-suppress assertWithSideEffect
assert(is_unused_module(impl->modules, generic_get_modules(this->get_main_module()), name) &&
"Module used in program");
assert(std::none_of(
impl->modules.at(name).begin(),
impl->modules.at(name).end(),
[&](auto&& ins) { return references_instruction(impl->modules, ins, name); }) &&
"Instruction referenced in another module");
impl->modules.erase(name);
}
void program::remove_unused_modules()
{
std::vector<module*> unused;
generic_get_unused_modules(
impl->modules, generic_get_modules(this->get_main_module()), std::back_inserter(unused));
for(auto* m : unused)
this->remove_module(m->name());
}
program& program::sort()
{
for(auto& mod : this->impl->modules)
for(auto& pp : this->impl->modules)
{
mod.sort();
pp.second.sort();
}
return *this;
......
......@@ -398,7 +398,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
m.def("to_gpu", &migraphx::gpu::to_gpu, py::arg("arg"), py::arg("host") = false);
m.def("from_gpu", &migraphx::gpu::from_gpu);
m.def("gpu_sync", &migraphx::gpu::gpu_sync);
m.def("gpu_sync", [] { migraphx::gpu::gpu_sync(); });
#endif
#ifdef VERSION_INFO
......
......@@ -11,8 +11,11 @@ void raw_data_to_value(value& v, const RawData& rd)
{
value result;
result["shape"] = migraphx::to_value(rd.get_shape());
result["data"] = migraphx::value::binary(rd.data(), rd.get_shape().bytes());
v = result;
if(rd.get_shape().type() == shape::tuple_type)
result["sub"] = migraphx::to_value(rd.get_sub_objects());
else
result["data"] = migraphx::value::binary(rd.data(), rd.get_shape().bytes());
v = result;
}
void migraphx_to_value(value& v, const literal& l) { raw_data_to_value(v, l); }
......@@ -25,8 +28,15 @@ void migraphx_from_value(const value& v, literal& l)
void migraphx_to_value(value& v, const argument& a) { raw_data_to_value(v, a); }
void migraphx_from_value(const value& v, argument& a)
{
literal l = migraphx::from_value<literal>(v);
a = l.get_argument();
if(v.contains("data"))
{
literal l = migraphx::from_value<literal>(v);
a = l.get_argument();
}
else
{
a = migraphx::from_value<std::vector<argument>>(v.at("sub"));
}
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -20,28 +20,36 @@ struct shape_impl
return result;
}
shape_impl() : m_type(shape::float_type), m_standard(false) {}
shape_impl() : m_type(shape::float_type) {}
shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({0}), m_standard(true) {}
shape_impl(shape::type_t t) : m_type(t), m_lens({1}), m_strides({0}), m_standard(true)
{
assert(t != shape::tuple_type);
}
shape_impl(shape::type_t t, std::vector<std::size_t> l)
: m_type(t), m_lens(std::move(l)), m_standard(true)
{
assert(t != shape::tuple_type);
this->calculate_strides();
assert(m_lens.size() == m_strides.size());
}
shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
{
assert(t != shape::tuple_type);
assert(m_lens.size() == m_strides.size());
// assert(std::any_of(m_strides.begin(), m_strides.end(), [](auto x) { return x > 0; }) and
// "At least one stride must be non-zero");
m_standard = this->elements() == this->element_space() and
std::is_sorted(m_strides.rbegin(), m_strides.rend());
}
shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}
shape::type_t m_type;
std::vector<std::size_t> m_lens;
std::vector<std::size_t> m_strides;
bool m_standard;
std::vector<std::size_t> m_lens = {};
std::vector<std::size_t> m_strides = {};
std::vector<shape> m_shapes = {};
bool m_standard = false;
void calculate_strides()
{
......@@ -84,7 +92,7 @@ const std::vector<shape::type_t>& shape::types()
{
static const std::vector<shape::type_t> result = {
#define MIGRAPHX_GENERATE_TYPE_VECTOR(x, t) x,
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_GENERATE_TYPE_VECTOR)};
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_GENERATE_TYPE_VECTOR) tuple_type};
return result;
}
......@@ -92,6 +100,7 @@ std::string shape::name(shape::type_t t)
{
switch(t)
{
case tuple_type: return "tuple_type";
#define MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE(x, t) \
case x: return #x;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_NAME_CASE)
......@@ -103,6 +112,7 @@ std::string shape::cpp_type(shape::type_t t)
{
switch(t)
{
case tuple_type: MIGRAPHX_THROW("No C++ type for tuple");
#define MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE(x, t) \
case x: return #t;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE)
......@@ -123,6 +133,8 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
{
}
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape shape::from_permutation(type_t t,
const std::vector<std::size_t>& l,
const std::vector<int64_t>& perm)
......@@ -139,14 +151,25 @@ const std::vector<std::size_t>& shape::strides() const { return impl->m_strides;
std::size_t shape::elements() const { return impl->elements(); }
std::size_t shape::bytes() const
{
std::size_t n = 0;
this->visit_type([&](auto as) { n = as.size(); });
return n * this->element_space();
if(this->sub_shapes().empty())
{
std::size_t n = 0;
this->visit_type([&](auto as) { n = as.size(); });
return n * this->element_space();
}
else
{
return std::accumulate(this->sub_shapes().begin(),
this->sub_shapes().end(),
std::size_t{0},
[&](auto x, auto y) { return x + y.bytes(); });
}
}
std::size_t shape::type_size() const
{
std::size_t n = 0;
this->visit_type([&](auto as) { n = as.size(); });
if(this->sub_shapes().empty())
this->visit_type([&](auto as) { n = as.size(); });
return n;
}
std::size_t shape::index(std::initializer_list<std::size_t> l) const
......@@ -208,7 +231,10 @@ void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end
});
}
bool shape::packed() const { return this->elements() == this->element_space(); }
bool shape::packed() const
{
return this->sub_shapes().empty() and this->elements() == this->element_space();
}
bool shape::transposed() const
{
......@@ -242,7 +268,8 @@ bool shape::scalar() const
{
assert(this->lens().size() == this->strides().size());
// if any stride > 0, then accumulate will return false
return std::accumulate(this->strides().begin(), this->strides().end(), std::size_t(0)) == 0;
return this->sub_shapes().empty() and
std::accumulate(this->strides().begin(), this->strides().end(), std::size_t(0)) == 0;
}
bool shape::standard() const { return impl->m_standard; }
......@@ -273,15 +300,23 @@ std::string shape::type_string() const { return name(this->type()); }
bool operator==(const shape& x, const shape& y)
{
return x.type() == y.type() && x.lens() == y.lens() && x.strides() == y.strides();
return x.impl == y.impl or (x.type() == y.type() and x.lens() == y.lens() and
x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes());
}
bool operator!=(const shape& x, const shape& y) { return !(x == y); }
std::ostream& operator<<(std::ostream& os, const shape& x)
{
os << x.type_string() << ", ";
os << "{" << to_string_range(x.lens()) << "}, ";
os << "{" << to_string_range(x.strides()) << "}";
if(x.sub_shapes().empty())
{
os << x.type_string() << ", ";
os << "{" << to_string_range(x.lens()) << "}, ";
os << "{" << to_string_range(x.strides()) << "}";
}
else
{
os << "[" << to_string_range(x.sub_shapes()) << "]";
}
return os;
}
......@@ -289,23 +324,36 @@ shape::type_t shape::parse_type(const std::string& s)
{
static const std::unordered_map<std::string, shape::type_t> m = {
#define MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP(x, t) {#x, x}, {#t, x},
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP)};
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_TYPE_STRING_MAP){"tuple_type",
tuple_type},
{"tuple", tuple_type}};
return m.at(s);
}
const std::vector<shape>& shape::sub_shapes() const { return impl->m_shapes; }
void migraphx_to_value(value& v, const shape& s)
{
value result;
result["type"] = migraphx::to_value(s.type_string());
result["lens"] = migraphx::to_value(s.lens());
result["strides"] = migraphx::to_value(s.strides());
v = result;
result["type"] = migraphx::to_value(s.type_string());
result["lens"] = migraphx::to_value(s.lens());
result["strides"] = migraphx::to_value(s.strides());
result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
v = result;
}
void migraphx_from_value(const value& v, shape& s)
{
s = shape{shape::parse_type(v.at("type").get_string()),
v.at("lens").to_vector<std::size_t>(),
v.at("strides").to_vector<std::size_t>()};
auto t = v.at("type").get_string();
if(t == "tuple_type")
{
s = shape{migraphx::from_value<std::vector<migraphx::shape>>(v.at("sub_shapes"))};
}
else
{
s = shape{shape::parse_type(t),
v.at("lens").to_vector<std::size_t>(),
v.at("strides").to_vector<std::size_t>()};
}
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -133,6 +133,7 @@ add_library(migraphx_gpu
logsoftmax.cpp
lrn.cpp
leaky_relu.cpp
mlir_conv.cpp
pack_args.cpp
pack_int8_args.cpp
pad.cpp
......@@ -148,6 +149,7 @@ add_library(migraphx_gpu
write_literals.cpp
)
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
function(register_migraphx_gpu_ops PREFIX)
foreach(OP ${ARGN})
register_op(migraphx_gpu HEADER migraphx/gpu/${OP}.hpp OPERATORS gpu::${PREFIX}${OP} INCLUDES migraphx/gpu/context.hpp)
......@@ -259,6 +261,20 @@ endif()
message(STATUS "clang-offload-bundler: ${MIGRAPHX_OFFLOADBUNDLER_BIN}")
message(STATUS "extractkernel: ${MIGRAPHX_EXTRACT_KERNEL}")
set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "")
if(MIGRAPHX_ENABLE_MLIR)
find_library(LIBMLIRMIOPEN MLIRMIOpenThin REQUIRED)
# REQUIRED is not supported before cmake 3.18
if(NOT LIBMLIRMIOPEN)
message(FATAL_ERROR "libMLIRMIOpenThin not found")
else()
message(STATUS "Build with libMLIRMIOpenThin: " ${LIBMLIRMIOPEN})
endif()
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_MLIR_MIOPEN_SUPPORT")
target_link_libraries(migraphx_gpu PUBLIC ${LIBMLIRMIOPEN})
endif()
# Get flags needed to compile hip
include(TargetFlags)
target_flags(HIP_COMPILER_FLAGS hip::device)
......@@ -286,7 +302,6 @@ endif()
# Workaround broken rocblas headers
target_compile_definitions(migraphx_gpu PUBLIC -D__HIP_PLATFORM_HCC__=1)
target_compile_options(migraphx_gpu PRIVATE -std=c++17)
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
......
......@@ -48,17 +48,6 @@ std::string generate_index_ints(const std::vector<T>& v)
return "index_ints<" + to_string_range(v) + ">{}";
}
std::string generate_cpp_type(shape::type_t t)
{
switch(t)
{
#define MIGRAPHX_GPU_GENERATE_TYPE_STRING(x, t) \
case shape::x: return #t;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_GPU_GENERATE_TYPE_STRING)
}
MIGRAPHX_THROW("Invalid type");
}
std::string generate_make_shape(const shape& s)
{
return "make_shape(" + generate_index_ints(s.lens()) + ", " + generate_index_ints(s.strides()) +
......@@ -80,7 +69,7 @@ std::string generate_make_tensor(std::size_t n, const shape& s)
{
return interpolate_string(make_tensor_template,
{{"n", std::to_string(n)},
{"type", generate_cpp_type(s.type())},
{"type", shape::cpp_type(s.type())},
{"lens", generate_index_ints(s.lens())},
{"strides", generate_index_ints(s.strides())}});
}
......
......@@ -7,9 +7,33 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void contiguous_nonstandard(hipStream_t stream, const argument& result, const argument& arg)
{
shape s{result.get_shape().type(), result.get_shape().lens()};
visit_all(result, arg)([&](auto output_v, auto input_v) {
hip_visit_views(output_v, input_v, s)([&](auto output, auto input, auto standard_shape) {
mi_gs_launch(stream,
standard_shape)([=](auto idx) __device__ { output[idx] = input[idx]; });
});
});
}
void contiguous_packed(hipStream_t stream, const argument& result, const argument& arg)
{
index_int nelements = result.get_shape().elements();
visit_all(result, arg)([&](auto output_v, auto input_v) {
const auto* input = device_cast(input_v.data());
auto* output = device_cast(output_v.data());
gs_launch(stream, nelements)([=](auto i) __device__ { output[i] = input[i]; });
});
}
void contiguous(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)([](auto x) __device__ { return x; });
if(result.get_shape() == arg.get_shape() and result.get_shape().packed())
contiguous_packed(stream, result, arg);
else
contiguous_nonstandard(stream, result, arg);
}
} // namespace device
......
......@@ -51,6 +51,50 @@ auto get_shape(const T& x) -> decltype(x.get_shape())
return x.get_shape();
}
template <class T>
struct is_hip_type : std::false_type
{
};
template <>
struct is_hip_type<float> : std::true_type
{
};
template <>
struct is_hip_type<half> : std::true_type
{
};
template <>
struct is_hip_type<bool> : std::true_type
{
};
template <>
struct is_hip_type<std::int8_t> : std::true_type
{
};
template <>
struct is_hip_type<std::uint8_t> : std::true_type
{
};
template <class T, class V, MIGRAPHX_REQUIRES(is_hip_type<typename T::type>{})>
void hip_visitor_invoke(T as, V&& v)
{
v(as);
}
template <class T, class V, MIGRAPHX_REQUIRES(not is_hip_type<typename T::type>{})>
void hip_visitor_invoke(T, V&&)
{
MIGRAPHX_THROW(std::string("Unsupported data type on GPU: ") + __PRETTY_FUNCTION__);
}
template <class V>
auto hip_visitor(V v)
{
return [=](auto as) { hip_visitor_invoke(as, v); };
}
template <class V, class F, class... Ts>
void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
{
......@@ -62,8 +106,9 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
static_cast<index_int>(get_shape(xs).lens().size())...};
if(!std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(),
[&](auto ndim) { s.visit_type([&](auto as) { v(f(xs, ndim, as)...); }); });
visit_tensor_size(s.lens().size(), [&](auto ndim) {
s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); }));
});
}
template <class V, class F, class... Ts>
......
......@@ -16,6 +16,7 @@ rocblas_datatype get_type(shape::type_t type)
case shape::uint8_type: return rocblas_datatype_u8_r;
case shape::int32_type: return rocblas_datatype_i32_r;
case shape::uint32_type: return rocblas_datatype_u32_r;
case shape::tuple_type:
case shape::bool_type:
case shape::uint16_type:
case shape::int16_type:
......
......@@ -15,6 +15,7 @@ namespace gpu {
MIGRAPHX_REGISTER_OP(hip_allocate)
MIGRAPHX_REGISTER_OP(hip_sync_device)
MIGRAPHX_REGISTER_OP(hip_sync_stream)
MIGRAPHX_REGISTER_OP(hip_copy_to_gpu)
MIGRAPHX_REGISTER_OP(hip_copy_from_gpu)
MIGRAPHX_REGISTER_OP(hip_copy)
......@@ -146,6 +147,8 @@ void gpu_sync()
MIGRAPHX_THROW("hip device synchronization failed: " + hip_error(status));
}
void gpu_sync(const context& ctx) { ctx.finish(); }
void hip_async_copy(context& ctx, const argument& src, const argument& dst, hipMemcpyKind kind)
{
std::size_t src_size = src.get_shape().bytes();
......
......@@ -21,10 +21,20 @@ using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy);
struct hip_device
{
hip_device() { add_stream(); }
hip_device()
{
device_props.gcnArchName[0] = '\0';
device_props.gcnArch = 0;
device_props.multiProcessorCount = 0;
add_stream();
}
hip_device(std::size_t id, std::size_t n) : device_id(id)
{
auto status = hipGetDeviceProperties(&device_props, device_id);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to allocate stream");
for(std::size_t i = 0; i < n; i++)
add_stream();
}
......@@ -97,6 +107,16 @@ struct hip_device
// MIGRAPHX_THROW("Unable to get hip device properties");
return "gfx" + std::to_string(props.gcnArch);
}
void wait() const
{
if(s == nullptr)
return;
setup();
auto status = hipStreamSynchronize(s.get());
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to wait.");
}
void wait(hipEvent_t event)
{
......@@ -127,16 +147,29 @@ struct hip_device
stream& get_stream(std::size_t n) { return streams.at(n); }
const stream& get_stream() const { return streams.at(current_stream); }
const stream& get_stream(std::size_t n) const { return streams.at(n); }
void set_stream(std::size_t n) { current_stream = n; }
std::size_t nstreams() const { return streams.size(); }
std::size_t stream_id() const { return current_stream; }
std::string get_device_name() const { return device_props.gcnArchName; }
std::size_t get_device_major() const { return device_props.major; }
std::size_t get_device_minor() const { return device_props.minor; }
std::size_t get_cu_count() const { return device_props.multiProcessorCount; }
private:
std::size_t device_id = 0;
std::size_t current_stream = 0;
std::vector<stream> streams;
hipDeviceProp_t device_props;
public:
std::unordered_map<std::string, argument> preallocations{};
......@@ -155,9 +188,21 @@ struct context
return *current_device;
}
const hip_device& get_current_device() const
{
assert(current_device != nullptr);
return *current_device;
}
hip_device::stream& get_stream() { return get_current_device().get_stream(); }
hip_device::stream& get_stream(std::size_t n) { return get_current_device().get_stream(n); }
const hip_device::stream& get_stream() const { return get_current_device().get_stream(); }
const hip_device::stream& get_stream(std::size_t n) const
{
return get_current_device().get_stream(n);
}
void set_stream(std::size_t n) { get_current_device().set_stream(n); }
void create_events(std::size_t num_of_events)
......@@ -169,7 +214,7 @@ struct context
hipEvent_t get_event(std::size_t i) const { return events.at(i).get(); }
std::vector<argument> literals{};
void finish() const { gpu_sync(); }
void finish() const { get_stream().wait(); }
static hip_event_ptr create_event()
{
......
......@@ -25,6 +25,7 @@ argument from_gpu(const argument& arg);
void set_device(std::size_t id);
void gpu_sync();
void gpu_sync(const context& ctx);
void gpu_copy(context& ctx, const argument& src, const argument& dst);
void copy_to_gpu(context& ctx, const argument& src, const argument& dst);
......@@ -82,6 +83,33 @@ struct hip_sync_device
}
};
struct hip_sync_stream
{
std::string tag{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.tag, "tag"));
}
std::string name() const { return "hip::sync_stream"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
if(inputs.empty())
return {};
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
gpu_sync(ctx);
if(args.empty())
return {};
return args.front();
}
};
struct hip_copy_to_gpu
{
std::string name() const { return "hip::copy_to_gpu"; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment