Commit 11e155c2 authored by Paul's avatar Paul
Browse files

Merge

parents 8a9c5bce aa7ff911
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_eyelike : op_parser<parse_eyelike>
{
std::vector<op_desc> operators() const { return {{"EyeLike"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto input_shape = args[0]->get_shape();
auto input_lens = input_shape.lens();
if(input_lens.size() != 2)
{
MIGRAPHX_THROW("EYELIKE: tensor input not of rank 2");
}
std::ptrdiff_t num_rows = input_lens.front();
std::ptrdiff_t num_cols = input_lens.back();
shape::type_t output_type = args[0]->get_shape().type();
if(contains(info.attributes, "dtype"))
{
output_type = get_type(info.attributes.at("dtype").i());
}
std::ptrdiff_t k = 0;
if(contains(info.attributes, "k"))
{
k = info.attributes.at("k").i();
}
if(k >= 0)
{
if(k >= num_cols)
{
std::ostringstream oss;
oss << "EYELIKE: positive k out of bounds, k = " << k << " num_cols = " << num_cols;
MIGRAPHX_THROW(oss.str());
}
}
else
{
if(std::abs(k) >= num_rows)
{
std::ostringstream oss;
oss << "EYELIKE: negative k out of bounds, k = " << k << " num_rows = " << num_cols;
MIGRAPHX_THROW(oss.str());
}
}
std::vector<char> eyelike_mat(num_rows * num_cols, 0);
for(std::ptrdiff_t i = 0; i < num_rows; ++i)
{
auto idx = i + k;
if(idx < num_cols and idx >= 0)
eyelike_mat[(num_cols + 1) * i + k] = char{1};
}
return info.add_literal(
migraphx::literal{migraphx::shape{output_type, input_lens}, eyelike_mat});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -10,6 +10,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{
std::vector<op_desc> operators() const
{
// clang-format off
return {{"Abs", "abs"},
{"Acos", "acos"},
{"Acosh", "acosh"},
......@@ -27,7 +28,9 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Flatten", "flatten"},
{"Floor", "floor"},
{"Gather", "gather"},
{"GatherND", "gathernd"},
{"Identity", "identity"},
{"IsNaN", "isnan"},
{"LeakyRelu", "leaky_relu"},
{"Log", "log"},
{"LRN", "lrn"},
......@@ -36,8 +39,6 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Reciprocal", "recip"},
{"Relu", "relu"},
{"Round", "round"},
{"Scatter", "scatter"},
{"ScatterElements", "scatter"},
{"Sigmoid", "sigmoid"},
{"Sign", "sign"},
{"Sin", "sin"},
......@@ -46,6 +47,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Tan", "tan"},
{"Tanh", "tanh"},
{"Not", "not"}};
// clang-format on
}
bool needs_contiguous(const std::string& op_name) const
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
//! Parser for LpNormalization ONNX operator.
/*!
Normalizes a tensor by the L1 or L2 norms along a given axis.
Norms that evaluate to 0 are changed to 1 to prevent division by zero.
*/
struct parse_lpnormalization : op_parser<parse_lpnormalization>
{
std::vector<op_desc> operators() const { return {{"LpNormalization"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int p = 2;
if(contains(info.attributes, "p"))
{
p = info.attributes.at("p").i();
}
if(p != 1 and p != 2)
{
MIGRAPHX_THROW("LPNORMALIZATION: only L1 and L2 norm supported");
}
auto input = args.front();
auto input_shape = input->get_shape();
const auto& input_lens = input_shape.lens();
auto input_type = input_shape.type();
std::ptrdiff_t num_axes = input_lens.size();
std::ptrdiff_t axis = -1;
if(contains(info.attributes, "axis"))
{
axis = info.attributes.at("axis").i();
if(axis < -num_axes or axis >= num_axes)
{
// handled in normalize_attributes but throwing here might be clearer
MIGRAPHX_THROW("LPNORMALIZATION: selected axis out of bounds");
}
}
migraphx::instruction_ref p_val;
if(p == 1)
{
p_val = info.add_instruction(migraphx::make_op("abs"), input);
}
else
{
p_val = info.add_instruction(migraphx::make_op("mul"), input, input);
}
// need to check for zeros from lp norm to prevent division by zero
// change them to 1 for the element-wise division
auto norms =
info.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {axis}}}), p_val);
if(p == 2)
{
norms = info.add_instruction(migraphx::make_op("sqrt"), norms);
}
// broadcast back to initial shape, negative axis option doesn't work with unidirectional
norms = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), norms);
auto zero_mb = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {0.}}));
auto one_mb = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {1.}}));
auto is_zero = info.add_instruction(migraphx::make_op("equal"), norms, zero_mb);
auto norms_zeros_to_one =
info.add_instruction(migraphx::make_op("where"), is_zero, one_mb, norms);
return info.add_instruction(migraphx::make_op("div"), input, norms_zeros_to_one);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -2,6 +2,7 @@
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -9,6 +10,9 @@ namespace onnx {
struct parse_mean : op_parser<parse_mean>
{
const std::set<shape::type_t> float_types = {
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"Mean"}}; }
/// Calculates the element-wise mean of n>=1 input tensors
......@@ -24,14 +28,29 @@ struct parse_mean : op_parser<parse_mean>
auto divisor = info.add_literal(
migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {num_data}});
return std::accumulate(args.begin(), args.end(), args[0], [&](auto& mean, auto& data_i) {
// Pre-divide each tensor element-wise by n to reduce risk of overflow during summation
data_i = info.add_broadcastable_binary_op("div", data_i, divisor);
if(contains(float_types, args[0]->get_shape().type()))
{
return std::accumulate(args.begin() + 1,
args.end(),
info.add_broadcastable_binary_op("div", args[0], divisor),
[&](auto mean, auto data_i) {
// Pre-divide each tensor element-wise by n to reduce risk of
// overflow during summation
auto div =
info.add_broadcastable_binary_op("div", data_i, divisor);
return info.add_broadcastable_binary_op("add", mean, div);
});
}
else
{
// Compute sum before division for integral types
auto sum = std::accumulate(
args.begin() + 1, args.end(), args[0], [&](auto accum, auto data_i) {
return info.add_broadcastable_binary_op("add", accum, data_i);
});
if(data_i != args[0])
return info.add_broadcastable_binary_op("add", mean, data_i);
return data_i;
});
return info.add_broadcastable_binary_op("div", sum, divisor);
}
}
};
......
......@@ -2,6 +2,7 @@
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/padding.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
......@@ -18,7 +19,9 @@ struct parse_pooling : op_parser<parse_pooling>
return {{"AveragePool", "average"},
{"GlobalAveragePool", "average"},
{"GlobalMaxPool", "max"},
{"MaxPool", "max"}};
{"MaxPool", "max"},
{"LpPool", "lpnorm"},
{"GlobalLpPool", "lpnorm"}};
}
instruction_ref parse(const op_desc& opd,
......@@ -26,11 +29,19 @@ struct parse_pooling : op_parser<parse_pooling>
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
const std::unordered_map<std::string, op::pooling_mode> mode_map = {
{"max", op::pooling_mode::max},
{"average", op::pooling_mode::average},
{"lpnorm", op::pooling_mode::lpnorm}};
std::string mode = opd.op_name;
operation op = make_op("pooling", {{"mode", mode}});
value values = op.to_value();
auto l0 = args[0];
auto in_lens = l0->get_shape().lens();
if(not contains(mode_map, mode))
{
MIGRAPHX_THROW("onnx pooling mode must be [\"max\", \"average\", \"lpnorm\"]");
}
operation op = make_op("pooling", {{"mode", mode_map.at(mode)}});
value values = op.to_value();
auto l0 = args[0];
auto in_lens = l0->get_shape().lens();
assert(in_lens.size() > 2);
auto kdims = in_lens.size() - 2;
......@@ -67,11 +78,18 @@ struct parse_pooling : op_parser<parse_pooling>
kdims, values["lengths"].size(), "PARSE_POOLING: inconsistent lengths");
}
// lp_order attribute
if(contains(info.attributes, "p"))
{
values["lp_order"] = info.attributes.at("p").i();
}
// ensure pads availabe only when auto_pad is "NOT_SET"
check_padding_mode(info, "POOLING");
std::vector<int64_t> paddings;
float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
if(contains(info.attributes, "pads"))
{
values["padding"].clear();
......@@ -110,7 +128,7 @@ struct parse_pooling : op_parser<parse_pooling>
std::fill_n(values["stride"].begin(), kdims, 1);
}
// used to calculate the supposed output shape
std::vector<int64_t> orig_padding(paddings.begin(), paddings.end());
std::vector<int64_t> orig_padding = paddings;
std::vector<int64_t> slice_start;
std::vector<int64_t> slice_end;
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
//! Parser for ReverseSequence ONNX operator.
/*!
Reverses the data along the time axis for the batches along the batch axis.
The sequence lengths can be given to reverse up to the given length for each batch, keeping the
rest of the sequence in the original order. Variable sequence_lens is not supported in this
version of MIGraphX. You can pass the sequence_lens either as a constant node or an attribute. The
batch axis and time axis must be [0, 1] and not the same.
*/
struct parse_reversesequence : op_parser<parse_reversesequence>
{
std::vector<op_desc> operators() const { return {{"ReverseSequence"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int batch_axis = 1;
if(contains(info.attributes, "batch_axis"))
{
batch_axis = info.attributes.at("batch_axis").i();
}
if(batch_axis != 0 and batch_axis != 1)
{
MIGRAPHX_THROW("REVERSESEQUENCE: batch axis not 0 or 1");
}
int time_axis = 0;
if(contains(info.attributes, "time_axis"))
{
time_axis = info.attributes.at("time_axis").i();
}
if(time_axis != 0 and time_axis != 1)
{
MIGRAPHX_THROW("REVERSESEQUENCE: time axis not 0 or 1");
}
if(time_axis == batch_axis)
{
MIGRAPHX_THROW("REVERSESEQUENCE: time axis and batch axis are the same");
}
auto input = args[0];
auto input_lens = input->get_shape().lens();
if(input_lens.size() < 2)
{
MIGRAPHX_THROW("REVERSESEQUENCE: input tensor must have rank >= 2");
}
std::vector<int64_t> sequence_lens;
if(args.size() == 2)
{
migraphx::argument seq_lens_arg = args.back()->eval();
check_arg_empty(seq_lens_arg, "REVERSESEQUENCE: cannot handle variable sequence_lens");
seq_lens_arg.visit([&](auto s) { sequence_lens.assign(s.begin(), s.end()); });
}
else if(contains(info.attributes, "sequence_lens"))
{
literal s = parser.parse_value(info.attributes.at("sequence_lens"));
s.visit([&](auto v) { sequence_lens.assign(v.begin(), v.end()); });
}
auto batch_size = input_lens[batch_axis];
auto time_size = input_lens[time_axis];
// this condition may still work if sequence_len's shape was incorrect
if(sequence_lens.size() != batch_size)
{
MIGRAPHX_THROW("REVERSESEQUENCE: sequence_lens has incorrect shape");
}
instruction_ref ret;
auto add_slice = [&info, &input, batch_axis, time_axis](int b, int t_start, int t_end) {
return info.add_instruction(make_op("slice",
{{"axes", {batch_axis, time_axis}},
{"starts", {b, t_start}},
{"ends", {b + 1, t_end}}}),
input);
};
for(int b = 0; b < batch_size; ++b)
{
instruction_ref s0;
if(sequence_lens[b] > 1)
{
s0 = add_slice(b, 0, sequence_lens[b]);
s0 = info.add_instruction(make_op("reverse", {{"axes", {time_axis}}}), s0);
// if reversed less than whole batch, concat rest of batch
if(sequence_lens[b] < time_size)
{
auto s1 = add_slice(b, sequence_lens[b], time_size);
s0 = info.add_instruction(make_op("concat", {{"axis", time_axis}}), s0, s1);
}
}
else
{ // cases where nothing changes
s0 = add_slice(b, 0, time_size);
}
if(b == 0)
{
ret = s0;
}
else
{
ret = info.add_instruction(make_op("concat", {{"axis", batch_axis}}), ret, s0);
}
}
return ret;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/op/common.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
......@@ -28,10 +29,14 @@ struct parse_roialign : op_parser<parse_roialign>
"\": invalid value!");
}
std::string mode = "avg";
migraphx::op::pooling_mode rmode(migraphx::op::pooling_mode::average);
if(contains(info.attributes, "mode"))
{
mode = info.attributes.at("mode").s();
// read mode; default is "avg"
if(info.attributes.at("mode").s() == "max")
{
rmode = migraphx::op::pooling_mode::max;
}
}
int64_t output_height = 1;
......@@ -57,10 +62,9 @@ struct parse_roialign : op_parser<parse_roialign>
{
spatial_scale = info.attributes.at("spatial_scale").f();
}
return info.add_instruction(make_op("roialign",
{{"coordinate_transformation_mode", coord_trans_mode},
{"mode", mode},
{"mode", rmode},
{"output_height", output_height},
{"output_width", output_width},
{"sampling_ratio", sampling_ratio},
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_scatter : op_parser<parse_scatter>
{
std::vector<op_desc> operators() const { return {{"ScatterElements"}, {"Scatter"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
operation op;
std::string op_name = "scatter_none";
int axis = 0;
if(contains(info.attributes, "axis"))
axis = info.attributes.at("axis").i();
if(contains(info.attributes, "reduction"))
{
std::string reduction_att(info.attributes.at("reduction").s());
// check for a valid reduction attribute. We have an operator for each one.
if(not contains({"none", "add", "mul"}, reduction_att))
MIGRAPHX_THROW("PARSE_SCATTER: unsupported reduction mode " + reduction_att);
// merge scatter with reduction attribute to specify which scatter operation. Future
// reduction op names should follow this pattern and should also be added to the check
// above.
op_name = std::string("scatter_") + reduction_att;
}
op = migraphx::make_op(op_name, {{"axis", axis}});
return info.add_instruction(op, args);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_scatternd : op_parser<parse_scatternd>
{
std::vector<op_desc> operators() const { return {{"ScatterND"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref>& args) const
{
if(contains(info.attributes, "reduction"))
{
if(info.attributes.at("reduction").s() == "add")
return info.add_instruction(migraphx::make_op("scatternd_add"), args);
if(info.attributes.at("reduction").s() == "mul")
return info.add_instruction(migraphx::make_op("scatternd_mul"), args);
}
return info.add_instruction(migraphx::make_op("scatternd_none"), args);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_size : op_parser<parse_size>
{
std::vector<op_desc> operators() const { return {{"Size"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
return info.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type},
{args[0]->get_shape().elements()}});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -30,11 +30,11 @@ struct parse_squeeze : op_parser<parse_squeeze>
std::vector<instruction_ref> args) const
{
auto op = parser.load(opd.op_name, info);
std::vector<int64_t> axes;
if(args.size() == 2)
{
auto arg_axes = args.at(1)->eval();
check_arg_empty(arg_axes, "PARSE_" + opd.op_name + ": cannot handle variable axes!");
std::vector<int64_t> axes;
arg_axes.visit([&](auto s) { axes.assign(s.begin(), s.end()); });
op = assign_axes(op, axes);
}
......
//
// Supporting functions for enum values used in operator parameters.
// These values are declared as "enum class" and should include << streaming operators
// to be able to write their values in human-readable format so users can
// save and edit model files.
//
#include <sstream>
#include <migraphx/op/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
std::ostream& operator<<(std::ostream& os, pooling_mode v)
{
// the strings for the enum are the same as the values used for onnx parsing
// but this enum is not onnx-specific: strings must be converted when parsing tf
static const std::vector<std::string> pooling_mode_str = {"average", "max", "lpnorm"};
os << pooling_mode_str[static_cast<std::underlying_type<pooling_mode>::type>(v)];
return os;
}
std::ostream& operator<<(std::ostream& os, rnn_direction v)
{
static const std::vector<std::string> rnn_direction_str = {
"forward", "reverse", "bidirectional"};
os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
return os;
}
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -4,11 +4,11 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring::apply(module& p) const
void memory_coloring::apply(module& m) const
{
if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
{
memory_coloring_impl opt(&p, allocation_op, verify);
memory_coloring_impl opt(&m, allocation_op, verify);
opt.run();
}
}
......
......@@ -20,7 +20,6 @@ int exec(const std::string& cmd, const std::function<void(const char*)>& std_out
int ec = 0;
if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{}))
std::cout << cmd << std::endl;
std::array<char, 128> buffer;
auto closer = [&](FILE* stream) {
auto status = pclose(stream);
ec = WIFEXITED(status) ? 0 : WEXITSTATUS(status); // NOLINT
......@@ -30,6 +29,7 @@ int exec(const std::string& cmd, const std::function<void(const char*)>& std_out
std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), "r"), closer); // NOLINT
if(!pipe)
MIGRAPHX_THROW("popen() failed: " + cmd);
std::array<char, 128> buffer;
while(fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr)
std_out(buffer.data());
}
......
......@@ -353,13 +353,20 @@ std::vector<argument> program::eval(parameter_map params) const
if(trace_level > 0)
{
std::unordered_map<instruction_ref, std::string> ins_out;
// get instruction names
this->print([&](auto x, auto ins_names) {
std::stringstream ss;
instruction::print(ss, x, ins_names);
ins_out[x] = ss.str();
});
return generic_eval(*this,
ctx,
std::move(params),
with_check_context([&](auto& ins, auto f, auto&& check_context) {
ctx.finish();
std::cout << "Run instruction: ";
this->debug_print(ins);
std::cout << "Run instruction: " << ins_out.at(ins) << std::endl;
timer t{};
auto result = check_context(f);
double t1 = t.record<milliseconds>();
......@@ -742,6 +749,14 @@ void program::print(
}
}
void program::print(
const std::function<void(instruction_ref ins,
std::unordered_map<instruction_ref, std::string>)>& print_func) const
{
std::unordered_map<instruction_ref, std::string> names;
this->print(names, print_func);
}
void program::print_graph(std::ostream& os, bool brief) const
{
const auto* mm = this->get_main_module();
......@@ -816,11 +831,12 @@ void generic_get_unused_modules(Map& m, const std::vector<T*>& mods, OutputItera
std::transform(mods.begin(), mods.end(), std::inserter(used, used.end()), [](auto&& mod) {
return mod->name();
});
transform_if(m.begin(),
m.end(),
out,
[&](auto&& pp) { return not contains(used, pp.first); },
[](auto&& pp) { return &pp.second; });
transform_if(
m.begin(),
m.end(),
out,
[&](auto&& pp) { return not contains(used, pp.first); },
[](auto&& pp) { return &pp.second; });
}
std::vector<const module*> program::get_modules() const
......
......@@ -3,6 +3,7 @@
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp>
#include <unordered_set>
namespace migraphx {
......@@ -20,33 +21,42 @@ bool skip_propogate(instruction_ref ins)
return false;
}
void propagate_constant::apply(module& p) const
bool is_const(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); }
void propagate_constant::apply(module& m) const
{
for(auto i : iterator_for(p))
std::unordered_set<instruction_ref> const_instrs;
auto last = std::prev(m.end());
// Find instructions that can be evaluated to a literal
for(auto i : iterator_for(m))
{
if(i->name() != "@literal")
if(is_const(i) and i != last)
continue;
if(i->outputs().empty())
continue;
fix([&](auto self, auto ins) {
std::unordered_set<instruction_ref> children(ins->outputs().begin(),
ins->outputs().end());
for(auto child : children)
{
if(child->name() == "@literal" or skip_propogate(child))
{
self(child);
continue;
}
auto r = child->eval();
if(not r.empty())
{
assert(r.get_shape() == child->get_shape());
auto l = p.add_literal(r.get_shape(), r.data());
self(p.replace_instruction(child, l));
}
}
})(i);
std::copy_if(
i->inputs().begin(),
i->inputs().end(),
std::inserter(const_instrs, const_instrs.begin()),
[&](const instruction_ref ins) { return is_const(ins) and ins->name() != "@literal"; });
}
// Compute literals in parallel
std::vector<instruction_ref> const_instrs_vec{const_instrs.begin(), const_instrs.end()};
std::vector<argument> literals(const_instrs_vec.size());
par_for(const_instrs_vec.size(), 1, [&](const auto i) {
literals[i] = const_instrs_vec[i]->eval();
});
// Replace instructions in m
for(size_t i = 0; i < const_instrs_vec.size(); i++)
{
if(not literals[i].empty())
{
assert(literals[i].get_shape() == const_instrs_vec[i]->get_shape());
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
m.replace_instruction(const_instrs_vec[i], l);
}
}
}
......
......@@ -3,8 +3,11 @@
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include <migraphx/program.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/tf.hpp>
......@@ -95,7 +98,6 @@ migraphx::value to_value(py::kwargs kwargs)
auto&& val = arg.second;
visit_py(val, [&](auto py_val) { v[key] = py_val; });
}
return v;
}
} // namespace migraphx
......@@ -211,12 +213,21 @@ migraphx::shape to_shape(const py::buffer_info& info)
MIGRAPHX_PYBIND11_MODULE(migraphx, m)
{
py::class_<migraphx::shape>(m, "shape")
.def(py::init<>())
.def(py::init([](py::kwargs kwargs) {
auto v = migraphx::to_value(kwargs);
auto t = migraphx::shape::parse_type(v.get("type", "float"));
auto lens = v.get<std::size_t>("lens", {1});
if(v.contains("strides"))
return migraphx::shape(t, lens, v.at("strides").to_vector<std::size_t>());
else
return migraphx::shape(t, lens);
}))
.def("type", &migraphx::shape::type)
.def("lens", &migraphx::shape::lens)
.def("strides", &migraphx::shape::strides)
.def("elements", &migraphx::shape::elements)
.def("bytes", &migraphx::shape::bytes)
.def("type_string", &migraphx::shape::type_string)
.def("type_size", &migraphx::shape::type_size)
.def("packed", &migraphx::shape::packed)
.def("transposed", &migraphx::shape::transposed)
......@@ -247,13 +258,46 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target");
py::class_<migraphx::module>(m, "module")
py::class_<migraphx::instruction_ref>(m, "instruction_ref");
py::class_<migraphx::module, std::unique_ptr<migraphx::module, py::nodelete>>(m, "module")
.def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; })
.def("__eq__", std::equal_to<migraphx::module>{})
.def("__ne__", std::not_equal_to<migraphx::module>{})
.def(
"add_instruction",
[](migraphx::module& mm,
const migraphx::operation& op,
std::vector<migraphx::instruction_ref>& args,
std::vector<migraphx::module*>& mod_args) {
return mm.add_instruction(op, args, mod_args);
},
py::arg("op"),
py::arg("args"),
py::arg("mod_args") = std::vector<migraphx::module*>{})
.def(
"add_literal",
[](migraphx::module& mm, py::buffer data) {
py::buffer_info info = data.request();
auto literal_shape = to_shape(info);
return mm.add_literal(literal_shape, reinterpret_cast<char*>(info.ptr));
},
py::arg("data"))
.def(
"add_parameter",
[](migraphx::module& mm, const std::string& name, const migraphx::shape shape) {
return mm.add_parameter(name, shape);
},
py::arg("name"),
py::arg("shape"))
.def(
"add_return",
[](migraphx::module& mm, std::vector<migraphx::instruction_ref>& args) {
return mm.add_return(args);
},
py::arg("args"))
.def("__repr__", [](const migraphx::module& mm) { return migraphx::to_string(mm); });
py::class_<migraphx::program>(m, "program")
.def(py::init([]() { return migraphx::program(); }))
.def("get_parameter_names", &migraphx::program::get_parameter_names)
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("get_output_shapes", &migraphx::program::get_output_shapes)
......@@ -268,11 +312,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("t"),
py::arg("offload_copy") = true,
py::arg("fast_math") = true)
.def("get_main_module",
[](migraphx::program& p) {
auto* mm = p.get_main_module();
return *mm;
})
.def("get_main_module", [](const migraphx::program& p) { return p.get_main_module(); })
.def(
"create_module",
[](migraphx::program& p, const std::string& name) { return p.create_module(name); },
py::arg("name"))
.def("run",
[](migraphx::program& p, py::dict params) {
migraphx::parameter_map pm;
......@@ -303,89 +347,94 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("name", &migraphx::operation::name);
m.def("parse_tf",
[](const std::string& filename,
bool is_nhwc,
unsigned int batch_size,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::vector<std::string> output_names) {
return migraphx::parse_tf(
filename,
migraphx::tf_options{is_nhwc, batch_size, map_input_dims, output_names});
},
"Parse tf protobuf (default format is nhwc)",
py::arg("filename"),
py::arg("is_nhwc") = true,
py::arg("batch_size") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("output_names") = std::vector<std::string>());
m.def("parse_onnx",
[](const std::string& filename,
unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool skip_unknown_operators,
bool print_program_on_error,
int64_t max_loop_iterations) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.map_input_dims = map_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
options.max_loop_iterations = max_loop_iterations;
return migraphx::parse_onnx(filename, options);
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10);
m.def("parse_onnx_buffer",
[](const std::string& onnx_buffer,
unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool skip_unknown_operators,
bool print_program_on_error) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.map_input_dims = map_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
return migraphx::parse_onnx_buffer(onnx_buffer, options);
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false);
m.def("load",
[](const std::string& name, const std::string& format) {
migraphx::file_options options;
options.format = format;
return migraphx::load(name, options);
},
"Load MIGraphX program",
py::arg("filename"),
py::arg("format") = "msgpack");
m.def("save",
[](const migraphx::program& p, const std::string& name, const std::string& format) {
migraphx::file_options options;
options.format = format;
return migraphx::save(p, name, options);
},
"Save MIGraphX program",
py::arg("p"),
py::arg("filename"),
py::arg("format") = "msgpack");
m.def(
"parse_tf",
[](const std::string& filename,
bool is_nhwc,
unsigned int batch_size,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::vector<std::string> output_names) {
return migraphx::parse_tf(
filename, migraphx::tf_options{is_nhwc, batch_size, map_input_dims, output_names});
},
"Parse tf protobuf (default format is nhwc)",
py::arg("filename"),
py::arg("is_nhwc") = true,
py::arg("batch_size") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("output_names") = std::vector<std::string>());
m.def(
"parse_onnx",
[](const std::string& filename,
unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool skip_unknown_operators,
bool print_program_on_error,
int64_t max_loop_iterations) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.map_input_dims = map_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
options.max_loop_iterations = max_loop_iterations;
return migraphx::parse_onnx(filename, options);
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10);
m.def(
"parse_onnx_buffer",
[](const std::string& onnx_buffer,
unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
bool skip_unknown_operators,
bool print_program_on_error) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.map_input_dims = map_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
return migraphx::parse_onnx_buffer(onnx_buffer, options);
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false);
m.def(
"load",
[](const std::string& name, const std::string& format) {
migraphx::file_options options;
options.format = format;
return migraphx::load(name, options);
},
"Load MIGraphX program",
py::arg("filename"),
py::arg("format") = "msgpack");
m.def(
"save",
[](const migraphx::program& p, const std::string& name, const std::string& format) {
migraphx::file_options options;
options.format = format;
return migraphx::save(p, name, options);
},
"Save MIGraphX program",
py::arg("p"),
py::arg("filename"),
py::arg("format") = "msgpack");
m.def("get_target", &migraphx::make_target);
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
m.def("fill_argument", &migraphx::fill_argument, py::arg("s"), py::arg("value"));
m.def("quantize_fp16",
&migraphx::quantize_fp16,
py::arg("prog"),
......
......@@ -16,10 +16,8 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
auto bstride = s.strides()[n + 1];
auto blen = s.lens()[n + 1];
if(astride == bstride * blen)
{
if(astride == bstride * blen or alen == 1)
new_lens.push_back(alen * blen);
}
}
if(new_lens.size() != shapes.size())
return false;
......@@ -37,12 +35,25 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
return true;
}
std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n)
void reduce_dim1(std::vector<shape>& shapes)
{
while(reduce_dim(shapes, n) and n < shapes.size())
if(std::any_of(shapes.begin(), shapes.end(), [&](const auto& s) {
return s.lens().size() < 2 or s.lens().back() != 1;
}))
return;
for(auto& s : shapes)
{
auto lens = s.lens();
auto strides = s.strides();
lens.pop_back();
strides.pop_back();
s = shape{s.type(), lens, strides};
}
}
std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n)
{
while(reduce_dim(shapes, n) and n < shapes.size()) {}
return n + 1;
}
void reduce_dim_all(std::vector<shape>& shapes)
......@@ -50,6 +61,7 @@ void reduce_dim_all(std::vector<shape>& shapes)
std::size_t n = 0;
while(n < shapes.front().lens().size() - 1)
n = reduce_dim_all(shapes, n);
reduce_dim1(shapes);
}
std::vector<std::size_t> base_lens(const std::vector<shape>& shapes)
......
#include <migraphx/register_target.hpp>
#include <unordered_map>
#include <migraphx/register_target.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -11,7 +11,17 @@ std::unordered_map<std::string, target>& target_map()
}
void register_target(const target& t) { target_map()[t.name()] = t; }
target make_target(const std::string& name) { return target_map().at(name); }
target make_target(const std::string& name)
{
const auto it = target_map().find(name);
if(it == target_map().end())
{
MIGRAPHX_THROW("Requested target '" + name + "' is not enabled or not supported");
}
return it->second;
}
std::vector<std::string> get_targets()
{
std::vector<std::string> result;
......
......@@ -14,9 +14,9 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_batchnorm::apply(module& p) const
void rewrite_batchnorm::apply(module& m) const
{
for(auto ins : iterator_for(p))
for(auto ins : iterator_for(m))
{
if(ins->name() != "batch_norm_inference")
continue;
......@@ -46,13 +46,13 @@ void rewrite_batchnorm::apply(module& p) const
});
auto broadcast = op::broadcast{1, ins->get_shape().lens()};
auto a_ins = p.add_literal({a.get_shape(), a.data()});
auto a_broadcast = p.insert_instruction(ins, broadcast, a_ins);
auto mul = p.insert_instruction(ins, make_op("mul"), ins->inputs().front(), a_broadcast);
auto b_ins = p.add_literal({b.get_shape(), b.data()});
auto b_broadcast = p.insert_instruction(ins, broadcast, b_ins);
auto add = p.insert_instruction(ins, make_op("add"), mul, b_broadcast);
p.replace_instruction(ins, add);
auto a_ins = m.add_literal({a.get_shape(), a.data()});
auto a_broadcast = m.insert_instruction(ins, broadcast, a_ins);
auto mul = m.insert_instruction(ins, make_op("mul"), ins->inputs().front(), a_broadcast);
auto b_ins = m.add_literal({b.get_shape(), b.data()});
auto b_broadcast = m.insert_instruction(ins, broadcast, b_ins);
auto add = m.insert_instruction(ins, make_op("add"), mul, b_broadcast);
m.replace_instruction(ins, add);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment