Unverified Commit c9b86f1c authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Module impl (#678)



* add an api get_main_module

* clang format

* modify onnx unit test for module

* clang format

* refactor ops unit test with the get_main_module

* clang format

* code backup

* clang format

* refine module c api

* add python api for module

* clang format

* fix a python api issue

* clang format

* fix cppcheck error

* clang format

* refine unit tests changes

* clang format

* code backup

* code backup

* clang format

* defer some changes to later PRs

* change return of get_main_module from ref to pointer

* clang format

* add unit tests for the get_main_module_api

* clang format

* fix cppcheck error

* clang format

* fix cppcheck error

* clang format

* add more unit tests for more code change coverage

* clang format

* fixed a unit test error

* clang format

* fix unit test

* clang format

* code backup

* code change for more code coverage

* change program to module in various passes and matcher

* clang format

* modify the pass API

* code backup

* code backup

* clang format

* code backup

* clang format

* Add option to no generate a destroy method

* Formatting

* fix some review comments

* clang format

* fix review comments

* clang format

* clang format

* code backup

* code backup

* clang format

* fix cppcheck errors

* clang format

* clang format

* fix build errors

* clang format

* modify gpu unit tests to using module

* clang format

* fix cppcheck error

* clang format

* Add flag to enable cpu backend

* Make buffers shared

* Enable optimizations

* Formatting

* fix review comments

* code backup

* clang format

* code backup

* clang format

* fix a bug related to a unit test

* clang format

* clang format

* fix a build error

* remove unnecessary code

* remove unnecessary files

* code backup

* clang format

* remove the compile function from the module class

* clang format

* clang format

* remove the context parameter from the from_value method of the module class

* code refinement

* clang format

* merge changes from develop branch

* clang format

* fix cppcheck error

* clang format

* fix a build error

* fixed a merge error

* fix cppcheck error

* fixed review comments

* clang format

* fix cppcheck error

* fix a cppcheck error

* fix cppcheck error

* fix build error caused by merge

* Add missing has_op function

* Formatting

* merge changes from develop branch

* fix a cppcheck error

* fixed some review comments

* clang format

* remove the begin/end function of the program class

* clang format

* refine code and fix cppcheck error

* clang format

* fix review comments

* clang format

* fix review comments

* clang format

* add unit tests for more code coverage

* clang format

* fix review comments

* clang format

* fix review comments

* clang format

* fix a build error in debug mode

* clang format
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
parent 1dd4e4d9
......@@ -7,8 +7,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
struct module;
/**
* Replace instructions which take all literals with a literal of the computation.
......
......@@ -8,8 +8,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
struct module;
/**
* Decompose operators.
......
......@@ -8,8 +8,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
struct module;
/**
* Rewrite batchnorm to a multiply and add.
......
......@@ -7,8 +7,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
struct module;
/**
* Rewrite pooling to reduce_mean
......
......@@ -11,8 +11,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
struct module;
/**
* Rewrite rnn to gemm and add.
......
......@@ -9,8 +9,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
struct module;
/**
* Schedule instructions for concurrent execution
......
......@@ -15,8 +15,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
struct module;
struct operation;
#ifdef DOXYGEN
......
......@@ -7,8 +7,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
struct module;
/**
* Simplify many algebraic instructions to more efficient versions.
......
......@@ -8,8 +8,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
struct module;
/**
* Eliminate redundant reshapes.
......
#include <migraphx/module.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/target.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/make_op.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
#include <set>
#include <utility>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_impl
{
// A list is used to keep references to an instruction stable
std::list<instruction> instructions;
std::vector<std::string> input_names;
};
const operation& get_operation(instruction_ref ins) { return ins->get_operator(); }
static void print_instruction(std::ostream& os,
instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& names)
{
os << names.at(ins) << " = ";
os << ins->get_operator();
if(ins->name() == "@literal")
{
if(ins->get_literal().get_shape().elements() > 10)
os << "{ ... }";
else
os << "{" << ins->get_literal() << "}";
}
if(!ins->inputs().empty())
{
char delim = '(';
for(auto&& arg : ins->inputs())
{
os << delim << names.at(arg);
delim = ',';
}
os << ")";
}
// skip return instruction shape
if(ins->name() != "@return")
os << " -> " << ins->get_shape();
}
module::module() : impl(std::make_unique<module_impl>()) {}
module::module(module&&) noexcept = default;
module::~module() noexcept = default;
// copy constructor
module::module(const module& m) { assign(m); }
// copy assignment operator
module& module::operator=(module m)
{
std::swap(m.impl, this->impl);
return *this;
}
void module::assign(const module& m)
{
// clean the current module
if(!impl)
{
impl = std::make_unique<module_impl>();
}
else if(!impl->instructions.empty())
{
impl->instructions.clear();
}
impl->input_names = m.impl->input_names;
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(m))
{
instruction_ref copy_ins{};
if(ins->name() == "@literal")
{
auto l = ins->get_literal();
copy_ins = impl->instructions.insert(impl->instructions.end(), instruction{l});
}
else if(ins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins->get_operator()).parameter;
auto s = ins->get_shape();
copy_ins = impl->instructions.insert(impl->instructions.end(),
{builtin::param{name}, std::move(s), {}});
}
else if(ins->name() == "@outline")
{
auto s = ins->get_shape();
copy_ins =
impl->instructions.insert(impl->instructions.end(), {builtin::outline{s}, s, {}});
}
else
{
// retrieve its mapped input
auto inputs = ins->inputs();
// ensure all inputs have its corresponding copy instructions
assert(std::all_of(
inputs.begin(), inputs.end(), [&](auto i) { return ins_map.count(i) > 0; }));
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return ins_map[i];
});
if(ins->name() == "@return")
{
copy_ins = add_return(copy_inputs);
}
else
{
copy_ins = add_instruction(ins->get_operator(), copy_inputs);
}
}
ins_map[ins] = copy_ins;
}
}
instruction_ref module::add_instruction(const operation& op, std::vector<instruction_ref> args)
{
return insert_instruction(impl->instructions.end(), op, std::move(args));
}
instruction_ref module::insert_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args)
{
assert(std::all_of(
args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args);
auto result = impl->instructions.insert(ins, {op, r, std::move(args)});
instruction::backreference(result);
assert(result->valid(begin()));
return result;
}
instruction_ref module::replace_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args) MIGRAPHX_TIDY_CONST
{
assert(std::all_of(
args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args);
instruction::replace(ins, op, r, std::move(args));
assert(ins->valid(begin()));
return ins;
}
instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref rep)
{
assert(has_instruction(ins));
assert(has_instruction(rep));
assert(ins != rep);
if(ins == std::prev(this->end()))
{
return replace_instruction(ins, make_op("identity"), rep);
}
// TODO: Should it be an error if the output is empty?
if(ins->outputs().empty())
{
return rep;
}
// Make a copy of outputs which can be changed when calling replace_argument
auto outputs = ins->outputs();
for(auto out : outputs)
{
// TODO: Check for possible cycles
if(out != rep)
{
instruction::replace_argument(out, ins, rep);
}
assert(out->valid(begin()));
}
// Replacement should not be dead code unless its the last instruction
assert(!rep->outputs().empty() or rep == std::prev(end()));
// Output of the original instruction should only be the replacement or empty
assert(ins->outputs().empty() or std::all_of(ins->outputs().begin(),
ins->outputs().end(),
[&](auto i) { return i == rep; }));
assert(ins->valid(begin()));
assert(rep->valid(begin()));
return rep;
}
instruction_ref module::remove_instruction(instruction_ref ins)
{
assert(has_instruction(ins));
assert(ins->outputs().empty());
ins->clear_arguments();
return impl->instructions.erase(ins);
}
instruction_ref module::remove_instructions(instruction_ref first, instruction_ref last)
{
if(first == last)
return first;
// TODO: Check every element
assert(has_instruction(first));
std::for_each(first, last, [&](instruction& ins) { ins.clear_arguments(); });
assert(std::all_of(first, last, [&](const instruction& ins) { return ins.outputs().empty(); }));
return impl->instructions.erase(first, last);
}
instruction_ref module::move_instruction(instruction_ref src, instruction_ref dst)
{
impl->instructions.splice(dst, impl->instructions, src);
return src;
}
instruction_ref module::move_instructions(instruction_ref src, instruction_ref dst)
{
this->move_instruction(src, dst);
for(auto ins : src->inputs())
this->move_instruction(ins, src);
return src;
}
instruction_ref module::add_literal(literal l)
{
impl->instructions.emplace_front(std::move(l));
return impl->instructions.begin();
}
instruction_ref module::add_outline(const shape& s)
{
impl->instructions.push_front({builtin::outline{s}, s, {}});
return impl->instructions.begin();
}
instruction_ref module::add_parameter(std::string name, shape s)
{
assert(get_parameter_shape(name) == shape{});
impl->input_names.push_back(name);
impl->instructions.push_front({builtin::param{std::move(name)}, std::move(s), {}});
return impl->instructions.begin();
}
instruction_ref module::add_return(std::vector<instruction_ref> args)
{
assert(std::all_of(
args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
impl->instructions.push_back({builtin::returns{}, {}, args});
auto result = std::prev(impl->instructions.end());
instruction::backreference(result);
assert(result->valid(begin()));
return result;
}
shape module::get_parameter_shape(std::string name) const
{
auto ins = std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
if(x.name() == "@param")
{
return any_cast<builtin::param>(x.get_operator()).parameter == name;
}
else
{
return false;
}
});
if(ins != this->end())
return ins->get_shape();
else
return {};
}
std::vector<std::string> module::get_parameter_names() const
{
std::vector<std::string> result = impl->input_names;
std::unordered_set<std::string> params;
for(auto&& ins : impl->instructions)
{
if(ins.name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins.get_operator()).parameter;
params.insert(name);
}
}
erase_if(result, [&](auto&& name) { return params.count(name) == 0; });
return result;
}
instruction_ref module::get_parameter(std::string name) const
{
auto ins = std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
if(x.name() == "@param")
{
return any_cast<builtin::param>(x.get_operator()).parameter == name;
}
else
{
return false;
}
});
if(ins != this->end())
return ins;
else
return this->end();
}
std::unordered_map<std::string, shape> module::get_parameter_shapes() const
{
std::unordered_map<std::string, shape> result;
for(auto&& ins : impl->instructions)
{
if(ins.name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins.get_operator()).parameter;
result[name] = ins.get_shape();
}
}
return result;
}
bool module::has_instruction(instruction_ref ins) const
{
return std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
return std::addressof(*ins) == std::addressof(x);
}) != impl->instructions.end();
}
std::size_t module::size() const { return impl->instructions.size(); }
instruction_ref module::begin() const { return impl->instructions.begin(); }
instruction_ref module::end() const { return impl->instructions.end(); }
std::vector<shape> module::get_output_shapes() const
{
auto last_ins = impl->instructions.back();
if(last_ins.name() == "@return")
{
const auto& output_ins = last_ins.inputs();
std::vector<shape> output_shapes;
std::transform(output_ins.begin(),
output_ins.end(),
std::back_inserter(output_shapes),
[](auto& ins) { return ins->get_shape(); });
return output_shapes;
}
// The else branch is to provide backward compatibility
else
{
return {last_ins.get_shape()};
}
}
instruction_ref module::validate() const
{
return std::find_if(impl->instructions.begin(),
impl->instructions.end(),
[&](const instruction& i) { return !i.valid(impl->instructions.begin()); });
}
void module::finalize(context& ctx)
{
for(auto ins : iterator_for(*this))
{
ins->finalize(ctx);
}
}
value module::to_value() const
{
value result;
value nodes;
this->print([&](auto ins, const auto& names) {
value node;
node["output"] = names.at(ins);
node["name"] = ins->name();
node["shape"] = migraphx::to_value(ins->get_shape());
if(ins->name() == "@literal")
node["literal"] = migraphx::to_value(ins->get_literal());
node["operator"] = ins->get_operator().to_value();
std::vector<std::string> inputs;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto i) { return names.at(i); });
node["inputs"] = inputs;
nodes.push_back(node);
});
result["nodes"] = nodes;
return result;
}
void module::from_value(const value& v)
{
std::unordered_map<std::string, instruction_ref> instructions;
for(const value& node : v.at("nodes"))
{
instruction_ref output;
auto name = node.at("name").to<std::string>();
auto fields = node.at("operator");
if(name == "@param")
{
output = this->add_parameter(fields["parameter"].to<std::string>(),
migraphx::from_value<shape>(node.at("shape")));
}
else if(name == "@literal")
{
output = this->add_literal(migraphx::from_value<literal>(node.at("literal")));
}
else
{
auto op = make_op(name, fields);
std::vector<instruction_ref> inputs;
std::transform(node.at("inputs").begin(),
node.at("inputs").end(),
std::back_inserter(inputs),
[&](const value& i) { return instructions[i.to<std::string>()]; });
if(name == "@return")
output = this->add_return(inputs);
else
output = this->add_instruction(op, inputs);
}
instructions[node.at("output").to<std::string>()] = output;
}
}
void module::debug_print() const { std::cout << *this << std::endl; }
void module::debug_print(instruction_ref ins) const
{
if(ins == this->end())
{
std::cout << "End instruction" << std::endl;
return;
}
if(not has_instruction(ins))
{
std::cout << "Instruction not part of module" << std::endl;
return;
}
std::stringstream ss;
this->print([&](auto x, const auto& names) {
if(x == ins)
{
print_instruction(std::cout, x, names);
std::cout << std::endl;
}
});
}
void module::debug_print(const std::vector<instruction_ref>& inss) const
{
for(auto ins : inss)
this->debug_print(ins);
std::cout << std::endl;
}
void module::print(const std::function<
void(instruction_ref, const std::unordered_map<instruction_ref, std::string>&)>&
print_func) const
{
std::unordered_map<instruction_ref, std::string> names;
int count = 0;
for(auto ins : iterator_for(*this))
{
std::string var_name;
if(ins->name() == "@param")
{
var_name = any_cast<builtin::param>(ins->get_operator()).parameter;
}
else
{
var_name = "@" + std::to_string(count);
count++;
}
names.emplace(ins, var_name);
assert(std::all_of(ins->inputs().begin(),
ins->inputs().end(),
[&](auto arg) { return this->has_instruction(arg); }) &&
"DEBUG_PRINT: Instruction not found");
print_func(ins, names);
}
}
static std::string enclose_name(const std::string& name)
{
return '"' + replace_string(name, "\"", "\\\"") + '"';
}
void module::print_graph(std::ostream& os, bool brief) const
{
os << "digraph {" << std::endl;
os << "\trankdir=LR;" << std::endl;
this->print([&](auto ins, const auto& names) {
std::string label;
if(brief)
label = ins->name();
else
label = to_string(ins->get_operator());
os << "\t" << enclose_name(names.at(ins)) << "[label=" << enclose_name(label) << "]";
os << ";" << std::endl;
if(!ins->inputs().empty())
{
for(auto&& arg : ins->inputs())
{
os << "\t" << enclose_name(names.at(arg)) << " -> " << enclose_name(names.at(ins));
if(not brief)
os << "[label=" << enclose_name(to_string(ins->get_shape())) << "]";
os << ";" << std::endl;
}
}
});
os << "}" << std::endl;
}
static std::string cpp_var_name(const std::string& name)
{
return "m" + replace_string(name, "@", "x");
}
static std::string cpp_op_var(const std::string& name, instruction_ref ins)
{
return replace_string(name, "@", ins->name());
}
static void print_op_attributes(std::ostream& os, const std::string& name, const operation& op)
{
std::string x = to_string(op);
if(contains(x, "["))
{
auto start = x.find('[');
auto end = x.find(']');
std::string attribute_text = x.substr(start + 1, end - start - 1);
std::vector<std::string> attributes;
for(auto&& attribute : split_string(attribute_text, ','))
{
if(contains(attribute, '='))
attributes.push_back(attribute);
else
attributes.back() += "," + attribute;
}
for(auto&& attribute : attributes)
{
auto p = split_string(attribute, '=');
auto key = p.front();
auto value = p.back();
if(contains({"bn_mode", "padding_mode"}, key))
continue;
if(key == "mode")
value = enclose_name(trim(value));
os << name << "." << key << " = " << value << ";" << std::endl;
}
}
}
static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
{
os << "migraphx::shape{migraphx::shape::" << s.type_string();
os << ", {" << to_string_range(s.lens()) << "}";
if(not s.standard())
os << ", {" << to_string_range(s.strides()) << "}";
os << "}";
}
void module::print_cpp(std::ostream& os) const
{
os << "migraphx::module p;" << std::endl;
// cppcheck-suppress variableScope
unsigned long seed = 0;
this->print([&](auto ins, const auto& names) {
auto op = cpp_op_var(names.at(ins), ins);
if(ins->name().front() != '@')
{
os << "migraphx::op::" << ins->name() << " " << op << ";" << std::endl;
print_op_attributes(os, op, ins->get_operator());
}
os << "auto " << cpp_var_name(names.at(ins)) << " = ";
if(ins->name() == "@literal")
{
os << "p.add_literal(";
bool use_abs = false;
ins->get_literal().visit([&](auto v) {
use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
});
if(use_abs)
os << "migraphx::abs(";
os << "migraphx::generate_literal(";
print_cpp_shape(os, ins->get_shape());
os << ", " << seed << ")";
if(use_abs)
os << ")";
os << ");" << std::endl;
seed++;
}
else if(ins->name() == "@param")
{
std::string name = any_cast<builtin::param>(ins->get_operator()).parameter;
os << "p.add_parameter(" << enclose_name(name) << ",";
print_cpp_shape(os, ins->get_shape());
os << ");" << std::endl;
}
else
{
os << "p.add_instruction(" << op;
for(auto input : ins->inputs())
{
os << ", " << cpp_var_name(names.at(input));
}
os << ");" << std::endl;
}
});
}
void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const
{
this->print([&](auto ins, const auto& names) {
print_instruction(os, ins, names);
a(ins);
os << std::endl;
});
}
module& module::sort()
{
fix([&](auto self, auto ins) {
this->move_instruction(ins, this->begin());
for(auto child : ins->inputs())
self(child);
})(std::prev(this->end()));
assert(this->validate() == this->end());
return *this;
}
bool operator==(const module& x, const module& y) { return to_string(x) == to_string(y); }
std::ostream& operator<<(std::ostream& os, const module& m)
{
m.print([&](auto ins, const auto& names) {
print_instruction(os, ins, names);
os << std::endl;
});
return os;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -35,7 +35,7 @@ struct parse_onehot : op_parser<parse_onehot>
auto type = args[2]->get_shape().type();
shape s{type, {depth, depth}};
auto l_val = info.mm->add_literal({s, depth_input});
auto l_val = info.add_literal({s, depth_input});
auto gather_out = info.add_instruction(make_op("gather", {{"axis", 0}}), {l_val, args[0]});
// Finally, we need a transpose to move the inner most dim to the axis dim
......
......@@ -70,7 +70,7 @@ instruction_ref reflect_pad(const onnx_parser::node_info& info,
{
*starts_it = idx;
*ends_it = *starts_it + 1;
slices.push_back(info.mm->add_instruction(
slices.push_back(info.add_instruction(
make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), input));
}
// when padding on the left side, the outermost pad should be at the beginning
......@@ -83,7 +83,7 @@ instruction_ref reflect_pad(const onnx_parser::node_info& info,
slices.push_back(info.add_instruction(
make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), input));
}
input = info.mm->add_instruction(make_op("concat", {{"axis", axis}}), slices);
input = info.add_instruction(make_op("concat", {{"axis", axis}}), slices);
}
return input;
}
......
......@@ -6,33 +6,30 @@
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/time.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/iterator_for.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
#include <set>
#include <utility>
#include <migraphx/make_op.hpp>
#include <unordered_set>
#include <map>
#include <cassert>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program_impl
{
// A list is used to keep references to an instruction stable
std::list<instruction> instructions;
std::vector<std::string> input_names;
// A map is used to keep references to modules of the program
std::map<std::string, module> modules;
context ctx;
std::string target_name;
};
const operation& get_operation(instruction_ref ins) { return ins->get_operator(); }
static void print_instruction(std::ostream& os,
instruction_ref ins,
const std::unordered_map<instruction_ref, std::string>& names)
......@@ -65,38 +62,7 @@ static void print_instruction(std::ostream& os,
os << " -> " << ins->get_shape();
}
template <class F>
static void print_program(const program& p, F print_func)
{
std::unordered_map<instruction_ref, std::string> names;
int count = 0;
for(auto ins : iterator_for(p))
{
std::string var_name;
if(ins->name() == "@param")
{
var_name = any_cast<builtin::param>(ins->get_operator()).parameter;
}
else
{
var_name = "@" + std::to_string(count);
count++;
}
names.emplace(ins, var_name);
// TODO: Use all_of
for(auto&& arg : ins->inputs())
{
assert(p.has_instruction(arg) && "Instruction not found");
(void)arg;
}
print_func(ins, names);
}
}
program::program() : impl(std::make_unique<program_impl>()) {}
program::program() : impl(std::make_unique<program_impl>()) { impl->modules["main"] = {}; }
program::program(program&&) noexcept = default;
program::~program() noexcept = default;
......@@ -113,346 +79,99 @@ program& program::operator=(program p)
void program::assign(const program& p)
{
// clean the current program
if(!impl)
{
impl = std::make_unique<program_impl>();
}
else if(!impl->instructions.empty())
else if(!impl->modules.empty())
{
impl->instructions.clear();
impl->modules.clear();
}
impl->ctx = p.impl->ctx;
impl->input_names = p.impl->input_names;
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(p))
{
instruction_ref copy_ins{};
if(ins->name() == "@literal")
{
auto l = ins->get_literal();
copy_ins = impl->instructions.insert(impl->instructions.end(), instruction{l});
}
else if(ins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins->get_operator()).parameter;
auto s = ins->get_shape();
copy_ins = impl->instructions.insert(impl->instructions.end(),
{builtin::param{name}, std::move(s), {}});
}
else if(ins->name() == "@outline")
{
auto s = ins->get_shape();
copy_ins =
impl->instructions.insert(impl->instructions.end(), {builtin::outline{s}, s, {}});
}
else
{
// retrieve its mapped input
auto inputs = ins->inputs();
// ensure all inputs have its corresponding copy instructions
assert(std::all_of(
inputs.begin(), inputs.end(), [&](auto i) { return ins_map.count(i) > 0; }));
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return ins_map[i];
});
if(ins->name() == "@return")
{
copy_ins = add_return(copy_inputs);
}
else
{
copy_ins = add_instruction(ins->get_operator(), copy_inputs);
}
}
ins_map[ins] = copy_ins;
}
}
instruction_ref program::add_instruction(const operation& op, std::vector<instruction_ref> args)
{
return insert_instruction(impl->instructions.end(), op, std::move(args));
}
instruction_ref program::insert_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args)
{
assert(std::all_of(
args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args);
auto result = impl->instructions.insert(ins, {op, r, std::move(args)});
instruction::backreference(result);
assert(result->valid(begin()));
return result;
}
instruction_ref program::replace_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args) MIGRAPHX_TIDY_CONST
{
assert(std::all_of(
args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args);
instruction::replace(ins, op, r, std::move(args));
assert(ins->valid(begin()));
return ins;
}
instruction_ref program::replace_instruction(instruction_ref ins, instruction_ref rep)
{
assert(has_instruction(ins));
assert(has_instruction(rep));
assert(ins != rep);
if(ins == std::prev(this->end()))
{
return replace_instruction(ins, make_op("identity"), rep);
}
// TODO: Should it be an error if the output is empty?
if(ins->outputs().empty())
{
return rep;
}
// Make a copy of outputs which can be changed when calling replace_argument
auto outputs = ins->outputs();
for(auto out : outputs)
{
// TODO: Check for possible cycles
if(out != rep)
{
instruction::replace_argument(out, ins, rep);
}
assert(out->valid(begin()));
}
// Replacement should not be dead code unless its the last instruction
assert(!rep->outputs().empty() or rep == std::prev(end()));
// Output of the original instruction should only be the replacement or empty
assert(ins->outputs().empty() or std::all_of(ins->outputs().begin(),
ins->outputs().end(),
[&](auto i) { return i == rep; }));
assert(ins->valid(begin()));
assert(rep->valid(begin()));
return rep;
}
instruction_ref program::remove_instruction(instruction_ref ins)
{
assert(has_instruction(ins));
assert(ins->outputs().empty());
ins->clear_arguments();
return impl->instructions.erase(ins);
}
instruction_ref program::remove_instructions(instruction_ref first, instruction_ref last)
{
if(first == last)
return first;
// TODO: Check every element
assert(has_instruction(first));
std::for_each(first, last, [&](instruction& ins) { ins.clear_arguments(); });
assert(std::all_of(first, last, [&](const instruction& ins) { return ins.outputs().empty(); }));
return impl->instructions.erase(first, last);
}
instruction_ref program::move_instruction(instruction_ref src, instruction_ref dst)
{
impl->instructions.splice(dst, impl->instructions, src);
return src;
}
instruction_ref program::move_instructions(instruction_ref src, instruction_ref dst)
{
this->move_instruction(src, dst);
for(auto ins : src->inputs())
this->move_instruction(ins, src);
return src;
}
instruction_ref program::add_literal(literal l)
{
impl->instructions.emplace_front(std::move(l));
return impl->instructions.begin();
}
instruction_ref program::add_outline(const shape& s)
{
impl->instructions.push_front({builtin::outline{s}, s, {}});
return impl->instructions.begin();
}
instruction_ref program::add_parameter(std::string name, shape s)
{
assert(get_parameter_shape(name) == shape{});
impl->input_names.push_back(name);
impl->instructions.push_front({builtin::param{std::move(name)}, std::move(s), {}});
return impl->instructions.begin();
}
instruction_ref program::add_return(std::vector<instruction_ref> args)
{
assert(std::all_of(
args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
impl->instructions.push_back({builtin::returns{}, {}, args});
auto result = std::prev(impl->instructions.end());
instruction::backreference(result);
assert(result->valid(begin()));
return result;
impl->target_name = p.impl->target_name;
impl->modules = p.impl->modules;
}
shape program::get_parameter_shape(std::string name) const
{
auto ins = std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
if(x.name() == "@param")
{
return any_cast<builtin::param>(x.get_operator()).parameter == name;
}
else
{
return false;
}
});
if(ins != this->end())
return ins->get_shape();
else
return {};
const auto* mm = this->get_main_module();
return mm->get_parameter_shape(std::move(name));
}
std::vector<std::string> program::get_parameter_names() const
{
std::vector<std::string> result = impl->input_names;
std::unordered_set<std::string> params;
for(auto&& ins : impl->instructions)
{
if(ins.name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins.get_operator()).parameter;
params.insert(name);
}
}
erase_if(result, [&](auto&& name) { return params.count(name) == 0; });
return result;
const auto* mm = this->get_main_module();
return mm->get_parameter_names();
}
instruction_ref program::get_parameter(std::string name) const
{
auto ins = std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
if(x.name() == "@param")
{
return any_cast<builtin::param>(x.get_operator()).parameter == name;
}
else
{
return false;
}
});
if(ins != this->end())
return ins;
else
return this->end();
const auto* mm = this->get_main_module();
return mm->get_parameter(std::move(name));
}
std::unordered_map<std::string, shape> program::get_parameter_shapes() const
{
std::unordered_map<std::string, shape> result;
for(auto&& ins : impl->instructions)
{
if(ins.name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins.get_operator()).parameter;
result[name] = ins.get_shape();
}
}
return result;
}
bool program::has_instruction(instruction_ref ins) const
{
return std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
return std::addressof(*ins) == std::addressof(x);
}) != impl->instructions.end();
const auto* mm = this->get_main_module();
return mm->get_parameter_shapes();
}
std::size_t program::size() const { return impl->instructions.size(); }
instruction_ref program::begin() const { return impl->instructions.begin(); }
instruction_ref program::end() const { return impl->instructions.end(); }
std::size_t program::size() const { return impl->modules.size(); }
std::vector<shape> program::get_output_shapes() const
{
auto last_ins = impl->instructions.back();
if(last_ins.name() == "@return")
{
const auto& output_ins = last_ins.inputs();
std::vector<shape> output_shapes;
std::transform(output_ins.begin(),
output_ins.end(),
std::back_inserter(output_shapes),
[](auto& ins) { return ins->get_shape(); });
return output_shapes;
}
// The else branch is to provide backward compatibility
else
{
return {last_ins.get_shape()};
}
const auto* mm = this->get_main_module();
return mm->get_output_shapes();
}
context& program::get_context() const { return impl->ctx; }
instruction_ref program::validate() const
{
return std::find_if(impl->instructions.begin(),
impl->instructions.end(),
[&](const instruction& i) { return !i.valid(impl->instructions.begin()); });
const auto* mm = this->get_main_module();
return mm->validate();
}
bool program::is_compiled() const { return not this->impl->target_name.empty(); }
void program::compile(const target& t, compile_options options)
{
assert(this->validate() == impl->instructions.end());
assert(not this->is_compiled());
this->impl->target_name = t.name();
this->impl->ctx = t.get_context();
if(enabled(MIGRAPHX_TRACE_COMPILE{}))
options.trace = tracer{std::cout};
options.trace(*this);
options.trace();
run_passes(*this, t.get_passes(this->impl->ctx, options), options.trace);
auto invalid = this->validate();
if(invalid != impl->instructions.end())
auto&& passes = t.get_passes(this->impl->ctx, options);
for(auto& mp : impl->modules)
{
auto index = std::distance(impl->instructions.begin(), invalid);
MIGRAPHX_THROW("Invalid program from compilation at instruction " + std::to_string(index));
auto& modl = mp.second;
assert(modl.validate() == modl.end());
run_passes(modl, passes, options.trace);
auto invalid = this->validate();
if(invalid != modl.end())
{
auto index = std::distance(modl.begin(), invalid);
MIGRAPHX_THROW("Invalid module " + mp.first + " from compilation at instruction " +
std::to_string(index));
}
modl.finalize(this->impl->ctx);
}
this->finalize();
}
void program::finalize()
{
for(auto ins : iterator_for(*this))
for(auto& mp : this->impl->modules)
{
ins->finalize(this->impl->ctx);
mp.second.finalize(this->impl->ctx);
}
}
template <class F>
std::vector<argument> generic_eval(const program& p,
std::vector<argument> generic_eval(const module& p,
context& ctx,
std::unordered_map<std::string, argument> params,
F trace)
......@@ -518,6 +237,16 @@ std::vector<argument> generic_eval(const program& p,
return {results.at(std::prev(p.end()))};
}
template <class F>
std::vector<argument> generic_eval(const program& p,
context& ctx,
std::unordered_map<std::string, argument> params,
F trace)
{
const auto* mm = p.get_main_module();
return generic_eval(*mm, ctx, params, trace);
}
std::vector<argument> program::eval(parameter_map params) const
{
auto& ctx = this->impl->ctx;
......@@ -555,7 +284,7 @@ std::vector<argument> program::eval(parameter_map params) const
}
}
const int program_file_version = 1;
const int program_file_version = 2;
value program::to_value() const
{
......@@ -564,31 +293,24 @@ value program::to_value() const
result["target"] = this->impl->target_name;
if(not this->impl->target_name.empty())
result["context"] = this->impl->ctx.to_value();
value nodes;
print_program(*this, [&](auto ins, const auto& names) {
value node;
node["output"] = names.at(ins);
node["name"] = ins->name();
node["shape"] = migraphx::to_value(ins->get_shape());
if(ins->name() == "@literal")
node["literal"] = migraphx::to_value(ins->get_literal());
node["operator"] = ins->get_operator().to_value();
std::vector<std::string> inputs;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto i) { return names.at(i); });
node["inputs"] = inputs;
nodes.push_back(node);
});
result["nodes"] = nodes;
result["modules"] = value::object{};
auto& module_val = result.at("modules");
for(auto& m : impl->modules)
{
module_val[m.first] = m.second.to_value();
}
return result;
}
void program::from_value(const value& v)
{
auto version = v.at("version").to<int>();
if(version != program_file_version)
std::cout << "Warning: Version mismatch" << std::endl;
{
MIGRAPHX_THROW("Warning: Program version mismatch");
}
this->impl->target_name = v.at("target").to<std::string>();
if(not this->impl->target_name.empty())
{
......@@ -597,35 +319,14 @@ void program::from_value(const value& v)
this->impl->ctx.from_value(v.at("context"));
}
std::unordered_map<std::string, instruction_ref> instructions;
for(const value& node : v.at("nodes"))
auto val_modules = v.at("modules");
for(const auto& vv : val_modules)
{
instruction_ref output;
auto name = node.at("name").to<std::string>();
auto fields = node.at("operator");
if(name == "@param")
{
output = this->add_parameter(fields["parameter"].to<std::string>(),
migraphx::from_value<shape>(node.at("shape")));
}
else if(name == "@literal")
{
output = this->add_literal(migraphx::from_value<literal>(node.at("literal")));
}
else
{
auto op = make_op(name, fields);
std::vector<instruction_ref> inputs;
std::transform(node.at("inputs").begin(),
node.at("inputs").end(),
std::back_inserter(inputs),
[&](const value& i) { return instructions[i.to<std::string>()]; });
if(name == "@return")
output = this->add_return(inputs);
else
output = this->add_instruction(op, inputs);
}
instructions[node.at("output").to<std::string>()] = output;
const auto& key = vv.get_key();
auto val = vv.without_key();
module modl;
modl.from_value(val);
impl->modules[key] = modl;
}
this->finalize();
}
......@@ -698,7 +399,7 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
double calculate_overhead_time = total_time - total_instruction_time;
double calculate_overhead_percent = calculate_overhead_time * 100.0 / total_time;
print_program(*this, [&](auto ins, const auto& names) {
this->print([&](auto ins, auto names) {
print_instruction(std::cout, ins, names);
// skip return instruction
......@@ -741,18 +442,23 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
void program::debug_print() const { std::cout << *this << std::endl; }
void program::debug_print(instruction_ref ins) const
{
if(ins == this->end())
if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](auto it) {
return (it.second.end() == ins);
}))
{
std::cout << "End instruction" << std::endl;
return;
}
if(not has_instruction(ins))
else if(not std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](auto it) {
return it.second.has_instruction(ins);
}))
{
std::cout << "Instruction not part of program" << std::endl;
return;
}
std::stringstream ss;
print_program(*this, [&](auto x, const auto& names) {
this->print([&](auto x, const auto& names) {
if(x == ins)
{
print_instruction(std::cout, x, names);
......@@ -760,140 +466,28 @@ void program::debug_print(instruction_ref ins) const
}
});
}
void program::debug_print(const std::vector<instruction_ref>& inss) const
{
for(auto ins : inss)
debug_print(ins);
std::cout << std::endl;
}
static std::string enclose_name(const std::string& name)
void program::print(const std::function<
void(instruction_ref, const std::unordered_map<instruction_ref, std::string>&)>&
print_func) const
{
return '"' + replace_string(name, "\"", "\\\"") + '"';
}
void program::print_graph(std::ostream& os, bool brief) const
{
os << "digraph {" << std::endl;
os << "\trankdir=LR;" << std::endl;
print_program(*this, [&](auto ins, const auto& names) {
std::string label;
if(brief)
label = ins->name();
else
label = to_string(ins->get_operator());
os << "\t" << enclose_name(names.at(ins)) << "[label=" << enclose_name(label) << "]";
os << ";" << std::endl;
if(!ins->inputs().empty())
{
for(auto&& arg : ins->inputs())
{
os << "\t" << enclose_name(names.at(arg)) << " -> " << enclose_name(names.at(ins));
if(not brief)
os << "[label=" << enclose_name(to_string(ins->get_shape())) << "]";
os << ";" << std::endl;
}
}
});
os << "}" << std::endl;
}
static std::string cpp_var_name(const std::string& name)
{
return "m" + replace_string(name, "@", "x");
}
static std::string cpp_op_var(const std::string& name, instruction_ref ins)
{
return replace_string(name, "@", ins->name());
}
static void print_op_attributes(std::ostream& os, const std::string& name, const operation& op)
{
std::string x = to_string(op);
if(contains(x, "["))
for(const auto& mdl : this->impl->modules)
{
auto start = x.find('[');
auto end = x.find(']');
std::string attribute_text = x.substr(start + 1, end - start - 1);
std::vector<std::string> attributes;
for(auto&& attribute : split_string(attribute_text, ','))
{
if(contains(attribute, '='))
attributes.push_back(attribute);
else
attributes.back() += "," + attribute;
}
for(auto&& attribute : attributes)
{
auto p = split_string(attribute, '=');
auto key = p.front();
auto value = p.back();
if(contains({"bn_mode", "padding_mode"}, key))
continue;
if(key == "mode")
value = enclose_name(trim(value));
os << name << "." << key << " = " << value << ";" << std::endl;
}
mdl.second.print(print_func);
}
}
static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
void program::print_graph(std::ostream& os, bool brief) const
{
os << "migraphx::shape{migraphx::shape::" << s.type_string();
os << ", {" << to_string_range(s.lens()) << "}";
if(not s.standard())
os << ", {" << to_string_range(s.strides()) << "}";
os << "}";
const auto* mm = this->get_main_module();
mm->print_graph(os, brief);
}
void program::print_cpp(std::ostream& os) const
{
os << "migraphx::program p;" << std::endl;
// cppcheck-suppress variableScope
unsigned long seed = 0;
print_program(*this, [&](auto ins, const auto& names) {
auto op = cpp_op_var(names.at(ins), ins);
if(ins->name().front() != '@')
{
os << "migraphx::op::" << ins->name() << " " << op << ";" << std::endl;
print_op_attributes(os, op, ins->get_operator());
}
os << "auto " << cpp_var_name(names.at(ins)) << " = ";
if(ins->name() == "@literal")
{
os << "p.add_literal(";
bool use_abs = false;
ins->get_literal().visit([&](auto v) {
use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
});
if(use_abs)
os << "migraphx::abs(";
os << "migraphx::generate_literal(";
print_cpp_shape(os, ins->get_shape());
os << ", " << seed << ")";
if(use_abs)
os << ")";
os << ");" << std::endl;
seed++;
}
else if(ins->name() == "@param")
{
std::string name = any_cast<builtin::param>(ins->get_operator()).parameter;
os << "p.add_parameter(" << enclose_name(name) << ",";
print_cpp_shape(os, ins->get_shape());
os << ");" << std::endl;
}
else
{
os << "p.add_instruction(" << op;
for(auto input : ins->inputs())
{
os << ", " << cpp_var_name(names.at(input));
}
os << ");" << std::endl;
}
});
const auto* mm = this->get_main_module();
mm->print_cpp(os);
}
void program::dry_run(std::unordered_map<std::string, argument> params) const
......@@ -902,23 +496,26 @@ void program::dry_run(std::unordered_map<std::string, argument> params) const
generic_eval(*this, ctx, std::move(params), [](auto&&...) { return argument{}; });
}
void program::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const
void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const
{
print_program(*this, [&](auto ins, const auto& names) {
print_instruction(os, ins, names);
a(ins);
os << std::endl;
});
for(auto& modl : this->impl->modules)
{
std::cout << modl.first << ":" << std::endl;
modl.second.annotate(os, a);
}
}
module* program::get_main_module() { return &impl->modules["main"]; }
const module* program::get_main_module() const { return &impl->modules["main"]; }
program& program::sort()
{
fix([&](auto self, auto ins) {
this->move_instruction(ins, this->begin());
for(auto child : ins->inputs())
self(child);
})(std::prev(this->end()));
assert(this->validate() == this->end());
for(auto& modl : this->impl->modules)
{
modl.second.sort();
}
return *this;
}
......@@ -926,10 +523,13 @@ bool operator==(const program& x, const program& y) { return to_string(x) == to_
std::ostream& operator<<(std::ostream& os, const program& p)
{
print_program(p, [&](auto ins, const auto& names) {
print_instruction(os, ins, names);
for(auto& mp : p.impl->modules)
{
os << "Module " << mp.first << ": " << std::endl;
os << mp.second;
os << std::endl;
});
}
return os;
}
......
......@@ -208,17 +208,6 @@ migraphx::shape to_shape(const py::buffer_info& info)
}
}
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_wrap
{
migraphx::program* prog;
operator const migraphx::program&() const { return *prog; }
operator migraphx::program&() { return *prog; }
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
MIGRAPHX_PYBIND11_MODULE(migraphx, m)
{
py::class_<migraphx::shape>(m, "shape")
......@@ -258,12 +247,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target");
py::class_<migraphx::module_wrap>(m, "module")
.def("print", [](const migraphx::module_wrap& mm) { std::cout << *mm.prog << std::endl; })
.def("__eq__", std::equal_to<migraphx::program>{})
.def("__ne__", std::not_equal_to<migraphx::program>{})
.def("__repr__",
[](const migraphx::module_wrap& mm) { return migraphx::to_string(*mm.prog); });
py::class_<migraphx::module>(m, "module")
.def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; })
.def("__eq__", std::equal_to<migraphx::module>{})
.def("__ne__", std::not_equal_to<migraphx::module>{})
.def("__repr__", [](const migraphx::module& mm) { return migraphx::to_string(mm); });
py::class_<migraphx::program>(m, "program")
.def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); })
......@@ -284,7 +272,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("get_main_module",
[](migraphx::program& p) {
auto* mm = p.get_main_module();
return migraphx::module_wrap{mm};
return *mm;
})
.def("run",
[](migraphx::program& p, py::dict params) {
......
......@@ -7,15 +7,14 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
struct module;
namespace cpu {
struct lowering
{
std::string name() const { return "cpu::lowering"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace cpu
......
......@@ -526,14 +526,14 @@ struct cpu_literal
struct cpu_apply
{
module* prog;
module* modl;
std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{};
std::unordered_map<instruction_ref, std::string> prog_output_names{};
instruction_ref last{};
void create_output_names()
{
this->last = instruction::get_output_alias(std::prev(prog->end()));
this->last = instruction::get_output_alias(std::prev(modl->end()));
if(this->last->name() == "@return")
{
const auto& prog_outputs = last->inputs();
......@@ -558,7 +558,7 @@ struct cpu_apply
auto&& op = ins->get_operator();
if(allocate)
replace(ins, make_op(cpu_name, op.to_value()));
return prog->replace_instruction(ins, make_op(cpu_name, op.to_value()), ins->inputs());
return modl->replace_instruction(ins, make_op(cpu_name, op.to_value()), ins->inputs());
});
}
......@@ -610,7 +610,7 @@ struct cpu_apply
void apply()
{
init();
for(auto it : iterator_for(*prog))
for(auto it : iterator_for(*modl))
{
if(it->name() == "@literal")
{
......@@ -629,7 +629,7 @@ struct cpu_apply
instruction_ref apply_literal(instruction_ref ins) const
{
return prog->replace_instruction(ins, cpu_literal{ins->get_literal().get_argument()});
return modl->replace_instruction(ins, cpu_literal{ins->get_literal().get_argument()});
}
instruction_ref apply_pooling(instruction_ref ins)
......@@ -651,7 +651,7 @@ struct cpu_apply
{
auto inputs = ins->inputs();
inputs.push_back(insert_allocation(ins, ins->get_shape()));
return prog->replace_instruction(ins, op, inputs);
return modl->replace_instruction(ins, op, inputs);
}
instruction_ref insert_allocation(instruction_ref ins, const shape& s)
......@@ -659,18 +659,18 @@ struct cpu_apply
auto ins_alias = instruction::get_output_alias(ins);
if(last->name() == "@return" and prog_output_names.count(ins_alias) > 0)
{
return prog->add_parameter(prog_output_names[ins_alias], s);
return modl->add_parameter(prog_output_names[ins_alias], s);
}
else if(ins == last)
{
return prog->add_parameter("output", s);
return modl->add_parameter("output", s);
}
return prog->insert_instruction(ins, make_op("cpu::allocate", {{"shape", to_value(s)}}));
return modl->insert_instruction(ins, make_op("cpu::allocate", {{"shape", to_value(s)}}));
}
};
void lowering::apply(module& p) const { cpu_apply{&p}.apply(); }
void lowering::apply(module& m) const { cpu_apply{&m}.apply(); }
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -356,7 +356,7 @@ struct find_triadd_layernorm
match::used_once(), match::all_of[match::inputs()](match::standard_shape()))));
}
void apply(program& p, const match::matcher_result& r) const
void apply(module& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto triadd = ins->inputs().front();
......
......@@ -7,8 +7,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
struct module;
namespace gpu {
......
......@@ -7,8 +7,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
struct module;
namespace gpu {
......
......@@ -7,8 +7,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
struct module;
namespace gpu {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment