Unverified Commit 1b098fd7 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into type-string-driver

parents 05f2ee1c c0398ded
#include <migraphx/serialize.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/literal.hpp>
#include <nlohmann/json.hpp>
#include <migraphx/json.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
using json = nlohmann::json;
void value_to_json(const value& val, json& j);
migraphx::value value_from_json(const json& j);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
namespace nlohmann {
template <>
struct adl_serializer<migraphx::value>
{
static void to_json(json& j, const migraphx::value& val) { migraphx::value_to_json(val, j); }
static void from_json(const json& j, migraphx::value& val)
{
val = migraphx::value_from_json(j);
}
};
} // namespace nlohmann
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
using json = nlohmann::json;
template <class T>
void value_to_json(const T& x, json& j)
{
j = x;
}
void value_to_json(const value::binary& x, json& j)
{
j = json::object();
j["bytes"] = std::vector<int>(x.begin(), x.end());
}
void value_to_json(const std::vector<value>& x, json& j)
{
for(const auto& v : x)
{
if(v.get_key().empty())
{
j.push_back(v);
}
else
{
j[v.get_key()] = v.without_key();
}
}
}
void value_to_json(std::nullptr_t&, json& j) { j = {}; }
void value_to_json(const value& val, json& j)
{
if(val.is_array())
{
j = json::array();
}
if(val.is_object())
{
j = json::object();
}
val.visit([&](auto v) { value_to_json(v, j); });
}
migraphx::value value_from_json(const json& j)
{
migraphx::value val;
json::value_t type = j.type();
switch(type)
{
case json::value_t::null: val = nullptr; break;
case json::value_t::boolean: val = j.get<bool>(); break;
case json::value_t::number_float: val = j.get<double>(); break;
case json::value_t::number_integer: val = j.get<int64_t>(); break;
case json::value_t::number_unsigned: val = j.get<uint64_t>(); break;
case json::value_t::string: val = j.get<std::string>(); break;
case json::value_t::array:
val = migraphx::value::array{};
std::transform(j.begin(), j.end(), std::back_inserter(val), [&](const json& jj) {
return jj.get<value>();
});
break;
case json::value_t::object:
if(j.contains("bytes") and j.size() == 1)
{
val = migraphx::value::binary{j["bytes"].get<std::vector<std::uint8_t>>()};
}
else
{
val = migraphx::value::object{};
for(const auto& item : j.items())
{
const auto& key = item.key();
const json& jv = item.value();
val[key] = jv.get<value>();
}
}
break;
case json::value_t::binary: MIGRAPHX_THROW("Convert JSON to Value: binary type not supported!");
case json::value_t::discarded:
MIGRAPHX_THROW("Convert JSON to Value: discarded type not supported!");
}
return val;
}
std::string to_json_string(const value& val)
{
json j = val;
return j.dump();
}
std::string to_pretty_json_string(const value& val, std::size_t indent)
{
json j = val;
return j.dump(indent);
}
migraphx::value from_json_string(const char* str, std::size_t size)
{
json j = json::parse(str, str + size);
return j.get<value>();
}
migraphx::value from_json_string(const std::string& str)
{
json j = json::parse(str);
return j.get<value>();
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/load_save.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/json.hpp>
#include <migraphx/msgpack.hpp>
#include <migraphx/file_buffer.hpp>
#include <fstream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
program load(const std::string& filename, const file_options& options)
{
return load_buffer(read_buffer(filename), options);
}
program load_buffer(const std::vector<char>& buffer, const file_options& options)
{
return load_buffer(buffer.data(), buffer.size(), options);
}
program load_buffer(const char* buffer, std::size_t size, const file_options& options)
{
program p;
if(options.format == "msgpack")
{
p.from_value(from_msgpack(buffer, size));
}
else if(options.format == "json")
{
p.from_value(from_json_string(buffer, size));
}
else
{
MIGRAPHX_THROW("Unknown format: " + options.format);
}
return p;
}
void save(const program& p, const std::string& filename, const file_options& options)
{
write_buffer(filename, save_buffer(p, options));
}
std::vector<char> save_buffer(const program& p, const file_options& options)
{
value v = p.to_value();
std::vector<char> buffer;
if(options.format == "msgpack")
{
buffer = to_msgpack(v);
}
else if(options.format == "json")
{
std::string s = to_json_string(v);
buffer = std::vector<char>(s.begin(), s.end());
}
else
{
MIGRAPHX_THROW("Unknown format: " + options.format);
}
return buffer;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name) { return load_op(name); }
template <class F>
operation make_op_generic(const std::string& name, F for_each)
{
auto op = load_op(name);
// Merge values
value w = op.to_value();
for_each([&](const auto& key, const auto& x) {
if(not w.contains(key))
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
MIGRAPHX_THROW("No key '" + key + "' in " + name);
w.at(key) = x;
});
op.from_value(w);
return op;
}
operation make_op(const std::string& name,
const std::initializer_list<std::pair<std::string, value>>& v)
{
return make_op_generic(name, [&](auto f) {
for(auto&& [key, x] : v)
f(key, x);
});
}
operation make_op_from_value(const std::string& name, const value& v)
{
if(not(v.is_object() or (v.empty() and v.is_array())))
MIGRAPHX_THROW("Value is not an object for make_op: " + name);
return make_op_generic(name, [&](auto f) {
for(auto&& x : v)
f(x.get_key(), x.without_key());
});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <iterator>
#include <migraphx/module.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/target.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/make_op.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
#include <set>
#include <utility>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_FINALIZE)
struct module_impl
{
// A list is used to keep references to an instruction stable
std::list<instruction> instructions;
std::unordered_set<instruction*> instruction_set;
std::string name;
uint32_t nparams = 0;
bool bypass = false;
bool contains(instruction_ref ins) const
{
if(is_end(ins, instructions.end()))
return false;
return instruction_set.count(std::addressof(*ins)) > 0;
}
template <class... Ts>
instruction_ref emplace(instruction_ref pos, Ts&&... xs)
{
// cppcheck-suppress redundantInitialization
auto r = instructions.emplace(pos, std::forward<Ts>(xs)...);
instruction_set.insert(std::addressof(*r));
return r;
}
instruction_ref insert(instruction_ref pos, const instruction& ins)
{
return emplace(pos, ins);
}
void clear()
{
instructions.clear();
instruction_set.clear();
nparams = 0;
}
void push_front(const instruction& ins) { insert(instructions.begin(), ins); }
void push_back(const instruction& ins) { insert(instructions.end(), ins); }
template <class... Ts>
void emplace_front(Ts&&... xs)
{
emplace(instructions.begin(), std::forward<Ts>(xs)...);
}
template <class... Ts>
void emplace_back(Ts&&... xs)
{
emplace(instructions.end(), std::forward<Ts>(xs)...);
}
instruction_ref erase(instruction_ref pos)
{
instruction_set.erase(std::addressof(*pos));
return instructions.erase(pos);
}
instruction_ref erase(instruction_ref start, instruction_ref last)
{
std::for_each(start, last, [&](auto& ins) { instruction_set.erase(std::addressof(ins)); });
return instructions.erase(start, last);
}
};
const operation& get_operation(instruction_ref ins) { return ins->get_operator(); }
module::module(const std::string& name) : impl(std::make_unique<module_impl>())
{
impl->name = name;
}
module::module(module&&) noexcept = default;
module::~module() noexcept = default;
// copy constructor
module::module(const module& m) { assign(m); }
// copy assignment operator
module& module::operator=(module m)
{
std::swap(m.impl, this->impl);
return *this;
}
std::string module::name() const { return impl->name; }
bool module::bypass() const { return impl->bypass; }
void module::set_bypass(bool b) { impl->bypass = b; }
void module::assign(const module& m)
{
// copy the impl
if(!impl)
impl = std::make_unique<module_impl>();
*impl = *m.impl;
// clear instructions
if(!impl->instructions.empty())
{
impl->clear();
}
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(m))
{
instruction_ref copy_ins{};
if(ins->name() == "@literal")
{
auto l = ins->get_literal();
copy_ins = impl->insert(impl->instructions.end(), instruction{l});
}
else if(ins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins->get_operator()).parameter;
auto order = any_cast<builtin::param>(ins->get_operator()).order;
auto s = ins->get_shape();
copy_ins = impl->insert(impl->instructions.end(),
{builtin::param{name, order}, std::move(s), {}});
}
else if(ins->name() == "@outline")
{
auto s = ins->get_shape();
copy_ins = impl->insert(impl->instructions.end(), {builtin::outline{s}, s, {}});
}
else
{
// if there are sub_module inputs, need to make a copy of the submodule
auto module_args = ins->module_inputs();
// retrieve its mapped input
auto inputs = ins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(ins_map, i) ? ins_map[i] : i;
});
if(ins->name() == "@return")
{
copy_ins = add_return(copy_inputs);
}
else
{
copy_ins = add_instruction(ins->get_operator(), copy_inputs, module_args);
}
}
ins_map[ins] = copy_ins;
}
}
instruction_ref module::add_instruction(const operation& op, std::vector<instruction_ref> args)
{
return insert_instruction(impl->instructions.end(), op, std::move(args));
}
instruction_ref module::insert_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args)
{
assert(has_instruction(ins) or is_end(ins, this->end()));
assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args);
auto result = impl->insert(ins, {op, r, std::move(args)});
instruction::backreference(result);
assert(result->valid(begin()));
return result;
}
instruction_ref module::add_instruction(const operation& op,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args)
{
return insert_instruction(
impl->instructions.end(), op, std::move(args), std::move(module_args));
}
instruction_ref module::insert_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args)
{
assert(has_instruction(ins) or is_end(ins, this->end()));
assert(not starts_with(op.name(), "@"));
auto out_shape = compute_shape(op, args, module_args);
auto result = impl->insert(ins, {op, out_shape, std::move(args), std::move(module_args)});
instruction::backreference(result);
assert(result->valid(begin()));
return result;
}
instruction_ref module::replace_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args) MIGRAPHX_TIDY_CONST
{
assert(has_instruction(ins));
assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args);
instruction::replace(ins, op, r, std::move(args));
assert(ins->valid(begin()));
return ins;
}
instruction_ref module::replace_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args) MIGRAPHX_TIDY_CONST
{
assert(has_instruction(ins));
assert(not starts_with(op.name(), "@"));
auto out_shape = compute_shape(op, args, module_args);
instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args));
assert(ins->valid(begin()));
return ins;
}
instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref rep)
{
assert(has_instruction(ins));
assert(has_instruction(rep));
assert(ins != rep);
if(ins == std::prev(this->end()))
{
return replace_instruction(ins, make_op("identity"), rep);
}
// TODO: Should it be an error if the output is empty?
if(ins->outputs().empty())
{
return rep;
}
// Make a copy of outputs which can be changed when calling replace_argument
auto outputs = ins->outputs();
for(auto out : outputs)
{
// TODO: Check for possible cycles
if(out != rep)
{
instruction::replace_argument(out, ins, rep);
}
assert(out->valid(begin()));
}
// Replacement should not be dead code unless its the last instruction
assert(!rep->outputs().empty() or rep == std::prev(end()));
// Output of the original instruction should only be the replacement or empty
assert(ins->outputs().empty() or std::all_of(ins->outputs().begin(),
ins->outputs().end(),
[&](auto i) { return i == rep; }));
assert(ins->valid(begin()));
assert(rep->valid(begin()));
return rep;
}
instruction_ref module::remove_instruction(instruction_ref ins)
{
assert(has_instruction(ins));
assert(ins->outputs().empty());
ins->clear_arguments();
return impl->erase(ins);
}
instruction_ref module::remove_instructions(instruction_ref first, instruction_ref last)
{
if(first == last)
return first;
// TODO: Check every element
assert(has_instruction(first));
std::for_each(first, last, [&](instruction& ins) { ins.clear_arguments(); });
assert(std::all_of(first, last, [&](const instruction& ins) { return ins.outputs().empty(); }));
return impl->erase(first, last);
}
instruction_ref module::move_instruction(instruction_ref src, instruction_ref dst)
{
assert(has_instruction(src));
assert(has_instruction(dst) or is_end(dst, this->end()));
impl->instructions.splice(dst, impl->instructions, src);
return src;
}
instruction_ref module::move_instructions(instruction_ref src, instruction_ref dst)
{
this->move_instruction(src, dst);
for(auto ins : src->inputs())
this->move_instruction(ins, src);
return src;
}
std::vector<instruction_ref> module::insert_module_instructions(
instruction_ref ins, module_ref m, std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
std::vector<instruction_ref> mod_outputs;
for(auto sins : iterator_for(*m))
{
if(contains(map_ins, sins))
continue;
instruction_ref copy_ins;
if(sins->name() == "@literal")
{
auto l = sins->get_literal();
copy_ins = this->add_literal(l);
}
else if(sins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(sins->get_operator()).parameter;
auto s = sins->get_shape();
copy_ins = this->add_parameter(name, s);
}
else if(sins->name() == "@outline")
{
auto s = sins->get_shape();
copy_ins = this->add_outline(s);
}
else
{
auto mod_args = sins->module_inputs();
auto inputs = sins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(map_ins, i) ? map_ins[i] : i;
});
if(sins->name() == "@return")
{
mod_outputs = copy_inputs;
break;
}
copy_ins = this->insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args);
}
map_ins[sins] = copy_ins;
}
if(mod_outputs.empty())
mod_outputs = {map_ins.at(std::prev(m->end()))};
return mod_outputs;
}
instruction_ref module::add_literal(literal l)
{
impl->emplace_front(std::move(l));
return impl->instructions.begin();
}
instruction_ref module::add_outline(const shape& s)
{
impl->push_front({builtin::outline{s}, s, {}});
return impl->instructions.begin();
}
instruction_ref module::add_parameter(std::string name, shape s)
{
assert(get_parameter_shape(name) == shape{});
impl->push_front({builtin::param{std::move(name), impl->nparams}, std::move(s), {}});
impl->nparams++;
return impl->instructions.begin();
}
instruction_ref module::add_return(std::vector<instruction_ref> args)
{
impl->push_back({builtin::returns{}, {}, std::move(args)});
auto result = std::prev(impl->instructions.end());
instruction::backreference(result);
assert(result->valid(begin()));
return result;
}
instruction_ref module::replace_return(std::vector<instruction_ref> args)
{
auto last = std::prev(this->end());
// If there is no return then add a return
if(last->name() != "@return")
return this->add_return(args);
shape r = compute_shape(last->get_operator(), args);
instruction::replace(last, last->get_operator(), r, std::move(args));
assert(last->valid(begin()));
return last;
}
shape module::get_parameter_shape(std::string name) const
{
auto ins = std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
if(x.name() == "@param")
{
return any_cast<builtin::param>(x.get_operator()).parameter == name;
}
else
{
return false;
}
});
if(ins != this->end())
return ins->get_shape();
else
return {};
}
std::vector<std::string> module::get_parameter_names() const
{
std::vector<std::string> result;
std::vector<builtin::param> params;
for(auto&& ins : impl->instructions)
{
if(ins.name() == "@param")
{
auto&& param = any_cast<builtin::param>(ins.get_operator());
params.push_back(param);
}
}
std::stable_sort(
params.begin(), params.end(), by(std::less<>{}, [](auto&& p) { return p.order; }));
std::transform(params.begin(), params.end(), std::back_inserter(result), [&](auto&& p) {
return p.parameter;
});
return result;
}
instruction_ref module::get_parameter(std::string name) const
{
auto ins = std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
if(x.name() == "@param")
{
return any_cast<builtin::param>(x.get_operator()).parameter == name;
}
else
{
return false;
}
});
if(ins != this->end())
return ins;
else
return this->end();
}
std::unordered_map<std::string, shape> module::get_parameter_shapes() const
{
std::unordered_map<std::string, shape> result;
for(auto&& ins : impl->instructions)
{
if(ins.name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins.get_operator()).parameter;
result[name] = ins.get_shape();
}
}
return result;
}
bool module::has_instruction(instruction_ref ins) const { return impl->contains(ins); }
std::size_t module::size() const { return impl->instructions.size(); }
instruction_ref module::begin() const { return impl->instructions.begin(); }
instruction_ref module::end() const { return impl->instructions.end(); }
std::vector<shape> module::get_output_shapes() const
{
if(impl->instructions.empty())
return {};
auto last_ins = impl->instructions.back();
if(last_ins.name() == "@return")
{
const auto& output_ins = last_ins.inputs();
std::vector<shape> output_shapes;
std::transform(output_ins.begin(),
output_ins.end(),
std::back_inserter(output_shapes),
[](auto& ins) { return ins->get_shape(); });
return output_shapes;
}
// The else branch is to provide backward compatibility
else
{
return {last_ins.get_shape()};
}
}
instruction_ref module::validate() const
{
return std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& i) {
auto inputs = i.inputs();
bool check_order = std::all_of(
inputs.begin(), inputs.end(), [&](auto in) { return has_instruction(in); });
return !i.valid(impl->instructions.begin(), check_order);
});
}
bool is_borrowed(instruction_ref ins)
{
auto alias = instruction::get_output_alias(ins, true);
if(alias == ins)
return false;
lifetime l = alias->get_operator().get_lifetime();
if(l == lifetime::borrow)
return true;
return is_borrowed(alias);
}
bool is_global(instruction_ref ins)
{
const auto& op = instruction::get_output_alias(ins)->get_operator();
return op.name() == "@param" or op.get_lifetime() == lifetime::global;
}
bool is_dangling(instruction_ref ins) { return not is_global(ins) and is_borrowed(ins); }
instruction_ref module::find_dangling_reference() const
{
auto last = std::prev(end());
if(last->name() == "@return")
{
auto dangling = std::find_if(
last->inputs().begin(), last->inputs().end(), [](auto x) { return is_dangling(x); });
if(dangling != last->inputs().end())
return *dangling;
}
else if(is_dangling(last))
{
return last;
}
return end();
}
void module::finalize(context& ctx)
{
const bool trace = enabled(MIGRAPHX_TRACE_FINALIZE{});
for(auto ins : iterator_for(*this))
{
if(trace)
{
std::cout << "Finalize: ";
this->debug_print(ins);
}
ins->finalize(ctx);
for(const auto& smod : ins->module_inputs())
{
smod->finalize(ctx);
}
}
// Warn when an instruction is not normalized
auto ins = std::find_if(begin(), end(), [](auto& i) { return i.need_normalization(); });
if(ins != end())
std::cerr << "WARNING: Instruction needs normalization, performance may be affected."
<< std::endl;
}
void module::debug_print() const { std::cout << *this << std::endl; }
void module::debug_print(instruction_ref ins,
std::unordered_map<instruction_ref, std::string>& names) const
{
if(is_end(ins, this->end()))
{
std::cout << "End instruction" << std::endl;
return;
}
if(not has_instruction(ins))
{
std::cout << "Instruction not part of module" << std::endl;
return;
}
std::stringstream ss;
names = this->print(
[&](auto x, auto ins_names) {
if(x == ins)
{
instruction::print(std::cout, x, ins_names);
std::cout << std::endl;
}
},
names);
}
void module::debug_print(instruction_ref ins) const
{
std::unordered_map<instruction_ref, std::string> names;
this->debug_print(ins, names);
}
void module::debug_print(const std::vector<instruction_ref>& inss) const
{
for(auto ins : inss)
this->debug_print(ins);
std::cout << std::endl;
}
std::unordered_map<instruction_ref, std::string> module::print(
const std::function<void(instruction_ref,
const std::unordered_map<instruction_ref, std::string>&)>& print_func,
std::unordered_map<instruction_ref, std::string> names) const
{
int count = 0;
for(auto ins : iterator_for(*this))
{
std::string var_name;
if(ins->name() == "@param")
{
var_name = any_cast<builtin::param>(ins->get_operator()).parameter;
}
else
{
var_name = this->name();
var_name.append((this->name().empty() ? "@" : ":@"));
var_name.append(std::to_string(count));
}
// count every instruction so index matches loc in the printout program
count++;
names.emplace(ins, var_name);
print_func(ins, names);
}
return names;
}
void module::print(const std::function<
void(instruction_ref, const std::unordered_map<instruction_ref, std::string>&)>&
print_func) const
{
this->print(print_func, {});
}
static std::string enclose_name(const std::string& name)
{
return '"' + replace_string(name, "\"", "\\\"") + '"';
}
void module::print_graph(std::ostream& os, bool brief) const
{
os << "digraph {" << std::endl;
os << "\trankdir=LR;" << std::endl;
this->print([&](auto ins, auto ins_names) {
std::string label;
if(brief)
label = ins->name();
else
label = to_string(ins->get_operator());
os << "\t" << enclose_name(ins_names.at(ins)) << "[label=" << enclose_name(label) << "]";
os << ";" << std::endl;
if(!ins->inputs().empty())
{
for(auto&& arg : ins->inputs())
{
os << "\t" << enclose_name(ins_names.at(arg)) << " -> "
<< enclose_name(ins_names.at(ins));
if(not brief)
os << "[label=" << enclose_name(to_string(ins->get_shape())) << "]";
os << ";" << std::endl;
}
}
});
os << "}" << std::endl;
}
static std::string cpp_var_name(const std::string& name)
{
return "m" + replace_string(name, "@", "x");
}
static std::string cpp_op_var(const std::string& name, instruction_ref ins)
{
return replace_string(name, "@", ins->name());
}
static void print_op_attributes(std::ostream& os, const std::string& name, const operation& op)
{
std::string x = to_string(op);
if(contains(x, "["))
{
auto start = x.find('[');
auto end = x.find(']');
std::string attribute_text = x.substr(start + 1, end - start - 1);
std::vector<std::string> attributes;
for(auto&& attribute : split_string(attribute_text, ','))
{
if(contains(attribute, '='))
attributes.push_back(attribute);
else
attributes.back() += "," + attribute;
}
for(auto&& attribute : attributes)
{
auto p = split_string(attribute, '=');
auto key = p.front();
auto value = p.back();
if(contains({"bn_mode", "padding_mode"}, key))
continue;
if(key == "mode")
value = enclose_name(trim(value));
os << name << "." << key << " = " << value << ";" << std::endl;
}
}
}
static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
{
os << "migraphx::shape{migraphx::shape::" << s.type_string();
os << ", {" << to_string_range(s.lens()) << "}";
if(not s.standard())
os << ", {" << to_string_range(s.strides()) << "}";
os << "}";
}
std::unordered_map<instruction_ref, std::string>
module::print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::string> names) const
{
os << "migraphx::module p;" << std::endl;
unsigned long seed = 0;
names = this->print(
[&](auto ins, auto ins_names) {
auto op = cpp_op_var(ins_names.at(ins), ins);
if(ins->name().front() != '@')
{
os << "migraphx::op::" << ins->name() << " " << op << ";" << std::endl;
print_op_attributes(os, op, ins->get_operator());
}
os << "auto " << cpp_var_name(ins_names.at(ins)) << " = ";
if(ins->name() == "@literal")
{
os << "p.add_literal(";
bool use_abs = false;
ins->get_literal().visit([&](auto v) {
use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
});
if(use_abs)
os << "migraphx::abs(";
os << "migraphx::generate_literal(";
print_cpp_shape(os, ins->get_shape());
os << ", " << seed << ")";
if(use_abs)
os << ")";
os << ");" << std::endl;
seed++;
}
else if(ins->name() == "@param")
{
std::string name = any_cast<builtin::param>(ins->get_operator()).parameter;
os << "p.add_parameter(" << enclose_name(name) << ",";
print_cpp_shape(os, ins->get_shape());
os << ");" << std::endl;
}
else
{
os << "p.add_instruction(" << op;
for(auto input : ins->inputs())
{
os << ", " << cpp_var_name(ins_names.at(input));
}
os << ");" << std::endl;
}
},
names);
return names;
}
void module::print_cpp(std::ostream& os) const { this->print_cpp(os, {}); }
void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const
{
this->print([&](auto ins, auto ins_names) {
instruction::print(os, ins, ins_names);
a(ins);
os << std::endl;
});
}
std::vector<module_ref> module::get_sub_modules() const
{
std::vector<module_ref> vec_modules;
for(auto ins : iterator_for(*this))
{
const auto& mod_args = ins->module_inputs();
vec_modules.insert(vec_modules.end(), mod_args.begin(), mod_args.end());
for(const auto& smod : mod_args)
{
auto sub_mods = smod->get_sub_modules();
vec_modules.insert(vec_modules.end(), sub_mods.begin(), sub_mods.end());
}
}
return vec_modules;
}
module& module::sort()
{
fix([&](auto self, auto ins) {
this->move_instruction(ins, this->begin());
for(auto child : ins->inputs())
{
if(!contains(this->impl->instructions, child))
{
continue;
}
self(child);
}
})(std::prev(this->end()));
assert(this->validate() == this->end());
return *this;
}
void module::calc_implicit_deps(const module& smod,
const module& pmod,
instruction_ref ins,
ins_dep_map& deps) const
{
const auto& ins_inputs = ins->inputs();
for(auto ii : iterator_for(smod))
{
const auto& ii_inputs = ii->inputs();
for(auto iii : ii_inputs)
{
if(pmod.has_instruction(iii))
{
if(not contains(ins_inputs, iii))
deps[ins].insert(iii);
}
}
const auto& mod_args = ii->module_inputs();
if(not mod_args.empty())
{
for(const auto* ssmod : mod_args)
{
calc_implicit_deps(*ssmod, pmod, ins, deps);
}
}
}
}
ins_dep_map module::calc_implicit_deps() const
{
ins_dep_map mod_implicit_deps;
for(auto ins : iterator_for(*this))
{
const auto& mod_args = ins->module_inputs();
if(mod_args.empty())
{
continue;
}
for(const auto* mod : mod_args)
{
calc_implicit_deps(*mod, *this, ins, mod_implicit_deps);
}
}
return mod_implicit_deps;
}
bool operator==(const module& x, const module& y) { return to_string(x) == to_string(y); }
std::ostream& operator<<(std::ostream& os, const module& m)
{
m.print([&](auto ins, auto ins_names) {
instruction::print(os, ins, ins_names);
os << std::endl;
});
return os;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/msgpack.hpp>
#include <migraphx/serialize.hpp>
#include <msgpack.hpp>
namespace msgpack {
MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
{
namespace adaptor {
template <>
struct convert<migraphx::value>
{
const msgpack::object& operator()(const msgpack::object& o, migraphx::value& v) const
{
switch(o.type)
{
case msgpack::type::NIL: {
v = nullptr;
break;
}
case msgpack::type::BOOLEAN: {
v = o.as<bool>();
break;
}
case msgpack::type::POSITIVE_INTEGER: {
v = o.as<std::uint64_t>();
break;
}
case msgpack::type::NEGATIVE_INTEGER: {
v = o.as<std::int64_t>();
break;
}
case msgpack::type::FLOAT32:
case msgpack::type::FLOAT64: {
v = o.as<double>();
break;
}
case msgpack::type::STR: {
v = o.as<std::string>();
break;
}
case msgpack::type::BIN: {
v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size};
break;
}
case msgpack::type::ARRAY: {
migraphx::value r = migraphx::value::array{};
std::for_each(
o.via.array.ptr,
o.via.array.ptr + o.via.array.size,
[&](const msgpack::object& so) { r.push_back(so.as<migraphx::value>()); });
v = r;
break;
}
case msgpack::type::MAP: {
migraphx::value r = migraphx::value::object{};
std::for_each(o.via.map.ptr,
o.via.map.ptr + o.via.map.size,
[&](const msgpack::object_kv& p) {
r[p.key.as<std::string>()] = p.val.as<migraphx::value>();
});
v = r;
break;
}
case msgpack::type::EXT: {
MIGRAPHX_THROW("msgpack EXT type not supported.");
}
}
return o;
}
};
template <>
struct pack<migraphx::value::binary>
{
template <class Stream>
packer<Stream>& operator()(msgpack::packer<Stream>& o,
const migraphx::value::binary& x) const
{
const auto* data = reinterpret_cast<const char*>(x.data());
auto size = x.size();
o.pack_bin(size);
o.pack_bin_body(data, size);
return o;
}
};
template <>
struct pack<migraphx::value>
{
template <class Stream>
void write(msgpack::packer<Stream>& o, const std::nullptr_t&) const
{
o.pack_nil();
}
template <class Stream, class T>
void write(msgpack::packer<Stream>& o, const T& x) const
{
o.pack(x);
}
template <class Stream>
void write(msgpack::packer<Stream>& o, const std::vector<migraphx::value>& v) const
{
if(v.empty())
{
o.pack_array(0);
return;
}
if(not v.front().get_key().empty())
{
o.pack_map(v.size());
for(auto&& x : v)
{
o.pack(x.get_key());
o.pack(x.without_key());
}
}
else
{
o.pack_array(v.size());
for(auto&& x : v)
{
o.pack(x);
}
}
}
template <class Stream>
packer<Stream>& operator()(msgpack::packer<Stream>& o, const migraphx::value& v) const
{
v.visit_value([&](auto&& x) { this->write(o, x); });
return o;
}
};
} // namespace adaptor
} // MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
} // namespace msgpack
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct vector_stream
{
std::vector<char> buffer{};
vector_stream& write(const char* b, std::size_t n)
{
buffer.insert(buffer.end(), b, b + n);
return *this;
}
};
std::vector<char> to_msgpack(const value& v)
{
vector_stream vs;
msgpack::pack(vs, v);
return vs.buffer;
}
value from_msgpack(const char* buffer, std::size_t size)
{
msgpack::object_handle oh = msgpack::unpack(buffer, size);
return oh.get().as<value>();
}
value from_msgpack(const std::vector<char>& buffer)
{
return from_msgpack(buffer.data(), buffer.size());
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/operation.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/normalize_attribute.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// different attributes
// 1) use_input(default)/use_output
// 2) use_rank(default)/use_len
// 3) clip_min(default)/not_clip_min
// 3.1) include_min(default)/exclude_min
// 4) clip_max(default)/not_clip_max
// 4.1) exclude_max(default)/include_max
auto tune_attribute(const std::vector<int64_t>& vec,
const std::vector<int64_t>& axes,
const value& val,
const std::vector<std::size_t>& lens)
{
std::vector<int64_t> result(vec);
int64_t n_rank = lens.size();
std::vector<op::normalize_attribute> vec_attrs = val.to_vector<op::normalize_attribute>();
if(contains(vec_attrs, op::normalize_attribute::use_output))
{
n_rank = n_rank + vec.size();
}
std::vector<int64_t> max_vals(vec.size(), n_rank);
if(contains(vec_attrs, op::normalize_attribute::use_len))
{
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) { return lens[i]; });
}
if(contains(vec_attrs, op::normalize_attribute::clip_max))
{
if(contains(vec_attrs, op::normalize_attribute::include_max))
{
std::transform(result.begin(),
result.end(),
max_vals.begin(),
result.begin(),
[](auto v, auto mv) { return v > mv ? mv : v; });
}
else
{
std::transform(result.begin(),
result.end(),
max_vals.begin(),
result.begin(),
[](auto v, auto mv) { return v >= mv ? mv - 1 : v; });
}
}
else
{
if(contains(vec_attrs, op::normalize_attribute::include_max))
{
if(!std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!");
}
}
else
{
if(!std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!");
}
}
}
std::vector<int64_t> min_vals = max_vals;
std::transform(min_vals.begin(), min_vals.end(), min_vals.begin(), [](auto v) { return -v; });
if(contains(vec_attrs, op::normalize_attribute::clip_min))
{
if(contains(vec_attrs, op::normalize_attribute::include_min))
{
std::transform(result.begin(),
result.end(),
min_vals.begin(),
result.begin(),
[](auto v, auto mv) { return v < mv ? mv : v; });
}
else
{
std::transform(result.begin(),
result.end(),
min_vals.begin(),
result.begin(),
[](auto v, auto mv) { return v < mv + 1 ? mv + 1 : v; });
}
}
else
{
if(contains(vec_attrs, op::normalize_attribute::include_min))
{
if(!std::equal(min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!");
}
}
else
{
if(!std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!");
}
}
}
std::transform(
result.begin(), result.end(), max_vals.begin(), result.begin(), [](auto v, auto mv) {
return v < 0 ? v + mv : v;
});
return result;
}
auto tune_pad_attribute(const value& val)
{
std::vector<size_t> vec_attrs = val.to_vector<size_t>();
std::vector<size_t> result(vec_attrs.begin(), vec_attrs.end());
std::copy(vec_attrs.begin(), vec_attrs.end(), std::back_inserter(result));
return result;
}
bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
{
bool tuned = false;
auto attrs = op.attributes();
auto val = op.to_value();
if(attrs.contains("normalize_padding"))
{
auto padding = val.at(attrs.at("normalize_padding").to<std::string>());
auto padding_size = padding.size();
// for now, assume the dimensions to pad start at dim 2
auto padding_start = 2;
if(padding_size == 2 * (lens.size() - padding_start))
tuned = true;
else if(padding_size != (lens.size() - padding_start))
MIGRAPHX_THROW("inconsistent padding size");
else
{
auto result = tune_pad_attribute(padding);
val["padding"] = result;
op.from_value(val);
tuned = true;
}
}
if(!attrs.contains("normalize_axes"))
{
return tuned;
}
auto attr_v = attrs.at("normalize_axes").without_key();
for(const auto& rv : attr_v)
{
const auto& key = rv.get_key();
if(val.contains(key))
{
auto vv = val.at(key).without_key();
if(vv.is_array())
{
std::vector<int64_t> axes;
if(val.contains("axes"))
{
axes = val.at("axes").without_key().to_vector<int64_t>();
}
auto vec = vv.to_vector<int64_t>();
auto result = tune_attribute(vec, axes, rv.without_key(), lens);
val[key] = result;
op.from_value(val);
val = op.to_value();
tuned = true;
}
else
{
auto num = vv.to<int64_t>();
auto result = tune_attribute({num}, {num}, rv.without_key(), lens);
val[key] = result.front();
op.from_value(val);
val = op.to_value();
tuned = true;
}
}
else
{
MIGRAPHX_THROW("NORMALIZE_ATTR : op " + op.name() + " attribute \"" + key +
"\" not exist!");
}
}
return tuned;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <unordered_set>
#include <migraphx/normalize_attributes.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/auto_any_cast.hpp>
#include <migraphx/value.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void normalize_ops::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
auto inputs = ins->inputs();
if(inputs.empty())
continue;
auto lens = inputs[0]->get_shape().lens();
migraphx::operation tuned_op = ins->get_operator();
if(normalize_attributes(tuned_op, lens))
{
m.replace_instruction(ins, tuned_op, inputs);
ins->set_normalized();
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -7,23 +7,15 @@ target_compile_options(onnx-proto PRIVATE -w)
target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
add_library(migraphx_onnx onnx.cpp)
file(GLOB ONNX_SRCS ${CONFIGURE_DEPENDS} *.cpp)
add_library(migraphx_onnx ${ONNX_SRCS})
target_include_directories(migraphx_onnx PRIVATE include)
set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx)
rocm_set_soversion(migraphx_onnx ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_onnx)
target_link_libraries(migraphx_onnx PRIVATE onnx-proto)
target_link_libraries(migraphx_onnx PRIVATE onnx-proto "-Wl,--exclude-libs,ALL")
target_link_libraries(migraphx_onnx PUBLIC migraphx)
rocm_install_targets(
TARGETS migraphx_onnx
)
if(MIGRAPHX_ENABLE_GPU)
add_executable(mnist mnist.cpp)
rocm_clang_tidy_check(mnist)
target_link_libraries(mnist migraphx_cpu migraphx_gpu migraphx_onnx)
add_executable(cifar10 cifar10.cpp)
rocm_clang_tidy_check(cifar10)
target_link_libraries(cifar10 migraphx_cpu migraphx_gpu migraphx_onnx)
endif()
\ No newline at end of file
#include <migraphx/onnx/checks.hpp>
#include <migraphx/errors.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
void check_arg_empty(const argument& arg, const std::string& msg)
{
if(arg.empty())
{
MIGRAPHX_THROW(msg);
}
}
void check_attr_sizes(size_t kdims, size_t attr_size, const std::string& error_msg)
{
if(kdims != attr_size)
{
MIGRAPHX_THROW(error_msg + " k-dims: " + std::to_string(kdims) +
" attribute size: " + std::to_string(attr_size));
}
}
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <cstdio>
#include <string>
#include <fstream>
#include <numeric>
#include <stdexcept>
#include <migraphx/onnx.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/generate.hpp>
#include "softmax.hpp"
auto read_cifar10_images(const std::string& full_path)
{
std::ifstream file(full_path, std::ios::binary);
const size_t nimages = 10;
const size_t nbytes_per_image = 3072;
std::vector<uint8_t> raw_data(nimages * (nbytes_per_image + 1));
std::vector<uint8_t> labels(nimages);
std::vector<float> data(nimages * nbytes_per_image);
if(file.is_open())
{
file.read(reinterpret_cast<char*>(raw_data.data()),
(nbytes_per_image + 1) * nimages * sizeof(uint8_t));
uint8_t* pimage = raw_data.data();
for(size_t i = 0; i < nimages; i++, pimage += nbytes_per_image)
{
labels[i] = *pimage++;
for(size_t j = 0; j < nbytes_per_image; j++)
{
float v = float(*(pimage + j)) / 255.0f;
data[i * nbytes_per_image + j] = v;
}
}
return std::make_pair(labels, data);
}
else
{
throw std::runtime_error("Cannot open file `" + full_path + "`!");
}
}
int main(int argc, char const* argv[])
{
if(argc < 4)
{
throw std::runtime_error("Usage: cifar10 [gpu | cpu] <onnx file> <cifar10 data file>");
}
std::string gpu_cpu = argv[1];
std::string file = argv[2];
std::string datafile = argv[3];
auto prog = migraphx::parse_onnx(file);
std::cout << prog << std::endl;
auto imageset = read_cifar10_images(datafile);
if(gpu_cpu == "gpu")
{
// GPU target
prog.compile(migraphx::gpu::target{});
migraphx::program::parameter_map m;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 32, 32}};
for(auto&& x : prog.get_parameter_shapes())
{
m[x.first] = migraphx::gpu::to_gpu(migraphx::generate_argument(x.second));
}
auto labels = imageset.first;
auto input = imageset.second;
auto ptr = input.data();
for(int i = 0; i < 10; i++)
{
std::cout << "label: " << static_cast<uint32_t>(labels[i]) << " ----> ";
m["0"] = migraphx::gpu::to_gpu(migraphx::argument{s, &ptr[3072 * i]});
auto result = migraphx::gpu::from_gpu(prog.eval(m));
std::vector<float> logits;
result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
std::vector<float> probs = softmax<float>(logits);
for(auto x : probs)
std::cout << x << " ";
std::cout << std::endl << std::endl;
}
}
else
{
// CPU target
prog.compile(migraphx::cpu::target{});
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 32, 32}};
auto labels = imageset.first;
auto input = imageset.second;
auto ptr = input.data();
for(int i = 0; i < 10; i++)
{
std::cout << "label: " << static_cast<uint32_t>(labels[i]) << " ----> ";
auto input3 = migraphx::argument{s, &ptr[3072 * i]};
auto result = prog.eval({{"0", input3}});
std::vector<float> logits;
result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
std::vector<float> probs = softmax<float>(logits);
for(auto x : probs)
std::cout << x << " ";
std::cout << std::endl;
}
}
}
#include <migraphx/onnx/conv.hpp>
#include <algorithm>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
void recalc_conv_attributes(value& v, size_t kdims)
{
if(not(v["padding"].size() == kdims or v["padding"].size() == kdims * 2))
{
v["padding"].resize(kdims);
std::fill_n(v["padding"].begin(), kdims, 0);
}
if(v["stride"].size() != kdims)
{
v["stride"].resize(kdims);
std::fill_n(v["stride"].begin(), kdims, 1);
}
if(v["dilation"].size() != kdims)
{
v["dilation"].resize(kdims);
std::fill_n(v["dilation"].begin(), kdims, 1);
}
}
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_CHECKS_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_CHECKS_HPP
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
void check_arg_empty(const argument& arg, const std::string& msg);
void check_attr_sizes(size_t kdims, size_t attr_size, const std::string& error_msg);
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_CONV_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_CONV_HPP
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
void recalc_conv_attributes(value& v, size_t kdims);
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_MAP_ACTIVATION_FUNCTIONS_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_MAP_ACTIVATION_FUNCTIONS_HPP
#include <migraphx/config.hpp>
#include <migraphx/operation.hpp>
#include <unordered_map>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
const std::unordered_map<std::string, operation>& map_activation_functions();
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_PARSER_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_PARSER_HPP
#include <migraphx/config.hpp>
#include <migraphx/program.hpp>
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <unordered_map>
#include <functional>
#include <utility>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
namespace onnx = onnx_for_migraphx;
struct onnx_parser
{
std::string filename;
std::string path = ".";
using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
struct node_info
{
attribute_map attributes{};
std::size_t num_outputs = 1;
std::string name = "";
module* mod = nullptr;
instruction_ref make_contiguous(instruction_ref ins) const;
instruction_ref add_bias(const std::vector<instruction_ref>& args,
instruction_ref curr_ins,
uint64_t axis) const;
instruction_ref add_broadcastable_binary_op(const std::string& op_name,
instruction_ref arg0,
instruction_ref arg1) const;
instruction_ref add_common_op(const std::string& op_name,
std::vector<instruction_ref> inputs) const;
template <class... Ts>
instruction_ref add_common_op(const std::string& op_name, Ts... xs) const
{
return add_common_op(op_name, {xs...});
}
instruction_ref add_instruction(const operation& op,
const std::vector<instruction_ref>& args) const;
instruction_ref add_instruction(const operation& op,
const std::vector<instruction_ref>& args,
const std::vector<module_ref>& mods) const;
template <class... Ts>
instruction_ref add_instruction(const operation& op, Ts... xs) const
{
return add_instruction(op, {xs...});
}
instruction_ref add_literal(literal l) const;
template <class... Ts>
instruction_ref add_literal(Ts&&... xs) const
{
return add_literal(literal{std::forward<Ts>(xs)...});
}
};
using node_map = std::unordered_map<std::string, onnx::NodeProto>;
using op_func = std::function<std::vector<instruction_ref>(
onnx_parser&, const node_info&, std::vector<instruction_ref>)>;
node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions;
program prog = program();
std::size_t default_dim_value = 1;
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
bool skip_unknown_operators = false;
int64_t max_loop_iterations = 10;
int64_t opset_version = 13;
std::unordered_map<std::string, op_func> ops;
onnx_parser();
operation load(const std::string& name, const node_info& info) const;
void parse_undefined(module* mod, const std::string& name);
static int64_t get_opset_version(const onnx::ModelProto& model);
void parse_from(std::istream& is, std::string name = "");
void parse_from(const void* data, std::size_t size);
void parse_graph(module* mod, const onnx::GraphProto& graph);
literal parse_value(const onnx::AttributeProto& attr) const;
literal parse_tensor(const onnx::TensorProto& t) const;
shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims) const;
};
shape::type_t get_type(int dtype);
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_REGISTER_OP_PARSER_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_REGISTER_OP_PARSER_HPP
#include <migraphx/config.hpp>
#include <migraphx/auto_register.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <cstring>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct op_desc
{
std::string onnx_name = "";
std::string op_name = "";
};
void register_op_parser(const std::string& name, onnx_parser::op_func f);
onnx_parser::op_func get_op_parser(const std::string& name);
std::vector<std::string> get_op_parsers();
inline std::vector<instruction_ref> implicit_multi_op(std::vector<instruction_ref> inss)
{
return inss;
}
inline std::vector<instruction_ref> implicit_multi_op(instruction_ref ins) { return {ins}; }
template <class T>
void register_op_parser()
{
T parser;
for(auto&& opd : parser.operators())
register_op_parser(opd.onnx_name, [opd, parser](auto&&... xs) {
return implicit_multi_op(parser.parse(opd, xs...));
});
}
struct register_op_parser_action
{
template <class T>
static void apply()
{
register_op_parser<T>();
}
};
template <class T>
using op_parser = auto_register<register_op_parser_action, T>;
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_PADDING_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_PADDING_HPP
#include <migraphx/config.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
bool is_asym_padding(const std::vector<int64_t>& padding);
void cal_auto_padding_size(onnx_parser::node_info info,
value& v,
const std::vector<std::size_t>& k_lens,
const std::vector<std::size_t>& dilation,
const std::vector<std::size_t>& in_lens,
std::vector<int64_t>& paddings);
void check_padding_mode(const onnx_parser::node_info& info, const std::string& op_name);
void tune_padding_size(const value& v,
std::vector<int64_t>& padding,
int count_include_pad,
std::vector<int64_t>& s_start);
void check_asym_padding(const onnx_parser::node_info& info,
instruction_ref& ins,
const std::vector<int64_t>& padding,
value& v,
int count_include_pad = 0,
float pad_val = 0);
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#include <migraphx/onnx/map_activation_functions.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
const std::unordered_map<std::string, operation>& map_activation_functions()
{
static const std::unordered_map<std::string, operation> m = {
{"tanh", make_op("tanh")},
{"relu", make_op("relu")},
{"sigmoid", make_op("sigmoid")},
{"leakyrelu", make_op("leaky_relu")},
{"elu", make_op("elu")}};
return m;
}
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <cstdio>
#include <string>
#include <fstream>
#include <numeric>
#include <stdexcept>
#include <migraphx/onnx.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/generate.hpp>
#include "softmax.hpp"
auto reverse_int(unsigned int i)
{
unsigned char c1;
unsigned char c2;
unsigned char c3;
unsigned char c4;
c1 = i & 255u;
c2 = (i >> 8u) & 255u;
c3 = (i >> 16u) & 255u;
c4 = (i >> 24u) & 255u;
return (static_cast<unsigned int>(c1) << 24u) + (static_cast<unsigned int>(c2) << 16u) +
(static_cast<unsigned int>(c3) << 8u) + c4;
};
std::vector<float>
read_mnist_images(const std::string& full_path, int& number_of_images, int& image_size)
{
using uchar = unsigned char;
std::ifstream file(full_path, std::ios::binary);
if(file.is_open())
{
int magic_number = 0;
int n_rows = 0;
int n_cols = 0;
file.read(reinterpret_cast<char*>(&magic_number), sizeof(magic_number));
magic_number = reverse_int(magic_number);
if(magic_number != 2051)
throw std::runtime_error("Invalid MNIST image file!");
file.read(reinterpret_cast<char*>(&number_of_images), sizeof(number_of_images));
number_of_images = reverse_int(number_of_images);
file.read(reinterpret_cast<char*>(&n_rows), sizeof(n_rows));
n_rows = reverse_int(n_rows);
file.read(reinterpret_cast<char*>(&n_cols), sizeof(n_cols));
n_cols = reverse_int(n_cols);
image_size = n_rows * n_cols;
std::vector<float> result(number_of_images * image_size);
for(int i = 0; i < number_of_images; i++)
{
for(int j = 0; j < image_size; j++)
{
uchar tmp;
file.read(reinterpret_cast<char*>(&tmp), 1);
result[i * image_size + j] = tmp / 255.0;
}
}
return result;
}
else
{
throw std::runtime_error("Cannot open file `" + full_path + "`!");
}
}
std::vector<int32_t> read_mnist_labels(const std::string& full_path, int& number_of_labels)
{
using uchar = unsigned char;
std::ifstream file(full_path, std::ios::binary);
if(file.is_open())
{
int magic_number = 0;
file.read(reinterpret_cast<char*>(&magic_number), sizeof(magic_number));
magic_number = reverse_int(magic_number);
if(magic_number != 2049)
throw std::runtime_error("Invalid MNIST label file!");
file.read(reinterpret_cast<char*>(&number_of_labels), sizeof(number_of_labels));
number_of_labels = reverse_int(number_of_labels);
std::vector<int32_t> result(number_of_labels);
for(int i = 0; i < number_of_labels; i++)
{
uchar tmp;
file.read(reinterpret_cast<char*>(&tmp), 1);
result[i] = tmp;
}
return result;
}
else
{
throw std::runtime_error("Unable to open file `" + full_path + "`!");
}
}
int main(int argc, char const* argv[])
{
if(argc > 3)
{
std::string datafile = argv[2];
std::string labelfile = argv[3];
int nimages = -1;
int image_size = -1;
int nlabels = -1;
std::vector<float> input = read_mnist_images(datafile, nimages, image_size);
std::vector<int32_t> labels = read_mnist_labels(labelfile, nlabels);
std::string file = argv[1];
auto prog = migraphx::parse_onnx(file);
std::cout << prog << std::endl << std::endl;
prog.compile(migraphx::gpu::target{});
auto s = migraphx::shape{migraphx::shape::float_type, {1, 1, 28, 28}};
std::cout << s << std::endl;
auto ptr = input.data();
migraphx::program::parameter_map m;
m["output"] =
migraphx::gpu::to_gpu(migraphx::generate_argument(prog.get_parameter_shape("output")));
for(int i = 0; i < 20; i++)
{
std::cout << "label: " << labels[i] << " ----> ";
m["0"] = migraphx::gpu::to_gpu(migraphx::argument{s, &ptr[784 * i]});
auto result = migraphx::gpu::from_gpu(prog.eval(m));
std::vector<float> logits;
result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
std::vector<float> probs = softmax(logits);
for(auto x : probs)
std::cout << x << " ";
std::cout << std::endl;
}
std::cout << std::endl;
}
}
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <iostream>
#include <fstream>
#include <unordered_map>
......@@ -9,1745 +8,58 @@
#include <utility>
#include <vector>
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/onnx.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct onnx_parser
template <class... Ts>
program parse_onnx_from(const onnx_options& options, Ts&&... xs)
{
using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
using node_map = std::unordered_map<std::string, onnx::NodeProto>;
using op_func =
std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>;
node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions;
program prog = program();
bool is_pytorch = false;
onnx::onnx_parser parser;
parser.map_input_dims = options.map_input_dims;
parser.default_dim_value = options.default_dim_value;
parser.skip_unknown_operators = options.skip_unknown_operators;
parser.max_loop_iterations = options.max_loop_iterations;
std::unordered_map<std::string, op_func> ops;
std::unordered_map<std::string, operation> map_actv_funcs;
onnx_parser()
{
add_generic_op("Relu", op::relu{});
add_generic_op("Sigmoid", op::sigmoid{});
add_generic_op("Abs", op::abs{});
add_generic_op("Exp", op::exp{});
add_generic_op("Erf", op::erf{});
add_generic_op("Log", op::log{});
// disable dropout for inference
add_generic_op("Dropout", op::identity{});
add_generic_op("Identity", op::identity{});
add_generic_op("Sin", op::sin{});
add_generic_op("Cos", op::cos{});
add_generic_op("Tan", op::tan{});
add_generic_op("Sinh", op::sinh{});
add_generic_op("Cosh", op::cosh{});
add_generic_op("Tanh", op::tanh{});
add_generic_op("Asin", op::asin{});
add_generic_op("Acos", op::acos{});
add_generic_op("Atan", op::atan{});
add_generic_op("Sqrt", op::sqrt{});
add_generic_op("Round", op::round{});
add_generic_op("Sign", op::sign{});
add_generic_op("Ceil", op::ceil{});
add_generic_op("Floor", op::floor{});
add_binary_op("Add", op::add{});
add_binary_op("Div", op::div{});
add_binary_op("Mul", op::mul{});
add_binary_op("Sub", op::sub{});
add_binary_op("Pow", op::pow{});
add_variadic_op("Sum", op::add{});
add_variadic_op("Max", op::max{});
add_variadic_op("Min", op::min{});
add_mem_op("ArgMax", &onnx_parser::parse_arg_op<op::argmax>);
add_mem_op("ArgMin", &onnx_parser::parse_arg_op<op::argmin>);
add_mem_op("Cast", &onnx_parser::parse_cast);
add_mem_op("Clip", &onnx_parser::parse_clip);
add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
add_mem_op("Elu", &onnx_parser::parse_elu);
add_mem_op("Expand", &onnx_parser::parse_expand);
add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("Conv", &onnx_parser::parse_conv);
add_mem_op("MaxPool", &onnx_parser::parse_pooling);
add_mem_op("AveragePool", &onnx_parser::parse_pooling);
add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
add_mem_op("Reshape", &onnx_parser::parse_reshape);
add_mem_op("Flatten", &onnx_parser::parse_flatten);
add_mem_op("Gemm", &onnx_parser::parse_gemm);
add_mem_op("MatMul", &onnx_parser::parse_matmul);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
add_mem_op("LogSoftmax", &onnx_parser::parse_softmax<op::logsoftmax>);
add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
add_mem_op("Slice", &onnx_parser::parse_slice);
add_mem_op("Concat", &onnx_parser::parse_concat);
add_mem_op("Gather", &onnx_parser::parse_gather);
add_mem_op("Shape", &onnx_parser::parse_shape);
add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("Pad", &onnx_parser::parse_pad);
add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>);
add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper<op::reduce_mean>);
add_mem_op("ReduceMin", &onnx_parser::parse_reduce_oper<op::reduce_min>);
add_mem_op("ReduceMax", &onnx_parser::parse_reduce_oper<op::reduce_max>);
// init the activation function map
init_actv_func();
}
void init_actv_func()
{
// Support name format of all lower case or the first letter capital
map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
map_actv_funcs.insert(std::make_pair("elu", op::elu{}));
}
template <class F>
void add_op(std::string name, F f)
{
ops.emplace(name, [=](auto&&... xs) {
return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
});
}
// Multi output op
template <class F>
void add_multi_op(std::string name, F f)
{
ops.emplace(name, f);
}
template <class F>
void add_mem_op(std::string name, F f)
{
add_op(name, [=](auto&&... xs) {
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
});
}
template <class T>
void add_binary_op(std::string name, T x)
{
add_op(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands");
if(contains(attributes, "broadcast") and contains(attributes, "axis"))
{
uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
if(broadcasted != 0)
{
uint64_t axis = parse_value(attributes.at("axis")).at<uint64_t>();
auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
args[1]);
return prog.add_instruction(x, args[0], l);
}
return prog.add_instruction(x, args);
}
else
{
return add_broadcastable_binary_op(args[0], args[1], x);
}
});
}
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1)
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
if(s0.size() > s1.size())
{
s0.swap(s1);
}
std::vector<std::size_t> out_lens(s1);
auto offset = s1.size() - s0.size();
std::transform(s0.begin(),
s0.end(),
s1.begin() + offset,
out_lens.begin() + offset,
[&](auto a, auto b) {
if(a != b and a != 1 and b != 1)
{
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" +
to_string_range(s0) + "} and {" +
to_string_range(s1) + "} mismatch!");
}
return std::max(a, b);
});
return out_lens;
}
instruction_ref make_contiguous(instruction_ref ins)
{
if(ins->get_shape().standard())
{
return ins;
}
return prog.add_instruction(op::contiguous{}, ins);
}
template <class T>
instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
{
if(arg0->get_shape().lens() != arg1->get_shape().lens())
{
// Get lengths for both arguments
auto s0 = arg0->get_shape().lens();
auto s1 = arg1->get_shape().lens();
auto out_lens = compute_broadcasted_lens(s0, s1);
auto l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);
return prog.add_instruction(x, l0, l1);
}
else
{
return prog.add_instruction(x, {arg0, arg1});
}
}
template <class T>
void add_generic_op(std::string name, T x)
{
add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args);
});
}
template <class T>
void add_variadic_op(std::string name, T x)
{
add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
return std::accumulate(std::next(args.begin()),
args.end(),
args.front(),
[this, x](instruction_ref a, instruction_ref b) {
return add_broadcastable_binary_op(a, b, x);
});
});
}
instruction_ref parse_clip(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
op::clip op;
if(contains(attributes, "max"))
{
op.max_val = parse_value(attributes.at("max")).at<float>();
}
if(contains(attributes, "min"))
{
op.min_val = parse_value(attributes.at("min")).at<float>();
}
return prog.add_instruction(op, std::move(args));
}
template <class Op>
instruction_ref parse_softmax(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int axis = 1;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
}
return prog.add_instruction(Op{axis}, std::move(args));
}
template <class Op>
instruction_ref parse_arg_op(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int64_t axis = 0;
if(contains(attributes, "axis"))
{
axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
if(keep_dims == 0)
{
auto ins = prog.add_instruction(Op{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{axis}}, ins);
}
else
{
return prog.add_instruction(Op{axis}, std::move(args));
}
}
instruction_ref
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
op::convolution op;
auto l0 = args[0];
if(contains(attributes, "pads"))
{
if(contains(attributes, "auto_pad"))
{
auto s = attributes["auto_pad"].s();
if(contains(attributes, "pads") and to_upper(s) != "NOTSET")
{
MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
}
}
std::vector<std::int64_t> padding;
copy(attributes["pads"].ints(), std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
{
// insert zeros for pad op (args[0] has 4 dims)
padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
l0 = prog.add_instruction(op::pad{padding}, l0);
}
else
{
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
}
if(contains(attributes, "strides"))
{
copy(attributes["strides"].ints(), op.stride.begin());
}
if(contains(attributes, "dilations"))
{
copy(attributes["dilations"].ints(), op.dilation.begin());
}
if(contains(attributes, "auto_pad"))
{
auto s = attributes["auto_pad"].s();
if(contains(attributes, "pads") and to_upper(s) != "NOTSET")
{
MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
}
if(s.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
}
}
if(contains(attributes, "group"))
{
op.group = parse_value(attributes.at("group")).at<int>();
}
if(args.size() == 3)
{
uint64_t axis = 1;
auto l1 = prog.add_instruction(op, l0, args[1]);
auto l2 = prog.add_instruction(op::broadcast{axis, l1->get_shape().lens()}, args[2]);
return prog.add_instruction(op::add{}, l1, l2);
}
return prog.add_instruction(op, l0, args[1]);
}
instruction_ref parse_pooling(const std::string& name,
attribute_map attributes,
std::vector<instruction_ref> args)
{
op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
auto l0 = args[0];
if(starts_with(name, "Global"))
{
auto lens = args.front()->get_shape().lens();
op.lengths = {lens[2], lens[3]};
}
if(contains(attributes, "pads"))
{
std::vector<std::int64_t> padding;
copy(attributes["pads"].ints(), std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
{
// insert zeros for pad op (args[0] has 4 dims)
padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
l0 = prog.add_instruction(op::pad{padding, std::numeric_limits<float>::lowest()},
l0);
}
else
{
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
}
if(contains(attributes, "strides"))
{
copy(attributes["strides"].ints(), op.stride.begin());
}
if(contains(attributes, "kernel_shape"))
{
copy(attributes["kernel_shape"].ints(), op.lengths.begin());
}
if(contains(attributes, "auto_pad"))
{
auto s = attributes["auto_pad"].s();
if(s.find("SAME_UPPER") == std::string::npos)
{
MIGRAPHX_THROW("auto_pad only supports SAME_UPPER for pooling");
}
op.padding_mode = op::padding_mode_t::same;
}
return prog.add_instruction(op, l0);
}
instruction_ref
parse_reshape(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
op::reshape op;
if(args.size() == 1)
{
literal s = parse_value(attributes.at("shape"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
if(args.size() == 2)
{
auto s = args[1]->eval();
check_arg_empty(s, "Reshape: dynamic shape is not supported");
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
return prog.add_instruction(op, make_contiguous(args[0]));
}
instruction_ref
parse_flatten(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
uint64_t axis = 1;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
}
return prog.add_instruction(op::flatten{axis}, args[0]);
}
instruction_ref
parse_squeeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
op::squeeze op;
literal s = parse_value(attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
return prog.add_instruction(op, args[0]);
}
instruction_ref
parse_unsqueeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
op::unsqueeze op;
literal s = parse_value(attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
return prog.add_instruction(op, args[0]);
}
instruction_ref
parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
// change to hande axis to be negative values
if(!contains(attributes, "axis"))
{
MIGRAPHX_THROW("PARSE_CONCAT: attribute axis is required!");
}
int axis = parse_value(attributes.at("axis")).at<int>();
op::concat op{axis};
return prog.add_instruction(op, std::move(args));
}
instruction_ref
parse_gather(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
int axis = 0;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
}
op::gather op{axis};
return prog.add_instruction(op, make_contiguous(args[0]), make_contiguous(args[1]));
}
instruction_ref
parse_slice(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
op::slice op;
std::vector<size_t> dims = args[0]->get_shape().lens();
size_t num_dims = dims.size();
if(contains(attributes, "axes"))
{
literal s = parse_value(attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
}
else
{
op.axes = std::vector<int64_t>(num_dims);
std::iota(op.axes.begin(), op.axes.end(), 0);
}
if(contains(attributes, "ends"))
{
op.ends = get_indices(attributes.at("ends"));
}
if(contains(attributes, "starts"))
{
literal s = parse_value(attributes.at("starts"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
}
return prog.add_instruction(op, args[0]);
}
instruction_ref parse_constant(const std::string&,
attribute_map attributes,
const std::vector<instruction_ref>&)
{
literal v = parse_value(attributes.at("value"));
// return empty literal
if(v.get_shape().elements() == 0)
{
return prog.add_literal(literal{});
}
auto dim_size = attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar
if(dim_size == 0)
{
migraphx::shape scalar_shape{v.get_shape().type()};
return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
}
return prog.add_literal(v);
}
instruction_ref
parse_gemm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
float alpha = 1.0f;
float beta = 1.0f;
bool transa = false;
bool transb = false;
if(contains(attributes, "alpha"))
{
alpha = parse_value(attributes.at("alpha")).at<float>();
}
if(contains(attributes, "beta"))
{
beta = parse_value(attributes.at("beta")).at<float>();
}
if(contains(attributes, "transA"))
{
transa = parse_value(attributes.at("transA")).at<bool>();
}
if(contains(attributes, "transB"))
{
transb = parse_value(attributes.at("transB")).at<bool>();
}
std::vector<int64_t> perm(args[0]->get_shape().lens().size());
std::iota(perm.begin(), perm.end(), int64_t{0});
// swap the last two elements
std::swap(*perm.rbegin(), *(perm.rbegin() + 1));
auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
if(args.size() == 3)
{
if(beta != 0.f && args[2]->get_shape().elements() > 0)
{
auto out_lens = l1->get_shape().lens();
out_lens.back() = l2->get_shape().lens().back();
auto l3 = args[2];
auto l3_lens = l3->get_shape().lens();
if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{
l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
}
return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3);
}
}
return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
}
instruction_ref
parse_matmul(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto l0 = args[0];
auto l1 = args[1];
auto l0_lens = l0->get_shape().lens();
auto l1_lens = l1->get_shape().lens();
// args[0] is a vector, prepend 1 to the shape
bool is_a_prepended = false;
if(l0_lens.size() == 1)
{
is_a_prepended = true;
l0_lens.insert(l0_lens.begin(), 1);
l0 = prog.add_instruction(op::unsqueeze{{0}}, args[0]);
}
bool is_b_appended = false;
if(l1_lens.size() == 1)
{
is_b_appended = true;
l1_lens.push_back(1);
l1 = prog.add_instruction(op::unsqueeze{{1}}, args[1]);
}
instruction_ref bl0 = l0;
instruction_ref bl1 = l1;
if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
{
auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it);
auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
l0_broadcasted_lens = output_lens;
l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
l1_broadcasted_lens = output_lens;
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
if(l0_lens != l0_broadcasted_lens)
{
bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
}
if(l1_lens != l1_broadcasted_lens)
{
bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
}
}
auto dot_res = prog.add_instruction(op::dot{1.0f, 0.0f}, bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if(is_a_prepended)
{
dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
--num_axis;
}
if(is_b_appended)
{
dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
}
return dot_res;
}
instruction_ref
parse_batchnorm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
float epsilon = 1e-5f;
float momentum = 0.9f;
op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
if(contains(attributes, "epsilon"))
{
epsilon = parse_value(attributes.at("epsilon")).at<float>();
}
if(contains(attributes, "momentum"))
{
momentum = parse_value(attributes.at("momentum")).at<float>();
}
if(contains(attributes, "spatial"))
{
bn_mode = (parse_value(attributes.at("spatial")).at<uint64_t>() > 0)
? op::batch_norm_inference::spatial
: op::batch_norm_inference::per_activation;
}
op::batch_norm_inference op{epsilon, momentum, bn_mode};
return prog.add_instruction(op, std::move(args));
}
instruction_ref parse_leaky_relu(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
float alpha = 0.01; // default alpha val for leaky relu
if(contains(attributes, "alpha"))
{
alpha = parse_value(attributes.at("alpha")).at<float>();
}
op::leaky_relu op{alpha};
return prog.add_instruction(op, args.front());
}
instruction_ref
parse_elu(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
float alpha = 1.0; // default alpha val for elu
if(contains(attributes, "alpha"))
{
alpha = parse_value(attributes.at("alpha")).at<float>();
}
op::elu op{alpha};
return prog.add_instruction(op, args.front());
}
instruction_ref
parse_lrn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
float alpha = 0.0001;
float beta = 0.75;
float bias = 1.0;
int size = 1;
if(contains(attributes, "alpha"))
alpha = parse_value(attributes.at("alpha")).at<float>();
if(contains(attributes, "beta"))
beta = parse_value(attributes.at("beta")).at<float>();
if(contains(attributes, "bias"))
bias = parse_value(attributes.at("bias")).at<float>();
if(contains(attributes, "size"))
size = parse_value(attributes.at("size")).at<int>();
op::lrn op{alpha, beta, bias, size};
return prog.add_instruction(op, args.front());
}
instruction_ref parse_imagescaler(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
float scale = 1.0;
std::vector<float> bias{};
if(contains(attributes, "scale"))
{
scale = parse_value(attributes.at("scale")).at<float>();
}
if(contains(attributes, "bias"))
{
auto&& bias_floats = attributes["bias"].floats();
bias = std::vector<float>(bias_floats.begin(), bias_floats.end());
}
auto input_lens = args.front()->get_shape().lens();
auto scale_val = prog.add_literal(scale);
auto bias_vals = prog.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {bias.size()}}, bias});
auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
auto img_scaled = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
auto bias_bcast = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
}
instruction_ref
parse_transpose(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
std::vector<int64_t> perm{};
if(contains(attributes, "perm"))
{
auto&& perm_vals = attributes["perm"].ints();
perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
}
return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
}
instruction_ref
parse_pad(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
std::vector<int64_t> pads{};
float value = 0.0f;
if(contains(attributes, "pads"))
{
auto&& pad_vals = attributes["pads"].ints();
pads = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
}
// check if padding is actually being done (at least one value is nonzero)
if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
{
return prog.add_instruction(migraphx::op::identity{}, args.front());
}
if(contains(attributes, "value"))
{
value = parse_value(attributes.at("value")).at<float>();
}
if(contains(attributes, "mode"))
{
auto mode = attributes.at("mode").s();
if(mode != "constant")
MIGRAPHX_THROW("migraphx currently only supports constant padding");
}
return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
}
// Use a literal instruction to replace the shape since, output of
// shape operator are literals in migraphx
instruction_ref
parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
if(args.size() != 1)
MIGRAPHX_THROW("Shape: operator should have 1 operand");
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
std::vector<int64_t> vec_shape(arg_shape.size());
migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
return int64_t(i);
});
return prog.add_literal(migraphx::literal{s, vec_shape});
}
// Use a literal instruction to replace the constantFill operator. In RNN, input shape
// and value are fixed, so no need to do the actual computation for the constantFill
// operator
instruction_ref parse_constant_fill(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
int input_as_shape = 0;
int dtype = 1;
float value = 0.0f;
if(contains(attributes, "dtype"))
{
dtype = parse_value(attributes.at("dtype")).at<int>();
}
shape::type_t type = get_type(dtype);
if(contains(attributes, "input_as_shape"))
{
input_as_shape = parse_value(attributes.at("input_as_shape")).at<int>();
}
if(contains(attributes, "value"))
{
value = parse_value(attributes.at("value")).at<float>();
}
if(contains(attributes, "extra_shape"))
{
MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
}
if(input_as_shape == 1)
{
if(args.size() != 1)
{
MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
}
if(contains(attributes, "shape"))
{
MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
"at the same time");
}
migraphx::argument in = args[0]->eval();
check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
migraphx::shape s(type, dims);
std::vector<float> values(s.elements(), value);
return prog.add_literal(migraphx::literal(s, values));
}
else if(input_as_shape == 0)
{
if(!contains(attributes, "shape"))
{
MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
}
literal ls = parse_value(attributes.at("shape"));
std::vector<std::size_t> dims;
ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
migraphx::shape s{type, dims};
std::vector<float> values(s.elements(), value);
return prog.add_literal(migraphx::literal(s, values));
}
else
{
MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
}
}
instruction_ref parse_constant_of_shape(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
if(options.print_program_on_error)
{
literal l_val{};
if(contains(attributes, "value"))
{
l_val = parse_value(attributes.at("value"));
if(l_val.get_shape().elements() != 1)
{
MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
}
}
else
// Log the program when it can't be parsed
try
{
l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
parser.parse_from(std::forward<Ts>(xs)...);
}
// input is empty, output is a scalar
auto type = l_val.get_shape().type();
if(args.empty())
{
MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
}
else
catch(...)
{
migraphx::shape s;
// empty input tensor, output is a scalar
if(args[0]->get_shape().elements() == 0)
{
s = migraphx::shape{type, {1}, {0}};
}
else
{
migraphx::argument in = args[0]->eval();
check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
s = migraphx::shape{type, dims};
}
literal l_out{};
l_val.visit([&](auto val) {
using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
// l_val contains only one element
std::vector<val_type> out_vec(s.elements(), val.front());
l_out = literal(s, out_vec);
});
return prog.add_literal(l_out);
std::cerr << parser.prog << std::endl;
throw;
}
}
instruction_ref
parse_expand(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto in_lens = args[0]->get_shape().lens();
migraphx::argument arg_s = args[1]->eval();
check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims);
return prog.add_instruction(op::multibroadcast{out_lens}, args[0]);
}
std::vector<instruction_ref>
parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
else
{
migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[1]->get_shape().lens()[1];
if(contains(attributes, "hidden_size"))
{
std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att)
{
MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
}
}
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(attributes, "direction"))
{
direction = attributes.at("direction").s();
}
op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional")
{
dirct = op::rnn_direction::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::rnn_direction::reverse;
}
std::vector<std::string> vec_names{"tanh"};
if(contains(attributes, "activations"))
{
auto names = attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
return to_lower(name);
});
}
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0);
});
if(name_it != vec_names.end())
{
MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
}
// bidirectional case should have two activation functions.
// one is for forward, and the other is for reverse.
// if only one actv function is provided, we use it in both
// forward and reverse direction
if(dirct == op::rnn_direction::bidirectional)
{
if(vec_names.size() == 1)
{
vec_names.push_back(vec_names.at(0));
}
}
std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(),
vec_names.end(),
vec_actv_funcs.begin(),
[&](const auto& fn) { return map_actv_funcs[fn]; });
// To be added later
float clip = 0.0;
if(contains(attributes, "clip"))
{
clip = parse_value(attributes.at("clip")).at<float>();
}
// if the number of arguments is less than 6, append
// undefined operator to have 6 arguments
if(args.size() < 6)
{
auto ins = prog.add_instruction(op::undefined{});
args.insert(args.end(), (6 - args.size()), ins);
}
// first output for the concatenation of hidden states
auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
std::move(args));
// second output for the last hidden state
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
return {hidden_states, last_output};
}
std::vector<instruction_ref>
parse_gru(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2];
if(contains(attributes, "hidden_size"))
{
std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att)
{
MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
}
}
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(attributes, "direction"))
{
direction = attributes.at("direction").s();
}
op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional")
{
dirct = op::rnn_direction::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::rnn_direction::reverse;
}
std::vector<std::string> vec_names = {"sigmoid", "tanh"};
if(contains(attributes, "activations"))
{
auto names = attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
return to_lower(name);
});
}
// need 4 activation functions
if(dirct == op::rnn_direction::bidirectional)
{
// 4 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provided,
// repeat 1 four times. If 2 actv functins are provided,
// assume forward and reverse use the same pair of actv
// functions. For the case of 3 actv functions provided,
// assume the 3rd one is repeated once and used by the
// reverse direction.
// This may need change later
if(vec_names.size() == 1)
{
vec_names.insert(vec_names.end(), 3, vec_names.at(0));
}
else if(vec_names.size() == 2)
{
// repeat the activation functions
vec_names.push_back(vec_names.at(0));
vec_names.push_back(vec_names.at(1));
}
else if(vec_names.size() == 3)
{
vec_names.push_back(vec_names.at(2));
}
}
else
{
if(vec_names.size() == 1)
{
vec_names.push_back(vec_names.at(0));
}
}
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0);
});
if(name_it != vec_names.end())
{
MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
}
std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(),
vec_names.end(),
vec_actv_funcs.begin(),
[&](const auto& name) { return map_actv_funcs[name]; });
float clip = 0.0;
if(contains(attributes, "clip"))
{
clip = parse_value(attributes.at("clip")).at<float>();
}
int linear_before_reset = 0;
if(contains(attributes, "linear_before_reset"))
{
linear_before_reset = parse_value(attributes.at("linear_before_reset")).at<int>();
}
// append undefined opeator to make 6 arguments
if(args.size() < 6)
{
auto ins = prog.add_instruction(op::undefined{});
args.insert(args.end(), 6 - args.size(), ins);
}
// first output for concatenation of hidden states
auto hidden_states = prog.add_instruction(
op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
std::move(args));
// second output for last gru output
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
return {hidden_states, last_output};
}
std::vector<instruction_ref>
parse_lstm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2];
if(contains(attributes, "hidden_size"))
{
std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att)
{
MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
}
}
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(attributes, "direction"))
{
direction = attributes.at("direction").s();
}
op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional")
{
dirct = op::rnn_direction::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::rnn_direction::reverse;
}
else if(direction == "forward")
{
dirct = op::rnn_direction::forward;
}
else
{
MIGRAPHX_THROW("LSTM: incorrect direction attribute");
}
std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
if(contains(attributes, "activations"))
{
auto names = attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
return to_lower(name);
});
}
// need 6 activation functions for bidirectional directions
if(dirct == op::rnn_direction::bidirectional)
{
// 6 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provided,
// repeat 1st six times. If 2 actv functins are provided,
// repeat 2nd once, then repeat all three once
// if 3 actv funcs are provide, repeat all three once.
// the same algorithm is used for 4, 5, and 6 actv funcions
// provided. This may need change later
switch(vec_names.size())
{
case 1:
vec_names = {vec_names.at(0),
vec_names.at(0),
vec_names.at(0),
vec_names.at(0),
vec_names.at(0),
vec_names.at(0)};
break;
case 2:
// repeat the 2nd actv func once, then repeat all three another time
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(1),
vec_names.at(0),
vec_names.at(1),
vec_names.at(1)};
break;
case 3:
// repeat all three actv funcs once
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(2),
vec_names.at(0),
vec_names.at(1),
vec_names.at(2)};
break;
case 4:
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(2),
vec_names.at(3),
vec_names.at(3),
vec_names.at(3)};
break;
case 5:
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(2),
vec_names.at(3),
vec_names.at(4),
vec_names.at(4)};
break;
default: break;
}
}
else
{
switch(vec_names.size())
{
case 1: vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)}; break;
case 2:
// repeat the 2nd actv func once, so we have 3 actv funcs
vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)};
break;
default: break;
}
}
auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
return (map_actv_funcs.count(name) == 0);
});
if(name_it != vec_names.end())
{
MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
}
std::vector<operation> vec_actv_funcs(vec_names.size());
std::transform(vec_names.begin(),
vec_names.end(),
vec_actv_funcs.begin(),
[&](const auto& name) { return map_actv_funcs[name]; });
float clip = 0.0;
if(contains(attributes, "clip"))
{
clip = parse_value(attributes.at("clip")).at<float>();
}
int input_forget = 0;
if(contains(attributes, "input_forget"))
{
input_forget = parse_value(attributes.at("input_forget")).at<int>();
}
// append undefined opeator to make 6 arguments
if(args.size() < 8)
{
auto ins = prog.add_instruction(op::undefined{});
args.insert(args.end(), 8 - args.size(), ins);
}
// first output for concatenation of hidden states
auto hidden_states = prog.add_instruction(
op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
// second output for last lstm output
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
// third output for last cell output
auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);
return {hidden_states, last_output, last_cell_output};
}
template <class T>
instruction_ref parse_reduce_oper(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
std::size_t n_dim = args.front()->get_shape().lens().size();
// default to reduce over all dimensions
std::vector<int64_t> axes(n_dim);
std::iota(axes.begin(), axes.end(), 0);
if(contains(attributes, "axes"))
{
axes.clear();
auto&& attr_axes = attributes["axes"].ints();
axes = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
if(keep_dims == 1)
{
return prog.add_instruction(T{axes}, std::move(args));
}
else
{
auto ins = prog.add_instruction(T{axes}, std::move(args));
return prog.add_instruction(op::squeeze{axes}, ins);
}
}
instruction_ref
parse_cast(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
if(!contains(attributes, "to"))
{
MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
}
int to_type = parse_value(attributes.at("to")).at<int>();
shape::type_t type = get_type(to_type);
return prog.add_instruction(op::convert{type}, std::move(args));
}
void parse_from(std::istream& is)
{
onnx::ModelProto model;
if(model.ParseFromIstream(&is))
{
if(model.has_graph())
{
this->parse_graph(model.graph());
}
}
else
{
MIGRAPHX_THROW("Failed reading onnx file.");
}
}
void parse_graph(const onnx::GraphProto& graph)
{
nodes = get_nodes(graph);
for(auto&& f : graph.initializer())
instructions[f.name()] = prog.add_literal(parse_tensor(f));
for(auto&& input : graph.input())
{
const std::string& name = input.name();
// input not in initializer_data, so it is a real input
if(!contains(instructions, name))
{
// TODO: Get shape of input parameter
shape s = parse_type(input.type());
instructions[name] = prog.add_parameter(name, s);
}
}
for(auto&& output : graph.output())
{
this->parse_node(output.name());
}
}
void parse_undefined(const std::string& name)
{
auto ins = prog.add_instruction(op::undefined{});
instructions[name] = ins;
}
void parse_node(const std::string& name)
{
if(name.empty())
MIGRAPHX_THROW("Onnx node must have a name");
if(instructions.count(name) == 0)
{
auto&& node = nodes.at(name);
std::vector<instruction_ref> args;
for(auto&& input : node.input())
{
if(nodes.count(input) > 0)
{
assert(name != input);
this->parse_node(input);
}
else if(input.empty())
{
this->parse_undefined(input);
}
args.push_back(instructions.at(input));
}
std::vector<instruction_ref> result;
if(ops.count(node.op_type()) == 0)
{
result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args));
}
else
{
result = ops[node.op_type()](get_attributes(node), args);
}
// Even no output nodes produce output in migraphx
if(node.output().empty() and result.size() == 1)
{
instructions[name] = result.front();
}
else
{
assert(node.output().size() >= result.size());
std::transform(result.begin(),
result.end(),
node.output().begin(),
std::inserter(instructions, instructions.end()),
[](auto&& x, auto&& y) { return std::make_pair(y, x); });
}
}
}
static attribute_map get_attributes(const onnx::NodeProto& node)
{
std::unordered_map<std::string, onnx::AttributeProto> result;
for(auto&& attr : node.attribute())
{
result[attr.name()] = attr;
}
return result;
}
static node_map get_nodes(const onnx::GraphProto& graph)
{
std::unordered_map<std::string, onnx::NodeProto> result;
std::size_t n = 0;
for(auto&& node : graph.node())
{
if(node.output().empty())
{
if(node.name().empty())
{
result["migraphx_unamed_node_" + std::to_string(n)] = node;
n++;
}
else
{
result[node.name()] = node;
}
}
for(auto&& output : node.output())
{
result[output] = node;
}
}
return result;
}
static std::vector<int64_t> get_indices(const onnx::AttributeProto& attr)
{
std::vector<int64_t> result;
literal s = parse_value(attr);
s.visit([&](auto v) { copy(v, std::back_inserter(result)); });
// Clamp large indices to -1
std::replace_if(
result.begin(),
result.end(),
[](auto x) { return x > int64_t{std::numeric_limits<std::int32_t>::max()} / 2; },
-1);
return result;
}
template <class T>
static literal from_repeated(shape::type_t t, const T& r)
{
std::size_t size = r.size();
return literal{{t, {size}}, r.begin(), r.end()};
}
static literal parse_value(const onnx::AttributeProto& attr)
{
switch(attr.type())
{
case onnx::AttributeProto::FLOAT: return literal{attr.f()};
case onnx::AttributeProto::INT: return literal{attr.i()};
case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
case onnx::AttributeProto::UNDEFINED:
case onnx::AttributeProto::GRAPH:
case onnx::AttributeProto::STRING:
case onnx::AttributeProto::STRINGS:
case onnx::AttributeProto::TENSORS:
case onnx::AttributeProto::GRAPHS: return {};
}
MIGRAPHX_THROW("Invalid attribute type");
}
static literal parse_tensor(const onnx::TensorProto& t)
{
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
if(t.has_raw_data())
{
const std::string& s = t.raw_data();
switch(t.data_type())
{
case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
case onnx::TensorProto::FLOAT16:
return create_literal(shape::half_type, dims, s.data());
case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, s.data());
case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
case onnx::TensorProto::INT8:
case onnx::TensorProto::UINT16:
case onnx::TensorProto::INT16:
case onnx::TensorProto::INT32:
case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::UINT8:
case onnx::TensorProto::STRING:
case onnx::TensorProto::UNDEFINED:
case onnx::TensorProto::UINT32:
case onnx::TensorProto::UINT64:
case onnx::TensorProto::COMPLEX64:
case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
}
MIGRAPHX_THROW("Invalid tensor type");
}
switch(t.data_type())
{
case onnx::TensorProto::INT8:
case onnx::TensorProto::UINT16:
case onnx::TensorProto::INT16:
case onnx::TensorProto::INT32:
case onnx::TensorProto::BOOL:
return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT64:
return create_literal(shape::int64_type, dims, t.int64_data());
case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, t.double_data());
case onnx::TensorProto::FLOAT:
return create_literal(shape::float_type, dims, t.float_data());
case onnx::TensorProto::FLOAT16:
{
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
std::vector<half> data_half;
std::transform(data_uint16.begin(),
data_uint16.end(),
std::back_inserter(data_half),
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return create_literal(shape::half_type, dims, data_half);
}
case onnx::TensorProto::UNDEFINED:
case onnx::TensorProto::UINT8:
case onnx::TensorProto::STRING:
case onnx::TensorProto::UINT32:
case onnx::TensorProto::UINT64:
case onnx::TensorProto::COMPLEX64:
case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
}
MIGRAPHX_THROW("Invalid tensor type");
}
static literal
create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const char* data)
{
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if(dims.empty())
return literal{{shape_type}, data};
return literal{{shape_type, dims}, data};
}
template <class T, MIGRAPHX_REQUIRES(not std::is_pointer<T>{})>
static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
{
if(dims.empty())
return literal{{shape_type}, data.begin(), data.end()};
return literal{{shape_type, dims}, data.begin(), data.end()};
}
static shape parse_type(const onnx::TypeProto& t)
{
shape::type_t shape_type{};
switch(t.tensor_type().elem_type())
{
case onnx::TensorProto::FLOAT: shape_type = shape::float_type; break;
case onnx::TensorProto::INT8: shape_type = shape::int8_type; break;
case onnx::TensorProto::UINT16: shape_type = shape::uint16_type; break;
case onnx::TensorProto::INT16: shape_type = shape::int16_type; break;
case onnx::TensorProto::INT32: shape_type = shape::int32_type; break;
case onnx::TensorProto::INT64: shape_type = shape::int64_type; break;
case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break;
case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
case onnx::TensorProto::UINT8:
case onnx::TensorProto::STRING:
case onnx::TensorProto::BOOL:
case onnx::TensorProto::UNDEFINED:
case onnx::TensorProto::COMPLEX64:
case onnx::TensorProto::COMPLEX128:
break; // throw std::runtime_error("Unsupported type");
}
std::vector<std::size_t> dims;
auto&& tensor_dims = t.tensor_type().shape().dim();
std::transform(tensor_dims.begin(),
tensor_dims.end(),
std::back_inserter(dims),
[](auto&& d) -> std::size_t {
if(not d.has_dim_value())
{
long default_batch_size = 1; // FIXME
return default_batch_size;
}
return d.dim_value();
});
return {shape_type, dims};
parser.parse_from(std::forward<Ts>(xs)...);
}
return std::move(parser.prog);
}
shape::type_t get_type(int dtype)
{
switch(dtype)
{
case 1: return shape::float_type;
case 2: return shape::uint8_type;
case 3: return shape::int8_type;
case 4: return shape::uint16_type;
case 5: return shape::int16_type;
case 6: return shape::int32_type;
case 7: return shape::int64_type;
case 10: return shape::half_type;
case 11: return shape::double_type;
case 12: return shape::uint32_type;
case 13: return shape::uint64_type;
default:
{
MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
}
}
}
program parse_onnx(const std::string& name, const onnx_options& options)
{
std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
return parse_onnx_from(options, input, name);
}
void check_arg_empty(const argument& arg, const std::string& msg)
{
if(arg.empty())
{
MIGRAPHX_THROW(msg);
}
}
};
program parse_onnx_buffer(const std::string& buffer, const onnx_options& options)
{
return parse_onnx_from(options, buffer.data(), buffer.size());
}
program parse_onnx(const std::string& name)
program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options& options)
{
std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
onnx_parser parser;
#ifndef NDEBUG
// Log the program when it can't be parsed
try
{
parser.parse_from(input);
}
catch(...)
{
std::cerr << parser.prog << std::endl;
throw;
}
#else
parser.parse_from(input);
#endif
return std::move(parser.prog);
return parse_onnx_from(options, data, size);
}
std::vector<std::string> get_onnx_operators() { return onnx::get_op_parsers(); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment