Unverified Commit 1b098fd7 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into type-string-driver

parents 05f2ee1c c0398ded
#include <migraphx/gpu/preallocate_param.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/program.hpp>
#include <migraphx/preallocate_param.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/module.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
void preallocate_param::apply(program& p) const
void preallocate_param::apply(module& m) const
{
for(auto ins : iterator_for(p))
auto last = std::prev(m.end());
for(auto ins : iterator_for(m))
{
if(ins->name() != "@param")
continue;
std::string id = any_cast<builtin::param>(ins->get_operator()).parameter;
if(id != param)
if(param != any_cast<builtin::param>(ins->get_operator()).parameter)
continue;
argument a = allocate_gpu(ins->get_shape());
ctx->get_current_device().preallocations[id] = a;
auto r = p.insert_instruction(ins, hip_load_memory{a.get_shape(), id});
p.replace_instruction(ins, r);
std::string id = m.name() + ":" + param;
auto r = m.insert_instruction(ins, model.preallocate(ins->get_shape(), id));
m.replace_instruction(ins, r);
m.move_instruction(ins, m.end());
}
m.remove_instructions(std::next(last), m.end());
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/process.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/env.hpp>
#include <functional>
#include <iostream>
#include <unistd.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_CMD_EXECUTE)
std::function<void(const char*)> redirect_to(std::ostream& os)
{
return [&](const char* x) { os << x; };
}
int exec(const std::string& cmd, const std::function<void(const char*)>& std_out)
{
int ec = 0;
if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{}))
std::cout << cmd << std::endl;
auto closer = [&](FILE* stream) {
auto status = pclose(stream);
ec = WIFEXITED(status) ? 0 : WEXITSTATUS(status); // NOLINT
};
{
// TODO: Use execve instead of popen
std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), "r"), closer); // NOLINT
if(!pipe)
MIGRAPHX_THROW("popen() failed: " + cmd);
std::array<char, 128> buffer;
while(fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr)
std_out(buffer.data());
}
return ec;
}
struct process_impl
{
std::string command{};
fs::path cwd{};
std::string get_command() const
{
std::string result;
if(not cwd.empty())
result += "cd " + cwd.string() + "; ";
result += command;
return result;
}
};
process::process(const std::string& cmd) : impl(std::make_unique<process_impl>())
{
impl->command = cmd;
}
process::process(process&&) noexcept = default;
process& process::operator=(process rhs)
{
std::swap(impl, rhs.impl);
return *this;
}
process::~process() noexcept = default;
process& process::cwd(const fs::path& p)
{
impl->cwd = p;
return *this;
}
void process::exec()
{
auto ec = migraphx::exec(impl->get_command(), redirect_to(std::cout));
if(ec != 0)
MIGRAPHX_THROW("Command " + impl->get_command() + " exited with status " +
std::to_string(ec));
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -6,87 +6,38 @@
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/output_iterator.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/marker.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
#include <set>
#include <utility>
#include <unordered_set>
#include <map>
#include <cassert>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
using milliseconds = std::chrono::duration<double, std::milli>;
struct program_impl
{
// A list is used to keep references to an instruction stable
std::list<instruction> instructions;
// A map is used to keep references to modules of the program
std::unordered_map<std::string, module> modules;
context ctx;
std::string target_name;
};
const operation& get_operation(instruction_ref ins) { return ins->get_operator(); }
static void print_instruction(std::ostream& os,
instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& names)
{
os << names.at(ins) << " = ";
os << ins->get_operator();
if(ins->name() == "@literal")
{
if(ins->get_literal().get_shape().elements() > 10)
os << "{ ... }";
else
os << "{" << ins->get_literal() << "}";
}
if(!ins->inputs().empty())
{
char delim = '(';
for(auto&& arg : ins->inputs())
{
os << delim << names.at(arg);
delim = ',';
}
os << ")";
}
os << " -> " << ins->get_shape();
}
template <class F>
static void print_program(const program& p, F print_func)
{
std::unordered_map<instruction_ref, std::string> names;
int count = 0;
for(auto ins : iterator_for(p))
{
std::string var_name;
if(ins->name() == "@param")
{
var_name = any_cast<builtin::param>(ins->get_operator()).parameter;
}
else
{
var_name = "@" + std::to_string(count);
count++;
}
names.emplace(ins, var_name);
// TODO: Use all_of
for(auto&& arg : ins->inputs())
{
assert(p.has_instruction(arg) && "Instruction not found");
(void)arg;
}
print_func(ins, names);
}
}
program::program() : impl(std::make_unique<program_impl>()) {}
program::program() : impl(std::make_unique<program_impl>()) { this->create_module("main"); }
program::program(program&&) noexcept = default;
program::~program() noexcept = default;
......@@ -103,287 +54,204 @@ program& program::operator=(program p)
void program::assign(const program& p)
{
// clean the current program
if(!impl)
{
impl = std::make_unique<program_impl>();
}
else if(!impl->instructions.empty())
else if(!impl->modules.empty())
{
impl->instructions.clear();
impl->modules.clear();
}
impl->ctx = p.impl->ctx;
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(p))
{
instruction_ref copy_ins{};
if(ins->name() == "@literal")
{
auto l = ins->get_literal();
copy_ins = impl->instructions.insert(impl->instructions.end(), instruction{l});
}
else if(ins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins->get_operator()).parameter;
auto s = ins->get_shape();
copy_ins = impl->instructions.insert(impl->instructions.end(),
{builtin::param{name}, std::move(s), {}});
}
else if(ins->name() == "@outline")
{
auto s = ins->get_shape();
copy_ins =
impl->instructions.insert(impl->instructions.end(), {builtin::outline{s}, s, {}});
}
else
{
// retrieve its mapped input
auto inputs = ins->inputs();
// ensure all inputs have its corresponding copy instructions
assert(std::all_of(
inputs.begin(), inputs.end(), [&](auto i) { return ins_map.count(i) > 0; }));
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return ins_map[i];
});
copy_ins = add_instruction(ins->get_operator(), copy_inputs);
}
ins_map[ins] = copy_ins;
}
}
instruction_ref program::add_instruction(const operation& op, std::vector<instruction_ref> args)
{
return insert_instruction(impl->instructions.end(), op, std::move(args));
}
instruction_ref program::insert_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args)
{
assert(std::all_of(
args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args);
auto result = impl->instructions.insert(ins, {op, r, std::move(args)});
instruction::backreference(result);
assert(result->valid(begin()));
return result;
}
impl->ctx = p.impl->ctx;
impl->target_name = p.impl->target_name;
impl->modules = p.impl->modules;
instruction_ref program::replace_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args)
{
assert(std::all_of(
args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args);
instruction::replace(ins, op, r, std::move(args));
assert(ins->valid(begin()));
return ins;
}
// build a map from old ins to new ins
// Build a map from old module to new module
std::unordered_map<module_ref, module_ref> mod_map;
std::transform(
impl->modules.begin(),
impl->modules.end(),
std::inserter(mod_map, mod_map.begin()),
[&](auto&& xp) { return std::make_pair(&p.impl->modules.at(xp.first), &xp.second); });
instruction_ref program::replace_instruction(instruction_ref ins, instruction_ref rep)
{
assert(has_instruction(ins));
assert(has_instruction(rep));
assert(ins != rep);
if(ins == std::prev(this->end()))
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto&& pp : mod_map)
{
return replace_instruction(ins, op::identity{}, rep);
auto old_ins = iterator_for(*pp.first);
auto new_ins = iterator_for(*pp.second);
std::transform(old_ins.begin(),
old_ins.end(),
new_ins.begin(),
std::inserter(ins_map, ins_map.begin()),
[](auto x, auto y) { return std::make_pair(x, y); });
}
// TODO: Should it be an error if the output is empty?
if(ins->outputs().empty())
{
return rep;
}
// Make a copy of outputs which can be changed when calling replace_argument
auto outputs = ins->outputs();
for(auto out : outputs)
// Update all references from all modules
for(auto&& mp : impl->modules)
{
// TODO: Check for possible cycles
if(out != rep)
{
instruction::replace_argument(out, ins, rep);
}
assert(out->valid(begin()));
for(auto ins : iterator_for(mp.second))
instruction::replace_refs(ins, ins_map, mod_map);
}
// Replacement should not be dead code unless its the last instruction
assert(!rep->outputs().empty() or rep == std::prev(end()));
// Output of the original instruction should only be the replacement or empty
assert(ins->outputs().empty() or std::all_of(ins->outputs().begin(),
ins->outputs().end(),
[&](auto i) { return i == rep; }));
assert(ins->valid(begin()));
assert(rep->valid(begin()));
return rep;
}
instruction_ref program::remove_instruction(instruction_ref ins)
{
assert(has_instruction(ins));
assert(ins->outputs().empty());
ins->clear_arguments();
return impl->instructions.erase(ins);
}
instruction_ref program::remove_instructions(instruction_ref first, instruction_ref last)
{
if(first == last)
return first;
// TODO: Check every element
assert(has_instruction(first));
std::for_each(first, last, [&](instruction& ins) { ins.clear_arguments(); });
assert(std::all_of(first, last, [&](const instruction& ins) { return ins.outputs().empty(); }));
return impl->instructions.erase(first, last);
}
instruction_ref program::move_instruction(instruction_ref src, instruction_ref dst)
{
impl->instructions.splice(dst, impl->instructions, src);
return src;
}
instruction_ref program::add_literal(literal l)
{
impl->instructions.emplace_front(std::move(l));
return impl->instructions.begin();
}
instruction_ref program::add_outline(const shape& s)
shape program::get_parameter_shape(std::string name) const
{
impl->instructions.push_front({builtin::outline{s}, s, {}});
return impl->instructions.begin();
const auto* mm = this->get_main_module();
return mm->get_parameter_shape(std::move(name));
}
instruction_ref program::add_parameter(std::string name, shape s)
std::vector<std::string> program::get_parameter_names() const
{
assert(get_parameter_shape(name) == shape{});
impl->instructions.push_front({builtin::param{std::move(name)}, std::move(s), {}});
return impl->instructions.begin();
}
shape program::get_parameter_shape(std::string name) const
{
auto ins = std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
if(x.name() == "@param")
{
return any_cast<builtin::param>(x.get_operator()).parameter == name;
}
else
{
return false;
}
});
if(ins != this->end())
return ins->get_shape();
else
return {};
const auto* mm = this->get_main_module();
return mm->get_parameter_names();
}
instruction_ref program::get_parameter(std::string name) const
{
auto ins = std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
if(x.name() == "@param")
{
return any_cast<builtin::param>(x.get_operator()).parameter == name;
}
else
{
return false;
}
});
if(ins != this->end())
return ins;
else
return this->end();
const auto* mm = this->get_main_module();
return mm->get_parameter(std::move(name));
}
std::unordered_map<std::string, shape> program::get_parameter_shapes() const
{
std::unordered_map<std::string, shape> result;
for(auto&& ins : impl->instructions)
{
if(ins.name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins.get_operator()).parameter;
result[name] = ins.get_shape();
}
}
return result;
const auto* mm = this->get_main_module();
return mm->get_parameter_shapes();
}
bool program::has_instruction(instruction_ref ins) const
std::size_t program::size() const { return impl->modules.size(); }
std::vector<shape> program::get_output_shapes() const
{
return std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
return std::addressof(*ins) == std::addressof(x);
}) != impl->instructions.end();
const auto* mm = this->get_main_module();
return mm->get_output_shapes();
}
std::size_t program::size() const { return impl->instructions.size(); }
instruction_ref program::begin() const { return impl->instructions.begin(); }
instruction_ref program::end() const { return impl->instructions.end(); }
shape program::get_shape() const { return impl->instructions.back().get_shape(); }
context& program::get_context() const { return impl->ctx; }
instruction_ref program::validate() const
{
return std::find_if(impl->instructions.begin(),
impl->instructions.end(),
[&](const instruction& i) { return !i.valid(impl->instructions.begin()); });
const auto* mm = this->get_main_module();
return mm->validate();
}
bool program::is_compiled() const { return not this->impl->target_name.empty(); }
void program::compile(const target& t, compile_options options)
{
assert(this->validate() == impl->instructions.end());
this->impl->ctx = t.get_context();
assert(not this->is_compiled());
this->impl->target_name = t.name();
this->impl->ctx = t.get_context();
if(enabled(MIGRAPHX_TRACE_COMPILE{}))
options.trace = tracer{std::cout};
options.trace(*this);
options.trace();
run_passes(*this, t.get_passes(this->impl->ctx, options), options.trace);
auto invalid = this->validate();
if(invalid != impl->instructions.end())
auto&& passes = t.get_passes(this->impl->ctx, options);
run_passes(*this, passes, options.trace);
auto mods = this->get_modules();
// Validate and finalize
for(const auto& mod : reverse(mods))
{
auto index = std::distance(impl->instructions.begin(), invalid);
MIGRAPHX_THROW("Invalid program from compilation at instruction " + std::to_string(index));
auto invalid = mod->validate();
if(invalid != mod->end())
{
MIGRAPHX_THROW("Invalid module " + mod->name() + " from compilation at instruction " +
std::to_string(std::distance(mod->begin(), invalid)));
}
auto dangling = mod->find_dangling_reference();
if(dangling != mod->end())
{
auto index = std::distance(mod->begin(), dangling);
MIGRAPHX_THROW("Dangling reference in module " + mod->name() + " from instruction " +
std::to_string(index));
}
mod->finalize(this->impl->ctx);
}
this->finalize();
}
void program::finalize()
{
for(auto ins : iterator_for(*this))
auto* mm = this->get_main_module();
mm->finalize(this->impl->ctx);
}
template <class T>
std::string classify(T x)
{
switch(std::fpclassify(x))
{
ins->finalize(this->impl->ctx);
case FP_INFINITE: return "inf";
case FP_NAN: return "nan";
case FP_NORMAL: return "normal";
case FP_SUBNORMAL: return "subnormal";
case FP_ZERO: return "zero";
default: return "unknown";
}
}
std::unordered_set<std::string> classify_argument(const argument& a)
{
std::unordered_set<std::string> result;
a.visit(
[&](auto t) {
for(const auto& x : t)
result.insert(classify(x));
},
[&](const auto& xs) {
for(const auto& x : xs)
{
auto r = classify_argument(x);
result.insert(r.begin(), r.end());
}
});
return result;
}
void preview_argument(std::ostream& os, const argument& a)
{
a.visit(
[&](auto t) {
if(t.size() <= 10)
{
os << t;
}
else
{
os << to_string_range(t.begin(), t.begin() + 5);
os << ", ..., ";
os << to_string_range(t.end() - 5, t.end());
}
},
[&](const auto& xs) {
for(const auto& x : xs)
{
os << '{';
preview_argument(os, x);
os << '}';
}
});
}
template <class F>
argument generic_eval(const program& p,
context& ctx,
std::unordered_map<std::string, argument> params,
F trace)
{
assert(p.validate() == p.end());
std::unordered_map<instruction_ref, argument> results;
results.reserve(p.size() * 2);
std::vector<argument> generic_eval(const module* mod,
context& ctx,
std::unordered_map<std::string, argument> params,
std::unordered_map<instruction_ref, argument> results,
F make_trace)
{
assert(mod->validate() == mod->end());
results.reserve(mod->size() * 2);
std::vector<argument> values;
values.reserve(16);
for(auto ins : iterator_for(p))
auto trace = make_trace(mod);
for(auto ins : iterator_for(*mod))
{
assert(results.find(ins) == results.end());
const auto& name = ins->name();
if(name == "@literal")
{
......@@ -407,6 +275,19 @@ argument generic_eval(const program& p,
{
results.emplace(ins, trace(ins, [&] { return argument{ins->get_shape(), nullptr}; }));
}
else if(name == "@return")
{
std::vector<argument> prog_outputs;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(prog_outputs),
[&](instruction_ref i) {
assert(results.find(i) != results.end());
return results[i];
});
return prog_outputs;
}
else
{
values.resize(ins->inputs().size());
......@@ -415,52 +296,280 @@ argument generic_eval(const program& p,
assert(results.find(i) != results.end());
return results[i];
});
const auto& mod_args = ins->module_inputs();
auto module_eval = [&](module_ref smod,
const std::unordered_map<std::string, argument>& inputs) {
auto ssctx = ctx;
return generic_eval(smod, ssctx, inputs, results, make_trace);
};
results.emplace(ins, trace(ins, [&] {
return ins->get_operator().compute(ctx, ins->get_shape(), values);
return ins->normalized_operator().compute(
ctx, ins->get_shape(), values, mod_args, module_eval);
}));
}
assert(results.find(ins) != results.end());
assert(results.at(ins).get_shape() == ins->get_shape());
}
return results.at(std::prev(p.end()));
return {results.at(std::prev(mod->end()))};
}
argument program::eval(std::unordered_map<std::string, argument> params) const
template <class F>
std::vector<argument> generic_eval(const program& p,
context& ctx,
std::unordered_map<std::string, argument> params,
F make_trace)
{
const module* mm = p.get_main_module();
return generic_eval(mm, ctx, params, {}, make_trace);
}
std::vector<argument> program::eval(parameter_map params) const
{
auto& ctx = this->impl->ctx;
#ifndef NDEBUG
auto sctx = ctx;
auto check_context = [&](auto f) {
assert(is_shared(ctx, sctx));
auto x = f();
sctx = ctx;
return x;
auto with_check_context = [&](auto f) {
return [=, &ctx](auto&&) {
auto sctx = std::make_shared<context>(ctx);
auto check_context = [=, &ctx](auto g) {
assert(is_shared(ctx, *sctx));
auto x = g();
*sctx = ctx;
return x;
};
return [=](auto&&... xs) { return f(xs..., check_context); };
};
};
#else
auto check_context = [](auto f) { return f(); };
auto with_check_context = [](auto f) {
return [=](auto&&) {
return [=](auto&&... xs) { return f(xs..., [](auto g) { return g(); }); };
};
};
#endif
auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{});
if(trace_level > 0)
{
return generic_eval(*this, ctx, std::move(params), [&](auto& ins, auto f) {
ctx.finish();
std::cout << "Run instruction: ";
this->debug_print(ins);
auto result = check_context(f);
ctx.finish();
if(trace_level > 1 and ins->name().front() != '@' and ins->name() != "load")
std::cout << "Ouput: " << result << std::endl;
return result;
std::unordered_map<instruction_ref, std::string> ins_out;
// get instruction names
this->print([&](auto x, auto ins_names) {
std::stringstream ss;
instruction::print(ss, x, ins_names);
ins_out[x] = ss.str();
});
return generic_eval(*this,
ctx,
std::move(params),
with_check_context([&](auto& ins, auto f, auto&& check_context) {
ctx.finish();
std::cout << "Run instruction: " << ins_out.at(ins) << std::endl;
timer t{};
auto result = check_context(f);
double t1 = t.record<milliseconds>();
ctx.finish();
double t2 = t.record<milliseconds>();
std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl;
if(trace_level > 1 and ins->name().front() != '@' and
ins->name() != "load" and not result.empty())
{
target tgt = make_target(this->impl->target_name);
auto buffer = tgt.copy_from(result);
if(trace_level == 2)
{
std::cout << "Output has "
<< to_string_range(classify_argument(buffer))
<< std::endl;
std::cout << "Output: ";
preview_argument(std::cout, buffer);
std::cout << std::endl;
}
else
{
std::cout << "Output: " << buffer << std::endl;
}
}
return result;
}));
}
else
{
return generic_eval(
*this, ctx, std::move(params), [&](auto&, auto f) { return check_context(f); });
return generic_eval(*this,
ctx,
std::move(params),
with_check_context([&](auto&, auto f, auto&& check_context) {
return check_context(f);
}));
}
}
const int program_file_version = 5;
value program::to_value() const
{
value result;
result["version"] = program_file_version;
result["target"] = this->impl->target_name;
if(not this->impl->target_name.empty())
result["context"] = this->impl->ctx.to_value();
value module_vals = value::object{};
std::unordered_map<instruction_ref, std::string> names;
for(auto& mod : this->get_modules())
{
value mod_val;
value nodes;
mod_val["name"] = mod->name();
names = mod->print(
[&](auto ins, auto ins_names) {
value node;
node["output"] = ins_names.at(ins);
node["name"] = ins->name();
node["shape"] = migraphx::to_value(ins->get_shape());
node["normalized"] = ins->is_normalized();
if(ins->name() == "@literal")
node["literal"] = migraphx::to_value(ins->get_literal());
node["operator"] = ins->get_operator().to_value();
std::vector<std::string> inputs;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto i) {
assert(contains(ins_names, i));
return ins_names.at(i);
});
node["inputs"] = inputs;
auto module_args = ins->module_inputs();
if(not module_args.empty())
{
std::vector<std::string> module_inputs;
std::transform(module_args.begin(),
module_args.end(),
std::back_inserter(module_inputs),
[&](auto mod_ref) { return mod_ref->name(); });
node["module_inputs"] = module_inputs;
}
nodes.push_back(node);
},
names);
mod_val["nodes"] = nodes;
module_vals[mod->name()] = mod_val;
}
result["modules"] = module_vals;
return result;
}
static void mod_from_val(module_ref mod,
const value& v,
std::unordered_map<std::string, instruction_ref>& instructions,
const std::unordered_map<std::string, module_ref>& map_mods)
{
const auto& module_val = v.at(mod->name());
for(const value& node : module_val.at("nodes"))
{
instruction_ref output;
auto name = node.at("name").to<std::string>();
auto fields = node.at("operator");
auto normalized = node.at("normalized").to<bool>();
if(name == "@param")
{
output = mod->add_parameter(fields["parameter"].to<std::string>(),
migraphx::from_value<shape>(node.at("shape")));
}
else if(name == "@literal")
{
output = mod->add_literal(migraphx::from_value<literal>(node.at("literal")));
}
else
{
auto op = make_op(name, fields);
std::vector<instruction_ref> inputs;
std::transform(node.at("inputs").begin(),
node.at("inputs").end(),
std::back_inserter(inputs),
[&](const value& i) {
auto i_name = i.to<std::string>();
assert(contains(instructions, i_name));
return instructions.at(i_name);
});
std::vector<module_ref> module_inputs;
if(node.contains("module_inputs"))
{
std::transform(node.at("module_inputs").begin(),
node.at("module_inputs").end(),
std::back_inserter(module_inputs),
[&](const value& i) { return map_mods.at(i.to<std::string>()); });
for(auto& smod : module_inputs)
{
mod_from_val(smod, v, instructions, map_mods);
}
}
if(name == "@return")
{
output = mod->add_return(inputs);
}
else if(module_inputs.empty())
{
output = mod->add_instruction(op, inputs);
}
else
{
output = mod->add_instruction(op, inputs, module_inputs);
}
}
output->set_normalized(normalized);
instructions[node.at("output").to<std::string>()] = output;
}
}
void program::from_value(const value& v)
{
auto version = v.at("version").to<int>();
if(version != program_file_version)
{
MIGRAPHX_THROW("Warning: Program version mismatch");
}
this->impl->target_name = v.at("target").to<std::string>();
if(not this->impl->target_name.empty())
{
target t = make_target(this->impl->target_name);
this->impl->ctx = t.get_context();
this->impl->ctx.from_value(v.at("context"));
}
auto module_vals = v.at("modules");
for(const auto& vv : module_vals)
{
const auto& name = vv.get_key();
if(name == "main")
continue;
impl->modules.emplace(name, name);
}
std::unordered_map<std::string, module_ref> map_mods;
std::transform(impl->modules.begin(),
impl->modules.end(),
std::inserter(map_mods, map_mods.end()),
[&](auto&& pp) { return std::make_pair(pp.first, &pp.second); });
std::unordered_map<std::string, instruction_ref> map_insts;
auto* mm = get_main_module();
mod_from_val(mm, module_vals, map_insts, map_mods);
this->finalize();
}
double common_average(const std::vector<double>& v)
{
std::size_t n = v.size() / 4;
......@@ -468,10 +577,38 @@ double common_average(const std::vector<double>& v)
return total / std::distance(v.begin() + n, v.end() - n);
}
void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) const
std::string perf_group(const operation& op)
{
auto attr = op.attributes();
if(attr.contains("group"))
return attr.at("group").to<std::string>();
return op.name();
}
void program::mark(const parameter_map& params, marker&& m)
{
using milliseconds = std::chrono::duration<double, std::milli>;
auto& ctx = this->impl->ctx;
auto& ctx = this->impl->ctx;
// Run once by itself
eval(params);
ctx.finish();
// Start marking
m.mark_start(*this);
generic_eval(*this, ctx, params, always([&](auto ins, auto f) {
argument result;
m.mark_start(ins);
result = f();
m.mark_stop(ins);
return result;
}));
m.mark_stop(*this);
}
void program::perf_report(std::ostream& os,
std::size_t n,
parameter_map params,
std::size_t batch) const
{
auto& ctx = this->impl->ctx;
// Run once by itself
eval(params);
ctx.finish();
......@@ -488,21 +625,22 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
std::sort(total_vec.begin(), total_vec.end());
std::unordered_map<instruction_ref, std::vector<double>> ins_vec;
// Fill the map
generic_eval(*this, ctx, params, [&](auto ins, auto) {
generic_eval(*this, ctx, params, always([&](auto ins, auto) {
ins_vec[ins].reserve(n);
return argument{};
});
return argument{ins->get_shape(), nullptr};
}));
// Run and time each instruction
for(std::size_t i = 0; i < n; i++)
{
generic_eval(*this, ctx, params, [&](auto ins, auto f) {
generic_eval(*this, ctx, params, always([&](auto ins, auto f) {
argument result;
ins_vec[ins].push_back(time<milliseconds>([&] {
result = f();
ctx.finish();
}));
return result;
});
}));
}
for(auto&& p : ins_vec)
std::sort(p.second.begin(), p.second.end());
......@@ -523,14 +661,20 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
for(auto&& p : ins_vec)
{
double avg = common_average(p.second);
op_times[p.first->name()] += avg;
op_times[perf_group(p.first->get_operator())] += avg;
total_instruction_time += avg;
}
double calculate_overhead_time = total_time - total_instruction_time;
double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time;
print_program(*this, [&](auto ins, const auto& names) {
print_instruction(std::cout, ins, names);
std::unordered_map<instruction_ref, std::string> names;
this->print(names, [&](auto ins, auto ins_names) {
instruction::print(std::cout, ins, ins_names);
// skip return instruction
if(ins->name() == "@return")
return;
double avg = common_average(ins_vec[ins]);
double percent = std::ceil(100.0 * avg / total_instruction_time);
os << ": " << avg << "ms, " << percent << "%";
......@@ -555,7 +699,8 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
os << std::endl;
os << "Rate: " << rate << "/sec" << std::endl;
os << "Batch size: " << batch << std::endl;
os << "Rate: " << rate * batch << "/sec" << std::endl;
os << "Total time: " << total_time << "ms" << std::endl;
os << "Total instructions time: " << total_instruction_time << "ms" << std::endl;
os << "Overhead time: " << overhead_time << "ms"
......@@ -567,86 +712,231 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
void program::debug_print() const { std::cout << *this << std::endl; }
void program::debug_print(instruction_ref ins) const
{
if(ins == this->end())
std::unordered_map<instruction_ref, std::string> names;
if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& pp) {
return is_end(pp.second.end(), ins);
}))
{
std::cout << "End instruction" << std::endl;
return;
}
if(not has_instruction(ins))
else if(std::none_of(this->impl->modules.begin(),
this->impl->modules.end(),
[&](const auto& pp) { return pp.second.has_instruction(ins); }))
{
std::cout << "Instruction not part of program" << std::endl;
return;
}
std::stringstream ss;
print_program(*this, [&](auto x, const auto& names) {
this->print(names, [&](auto x, auto ins_names) {
if(x == ins)
{
print_instruction(std::cout, x, names);
instruction::print(std::cout, x, ins_names);
std::cout << std::endl;
}
});
}
void program::debug_print(const std::vector<instruction_ref>& inss) const
void program::print(
std::unordered_map<instruction_ref, std::string>& names,
const std::function<void(instruction_ref, std::unordered_map<instruction_ref, std::string>)>&
print_func) const
{
for(auto ins : inss)
debug_print(ins);
std::cout << std::endl;
for(const auto& pp : this->impl->modules)
{
names = pp.second.print(print_func, names);
}
}
static std::string enclose_name(const std::string& name)
void program::print(
const std::function<void(instruction_ref ins,
std::unordered_map<instruction_ref, std::string>)>& print_func) const
{
return '"' + replace_string(name, "\"", "\\\"") + '"';
std::unordered_map<instruction_ref, std::string> names;
this->print(names, print_func);
}
void program::print_graph(std::ostream& os, bool brief) const
{
os << "digraph {" << std::endl;
os << "\trankdir=LR;" << std::endl;
print_program(*this, [&](auto ins, const auto& names) {
std::string label;
if(brief)
label = ins->name();
else
label = to_string(ins->get_operator());
os << "\t" << enclose_name(names.at(ins)) << "[label=" << enclose_name(label) << "]";
os << ";" << std::endl;
if(!ins->inputs().empty())
const auto* mm = this->get_main_module();
mm->print_graph(os, brief);
}
void program::print_cpp(std::ostream& os) const
{
auto vec_modules = this->get_modules();
std::unordered_map<instruction_ref, std::string> names;
for(auto& mod : vec_modules)
{
os << "module: \"" << mod->name() << "\"" << std::endl;
names = mod->print_cpp(os, names);
os << std::endl;
}
}
void program::dry_run(std::unordered_map<std::string, argument> params) const
{
auto& ctx = this->impl->ctx;
generic_eval(*this, ctx, std::move(params), always([](auto ins, auto&&...) {
return argument{ins->get_shape(), nullptr};
}));
}
void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const
{
for(auto& pp : this->impl->modules)
{
std::cout << pp.first << ":" << std::endl;
pp.second.annotate(os, a);
}
}
const module* program::get_module(const std::string& name) const { return &impl->modules.at(name); }
module* program::create_module(const std::string& name)
{
assert(not contains(impl->modules, name));
auto r = impl->modules.emplace(name, name);
return &(r.first->second);
}
module* program::get_module(const std::string& name) { return &impl->modules.at(name); }
module* program::get_main_module() { return get_module("main"); }
const module* program::get_main_module() const { return get_module("main"); }
template <class T>
std::vector<T*> generic_get_modules(T* mm)
{
std::vector<T*> vec_modules;
vec_modules.push_back(mm);
auto sub_modules = mm->get_sub_modules();
vec_modules.insert(vec_modules.end(), sub_modules.begin(), sub_modules.end());
return vec_modules;
}
template <class Map, class T, class OutputIterator>
void generic_get_unused_modules(Map& m, const std::vector<T*>& mods, OutputIterator out)
{
std::unordered_set<std::string> used;
std::transform(mods.begin(), mods.end(), std::inserter(used, used.end()), [](auto&& mod) {
return mod->name();
});
transform_if(
m.begin(),
m.end(),
out,
[&](auto&& pp) { return not contains(used, pp.first); },
[](auto&& pp) { return &pp.second; });
}
std::vector<const module*> program::get_modules() const
{
auto result = generic_get_modules(this->get_main_module());
generic_get_unused_modules(impl->modules, result, std::back_inserter(result));
return result;
}
std::vector<module*> program::get_modules()
{
auto result = generic_get_modules(this->get_main_module());
generic_get_unused_modules(impl->modules, result, std::back_inserter(result));
return result;
}
template <class Map, class T>
bool is_unused_module(Map& m, const std::vector<T*>& mods, const std::string& name)
{
bool is_unused = false;
generic_get_unused_modules(m, mods, make_function_output_iterator([&](auto* mod) {
if(mod->name() == name)
is_unused = true;
}));
return is_unused;
}
template <class Map>
bool references_instruction(Map& m, const instruction& ins, const std::string& name)
{
return std::any_of(m.begin(), m.end(), [&](auto&& p) {
if(p.first == name)
return false;
return std::any_of(p.second.begin(), p.second.end(), [&](auto&& i) {
return std::any_of(i.inputs().begin(), i.inputs().end(), [&](auto&& j) {
return std::addressof(*j) == std::addressof(ins);
});
});
});
}
void program::remove_module(const std::string& name)
{
// cppcheck-suppress assertWithSideEffect
assert(is_unused_module(impl->modules, generic_get_modules(this->get_main_module()), name) &&
"Module used in program");
assert(std::none_of(
impl->modules.at(name).begin(),
impl->modules.at(name).end(),
[&](auto&& ins) { return references_instruction(impl->modules, ins, name); }) &&
"Instruction referenced in another module");
// if an instruction has an input out side of the current module, need to remove
// the instruction from its input's outputs
auto& mod = impl->modules.at(name);
for(auto ins : iterator_for(mod))
{
auto inputs = ins->inputs();
for(auto in : inputs)
{
for(auto&& arg : ins->inputs())
if(not mod.has_instruction(in))
{
os << "\t" << enclose_name(names.at(arg)) << " -> " << enclose_name(names.at(ins));
if(not brief)
os << "[label=" << enclose_name(to_string(ins->get_shape())) << "]";
os << ";" << std::endl;
in->remove_output(ins);
}
}
});
os << "}" << std::endl;
}
impl->modules.erase(name);
}
void program::dry_run(std::unordered_map<std::string, argument> params) const
void program::remove_unused_modules()
{
auto& ctx = this->impl->ctx;
generic_eval(*this, ctx, std::move(params), [](auto&&...) { return argument{}; });
std::vector<module*> unused;
generic_get_unused_modules(
impl->modules, generic_get_modules(this->get_main_module()), std::back_inserter(unused));
for(auto* m : unused)
this->remove_module(m->name());
}
void program::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const
program& program::sort()
{
print_program(*this, [&](auto ins, const auto& names) {
print_instruction(os, ins, names);
a(ins);
os << std::endl;
});
for(auto& pp : this->impl->modules)
{
pp.second.sort();
}
return *this;
}
bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); }
std::ostream& operator<<(std::ostream& os, const program& p)
{
print_program(p, [&](auto ins, const auto& names) {
print_instruction(os, ins, names);
auto vec_modules = p.get_modules();
std::unordered_map<instruction_ref, std::string> names;
for(auto& mod : vec_modules)
{
os << "module: \"" << mod->name() << "\"" << std::endl;
names = mod->print(
[&](auto ins, auto ins_names) {
instruction::print(os, ins, ins_names);
os << std::endl;
},
names);
os << std::endl;
});
}
return os;
}
......
......@@ -3,6 +3,7 @@
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp>
#include <unordered_set>
namespace migraphx {
......@@ -20,33 +21,42 @@ bool skip_propogate(instruction_ref ins)
return false;
}
void propagate_constant::apply(program& p) const
bool is_const(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); }
void propagate_constant::apply(module& m) const
{
for(auto i : iterator_for(p))
std::unordered_set<instruction_ref> const_instrs;
auto last = std::prev(m.end());
// Find instructions that can be evaluated to a literal
for(auto i : iterator_for(m))
{
if(i->name() != "@literal")
if(is_const(i) and i != last)
continue;
if(i->outputs().empty())
continue;
fix([&](auto self, auto ins) {
std::unordered_set<instruction_ref> children(ins->outputs().begin(),
ins->outputs().end());
for(auto child : children)
{
if(child->name() == "@literal" or skip_propogate(child))
{
self(child);
continue;
}
auto r = child->eval();
if(not r.empty())
{
assert(r.get_shape() == child->get_shape());
auto l = p.add_literal(r.get_shape(), r.data());
self(p.replace_instruction(child, l));
}
}
})(i);
std::copy_if(
i->inputs().begin(),
i->inputs().end(),
std::inserter(const_instrs, const_instrs.begin()),
[&](const instruction_ref ins) { return is_const(ins) and ins->name() != "@literal"; });
}
// Compute literals in parallel
std::vector<instruction_ref> const_instrs_vec{const_instrs.begin(), const_instrs.end()};
std::vector<argument> literals(const_instrs_vec.size());
par_for(const_instrs_vec.size(), 1, [&](const auto i) {
literals[i] = const_instrs_vec[i]->eval();
});
// Replace instructions in m
for(size_t i = 0; i < const_instrs_vec.size(); i++)
{
if(not literals[i].empty())
{
assert(literals[i].get_shape() == const_instrs_vec[i]->get_shape());
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
m.replace_instruction(const_instrs_vec[i], l);
}
}
}
......
option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON)
if(MIGRAPHX_ENABLE_PYTHON)
find_program(DEFAULT_PYTHON_EXE python)
if(DEFAULT_PYTHON_EXE)
set(PYTHON_EXECUTABLE ${DEFAULT_PYTHON_EXE} CACHE PATH "Path to python executable")
endif()
find_package(pybind11 REQUIRED)
pybind11_add_module(migraphx_py migraphx_py.cpp)
set_target_properties(migraphx_py PROPERTIES
OUTPUT_NAME migraphx
C_VISIBILITY_PRESET hidden
CXX_VISIBILITY_PRESET hidden
)
target_link_libraries(migraphx_py PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_cpu)
if(MIGRAPHX_ENABLE_GPU)
target_link_libraries(migraphx_py PRIVATE migraphx_gpu)
target_compile_definitions(migraphx_py PRIVATE -DHAVE_GPU)
endif()
rocm_install_targets(TARGETS migraphx_py)
include(PythonModules)
add_custom_target(migraphx_py)
foreach(PYTHON_VERSION ${PYTHON_VERSIONS})
py_add_module(migraphx_py_${PYTHON_VERSION} migraphx_py.cpp PYTHON_VERSION ${PYTHON_VERSION} PYTHON_MODULE migraphx)
target_link_libraries(migraphx_py_${PYTHON_VERSION} PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets)
rocm_install_targets(TARGETS migraphx_py_${PYTHON_VERSION})
add_dependencies(migraphx_py migraphx_py_${PYTHON_VERSION})
endforeach()
endif()
# -------------------------------------------------------------------------
# Copyright (c) Advanced Micro Devices. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""
Implements ONNX's backend API.
"""
import sys
if sys.version_info < (3, 0):
sys.exit()
from onnx import ModelProto
from onnx.checker import check_model
from onnx.backend.base import Backend
import migraphx
from onnx_migraphx.backend_rep import MIGraphXBackendRep
def get_device():
return ("CPU", "GPU")
class MIGraphXBackend(Backend):
_device = "GPU"
_input_names = []
_prog_string = ""
@classmethod
def set_device(cls, device):
cls._device = device
"""
Implements
`ONNX's backend API <https://github.com/onnx/onnx/blob/master/docs/ImplementingAnOnnxBackend.md>`_
with *ONNX Runtime*.
The backend is mostly used when you need to switch between
multiple runtimes with the same API.
`Importing models from ONNX to Caffe2 <https://github.com/onnx/tutorials/blob/master/tutorials/OnnxCaffe2Import.ipynb>`_
shows how to use *caffe2* as a backend for a converted model.
Note: This is not the official Python API.
""" # noqa: E501
@classmethod
def get_program(cls):
return cls._prog_string
@classmethod
def is_compatible(cls, model, device=None, **kwargs):
"""
Return whether the model is compatible with the backend.
:param model: unused
:param device: None to use the default device or a string (ex: `'CPU'`)
:return: boolean
"""
device = cls._device
return cls.supports_device(device)
@classmethod
def supports_device(cls, device):
"""
Check whether the backend is compiled with particular device support.
In particular it's used in the testing suite.
"""
return device in get_device()
@classmethod
def prepare(cls, model, device=None, **kwargs):
"""
Load the model and creates a :class:`migraphx.program`
ready to be used as a backend.
:param model: ModelProto (returned by `onnx.load`),
string for a filename or bytes for a serialized model
:param device: requested device for the computation,
None means the default one which depends on
the compilation settings
:param kwargs: see :class:`onnxruntime.SessionOptions`
:return: :class:`migraphx.program`
"""
if isinstance(model, MIGraphXBackendRep):
return model
elif isinstance(model, migraphx.program):
return MIGraphXBackendRep(model, cls._input_names)
elif isinstance(model, (str, bytes)):
if device is not None and not cls.supports_device(device):
raise RuntimeError(
"Incompatible device expected '{0}', got '{1}'".format(
device, get_device()))
inf = migraphx.parse_onnx_buffer(model)
cls._prog_string = str("\nProgram =\n{}".format(inf))
device = cls._device
cls._input_names = inf.get_parameter_names()
inf.compile(migraphx.get_target(device.lower()))
cls._prog_string = cls._prog_string + str(
"\nCompiled program =\n{}".format(inf))
return cls.prepare(inf, device, **kwargs)
else:
# type: ModelProto
check_model(model)
bin = model.SerializeToString()
return cls.prepare(bin, device, **kwargs)
@classmethod
def run_model(cls, model, inputs, device=None, **kwargs):
"""
Compute the prediction.
:param model: :class:`migraphx.program` returned
by function *prepare*
:param inputs: inputs
:param device: requested device for the computation,
None means the default one which depends on
the compilation settings
:param kwargs: see :class:`migraphx.program`
:return: predictions
"""
rep = cls.prepare(model, device, **kwargs)
return rep.run(inputs, **kwargs)
@classmethod
def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
'''
This method is not implemented as it is much more efficient
to run a whole model than every node independently.
'''
raise NotImplementedError(
"It is much more efficient to run a whole model than every node independently."
)
is_compatible = MIGraphXBackend.is_compatible
prepare = MIGraphXBackend.prepare
run = MIGraphXBackend.run_model
supports_device = MIGraphXBackend.supports_device
# -------------------------------------------------------------------------
# Copyright (c) Advanced Micro Device Inc. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""
Implements ONNX's backend API.
"""
import sys
if sys.version_info < (3, 0):
sys.exit()
import migraphx
from onnx.backend.base import BackendRep
import numpy as np
from typing import Any, Tuple
class MIGraphXBackendRep(BackendRep):
"""
Computes the prediction for a pipeline converted into
an :class:`onnxruntime.InferenceSession` node.
"""
def __init__(self, prog, input_names):
"""
:param session: :class:`migraphx.program`
"""
self._program = prog
self._input_names = input_names
def run(self, inputs, **kwargs): # type: (Any, **Any) -> Tuple[Any, ...]
"""
Computes the prediction.
See :meth:`migraphx.program.run`.
"""
if isinstance(inputs, list):
inps = {}
for i, name in enumerate(self._input_names):
inps[name] = migraphx.argument(inputs[i])
mgx_outputs = self._program.run(inps)
outs = []
for out in mgx_outputs:
outs.append(np.array(out))
return outs
else:
inp = self._program.get_parameter_shapes().keys()
if len(inp) != 1:
raise RuntimeError("Model expect {0} inputs".format(len(inp)))
inps = {inp[0]: migraphx.argument(inputs)}
mgx_outputs = self._program.run(inps)
outs = []
for out in mgx_outputs:
outs.append(np.array(out))
return self._program.run(inps)
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include <migraphx/program.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/json.hpp>
#include <migraphx/make_op.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#endif
using half = half_float::half;
namespace py = pybind11;
template <class F>
struct throw_half
{
F f;
#ifdef __clang__
#define MIGRAPHX_PUSH_UNUSED_WARNING \
_Pragma("clang diagnostic push") \
_Pragma("clang diagnostic ignored \"-Wused-but-marked-unused\"")
#define MIGRAPHX_POP_WARNING _Pragma("clang diagnostic pop")
#else
#define MIGRAPHX_PUSH_UNUSED_WARNING
#define MIGRAPHX_POP_WARNING
#endif
#define MIGRAPHX_PYBIND11_MODULE(...) \
MIGRAPHX_PUSH_UNUSED_WARNING \
PYBIND11_MODULE(__VA_ARGS__) \
MIGRAPHX_POP_WARNING
template <class A>
void operator()(A a) const
namespace migraphx {
migraphx::value to_value(py::kwargs kwargs);
migraphx::value to_value(py::list lst);
template <class T, class F>
void visit_py(T x, F f)
{
if(py::isinstance<py::kwargs>(x))
{
f(a);
f(to_value(x.template cast<py::kwargs>()));
}
void operator()(migraphx::shape::as<migraphx::half>) const
else if(py::isinstance<py::list>(x))
{
f(to_value(x.template cast<py::list>()));
}
else if(py::isinstance<py::bool_>(x))
{
f(x.template cast<bool>());
}
else if(py::isinstance<py::int_>(x))
{
throw std::runtime_error("Half not supported in python yet.");
f(x.template cast<int>());
}
else if(py::isinstance<py::float_>(x))
{
f(x.template cast<float>());
}
else if(py::isinstance<py::str>(x))
{
f(x.template cast<std::string>());
}
else
{
MIGRAPHX_THROW("VISIT_PY: Unsupported data type!");
}
}
void operator()(migraphx::tensor_view<migraphx::half>) const
migraphx::value to_value(py::list lst)
{
migraphx::value v = migraphx::value::array{};
for(auto val : lst)
{
throw std::runtime_error("Half not supported in python yet.");
visit_py(val, [&](auto py_val) { v.push_back(py_val); });
}
};
template <class F>
struct skip_half
return v;
}
migraphx::value to_value(py::kwargs kwargs)
{
F f;
migraphx::value v = migraphx::value::object{};
template <class A>
void operator()(A a) const
for(auto arg : kwargs)
{
f(a);
auto&& key = py::str(arg.first);
auto&& val = arg.second;
visit_py(val, [&](auto py_val) { v[key] = py_val; });
}
return v;
}
} // namespace migraphx
void operator()(migraphx::shape::as<migraphx::half>) const {}
namespace pybind11 {
namespace detail {
void operator()(migraphx::tensor_view<migraphx::half>) const {}
template <>
struct npy_format_descriptor<half>
{
static std::string format()
{
// following: https://docs.python.org/3/library/struct.html#format-characters
return "e";
}
static constexpr auto name() { return _("half"); }
};
} // namespace detail
} // namespace pybind11
template <class F>
void visit_type(const migraphx::shape& s, F f)
{
s.visit_type(throw_half<F>{f});
s.visit_type(f);
}
template <class T, class F>
void visit(const migraphx::raw_data<T>& x, F f)
{
x.visit(throw_half<F>{f});
x.visit(f);
}
template <class F>
void visit_types(F f)
{
migraphx::shape::visit_types(skip_half<F>{f});
migraphx::shape::visit_types(f);
}
template <class T>
......@@ -82,12 +146,26 @@ py::buffer_info to_buffer_info(T& x)
strides.begin(), strides.end(), strides.begin(), [&](auto i) { return i * s.type_size(); });
py::buffer_info b;
visit_type(s, [&](auto as) {
b = py::buffer_info(x.data(),
as.size(),
py::format_descriptor<decltype(as())>::format(),
s.lens().size(),
s.lens(),
strides);
// migraphx use int8_t data to store bool type, we need to
// explicitly specify the data type as bool for python
if(s.type() == migraphx::shape::bool_type)
{
b = py::buffer_info(x.data(),
as.size(),
py::format_descriptor<bool>::format(),
s.lens().size(),
s.lens(),
strides);
}
else
{
b = py::buffer_info(x.data(),
as.size(),
py::format_descriptor<decltype(as())>::format(),
s.lens().size(),
s.lens(),
strides);
}
});
return b;
}
......@@ -97,34 +175,59 @@ migraphx::shape to_shape(const py::buffer_info& info)
migraphx::shape::type_t t;
std::size_t n = 0;
visit_types([&](auto as) {
if(info.format == py::format_descriptor<decltype(as())>::format())
if(info.format == py::format_descriptor<decltype(as())>::format() or
(info.format == "l" and py::format_descriptor<decltype(as())>::format() == "q") or
(info.format == "L" and py::format_descriptor<decltype(as())>::format() == "Q"))
{
t = as.type_enum();
n = sizeof(as());
}
else if(info.format == "?" and py::format_descriptor<decltype(as())>::format() == "b")
{
t = migraphx::shape::bool_type;
n = sizeof(bool);
}
});
if(n == 0)
{
MIGRAPHX_THROW("MIGRAPHX PYTHON: Unsupported data type" + info.format);
MIGRAPHX_THROW("MIGRAPHX PYTHON: Unsupported data type " + info.format);
}
auto strides = info.strides;
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t {
return n > 0 ? i / n : 0;
});
return migraphx::shape{t, info.shape, strides};
// scalar support
if(info.shape.empty())
{
return migraphx::shape{t};
}
else
{
return migraphx::shape{t, info.shape, strides};
}
}
PYBIND11_MODULE(migraphx, m)
MIGRAPHX_PYBIND11_MODULE(migraphx, m)
{
py::class_<migraphx::shape>(m, "shape")
.def(py::init<>())
.def(py::init([](py::kwargs kwargs) {
auto v = migraphx::to_value(kwargs);
auto t = migraphx::shape::parse_type(v.get("type", "float"));
auto lens = v.get<std::size_t>("lens", {1});
if(v.contains("strides"))
return migraphx::shape(t, lens, v.at("strides").to_vector<std::size_t>());
else
return migraphx::shape(t, lens);
}))
.def("type", &migraphx::shape::type)
.def("lens", &migraphx::shape::lens)
.def("strides", &migraphx::shape::strides)
.def("elements", &migraphx::shape::elements)
.def("bytes", &migraphx::shape::bytes)
.def("type_string", &migraphx::shape::type_string)
.def("type_size", &migraphx::shape::type_size)
.def("packed", &migraphx::shape::packed)
.def("transposed", &migraphx::shape::transposed)
......@@ -155,41 +258,183 @@ PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target");
py::class_<migraphx::instruction_ref>(m, "instruction_ref");
py::class_<migraphx::module, std::unique_ptr<migraphx::module, py::nodelete>>(m, "module")
.def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; })
.def(
"add_instruction",
[](migraphx::module& mm,
const migraphx::operation& op,
std::vector<migraphx::instruction_ref>& args,
std::vector<migraphx::module*>& mod_args) {
return mm.add_instruction(op, args, mod_args);
},
py::arg("op"),
py::arg("args"),
py::arg("mod_args") = std::vector<migraphx::module*>{})
.def(
"add_literal",
[](migraphx::module& mm, py::buffer data) {
py::buffer_info info = data.request();
auto literal_shape = to_shape(info);
return mm.add_literal(literal_shape, reinterpret_cast<char*>(info.ptr));
},
py::arg("data"))
.def(
"add_parameter",
[](migraphx::module& mm, const std::string& name, const migraphx::shape shape) {
return mm.add_parameter(name, shape);
},
py::arg("name"),
py::arg("shape"))
.def(
"add_return",
[](migraphx::module& mm, std::vector<migraphx::instruction_ref>& args) {
return mm.add_return(args);
},
py::arg("args"))
.def("__repr__", [](const migraphx::module& mm) { return migraphx::to_string(mm); });
py::class_<migraphx::program>(m, "program")
.def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); })
.def(py::init([]() { return migraphx::program(); }))
.def("get_parameter_names", &migraphx::program::get_parameter_names)
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("get_shape", &migraphx::program::get_shape)
.def("compile",
[](migraphx::program& p, const migraphx::target& t, bool offload_copy) {
migraphx::compile_options options;
options.offload_copy = offload_copy;
p.compile(t, options);
},
py::arg("t"),
py::arg("offload_copy") = true)
.def("run", &migraphx::program::eval)
.def("get_output_shapes", &migraphx::program::get_output_shapes)
.def(
"compile",
[](migraphx::program& p, const migraphx::target& t, bool offload_copy, bool fast_math) {
migraphx::compile_options options;
options.offload_copy = offload_copy;
options.fast_math = fast_math;
p.compile(t, options);
},
py::arg("t"),
py::arg("offload_copy") = true,
py::arg("fast_math") = true)
.def("get_main_module", [](const migraphx::program& p) { return p.get_main_module(); })
.def(
"create_module",
[](migraphx::program& p, const std::string& name) { return p.create_module(name); },
py::arg("name"))
.def("run",
[](migraphx::program& p, py::dict params) {
migraphx::parameter_map pm;
for(auto x : params)
{
std::string key = x.first.cast<std::string>();
py::buffer b = x.second.cast<py::buffer>();
py::buffer_info info = b.request();
pm[key] = migraphx::argument(to_shape(info), info.ptr);
}
return p.eval(pm);
})
.def("sort", &migraphx::program::sort)
.def("print", [](const migraphx::program& p) { std::cout << p << std::endl; })
.def("__eq__", std::equal_to<migraphx::program>{})
.def("__ne__", std::not_equal_to<migraphx::program>{})
.def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); });
m.def("parse_tf",
&migraphx::parse_tf,
"Parse tf protobuf (default format is nhwc)",
py::arg("filename"),
py::arg("is_nhwc") = true);
m.def("parse_onnx", &migraphx::parse_onnx);
py::class_<migraphx::operation>(m, "op")
.def(py::init([](const std::string& name, py::kwargs kwargs) {
migraphx::value v = migraphx::value::object{};
if(kwargs)
{
v = migraphx::to_value(kwargs);
}
return migraphx::make_op(name, v);
}))
m.def("get_target", [](const std::string& name) -> migraphx::target {
if(name == "cpu")
return migraphx::cpu::target{};
#ifdef HAVE_GPU
if(name == "gpu")
return migraphx::gpu::target{};
#endif
throw std::runtime_error("Target not found: " + name);
});
.def("name", &migraphx::operation::name);
m.def(
"parse_tf",
[](const std::string& filename,
bool is_nhwc,
unsigned int batch_size,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::vector<std::string> output_names) {
return migraphx::parse_tf(
filename, migraphx::tf_options{is_nhwc, batch_size, map_input_dims, output_names});
},
"Parse tf protobuf (default format is nhwc)",
py::arg("filename"),
py::arg("is_nhwc") = true,
py::arg("batch_size") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("output_names") = std::vector<std::string>());
m.def(
"parse_onnx",
[](const std::string& filename,
unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool skip_unknown_operators,
bool print_program_on_error,
int64_t max_loop_iterations) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.map_input_dims = map_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
options.max_loop_iterations = max_loop_iterations;
return migraphx::parse_onnx(filename, options);
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10);
m.def(
"parse_onnx_buffer",
[](const std::string& onnx_buffer,
unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool skip_unknown_operators,
bool print_program_on_error) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.map_input_dims = map_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
return migraphx::parse_onnx_buffer(onnx_buffer, options);
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false);
m.def(
"load",
[](const std::string& name, const std::string& format) {
migraphx::file_options options;
options.format = format;
return migraphx::load(name, options);
},
"Load MIGraphX program",
py::arg("filename"),
py::arg("format") = "msgpack");
m.def(
"save",
[](const migraphx::program& p, const std::string& name, const std::string& format) {
migraphx::file_options options;
options.format = format;
return migraphx::save(p, name, options);
},
"Save MIGraphX program",
py::arg("p"),
py::arg("filename"),
py::arg("format") = "msgpack");
m.def("get_target", &migraphx::make_target);
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
m.def("fill_argument", &migraphx::fill_argument, py::arg("s"), py::arg("value"));
m.def("quantize_fp16",
&migraphx::quantize_fp16,
py::arg("prog"),
......@@ -198,14 +443,14 @@ PYBIND11_MODULE(migraphx, m)
&migraphx::quantize_int8,
py::arg("prog"),
py::arg("t"),
py::arg("calibration") = std::vector<migraphx::program::parameter_map>{},
py::arg("calibration") = std::vector<migraphx::parameter_map>{},
py::arg("ins_names") = std::vector<std::string>{"dot", "convolution"});
#ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
m.def("to_gpu", &migraphx::gpu::to_gpu, py::arg("arg"), py::arg("host") = false);
m.def("from_gpu", &migraphx::gpu::from_gpu);
m.def("gpu_sync", &migraphx::gpu::gpu_sync);
m.def("gpu_sync", [] { migraphx::gpu::gpu_sync(); });
#endif
#ifdef VERSION_INFO
......
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/quantize_int8.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/convert.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/round.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <utility>
#include <migraphx/make_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <set>
#include <iomanip>
#include <fstream>
#include <algorithm>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
instruction_ref insert_quant_ins(program& prog,
instruction_ref& ins,
shape::type_t type,
std::unordered_map<instruction_ref, instruction_ref>& map_ins,
float scale = 1.0f,
float shift = 0.0f)
{
if(map_ins.count(ins) > 0)
{
return map_ins[ins];
}
if(ins->name() == "undefined")
{
return ins;
}
assert(ins->get_shape().type() == shape::float_type or
ins->get_shape().type() == shape::double_type or
ins->get_shape().type() == shape::int32_type or
ins->get_shape().type() == shape::half_type);
instruction_ref quant_ins{};
auto insert_loc = std::next(ins);
if(type == shape::int8_type)
{
auto scaled_ins = ins;
if(scale != 1.0f)
{
auto float_ins = scaled_ins;
if(scaled_ins->get_shape().type() != shape::float_type)
{
float_ins =
prog.insert_instruction(insert_loc, op::convert{shape::float_type}, scaled_ins);
}
std::vector<float> vec_scale(scaled_ins->get_shape().elements(), scale);
auto l_scale = prog.add_literal(literal(float_ins->get_shape(), vec_scale));
scaled_ins = prog.insert_instruction(insert_loc, op::mul{}, l_scale, float_ins);
}
auto shifted_ins = scaled_ins;
if(shift != 0.0f)
{
auto float_ins = shifted_ins;
if(shifted_ins->get_shape().type() != shape::float_type)
{
float_ins = prog.insert_instruction(
insert_loc, op::convert{shape::float_type}, shifted_ins);
}
std::vector<float> vec_shift(shifted_ins->get_shape().elements(), shift);
auto l_shift = prog.add_literal(literal(float_ins->get_shape(), vec_shift));
shifted_ins = prog.insert_instruction(insert_loc, op::add{}, l_shift, float_ins);
}
auto rounded_ins = prog.insert_instruction(insert_loc, op::round{}, shifted_ins);
auto clipped_ins =
prog.insert_instruction(insert_loc, op::clip{127.0f, -128.0f}, rounded_ins);
quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, clipped_ins);
}
else
{
quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, ins);
}
map_ins[ins] = quant_ins;
return quant_ins;
}
// This function is to convert any instructions specified in the input
// from double or float to float16 by inserting a convert operator.
// For the conversion, there could be cases of overflowing, but it
......@@ -102,402 +30,29 @@ instruction_ref insert_quant_ins(program& prog,
// truncate of the input to get the fp16.
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
{
std::unordered_map<instruction_ref, instruction_ref> map_fp16;
for(auto ins : iterator_for(prog))
{
// all indicates every instruction is converted
if((not contains(ins_names, "all")) and (not contains(ins_names, ins->name())))
{
continue;
}
shape::type_t orig_type = ins->get_shape().type();
// process all inputs, if input is a fp32 or fp64, convert it
// to a fp16 by adding a convert operator.
auto inputs = ins->inputs();
std::vector<instruction_ref> converted_inputs;
for(auto input : inputs)
{
auto s = input->get_shape();
if(s.type() == shape::float_type || s.type() == shape::double_type)
{
// if the input is a convert operator, uses its input
// as its current input
instruction_ref input_fp16{};
if(input->name() == "convert" and
input->inputs().front()->get_shape().type() == shape::half_type)
{
input_fp16 = input->inputs().front();
}
else
{
input_fp16 = insert_quant_ins(prog, input, shape::half_type, map_fp16);
}
converted_inputs.push_back(input_fp16);
}
else
{
converted_inputs.push_back(input);
}
}
// no change for the input, go to the next instruction
if(inputs == converted_inputs)
{
continue;
}
auto op = ins->get_operator();
auto ins_shape = compute_shape(op, converted_inputs);
if(ins_shape.type() != orig_type)
{
// check the dead code case to avoid assert
bool output_empty = ins->outputs().empty();
auto ins_orig_type =
prog.insert_instruction(std::next(ins), op::convert{orig_type}, ins);
if(!output_empty)
{
prog.replace_instruction(ins, ins_orig_type);
}
}
prog.replace_instruction(ins, op, converted_inputs);
}
}
static void ins_quantize_int8(program& prog,
instruction_ref ins,
std::vector<instruction_ref>& converted_inputs,
const std::vector<std::pair<float, float>>& ins_quant_params)
{
auto orig_type = ins->get_shape().type();
auto inputs = ins->inputs();
if(ins->name() == "dot")
{
auto dot_op = any_cast<op::dot>(ins->get_operator());
float new_alpha = dot_op.alpha / (ins_quant_params[0].first * ins_quant_params[1].first);
float new_beta = dot_op.beta;
// We need additional checking about the quant_alpha value. If
// abs(quant_alpha) > 50 (some tmp value set here), we can convert
// it to an integer as the new_alpha in the quant_dot
float threshold = 50.0f;
if(fabs(new_alpha) >= threshold && fabs(new_beta) >= threshold)
{
int32_t quant_alpha = static_cast<int32_t>(std::round(new_alpha));
int32_t quant_beta = static_cast<int32_t>(std::round(new_beta));
if(shape::int32_type == orig_type)
{
prog.replace_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
}
else
{
auto quant_dot = prog.insert_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
prog.replace_instruction(ins, op::convert{orig_type}, quant_dot);
}
}
// either alpha or beta cannot be quantized because of too big
// relative rounding error
else
{
if(converted_inputs.size() == 3)
{
converted_inputs.pop_back();
}
auto q_dot = prog.insert_instruction(ins, op::quant_dot{1, 0}, converted_inputs);
auto f_dot = prog.insert_instruction(ins, op::convert{shape::float_type}, q_dot);
auto c_shape = q_dot->get_shape();
std::vector<float> vec_alpha(c_shape.elements(), new_alpha);
auto l_alpha =
prog.add_literal(literal({shape::float_type, c_shape.lens()}, vec_alpha));
if(inputs.size() == 3 and dot_op.beta != 0.0f)
{
auto alpha_ab = prog.insert_instruction(ins, op::mul{}, l_alpha, f_dot);
std::vector<float> vec_beta(c_shape.elements(), dot_op.beta);
auto l_beta =
prog.add_literal(literal({shape::float_type, c_shape.lens()}, vec_beta));
instruction_ref beta_c{};
if(orig_type != shape::float_type)
{
auto fp32_c =
prog.insert_instruction(ins, op::convert{shape::float_type}, inputs.back());
beta_c = prog.insert_instruction(ins, op::mul{}, l_beta, fp32_c);
}
else
{
beta_c = prog.insert_instruction(ins, op::mul{}, l_beta, inputs.back());
}
if(orig_type == shape::float_type)
{
prog.replace_instruction(ins, op::add{}, alpha_ab, beta_c);
}
else
{
auto f_res = prog.insert_instruction(ins, op::add{}, alpha_ab, beta_c);
prog.replace_instruction(ins, op::convert{orig_type}, f_res);
}
}
else
{
if(orig_type == shape::float_type)
{
prog.replace_instruction(ins, op::mul{}, l_alpha, f_dot);
}
else
{
auto alpha_ab = prog.insert_instruction(ins, op::mul{}, l_alpha, f_dot);
prog.replace_instruction(ins, op::convert{orig_type}, alpha_ab);
}
}
}
}
else if(ins->name() == "convolution")
{
// Current MIOpen convolution does not support alpha and beta,
// so we need a separate multiply to adjust the output
auto conv_op = any_cast<op::convolution>(ins->get_operator());
auto padding = conv_op.padding;
auto stride = conv_op.stride;
auto dilation = conv_op.dilation;
auto padding_mode = conv_op.padding_mode;
auto group = conv_op.group;
auto adjust_factor = 1.0f / (ins_quant_params[0].first * ins_quant_params[1].first);
auto quant_conv = prog.insert_instruction(
ins,
op::quant_convolution{padding, stride, dilation, padding_mode, group},
converted_inputs);
float threshold = 50.0f;
std::vector<float> vec_factor(quant_conv->get_shape().elements(), adjust_factor);
if(quant_conv->get_shape().type() == orig_type and adjust_factor >= threshold)
{
auto l_factor = prog.add_literal(
literal(quant_conv->get_shape(), vec_factor.begin(), vec_factor.end()));
prog.replace_instruction(ins, op::mul{}, quant_conv, l_factor);
}
// convert quant_conv output to float type, multiply the factor and
// conver back to original type
else
{
auto float_conv =
prog.insert_instruction(ins, op::convert{shape::float_type}, quant_conv);
auto l_factor = prog.add_literal(literal(float_conv->get_shape(), vec_factor));
if(orig_type == shape::float_type)
{
prog.replace_instruction(ins, op::mul{}, l_factor, float_conv);
}
else
{
auto adjusted_conv = prog.insert_instruction(ins, op::mul{}, l_factor, float_conv);
prog.replace_instruction(ins, op::convert{orig_type}, adjusted_conv);
}
}
}
else
{
MIGRAPHX_THROW("QUANTIZE_INT8: does not support operator " + ins->name());
}
}
// int8 quantization is different from fp16 since int8 can only handle value
// -128 ~ 127. To convert the float or double to int8, we need a scale and
// a shift, then the convert can be done as v_int8 = fp * scale + shift.
// To simplify the changes, we consider shift as 0.0f for now.
void quantize_int8_impl(program& prog,
const std::vector<std::pair<float, float>>& quant_params,
const std::vector<std::string>& ins_names)
{
if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{}))
{
for(std::size_t i = 0; i < quant_params.size(); ++i)
{
auto param = quant_params.at(i);
std::cout << "ins_index = " << i << ", scale = " << param.first
<< ", shift = " << param.second << std::endl;
}
std::cout << std::endl;
}
// For now, we only support the int8 quantization of gemm and convolution
std::set<std::string> op_names = {"convolution", "dot"};
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if(!std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
std::size_t quant_param_index = 0;
std::unordered_map<instruction_ref, instruction_ref> map_quant_ins;
std::unordered_map<instruction_ref, std::size_t> map_ins_index;
for(auto ins : iterator_for(prog))
{
if(not contains(ins_names, ins->name()))
{
continue;
}
// for the dot operator, there could be 2 or 3 input arguments
// if the 3rd argument is available, convert it to an int32.
std::vector<instruction_ref> converted_inputs;
// process all inputs, if input is a fp32 or fp64, convert it
// to a int8 type by adding a convert operator and replace
// the operator with the corresponding int8 version
auto inputs = ins->inputs();
std::vector<std::pair<float, float>> ins_quant_params;
for(auto input : inputs)
{
// calculate the index of each instruction to be quantized
std::size_t ins_index =
(map_ins_index.count(input) > 0) ? map_ins_index[input] : quant_param_index++;
map_ins_index[input] = ins_index;
auto param = quant_params[map_ins_index[input]];
ins_quant_params.push_back(param);
// In general, the target_type is int8, but for the dot
// operation, if it has 3 inputs, then the last one should
// be converted to int32_type
shape::type_t quant_type = shape::int8_type;
if((ins->name() == "dot") and (inputs.size() == 3) and (input == inputs.back()))
{
quant_type = shape::int32_type;
}
auto s = input->get_shape();
if((s.type() == shape::float_type or s.type() == shape::double_type or
s.type() == shape::half_type or s.type() == shape::int32_type) and
s.type() != quant_type)
{
// if the input is a convert operator, uses its input
// as its current input
instruction_ref quant_input{};
if(input->name() == "convert" and
input->inputs().front()->get_shape().type() == quant_type)
{
quant_input = input->inputs().front();
// the scale in this case is not used, so tune the scale
// to 1.0f for this parameter
ins_quant_params.back() = std::pair<float, float>(1.0f, 0.0f);
}
else
{
quant_input = insert_quant_ins(
prog, input, quant_type, map_quant_ins, param.first, param.second);
}
converted_inputs.push_back(quant_input);
}
else
{
converted_inputs.push_back(input);
}
}
// no change for the input, go to the next instruction
if(inputs == converted_inputs)
{
continue;
}
ins_quantize_int8(prog, ins, converted_inputs, ins_quant_params);
}
if(quant_param_index != quant_params.size())
{
MIGRAPHX_THROW("QUANTIZE_INT8: number of scales does not match");
}
run_passes(prog,
{quantize_fp16_pass{ins_names},
eliminate_common_subexpression{},
dead_code_elimination{},
simplify_reshapes{},
dead_code_elimination{},
simplify_qdq{},
dead_code_elimination{}});
}
void quantize_int8(program& prog,
const target& t,
const std::vector<program::parameter_map>& calibration,
const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names)
{
// insert capture operator
auto cap_prog = prog;
auto int8_quant_params = capture_arguments(cap_prog, t, ins_names);
// use the calibration data to compute the quantization scale
cap_prog.compile(t);
// use all calibration data to run the program to calculate the
// quantization scale and shift
for(auto&& arg : calibration)
{
program::parameter_map m;
for(auto&& x : cap_prog.get_parameter_shapes())
{
if(arg.count(x.first) > 0)
{
assert(x.second == arg.at(x.first).get_shape());
m[x.first] = t.copy_to(arg.at(x.first));
}
else
{
m[x.first] = t.allocate(x.second);
}
}
cap_prog.eval(m);
}
quantize_int8_impl(prog, *int8_quant_params, ins_names);
}
// For the input of each input argument, we need to insert a
// capture operator to compute the scale and shift
std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func)
{
size_t num_quant_params = 0;
// the int8 quantization only support dot and convolution
std::set<std::string> op_names = {"dot", "convolution"};
std::set<std::string> op_names = {"convolution", "dot"};
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if(!std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{
MIGRAPHX_THROW("CAPTURE_ARGUMENTS: input operator is not supported");
}
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(prog))
{
if(not contains(ins_names, ins->name()))
{
continue;
}
auto inputs = ins->inputs();
std::vector<instruction_ref> new_args;
for(auto input : inputs)
{
instruction_ref new_ins{};
if(ins_map.count(input) > 0)
{
new_ins = ins_map[input];
}
else
{
new_ins = prog.insert_instruction(
std::next(input), op::capture{num_quant_params++, func}, input);
ins_map[input] = new_ins;
}
new_args.push_back(new_ins);
}
instruction::replace(ins, ins->get_operator(), ins->get_shape(), new_args);
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
return num_quant_params;
}
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments_impl(program& prog, const target& t, const std::vector<std::string>& ins_names)
{
std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params =
std::make_shared<std::vector<std::pair<float, float>>>();
std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>();
......@@ -505,7 +60,6 @@ capture_arguments_impl(program& prog, const target& t, const std::vector<std::st
auto calc_quant_params = [int8_quant_params, max_abs_vals, &t](std::size_t ins_index,
std::vector<argument> args) {
std::pair<float, float> param_pair{64.0f, 0.0f};
// scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0
std::vector<float> vec_val;
......@@ -528,12 +82,56 @@ capture_arguments_impl(program& prog, const target& t, const std::vector<std::st
int8_quant_params->at(ins_index) = param_pair;
};
auto num_params = capture_arguments(prog, ins_names, calc_quant_params);
// pass to add capture argument op
std::size_t param_num = 0;
run_passes(prog, {capture_arguments_pass{ins_names, calc_quant_params, &param_num}});
int8_quant_params->resize(param_num, std::pair<float, float>(64.0f, 0.0f));
max_abs_vals->resize(param_num, 0.0f);
// use the calibration data to compute the quantization scale
auto capture_prog = prog;
capture_prog.compile(t);
// use all calibration data to run the program to calculate the
// quantization scale and shift
for(auto&& arg : calibration)
{
parameter_map m;
for(auto&& x : capture_prog.get_parameter_shapes())
{
if(arg.count(x.first) > 0)
{
assert(x.second == arg.at(x.first).get_shape());
m[x.first] = t.copy_to(arg.at(x.first));
}
else
{
m[x.first] = t.allocate(x.second);
}
}
capture_prog.eval(m);
}
int8_quant_params->resize(num_params, std::pair<float, float>(64.0f, 0.0f));
max_abs_vals->resize(num_params, 0.0f);
// print the quantization parameters in only the main module
if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{}))
{
for(std::size_t i = 0; i < int8_quant_params->size(); ++i)
{
auto param = int8_quant_params->at(i);
std::cout << "ins_index = " << i << ", scale = " << param.first
<< ", shift = " << param.second << std::endl;
}
std::cout << std::endl;
}
return int8_quant_params;
run_passes(prog,
{quantize_int8_pass{ins_names, *int8_quant_params},
eliminate_common_subexpression{},
dead_code_elimination{},
simplify_reshapes{},
dead_code_elimination{},
simplify_qdq{},
dead_code_elimination{}});
}
} // namespace MIGRAPHX_INLINE_NS
......
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static void quantize_module(module& m, const std::vector<std::string>& ins_names)
{
for(auto ins : iterator_for(m))
{
// instructions are not in the set to be quantized
if(not(contains(ins_names, ins->name()) or contains(ins_names, "all")))
continue;
// skip return and convert instructions
if(contains({"@return", "convert"}, ins->name()))
continue;
if(ins->inputs().empty())
continue;
auto mod_inputs = ins->module_inputs();
auto s = ins->get_shape();
// Convert back to original type before quantizing the inputs
if(mod_inputs.empty())
{
auto r = m.insert_instruction(
std::next(ins), make_op("convert", {{"target_type", s.type()}}), ins);
m.replace_instruction(ins, r);
}
// Convert each of the inputs that are floating point to fp16
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
auto input_type = input->get_shape().type();
if(input_type != shape::float_type and input_type != shape::double_type)
return input;
return m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::half_type}}), input);
});
// Replace inputs
m.replace_instruction(ins, ins->get_operator(), inputs, mod_inputs);
}
}
void quantize_fp16_pass::apply(module& m) const { quantize_module(m, ins_names); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/operation.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_int8.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <numeric>
#include <set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
static std::vector<shape::type_t>& get_quantizable_type()
{
static std::vector<shape::type_t> quantable_types = {
shape::float_type, shape::double_type, shape::half_type};
return quantable_types;
}
void quantize_int8_pass::apply(module& m) const // NOLINT
{
const auto& quantizable_types = get_quantizable_type();
for(auto ins : iterator_for(m))
{
if(ins->name() != "capture")
continue;
auto op_val = ins->get_operator().to_value();
assert(op_val.contains("ins_index"));
auto param_index = op_val.at("ins_index").to<std::size_t>();
auto param = quant_params[param_index];
auto input = ins->inputs().front();
auto s = input->get_shape();
if(contains(quantizable_types, s.type()) and s.type() != shape::int8_type)
{
auto zero_point = m.add_literal(static_cast<int8_t>(param.second));
auto scale = m.add_literal(literal({s.type()}, {1.0f / param.first}));
const auto& lens = s.lens();
scale =
m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), scale);
zero_point = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", lens}}), zero_point);
auto q_in =
m.insert_instruction(ins, make_op("quantizelinear"), input, scale, zero_point);
auto dq_in =
m.insert_instruction(ins, make_op("dequantizelinear"), q_in, scale, zero_point);
m.replace_instruction(ins, dq_in);
}
}
}
void capture_arguments_pass::apply(module& m) const // NOLINT
{
assert(param_index != nullptr);
for(auto ins : iterator_for(m))
{
if(not contains(ins_names, ins->name()))
{
continue;
}
auto inputs = ins->inputs();
std::vector<instruction_ref> new_args;
for(auto input : inputs)
{
auto new_in = m.insert_instruction(ins, op::capture{(*param_index)++, f}, input);
new_args.push_back(new_in);
}
m.replace_instruction(ins, ins->get_operator(), new_args);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/reduce_dims.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
{
std::vector<std::size_t> new_lens;
for(const auto& s : shapes)
{
assert(n < s.lens().size());
if((n + 1) >= s.lens().size())
return false;
auto astride = s.strides()[n];
auto alen = s.lens()[n];
auto bstride = s.strides()[n + 1];
auto blen = s.lens()[n + 1];
if(astride == bstride * blen or alen == 1)
new_lens.push_back(alen * blen);
}
if(new_lens.size() != shapes.size())
return false;
std::size_t i = 0;
for(auto& s : shapes)
{
auto lens = s.lens();
auto strides = s.strides();
lens.erase(lens.begin() + n);
strides.erase(strides.begin() + n);
lens[n] = new_lens[i];
s = shape{s.type(), lens, strides};
i++;
}
return true;
}
void reduce_dim1(std::vector<shape>& shapes)
{
if(std::any_of(shapes.begin(), shapes.end(), [&](const auto& s) {
return s.lens().size() < 2 or s.lens().back() != 1;
}))
return;
for(auto& s : shapes)
{
auto lens = s.lens();
auto strides = s.strides();
lens.pop_back();
strides.pop_back();
s = shape{s.type(), lens, strides};
}
}
std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n)
{
while(reduce_dim(shapes, n) and n < shapes.size()) {}
return n + 1;
}
void reduce_dim_all(std::vector<shape>& shapes)
{
std::size_t n = 0;
while(n < shapes.front().lens().size() - 1)
n = reduce_dim_all(shapes, n);
reduce_dim1(shapes);
}
std::vector<std::size_t> base_lens(const std::vector<shape>& shapes)
{
return std::accumulate(
shapes.begin() + 1, shapes.end(), shapes.front().lens(), [](auto&& lens, auto&& s) {
std::vector<std::size_t> result;
const auto* x = &s.lens();
const auto* y = &lens;
if(x->size() > y->size())
std::swap(x, y);
std::transform(
x->begin(), x->end(), y->begin(), std::back_inserter(result), [&](auto a, auto b) {
return std::max(a, b);
});
return result;
});
}
shape mask_shape(const shape& s, const std::vector<std::size_t>& lens)
{
assert(s.lens().size() == lens.size());
std::vector<std::size_t> rstrides(lens.size());
std::size_t stride = 1;
for(std::size_t i = lens.size() - 1; i < lens.size(); i--)
{
if(lens[i] == s.lens()[i])
{
rstrides[i] = stride;
stride *= lens[i];
}
else if(lens[i] != 1 and s.lens()[i] != 1)
{
return shape{};
}
}
return shape{s.type(), lens, rstrides};
}
std::vector<shape> reduce_dims(const std::vector<shape>& shapes)
{
if(shapes.empty())
return {};
auto result = shapes;
auto base = base_lens(shapes);
for(auto&& s : shapes)
{
if(s.lens().size() != base.size())
return shapes;
if(s.lens() == base)
continue;
auto mshape = mask_shape(s, base);
if(mshape.lens().size() != base.size())
return shapes;
result.push_back(mshape);
}
reduce_dim_all(result);
result.erase(result.begin() + shapes.size(), result.end());
return result;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/register_op.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_map>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::unordered_map<std::string, operation>& op_map()
{
static std::unordered_map<std::string, operation> m; // NOLINT
return m;
}
void register_op(const operation& op) { op_map()[op.name()] = op; }
operation load_op(const std::string& name)
{
return at(op_map(), name, "Operator not found: " + name);
}
bool has_op(const std::string& name) { return op_map().count(name) == 1; }
std::vector<std::string> get_operators()
{
std::vector<std::string> result;
std::transform(op_map().begin(), op_map().end(), std::back_inserter(result), [&](auto&& p) {
return p.first;
});
std::sort(result.begin(), result.end());
return result;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <unordered_map>
#include <migraphx/register_target.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::unordered_map<std::string, target>& target_map()
{
static std::unordered_map<std::string, target> m; // NOLINT
return m;
}
void register_target(const target& t) { target_map()[t.name()] = t; }
target make_target(const std::string& name)
{
const auto it = target_map().find(name);
if(it == target_map().end())
{
MIGRAPHX_THROW("Requested target '" + name + "' is not enabled or not supported");
}
return it->second;
}
std::vector<std::string> get_targets()
{
std::vector<std::string> result;
std::transform(target_map().begin(),
target_map().end(),
std::back_inserter(result),
[&](auto&& p) { return p.first; });
std::sort(result.begin(), result.end());
return result;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/replace_allocate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/op/allocate.hpp>
#include <map>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::unordered_map<instruction_ref, std::string> create_output_names(const module& mod)
{
std::unordered_map<instruction_ref, std::string> mod_output_names{};
auto last = std::prev(mod.end());
if(last->name() == "@return")
{
const auto& prog_outputs = last->inputs();
std::vector<instruction_ref> outputs_alias(prog_outputs.size());
std::transform(prog_outputs.begin(),
prog_outputs.end(),
outputs_alias.begin(),
[](const auto& i) { return instruction::get_output_alias(i); });
std::size_t index = 0;
for(auto ins : outputs_alias)
{
mod_output_names[ins] = mod.name() + ":#output_" + std::to_string(index++);
}
}
else
{
auto ins = instruction::get_output_alias(last);
mod_output_names[ins] = "output";
}
return mod_output_names;
}
void insert_submod_allocations(instruction_ref ins, module& mod, const allocation_model& model)
{
std::vector<instruction_ref> inputs = ins->inputs();
std::vector<module_ref> mod_args = ins->module_inputs();
std::map<std::string, shape> name_shapes;
for(const auto& smod : mod_args)
{
auto ps = smod->get_parameter_shapes();
name_shapes.insert(ps.begin(), ps.end());
}
for(auto& pn : name_shapes)
{
const auto& s = pn.second;
instruction_ref output{};
output = mod.insert_instruction(ins, model.allocate(s));
inputs.push_back(output);
}
mod.replace_instruction(ins, ins->get_operator(), inputs, mod_args);
}
void replace_allocate::apply(module& m) const
{
auto mod_output_names = create_output_names(m);
bool main_offload_copy = m.name() == "main" ? this->offload_copy : false;
for(auto ins : iterator_for(m))
{
auto op = ins->get_operator();
auto op_name = op.name();
// check if allocations from submodules need to be inserted
// for now, only the "if" operator is affected
if(op_name == "if")
{
insert_submod_allocations(ins, m, model);
continue;
}
if(op_name != "allocate")
continue;
auto s = ins->get_shape();
if(not main_offload_copy and model.needs_out_params() and contains(mod_output_names, ins))
{
auto out_param = m.add_parameter(mod_output_names[ins], s);
m.replace_instruction(ins, out_param);
continue;
}
m.replace_instruction(
ins,
m.insert_instruction(ins,
make_op(model.name(), migraphx::value{{"shape", to_value(s)}})));
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/dfor.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_batchnorm::apply(program& p) const
void rewrite_batchnorm::apply(module& m) const
{
for(auto ins : iterator_for(p))
for(auto ins : iterator_for(m))
{
if(ins->name() != "batch_norm_inference")
continue;
......@@ -26,7 +28,8 @@ void rewrite_batchnorm::apply(program& p) const
if(any_of({gamma, bias, mean, variance}, [](auto arg) { return arg.empty(); }))
continue;
auto s = shape{ins->get_shape().type(), {ins->get_shape().lens()[1]}};
std::vector<std::size_t> lens = ins->inputs()[1]->get_shape().lens();
shape s{ins->get_shape().type(), lens};
// Get epsilon
auto bn_op = any_cast<op::batch_norm_inference>(ins->get_operator());
auto epsilon = bn_op.epsilon;
......@@ -43,13 +46,13 @@ void rewrite_batchnorm::apply(program& p) const
});
auto broadcast = op::broadcast{1, ins->get_shape().lens()};
auto a_ins = p.add_literal({a.get_shape(), a.data()});
auto a_broadcast = p.insert_instruction(ins, broadcast, a_ins);
auto mul = p.insert_instruction(ins, op::mul{}, ins->inputs().front(), a_broadcast);
auto b_ins = p.add_literal({b.get_shape(), b.data()});
auto b_broadcast = p.insert_instruction(ins, broadcast, b_ins);
auto add = p.insert_instruction(ins, op::add{}, mul, b_broadcast);
p.replace_instruction(ins, add);
auto a_ins = m.add_literal({a.get_shape(), a.data()});
auto a_broadcast = m.insert_instruction(ins, broadcast, a_ins);
auto mul = m.insert_instruction(ins, make_op("mul"), ins->inputs().front(), a_broadcast);
auto b_ins = m.add_literal({b.get_shape(), b.data()});
auto b_broadcast = m.insert_instruction(ins, broadcast, b_ins);
auto add = m.insert_instruction(ins, make_op("add"), mul, b_broadcast);
m.replace_instruction(ins, add);
}
}
......
......@@ -4,39 +4,54 @@
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/reduce_max.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_pooling::apply(program& prog) const
void rewrite_pooling::apply(module& m) const
{
for(auto ins : iterator_for(prog))
for(auto ins : iterator_for(m))
{
if(ins->name() != "pooling")
continue;
if(ins->get_shape().lens().size() != 4)
continue;
if(ins->inputs().empty())
continue;
auto&& s = ins->inputs().front()->get_shape();
if(not s.standard())
continue;
auto&& op = any_cast<op::pooling>(ins->get_operator());
if(op.mode != "average")
continue;
if(op.padding[0] != 0 and op.padding[1] != 0)
if(!std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; }))
continue;
if(op.stride[0] != 1 and op.stride[1] != 1)
if(!std::all_of(op.stride.begin(), op.stride.end(), [](auto i) { return i == 1; }))
continue;
if(s.lens()[2] != op.lengths[0] and s.lens()[3] != op.lengths[1])
auto lens = s.lens();
if(!std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end()))
continue;
std::int64_t n = s.lens()[0];
std::int64_t c = s.lens()[1];
auto reshape =
prog.insert_instruction(ins, op::reshape{{n * c, -1}}, ins->inputs().front());
auto pooling = prog.insert_instruction(ins, op::reduce_mean{{1}}, reshape);
prog.replace_instruction(ins, op::reshape{{n, c, 1, 1}}, pooling);
auto reshape = m.insert_instruction(
ins, make_op("reshape", {{"dims", {n * c, -1}}}), ins->inputs().front());
instruction_ref pooling{};
// average pooling
if(op.mode == op::pooling_mode::average)
{
pooling = m.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape);
}
// max pooling
else
{
pooling = m.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape);
}
std::vector<int64_t> rsp_lens(lens.size(), 1);
rsp_lens[0] = n;
rsp_lens[1] = c;
m.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling);
}
}
......
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void apply_quantizelinear(module& m, instruction_ref ins)
{
assert(ins->name() == "quantizelinear");
auto x = ins->inputs()[0];
auto y_scale = ins->inputs()[1];
if(x->get_shape().type() != y_scale->get_shape().type())
{
x = m.insert_instruction(ins, make_op("convert", {{"target_type", shape::float_type}}), x);
}
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("round"), div);
if(ins->inputs().size() == 3)
{
auto zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]);
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
}
int64_t max_quant = 0;
int64_t min_quant = 0;
ins->get_shape().visit_type([&](auto qt) {
max_quant = qt.max();
min_quant = qt.min();
});
auto s = add_zero_point->get_shape();
std::vector<int> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant);
auto min_arg = m.add_literal(literal(s, min_data));
auto max_arg = m.add_literal(literal(s, max_data));
auto saturate = m.insert_instruction(ins, make_op("clip"), add_zero_point, min_arg, max_arg);
m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
}
void apply_dequantizelinear(module& m, instruction_ref ins)
{
assert(ins->name() == "dequantizelinear");
auto x = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[0]);
auto x_scale = ins->inputs()[1];
if(ins->inputs().size() == 3)
{
auto x_zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]);
x = m.insert_instruction(ins, make_op("sub"), x, x_zero_point);
}
m.replace_instruction(ins, make_op("mul"), x, x_scale);
}
void rewrite_quantization::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(ins->name() == "quantizelinear")
{
apply_quantizelinear(m, ins);
}
else if(ins->name() == "dequantizelinear")
{
apply_dequantizelinear(m, ins);
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment