Unverified Commit 23cb7917 authored by Brian Pickrell's avatar Brian Pickrell Committed by GitHub
Browse files

Merge branch 'develop' into blas_tuning

parents b5fcc0bc ea32ca70
......@@ -35,8 +35,7 @@ namespace onnx {
struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
{
const std::set<shape::type_t> valid_types = {
shape::float_type, shape::half_type, shape::double_type};
std::set<shape::type_t> valid_types = {shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"RandomUniform"}, {"RandomUniformLike"}}; }
......@@ -97,7 +96,7 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
if(contains(info.attributes, "seed"))
gen.seed(info.attributes.at("seed").f());
std::uniform_real_distribution<> d(high, low);
std::uniform_real_distribution<> d(low, high);
std::vector<double> rand_vals(out_shape.elements());
std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); });
......
......@@ -30,8 +30,11 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
// Use a literal instruction to replace the shape since, output of
// shape operator are literals in migraphx
/**
* If static shape input, creates a literal in migraphx.
* If dynamic shape input, creates a dimensions_of operator in migraphx (runtime evaluation of
* shape).
*/
struct parse_shape : op_parser<parse_shape>
{
std::vector<op_desc> operators() const { return {{"Shape"}}; }
......@@ -43,14 +46,55 @@ struct parse_shape : op_parser<parse_shape>
{
if(args.size() != 1)
MIGRAPHX_THROW("Shape: operator should have 1 operand");
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
std::vector<int64_t> vec_shape(arg_shape.size());
migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
return int64_t(i);
});
auto input_shape = args[0]->get_shape();
int input_ndim = input_shape.ndim();
std::size_t start = 0;
std::size_t end = input_ndim;
// Normalizing the start and end is handled here because of how the static shape version
// works. Clamping to [-r, r], where r is ndim of input and then making positive.
auto normalize_ind = [&](int64_t ind) {
if(ind < (-1 * input_ndim))
{
ind = -1 * input_ndim;
}
if(ind > input_ndim)
{
ind = input_ndim;
}
return (ind >= 0) ? ind : input_ndim + ind;
};
if(contains(info.attributes, "end"))
{
end = normalize_ind(info.attributes.at("end").i());
}
if(contains(info.attributes, "start"))
{
start = normalize_ind(info.attributes.at("start").i());
}
if(end <= start)
{
MIGRAPHX_THROW("PARSE_SHAPE: ending axis <= starting axis, end: " +
std::to_string(end) + " start: " + std::to_string(start));
}
if(input_shape.dynamic())
{
return info.add_instruction(make_op("dimensions_of", {{"start", start}, {"end", end}}),
args[0]);
}
else
{
std::size_t output_ndim = end - start;
std::vector<int64_t> vec_shape(output_ndim);
migraphx::shape s(migraphx::shape::int64_type, {output_ndim});
std::vector<std::size_t> input_lens = input_shape.lens();
std::transform(input_lens.begin() + start,
input_lens.begin() + end,
vec_shape.begin(),
[](auto i) { return int64_t(i); });
return info.add_literal(migraphx::literal{s, vec_shape});
}
}
};
} // namespace onnx
......
......@@ -56,6 +56,7 @@ struct parse_where : op_parser<parse_where>
auto lens =
compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(args[0]->get_shape().lens() != lens)
{
args[0] =
......
......@@ -68,12 +68,18 @@ void run_pass(program& prog, const pass& p, tracer trace)
struct module_pm : module_pass_manager
{
module* mod = nullptr;
module* root_mod = nullptr;
tracer* t = nullptr;
module* common_parent = nullptr;
program* prog = nullptr;
module_pm(module* pmod = nullptr, tracer* pt = nullptr) : mod(pmod), t(pt) {}
module_pm(module* pmod = nullptr, module* rmod = nullptr, tracer* pt = nullptr)
: mod(pmod), root_mod(rmod), t(pt)
{
}
template <class... Ts>
void trace(Ts&&... xs) const
{
......@@ -97,6 +103,8 @@ struct module_pm : module_pass_manager
virtual module* get_root_module() override
{
if(root_mod != nullptr)
return root_mod;
assert(prog);
return prog->get_main_module();
}
......@@ -123,33 +131,24 @@ struct module_pm : module_pass_manager
module& get_module(module_pass_manager& mpm) { return mpm.get_module(); }
void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
{
if(enabled(MIGRAPHX_TRACE_PASSES{}))
trace = tracer{std::cout};
for(const auto& p : passes)
{
module_pm{&mod, &trace}.run_pass(p);
}
}
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
void run_passes(program& prog, module_ref root_mod, const std::vector<pass>& passes, tracer trace)
{
if(enabled(MIGRAPHX_TRACE_PASSES{}))
trace = tracer{std::cout};
std::unordered_set<module_ref> visited;
for(const auto& p : passes)
{
auto mods = prog.get_modules();
auto tree = prog.get_module_tree();
std::vector<module_ref> sub_mods = root_mod->get_sub_modules();
sub_mods.insert(sub_mods.begin(), root_mod);
visited.clear();
for(const auto& mod : reverse(mods))
for(const auto& mod : reverse(sub_mods))
{
if(mod->bypass())
continue;
if(not visited.insert(mod).second)
continue;
module_pm mpm{mod, &trace};
module_pm mpm{mod, root_mod, &trace};
mpm.prog = &prog;
auto parents = range(tree.equal_range(mod));
auto nparents = distance(parents);
......@@ -167,5 +166,20 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
}
}
void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
{
if(enabled(MIGRAPHX_TRACE_PASSES{}))
trace = tracer{std::cout};
for(const auto& p : passes)
{
module_pm{&mod, &mod, &trace}.run_pass(p);
}
}
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
{
run_passes(prog, prog.get_main_module(), passes, trace);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -74,5 +74,15 @@ std::vector<int64_t> find_permutation(const std::vector<shape>& shapes)
return it->first;
}
std::vector<shape> normalize_permutation(const std::vector<shape>& shapes)
{
auto result = shapes;
auto perm = find_permutation(shapes);
std::transform(result.begin(), result.end(), result.begin(), [&](auto s) {
return reorder_shape(s, perm);
});
return result;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -21,6 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/version.h>
#include <migraphx/compile_options.hpp>
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
......@@ -38,12 +40,14 @@
#include <migraphx/make_op.hpp>
#include <migraphx/marker.hpp>
#include <migraphx/supported_segments.hpp>
#include <iostream>
#include <queue>
#include <sstream>
#include <algorithm>
#include <set>
#include <unordered_map>
#include <utility>
#include <unordered_set>
#include <map>
#include <cassert>
......@@ -53,12 +57,23 @@ inline namespace MIGRAPHX_INLINE_NS {
using milliseconds = std::chrono::duration<double, std::milli>;
struct mark_instruction_target
{
std::size_t target_id = 0;
std::string name() const { return "mark_instruction_target"; }
void apply(module& m) const
{
for(auto& ins : m)
ins.set_target_id(target_id);
}
};
struct program_impl
{
// A map is used to keep references to modules of the program
std::unordered_map<std::string, module> modules;
context ctx;
std::string target_name;
std::vector<context> contexts;
std::vector<target> targets;
};
program::program() : impl(std::make_unique<program_impl>()) { this->create_module("main"); }
......@@ -82,14 +97,8 @@ void program::assign(const program& p)
{
impl = std::make_unique<program_impl>();
}
else if(not impl->modules.empty())
{
impl->modules.clear();
}
impl->ctx = p.impl->ctx;
impl->target_name = p.impl->target_name;
impl->modules = p.impl->modules;
*impl = *p.impl;
// build a map from old ins to new ins
// Build a map from old module to new module
......@@ -152,7 +161,11 @@ std::vector<shape> program::get_output_shapes() const
return mm->get_output_shapes();
}
context& program::get_context() const { return impl->ctx; }
context& program::get_context() const
{
assert(impl->contexts.size() == 1);
return impl->contexts.front();
}
instruction_ref program::validate() const
{
......@@ -203,20 +216,106 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta
return p;
}
bool program::is_compiled() const { return not this->impl->target_name.empty(); }
bool program::is_compiled() const { return not this->impl->contexts.empty(); }
void program::compile(const std::vector<target>& targets, std::vector<compile_options> compile_opts)
{
// Gather all the target roots
std::unordered_multimap<std::size_t, module_ref> roots;
auto mods = this->get_modules();
for(const auto* mod : mods)
{
for(const auto& ins : *mod)
{
if(ins.name() != "run_on_target")
continue;
auto v = ins.get_operator().to_value();
module_ref root = ins.module_inputs().front();
std::size_t root_target_id = v.at("target_id").to<std::size_t>();
assert(root_target_id < targets.size());
roots.insert({root_target_id, root});
}
}
auto trace = tracer{};
// TODO: Add tracer based on compile options
if(enabled(MIGRAPHX_TRACE_COMPILE{}))
trace = tracer{std::cout};
trace(*this);
trace();
// It is assumed that all instructions outside of any root module would run on "ref" target
// Ref target may or may not be passed as one of the target for the "compile()".
// If it is not passed, Create one and add context of it into the map.
auto target_idx = [&](const std::string& t_name) {
return static_cast<std::size_t>(
std::find_if(
targets.begin(), targets.end(), [&](const auto& t) { return t.name() == t_name; }) -
targets.begin());
};
std::size_t ref_target_id = target_idx("ref");
if(ref_target_id == targets.size())
{
this->impl->contexts.resize(targets.size() + 1);
this->impl->contexts[ref_target_id] = migraphx::make_target("ref").get_context();
// users could pass lessers compile_ops than targets, in that case use default compile_opts
compile_opts.resize(targets.size() + 1, migraphx::compile_options{});
}
else
{
this->impl->contexts.resize(targets.size());
compile_opts.resize(targets.size(), migraphx::compile_options{});
}
// mark all the instruction as ref target first, later change target_id based on root-target
run_passes(*this, {mark_instruction_target{ref_target_id}});
// Run passes on each root target
for(const auto i : range(targets.size()))
{
const auto& root_target = targets.at(i);
auto root_target_id = i;
auto root_modules_range = roots.equal_range(root_target_id);
this->impl->contexts[root_target_id] = root_target.get_context();
for(const auto& [id, current_mod] : range(root_modules_range))
{
auto passes = root_target.get_passes(this->impl->contexts[root_target_id],
compile_opts[root_target_id]);
passes.push_back(mark_instruction_target{static_cast<size_t>(root_target_id)});
run_passes(*this, current_mod, passes, trace);
auto invalid = current_mod->validate();
if(invalid != current_mod->end())
{
MIGRAPHX_THROW("Invalid module " + current_mod->name() +
" from compilation at instruction " +
std::to_string(std::distance(current_mod->begin(), invalid)));
}
auto dangling = current_mod->find_dangling_reference();
if(dangling != current_mod->end())
{
auto index = std::distance(current_mod->begin(), dangling);
MIGRAPHX_THROW("Dangling reference in module " + current_mod->name() +
" from instruction " + std::to_string(index));
}
}
}
this->finalize();
}
void program::compile(const target& t, compile_options options)
{
// todo: combine with multi-target compile method
assert(not this->is_compiled());
this->impl->target_name = t.name();
this->impl->ctx = t.get_context();
this->impl->targets = {t};
this->impl->contexts = {t.get_context()};
if(enabled(MIGRAPHX_TRACE_COMPILE{}))
options.trace = tracer{std::cout};
options.trace(*this);
options.trace();
auto&& passes = t.get_passes(this->impl->ctx, options);
auto&& passes = t.get_passes(this->impl->contexts.front(), options);
run_passes(*this, passes, options.trace);
auto mods = this->get_modules();
// Validate and finalize
......@@ -235,14 +334,14 @@ void program::compile(const target& t, compile_options options)
MIGRAPHX_THROW("Dangling reference in module " + mod->name() + " from instruction " +
std::to_string(index));
}
mod->finalize(this->impl->ctx);
mod->finalize(this->impl->contexts);
}
}
void program::finalize()
{
auto* mm = this->get_main_module();
mm->finalize(this->impl->ctx);
mm->finalize(this->impl->contexts);
}
template <class T>
......@@ -259,6 +358,31 @@ std::string classify(T x)
}
}
void print_statistics(std::ostream& os, const argument& a)
{
a.visit(
[&](auto t) {
os << "Min value: " << *std::min_element(t.begin(), t.end()) << ", ";
os << "Max value: " << *std::max_element(t.begin(), t.end()) << ", ";
double num_elements = t.size();
auto mean = std::accumulate(t.begin(), t.end(), 0.0) / num_elements;
auto stddev = std::sqrt(
std::accumulate(t.begin(),
t.end(),
0.0,
[&](auto r, auto v) { return r + std::pow((v - mean), 2.0); }) /
num_elements);
os << "Mean: " << mean << ", ";
os << "StdDev: " << stddev << "\n";
},
[&](const auto& xs) {
for(const auto& x : xs)
{
print_statistics(os, x);
}
});
}
std::unordered_set<std::string> classify_argument(const argument& a)
{
std::unordered_set<std::string> result;
......@@ -304,16 +428,15 @@ void preview_argument(std::ostream& os, const argument& a)
template <class F>
std::vector<argument> generic_eval(const module* mod,
context& ctx,
std::vector<context>& ctx,
std::unordered_map<std::string, argument> params,
std::unordered_map<instruction_ref, argument> results,
F make_trace)
F trace)
{
assert(mod->validate() == mod->end());
results.reserve(mod->size() * 2);
std::vector<argument> values;
values.reserve(16);
auto trace = make_trace(mod);
for(auto ins : iterator_for(*mod))
{
assert(results.find(ins) == results.end());
......@@ -366,17 +489,21 @@ std::vector<argument> generic_eval(const module* mod,
assert(results.find(i) != results.end());
return results[i];
});
const auto& mod_args = ins->module_inputs();
auto module_eval = [&](module_ref smod,
const std::unordered_map<std::string, argument>& inputs) {
auto ssctx = ctx;
return generic_eval(smod, ssctx, inputs, results, make_trace);
return generic_eval(smod, ctx, inputs, results, trace);
};
results.emplace(ins, trace(ins, [&] {
return ins->normalized_operator().compute(
ctx, ins->get_shape(), values, mod_args, module_eval);
results.emplace(
ins, trace(ins, [&] {
auto op = ins->normalized_operator();
if(op.is_context_free())
return op.compute(ins->get_shape(), values, mod_args, module_eval);
if(ins->get_target_id() >= ctx.size())
MIGRAPHX_THROW("No context available for " + op.name());
return op.compute(
ctx[ins->get_target_id()], ins->get_shape(), values, mod_args, module_eval);
}));
}
assert(results.find(ins) != results.end());
......@@ -390,44 +517,25 @@ std::vector<argument> generic_eval(const module* mod,
template <class F>
std::vector<argument> generic_eval(const program& p,
context& ctx,
std::vector<context>& ctx,
std::unordered_map<std::string, argument> params,
F make_trace)
F trace)
{
const module* mm = p.get_main_module();
return generic_eval(mm, ctx, params, {}, make_trace);
return generic_eval(mm, ctx, params, {}, trace);
}
std::vector<argument> program::eval(parameter_map params, execution_environment exec_env) const
{
auto& ctx = this->impl->ctx;
#ifndef NDEBUG
auto with_check_context = [&](auto f) {
return [=, &ctx](auto&&) {
auto sctx = std::make_shared<context>(ctx);
auto check_context = [=, &ctx](auto g) {
assert(is_shared(ctx, *sctx));
auto x = g();
*sctx = ctx;
return x;
};
return [=](auto&&... xs) { return f(xs..., check_context); };
};
};
#else
auto with_check_context = [](auto f) {
return [=](auto&&) {
return [=](auto&&... xs) { return f(xs..., [](auto g) { return g(); }); };
};
};
#endif
auto& contexts = this->impl->contexts;
auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{});
std::vector<argument> ret;
if(exec_env.async)
{
ctx.wait_for(exec_env.queue);
assert(contexts.size() == 1);
contexts.front().wait_for(exec_env.queue);
}
if(trace_level > 0)
......@@ -439,32 +547,42 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
instruction::print(ss, x, ins_names);
ins_out[x] = ss.str();
});
ret = generic_eval(*this,
ctx,
std::move(params),
with_check_context([&](auto& ins, auto f, auto&& check_context) {
ret = generic_eval(*this, contexts, std::move(params), [&](instruction_ref ins, auto f) {
const auto& ctx = contexts[ins->get_target_id()];
ctx.finish();
std::cout << "Run instruction: " << ins_out.at(ins) << std::endl;
timer t{};
auto result = check_context(f);
auto result = f();
double t1 = t.record<milliseconds>();
ctx.finish();
double t2 = t.record<milliseconds>();
std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl;
if(trace_level > 1 and ins->name().front() != '@' and
ins->name() != "load" and not result.empty())
if(trace_level > 1 and ins->name().front() != '@' and ins->name() != "load" and
not result.empty())
{
migraphx::argument buffer;
try
{
const target& tgt = this->impl->targets.at(ins->get_target_id());
buffer = tgt.copy_from(result);
}
catch(const migraphx::exception&)
{
target tgt = make_target(this->impl->target_name);
auto buffer = tgt.copy_from(result);
// instruction was run on host then no need to copy buffer from target
buffer = result;
}
catch(...)
{
MIGRAPHX_THROW("MIGraphX program execution with MIGRAPHX_TRACE_EVAL failed.\n");
}
if(trace_level == 2)
{
std::cout << "Output has "
<< to_string_range(classify_argument(buffer))
std::cout << "Output has " << to_string_range(classify_argument(buffer))
<< std::endl;
std::cout << "Output: ";
preview_argument(std::cout, buffer);
std::cout << std::endl;
print_statistics(std::cout, buffer);
}
else
{
......@@ -472,36 +590,49 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
}
}
return result;
}));
});
}
else
{
ret = generic_eval(*this,
ctx,
std::move(params),
with_check_context([&](auto&, auto f, auto&& check_context) {
return check_context(f);
}));
ret = generic_eval(*this, contexts, std::move(params), [&](auto&&, auto f) { return f(); });
}
if(exec_env.async)
{
ctx.finish_on(exec_env.queue);
assert(contexts.size() == 1);
contexts.front().finish_on(exec_env.queue);
}
return ret;
}
const int program_file_version = 5;
void program::finish() const
{
for(const auto& ctx : this->impl->contexts)
ctx.finish();
}
std::string get_migraphx_version()
{
std::stringstream ss;
ss << std::to_string(MIGRAPHX_VERSION_MAJOR) << "." << std::to_string(MIGRAPHX_VERSION_MINOR)
<< "." << std::to_string(MIGRAPHX_VERSION_PATCH);
return ss.str();
}
/*
program file version is for the data structure or format of the MXR file. Version should be bumped
if any changes occur to the format of the MXR file.
*/
const int program_file_version = 6;
value program::to_value() const
{
value result;
result["version"] = program_file_version;
result["target"] = this->impl->target_name;
if(not this->impl->target_name.empty())
result["context"] = this->impl->ctx.to_value();
result["migraphx_version"] = get_migraphx_version();
result["targets"] = migraphx::to_value(this->impl->targets);
result["contexts"] = migraphx::to_value(this->impl->contexts);
value module_vals = value::object{};
std::unordered_map<instruction_ref, std::string> names;
for(auto& mod : this->get_modules())
......@@ -597,7 +728,7 @@ static void mod_from_val(module_ref mod,
std::back_inserter(module_inputs),
[&](const value& i) { return map_mods.at(i.to<std::string>()); });
for(auto& smod : module_inputs)
for(const auto& smod : module_inputs)
{
mod_from_val(smod, v, instructions, map_mods);
}
......@@ -626,15 +757,27 @@ void program::from_value(const value& v)
auto version = v.at("version").to<int>();
if(version != program_file_version)
{
MIGRAPHX_THROW("Warning: Program version mismatch");
MIGRAPHX_THROW(
"Error: Program version mismatch. MXR file was created using program file version: " +
std::to_string(version) + ", while installed MIGraphX is using program file version: " +
std::to_string(program_file_version) +
", Try regenerating MXR file using installed MIGraphX and running again.");
}
this->impl->target_name = v.at("target").to<std::string>();
if(not this->impl->target_name.empty())
auto migx_version = v.at("migraphx_version").to<std::string>();
if(migx_version != get_migraphx_version())
{
target t = make_target(this->impl->target_name);
this->impl->ctx = t.get_context();
this->impl->ctx.from_value(v.at("context"));
std::cout << "WARNING: MXR File was created using MIGraphX version: " << migx_version
<< ", while installed MIGraphX is at version: " << get_migraphx_version()
<< ", operators implementation could be mismatched.";
}
migraphx::from_value(v.at("targets"), this->impl->targets);
for(auto i : range(this->impl->targets.size()))
{
this->impl->contexts.push_back(this->impl->targets[i].get_context());
this->impl->contexts.back().from_value(v.at("contexts")[i]);
}
auto module_vals = v.at("modules");
......@@ -655,6 +798,8 @@ void program::from_value(const value& v)
auto* mm = get_main_module();
mod_from_val(mm, module_vals, map_insts, map_mods);
// Finalize a compiled model
if(not this->impl->contexts.empty())
this->finalize();
}
......@@ -675,19 +820,19 @@ std::string perf_group(const operation& op)
void program::mark(const parameter_map& params, marker&& m)
{
auto& ctx = this->impl->ctx;
auto& ctx = this->impl->contexts;
// Run once by itself
eval(params);
ctx.finish();
this->finish();
// Start marking
m.mark_start(*this);
generic_eval(*this, ctx, params, always([&](auto ins, auto f) {
generic_eval(*this, ctx, params, [&](auto ins, auto f) {
argument result;
m.mark_start(ins);
result = f();
m.mark_stop(ins);
return result;
}));
});
m.mark_stop(*this);
}
......@@ -696,10 +841,10 @@ void program::perf_report(std::ostream& os,
parameter_map params,
std::size_t batch) const
{
auto& ctx = this->impl->ctx;
auto& ctx = this->impl->contexts;
// Run once by itself
eval(params);
ctx.finish();
this->finish();
// Run and time entire program
std::vector<double> total_vec;
total_vec.reserve(n);
......@@ -707,28 +852,28 @@ void program::perf_report(std::ostream& os,
{
total_vec.push_back(time<milliseconds>([&] {
eval(params);
ctx.finish();
this->finish();
}));
}
std::sort(total_vec.begin(), total_vec.end());
std::unordered_map<instruction_ref, std::vector<double>> ins_vec;
// Fill the map
generic_eval(*this, ctx, params, always([&](auto ins, auto) {
generic_eval(*this, ctx, params, [&](auto ins, auto) {
ins_vec[ins].reserve(n);
return argument{ins->get_shape(), nullptr};
}));
});
// Run and time each instruction
for(std::size_t i = 0; i < n; i++)
{
generic_eval(*this, ctx, params, always([&](auto ins, auto f) {
generic_eval(*this, ctx, params, [&](auto ins, auto f) {
argument result;
ins_vec[ins].push_back(time<milliseconds>([&] {
result = f();
ctx.finish();
this->impl->contexts[ins->get_target_id()].finish();
}));
return result;
}));
});
}
for(auto&& p : ins_vec)
std::sort(p.second.begin(), p.second.end());
......@@ -861,7 +1006,9 @@ void program::print_py(std::ostream& os) const
os << "p = migraphx.program()\n";
for(auto& mod : vec_modules)
{
std::string var_name = "m" + mod->name();
std::string var_name = "m";
if(mod->name() != "main")
var_name += mod->name();
os << var_name << " = ";
if(mod->name() == "main")
os << "p.get_main_module()";
......@@ -894,10 +1041,10 @@ void program::print_cpp(std::ostream& os) const
void program::dry_run(std::unordered_map<std::string, argument> params) const
{
auto& ctx = this->impl->ctx;
generic_eval(*this, ctx, std::move(params), always([](auto ins, auto&&...) {
auto& ctx = this->impl->contexts;
generic_eval(*this, ctx, std::move(params), [](auto ins, auto&&...) {
return argument{ins->get_shape(), nullptr};
}));
});
}
void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const
......@@ -1039,17 +1186,25 @@ void program::remove_unused_modules()
std::vector<module*> unused;
generic_get_unused_modules(
impl->modules, generic_get_modules(this->get_main_module()), std::back_inserter(unused));
for(auto* m : unused)
for(const auto* m : unused)
this->remove_module(m->name());
}
program& program::sort()
{
for(auto& pp : this->impl->modules)
std::queue<migraphx::module_ref> mqueue;
mqueue.push(get_main_module());
while(not mqueue.empty())
{
pp.second.sort();
module_ref current_mod = mqueue.front();
current_mod->sort();
mqueue.pop();
auto child_mods = current_mod->get_sub_modules(true);
for(auto& sub_mod : child_mods)
{
mqueue.push(sub_mod);
}
}
return *this;
}
......
......@@ -34,7 +34,7 @@ void promote_literals::apply(module_pass_manager& mpm) const
{
module& m = mpm.get_module();
module_ref root_module = mpm.get_root_module();
if(m.name() == "main")
if(m == *root_module)
return;
for(auto ins : iterator_for(m))
......
......@@ -23,14 +23,25 @@
#####################################################################################
option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON)
add_library(migraphx_py py_loader.cpp)
migraphx_generate_export_header(migraphx_py)
target_include_directories(migraphx_py PRIVATE include)
target_link_libraries(migraphx_py PUBLIC migraphx)
rocm_install_targets(TARGETS migraphx_py INCLUDE include)
if(MIGRAPHX_ENABLE_PYTHON)
include(PythonModules)
add_custom_target(migraphx_py)
foreach(PYTHON_VERSION ${PYTHON_VERSIONS})
py_add_module(migraphx_py_${PYTHON_VERSION} migraphx_py.cpp PYTHON_VERSION ${PYTHON_VERSION} PYTHON_MODULE migraphx)
target_link_libraries(migraphx_py_${PYTHON_VERSION} PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets)
py_add_module(migraphx_pybind_${PYTHON_VERSION} migraphx_py.cpp PYTHON_VERSION ${PYTHON_VERSION} PYTHON_MODULE migraphx)
target_link_libraries(migraphx_pybind_${PYTHON_VERSION} PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets)
rocm_install_targets(TARGETS migraphx_pybind_${PYTHON_VERSION})
add_dependencies(migraphx_py migraphx_pybind_${PYTHON_VERSION})
add_library(migraphx_py_${PYTHON_VERSION} py.cpp)
target_include_directories(migraphx_py_${PYTHON_VERSION} PRIVATE include)
target_link_libraries(migraphx_py_${PYTHON_VERSION} PUBLIC migraphx)
target_link_libraries(migraphx_py_${PYTHON_VERSION} PRIVATE pybind11::pybind11 python${PYTHON_VERSION}::runtime)
rocm_install_targets(TARGETS migraphx_py_${PYTHON_VERSION})
add_dependencies(migraphx_py migraphx_py_${PYTHON_VERSION})
endforeach()
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PY_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_PY_HPP
#include <migraphx/config.hpp>
#include <migraphx/program.hpp>
#include <migraphx/py/export.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_PY_EXPORT program load_py(const std::string& filename);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_PY_HPP
......@@ -95,6 +95,10 @@ void visit_py(T x, F f)
{
f(x.template cast<std::string>());
}
else if(py::isinstance<migraphx::shape::dynamic_dimension>(x))
{
f(migraphx::to_value(x.template cast<migraphx::shape::dynamic_dimension>()));
}
else
{
MIGRAPHX_THROW("VISIT_PY: Unsupported data type!");
......@@ -165,6 +169,9 @@ template <class T>
py::buffer_info to_buffer_info(T& x)
{
migraphx::shape s = x.get_shape();
assert(s.type() != migraphx::shape::tuple_type);
if(s.dynamic())
MIGRAPHX_THROW("MIGRAPHX PYTHON: dynamic shape argument passed to to_buffer_info");
auto strides = s.strides();
std::transform(
strides.begin(), strides.end(), strides.begin(), [&](auto i) { return i * s.type_size(); });
......@@ -177,7 +184,7 @@ py::buffer_info to_buffer_info(T& x)
b = py::buffer_info(x.data(),
as.size(),
py::format_descriptor<bool>::format(),
s.lens().size(),
s.ndim(),
s.lens(),
strides);
}
......@@ -186,7 +193,7 @@ py::buffer_info to_buffer_info(T& x)
b = py::buffer_info(x.data(),
as.size(),
py::format_descriptor<decltype(as())>::format(),
s.lens().size(),
s.ndim(),
s.lens(),
strides);
}
......@@ -241,6 +248,13 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def(py::init([](py::kwargs kwargs) {
auto v = migraphx::to_value(kwargs);
auto t = migraphx::shape::parse_type(v.get("type", "float"));
if(v.contains("dyn_dims"))
{
auto dyn_dims =
migraphx::from_value<std::vector<migraphx::shape::dynamic_dimension>>(
v.at("dyn_dims"));
return migraphx::shape(t, dyn_dims);
}
auto lens = v.get<std::size_t>("lens", {1});
if(v.contains("strides"))
return migraphx::shape(t, lens, v.at("strides").to_vector<std::size_t>());
......@@ -250,15 +264,18 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("type", &migraphx::shape::type)
.def("lens", &migraphx::shape::lens)
.def("strides", &migraphx::shape::strides)
.def("ndim", &migraphx::shape::ndim)
.def("elements", &migraphx::shape::elements)
.def("bytes", &migraphx::shape::bytes)
.def("type_string", &migraphx::shape::type_string)
.def("type_size", &migraphx::shape::type_size)
.def("dyn_dims", &migraphx::shape::dyn_dims)
.def("packed", &migraphx::shape::packed)
.def("transposed", &migraphx::shape::transposed)
.def("broadcasted", &migraphx::shape::broadcasted)
.def("standard", &migraphx::shape::standard)
.def("scalar", &migraphx::shape::scalar)
.def("dynamic", &migraphx::shape::dynamic)
.def("__eq__", std::equal_to<migraphx::shape>{})
.def("__ne__", std::not_equal_to<migraphx::shape>{})
.def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); });
......@@ -266,6 +283,15 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::enum_<migraphx::shape::type_t>(shape_cls, "type_t")
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM);
py::class_<migraphx::shape::dynamic_dimension>(shape_cls, "dynamic_dimension")
.def(py::init<>())
.def(py::init<std::size_t, std::size_t>())
.def(py::init<std::size_t, std::size_t, std::set<std::size_t>>())
.def_readwrite("min", &migraphx::shape::dynamic_dimension::min)
.def_readwrite("max", &migraphx::shape::dynamic_dimension::max)
.def_readwrite("optimals", &migraphx::shape::dynamic_dimension::optimals)
.def("is_fixed", &migraphx::shape::dynamic_dimension::is_fixed);
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
.def(py::init([](py::buffer b) {
......@@ -440,13 +466,18 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
"parse_onnx",
[](const std::string& filename,
unsigned int default_dim_value,
migraphx::shape::dynamic_dimension default_dyn_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>
map_dyn_input_dims,
bool skip_unknown_operators,
bool print_program_on_error,
int64_t max_loop_iterations) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.default_dyn_dim_value = default_dyn_dim_value;
options.map_input_dims = map_input_dims;
options.map_dyn_input_dims = map_dyn_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
options.max_loop_iterations = max_loop_iterations;
......@@ -454,8 +485,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("default_dim_value") = 0,
py::arg("default_dyn_dim_value") = migraphx::shape::dynamic_dimension{1, 1},
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("map_dyn_input_dims") =
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10);
......@@ -464,20 +498,28 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
"parse_onnx_buffer",
[](const std::string& onnx_buffer,
unsigned int default_dim_value,
migraphx::shape::dynamic_dimension default_dyn_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>
map_dyn_input_dims,
bool skip_unknown_operators,
bool print_program_on_error) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.default_dyn_dim_value = default_dyn_dim_value;
options.map_input_dims = map_input_dims;
options.map_dyn_input_dims = map_dyn_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
return migraphx::parse_onnx_buffer(onnx_buffer, options);
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("default_dim_value") = 0,
py::arg("default_dyn_dim_value") = migraphx::shape::dynamic_dimension{1, 1},
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("map_dyn_input_dims") =
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false);
......@@ -505,6 +547,13 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("format") = "msgpack");
m.def("get_target", &migraphx::make_target);
m.def("create_argument", [](const migraphx::shape& s, const std::vector<double>& values) {
if(values.size() != s.elements())
MIGRAPHX_THROW("Values and shape elements do not match");
migraphx::argument a{s};
a.fill(values.begin(), values.end());
return a;
});
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
m.def("fill_argument", &migraphx::fill_argument, py::arg("s"), py::arg("value"));
m.def("quantize_fp16",
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/config.hpp>
#include <migraphx/program.hpp>
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/file_buffer.hpp>
#include <pybind11/embed.h>
namespace py = pybind11;
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreturn-type-c-linkage"
#endif
// extern "C" is used to disable name mangling, but the function will still be called from C++
extern "C" program migraphx_load_py(const std::string& filename);
#ifdef __clang__
#pragma clang diagnostic pop
#endif
const std::string& python_path()
{
static const auto path = dynamic_loader::path(&migraphx_load_py).parent_path().string();
return path;
}
static py::dict run_file(const std::string& file)
{
py::object scope = py::module_::import("__main__").attr("__dict__");
std::string buffer;
buffer.append("import sys\n");
buffer.append("sys.path.insert(0, '" + python_path() + "')\n");
buffer.append("import migraphx\n");
buffer.append(read_string(file));
py::exec(buffer, scope);
return scope.cast<py::dict>();
}
extern "C" program migraphx_load_py(const std::string& filename)
{
py::scoped_interpreter guard{};
py::dict vars = run_file(filename);
auto it = std::find_if(vars.begin(), vars.end(), [](const auto& p) {
return py::isinstance<migraphx::program>(p.second);
});
if(it == vars.end())
MIGRAPHX_THROW("No program variable found");
return it->second.cast<migraphx::program>();
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/py.hpp>
#include <migraphx/dynamic_loader.hpp>
#include <migraphx/process.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static std::vector<fs::path> find_available_python_versions()
{
std::vector<fs::path> result;
auto path = dynamic_loader::path(&load_py).parent_path();
for(const auto& entry : fs::directory_iterator{path})
{
auto p = entry.path();
if(not fs::is_regular_file(p))
continue;
if(not contains(p.stem().string(), "migraphx_py_"))
continue;
result.push_back(p);
}
std::sort(result.begin(), result.end(), std::greater<>{});
return result;
}
static dynamic_loader load_py_lib()
{
auto libs = find_available_python_versions();
for(const auto& lib : libs)
{
auto result = dynamic_loader::try_load(lib);
if(result.has_value())
return *result;
}
MIGRAPHX_THROW("Cant find a viable version of python");
}
static dynamic_loader py_lib()
{
static dynamic_loader lib = load_py_lib();
return lib;
}
MIGRAPHX_PY_EXPORT program load_py(const std::string& filename)
{
static auto f = py_lib().get_function<program(const std::string&)>("migraphx_load_py");
return f(filename);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -29,6 +29,7 @@
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/optimize_module.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
......@@ -48,19 +49,12 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
// This function is to convert any instructions specified in the input
// from double or float to float16 by inserting a convert operator.
// For the conversion, there could be cases of overflowing, but it
// is very rare in the area of deeping learning, so we just do a
// truncate of the input to get the fp16.
// For the conversion, there could be cases of overflowing or underflowing, but it
// is uncommon. Run optimize_module() before converting to fp16 to const eval and fold in FP32 to
// avoid loss of precision.
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
{
run_passes(prog,
{quantize_fp16_pass{ins_names},
eliminate_common_subexpression{},
dead_code_elimination{},
simplify_reshapes{},
dead_code_elimination{},
simplify_qdq{},
dead_code_elimination{}});
run_passes(prog, {optimize_module{}, quantize_fp16_pass{ins_names}, optimize_module{}});
}
void quantize_int8(program& prog,
......
......@@ -52,14 +52,6 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names
auto mod_inputs = ins->module_inputs();
auto s = ins->get_shape();
// Convert back to original type before quantizing the inputs
if(mod_inputs.empty())
{
auto r = m.insert_instruction(
std::next(ins), make_op("convert", {{"target_type", s.type()}}), ins);
m.replace_instruction(ins, r);
}
// Convert each of the inputs that are floating point to fp16
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
......@@ -70,8 +62,17 @@ static void quantize_module(module& m, const std::vector<std::string>& ins_names
ins, make_op("convert", {{"target_type", shape::half_type}}), input);
});
// Replace inputs
m.replace_instruction(ins, ins->get_operator(), inputs, mod_inputs);
// Insert quantized ins
auto converted_ins = m.insert_instruction(ins, ins->get_operator(), inputs, mod_inputs);
// Convert back to original type after quantizing
if(mod_inputs.empty())
{
converted_ins = m.insert_instruction(
ins, make_op("convert", {{"target_type", s.type()}}), converted_ins);
}
// Replace original instruction
m.replace_instruction(ins, converted_ins);
}
}
......
......@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/pass_manager.hpp>
#include <migraphx/replace_allocate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
......@@ -84,10 +85,11 @@ void insert_submod_allocations(instruction_ref ins, module& mod, const allocatio
mod.replace_instruction(ins, ins->get_operator(), inputs, mod_args);
}
void replace_allocate::apply(module& m) const
void replace_allocate::apply(module_pass_manager& mpm) const
{
module& m = mpm.get_module();
auto mod_output_names = create_output_names(m);
bool main_offload_copy = m.name() == "main" ? this->offload_copy : false;
bool root_offload_copy = (*mpm.get_root_module() == m) ? this->offload_copy : false;
for(auto ins : iterator_for(m))
{
auto op = ins->get_operator();
......@@ -104,7 +106,7 @@ void replace_allocate::apply(module& m) const
continue;
auto s = ins->get_shape();
if(not main_offload_copy and model.needs_out_params() and contains(mod_output_names, ins))
if(not root_offload_copy and model.needs_out_params() and contains(mod_output_names, ins))
{
auto out_param = m.add_parameter(mod_output_names[ins], s);
m.replace_instruction(ins, out_param);
......
......@@ -28,6 +28,7 @@
#include <migraphx/tune_axis.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -40,15 +41,18 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if(x->get_shape().type() != y_scale->get_shape().type())
{
x = m.insert_instruction(ins, make_op("convert", {{"target_type", shape::float_type}}), x);
x = m.insert_instruction(
ins, make_op("convert", {{"target_type", y_scale->get_shape().type()}}), x);
}
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("round"), div);
if(ins->inputs().size() == 3)
{
auto zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]);
auto zero_point =
m.insert_instruction(ins,
make_op("convert", {{"target_type", y_scale->get_shape().type()}}),
ins->inputs()[2]);
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
}
......@@ -59,12 +63,9 @@ void apply_quantizelinear(module& m, instruction_ref ins)
min_quant = qt.min();
});
auto s = add_zero_point->get_shape();
std::vector<int> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant);
auto min_arg = m.add_literal(literal(s, min_data));
auto max_arg = m.add_literal(literal(s, max_data));
auto saturate = m.insert_instruction(ins, make_op("clip"), add_zero_point, min_arg, max_arg);
auto min_arg = m.add_literal(literal{shape{s.type()}, {min_quant}});
auto max_arg = m.add_literal(literal{shape{s.type()}, {max_quant}});
auto saturate = insert_common_op(m, ins, make_op("clip"), {add_zero_point, min_arg, max_arg});
m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
}
......@@ -72,14 +73,16 @@ void apply_quantizelinear(module& m, instruction_ref ins)
void apply_dequantizelinear(module& m, instruction_ref ins)
{
assert(ins->name() == "dequantizelinear");
auto x = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[0]);
auto x_scale = ins->inputs()[1];
auto x = m.insert_instruction(
ins, make_op("convert", {{"target_type", x_scale->get_shape().type()}}), ins->inputs()[0]);
if(ins->inputs().size() == 3)
{
auto x_zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]);
auto x_zero_point =
m.insert_instruction(ins,
make_op("convert", {{"target_type", x_scale->get_shape().type()}}),
ins->inputs()[2]);
x = m.insert_instruction(ins, make_op("sub"), x, x_zero_point);
}
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -273,9 +273,23 @@ shape shape::from_permutation(type_t t,
shape::type_t shape::type() const { return impl->m_type; }
const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
const std::vector<std::size_t>& shape::lens() const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: lens() called on a dynamic shape");
}
return impl->m_lens;
}
const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; }
const std::vector<std::size_t>& shape::strides() const
{
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: strides() called on a dynamic shape");
}
return impl->m_strides;
}
std::size_t shape::ndim() const
{
......@@ -535,7 +549,14 @@ bool shape::any_of_dynamic() const
});
}
const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; }
const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const
{
if(not this->dynamic())
{
MIGRAPHX_THROW("SHAPE: dyn_dims() called on a static shape");
}
return impl->m_dyn_dims;
}
std::vector<std::size_t> shape::min_lens() const
{
......@@ -680,10 +701,20 @@ void migraphx_to_value(value& v, const shape& s)
{
value result;
result["type"] = migraphx::to_value(s.type_string());
result["lens"] = migraphx::to_value(s.lens());
result["strides"] = migraphx::to_value(s.strides());
result["sub_shapes"] = migraphx::to_value(s.sub_shapes());
// avoid calling functions that will throw
if(s.dynamic())
{
result["lens"] = {};
result["strides"] = {};
result["dynamic_dimensions"] = migraphx::to_value(s.dyn_dims());
}
else
{
result["lens"] = migraphx::to_value(s.lens());
result["strides"] = migraphx::to_value(s.strides());
result["dynamic_dimensions"] = {};
}
v = result;
}
......@@ -706,13 +737,9 @@ void migraphx_from_value(const value& v, shape& s)
{
auto v_dd = v.at("dynamic_dimensions");
std::vector<shape::dynamic_dimension> dyn_dims(v.at("dynamic_dimensions").size());
std::transform(v_dd.begin(), v_dd.end(), dyn_dims.begin(), [](migraphx::value x) {
auto x_min = x.at("min").template to<size_t>();
auto x_max = x.at("max").template to<size_t>();
auto v_optimals = x.at("optimals");
std::set<size_t> set_x_optimals =
from_value<std::set<std::size_t>>(x.at("optimals"));
return shape::dynamic_dimension{x_min, x_max, set_x_optimals};
std::transform(
v_dd.begin(), v_dd.end(), dyn_dims.begin(), [](const migraphx::value& x) {
return from_value<shape::dynamic_dimension>(x);
});
s = shape{shape::parse_type(t), dyn_dims};
......
......@@ -204,6 +204,131 @@ struct find_mul_slice_conv
}
};
struct find_mul_dot
{
auto matcher() const
{
auto is_dot_const_inputs =
match::name("dot")(match::any_of[match::inputs()](match::is_constant()));
return match::name("mul")(match::either_arg(0, 1)(
is_dot_const_inputs.bind("dot"), match::name("broadcast", "multibroadcast").bind("c")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto dot_ins = r.instructions["dot"];
auto a_ins = dot_ins->inputs()[0];
auto b_ins = dot_ins->inputs()[1];
auto c_ins = r.instructions["c"];
const auto& c_strides = c_ins->get_shape().strides();
// There should only be one stride that is not zero
if(std::count_if(c_strides.begin(), c_strides.end(), [](auto s) { return s != 0; }) > 1)
return;
auto add_mul_const = [&](instruction_ref x_ins) {
if(not x_ins->can_eval())
return m.end();
auto broadcast_v = c_ins->get_operator().to_value();
broadcast_v["out_lens"] = x_ins->get_shape().lens();
auto cb_ins =
m.insert_instruction(ins, make_op(c_ins->name(), broadcast_v), c_ins->inputs());
return m.insert_instruction(ins, make_op("mul"), x_ins, cb_ins);
};
if(c_strides.back() == 1)
{
b_ins = add_mul_const(b_ins);
}
else if(c_strides[c_strides.size() - 2] == 1)
{
a_ins = add_mul_const(a_ins);
}
else if(c_ins->get_shape().scalar())
{
if(a_ins->can_eval())
a_ins = add_mul_const(a_ins);
else
b_ins = add_mul_const(b_ins);
}
else
{
return;
}
if(contains({a_ins, b_ins}, m.end()))
return;
m.replace_instruction(ins, make_op("dot"), a_ins, b_ins);
}
};
struct find_dot_mul
{
auto matcher() const
{
auto const_broadcast = match::name("broadcast", "multibroadcast")(match::is_constant());
auto mul = match::name("mul")(
match::used_once(),
match::either_arg(0, 1)(const_broadcast.bind("d"),
match::none_of(match::is_constant()).bind("z")));
return match::name("dot")(match::either_arg(0, 1)(mul, match::is_constant().bind("c")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = ins->inputs()[0];
auto b_ins = ins->inputs()[1];
auto d_ins = r.instructions["d"];
auto c_ins = r.instructions["c"];
auto z_ins = r.instructions["z"];
const auto& d_strides = d_ins->get_shape().strides();
// There should only be one stride that is not zero
if(std::count_if(d_strides.begin(), d_strides.end(), [](auto s) { return s != 0; }) > 1)
return;
if(not d_ins->get_shape().scalar())
{
if(d_strides.back() == 1 and not b_ins->can_eval())
return;
if(d_strides[d_strides.size() - 2] == 1 and not a_ins->can_eval())
return;
}
auto broadcast_v = d_ins->get_operator().to_value();
auto c_lens = c_ins->get_shape().lens();
std::vector<int64_t> permutation(c_lens.size());
std::iota(permutation.begin(), permutation.end(), 0);
std::swap(permutation.back(), permutation[permutation.size() - 2]);
c_lens = reorder_dims(c_lens, permutation);
broadcast_v["out_lens"] = c_lens;
auto db_ins =
m.insert_instruction(ins, make_op(d_ins->name(), broadcast_v), d_ins->inputs());
auto db_transpose_ins =
m.insert_instruction(ins, make_op("transpose", {{"permutation", permutation}}), db_ins);
auto cd_ins = m.insert_instruction(ins, make_op("mul"), c_ins, db_transpose_ins);
if(c_ins == b_ins)
{
a_ins = z_ins;
b_ins = cd_ins;
}
else
{
a_ins = cd_ins;
b_ins = z_ins;
}
m.replace_instruction(ins, make_op("dot"), a_ins, b_ins);
}
};
// ******************************
// a * (x + b) => a * x + a * b
// ******************************
......@@ -361,30 +486,123 @@ struct find_inner_broadcast
{
auto matcher() const { return pointwise(match::all_of[match::inputs()](match::broadcast())); }
static auto non_scalar_op(const std::string& name)
{
return [=](instruction_ref ins) {
if(ins->get_shape().scalar())
return false;
return ins->name() == name;
};
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto broadcasts = ins->inputs();
if(broadcasts.empty())
return;
// Skip if different data types are used
if(any_of(broadcasts, [&](auto i) {
return i->get_shape().type() != broadcasts.front()->get_shape().type();
}))
return;
bool mixed_broadcasts = any_of(broadcasts, non_scalar_op("broadcast")) and
any_of(broadcasts, non_scalar_op("multibroadcast"));
// If the broadcast is not a single dimension, then dont perform inner_broadcast
if(mixed_broadcasts and any_of(broadcasts, [&](instruction_ref i) {
if(i->get_shape().scalar())
return false;
if(i->name() == "multibroadcast")
return false;
auto input = i->inputs().at(0);
const auto& lens = input->get_shape().lens();
return std::count_if(lens.begin(), lens.end(), [&](std::size_t d) {
return d == 1;
}) < (lens.size() - 1);
}))
return;
std::vector<instruction_ref> inputs;
std::transform(broadcasts.begin(),
broadcasts.end(),
std::back_inserter(inputs),
[](auto i) { return i->inputs().front(); });
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape() != inputs.front()->get_shape() and
i->get_shape().elements() != 1;
}))
return;
auto b_it = std::find_if(broadcasts.begin(), broadcasts.end(), [&](auto i) {
return not i->get_shape().scalar();
[&](instruction_ref i) {
auto input = i->inputs().front();
if(mixed_broadcasts and not i->get_shape().scalar() and
i->get_shape().lens().size() > 1)
return m.insert_instruction(i, make_op("squeeze"), input);
return input;
});
if(b_it == broadcasts.end())
b_it = broadcasts.begin();
std::sort(broadcasts.begin(), broadcasts.end(), by(std::less<>{}, [](instruction_ref i) {
if(i->get_shape().scalar())
return 2;
else if(i->name() == "broadcast")
return 0;
if(i->name() == "multibroadcast")
return 1;
return 3;
}));
auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
m.replace_instruction(ins, (*b_it)->get_operator(), op);
m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
}
};
struct find_dot_broadcast
{
auto matcher() const
{
return match::name("dot")(match::all_of[match::inputs()](match::broadcast()));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a = ins->inputs()[0];
auto b = ins->inputs()[1];
if(a->get_operator().name() != b->get_operator().name())
return;
if(ins->get_shape().lens().size() < 3)
return;
auto nbatch_axes = ins->get_shape().lens().size() - 2;
const auto& a_strides = a->get_shape().strides();
const auto& b_strides = b->get_shape().strides();
// Find leading batch axes that are broadcasted
auto p =
std::mismatch(a_strides.begin(),
a_strides.begin() + nbatch_axes,
b_strides.begin(),
b_strides.begin() + nbatch_axes,
[](auto astride, auto bstride) { return astride == 0 and bstride == 0; });
auto naxes = p.first - a_strides.begin();
assert(naxes <= nbatch_axes);
std::vector<std::size_t> axes(naxes);
std::iota(axes.begin(), axes.end(), 0);
auto insert_broadcast = [&](instruction_ref b_ins) -> instruction_ref {
auto input = b_ins->inputs()[0];
std::vector<std::size_t> lens(b_ins->get_shape().lens().begin() + naxes,
b_ins->get_shape().lens().end());
if(b_ins->name() == "multibroadcast")
{
return m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", lens}}), input);
}
else if(b_ins->name() == "broadcast")
{
auto v = b_ins->get_operator().to_value();
auto axis = v.at("axis").to<std::size_t>() - naxes;
return m.insert_instruction(
ins, make_op("broadcast", {{"axis", axis}, {"out_lens", lens}}), input);
}
assert(false);
return m.end();
};
auto a1 = insert_broadcast(a);
auto b1 = insert_broadcast(b);
auto dot = m.insert_instruction(ins, make_op("dot"), a1, b1);
auto broadcast = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), dot);
m.replace_instruction(ins, broadcast);
}
};
......@@ -393,7 +611,8 @@ struct find_concat_op
auto matcher() const
{
return match::name("concat")(match::any_of[match::inputs()](
match::any_of(match::pointwise(), match::name("broadcast")), match::used_once()));
match::any_of(match::pointwise(), match::name("broadcast", "multibroadcast")),
match::used_once()));
}
template <class Iterator>
......@@ -412,7 +631,8 @@ struct find_concat_op
static bool is_valid_op(const operation& op)
{
return op.name() == "broadcast" or op.attributes().contains("pointwise");
return contains({"broadcast", "multibroadcast"}, op.name()) or
op.attributes().contains("pointwise");
}
void apply(module& m, const match::matcher_result& r) const
......@@ -440,6 +660,16 @@ struct find_concat_op
op = b;
iaxis = 0;
}
else if(op.name() == "multibroadcast")
{
shape bshape = (*start)->get_shape();
auto input = (*start)->inputs()[0];
if(iaxis >= bshape.strides().size() or bshape.strides()[iaxis] == 0)
return {start, last};
op.from_value({{"out_lens", get_output_lens(start, last, iaxis)}});
auto delta = bshape.lens().size() - input->get_shape().lens().size();
iaxis -= delta;
}
std::vector<instruction_ref> concats;
for(std::size_t i = 0; i < x->inputs().size(); i++)
......@@ -865,8 +1095,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
};
};
auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot"));
auto qdots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("quant_dot"));
auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution"));
return (dots >= 2 or convs >= 2);
return (dots >= 2 or convs >= 2 or qdots >= 2);
}
struct find_conv_dot_horiz_fusion
......@@ -880,7 +1111,7 @@ struct find_conv_dot_horiz_fusion
auto pred = [](auto i, auto j) {
if(i->get_operator() != j->get_operator())
return false;
if(not contains({"dot", "convolution"}, i->name()))
if(not contains({"quant_dot", "dot", "convolution"}, i->name()))
return true;
auto x = i->inputs()[1]->get_shape().lens();
auto y = j->inputs()[1]->get_shape().lens();
......@@ -888,7 +1119,7 @@ struct find_conv_dot_horiz_fusion
return false;
// Check that non-axes match
int axis = 1;
if(i->name() == "dot")
if(i->name() == "dot" or i->name() == "quant_dot")
{
axis = x.size() - 1;
}
......@@ -899,7 +1130,7 @@ struct find_conv_dot_horiz_fusion
if(std::distance(start, last) < 2)
return;
auto&& name = (*start)->name();
if(not contains({"dot", "convolution"}, name))
if(not contains({"quant_dot", "dot", "convolution"}, name))
return;
auto op = (*start)->get_operator();
int group = 1;
......@@ -914,7 +1145,7 @@ struct find_conv_dot_horiz_fusion
start, last, std::back_inserter(args), [&](auto x) { return x->inputs().at(1); });
int axis = 1;
int concat_axis = 0;
if(name == "dot")
if(name == "dot" or name == "quant_dot")
{
axis = int(args.front()->get_shape().lens().size() - 1);
concat_axis = axis;
......@@ -1260,12 +1491,15 @@ void simplify_algebra::apply(module& m) const
{
match::find_matches(m,
find_inner_broadcast{},
find_dot_broadcast{},
find_double_add_lit_broadcast{},
find_add_lit_broadcast{},
find_add_convs{},
find_conv_dot_horiz_fusion{},
find_mul_conv{},
find_mul_slice_conv{},
find_mul_dot{},
find_dot_mul{},
find_mul_add{},
find_unit_ops{},
find_neg_unit_ops{},
......
......@@ -89,38 +89,23 @@ struct find_reshaper
{
auto matcher() const
{
return match::name(reshaper_names())(
match::any_of[match::outputs()](match::name(reshaper_names())));
auto reshaper = match::name(reshaper_names());
auto contiguous = match::name("contiguous");
auto no_output_reshape = match::none_of[match::outputs()](reshaper);
auto input_reshape = match::arg(0)(match::skip(contiguous)(reshaper));
auto input = match::skip(reshaper, contiguous)(match::any().bind("x"));
return reshaper(no_output_reshape, input_reshape, input);
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()))
{
assert(not reshapes.back()->inputs().empty());
assert(m.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front();
reshapes.push_back(input);
}
auto input = mr.instructions["x"];
auto dims = ins->get_shape().lens();
std::pair<instruction_ref, instruction_ref> r{m.end(), m.end()};
for(auto start : iterator_for(reshapes))
{
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
return i->get_shape() == (*start)->get_shape() and i != (*start);
});
if(last != reshapes.rend())
{
r = std::make_pair(*start, *last);
break;
}
}
if(r.first != r.second)
{
m.replace_instruction(r.first, r.second);
}
if(not input->get_shape().standard())
input = m.insert_instruction(ins, make_op("contiguous"), input);
m.replace_instruction(ins, make_op("reshape", {{"dims", dims}}), input);
}
};
......@@ -804,9 +789,9 @@ void simplify_reshapes::apply(module& m) const
match::find_matches(m,
find_where_op{},
find_resize{},
find_reshape_cont{},
find_nop_reshapes{},
find_reshaper{},
find_reshape_cont{},
find_transpose{},
find_concat_transpose{},
find_concat_multibroadcasts{},
......
......@@ -100,10 +100,10 @@ struct find_static_2in_broadcasts
} // namespace
/**
* Makes all the shapes in the dynamic_dimension range.
* Probably won't work for `if` and `loop` instructions, depending on how the submodules for those
* Makes all the shapes in the dynamic_dimension range. Probably won't work for `if`
* and `loop` instructions, depending on how the submodules for those
* work. Inserts select_module instruction to the top. Replaces return, bypassing other
* instructions.
* instructions. Skips if the dynamic parameter outputs to a select_module operator.
*/
void split_single_dyn_dim::apply(module_pass_manager& mpm) const
{
......@@ -111,7 +111,13 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
auto param_names = mm->get_parameter_names();
auto param_shapes = mm->get_parameter_shapes();
optional<dynamic_dimensions_check> dd_check = has_one_dyn_dim(param_shapes);
if(dd_check.has_value())
auto any_sm_next = [&](auto ddc) {
auto p_outputs = mm->get_parameter(ddc->dyn_param_str)->outputs();
return std::any_of(p_outputs.cbegin(), p_outputs.cend(), [](auto ins) {
return ins->name() == "select_module";
});
};
if(dd_check.has_value() and not any_sm_next(dd_check))
{
const auto& dyn_param = mm->get_parameter(dd_check->dyn_param_str);
auto dyn_param_shape = mm->get_parameter_shape(dd_check->dyn_param_str);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment