Unverified Commit 1b098fd7 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into type-string-driver

parents 05f2ee1c c0398ded
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_softmax : op_parser<parse_softmax>
{
std::vector<op_desc> operators() const
{
return {{"Softmax", "softmax"}, {"LogSoftmax", "logsoftmax"}};
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
// default axis value is -1 for opset 13
int64_t axis = -1;
// axis value is 1 for previous opset versions
if(parser.opset_version < 13)
{
axis = 1;
}
if(contains(info.attributes, "axis"))
{
axis = parser.parse_value(info.attributes.at("axis")).at<int>();
}
return info.add_instruction(make_op(opd.op_name, {{"axis", axis}}), args);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_softplus : op_parser<parse_softplus>
{
std::vector<op_desc> operators() const { return {{"Softplus"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
// Apply pointwise formula: y = ln(exp(x) + 1)
auto mb_ones = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", args[0]->get_shape().lens()}}),
info.add_literal(migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {1}}));
auto exp = info.add_instruction(migraphx::make_op("exp"), args[0]);
auto add = info.add_instruction(migraphx::make_op("add"), exp, mb_ones);
return info.add_instruction(migraphx::make_op("log"), add);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_softsign : op_parser<parse_softsign>
{
std::vector<op_desc> operators() const { return {{"Softsign"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
// Apply pointwise formula: y = x / (1 + |x|)
auto mb_ones = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", args[0]->get_shape().lens()}}),
info.add_literal(migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {1}}));
auto abs = info.add_instruction(migraphx::make_op("abs"), args[0]);
auto add = info.add_instruction(migraphx::make_op("add"), abs, mb_ones);
return info.add_instruction(migraphx::make_op("div"), args[0], add);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_spacetodepth : op_parser<parse_spacetodepth>
{
std::vector<op_desc> operators() const { return {{"SpaceToDepth"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto s = args[0]->get_shape();
// blocksize attribute of SpaceToDepth
int blocksize = 1; // if blockSize of 1 then, this is a no-op
if(contains(info.attributes, "blocksize"))
{
blocksize = info.attributes.at("blocksize").i();
}
if(blocksize < 1)
{
// blockSize less than 1 would rather result in DepthToSpace instead of SpaceToDepth
MIGRAPHX_THROW("SpaceToDepth: blocksize is less than 1");
}
// calculate dimensions
auto res_lens = s.lens(); // {N, C, H, W}
if(((res_lens[2] % blocksize) == 0) and ((res_lens[3] % blocksize) == 0))
{
// Co = C * (blocksize ^ 2)
res_lens[1] = res_lens[1] * blocksize * blocksize;
// Ho = (H / blocksize)
res_lens[2] = res_lens[2] / blocksize;
// Wo = (W / blocksize)
res_lens[3] = res_lens[3] / blocksize;
} // res_shape = (N, Co, Ho, Wo)
else
MIGRAPHX_THROW("SpaceToDepth: div by blocksize quotient not int ");
auto trans_lens = s.lens(); // {N, C, H, W}
trans_lens[2] = res_lens[2];
trans_lens[3] = blocksize;
trans_lens.push_back(res_lens[3]);
trans_lens.push_back(blocksize); // {N, C, Ho, blocksize, Wo, blocksize}
std::vector<int64_t> perm = {0, 3, 5, 1, 2, 4};
auto temp1 = info.add_instruction(make_op("reshape", {{"dims", trans_lens}}), args[0]);
auto temp2 = info.add_instruction(make_op("transpose", {{"permutation", perm}}), temp1);
return info.add_instruction(make_op("reshape", {{"dims", res_lens}}),
info.make_contiguous(temp2));
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_split : op_parser<parse_split>
{
std::vector<op_desc> operators() const { return {{"Split"}}; }
std::vector<instruction_ref> parse(const op_desc& opd,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
int64_t axis = 0;
if(contains(info.attributes, "axis"))
{
axis = parser.parse_value(info.attributes.at("axis")).at<int>();
}
auto lens = args[0]->get_shape().lens();
int64_t n_rank = lens.size();
int64_t tuned_axis = tune_axis(n_rank, axis, opd.op_name);
std::vector<int64_t> vec_splits;
if(contains(info.attributes, "split"))
{
literal s = parser.parse_value(info.attributes.at("split"));
s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });
if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
static_cast<int64_t>(lens[tuned_axis]))
{
MIGRAPHX_THROW("PARSE_SPLIT: sum of split attribute unequal to dim size of axis!");
}
}
// no split attribute, input is equally divided
else
{
if((lens[tuned_axis] % info.num_outputs) != 0)
{
MIGRAPHX_THROW("PARSE_SPLIT: input cannot be equally divided into " +
std::to_string(info.num_outputs) + " splits!");
}
auto dl = lens[tuned_axis] / info.num_outputs;
vec_splits.resize(info.num_outputs, dl);
}
std::vector<instruction_ref> ret_ins;
int64_t start = 0;
for(auto sl : vec_splits)
{
ret_ins.push_back(info.add_instruction(
make_op("slice", {{"axes", {axis}}, {"starts", {start}}, {"ends", {start + sl}}}),
args[0]));
start += sl;
}
return ret_ins;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_squeeze : op_parser<parse_squeeze>
{
std::vector<op_desc> operators() const
{
return {{"Squeeze", "squeeze"}, {"Unsqueeze", "unsqueeze"}};
}
operation assign_axes(operation& op, const std::vector<int64_t>& axes) const
{
auto v = op.to_value();
v["axes"] = axes;
op.from_value(v);
return op;
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto op = parser.load(opd.op_name, info);
if(args.size() == 2)
{
auto arg_axes = args.at(1)->eval();
check_arg_empty(arg_axes, "PARSE_" + opd.op_name + ": cannot handle variable axes!");
std::vector<int64_t> axes;
arg_axes.visit([&](auto s) { axes.assign(s.begin(), s.end()); });
op = assign_axes(op, axes);
}
auto arg = info.make_contiguous(args.front());
return info.add_instruction(op, arg);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_thresholdedrelu : op_parser<parse_thresholdedrelu>
{
std::vector<op_desc> operators() const { return {{"ThresholdedRelu"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float alpha = 1.0;
if(contains(info.attributes, "alpha"))
alpha = parser.parse_value(info.attributes.at("alpha")).at<float>();
auto x_shape = args[0]->get_shape();
auto lit_zero = info.add_literal(migraphx::literal{migraphx::shape{x_shape.type()}, {0}});
auto lit_alpha =
info.add_literal(migraphx::literal{migraphx::shape{x_shape.type()}, {alpha}});
auto mb_zero = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x_shape.lens()}}), lit_zero);
auto mb_alpha = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x_shape.lens()}}), lit_alpha);
auto condition = info.add_instruction(migraphx::make_op("greater"), args[0], mb_alpha);
return info.add_instruction(migraphx::make_op("where"), condition, args[0], mb_zero);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_tile : op_parser<parse_tile>
{
std::vector<op_desc> operators() const { return {{"Tile"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
migraphx::argument arg_s = args[1]->eval();
check_arg_empty(arg_s, "PARSE_TILE: dynamic shape is not supported");
std::vector<std::int64_t> repeats;
arg_s.visit([&](auto input) { repeats.assign(input.begin(), input.end()); });
auto l0 = args[0];
for(int i = 0; i < repeats.size(); i++)
{
auto l1 = l0;
for(int j = 1; j < repeats[i]; j++)
{
l0 = info.add_instruction(make_op("concat", {{"axis", i}}), l0, l1);
}
}
return l0;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_topk : op_parser<parse_topk>
{
std::vector<op_desc> operators() const { return {{"TopK"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
int64_t k = 0;
if(args.size() == 2)
{
auto arg_k = args.at(1)->eval();
check_arg_empty(arg_k, "PARSE_TopK: k input must be constant");
k = arg_k.at<int>();
}
else if(contains(info.attributes, "k"))
{
k = info.attributes.at("k").i();
}
bool largest = true;
if(contains(info.attributes, "largest"))
{
largest = static_cast<bool>(info.attributes.at("largest").i());
}
int64_t axis = -1;
if(contains(info.attributes, "axis"))
{
axis = parser.parse_value(info.attributes.at("axis")).at<int>();
}
auto topk_ret = info.add_instruction(
make_op("topk", {{"k", k}, {"axis", axis}, {"largest", largest}}), args.at(0));
auto ret_val = info.add_instruction(make_op("get_tuple_elem", {{"index", 0}}), topk_ret);
auto ret_ind = info.add_instruction(make_op("get_tuple_elem", {{"index", 1}}), topk_ret);
return {ret_val, ret_ind};
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_transpose : op_parser<parse_transpose>
{
std::vector<op_desc> operators() const { return {{"Transpose"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
std::vector<int64_t> perm{};
if(contains(info.attributes, "perm"))
{
auto&& perm_vals = info.attributes["perm"].ints();
perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
}
// if perm is empty, use the default value
auto n_dim = args.front()->get_shape().lens().size();
if(perm.empty())
{
perm.resize(n_dim);
std::iota(perm.rbegin(), perm.rend(), 0);
}
if(perm.size() != n_dim)
{
MIGRAPHX_THROW("PARSE_TRANSPOSE: perm and input have diffferent number of dims!");
}
return info.add_instruction(make_op("transpose", {{"permutation", perm}}), args.front());
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_variadic_op : op_parser<parse_variadic_op>
{
std::vector<op_desc> operators() const
{
return {{"Sum", "add"}, {"Max", "max"}, {"Min", "min"}};
}
instruction_ref parse(const op_desc& opd,
const onnx_parser&,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
return std::accumulate(std::next(args.begin()),
args.end(),
args.front(),
[&](instruction_ref a, instruction_ref b) {
return info.add_broadcastable_binary_op(opd.op_name, a, b);
});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_where : op_parser<parse_where>
{
std::vector<op_desc> operators() const { return {{"Where"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto lens =
compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(args[0]->get_shape().lens() != lens)
{
args[0] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[0]);
}
if(args[1]->get_shape().lens() != lens)
{
args[1] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[1]);
}
if(args[2]->get_shape().lens() != lens)
{
args[2] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[2]);
}
return info.add_instruction(make_op("where"), args[0], args[1], args[2]);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <vector>
#include <algorithm>
#include <cmath>
template <typename T>
std::vector<T> softmax(const std::vector<T>& p)
{
size_t n = p.size();
std::vector<T> result(n);
std::transform(p.begin(), p.end(), result.begin(), [](auto x) { return std::exp(x); });
T s = std::accumulate(result.begin(), result.end(), 0.0f, std::plus<T>());
std::transform(result.begin(), result.end(), result.begin(), [=](auto x) { return x / s; });
return result;
}
//
// Supporting functions for enum values used in operator parameters.
// These values are declared as "enum class" and should include << streaming operators
// to be able to write their values in human-readable format so users can
// save and edit model files.
//
#include <sstream>
#include <migraphx/op/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
std::ostream& operator<<(std::ostream& os, pooling_mode v)
{
// the strings for the enum are the same as the values used for onnx parsing
// but this enum is not onnx-specific: strings must be converted when parsing tf
static const std::vector<std::string> pooling_mode_str = {"average", "max", "lpnorm"};
os << pooling_mode_str[static_cast<std::underlying_type<pooling_mode>::type>(v)];
return os;
}
std::ostream& operator<<(std::ostream& os, rnn_direction v)
{
static const std::vector<std::string> rnn_direction_str = {
"forward", "reverse", "bidirectional"};
os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
return os;
}
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/operation.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void migraphx_to_value(value& v, const operation& op)
{
v["name"] = op.name();
v["operator"] = op.to_value();
}
void migraphx_from_value(const value& v, operation& op)
{
op = make_op(v.at("name").to<std::string>(), v.at("operator"));
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring::apply(program& p) const void memory_coloring::apply(module& m) const
{ {
if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{})) if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
{ {
memory_coloring_impl opt(&p, allocation_op, verify); memory_coloring_impl opt(&m, allocation_op, verify);
opt.run(); opt.run();
} }
} }
......
#include <migraphx/op/load.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include "memory_coloring_impl.hpp" #include "memory_coloring_impl.hpp"
namespace migraphx { namespace migraphx {
...@@ -6,8 +9,11 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -6,8 +9,11 @@ inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring_impl::run() void memory_coloring_impl::run()
{ {
// calc implicit depdendencies
mod_implicit_deps = p_mod->calc_implicit_deps();
MIGRAPHX_DEBUG(dump("---Before memory coloring---")); MIGRAPHX_DEBUG(dump("---Before memory coloring---"));
MIGRAPHX_DEBUG(dump_program()); MIGRAPHX_DEBUG(dump_module());
build(); build();
if(num_of_lives != 0) if(num_of_lives != 0)
{ {
...@@ -19,7 +25,10 @@ void memory_coloring_impl::run() ...@@ -19,7 +25,10 @@ void memory_coloring_impl::run()
allocate(interval); allocate(interval);
alloc_queue.pop(); alloc_queue.pop();
} }
// rewrite happens after all modules are processed
rewrite(); rewrite();
if(enable_verify) if(enable_verify)
verify(); verify();
} }
...@@ -31,7 +40,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval) ...@@ -31,7 +40,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
std::size_t size = s.bytes(); std::size_t size = s.bytes();
if(size == 0) if(size == 0)
return false; return false;
std::size_t element_size = size / s.elements(); std::size_t element_size = (s.elements() == 0 ? 4 : (size / s.elements()));
live_range& segment = interval->segment; live_range& segment = interval->segment;
int vn = segment.vn; int vn = segment.vn;
std::priority_queue<live_range*, std::vector<live_range*>, ordering> conflict_queue; std::priority_queue<live_range*, std::vector<live_range*>, ordering> conflict_queue;
...@@ -41,7 +50,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval) ...@@ -41,7 +50,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
if(conflict_table.find(vn) != conflict_table.end()) if(conflict_table.find(vn) != conflict_table.end())
{ {
std::set<int>& vn_set = conflict_table[vn]; std::set<int>& vn_set = conflict_table[vn];
for(auto& iter : vn_set) for(const auto& iter : vn_set)
{ {
live_range* range = live_ranges[iter]; live_range* range = live_ranges[iter];
long long offset = range->offset; long long offset = range->offset;
...@@ -96,13 +105,13 @@ bool memory_coloring_impl::allocate(interval_ptr interval) ...@@ -96,13 +105,13 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
void memory_coloring_impl::build() void memory_coloring_impl::build()
{ {
std::size_t num_of_instrs = p_program->size(); std::size_t num_of_instrs = p_mod->size();
if(num_of_instrs == 0) if(num_of_instrs == 0)
return; return;
auto cur_points = num_of_instrs * 2; auto cur_points = num_of_instrs * 2;
instruction_ref iter = p_program->end(); instruction_ref iter = p_mod->end();
instruction_ref begin = p_program->begin(); instruction_ref begin = p_mod->begin();
std::vector<instruction_ref> dead_instrs; std::vector<instruction_ref> dead_instrs;
std::set<int> live_set; std::set<int> live_set;
// Build live intervals. // Build live intervals.
...@@ -134,8 +143,19 @@ void memory_coloring_impl::build() ...@@ -134,8 +143,19 @@ void memory_coloring_impl::build()
{ {
is_dead = true; is_dead = true;
} }
for(auto&& arg : iter->inputs())
auto inputs = iter->inputs();
if(contains(mod_implicit_deps, iter))
{ {
const auto& impl_deps = mod_implicit_deps.at(iter);
inputs.insert(inputs.end(), impl_deps.begin(), impl_deps.end());
}
for(auto&& arg : inputs)
{
if(not p_mod->has_instruction(arg))
continue;
if(is_param(arg) || is_outline(arg)) if(is_param(arg) || is_outline(arg))
{ {
if(is_output_param(arg)) if(is_output_param(arg))
...@@ -182,8 +202,8 @@ void memory_coloring_impl::rewrite() ...@@ -182,8 +202,8 @@ void memory_coloring_impl::rewrite()
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
dims.push_back((required_bytes + sizeof(float) - 1) / sizeof(float)); dims.push_back((required_bytes + sizeof(float) - 1) / sizeof(float));
shape s = {shape::float_type, dims}; shape s = {shape::float_type, dims};
instruction_ref scratch_param = p_program->add_parameter("scratch", s); instruction_ref scratch_param = p_mod->add_parameter("scratch", s);
for(auto ins : iterator_for(*p_program)) for(auto ins : iterator_for(*p_mod))
{ {
const instruction* p_iter = &(*ins); const instruction* p_iter = &(*ins);
if(instr2_live.find(p_iter) != instr2_live.end()) if(instr2_live.find(p_iter) != instr2_live.end())
...@@ -207,13 +227,15 @@ void memory_coloring_impl::rewrite() ...@@ -207,13 +227,15 @@ void memory_coloring_impl::rewrite()
if(is_allocate(ins)) if(is_allocate(ins))
{ {
p_program->replace_instruction( p_mod->replace_instruction(
ins, op::load{ins->get_shape(), offset}, scratch_param); ins,
make_op("load", {{"shape", to_value(ins->get_shape())}, {"offset", offset}}),
scratch_param);
} }
} }
} }
MIGRAPHX_DEBUG(dump("---After rewrite---")); MIGRAPHX_DEBUG(dump("---After rewrite---"));
MIGRAPHX_DEBUG(dump_program()); MIGRAPHX_DEBUG(dump_module());
} }
void memory_coloring_impl::verify() void memory_coloring_impl::verify()
...@@ -227,8 +249,8 @@ void memory_coloring_impl::verify() ...@@ -227,8 +249,8 @@ void memory_coloring_impl::verify()
if(segment.begin == invalid_offset) if(segment.begin == invalid_offset)
{ {
if(!interval.is_live_on_entry) // if(!interval.is_live_on_entry)
MIGRAPHX_THROW("interval is not live on entry"); // MIGRAPHX_THROW("interval is not live on entry");
continue; continue;
} }
...@@ -240,7 +262,7 @@ void memory_coloring_impl::verify() ...@@ -240,7 +262,7 @@ void memory_coloring_impl::verify()
if(conflict_table.find(vn) != conflict_table.end()) if(conflict_table.find(vn) != conflict_table.end())
{ {
std::set<int>& vn_set = conflict_table[vn]; std::set<int>& vn_set = conflict_table[vn];
for(auto& iter : vn_set) for(const auto& iter : vn_set)
{ {
live_range* range = live_ranges[iter]; live_range* range = live_ranges[iter];
if(range->offset == invalid_offset) if(range->offset == invalid_offset)
...@@ -257,7 +279,7 @@ void memory_coloring_impl::verify() ...@@ -257,7 +279,7 @@ void memory_coloring_impl::verify()
void memory_coloring_impl::dump(const std::string& str) { std::cout << str << std::endl; } void memory_coloring_impl::dump(const std::string& str) { std::cout << str << std::endl; }
void memory_coloring_impl::dump_program() { std::cout << *p_program << std::endl; } void memory_coloring_impl::dump_module() { std::cout << *p_mod << std::endl; }
void memory_coloring_impl::dump_intervals() void memory_coloring_impl::dump_intervals()
{ {
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/pass_config.hpp> #include <migraphx/pass_config.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <set> #include <set>
...@@ -39,10 +40,6 @@ struct live_interval ...@@ -39,10 +40,6 @@ struct live_interval
{ {
live_interval() : segment({invalid_offset, invalid_offset, invalid_offset, invalid_offset, 0}) live_interval() : segment({invalid_offset, invalid_offset, invalid_offset, invalid_offset, 0})
{ {
id = invalid_offset;
def_point = invalid_offset;
is_literal = false;
is_live_on_entry = false;
} }
void add_use(std::size_t use) { use_points.push_front(use); } void add_use(std::size_t use) { use_points.push_front(use); }
...@@ -55,35 +52,27 @@ struct live_interval ...@@ -55,35 +52,27 @@ struct live_interval
#endif #endif
live_range segment; live_range segment;
std::size_t id; std::size_t id = invalid_offset;
std::list<std::size_t> use_points; std::list<std::size_t> use_points{};
std::size_t def_point; std::size_t def_point = invalid_offset;
shape result; shape result{};
bool is_literal; bool is_literal = false;
bool is_live_on_entry; bool is_live_on_entry = false;
}; };
using interval_ptr = live_interval*; using interval_ptr = live_interval*;
struct memory_coloring_impl struct memory_coloring_impl
{ {
memory_coloring_impl(program* p, std::string alloc_op, bool p_verify) memory_coloring_impl(module* p, std::string alloc_op, bool p_verify)
: p_program(p), allocation_op(std::move(alloc_op)), enable_verify(p_verify) : p_mod(p), allocation_op(std::move(alloc_op)), enable_verify(p_verify)
{ {
instr2_live.clear();
live_ranges.clear();
conflict_table.clear();
num_of_lives = 0;
max_value_number = -1;
required_bytes = 0;
earliest_end_point = -1;
latest_end_point = -1;
unify_literals = false;
} }
bool allocate(interval_ptr); bool allocate(interval_ptr);
void add_conflicts(std::set<int>& live_set, int val) void add_conflicts(const std::set<int>& live_set, int val)
{ {
for(auto& iter : live_set) for(const auto& iter : live_set)
{ {
conflict_table[iter].insert(val); conflict_table[iter].insert(val);
conflict_table[val].insert(iter); conflict_table[val].insert(iter);
...@@ -97,7 +86,11 @@ struct memory_coloring_impl ...@@ -97,7 +86,11 @@ struct memory_coloring_impl
static bool is_param(const instruction_ref ins) { return ins->name() == "@param"; } static bool is_param(const instruction_ref ins) { return ins->name() == "@param"; }
static bool is_output_param(const instruction_ref ins) static bool is_output_param(const instruction_ref ins)
{ {
return is_param(ins) && any_cast<builtin::param>(ins->get_operator()).parameter == "output"; if(not is_param(ins))
return false;
auto param_name = any_cast<builtin::param>(ins->get_operator()).parameter;
return contains(param_name, "#output_");
} }
bool is_allocate(const instruction_ref ins) const { return ins->name() == allocation_op; } bool is_allocate(const instruction_ref ins) const { return ins->name() == allocation_op; }
static bool is_outline(const instruction_ref ins) { return ins->name() == "@outline"; } static bool is_outline(const instruction_ref ins) { return ins->name() == "@outline"; }
...@@ -118,12 +111,12 @@ struct memory_coloring_impl ...@@ -118,12 +111,12 @@ struct memory_coloring_impl
void verify(); void verify();
#ifdef MIGRAPHX_DEBUG_OPT #ifdef MIGRAPHX_DEBUG_OPT
void dump(const std::string&); void dump(const std::string&);
void dump_program(); void dump_module();
void dump_intervals(); void dump_intervals();
#endif #endif
struct ordering struct ordering
{ {
bool operator()(const interval_ptr i1, const interval_ptr i2) const bool operator()(const interval_ptr& i1, const interval_ptr& i2) const
{ {
auto len1 = i1->get_end() - i1->get_begin(); auto len1 = i1->get_end() - i1->get_begin();
auto len2 = i2->get_end() - i2->get_begin(); auto len2 = i2->get_end() - i2->get_begin();
...@@ -145,28 +138,31 @@ struct memory_coloring_impl ...@@ -145,28 +138,31 @@ struct memory_coloring_impl
return (i1->offset > i2->offset); return (i1->offset > i2->offset);
} }
}; };
program* p_program;
module* p_mod;
std::unordered_map<const instruction*, interval_ptr> instr2_live; std::unordered_map<const instruction*, interval_ptr> instr2_live;
// universe of live intervals. // universe of live intervals.
std::vector<live_interval> live_intervals; std::vector<live_interval> live_intervals = {};
// Map live range value number to live range. // Map live range value number to live range.
std::unordered_map<int, live_range*> live_ranges; std::unordered_map<int, live_range*> live_ranges = {};
// Map live range value number to a set of conflicting live ranges' value numbers. // Map live range value number to a set of conflicting live ranges' value numbers.
std::unordered_map<int, std::set<int>> conflict_table; std::unordered_map<int, std::set<int>> conflict_table = {};
// Priority queue for coloring. // Priority queue for coloring.
std::priority_queue<interval_ptr, std::vector<interval_ptr>, ordering> alloc_queue; std::priority_queue<interval_ptr, std::vector<interval_ptr>, ordering> alloc_queue{};
int num_of_lives; int num_of_lives = 0;
int max_value_number; int max_value_number = -1;
std::size_t required_bytes; std::size_t required_bytes = 0;
// The earliest program point where an live interval ends. // The earliest program point where an live interval ends.
int earliest_end_point; int earliest_end_point = -1;
// The latest program point where an live interval ends. // The latest program point where an live interval ends.
int latest_end_point; int latest_end_point = -1;
// Whether to unify literals into coloring. // Whether to unify literals into coloring.
bool unify_literals; bool unify_literals = false;
std::string allocation_op{}; std::string allocation_op{};
bool enable_verify; bool enable_verify;
ins_dep_map mod_implicit_deps;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -15,25 +15,97 @@ ...@@ -15,25 +15,97 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PASSES);
void validate_pass(module& mod, const pass& p, tracer trace)
{ {
for(auto& p : passes) (void)mod;
(void)p;
(void)trace;
#ifndef NDEBUG
trace("Validate ...");
auto invalid = mod.validate();
if(invalid != mod.end())
{ {
trace("Pass: ", p.name()); auto index = std::distance(mod.begin(), invalid);
p.apply(prog); MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " +
trace(prog); std::to_string(index) + ": " + invalid->name());
}
trace();
#endif
}
void run_pass(program& prog, const pass& p, tracer trace)
{
trace("Pass: ", p.name());
p.apply(prog);
trace(prog);
}
#ifndef NDEBUG struct module_pm : module_pass_manager
trace("Validate ..."); {
auto invalid = prog.validate(); module* mod;
if(invalid != prog.end()) program* prog;
tracer* t;
module_pm(module* pmod = nullptr, program* pprog = nullptr, tracer* pt = nullptr)
: mod(pmod), prog(pprog), t(pt)
{
}
template <class... Ts>
void trace(Ts&&... xs) const
{
assert(t);
(*t)(xs...);
}
virtual module& get_module() override
{
assert(mod);
return *mod;
}
virtual module* create_module(const std::string& name) override
{
assert(prog);
return prog->create_module(name);
}
virtual void run_pass(const pass& p) override
{
assert(mod);
trace("Module: ", mod->name(), ", Pass: ", p.name());
assert(mod->validate() == mod->end());
p.apply(*this);
trace(*mod);
validate_pass(*mod, p, *t);
}
};
module& get_module(module_pass_manager& mpm) { return mpm.get_module(); }
void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
{
if(enabled(MIGRAPHX_TRACE_PASSES{}))
trace = tracer{std::cout};
for(const auto& p : passes)
{
module_pm{&mod, nullptr, &trace}.run_pass(p);
}
}
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
{
if(enabled(MIGRAPHX_TRACE_PASSES{}))
trace = tracer{std::cout};
for(const auto& p : passes)
{
auto mods = prog.get_modules();
for(const auto& mod : reverse(mods))
{ {
auto index = std::distance(prog.begin(), invalid); if(mod->bypass())
MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " + continue;
std::to_string(index) + ": " + invalid->name()); module_pm{mod, &prog, &trace}.run_pass(p);
} }
trace(); run_pass(prog, p, trace);
#endif
} }
} }
......
#include <migraphx/permutation.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/algorithm.hpp>
#include <map>
#include <functional>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
shape reorder_shape(const shape& s, const std::vector<int64_t>& permutation)
{
return {s.type(), reorder_dims(s.lens(), permutation), reorder_dims(s.strides(), permutation)};
}
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation)
{
return sort_permutation(permutation, std::less<>{});
}
std::vector<int64_t> find_permutation(const shape& s)
{
std::vector<std::int64_t> result(s.lens().size());
std::iota(result.begin(), result.end(), 0);
std::stable_sort(result.begin(), result.end(), by(std::greater<>{}, [&](auto x) {
return std::make_tuple(s.strides()[x], s.lens()[x]);
}));
return result;
}
std::vector<int64_t> find_permutation(const std::vector<shape>& shapes)
{
if(shapes.empty())
return {};
std::map<std::vector<int64_t>, std::size_t> count;
for(auto&& s : shapes)
{
if(s.broadcasted())
continue;
count[find_permutation(s)]++;
}
if(count.empty())
{
std::vector<int64_t> r(shapes.front().lens().size());
std::iota(r.begin(), r.end(), 0);
return r;
}
auto it = std::max_element(
count.begin(), count.end(), by(std::less<>{}, [](auto&& p) { return p.second; }));
assert(it != count.end());
return it->first;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment